Skip to content

Commit d174a08

Browse files
Create shared megatron calibration forward loop for MLM and MBridge prune/quantize
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent c4c662e commit d174a08

6 files changed

Lines changed: 153 additions & 174 deletions

File tree

examples/megatron_bridge/prune_minitron.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,8 @@
5353
import modelopt.torch.prune as mtp
5454
import modelopt.torch.utils.distributed as dist
5555
from modelopt.torch.utils import get_supported_datasets, print_rank_0, warn_rank_0
56-
from modelopt.torch.utils.plugins.mbridge import (
57-
get_hf_mbridge_calibration_loop,
58-
load_mbridge_model_from_hf,
59-
)
56+
from modelopt.torch.utils.plugins.mbridge import load_mbridge_model_from_hf
57+
from modelopt.torch.utils.plugins.megatron_calibration import get_megatron_calibration_forward_loop
6058
from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu
6159

6260

@@ -296,16 +294,12 @@ def main(args: argparse.Namespace):
296294
init_model_parallel=True,
297295
moe_grouped_gemm=False,
298296
)
299-
forward_loop = get_hf_mbridge_calibration_loop(
300-
model=model,
301-
provider=provider,
302-
tokenizer=tokenizer,
303-
hf_model_name_or_path=args.hf_model_name_or_path,
304-
trust_remote_code=args.trust_remote_code,
297+
forward_loop = get_megatron_calibration_forward_loop(
298+
tokenizer,
305299
dataset_name=args.calib_dataset_name,
306300
num_samples=args.calib_num_samples,
307-
micro_batch_size=args.calib_mbs,
308-
global_batch_size=args.calib_gbs,
301+
seq_length=args.seq_length,
302+
batch_size=args.calib_gbs,
309303
)
310304

311305
pruning_config = {

examples/pruning/README.md

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ Please see example snippets of both modes for Minitron pruning on Megatron-Bridg
5050
```python
5151
import torch
5252
import modelopt.torch.prune as mtp
53-
from modelopt.torch.utils.plugins.mbridge import (
54-
get_hf_mbridge_calibration_loop,
55-
load_mbridge_model_from_hf,
53+
from modelopt.torch.utils.plugins.mbridge import load_mbridge_model_from_hf
54+
from modelopt.torch.utils.plugins.megatron_calibration import (
55+
get_megatron_calibration_forward_loop,
5656
)
5757

5858
# Import the Megatron-Bridge Qwen3-8B model from Hugging Face checkpoint
@@ -67,13 +67,11 @@ bridge, provider, model, unwrapped_model, tokenizer = load_mbridge_model_from_hf
6767
)
6868

6969
# Set up the forward loop to run on 1024 train samples
70-
forward_loop = get_hf_mbridge_calibration_loop(
71-
model=model,
72-
provider=provider,
73-
tokenizer=tokenizer,
74-
hf_model_name_or_path="Qwen/Qwen3-8B",
70+
forward_loop = get_megatron_calibration_forward_loop(
71+
tokenizer,
7572
dataset_name="nemotron-post-training-dataset-v2",
7673
num_samples=1024,
74+
seq_length=4096,
7775
)
7876

7977
# Run pruning on the unwrapped model

modelopt/torch/utils/dataset_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def get_dataset_dataloader(
617617
batch_size: int = 1,
618618
num_samples: int | list[int] = 512,
619619
max_sample_length: int = 512,
620-
device: torch.device | None = None,
620+
device: torch.device | str | None = None,
621621
include_labels: bool = False,
622622
apply_chat_template: bool = False,
623623
pack: bool = False,

modelopt/torch/utils/plugins/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
from modelopt.torch.utils import import_plugin
1919

20+
with import_plugin("megatron_calibration"):
21+
from .megatron_calibration import *
22+
2023
with import_plugin("megatron_generate"):
2124
from .megatron_generate import *
2225

modelopt/torch/utils/plugins/mbridge.py

Lines changed: 2 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -14,43 +14,23 @@
1414
# limitations under the License.
1515
"""Megatron-Bridge plugins for using with Model-Optimizer."""
1616

17-
from collections.abc import Callable
1817
from typing import Any
1918

20-
import torch.nn as nn
21-
from datasets import DatasetDict
2219
from megatron.bridge import AutoBridge
23-
from megatron.bridge.data.builders.hf_dataset import HFDatasetConfig
24-
from megatron.bridge.data.loaders import setup_data_iterators
25-
from megatron.bridge.data.utils import get_dataset_provider
2620
from megatron.bridge.models.gpt_provider import GPTModelProvider
2721
from megatron.bridge.models.hf_pretrained.utils import is_safe_repo
2822
from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider
29-
from megatron.bridge.training.config import (
30-
CheckpointConfig,
31-
ConfigContainer,
32-
LoggerConfig,
33-
OptimizerConfig,
34-
SchedulerConfig,
35-
TrainingConfig,
36-
runtime_config_update,
37-
)
38-
from megatron.bridge.training.eval import evaluate_and_print_results
39-
from megatron.bridge.training.gpt_step import forward_step
40-
from megatron.bridge.training.state import GlobalState
41-
from megatron.bridge.training.tokenizers.config import TokenizerConfig
4223
from megatron.core.models.gpt import GPTModel
4324
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
4425
from megatron.core.models.mamba import MambaModel
45-
from megatron.core.parallel_state import get_data_parallel_group
4626
from megatron.core.transformer.module import MegatronModule
4727
from megatron.core.utils import unwrap_model
4828
from transformers import AutoTokenizer
4929

5030
from modelopt.torch.nas.plugins.megatron import get_te_mamba_stack_spec
51-
from modelopt.torch.utils import get_dataset_samples, print_rank_0, warn_rank_0
31+
from modelopt.torch.utils import print_rank_0
5232

53-
__all__ = ["get_hf_mbridge_calibration_loop", "load_mbridge_model_from_hf"]
33+
__all__ = ["load_mbridge_model_from_hf"]
5434

5535

5636
def load_mbridge_model_from_hf(
@@ -118,134 +98,3 @@ def load_mbridge_model_from_hf(
11898
)
11999

120100
return bridge, provider, model, unwrapped_model, tokenizer
121-
122-
123-
def _get_dataset_cfg(
124-
dataset_name: str,
125-
num_samples: int,
126-
seq_length: int,
127-
apply_chat_template: bool = True,
128-
tokenizer: AutoTokenizer | None = None,
129-
) -> HFDatasetConfig:
130-
"""Get a dataset config for the dataset."""
131-
dataset = get_dataset_samples(
132-
dataset_name, num_samples, apply_chat_template=apply_chat_template, tokenizer=tokenizer
133-
)
134-
dataset_cfg = HFDatasetConfig(
135-
dataset_name=f"{dataset_name}_{num_samples}",
136-
dataset_dict=DatasetDict({"train": dataset}),
137-
process_example_fn=lambda example, tokenizer: {"input": example, "output": ""},
138-
seq_length=seq_length,
139-
dataloader_type="batch",
140-
num_workers=1,
141-
do_validation=False,
142-
do_test=False,
143-
val_proportion=None,
144-
split_val_from_train=False,
145-
rewrite=True,
146-
)
147-
148-
return dataset_cfg
149-
150-
151-
def get_hf_mbridge_calibration_loop(
152-
*,
153-
model: list[MegatronModule],
154-
provider: GPTModelProvider | MambaModelProvider,
155-
tokenizer: AutoTokenizer,
156-
hf_model_name_or_path: str,
157-
trust_remote_code: bool = False,
158-
dataset_name: str = "nemotron-post-training-dataset-v2",
159-
num_samples: int = 512,
160-
micro_batch_size: int = 1,
161-
global_batch_size: int = 1,
162-
) -> Callable[[nn.Module], None]:
163-
"""Get a modelopt calibration loop for a Megatron-Bridge model.
164-
165-
Args:
166-
model: The model to calibrate.
167-
provider: The provider to use for the model.
168-
tokenizer: The tokenizer to use for the model.
169-
hf_model_name_or_path: The name or path of the HF model.
170-
trust_remote_code: Whether to trust remote code.
171-
dataset_name: The name of the dataset to use for evaluation.
172-
num_samples: The number of samples to use for evaluation.
173-
micro_batch_size: The micro batch size to use for evaluation.
174-
global_batch_size: The global batch size to use for evaluation.
175-
176-
Returns:
177-
A function that can be used to calibrate the model with a modelopt.torch API.
178-
"""
179-
if global_batch_size < micro_batch_size:
180-
warn_rank_0(
181-
f"{global_batch_size=} is smaller than {micro_batch_size=}. Setting gbs to {micro_batch_size}."
182-
)
183-
global_batch_size = micro_batch_size
184-
num_iters = num_samples // global_batch_size
185-
186-
cfg = ConfigContainer(
187-
model=provider,
188-
train=TrainingConfig(
189-
micro_batch_size=micro_batch_size,
190-
global_batch_size=global_batch_size,
191-
train_iters=num_iters,
192-
eval_iters=num_iters,
193-
skip_train=True,
194-
),
195-
# TODO: Replace validation args in train with validation config in nemo:26.04
196-
# validation=ValidationConfig(eval_iters=num_iters, eval_interval=1, skip_train=True),
197-
dataset=_get_dataset_cfg(
198-
dataset_name,
199-
num_samples,
200-
provider.seq_length,
201-
apply_chat_template=True,
202-
tokenizer=tokenizer,
203-
),
204-
tokenizer=TokenizerConfig(
205-
tokenizer_type="HuggingFaceTokenizer",
206-
tokenizer_model=hf_model_name_or_path,
207-
# NOTE: Issue with Nemotron Nano v2 tokenizer returning bool hence using use_fast=True as a WAR
208-
hf_tokenizer_kwargs={
209-
"trust_remote_code": trust_remote_code,
210-
"use_fast": tokenizer.is_fast,
211-
},
212-
),
213-
# Unused
214-
optimizer=OptimizerConfig(optimizer="adam", lr=1e-4, use_distributed_optimizer=False),
215-
scheduler=SchedulerConfig(lr_decay_style="constant"),
216-
logger=LoggerConfig(),
217-
checkpoint=CheckpointConfig(),
218-
)
219-
runtime_config_update(cfg)
220-
221-
state = GlobalState()
222-
state.cfg = cfg
223-
224-
dataset_provider = get_dataset_provider(cfg.dataset)
225-
226-
def _train_valid_test_datasets_provider(
227-
train_val_test_num_samples: tuple, dataset_cfg: HFDatasetConfig
228-
):
229-
return dataset_provider(train_val_test_num_samples, dataset_cfg, tokenizer=state.tokenizer)
230-
231-
train_data_iterator, _, _ = setup_data_iterators(
232-
cfg=cfg,
233-
train_state=state.train_state,
234-
model_length=len(model),
235-
train_valid_test_datasets_provider=_train_valid_test_datasets_provider,
236-
dp_group=get_data_parallel_group(),
237-
)
238-
239-
def forward_loop(m):
240-
evaluate_and_print_results(
241-
state,
242-
prefix="iteration 1",
243-
forward_step_func=forward_step,
244-
data_iterator=train_data_iterator,
245-
model=model,
246-
config=cfg,
247-
verbose=True,
248-
write_to_tensorboard=False,
249-
)
250-
251-
return forward_loop
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Shared calibration forward-loop builder for Megatron-Core models.
17+
18+
Drives a logits-free prefill pass through the model over a calibration dataset,
19+
producing the ``forward_loop`` callable that ``mtq.quantize`` / ``mtp.prune`` /
20+
``mtq.calibrate`` expect. Replaces the bespoke calibration loops in
21+
``Megatron-LM/examples/post_training/modelopt/{quantize,prune}.py``,
22+
``Megatron-Bridge/examples/quantization/quantize.py``, and
23+
``examples/megatron_bridge/prune_minitron.py``.
24+
25+
Picks the best primitives from each existing path:
26+
- ``get_dataset_dataloader`` for dataset surface (HF registry + JSONL auto-detection,
27+
multi-source blending, ``pack=True`` for real-token density)
28+
- ``megatron_prefill(skip_return_logits=True)`` for the forward primitive (no
29+
logits compute, just activation flow for hooks)
30+
- ``get_batch_on_this_cp_rank`` for context-parallel correctness
31+
"""
32+
33+
from collections.abc import Callable
34+
from typing import TYPE_CHECKING
35+
36+
import torch
37+
from megatron.core.utils import get_batch_on_this_cp_rank
38+
from tqdm import tqdm
39+
40+
from modelopt.torch.utils import distributed as dist
41+
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader
42+
43+
from .megatron_generate import megatron_prefill
44+
45+
if TYPE_CHECKING:
46+
from transformers import PreTrainedTokenizerBase
47+
48+
__all__ = ["get_megatron_calibration_forward_loop"]
49+
50+
51+
def get_megatron_calibration_forward_loop(
52+
tokenizer: "PreTrainedTokenizerBase",
53+
*,
54+
dataset_name: str | list[str] = "cnn_dailymail",
55+
num_samples: int | list[int] = 512,
56+
seq_length: int = 512,
57+
batch_size: int = 1,
58+
pack: bool = True,
59+
apply_chat_template: bool = False,
60+
device: torch.device | str | None = "cuda",
61+
) -> Callable[[torch.nn.Module], None]:
62+
"""Build a Megatron-Core calibration ``forward_loop(model)`` for PTQ / pruning.
63+
64+
The returned callable iterates a ``get_dataset_dataloader``-produced dataloader,
65+
slices each batch to the local context-parallel (CP) rank, and drives a
66+
logits-free prefill pass through the model so activation hooks fire.
67+
68+
Args:
69+
tokenizer: HuggingFace tokenizer. ``pad_token`` is set to ``eos_token`` if
70+
missing so non-packing tokenization paths don't fail.
71+
dataset_name: Dataset key (see :func:`get_supported_datasets`), a path to a
72+
``.jsonl`` file, or a list mixing the two. Multi-source blends are
73+
supported when ``pack=True``.
74+
num_samples: With ``pack=True``, the number of ``seq_length``-token chunks
75+
per source; with ``pack=False``, the number of raw samples (each
76+
padded/truncated). May be a list aligned with ``dataset_name``.
77+
seq_length: Tokens per row. Under ``pack=True`` (default) every row is
78+
exactly this length; under ``pack=False`` it's the truncation /
79+
padding target. Matches Megatron-Core's ``seq_length`` convention.
80+
batch_size: Calibration micro-batch size. Default ``1`` matches the
81+
historical convention. Under ``pack=True`` it is safe to raise this
82+
for throughput — every position is a real token (no per-sample
83+
padding bias), and causal attention masking ensures batch entries
84+
don't cross-attend, so ``mbs=N`` is forward-equivalent to
85+
``mbs=1`` repeated ``N`` times. Under ``pack=False``, keep ``mbs=1``
86+
to avoid pad-token activations contaminating amax / sensitivity
87+
statistics (calibration hooks fire before ``attention_mask`` is applied).
88+
pack: Forwarded to :func:`get_dataset_dataloader`. Default ``True`` here
89+
(vs. ``False`` in the underlying loader) because every Megatron
90+
calibration call site we know of benefits from packing — long
91+
documents stop being truncated and padding stops contaminating
92+
activation statistics.
93+
apply_chat_template: Forwarded to :func:`get_dataset_dataloader`.
94+
device: Forwarded to :func:`get_dataset_dataloader`.
95+
96+
Returns:
97+
A ``forward_loop(model)`` callable to pass into ``mtq.quantize``,
98+
``mtp.prune``, or ``mtq.calibrate``.
99+
100+
Example::
101+
102+
import modelopt.torch.quantization as mtq
103+
from modelopt.torch.utils.plugins.megatron_calibration import (
104+
get_megatron_calibration_forward_loop,
105+
)
106+
107+
forward_loop = get_megatron_calibration_forward_loop(
108+
tokenizer,
109+
dataset_name="cnn_dailymail",
110+
num_samples=1024,
111+
seq_length=512,
112+
)
113+
mtq.quantize(unwrapped_model, mtq_config, forward_loop)
114+
"""
115+
if getattr(tokenizer, "pad_token", None) is None:
116+
tokenizer.pad_token = tokenizer.eos_token
117+
118+
dataloader = get_dataset_dataloader(
119+
dataset_name=dataset_name,
120+
tokenizer=tokenizer,
121+
batch_size=batch_size,
122+
num_samples=num_samples,
123+
max_sample_length=seq_length,
124+
device=device,
125+
apply_chat_template=apply_chat_template,
126+
pack=pack,
127+
)
128+
129+
def _forward_loop(model: torch.nn.Module) -> None:
130+
for sample in tqdm(dataloader, disable=not dist.is_master()):
131+
# CP shard slicing is a no-op under CP=1 and required under CP>1.
132+
sample = get_batch_on_this_cp_rank(sample)
133+
megatron_prefill(model, sample["input_ids"], skip_return_logits=True)
134+
135+
return _forward_loop

0 commit comments

Comments
 (0)