Skip to content

Commit 7aa6960

Browse files
committed
Add unit test for 2:4 sparse softmax
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent a9430be commit 7aa6960

5 files changed

Lines changed: 457 additions & 295 deletions

File tree

CHANGELOG.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
NVIDIA Model Optimizer Changelog
22
================================
33

4+
0.44 (2026-04-xx)
5+
^^^^^^^^^^^^^^^^^
6+
7+
**New Features**
8+
9+
- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). For every M consecutive key positions, the top-N attention scores are kept and the rest are set to -inf before softmax.
10+
411
0.43 (2026-03-xx)
512
^^^^^^^^^^^^^^^^^
613

modelopt/torch/kernels/triton_fa.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ def _apply_sparse_nm_to_qk_tile(
111111
For every ``SPARSITY_M`` consecutive elements along the N (key) dimension,
112112
keeps the top ``SPARSITY_N`` values and sets the rest to ``-inf``.
113113
``BLOCK_N`` must be divisible by ``SPARSITY_M``.
114+
115+
For M=4, exactly N values are retained (ties broken by position).
116+
For M=8, a threshold-based approach (``tl.sort``) may retain more
117+
than N values when ties straddle the threshold boundary.
114118
"""
115119
tl.static_assert(SPARSITY_M == 4 or SPARSITY_M == 8, "SPARSITY_M must be 4 or 8") # noqa: PLR1714
116120
MASK_VAL: tl.constexpr = float("-inf")
@@ -141,7 +145,7 @@ def _apply_sparse_nm_to_qk_tile(
141145
sorted_vals = tl.sort(reshaped, dim=2)
142146
KTH_IDX: tl.constexpr = SPARSITY_M - SPARSITY_N # index of N-th largest in ascending order
143147

144-
# Extract the threshold value (one extraction vs eight before)
148+
# Extract the threshold value at KTH_IDX via masked sum
145149
# Use 0.0 as fill (not -inf) so sum equals just the KTH element
146150
cols = tl.arange(0, 8)[None, None, :]
147151
threshold = tl.sum(tl.where(cols == KTH_IDX, sorted_vals, 0.0), axis=2)
@@ -272,7 +276,7 @@ def _attn_fwd(
272276
scores = tl.dot(q, k) * qk_scale
273277
scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL)
274278

275-
# --- Optional 2:4 structured sparsity ---
279+
# --- Optional N:M structured sparsity ---
276280
if SPARSITY_N > 0:
277281
# Check if this KV tile should be kept dense
278282
is_sink = kv_start < NUM_SINK_TOKENS
@@ -473,7 +477,7 @@ def _attn_bwd_dq(
473477
scores = tl.dot(q, kT) * qk_scale
474478
scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL)
475479

476-
# Re-apply 2:4 sparsity to match forward pass
480+
# Re-apply N:M sparsity to match forward pass
477481
if SPARSITY_N > 0:
478482
is_sink = kv_start < NUM_SINK_TOKENS
479483
causal_offset = seq_len_kv - seq_len_q
@@ -613,7 +617,7 @@ def _attn_bwd_dkdv(
613617
scores = tl.dot(q_tile, kT) * qk_scale
614618
scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL)
615619

616-
# Re-apply 2:4 sparsity to match forward pass
620+
# Re-apply N:M sparsity to match forward pass
617621
if SPARSITY_N > 0:
618622
is_sink = kv_start < NUM_SINK_TOKENS
619623
causal_offset = seq_len_kv - seq_len_q
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Shared fixtures and helpers for Triton flash attention tests."""
17+
18+
import pytest
19+
import torch
20+
import torch.nn.functional as F
21+
22+
23+
def make_qkv(total, num_heads, num_kv_heads, head_dim, device="cuda", dtype=torch.float16):
24+
"""Create packed Q, K, V tensors."""
25+
q = torch.randn(total, num_heads, head_dim, device=device, dtype=dtype)
26+
k = torch.randn(total, num_kv_heads, head_dim, device=device, dtype=dtype)
27+
v = torch.randn(total, num_kv_heads, head_dim, device=device, dtype=dtype)
28+
return q, k, v
29+
30+
31+
def make_varlen_meta(seq_lens, device="cuda"):
32+
"""Create b_start_loc and b_seq_len from a list of sequence lengths."""
33+
b_seq_len = torch.tensor(seq_lens, device=device, dtype=torch.int32)
34+
b_start_loc = torch.zeros(len(seq_lens), device=device, dtype=torch.int32)
35+
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0)
36+
return b_start_loc, b_seq_len
37+
38+
39+
def sdpa_reference(q, k, v, b_start_loc, b_seq_len, is_causal=True):
40+
"""SDPA reference. Supports GQA. Returns [total_tokens, num_heads, dim]."""
41+
batch = b_seq_len.shape[0]
42+
num_q, num_kv = q.shape[1], k.shape[1]
43+
parts = []
44+
for b in range(batch):
45+
s, n = int(b_start_loc[b].item()), int(b_seq_len[b].item())
46+
qb = q[s : s + n].unsqueeze(0).permute(0, 2, 1, 3)
47+
kb = k[s : s + n].unsqueeze(0).permute(0, 2, 1, 3)
48+
vb = v[s : s + n].unsqueeze(0).permute(0, 2, 1, 3)
49+
if num_q != num_kv:
50+
r = num_q // num_kv
51+
kb = kb.repeat_interleave(r, dim=1)
52+
vb = vb.repeat_interleave(r, dim=1)
53+
ob = F.scaled_dot_product_attention(qb, kb, vb, is_causal=is_causal)
54+
parts.append(ob.permute(0, 2, 1, 3).squeeze(0))
55+
return torch.cat(parts, dim=0)
56+
57+
58+
@pytest.fixture(scope="module")
59+
def tiny_llama_dir(tmp_path_factory):
60+
"""Tiny Llama: 2 layers, 64 hidden, 4 q-heads, 2 kv-heads, head_dim=16."""
61+
from _test_utils.torch.transformers_models import create_tiny_llama_dir
62+
63+
return create_tiny_llama_dir(
64+
tmp_path_factory.mktemp("tiny_llama"),
65+
with_tokenizer=True,
66+
num_hidden_layers=2,
67+
hidden_size=64,
68+
num_attention_heads=4,
69+
num_key_value_heads=2,
70+
intermediate_size=64,
71+
max_position_embeddings=64,
72+
)

0 commit comments

Comments
 (0)