1+ import math
2+
13import torch
24import torch .nn .functional as F
35import triton
68from torch .nn .modules .utils import _pair
79
810from liger_kernel .ops .backends ._ascend .ub_manager import compute_default_tiling_strategy
11+ from liger_kernel .ops .backends ._ascend .ub_manager import get_ub_manager
912from liger_kernel .ops .utils import ensure_contiguous
1013from liger_kernel .ops .utils import get_npu_core_count
1114
@@ -58,23 +61,23 @@ def _fused_mask_softmax_fwd_kernel(
5861 for block_start in range (0 , valid_len , BLOCK_SIZE ):
5962 col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
6063 col_mask = col_idx < valid_len
61- vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ), eviction_policy = "evict_first" )
64+ vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ))
6265 max_val = tl .maximum (max_val , tl .max (vals ))
6366
6467 # Second pass: compute exp and sum
6568 sum_exp = 0.0
6669 for block_start in range (0 , valid_len , BLOCK_SIZE ):
6770 col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
6871 col_mask = col_idx < valid_len
69- vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ), eviction_policy = "evict_first" )
72+ vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ))
7073 exp_vals = tl .exp (vals - max_val )
7174 sum_exp += tl .sum (tl .where (col_mask , exp_vals , 0.0 ))
7275
7376 # Third pass: normalize and store
7477 for block_start in range (0 , valid_len , BLOCK_SIZE ):
7578 col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
7679 col_mask = col_idx < valid_len
77- vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ), eviction_policy = "evict_first" )
80+ vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ))
7881 exp_vals = tl .exp (vals - max_val )
7982 probs = exp_vals / sum_exp
8083 tl .store (out_row_ptr + col_idx , probs , mask = col_mask )
@@ -133,8 +136,8 @@ def _fused_mask_softmax_bwd_kernel(
133136 for block_start in range (0 , valid_len , BLOCK_SIZE ):
134137 col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
135138 col_mask = col_idx < valid_len
136- grad_vals = tl .load (grad_row_ptr + col_idx , mask = col_mask , other = 0.0 , eviction_policy = "evict_first" )
137- prob_vals = tl .load (probs_row_ptr + col_idx , mask = col_mask , other = 0.0 , eviction_policy = "evict_first" )
139+ grad_vals = tl .load (grad_row_ptr + col_idx , mask = col_mask , other = 0.0 )
140+ prob_vals = tl .load (probs_row_ptr + col_idx , mask = col_mask , other = 0.0 )
138141 dot += tl .sum (tl .where (col_mask , grad_vals * prob_vals , 0.0 ))
139142
140143 # Second pass: compute gradient
@@ -201,12 +204,8 @@ def _fused_mask_sparsemax_bwd_kernel(
201204 for block_start in range (0 , valid_len , BLOCK_SIZE ):
202205 col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
203206 col_mask = col_idx < valid_len
204- prob_vals = tl .load (probs_row_ptr + col_idx , mask = col_mask , other = 0.0 , eviction_policy = "evict_first" ).to (
205- tl .float32
206- )
207- grad_vals = tl .load (grad_row_ptr + col_idx , mask = col_mask , other = 0.0 , eviction_policy = "evict_first" ).to (
208- tl .float32
209- )
207+ prob_vals = tl .load (probs_row_ptr + col_idx , mask = col_mask , other = 0.0 ).to (tl .float32 )
208+ grad_vals = tl .load (grad_row_ptr + col_idx , mask = col_mask , other = 0.0 ).to (tl .float32 )
210209 supp = prob_vals > 0.0
211210 go_sum += tl .sum (tl .where (supp & col_mask , grad_vals , 0.0 ))
212211 supp_cnt += tl .sum (tl .where (supp & col_mask , 1.0 , 0.0 ))
@@ -229,110 +228,103 @@ def _fused_mask_sparsemax_bwd_kernel(
229228 tl .store (out_row_ptr + col_idx , 0.0 , mask = col_mask )
230229
231230
232- def mask_zero_rowwise (scores : torch .Tensor ) -> torch .Tensor :
233- """
234- Forward pass for causal masking with zero values.
235- Uses 1D row-wise processing.
236-
237- Args:
238- scores: Input scores tensor of shape (*batch, L, L)
239-
240- Returns:
241- Masked scores tensor with future positions set to 0.0
242- """
243- * batch , L , _ = scores .shape
244- N = int (torch .prod (torch .tensor (batch ))) if batch else 1
245- scores_f = scores .view (N , L , L )
246- out = torch .empty_like (scores_f )
247-
248- BLOCK_SIZE = get_optimal_block_size (L , is_forward = True )
249- num_cores = get_npu_core_count ()
250- grid_size = min (num_cores , N * L )
251-
252- _mask_row_kernel [(grid_size ,)](
253- scores_f ,
254- out ,
255- scores_f .stride (0 ),
256- scores_f .stride (1 ),
257- N ,
258- L ,
259- mask_val = 0.0 ,
260- BLOCK_SIZE = BLOCK_SIZE ,
261- )
262-
263- return out .view (* batch , L , L )
264-
265-
266231@triton .jit
267232def _mask_row_kernel (
268233 scores_ptr ,
269234 out_ptr ,
270235 stride_b ,
271236 stride_row ,
272- N ,
273237 L ,
274238 mask_val : tl .constexpr ,
275239 BLOCK_SIZE : tl .constexpr ,
276240):
277- pid = tl .program_id (0 )
278- num_progs = tl .num_programs (0 )
279- n_rows = N * L
241+ """
242+ 2D-tiled causal masking kernel.
280243
281- for linear_idx in tl .range (pid , n_rows , num_progs ):
282- batch_id = linear_idx // L
283- row_idx = linear_idx % L
244+ Grid:
245+ axis 0 (pid_m): row-block index
246+ axis 1 (pid_n): col-block index
247+ axis 2 (pid_b): batch index
284248
285- row_ptr = scores_ptr + batch_id * stride_b + row_idx * stride_row
286- out_row_ptr = out_ptr + batch_id * stride_b + row_idx * stride_row
249+ Each program handles a BLOCK_SIZE x BLOCK_SIZE tile.
250+ """
287251
288- # columns handled in blocks
289- for block_start in range (0 , L , BLOCK_SIZE ):
290- col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
291- col_mask = col_idx < L
252+ pid_m = tl .program_id (0 ) # row block
253+ pid_n = tl .program_id (1 ) # col block
254+ pid_b = tl .program_id (2 ) # batch
255+
256+ # Compute row/col indices
257+ row_start = pid_m * BLOCK_SIZE
258+ col_start = pid_n * BLOCK_SIZE
292259
293- # causal condition
294- keep = col_idx <= row_idx
260+ row_idx = row_start + tl . arange ( 0 , BLOCK_SIZE )
261+ col_idx = col_start + tl . arange ( 0 , BLOCK_SIZE )
295262
296- vals = tl .load (
297- row_ptr + col_idx ,
298- mask = col_mask ,
299- other = 0.0 ,
300- )
263+ row_mask = row_idx < L
264+ col_mask = col_idx < L
301265
302- masked_vals = tl .where (keep , vals , mask_val )
266+ ptrs = scores_ptr + pid_b * stride_b + row_idx [:, None ] * stride_row + col_idx [None , :]
267+ out_ptrs = out_ptr + pid_b * stride_b + row_idx [:, None ] * stride_row + col_idx [None , :]
303268
304- tl .store (
305- out_row_ptr + col_idx ,
306- masked_vals ,
307- mask = col_mask ,
308- )
269+ vals = tl .load (ptrs , mask = row_mask [:, None ] & col_mask [None , :], other = 0.0 )
270+ causal = col_idx [None , :] <= row_idx [:, None ]
271+ masked_vals = tl .where (causal , vals , mask_val )
309272
273+ tl .store (out_ptrs , masked_vals , mask = row_mask [:, None ] & col_mask [None , :])
310274
311- def get_optimal_block_size (n_cols : int , is_forward : bool ) -> int :
275+
276+ def get_2d_optimal_block_size (
277+ n : int ,
278+ dtype_size : int ,
279+ memory_multiplier : float ,
280+ safety_margin : float = 0.9 ,
281+ min_block : int = 16 ,
282+ max_block : int = 512 ,
283+ ):
312284 """
313- Compute optimal block size for mask-zero rowwise kernel.
285+ Compute optimal 2D block sizes (BLOCK_SIZE, BLOCK_SIZE) for a square matrix.
286+
287+ Strategy:
288+ - Compute max tile area allowed by UB
289+ - Use square tiling for best balance
290+ - Round to power-of-two for Triton efficiency
291+ - Clamp to problem size
292+
293+ Args:
294+ n: matrix size (N x N)
295+ ub_capacity_bits: UB capacity in bits
296+ dtype_size: bytes per element
297+ memory_multiplier: estimated live buffer multiplier
298+ safety_margin: UB safety factor
299+ min_block: minimum block size
300+ max_block: maximum block size
301+
302+ Returns:
303+ BLOCK_SIZE
314304 """
305+ ub_manager = get_ub_manager ()
315306
316- # For small sizes, just use next power of 2
317- if n_cols <= 4096 :
318- return triton .next_power_of_2 (n_cols )
307+ dtype_size = max (dtype_size , 4 )
319308
320- # Mask kernel is light → small multiplier
321- memory_multiplier = 4.0 if is_forward else 6.0 # slightly conservative
309+ # Safe UB budget
310+ safe_ub = int ( ub_manager . ub_capacity_bits * safety_margin )
322311
323- tile_shapes = compute_default_tiling_strategy (
324- safety_margin = 0.9 ,
325- dtype_size = 4 ,
326- memory_multiplier = memory_multiplier ,
327- shapes = ((n_cols ,),),
328- tiling_dims = (0 ,),
329- )
312+ # Max tile area
313+ max_area = safe_ub // (memory_multiplier * dtype_size * 8 )
314+ max_area = max (1 , max_area )
330315
331- if tile_shapes and len (tile_shapes ) > 0 :
332- block_size = tile_shapes [0 ][0 ]
333- return max (4096 , block_size )
316+ # Ideal square tile size
317+ block = int (math .sqrt (max_area ))
318+
319+ # Clamp to problem size
320+ block = min (block , n , max_block )
321+ block = max (block , min_block )
322+
323+ block = triton .next_power_of_2 (block )
324+ if block > n :
325+ block = n
334326
335- return 4096
327+ return block
336328
337329
338330def get_optimal_size_fused_mask_softmax (L : int , is_forward : bool = True , dtype_size : int = 2 ):
@@ -360,6 +352,38 @@ def get_optimal_size_fused_mask_softmax(L: int, is_forward: bool = True, dtype_s
360352 return 2048
361353
362354
355+ def mask_zero_rowwise (scores : torch .Tensor ) -> torch .Tensor :
356+ """
357+ Forward pass for causal masking with zero values.
358+ Uses 1D row-wise processing.
359+
360+ Args:
361+ scores: Input scores tensor of shape (*batch, L, L)
362+
363+ Returns:
364+ Masked scores tensor with future positions set to 0.0
365+ """
366+ * batch , L , _ = scores .shape
367+ N = int (torch .prod (torch .tensor (batch ))) if batch else 1
368+ scores_f = scores .view (N , L , L )
369+ out = torch .empty_like (scores_f )
370+
371+ BLOCK_SIZE = get_2d_optimal_block_size (L , dtype_size = scores_f .element_size (), memory_multiplier = 12.0 )
372+
373+ grid = (triton .cdiv (L , BLOCK_SIZE ), triton .cdiv (L , BLOCK_SIZE ), N )
374+ _mask_row_kernel [grid ](
375+ scores_f ,
376+ out ,
377+ scores_f .stride (0 ),
378+ scores_f .stride (1 ),
379+ L ,
380+ mask_val = 0.0 ,
381+ BLOCK_SIZE = BLOCK_SIZE ,
382+ )
383+
384+ return out .view (* batch , L , L )
385+
386+
363387def fused_mask_softmax_forward (scores : torch .Tensor ) -> torch .Tensor :
364388 """
365389 Fused forward pass: causal masking + softmax.
@@ -453,16 +477,15 @@ def fused_mask_sparsemax_forward(scores: torch.Tensor) -> tuple[torch.Tensor, to
453477 scores_f = scores .view (N , L , L )
454478 scores_masked = torch .empty_like (scores_f )
455479
456- BLOCK_SIZE = get_optimal_size_fused_mask_softmax (L , is_forward = True )
457- num_cores = get_npu_core_count ()
458- grid_size = min (num_cores , N * L )
480+ # BLOCK_SIZE = get_optimal_size_fused_mask_softmax(L, is_forward=True)
481+ BLOCK_SIZE = get_2d_optimal_block_size (L , dtype_size = scores_f .element_size (), memory_multiplier = 12.0 )
459482
460- _mask_row_kernel [(grid_size ,)](
483+ grid = (triton .cdiv (L , BLOCK_SIZE ), triton .cdiv (L , BLOCK_SIZE ), N )
484+ _mask_row_kernel [grid ](
461485 scores_f ,
462486 scores_masked ,
463487 scores_f .stride (0 ),
464488 scores_f .stride (1 ),
465- N ,
466489 L ,
467490 mask_val = - 1e9 ,
468491 BLOCK_SIZE = BLOCK_SIZE ,
0 commit comments