File tree Expand file tree Collapse file tree 1 file changed +5
-11
lines changed
src/liger_kernel/ops/backends/_ascend/ops Expand file tree Collapse file tree 1 file changed +5
-11
lines changed Original file line number Diff line number Diff line change @@ -64,17 +64,11 @@ def _fused_mask_softmax_fwd_kernel(
6464 # Second pass: normalize and store
6565 for block_start in range (0 , valid_len , BLOCK_SIZE ):
6666 col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
67- col_mask = col_idx < valid_len
68- vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ))
69- exp_vals = tl .exp (vals - max_val )
70- probs = exp_vals / d_sum
71- tl .store (out_row_ptr + col_idx , probs , mask = col_mask )
72-
73- # Store zeros for masked positions
74- for block_start in range (valid_len , L , BLOCK_SIZE ):
75- col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
76- col_mask = col_idx < L
77- tl .store (out_row_ptr + col_idx , 0.0 , mask = col_mask )
67+ mask = col_idx < L
68+ causal = col_idx <= row_idx
69+ vals = tl .load (row_ptr + col_idx , mask = mask & causal , other = float ("-inf" ))
70+ probs = tl .exp (vals - max_val ) / d_sum
71+ tl .store (out_row_ptr + col_idx , probs , mask = mask )
7872
7973
8074@triton .jit
You can’t perform that action at this time.
0 commit comments