Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/megatron_bridge/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def _build_model_provider(hf_path):
manual_gc=True,
manual_gc_interval=100,
),
# TODO: Replace validation args in train with validation config in nemo:26.04
# validation=ValidationConfig(eval_interval=args.eval_interval, eval_iters=args.eval_iters),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will you implement this in the current PR or a later PR?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later. It wont be merged into 26.02 container so will raise error if we do it now but wanted to keep it for reference incase we try to mount latest m-bridge and run the script

optimizer=optimizer_config,
scheduler=scheduler_config,
ddp=DistributedDataParallelConfig(
Expand Down
4 changes: 3 additions & 1 deletion examples/megatron_bridge/prune_minitron.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ def score_func_mmlu(m):
AutoModelForCausalLM.from_config(
hf_cfg, trust_remote_code=args.trust_remote_code
).save_pretrained(args.output_hf_path, trust_remote_code=args.trust_remote_code)
pruned_bridge = AutoBridge.from_hf_pretrained(args.output_hf_path)
pruned_bridge = AutoBridge.from_hf_pretrained(
args.output_hf_path, trust_remote_code=args.trust_remote_code
)
pruned_bridge.save_hf_weights(model, args.output_hf_path)
print_rank_0(f"Saved pruned model to {args.output_hf_path} in HF checkpoint format")

Expand Down
8 changes: 3 additions & 5 deletions modelopt/torch/opt/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@
from collections.abc import Callable
from contextlib import nullcontext
from typing import Any, final
from warnings import warn

import numpy as np
import pulp
import torch
import torch.nn as nn

from modelopt.torch.utils import distributed as dist
from modelopt.torch.utils import no_stdout, print_rank_0, run_forward_loop
from modelopt.torch.utils import no_stdout, print_rank_0, run_forward_loop, warn_rank_0

LimitsTuple = tuple[float, float]
ConstraintsDict = dict[str, str | float | dict | None]
Expand Down Expand Up @@ -244,12 +243,11 @@ def load_search_checkpoint(self) -> bool:
if checkpoint is None:
return False
if not os.path.exists(checkpoint):
if dist.is_master():
warn(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.")
warn_rank_0(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.")
return False

# iterate through state dict and load keys
print(f"Loading searcher state from {checkpoint}...")
print_rank_0(f"Loading searcher state from {checkpoint}...")
# Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input
state_dict = torch.load(checkpoint, weights_only=False)
assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!"
Expand Down
2 changes: 2 additions & 0 deletions modelopt/torch/utils/plugins/mbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ def get_hf_mbridge_calibration_loop(
eval_iters=num_iters,
skip_train=True,
),
# TODO: Replace validation args in train with validation config in nemo:26.04
# validation=ValidationConfig(eval_iters=num_iters, eval_interval=1, skip_train=True),
dataset=_get_dataset_cfg(
dataset_name,
num_samples,
Expand Down
Loading