Video-Text-to-Text
Transformers
Safetensors
English
qwen2_5_vl
video-scene-graph
scene-graph-generation
video-understanding
trajectory-aware
perceiver-resampler
qwen2.5-vl
text-generation-inference
Instructions to use UWGZQ/TRASER with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use UWGZQ/TRASER with Transformers:
# Load model directly from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration_Insert processor = AutoProcessor.from_pretrained("UWGZQ/TRASER") model = Qwen2_5_VLForConditionalGeneration_Insert.from_pretrained("UWGZQ/TRASER") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.nn.functional as F | |
| from typing import List, Optional, Tuple | |
| import math | |
| def rearrange_token( | |
| model, | |
| input_ids: torch.LongTensor, | |
| attention_mask: torch.LongTensor, | |
| pixel_values: Optional[torch.FloatTensor], | |
| image_grid_thw: Optional[torch.LongTensor], | |
| pixel_values_videos: Optional[torch.FloatTensor], | |
| video_grid_thw: Optional[torch.LongTensor], | |
| second_per_grid_ts: Optional[torch.Tensor], | |
| obj_token_indices_per_sample: List[List[torch.Tensor]], | |
| obj_traj_start_id: Optional[int] = None, | |
| obj_traj_end_id: Optional[int] = None, | |
| text_token_ids_per_sample: Optional[List[List[torch.Tensor]]] = None, | |
| timestamp_token_ids_per_batch=None, | |
| grids_per_temporal_window_per_batch=None, | |
| labels: Optional[torch.LongTensor] = None, | |
| IGNORE_ID: int = -100, | |
| use_resampler: bool = True, | |
| use_second_resampler: bool = True, | |
| add_timestamp_token: bool = True, | |
| ): | |
| dev = input_ids.device | |
| B, L = input_ids.shape | |
| cpu = torch.device("cpu") | |
| assert text_token_ids_per_sample is not None and len(text_token_ids_per_sample) == B, \ | |
| "mode3_traj_and_text requires text_token_ids_per_sample with length B." | |
| if add_timestamp_token: | |
| assert timestamp_token_ids_per_batch is not None and len(timestamp_token_ids_per_batch) == B, \ | |
| "add_timestamp_token=True requires timestamp_token_ids_per_batch with length B." | |
| assert grids_per_temporal_window_per_batch is not None and len(grids_per_temporal_window_per_batch) == B, \ | |
| "add_timestamp_token=True requires grids_per_temporal_window_per_batch with length B." | |
| else: | |
| assert grids_per_temporal_window_per_batch is not None and len(grids_per_temporal_window_per_batch) == B, \ | |
| "grids_per_temporal_window_per_batch is required." | |
| tok_embed = model.get_input_embeddings() | |
| vt_id = int(model.config.video_token_id) | |
| vs_id = getattr(model.config, "vision_start_token_id", None) | |
| ve_id = getattr(model.config, "vision_end_token_id", None) | |
| pad_id = 151643 | |
| # ---- (0+) temporal window meta ---- | |
| assert video_grid_thw is not None, "video_grid_thw is required for temporal windowing" | |
| assert video_grid_thw.shape[0] == B and video_grid_thw.shape[1] == 3, \ | |
| f"video_grid_thw should be ({B},3), got {video_grid_thw.shape}" | |
| grid_area_batch: List[int] = [] | |
| temporal_window_size_batch = grids_per_temporal_window_per_batch | |
| # ---- (0) Compute visual features (with grad) ---- | |
| video_embeds = None | |
| if pixel_values_videos is not None: | |
| _vid = model.model.get_video_features( | |
| pixel_values_videos.type(model.model.visual.dtype), video_grid_thw | |
| ) | |
| video_embeds = torch.cat(_vid, dim=0) if isinstance(_vid, (list, tuple)) else _vid | |
| del pixel_values_videos, _vid | |
| # ---- (0.1) Resamplers ---- | |
| resampler = None | |
| resampler_num_latents = None | |
| second_resampler = None | |
| second_resampler_num_latents = None | |
| if use_resampler: | |
| if not hasattr(model, "perceiver_resampler"): | |
| raise RuntimeError("use_resampler=True, but model.perceiver_resampler not found.") | |
| resampler = model.perceiver_resampler | |
| resampler_num_latents = int(resampler.n_latents) | |
| if use_second_resampler: | |
| if not hasattr(model, "second_perceiver_resampler"): | |
| raise RuntimeError("use_second_resampler=True, but model.second_perceiver_resampler not found.") | |
| second_resampler = model.second_perceiver_resampler | |
| second_resampler_num_latents = int(second_resampler.n_latents) | |
| # ---- (1) Position ids preparation ---- | |
| position_ids_full = None | |
| # ---- (2) Move to CPU for sequence planning ---- | |
| attn_cpu = attention_mask.to(cpu, dtype=torch.bool) | |
| ids_cpu = input_ids.to(cpu) | |
| pid_cpu = None | |
| lbls_cpu = labels.to(cpu) if labels is not None else None | |
| eff_lens: List[int] = [] | |
| vid_idx_list: List[torch.Tensor] = [] | |
| for b in range(B): | |
| video_grid_thw_b = video_grid_thw[b] | |
| grid_area = (int(video_grid_thw_b[1].item()) * int(video_grid_thw_b[2].item())) // 4 | |
| grid_area_batch.append(int(grid_area)) | |
| nz = torch.nonzero(attn_cpu[b], as_tuple=False).flatten() | |
| L_eff = int(nz[-1].item()) + 1 if nz.numel() > 0 else 0 | |
| eff_lens.append(L_eff) | |
| if L_eff > 0: | |
| ids_b_eff = ids_cpu[b, :L_eff] | |
| vid_idx = torch.nonzero(ids_b_eff == vt_id, as_tuple=False).flatten() | |
| vid_idx_list.append(vid_idx) | |
| else: | |
| vid_idx_list.append(torch.empty(0, dtype=torch.long)) | |
| vid_counts = [int(v.numel()) for v in vid_idx_list] | |
| vid_offsets: List[int] = [0] * B | |
| running = 0 | |
| for b in range(B): | |
| vid_offsets[b] = running | |
| running += vid_counts[b] | |
| # ---- (3) Length planning ---- | |
| def _object_block_len(b: int, obj_i: int, sel_latent_len: int, rel_temporal_window_idx: torch.Tensor) -> int: | |
| add = 0 | |
| if obj_traj_start_id is not None: | |
| add += 1 | |
| tlen = int(text_token_ids_per_sample[b][obj_i].numel()) | |
| add += tlen | |
| if vs_id is not None: | |
| add += 1 | |
| if add_timestamp_token and timestamp_token_ids_per_batch is not None: | |
| locs = rel_temporal_window_idx.unique() | |
| for loc in locs: | |
| loc_i = int(loc.item()) | |
| if loc_i < len(timestamp_token_ids_per_batch[b]): | |
| add += int(timestamp_token_ids_per_batch[b][loc_i].numel()) | |
| else: | |
| add += int(timestamp_token_ids_per_batch[b][-1].numel()) | |
| add += int(sel_latent_len) | |
| # VE | |
| if ve_id is not None: | |
| add += 1 | |
| if obj_traj_end_id is not None: | |
| add += 1 | |
| return add | |
| L_new_each: List[int] = [] | |
| for b in range(B): | |
| L_eff = eff_lens[b] | |
| ids_b = ids_cpu[b, :L_eff] | |
| vid_idx = vid_idx_list[b] | |
| if L_eff == 0: | |
| L_new_each.append(0) | |
| continue | |
| if vid_idx.numel() == 0: | |
| L_new_each.append(L_eff) | |
| continue | |
| v_s = int(vid_idx[0].item()) | |
| v_e = int(vid_idx[-1].item()) | |
| has_vs = (vs_id is not None and v_s - 1 >= 0 and ids_b[v_s - 1].item() == vs_id) | |
| has_ve = (ve_id is not None and v_e + 1 < L_eff and ids_b[v_e + 1].item() == ve_id) | |
| if has_vs: | |
| v_s -= 1 | |
| if has_ve: | |
| v_e += 1 | |
| prefix_len = v_s | |
| suffix_len = L_eff - (v_e + 1) | |
| sel_lists = obj_token_indices_per_sample[b] | |
| Nv = int(vid_idx.numel()) | |
| cur_total = 0 | |
| for i, rel in enumerate(sel_lists): | |
| rel = rel.to(cpu, dtype=torch.long) | |
| sel_len = int(rel.numel()) | |
| tokens_per_window = int(grid_area_batch[b] * int(temporal_window_size_batch[b])) | |
| rel_temporal_window_idx = rel // tokens_per_window if (tokens_per_window > 0) else torch.zeros_like(rel) | |
| nonempty_windows = int(rel_temporal_window_idx.unique().numel()) | |
| if use_second_resampler and second_resampler_num_latents is not None: | |
| sel_len = int(second_resampler_num_latents) + int(resampler_num_latents) * nonempty_windows | |
| else: | |
| sel_len = int(resampler_num_latents) * nonempty_windows | |
| cur_total += _object_block_len(b, i, sel_len, rel_temporal_window_idx) | |
| L_new_each.append(prefix_len + cur_total + suffix_len) | |
| Lmax = max(L_new_each) if len(L_new_each) > 0 else 0 | |
| # ---- (4) Allocate new sequence tensors on CPU and fill per-sample ---- | |
| new_input_ids_cpu = torch.full((B, Lmax), pad_id, dtype=torch.long, device=cpu) | |
| new_attention_mask_cpu = torch.zeros((B, Lmax), dtype=torch.bool, device=cpu) | |
| new_position_ids_cpu = torch.zeros((3, B, Lmax), dtype=torch.int32, device=cpu) | |
| new_labels_cpu = None | |
| if labels is not None: | |
| new_labels_cpu = torch.full((B, Lmax), IGNORE_ID, dtype=torch.long, device=cpu) | |
| rows_for_video: List[torch.Tensor] = [torch.empty(0, dtype=torch.long) for _ in range(B)] | |
| batched_obj_rows: List[torch.Tensor] = [] | |
| batched_obj_pos: List[torch.Tensor] = [] | |
| batched_obj_bids: List[int] = [] | |
| batched_obj_lens: List[int] = [] | |
| batched_second_rows: List[torch.Tensor] = [] | |
| batched_second_pos: List[torch.Tensor] = [] | |
| batched_second_bids: List[int] = [] | |
| batched_second_oids: List[int] = [] | |
| def _text_pos_block(start_scalar: int, length: int, dtype=torch.int32) -> torch.Tensor: | |
| """Create 1D-linear positions replicated across 3 RoPE dims.""" | |
| if length <= 0: | |
| return torch.empty(3, 0, dtype=dtype, device=cpu) | |
| ar = torch.arange(start_scalar, start_scalar + length, device=cpu, dtype=dtype) | |
| return torch.stack([ar, ar, ar], dim=0) | |
| for b in range(B): | |
| L_eff = eff_lens[b] | |
| if L_eff == 0: | |
| continue | |
| ids_b = ids_cpu[b, :L_eff] | |
| msk_b = attn_cpu[b, :L_eff] | |
| labs_b = lbls_cpu[b, :L_eff] if lbls_cpu is not None else None | |
| vid_idx = vid_idx_list[b] | |
| dst = 0 | |
| if vid_idx.numel() == 0: | |
| new_input_ids_cpu[b, :L_eff] = ids_b | |
| new_attention_mask_cpu[b, :L_eff] = msk_b | |
| if new_labels_cpu is not None and labs_b is not None: | |
| new_labels_cpu[b, :L_eff] = labs_b | |
| new_position_ids_cpu[:, b, :L_eff] = _text_pos_block(0, L_eff, dtype=torch.int32) | |
| continue | |
| v_s = int(vid_idx[0].item()) | |
| v_e = int(vid_idx[-1].item()) | |
| has_vs = (vs_id is not None and v_s - 1 >= 0 and ids_b[v_s - 1].item() == vs_id) | |
| has_ve = (ve_id is not None and v_e + 1 < L_eff and ids_b[v_e + 1].item() == ve_id) | |
| if has_vs: | |
| v_s -= 1 | |
| if has_ve: | |
| v_e += 1 | |
| prefix_len = v_s | |
| suffix_len = L_eff - (v_e + 1) | |
| if prefix_len > 0: | |
| new_input_ids_cpu[b, dst:dst + prefix_len] = ids_b[:prefix_len] | |
| new_attention_mask_cpu[b, dst:dst + prefix_len] = msk_b[:prefix_len] | |
| if new_labels_cpu is not None and labs_b is not None: | |
| new_labels_cpu[b, dst:dst + prefix_len] = labs_b[:prefix_len] | |
| new_position_ids_cpu[:, b, dst:dst + prefix_len] = _text_pos_block(dst, prefix_len, dtype=torch.int32) | |
| dst += prefix_len | |
| Nv = int(vid_idx.numel()) | |
| pos2rank = torch.full((L_eff,), -1, dtype=torch.long, device=cpu) | |
| if Nv > 0: | |
| pos2rank[vid_idx] = torch.arange(Nv, dtype=torch.long, device=cpu) | |
| vid_offset = int(vid_offsets[b]) | |
| sel_lists = obj_token_indices_per_sample[b] | |
| for i, rel in enumerate(sel_lists): | |
| rel = rel.to(cpu, dtype=torch.long) | |
| if rel.numel() > 0: | |
| rel.clamp_(0, Nv - 1) | |
| g = vid_idx.index_select(0, rel) if (Nv > 0 and rel.numel() > 0) else torch.empty(0, dtype=torch.long, device=cpu) | |
| # (1) <obj_traj_start> (optional) | |
| if obj_traj_start_id is not None: | |
| new_input_ids_cpu[b, dst] = int(obj_traj_start_id) | |
| new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32) | |
| if new_labels_cpu is not None: | |
| new_labels_cpu[b, dst] = IGNORE_ID | |
| new_attention_mask_cpu[b, dst] = True | |
| dst += 1 | |
| # (2) text tokens (required) | |
| txt_ids = text_token_ids_per_sample[b][i].to(cpu, dtype=torch.long) | |
| k = int(txt_ids.numel()) | |
| if k > 0: | |
| new_input_ids_cpu[b, dst:dst + k] = txt_ids | |
| new_position_ids_cpu[:, b, dst:dst + k] = _text_pos_block(dst, k, dtype=torch.int32) | |
| if new_labels_cpu is not None: | |
| new_labels_cpu[b, dst:dst + k] = IGNORE_ID | |
| new_attention_mask_cpu[b, dst:dst + k] = True | |
| dst += k | |
| # (3) <VS> (optional) | |
| if vs_id is not None: | |
| new_input_ids_cpu[b, dst] = int(vs_id) | |
| new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32) | |
| if new_labels_cpu is not None: | |
| new_labels_cpu[b, dst] = IGNORE_ID | |
| new_attention_mask_cpu[b, dst] = True | |
| dst += 1 | |
| # (4) video tokens | |
| if g.numel() > 0: | |
| tokens_per_window = int(grid_area_batch[b] * int(temporal_window_size_batch[b])) | |
| rel_temporal_window_idx = rel // tokens_per_window if (tokens_per_window > 0) else torch.zeros_like(rel) | |
| W_eff = int(rel_temporal_window_idx.max().item()) + 1 if rel_temporal_window_idx.numel() > 0 else 0 | |
| all_rows_list = [] | |
| for w in range(W_eff): | |
| m_w = (rel_temporal_window_idx == w) | |
| if not torch.any(m_w): | |
| all_rows_list.append(torch.empty(0, dtype=torch.long, device=cpu)) | |
| continue | |
| rel_w = rel[m_w] | |
| rows_w = rel_w + vid_offset | |
| all_rows_list.append(rows_w) | |
| # second resampler: global object summary | |
| if use_second_resampler and second_resampler is not None: | |
| rows_all = torch.cat([x for x in all_rows_list if x.numel() > 0], dim=0) if any(x.numel() > 0 for x in all_rows_list) \ | |
| else torch.empty(0, dtype=torch.long, device=cpu) | |
| if rows_all.numel() > 0: | |
| R2 = int(second_resampler_num_latents) | |
| new_input_ids_cpu[b, dst:dst + R2] = int(vt_id) | |
| new_position_ids_cpu[:, b, dst:dst + R2] = _text_pos_block( dst, R2, dtype=torch.int32) | |
| if new_labels_cpu is not None: | |
| new_labels_cpu[b, dst:dst + R2] = IGNORE_ID | |
| new_attention_mask_cpu[b, dst:dst + R2] = True | |
| pos_idx2 = torch.arange(dst, dst + R2, dtype=torch.long, device=cpu) | |
| batched_second_rows.append(rows_all) | |
| batched_second_pos.append(pos_idx2) | |
| batched_second_bids.append(b) | |
| batched_second_oids.append(i) | |
| dst += R2 | |
| R = int(resampler_num_latents) | |
| for w in range(W_eff): | |
| m_w = (rel_temporal_window_idx == w) | |
| if not torch.any(m_w): | |
| continue | |
| # timestamp tokens (text-only; NOT injected into resampler) | |
| if add_timestamp_token and (timestamp_token_ids_per_batch is not None): | |
| loc = w | |
| if loc < len(timestamp_token_ids_per_batch[b]): | |
| ts_ids = timestamp_token_ids_per_batch[b][loc].to(cpu, dtype=torch.long) | |
| else: | |
| ts_ids = timestamp_token_ids_per_batch[b][-1].to(cpu, dtype=torch.long) | |
| kt = int(ts_ids.numel()) | |
| assert kt > 0, "Timestamp token ids should not be empty." | |
| new_input_ids_cpu[b, dst:dst + kt] = ts_ids | |
| new_position_ids_cpu[:, b, dst:dst + kt] = _text_pos_block(dst, kt, dtype=torch.int32) | |
| if new_labels_cpu is not None: | |
| new_labels_cpu[b, dst:dst + kt] = IGNORE_ID | |
| new_attention_mask_cpu[b, dst:dst + kt] = True | |
| dst += kt | |
| new_input_ids_cpu[b, dst:dst + R] = int(vt_id) | |
| new_position_ids_cpu[:, b, dst:dst + R] = _text_pos_block(dst, R, dtype=torch.int32) | |
| if new_labels_cpu is not None: | |
| new_labels_cpu[b, dst:dst + R] = IGNORE_ID | |
| new_attention_mask_cpu[b, dst:dst + R] = True | |
| rel_w = rel[m_w] | |
| rows_w = rel_w + vid_offset | |
| pos_idx = torch.arange(dst, dst + R, dtype=torch.long, device=cpu) | |
| batched_obj_rows.append(rows_w) | |
| batched_obj_pos.append(pos_idx) | |
| batched_obj_bids.append(b) | |
| batched_obj_lens.append(int(rows_w.numel())) | |
| dst += R | |
| # (5) <VE> (optional) | |
| if ve_id is not None: | |
| new_input_ids_cpu[b, dst] = int(ve_id) | |
| new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32) | |
| if new_labels_cpu is not None: | |
| new_labels_cpu[b, dst] = IGNORE_ID | |
| new_attention_mask_cpu[b, dst] = True | |
| dst += 1 | |
| # (6) <obj_traj_end> (optional) | |
| if obj_traj_end_id is not None: | |
| new_input_ids_cpu[b, dst] = int(obj_traj_end_id) | |
| new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32) | |
| if new_labels_cpu is not None: | |
| new_labels_cpu[b, dst] = IGNORE_ID | |
| new_attention_mask_cpu[b, dst] = True | |
| dst += 1 | |
| # suffix | |
| if suffix_len > 0: | |
| src_lo = v_e + 1 | |
| src_hi = L_eff | |
| seg = src_hi - src_lo | |
| new_input_ids_cpu[b, dst:dst + seg] = ids_b[src_lo:src_hi] | |
| new_attention_mask_cpu[b, dst:dst + seg] = msk_b[src_lo:src_hi] | |
| if new_labels_cpu is not None and labs_b is not None: | |
| new_labels_cpu[b, dst:dst + seg] = labs_b[src_lo:src_hi] | |
| new_position_ids_cpu[:, b, dst:dst + seg] = _text_pos_block(dst, seg, dtype=torch.int32) | |
| dst += seg | |
| assert dst == L_new_each[b], f"sample {b}: dst={dst}, L_new={L_new_each[b]}" | |
| # ---- (5) Move back to device, build inputs_embeds, and paste visual features ---- | |
| new_input_ids = new_input_ids_cpu.to(dev, non_blocking=True) | |
| new_position_ids = new_position_ids_cpu.to(dev, non_blocking=True) | |
| new_attention_mask = new_attention_mask_cpu.to(dev, non_blocking=True) | |
| new_labels = None if new_labels_cpu is None else new_labels_cpu.to(dev, non_blocking=True) | |
| base = tok_embed(new_input_ids) | |
| new_inputs_embeds = base.clone() | |
| # ---- (5.1) second resampler: object-level global summary ---- | |
| if use_resampler and use_second_resampler and len(batched_second_rows) > 0: | |
| if video_embeds is None: | |
| raise RuntimeError("use_second_resampler=True but video_embeds is None.") | |
| dev_emb = video_embeds.device | |
| dtype_emb = video_embeds.dtype | |
| D = video_embeds.shape[-1] | |
| N_obj2 = len(batched_second_rows) | |
| seqs2 = [] | |
| lens2 = [] | |
| for rows_all in batched_second_rows: | |
| if rows_all.numel() == 0: | |
| seqs2.append(torch.zeros(0, D, device=dev_emb, dtype=dtype_emb)) | |
| lens2.append(0) | |
| else: | |
| seqs2.append(video_embeds.index_select(0, rows_all.to(dev_emb))) | |
| lens2.append(int(rows_all.numel())) | |
| x2 = torch.nn.utils.rnn.pad_sequence(seqs2, batch_first=True) if len(seqs2) > 0 else torch.zeros(0, 0, D, device=dev_emb, dtype=dtype_emb) | |
| L2_max = x2.size(1) if x2.numel() > 0 else 0 | |
| lens2_t = torch.tensor(lens2, device=dev_emb, dtype=torch.long) if len(lens2) > 0 else torch.zeros(0, device=dev_emb, dtype=torch.long) | |
| ar2 = torch.arange(L2_max, device=dev_emb).unsqueeze(0) if L2_max > 0 else torch.zeros(1, 0, device=dev_emb, dtype=torch.long) | |
| mask2 = (ar2 < lens2_t.unsqueeze(1)) if L2_max > 0 else torch.zeros(0, 0, device=dev_emb, dtype=torch.bool) | |
| y2 = second_resampler(x2, attention_mask=mask2) | |
| y2 = y2.to(new_inputs_embeds.dtype) | |
| for j in range(N_obj2): | |
| b_cur = batched_second_bids[j] | |
| pos2 = batched_second_pos[j].to(dev) | |
| new_inputs_embeds[b_cur, pos2] = y2[j] | |
| # ---- (5.2) main resampler: temporal resampler---- | |
| if use_resampler and len(batched_obj_rows) > 0: | |
| if video_embeds is None: | |
| raise RuntimeError("use_resampler=True but video_embeds is None.") | |
| dev_emb = video_embeds.device | |
| dtype_emb = video_embeds.dtype | |
| D = video_embeds.shape[-1] | |
| N_obj = len(batched_obj_rows) | |
| lens = torch.tensor(batched_obj_lens, device=dev_emb, dtype=torch.long) | |
| L_max = int(lens.max().item()) if lens.numel() > 0 else 0 | |
| seqs = [] | |
| for rows in batched_obj_rows: | |
| if rows.numel() == 0: | |
| seqs.append(torch.zeros(0, D, device=dev_emb, dtype=dtype_emb)) | |
| else: | |
| seqs.append(video_embeds.index_select(0, rows.to(dev_emb))) | |
| x = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True) if len(seqs) > 0 else torch.zeros(0, 0, D, device=dev_emb, dtype=dtype_emb) | |
| ar = torch.arange(L_max, device=dev_emb).unsqueeze(0) if L_max > 0 else torch.zeros(1, 0, device=dev_emb, dtype=torch.long) | |
| mask = (ar < lens.unsqueeze(1)) if L_max > 0 else torch.zeros(0, 0, device=dev_emb, dtype=torch.bool) | |
| y = resampler(x, attention_mask=mask) | |
| y = y.to(new_inputs_embeds.dtype) | |
| per_b_indices: List[List[int]] = [[] for _ in range(B)] | |
| for i in range(N_obj): | |
| per_b_indices[batched_obj_bids[i]].append(i) | |
| for b in range(B): | |
| if not per_b_indices[b]: | |
| continue | |
| pos_list = [] | |
| emb_list = [] | |
| for i in per_b_indices[b]: | |
| pos_list.append(batched_obj_pos[i].to(dev)) | |
| emb_list.append(y[i]) | |
| pos_b = torch.cat(pos_list, dim=0) | |
| emb_b = torch.cat(emb_list, dim=0) | |
| new_inputs_embeds[b, pos_b] = emb_b | |
| # ---- (6) rope_deltas / cache_position ---- | |
| maxpos = new_position_ids.max(dim=0)[0].max(dim=1, keepdim=True)[0] | |
| rope_deltas = (maxpos + 1 - new_inputs_embeds.shape[1]).to(dtype=torch.long, device=dev) | |
| cache_position = torch.arange(new_inputs_embeds.shape[1], device=dev, dtype=torch.int32) | |
| return new_inputs_embeds, new_position_ids, new_attention_mask, rope_deltas, cache_position, new_input_ids, new_labels | |