11import torch
22import triton
33import triton .language as tl
4-
5-
4+ from sgl_kernel_npu .utils .triton_utils import get_device_properties
5+
6+
7+ @triton .autotune (
8+ configs = [
9+ triton .Config (
10+ {"block_l" : 128 , "block_c" : 128 },
11+ ),
12+ triton .Config (
13+ {"block_l" : 112 , "block_c" : 128 },
14+ ),
15+ ],
16+ key = ["num_tokens" , "hidden_size" ],
17+ )
618@triton .jit
719def fused_scale_shift_kernel (
820 x_ptr ,
@@ -15,39 +27,62 @@ def fused_scale_shift_kernel(
1527 shift_numel : tl .constexpr ,
1628 block_l : tl .constexpr ,
1729 block_c : tl .constexpr ,
30+ kernel_num : tl .constexpr ,
1831):
19- row_pid = tl .program_id (0 )
20- col_pid = tl .program_id (1 )
21-
22- token_offsets = row_pid * block_l + tl .arange (0 , block_l )
23- dim_offsets = col_pid * block_c + tl .arange (0 , block_c )
32+ pid = tl .program_id (0 )
33+ row_tasks = tl .cdiv (num_tokens , block_l )
34+ col_tasks = tl .cdiv (hidden_size , block_c )
35+ total_tasks = row_tasks * col_tasks
2436
25- mask = (token_offsets [:, None ] < num_tokens ) & (dim_offsets [None , :] < hidden_size )
26- offset = token_offsets [:, None ] * hidden_size + dim_offsets [None , :]
37+ for task_id in range (pid , total_tasks , kernel_num ):
38+ row_pid = task_id // col_tasks
39+ col_pid = task_id % col_tasks
2740
28- x = tl .load (x_ptr + offset , mask = mask , other = 0.0 )
41+ token_offsets = row_pid * block_l + tl .arange (0 , block_l )
42+ dim_offsets = col_pid * block_c + tl .arange (0 , block_c )
2943
30- if scale_numel == 1 :
31- scale = tl .load (scale_ptr )
32- else :
33- scale_offsets = dim_offsets [None , :]
34- scale_mask = dim_offsets [None , :] < hidden_size
35- scale = tl .load (scale_ptr + scale_offsets , mask = scale_mask , other = 0.0 )
36-
37- if shift_numel == 1 :
38- shift = tl .load (shift_ptr )
39- else :
40- shift_offsets = dim_offsets [None , :]
41- shift_mask = dim_offsets [None , :] < hidden_size
42- shift = tl .load (shift_ptr + shift_offsets , mask = shift_mask , other = 0.0 ).to (
43- tl .float32
44+ mask = (token_offsets [:, None ] < num_tokens ) & (
45+ dim_offsets [None , :] < hidden_size
4446 )
45-
46- output = x * (1.0 + scale ) + shift
47-
48- tl .store (output_ptr + offset , output .to (output_ptr .dtype .element_ty ), mask = mask )
49-
50-
47+ offset = token_offsets [:, None ] * hidden_size + dim_offsets [None , :]
48+
49+ x = tl .load (x_ptr + offset , mask = mask , other = 0.0 )
50+
51+ if scale_numel == 1 :
52+ scale = tl .load (scale_ptr )
53+ else :
54+ scale_offsets = dim_offsets [None , :]
55+ scale_mask = dim_offsets [None , :] < hidden_size
56+ scale = tl .load (scale_ptr + scale_offsets , mask = scale_mask , other = 0.0 )
57+
58+ if shift_numel == 1 :
59+ shift = tl .load (shift_ptr )
60+ else :
61+ shift_offsets = dim_offsets [None , :]
62+ shift_mask = dim_offsets [None , :] < hidden_size
63+ shift = tl .load (shift_ptr + shift_offsets , mask = shift_mask , other = 0.0 ).to (
64+ tl .float32
65+ )
66+
67+ output = x * (1.0 + scale ) + shift
68+
69+ tl .store (output_ptr + offset , output .to (output_ptr .dtype .element_ty ), mask = mask )
70+
71+
72+ @triton .autotune (
73+ configs = [
74+ triton .Config (
75+ {"block_l" : 96 , "block_c" : 128 },
76+ ),
77+ triton .Config (
78+ {"block_l" : 80 , "block_c" : 128 },
79+ ),
80+ triton .Config (
81+ {"block_l" : 64 , "block_c" : 128 },
82+ ),
83+ ],
84+ key = ["num_tokens" , "hidden_size" ],
85+ )
5186@triton .jit
5287def fused_scale_shift_kernel_2 (
5388 x_ptr ,
@@ -59,36 +94,43 @@ def fused_scale_shift_kernel_2(
5994 scale_constant : tl .constexpr ,
6095 block_l : tl .constexpr ,
6196 block_c : tl .constexpr ,
97+ kernel_num : tl .constexpr ,
6298):
63- row_pid = tl .program_id (0 )
64- col_pid = tl .program_id (1 )
99+ pid = tl .program_id (0 )
100+ row_tasks = tl .cdiv (num_tokens , block_l )
101+ col_tasks = tl .cdiv (hidden_size , block_c )
102+ total_tasks = row_tasks * col_tasks
65103
66- token_offsets = row_pid * block_l + tl .arange (0 , block_l )
67- dim_offsets = col_pid * block_c + tl .arange (0 , block_c )
104+ for task_id in range (pid , total_tasks , kernel_num ):
105+ row_pid = task_id // col_tasks
106+ col_pid = task_id % col_tasks
68107
69- mask = ( token_offsets [:, None ] < num_tokens ) & ( dim_offsets [ None , :] < hidden_size )
70- offset = token_offsets [:, None ] * hidden_size + dim_offsets [ None , :]
108+ token_offsets = row_pid * block_l + tl . arange ( 0 , block_l )
109+ dim_offsets = col_pid * block_c + tl . arange ( 0 , block_c )
71110
72- x = tl .load (x_ptr + offset , mask = mask , other = 0.0 )
111+ mask = (token_offsets [:, None ] < num_tokens ) & (
112+ dim_offsets [None , :] < hidden_size
113+ )
114+ offset = token_offsets [:, None ] * hidden_size + dim_offsets [None , :]
73115
74- scale_offsets = dim_offsets [None , :]
75- scale_mask = dim_offsets [None , :] < hidden_size
76- scale = tl .load (scale_ptr + scale_offsets , mask = scale_mask , other = 0.0 )
116+ x = tl .load (x_ptr + offset , mask = mask , other = 0.0 )
117+
118+ scale_offsets = dim_offsets [None , :]
119+ scale_mask = dim_offsets [None , :] < hidden_size
120+ scale = tl .load (scale_ptr + scale_offsets , mask = scale_mask , other = 0.0 )
77121
78- shift = tl .load (shift_ptr + offset , mask = mask , other = 0.0 ).to (tl .float32 )
122+ shift = tl .load (shift_ptr + offset , mask = mask , other = 0.0 ).to (tl .float32 )
79123
80- output = x * (scale_constant + scale ) + shift
124+ output = x * (scale_constant + scale ) + shift
81125
82- tl .store (output_ptr + offset , output .to (output_ptr .dtype .element_ty ), mask = mask )
126+ tl .store (output_ptr + offset , output .to (output_ptr .dtype .element_ty ), mask = mask )
83127
84128
85129def fused_scale_shift (
86130 x : torch .Tensor ,
87131 scale : torch .Tensor ,
88132 shift : torch .Tensor ,
89133 scale_constant : float = 1.0 ,
90- block_l : int = 128 ,
91- block_c : int = 128 ,
92134):
93135 orig_shape = x .shape
94136 num_tokens = orig_shape [0 ] * orig_shape [1 ]
@@ -110,10 +152,8 @@ def fused_scale_shift(
110152
111153 output = torch .empty_like (x )
112154
113- grid = (
114- triton .cdiv (num_tokens , block_l ),
115- triton .cdiv (hidden_size , block_c ),
116- )
155+ kernel_num = get_device_properties ()[1 ]
156+ grid = (kernel_num ,)
117157
118158 if shift_numel == x_numel :
119159 fused_scale_shift_kernel_2 [grid ](
@@ -124,8 +164,7 @@ def fused_scale_shift(
124164 num_tokens ,
125165 hidden_size ,
126166 scale_constant ,
127- block_l = block_l ,
128- block_c = block_c ,
167+ kernel_num = kernel_num ,
129168 )
130169
131170 else :
@@ -138,8 +177,7 @@ def fused_scale_shift(
138177 hidden_size ,
139178 scale_numel = scale_numel ,
140179 shift_numel = shift_numel ,
141- block_l = block_l ,
142- block_c = block_c ,
180+ kernel_num = kernel_num ,
143181 )
144182
145183 return output
0 commit comments