@@ -578,14 +578,23 @@ def wrap_ddp(model):
578578 list_embedding .append (head_embedding )
579579
580580 elif dataset_config .dali_type in ["decord_residual" ]:
581- head_input = list_data_batch [head_id ]["videos" ]
581+ head_input = list_data_batch [head_id ]["videos" ] # [bs, C, 64, H, W]
582582 list_batch_sizes .append (head_input .size (0 ))
583- visible_indices = list_data_batch [head_id ]["video_visible_indices" ] # [bs, ?],需要至少 args.total_indices 合法列
583+ visible_indices = list_data_batch [head_id ]["video_visible_indices" ] # [bs, ?]
584584 visible_indices = visible_indices .long ()
585585
586586 bs = visible_indices .shape [0 ]
587587 dev = visible_indices .device
588588
589+ # Get patch_size from backbone config
590+ backbone_module = unwrap_module (backbone )
591+ if hasattr (backbone_module , 'config' ):
592+ patch_size = backbone_module .config .patch_size
593+ elif hasattr (backbone_module , 'embeddings' ) and hasattr (backbone_module .embeddings , 'patch_size' ):
594+ patch_size = backbone_module .embeddings .patch_size
595+ else :
596+ patch_size = 16 # default fallback
597+
589598 out = visible_indices [:, :args .target_num ].clone ()
590599 n1 = int (bs * 0.5 )
591600 n2 = int (bs * 0.875 )
@@ -595,61 +604,17 @@ def wrap_ddp(model):
595604 mask_frame_sampling = (idx_range >= n1 ) & (idx_range < n2 ) # idx in [n1, n2)
596605 mask_collage = idx_range >= n2 # idx in [n2, bs)
597606
598- # Storage for selected frames per sample
599- SEQ = 8
600- FRAMES = 64
601- selected_frames_all = torch .zeros (bs , SEQ , device = dev , dtype = torch .long )
602-
607+ # For mask_residual: directly select first args.target_num patches
603608 if mask_residual .any ():
604- vis_a = visible_indices [mask_residual , :args .total_indices ]
605- must = vis_a [:, :args .must_num ]
606- candidates = vis_a [:, args .must_num :args .total_indices ]
607- k = max (0 , args .target_num - args .must_num )
608- k = min (k , candidates .size (1 ))
609- if k > 0 :
610- scores = torch .rand (vis_a .size (0 ), candidates .size (1 ), device = dev )
611- idx = scores .topk (k , dim = 1 ).indices
612- sampled = torch .gather (candidates , 1 , idx )
613- sel_a = torch .cat ([must , sampled ], dim = 1 )
614- else :
615- sel_a = must
609+ sel_a = visible_indices [mask_residual , :args .target_num ]
616610 if sel_a .size (1 ) < args .target_num :
617611 pad = sel_a [:, - 1 :].repeat (1 , args .target_num - sel_a .size (1 ))
618612 sel_a = torch .cat ([sel_a , pad ], dim = 1 )
619613 out [mask_residual ] = sel_a
620614
621- # Extract frame indices from sel_a for residual branch
622- # sel_a contains patch indices, convert to frame indices
623- nA = sel_a .size (0 )
624- # Get frame index for each patch
625- frame_indices_per_patch = sel_a // args .num_tokens_per_frame # [nA, target_num]
626-
627- # For each sample, get unique sorted frame indices
628- # We need exactly SEQ frames
629- # Vectorized approach: sort and take unique by checking consecutive differences
630- sorted_frames , _ = torch .sort (frame_indices_per_patch , dim = 1 ) # [nA, target_num]
631-
632- # Find where consecutive values differ (mark the first of each unique value)
633- diff_mask = torch .cat ([
634- torch .ones (nA , 1 , device = dev , dtype = torch .bool ),
635- sorted_frames [:, 1 :] != sorted_frames [:, :- 1 ]
636- ], dim = 1 ) # [nA, target_num]
637-
638- # For each sample, collect the first SEQ unique frames
639- residual_frames = torch .zeros (nA , SEQ , device = dev , dtype = torch .long )
640- for i in range (nA ):
641- unique_positions = diff_mask [i ].nonzero (as_tuple = True )[0 ]
642- unique_values = sorted_frames [i , unique_positions ]
643- num_unique = unique_values .numel ()
644- if num_unique >= SEQ :
645- residual_frames [i ] = unique_values [:SEQ ]
646- elif num_unique > 0 :
647- residual_frames [i , :num_unique ] = unique_values
648- residual_frames [i , num_unique :] = unique_values [- 1 ]
649- # If num_unique == 0, residual_frames[i] stays at zeros (frame 0)
650- selected_frames_all [mask_residual ] = residual_frames
651-
652-
615+ # For mask_frame_sampling: compute patch indices based on frame sampling
616+ SEQ = 8
617+ FRAMES = 64
653618 if mask_frame_sampling .any ():
654619 nB = visible_indices [mask_frame_sampling ].size (0 )
655620 avg = FRAMES // SEQ
@@ -669,35 +634,69 @@ def wrap_ddp(model):
669634 pad = sel_b [:, - 1 :].repeat (1 , args .target_num - sel_b .size (1 ))
670635 out [mask_frame_sampling ] = torch .cat ([sel_b , pad ], dim = 1 )
671636
672- # Store frame indices for frame_sampling branch
673- selected_frames_all [mask_frame_sampling ] = frames
674-
675-
676637 combined_mask = mask_residual | mask_frame_sampling
677638 if combined_mask .any ():
678639 combined_idx = torch .nonzero (combined_mask , as_tuple = False ).squeeze (1 )
679- combined_head_input_full = head_input [combined_idx ] # [n, C, 64, H, W]
680- combined_out = out [combined_idx ]
681- combined_frames = selected_frames_all [combined_idx ] # [n, SEQ]
682-
683- # Select frames from the video based on selected_frames
684- # combined_head_input_full: [n, C, 64, H, W]
685- # combined_frames: [n, SEQ]
686- nComb = combined_head_input_full .size (0 )
687- Cf = combined_head_input_full .size (1 )
688- Hf = combined_head_input_full .size (3 )
689- Wf = combined_head_input_full .size (4 )
690-
691- # Expand frame indices for gather: [n, 1, SEQ, 1, 1] -> [n, C, SEQ, H, W]
692- frame_idx_expand = combined_frames .view (nComb , 1 , SEQ , 1 , 1 ).expand (- 1 , Cf , - 1 , Hf , Wf )
693- combined_head_input = torch .gather (combined_head_input_full , 2 , frame_idx_expand ) # [n, C, SEQ, H, W]
640+ combined_video = head_input [combined_idx ] # [n, C, 64, H, W]
641+ combined_out = out [combined_idx ] # [n, target_num]
642+
643+ n_comb , C_vid , T_vid , H_vid , W_vid = combined_video .shape
644+ Hp = H_vid // patch_size # patches per row
645+ Wp = W_vid // patch_size # patches per col
646+ patches_per_frame = Hp * Wp
647+ total_patches = T_vid * patches_per_frame
648+
649+ # Convert video to patches: [n, C, T*Hp*Wp, patch_size, patch_size]
650+ # First reshape to [n, C, T, Hp, patch_size, Wp, patch_size]
651+ video_reshaped = combined_video .view (n_comb , C_vid , T_vid , Hp , patch_size , Wp , patch_size )
652+ # Permute to [n, C, T, Hp, Wp, patch_size, patch_size]
653+ video_reshaped = video_reshaped .permute (0 , 1 , 2 , 3 , 5 , 4 , 6 )
654+ # Reshape to [n, C, T*Hp*Wp, patch_size, patch_size]
655+ video_patches = video_reshaped .reshape (n_comb , C_vid , total_patches , patch_size , patch_size )
656+
657+ # Select patches using combined_out (visible_indices): [n, target_num]
658+ # Expand combined_out for gathering: [n, target_num, 1, 1] -> [n, C, target_num, patch_size, patch_size]
659+ idx_expand = combined_out .unsqueeze (1 ).unsqueeze (- 1 ).unsqueeze (- 1 ) # [n, 1, target_num, 1, 1]
660+ idx_expand = idx_expand .expand (- 1 , C_vid , - 1 , patch_size , patch_size ) # [n, C, target_num, patch_size, patch_size]
661+ selected_patches = torch .gather (video_patches , 2 , idx_expand ) # [n, C, target_num, patch_size, patch_size]
662+
663+ # Reshape selected patches back to video format [n, C, T', H', W']
664+ # We have target_num patches, need to figure out T', Hp', Wp'
665+ # For simplicity: T' = target_num // patches_per_frame, and use original Hp, Wp
666+ T_new = args .target_num // patches_per_frame
667+ expected_patches = T_new * patches_per_frame
668+
669+ # Handle case when target_num is not divisible by patches_per_frame
670+ if expected_patches != args .target_num :
671+ T_new = max (1 , T_new )
672+ expected_patches = T_new * patches_per_frame
673+ # Truncate or pad selected_patches to match expected_patches
674+ if args .target_num > expected_patches :
675+ selected_patches = selected_patches [:, :, :expected_patches , :, :]
676+ else :
677+ # Pad with the last patch repeated
678+ pad_size = expected_patches - args .target_num
679+ pad_patches = selected_patches [:, :, - 1 :, :, :].repeat (1 , 1 , pad_size , 1 , 1 )
680+ selected_patches = torch .cat ([selected_patches , pad_patches ], dim = 2 )
681+
682+ # Reshape: [n, C, expected_patches, patch_size, patch_size] -> [n, C, T_new, Hp, Wp, patch_size, patch_size]
683+ # Then -> [n, C, T_new, Hp*patch_size, Wp*patch_size] = [n, C, T_new, H, W]
684+ H_new = Hp * patch_size
685+ W_new = Wp * patch_size
686+
687+ # First reshape to [n, C, T_new, Hp, Wp, patch_size, patch_size]
688+ selected_reshaped = selected_patches .view (n_comb , C_vid , T_new , Hp , Wp , patch_size , patch_size )
689+ # Permute to [n, C, T_new, Hp, patch_size, Wp, patch_size]
690+ selected_reshaped = selected_reshaped .permute (0 , 1 , 2 , 3 , 5 , 4 , 6 )
691+ # Reshape to [n, C, T_new, H_new, W_new]
692+ combined_head_input = selected_reshaped .reshape (n_comb , C_vid , T_new , H_new , W_new )
694693
695694 with torch .amp .autocast (dtype = torch .bfloat16 , device_type = "cuda" ):
696695 combined_head_output = backbone_ddp_compiled (combined_head_input , combined_out )
697696 if hasattr (combined_head_output , "pooler_output" ):
698697 combined_head_output = combined_head_output .pooler_output
699698 else :
700- combined_head_output = combined_head_output ["head_output" ]
699+ combined_head_output = combined_head_output ["head_output" ]
701700
702701 combined_head_output = combined_head_output .float ()
703702
0 commit comments