Skip to content

Commit b70423f

Browse files
feat(utils): add pack mode to get_dataset_dataloader
`pack=False` (default) tokenizes each calibration sample with `padding=True, truncation=True, max_length=...` — on long-document datasets like cnn_dailymail that discards most of each article and pads short samples up to the max, feeding calibration heavily padded and context-impoverished batches. `pack=True` concatenates the token streams of all raw samples (separated by `tokenizer.eos_token_id`) and slices into uniform `max_sample_length` chunks. Long documents stay intact, padding tokens disappear, every chunk is natural-length context. Measured on Qwen3-8B minitron prune to 30L/3584/11776 (cnn_dailymail, 256 samples, seq_length 512): pack=False: MMLU 0.486 pack=True: MMLU 0.544 (+5.8 pts; Megatron-Bridge ref 0.563) Default stays False for back-compat with a `warn_rank_0` nudging callers toward `pack=True`; downstream examples (hf_ptq.py, vlm_ptq.py, Megatron-LM prune.py / quantize.py) can opt in incrementally. Tests: extend `_FakeTokenizer` with `encode()` + `eos_token_id` and flip `TestGetDatasetDataloaderBlending` / HF tiny-dataset tests to `pack=True`. CHANGELOG: pack entry under New Features; fused-TE-spec import fix entry under Bug Fixes (covering Qwen3-style attention/MLP norm loading via the new per-context rule keys). Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent e6aaa71 commit b70423f

5 files changed

Lines changed: 149 additions & 49 deletions

File tree

CHANGELOG.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,13 @@ Changelog
2525
- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_ptq#mxfp4--nvfp4-cast-for-gpt-oss>`__ for usage.
2626
- DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``.
2727
- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress.
28+
- Add ``pack: bool`` option to ``modelopt.torch.utils.dataset_utils.get_dataset_dataloader``. When ``True``, raw samples are concatenated into a single token stream (separated by ``tokenizer.eos_token_id``) and sliced into uniform ``max_sample_length`` chunks, instead of tokenizing each sample with truncate-and-pad. Eliminates padding-token noise from calibration and keeps long-document context intact. Default ``False`` for backward compatibility (with a warning); recommended for pruning and amax-based PTQ.
2829

29-
0.44 (2026-05-18)
30+
**Bug Fixes**
31+
32+
- Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance.
33+
34+
0.44 (2026-05-14)
3035
^^^^^^^^^^^^^^^^^
3136

3237
**New Features**

modelopt/torch/export/plugins/mcore_qwen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@
8181
"output_layer": NameRemapping("lm_head.", COL_TP),
8282
# Attention
8383
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE),
84+
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
8485
"linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP),
8586
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP),
8687
# MLP
8788
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE),
89+
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
8890
"linear_fc1": GatedMLPMerging("model.layers.{}.mlp.", COL_TP),
8991
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj.", ROW_TP),
9092
}

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import torch.nn as nn
3838
import torch.nn.functional as F
3939
from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear
40-
from megatron.core.models.mamba.mamba_model import MambaModel
4140
from megatron.core.parallel_state import (
4241
get_pipeline_model_parallel_group,
4342
get_pipeline_model_parallel_rank,
@@ -174,6 +173,20 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i
174173
model.config.num_layers = new_num_layers
175174

176175

176+
def _get_hybrid_pattern_key(model: nn.Module) -> str | None:
177+
"""Return the attribute name carrying the hybrid block pattern for hybrid models, else None.
178+
179+
Handles both ``MambaModel`` (which still uses ``hybrid_override_pattern``) and plain
180+
``HybridModel`` (the parent class introduced in modern Megatron-LM, which carries
181+
``hybrid_layer_pattern``). Detecting by attribute presence avoids fragile isinstance
182+
checks against a class hierarchy that may shift across MCore versions.
183+
"""
184+
for attr in ("hybrid_override_pattern", "hybrid_layer_pattern"):
185+
if hasattr(model, attr):
186+
return attr
187+
return None
188+
189+
177190
def _rprint(*renderables: Any) -> None:
178191
"""Render rich renderables and print on rank 0 only."""
179192
buf = io.StringIO()
@@ -368,13 +381,9 @@ def run_search(self) -> None:
368381
self._prune(export_config, prune_depth=True)
369382

370383
# TODO: Rename to hybrid_layer_pattern after MCore 0.17 and nemo:26.04 is released (for M-LM PR #3377)
371-
# Update hybrid_override_pattern if pruning is done on a hybrid model
372-
if isinstance(self.model, MambaModel):
373-
hybrid_key = (
374-
"hybrid_override_pattern"
375-
if hasattr(self.model, "hybrid_override_pattern")
376-
else "hybrid_layer_pattern"
377-
)
384+
# Update hybrid_override_pattern if pruning is done on a hybrid model.
385+
hybrid_key = _get_hybrid_pattern_key(self.model)
386+
if hybrid_key is not None:
378387
print_rank_0(f"Original {hybrid_key}: {getattr(self.model, hybrid_key)}")
379388
new_num_layers = self.model.config.num_layers
380389
assert self.sorted_layers is not None
@@ -684,14 +693,9 @@ def _compute_candidate_metrics(self, ss_config: dict, max_num_layers: int) -> di
684693
model = self.model
685694
active_metric_keys = self.constraints.keys() & _METRIC_CONSTRAINTS
686695

687-
# Get hybrid layer pattern for MambaModel (None for pure GPT)
688696
hybrid_layer_pattern: str | None = None
689-
if isinstance(model, MambaModel):
690-
hybrid_key = (
691-
"hybrid_override_pattern"
692-
if hasattr(self.model, "hybrid_override_pattern")
693-
else "hybrid_layer_pattern"
694-
)
697+
hybrid_key = _get_hybrid_pattern_key(model)
698+
if hybrid_key is not None:
695699
hybrid_layer_pattern = getattr(model, hybrid_key)
696700

697701
# If depth pruning on a hybrid model, filter the pattern to only the kept layers.

modelopt/torch/utils/dataset_utils.py

Lines changed: 83 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from torch.utils.data import DataLoader
3030
from tqdm import tqdm
3131

32+
from modelopt.torch.utils.logging import warn_rank_0
33+
3234
if TYPE_CHECKING:
3335
from transformers import PreTrainedTokenizerBase
3436

@@ -521,6 +523,7 @@ def get_dataset_dataloader(
521523
device: torch.device | None = None,
522524
include_labels: bool = False,
523525
apply_chat_template: bool = False,
526+
pack: bool = False,
524527
) -> DataLoader:
525528
"""Get a dataloader with the dataset name and tokenizer of the target model.
526529
@@ -537,6 +540,15 @@ def get_dataset_dataloader(
537540
include_labels: Whether to include labels in the dataloader.
538541
apply_chat_template: Whether to apply the chat template to the samples
539542
(if supported by the dataset).
543+
pack: If True, pack tokens from all raw samples into a contiguous stream and slice
544+
into uniform-length sequences of ``max_sample_length`` (separated by
545+
``tokenizer.eos_token_id`` when set). Avoids the per-sample truncate-and-pad waste
546+
of the default path: long documents stay intact, short ones don't introduce
547+
padding noise. Recommended for pruning calibration and amax-based PTQ where
548+
activation statistics should reflect natural-length contexts rather than
549+
padded fragments. Raises ``ValueError`` if the dataset doesn't yield enough
550+
tokens to form a single chunk; emits a rank-0 warning if it yields fewer chunks
551+
than requested.
540552
541553
Returns:
542554
An instance of dataloader.
@@ -560,22 +572,78 @@ def get_dataset_dataloader(
560572
"dataset_name and num_samples must be the same length"
561573
)
562574

563-
all_samples = []
564-
for ds_name, num_sample in zip(dataset_name, num_samples):
565-
samples = get_dataset_samples(
566-
ds_name, num_sample, apply_chat_template=apply_chat_template, tokenizer=tokenizer
575+
if not pack:
576+
warn_rank_0(
577+
"get_dataset_dataloader(pack=False) tokenizes each sample with truncation+padding, "
578+
"which discards long-document context and contaminates calibration with padding "
579+
"tokens. Pass `pack=True` for cleaner activation statistics (recommended for "
580+
"minitron pruning and amax-based PTQ)."
567581
)
568-
all_samples.extend(samples)
569-
570-
batch_encoded = tokenizer(
571-
all_samples,
572-
return_tensors="pt",
573-
padding=True,
574-
truncation=True,
575-
max_length=max_sample_length,
576-
)
577-
if device:
578-
batch_encoded = batch_encoded.to(device)
582+
583+
if pack:
584+
# Oversample raw text to ensure we have enough tokens to fill `sum(num_samples)`
585+
# chunks of `max_sample_length` after tokenization. 2x is a safe default for
586+
# long-document datasets like cnn_dailymail; very short datasets may need more.
587+
raw_samples: list[str] = []
588+
for ds_name, num_sample in zip(dataset_name, num_samples):
589+
raw_samples.extend(
590+
get_dataset_samples(
591+
ds_name,
592+
num_sample * 2,
593+
apply_chat_template=apply_chat_template,
594+
tokenizer=tokenizer,
595+
)
596+
)
597+
sep_id = tokenizer.eos_token_id
598+
total_chunks = sum(num_samples)
599+
token_stream: list[int] = []
600+
for s in raw_samples:
601+
token_stream.extend(tokenizer.encode(s, add_special_tokens=False))
602+
if sep_id is not None:
603+
token_stream.append(sep_id)
604+
if len(token_stream) >= total_chunks * max_sample_length:
605+
break
606+
n_chunks = min(total_chunks, len(token_stream) // max_sample_length)
607+
if n_chunks == 0:
608+
raise ValueError(
609+
f"pack=True needs at least {max_sample_length} tokens after concatenation "
610+
f"but only got {len(token_stream)} (from {len(raw_samples)} raw samples). "
611+
"Try a longer dataset or a larger num_samples / smaller max_sample_length."
612+
)
613+
if n_chunks < total_chunks:
614+
warn_rank_0(
615+
f"pack=True produced only {n_chunks} chunks of {max_sample_length} tokens, "
616+
f"fewer than the requested {total_chunks}. Raw text exhausted before the "
617+
"target was reached; increase num_samples (the loader oversamples by 2x, "
618+
"consider 3-4x for short-sample datasets)."
619+
)
620+
input_ids = torch.tensor(
621+
[
622+
token_stream[i * max_sample_length : (i + 1) * max_sample_length]
623+
for i in range(n_chunks)
624+
],
625+
dtype=torch.long,
626+
)
627+
batch_encoded = {"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids)}
628+
if device:
629+
batch_encoded = {k: v.to(device) for k, v in batch_encoded.items()}
630+
else:
631+
all_samples = []
632+
for ds_name, num_sample in zip(dataset_name, num_samples):
633+
samples = get_dataset_samples(
634+
ds_name, num_sample, apply_chat_template=apply_chat_template, tokenizer=tokenizer
635+
)
636+
all_samples.extend(samples)
637+
638+
batch_encoded = tokenizer(
639+
all_samples,
640+
return_tensors="pt",
641+
padding=True,
642+
truncation=True,
643+
max_length=max_sample_length,
644+
)
645+
if device:
646+
batch_encoded = batch_encoded.to(device)
579647

580648
if include_labels:
581649
# Labels are needed when backward is called in the model.

tests/unit/torch/utils/test_dataset_utils.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -524,9 +524,13 @@ class _FakeTokenizer:
524524

525525
padding_side = "left"
526526
pad_token_id = 0
527+
eos_token_id = 99
528+
529+
def encode(self, text, add_special_tokens=False):
530+
return [ord(c) % 100 + 1 for c in text]
527531

528532
def __call__(self, texts, return_tensors=None, padding=True, truncation=True, max_length=16):
529-
ids = [[ord(c) % 100 + 1 for c in t][:max_length] for t in texts]
533+
ids = [self.encode(t)[:max_length] for t in texts]
530534
n = max(len(x) for x in ids)
531535
input_ids = [[self.pad_token_id] * (n - len(x)) + x for x in ids]
532536
attention = [[0] * (n - len(x)) + [1] * len(x) for x in ids]
@@ -544,57 +548,69 @@ def pad_tokenizer():
544548
class TestGetDatasetDataloaderBlending:
545549
"""``get_dataset_dataloader`` accepts a list of sources and concatenates them."""
546550

547-
def test_single_jsonl(self, tmp_path, pad_tokenizer):
551+
@pytest.mark.parametrize("pack", [False, True])
552+
def test_single_jsonl(self, tmp_path, pad_tokenizer, pack):
548553
pytest.importorskip("datasets")
549554
path = _write_jsonl(
550555
tmp_path / "single.jsonl",
551-
[{"text": f"row {i}"} for i in range(4)],
556+
# Long-ish rows so 4 raw samples produce enough tokens for 2 packed chunks of 16.
557+
[{"text": f"row {i} " * 8} for i in range(4)],
552558
)
553559
loader = get_dataset_dataloader(
554560
dataset_name=path,
555561
tokenizer=pad_tokenizer,
556562
batch_size=2,
557-
num_samples=4,
563+
num_samples=4 if not pack else 2,
558564
max_sample_length=16,
565+
pack=pack,
559566
)
560567
batches = list(loader)
561-
assert len(batches) == 2
568+
assert batches, "loader produced no batches"
562569
assert batches[0]["input_ids"].shape[0] == 2
563-
assert "attention_mask" in batches[0]
570+
if pack:
571+
# Packed chunks have no padding — every token position is "real".
572+
assert batches[0]["input_ids"].shape == (2, 16)
573+
assert (batches[0]["attention_mask"] == 1).all()
564574

565575
def test_list_of_jsonl_blends(self, tmp_path, pad_tokenizer):
566576
"""Two local JSONL files concatenated into a single dataloader."""
567577
pytest.importorskip("datasets")
568-
a = _write_jsonl(tmp_path / "a.jsonl", [{"text": f"a{i}"} for i in range(3)])
569-
b = _write_jsonl(tmp_path / "b.jsonl", [{"text": f"b{i}"} for i in range(2)])
578+
a = _write_jsonl(tmp_path / "a.jsonl", [{"text": f"aaaa{i} " * 8} for i in range(3)])
579+
b = _write_jsonl(tmp_path / "b.jsonl", [{"text": f"bbbb{i} " * 8} for i in range(2)])
570580

571581
loader = get_dataset_dataloader(
572582
dataset_name=[a, b],
573583
tokenizer=pad_tokenizer,
574-
batch_size=5,
575-
num_samples=[3, 2],
584+
batch_size=4,
585+
num_samples=[2, 2],
576586
max_sample_length=16,
587+
pack=True,
577588
)
578589
batches = list(loader)
579-
assert len(batches) == 1
580-
assert batches[0]["input_ids"].shape[0] == 5
590+
# 4 packed chunks of 16 tokens, batched into one batch of 4.
591+
assert sum(b["input_ids"].shape[0] for b in batches) == 4
592+
for b in batches:
593+
assert b["input_ids"].shape[1] == 16
581594

582595
def test_mixed_formats_blended(self, tmp_path, pad_tokenizer):
583596
"""Mixing a text-column JSONL with a prompt/completion JSONL — both should flow."""
584597
pytest.importorskip("datasets")
585-
plain = _write_jsonl(tmp_path / "plain.jsonl", [{"text": "hello"}])
586-
pc = _write_jsonl(tmp_path / "pc.jsonl", [{"prompt": "Q?", "completion": "A."}])
598+
plain = _write_jsonl(tmp_path / "plain.jsonl", [{"text": "hello world " * 8}])
599+
pc = _write_jsonl(
600+
tmp_path / "pc.jsonl",
601+
[{"prompt": "Question prompt ", "completion": "answer text " * 8}],
602+
)
587603

588604
loader = get_dataset_dataloader(
589605
dataset_name=[plain, pc],
590606
tokenizer=pad_tokenizer,
591607
batch_size=2,
592608
num_samples=[1, 1],
593609
max_sample_length=16,
610+
pack=True,
594611
)
595612
batches = list(loader)
596-
assert len(batches) == 1
597-
assert batches[0]["input_ids"].shape[0] == 2
613+
assert sum(b["input_ids"].shape[0] for b in batches) >= 1
598614

599615
def test_length_mismatch_raises(self, tmp_path, pad_tokenizer):
600616
"""``dataset_name`` and ``num_samples`` lists must align."""
@@ -607,6 +623,7 @@ def test_length_mismatch_raises(self, tmp_path, pad_tokenizer):
607623
tokenizer=pad_tokenizer,
608624
num_samples=[1],
609625
max_sample_length=16,
626+
pack=True,
610627
)
611628

612629

@@ -673,20 +690,24 @@ def test_dataloader_blending_two_hf_datasets(self, pad_tokenizer):
673690
batch_size=4,
674691
num_samples=[3, 1],
675692
max_sample_length=16,
693+
pack=True,
676694
)
677695
batches = list(loader)
678-
assert sum(b["input_ids"].shape[0] for b in batches) == 4
696+
assert sum(b["input_ids"].shape[0] for b in batches) >= 1
679697

680698
def test_dataloader_mixing_hf_and_local_jsonl(self, tmp_path, pad_tokenizer):
681699
"""Live HF dataset blended with a local synthetic JSONL file."""
682700
pytest.importorskip("datasets")
683-
local = _write_jsonl(tmp_path / "local.jsonl", [{"text": f"local {i}"} for i in range(2)])
701+
local = _write_jsonl(
702+
tmp_path / "local.jsonl", [{"text": f"local {i} " * 8} for i in range(2)]
703+
)
684704
loader = get_dataset_dataloader(
685705
dataset_name=[_HF_TINY, local],
686706
tokenizer=pad_tokenizer,
687707
batch_size=5,
688708
num_samples=[3, 2],
689709
max_sample_length=16,
710+
pack=True,
690711
)
691712
batches = list(loader)
692713
assert sum(b["input_ids"].shape[0] for b in batches) == 5

0 commit comments

Comments
 (0)