Skip to content

[wave] NSA: numerical correctness tests vs dense attention reference #1259

@harsh-nod

Description

@harsh-nod

Parent

Part of #1243 — DeepSeek NSA kernels for MI350

Description

Build a comprehensive correctness test suite that validates all NSA kernels against a pure PyTorch dense attention reference implementation.

Test matrix

Kernel Reference Tolerance
Mean pooling torch.mean(K.reshape(B,N//bs,bs,G,D), dim=2) exact (FP32), 1e-3 (FP16)
Compressed attention F.scaled_dot_product_attention(Q, K_cmp, V_cmp) 1e-3 abs, 1e-2 rel
Top-k selection torch.topk(scores, k=block_count) exact (integer indices)
Selection attention Dense attention masked to selected blocks 1e-3 abs, 1e-2 rel
Sliding window flash_attn_func(..., window_size=...) 1e-3 abs
Gated combination g_cmp*O_cmp + g_slc*O_slc + g_swa*O_swa exact (FP32), 1e-4 (FP16)
Full NSA pipeline Dense full attention 1e-2 rel (aggregate)

Test configurations

configs = [
    dict(B=1, M=1024, N=1024, H=32, G=4, D=64, bs=32, bc=8, ws=128),
    dict(B=2, M=4096, N=4096, H=64, G=8, D=128, bs=64, bc=16, ws=512),
    dict(B=1, M=64, N=65536, H=128, G=8, D=128, bs=64, bc=16, ws=512),  # decode-like
    dict(B=1, M=1, N=8192, H=128, G=8, D=128, bs=64, bc=16, ws=512),   # single-token decode
]

Edge cases

  • Sequence length not divisible by block_size
  • block_count > N // block_size (more blocks requested than available)
  • window_size > N (window larger than sequence)
  • Causal mask at sequence boundaries
  • GQA with HEADS_PER_GROUP=1 (MHA) and HEADS_PER_GROUP=H (MQA)

Infrastructure

  • Use pytest with parametrized fixtures
  • Run on both CPU (FP64 reference) and MI350 (FP16 kernel)
  • CI integration: add to wave test suite

Depends on

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestnsaDeepSeek Native Sparse Attention

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions