Skip to content

Commit 393915f

Browse files
committed
softmax forward 2 passes
1 parent 9d5afc4 commit 393915f

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

src/liger_kernel/ops/backends/_ascend/ops/multi_token_attention.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,11 @@ def _fused_mask_softmax_fwd_kernel(
6464
# Second pass: normalize and store
6565
for block_start in range(0, valid_len, BLOCK_SIZE):
6666
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
67-
col_mask = col_idx < valid_len
68-
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"))
69-
exp_vals = tl.exp(vals - max_val)
70-
probs = exp_vals / d_sum
71-
tl.store(out_row_ptr + col_idx, probs, mask=col_mask)
72-
73-
# Store zeros for masked positions
74-
for block_start in range(valid_len, L, BLOCK_SIZE):
75-
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
76-
col_mask = col_idx < L
77-
tl.store(out_row_ptr + col_idx, 0.0, mask=col_mask)
67+
mask = col_idx < L
68+
causal = col_idx <= row_idx
69+
vals = tl.load(row_ptr + col_idx, mask=mask & causal, other=float("-inf"))
70+
probs = tl.exp(vals - max_val) / d_sum
71+
tl.store(out_row_ptr + col_idx, probs, mask=mask)
7872

7973

8074
@triton.jit

0 commit comments

Comments
 (0)