Skip to content

Commit d3a0dd8

Browse files
author
niushengxiao
committed
feat: add hadamard_transform kernel
1 parent 2ac3630 commit d3a0dd8

4 files changed

Lines changed: 157 additions & 6 deletions

File tree

lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def _get_indices(
251251
@staticmethod
252252
def _rotate_activation(x: torch.Tensor) -> torch.Tensor:
253253
assert x.dtype == torch.bfloat16
254-
from sgl_kernel import hadamard_transform
254+
from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform
255255

256256
hidden_size = x.size(-1)
257257
assert (hidden_size & (hidden_size - 1)) == 0, "Hidden size must be a power of 2 for Hadamard transform."
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.jit
7+
def _butterfly_stage(x, GROUPS: tl.constexpr, STEP: tl.constexpr, BLOCK_N: tl.constexpr):
8+
x_grouped = tl.reshape(x, (GROUPS, 2, STEP))
9+
x_grouped = tl.permute(x_grouped, (0, 2, 1))
10+
left, right = tl.split(x_grouped)
11+
x_pair = tl.join(left + right, left - right)
12+
x_pair = tl.permute(x_pair, (0, 2, 1))
13+
return tl.reshape(x_pair, (BLOCK_N,))
14+
15+
16+
@triton.jit
17+
def _hadamard_transform_kernel(
18+
X,
19+
Y,
20+
scale: tl.constexpr,
21+
BLOCK_N: tl.constexpr,
22+
):
23+
row = tl.program_id(0)
24+
offsets = tl.arange(0, BLOCK_N)
25+
x = tl.load(X + row * BLOCK_N + offsets).to(tl.float32)
26+
27+
x = _butterfly_stage(x, 64, 1, BLOCK_N)
28+
x = _butterfly_stage(x, 32, 2, BLOCK_N)
29+
x = _butterfly_stage(x, 16, 4, BLOCK_N)
30+
x = _butterfly_stage(x, 8, 8, BLOCK_N)
31+
x = _butterfly_stage(x, 4, 16, BLOCK_N)
32+
x = _butterfly_stage(x, 2, 32, BLOCK_N)
33+
x = _butterfly_stage(x, 1, 64, BLOCK_N)
34+
35+
tl.store(Y + row * BLOCK_N + offsets, x * scale)
36+
37+
38+
def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
39+
assert x.is_cuda, "hadamard_transform only supports CUDA tensors"
40+
assert x.dtype == torch.bfloat16, "hadamard_transform expects bfloat16 input"
41+
42+
original_shape = x.shape
43+
hidden_size = x.size(-1)
44+
assert hidden_size == 128, "DeepSeek-V3.2 Hadamard transform expects hidden size 128"
45+
46+
x = x.contiguous()
47+
out = torch.empty_like(x)
48+
rows = x.numel() // hidden_size
49+
_hadamard_transform_kernel[(rows,)](
50+
x,
51+
out,
52+
scale,
53+
BLOCK_N=hidden_size,
54+
num_warps=4,
55+
)
56+
57+
return out.view(original_shape)

lightllm/utils/backend_validator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,15 @@ def _validate_flashmla_sparse():
196196
except Exception as e:
197197
return False, f"sgl_kernel.flash_mla import failed: {type(e).__name__}: {e}"
198198

199-
batch, heads, seq, dim = 1, 64, 128, 512 + 64
199+
batch, heads, seq = 1, 64, 128
200+
kv_lora_rank = 512
201+
qk_rope_head_dim = 64
202+
qk_dim = kv_lora_rank + qk_rope_head_dim
200203
dtype = torch.bfloat16
201204
device = "cuda"
202205

203-
q = torch.randn(batch * seq, heads, dim, dtype=dtype, device=device)
204-
kv = torch.zeros(batch * seq, 1, dim, dtype=dtype, device=device)
206+
q = torch.randn(batch * seq, heads, qk_dim, dtype=dtype, device=device)
207+
kv = torch.zeros(batch * seq, 1, qk_dim, dtype=dtype, device=device)
205208

206209
index_topk = 128
207210
topk_indices = torch.zeros(batch * seq, index_topk, dtype=torch.int32, device=device)
@@ -210,8 +213,7 @@ def _validate_flashmla_sparse():
210213

211214
topk_indices = topk_indices.view(batch * seq, 1, index_topk)
212215

213-
softmax_scale = 1.0 / (dim ** 0.5)
214-
kv_lora_rank = dim
216+
softmax_scale = 1.0 / (qk_dim ** 0.5)
215217

216218
try:
217219
mla_out, _, _ = flash_mla_sparse_fwd(
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import pytest
2+
import torch
3+
4+
from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform
5+
6+
7+
TP = 8
8+
INDEX_N_HEADS = 64
9+
INDEX_HEAD_DIM = 128
10+
TP_INDEX_N_HEADS = INDEX_N_HEADS // TP
11+
SCALE = INDEX_HEAD_DIM ** -0.5
12+
13+
14+
def _get_sgl_kernel_hadamard_transform():
15+
if not torch.cuda.is_available():
16+
pytest.skip("CUDA is required for hadamard_transform comparison")
17+
try:
18+
from sgl_kernel import hadamard_transform as sgl_hadamard_transform
19+
except ImportError:
20+
pytest.skip("sgl_kernel.hadamard_transform is not available")
21+
return sgl_hadamard_transform
22+
23+
24+
def _bench(fn, x, warmup=30, iters=300):
25+
for _ in range(warmup):
26+
fn(x, scale=SCALE)
27+
torch.cuda.synchronize()
28+
29+
start = torch.cuda.Event(enable_timing=True)
30+
end = torch.cuda.Event(enable_timing=True)
31+
start.record()
32+
for _ in range(iters):
33+
y = fn(x, scale=SCALE)
34+
end.record()
35+
torch.cuda.synchronize()
36+
return start.elapsed_time(end) / iters, y
37+
38+
39+
@pytest.mark.parametrize("tokens", [1, 16, 128, 512, 1024, 2048, 4096, 8192, 16384])
40+
def test_hadamard_transform_matches_sgl_kernel_deepseek_v32_shapes(tokens):
41+
sgl_hadamard_transform = _get_sgl_kernel_hadamard_transform()
42+
43+
q = torch.randn(tokens, TP_INDEX_N_HEADS, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda")
44+
k = torch.randn(tokens, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda")
45+
46+
q_expected = sgl_hadamard_transform(q, scale=SCALE)
47+
q_actual = hadamard_transform(q, scale=SCALE)
48+
k_expected = sgl_hadamard_transform(k, scale=SCALE)
49+
k_actual = hadamard_transform(k, scale=SCALE)
50+
torch.cuda.synchronize()
51+
52+
assert torch.equal(q_actual, q_expected)
53+
assert torch.equal(k_actual, k_expected)
54+
55+
56+
def test_hadamard_transform_perf_report_deepseek_v32_shapes():
57+
sgl_hadamard_transform = _get_sgl_kernel_hadamard_transform()
58+
59+
print(
60+
"\nDeepSeek-V3.2 per-rank shapes with tp=8:"
61+
"\n q: [tokens, 8, 128]"
62+
"\n k: [tokens, 128]"
63+
"\n\ntokens | q_diff | k_diff | sgl_q ms | tri_q ms | sgl_k ms | tri_k ms | tri(q+k) ms | slowdown q+k"
64+
)
65+
66+
for tokens in [1, 16, 128, 512, 1024, 2048, 4096, 8192, 16384]:
67+
q = torch.randn(tokens, TP_INDEX_N_HEADS, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda")
68+
k = torch.randn(tokens, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda")
69+
70+
q_expected = sgl_hadamard_transform(q, scale=SCALE)
71+
q_actual = hadamard_transform(q, scale=SCALE)
72+
k_expected = sgl_hadamard_transform(k, scale=SCALE)
73+
k_actual = hadamard_transform(k, scale=SCALE)
74+
torch.cuda.synchronize()
75+
76+
q_diff = (q_expected.float() - q_actual.float()).abs().max().item()
77+
k_diff = (k_expected.float() - k_actual.float()).abs().max().item()
78+
sgl_q_ms, _ = _bench(sgl_hadamard_transform, q)
79+
tri_q_ms, _ = _bench(hadamard_transform, q)
80+
sgl_k_ms, _ = _bench(sgl_hadamard_transform, k)
81+
tri_k_ms, _ = _bench(hadamard_transform, k)
82+
sgl_sum_ms = sgl_q_ms + sgl_k_ms
83+
tri_sum_ms = tri_q_ms + tri_k_ms
84+
85+
print(
86+
f"{tokens:6d} | {q_diff:6.1g} | {k_diff:6.1g} | "
87+
f"{sgl_q_ms:8.4f} | {tri_q_ms:8.4f} | {sgl_k_ms:8.4f} | {tri_k_ms:8.4f} | "
88+
f"{tri_sum_ms:11.4f} | {tri_sum_ms / sgl_sum_ms:10.2f}x"
89+
)
90+
91+
assert q_diff == 0
92+
assert k_diff == 0

0 commit comments

Comments
 (0)