Parent
Part of #1243 — DeepSeek NSA kernels for MI350
Description
Define new water/wave dialect operations for NSA and implement lowering paths to the existing wave compiler infrastructure.
New ops needed
-
wave.nsa_compress — mean-pool KV into compressed blocks
- Could lower through existing reduction + reshape ops if available
- Otherwise, a new dedicated op
-
wave.nsa_topk_select — top-k block index selection
- Novel op: combines scoring (compressed attention logits) with top-k extraction
- Returns integer indices — different from typical wave float ops
- May need a specialized lowering since top-k doesn't map cleanly to matmul/elementwise
-
wave.nsa_selection_attn — sparse attention over gathered blocks
- This is the core NSA primitive
- Could potentially decompose into existing
wave.attention + gather ops, but a fused op will be much more efficient
- Needs to carry block_indices as an additional operand
-
wave.nsa_gated_combine — weighted sum of three branches
- Straightforward elementwise, could lower through existing ops
Lowering strategy
- Phase 1: Implement as opaque custom kernels called from Python (bypass water dialect, direct wave kernel launch). Gets us running quickly.
- Phase 2: Define water dialect ops with proper index expression inference, EPT (elements-per-thread) propagation, and scheduling strategy support. Enables the wave compiler to optimize across NSA ops.
Index expression considerations (per #1081, #1189)
- Selection attention's gather pattern means index expressions depend on runtime values (block_indices)
- This is fundamentally different from the static index patterns in standard attention
- May need a new
InferIndexExprsOpInterface implementation that handles indirect indexing
Depends on
Parent
Part of #1243 — DeepSeek NSA kernels for MI350
Description
Define new water/wave dialect operations for NSA and implement lowering paths to the existing wave compiler infrastructure.
New ops needed
wave.nsa_compress— mean-pool KV into compressed blockswave.nsa_topk_select— top-k block index selectionwave.nsa_selection_attn— sparse attention over gathered blockswave.attention+ gather ops, but a fused op will be much more efficientwave.nsa_gated_combine— weighted sum of three branchesLowering strategy
Index expression considerations (per #1081, #1189)
InferIndexExprsOpInterfaceimplementation that handles indirect indexingDepends on