Skip to content

Commit 422a5f0

Browse files
committed
Fix decode calibration: full-cache kv_bound + 128x128 block to match PyTorch
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 8490c42 commit 422a5f0

3 files changed

Lines changed: 21 additions & 4 deletions

File tree

modelopt/torch/kernels/common/attention/triton_fa.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def _load_sparsity_helpers() -> None:
8080
_FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)]
8181

8282
_MEASURE_BLOCK_M = 128
83-
_MEASURE_BLOCK_N = 64
83+
# 128 (not 64) so the kernel sparsity-measurement block matches the PyTorch
84+
# flash_skip_softmax calibration block (br = bc = 128) and the Triton
85+
# calibration kernel; otherwise the two measure at different granularities.
86+
_MEASURE_BLOCK_N = 128
8487
_MEASURE_NUM_STAGES = 1
8588
_MEASURE_NUM_WARPS = 4
8689

modelopt/torch/kernels/sparsity/attention/calibrate.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,17 @@ def _attn_fwd_calibrate(
111111
local_skipped = tl.zeros([PADDED_THRESHOLDS], dtype=tl.int32)
112112
num_tiles = 0
113113

114-
kv_bound = seq_len_kv if not IS_CAUSAL else tl.minimum((tile_q + 1) * BLOCK_M, seq_len_kv)
114+
# Causal bound: when Q is a suffix of KV (decode: seq_len_q == 1 against a
115+
# long cache; or chunked prefill), the visible KV extends to
116+
# causal_offset + (tile_q + 1) * BLOCK_M. Without the offset the loop stops
117+
# at the first BLOCK_M KV tokens, so decode would only ever measure the
118+
# start of the cache instead of the whole thing.
119+
causal_offset = seq_len_kv - seq_len_q
120+
kv_bound = (
121+
seq_len_kv
122+
if not IS_CAUSAL
123+
else tl.minimum(causal_offset + (tile_q + 1) * BLOCK_M, seq_len_kv)
124+
)
115125

116126
for kv_start in range(0, kv_bound, BLOCK_N):
117127
kv_start = tl.multiple_of(kv_start, BLOCK_N)
@@ -261,8 +271,10 @@ def attention_calibrate(
261271
sm_scale = 1.0 / (HEAD_DIM**0.5) if softmax_scale is None else softmax_scale
262272
qk_scale = sm_scale * LOG2E
263273
BLOCK_D = triton.next_power_of_2(HEAD_DIM)
274+
# 128x128 to match the PyTorch flash_skip_softmax calibration block (br = bc = 128),
275+
# so Triton-kernel and PyTorch calibration measure sparsity at the same granularity.
264276
BLOCK_M = 128
265-
BLOCK_N = 64
277+
BLOCK_N = 128
266278

267279
if b_seq_len_k is None:
268280
b_seq_len_k = b_seq_len

tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,9 @@ def test_first_measured_call_has_real_tile_count_with_autotune(self):
319319
assert result.returncode == 0, result.stderr
320320
totals = [line for line in result.stdout.splitlines() if line.startswith("TOTAL=")]
321321
assert totals, result.stdout
322-
assert int(totals[-1].split("=", maxsplit=1)[1]) == 8
322+
# seq_len=256, _MEASURE_BLOCK_M = _MEASURE_BLOCK_N = 128, non-causal:
323+
# Q tiles = ceil(256/128) = 2, KV tiles = ceil(256/128) = 2, total = 4.
324+
assert int(totals[-1].split("=", maxsplit=1)[1]) == 4
323325

324326
def test_measure_sparsity_without_skip_is_noop(self):
325327
"""Without skip-softmax, measure_sparsity doesn't attach counters."""

0 commit comments

Comments
 (0)