Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/megatron_bridge/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ torchrun --nproc_per_node 2 prune_minitron.py \
--hf_model_name_or_path Qwen/Qwen3-8B \
--prune_target_memory_mb 12288 \
--seq_length 4096 \
--calib_mbs 1 \
--calib_batch_size 1 \
--output_hf_path /tmp/Qwen3-8B-Pruned-12GB
```

Expand Down
32 changes: 12 additions & 20 deletions examples/megatron_bridge/prune_minitron.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,8 @@
import modelopt.torch.prune as mtp
import modelopt.torch.utils.distributed as dist
from modelopt.torch.utils import get_supported_datasets, print_rank_0, warn_rank_0
from modelopt.torch.utils.plugins.mbridge import (
get_hf_mbridge_calibration_loop,
load_mbridge_model_from_hf,
)
from modelopt.torch.utils.plugins.mbridge import load_mbridge_model_from_hf
from modelopt.torch.utils.plugins.megatron_calibration import get_megatron_calibration_forward_loop
from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu


Expand Down Expand Up @@ -104,11 +102,7 @@ def get_args() -> argparse.Namespace:
"--calib_num_samples", type=int, default=1024, help="Number of samples for calibration"
)
# TODO: Add support for pre-training dataset (pre-tokenized)
# TODO: only allow mbs>1 for pretraining dataset
parser.add_argument(
"--calib_mbs", type=int, default=1, choices=[1], help="Calibration micro-batch size"
)
parser.add_argument("--calib_gbs", type=int, default=1, help="Calibration global batch size")
parser.add_argument("--calib_batch_size", type=int, default=1, help="Calibration batch size")
parser.add_argument("--seq_length", type=int, default=4096)
# Pruning parameters
parser.add_argument(
Expand Down Expand Up @@ -164,8 +158,8 @@ def get_args() -> argparse.Namespace:
default=None,
help=(
"Batch size used only for KV-cache sizing in --prune_target_memory_mb. "
"Defaults to --calib_mbs when not set. "
"Use this to target an inference batch size that differs from the calibration micro-batch size."
"Defaults to --calib_batch_size when not set. "
"Use this to target an inference batch size that differs from the calibration batch size."
),
)

Expand Down Expand Up @@ -296,16 +290,12 @@ def main(args: argparse.Namespace):
init_model_parallel=True,
moe_grouped_gemm=False,
)
forward_loop = get_hf_mbridge_calibration_loop(
model=model,
provider=provider,
tokenizer=tokenizer,
hf_model_name_or_path=args.hf_model_name_or_path,
trust_remote_code=args.trust_remote_code,
forward_loop = get_megatron_calibration_forward_loop(
tokenizer,
dataset_name=args.calib_dataset_name,
num_samples=args.calib_num_samples,
micro_batch_size=args.calib_mbs,
global_batch_size=args.calib_gbs,
seq_length=args.seq_length,
batch_size=args.calib_batch_size,
)

pruning_config = {
Expand Down Expand Up @@ -385,7 +375,9 @@ def score_func(m):
pruning_config["top_k"] = args.top_k
# memory_mb constraint requires batch_size and seq_length
pruning_config["batch_size"] = (
args.inference_batch_size if args.inference_batch_size is not None else args.calib_mbs
args.inference_batch_size
if args.inference_batch_size is not None
else args.calib_batch_size
)
Comment thread
kevalmorabia97 marked this conversation as resolved.
pruning_config["seq_length"] = args.seq_length
print_rank_0(f"Pruning constraints: {pruning_constraints}")
Expand Down
14 changes: 6 additions & 8 deletions examples/pruning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ Please see example snippets of both modes for Minitron pruning on Megatron-Bridg
```python
import torch
import modelopt.torch.prune as mtp
from modelopt.torch.utils.plugins.mbridge import (
get_hf_mbridge_calibration_loop,
load_mbridge_model_from_hf,
from modelopt.torch.utils.plugins.mbridge import load_mbridge_model_from_hf
from modelopt.torch.utils.plugins.megatron_calibration import (
get_megatron_calibration_forward_loop,
)

# Import the Megatron-Bridge Qwen3-8B model from Hugging Face checkpoint
Expand All @@ -67,13 +67,11 @@ bridge, provider, model, unwrapped_model, tokenizer = load_mbridge_model_from_hf
)

# Set up the forward loop to run on 1024 train samples
forward_loop = get_hf_mbridge_calibration_loop(
model=model,
provider=provider,
tokenizer=tokenizer,
hf_model_name_or_path="Qwen/Qwen3-8B",
forward_loop = get_megatron_calibration_forward_loop(
tokenizer,
dataset_name="nemotron-post-training-dataset-v2",
num_samples=1024,
seq_length=4096,
)

# Run pruning on the unwrapped model
Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def get_dataset_dataloader(
batch_size: int = 1,
num_samples: int | list[int] = 512,
max_sample_length: int = 512,
device: torch.device | None = None,
device: torch.device | str | None = None,
include_labels: bool = False,
apply_chat_template: bool = False,
) -> DataLoader:
Expand Down
3 changes: 3 additions & 0 deletions modelopt/torch/utils/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

from modelopt.torch.utils import import_plugin

with import_plugin("megatron_calibration"):
from .megatron_calibration import *

with import_plugin("megatron_generate"):
from .megatron_generate import *

Expand Down
155 changes: 2 additions & 153 deletions modelopt/torch/utils/plugins/mbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,23 @@
# limitations under the License.
"""Megatron-Bridge plugins for using with Model-Optimizer."""

from collections.abc import Callable
from typing import Any

import torch.nn as nn
from datasets import DatasetDict
from megatron.bridge import AutoBridge
from megatron.bridge.data.builders.hf_dataset import HFDatasetConfig
from megatron.bridge.data.loaders import setup_data_iterators
from megatron.bridge.data.utils import get_dataset_provider
from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.models.hf_pretrained.utils import is_safe_repo
from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider
from megatron.bridge.training.config import (
CheckpointConfig,
ConfigContainer,
LoggerConfig,
OptimizerConfig,
SchedulerConfig,
TrainingConfig,
runtime_config_update,
)
from megatron.bridge.training.eval import evaluate_and_print_results
from megatron.bridge.training.gpt_step import forward_step
from megatron.bridge.training.state import GlobalState
from megatron.bridge.training.tokenizers.config import TokenizerConfig
from megatron.core.models.gpt import GPTModel
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.models.mamba import MambaModel
from megatron.core.parallel_state import get_data_parallel_group
from megatron.core.transformer.module import MegatronModule
from megatron.core.utils import unwrap_model
from transformers import AutoTokenizer

from modelopt.torch.nas.plugins.megatron import get_te_mamba_stack_spec
from modelopt.torch.utils import get_dataset_samples, print_rank_0, warn_rank_0
from modelopt.torch.utils import print_rank_0

__all__ = ["get_hf_mbridge_calibration_loop", "load_mbridge_model_from_hf"]
__all__ = ["load_mbridge_model_from_hf"]


def load_mbridge_model_from_hf(
Expand Down Expand Up @@ -118,134 +98,3 @@ def load_mbridge_model_from_hf(
)

return bridge, provider, model, unwrapped_model, tokenizer


def _get_dataset_cfg(
dataset_name: str,
num_samples: int,
seq_length: int,
apply_chat_template: bool = True,
tokenizer: AutoTokenizer | None = None,
) -> HFDatasetConfig:
"""Get a dataset config for the dataset."""
dataset = get_dataset_samples(
dataset_name, num_samples, apply_chat_template=apply_chat_template, tokenizer=tokenizer
)
dataset_cfg = HFDatasetConfig(
dataset_name=f"{dataset_name}_{num_samples}",
dataset_dict=DatasetDict({"train": dataset}),
process_example_fn=lambda example, tokenizer: {"input": example, "output": ""},
seq_length=seq_length,
dataloader_type="batch",
num_workers=1,
do_validation=False,
do_test=False,
val_proportion=None,
split_val_from_train=False,
rewrite=True,
)

return dataset_cfg


def get_hf_mbridge_calibration_loop(
*,
model: list[MegatronModule],
provider: GPTModelProvider | MambaModelProvider,
tokenizer: AutoTokenizer,
hf_model_name_or_path: str,
trust_remote_code: bool = False,
dataset_name: str = "nemotron-post-training-dataset-v2",
num_samples: int = 512,
micro_batch_size: int = 1,
global_batch_size: int = 1,
) -> Callable[[nn.Module], None]:
"""Get a modelopt calibration loop for a Megatron-Bridge model.

Args:
model: The model to calibrate.
provider: The provider to use for the model.
tokenizer: The tokenizer to use for the model.
hf_model_name_or_path: The name or path of the HF model.
trust_remote_code: Whether to trust remote code.
dataset_name: The name of the dataset to use for evaluation.
num_samples: The number of samples to use for evaluation.
micro_batch_size: The micro batch size to use for evaluation.
global_batch_size: The global batch size to use for evaluation.

Returns:
A function that can be used to calibrate the model with a modelopt.torch API.
"""
if global_batch_size < micro_batch_size:
warn_rank_0(
f"{global_batch_size=} is smaller than {micro_batch_size=}. Setting gbs to {micro_batch_size}."
)
global_batch_size = micro_batch_size
num_iters = num_samples // global_batch_size

cfg = ConfigContainer(
model=provider,
train=TrainingConfig(
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
train_iters=num_iters,
eval_iters=num_iters,
skip_train=True,
),
# TODO: Replace validation args in train with validation config in nemo:26.04
# validation=ValidationConfig(eval_iters=num_iters, eval_interval=1, skip_train=True),
dataset=_get_dataset_cfg(
dataset_name,
num_samples,
provider.seq_length,
apply_chat_template=True,
tokenizer=tokenizer,
),
tokenizer=TokenizerConfig(
tokenizer_type="HuggingFaceTokenizer",
tokenizer_model=hf_model_name_or_path,
# NOTE: Issue with Nemotron Nano v2 tokenizer returning bool hence using use_fast=True as a WAR
hf_tokenizer_kwargs={
"trust_remote_code": trust_remote_code,
"use_fast": tokenizer.is_fast,
},
),
# Unused
optimizer=OptimizerConfig(optimizer="adam", lr=1e-4, use_distributed_optimizer=False),
scheduler=SchedulerConfig(lr_decay_style="constant"),
logger=LoggerConfig(),
checkpoint=CheckpointConfig(),
)
runtime_config_update(cfg)

state = GlobalState()
state.cfg = cfg

dataset_provider = get_dataset_provider(cfg.dataset)

def _train_valid_test_datasets_provider(
train_val_test_num_samples: tuple, dataset_cfg: HFDatasetConfig
):
return dataset_provider(train_val_test_num_samples, dataset_cfg, tokenizer=state.tokenizer)

train_data_iterator, _, _ = setup_data_iterators(
cfg=cfg,
train_state=state.train_state,
model_length=len(model),
train_valid_test_datasets_provider=_train_valid_test_datasets_provider,
dp_group=get_data_parallel_group(),
)

def forward_loop(m):
evaluate_and_print_results(
state,
prefix="iteration 1",
forward_step_func=forward_step,
data_iterator=train_data_iterator,
model=model,
config=cfg,
verbose=True,
write_to_tensorboard=False,
)

return forward_loop
Loading
Loading