@@ -1080,7 +1080,6 @@ def _sdpa_decode_splitk_kernel(
10801080 K_ptr ,
10811081 V_ptr ,
10821082 O_partial_ptr ,
1083- M_partial_ptr ,
10841083 L_partial_ptr ,
10851084 Mask_ptr ,
10861085 B ,
@@ -1102,13 +1101,14 @@ def _sdpa_decode_splitk_kernel(
11021101 stride_op_b ,
11031102 stride_op_h ,
11041103 stride_op_d ,
1105- stride_mp_s ,
1106- stride_mp_b ,
1107- stride_mp_h ,
1104+ stride_lp_s ,
1105+ stride_lp_b ,
1106+ stride_lp_h ,
11081107 stride_mb ,
11091108 stride_mq ,
11101109 stride_mk ,
11111110 sm_scale : tl .float32 ,
1111+ phi : tl .float32 ,
11121112 chunk_size ,
11131113 HAS_MASK : tl .constexpr ,
11141114 BLOCK_N : tl .constexpr ,
@@ -1138,7 +1138,7 @@ def _sdpa_decode_splitk_kernel(
11381138 )
11391139 q = tl .load (q_ptrs , mask = g_valid [:, None ], other = 0.0 ).to (tl .bfloat16 )
11401140
1141- m_i = tl . full ([ BLOCK_G ], - float ( "inf" ), dtype = tl . float32 )
1141+ # FlashDecoding++ async softmax: use unified max phi instead of tracking m_i
11421142 l_i = tl .zeros ([BLOCK_G ], dtype = tl .float32 )
11431143 acc = tl .zeros ([BLOCK_G , HEAD_DIM ], dtype = tl .float32 )
11441144
@@ -1175,15 +1175,10 @@ def _sdpa_decode_splitk_kernel(
11751175 mask_block , qk , tl .full (qk .shape , - float ("inf" ), dtype = tl .float32 )
11761176 )
11771177
1178- # Online softmax update
1179- m_ij = tl .maximum (m_i , tl .max (qk , axis = 1 ).to (tl .float32 ))
1180- safe_diff = tl .where (
1181- m_ij [:, None ] > - float ("inf" ), qk - m_ij [:, None ], - float ("inf" )
1182- )
1178+ # FlashDecoding++ async softmax: subtract unified phi instead of local max
1179+ safe_diff = tl .where (qk > - float ("inf" ), qk - phi , - float ("inf" ))
11831180 p_f32 = tl .exp (safe_diff ).to (tl .float32 )
11841181 l_ij = tl .sum (p_f32 , axis = 1 ).to (tl .float32 )
1185- safe_alpha_diff = tl .where (m_ij > - float ("inf" ), m_i - m_ij , 0.0 )
1186- alpha = tl .exp (safe_alpha_diff ).to (tl .float32 )
11871182
11881183 v_ptrs = V_ptr + (
11891184 b * stride_vb
@@ -1194,9 +1189,8 @@ def _sdpa_decode_splitk_kernel(
11941189 v = tl .load (v_ptrs , mask = n_valid [:, None ], other = 0.0 ).to (tl .bfloat16 )
11951190
11961191 p_bf16 = p_f32 .to (tl .bfloat16 )
1197- acc = (acc * alpha [:, None ] + tl .dot (p_bf16 , v )).to (tl .float32 )
1198- l_i = (l_i * alpha + l_ij ).to (tl .float32 )
1199- m_i = m_ij
1192+ acc = (acc + tl .dot (p_bf16 , v )).to (tl .float32 )
1193+ l_i = (l_i + l_ij ).to (tl .float32 )
12001194
12011195 # Store partial results for valid groups only
12021196 h_q_all = h_kv * NUM_GROUPS + offs_g # [BLOCK_G]
@@ -1208,31 +1202,25 @@ def _sdpa_decode_splitk_kernel(
12081202 )
12091203 tl .store (o_ptrs , acc , mask = g_valid [:, None ])
12101204
1211- ml_ptrs = M_partial_ptr + (
1212- split_id * stride_mp_s + b * stride_mp_b + h_q_all * stride_mp_h
1213- )
1214- tl .store (ml_ptrs , m_i , mask = g_valid )
1215-
12161205 ll_ptrs = L_partial_ptr + (
1217- split_id * stride_mp_s + b * stride_mp_b + h_q_all * stride_mp_h
1206+ split_id * stride_lp_s + b * stride_lp_b + h_q_all * stride_lp_h
12181207 )
12191208 tl .store (ll_ptrs , l_i , mask = g_valid )
12201209
12211210
12221211@triton .jit
12231212def _sdpa_decode_reduce_kernel (
12241213 O_partial_ptr ,
1225- M_partial_ptr ,
12261214 L_partial_ptr ,
12271215 O_ptr ,
12281216 num_splits ,
12291217 stride_op_s ,
12301218 stride_op_b ,
12311219 stride_op_h ,
12321220 stride_op_d ,
1233- stride_mp_s ,
1234- stride_mp_b ,
1235- stride_mp_h ,
1221+ stride_lp_s ,
1222+ stride_lp_b ,
1223+ stride_lp_h ,
12361224 stride_ob ,
12371225 stride_oh ,
12381226 stride_om ,
@@ -1242,40 +1230,25 @@ def _sdpa_decode_reduce_kernel(
12421230 pid = tl .program_id (axis = 0 )
12431231 offs_d = tl .arange (0 , HEAD_DIM )
12441232
1245- # pid indexes into flattened (B, H_q). Partial buffers are allocated
1246- # contiguous in _launch_decode_splitk, so pid * stride_*_h is valid.
1247- # Find global max across all splits
1248- m_global = tl .full ([1 ], - float ("inf" ), dtype = tl .float32 )
1249- for s in tl .range (0 , num_splits ):
1250- m_ptr = M_partial_ptr + s * stride_mp_s + pid * stride_mp_h
1251- m_s = tl .load (m_ptr )
1252- m_global = tl .maximum (m_global , m_s )
1253-
1254- # Accumulate rescaled outputs
1233+ # FlashDecoding++ async softmax: no rescaling needed, just sum partials
12551234 acc = tl .zeros ([HEAD_DIM ], dtype = tl .float32 )
12561235 l_global = tl .zeros ([1 ], dtype = tl .float32 )
1236+
12571237 for s in tl .range (0 , num_splits ):
1258- m_ptr = M_partial_ptr + s * stride_mp_s + pid * stride_mp_h
1259- l_ptr = L_partial_ptr + s * stride_mp_s + pid * stride_mp_h
1238+ l_ptr = L_partial_ptr + s * stride_lp_s + pid * stride_lp_h
12601239 o_ptrs = O_partial_ptr + (
12611240 s * stride_op_s + pid * stride_op_h + offs_d * stride_op_d
12621241 )
12631242
1264- m_s = tl .load (m_ptr )
12651243 l_s = tl .load (l_ptr )
12661244 o_s = tl .load (o_ptrs )
12671245
1268- safe_diff = tl .where (m_global > - float ("inf" ), m_s - m_global , 0.0 )
1269- scale = tl .exp (safe_diff ).to (tl .float32 )
1270- acc += o_s * scale
1271- l_global += l_s * scale
1246+ acc += o_s
1247+ l_global += l_s
12721248
12731249 inv_l = tl .where (l_global > 0 , 1.0 / l_global , 0.0 )
12741250 acc = acc * inv_l
12751251
1276- # pid = b*H_q + h_q. For contiguous output [B, H_q, 1, D] with L_q=1,
1277- # stride_ob == H_q * stride_oh, so pid * stride_oh is correct.
1278- # This relies on `out` being freshly allocated and contiguous.
12791252 o_out_ptrs = O_ptr + pid * stride_oh + offs_d * stride_od
12801253 tl .store (o_out_ptrs , acc .to (tl .bfloat16 ))
12811254
@@ -1297,16 +1270,14 @@ def _launch_decode_splitk(
12971270 stride_mq : int ,
12981271 stride_mk : int ,
12991272 num_groups : int ,
1273+ phi : float ,
13001274) -> None :
13011275 num_splits = min (max (triton .cdiv (L_kv , 256 ), 1 ), 128 )
13021276 chunk_size = triton .cdiv (L_kv , num_splits )
13031277
13041278 O_partial = torch .empty (
13051279 (num_splits , B , H_q , D ), device = query .device , dtype = torch .float32
13061280 )
1307- M_partial = torch .full (
1308- (num_splits , B , H_q ), - float ("inf" ), device = query .device , dtype = torch .float32
1309- )
13101281 L_partial = torch .zeros (
13111282 (num_splits , B , H_q ), device = query .device , dtype = torch .float32
13121283 )
@@ -1316,15 +1287,14 @@ def _launch_decode_splitk(
13161287 stride_vb , stride_vh , stride_vn , stride_vd = value .stride ()
13171288 stride_ob , stride_oh , stride_om , stride_od = out .stride ()
13181289 stride_op_s , stride_op_b , stride_op_h , stride_op_d = O_partial .stride ()
1319- stride_mp_s , stride_mp_b , stride_mp_h = M_partial .stride ()
1290+ stride_lp_s , stride_lp_b , stride_lp_h = L_partial .stride ()
13201291
13211292 grid_split = (num_splits , B * H_kv )
13221293 wrap_triton (_sdpa_decode_splitk_kernel )[grid_split ](
13231294 query ,
13241295 key ,
13251296 value ,
13261297 O_partial ,
1327- M_partial ,
13281298 L_partial ,
13291299 Mask_ptr if HAS_MASK else 0 ,
13301300 B ,
@@ -1346,13 +1316,14 @@ def _launch_decode_splitk(
13461316 stride_op_b ,
13471317 stride_op_h ,
13481318 stride_op_d ,
1349- stride_mp_s ,
1350- stride_mp_b ,
1351- stride_mp_h ,
1319+ stride_lp_s ,
1320+ stride_lp_b ,
1321+ stride_lp_h ,
13521322 stride_mb ,
13531323 stride_mq ,
13541324 stride_mk ,
13551325 sm_scale ,
1326+ phi ,
13561327 chunk_size ,
13571328 HAS_MASK = HAS_MASK ,
13581329 HEAD_DIM = D ,
@@ -1363,17 +1334,16 @@ def _launch_decode_splitk(
13631334 grid_reduce = (B * H_q ,)
13641335 wrap_triton (_sdpa_decode_reduce_kernel )[grid_reduce ](
13651336 O_partial ,
1366- M_partial ,
13671337 L_partial ,
13681338 out ,
13691339 num_splits ,
13701340 stride_op_s ,
13711341 stride_op_b ,
13721342 stride_op_h ,
13731343 stride_op_d ,
1374- stride_mp_s ,
1375- stride_mp_b ,
1376- stride_mp_h ,
1344+ stride_lp_s ,
1345+ stride_lp_b ,
1346+ stride_lp_h ,
13771347 stride_ob ,
13781348 stride_oh ,
13791349 stride_om ,
@@ -1394,9 +1364,13 @@ def sdpa_decode_splitk(
13941364 is_causal : bool = False ,
13951365 scale : float = 0.0 ,
13961366 enable_gqa : bool = False ,
1367+ phi : float = 5.0 ,
13971368) -> torch .Tensor :
13981369 """Split-K flash-decoding SDPA for L_q=1 (decode step).
13991370
1371+ Uses FlashDecoding++ async softmax with unified maximum value (phi)
1372+ to eliminate per-split max tracking and cross-split rescaling.
1373+
14001374 Signature mirrors sdpa() for drop-in use with torch.cond dispatch.
14011375 enable_gqa is accepted but ignored — GQA is handled natively via
14021376 H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1.
@@ -1452,6 +1426,7 @@ def sdpa_decode_splitk(
14521426 stride_mq ,
14531427 stride_mk ,
14541428 num_groups ,
1429+ phi ,
14551430 )
14561431 return out
14571432
@@ -1466,6 +1441,7 @@ def _sdpa_decode_splitk_abstract(
14661441 is_causal : bool = False ,
14671442 scale : float = 0.0 ,
14681443 enable_gqa : bool = False ,
1444+ phi : float = 5.0 ,
14691445) -> torch .Tensor :
14701446 assert query .dtype == key .dtype == value .dtype , "Q, K, V must have the same dtype"
14711447 B , H_q , L_q , D = query .shape
0 commit comments