@@ -49,30 +49,25 @@ def _fused_mask_softmax_fwd_kernel(
4949
5050 valid_len = row_idx + 1
5151
52- # First pass: compute max for numerical stability
52+ # First pass: online softmax
5353 max_val = float ("-inf" )
54+ d_sum = 0.0
5455 for block_start in range (0 , valid_len , BLOCK_SIZE ):
5556 col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
5657 col_mask = col_idx < valid_len
5758 vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ))
58- max_val = tl .maximum (max_val , tl .max (vals ))
59+ m_block = tl .max (vals )
60+ m_new = tl .maximum (max_val , m_block )
61+ d_sum = d_sum * tl .exp (max_val - m_new ) + tl .sum (tl .exp (vals - m_new ))
62+ max_val = m_new
5963
60- # Second pass: compute exp and sum
61- sum_exp = 0.0
64+ # Second pass: normalize and store
6265 for block_start in range (0 , valid_len , BLOCK_SIZE ):
6366 col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
6467 col_mask = col_idx < valid_len
6568 vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ))
6669 exp_vals = tl .exp (vals - max_val )
67- sum_exp += tl .sum (tl .where (col_mask , exp_vals , 0.0 ))
68-
69- # Third pass: normalize and store
70- for block_start in range (0 , valid_len , BLOCK_SIZE ):
71- col_idx = block_start + tl .arange (0 , BLOCK_SIZE )
72- col_mask = col_idx < valid_len
73- vals = tl .load (row_ptr + col_idx , mask = col_mask , other = float ("-inf" ))
74- exp_vals = tl .exp (vals - max_val )
75- probs = exp_vals / sum_exp
70+ probs = exp_vals / d_sum
7671 tl .store (out_row_ptr + col_idx , probs , mask = col_mask )
7772
7873 # Store zeros for masked positions
0 commit comments