Skip to content

Commit 297546c

Browse files
committed
softmax forward 2 passes
1 parent 9d5afc4 commit 297546c

1 file changed

Lines changed: 7 additions & 24 deletions

File tree

src/liger_kernel/ops/backends/_ascend/ops/multi_token_attention.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,11 @@ def _fused_mask_softmax_fwd_kernel(
6464
# Second pass: normalize and store
6565
for block_start in range(0, valid_len, BLOCK_SIZE):
6666
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
67-
col_mask = col_idx < valid_len
68-
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"))
69-
exp_vals = tl.exp(vals - max_val)
70-
probs = exp_vals / d_sum
71-
tl.store(out_row_ptr + col_idx, probs, mask=col_mask)
72-
73-
# Store zeros for masked positions
74-
for block_start in range(valid_len, L, BLOCK_SIZE):
75-
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
76-
col_mask = col_idx < L
77-
tl.store(out_row_ptr + col_idx, 0.0, mask=col_mask)
67+
mask = col_idx < L
68+
causal = col_idx <= row_idx
69+
vals = tl.load(row_ptr + col_idx, mask=mask & causal, other=float("-inf"))
70+
probs = tl.exp(vals - max_val) / d_sum
71+
tl.store(out_row_ptr + col_idx, probs, mask=mask)
7872

7973

8074
@triton.jit
@@ -137,11 +131,6 @@ def _fused_mask_softmax_bwd_kernel(
137131
grad_scores = prob_vals * (grad_vals - dot)
138132
tl.store(out_row_ptr + col_idx, grad_scores, mask=col_mask)
139133

140-
# Zero out masked positions
141-
for block_start in range(valid_len, L, BLOCK_SIZE):
142-
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
143-
col_mask = col_idx < L
144-
tl.store(out_row_ptr + col_idx, 0.0, mask=col_mask)
145134

146135

147136
@triton.jit
@@ -209,12 +198,6 @@ def _fused_mask_sparsemax_bwd_kernel(
209198
grad_scores = tl.where(supp, grad_vals - avg_grad, 0.0)
210199
tl.store(out_row_ptr + col_idx, grad_scores, mask=col_mask)
211200

212-
# Zero out masked positions
213-
for block_start in range(valid_len, L, BLOCK_SIZE):
214-
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
215-
col_mask = col_idx < L
216-
tl.store(out_row_ptr + col_idx, 0.0, mask=col_mask)
217-
218201

219202
@triton.jit
220203
def _mask_row_kernel(
@@ -418,7 +401,7 @@ def fused_mask_softmax_backward(grad_out: torch.Tensor, probs: torch.Tensor) ->
418401
N = int(torch.prod(torch.tensor(batch))) if batch else 1
419402
grad_out_f = grad_out.view(N, L, L)
420403
probs_f = probs.view(N, L, L)
421-
grad_scores = torch.empty_like(grad_out_f)
404+
grad_scores = torch.zeros_like(grad_out_f)
422405

423406
BLOCK_SIZE = get_optimal_size_fused_mask_softmax(L, is_forward=False)
424407

@@ -493,7 +476,7 @@ def fused_mask_sparsemax_backward(grad_out: torch.Tensor, probs: torch.Tensor) -
493476
N = int(torch.prod(torch.tensor(batch))) if batch else 1
494477
grad_out_f = grad_out.view(N, L, L)
495478
probs_f = probs.view(N, L, L)
496-
grad_scores = torch.empty_like(grad_out_f)
479+
grad_scores = torch.zeros_like(grad_out_f)
497480

498481
BLOCK_SIZE = get_optimal_size_fused_mask_softmax(L, is_forward=False)
499482

0 commit comments

Comments
 (0)