Skip to content

Commit 770ad97

Browse files
Create shared megatron calibration forward loop for MLM and MBridge prune/quantize
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 1ce5ea2 commit 770ad97

7 files changed

Lines changed: 134 additions & 178 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Changelog
2222
- Add composable ``$import`` system for recipe YAML configs, enabling reusable config snippets referenced via ``{$import: name}`` markers. All built-in PTQ recipes converted to use imports with shared snippets under ``modelopt_recipes/configs/`` (numeric formats, quant_cfg building blocks, presets). See :ref:`composable-imports`.
2323
- Add offline DFlash speculative decoding training. Train the draft module from pre-computed base-model hidden states dumped by ``examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py``; base-model transformer layers are deleted after conversion to save memory. Controlled by the auto-derived ``dflash_offline`` flag on ``DFlashConfig`` (derived from ``data_args.offline_data_path``). The dump scripts now share ``collect_hidden_states/common.py`` for aux-layer selection (``--aux-layers eagle|dflash|<list>``) and optional assistant-token ``loss_mask`` for answer-only-loss training.
2424
- Add support for ``active_params`` (for MoE models) and ``memory_mb`` constraints in Minitron pruning on top of existing ``params`` constraint. You can also provide multiple constraints. See `examples/pruning/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/pruning>`_ for more details. The underlying utility functions ``mcore_param_count``, ``mcore_memory_footprint_mb``, and ``print_mcore_model_stats`` in ``modelopt.torch.nas.plugins.megatron_model_stats`` are also available for standalone use to compute parameter counts and memory footprints (weights + KV-cache + Mamba state) for any Megatron-Core model.
25+
- Enable ``--calib_mbs>1`` support in Minitron pruning for faster calibration
2526
- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_ptq#mxfp4--nvfp4-cast-for-gpt-oss>`__ for usage.
2627
- DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``.
2728
- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress.

examples/megatron_bridge/prune_minitron.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,8 @@
5353
import modelopt.torch.prune as mtp
5454
import modelopt.torch.utils.distributed as dist
5555
from modelopt.torch.utils import get_supported_datasets, print_rank_0, warn_rank_0
56-
from modelopt.torch.utils.plugins.mbridge import (
57-
get_hf_mbridge_calibration_loop,
58-
load_mbridge_model_from_hf,
59-
)
56+
from modelopt.torch.utils.plugins.mbridge import load_mbridge_model_from_hf
57+
from modelopt.torch.utils.plugins.megatron_calibration import get_megatron_calibration_forward_loop
6058
from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu
6159

6260

@@ -104,10 +102,7 @@ def get_args() -> argparse.Namespace:
104102
"--calib_num_samples", type=int, default=1024, help="Number of samples for calibration"
105103
)
106104
# TODO: Add support for pre-training dataset (pre-tokenized)
107-
# TODO: only allow mbs>1 for pretraining dataset
108-
parser.add_argument(
109-
"--calib_mbs", type=int, default=1, choices=[1], help="Calibration micro-batch size"
110-
)
105+
parser.add_argument("--calib_mbs", type=int, default=1, help="Calibration micro-batch size")
111106
parser.add_argument("--calib_gbs", type=int, default=1, help="Calibration global batch size")
112107
parser.add_argument("--seq_length", type=int, default=4096)
113108
# Pruning parameters
@@ -227,6 +222,12 @@ def get_args() -> argparse.Namespace:
227222
args = parser.parse_args()
228223

229224
# Validate pruning target arguments
225+
if args.calib_mbs > args.calib_gbs:
226+
args.calib_gbs = args.calib_mbs
227+
print_rank_0(
228+
f"{args.calib_gbs=} is less than {args.calib_mbs=}, setting it to {args.calib_mbs}."
229+
)
230+
230231
_nas_targets = [
231232
args.prune_target_params,
232233
args.prune_target_active_params,
@@ -296,16 +297,12 @@ def main(args: argparse.Namespace):
296297
init_model_parallel=True,
297298
moe_grouped_gemm=False,
298299
)
299-
forward_loop = get_hf_mbridge_calibration_loop(
300-
model=model,
301-
provider=provider,
302-
tokenizer=tokenizer,
303-
hf_model_name_or_path=args.hf_model_name_or_path,
304-
trust_remote_code=args.trust_remote_code,
300+
forward_loop = get_megatron_calibration_forward_loop(
301+
tokenizer,
305302
dataset_name=args.calib_dataset_name,
306303
num_samples=args.calib_num_samples,
307-
micro_batch_size=args.calib_mbs,
308-
global_batch_size=args.calib_gbs,
304+
seq_length=args.seq_length,
305+
batch_size=args.calib_gbs,
309306
)
310307

311308
pruning_config = {

examples/pruning/README.md

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ Please see example snippets of both modes for Minitron pruning on Megatron-Bridg
5050
```python
5151
import torch
5252
import modelopt.torch.prune as mtp
53-
from modelopt.torch.utils.plugins.mbridge import (
54-
get_hf_mbridge_calibration_loop,
55-
load_mbridge_model_from_hf,
53+
from modelopt.torch.utils.plugins.mbridge import load_mbridge_model_from_hf
54+
from modelopt.torch.utils.plugins.megatron_calibration import (
55+
get_megatron_calibration_forward_loop,
5656
)
5757

5858
# Import the Megatron-Bridge Qwen3-8B model from Hugging Face checkpoint
@@ -67,13 +67,11 @@ bridge, provider, model, unwrapped_model, tokenizer = load_mbridge_model_from_hf
6767
)
6868

6969
# Set up the forward loop to run on 1024 train samples
70-
forward_loop = get_hf_mbridge_calibration_loop(
71-
model=model,
72-
provider=provider,
73-
tokenizer=tokenizer,
74-
hf_model_name_or_path="Qwen/Qwen3-8B",
70+
forward_loop = get_megatron_calibration_forward_loop(
71+
tokenizer,
7572
dataset_name="nemotron-post-training-dataset-v2",
7673
num_samples=1024,
74+
seq_length=4096,
7775
)
7876

7977
# Run pruning on the unwrapped model

modelopt/torch/utils/dataset_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def get_dataset_dataloader(
662662
batch_size: int = 1,
663663
num_samples: int | list[int] = 512,
664664
max_sample_length: int = 512,
665-
device: torch.device | None = None,
665+
device: torch.device | str | None = None,
666666
include_labels: bool = False,
667667
apply_chat_template: bool = False,
668668
pack: bool = False,

modelopt/torch/utils/plugins/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
from modelopt.torch.utils import import_plugin
1919

20+
with import_plugin("megatron_calibration"):
21+
from .megatron_calibration import *
22+
2023
with import_plugin("megatron_generate"):
2124
from .megatron_generate import *
2225

modelopt/torch/utils/plugins/mbridge.py

Lines changed: 2 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -14,43 +14,23 @@
1414
# limitations under the License.
1515
"""Megatron-Bridge plugins for using with Model-Optimizer."""
1616

17-
from collections.abc import Callable
1817
from typing import Any
1918

20-
import torch.nn as nn
21-
from datasets import DatasetDict
2219
from megatron.bridge import AutoBridge
23-
from megatron.bridge.data.builders.hf_dataset import HFDatasetConfig
24-
from megatron.bridge.data.loaders import setup_data_iterators
25-
from megatron.bridge.data.utils import get_dataset_provider
2620
from megatron.bridge.models.gpt_provider import GPTModelProvider
2721
from megatron.bridge.models.hf_pretrained.utils import is_safe_repo
2822
from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider
29-
from megatron.bridge.training.config import (
30-
CheckpointConfig,
31-
ConfigContainer,
32-
LoggerConfig,
33-
OptimizerConfig,
34-
SchedulerConfig,
35-
TrainingConfig,
36-
runtime_config_update,
37-
)
38-
from megatron.bridge.training.eval import evaluate_and_print_results
39-
from megatron.bridge.training.gpt_step import forward_step
40-
from megatron.bridge.training.state import GlobalState
41-
from megatron.bridge.training.tokenizers.config import TokenizerConfig
4223
from megatron.core.models.gpt import GPTModel
4324
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
4425
from megatron.core.models.mamba import MambaModel
45-
from megatron.core.parallel_state import get_data_parallel_group
4626
from megatron.core.transformer.module import MegatronModule
4727
from megatron.core.utils import unwrap_model
4828
from transformers import AutoTokenizer
4929

5030
from modelopt.torch.nas.plugins.megatron import get_te_mamba_stack_spec
51-
from modelopt.torch.utils import get_dataset_samples, print_rank_0, warn_rank_0
31+
from modelopt.torch.utils import print_rank_0
5232

53-
__all__ = ["get_hf_mbridge_calibration_loop", "load_mbridge_model_from_hf"]
33+
__all__ = ["load_mbridge_model_from_hf"]
5434

5535

5636
def load_mbridge_model_from_hf(
@@ -118,134 +98,3 @@ def load_mbridge_model_from_hf(
11898
)
11999

120100
return bridge, provider, model, unwrapped_model, tokenizer
121-
122-
123-
def _get_dataset_cfg(
124-
dataset_name: str,
125-
num_samples: int,
126-
seq_length: int,
127-
apply_chat_template: bool = True,
128-
tokenizer: AutoTokenizer | None = None,
129-
) -> HFDatasetConfig:
130-
"""Get a dataset config for the dataset."""
131-
dataset = get_dataset_samples(
132-
dataset_name, num_samples, apply_chat_template=apply_chat_template, tokenizer=tokenizer
133-
)
134-
dataset_cfg = HFDatasetConfig(
135-
dataset_name=f"{dataset_name}_{num_samples}",
136-
dataset_dict=DatasetDict({"train": dataset}),
137-
process_example_fn=lambda example, tokenizer: {"input": example, "output": ""},
138-
seq_length=seq_length,
139-
dataloader_type="batch",
140-
num_workers=1,
141-
do_validation=False,
142-
do_test=False,
143-
val_proportion=None,
144-
split_val_from_train=False,
145-
rewrite=True,
146-
)
147-
148-
return dataset_cfg
149-
150-
151-
def get_hf_mbridge_calibration_loop(
152-
*,
153-
model: list[MegatronModule],
154-
provider: GPTModelProvider | MambaModelProvider,
155-
tokenizer: AutoTokenizer,
156-
hf_model_name_or_path: str,
157-
trust_remote_code: bool = False,
158-
dataset_name: str = "nemotron-post-training-dataset-v2",
159-
num_samples: int = 512,
160-
micro_batch_size: int = 1,
161-
global_batch_size: int = 1,
162-
) -> Callable[[nn.Module], None]:
163-
"""Get a modelopt calibration loop for a Megatron-Bridge model.
164-
165-
Args:
166-
model: The model to calibrate.
167-
provider: The provider to use for the model.
168-
tokenizer: The tokenizer to use for the model.
169-
hf_model_name_or_path: The name or path of the HF model.
170-
trust_remote_code: Whether to trust remote code.
171-
dataset_name: The name of the dataset to use for evaluation.
172-
num_samples: The number of samples to use for evaluation.
173-
micro_batch_size: The micro batch size to use for evaluation.
174-
global_batch_size: The global batch size to use for evaluation.
175-
176-
Returns:
177-
A function that can be used to calibrate the model with a modelopt.torch API.
178-
"""
179-
if global_batch_size < micro_batch_size:
180-
warn_rank_0(
181-
f"{global_batch_size=} is smaller than {micro_batch_size=}. Setting gbs to {micro_batch_size}."
182-
)
183-
global_batch_size = micro_batch_size
184-
num_iters = num_samples // global_batch_size
185-
186-
cfg = ConfigContainer(
187-
model=provider,
188-
train=TrainingConfig(
189-
micro_batch_size=micro_batch_size,
190-
global_batch_size=global_batch_size,
191-
train_iters=num_iters,
192-
eval_iters=num_iters,
193-
skip_train=True,
194-
),
195-
# TODO: Replace validation args in train with validation config in nemo:26.04
196-
# validation=ValidationConfig(eval_iters=num_iters, eval_interval=1, skip_train=True),
197-
dataset=_get_dataset_cfg(
198-
dataset_name,
199-
num_samples,
200-
provider.seq_length,
201-
apply_chat_template=True,
202-
tokenizer=tokenizer,
203-
),
204-
tokenizer=TokenizerConfig(
205-
tokenizer_type="HuggingFaceTokenizer",
206-
tokenizer_model=hf_model_name_or_path,
207-
# NOTE: Issue with Nemotron Nano v2 tokenizer returning bool hence using use_fast=True as a WAR
208-
hf_tokenizer_kwargs={
209-
"trust_remote_code": trust_remote_code,
210-
"use_fast": tokenizer.is_fast,
211-
},
212-
),
213-
# Unused
214-
optimizer=OptimizerConfig(optimizer="adam", lr=1e-4, use_distributed_optimizer=False),
215-
scheduler=SchedulerConfig(lr_decay_style="constant"),
216-
logger=LoggerConfig(),
217-
checkpoint=CheckpointConfig(),
218-
)
219-
runtime_config_update(cfg)
220-
221-
state = GlobalState()
222-
state.cfg = cfg
223-
224-
dataset_provider = get_dataset_provider(cfg.dataset)
225-
226-
def _train_valid_test_datasets_provider(
227-
train_val_test_num_samples: tuple, dataset_cfg: HFDatasetConfig
228-
):
229-
return dataset_provider(train_val_test_num_samples, dataset_cfg, tokenizer=state.tokenizer)
230-
231-
train_data_iterator, _, _ = setup_data_iterators(
232-
cfg=cfg,
233-
train_state=state.train_state,
234-
model_length=len(model),
235-
train_valid_test_datasets_provider=_train_valid_test_datasets_provider,
236-
dp_group=get_data_parallel_group(),
237-
)
238-
239-
def forward_loop(m):
240-
evaluate_and_print_results(
241-
state,
242-
prefix="iteration 1",
243-
forward_step_func=forward_step,
244-
data_iterator=train_data_iterator,
245-
model=model,
246-
config=cfg,
247-
verbose=True,
248-
write_to_tensorboard=False,
249-
)
250-
251-
return forward_loop

0 commit comments

Comments
 (0)