You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[ROCm][DSv4] Refresh AITER work plan every step, not just on rebuild
`aiter.get_mla_metadata_v1` produces a `work_*`/`reduce_*` plan that is
keyed on the *actual* per-batch kv lengths, not just on shapes. The
persistent ASM `mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps` kernel reads
out of bounds (causing a GPU memory access fault) if those buffers are
left stale across steps with different kv lengths.
Fix the cudagraph-clean refactor so the metadata is rewritten in-place
on every per-step call against the current `kv_indptr`. The buffer
sizes returned by `get_mla_metadata_info_v1` are determined by shapes
+ `max_split_per_batch` only, so they remain large enough for any kv
length distribution and the data pointers stay stable for graph capture.
* `AiterSparseScratch.rebuild()` now only allocates buffers and stores
the static gqa/topk/dtype parameters; it no longer requires a
`kv_indptr_seed` and no longer runs the metadata builder itself.
* New `AiterSparseScratch.refresh_metadata()` reruns
`get_mla_metadata_v1` writing into the same `work_*`/`reduce_*` slots.
* `_aiter_decode_one_scope` writes `valid_mask`/`valid_lens`/
`kv_indptr`/`kv_indices_2d`/`q_fp8` directly into scratch every
step, then calls `refresh_metadata()` and `mla.mla_decode_fwd`.
Validated with the standalone `bench_remote/_unit_test_cudagraph.py`
harness on MI355X:
- Call 1 (lens=[3,2]): success, scratch key set.
- Call 2 (same lens): rebuild skipped, all data_ptrs stable, output
bit-identical to call 1.
- Call 3 (lens=[4,1]): all data_ptrs still stable, output differs as
expected (max abs diff = 2.39 vs identical-input call), no fault.
- Parity check vs the original non-cudagraph implementation:
max abs diff = 0.000000.
Signed-off-by: Chuan Li <chuanli1101@gmail.com>
Co-authored-by: Cursor
Signed-off-by: Li <chuali@amd.com>
0 commit comments