Skip to content

[wave] NSA: gradient correctness tests (backward pass) #1260

@harsh-nod

Description

@harsh-nod

Parent

Part of #1243 — DeepSeek NSA kernels for MI350

Description

Validate gradient correctness for all NSA backward kernels using torch.autograd.gradcheck and manual reference comparisons.

Test approach

  1. torch.autograd.gradcheck — run with FP64 inputs and small tensor sizes to verify analytical gradients match finite-difference numerical gradients

    • Requires FP64 kernel variants or CPU fallback
    • Inputs to check: Q, K, V, g_cmp, g_slc, g_swa
    • Use eps=1e-6, atol=1e-4, rtol=1e-3
  2. Reference comparison — compute gradients using pure PyTorch autograd on the dense-equivalent computation, compare against kernel gradients at FP16

    • Tolerance: 1e-2 relative for dQ, dK, dV; 1e-3 for dg_*
  3. Per-kernel backward tests:

    • Selection attention backward: dQ_slc, dK_slc, dV_slc
    • Compressed attention backward: dQ_cmp, dK_cmp → dK_pool, dV_cmp → dV_pool
    • Gating backward: dg_cmp, dg_slc, dg_swa, dO_cmp, dO_slc, dO_swa
    • Sliding window backward: reuse FA v3 backward (already tested)
  4. Atomic accumulation correctness

    • Selection attention backward uses atomic adds for dK/dV
    • Test with scenarios where many queries select the same blocks (high contention)
    • Verify determinism or acceptable non-determinism bounds

Configurations

  • Small sizes for gradcheck: B=1, M=64, N=128, H=4, G=2, D=32
  • Realistic sizes for reference comparison: B=1, M=4096, N=4096, H=128, G=8, D=128

Depends on

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestnsaDeepSeek Native Sparse Attention

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions