@@ -142,6 +142,8 @@ def _apply_sparse_nm_to_qk_tile(
142142def _skip_softmax_decision (
143143 scores ,
144144 row_max ,
145+ q_pos ,
146+ seq_len_q ,
145147 SKIP_THRESHOLD_LOG2 : tl .constexpr ,
146148 Sparsity_total ,
147149 Sparsity_skipped ,
@@ -159,16 +161,25 @@ def _skip_softmax_decision(
159161 The threshold is converted to the kernel's scaled log2 score space by the
160162 Python wrapper so it can be compared directly against ``scores``.
161163
164+ ``q_pos`` (``[BLOCK_M]`` absolute query positions) and the scalar
165+ ``seq_len_q`` identify padding rows. When a tile has fewer than ``BLOCK_M``
166+ valid queries — decode has one valid query plus ``BLOCK_M - 1`` padding
167+ rows, and the last prefill tile is partial when ``seq_q`` is not a multiple
168+ of ``BLOCK_M`` — the padding rows carry zero scores that are never
169+ negligible versus their own running max and would otherwise veto every
170+ skip. They are forced skippable so the decision reflects only valid rows.
171+
162172 Returns:
163- True when *all* Q rows in the tile satisfy the skip criterion.
173+ True when *all valid * Q rows in the tile satisfy the skip criterion.
164174
165175 When ``MEASURE_SPARSITY`` is set, also records total/skipped tile counts
166176 via atomic adds on ``Sparsity_total`` / ``Sparsity_skipped``.
167177 """
168178 tile_row_max = tl .max (scores , 1 ) # [BLOCK_M] — ~m_i^(j) (scaled)
169- # Per-row: True if row's tile max is negligible vs running max
170- can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2 )
171- # Per-tile: skip entire tile only if ALL rows are negligible
179+ # Per-row: True if the row's tile max is negligible vs running max, OR the
180+ # row is padding (q_pos >= seq_len_q) so it must not veto the tile decision.
181+ can_skip = (tile_row_max < (row_max + SKIP_THRESHOLD_LOG2 )) | (q_pos >= seq_len_q )
182+ # Per-tile: skip entire tile only if ALL valid rows are negligible
172183 skip_tile = tl .min (can_skip .to (tl .int32 )) == 1
173184
174185 if MEASURE_SPARSITY :
0 commit comments