Skip to content

Commit 952a62b

Browse files
cjluo-nvclaude
andauthored
Fix missing attention_mask in calibration dataloader (#1261)
## Summary - When `include_labels=False` (the default for PTQ calibration), `get_dataset_dataloader` was discarding the `attention_mask` produced by the tokenizer and only returning `input_ids`. - Without `attention_mask`, HuggingFace models create a full causal mask, causing padding tokens to participate in attention during calibration and skewing quantization statistics. - This fix includes `attention_mask` alongside `input_ids` so the model correctly ignores padding tokens during calibration forward passes. ## Details In `modelopt/torch/utils/dataset_utils.py`, the tokenizer call at line 387 with `padding=True` produces both `input_ids` and `attention_mask`. The `include_labels=True` path (line 406) already preserves the full `batch_encoded` dict including `attention_mask`. However, the `include_labels=False` path was only keeping `input_ids` "for backward compatibility." During the calibration forward loop (`_forward_loop` → `_process_batch`), the batch dict is unpacked as `**kwargs` into `model.forward()`. Without `attention_mask`, HF models default to attending to all positions including padding, which pollutes calibration statistics. **Practical impact**: With `batch_size=1` there is no padding so the bug is invisible. With larger batch sizes and variable-length samples, shorter sequences get padded and the effect grows. ## Test plan - [x] Existing unit tests pass (`tests/unit/torch/utils/test_dataset_utils.py`) - [x] Pre-commit hooks pass - [ ] Verify PTQ accuracy with batch_size > 1 on a padded calibration dataset (GPU required) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Signed-off-by: Chenjie Luo <chenjiel@nvidia.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c9b1155 commit 952a62b

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

modelopt/torch/utils/dataset_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,15 @@ def get_dataset_dataloader(
405405
)
406406
tokenized_dataset = _CustomDataset(batch_encoded)
407407
else:
408-
# For backward compatibility, if labels are not needed, we only return the input_ids.
409-
tokenized_dataset = _CustomDataset({"input_ids": batch_encoded["input_ids"]})
408+
# Always include attention_mask so the model correctly ignores padding tokens
409+
# during calibration. Without it, HF models create a full causal mask and
410+
# padding tokens participate in attention, skewing calibration statistics.
411+
tokenized_dataset = _CustomDataset(
412+
{
413+
"input_ids": batch_encoded["input_ids"],
414+
"attention_mask": batch_encoded["attention_mask"],
415+
}
416+
)
410417

411418
calib_dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False)
412419

0 commit comments

Comments
 (0)