import torch import torch.nn.functional as F from typing import List, Optional, Tuple import math def rearrange_token( model, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, pixel_values: Optional[torch.FloatTensor], image_grid_thw: Optional[torch.LongTensor], pixel_values_videos: Optional[torch.FloatTensor], video_grid_thw: Optional[torch.LongTensor], second_per_grid_ts: Optional[torch.Tensor], obj_token_indices_per_sample: List[List[torch.Tensor]], obj_traj_start_id: Optional[int] = None, obj_traj_end_id: Optional[int] = None, text_token_ids_per_sample: Optional[List[List[torch.Tensor]]] = None, timestamp_token_ids_per_batch=None, grids_per_temporal_window_per_batch=None, labels: Optional[torch.LongTensor] = None, IGNORE_ID: int = -100, use_resampler: bool = True, use_second_resampler: bool = True, add_timestamp_token: bool = True, ): dev = input_ids.device B, L = input_ids.shape cpu = torch.device("cpu") assert text_token_ids_per_sample is not None and len(text_token_ids_per_sample) == B, \ "mode3_traj_and_text requires text_token_ids_per_sample with length B." if add_timestamp_token: assert timestamp_token_ids_per_batch is not None and len(timestamp_token_ids_per_batch) == B, \ "add_timestamp_token=True requires timestamp_token_ids_per_batch with length B." assert grids_per_temporal_window_per_batch is not None and len(grids_per_temporal_window_per_batch) == B, \ "add_timestamp_token=True requires grids_per_temporal_window_per_batch with length B." else: assert grids_per_temporal_window_per_batch is not None and len(grids_per_temporal_window_per_batch) == B, \ "grids_per_temporal_window_per_batch is required." tok_embed = model.get_input_embeddings() vt_id = int(model.config.video_token_id) vs_id = getattr(model.config, "vision_start_token_id", None) ve_id = getattr(model.config, "vision_end_token_id", None) pad_id = 151643 # ---- (0+) temporal window meta ---- assert video_grid_thw is not None, "video_grid_thw is required for temporal windowing" assert video_grid_thw.shape[0] == B and video_grid_thw.shape[1] == 3, \ f"video_grid_thw should be ({B},3), got {video_grid_thw.shape}" grid_area_batch: List[int] = [] temporal_window_size_batch = grids_per_temporal_window_per_batch # ---- (0) Compute visual features (with grad) ---- video_embeds = None if pixel_values_videos is not None: _vid = model.model.get_video_features( pixel_values_videos.type(model.model.visual.dtype), video_grid_thw ) video_embeds = torch.cat(_vid, dim=0) if isinstance(_vid, (list, tuple)) else _vid del pixel_values_videos, _vid # ---- (0.1) Resamplers ---- resampler = None resampler_num_latents = None second_resampler = None second_resampler_num_latents = None if use_resampler: if not hasattr(model, "perceiver_resampler"): raise RuntimeError("use_resampler=True, but model.perceiver_resampler not found.") resampler = model.perceiver_resampler resampler_num_latents = int(resampler.n_latents) if use_second_resampler: if not hasattr(model, "second_perceiver_resampler"): raise RuntimeError("use_second_resampler=True, but model.second_perceiver_resampler not found.") second_resampler = model.second_perceiver_resampler second_resampler_num_latents = int(second_resampler.n_latents) # ---- (1) Position ids preparation ---- position_ids_full = None # ---- (2) Move to CPU for sequence planning ---- attn_cpu = attention_mask.to(cpu, dtype=torch.bool) ids_cpu = input_ids.to(cpu) pid_cpu = None lbls_cpu = labels.to(cpu) if labels is not None else None eff_lens: List[int] = [] vid_idx_list: List[torch.Tensor] = [] for b in range(B): video_grid_thw_b = video_grid_thw[b] grid_area = (int(video_grid_thw_b[1].item()) * int(video_grid_thw_b[2].item())) // 4 grid_area_batch.append(int(grid_area)) nz = torch.nonzero(attn_cpu[b], as_tuple=False).flatten() L_eff = int(nz[-1].item()) + 1 if nz.numel() > 0 else 0 eff_lens.append(L_eff) if L_eff > 0: ids_b_eff = ids_cpu[b, :L_eff] vid_idx = torch.nonzero(ids_b_eff == vt_id, as_tuple=False).flatten() vid_idx_list.append(vid_idx) else: vid_idx_list.append(torch.empty(0, dtype=torch.long)) vid_counts = [int(v.numel()) for v in vid_idx_list] vid_offsets: List[int] = [0] * B running = 0 for b in range(B): vid_offsets[b] = running running += vid_counts[b] # ---- (3) Length planning ---- def _object_block_len(b: int, obj_i: int, sel_latent_len: int, rel_temporal_window_idx: torch.Tensor) -> int: add = 0 if obj_traj_start_id is not None: add += 1 tlen = int(text_token_ids_per_sample[b][obj_i].numel()) add += tlen if vs_id is not None: add += 1 if add_timestamp_token and timestamp_token_ids_per_batch is not None: locs = rel_temporal_window_idx.unique() for loc in locs: loc_i = int(loc.item()) if loc_i < len(timestamp_token_ids_per_batch[b]): add += int(timestamp_token_ids_per_batch[b][loc_i].numel()) else: add += int(timestamp_token_ids_per_batch[b][-1].numel()) add += int(sel_latent_len) # VE if ve_id is not None: add += 1 if obj_traj_end_id is not None: add += 1 return add L_new_each: List[int] = [] for b in range(B): L_eff = eff_lens[b] ids_b = ids_cpu[b, :L_eff] vid_idx = vid_idx_list[b] if L_eff == 0: L_new_each.append(0) continue if vid_idx.numel() == 0: L_new_each.append(L_eff) continue v_s = int(vid_idx[0].item()) v_e = int(vid_idx[-1].item()) has_vs = (vs_id is not None and v_s - 1 >= 0 and ids_b[v_s - 1].item() == vs_id) has_ve = (ve_id is not None and v_e + 1 < L_eff and ids_b[v_e + 1].item() == ve_id) if has_vs: v_s -= 1 if has_ve: v_e += 1 prefix_len = v_s suffix_len = L_eff - (v_e + 1) sel_lists = obj_token_indices_per_sample[b] Nv = int(vid_idx.numel()) cur_total = 0 for i, rel in enumerate(sel_lists): rel = rel.to(cpu, dtype=torch.long) sel_len = int(rel.numel()) tokens_per_window = int(grid_area_batch[b] * int(temporal_window_size_batch[b])) rel_temporal_window_idx = rel // tokens_per_window if (tokens_per_window > 0) else torch.zeros_like(rel) nonempty_windows = int(rel_temporal_window_idx.unique().numel()) if use_second_resampler and second_resampler_num_latents is not None: sel_len = int(second_resampler_num_latents) + int(resampler_num_latents) * nonempty_windows else: sel_len = int(resampler_num_latents) * nonempty_windows cur_total += _object_block_len(b, i, sel_len, rel_temporal_window_idx) L_new_each.append(prefix_len + cur_total + suffix_len) Lmax = max(L_new_each) if len(L_new_each) > 0 else 0 # ---- (4) Allocate new sequence tensors on CPU and fill per-sample ---- new_input_ids_cpu = torch.full((B, Lmax), pad_id, dtype=torch.long, device=cpu) new_attention_mask_cpu = torch.zeros((B, Lmax), dtype=torch.bool, device=cpu) new_position_ids_cpu = torch.zeros((3, B, Lmax), dtype=torch.int32, device=cpu) new_labels_cpu = None if labels is not None: new_labels_cpu = torch.full((B, Lmax), IGNORE_ID, dtype=torch.long, device=cpu) rows_for_video: List[torch.Tensor] = [torch.empty(0, dtype=torch.long) for _ in range(B)] batched_obj_rows: List[torch.Tensor] = [] batched_obj_pos: List[torch.Tensor] = [] batched_obj_bids: List[int] = [] batched_obj_lens: List[int] = [] batched_second_rows: List[torch.Tensor] = [] batched_second_pos: List[torch.Tensor] = [] batched_second_bids: List[int] = [] batched_second_oids: List[int] = [] def _text_pos_block(start_scalar: int, length: int, dtype=torch.int32) -> torch.Tensor: """Create 1D-linear positions replicated across 3 RoPE dims.""" if length <= 0: return torch.empty(3, 0, dtype=dtype, device=cpu) ar = torch.arange(start_scalar, start_scalar + length, device=cpu, dtype=dtype) return torch.stack([ar, ar, ar], dim=0) for b in range(B): L_eff = eff_lens[b] if L_eff == 0: continue ids_b = ids_cpu[b, :L_eff] msk_b = attn_cpu[b, :L_eff] labs_b = lbls_cpu[b, :L_eff] if lbls_cpu is not None else None vid_idx = vid_idx_list[b] dst = 0 if vid_idx.numel() == 0: new_input_ids_cpu[b, :L_eff] = ids_b new_attention_mask_cpu[b, :L_eff] = msk_b if new_labels_cpu is not None and labs_b is not None: new_labels_cpu[b, :L_eff] = labs_b new_position_ids_cpu[:, b, :L_eff] = _text_pos_block(0, L_eff, dtype=torch.int32) continue v_s = int(vid_idx[0].item()) v_e = int(vid_idx[-1].item()) has_vs = (vs_id is not None and v_s - 1 >= 0 and ids_b[v_s - 1].item() == vs_id) has_ve = (ve_id is not None and v_e + 1 < L_eff and ids_b[v_e + 1].item() == ve_id) if has_vs: v_s -= 1 if has_ve: v_e += 1 prefix_len = v_s suffix_len = L_eff - (v_e + 1) if prefix_len > 0: new_input_ids_cpu[b, dst:dst + prefix_len] = ids_b[:prefix_len] new_attention_mask_cpu[b, dst:dst + prefix_len] = msk_b[:prefix_len] if new_labels_cpu is not None and labs_b is not None: new_labels_cpu[b, dst:dst + prefix_len] = labs_b[:prefix_len] new_position_ids_cpu[:, b, dst:dst + prefix_len] = _text_pos_block(dst, prefix_len, dtype=torch.int32) dst += prefix_len Nv = int(vid_idx.numel()) pos2rank = torch.full((L_eff,), -1, dtype=torch.long, device=cpu) if Nv > 0: pos2rank[vid_idx] = torch.arange(Nv, dtype=torch.long, device=cpu) vid_offset = int(vid_offsets[b]) sel_lists = obj_token_indices_per_sample[b] for i, rel in enumerate(sel_lists): rel = rel.to(cpu, dtype=torch.long) if rel.numel() > 0: rel.clamp_(0, Nv - 1) g = vid_idx.index_select(0, rel) if (Nv > 0 and rel.numel() > 0) else torch.empty(0, dtype=torch.long, device=cpu) # (1) (optional) if obj_traj_start_id is not None: new_input_ids_cpu[b, dst] = int(obj_traj_start_id) new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32) if new_labels_cpu is not None: new_labels_cpu[b, dst] = IGNORE_ID new_attention_mask_cpu[b, dst] = True dst += 1 # (2) text tokens (required) txt_ids = text_token_ids_per_sample[b][i].to(cpu, dtype=torch.long) k = int(txt_ids.numel()) if k > 0: new_input_ids_cpu[b, dst:dst + k] = txt_ids new_position_ids_cpu[:, b, dst:dst + k] = _text_pos_block(dst, k, dtype=torch.int32) if new_labels_cpu is not None: new_labels_cpu[b, dst:dst + k] = IGNORE_ID new_attention_mask_cpu[b, dst:dst + k] = True dst += k # (3) (optional) if vs_id is not None: new_input_ids_cpu[b, dst] = int(vs_id) new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32) if new_labels_cpu is not None: new_labels_cpu[b, dst] = IGNORE_ID new_attention_mask_cpu[b, dst] = True dst += 1 # (4) video tokens if g.numel() > 0: tokens_per_window = int(grid_area_batch[b] * int(temporal_window_size_batch[b])) rel_temporal_window_idx = rel // tokens_per_window if (tokens_per_window > 0) else torch.zeros_like(rel) W_eff = int(rel_temporal_window_idx.max().item()) + 1 if rel_temporal_window_idx.numel() > 0 else 0 all_rows_list = [] for w in range(W_eff): m_w = (rel_temporal_window_idx == w) if not torch.any(m_w): all_rows_list.append(torch.empty(0, dtype=torch.long, device=cpu)) continue rel_w = rel[m_w] rows_w = rel_w + vid_offset all_rows_list.append(rows_w) # second resampler: global object summary if use_second_resampler and second_resampler is not None: rows_all = torch.cat([x for x in all_rows_list if x.numel() > 0], dim=0) if any(x.numel() > 0 for x in all_rows_list) \ else torch.empty(0, dtype=torch.long, device=cpu) if rows_all.numel() > 0: R2 = int(second_resampler_num_latents) new_input_ids_cpu[b, dst:dst + R2] = int(vt_id) new_position_ids_cpu[:, b, dst:dst + R2] = _text_pos_block( dst, R2, dtype=torch.int32) if new_labels_cpu is not None: new_labels_cpu[b, dst:dst + R2] = IGNORE_ID new_attention_mask_cpu[b, dst:dst + R2] = True pos_idx2 = torch.arange(dst, dst + R2, dtype=torch.long, device=cpu) batched_second_rows.append(rows_all) batched_second_pos.append(pos_idx2) batched_second_bids.append(b) batched_second_oids.append(i) dst += R2 R = int(resampler_num_latents) for w in range(W_eff): m_w = (rel_temporal_window_idx == w) if not torch.any(m_w): continue # timestamp tokens (text-only; NOT injected into resampler) if add_timestamp_token and (timestamp_token_ids_per_batch is not None): loc = w if loc < len(timestamp_token_ids_per_batch[b]): ts_ids = timestamp_token_ids_per_batch[b][loc].to(cpu, dtype=torch.long) else: ts_ids = timestamp_token_ids_per_batch[b][-1].to(cpu, dtype=torch.long) kt = int(ts_ids.numel()) assert kt > 0, "Timestamp token ids should not be empty." new_input_ids_cpu[b, dst:dst + kt] = ts_ids new_position_ids_cpu[:, b, dst:dst + kt] = _text_pos_block(dst, kt, dtype=torch.int32) if new_labels_cpu is not None: new_labels_cpu[b, dst:dst + kt] = IGNORE_ID new_attention_mask_cpu[b, dst:dst + kt] = True dst += kt new_input_ids_cpu[b, dst:dst + R] = int(vt_id) new_position_ids_cpu[:, b, dst:dst + R] = _text_pos_block(dst, R, dtype=torch.int32) if new_labels_cpu is not None: new_labels_cpu[b, dst:dst + R] = IGNORE_ID new_attention_mask_cpu[b, dst:dst + R] = True rel_w = rel[m_w] rows_w = rel_w + vid_offset pos_idx = torch.arange(dst, dst + R, dtype=torch.long, device=cpu) batched_obj_rows.append(rows_w) batched_obj_pos.append(pos_idx) batched_obj_bids.append(b) batched_obj_lens.append(int(rows_w.numel())) dst += R # (5) (optional) if ve_id is not None: new_input_ids_cpu[b, dst] = int(ve_id) new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32) if new_labels_cpu is not None: new_labels_cpu[b, dst] = IGNORE_ID new_attention_mask_cpu[b, dst] = True dst += 1 # (6) (optional) if obj_traj_end_id is not None: new_input_ids_cpu[b, dst] = int(obj_traj_end_id) new_position_ids_cpu[:, b, dst:dst + 1] = _text_pos_block(dst, 1, dtype=torch.int32) if new_labels_cpu is not None: new_labels_cpu[b, dst] = IGNORE_ID new_attention_mask_cpu[b, dst] = True dst += 1 # suffix if suffix_len > 0: src_lo = v_e + 1 src_hi = L_eff seg = src_hi - src_lo new_input_ids_cpu[b, dst:dst + seg] = ids_b[src_lo:src_hi] new_attention_mask_cpu[b, dst:dst + seg] = msk_b[src_lo:src_hi] if new_labels_cpu is not None and labs_b is not None: new_labels_cpu[b, dst:dst + seg] = labs_b[src_lo:src_hi] new_position_ids_cpu[:, b, dst:dst + seg] = _text_pos_block(dst, seg, dtype=torch.int32) dst += seg assert dst == L_new_each[b], f"sample {b}: dst={dst}, L_new={L_new_each[b]}" # ---- (5) Move back to device, build inputs_embeds, and paste visual features ---- new_input_ids = new_input_ids_cpu.to(dev, non_blocking=True) new_position_ids = new_position_ids_cpu.to(dev, non_blocking=True) new_attention_mask = new_attention_mask_cpu.to(dev, non_blocking=True) new_labels = None if new_labels_cpu is None else new_labels_cpu.to(dev, non_blocking=True) base = tok_embed(new_input_ids) new_inputs_embeds = base.clone() # ---- (5.1) second resampler: object-level global summary ---- if use_resampler and use_second_resampler and len(batched_second_rows) > 0: if video_embeds is None: raise RuntimeError("use_second_resampler=True but video_embeds is None.") dev_emb = video_embeds.device dtype_emb = video_embeds.dtype D = video_embeds.shape[-1] N_obj2 = len(batched_second_rows) seqs2 = [] lens2 = [] for rows_all in batched_second_rows: if rows_all.numel() == 0: seqs2.append(torch.zeros(0, D, device=dev_emb, dtype=dtype_emb)) lens2.append(0) else: seqs2.append(video_embeds.index_select(0, rows_all.to(dev_emb))) lens2.append(int(rows_all.numel())) x2 = torch.nn.utils.rnn.pad_sequence(seqs2, batch_first=True) if len(seqs2) > 0 else torch.zeros(0, 0, D, device=dev_emb, dtype=dtype_emb) L2_max = x2.size(1) if x2.numel() > 0 else 0 lens2_t = torch.tensor(lens2, device=dev_emb, dtype=torch.long) if len(lens2) > 0 else torch.zeros(0, device=dev_emb, dtype=torch.long) ar2 = torch.arange(L2_max, device=dev_emb).unsqueeze(0) if L2_max > 0 else torch.zeros(1, 0, device=dev_emb, dtype=torch.long) mask2 = (ar2 < lens2_t.unsqueeze(1)) if L2_max > 0 else torch.zeros(0, 0, device=dev_emb, dtype=torch.bool) y2 = second_resampler(x2, attention_mask=mask2) y2 = y2.to(new_inputs_embeds.dtype) for j in range(N_obj2): b_cur = batched_second_bids[j] pos2 = batched_second_pos[j].to(dev) new_inputs_embeds[b_cur, pos2] = y2[j] # ---- (5.2) main resampler: temporal resampler---- if use_resampler and len(batched_obj_rows) > 0: if video_embeds is None: raise RuntimeError("use_resampler=True but video_embeds is None.") dev_emb = video_embeds.device dtype_emb = video_embeds.dtype D = video_embeds.shape[-1] N_obj = len(batched_obj_rows) lens = torch.tensor(batched_obj_lens, device=dev_emb, dtype=torch.long) L_max = int(lens.max().item()) if lens.numel() > 0 else 0 seqs = [] for rows in batched_obj_rows: if rows.numel() == 0: seqs.append(torch.zeros(0, D, device=dev_emb, dtype=dtype_emb)) else: seqs.append(video_embeds.index_select(0, rows.to(dev_emb))) x = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True) if len(seqs) > 0 else torch.zeros(0, 0, D, device=dev_emb, dtype=dtype_emb) ar = torch.arange(L_max, device=dev_emb).unsqueeze(0) if L_max > 0 else torch.zeros(1, 0, device=dev_emb, dtype=torch.long) mask = (ar < lens.unsqueeze(1)) if L_max > 0 else torch.zeros(0, 0, device=dev_emb, dtype=torch.bool) y = resampler(x, attention_mask=mask) y = y.to(new_inputs_embeds.dtype) per_b_indices: List[List[int]] = [[] for _ in range(B)] for i in range(N_obj): per_b_indices[batched_obj_bids[i]].append(i) for b in range(B): if not per_b_indices[b]: continue pos_list = [] emb_list = [] for i in per_b_indices[b]: pos_list.append(batched_obj_pos[i].to(dev)) emb_list.append(y[i]) pos_b = torch.cat(pos_list, dim=0) emb_b = torch.cat(emb_list, dim=0) new_inputs_embeds[b, pos_b] = emb_b # ---- (6) rope_deltas / cache_position ---- maxpos = new_position_ids.max(dim=0)[0].max(dim=1, keepdim=True)[0] rope_deltas = (maxpos + 1 - new_inputs_embeds.shape[1]).to(dtype=torch.long, device=dev) cache_position = torch.arange(new_inputs_embeds.shape[1], device=dev, dtype=torch.int32) return new_inputs_embeds, new_position_ids, new_attention_mask, rope_deltas, cache_position, new_input_ids, new_labels