@@ -580,113 +580,119 @@ def wrap_ddp(model):
580580 list_embedding .append (head_embedding )
581581
582582 elif dataset_config .dali_type in ["decord_residual" ]:
583- head_input = list_data_batch [head_id ]["videos" ] # [bs, C, 64, H, W]
583+ # Example: bs=16, target_num=2048 (8*256), num_tokens_per_frame=256, H=W=224, patch_size=16
584+ # Hp=Wp=14, patches_per_frame=196, T=64, total_patches=12544
585+
586+ head_input = list_data_batch [head_id ]["videos" ] # [16, 3, 64, 224, 224]
584587 list_batch_sizes .append (head_input .size (0 ))
585- visible_indices = list_data_batch [head_id ]["video_visible_indices" ].long ()
588+ visible_indices = list_data_batch [head_id ]["video_visible_indices" ].long () # [16, >=2048]
586589
587- bs = visible_indices .shape [0 ]
588- out = visible_indices [:, :args .target_num ].clone ()
589- n1 , n2 = int (bs * 0.5 ), int (bs * 0.875 )
590+ bs = visible_indices .shape [0 ] # 16
591+ out = visible_indices [:, :args .target_num ].clone () # [16, 2048]
592+ n1 , n2 = int (bs * 0.5 ), int (bs * 0.875 ) # n1=8, n2=14
590593
591- idx_range = torch .arange (bs ).cuda ()
592- mask_residual = idx_range < n1
593- mask_frame_sampling = (idx_range >= n1 ) & (idx_range < n2 )
594- mask_collage = idx_range >= n2
594+ idx_range = torch .arange (bs ).cuda () # [16]
595+ mask_residual = idx_range < n1 # [16], first 8 samples are True
596+ mask_frame_sampling = (idx_range >= n1 ) & (idx_range < n2 ) # [16], samples 8-13 are True
597+ mask_collage = idx_range >= n2 # [16], samples 14-15 are True
595598
596599 # mask_residual: select first args.target_num patches
597600 if mask_residual .any ():
598- sel_a = visible_indices [mask_residual , :args .target_num ]
601+ sel_a = visible_indices [mask_residual , :args .target_num ] # [8, 2048]
599602 if sel_a .size (1 ) < args .target_num :
600- sel_a = torch .cat ([sel_a , sel_a [:, - 1 :].expand (- 1 , args .target_num - sel_a .size (1 ))], dim = 1 )
601- out [mask_residual ] = sel_a
603+ sel_a = torch .cat ([sel_a , sel_a [:, - 1 :].expand (- 1 , args .target_num - sel_a .size (1 ))], dim = 1 ) # [8, 2048]
604+ out [mask_residual ] = sel_a # out[0:8] = sel_a
602605
603606 # mask_frame_sampling: sample 8 frames from 64, get all patches per frame
604607 FRAMES = 64
605608 if mask_frame_sampling .any ():
606- nB = mask_frame_sampling .sum ().item ()
607- frames = torch .arange (args .actual_num_frames ).cuda () * (FRAMES // args .actual_num_frames ) + torch .randint (FRAMES // args .actual_num_frames , (nB , args .actual_num_frames )).cuda ()
608- sel_b = (frames .unsqueeze (- 1 ) * args .num_tokens_per_frame + torch .arange (args .num_tokens_per_frame ).cuda ()).reshape (nB , - 1 )
609+ nB = mask_frame_sampling .sum ().item () # 6
610+ # frames: sample 1 frame from each of 8 bins (each bin has 8 frames)
611+ frames = torch .arange (args .actual_num_frames ).cuda () * (FRAMES // args .actual_num_frames ) + torch .randint (FRAMES // args .actual_num_frames , (nB , args .actual_num_frames )).cuda () # [6, 8], values in [0,7], [8,15], ..., [56,63]
612+ # sel_b: for each frame, get all 256 patches
613+ sel_b = (frames .unsqueeze (- 1 ) * args .num_tokens_per_frame + torch .arange (args .num_tokens_per_frame ).cuda ()).reshape (nB , - 1 ) # [6, 8*256] = [6, 2048]
609614 if sel_b .size (1 ) > args .target_num :
610- sel_b = sel_b [:, :args .target_num ]
615+ sel_b = sel_b [:, :args .target_num ] # [6, 2048]
611616 elif sel_b .size (1 ) < args .target_num :
612- sel_b = torch .cat ([sel_b , sel_b [:, - 1 :].expand (- 1 , args .target_num - sel_b .size (1 ))], dim = 1 )
613- out [mask_frame_sampling ] = sel_b
617+ sel_b = torch .cat ([sel_b , sel_b [:, - 1 :].expand (- 1 , args .target_num - sel_b .size (1 ))], dim = 1 ) # [6, 2048]
618+ out [mask_frame_sampling ] = sel_b # out[8:14] = sel_b
614619
615- combined_mask = mask_residual | mask_frame_sampling
620+ combined_mask = mask_residual | mask_frame_sampling # [16], first 14 samples are True
616621 if combined_mask .any ():
617- combined_idx = combined_mask .nonzero (as_tuple = False ).squeeze (1 )
618- video = head_input [combined_idx ] # [n, C, T, H, W ]
619- vis_idx = out [combined_idx ] # [n, target_num ]
622+ combined_idx = combined_mask .nonzero (as_tuple = False ).squeeze (1 ) # [14]
623+ video = head_input [combined_idx ] # [14, 3, 64, 224, 224 ]
624+ vis_idx = out [combined_idx ] # [14, 2048 ]
620625
621- n , C , T , H , W = video .shape
622- Hp , Wp = H // patch_size , W // patch_size
626+ n , C , T , H , W = video .shape # n=14, C=3, T=64, H=224, W=224
627+ Hp , Wp = H // patch_size , W // patch_size # Hp=14, Wp=14
623628
624629 # Patchify: [n, C, T, H, W] -> [n, C, T*Hp*Wp, p, p]
625- patches = video .view (n , C , T , Hp , patch_size , Wp , patch_size ).permute (0 , 1 , 2 , 3 , 5 , 4 , 6 ).reshape (n , C , T * Hp * Wp , patch_size , patch_size )
630+ # [14, 3, 64, 224, 224] -> [14, 3, 64, 14, 16, 14, 16] -> [14, 3, 64, 14, 14, 16, 16] -> [14, 3, 12544, 16, 16]
631+ patches = video .view (n , C , T , Hp , patch_size , Wp , patch_size ).permute (0 , 1 , 2 , 3 , 5 , 4 , 6 ).reshape (n , C , T * Hp * Wp , patch_size , patch_size ) # [14, 3, 12544, 16, 16]
626632
627633 # Select patches by vis_idx
628- idx = vis_idx [:, None , :, None , None ].expand (- 1 , C , - 1 , patch_size , patch_size )
629- selected = torch .gather (patches , 2 , idx ) # [n, C, target_num, p, p ]
634+ idx = vis_idx [:, None , :, None , None ].expand (- 1 , C , - 1 , patch_size , patch_size ) # [14, 3, 2048, 16, 16]
635+ selected = torch .gather (patches , 2 , idx ) # [14, 3, 2048, 16, 16 ]
630636
631637 # Unpatchify: [n, C, target_num, p, p] -> [n, C, T', H, W]
632- T_new = args .target_num // (Hp * Wp )
638+ T_new = args .target_num // (Hp * Wp ) # 2048 // 196 = 10
633639 if T_new == 0 :
634640 T_new = 1
635- num_patches = T_new * Hp * Wp
641+ num_patches = T_new * Hp * Wp # 10 * 196 = 1960
636642 if selected .size (2 ) > num_patches :
637- selected = selected [:, :, :num_patches ]
643+ selected = selected [:, :, :num_patches ] # [14, 3, 1960, 16, 16]
638644 elif selected .size (2 ) < num_patches :
639- selected = torch .cat ([selected , selected [:, :, - 1 :].expand (- 1 , - 1 , num_patches - selected .size (2 ), - 1 , - 1 )], dim = 2 )
640- combined_head_input = selected .view (n , C , T_new , Hp , Wp , patch_size , patch_size ).permute (0 , 1 , 2 , 3 , 5 , 4 , 6 ).reshape (n , C , T_new , H , W )
645+ selected = torch .cat ([selected , selected [:, :, - 1 :].expand (- 1 , - 1 , num_patches - selected .size (2 ), - 1 , - 1 )], dim = 2 ) # [14, 3, 1960, 16, 16]
646+ # [14, 3, 1960, 16, 16] -> [14, 3, 10, 14, 14, 16, 16] -> [14, 3, 10, 14, 16, 14, 16] -> [14, 3, 10, 224, 224]
647+ combined_head_input = selected .view (n , C , T_new , Hp , Wp , patch_size , patch_size ).permute (0 , 1 , 2 , 3 , 5 , 4 , 6 ).reshape (n , C , T_new , H , W ) # [14, 3, 10, 224, 224]
641648
642649 with torch .amp .autocast (dtype = torch .bfloat16 , device_type = "cuda" ):
643- combined_head_output = backbone_ddp_compiled (combined_head_input , vis_idx )
644- combined_head_output = (combined_head_output .pooler_output if hasattr (combined_head_output , "pooler_output" ) else combined_head_output ["head_output" ]).float ()
650+ combined_head_output = backbone_ddp_compiled (combined_head_input , vis_idx ) # input: [14, 3, 10, 224, 224], vis_idx: [14, 2048]
651+ combined_head_output = (combined_head_output .pooler_output if hasattr (combined_head_output , "pooler_output" ) else combined_head_output ["head_output" ]).float () # [14, D]
645652
646653
647654 if mask_collage .any ():
648- coll_idx = torch .nonzero (mask_collage , as_tuple = False ).squeeze (1 )
649- nC = coll_idx .numel ()
650- FRAMES = 64 # assume fixed 64 frames for head_subset
655+ coll_idx = torch .nonzero (mask_collage , as_tuple = False ).squeeze (1 ) # [2]
656+ nC = coll_idx .numel () # 2
657+ FRAMES = 64
651658
652- head_subset = head_input [coll_idx ] # [nC, C , 64, H, W] (must hold)
659+ head_subset = head_input [coll_idx ] # [2, 3 , 64, 224, 224]
653660
654- # 检查形状
655661 if head_subset .dim () != 5 or head_subset .size (2 ) != FRAMES :
656662 raise RuntimeError (
657663 f"collage branch expects head_subset shape [nC, C, { FRAMES } , H, W], got { tuple (head_subset .shape )} "
658664 )
659665
660- nC = head_subset .size (0 )
661- Cf = head_subset .size (1 )
662- Hf = head_subset .size (3 )
663- Wf = head_subset .size (4 )
664- avg = FRAMES // args .actual_num_frames # 8
665- base = torch .arange (args .actual_num_frames ).cuda () * avg
666- offs = torch .randint (avg , (nC , args .actual_num_frames )).cuda ()
667- frames_idx = (base .unsqueeze (0 ) + offs ).long ().clamp (max = FRAMES - 1 ) # [nC, actual_num_frames ], 范围在 [0, 63]
668- idx_expand = frames_idx .view (nC , 1 , args .actual_num_frames , 1 , 1 ).expand (- 1 , Cf , - 1 , Hf , Wf )
669- sel_frames = torch .gather (head_subset , 2 , idx_expand ) # [nC, Cf, actual_num_frames, Hf, Wf ]
670- sel_frames = sel_frames .permute (0 , 2 , 1 , 3 , 4 ) # [nC, actual_num_frames, Cf, Hf, Wf ]
671- grid_rows = [sel_frames [:, i , :, :, :] for i in range (args .actual_num_frames )]
672- grid = torch .cat (grid_rows , dim = - 2 ) # [nC, Cf, Hf*args.actual_num_frames, Wf]
666+ nC = head_subset .size (0 ) # 2
667+ Cf = head_subset .size (1 ) # 3
668+ Hf = head_subset .size (3 ) # 224
669+ Wf = head_subset .size (4 ) # 224
670+ avg = FRAMES // args .actual_num_frames # 64 // 8 = 8
671+ base = torch .arange (args .actual_num_frames ).cuda () * avg # [0, 8, 16, 24, 32, 40, 48, 56]
672+ offs = torch .randint (avg , (nC , args .actual_num_frames )).cuda () # [2, 8], values in [0, 7]
673+ frames_idx = (base .unsqueeze (0 ) + offs ).long ().clamp (max = FRAMES - 1 ) # [2, 8 ], values in [0, 63]
674+ idx_expand = frames_idx .view (nC , 1 , args .actual_num_frames , 1 , 1 ).expand (- 1 , Cf , - 1 , Hf , Wf ) # [2, 3, 8, 224, 224]
675+ sel_frames = torch .gather (head_subset , 2 , idx_expand ) # [2, 3, 8, 224, 224 ]
676+ sel_frames = sel_frames .permute (0 , 2 , 1 , 3 , 4 ) # [2, 8, 3, 224, 224 ]
677+ grid_rows = [sel_frames [:, i , :, :, :] for i in range (args .actual_num_frames )] # 8 x [2, 3, 224, 224]
678+ grid = torch .cat (grid_rows , dim = - 2 ) # [2, 3, 1792, 224] (1792 = 224 * 8)
673679 with torch .amp .autocast (dtype = torch .bfloat16 , device_type = "cuda" ):
674- collage_head_output = backbone_ddp_compiled (grid )
680+ collage_head_output = backbone_ddp_compiled (grid ) # input: [2, 3, 1792, 224]
675681 if hasattr (collage_head_output , "pooler_output" ):
676682 collage_head_output = collage_head_output .pooler_output
677683 else :
678684 collage_head_output = collage_head_output ["head_output" ]
679- collage_head_output = collage_head_output .float ()
685+ collage_head_output = collage_head_output .float () # [2, D]
680686
681- D = combined_head_output .size (1 )
687+ D = combined_head_output .size (1 ) # embedding dimension
682688
683- head_embedding_full = torch .zeros (bs , D , dtype = torch .float32 ).cuda ()
689+ head_embedding_full = torch .zeros (bs , D , dtype = torch .float32 ).cuda () # [16, D]
684690 if combined_mask .any ():
685- head_embedding_full [combined_idx ] = combined_head_output
691+ head_embedding_full [combined_idx ] = combined_head_output # head_embedding_full[0:14] = [14, D]
686692 if mask_collage .any ():
687- head_embedding_full [coll_idx ] = collage_head_output
693+ head_embedding_full [coll_idx ] = collage_head_output # head_embedding_full[14:16] = [2, D]
688694
689- list_embedding .append (head_embedding_full )
695+ list_embedding .append (head_embedding_full ) # [16, D]
690696
691697 elif dataset_config .dali_type in ["origin" , "ocr" ]:
692698 head_input = list_data_batch [head_id ]["pixel_values" ]
0 commit comments