Skip to content

Commit d1597cf

Browse files
Create shared Megatron calibration forward loop for prune / quantize
Replaces the bespoke calibration loops in M-LM and M-Bridge prune/quantize example scripts with a single shared util, ``modelopt.torch.utils.plugins.megatron_calibration.get_megatron_calibration_forward_loop``. The shared loop emits one sample per row (via ``get_dataset_dataloader``), trims each row to its real length using the dataloader's attention mask, and forces EOS at the trimmed last position before forwarding via ``megatron_prefill(skip_return_logits=True)``. Matches legacy ``GPTSFTDataset(add_eos=True)`` semantics exactly. Samples are sorted by real length descending so front batches are mostly full-length (true batched forward); back batches that contain padding fall through to per-row forward to keep padding-token activations out of the calibration hook stream. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent a5bc6f8 commit d1597cf

10 files changed

Lines changed: 205 additions & 204 deletions

File tree

examples/megatron_bridge/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ torchrun --nproc_per_node 2 prune_minitron.py \
102102
--hf_model_name_or_path Qwen/Qwen3-8B \
103103
--prune_target_memory_mb 12288 \
104104
--seq_length 4096 \
105-
--calib_mbs 1 \
105+
--calib_batch_size 1 \
106106
--output_hf_path /tmp/Qwen3-8B-Pruned-12GB
107107
```
108108

examples/megatron_bridge/prune_minitron.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,8 @@
5353
import modelopt.torch.prune as mtp
5454
import modelopt.torch.utils.distributed as dist
5555
from modelopt.torch.utils import get_supported_datasets, print_rank_0, warn_rank_0
56-
from modelopt.torch.utils.plugins.mbridge import (
57-
get_hf_mbridge_calibration_loop,
58-
load_mbridge_model_from_hf,
59-
)
56+
from modelopt.torch.utils.plugins.mbridge import load_mbridge_model_from_hf
57+
from modelopt.torch.utils.plugins.megatron_calibration import get_megatron_calibration_forward_loop
6058
from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu
6159

6260

@@ -104,11 +102,7 @@ def get_args() -> argparse.Namespace:
104102
"--calib_num_samples", type=int, default=1024, help="Number of samples for calibration"
105103
)
106104
# TODO: Add support for pre-training dataset (pre-tokenized)
107-
# TODO: only allow mbs>1 for pretraining dataset
108-
parser.add_argument(
109-
"--calib_mbs", type=int, default=1, choices=[1], help="Calibration micro-batch size"
110-
)
111-
parser.add_argument("--calib_gbs", type=int, default=1, help="Calibration global batch size")
105+
parser.add_argument("--calib_batch_size", type=int, default=1, help="Calibration batch size")
112106
parser.add_argument("--seq_length", type=int, default=4096)
113107
# Pruning parameters
114108
parser.add_argument(
@@ -164,8 +158,8 @@ def get_args() -> argparse.Namespace:
164158
default=None,
165159
help=(
166160
"Batch size used only for KV-cache sizing in --prune_target_memory_mb. "
167-
"Defaults to --calib_mbs when not set. "
168-
"Use this to target an inference batch size that differs from the calibration micro-batch size."
161+
"Defaults to --calib_batch_size when not set. "
162+
"Use this to target an inference batch size that differs from the calibration batch size."
169163
),
170164
)
171165

@@ -296,16 +290,12 @@ def main(args: argparse.Namespace):
296290
init_model_parallel=True,
297291
moe_grouped_gemm=False,
298292
)
299-
forward_loop = get_hf_mbridge_calibration_loop(
300-
model=model,
301-
provider=provider,
302-
tokenizer=tokenizer,
303-
hf_model_name_or_path=args.hf_model_name_or_path,
304-
trust_remote_code=args.trust_remote_code,
293+
forward_loop = get_megatron_calibration_forward_loop(
294+
tokenizer,
305295
dataset_name=args.calib_dataset_name,
306296
num_samples=args.calib_num_samples,
307-
micro_batch_size=args.calib_mbs,
308-
global_batch_size=args.calib_gbs,
297+
seq_length=args.seq_length,
298+
batch_size=args.calib_batch_size,
309299
)
310300

311301
pruning_config = {
@@ -385,7 +375,9 @@ def score_func(m):
385375
pruning_config["top_k"] = args.top_k
386376
# memory_mb constraint requires batch_size and seq_length
387377
pruning_config["batch_size"] = (
388-
args.inference_batch_size if args.inference_batch_size is not None else args.calib_mbs
378+
args.inference_batch_size
379+
if args.inference_batch_size is not None
380+
else args.calib_batch_size
389381
)
390382
pruning_config["seq_length"] = args.seq_length
391383
print_rank_0(f"Pruning constraints: {pruning_constraints}")

examples/pruning/README.md

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ Please see example snippets of both modes for Minitron pruning on Megatron-Bridg
5050
```python
5151
import torch
5252
import modelopt.torch.prune as mtp
53-
from modelopt.torch.utils.plugins.mbridge import (
54-
get_hf_mbridge_calibration_loop,
55-
load_mbridge_model_from_hf,
53+
from modelopt.torch.utils.plugins.mbridge import load_mbridge_model_from_hf
54+
from modelopt.torch.utils.plugins.megatron_calibration import (
55+
get_megatron_calibration_forward_loop,
5656
)
5757

5858
# Import the Megatron-Bridge Qwen3-8B model from Hugging Face checkpoint
@@ -67,13 +67,11 @@ bridge, provider, model, unwrapped_model, tokenizer = load_mbridge_model_from_hf
6767
)
6868

6969
# Set up the forward loop to run on 1024 train samples
70-
forward_loop = get_hf_mbridge_calibration_loop(
71-
model=model,
72-
provider=provider,
73-
tokenizer=tokenizer,
74-
hf_model_name_or_path="Qwen/Qwen3-8B",
70+
forward_loop = get_megatron_calibration_forward_loop(
71+
tokenizer,
7572
dataset_name="nemotron-post-training-dataset-v2",
7673
num_samples=1024,
74+
seq_length=4096,
7775
)
7876

7977
# Run pruning on the unwrapped model

modelopt/torch/utils/dataset_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def get_dataset_dataloader(
563563
batch_size: int = 1,
564564
num_samples: int | list[int] = 512,
565565
max_sample_length: int = 512,
566-
device: torch.device | None = None,
566+
device: torch.device | str | None = None,
567567
include_labels: bool = False,
568568
apply_chat_template: bool = False,
569569
) -> DataLoader:

modelopt/torch/utils/plugins/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
from modelopt.torch.utils import import_plugin
1919

20+
with import_plugin("megatron_calibration"):
21+
from .megatron_calibration import *
22+
2023
with import_plugin("megatron_generate"):
2124
from .megatron_generate import *
2225

modelopt/torch/utils/plugins/mbridge.py

Lines changed: 2 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -14,43 +14,23 @@
1414
# limitations under the License.
1515
"""Megatron-Bridge plugins for using with Model-Optimizer."""
1616

17-
from collections.abc import Callable
1817
from typing import Any
1918

20-
import torch.nn as nn
21-
from datasets import DatasetDict
2219
from megatron.bridge import AutoBridge
23-
from megatron.bridge.data.builders.hf_dataset import HFDatasetConfig
24-
from megatron.bridge.data.loaders import setup_data_iterators
25-
from megatron.bridge.data.utils import get_dataset_provider
2620
from megatron.bridge.models.gpt_provider import GPTModelProvider
2721
from megatron.bridge.models.hf_pretrained.utils import is_safe_repo
2822
from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider
29-
from megatron.bridge.training.config import (
30-
CheckpointConfig,
31-
ConfigContainer,
32-
LoggerConfig,
33-
OptimizerConfig,
34-
SchedulerConfig,
35-
TrainingConfig,
36-
runtime_config_update,
37-
)
38-
from megatron.bridge.training.eval import evaluate_and_print_results
39-
from megatron.bridge.training.gpt_step import forward_step
40-
from megatron.bridge.training.state import GlobalState
41-
from megatron.bridge.training.tokenizers.config import TokenizerConfig
4223
from megatron.core.models.gpt import GPTModel
4324
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
4425
from megatron.core.models.mamba import MambaModel
45-
from megatron.core.parallel_state import get_data_parallel_group
4626
from megatron.core.transformer.module import MegatronModule
4727
from megatron.core.utils import unwrap_model
4828
from transformers import AutoTokenizer
4929

5030
from modelopt.torch.nas.plugins.megatron import get_te_mamba_stack_spec
51-
from modelopt.torch.utils import get_dataset_samples, print_rank_0, warn_rank_0
31+
from modelopt.torch.utils import print_rank_0
5232

53-
__all__ = ["get_hf_mbridge_calibration_loop", "load_mbridge_model_from_hf"]
33+
__all__ = ["load_mbridge_model_from_hf"]
5434

5535

5636
def load_mbridge_model_from_hf(
@@ -118,134 +98,3 @@ def load_mbridge_model_from_hf(
11898
)
11999

120100
return bridge, provider, model, unwrapped_model, tokenizer
121-
122-
123-
def _get_dataset_cfg(
124-
dataset_name: str,
125-
num_samples: int,
126-
seq_length: int,
127-
apply_chat_template: bool = True,
128-
tokenizer: AutoTokenizer | None = None,
129-
) -> HFDatasetConfig:
130-
"""Get a dataset config for the dataset."""
131-
dataset = get_dataset_samples(
132-
dataset_name, num_samples, apply_chat_template=apply_chat_template, tokenizer=tokenizer
133-
)
134-
dataset_cfg = HFDatasetConfig(
135-
dataset_name=f"{dataset_name}_{num_samples}",
136-
dataset_dict=DatasetDict({"train": dataset}),
137-
process_example_fn=lambda example, tokenizer: {"input": example, "output": ""},
138-
seq_length=seq_length,
139-
dataloader_type="batch",
140-
num_workers=1,
141-
do_validation=False,
142-
do_test=False,
143-
val_proportion=None,
144-
split_val_from_train=False,
145-
rewrite=True,
146-
)
147-
148-
return dataset_cfg
149-
150-
151-
def get_hf_mbridge_calibration_loop(
152-
*,
153-
model: list[MegatronModule],
154-
provider: GPTModelProvider | MambaModelProvider,
155-
tokenizer: AutoTokenizer,
156-
hf_model_name_or_path: str,
157-
trust_remote_code: bool = False,
158-
dataset_name: str = "nemotron-post-training-dataset-v2",
159-
num_samples: int = 512,
160-
micro_batch_size: int = 1,
161-
global_batch_size: int = 1,
162-
) -> Callable[[nn.Module], None]:
163-
"""Get a modelopt calibration loop for a Megatron-Bridge model.
164-
165-
Args:
166-
model: The model to calibrate.
167-
provider: The provider to use for the model.
168-
tokenizer: The tokenizer to use for the model.
169-
hf_model_name_or_path: The name or path of the HF model.
170-
trust_remote_code: Whether to trust remote code.
171-
dataset_name: The name of the dataset to use for evaluation.
172-
num_samples: The number of samples to use for evaluation.
173-
micro_batch_size: The micro batch size to use for evaluation.
174-
global_batch_size: The global batch size to use for evaluation.
175-
176-
Returns:
177-
A function that can be used to calibrate the model with a modelopt.torch API.
178-
"""
179-
if global_batch_size < micro_batch_size:
180-
warn_rank_0(
181-
f"{global_batch_size=} is smaller than {micro_batch_size=}. Setting gbs to {micro_batch_size}."
182-
)
183-
global_batch_size = micro_batch_size
184-
num_iters = num_samples // global_batch_size
185-
186-
cfg = ConfigContainer(
187-
model=provider,
188-
train=TrainingConfig(
189-
micro_batch_size=micro_batch_size,
190-
global_batch_size=global_batch_size,
191-
train_iters=num_iters,
192-
eval_iters=num_iters,
193-
skip_train=True,
194-
),
195-
# TODO: Replace validation args in train with validation config in nemo:26.04
196-
# validation=ValidationConfig(eval_iters=num_iters, eval_interval=1, skip_train=True),
197-
dataset=_get_dataset_cfg(
198-
dataset_name,
199-
num_samples,
200-
provider.seq_length,
201-
apply_chat_template=True,
202-
tokenizer=tokenizer,
203-
),
204-
tokenizer=TokenizerConfig(
205-
tokenizer_type="HuggingFaceTokenizer",
206-
tokenizer_model=hf_model_name_or_path,
207-
# NOTE: Issue with Nemotron Nano v2 tokenizer returning bool hence using use_fast=True as a WAR
208-
hf_tokenizer_kwargs={
209-
"trust_remote_code": trust_remote_code,
210-
"use_fast": tokenizer.is_fast,
211-
},
212-
),
213-
# Unused
214-
optimizer=OptimizerConfig(optimizer="adam", lr=1e-4, use_distributed_optimizer=False),
215-
scheduler=SchedulerConfig(lr_decay_style="constant"),
216-
logger=LoggerConfig(),
217-
checkpoint=CheckpointConfig(),
218-
)
219-
runtime_config_update(cfg)
220-
221-
state = GlobalState()
222-
state.cfg = cfg
223-
224-
dataset_provider = get_dataset_provider(cfg.dataset)
225-
226-
def _train_valid_test_datasets_provider(
227-
train_val_test_num_samples: tuple, dataset_cfg: HFDatasetConfig
228-
):
229-
return dataset_provider(train_val_test_num_samples, dataset_cfg, tokenizer=state.tokenizer)
230-
231-
train_data_iterator, _, _ = setup_data_iterators(
232-
cfg=cfg,
233-
train_state=state.train_state,
234-
model_length=len(model),
235-
train_valid_test_datasets_provider=_train_valid_test_datasets_provider,
236-
dp_group=get_data_parallel_group(),
237-
)
238-
239-
def forward_loop(m):
240-
evaluate_and_print_results(
241-
state,
242-
prefix="iteration 1",
243-
forward_step_func=forward_step,
244-
data_iterator=train_data_iterator,
245-
model=model,
246-
config=cfg,
247-
verbose=True,
248-
write_to_tensorboard=False,
249-
)
250-
251-
return forward_loop

0 commit comments

Comments
 (0)