2929from torch .utils .data import DataLoader
3030from tqdm import tqdm
3131
32+ from .logging import warn_rank_0
33+
3234if TYPE_CHECKING :
3335 from transformers import PreTrainedTokenizerBase
3436
@@ -512,6 +514,103 @@ def __len__(self):
512514 return len (next (iter (self .encodings .values ())))
513515
514516
517+ def _build_packed_input_ids (
518+ dataset_name : list [str ],
519+ num_samples : list [int ],
520+ max_sample_length : int ,
521+ tokenizer : "PreTrainedTokenizerBase" ,
522+ apply_chat_template : bool ,
523+ ) -> torch .Tensor :
524+ """Pack raw samples into a ``(n_chunks, max_sample_length)`` int tensor.
525+
526+ Each source contributes ``num_sample`` chunks (or fewer if exhausted), so the requested
527+ per-source ratio in ``num_samples`` is preserved instead of letting whichever source
528+ appears first dominate the budget. Within a source, tokenization runs in batches of
529+ ``max(8, num_sample // 4)`` samples so we stop tokenizing once the chunk budget is
530+ full, instead of eagerly paying for the entire ``num_sample * 2`` oversample.
531+
532+ Documents are separated by ``tokenizer.eos_token_id`` when set; ``add_special_tokens=False``
533+ avoids injecting a fresh BOS at every sample boundary. Note that packed chunks therefore
534+ have no BOS at position 0 — fine for amax / sensitivity calibration where boundary
535+ tokens are statistically dominated, less ideal for callers that need BOS-prefixed
536+ sequences (use ``pack=False`` for those). When ``apply_chat_template=True``, the rendered
537+ samples often already end with the chat EOS marker (e.g. ``<|im_end|>``), which can
538+ tokenize to ``eos_token_id`` and produce ``<eos><eos>`` at document boundaries —
539+ harmless for calibration statistics but worth noting.
540+
541+ Sizing note: ``num_sample`` here is the desired chunk count per source. The loader
542+ internally fetches ``num_sample * 2`` raw samples. Short-document sources can still
543+ under-fill — to recover the target, scale ``num_sample`` itself (which doubles both
544+ the target and the internal raw-sample draw). Example: short-row source returning 1
545+ chunk for ``num_sample=64`` typically returns 4 chunks for ``num_sample=128`` because
546+ the raw draw goes from 128 to 256.
547+ """
548+ sep_id = tokenizer .eos_token_id
549+ if sep_id is None :
550+ warn_rank_0 (
551+ "pack=True: tokenizer has no eos_token_id; raw documents will be concatenated "
552+ "without a separator, so calibration activations will span document boundaries. "
553+ "Set tokenizer.eos_token_id (or another sentinel) for explicit separators."
554+ )
555+
556+ per_source_chunks : list [list [int ]] = []
557+ actual_per_source : list [int ] = []
558+ for ds_name , num_sample in zip (dataset_name , num_samples ):
559+ # 2x oversample sized for cnn_dailymail-style long docs; short-sample datasets may
560+ # still under-fill and trigger the warning below.
561+ raw_samples = get_dataset_samples (
562+ ds_name ,
563+ num_sample * 2 ,
564+ apply_chat_template = apply_chat_template ,
565+ tokenizer = tokenizer ,
566+ )
567+ needed_tokens = num_sample * max_sample_length
568+ # max(8, ...) floor keeps the Rust-batched tokenizer happy for small calibrations
569+ # (num_sample < 32 → batch is 8); above that, `// 4` grows the batch with the
570+ # request while keeping the early-exit check granular enough to actually skip
571+ # tokenizing the back half of the 2x oversample on long-doc sources.
572+ tokenize_batch_size = max (8 , num_sample // 4 )
573+ stream : list [int ] = []
574+ for batch_start in range (0 , len (raw_samples ), tokenize_batch_size ):
575+ if len (stream ) >= needed_tokens :
576+ break
577+ batch = raw_samples [batch_start : batch_start + tokenize_batch_size ]
578+ # padding/truncation=False explicit: don't trust subclass __call__ defaults.
579+ encoded = tokenizer (batch , add_special_tokens = False , padding = False , truncation = False )[
580+ "input_ids"
581+ ]
582+ for ids in encoded :
583+ stream .extend (ids )
584+ if sep_id is not None :
585+ stream .append (sep_id )
586+ if len (stream ) >= needed_tokens :
587+ break
588+ available = len (stream ) // max_sample_length
589+ take = min (num_sample , available )
590+ per_source_chunks .extend (
591+ stream [i * max_sample_length : (i + 1 ) * max_sample_length ] for i in range (take )
592+ )
593+ actual_per_source .append (take )
594+
595+ n_chunks = len (per_source_chunks )
596+ total_chunks = sum (num_samples )
597+ if n_chunks == 0 :
598+ raise ValueError (
599+ f"pack=True yielded 0 chunks across { len (dataset_name )} source(s); each source "
600+ f"needs at least { max_sample_length } tokens after concatenation. Try longer "
601+ "samples or a smaller max_sample_length."
602+ )
603+ if n_chunks < total_chunks :
604+ warn_rank_0 (
605+ f"pack=True produced { n_chunks } chunks (per-source { actual_per_source } ) vs "
606+ f"requested { total_chunks } (per-source { list (num_samples )} ). Some sources "
607+ "exhausted before reaching their target. The loader internally fetches "
608+ "`num_samples * 2` raw samples per source; for very short-sample sources, "
609+ "pass a 2-3x larger `num_samples` so the 2x draw covers the chunk budget."
610+ )
611+ return torch .tensor (per_source_chunks , dtype = torch .long )
612+
613+
515614def get_dataset_dataloader (
516615 dataset_name : str | list [str ] = "cnn_dailymail" ,
517616 tokenizer : "PreTrainedTokenizerBase | None" = None ,
@@ -521,6 +620,7 @@ def get_dataset_dataloader(
521620 device : torch .device | None = None ,
522621 include_labels : bool = False ,
523622 apply_chat_template : bool = False ,
623+ pack : bool = False ,
524624) -> DataLoader :
525625 """Get a dataloader with the dataset name and tokenizer of the target model.
526626
@@ -531,12 +631,31 @@ def get_dataset_dataloader(
531631 an ``int`` (applied to a single source) or a list aligned with ``dataset_name``.
532632 tokenizer: Instance of HuggingFace tokenizer.
533633 batch_size: Batch size of the returned dataloader.
534- num_samples: Number of samples from the dataset.
634+ num_samples: Number of samples from the dataset. Semantics depend on ``pack``:
635+ with ``pack=False`` this is the number of raw samples to fetch and tokenize
636+ (each becomes one row of ``max_sample_length`` after truncate-and-pad); with
637+ ``pack=True`` this is the number of ``max_sample_length``-token chunks to
638+ produce per source. Migrating an existing call site to ``pack=True`` may
639+ therefore need a different value to hit the same total-token calibration
640+ budget.
535641 max_sample_length: Maximum length of a sample.
536642 device: Target device for the returned dataloader.
537643 include_labels: Whether to include labels in the dataloader.
538644 apply_chat_template: Whether to apply the chat template to the samples
539645 (if supported by the dataset).
646+ pack: If True, raw samples from each source are concatenated into a per-source token
647+ stream (separated by ``tokenizer.eos_token_id`` when set) and sliced into
648+ uniform-length chunks of ``max_sample_length``; the per-source chunks are then
649+ concatenated **contiguously by source** (no cross-source interleaving), preserving
650+ the requested per-source ratio in ``num_samples``. Avoids the per-sample
651+ truncate-and-pad waste of the default path: long documents stay intact, short
652+ ones don't introduce padding noise. Recommended for pruning calibration and
653+ amax-based PTQ where activation statistics should reflect natural-length
654+ contexts rather than padded fragments. ``attention_mask`` is unconditionally
655+ all-ones — attention crosses document boundaries (the ``eos`` separator is a
656+ token, not a mask boundary). Raises ``ValueError`` if the dataset doesn't yield
657+ enough tokens to form a single chunk; emits a rank-0 warning if it yields
658+ fewer chunks than requested.
540659
541660 Returns:
542661 An instance of dataloader.
@@ -560,22 +679,30 @@ def get_dataset_dataloader(
560679 "dataset_name and num_samples must be the same length"
561680 )
562681
563- all_samples = []
564- for ds_name , num_sample in zip (dataset_name , num_samples ):
565- samples = get_dataset_samples (
566- ds_name , num_sample , apply_chat_template = apply_chat_template , tokenizer = tokenizer
682+ if pack :
683+ input_ids = _build_packed_input_ids (
684+ dataset_name , num_samples , max_sample_length , tokenizer , apply_chat_template
567685 )
568- all_samples .extend (samples )
569-
570- batch_encoded = tokenizer (
571- all_samples ,
572- return_tensors = "pt" ,
573- padding = True ,
574- truncation = True ,
575- max_length = max_sample_length ,
576- )
577- if device :
578- batch_encoded = batch_encoded .to (device )
686+ batch_encoded = {"input_ids" : input_ids , "attention_mask" : torch .ones_like (input_ids )}
687+ if device :
688+ batch_encoded = {k : v .to (device ) for k , v in batch_encoded .items ()}
689+ else :
690+ all_samples = []
691+ for ds_name , num_sample in zip (dataset_name , num_samples ):
692+ samples = get_dataset_samples (
693+ ds_name , num_sample , apply_chat_template = apply_chat_template , tokenizer = tokenizer
694+ )
695+ all_samples .extend (samples )
696+
697+ batch_encoded = tokenizer (
698+ all_samples ,
699+ return_tensors = "pt" ,
700+ padding = True ,
701+ truncation = True ,
702+ max_length = max_sample_length ,
703+ )
704+ if device :
705+ batch_encoded = batch_encoded .to (device )
579706
580707 if include_labels :
581708 # Labels are needed when backward is called in the model.
@@ -844,6 +971,7 @@ def create_forward_loop(
844971 include_labels : bool = False ,
845972 dataloader : DataLoader | None = None ,
846973 allowed_non_tensor_keys : set | None = None ,
974+ pack : bool = False ,
847975) -> Callable :
848976 """Creates and returns a forward loop function configured for a specific model, dataset, and tokenizer.
849977
@@ -865,6 +993,8 @@ def create_forward_loop(
865993 allowed_non_tensor_keys: Set of key names whose batch values may be non-tensor types.
866994 Useful when the dataloader yields batches with non-standard fields (e.g., nested
867995 model outputs).
996+ pack: Forwarded to :func:`get_dataset_dataloader`. See its docstring for semantics
997+ (including the ``num_samples`` chunk-vs-document distinction).
868998
869999 Example usage for quantization:
8701000
@@ -902,6 +1032,7 @@ def create_forward_loop(
9021032 max_sample_length = max_sample_length ,
9031033 device = device ,
9041034 include_labels = include_labels ,
1035+ pack = pack ,
9051036 )
9061037
9071038 return lambda model : _forward_loop (model , dataloader , allowed_non_tensor_keys )
0 commit comments