Skip to content

Commit bf886f9

Browse files
Copilotanxiangsir
andcommitted
Fix ViT interface: select frames before passing to backbone for residual and frame_sampling branches
Co-authored-by: anxiangsir <31175974+anxiangsir@users.noreply.github.com>
1 parent 3a34c74 commit bf886f9

1 file changed

Lines changed: 55 additions & 5 deletions

File tree

training/train.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,11 @@ def wrap_ddp(model):
595595
mask_frame_sampling = (idx_range >= n1) & (idx_range < n2) # idx in [n1, n2)
596596
mask_collage = idx_range >= n2 # idx in [n2, bs)
597597

598+
# Storage for selected frames per sample
599+
SEQ = 8
600+
FRAMES = 64
601+
selected_frames_all = torch.zeros(bs, SEQ, device=dev, dtype=torch.long)
602+
598603
if mask_residual.any():
599604
vis_a = visible_indices[mask_residual, :args.total_indices]
600605
must = vis_a[:, :args.must_num]
@@ -613,18 +618,47 @@ def wrap_ddp(model):
613618
sel_a = torch.cat([sel_a, pad], dim=1)
614619
out[mask_residual] = sel_a
615620

621+
# Extract frame indices from sel_a for residual branch
622+
# sel_a contains patch indices, convert to frame indices
623+
nA = sel_a.size(0)
624+
# Get frame index for each patch
625+
frame_indices_per_patch = sel_a // args.num_tokens_per_frame # [nA, target_num]
626+
627+
# For each sample, get unique sorted frame indices
628+
# We need exactly SEQ frames
629+
# Vectorized approach: sort and take unique by checking consecutive differences
630+
sorted_frames, _ = torch.sort(frame_indices_per_patch, dim=1) # [nA, target_num]
631+
632+
# Find where consecutive values differ (mark the first of each unique value)
633+
diff_mask = torch.cat([
634+
torch.ones(nA, 1, device=dev, dtype=torch.bool),
635+
sorted_frames[:, 1:] != sorted_frames[:, :-1]
636+
], dim=1) # [nA, target_num]
637+
638+
# For each sample, collect the first SEQ unique frames
639+
residual_frames = torch.zeros(nA, SEQ, device=dev, dtype=torch.long)
640+
for i in range(nA):
641+
unique_positions = diff_mask[i].nonzero(as_tuple=True)[0]
642+
unique_values = sorted_frames[i, unique_positions]
643+
num_unique = unique_values.numel()
644+
if num_unique >= SEQ:
645+
residual_frames[i] = unique_values[:SEQ]
646+
else:
647+
residual_frames[i, :num_unique] = unique_values
648+
if num_unique > 0:
649+
residual_frames[i, num_unique:] = unique_values[-1]
650+
selected_frames_all[mask_residual] = residual_frames
651+
616652

617653
if mask_frame_sampling.any():
618654
nB = visible_indices[mask_frame_sampling].size(0)
619-
SEQ = 8
620-
FRAMES = 64
621655
avg = FRAMES // SEQ
622656
base = torch.arange(SEQ, device=dev) * avg
623657
offs = torch.randint(avg, (nB, SEQ), device=dev)
624658
frames = base + offs # [nB, 8]
625659

626-
per = torch.arange(args.must_num, device=dev)
627-
pos = (frames.unsqueeze(-1) * args.must_num + per).reshape(nB, -1) # [nB, 8*args.must_num]
660+
per = torch.arange(args.num_tokens_per_frame, device=dev)
661+
pos = (frames.unsqueeze(-1) * args.num_tokens_per_frame + per).reshape(nB, -1) # [nB, 8*num_tokens_per_frame]
628662
sel_b = pos.to(visible_indices.dtype)
629663

630664
if sel_b.size(1) == args.target_num:
@@ -635,12 +669,28 @@ def wrap_ddp(model):
635669
pad = sel_b[:, -1:].repeat(1, args.target_num - sel_b.size(1))
636670
out[mask_frame_sampling] = torch.cat([sel_b, pad], dim=1)
637671

672+
# Store frame indices for frame_sampling branch
673+
selected_frames_all[mask_frame_sampling] = frames
674+
638675

639676
combined_mask = mask_residual | mask_frame_sampling
640677
if combined_mask.any():
641678
combined_idx = torch.nonzero(combined_mask, as_tuple=False).squeeze(1)
642-
combined_head_input = head_input[combined_idx] # 保持原样(可能为 [n, C, H, W] 或其他)
679+
combined_head_input_full = head_input[combined_idx] # [n, C, 64, H, W]
643680
combined_out = out[combined_idx]
681+
combined_frames = selected_frames_all[combined_idx] # [n, SEQ]
682+
683+
# Select frames from the video based on selected_frames
684+
# combined_head_input_full: [n, C, 64, H, W]
685+
# combined_frames: [n, SEQ]
686+
nComb = combined_head_input_full.size(0)
687+
Cf = combined_head_input_full.size(1)
688+
Hf = combined_head_input_full.size(3)
689+
Wf = combined_head_input_full.size(4)
690+
691+
# Expand frame indices for gather: [n, 1, SEQ, 1, 1] -> [n, C, SEQ, H, W]
692+
frame_idx_expand = combined_frames.view(nComb, 1, SEQ, 1, 1).expand(-1, Cf, -1, Hf, Wf)
693+
combined_head_input = torch.gather(combined_head_input_full, 2, frame_idx_expand) # [n, C, SEQ, H, W]
644694

645695
with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
646696
combined_head_output = backbone_ddp_compiled(combined_head_input, combined_out)

0 commit comments

Comments
 (0)