Skip to content

Commit 0781a04

Browse files
MBridge pruning minor fix for saving pruned NemotronH (#887)
## What does this PR do? **Type of change:** Bug fix <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> ## Testing <!-- Mention how have you tested your change if applicable. --> Nemotron Nano v2 pruned can be saved <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed Hugging Face model loading to properly respect the `trust_remote_code` parameter during model instantiation. * **Improvements** * Enhanced distributed training logging with rank-0 aware warning and logging mechanisms for cleaner, non-redundant output in multi-GPU and multi-node scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 2a9f431 commit 0781a04

File tree

4 files changed

+10
-6
lines changed

4 files changed

+10
-6
lines changed

examples/megatron_bridge/distill.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ def _build_model_provider(hf_path):
198198
manual_gc=True,
199199
manual_gc_interval=100,
200200
),
201+
# TODO: Replace validation args in train with validation config in nemo:26.04
202+
# validation=ValidationConfig(eval_interval=args.eval_interval, eval_iters=args.eval_iters),
201203
optimizer=optimizer_config,
202204
scheduler=scheduler_config,
203205
ddp=DistributedDataParallelConfig(

examples/megatron_bridge/prune_minitron.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,9 @@ def score_func_mmlu(m):
380380
AutoModelForCausalLM.from_config(
381381
hf_cfg, trust_remote_code=args.trust_remote_code
382382
).save_pretrained(args.output_hf_path, trust_remote_code=args.trust_remote_code)
383-
pruned_bridge = AutoBridge.from_hf_pretrained(args.output_hf_path)
383+
pruned_bridge = AutoBridge.from_hf_pretrained(
384+
args.output_hf_path, trust_remote_code=args.trust_remote_code
385+
)
384386
pruned_bridge.save_hf_weights(model, args.output_hf_path)
385387
print_rank_0(f"Saved pruned model to {args.output_hf_path} in HF checkpoint format")
386388

modelopt/torch/opt/searcher.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@
2727
from collections.abc import Callable
2828
from contextlib import nullcontext
2929
from typing import Any, final
30-
from warnings import warn
3130

3231
import numpy as np
3332
import pulp
3433
import torch
3534
import torch.nn as nn
3635

3736
from modelopt.torch.utils import distributed as dist
38-
from modelopt.torch.utils import no_stdout, print_rank_0, run_forward_loop
37+
from modelopt.torch.utils import no_stdout, print_rank_0, run_forward_loop, warn_rank_0
3938

4039
LimitsTuple = tuple[float, float]
4140
ConstraintsDict = dict[str, str | float | dict | None]
@@ -244,12 +243,11 @@ def load_search_checkpoint(self) -> bool:
244243
if checkpoint is None:
245244
return False
246245
if not os.path.exists(checkpoint):
247-
if dist.is_master():
248-
warn(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.")
246+
warn_rank_0(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.")
249247
return False
250248

251249
# iterate through state dict and load keys
252-
print(f"Loading searcher state from {checkpoint}...")
250+
print_rank_0(f"Loading searcher state from {checkpoint}...")
253251
# Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input
254252
state_dict = torch.load(checkpoint, weights_only=False)
255253
assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!"

modelopt/torch/utils/plugins/mbridge.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ def get_hf_mbridge_calibration_loop(
191191
eval_iters=num_iters,
192192
skip_train=True,
193193
),
194+
# TODO: Replace validation args in train with validation config in nemo:26.04
195+
# validation=ValidationConfig(eval_iters=num_iters, eval_interval=1, skip_train=True),
194196
dataset=_get_dataset_cfg(
195197
dataset_name,
196198
num_samples,

0 commit comments

Comments
 (0)