Skip to content

feat - support xqa spec#853

Closed
zerozw wants to merge 1 commit into
mainfrom
feature/prepare_cu_opt
Closed

feat - support xqa spec#853
zerozw wants to merge 1 commit into
mainfrom
feature/prepare_cu_opt

Conversation

@zerozw
Copy link
Copy Markdown
Collaborator

@zerozw zerozw commented Apr 2, 2026

No description provided.

@zerozw zerozw requested a review from LLLLKKKK as a code owner April 2, 2026 03:25
Copilot AI review requested due to automatic review settings April 2, 2026 03:25
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a dedicated XQA “spec decode” path for target-verify workloads (FP8 KV cache), wiring it into the attention factory and extending tests to cover support + correctness.

Changes:

  • Introduces XQASpecAttnOp (C++/pybind) and XQASpecImpl (Python FMHA impl) for target-verify spec decode.
  • Updates support/selection logic to route is_target_verify=True away from the normal XQAAttnOp.
  • Extends CUDA attention tests to exercise FP8 KV cache + spec decode behavior.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
rtp_llm/models_py/modules/factory/attention/cuda_impl/xqa.py Adds Python XQASpecImpl and a multi-token fallback in decode path; adjusts semaphore dtype.
rtp_llm/models_py/modules/factory/attention/cuda_impl/test/test_xqa.py Adds spec inputs, reference computation, and new tests for spec op support + correctness.
rtp_llm/models_py/modules/factory/attention/cuda_impl/test/base_attention_test.py Extends test config to set kv_cache_dtype; improves FP8 KV cache test tensor creation.
rtp_llm/models_py/modules/factory/attention/init.py Registers XQASpecImpl into decode implementation list.
rtp_llm/models_py/bindings/cuda/XQAAttnOp.h Declares new XQASpecAttnOp binding.
rtp_llm/models_py/bindings/cuda/XQAAttnOp.cc Implements XQASpecAttnOp support/prepare/forward and updates XQAAttnOp::support.
rtp_llm/cpp/cuda/ops/CudaXqa.h Extends XQAParams with q_cu_seqlens and max_q_len.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +54 to +61
const auto input_type = attn_configs_.dtype == torch::kBFloat16 ? DataType::TYPE_BF16 : DataType::TYPE_FP16;
const auto kv_type = DataType::TYPE_FP8_E4M3;
return supportXqa(input_type,
input_type,
kv_type,
attn_configs_.head_num / attn_configs_.kv_head_num,
attn_configs_.size_per_head,
attn_configs_.kernel_tokens_per_block);
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XQASpecAttnOp::support() treats any non-BF16 attention dtype as FP16. If attn_configs_.dtype can be something else (e.g., FP8), support() may return true but forward() will still pass the original input tensor to runXqa() (which then interprets it as FP16 based on the BF16 boolean), producing incorrect results. Restrict support to {torch::kBFloat16, torch::kFloat16} (return false otherwise) or explicitly cast the input to FP16/BF16 in forward() and keep support() consistent with that behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +179 to +180
return output.reshape(output.shape[0] * output.shape[1], -1)

Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XQASpecAttnOp::forward() allocates [batch_size, max_q_len, heads, dim] on the C++ side, so output.shape[1] is max_q_len (not necessarily the actual per-request q length). Flattening via reshape(batch * max_q_len, -1) can leak padded/uninitialized positions into the returned tensor when q lengths vary across the batch. Consider returning a packed [total_tokens, ...] tensor by slicing/packing using self.attn_inputs.decode_cu_seqlens/input_lengths, or keep the 4D output and let the caller slice to valid ranges.

Suggested change
return output.reshape(output.shape[0] * output.shape[1], -1)
# XQASpecAttnOp returns output with shape [batch_size, max_q_len, num_heads, head_dim].
# We must not blindly flatten [batch_size * max_q_len, ...], because max_q_len may exceed
# the actual per-request sequence lengths, which would leak padded/uninitialized positions.
# Instead, pack only the valid tokens according to self.attn_inputs.input_lengths.
batch_size, max_q_len, num_heads, head_dim = output.shape
hidden_size = num_heads * head_dim
# input_lengths is expected to be a 1D tensor/list-like of length batch_size
input_lengths = self.attn_inputs.input_lengths
# Convert to a tensor if needed, without changing device/dtype of sequence length storage
if not torch.is_tensor(input_lengths):
input_lengths = torch.tensor(input_lengths, device=output.device, dtype=torch.long)
packed_outputs = []
for b in range(batch_size):
seq_len = int(input_lengths[b].item())
if seq_len <= 0:
continue
# Slice valid timesteps for this batch element and reshape to [seq_len, hidden_size]
packed_outputs.append(output[b, :seq_len].reshape(-1, hidden_size))
if not packed_outputs:
# All sequences are empty; return a correctly shaped empty tensor
return output.new_empty((0, hidden_size))
return torch.cat(packed_outputs, dim=0)

Copilot uses AI. Check for mistakes.
Comment on lines +368 to +417
def _spec_decode_fallback(
self,
q_4d: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
page_table: torch.Tensor,
seq_lens: torch.Tensor,
) -> torch.Tensor:
batch_size, q_len_per_req, num_heads, head_dim = q_4d.shape
num_kv_heads = k_cache.shape[1]
group_size = num_heads // num_kv_heads
scale = head_dim**-0.5
outputs = []

for batch_idx in range(batch_size):
block_ids = page_table[batch_idx]
k_blocks = k_cache[block_ids]
v_blocks = v_cache[block_ids]
k_seq = (
k_blocks.permute(1, 0, 2, 3)
.reshape(num_kv_heads, -1, head_dim)
.permute(1, 0, 2)
.contiguous()
)
v_seq = (
v_blocks.permute(1, 0, 2, 3)
.reshape(num_kv_heads, -1, head_dim)
.permute(1, 0, 2)
.contiguous()
)

prefix_len = int(seq_lens[batch_idx].item())
token_outputs = []
for token_idx in range(q_len_per_req):
seq_len = prefix_len + token_idx + 1
q_token = q_4d[batch_idx, token_idx]
k_token = k_seq[:seq_len]
v_token = v_seq[:seq_len]
if group_size > 1:
k_token = k_token.repeat_interleave(group_size, dim=1)
v_token = v_token.repeat_interleave(group_size, dim=1)

scores = (
torch.einsum("hd,thd->ht", q_token.float(), k_token.float()) * scale
)
attn = torch.softmax(scores, dim=-1).to(v_token.dtype)
token_output = torch.einsum("ht,thd->hd", attn, v_token)
token_outputs.append(token_output.to(q_4d.dtype))
outputs.append(torch.stack(token_outputs, dim=0))
return torch.stack(outputs, dim=0).unsqueeze(1).contiguous()
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fallback is a nested Python loop over batch and tokens with multiple einsum() calls and repeat_interleave(), which will be extremely slow and can cause major regressions if hit outside of small tests. Additionally, attn = softmax(...).to(v_token.dtype) can be especially problematic when v_token is FP8 (precision loss/underflow for attention weights). Recommended: avoid silently using this path in production (e.g., raise/return unsupported so the factory selects the spec kernel), and if it must exist, keep attention weights in FP16/BF16/FP32 (only cast the final output) and add an explicit warning that it is for debug-only.

Copilot uses AI. Check for mistakes.
Comment on lines +67 to +70
RTP_LLM_CHECK_WITH_INFO(attn_inputs.kv_cache_kernel_block_id_host.defined()
&& attn_inputs.kv_cache_kernel_block_id_device.defined(),
"decode should have kv cache block id.");

Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is inside XQASpecAttnOp::prepare(), but the error message says "decode should have kv cache block id." Consider updating it to mention spec/target-verify (and the exact required fields), e.g., "spec decode requires kv_cache_kernel_block_id_{host,device} to be set".

Copilot uses AI. Check for mistakes.
Comment on lines +262 to +264
self.semaphores = torch.zeros(
8 * 1024 * 1024, dtype=torch.uint32, device="cuda"
)
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing semaphores from uint8 to uint32 increases the allocation from ~8 MiB to ~32 MiB per instance (same element count, 4× element size). If multiple XQADecodeImpl instances can exist concurrently, this becomes a significant GPU memory overhead. If uint32 is required by the kernel/atomic semantics, consider reducing the element count to keep bytes constant, or allocate this buffer from a shared pool similar to workspace_buffer.

Copilot uses AI. Check for mistakes.
Comment on lines 80 to 82
DECODE_MHA_IMPS.extend([FlashInferTRTLLMDecodeImpl])
DECODE_MHA_IMPS.append(XQASpecImpl)
DECODE_MHA_IMPS.append(get_xqa_impl())
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The factory selection is now extended with XQASpecImpl, but the added tests focus on the raw XQASpecAttnOp behavior rather than verifying the factory chooses XQASpecImpl when attn_inputs.is_target_verify=True (and rejects it otherwise). Adding a small test that exercises the factory/impl selection would help prevent regressions where ordering or support predicates change.

Copilot uses AI. Check for mistakes.
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

LLLLKKKK commented Apr 2, 2026

🤖 Code Review (v1) — P2

Verdict: P2 (Suggestions)

为 XQA 添加 speculative decoding 支持,新增 XQASpecAttnOp (C++) 和 XQASpecImpl (Python),仅支持 FP8 KV cache + SM90 的 target verify 场景。实现结构清晰,测试包含 reference 对比验证。

P2 Issues

1. XQASpecAttnOp::prepare 中多次 GPU→CPU 同步
item<int32_t>() 触发 3 次 GPU→CPU 同步(max_q_len 1次 + max_seq_len 2次)。prepare 阶段影响有限,但频繁调用时可考虑合并。

2. XQASpecAttnOpXQAAttnOp 代码重复
prepare 和 forward 中有大量重复代码(prepareTrtAttnParams、kv_block_array 设置等),建议提取公共 helper。

3. 测试 tolerance 较宽松
rtol=0.1, atol=0.3 对 attention 输出来说比较宽松,建议添加注释说明原因(FP8 量化损失)。

4. XQASpecImpl.forward 未使用 write_cache_store_impl
构造函数中创建了 self.write_cache_store_impl,但 forward 中没有调用。如果不需要则移除创建,如果需要则补充调用。


Automated review by CI Bot

Copilot AI review requested due to automatic review settings April 7, 2026 06:14
@zerozw zerozw force-pushed the feature/prepare_cu_opt branch from e162507 to 68ff089 Compare April 7, 2026 06:14
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +154 to +171
torch::Tensor xqa_input = input.contiguous();
runXqa(xqa_input.data_ptr(),
input.dtype() == torch::kBFloat16,
output.data_ptr(),
local_head_num,
local_head_num_kv,
size_per_head,
params->batch_size,
static_cast<size_t>(kv_block_array.mMaxBlocksPerSeq),
params->max_seq_len,
attn_configs_.kernel_tokens_per_block,
kv_block_array.mPrimaryPoolPtr,
reinterpret_cast<int32_t*>((KVCacheIndex*)(params->kv_cache_offset.data_ptr())),
kv_block_array.cache_type == KvCacheDataType::FP8,
reinterpret_cast<uint32_t*>(params->sequence_lengths.data_ptr()),
nullptr,
params->max_q_len,
params->q_cu_seqlens.data_ptr());
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XQASpecAttnOp::forward passes params->max_q_len and params->q_cu_seqlens into runXqa(), enabling the spec path. However runXqa() (rtp_llm/cpp/cuda/ops/CudaXqa.cc) initializes q_mask/semaphores as function-static tensors based on the first call’s (kv_head_num, group_size, max_q_len, max_batch_size). If a non-spec decode call happens first (default max_q_len=2) and later spec verify uses a larger max_q_len (e.g. propose_step_+1), the cached q_mask size will mismatch max_q_len and can lead to incorrect results or out-of-bounds reads in the kernel. runXqa should cache these buffers keyed by max_q_len/(kv_head_num,group_size) or recreate them when arguments change, rather than using single-init statics tied to the first invocation.

Copilot uses AI. Check for mistakes.
Comment on lines +9 to +20
from rtp_llm.config.engine_config import EngineConfig
from rtp_llm.config.model_config import ModelConfig
from rtp_llm.config.py_config_modules import PyEnvConfigs
from rtp_llm.ops import KvCacheDataType
from rtp_llm.ops.compute_ops import (
FusedRopeKVCachePrefillOpQOut,
PyAttentionInputs,
XQAAttnOp,
XQAParams,
XQASpecAttnOp,
get_typemeta,
init_device,
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several newly added imports appear unused in this test module (EngineConfig, ModelConfig, PyEnvConfigs, FusedRopeKVCachePrefillOpQOut, init_device). If the repo runs linting/static checks, these can fail CI and also add noise for readers. Please remove unused imports or add the missing usage if they’re intended for setup.

Suggested change
from rtp_llm.config.engine_config import EngineConfig
from rtp_llm.config.model_config import ModelConfig
from rtp_llm.config.py_config_modules import PyEnvConfigs
from rtp_llm.ops import KvCacheDataType
from rtp_llm.ops.compute_ops import (
FusedRopeKVCachePrefillOpQOut,
PyAttentionInputs,
XQAAttnOp,
XQAParams,
XQASpecAttnOp,
get_typemeta,
init_device,
from rtp_llm.ops import KvCacheDataType
from rtp_llm.ops.compute_ops import (
PyAttentionInputs,
XQAAttnOp,
XQAParams,
XQASpecAttnOp,
get_typemeta,

Copilot uses AI. Check for mistakes.
int batch_size = attn_inputs.sequence_lengths.size(0);
RTP_LLM_CHECK_WITH_INFO(attn_inputs.kv_cache_kernel_block_id_host.defined()
&& attn_inputs.kv_cache_kernel_block_id_device.defined(),
"decode should have kv cache block id.");
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XQASpecAttnOp::prepare() uses the error text "decode should have kv cache block id." even though this path is specifically for target-verify/spec decode. Updating the message to mention spec/target-verify would make failures easier to diagnose (especially since XQAAttnOp::prepare uses the same text for true decode).

Suggested change
"decode should have kv cache block id.");
"spec decode/target-verify should have kv cache block id.");

Copilot uses AI. Check for mistakes.
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

LLLLKKKK commented Apr 9, 2026

🤖 AI Code Review (incremental) — PR #853
Head SHA: fe51006c3481 | Previous: 68ff0898bca6 | Verdict: P2

Changes since last review

New commit fe51006c3481 ("fix - fix mtp target layer_to_groups size error") fixes a bug in KVCacheManager::getMainModelCacheLayerLayout() — moved resize() after the assignment so it correctly extends/truncates the copied vector.

Findings

[P2] layer_to_groups resize — missing defensive assertion
The fix is correct for the common case (layer_to_group_id.size() <= layer_num), but if layer_to_group_id.size() > layer_num, the resize silently truncates. Consider adding:

RTP_LLM_CHECK_WITH_INFO(config_.layer_to_group_id.size() <= config_.layer_num,
    "layer_to_group_id size exceeds layer_num");

[Nit] Spec decode test tolerance (rtol=1e-1, atol=3e-1) is quite loose — understandable for FP8 KV cache but a comment explaining the choice would help.

XQA spec decode implementation is solid with good test coverage. Maintaining P2 for the missing guard.

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review — PR #853

PR 概述

Title: feat - support xqa spec
Author: zerozw
规模: 8(GitHub) files, +444/-5

核心目标

为 XQA attention 添加 speculative decoding(target verify)支持。新增 XQASpecAttnOp(C++ binding)和 XQASpecImpl(Python wrapper),使 XQA kernel 能处理 q_len_per_req > 1 的 spec decode 场景。同时修复了 KVCacheManager::getMainModelCacheLayerLayout()layer_to_groups 的 resize 时序 bug。


改动逻辑拆解

GitHub 开源仓库变更(主要代码)

1. C++ XQA Params 扩展 (CudaXqa.h)

XQAParams 中新增 q_cu_seqlens(cumulative query sequence lengths)和 max_q_len(默认值 1),用于传递 spec decode 的 query 长度信息给底层 runXqa kernel。

2. XQASpecAttnOp C++ 实现 (XQAAttnOp.cc/h)

新增 XQASpecAttnOp 类,包含 support()prepare()forward() 三个方法:

  • support(): 仅在 is_target_verify=truedecode_cu_seqlens_d 有效、FP8 KV cache、SM90 时启用
  • prepare(): 构建 XQAParams,计算 sequence_lengths = prefix_lengths + input_lengths,从 decode_cu_seqlens_d 计算 max_q_len
  • forward(): 分配 4D output tensor [batch, max_q_len, heads, head_dim],调用 runXqa 并传入 max_q_lenq_cu_seqlens

3. XQASpecImpl Python 包装 (xqa.py)

新增 XQASpecImpl(FMHAImplBase) 类,使用 FusedRopeKVCachePrefillOpQOut 做 RoPE(而非 decode 版本),调用 XQASpecAttnOp 做 attention,最后 reshape output 为 [total_tokens, hidden_dim]

4. 注册到 Attention Factory (init.py)

XQASpecImpl 添加到 PREFILL_MHA_IMPS 列表,位于 FlashInferTRTLLMSpecDecodeImplFlashInferTRTLLMPrefillImpl 之后。

5. XQAAttnOp::support() 增加 guard

XQAAttnOp::support() 新增 is_target_verify 检查,target verify 请求不走原 decode XQA 路径。对应 Python 侧 XQAImpl.support() 也加了同样的 guard。

6. KVCacheManager bug fix (KVCacheManager.cc)

layout.layer_to_groups.resize(config_.layer_num) 从赋值 layout.layer_to_groups = config_.layer_to_group_id 之前移到之后。修复了当 config_.layer_to_group_id.size() > config_.layer_num 时(MTP 场景),layer_to_groups 包含多余 MTP 层映射的问题。

7. 测试 (test_xqa.py, base_attention_test.py)

  • base_attention_test.py: 新增 kv_cache_dtype 参数支持,_create_kv_cache 支持 FP8 类型
  • test_xqa.py: 新增 _create_spec_attention_inputs_compute_flashinfer_spec_decode_reference_test_spec_decode_correctnesstest_spec_supporttest_spec_decode 等测试方法

Checklist 检查结果

通用原则

软件工程原则

检查项 结果
SRP XQASpecAttnOpXQAAttnOp 大量重复代码,应考虑继承或模板化
OCP ✅ 通过新增类扩展,未修改核心 XQA 逻辑
LSP
ISP
DIP
DRY XQASpecAttnOp::forward()XQAAttnOp::forward() 约 70% 代码重复
KISS
YAGNI

架构审视

检查项 结果
抽象边界
依赖方向
状态完整性
错误语义
可观测性
可演进性
可运维性

测试

检查项 结果
新功能有对应测试
边界 case 覆盖 ❌ 仅测试了 batch_size=2, q_len_per_req=4 单一配置

代码质量与文档

检查项 结果
无关改动分离 ❌ KVCacheManager bug fix 与 XQA spec 功能混在同一 PR
Commit 原子性
PR description ❌ PR body 为空

领域检查

A. 兼容性与配置 — 全部 ✅

B. 正确性与逻辑 — ❌ 见问题 1, 2

C. 线程安全与并发 — 全部 ✅

D. 性能 — ❌ 见问题 5, 6

E. 分布式 — 全部 ✅

F. 跨平台 — 全部 ✅

G. 语言与框架特有 — 全部 ✅

H. 测试与 CI — 见通用原则.测试

I. 代码质量 — ❌ 见问题 8


Review 意见

问题

  1. kernel_tokens_per_block 字段不存在,编译必定失败 [P0]

    XQASpecAttnOp::support()XQASpecAttnOp::forward() 中引用了 attn_configs_.kernel_tokens_per_block,但 AttentionConfigs(定义在 AttentionConfig.h)中只有 tokens_per_block,不存在 kernel_tokens_per_block 字段。全局搜索整个代码库也无此字段。

    // XQASpecAttnOp::support() 中:
    attn_configs_.kernel_tokens_per_block  // 编译错误
    // XQASpecAttnOp::forward() 中:
    attn_configs_.kernel_tokens_per_block  // 编译错误

    对比 XQAAttnOp 使用的是 attn_configs_.tokens_per_block

    建议:将 kernel_tokens_per_block 改为 tokens_per_block

  2. max_seq_len 计算与原 decode 路径不一致,可能导致 attention 越界 [P1]

    XQAAttnOp::forward() 传给 runXqamax_seq_lenparams->max_seq_len + 1,而 XQASpecAttnOp::prepare() 计算 max_seq_len 时:

    params->max_seq_len =
        attn_inputs.input_lengths.max().item<int32_t>() + attn_inputs.prefix_lengths.max().item<int32_t>();

    XQASpecAttnOp::forward() 直接传 params->max_seq_len,没有 +1。

    需要确认 spec decode 场景下是否也需要 +1,否则可能导致最后一个 token 的 KV 被截断。

    另外,input_lengths.max() + prefix_lengths.max() 取的是各自的 max 再相加,而非 per-request 的 (input_lengths[i] + prefix_lengths[i]) 的 max。当不同 request 的分布不同时,这个值可能不准确。

  3. XQASpecImpl.forward() 缺少 apply_write_cache_store 调用 [P1]

    对比同类实现(XQAImplFlashInferTRTLLMSpecDecodeImpl),它们的 forward() 都在 RoPE 之后、FMHA 之前调用了 common.apply_write_cache_store()XQASpecImpl.forward() 完全缺少这一步。

    缺少 common.apply_write_cache_store(self.write_cache_store_impl, self.attn_inputs, kv_cache) 调用。这会导致 prefill 阶段的 cache store 写入被跳过,在需要 cache store 的场景下产生正确性问题。

  4. XQASpecAttnOp::forward() 中 kv_cache 检查在使用之后 [P2]

    if (kv_cache.has_value()) {
        kv_block_array = params->kv_block_array;
        kv_block_array.mPrimaryPoolPtr = kv_cache.value().kv_cache_base.data_ptr();
        // ...
    }
    RTP_LLM_CHECK_WITH_INFO(kv_cache.has_value(), "spec decode should have kv cache.");

    CHECK 放在 if 块之后。建议将 CHECK 移到 if 块之前。

  5. prepare() 中多次 .item() 调用导致 GPU-CPU 同步 [P2]

    XQASpecAttnOp::prepare() 中有 3 次 .item<int32_t>() 调用,每次都会触发一次 CUDA synchronize。在 hot path 上这会显著影响性能。建议合并为一次同步,或在 CPU 端计算(input_lengthsprefix_lengths 已经在 CPU 上 pin_memory)。

  6. prepare_cuda_graph() 未更新 spec 特有参数 [P2]

    XQASpecImpl.prepare_cuda_graph() 调用 common.update_trt_params(),但 spec decode 特有的参数(q_cu_seqlensmax_q_lenmax_seq_len)不会被 update_trt_params 更新。如果 CUDA Graph replay 时这些参数变化了,会使用 capture 时的旧值。

  7. 测试容差过大 [P2]

    compare_tensors(output, ref_output, rtol=1e-1, atol=3e-1, ...)

    rtol=0.1, atol=0.3 对于 attention 输出来说非常宽松,可能掩盖实际的计算错误。建议收紧到 rtol=5e-2, atol=1e-1 并验证是否通过。

  8. XQASpecAttnOpXQAAttnOp 代码重复严重 [P3]

    两个类的 forward() 方法约 70% 代码相同。建议 XQASpecAttnOp 继承 XQAAttnOp,或提取公共辅助方法。

小问题

  • [P3] KVCacheManager bug fix 应作为独立 PR 提交
  • [P3] PR description 为空,建议补充动机和设计说明

整体评价

本 PR 为 XQA attention 添加 speculative decoding 支持,整体设计方向正确,测试覆盖基本到位。但存在一个编译阻塞问题(kernel_tokens_per_block 字段不存在)和两个重要问题(max_seq_len 计算可能不一致、缺少 apply_write_cache_store 调用)。

存在阻塞/重要问题,不建议合入

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

Code Review: PR #853 — feat - support xqa spec

审查日期: 2026-04-14 | 审查者: Claude Agent


变更概览

本 PR 包含两个改动:

  1. XQA Speculative Decoding 支持:新增 XQASpecAttnOp(C++)和 XQASpecImpl(Python),为 target verify 场景提供基于 TRT XQA kernel 的 attention 实现,仅在 SM90 + FP8 KV Cache 下启用。
  2. KVCacheManager bug fix:修复 MTP 场景下 layer_to_groups 的 resize 时序问题。
级别 数量
P0 阻塞 0
P1 重要 3
P2 建议 4
P3 Nit 2

P1-1: XQASpecImpl.forward() 缺少 write_cache_store 调用

文件: rtp_llm/models_py/modules/factory/attention/cuda_impl/xqa.pyXQASpecImpl.forward()

对比同模块的 XQAImpl.forward()FlashInferTRTLLMSpecDecodeImpl.forward() 等所有现有 FMHAImplBase 实现,它们在 forward 中都调用了 common.apply_write_cache_store() 来将 KV 写入 cache。XQASpecImpl 虽然在 __init__ 中创建了 self.write_cache_store_impl,但 forward() 中没有调用它。

Speculative decoding 的 target verify 阶段需要将新 token 的 KV 写入 cache,缺少此调用可能导致 KV cache 未被正确更新。

建议在 fmha_impl.forward() 调用前添加:

common.apply_write_cache_store(
    self.write_cache_store_impl, self.attn_inputs, kv_cache
)

P1-2: kernel_tokens_per_block 字段在 main 分支不存在

文件: XQAAttnOp.ccXQASpecAttnOp::support()forward()

PR 中使用了 attn_configs_.kernel_tokens_per_block,但当前 AttentionConfig.h 中只有 tokens_per_block(现有 XQAAttnOp 也使用 tokens_per_block)。同样 prepare() 使用了 kv_cache_kernel_block_id_device,但 OpDefs.h 中只有 kv_cache_block_id_device

如果这些字段由其他 PR 引入,需确认合入顺序;否则本 PR 无法编译。

P1-3: __init__.py 依赖未合入的 HeadWisePrefillImpl

diff 中 XQASpecImpl 被插入到 HeadWisePrefillImpl 之后,但当前 main 分支不存在该类。直接合入会导致 merge conflict 或 import 错误。


P2-1: max_seq_len 计算语义偏差

文件: XQAAttnOp.ccXQASpecAttnOp::prepare()

input_lengths.max() + prefix_lengths.max() 取的是各自 max 之和,而非 per-request (input_length + prefix_length) 的 max。当分布不均匀时值会偏大。建议改为 (input_lengths + prefix_lengths).max().item<int32_t>()

P2-2: kv_block_array 使用在 CHECK 之前

文件: XQAAttnOp.ccXQASpecAttnOp::forward()

先对 kv_cachehas_value() 使用,再执行 RTP_LLM_CHECK_WITH_INFO(kv_cache.has_value(), ...)。建议将 CHECK 移到使用之前。

P2-3: 测试容差较宽松 (rtol=1e-1, atol=3e-1)

文件: test_xqa.py_test_spec_decode_correctness()

30% 绝对容差可能掩盖计算错误。建议添加注释说明 FP8 量化导致的预期误差范围。

P2-4: XQASpecAttnOp 与 XQAAttnOp 大量代码重复

forward() 约 80% 代码重复,建议后续提取公共方法或使用继承。


P3-1: Commit 原子性

KVCacheManager bug fix 与 XQA spec 功能无关,建议拆分为独立 PR。

P3-2: XQASpecAttnOp::support() 硬编码 SM90

使用 get_sm() != kSM_90(严格等于),而原 XQAAttnOp 使用 >= kSM_90。SM100+ GPU 不会启用此路径。如有意为之建议添加注释。


KVCacheManager Bug Fix

layer_to_groups.resize(config_.layer_num) 移到 layout.layer_to_groups = config_.layer_to_group_id 赋值之后,确保先拷贝完整 group_id 映射再 truncate 到主模型 layer_num。修复正确。


🤖 Generated by Claude Agent Code Review

Copilot AI review requested due to automatic review settings April 15, 2026 08:13
@zerozw zerozw force-pushed the feature/prepare_cu_opt branch from fe51006 to 320919b Compare April 15, 2026 08:13
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 1 out of 1 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 260 to +262
layout.layer_to_groups = config_.layer_to_group_id;
layout.group_types = config_.group_types;
layout.layer_to_groups.resize(config_.layer_num);
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resizing layout.layer_to_groups after assignment can silently truncate or default-fill the mapping if config_.layer_to_group_id.size() differs from config_.layer_num, which may produce incorrect layer→group pairing without any signal. Prefer validating that config_.layer_to_group_id.size() == config_.layer_num (and failing fast if not), or explicitly defining the intended padding/truncation behavior before resizing.

Suggested change
layout.layer_to_groups = config_.layer_to_group_id;
layout.group_types = config_.group_types;
layout.layer_to_groups.resize(config_.layer_num);
RTP_LLM_CHECK_WITH_INFO(config_.layer_to_group_id.size() == config_.layer_num,
"config_.layer_to_group_id.size()[%ld] != config_.layer_num[%d]",
config_.layer_to_group_id.size(),
config_.layer_num);
layout.layer_to_groups = config_.layer_to_group_id;
layout.group_types = config_.group_types;

Copilot uses AI. Check for mistakes.
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

Code Review v2 — PR #853 (feat - support xqa spec)

Review 版本: v2(增量 review) | SHA: 320919b43d61 | 日期: 2026-04-15


v1 -> v2 变更摘要

PR 被 force-push 重写,原有的 XQA Speculative Decoding 功能代码(7 个文件,+444 行)已全部移除。当前 PR 仅保留 KVCacheManager 的 bug fix(1 个文件,+1/-1)。


v1 P1 问题跟踪

ID v1 问题 v2 状态
P1-1 XQASpecImpl.forward() 缺少 write_cache_store 调用 不再适用(代码已移除)
P1-2 kernel_tokens_per_block 字段在 main 不存在 不再适用(代码已移除)
P1-3 init.py 依赖未合入的 HeadWisePrefillImpl 不再适用(代码已移除)

当前变更分析

文件: rtp_llm/cpp/cache/KVCacheManager.ccgetMainModelCacheLayerLayout()

layout.layer_to_groups.resize(config_.layer_num)layout.layer_to_groups = config_.layer_to_group_id 赋值之前移到之后,修复 MTP 场景下 layer_to_groups 的 size 不正确问题。

  • 修复前:先 resize 再赋值,resize 效果被赋值覆盖,vector 携带多余的 MTP 层数据
  • 修复后:先赋值完整映射,再 truncate 到主模型层数,size 正确

修复逻辑正确,无风险。


总结

P0 P1 P2 P3
0 0 0 0

v1 的 3 个 P1 因 XQA spec 代码移除而不再适用。当前 PR 仅含 KVCacheManager bug fix,修复正确,可以合入


🤖 Generated by Claude Code Review Agent v2

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review — PR #853

PR 概述

Title: feat - support xqa spec
Author: zerozw
规模: 1 file, +1/-1

核心目标

修复 KVCacheManager::getMainModelCacheLayerLayout()layer_to_groups 的初始化顺序:将 resize(config_.layer_num) 从赋值前移到赋值后,确保在 config_.layer_to_group_id 拷贝完成后再截断/扩展到正确大小。


改动分析

旧代码中 resize= 之前执行,赋值会完全覆盖 vector 内容和大小,因此 resize 是无效操作。若 config_.layer_to_group_id.size() < config_.layer_num,赋值后 vector 大小不足,后续循环按 layer_id 索引访问存在越界风险。

新代码将 resize 移到赋值之后,语义为"先拷贝 config 中的 group mapping,再确保 vector 恰好有 layer_num 个元素"。逻辑正确。


Review 结论

LGTM ready to ci

无 P0/P1 问题。

P3 建议: PR description 为空,建议补充一句说明改动背景(XQA speculative 场景下 layer_to_group_id.size() 可能与 layer_num 不一致)。

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

AI Code Review — PR #853

Summary: P0/0 · P1/0 · P2/0 · P3/0

Review status: LGTM

lgtm ready to ci

Strengths

  • Correctly fixes a subtle bug where the main model CacheLayerLayout could contain stale MTP layer-to-group mappings, since the old resize was immediately overwritten by the full layer_to_group_id assignment
  • Minimal, focused change — one line moved, no unnecessary refactoring

@wht21
Copy link
Copy Markdown
Collaborator

wht21 commented Apr 28, 2026

internal source has been updated, please review the changes!

@zerozw zerozw closed this Apr 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants