Fix address overflow extra issue#3124
Merged
Merged
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
valarLip
approved these changes
May 12, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
PR #2189 ("[FIX] address overflow fix on fmha_bwd of gfx942/gfx950") shipped
two unintended regressions on gfx950 that fail in production-shape tests but
were not covered by the existing
smoke_test_bwd_v3.sh:bwd_hd128_*_causal_br_a32_psskddv_group.co(bf16 + fp16)takes a GPU memory access fault or returns wrong values for any
group-mode + bottom-right causal + GQA shape such as
b=2 h=24 h_k=1 s=2048 s_k=2048 mask=b -bwd_v3=1 -v3_atomic_fp32=1 -mode=1.Originally surfaced by a downstream FlashAttention test.
(
bwd_hd128_{bf16,fp16}_a16_psskddv_group.co,bwd_hd128_{bf16,fp16}_causal_a16_psskddv_group.co,bwd_hd128_{bf16,fp16}_causal_br_a16_psskddv_group.co)return wrong values on every shape including
SMOKE(
b=2 h=3 h_k=3 s=200 s_k=200). Cross-batch data corruption fromdq_acclayout mismatch.Both regressions trace to a single upstream commit
f3b06a7inpoc_kl_merg("fix address overflow on fmha bwd"), shipped asPR [FIX] address overflow fix on fmha_bwd of gfx942/gfx950 #2189 on the aiter side. The SP3 source fixes live in
poc_kl_mergunder
scripts/fmha_bwd/prod_kernels/(3 SP3 files); this PR ships onlythe regenerated
.cobinaries so they can be picked up at runtime viaAITER_ASM_DIR.Technical Details
kernel changes in
https://github.com/niels-zhang/poc_kl_merg/pull/33
Three SP3 source files (in
poc_kl_merg, separately committed) and eightregenerated
.cofiles (this PR):Bucket 1 -
FMHA_BWD_D128_..._A32_cas_br_kb_Genl.sp3(commits a3812ca + 4a2508c)PR #2189 lifted the per-block
dK/dVbyte offsetSeqs_dk * loop_idx_kfrom a per-thread vector add onv_dK_addrto ascalar add on
s_dK_buf[0:1](withs_addc_u32for the carry). The changeis algebraically equivalent for
idxen:1reads, but it failed to alsoadjust
s_dK_buf[2](num_records, the OOB envelope). With a bumped basebut un-bumped envelope, the buffer-OOB rule
index < num_recordsneverfires, so the bottom partial-block of the per-batch
dKregion writesspill into the next batch's region, surfacing as KGrad/VGrad mismatches at
the batch-1 boundary or full GPU memory faults under GQA replication.
Fix: keep PR #2189's 64-bit-safe scalar bump, additionally shrink/restore
s_dK_buf[2]ands_dV_buf[2]by the same offset in DW around the writesand at
code_exit_mask:. 4 hunks, +14 lines net.Bucket 2 -
FMHA_BWD_D128_..._A16_Genl.sp3+FMHA_BWD_D128_DQ_SHUFFLE.sp3(commit 3bf422d)The host allocates
dq_accfor A16 group as(b, h, padded_max_sq, d)with
a16_dq_acc_seq = (max_seqlen_q + 15) / 16 * 16(see
op_tests/cpp/mha/benchmark_mha_bwd.cpp:357-360, 535, 550).PR #2189 rewrote 3 separate MODE==1 indexing sites to use a flat
(h, sum_padded_seqlen, d)layout and unpaddedseqlen_qfor OOB,mismatching the host allocation:
FMHA_BWD_D128_DQ_SHUFFLE.sp3(~line 292)sq_start * Seqs_dq_accfor dq_acc READ batch_offsettg_idz * BAs_dq_accFMHA_BWD_D128_..._A16_Genl.sp3(~line 2857)sq_start * Seqs_dQfor dq_acc WRITE batch_offsettg_idz * BAs_dqFMHA_BWD_D128_..._A16_Genl.sp3(~line 2871)seqlen_q * Seqs_dQfor OOB envelope (unpadded -> rows in[seqlen, padded)get garbage)padded_seqlen_q * Seqs_dQs_mul_hi_u32+s_addc_u32carry intos_dQ_buf[1]) is preserved - the fix only dropsFiles in this PR (8 .co binaries)
Regenerated from the SP3 source fixes above. Each file's MODE==0 (batch
mode) bytes are functionally equivalent to the prior shipped versions; only
the MODE==1 (group mode) code paths change.
Test Plan
Test Result
Submission Checklist