1+ import math
2+
13import torch
24import torch .nn .functional as F
35import triton
68from torch .nn .modules .utils import _pair
79
810from liger_kernel .ops .backends ._ascend .ub_manager import compute_default_tiling_strategy
11+ from liger_kernel .ops .backends ._ascend .ub_manager import get_ub_manager
912from liger_kernel .ops .utils import ensure_contiguous
1013from liger_kernel .ops .utils import get_npu_core_count
1114
@@ -22,13 +25,6 @@ def _fused_mask_softmax_fwd_kernel(
2225):
2326 """
2427 Fused forward kernel: causal masking + softmax with grid-stride loop.
25- Each program processes multiple rows for better resource utilization.
26-
27- Optimizations:
28- - Grid-stride loop to reduce kernel launch overhead
29- - Fuses masking and softmax to reduce memory traffic
30- - Online softmax algorithm for numerical stability
31- - Only loads valid elements (causal mask)
3228
3329 Args:
3430 scores_ptr: Input scores tensor pointer
@@ -58,23 +54,23 @@ def _fused_mask_softmax_fwd_kernel(
5854 for block_start in range (0 , valid_len , BLOCK_SIZE ):
5955 col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
6056 col_mask = col_idx < valid_len
61- vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ), eviction_policy = "evict_first" )
57+ vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ))
6258 max_val = tl .maximum (max_val , tl .max (vals ))
6359
6460 # Second pass: compute exp and sum
6561 sum_exp = 0.0
6662 for block_start in range (0 , valid_len , BLOCK_SIZE ):
6763 col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
6864 col_mask = col_idx < valid_len
69- vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ), eviction_policy = "evict_first" )
65+ vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ))
7066 exp_vals = tl .exp (vals - max_val )
7167 sum_exp += tl .sum (tl .where (col_mask , exp_vals , 0.0 ))
7268
7369 # Third pass: normalize and store
7470 for block_start in range (0 , valid_len , BLOCK_SIZE ):
7571 col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
7672 col_mask = col_idx < valid_len
77- vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ), eviction_policy = "evict_first" )
73+ vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ))
7874 exp_vals = tl .exp (vals - max_val )
7975 probs = exp_vals / sum_exp
8076 tl .store (out_row_ptr + col_idx , probs , mask = col_mask )
@@ -133,8 +129,8 @@ def _fused_mask_softmax_bwd_kernel(
133129 for block_start in range (0 , valid_len , BLOCK_SIZE ):
134130 col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
135131 col_mask = col_idx < valid_len
136- grad_vals = tl .load (grad_row_ptr + col_idx , mask = col_mask , other = 0.0 , eviction_policy = "evict_first" )
137- prob_vals = tl .load (probs_row_ptr + col_idx , mask = col_mask , other = 0.0 , eviction_policy = "evict_first" )
132+ grad_vals = tl .load (grad_row_ptr + col_idx , mask = col_mask , other = 0.0 )
133+ prob_vals = tl .load (probs_row_ptr + col_idx , mask = col_mask , other = 0.0 )
138134 dot += tl .sum (tl .where (col_mask , grad_vals * prob_vals , 0.0 ))
139135
140136 # Second pass: compute gradient
@@ -201,12 +197,8 @@ def _fused_mask_sparsemax_bwd_kernel(
201197 for block_start in range (0 , valid_len , BLOCK_SIZE ):
202198 col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
203199 col_mask = col_idx < valid_len
204- prob_vals = tl .load (probs_row_ptr + col_idx , mask = col_mask , other = 0.0 , eviction_policy = "evict_first" ).to (
205- tl .float32
206- )
207- grad_vals = tl .load (grad_row_ptr + col_idx , mask = col_mask , other = 0.0 , eviction_policy = "evict_first" ).to (
208- tl .float32
209- )
200+ prob_vals = tl .load (probs_row_ptr + col_idx , mask = col_mask , other = 0.0 ).to (tl .float32 )
201+ grad_vals = tl .load (grad_row_ptr + col_idx , mask = col_mask , other = 0.0 ).to (tl .float32 )
210202 supp = prob_vals > 0.0
211203 go_sum += tl .sum (tl .where (supp & col_mask , grad_vals , 0.0 ))
212204 supp_cnt += tl .sum (tl .where (supp & col_mask , 1.0 , 0.0 ))
@@ -229,110 +221,97 @@ def _fused_mask_sparsemax_bwd_kernel(
229221 tl .store (out_row_ptr + col_idx , 0.0 , mask = col_mask )
230222
231223
232- def mask_zero_rowwise (scores : torch .Tensor ) -> torch .Tensor :
233- """
234- Forward pass for causal masking with zero values.
235- Uses 1D row-wise processing.
236-
237- Args:
238- scores: Input scores tensor of shape (*batch, L, L)
239-
240- Returns:
241- Masked scores tensor with future positions set to 0.0
242- """
243- * batch , L , _ = scores .shape
244- N = int (torch .prod (torch .tensor (batch ))) if batch else 1
245- scores_f = scores .view (N , L , L )
246- out = torch .empty_like (scores_f )
247-
248- BLOCK_SIZE = get_optimal_block_size (L , is_forward = True )
249- num_cores = get_npu_core_count ()
250- grid_size = min (num_cores , N * L )
251-
252- _mask_row_kernel [(grid_size ,)](
253- scores_f ,
254- out ,
255- scores_f .stride (0 ),
256- scores_f .stride (1 ),
257- N ,
258- L ,
259- mask_val = 0.0 ,
260- BLOCK_SIZE = BLOCK_SIZE ,
261- )
262-
263- return out .view (* batch , L , L )
264-
265-
266224@triton .jit
267225def _mask_row_kernel (
268226 scores_ptr ,
269227 out_ptr ,
270228 stride_b ,
271229 stride_row ,
272- N ,
273230 L ,
274231 mask_val : tl .constexpr ,
275232 BLOCK_SIZE : tl .constexpr ,
276233):
277- pid = tl .program_id (0 )
278- num_progs = tl .num_programs (0 )
279- n_rows = N * L
234+ """
235+ 2D-tiled causal masking kernel.
280236
281- for linear_idx in tl .range (pid , n_rows , num_progs ):
282- batch_id = linear_idx // L
283- row_idx = linear_idx % L
237+ Grid:
238+ axis 0 (pid_m): row-block index
239+ axis 1 (pid_n): col-block index
240+ axis 2 (pid_b): batch index
284241
285- row_ptr = scores_ptr + batch_id * stride_b + row_idx * stride_row
286- out_row_ptr = out_ptr + batch_id * stride_b + row_idx * stride_row
242+ Each program handles a BLOCK_SIZE x BLOCK_SIZE tile.
243+ """
287244
288- # columns handled in blocks
289- for block_start in range (0 , L , BLOCK_SIZE ):
290- col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
291- col_mask = col_idx < L
245+ pid_m = tl .program_id (0 ) # row block
246+ pid_n = tl .program_id (1 ) # col block
247+ pid_b = tl .program_id (2 ) # batch
248+
249+ # Compute row/col indices
250+ row_start = pid_m * BLOCK_SIZE
251+ col_start = pid_n * BLOCK_SIZE
292252
293- # causal condition
294- keep = col_idx <= row_idx
253+ row_idx = row_start + tl . arange ( 0 , BLOCK_SIZE )
254+ col_idx = col_start + tl . arange ( 0 , BLOCK_SIZE )
295255
296- vals = tl .load (
297- row_ptr + col_idx ,
298- mask = col_mask ,
299- other = 0.0 ,
300- )
256+ row_mask = row_idx < L
257+ col_mask = col_idx < L
301258
302- masked_vals = tl .where (keep , vals , mask_val )
259+ ptrs = scores_ptr + pid_b * stride_b + row_idx [:, None ] * stride_row + col_idx [None , :]
260+ out_ptrs = out_ptr + pid_b * stride_b + row_idx [:, None ] * stride_row + col_idx [None , :]
303261
304- tl .store (
305- out_row_ptr + col_idx ,
306- masked_vals ,
307- mask = col_mask ,
308- )
262+ vals = tl .load (ptrs , mask = row_mask [:, None ] & col_mask [None , :], other = 0.0 )
263+ causal = col_idx [None , :] <= row_idx [:, None ]
264+ masked_vals = tl .where (causal , vals , mask_val )
309265
266+ tl .store (out_ptrs , masked_vals , mask = row_mask [:, None ] & col_mask [None , :])
310267
311- def get_optimal_block_size (n_cols : int , is_forward : bool ) -> int :
268+
269+ def get_2d_optimal_block_size (
270+ n : int ,
271+ dtype_size : int ,
272+ memory_multiplier : float ,
273+ safety_margin : float = 0.9 ,
274+ min_block : int = 16 ,
275+ max_block : int = 512 ,
276+ ):
312277 """
313- Compute optimal block size for mask-zero rowwise kernel.
278+ Compute optimal 2D block sizes (BLOCK_SIZE, BLOCK_SIZE) for a square matrix.
279+
280+ Args:
281+ n: matrix size (N x N)
282+ ub_capacity_bits: UB capacity in bits
283+ dtype_size: bytes per element
284+ memory_multiplier: estimated live buffer multiplier
285+ safety_margin: UB safety factor
286+ min_block: minimum block size
287+ max_block: maximum block size
288+
289+ Returns:
290+ BLOCK_SIZE
314291 """
292+ ub_manager = get_ub_manager ()
315293
316- # For small sizes, just use next power of 2
317- if n_cols <= 4096 :
318- return triton .next_power_of_2 (n_cols )
294+ dtype_size = max (dtype_size , 4 )
319295
320- # Mask kernel is light → small multiplier
321- memory_multiplier = 4.0 if is_forward else 6.0 # slightly conservative
296+ # Safe UB budget
297+ safe_ub = int ( ub_manager . ub_capacity_bits * safety_margin )
322298
323- tile_shapes = compute_default_tiling_strategy (
324- safety_margin = 0.9 ,
325- dtype_size = 4 ,
326- memory_multiplier = memory_multiplier ,
327- shapes = ((n_cols ,),),
328- tiling_dims = (0 ,),
329- )
299+ # Max tile area
300+ max_area = safe_ub // (memory_multiplier * dtype_size * 8 )
301+ max_area = max (1 , max_area )
330302
331- if tile_shapes and len (tile_shapes ) > 0 :
332- block_size = tile_shapes [0 ][0 ]
333- return max (4096 , block_size )
303+ # Ideal square tile size
304+ block = int (math .sqrt (max_area ))
334305
335- return 4096
306+ # Clamp to problem size
307+ block = min (block , n , max_block )
308+ block = max (block , min_block )
309+
310+ block = triton .next_power_of_2 (block )
311+ if block > n :
312+ block = n
313+
314+ return block
336315
337316
338317def get_optimal_size_fused_mask_softmax (L : int , is_forward : bool = True , dtype_size : int = 2 ):
@@ -360,6 +339,38 @@ def get_optimal_size_fused_mask_softmax(L: int, is_forward: bool = True, dtype_s
360339 return 2048
361340
362341
342+ def mask_zero_rowwise (scores : torch .Tensor ) -> torch .Tensor :
343+ """
344+ Forward pass for causal masking with zero values.
345+ Uses 1D row-wise processing.
346+
347+ Args:
348+ scores: Input scores tensor of shape (*batch, L, L)
349+
350+ Returns:
351+ Masked scores tensor with future positions set to 0.0
352+ """
353+ * batch , L , _ = scores .shape
354+ N = int (torch .prod (torch .tensor (batch ))) if batch else 1
355+ scores_f = scores .view (N , L , L )
356+ out = torch .empty_like (scores_f )
357+
358+ BLOCK_SIZE = get_2d_optimal_block_size (L , dtype_size = scores_f .element_size (), memory_multiplier = 12.0 )
359+
360+ grid = (triton .cdiv (L , BLOCK_SIZE ), triton .cdiv (L , BLOCK_SIZE ), N )
361+ _mask_row_kernel [grid ](
362+ scores_f ,
363+ out ,
364+ scores_f .stride (0 ),
365+ scores_f .stride (1 ),
366+ L ,
367+ mask_val = 0.0 ,
368+ BLOCK_SIZE = BLOCK_SIZE ,
369+ )
370+
371+ return out .view (* batch , L , L )
372+
373+
363374def fused_mask_softmax_forward (scores : torch .Tensor ) -> torch .Tensor :
364375 """
365376 Fused forward pass: causal masking + softmax.
@@ -438,7 +449,6 @@ def fused_mask_sparsemax_forward(scores: torch.Tensor) -> tuple[torch.Tensor, to
438449 """
439450 Forward pass: causal masking + sparsemax using reference implementation.
440451 Uses one-axis grid (one program per row).
441- Because of the complexity of sparsemax, we implement it in PyTorch and fuse only the masking.
442452
443453 Args:
444454 scores: Input scores tensor of shape (*batch, L, L)
@@ -453,16 +463,15 @@ def fused_mask_sparsemax_forward(scores: torch.Tensor) -> tuple[torch.Tensor, to
453463 scores_f = scores .view (N , L , L )
454464 scores_masked = torch .empty_like (scores_f )
455465
456- BLOCK_SIZE = get_optimal_size_fused_mask_softmax (L , is_forward = True )
457- num_cores = get_npu_core_count ()
458- grid_size = min (num_cores , N * L )
466+ # BLOCK_SIZE = get_optimal_size_fused_mask_softmax(L, is_forward=True)
467+ BLOCK_SIZE = get_2d_optimal_block_size (L , dtype_size = scores_f .element_size (), memory_multiplier = 12.0 )
459468
460- _mask_row_kernel [(grid_size ,)](
469+ grid = (triton .cdiv (L , BLOCK_SIZE ), triton .cdiv (L , BLOCK_SIZE ), N )
470+ _mask_row_kernel [grid ](
461471 scores_f ,
462472 scores_masked ,
463473 scores_f .stride (0 ),
464474 scores_f .stride (1 ),
465- N ,
466475 L ,
467476 mask_val = - 1e9 ,
468477 BLOCK_SIZE = BLOCK_SIZE ,
@@ -513,11 +522,6 @@ def fused_mask_sparsemax_backward(grad_out: torch.Tensor, probs: torch.Tensor) -
513522class LigerMultiTokenAttentionFunction (torch .autograd .Function ):
514523 """
515524 NPU-optimized Multi-Token Attention using 1D row-wise processing.
516-
517- This implementation is optimized for NPU hardware by:
518- 1. Using 1D row-wise kernels instead of 2D block-based kernels
519- 2. Larger block sizes for better memory throughput
520- 3. Reduced kernel launch overhead
521525 """
522526
523527 @staticmethod
0 commit comments