Skip to content

Commit 5ca71ad

Browse files
kaix-nvrohansjoshi
authored andcommitted
Fix decode calibration: padded row in decode
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 062231f commit 5ca71ad

2 files changed

Lines changed: 17 additions & 4 deletions

File tree

modelopt/torch/kernels/common/attention/triton_fa.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,8 @@ def _attn_fwd(
366366
skip_tile = _skip_softmax_decision(
367367
scores,
368368
row_max,
369+
q_pos,
370+
seq_len_q,
369371
SKIP_THRESHOLD_LOG2,
370372
Sparsity_total,
371373
Sparsity_skipped,

modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ def _apply_sparse_nm_to_qk_tile(
142142
def _skip_softmax_decision(
143143
scores,
144144
row_max,
145+
q_pos,
146+
seq_len_q,
145147
SKIP_THRESHOLD_LOG2: tl.constexpr,
146148
Sparsity_total,
147149
Sparsity_skipped,
@@ -159,16 +161,25 @@ def _skip_softmax_decision(
159161
The threshold is converted to the kernel's scaled log2 score space by the
160162
Python wrapper so it can be compared directly against ``scores``.
161163
164+
``q_pos`` (``[BLOCK_M]`` absolute query positions) and the scalar
165+
``seq_len_q`` identify padding rows. When a tile has fewer than ``BLOCK_M``
166+
valid queries — decode has one valid query plus ``BLOCK_M - 1`` padding
167+
rows, and the last prefill tile is partial when ``seq_q`` is not a multiple
168+
of ``BLOCK_M`` — the padding rows carry zero scores that are never
169+
negligible versus their own running max and would otherwise veto every
170+
skip. They are forced skippable so the decision reflects only valid rows.
171+
162172
Returns:
163-
True when *all* Q rows in the tile satisfy the skip criterion.
173+
True when *all valid* Q rows in the tile satisfy the skip criterion.
164174
165175
When ``MEASURE_SPARSITY`` is set, also records total/skipped tile counts
166176
via atomic adds on ``Sparsity_total`` / ``Sparsity_skipped``.
167177
"""
168178
tile_row_max = tl.max(scores, 1) # [BLOCK_M] — ~m_i^(j) (scaled)
169-
# Per-row: True if row's tile max is negligible vs running max
170-
can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2)
171-
# Per-tile: skip entire tile only if ALL rows are negligible
179+
# Per-row: True if the row's tile max is negligible vs running max, OR the
180+
# row is padding (q_pos >= seq_len_q) so it must not veto the tile decision.
181+
can_skip = (tile_row_max < (row_max + SKIP_THRESHOLD_LOG2)) | (q_pos >= seq_len_q)
182+
# Per-tile: skip entire tile only if ALL valid rows are negligible
172183
skip_tile = tl.min(can_skip.to(tl.int32)) == 1
173184

174185
if MEASURE_SPARSITY:

0 commit comments

Comments
 (0)