66import triton
77import triton .language as tl
88
9- from tensorrt_llm ._torch .modules .fla .utils import input_guard
9+ from tensorrt_llm ._torch .modules .fla .utils import custom_device_ctx
1010
1111
1212@triton .heuristics ({
@@ -30,6 +30,12 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
3030 cu_seqlens ,
3131 scale ,
3232 T ,
33+ total_nh ,
34+ stride_q ,
35+ stride_k ,
36+ stride_v ,
37+ stride_a ,
38+ stride_b ,
3339 s_h0_0 ,
3440 h0_dim0 ,
3541 B : tl .constexpr ,
@@ -46,117 +52,127 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
4652 """
4753 Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
4854 """
49- i_nh , i_v , i_k = tl .program_id (0 ), tl .program_id (1 ), tl .program_id (2 )
50- i_n , i_hv = i_nh // HV , i_nh % HV
51- i_h = i_hv // (HV // H )
52-
53- if IS_VARLEN :
54- bos , eos = (
55- tl .load (cu_seqlens + i_n ).to (tl .int64 ),
56- tl .load (cu_seqlens + i_n + 1 ).to (tl .int64 ),
57- )
58- all = T
59- T = eos - bos
60- else :
61- bos , eos = i_n * T , i_n * T + T
62- all = B * T
63-
55+ i_k , i_v , i_nh = tl .program_id (0 ), tl .program_id (1 ), tl .program_id (2 )
6456 o_k = i_k * BK + tl .arange (0 , BK )
6557 o_v = i_v * BV + tl .arange (0 , BV )
66-
67- p_q = q + (bos * H + i_h ) * K + o_k
68- p_k = k + (bos * H + i_h ) * K + o_k
69- p_v = v + (bos * HV + i_hv ) * V + o_v
70- p_b = b + bos * HV + i_hv
71- p_o = o + ((i_k * all + bos ) * HV + i_hv ) * V + o_v
72-
73- # Gating computation pointers
74- p_A_log = A_log + i_hv
75- p_a = a + bos * HV + i_hv
76- p_dt_bias = dt_bias + i_hv
77-
7858 mask_k = o_k < K
7959 mask_v = o_v < V
8060 mask_h = mask_k [:, None ] & mask_v [None , :]
61+ grid_stride_nh = tl .num_programs (2 )
8162
82- b_h = tl .zeros ([BK , BV ], dtype = tl .float32 )
83- if USE_INITIAL_STATE :
84- idx = tl .load (h0_indices + i_n ).to (tl .int64 ) # prevent int32 overflow
85- if idx >= 0 :
86- tl .device_assert (idx < h0_dim0 ,
87- "idx out of bounds in h0_source load" )
88- p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V + o_k [:, None ] * V +
89- o_v [None , :])
90- b_h += tl .load (p_h0 , mask = mask_h , other = 0 ).to (tl .float32 )
91-
92- for _ in range (0 , T ):
93- # Load inputs
94- b_q = tl .load (p_q , mask = mask_k , other = 0 ).to (tl .float32 )
95- b_k = tl .load (p_k , mask = mask_k , other = 0 ).to (tl .float32 )
96- b_v = tl .load (p_v , mask = mask_v , other = 0 ).to (tl .float32 )
97- b_b = tl .load (p_b ).to (tl .float32 )
98-
99- # Compute sigmoid gating
100- # Load gating parameters
101- b_A_log = tl .load (p_A_log ).to (tl .float32 )
102- b_a = tl .load (p_a ).to (tl .float32 )
103- b_dt_bias = tl .load (p_dt_bias ).to (tl .float32 )
104-
105- # Compute g = -exp(A_log) * softplus(a + dt_bias)
106- x = b_a + b_dt_bias
107- beta_x = softplus_beta * x
108- # Apply softplus with numerical stability
109- softplus_x = tl .where (
110- beta_x <= softplus_threshold ,
111- (1.0 / softplus_beta ) * tl .log (1.0 + tl .exp (beta_x )),
112- x ,
113- )
114- b_g = - tl .exp (b_A_log ) * softplus_x
63+ while i_nh < total_nh :
64+ i_n , i_hv = i_nh // HV , i_nh % HV
65+ i_h = i_hv // (HV // H )
66+
67+ if IS_VARLEN :
68+ bos , eos = (
69+ tl .load (cu_seqlens + i_n ).to (tl .int64 ),
70+ tl .load (cu_seqlens + i_n + 1 ).to (tl .int64 ),
71+ )
72+ all = T
73+ seq_T = eos - bos
74+ else :
75+ bos , eos = i_n * T , i_n * T + T
76+ all = B * T
77+ seq_T = T
78+
79+ # Decode q/k/v/a/b often arrive as views sliced out of larger packed tensors.
80+ # Use the caller-provided token strides so the kernel can consume those views
81+ # directly instead of relying on a packed contiguous layout.
82+ p_q = q + bos * stride_q + i_h * K + o_k
83+ p_k = k + bos * stride_k + i_h * K + o_k
84+ p_v = v + bos * stride_v + i_hv * V + o_v
85+ p_b = b + bos * stride_b + i_hv
86+ # o is allocated in this wrapper and kept contiguous, so the output
87+ # pointer arithmetic can use the packed [NK, B, T, HV, V] layout.
88+ p_o = o + ((i_k * all + bos ) * HV + i_hv ) * V + o_v
11589
116- # Compute beta = sigmoid(b)
117- b_beta = 1.0 / (1.0 + tl .exp (- b_b ))
90+ # Gating computation pointers
91+ p_A_log = A_log + i_hv
92+ p_a = a + bos * stride_a + i_hv
93+ p_dt_bias = dt_bias + i_hv
11894
119- # Apply L2 normalization if enabled
120- if USE_QK_L2NORM_IN_KERNEL :
121- b_q = b_q / (tl .sqrt (tl .sum (b_q * b_q )) + 1e-6 )
122- b_k = b_k / (tl .sqrt (tl .sum (b_k * b_k )) + 1e-6 )
95+ b_h = tl .zeros ([BK , BV ], dtype = tl .float32 )
96+ if USE_INITIAL_STATE :
97+ idx = tl .load (h0_indices + i_n ).to (tl .int64 )
98+ if idx >= 0 :
99+ tl .device_assert (idx < h0_dim0 ,
100+ "idx out of bounds in h0_source load" )
101+ p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V +
102+ o_k [:, None ] * V + o_v [None , :])
103+ b_h += tl .load (p_h0 , mask = mask_h , other = 0 ).to (tl .float32 )
123104
124- b_q = b_q * scale
105+ for _ in range (0 , seq_T ):
106+ # Load inputs
107+ b_q = tl .load (p_q , mask = mask_k , other = 0 ).to (tl .float32 )
108+ b_k = tl .load (p_k , mask = mask_k , other = 0 ).to (tl .float32 )
109+ b_v = tl .load (p_v , mask = mask_v , other = 0 ).to (tl .float32 )
110+ b_b = tl .load (p_b ).to (tl .float32 )
125111
126- # Apply gating to hidden state: h *= exp(g)
127- b_h *= tl .exp (b_g )
112+ # Compute sigmoid gating
113+ # Load gating parameters
114+ b_A_log = tl .load (p_A_log ).to (tl .float32 )
115+ b_a = tl .load (p_a ).to (tl .float32 )
116+ b_dt_bias = tl .load (p_dt_bias ).to (tl .float32 )
128117
129- # Delta rule: v -= sum(h * k, dim=0)
130- b_v -= tl .sum (b_h * b_k [:, None ], 0 )
118+ # Compute g = -exp(A_log) * softplus(a + dt_bias)
119+ x = b_a + b_dt_bias
120+ beta_x = softplus_beta * x
121+ # Apply softplus with numerical stability
122+ softplus_x = tl .where (
123+ beta_x <= softplus_threshold ,
124+ (1.0 / softplus_beta ) * tl .log (1.0 + tl .exp (beta_x )),
125+ x ,
126+ )
127+ b_g = - tl .exp (b_A_log ) * softplus_x
131128
132- # Apply beta gating: v *= beta
133- b_v *= b_beta
129+ # Compute beta = sigmoid(b)
130+ b_beta = 1.0 / ( 1.0 + tl . exp ( - b_b ))
134131
135- # Update hidden state: h += k[:, None] * v[None, :]
136- b_h += b_k [:, None ] * b_v [None , :]
132+ # Apply L2 normalization if enabled
133+ if USE_QK_L2NORM_IN_KERNEL :
134+ b_q = b_q / (tl .sqrt (tl .sum (b_q * b_q )) + 1e-6 )
135+ b_k = b_k / (tl .sqrt (tl .sum (b_k * b_k )) + 1e-6 )
137136
138- # Compute output: o = sum(h * q, dim=0)
139- b_o = tl .sum (b_h * b_q [:, None ], 0 )
140- tl .store (p_o , b_o .to (p_o .dtype .element_ty ), mask = mask_v )
137+ b_q = b_q * scale
141138
142- # Update pointers for next timestep
143- p_q += H * K
144- p_k += H * K
145- p_o += HV * V
146- p_v += HV * V
147- p_b += HV
148- p_a += HV
139+ # Apply gating to hidden state: h *= exp(g)
140+ b_h *= tl .exp (b_g )
149141
150- # Store final state back to h0_source with bounds checking
151- if USE_INITIAL_STATE :
152- idx = tl .load (h0_indices + i_n ).to (tl .int64 )
153- if idx >= 0 :
154- p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V + o_k [:, None ] * V +
155- o_v [None , :])
156- tl .store (p_h0 , b_h .to (p_h0 .dtype .element_ty ), mask = mask_h )
142+ # Delta rule: v -= sum(h * k, dim=0)
143+ b_v -= tl .sum (b_h * b_k [:, None ], 0 )
144+
145+ # Apply beta gating: v *= beta
146+ b_v *= b_beta
147+
148+ # Update hidden state: h += k[:, None] * v[None, :]
149+ b_h += b_k [:, None ] * b_v [None , :]
150+
151+ # Compute output: o = sum(h * q, dim=0)
152+ b_o = tl .sum (b_h * b_q [:, None ], 0 )
153+ tl .store (p_o , b_o .to (p_o .dtype .element_ty ), mask = mask_v )
154+
155+ # Update pointers for next timestep
156+ p_q += stride_q
157+ p_k += stride_k
158+ p_o += HV * V
159+ p_v += stride_v
160+ p_b += stride_b
161+ p_a += stride_a
162+
163+ # Store final state back to h0_source with bounds checking
164+ if USE_INITIAL_STATE :
165+ idx = tl .load (h0_indices + i_n ).to (tl .int64 )
166+ if idx >= 0 :
167+ tl .device_assert (idx < h0_dim0 ,
168+ "idx out of bounds in h0_source store" )
169+ p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V +
170+ o_k [:, None ] * V + o_v [None , :])
171+ tl .store (p_h0 , b_h .to (p_h0 .dtype .element_ty ), mask = mask_h )
172+
173+ i_nh += grid_stride_nh
157174
158175
159- @input_guard (exclude_args = ["initial_state_source" ])
160176def fused_sigmoid_gating_delta_rule_update (
161177 A_log : torch .Tensor ,
162178 a : torch .Tensor ,
@@ -181,6 +197,14 @@ def fused_sigmoid_gating_delta_rule_update(
181197 B , T , H , K , V = * k .shape , v .shape [- 1 ]
182198 HV = v .shape [2 ]
183199 N = B if cu_seqlens is None else len (cu_seqlens ) - 1
200+
201+ # Accept native view layouts from forward_decode rather than forcing packed
202+ # copies through input_guard.
203+ stride_q = q .stride (1 )
204+ stride_k = k .stride (1 )
205+ stride_v = v .stride (1 )
206+ stride_a = a .stride (- 2 )
207+ stride_b = b .stride (- 2 )
184208 BK , BV = triton .next_power_of_2 (K ), min (triton .next_power_of_2 (V ), 32 )
185209 NK , NV = triton .cdiv (K , BK ), triton .cdiv (V , BV )
186210 assert NK == 1 , "NK > 1 is not supported yet"
@@ -193,7 +217,10 @@ def fused_sigmoid_gating_delta_rule_update(
193217 assert scale > 0 , "scale must be positive"
194218
195219 o = q .new_empty (NK , * v .shape )
196- grid = (N * HV , NV , NK )
220+ # (NK, NV, N * HV) is found faster than (N * HV, NV, NK)
221+ # As max of grid.z is 65535, we cap grid.z and let each Triton program
222+ # grid-stride across the remaining N * HV tiles.
223+ grid = (NK , NV , min (N * HV , 65535 ))
197224
198225 if initial_state_source is not None :
199226 s_h0_0 , s_h0_1 , s_h0_2 , s_h0_3 = initial_state_source .stride ()
@@ -205,34 +232,44 @@ def fused_sigmoid_gating_delta_rule_update(
205232 s_h0_0 = 0
206233 slot_num = 0
207234
208- fused_sigmoid_gating_delta_rule_update_kernel [grid ](
209- A_log = A_log ,
210- a = a ,
211- dt_bias = dt_bias ,
212- softplus_beta = softplus_beta ,
213- softplus_threshold = softplus_threshold ,
214- q = q ,
215- k = k ,
216- v = v ,
217- b = b ,
218- o = o ,
219- h0_source = initial_state_source ,
220- h0_indices = initial_state_indices ,
221- cu_seqlens = cu_seqlens ,
222- scale = scale ,
223- T = T ,
224- s_h0_0 = s_h0_0 ,
225- h0_dim0 = slot_num ,
226- B = B ,
227- H = H ,
228- HV = HV ,
229- K = K ,
230- V = V ,
231- BK = BK ,
232- BV = BV ,
233- USE_QK_L2NORM_IN_KERNEL = use_qk_l2norm_in_kernel ,
234- num_warps = num_warps ,
235- num_stages = num_stages ,
236- )
235+ # input_guard used to set the active CUDA device and make inputs contiguous.
236+ # We keep only the device-context part here so Triton launches on q's device
237+ # without re-packing the decode views.
238+ with custom_device_ctx (q .device .index ):
239+ fused_sigmoid_gating_delta_rule_update_kernel [grid ](
240+ A_log = A_log ,
241+ a = a ,
242+ dt_bias = dt_bias ,
243+ softplus_beta = softplus_beta ,
244+ softplus_threshold = softplus_threshold ,
245+ q = q ,
246+ k = k ,
247+ v = v ,
248+ b = b ,
249+ o = o ,
250+ h0_source = initial_state_source ,
251+ h0_indices = initial_state_indices ,
252+ cu_seqlens = cu_seqlens ,
253+ scale = scale ,
254+ T = T ,
255+ total_nh = N * HV ,
256+ stride_q = stride_q ,
257+ stride_k = stride_k ,
258+ stride_v = stride_v ,
259+ stride_a = stride_a ,
260+ stride_b = stride_b ,
261+ s_h0_0 = s_h0_0 ,
262+ h0_dim0 = slot_num ,
263+ B = B ,
264+ H = H ,
265+ HV = HV ,
266+ K = K ,
267+ V = V ,
268+ BK = BK ,
269+ BV = BV ,
270+ USE_QK_L2NORM_IN_KERNEL = use_qk_l2norm_in_kernel ,
271+ num_warps = num_warps ,
272+ num_stages = num_stages ,
273+ )
237274 o = o .squeeze (0 )
238275 return o
0 commit comments