Skip to content

Commit bef1e90

Browse files
Copilotanxiangsir
andcommitted
Move patch_size extraction outside of training loop
Co-authored-by: anxiangsir <31175974+anxiangsir@users.noreply.github.com>
1 parent c0ce246 commit bef1e90

1 file changed

Lines changed: 9 additions & 9 deletions

File tree

training/train.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,15 @@ def wrap_ddp(model):
413413
backbone_ddp = wrap_ddp(backbone)
414414
backbone_ddp_compiled = torch.compile(backbone_ddp)
415415

416+
# Get patch_size from backbone config (outside of training loop for efficiency)
417+
backbone_module = unwrap_module(backbone)
418+
if hasattr(backbone_module, 'config'):
419+
patch_size = backbone_module.config.patch_size
420+
elif hasattr(backbone_module, 'embeddings') and hasattr(backbone_module.embeddings, 'patch_size'):
421+
patch_size = backbone_module.embeddings.patch_size
422+
else:
423+
patch_size = 16 # default fallback
424+
416425
list_dali_dataloader = []
417426
list_head_names = []
418427
for head_id, dataset_config in enumerate(args.list_datasets):
@@ -586,15 +595,6 @@ def wrap_ddp(model):
586595
bs = visible_indices.shape[0]
587596
dev = visible_indices.device
588597

589-
# Get patch_size from backbone config
590-
backbone_module = unwrap_module(backbone)
591-
if hasattr(backbone_module, 'config'):
592-
patch_size = backbone_module.config.patch_size
593-
elif hasattr(backbone_module, 'embeddings') and hasattr(backbone_module.embeddings, 'patch_size'):
594-
patch_size = backbone_module.embeddings.patch_size
595-
else:
596-
patch_size = 16 # default fallback
597-
598598
out = visible_indices[:, :args.target_num].clone()
599599
n1 = int(bs * 0.5)
600600
n2 = int(bs * 0.875)

0 commit comments

Comments
 (0)