Skip to content

Commit 5b76b29

Browse files
authored
Add and update kernel for Wan (#415)
* update_scale_shift * add rmsnorm split * add test * fix * fix variance * fix * cleancode * cleancode
1 parent 5c2013b commit 5b76b29

4 files changed

Lines changed: 314 additions & 55 deletions

File tree

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
from sgl_kernel_npu.utils.triton_utils import get_device_properties
5+
6+
7+
@triton.autotune(
8+
configs=[
9+
triton.Config(
10+
{"block_l": 256, "block_c": 256},
11+
),
12+
triton.Config(
13+
{"block_l": 128, "block_c": 256},
14+
),
15+
triton.Config(
16+
{"block_l": 128, "block_c": 128},
17+
),
18+
triton.Config(
19+
{"block_l": 64, "block_c": 128},
20+
),
21+
triton.Config(
22+
{"block_l": 64, "block_c": 64},
23+
),
24+
triton.Config(
25+
{"block_l": 64, "block_c": 32},
26+
),
27+
triton.Config(
28+
{"block_l": 32, "block_c": 32},
29+
),
30+
],
31+
key=["num_tokens", "hidden_size"],
32+
)
33+
@triton.jit
34+
def fused_rsqrt_mul_kernel(
35+
x_ptr,
36+
variance_ptr,
37+
weight_ptr,
38+
eps,
39+
output_ptr,
40+
num_tokens,
41+
hidden_size,
42+
block_l: tl.constexpr,
43+
block_c: tl.constexpr,
44+
kernel_num: tl.constexpr,
45+
):
46+
pid = tl.program_id(0)
47+
row_tasks = tl.cdiv(num_tokens, block_l)
48+
col_tasks = tl.cdiv(hidden_size, block_c)
49+
total_tasks = row_tasks * col_tasks
50+
51+
for task_id in range(pid, total_tasks, kernel_num):
52+
row_pid = task_id // col_tasks
53+
col_pid = task_id % col_tasks
54+
55+
token_offsets = row_pid * block_l + tl.arange(0, block_l)
56+
dim_offsets = col_pid * block_c + tl.arange(0, block_c)
57+
offset = token_offsets[:, None] * hidden_size + dim_offsets[None, :]
58+
59+
mask_token = token_offsets < num_tokens
60+
mask_dim = dim_offsets < hidden_size
61+
mask = mask_token[:, None] & mask_dim[None, :]
62+
63+
x = tl.load(x_ptr + offset, mask=mask, other=0.0)
64+
variance = tl.load(
65+
variance_ptr + token_offsets[:, None], mask=mask_token[:, None], other=0.0
66+
)
67+
weight = tl.load(weight_ptr + dim_offsets, mask=mask_dim, other=0.0)
68+
69+
rsqrt = tl.math.rsqrt(variance + eps)
70+
output = x * rsqrt * weight
71+
tl.store(output_ptr + offset, output, mask=mask)
72+
73+
74+
def fused_rsqrt_mul(x, variance, weight, eps=1e-6):
75+
_, kernel_num = get_device_properties()
76+
B, L, C = x.shape[0], x.shape[1], x.shape[2]
77+
grid = (kernel_num,)
78+
79+
output = torch.empty_like(x)
80+
81+
fused_rsqrt_mul_kernel[grid](
82+
x,
83+
variance,
84+
weight,
85+
eps,
86+
output,
87+
B * L,
88+
C,
89+
kernel_num=kernel_num,
90+
)
91+
92+
return output
93+
94+
95+
@triton.autotune(
96+
configs=[
97+
triton.Config(
98+
{"block_l": 96},
99+
),
100+
triton.Config(
101+
{"block_l": 64},
102+
),
103+
triton.Config(
104+
{"block_l": 32},
105+
),
106+
triton.Config(
107+
{"block_l": 16},
108+
),
109+
triton.Config(
110+
{"block_l": 8},
111+
),
112+
triton.Config(
113+
{"block_l": 4},
114+
),
115+
triton.Config(
116+
{"block_l": 2},
117+
),
118+
triton.Config(
119+
{"block_l": 1},
120+
),
121+
],
122+
key=["num_tokens"],
123+
)
124+
@triton.jit
125+
def fused_variance_kernel(
126+
x_ptr,
127+
output_ptr,
128+
num_tokens,
129+
hidden_size: tl.constexpr,
130+
block_l: tl.constexpr,
131+
kernel_num: tl.constexpr,
132+
):
133+
pid = tl.program_id(0)
134+
total_tasks = tl.cdiv(num_tokens, block_l)
135+
136+
for task_id in range(pid, total_tasks, kernel_num):
137+
token_offsets = task_id * block_l + tl.arange(0, block_l)
138+
dim_offsets = tl.arange(0, hidden_size)
139+
140+
offset = token_offsets[:, None] * hidden_size + dim_offsets[None, :]
141+
mask_out = token_offsets[:, None] < num_tokens
142+
mask = mask_out & (dim_offsets[None, :] < hidden_size)
143+
144+
x = tl.load(x_ptr + offset, mask=mask, other=0.0)
145+
146+
x_sq = x * x
147+
sum_sq = tl.sum(x_sq, axis=1)
148+
variance = sum_sq / hidden_size
149+
150+
tl.store(output_ptr + token_offsets[:, None], variance[:, None], mask=mask_out)
151+
152+
153+
def fused_variance(x: torch.Tensor):
154+
_, kernel_num = get_device_properties()
155+
B, L, C = x.shape[0], x.shape[1], x.shape[2]
156+
grid = (kernel_num,)
157+
158+
output = torch.empty((B, L, 1), device=x.device, dtype=x.dtype)
159+
160+
fused_variance_kernel[grid](x, output, B * L, C, kernel_num=kernel_num)
161+
return output
Lines changed: 91 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
11
import torch
22
import triton
33
import triton.language as tl
4-
5-
4+
from sgl_kernel_npu.utils.triton_utils import get_device_properties
5+
6+
7+
@triton.autotune(
8+
configs=[
9+
triton.Config(
10+
{"block_l": 128, "block_c": 128},
11+
),
12+
triton.Config(
13+
{"block_l": 112, "block_c": 128},
14+
),
15+
],
16+
key=["num_tokens", "hidden_size"],
17+
)
618
@triton.jit
719
def fused_scale_shift_kernel(
820
x_ptr,
@@ -15,39 +27,62 @@ def fused_scale_shift_kernel(
1527
shift_numel: tl.constexpr,
1628
block_l: tl.constexpr,
1729
block_c: tl.constexpr,
30+
kernel_num: tl.constexpr,
1831
):
19-
row_pid = tl.program_id(0)
20-
col_pid = tl.program_id(1)
21-
22-
token_offsets = row_pid * block_l + tl.arange(0, block_l)
23-
dim_offsets = col_pid * block_c + tl.arange(0, block_c)
32+
pid = tl.program_id(0)
33+
row_tasks = tl.cdiv(num_tokens, block_l)
34+
col_tasks = tl.cdiv(hidden_size, block_c)
35+
total_tasks = row_tasks * col_tasks
2436

25-
mask = (token_offsets[:, None] < num_tokens) & (dim_offsets[None, :] < hidden_size)
26-
offset = token_offsets[:, None] * hidden_size + dim_offsets[None, :]
37+
for task_id in range(pid, total_tasks, kernel_num):
38+
row_pid = task_id // col_tasks
39+
col_pid = task_id % col_tasks
2740

28-
x = tl.load(x_ptr + offset, mask=mask, other=0.0)
41+
token_offsets = row_pid * block_l + tl.arange(0, block_l)
42+
dim_offsets = col_pid * block_c + tl.arange(0, block_c)
2943

30-
if scale_numel == 1:
31-
scale = tl.load(scale_ptr)
32-
else:
33-
scale_offsets = dim_offsets[None, :]
34-
scale_mask = dim_offsets[None, :] < hidden_size
35-
scale = tl.load(scale_ptr + scale_offsets, mask=scale_mask, other=0.0)
36-
37-
if shift_numel == 1:
38-
shift = tl.load(shift_ptr)
39-
else:
40-
shift_offsets = dim_offsets[None, :]
41-
shift_mask = dim_offsets[None, :] < hidden_size
42-
shift = tl.load(shift_ptr + shift_offsets, mask=shift_mask, other=0.0).to(
43-
tl.float32
44+
mask = (token_offsets[:, None] < num_tokens) & (
45+
dim_offsets[None, :] < hidden_size
4446
)
45-
46-
output = x * (1.0 + scale) + shift
47-
48-
tl.store(output_ptr + offset, output.to(output_ptr.dtype.element_ty), mask=mask)
49-
50-
47+
offset = token_offsets[:, None] * hidden_size + dim_offsets[None, :]
48+
49+
x = tl.load(x_ptr + offset, mask=mask, other=0.0)
50+
51+
if scale_numel == 1:
52+
scale = tl.load(scale_ptr)
53+
else:
54+
scale_offsets = dim_offsets[None, :]
55+
scale_mask = dim_offsets[None, :] < hidden_size
56+
scale = tl.load(scale_ptr + scale_offsets, mask=scale_mask, other=0.0)
57+
58+
if shift_numel == 1:
59+
shift = tl.load(shift_ptr)
60+
else:
61+
shift_offsets = dim_offsets[None, :]
62+
shift_mask = dim_offsets[None, :] < hidden_size
63+
shift = tl.load(shift_ptr + shift_offsets, mask=shift_mask, other=0.0).to(
64+
tl.float32
65+
)
66+
67+
output = x * (1.0 + scale) + shift
68+
69+
tl.store(output_ptr + offset, output.to(output_ptr.dtype.element_ty), mask=mask)
70+
71+
72+
@triton.autotune(
73+
configs=[
74+
triton.Config(
75+
{"block_l": 96, "block_c": 128},
76+
),
77+
triton.Config(
78+
{"block_l": 80, "block_c": 128},
79+
),
80+
triton.Config(
81+
{"block_l": 64, "block_c": 128},
82+
),
83+
],
84+
key=["num_tokens", "hidden_size"],
85+
)
5186
@triton.jit
5287
def fused_scale_shift_kernel_2(
5388
x_ptr,
@@ -59,36 +94,43 @@ def fused_scale_shift_kernel_2(
5994
scale_constant: tl.constexpr,
6095
block_l: tl.constexpr,
6196
block_c: tl.constexpr,
97+
kernel_num: tl.constexpr,
6298
):
63-
row_pid = tl.program_id(0)
64-
col_pid = tl.program_id(1)
99+
pid = tl.program_id(0)
100+
row_tasks = tl.cdiv(num_tokens, block_l)
101+
col_tasks = tl.cdiv(hidden_size, block_c)
102+
total_tasks = row_tasks * col_tasks
65103

66-
token_offsets = row_pid * block_l + tl.arange(0, block_l)
67-
dim_offsets = col_pid * block_c + tl.arange(0, block_c)
104+
for task_id in range(pid, total_tasks, kernel_num):
105+
row_pid = task_id // col_tasks
106+
col_pid = task_id % col_tasks
68107

69-
mask = (token_offsets[:, None] < num_tokens) & (dim_offsets[None, :] < hidden_size)
70-
offset = token_offsets[:, None] * hidden_size + dim_offsets[None, :]
108+
token_offsets = row_pid * block_l + tl.arange(0, block_l)
109+
dim_offsets = col_pid * block_c + tl.arange(0, block_c)
71110

72-
x = tl.load(x_ptr + offset, mask=mask, other=0.0)
111+
mask = (token_offsets[:, None] < num_tokens) & (
112+
dim_offsets[None, :] < hidden_size
113+
)
114+
offset = token_offsets[:, None] * hidden_size + dim_offsets[None, :]
73115

74-
scale_offsets = dim_offsets[None, :]
75-
scale_mask = dim_offsets[None, :] < hidden_size
76-
scale = tl.load(scale_ptr + scale_offsets, mask=scale_mask, other=0.0)
116+
x = tl.load(x_ptr + offset, mask=mask, other=0.0)
117+
118+
scale_offsets = dim_offsets[None, :]
119+
scale_mask = dim_offsets[None, :] < hidden_size
120+
scale = tl.load(scale_ptr + scale_offsets, mask=scale_mask, other=0.0)
77121

78-
shift = tl.load(shift_ptr + offset, mask=mask, other=0.0).to(tl.float32)
122+
shift = tl.load(shift_ptr + offset, mask=mask, other=0.0).to(tl.float32)
79123

80-
output = x * (scale_constant + scale) + shift
124+
output = x * (scale_constant + scale) + shift
81125

82-
tl.store(output_ptr + offset, output.to(output_ptr.dtype.element_ty), mask=mask)
126+
tl.store(output_ptr + offset, output.to(output_ptr.dtype.element_ty), mask=mask)
83127

84128

85129
def fused_scale_shift(
86130
x: torch.Tensor,
87131
scale: torch.Tensor,
88132
shift: torch.Tensor,
89133
scale_constant: float = 1.0,
90-
block_l: int = 128,
91-
block_c: int = 128,
92134
):
93135
orig_shape = x.shape
94136
num_tokens = orig_shape[0] * orig_shape[1]
@@ -110,10 +152,8 @@ def fused_scale_shift(
110152

111153
output = torch.empty_like(x)
112154

113-
grid = (
114-
triton.cdiv(num_tokens, block_l),
115-
triton.cdiv(hidden_size, block_c),
116-
)
155+
kernel_num = get_device_properties()[1]
156+
grid = (kernel_num,)
117157

118158
if shift_numel == x_numel:
119159
fused_scale_shift_kernel_2[grid](
@@ -124,8 +164,7 @@ def fused_scale_shift(
124164
num_tokens,
125165
hidden_size,
126166
scale_constant,
127-
block_l=block_l,
128-
block_c=block_c,
167+
kernel_num=kernel_num,
129168
)
130169

131170
else:
@@ -138,8 +177,7 @@ def fused_scale_shift(
138177
hidden_size,
139178
scale_numel=scale_numel,
140179
shift_numel=shift_numel,
141-
block_l=block_l,
142-
block_c=block_c,
180+
kernel_num=kernel_num,
143181
)
144182

145183
return output

0 commit comments

Comments
 (0)