Skip to content

Commit 722817c

Browse files
committed
hf export: download trust_remote_code .py files via snapshot_download
Signed-off-by: jrausch <jrausch@nvidia.com>
1 parent f3d90a3 commit 722817c

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

examples/megatron_bridge/distill.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,23 @@ def _build_model_provider(hf_path):
338338
show_progress=True,
339339
strict=True,
340340
)
341-
# Copy config.json from student_hf_path (handles both local paths and HF model IDs)
342-
AutoConfig.from_pretrained(
341+
# Save config.json from student_hf_path (handles both local paths and HF model IDs)
342+
config = AutoConfig.from_pretrained(
343343
args.student_hf_path, trust_remote_code=args.trust_remote_code
344-
).save_pretrained(args.hf_export_path)
344+
)
345+
config.save_pretrained(args.hf_export_path)
346+
347+
# Download trust_remote_code .py files from HF hub so the exported
348+
# checkpoint can be loaded with trust_remote_code=True from a local path.
349+
# save_pretrained only writes config.json, not the modeling code.
350+
if hasattr(config, "auto_map") and isinstance(config.auto_map, dict):
351+
from huggingface_hub import snapshot_download
352+
353+
snapshot_download(
354+
repo_id=args.student_hf_model,
355+
local_dir=args.hf_export_path,
356+
allow_patterns=["*.py"],
357+
)
345358

346359

347360
if __name__ == "__main__":

0 commit comments

Comments
 (0)