Skip to content

Fix missing attention_mask in calibration dataloader#1261

Open
cjluo-nv wants to merge 1 commit intomainfrom
fix/include-attention-mask-in-calibration
Open

Fix missing attention_mask in calibration dataloader#1261
cjluo-nv wants to merge 1 commit intomainfrom
fix/include-attention-mask-in-calibration

Conversation

@cjluo-nv
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv commented Apr 14, 2026

Summary

  • When include_labels=False (the default for PTQ calibration), get_dataset_dataloader was discarding the attention_mask produced by the tokenizer and only returning input_ids.
  • Without attention_mask, HuggingFace models create a full causal mask, causing padding tokens to participate in attention during calibration and skewing quantization statistics.
  • This fix includes attention_mask alongside input_ids so the model correctly ignores padding tokens during calibration forward passes.

Details

In modelopt/torch/utils/dataset_utils.py, the tokenizer call at line 387 with padding=True produces both input_ids and attention_mask. The include_labels=True path (line 406) already preserves the full batch_encoded dict including attention_mask. However, the include_labels=False path was only keeping input_ids "for backward compatibility."

During the calibration forward loop (_forward_loop_process_batch), the batch dict is unpacked as **kwargs into model.forward(). Without attention_mask, HF models default to attending to all positions including padding, which pollutes calibration statistics.

Practical impact: With batch_size=1 there is no padding so the bug is invisible. With larger batch sizes and variable-length samples, shorter sequences get padded and the effect grows.

Test plan

  • Existing unit tests pass (tests/unit/torch/utils/test_dataset_utils.py)
  • Pre-commit hooks pass
  • [ ] Verify PTQ accuracy with batch_size > 1 on a padded calibration dataset (GPU required)

🤖 Generated with Claude Code

When include_labels=False (the default for PTQ calibration),
get_dataset_dataloader was only returning input_ids and discarding
the attention_mask produced by the tokenizer. This caused HF models
to create a full causal mask, allowing padding tokens to participate
in attention during calibration and skewing quantization statistics.

Include attention_mask alongside input_ids so the model correctly
ignores padding tokens during calibration forward passes.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
@cjluo-nv cjluo-nv requested a review from a team as a code owner April 14, 2026 19:46
@github-actions
Copy link
Copy Markdown
Contributor

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1261/

Built to branch gh-pages at 2026-04-14 19:50 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 14, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.90%. Comparing base (5ff1d7b) to head (c40a374).
⚠️ Report is 6 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1261      +/-   ##
==========================================
- Coverage   76.91%   76.90%   -0.02%     
==========================================
  Files         350      350              
  Lines       40481    40859     +378     
==========================================
+ Hits        31137    31423     +286     
- Misses       9344     9436      +92     
Flag Coverage Δ
examples 44.14% <100.00%> (+1.18%) ⬆️
gpu 57.39% <100.00%> (-0.14%) ⬇️
unit 55.61% <0.00%> (+0.09%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@cjluo-nv cjluo-nv requested a review from meenchen April 14, 2026 20:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants