@@ -413,6 +413,15 @@ def wrap_ddp(model):
413413 backbone_ddp = wrap_ddp (backbone )
414414 backbone_ddp_compiled = torch .compile (backbone_ddp )
415415
416+ # Get patch_size from backbone config (outside of training loop for efficiency)
417+ backbone_module = unwrap_module (backbone )
418+ if hasattr (backbone_module , 'config' ):
419+ patch_size = backbone_module .config .patch_size
420+ elif hasattr (backbone_module , 'embeddings' ) and hasattr (backbone_module .embeddings , 'patch_size' ):
421+ patch_size = backbone_module .embeddings .patch_size
422+ else :
423+ patch_size = 16 # default fallback
424+
416425 list_dali_dataloader = []
417426 list_head_names = []
418427 for head_id , dataset_config in enumerate (args .list_datasets ):
@@ -586,15 +595,6 @@ def wrap_ddp(model):
586595 bs = visible_indices .shape [0 ]
587596 dev = visible_indices .device
588597
589- # Get patch_size from backbone config
590- backbone_module = unwrap_module (backbone )
591- if hasattr (backbone_module , 'config' ):
592- patch_size = backbone_module .config .patch_size
593- elif hasattr (backbone_module , 'embeddings' ) and hasattr (backbone_module .embeddings , 'patch_size' ):
594- patch_size = backbone_module .embeddings .patch_size
595- else :
596- patch_size = 16 # default fallback
597-
598598 out = visible_indices [:, :args .target_num ].clone ()
599599 n1 = int (bs * 0.5 )
600600 n2 = int (bs * 0.875 )
0 commit comments