diff --git a/examples/megatron_bridge/distill.py b/examples/megatron_bridge/distill.py index c21bf73121..31f1cfc71c 100644 --- a/examples/megatron_bridge/distill.py +++ b/examples/megatron_bridge/distill.py @@ -198,6 +198,8 @@ def _build_model_provider(hf_path): manual_gc=True, manual_gc_interval=100, ), + # TODO: Replace validation args in train with validation config in nemo:26.04 + # validation=ValidationConfig(eval_interval=args.eval_interval, eval_iters=args.eval_iters), optimizer=optimizer_config, scheduler=scheduler_config, ddp=DistributedDataParallelConfig( diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py index 44eac3a31a..c4da627f14 100644 --- a/examples/megatron_bridge/prune_minitron.py +++ b/examples/megatron_bridge/prune_minitron.py @@ -380,7 +380,9 @@ def score_func_mmlu(m): AutoModelForCausalLM.from_config( hf_cfg, trust_remote_code=args.trust_remote_code ).save_pretrained(args.output_hf_path, trust_remote_code=args.trust_remote_code) - pruned_bridge = AutoBridge.from_hf_pretrained(args.output_hf_path) + pruned_bridge = AutoBridge.from_hf_pretrained( + args.output_hf_path, trust_remote_code=args.trust_remote_code + ) pruned_bridge.save_hf_weights(model, args.output_hf_path) print_rank_0(f"Saved pruned model to {args.output_hf_path} in HF checkpoint format") diff --git a/modelopt/torch/opt/searcher.py b/modelopt/torch/opt/searcher.py index 9e73b143ce..ab3930c207 100644 --- a/modelopt/torch/opt/searcher.py +++ b/modelopt/torch/opt/searcher.py @@ -27,7 +27,6 @@ from collections.abc import Callable from contextlib import nullcontext from typing import Any, final -from warnings import warn import numpy as np import pulp @@ -35,7 +34,7 @@ import torch.nn as nn from modelopt.torch.utils import distributed as dist -from modelopt.torch.utils import no_stdout, print_rank_0, run_forward_loop +from modelopt.torch.utils import no_stdout, print_rank_0, run_forward_loop, warn_rank_0 LimitsTuple = tuple[float, float] ConstraintsDict = dict[str, str | float | dict | None] @@ -244,12 +243,11 @@ def load_search_checkpoint(self) -> bool: if checkpoint is None: return False if not os.path.exists(checkpoint): - if dist.is_master(): - warn(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.") + warn_rank_0(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.") return False # iterate through state dict and load keys - print(f"Loading searcher state from {checkpoint}...") + print_rank_0(f"Loading searcher state from {checkpoint}...") # Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input state_dict = torch.load(checkpoint, weights_only=False) assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!" diff --git a/modelopt/torch/utils/plugins/mbridge.py b/modelopt/torch/utils/plugins/mbridge.py index ed2551ee15..94cdf87cf5 100644 --- a/modelopt/torch/utils/plugins/mbridge.py +++ b/modelopt/torch/utils/plugins/mbridge.py @@ -191,6 +191,8 @@ def get_hf_mbridge_calibration_loop( eval_iters=num_iters, skip_train=True, ), + # TODO: Replace validation args in train with validation config in nemo:26.04 + # validation=ValidationConfig(eval_iters=num_iters, eval_interval=1, skip_train=True), dataset=_get_dataset_cfg( dataset_name, num_samples,