Skip to content

Commit ea14ac0

Browse files
committed
[ROCm][DSv4] Refresh AITER work plan every step, not just on rebuild
`aiter.get_mla_metadata_v1` produces a `work_*`/`reduce_*` plan that is keyed on the *actual* per-batch kv lengths, not just on shapes. The persistent ASM `mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps` kernel reads out of bounds (causing a GPU memory access fault) if those buffers are left stale across steps with different kv lengths. Fix the cudagraph-clean refactor so the metadata is rewritten in-place on every per-step call against the current `kv_indptr`. The buffer sizes returned by `get_mla_metadata_info_v1` are determined by shapes + `max_split_per_batch` only, so they remain large enough for any kv length distribution and the data pointers stay stable for graph capture. * `AiterSparseScratch.rebuild()` now only allocates buffers and stores the static gqa/topk/dtype parameters; it no longer requires a `kv_indptr_seed` and no longer runs the metadata builder itself. * New `AiterSparseScratch.refresh_metadata()` reruns `get_mla_metadata_v1` writing into the same `work_*`/`reduce_*` slots. * `_aiter_decode_one_scope` writes `valid_mask`/`valid_lens`/ `kv_indptr`/`kv_indices_2d`/`q_fp8` directly into scratch every step, then calls `refresh_metadata()` and `mla.mla_decode_fwd`. Validated with the standalone `bench_remote/_unit_test_cudagraph.py` harness on MI355X: - Call 1 (lens=[3,2]): success, scratch key set. - Call 2 (same lens): rebuild skipped, all data_ptrs stable, output bit-identical to call 1. - Call 3 (lens=[4,1]): all data_ptrs still stable, output differs as expected (max abs diff = 2.39 vs identical-input call), no fault. - Parity check vs the original non-cudagraph implementation: max abs diff = 0.000000. Signed-off-by: Chuan Li <chuanli1101@gmail.com> Co-authored-by: Cursor Signed-off-by: Li <chuali@amd.com>
1 parent 0879185 commit ea14ac0

1 file changed

Lines changed: 80 additions & 34 deletions

File tree

vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py

Lines changed: 80 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,19 @@ class AiterSparseScratch:
3636
and reuse across all 61 DSv4 attention layers in the same decode step.
3737
Buffers fall into three groups:
3838
39-
* AITER persistent metadata (`work_*`, `reduce_*`) — written once per
40-
rebuild by `aiter.get_mla_metadata_v1`, read every step by the kernel.
39+
* AITER persistent metadata (`work_*`, `reduce_*`) — sized at rebuild
40+
from `get_mla_metadata_info_v1` (purely shape-determined), but rewritten
41+
in-place by `aiter.get_mla_metadata_v1` every step, because the work
42+
plan encodes the *actual* kv lengths and the persistent ASM kernel
43+
reads out-of-bounds if it is left stale.
4144
* Per-step indexing/IO buffers (`qo_indptr`, `kv_indptr`, `kv_indices_2d`,
4245
`kv_last_page_lens`, `valid_mask`, `valid_lens`, `col_arange`, `q_fp8`,
4346
`out_buf`) — written in-place each step.
4447
* Constant scale tensors (`q_scale`, `kv_scale`) — initialised once.
48+
49+
The metadata buffers, the per-step buffers and the scale tensors all keep
50+
stable `data_ptr()`s across the entire lifetime of a shape key, so a
51+
HIP/CUDA graph captured around the second or later step replays correctly.
4552
"""
4653

4754
__slots__ = (
@@ -66,6 +73,15 @@ class AiterSparseScratch:
6673
# Constant scale tensors (always 1.0 for our quantization scheme)
6774
"q_scale",
6875
"kv_scale",
76+
# GQA ratios captured at rebuild time so per-step refresh can call
77+
# `get_mla_metadata_v1` with the same parameters every time.
78+
"_gqa_ratio",
79+
"_nhead_kv",
80+
"_page_size",
81+
"_topk",
82+
"_dtype",
83+
"_kvtype",
84+
"_max_split_per_batch",
6985
# Identity key for cache lookups
7086
"_key",
7187
)
@@ -102,16 +118,19 @@ def rebuild(
102118
device: torch.device,
103119
max_split_per_batch: int = 256,
104120
) -> None:
105-
"""Allocate every persistent buffer for the given shape key and run
106-
`aiter.get_mla_metadata_v1` once to populate the AITER metadata.
107-
108-
Static buffers (`qo_indptr`, `kv_last_page_lens`, `col_arange`,
109-
`q_scale`, `kv_scale`) are filled here and never rewritten on the
110-
per-step path.
121+
"""Allocate every persistent buffer for the given shape key.
122+
123+
Buffer sizes returned by `get_mla_metadata_info_v1` are determined by
124+
shapes and `max_split_per_batch` only, so they are large enough for
125+
any kv-length distribution. The actual work plan is computed on the
126+
per-step path by `refresh_metadata`, which writes these buffers
127+
in-place using the freshly populated `qo_indptr`/`kv_indptr`/
128+
`kv_last_page_lens` -- those pointers stay stable for the lifetime
129+
of this scratch.
111130
"""
112131
import aiter
113132

114-
# ---- AITER persistent metadata ---------------------------------- #
133+
# ---- AITER persistent metadata buffers (sizes only) ------------- #
115134
(
116135
(wmd_size, wmd_type),
117136
(wi_size, wi_type),
@@ -141,7 +160,6 @@ def rebuild(
141160
self.qo_indptr = torch.arange(
142161
total_q + 1, dtype=torch.int32, device=device
143162
)
144-
# kv_indptr is rewritten in-place every step.
145163
self.kv_indptr = torch.zeros(
146164
total_q + 1, dtype=torch.int32, device=device
147165
)
@@ -172,36 +190,53 @@ def rebuild(
172190
self.q_scale = torch.ones(1, dtype=torch.float32, device=device)
173191
self.kv_scale = torch.ones(1, dtype=torch.float32, device=device)
174192

175-
# ---- Run AITER metadata builder once --------------------------- #
176-
# The kv_indptr passed here only needs to be a valid placeholder of
177-
# the right shape — the metadata buffers describe split points based
178-
# on shape, not per-step lengths.
193+
# Cache parameters for `refresh_metadata`.
194+
self._gqa_ratio = nhead // nhead_kv
195+
self._nhead_kv = nhead_kv
196+
self._page_size = page_size
197+
self._topk = topk
198+
self._dtype = dtype
199+
self._kvtype = kvtype
200+
self._max_split_per_batch = max_split_per_batch
201+
202+
self._key = (total_q, nhead, topk, d_qk, d_v, dtype, kvtype)
203+
204+
def refresh_metadata(self) -> None:
205+
"""Re-run `aiter.get_mla_metadata_v1` against the current
206+
`kv_indptr` / `kv_last_page_lens`, writing the new work plan into the
207+
same `work_*` / `reduce_*` buffers in-place.
208+
209+
Must be called every step *after* `kv_indptr` is updated and *before*
210+
`aiter.mla.mla_decode_fwd`. The persistent ASM kernel reads
211+
out-of-bounds if it is left with a stale work plan, so this call is
212+
not optional even when shapes are unchanged.
213+
"""
214+
import aiter
215+
179216
aiter.get_mla_metadata_v1(
180217
self.qo_indptr,
181218
self.kv_indptr,
182219
self.kv_last_page_lens,
183-
nhead // nhead_kv,
184-
nhead_kv,
220+
self._gqa_ratio,
221+
self._nhead_kv,
185222
True,
186223
self.work_meta_data,
187224
self.work_info_set,
188225
self.work_indptr,
189226
self.reduce_indptr,
190227
self.reduce_final_map,
191228
self.reduce_partial_map,
192-
page_size=page_size,
193-
kv_granularity=max(page_size, 16),
229+
page_size=self._page_size,
230+
kv_granularity=max(self._page_size, 16),
194231
max_seqlen_qo=1,
195232
uni_seqlen_qo=1,
196233
fast_mode=True,
197-
max_split_per_batch=max_split_per_batch,
198-
topk=topk,
199-
dtype_q=dtype,
200-
dtype_kv=kvtype,
234+
max_split_per_batch=self._max_split_per_batch,
235+
topk=self._topk,
236+
dtype_q=self._dtype,
237+
dtype_kv=self._kvtype,
201238
)
202239

203-
self._key = (total_q, nhead, topk, d_qk, d_v, dtype, kvtype)
204-
205240

206241
def aiter_sparse_attn_decode(
207242
*,
@@ -351,9 +386,9 @@ def _aiter_decode_one_scope(
351386
topk_max = indices_2d.size(-1)
352387

353388
fp8_dtype = torch.float8_e4m3fn
354-
355-
# ---- Lazy/keyed scratch (re)build ---------------------------------- #
356-
if not scratch.matches(total_q, h_q, topk_max, d_qk, d_v, fp8_dtype, fp8_dtype):
389+
if not scratch.matches(
390+
total_q, h_q, topk_max, d_qk, d_v, fp8_dtype, fp8_dtype
391+
):
357392
scratch.rebuild(
358393
total_q=total_q,
359394
nhead=h_q,
@@ -367,7 +402,7 @@ def _aiter_decode_one_scope(
367402
device=device,
368403
)
369404

370-
# ---- Build valid_mask + valid_lens in-place ------------------------ #
405+
# ---- Build valid_mask + valid_lens directly into scratch ----------- #
371406
if lens is not None:
372407
if lens.numel() == b and s_q > 1:
373408
lens_per_tok = lens.repeat_interleave(s_q)
@@ -383,17 +418,22 @@ def _aiter_decode_one_scope(
383418
else:
384419
torch.ge(indices_2d, 0, out=scratch.valid_mask)
385420

386-
torch.sum(scratch.valid_mask, dim=-1, dtype=torch.int32, out=scratch.valid_lens)
421+
torch.sum(
422+
scratch.valid_mask, dim=-1, dtype=torch.int32, out=scratch.valid_lens
423+
)
424+
425+
# ---- Compute kv_indptr in place (cumsum(min(valid_lens, topk))) ---- #
426+
scratch.kv_indptr[0] = 0
427+
torch.cumsum(
428+
scratch.valid_lens.clamp(max=topk_max),
429+
dim=0,
430+
out=scratch.kv_indptr[1:],
431+
)
387432

388433
# ---- Fill kv_indices_2d in-place: keep valid, sentinel -1 elsewhere - #
389434
scratch.kv_indices_2d.copy_(indices_2d)
390435
scratch.kv_indices_2d.masked_fill_(~scratch.valid_mask, -1)
391436

392-
# ---- kv_indptr = cumsum(min(valid_lens, topk_max)), prefixed by 0 -- #
393-
seq_lens_kv = scratch.valid_lens.clamp(max=topk_max)
394-
scratch.kv_indptr[0] = 0
395-
torch.cumsum(seq_lens_kv, dim=0, out=scratch.kv_indptr[1:])
396-
397437
# ---- Cast q to FP8 in-place into the preallocated buffer ----------- #
398438
scratch.q_fp8.copy_(q.reshape(total_q, h_q, d_qk))
399439

@@ -403,6 +443,12 @@ def _aiter_decode_one_scope(
403443
kv_fp8 = blocked_k.to(fp8_dtype)
404444
kv_view = kv_fp8.view(-1, 1, 1, d_qk)
405445

446+
# ---- Refresh AITER work plan against the current kv_indptr --------- #
447+
# The persistent ASM kernel encodes per-batch lengths into work_*; if we
448+
# leave that stale, the kernel reads out of bounds. Rewrite into the same
449+
# buffers in place so pointers stay stable for cudagraph capture.
450+
scratch.refresh_metadata()
451+
406452
# ---- Persistent-mode FP8 mla_decode_fwd ---------------------------- #
407453
_, lse = aiter.mla.mla_decode_fwd(
408454
scratch.q_fp8,

0 commit comments

Comments
 (0)