Add DeepSeek-V3.2 fused indexer path#788
Conversation
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
Pull request overview
This PR introduces a fused indexer path for DeepSeek-V3.2 to improve decode performance. It fuses the indexer's wk and weights_proj into a single GEMM (with FP8 block-scale wk load support) and adds an AITER indexer_qk_rope_quant_and_cache kernel that combines Q RoPE, Q quantization, weight scaling, and K norm/RoPE/cache writes. The new path is gated by ATOM_ENABLE_DS_INDEXER_QK_ROPE_CACHE_FUSION (default on), with fallbacks when the env is disabled, when shape constraints aren't met, or when indexer.wk uses an unsupported dtype.
Changes:
- New
IndexerWkWeightsProjLinear(FP8 block-scale wk dequant → BF16) and packed-modules wiring to fuseindexer.wk+indexer.weights_proj. - New fused
indexer_qk_rope_quant_and_cachecall site in both native and sparse plugin paths, with env-guarded fallback toindexer_k_quant_and_cache. - Updated
sparse_attn_indexercustom-op signatures (and fakes) to take K-norm/RoPE cache/scale args; fake/dummy returns now allocate fp32 weights to match the fused output dtype.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| atom/utils/envs.py | Adds ATOM_ENABLE_DS_INDEXER_QK_ROPE_CACHE_FUSION env (default 1). |
| atom/plugin/attention_mla_sparse.py | Plugin-mode sparse indexer routes to fused kernel and updates fake/dummy returns. |
| atom/models/deepseek_v2.py | Adds fused wk+weights_proj linear, fused QK-rope/cache call, fusion eligibility check, packed-modules and quant-exclude updates. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| model_prefix = maybe_prefix(prefix, "model") | ||
| use_indexer_wk_weights_proj_fusion = _can_fuse_indexer_wk_weights_proj( | ||
| config, | ||
| quant_config, | ||
| f"{model_prefix}.layers.0.self_attn.indexer", | ||
| ) |
| quant_exclude_name_mapping: dict[str, str] = { | ||
| # HF quant config uses "indexers_proj" but the ATOM module path is | ||
| # "indexer.weights_proj". str.replace translates each exclude entry. | ||
| "indexers_proj": "indexer.weights_proj", | ||
| # "indexer.wk_weights_proj". str.replace translates each exclude entry. | ||
| "indexers_proj": "indexer.wk_weights_proj", | ||
| } |
Keep indexer weight fusion decisions consistent with quant fallback paths and make dummy/profile behavior deterministic. Co-authored-by: Cursor <cursoragent@cursor.com>
Keep the PR compatible with the repository's Black formatting check after the review fixes. Co-authored-by: Cursor <cursoragent@cursor.com>
Collapse the review helper stack while keeping a single model-level fusion decision for weight loading. Co-authored-by: Cursor <cursoragent@cursor.com>
Ensure threaded checkpoint loading dequantizes pending FP8 indexer wk weights on the same device as their scales. Co-authored-by: Cursor <cursoragent@cursor.com>
Summary
wkandweights_projwhen the checkpoint layout is compatible, including FP8 block-scalewkload support.ATOM_ENABLE_DS_INDEXER_QK_ROPE_CACHE_FUSIONand enabled by default.indexer.wkuses an unsupported weight dtype such as FP4/MXFP4.Performance Validation
Environment:
/shared/data/amd_int/models/DeepSeek-V3.2NUM=10*CON, warmups=4*CONATOM_ENABLE_DS_INDEXER_QK_ROPE_CACHE_FUSION=0ATOM_ENABLE_DS_INDEXER_QK_ROPE_CACHE_FUSION=1/workdir/my_test/results/indexer_refusion_repro_20260514_120947Accuracy Validation
gsm8k, 3-shot, fused path enabled:No material accuracy regression was observed in this run.
Test plan
git diff --checkpassed for ATOM changes.python3 -m py_compile atom/models/deepseek_v2.py atom/plugin/attention_mla_sparse.py atom/utils/envs.pypassed.indexer_qk_rope_quant_and_cacheJIT compiled and ran during local validation.gsm8k3-shot accuracy with the fused path enabled.Notes
Made with Cursor