Skip to content

[Issue]: aiter ASM paged-attention with block_size=16 #702

@yhl-amd

Description

@yhl-amd

aiter ASM paged-attention block_id truncation issue

Summary

related PR: #631

The aiter precompiled ASM paged-attention kernels (pa_*.co family on gfx950 / gfx942) silently corrupt KV cache reads when the block_id value loaded from the block_tables tensor crosses 65,535 (= 2^16). The triton/gluon paged-attention path (paged_attention_triton) does not exhibit the bug.

In ATOM Eagle3 spec decoding, the symptom is a sharp, permanent acceptance-rate collapse from ~80% to ~10–25% the first time the draft KV pool's block_id crosses 65,535. Final completions become garbage tokens; the engine never recovers without a server restart.

Affected binaries

Confirmed via [aiter] LoadKernel: logs with AITER_LOG_LEVEL=INFO:

/root/aiter/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w.co            (qlen=1, non-MTP)
/root/aiter/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co   (qlen=4, MTP)

Likely the entire same-build pa_*_blkSz=16.co family, including the bf16/bf16 variants used when --kv_cache_dtype bf16.

C++ entry point: /root/aiter/csrc/py_itfs_cu/asm_pa.cu:166 pa_fwd.

Trigger conditions (necessary and sufficient)

  1. Inference runs through paged_attention_asm (default ATOM dispatch path before the workaround).
  2. num_kvcache_blocks > 65,535 so block_id can exceed 2^16.
  3. Sustained load drives block_id past 65,535.

bf16 KV at --gpu-memory-utilization 0.8 gives num_blocks ≈ 134k and hits the boundary in 17–19 waves of 64-concurrent same-prompt requests. fp8 KV (num_blocks ≈ 267k) hits it later (~wave 35–60+) but eventually collapses identically. --gpu-memory-utilization 0.4 keeps num_blocks under 30k and the bug never fires.

Symptom signature

Watch [MTP Stats Interval] lines in the server log:

healthy:  toks/fwd 2.5–3.0   accept 50–68%   bucket-3 40–50%   bucket-0 25–30%
crashed:  toks/fwd 1.2–1.4   accept  8–16%   bucket-3  ~0%     bucket-0 65–80%

The transition is sharp (single 1000-token interval), permanent until restart, and 100% reproducible across 7+ independent runs at this configuration.

Root-cause hypothesis

For Eagle3 draft (single-layer MHA at TP=8):

slot_addr = block_id × (block_size × num_kv_heads × head_dim × elem_size) + offset_within_block
          = block_id × 32768 + offset_within_block

The product must be 64-bit. If any intermediate stores block_id in a 16-bit register or AND-masks it with 0xFFFF before the multiply (e.g. via s_pack_*_b16 / v_pack_*_b16), the result wraps at block_id ≥ 65,536 and reads land in the wrong physical slot. aiter ASM source review is required to localize the truncation.

Workaround

atom/model_ops/attention_mha.py rope_cache:

# MTP MHA must go through triton/gluon; aiter ASM non-persistent path
# may have some unexpected behavior.
use_triton_attn = (
    self.sliding_window != -1
    or self.head_dim != 128
    or self.num_heads == self.num_kv_heads
)

Why this exact condition:

  • paged_attention_triton (gluon kernel) does not have the bug — verified at 95 waves of 64-concurrent load with no collapse.
  • num_heads == num_kv_heads ⇔ pure MHA. In ATOM the only known large-KV-pool path that goes through paged_attention_asm is the Eagle3 draft, which is exactly pure MHA.
  • GQA (num_heads > num_kv_heads) and MLA target keep their existing kernels — either don't go through this ASM family, or empirically don't reach the trigger in current production configs.
  • The condition is a pure local property of the layer (no global state read), and produces zero behavior change for existing GQA / MLA users.

Reproduction

The script below embeds the exact prompt that triggered the bug during investigation. Run as-is against an ATOM server started with --kv_cache_dtype bf16 --gpu-memory-utilization 0.8 and the workaround reverted (i.e. the or self.num_heads == self.num_kv_heads clause removed) to see the collapse. Run again with the workaround in place to confirm the fix.

#!/bin/bash
# repro_aiter_pa_block_id_truncation.sh
set -e
N_WAVES=${1:-25}
LOG=/tmp/atom_repro.log
PROMPT_FILE=/tmp/repro_prompt.txt
PAYLOAD_FILE=/tmp/repro_payload.json

# 1. Embed the exact 5-shot gsm8k prompt used during investigation.
cat > "$PROMPT_FILE" << 'PROMPT'
Question: Elroy decides to enter a walk-a-thon and wants to make sure he ties last year's winner's cash collection. Last year, walkers earned $4 a mile. This year walkers earn $2.75 a mile. If last year's winner collected $44, how many more miles will Elroy walk than last year's winner to collect the same amount of money?
Answer: Last year's winner walked 11 miles because 44 / 4 = <<44/4=11>>11
Elroy has to walk 16 miles to collect $44 because 44 / 2.75 = <<44/2.75=16>>16
Elroy will walk 5 more miles because 16 - 11 = <<16-11=5>>5
#### 5

Question: Tom decides to make lasagna with all his beef.  It takes twice as many noodles as beef.  He has 10 pounds of beef.  He already has 4 pounds of lasagna noodles and the noodles come in 2-pound packages.  How many packages does he need to buy?
Answer: He needs 10*2=<<10*2=20>>20 pounds of noodles
That means he needs to buy 20-4=<<20-4=16>>16 pounds of noodles
So he needs to buy 16/2=<<16/2=8>>8 packages
#### 8

Question: Rodney, Roger and Ron can lift a combined weight of 239 pounds.  Rodney can lift twice as much as Roger, and Roger can lift 7 pounds less than 4 times the amount that Ron can lift.  How much can Rodney lift?
Answer: Let x represent the amount that Ron can lift.
Roger: 4x-7
Rodney:2(4x-7)=8x-14
Total:x+4x-7+8x-14=239
13x-21=239
13x=260
x=<<20=20>>20
Rodney:8(20)-14=146 pounds
#### 146

Question: Mr. Salazar had seven dozen oranges. He reserved 1/4 of it for his friend and was able to sell 3/7 of the remaining yesterday. Today, he saw four rotten oranges. How many oranges are left to be sold today?
Answer: Mr. Salazar had 7 x 12 = <<7*12=84>>84 oranges.
He reserved 84 x 1/4 = <<84*1/4=21>>21 oranges for his friend.
Ha had 84 - 21 = <<84-21=63>>63 oranges to be sold yesterday.
But only 63 x 3/7 = <<63*3/7=27>>27 oranges were sold yesterday.
So, he had 63 - 27 = <<63-27=36>>36 oranges left.
Since four oranges are rotten, then only 36 - 4 = <<36-4=32>>32 oranges are left to be sold today.
#### 32

Question: Anna ate 4 apples on Tuesday. On Wednesday, she ate double the apples she ate on Tuesday. On Thursday, Anna ate half the apples she ate on Tuesday. How many apples has Anna eaten by the end of these three days?
Answer: On Tuesday, Anna ate 4 apples.
On Wednesday, she ate 4 x 2 = <<4*2=8>>8 apples.
On Thursday, she ate 4 / 2 = <<4/2=2>>2 apples.
In total, Anna ate 4 + 8 + 2 = <<4+8+2=14>>14 apples.
#### 14

Question: Martha has been collecting shells since she turned 5 years old, every month she collects one shell. By her 10th birthday, how many shells will Martha have collected?
Answer:
PROMPT

# 2. Build the JSON payload (max_tokens=256, temperature=0, gsm8k stops).
jq -n --rawfile p "$PROMPT_FILE" '{
  model:"/root/models/Kimi-K2.5-MXFP4",
  prompt:$p, max_tokens:256, temperature:0,
  stop:["Question:","</s>","<|im_end|>"], seed:1234
}' > "$PAYLOAD_FILE"

# 3. Start ATOM server with util=0.8 so num_blocks > 65,535.
pkill -9 -f atom.entrypoints.openai_server 2>/dev/null || true
sleep 5
until [ "$(rocm-smi --showmemuse 2>&1 | grep -E 'VRAM%\): [0-9]' \
                                       | awk '{print $NF}' | sort -n | tail -1)" = "0" ]; do
  sleep 2
done

HSA_NO_SCRATCH_RECLAIM=1 AITER_LOG_LEVEL=WARNING \
nohup python -m atom.entrypoints.openai_server \
  --model /root/models/Kimi-K2.5-MXFP4 --kv_cache_dtype bf16 -tp 8 \
  --port 8000 --trust-remote-code --gpu-memory-utilization 0.8 \
  --method eagle3 --draft-model /root/models/kimi-k2.5-eagle3 \
  --num-speculative-tokens 3 \
  > "$LOG" 2>&1 &

while ! grep -q "Uvicorn running on http" "$LOG" 2>/dev/null; do sleep 5; done

# 4. Drive load: N waves of 64 concurrent same-prompt requests.
for w in $(seq 1 $N_WAVES); do
  T0=$(date +%s.%N)
  for i in $(seq 1 64); do
    curl -s -o /dev/null --connect-timeout 5 --max-time 60 \
      -H "Content-Type: application/json" -d "@$PAYLOAD_FILE" \
      http://localhost:8000/v1/completions &
  done
  wait
  T1=$(date +%s.%N)
  if [ $((w % 5)) -eq 0 ] || [ $w -eq 1 ]; then
    echo "wave $w: $(python3 -c "print(f'{$T1-$T0:.2f}')")s"
  fi
done

# 5. Outcome.
echo
echo "=== Final cumulative [MTP Stats] ==="
grep "MTP Stats         " "$LOG" | tail -1
echo
echo "=== Last 8 [MTP Stats Interval] ==="
grep "MTP Stats Interval" "$LOG" | tail -8

Expected outputs:

Build Result
Workaround reverted (paged_attention_asm for MHA), bf16 KV, util=0.8 Around wave 17–19 a single Interval drops to 8–16% and stays. Cumulative trends to ~25–47%.
Workaround in place (paged_attention_triton for MHA), bf16 KV, util=0.8 All Intervals remain 50–68% indefinitely. Verified at 95 waves with fp8 KV at the same util=0.8.
--gpu-memory-utilization 0.4 (caps num_blocks at ~30k) All Intervals remain 50–68% indefinitely regardless of dispatch path.

Permanent fix (out of scope of this workaround)

Audit the aiter ASM source for the affected pa_*.co kernels. For every load from block_tables, trace the value's register width through to the final buffer_load_* slot-address computation. Any 16-bit narrowing (s_pack_*_b16, v_pack_*_b16, AND 0xFFFF) on block_id before the × 32768 multiply must be removed; the product and the final add must both be 64-bit.

Operating System

Ubuntu 24.04.4 LTS (Noble Numbat)

CPU

AMD EPYC 9575F 64-Core Processor (2 sockets, 256 cores total)

GPU

AMD Instinct MI355X × 8 (gfx950, device id 0x75a3)

ROCm Version

rocm-7.2.2

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions