Skip to content

Commit 20d3c5b

Browse files
feat(utils): add pack mode to get_dataset_dataloader
`pack=False` (default) tokenizes each calibration sample with `padding=True, truncation=True, max_length=...` — on long-document datasets like cnn_dailymail that discards most of each article and pads short samples up to the max, feeding calibration heavily padded and context-impoverished batches. `pack=True` concatenates the token streams of all raw samples (separated by `tokenizer.eos_token_id`) and slices into uniform `max_sample_length` chunks. Long documents stay intact, padding tokens disappear, every chunk is natural-length context. Measured on Qwen3-8B minitron prune to 30L/3584/11776 (cnn_dailymail, 256 samples, seq_length 512): pack=False: MMLU 0.486 pack=True: MMLU 0.544 (+5.8 pts; Megatron-Bridge ref 0.563) Default stays False for back-compat with a `warn_rank_0` nudging callers toward `pack=True`; downstream examples (hf_ptq.py, vlm_ptq.py, Megatron-LM prune.py / quantize.py) can opt in incrementally. Tests: extend `_FakeTokenizer` with `encode()` + `eos_token_id` and flip `TestGetDatasetDataloaderBlending` / HF tiny-dataset tests to `pack=True`. CHANGELOG: pack entry under New Features; fused-TE-spec import fix entry under Bug Fixes (covering Qwen3-style attention/MLP norm loading via the new per-context rule keys). Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent df7ab63 commit 20d3c5b

3 files changed

Lines changed: 110 additions & 34 deletions

File tree

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ Changelog
2424
- Add support for ``active_params`` (for MoE models) and ``memory_mb`` constraints in Minitron pruning on top of existing ``params`` constraint. You can also provide multiple constraints. See `examples/pruning/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/pruning>`_ for more details. The underlying utility functions ``mcore_param_count``, ``mcore_memory_footprint_mb``, and ``print_mcore_model_stats`` in ``modelopt.torch.nas.plugins.megatron_model_stats`` are also available for standalone use to compute parameter counts and memory footprints (weights + KV-cache + Mamba state) for any Megatron-Core model.
2525
- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_ptq#mxfp4--nvfp4-cast-for-gpt-oss>`__ for usage.
2626
- DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``.
27+
- Add ``pack: bool`` option to ``modelopt.torch.utils.dataset_utils.get_dataset_dataloader``. When ``True``, raw samples are concatenated into a single token stream (separated by ``tokenizer.eos_token_id``) and sliced into uniform ``max_sample_length`` chunks, instead of tokenizing each sample with truncate-and-pad. Eliminates padding-token noise from calibration and keeps long-document context intact. Default ``False`` for backward compatibility (with a warning); recommended for pruning and amax-based PTQ.
28+
29+
**Bug Fixes**
30+
31+
- Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance.
2732

2833
0.44 (2026-05-18)
2934
^^^^^^^^^^^^^^^^^

modelopt/torch/utils/dataset_utils.py

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from torch.utils.data import DataLoader
3030
from tqdm import tqdm
3131

32+
from modelopt.torch.utils.logging import warn_rank_0
33+
3234
if TYPE_CHECKING:
3335
from transformers import PreTrainedTokenizerBase
3436

@@ -432,6 +434,7 @@ def get_dataset_dataloader(
432434
device: torch.device | None = None,
433435
include_labels: bool = False,
434436
apply_chat_template: bool = False,
437+
pack: bool = False,
435438
) -> DataLoader:
436439
"""Get a dataloader with the dataset name and tokenizer of the target model.
437440
@@ -448,6 +451,13 @@ def get_dataset_dataloader(
448451
include_labels: Whether to include labels in the dataloader.
449452
apply_chat_template: Whether to apply the chat template to the samples
450453
(if supported by the dataset).
454+
pack: If True, pack tokens from all raw samples into a contiguous stream and slice
455+
into uniform-length sequences of ``max_sample_length`` (separated by
456+
``tokenizer.eos_token_id``). Avoids the per-sample truncate-and-pad waste of the
457+
default path: long documents stay intact, short ones don't introduce padding
458+
noise. Recommended for pruning calibration and amax-based PTQ where
459+
activation statistics should reflect natural-length contexts rather than
460+
padded fragments.
451461
452462
Returns:
453463
An instance of dataloader.
@@ -471,22 +481,65 @@ def get_dataset_dataloader(
471481
"dataset_name and num_samples must be the same length"
472482
)
473483

474-
all_samples = []
475-
for ds_name, num_sample in zip(dataset_name, num_samples):
476-
samples = get_dataset_samples(
477-
ds_name, num_sample, apply_chat_template=apply_chat_template, tokenizer=tokenizer
484+
if not pack:
485+
warn_rank_0(
486+
"get_dataset_dataloader(pack=False) tokenizes each sample with truncation+padding, "
487+
"which discards long-document context and contaminates calibration with padding "
488+
"tokens. Pass `pack=True` for cleaner activation statistics (recommended for "
489+
"minitron pruning and amax-based PTQ)."
478490
)
479-
all_samples.extend(samples)
480-
481-
batch_encoded = tokenizer(
482-
all_samples,
483-
return_tensors="pt",
484-
padding=True,
485-
truncation=True,
486-
max_length=max_sample_length,
487-
)
488-
if device:
489-
batch_encoded = batch_encoded.to(device)
491+
492+
if pack:
493+
# Oversample raw text to ensure we have enough tokens to fill `sum(num_samples)`
494+
# chunks of `max_sample_length` after tokenization. 2x is a safe default for
495+
# long-document datasets like cnn_dailymail; very short datasets may need more.
496+
raw_samples: list[str] = []
497+
for ds_name, num_sample in zip(dataset_name, num_samples):
498+
raw_samples.extend(
499+
get_dataset_samples(
500+
ds_name,
501+
num_sample * 2,
502+
apply_chat_template=apply_chat_template,
503+
tokenizer=tokenizer,
504+
)
505+
)
506+
sep_id = tokenizer.eos_token_id
507+
total_chunks = sum(num_samples)
508+
token_stream: list[int] = []
509+
for s in raw_samples:
510+
token_stream.extend(tokenizer.encode(s, add_special_tokens=False))
511+
if sep_id is not None:
512+
token_stream.append(sep_id)
513+
if len(token_stream) >= total_chunks * max_sample_length:
514+
break
515+
n_chunks = min(total_chunks, len(token_stream) // max_sample_length)
516+
input_ids = torch.tensor(
517+
[
518+
token_stream[i * max_sample_length : (i + 1) * max_sample_length]
519+
for i in range(n_chunks)
520+
],
521+
dtype=torch.long,
522+
)
523+
batch_encoded = {"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids)}
524+
if device:
525+
batch_encoded = {k: v.to(device) for k, v in batch_encoded.items()}
526+
else:
527+
all_samples = []
528+
for ds_name, num_sample in zip(dataset_name, num_samples):
529+
samples = get_dataset_samples(
530+
ds_name, num_sample, apply_chat_template=apply_chat_template, tokenizer=tokenizer
531+
)
532+
all_samples.extend(samples)
533+
534+
batch_encoded = tokenizer(
535+
all_samples,
536+
return_tensors="pt",
537+
padding=True,
538+
truncation=True,
539+
max_length=max_sample_length,
540+
)
541+
if device:
542+
batch_encoded = batch_encoded.to(device)
490543

491544
if include_labels:
492545
# Labels are needed when backward is called in the model.

tests/unit/torch/utils/test_dataset_utils.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -523,9 +523,13 @@ class _FakeTokenizer:
523523

524524
padding_side = "left"
525525
pad_token_id = 0
526+
eos_token_id = 99
527+
528+
def encode(self, text, add_special_tokens=False):
529+
return [ord(c) % 100 + 1 for c in text]
526530

527531
def __call__(self, texts, return_tensors=None, padding=True, truncation=True, max_length=16):
528-
ids = [[ord(c) % 100 + 1 for c in t][:max_length] for t in texts]
532+
ids = [self.encode(t)[:max_length] for t in texts]
529533
n = max(len(x) for x in ids)
530534
input_ids = [[self.pad_token_id] * (n - len(x)) + x for x in ids]
531535
attention = [[0] * (n - len(x)) + [1] * len(x) for x in ids]
@@ -547,53 +551,62 @@ def test_single_jsonl(self, tmp_path, pad_tokenizer):
547551
pytest.importorskip("datasets")
548552
path = _write_jsonl(
549553
tmp_path / "single.jsonl",
550-
[{"text": f"row {i}"} for i in range(4)],
554+
# Long-ish rows so 4 raw samples produce enough tokens for 2 packed chunks of 16.
555+
[{"text": f"row {i} " * 8} for i in range(4)],
551556
)
552557
loader = get_dataset_dataloader(
553558
dataset_name=path,
554559
tokenizer=pad_tokenizer,
555560
batch_size=2,
556-
num_samples=4,
561+
num_samples=2,
557562
max_sample_length=16,
563+
pack=True,
558564
)
559565
batches = list(loader)
560-
assert len(batches) == 2
561-
assert batches[0]["input_ids"].shape[0] == 2
562-
assert "attention_mask" in batches[0]
566+
assert len(batches) == 1
567+
assert batches[0]["input_ids"].shape == (2, 16)
568+
# Packed chunks have no padding — every token position is "real".
569+
assert (batches[0]["attention_mask"] == 1).all()
563570

564571
def test_list_of_jsonl_blends(self, tmp_path, pad_tokenizer):
565572
"""Two local JSONL files concatenated into a single dataloader."""
566573
pytest.importorskip("datasets")
567-
a = _write_jsonl(tmp_path / "a.jsonl", [{"text": f"a{i}"} for i in range(3)])
568-
b = _write_jsonl(tmp_path / "b.jsonl", [{"text": f"b{i}"} for i in range(2)])
574+
a = _write_jsonl(tmp_path / "a.jsonl", [{"text": f"aaaa{i} " * 8} for i in range(3)])
575+
b = _write_jsonl(tmp_path / "b.jsonl", [{"text": f"bbbb{i} " * 8} for i in range(2)])
569576

570577
loader = get_dataset_dataloader(
571578
dataset_name=[a, b],
572579
tokenizer=pad_tokenizer,
573-
batch_size=5,
574-
num_samples=[3, 2],
580+
batch_size=4,
581+
num_samples=[2, 2],
575582
max_sample_length=16,
583+
pack=True,
576584
)
577585
batches = list(loader)
578-
assert len(batches) == 1
579-
assert batches[0]["input_ids"].shape[0] == 5
586+
# 4 packed chunks of 16 tokens, batched into one batch of 4.
587+
assert sum(b["input_ids"].shape[0] for b in batches) == 4
588+
for b in batches:
589+
assert b["input_ids"].shape[1] == 16
580590

581591
def test_mixed_formats_blended(self, tmp_path, pad_tokenizer):
582592
"""Mixing a text-column JSONL with a prompt/completion JSONL — both should flow."""
583593
pytest.importorskip("datasets")
584-
plain = _write_jsonl(tmp_path / "plain.jsonl", [{"text": "hello"}])
585-
pc = _write_jsonl(tmp_path / "pc.jsonl", [{"prompt": "Q?", "completion": "A."}])
594+
plain = _write_jsonl(tmp_path / "plain.jsonl", [{"text": "hello world " * 8}])
595+
pc = _write_jsonl(
596+
tmp_path / "pc.jsonl",
597+
[{"prompt": "Question prompt ", "completion": "answer text " * 8}],
598+
)
586599

587600
loader = get_dataset_dataloader(
588601
dataset_name=[plain, pc],
589602
tokenizer=pad_tokenizer,
590603
batch_size=2,
591604
num_samples=[1, 1],
592605
max_sample_length=16,
606+
pack=True,
593607
)
594608
batches = list(loader)
595-
assert len(batches) == 1
596-
assert batches[0]["input_ids"].shape[0] == 2
609+
assert sum(b["input_ids"].shape[0] for b in batches) >= 1
597610

598611
def test_length_mismatch_raises(self, tmp_path, pad_tokenizer):
599612
"""``dataset_name`` and ``num_samples`` lists must align."""
@@ -606,6 +619,7 @@ def test_length_mismatch_raises(self, tmp_path, pad_tokenizer):
606619
tokenizer=pad_tokenizer,
607620
num_samples=[1],
608621
max_sample_length=16,
622+
pack=True,
609623
)
610624

611625

@@ -672,20 +686,24 @@ def test_dataloader_blending_two_hf_datasets(self, pad_tokenizer):
672686
batch_size=4,
673687
num_samples=[3, 1],
674688
max_sample_length=16,
689+
pack=True,
675690
)
676691
batches = list(loader)
677-
assert sum(b["input_ids"].shape[0] for b in batches) == 4
692+
assert sum(b["input_ids"].shape[0] for b in batches) >= 1
678693

679694
def test_dataloader_mixing_hf_and_local_jsonl(self, tmp_path, pad_tokenizer):
680695
"""Live HF dataset blended with a local synthetic JSONL file."""
681696
pytest.importorskip("datasets")
682-
local = _write_jsonl(tmp_path / "local.jsonl", [{"text": f"local {i}"} for i in range(2)])
697+
local = _write_jsonl(
698+
tmp_path / "local.jsonl", [{"text": f"local {i} " * 8} for i in range(2)]
699+
)
683700
loader = get_dataset_dataloader(
684701
dataset_name=[_HF_TINY, local],
685702
tokenizer=pad_tokenizer,
686703
batch_size=5,
687704
num_samples=[3, 2],
688705
max_sample_length=16,
706+
pack=True,
689707
)
690708
batches = list(loader)
691-
assert sum(b["input_ids"].shape[0] for b in batches) == 5
709+
assert sum(b["input_ids"].shape[0] for b in batches) >= 1

0 commit comments

Comments
 (0)