Skip to content

Commit c40a374

Browse files
cjluo-nvclaude
andcommitted
Fix missing attention_mask in calibration dataloader
When include_labels=False (the default for PTQ calibration), get_dataset_dataloader was only returning input_ids and discarding the attention_mask produced by the tokenizer. This caused HF models to create a full causal mask, allowing padding tokens to participate in attention during calibration and skewing quantization statistics. Include attention_mask alongside input_ids so the model correctly ignores padding tokens during calibration forward passes. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent 5523505 commit c40a374

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

modelopt/torch/utils/dataset_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,15 @@ def get_dataset_dataloader(
405405
)
406406
tokenized_dataset = _CustomDataset(batch_encoded)
407407
else:
408-
# For backward compatibility, if labels are not needed, we only return the input_ids.
409-
tokenized_dataset = _CustomDataset({"input_ids": batch_encoded["input_ids"]})
408+
# Always include attention_mask so the model correctly ignores padding tokens
409+
# during calibration. Without it, HF models create a full causal mask and
410+
# padding tokens participate in attention, skewing calibration statistics.
411+
tokenized_dataset = _CustomDataset(
412+
{
413+
"input_ids": batch_encoded["input_ids"],
414+
"attention_mask": batch_encoded["attention_mask"],
415+
}
416+
)
410417

411418
calib_dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False)
412419

0 commit comments

Comments
 (0)