Skip to content

Commit 9d5afc4

Browse files
committed
use online softmax
1 parent bbd247b commit 9d5afc4

1 file changed

Lines changed: 8 additions & 13 deletions

File tree

src/liger_kernel/ops/backends/_ascend/ops/multi_token_attention.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,30 +49,25 @@ def _fused_mask_softmax_fwd_kernel(
4949

5050
valid_len = row_idx + 1
5151

52-
# First pass: compute max for numerical stability
52+
# First pass: online softmax
5353
max_val = float("-inf")
54+
d_sum = 0.0
5455
for block_start in range(0, valid_len, BLOCK_SIZE):
5556
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
5657
col_mask = col_idx < valid_len
5758
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"))
58-
max_val = tl.maximum(max_val, tl.max(vals))
59+
m_block = tl.max(vals)
60+
m_new = tl.maximum(max_val, m_block)
61+
d_sum = d_sum * tl.exp(max_val - m_new) + tl.sum(tl.exp(vals - m_new))
62+
max_val = m_new
5963

60-
# Second pass: compute exp and sum
61-
sum_exp = 0.0
64+
# Second pass: normalize and store
6265
for block_start in range(0, valid_len, BLOCK_SIZE):
6366
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
6467
col_mask = col_idx < valid_len
6568
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"))
6669
exp_vals = tl.exp(vals - max_val)
67-
sum_exp += tl.sum(tl.where(col_mask, exp_vals, 0.0))
68-
69-
# Third pass: normalize and store
70-
for block_start in range(0, valid_len, BLOCK_SIZE):
71-
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
72-
col_mask = col_idx < valid_len
73-
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"))
74-
exp_vals = tl.exp(vals - max_val)
75-
probs = exp_vals / sum_exp
70+
probs = exp_vals / d_sum
7671
tl.store(out_row_ptr + col_idx, probs, mask=col_mask)
7772

7873
# Store zeros for masked positions

0 commit comments

Comments
 (0)