Skip to content

Commit c4c662e

Browse files
feat(utils): support pack=True calibration mode for get_dataset_dataloader
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent d773356 commit c4c662e

3 files changed

Lines changed: 236 additions & 51 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Changelog
2525
- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_ptq#mxfp4--nvfp4-cast-for-gpt-oss>`__ for usage.
2626
- DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``.
2727
- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress.
28+
- Add ``pack: bool`` option to ``modelopt.torch.utils.dataset_utils.get_dataset_dataloader``. When ``True``, raw samples from each source are concatenated into a per-source token stream (separated by ``tokenizer.eos_token_id``) and sliced into uniform ``max_sample_length`` chunks, preserving the requested per-source ratio in ``num_samples``. Eliminates padding-token noise from calibration and keeps long-document context intact. Default ``False`` for backward compatibility; recommended for pruning and amax-based PTQ.
2829

2930
**Bug Fixes**
3031

modelopt/torch/utils/dataset_utils.py

Lines changed: 147 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from torch.utils.data import DataLoader
3030
from tqdm import tqdm
3131

32+
from .logging import warn_rank_0
33+
3234
if TYPE_CHECKING:
3335
from transformers import PreTrainedTokenizerBase
3436

@@ -512,6 +514,103 @@ def __len__(self):
512514
return len(next(iter(self.encodings.values())))
513515

514516

517+
def _build_packed_input_ids(
518+
dataset_name: list[str],
519+
num_samples: list[int],
520+
max_sample_length: int,
521+
tokenizer: "PreTrainedTokenizerBase",
522+
apply_chat_template: bool,
523+
) -> torch.Tensor:
524+
"""Pack raw samples into a ``(n_chunks, max_sample_length)`` int tensor.
525+
526+
Each source contributes ``num_sample`` chunks (or fewer if exhausted), so the requested
527+
per-source ratio in ``num_samples`` is preserved instead of letting whichever source
528+
appears first dominate the budget. Within a source, tokenization runs in batches of
529+
``max(8, num_sample // 4)`` samples so we stop tokenizing once the chunk budget is
530+
full, instead of eagerly paying for the entire ``num_sample * 2`` oversample.
531+
532+
Documents are separated by ``tokenizer.eos_token_id`` when set; ``add_special_tokens=False``
533+
avoids injecting a fresh BOS at every sample boundary. Note that packed chunks therefore
534+
have no BOS at position 0 — fine for amax / sensitivity calibration where boundary
535+
tokens are statistically dominated, less ideal for callers that need BOS-prefixed
536+
sequences (use ``pack=False`` for those). When ``apply_chat_template=True``, the rendered
537+
samples often already end with the chat EOS marker (e.g. ``<|im_end|>``), which can
538+
tokenize to ``eos_token_id`` and produce ``<eos><eos>`` at document boundaries —
539+
harmless for calibration statistics but worth noting.
540+
541+
Sizing note: ``num_sample`` here is the desired chunk count per source. The loader
542+
internally fetches ``num_sample * 2`` raw samples. Short-document sources can still
543+
under-fill — to recover the target, scale ``num_sample`` itself (which doubles both
544+
the target and the internal raw-sample draw). Example: short-row source returning 1
545+
chunk for ``num_sample=64`` typically returns 4 chunks for ``num_sample=128`` because
546+
the raw draw goes from 128 to 256.
547+
"""
548+
sep_id = tokenizer.eos_token_id
549+
if sep_id is None:
550+
warn_rank_0(
551+
"pack=True: tokenizer has no eos_token_id; raw documents will be concatenated "
552+
"without a separator, so calibration activations will span document boundaries. "
553+
"Set tokenizer.eos_token_id (or another sentinel) for explicit separators."
554+
)
555+
556+
per_source_chunks: list[list[int]] = []
557+
actual_per_source: list[int] = []
558+
for ds_name, num_sample in zip(dataset_name, num_samples):
559+
# 2x oversample sized for cnn_dailymail-style long docs; short-sample datasets may
560+
# still under-fill and trigger the warning below.
561+
raw_samples = get_dataset_samples(
562+
ds_name,
563+
num_sample * 2,
564+
apply_chat_template=apply_chat_template,
565+
tokenizer=tokenizer,
566+
)
567+
needed_tokens = num_sample * max_sample_length
568+
# max(8, ...) floor keeps the Rust-batched tokenizer happy for small calibrations
569+
# (num_sample < 32 → batch is 8); above that, `// 4` grows the batch with the
570+
# request while keeping the early-exit check granular enough to actually skip
571+
# tokenizing the back half of the 2x oversample on long-doc sources.
572+
tokenize_batch_size = max(8, num_sample // 4)
573+
stream: list[int] = []
574+
for batch_start in range(0, len(raw_samples), tokenize_batch_size):
575+
if len(stream) >= needed_tokens:
576+
break
577+
batch = raw_samples[batch_start : batch_start + tokenize_batch_size]
578+
# padding/truncation=False explicit: don't trust subclass __call__ defaults.
579+
encoded = tokenizer(batch, add_special_tokens=False, padding=False, truncation=False)[
580+
"input_ids"
581+
]
582+
for ids in encoded:
583+
stream.extend(ids)
584+
if sep_id is not None:
585+
stream.append(sep_id)
586+
if len(stream) >= needed_tokens:
587+
break
588+
available = len(stream) // max_sample_length
589+
take = min(num_sample, available)
590+
per_source_chunks.extend(
591+
stream[i * max_sample_length : (i + 1) * max_sample_length] for i in range(take)
592+
)
593+
actual_per_source.append(take)
594+
595+
n_chunks = len(per_source_chunks)
596+
total_chunks = sum(num_samples)
597+
if n_chunks == 0:
598+
raise ValueError(
599+
f"pack=True yielded 0 chunks across {len(dataset_name)} source(s); each source "
600+
f"needs at least {max_sample_length} tokens after concatenation. Try longer "
601+
"samples or a smaller max_sample_length."
602+
)
603+
if n_chunks < total_chunks:
604+
warn_rank_0(
605+
f"pack=True produced {n_chunks} chunks (per-source {actual_per_source}) vs "
606+
f"requested {total_chunks} (per-source {list(num_samples)}). Some sources "
607+
"exhausted before reaching their target. The loader internally fetches "
608+
"`num_samples * 2` raw samples per source; for very short-sample sources, "
609+
"pass a 2-3x larger `num_samples` so the 2x draw covers the chunk budget."
610+
)
611+
return torch.tensor(per_source_chunks, dtype=torch.long)
612+
613+
515614
def get_dataset_dataloader(
516615
dataset_name: str | list[str] = "cnn_dailymail",
517616
tokenizer: "PreTrainedTokenizerBase | None" = None,
@@ -521,6 +620,7 @@ def get_dataset_dataloader(
521620
device: torch.device | None = None,
522621
include_labels: bool = False,
523622
apply_chat_template: bool = False,
623+
pack: bool = False,
524624
) -> DataLoader:
525625
"""Get a dataloader with the dataset name and tokenizer of the target model.
526626
@@ -531,12 +631,31 @@ def get_dataset_dataloader(
531631
an ``int`` (applied to a single source) or a list aligned with ``dataset_name``.
532632
tokenizer: Instance of HuggingFace tokenizer.
533633
batch_size: Batch size of the returned dataloader.
534-
num_samples: Number of samples from the dataset.
634+
num_samples: Number of samples from the dataset. Semantics depend on ``pack``:
635+
with ``pack=False`` this is the number of raw samples to fetch and tokenize
636+
(each becomes one row of ``max_sample_length`` after truncate-and-pad); with
637+
``pack=True`` this is the number of ``max_sample_length``-token chunks to
638+
produce per source. Migrating an existing call site to ``pack=True`` may
639+
therefore need a different value to hit the same total-token calibration
640+
budget.
535641
max_sample_length: Maximum length of a sample.
536642
device: Target device for the returned dataloader.
537643
include_labels: Whether to include labels in the dataloader.
538644
apply_chat_template: Whether to apply the chat template to the samples
539645
(if supported by the dataset).
646+
pack: If True, raw samples from each source are concatenated into a per-source token
647+
stream (separated by ``tokenizer.eos_token_id`` when set) and sliced into
648+
uniform-length chunks of ``max_sample_length``; the per-source chunks are then
649+
concatenated **contiguously by source** (no cross-source interleaving), preserving
650+
the requested per-source ratio in ``num_samples``. Avoids the per-sample
651+
truncate-and-pad waste of the default path: long documents stay intact, short
652+
ones don't introduce padding noise. Recommended for pruning calibration and
653+
amax-based PTQ where activation statistics should reflect natural-length
654+
contexts rather than padded fragments. ``attention_mask`` is unconditionally
655+
all-ones — attention crosses document boundaries (the ``eos`` separator is a
656+
token, not a mask boundary). Raises ``ValueError`` if the dataset doesn't yield
657+
enough tokens to form a single chunk; emits a rank-0 warning if it yields
658+
fewer chunks than requested.
540659
541660
Returns:
542661
An instance of dataloader.
@@ -560,22 +679,30 @@ def get_dataset_dataloader(
560679
"dataset_name and num_samples must be the same length"
561680
)
562681

563-
all_samples = []
564-
for ds_name, num_sample in zip(dataset_name, num_samples):
565-
samples = get_dataset_samples(
566-
ds_name, num_sample, apply_chat_template=apply_chat_template, tokenizer=tokenizer
682+
if pack:
683+
input_ids = _build_packed_input_ids(
684+
dataset_name, num_samples, max_sample_length, tokenizer, apply_chat_template
567685
)
568-
all_samples.extend(samples)
569-
570-
batch_encoded = tokenizer(
571-
all_samples,
572-
return_tensors="pt",
573-
padding=True,
574-
truncation=True,
575-
max_length=max_sample_length,
576-
)
577-
if device:
578-
batch_encoded = batch_encoded.to(device)
686+
batch_encoded = {"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids)}
687+
if device:
688+
batch_encoded = {k: v.to(device) for k, v in batch_encoded.items()}
689+
else:
690+
all_samples = []
691+
for ds_name, num_sample in zip(dataset_name, num_samples):
692+
samples = get_dataset_samples(
693+
ds_name, num_sample, apply_chat_template=apply_chat_template, tokenizer=tokenizer
694+
)
695+
all_samples.extend(samples)
696+
697+
batch_encoded = tokenizer(
698+
all_samples,
699+
return_tensors="pt",
700+
padding=True,
701+
truncation=True,
702+
max_length=max_sample_length,
703+
)
704+
if device:
705+
batch_encoded = batch_encoded.to(device)
579706

580707
if include_labels:
581708
# Labels are needed when backward is called in the model.
@@ -844,6 +971,7 @@ def create_forward_loop(
844971
include_labels: bool = False,
845972
dataloader: DataLoader | None = None,
846973
allowed_non_tensor_keys: set | None = None,
974+
pack: bool = False,
847975
) -> Callable:
848976
"""Creates and returns a forward loop function configured for a specific model, dataset, and tokenizer.
849977
@@ -865,6 +993,8 @@ def create_forward_loop(
865993
allowed_non_tensor_keys: Set of key names whose batch values may be non-tensor types.
866994
Useful when the dataloader yields batches with non-standard fields (e.g., nested
867995
model outputs).
996+
pack: Forwarded to :func:`get_dataset_dataloader`. See its docstring for semantics
997+
(including the ``num_samples`` chunk-vs-document distinction).
868998
869999
Example usage for quantization:
8701000
@@ -902,6 +1032,7 @@ def create_forward_loop(
9021032
max_sample_length=max_sample_length,
9031033
device=device,
9041034
include_labels=include_labels,
1035+
pack=pack,
9051036
)
9061037

9071038
return lambda model: _forward_loop(model, dataloader, allowed_non_tensor_keys)

0 commit comments

Comments
 (0)