@@ -595,6 +595,11 @@ def wrap_ddp(model):
595595 mask_frame_sampling = (idx_range >= n1 ) & (idx_range < n2 ) # idx in [n1, n2)
596596 mask_collage = idx_range >= n2 # idx in [n2, bs)
597597
598+ # Storage for selected frames per sample
599+ SEQ = 8
600+ FRAMES = 64
601+ selected_frames_all = torch .zeros (bs , SEQ , device = dev , dtype = torch .long )
602+
598603 if mask_residual .any ():
599604 vis_a = visible_indices [mask_residual , :args .total_indices ]
600605 must = vis_a [:, :args .must_num ]
@@ -613,18 +618,47 @@ def wrap_ddp(model):
613618 sel_a = torch .cat ([sel_a , pad ], dim = 1 )
614619 out [mask_residual ] = sel_a
615620
621+ # Extract frame indices from sel_a for residual branch
622+ # sel_a contains patch indices, convert to frame indices
623+ nA = sel_a .size (0 )
624+ # Get frame index for each patch
625+ frame_indices_per_patch = sel_a // args .num_tokens_per_frame # [nA, target_num]
626+
627+ # For each sample, get unique sorted frame indices
628+ # We need exactly SEQ frames
629+ # Vectorized approach: sort and take unique by checking consecutive differences
630+ sorted_frames , _ = torch .sort (frame_indices_per_patch , dim = 1 ) # [nA, target_num]
631+
632+ # Find where consecutive values differ (mark the first of each unique value)
633+ diff_mask = torch .cat ([
634+ torch .ones (nA , 1 , device = dev , dtype = torch .bool ),
635+ sorted_frames [:, 1 :] != sorted_frames [:, :- 1 ]
636+ ], dim = 1 ) # [nA, target_num]
637+
638+ # For each sample, collect the first SEQ unique frames
639+ residual_frames = torch .zeros (nA , SEQ , device = dev , dtype = torch .long )
640+ for i in range (nA ):
641+ unique_positions = diff_mask [i ].nonzero (as_tuple = True )[0 ]
642+ unique_values = sorted_frames [i , unique_positions ]
643+ num_unique = unique_values .numel ()
644+ if num_unique >= SEQ :
645+ residual_frames [i ] = unique_values [:SEQ ]
646+ else :
647+ residual_frames [i , :num_unique ] = unique_values
648+ if num_unique > 0 :
649+ residual_frames [i , num_unique :] = unique_values [- 1 ]
650+ selected_frames_all [mask_residual ] = residual_frames
651+
616652
617653 if mask_frame_sampling .any ():
618654 nB = visible_indices [mask_frame_sampling ].size (0 )
619- SEQ = 8
620- FRAMES = 64
621655 avg = FRAMES // SEQ
622656 base = torch .arange (SEQ , device = dev ) * avg
623657 offs = torch .randint (avg , (nB , SEQ ), device = dev )
624658 frames = base + offs # [nB, 8]
625659
626- per = torch .arange (args .must_num , device = dev )
627- pos = (frames .unsqueeze (- 1 ) * args .must_num + per ).reshape (nB , - 1 ) # [nB, 8*args.must_num ]
660+ per = torch .arange (args .num_tokens_per_frame , device = dev )
661+ pos = (frames .unsqueeze (- 1 ) * args .num_tokens_per_frame + per ).reshape (nB , - 1 ) # [nB, 8*num_tokens_per_frame ]
628662 sel_b = pos .to (visible_indices .dtype )
629663
630664 if sel_b .size (1 ) == args .target_num :
@@ -635,12 +669,28 @@ def wrap_ddp(model):
635669 pad = sel_b [:, - 1 :].repeat (1 , args .target_num - sel_b .size (1 ))
636670 out [mask_frame_sampling ] = torch .cat ([sel_b , pad ], dim = 1 )
637671
672+ # Store frame indices for frame_sampling branch
673+ selected_frames_all [mask_frame_sampling ] = frames
674+
638675
639676 combined_mask = mask_residual | mask_frame_sampling
640677 if combined_mask .any ():
641678 combined_idx = torch .nonzero (combined_mask , as_tuple = False ).squeeze (1 )
642- combined_head_input = head_input [combined_idx ] # 保持原样(可能为 [n, C, H, W] 或其他)
679+ combined_head_input_full = head_input [combined_idx ] # [n, C, 64, H, W]
643680 combined_out = out [combined_idx ]
681+ combined_frames = selected_frames_all [combined_idx ] # [n, SEQ]
682+
683+ # Select frames from the video based on selected_frames
684+ # combined_head_input_full: [n, C, 64, H, W]
685+ # combined_frames: [n, SEQ]
686+ nComb = combined_head_input_full .size (0 )
687+ Cf = combined_head_input_full .size (1 )
688+ Hf = combined_head_input_full .size (3 )
689+ Wf = combined_head_input_full .size (4 )
690+
691+ # Expand frame indices for gather: [n, 1, SEQ, 1, 1] -> [n, C, SEQ, H, W]
692+ frame_idx_expand = combined_frames .view (nComb , 1 , SEQ , 1 , 1 ).expand (- 1 , Cf , - 1 , Hf , Wf )
693+ combined_head_input = torch .gather (combined_head_input_full , 2 , frame_idx_expand ) # [n, C, SEQ, H, W]
644694
645695 with torch .amp .autocast (dtype = torch .bfloat16 , device_type = "cuda" ):
646696 combined_head_output = backbone_ddp_compiled (combined_head_input , combined_out )
0 commit comments