Skip to content

[wave] NSA: water dialect lowering for NSA composite ops #1262

@harsh-nod

Description

@harsh-nod

Parent

Part of #1243 — DeepSeek NSA kernels for MI350

Description

Define new water/wave dialect operations for NSA and implement lowering paths to the existing wave compiler infrastructure.

New ops needed

  1. wave.nsa_compress — mean-pool KV into compressed blocks

    • Could lower through existing reduction + reshape ops if available
    • Otherwise, a new dedicated op
  2. wave.nsa_topk_select — top-k block index selection

    • Novel op: combines scoring (compressed attention logits) with top-k extraction
    • Returns integer indices — different from typical wave float ops
    • May need a specialized lowering since top-k doesn't map cleanly to matmul/elementwise
  3. wave.nsa_selection_attn — sparse attention over gathered blocks

    • This is the core NSA primitive
    • Could potentially decompose into existing wave.attention + gather ops, but a fused op will be much more efficient
    • Needs to carry block_indices as an additional operand
  4. wave.nsa_gated_combine — weighted sum of three branches

    • Straightforward elementwise, could lower through existing ops

Lowering strategy

  • Phase 1: Implement as opaque custom kernels called from Python (bypass water dialect, direct wave kernel launch). Gets us running quickly.
  • Phase 2: Define water dialect ops with proper index expression inference, EPT (elements-per-thread) propagation, and scheduling strategy support. Enables the wave compiler to optimize across NSA ops.

Index expression considerations (per #1081, #1189)

  • Selection attention's gather pattern means index expressions depend on runtime values (block_indices)
  • This is fundamentally different from the static index patterns in standard attention
  • May need a new InferIndexExprsOpInterface implementation that handles indirect indexing

Depends on

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestnsaDeepSeek Native Sparse Attention

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions