feat - support xqa spec#853
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a dedicated XQA “spec decode” path for target-verify workloads (FP8 KV cache), wiring it into the attention factory and extending tests to cover support + correctness.
Changes:
- Introduces
XQASpecAttnOp(C++/pybind) andXQASpecImpl(Python FMHA impl) for target-verify spec decode. - Updates support/selection logic to route
is_target_verify=Trueaway from the normalXQAAttnOp. - Extends CUDA attention tests to exercise FP8 KV cache + spec decode behavior.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| rtp_llm/models_py/modules/factory/attention/cuda_impl/xqa.py | Adds Python XQASpecImpl and a multi-token fallback in decode path; adjusts semaphore dtype. |
| rtp_llm/models_py/modules/factory/attention/cuda_impl/test/test_xqa.py | Adds spec inputs, reference computation, and new tests for spec op support + correctness. |
| rtp_llm/models_py/modules/factory/attention/cuda_impl/test/base_attention_test.py | Extends test config to set kv_cache_dtype; improves FP8 KV cache test tensor creation. |
| rtp_llm/models_py/modules/factory/attention/init.py | Registers XQASpecImpl into decode implementation list. |
| rtp_llm/models_py/bindings/cuda/XQAAttnOp.h | Declares new XQASpecAttnOp binding. |
| rtp_llm/models_py/bindings/cuda/XQAAttnOp.cc | Implements XQASpecAttnOp support/prepare/forward and updates XQAAttnOp::support. |
| rtp_llm/cpp/cuda/ops/CudaXqa.h | Extends XQAParams with q_cu_seqlens and max_q_len. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| const auto input_type = attn_configs_.dtype == torch::kBFloat16 ? DataType::TYPE_BF16 : DataType::TYPE_FP16; | ||
| const auto kv_type = DataType::TYPE_FP8_E4M3; | ||
| return supportXqa(input_type, | ||
| input_type, | ||
| kv_type, | ||
| attn_configs_.head_num / attn_configs_.kv_head_num, | ||
| attn_configs_.size_per_head, | ||
| attn_configs_.kernel_tokens_per_block); |
There was a problem hiding this comment.
XQASpecAttnOp::support() treats any non-BF16 attention dtype as FP16. If attn_configs_.dtype can be something else (e.g., FP8), support() may return true but forward() will still pass the original input tensor to runXqa() (which then interprets it as FP16 based on the BF16 boolean), producing incorrect results. Restrict support to {torch::kBFloat16, torch::kFloat16} (return false otherwise) or explicitly cast the input to FP16/BF16 in forward() and keep support() consistent with that behavior.
| return output.reshape(output.shape[0] * output.shape[1], -1) | ||
|
|
There was a problem hiding this comment.
XQASpecAttnOp::forward() allocates [batch_size, max_q_len, heads, dim] on the C++ side, so output.shape[1] is max_q_len (not necessarily the actual per-request q length). Flattening via reshape(batch * max_q_len, -1) can leak padded/uninitialized positions into the returned tensor when q lengths vary across the batch. Consider returning a packed [total_tokens, ...] tensor by slicing/packing using self.attn_inputs.decode_cu_seqlens/input_lengths, or keep the 4D output and let the caller slice to valid ranges.
| return output.reshape(output.shape[0] * output.shape[1], -1) | |
| # XQASpecAttnOp returns output with shape [batch_size, max_q_len, num_heads, head_dim]. | |
| # We must not blindly flatten [batch_size * max_q_len, ...], because max_q_len may exceed | |
| # the actual per-request sequence lengths, which would leak padded/uninitialized positions. | |
| # Instead, pack only the valid tokens according to self.attn_inputs.input_lengths. | |
| batch_size, max_q_len, num_heads, head_dim = output.shape | |
| hidden_size = num_heads * head_dim | |
| # input_lengths is expected to be a 1D tensor/list-like of length batch_size | |
| input_lengths = self.attn_inputs.input_lengths | |
| # Convert to a tensor if needed, without changing device/dtype of sequence length storage | |
| if not torch.is_tensor(input_lengths): | |
| input_lengths = torch.tensor(input_lengths, device=output.device, dtype=torch.long) | |
| packed_outputs = [] | |
| for b in range(batch_size): | |
| seq_len = int(input_lengths[b].item()) | |
| if seq_len <= 0: | |
| continue | |
| # Slice valid timesteps for this batch element and reshape to [seq_len, hidden_size] | |
| packed_outputs.append(output[b, :seq_len].reshape(-1, hidden_size)) | |
| if not packed_outputs: | |
| # All sequences are empty; return a correctly shaped empty tensor | |
| return output.new_empty((0, hidden_size)) | |
| return torch.cat(packed_outputs, dim=0) |
| def _spec_decode_fallback( | ||
| self, | ||
| q_4d: torch.Tensor, | ||
| k_cache: torch.Tensor, | ||
| v_cache: torch.Tensor, | ||
| page_table: torch.Tensor, | ||
| seq_lens: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| batch_size, q_len_per_req, num_heads, head_dim = q_4d.shape | ||
| num_kv_heads = k_cache.shape[1] | ||
| group_size = num_heads // num_kv_heads | ||
| scale = head_dim**-0.5 | ||
| outputs = [] | ||
|
|
||
| for batch_idx in range(batch_size): | ||
| block_ids = page_table[batch_idx] | ||
| k_blocks = k_cache[block_ids] | ||
| v_blocks = v_cache[block_ids] | ||
| k_seq = ( | ||
| k_blocks.permute(1, 0, 2, 3) | ||
| .reshape(num_kv_heads, -1, head_dim) | ||
| .permute(1, 0, 2) | ||
| .contiguous() | ||
| ) | ||
| v_seq = ( | ||
| v_blocks.permute(1, 0, 2, 3) | ||
| .reshape(num_kv_heads, -1, head_dim) | ||
| .permute(1, 0, 2) | ||
| .contiguous() | ||
| ) | ||
|
|
||
| prefix_len = int(seq_lens[batch_idx].item()) | ||
| token_outputs = [] | ||
| for token_idx in range(q_len_per_req): | ||
| seq_len = prefix_len + token_idx + 1 | ||
| q_token = q_4d[batch_idx, token_idx] | ||
| k_token = k_seq[:seq_len] | ||
| v_token = v_seq[:seq_len] | ||
| if group_size > 1: | ||
| k_token = k_token.repeat_interleave(group_size, dim=1) | ||
| v_token = v_token.repeat_interleave(group_size, dim=1) | ||
|
|
||
| scores = ( | ||
| torch.einsum("hd,thd->ht", q_token.float(), k_token.float()) * scale | ||
| ) | ||
| attn = torch.softmax(scores, dim=-1).to(v_token.dtype) | ||
| token_output = torch.einsum("ht,thd->hd", attn, v_token) | ||
| token_outputs.append(token_output.to(q_4d.dtype)) | ||
| outputs.append(torch.stack(token_outputs, dim=0)) | ||
| return torch.stack(outputs, dim=0).unsqueeze(1).contiguous() |
There was a problem hiding this comment.
This fallback is a nested Python loop over batch and tokens with multiple einsum() calls and repeat_interleave(), which will be extremely slow and can cause major regressions if hit outside of small tests. Additionally, attn = softmax(...).to(v_token.dtype) can be especially problematic when v_token is FP8 (precision loss/underflow for attention weights). Recommended: avoid silently using this path in production (e.g., raise/return unsupported so the factory selects the spec kernel), and if it must exist, keep attention weights in FP16/BF16/FP32 (only cast the final output) and add an explicit warning that it is for debug-only.
| RTP_LLM_CHECK_WITH_INFO(attn_inputs.kv_cache_kernel_block_id_host.defined() | ||
| && attn_inputs.kv_cache_kernel_block_id_device.defined(), | ||
| "decode should have kv cache block id."); | ||
|
|
There was a problem hiding this comment.
This check is inside XQASpecAttnOp::prepare(), but the error message says "decode should have kv cache block id." Consider updating it to mention spec/target-verify (and the exact required fields), e.g., "spec decode requires kv_cache_kernel_block_id_{host,device} to be set".
| self.semaphores = torch.zeros( | ||
| 8 * 1024 * 1024, dtype=torch.uint32, device="cuda" | ||
| ) |
There was a problem hiding this comment.
Changing semaphores from uint8 to uint32 increases the allocation from ~8 MiB to ~32 MiB per instance (same element count, 4× element size). If multiple XQADecodeImpl instances can exist concurrently, this becomes a significant GPU memory overhead. If uint32 is required by the kernel/atomic semantics, consider reducing the element count to keep bytes constant, or allocate this buffer from a shared pool similar to workspace_buffer.
| DECODE_MHA_IMPS.extend([FlashInferTRTLLMDecodeImpl]) | ||
| DECODE_MHA_IMPS.append(XQASpecImpl) | ||
| DECODE_MHA_IMPS.append(get_xqa_impl()) |
There was a problem hiding this comment.
The factory selection is now extended with XQASpecImpl, but the added tests focus on the raw XQASpecAttnOp behavior rather than verifying the factory chooses XQASpecImpl when attn_inputs.is_target_verify=True (and rejects it otherwise). Adding a small test that exercises the factory/impl selection would help prevent regressions where ordering or support predicates change.
cac26c4 to
e162507
Compare
🤖 Code Review (v1) — P2Verdict: P2 (Suggestions) 为 XQA 添加 speculative decoding 支持,新增 P2 Issues1. 2. 3. 测试 tolerance 较宽松 4. Automated review by CI Bot |
e162507 to
68ff089
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| torch::Tensor xqa_input = input.contiguous(); | ||
| runXqa(xqa_input.data_ptr(), | ||
| input.dtype() == torch::kBFloat16, | ||
| output.data_ptr(), | ||
| local_head_num, | ||
| local_head_num_kv, | ||
| size_per_head, | ||
| params->batch_size, | ||
| static_cast<size_t>(kv_block_array.mMaxBlocksPerSeq), | ||
| params->max_seq_len, | ||
| attn_configs_.kernel_tokens_per_block, | ||
| kv_block_array.mPrimaryPoolPtr, | ||
| reinterpret_cast<int32_t*>((KVCacheIndex*)(params->kv_cache_offset.data_ptr())), | ||
| kv_block_array.cache_type == KvCacheDataType::FP8, | ||
| reinterpret_cast<uint32_t*>(params->sequence_lengths.data_ptr()), | ||
| nullptr, | ||
| params->max_q_len, | ||
| params->q_cu_seqlens.data_ptr()); |
There was a problem hiding this comment.
XQASpecAttnOp::forward passes params->max_q_len and params->q_cu_seqlens into runXqa(), enabling the spec path. However runXqa() (rtp_llm/cpp/cuda/ops/CudaXqa.cc) initializes q_mask/semaphores as function-static tensors based on the first call’s (kv_head_num, group_size, max_q_len, max_batch_size). If a non-spec decode call happens first (default max_q_len=2) and later spec verify uses a larger max_q_len (e.g. propose_step_+1), the cached q_mask size will mismatch max_q_len and can lead to incorrect results or out-of-bounds reads in the kernel. runXqa should cache these buffers keyed by max_q_len/(kv_head_num,group_size) or recreate them when arguments change, rather than using single-init statics tied to the first invocation.
| from rtp_llm.config.engine_config import EngineConfig | ||
| from rtp_llm.config.model_config import ModelConfig | ||
| from rtp_llm.config.py_config_modules import PyEnvConfigs | ||
| from rtp_llm.ops import KvCacheDataType | ||
| from rtp_llm.ops.compute_ops import ( | ||
| FusedRopeKVCachePrefillOpQOut, | ||
| PyAttentionInputs, | ||
| XQAAttnOp, | ||
| XQAParams, | ||
| XQASpecAttnOp, | ||
| get_typemeta, | ||
| init_device, |
There was a problem hiding this comment.
Several newly added imports appear unused in this test module (EngineConfig, ModelConfig, PyEnvConfigs, FusedRopeKVCachePrefillOpQOut, init_device). If the repo runs linting/static checks, these can fail CI and also add noise for readers. Please remove unused imports or add the missing usage if they’re intended for setup.
| from rtp_llm.config.engine_config import EngineConfig | |
| from rtp_llm.config.model_config import ModelConfig | |
| from rtp_llm.config.py_config_modules import PyEnvConfigs | |
| from rtp_llm.ops import KvCacheDataType | |
| from rtp_llm.ops.compute_ops import ( | |
| FusedRopeKVCachePrefillOpQOut, | |
| PyAttentionInputs, | |
| XQAAttnOp, | |
| XQAParams, | |
| XQASpecAttnOp, | |
| get_typemeta, | |
| init_device, | |
| from rtp_llm.ops import KvCacheDataType | |
| from rtp_llm.ops.compute_ops import ( | |
| PyAttentionInputs, | |
| XQAAttnOp, | |
| XQAParams, | |
| XQASpecAttnOp, | |
| get_typemeta, |
| int batch_size = attn_inputs.sequence_lengths.size(0); | ||
| RTP_LLM_CHECK_WITH_INFO(attn_inputs.kv_cache_kernel_block_id_host.defined() | ||
| && attn_inputs.kv_cache_kernel_block_id_device.defined(), | ||
| "decode should have kv cache block id."); |
There was a problem hiding this comment.
XQASpecAttnOp::prepare() uses the error text "decode should have kv cache block id." even though this path is specifically for target-verify/spec decode. Updating the message to mention spec/target-verify would make failures easier to diagnose (especially since XQAAttnOp::prepare uses the same text for true decode).
| "decode should have kv cache block id."); | |
| "spec decode/target-verify should have kv cache block id."); |
|
🤖 AI Code Review (incremental) — PR #853 Changes since last reviewNew commit Findings[P2] layer_to_groups resize — missing defensive assertion RTP_LLM_CHECK_WITH_INFO(config_.layer_to_group_id.size() <= config_.layer_num,
"layer_to_group_id size exceeds layer_num");[Nit] Spec decode test tolerance ( XQA spec decode implementation is solid with good test coverage. Maintaining P2 for the missing guard. |
|
🤖 AI Code Review — PR #853 PR 概述Title: 核心目标为 XQA attention 添加 speculative decoding(target verify)支持。新增 改动逻辑拆解GitHub 开源仓库变更(主要代码)1. C++ XQA Params 扩展 (CudaXqa.h)在 2. XQASpecAttnOp C++ 实现 (XQAAttnOp.cc/h)新增
3. XQASpecImpl Python 包装 (xqa.py)新增 4. 注册到 Attention Factory (init.py)将 5. XQAAttnOp::support() 增加 guard原 6. KVCacheManager bug fix (KVCacheManager.cc)将 7. 测试 (test_xqa.py, base_attention_test.py)
Checklist 检查结果通用原则软件工程原则
架构审视
测试
代码质量与文档
领域检查A. 兼容性与配置 — 全部 ✅B. 正确性与逻辑 — ❌ 见问题 1, 2C. 线程安全与并发 — 全部 ✅D. 性能 — ❌ 见问题 5, 6E. 分布式 — 全部 ✅F. 跨平台 — 全部 ✅G. 语言与框架特有 — 全部 ✅H. 测试与 CI — 见通用原则.测试I. 代码质量 — ❌ 见问题 8Review 意见问题
小问题
整体评价本 PR 为 XQA attention 添加 speculative decoding 支持,整体设计方向正确,测试覆盖基本到位。但存在一个编译阻塞问题( ❌ 存在阻塞/重要问题,不建议合入 |
Code Review: PR #853 — feat - support xqa spec审查日期: 2026-04-14 | 审查者: Claude Agent 变更概览本 PR 包含两个改动:
P1-1: XQASpecImpl.forward() 缺少 write_cache_store 调用文件: 对比同模块的 Speculative decoding 的 target verify 阶段需要将新 token 的 KV 写入 cache,缺少此调用可能导致 KV cache 未被正确更新。 建议在 common.apply_write_cache_store(
self.write_cache_store_impl, self.attn_inputs, kv_cache
)P1-2:
|
fe51006 to
320919b
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 1 out of 1 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| layout.layer_to_groups = config_.layer_to_group_id; | ||
| layout.group_types = config_.group_types; | ||
| layout.layer_to_groups.resize(config_.layer_num); |
There was a problem hiding this comment.
Resizing layout.layer_to_groups after assignment can silently truncate or default-fill the mapping if config_.layer_to_group_id.size() differs from config_.layer_num, which may produce incorrect layer→group pairing without any signal. Prefer validating that config_.layer_to_group_id.size() == config_.layer_num (and failing fast if not), or explicitly defining the intended padding/truncation behavior before resizing.
| layout.layer_to_groups = config_.layer_to_group_id; | |
| layout.group_types = config_.group_types; | |
| layout.layer_to_groups.resize(config_.layer_num); | |
| RTP_LLM_CHECK_WITH_INFO(config_.layer_to_group_id.size() == config_.layer_num, | |
| "config_.layer_to_group_id.size()[%ld] != config_.layer_num[%d]", | |
| config_.layer_to_group_id.size(), | |
| config_.layer_num); | |
| layout.layer_to_groups = config_.layer_to_group_id; | |
| layout.group_types = config_.group_types; |
Code Review v2 — PR #853 (feat - support xqa spec)Review 版本: v2(增量 review) | SHA: v1 -> v2 变更摘要PR 被 force-push 重写,原有的 XQA Speculative Decoding 功能代码(7 个文件,+444 行)已全部移除。当前 PR 仅保留 KVCacheManager 的 bug fix(1 个文件,+1/-1)。 v1 P1 问题跟踪
当前变更分析文件: 将
修复逻辑正确,无风险。 总结
v1 的 3 个 P1 因 XQA spec 代码移除而不再适用。当前 PR 仅含 KVCacheManager bug fix,修复正确,可以合入。 🤖 Generated by Claude Code Review Agent v2 |
|
🤖 AI Code Review — PR #853 PR 概述Title: 核心目标修复 改动分析旧代码中 新代码将 Review 结论LGTM ready to ci ✅ 无 P0/P1 问题。 P3 建议: PR description 为空,建议补充一句说明改动背景(XQA speculative 场景下 |
AI Code Review — PR #853Summary: P0/0 · P1/0 · P2/0 · P3/0 Review status: LGTM lgtm ready to ci Strengths
|
|
internal source has been updated, please review the changes! |
No description provided.