@@ -304,14 +304,6 @@ def _expand(name, v):
304304
305305 if args .finetune_backbone :
306306 backbone .requires_grad_ (True )
307- else :
308- backbone .requires_grad_ (False )
309- backbone_module = unwrap_module (backbone )
310- if hasattr (backbone_module , "head" ):
311- for p in backbone_module .head .parameters ():
312- p .requires_grad = True
313- else :
314- raise RuntimeError ()
315307
316308 backbone_parameters = filter (lambda p : p .requires_grad , backbone .parameters ())
317309
@@ -347,7 +339,7 @@ def _expand(name, v):
347339 )
348340
349341 partial_fc .train ().cuda ()
350- # list_module_pfc.append(torch.compile(partial_fc))
342+
351343 list_module_pfc .append (partial_fc )
352344 dict_pfc_modules [head_name ] = partial_fc
353345
@@ -481,6 +473,7 @@ def wrap_ddp(model):
481473 shard_id = dataset_config .shard_id ,
482474 num_shards = dataset_config .num_shards
483475 )
476+
484477 elif dataset_config .dali_type == "ocr" :
485478 if args .debug :
486479 from dataloader .data_v2_ocr import SyntheticDataIter
@@ -495,8 +488,8 @@ def wrap_ddp(model):
495488 image_size = args .image_size ,
496489 workers = args .workers ,
497490 shard_id = dataset_config .shard_id ,
498- num_shards = dataset_config .num_shards
499- )
491+ num_shards = dataset_config .num_shards )
492+
500493 else :
501494 raise ValueError (
502495 f"dataset_config.dali_type { dataset_config .dali_type } not support!"
@@ -608,10 +601,10 @@ def wrap_ddp(model):
608601 out [mask_residual ] = sel_a
609602
610603 # mask_frame_sampling: sample 8 frames from 64, get all patches per frame
611- SEQ , FRAMES = 8 , 64
604+ FRAMES = 64
612605 if mask_frame_sampling .any ():
613606 nB = mask_frame_sampling .sum ().item ()
614- frames = torch .arange (SEQ , device = dev ) * (FRAMES // SEQ ) + torch .randint (FRAMES // SEQ , (nB , SEQ ), device = dev )
607+ frames = torch .arange (args . actual_num_frames , device = dev ) * (FRAMES // args . actual_num_frames ) + torch .randint (FRAMES // args . actual_num_frames , (nB , args . actual_num_frames ), device = dev )
615608 sel_b = (frames .unsqueeze (- 1 ) * args .num_tokens_per_frame + torch .arange (args .num_tokens_per_frame , device = dev )).reshape (nB , - 1 )
616609 if sel_b .size (1 ) > args .target_num :
617610 sel_b = sel_b [:, :args .target_num ]
@@ -654,7 +647,6 @@ def wrap_ddp(model):
654647 if mask_collage .any ():
655648 coll_idx = torch .nonzero (mask_collage , as_tuple = False ).squeeze (1 )
656649 nC = coll_idx .numel ()
657- SEQ = 8
658650 FRAMES = 64 # assume fixed 64 frames for head_subset
659651
660652 head_subset = head_input [coll_idx ] # [nC, C, 64, H, W] (must hold)
@@ -669,15 +661,15 @@ def wrap_ddp(model):
669661 Cf = head_subset .size (1 )
670662 Hf = head_subset .size (3 )
671663 Wf = head_subset .size (4 )
672- avg = FRAMES // SEQ # 8
673- base = torch .arange (SEQ , device = dev ) * avg
674- offs = torch .randint (avg , (nC , SEQ ), device = dev )
675- frames_idx = (base .unsqueeze (0 ) + offs ).long ().clamp (max = FRAMES - 1 ) # [nC, SEQ ], 范围在 [0, 63]
676- idx_expand = frames_idx .view (nC , 1 , SEQ , 1 , 1 ).expand (- 1 , Cf , - 1 , Hf , Wf ).to (head_subset .device )
677- sel_frames = torch .gather (head_subset , 2 , idx_expand ) # [nC, Cf, SEQ , Hf, Wf]
678- sel_frames = sel_frames .permute (0 , 2 , 1 , 3 , 4 ) # [nC, SEQ , Cf, Hf, Wf]
679- grid_rows = [sel_frames [:, i , :, :, :] for i in range (SEQ )]
680- grid = torch .cat (grid_rows , dim = - 2 ) # [nC, Cf, Hf*SEQ , Wf]
664+ avg = FRAMES // args . actual_num_frames # 8
665+ base = torch .arange (args . actual_num_frames , device = dev ) * avg
666+ offs = torch .randint (avg , (nC , args . actual_num_frames ), device = dev )
667+ frames_idx = (base .unsqueeze (0 ) + offs ).long ().clamp (max = FRAMES - 1 ) # [nC, actual_num_frames ], 范围在 [0, 63]
668+ idx_expand = frames_idx .view (nC , 1 , args . actual_num_frames , 1 , 1 ).expand (- 1 , Cf , - 1 , Hf , Wf ).to (head_subset .device )
669+ sel_frames = torch .gather (head_subset , 2 , idx_expand ) # [nC, Cf, actual_num_frames , Hf, Wf]
670+ sel_frames = sel_frames .permute (0 , 2 , 1 , 3 , 4 ) # [nC, actual_num_frames , Cf, Hf, Wf]
671+ grid_rows = [sel_frames [:, i , :, :, :] for i in range (args . actual_num_frames )]
672+ grid = torch .cat (grid_rows , dim = - 2 ) # [nC, Cf, Hf*args.actual_num_frames , Wf]
681673 with torch .amp .autocast (dtype = torch .bfloat16 , device_type = "cuda" ):
682674 collage_head_output = backbone_ddp_compiled (grid )
683675 if hasattr (collage_head_output , "pooler_output" ):
0 commit comments