@@ -589,116 +589,66 @@ def wrap_ddp(model):
589589 elif dataset_config .dali_type in ["decord_residual" ]:
590590 head_input = list_data_batch [head_id ]["videos" ] # [bs, C, 64, H, W]
591591 list_batch_sizes .append (head_input .size (0 ))
592- visible_indices = list_data_batch [head_id ]["video_visible_indices" ] # [bs, ?]
593- visible_indices = visible_indices .long ()
594-
595- bs = visible_indices .shape [0 ]
596- dev = visible_indices .device
592+ visible_indices = list_data_batch [head_id ]["video_visible_indices" ].long ()
597593
594+ bs , dev = visible_indices .shape [0 ], visible_indices .device
598595 out = visible_indices [:, :args .target_num ].clone ()
599- n1 = int (bs * 0.5 )
600- n2 = int (bs * 0.875 )
596+ n1 , n2 = int (bs * 0.5 ), int (bs * 0.875 )
601597
602598 idx_range = torch .arange (bs , device = dev )
603- mask_residual = idx_range < n1 # idx in [0, n1)
604- mask_frame_sampling = (idx_range >= n1 ) & (idx_range < n2 ) # idx in [n1, n2)
605- mask_collage = idx_range >= n2 # idx in [n2, bs)
599+ mask_residual = idx_range < n1
600+ mask_frame_sampling = (idx_range >= n1 ) & (idx_range < n2 )
601+ mask_collage = idx_range >= n2
606602
607- # For mask_residual: directly select first args.target_num patches
603+ # mask_residual: select first args.target_num patches
608604 if mask_residual .any ():
609605 sel_a = visible_indices [mask_residual , :args .target_num ]
610606 if sel_a .size (1 ) < args .target_num :
611- pad = sel_a [:, - 1 :].repeat (1 , args .target_num - sel_a .size (1 ))
612- sel_a = torch .cat ([sel_a , pad ], dim = 1 )
607+ sel_a = torch .cat ([sel_a , sel_a [:, - 1 :].expand (- 1 , args .target_num - sel_a .size (1 ))], dim = 1 )
613608 out [mask_residual ] = sel_a
614609
615- # For mask_frame_sampling: compute patch indices based on frame sampling
616- SEQ = 8
617- FRAMES = 64
610+ # mask_frame_sampling: sample 8 frames from 64, get all patches per frame
611+ SEQ , FRAMES = 8 , 64
618612 if mask_frame_sampling .any ():
619- nB = visible_indices [mask_frame_sampling ].size (0 )
620- avg = FRAMES // SEQ
621- base = torch .arange (SEQ , device = dev ) * avg
622- offs = torch .randint (avg , (nB , SEQ ), device = dev )
623- frames = base + offs # [nB, 8]
624-
625- per = torch .arange (args .num_tokens_per_frame , device = dev )
626- pos = (frames .unsqueeze (- 1 ) * args .num_tokens_per_frame + per ).reshape (nB , - 1 ) # [nB, 8*num_tokens_per_frame]
627- sel_b = pos .to (visible_indices .dtype )
628-
629- if sel_b .size (1 ) == args .target_num :
630- out [mask_frame_sampling ] = sel_b
631- elif sel_b .size (1 ) > args .target_num :
632- out [mask_frame_sampling ] = sel_b [:, :args .target_num ]
633- else :
634- pad = sel_b [:, - 1 :].repeat (1 , args .target_num - sel_b .size (1 ))
635- out [mask_frame_sampling ] = torch .cat ([sel_b , pad ], dim = 1 )
613+ nB = mask_frame_sampling .sum ().item ()
614+ frames = torch .arange (SEQ , device = dev ) * (FRAMES // SEQ ) + torch .randint (FRAMES // SEQ , (nB , SEQ ), device = dev )
615+ sel_b = (frames .unsqueeze (- 1 ) * args .num_tokens_per_frame + torch .arange (args .num_tokens_per_frame , device = dev )).reshape (nB , - 1 )
616+ if sel_b .size (1 ) > args .target_num :
617+ sel_b = sel_b [:, :args .target_num ]
618+ elif sel_b .size (1 ) < args .target_num :
619+ sel_b = torch .cat ([sel_b , sel_b [:, - 1 :].expand (- 1 , args .target_num - sel_b .size (1 ))], dim = 1 )
620+ out [mask_frame_sampling ] = sel_b
636621
637622 combined_mask = mask_residual | mask_frame_sampling
638623 if combined_mask .any ():
639- combined_idx = torch .nonzero (combined_mask , as_tuple = False ).squeeze (1 )
640- combined_video = head_input [combined_idx ] # [n, C, 64, H, W]
641- combined_out = out [combined_idx ] # [n, target_num]
642-
643- n_comb , C_vid , T_vid , H_vid , W_vid = combined_video .shape
644- Hp = H_vid // patch_size # patches per row
645- Wp = W_vid // patch_size # patches per col
646- patches_per_frame = Hp * Wp
647- total_patches = T_vid * patches_per_frame
648-
649- # Convert video to patches: [n, C, T*Hp*Wp, patch_size, patch_size]
650- # First reshape to [n, C, T, Hp, patch_size, Wp, patch_size]
651- video_reshaped = combined_video .view (n_comb , C_vid , T_vid , Hp , patch_size , Wp , patch_size )
652- # Permute to [n, C, T, Hp, Wp, patch_size, patch_size]
653- video_reshaped = video_reshaped .permute (0 , 1 , 2 , 3 , 5 , 4 , 6 )
654- # Reshape to [n, C, T*Hp*Wp, patch_size, patch_size]
655- video_patches = video_reshaped .reshape (n_comb , C_vid , total_patches , patch_size , patch_size )
656-
657- # Select patches using combined_out (visible_indices): [n, target_num]
658- # Expand combined_out for gathering: [n, target_num, 1, 1] -> [n, C, target_num, patch_size, patch_size]
659- idx_expand = combined_out .unsqueeze (1 ).unsqueeze (- 1 ).unsqueeze (- 1 ) # [n, 1, target_num, 1, 1]
660- idx_expand = idx_expand .expand (- 1 , C_vid , - 1 , patch_size , patch_size ) # [n, C, target_num, patch_size, patch_size]
661- selected_patches = torch .gather (video_patches , 2 , idx_expand ) # [n, C, target_num, patch_size, patch_size]
662-
663- # Reshape selected patches back to video format [n, C, T', H', W']
664- # We have target_num patches, need to figure out T', Hp', Wp'
665- # For simplicity: T' = target_num // patches_per_frame, and use original Hp, Wp
666- T_new = args .target_num // patches_per_frame
667- expected_patches = T_new * patches_per_frame
668-
669- # Handle case when target_num is not divisible by patches_per_frame
670- if expected_patches != args .target_num :
671- T_new = max (1 , T_new )
672- expected_patches = T_new * patches_per_frame
673- # Truncate or pad selected_patches to match expected_patches
674- if args .target_num > expected_patches :
675- selected_patches = selected_patches [:, :, :expected_patches , :, :]
676- else :
677- # Pad with the last patch repeated
678- pad_size = expected_patches - args .target_num
679- pad_patches = selected_patches [:, :, - 1 :, :, :].repeat (1 , 1 , pad_size , 1 , 1 )
680- selected_patches = torch .cat ([selected_patches , pad_patches ], dim = 2 )
681-
682- # Reshape: [n, C, expected_patches, patch_size, patch_size] -> [n, C, T_new, Hp, Wp, patch_size, patch_size]
683- # Then -> [n, C, T_new, Hp*patch_size, Wp*patch_size] = [n, C, T_new, H, W]
684- H_new = Hp * patch_size
685- W_new = Wp * patch_size
686-
687- # First reshape to [n, C, T_new, Hp, Wp, patch_size, patch_size]
688- selected_reshaped = selected_patches .view (n_comb , C_vid , T_new , Hp , Wp , patch_size , patch_size )
689- # Permute to [n, C, T_new, Hp, patch_size, Wp, patch_size]
690- selected_reshaped = selected_reshaped .permute (0 , 1 , 2 , 3 , 5 , 4 , 6 )
691- # Reshape to [n, C, T_new, H_new, W_new]
692- combined_head_input = selected_reshaped .reshape (n_comb , C_vid , T_new , H_new , W_new )
624+ combined_idx = combined_mask .nonzero (as_tuple = False ).squeeze (1 )
625+ video = head_input [combined_idx ] # [n, C, T, H, W]
626+ vis_idx = out [combined_idx ] # [n, target_num]
627+
628+ n , C , T , H , W = video .shape
629+ Hp , Wp = H // patch_size , W // patch_size
630+
631+ # Patchify: [n, C, T, H, W] -> [n, C, T*Hp*Wp, p, p]
632+ patches = video .view (n , C , T , Hp , patch_size , Wp , patch_size ).permute (0 , 1 , 2 , 3 , 5 , 4 , 6 ).reshape (n , C , T * Hp * Wp , patch_size , patch_size )
633+
634+ # Select patches by vis_idx
635+ idx = vis_idx [:, None , :, None , None ].expand (- 1 , C , - 1 , patch_size , patch_size )
636+ selected = torch .gather (patches , 2 , idx ) # [n, C, target_num, p, p]
637+
638+ # Unpatchify: [n, C, target_num, p, p] -> [n, C, T', H, W]
639+ T_new = args .target_num // (Hp * Wp )
640+ if T_new == 0 :
641+ T_new = 1
642+ num_patches = T_new * Hp * Wp
643+ if selected .size (2 ) > num_patches :
644+ selected = selected [:, :, :num_patches ]
645+ elif selected .size (2 ) < num_patches :
646+ selected = torch .cat ([selected , selected [:, :, - 1 :].expand (- 1 , - 1 , num_patches - selected .size (2 ), - 1 , - 1 )], dim = 2 )
647+ combined_head_input = selected .view (n , C , T_new , Hp , Wp , patch_size , patch_size ).permute (0 , 1 , 2 , 3 , 5 , 4 , 6 ).reshape (n , C , T_new , H , W )
693648
694649 with torch .amp .autocast (dtype = torch .bfloat16 , device_type = "cuda" ):
695- combined_head_output = backbone_ddp_compiled (combined_head_input , combined_out )
696- if hasattr (combined_head_output , "pooler_output" ):
697- combined_head_output = combined_head_output .pooler_output
698- else :
699- combined_head_output = combined_head_output ["head_output" ]
700-
701- combined_head_output = combined_head_output .float ()
650+ combined_head_output = backbone_ddp_compiled (combined_head_input , vis_idx )
651+ combined_head_output = (combined_head_output .pooler_output if hasattr (combined_head_output , "pooler_output" ) else combined_head_output ["head_output" ]).float ()
702652
703653
704654 if mask_collage .any ():
0 commit comments