@@ -1080,7 +1080,6 @@ def _sdpa_decode_splitk_kernel(
10801080 K_ptr ,
10811081 V_ptr ,
10821082 O_partial_ptr ,
1083- M_partial_ptr ,
10841083 L_partial_ptr ,
10851084 Mask_ptr ,
10861085 B ,
@@ -1102,13 +1101,14 @@ def _sdpa_decode_splitk_kernel(
11021101 stride_op_b ,
11031102 stride_op_h ,
11041103 stride_op_d ,
1105- stride_mp_s ,
1106- stride_mp_b ,
1107- stride_mp_h ,
1104+ stride_lp_s ,
1105+ stride_lp_b ,
1106+ stride_lp_h ,
11081107 stride_mb ,
11091108 stride_mq ,
11101109 stride_mk ,
11111110 sm_scale : tl .float32 ,
1111+ phi : tl .float32 ,
11121112 chunk_size ,
11131113 HAS_MASK : tl .constexpr ,
11141114 BLOCK_N : tl .constexpr ,
@@ -1138,7 +1138,7 @@ def _sdpa_decode_splitk_kernel(
11381138 )
11391139 q = tl .load (q_ptrs , mask = g_valid [:, None ], other = 0.0 ).to (tl .bfloat16 )
11401140
1141- m_i = tl . full ([ BLOCK_G ], - float ( "inf" ), dtype = tl . float32 )
1141+ # FlashDecoding++ async softmax: use unified max phi instead of tracking m_i
11421142 l_i = tl .zeros ([BLOCK_G ], dtype = tl .float32 )
11431143 acc = tl .zeros ([BLOCK_G , HEAD_DIM ], dtype = tl .float32 )
11441144
@@ -1175,15 +1175,12 @@ def _sdpa_decode_splitk_kernel(
11751175 mask_block , qk , tl .full (qk .shape , - float ("inf" ), dtype = tl .float32 )
11761176 )
11771177
1178- # Online softmax update
1179- m_ij = tl .maximum (m_i , tl .max (qk , axis = 1 ).to (tl .float32 ))
1178+ # FlashDecoding++ async softmax: subtract unified phi instead of local max
11801179 safe_diff = tl .where (
1181- m_ij [:, None ] > - float ("inf" ), qk - m_ij [:, None ] , - float ("inf" )
1180+ qk > - float ("inf" ), qk - phi , - float ("inf" )
11821181 )
11831182 p_f32 = tl .exp (safe_diff ).to (tl .float32 )
11841183 l_ij = tl .sum (p_f32 , axis = 1 ).to (tl .float32 )
1185- safe_alpha_diff = tl .where (m_ij > - float ("inf" ), m_i - m_ij , 0.0 )
1186- alpha = tl .exp (safe_alpha_diff ).to (tl .float32 )
11871184
11881185 v_ptrs = V_ptr + (
11891186 b * stride_vb
@@ -1194,9 +1191,8 @@ def _sdpa_decode_splitk_kernel(
11941191 v = tl .load (v_ptrs , mask = n_valid [:, None ], other = 0.0 ).to (tl .bfloat16 )
11951192
11961193 p_bf16 = p_f32 .to (tl .bfloat16 )
1197- acc = (acc * alpha [:, None ] + tl .dot (p_bf16 , v )).to (tl .float32 )
1198- l_i = (l_i * alpha + l_ij ).to (tl .float32 )
1199- m_i = m_ij
1194+ acc = (acc + tl .dot (p_bf16 , v )).to (tl .float32 )
1195+ l_i = (l_i + l_ij ).to (tl .float32 )
12001196
12011197 # Store partial results for valid groups only
12021198 h_q_all = h_kv * NUM_GROUPS + offs_g # [BLOCK_G]
@@ -1208,31 +1204,25 @@ def _sdpa_decode_splitk_kernel(
12081204 )
12091205 tl .store (o_ptrs , acc , mask = g_valid [:, None ])
12101206
1211- ml_ptrs = M_partial_ptr + (
1212- split_id * stride_mp_s + b * stride_mp_b + h_q_all * stride_mp_h
1213- )
1214- tl .store (ml_ptrs , m_i , mask = g_valid )
1215-
12161207 ll_ptrs = L_partial_ptr + (
1217- split_id * stride_mp_s + b * stride_mp_b + h_q_all * stride_mp_h
1208+ split_id * stride_lp_s + b * stride_lp_b + h_q_all * stride_lp_h
12181209 )
12191210 tl .store (ll_ptrs , l_i , mask = g_valid )
12201211
12211212
12221213@triton .jit
12231214def _sdpa_decode_reduce_kernel (
12241215 O_partial_ptr ,
1225- M_partial_ptr ,
12261216 L_partial_ptr ,
12271217 O_ptr ,
12281218 num_splits ,
12291219 stride_op_s ,
12301220 stride_op_b ,
12311221 stride_op_h ,
12321222 stride_op_d ,
1233- stride_mp_s ,
1234- stride_mp_b ,
1235- stride_mp_h ,
1223+ stride_lp_s ,
1224+ stride_lp_b ,
1225+ stride_lp_h ,
12361226 stride_ob ,
12371227 stride_oh ,
12381228 stride_om ,
@@ -1242,40 +1232,25 @@ def _sdpa_decode_reduce_kernel(
12421232 pid = tl .program_id (axis = 0 )
12431233 offs_d = tl .arange (0 , HEAD_DIM )
12441234
1245- # pid indexes into flattened (B, H_q). Partial buffers are allocated
1246- # contiguous in _launch_decode_splitk, so pid * stride_*_h is valid.
1247- # Find global max across all splits
1248- m_global = tl .full ([1 ], - float ("inf" ), dtype = tl .float32 )
1249- for s in tl .range (0 , num_splits ):
1250- m_ptr = M_partial_ptr + s * stride_mp_s + pid * stride_mp_h
1251- m_s = tl .load (m_ptr )
1252- m_global = tl .maximum (m_global , m_s )
1253-
1254- # Accumulate rescaled outputs
1235+ # FlashDecoding++ async softmax: no rescaling needed, just sum partials
12551236 acc = tl .zeros ([HEAD_DIM ], dtype = tl .float32 )
12561237 l_global = tl .zeros ([1 ], dtype = tl .float32 )
1238+
12571239 for s in tl .range (0 , num_splits ):
1258- m_ptr = M_partial_ptr + s * stride_mp_s + pid * stride_mp_h
1259- l_ptr = L_partial_ptr + s * stride_mp_s + pid * stride_mp_h
1240+ l_ptr = L_partial_ptr + s * stride_lp_s + pid * stride_lp_h
12601241 o_ptrs = O_partial_ptr + (
12611242 s * stride_op_s + pid * stride_op_h + offs_d * stride_op_d
12621243 )
12631244
1264- m_s = tl .load (m_ptr )
12651245 l_s = tl .load (l_ptr )
12661246 o_s = tl .load (o_ptrs )
12671247
1268- safe_diff = tl .where (m_global > - float ("inf" ), m_s - m_global , 0.0 )
1269- scale = tl .exp (safe_diff ).to (tl .float32 )
1270- acc += o_s * scale
1271- l_global += l_s * scale
1248+ acc += o_s
1249+ l_global += l_s
12721250
12731251 inv_l = tl .where (l_global > 0 , 1.0 / l_global , 0.0 )
12741252 acc = acc * inv_l
12751253
1276- # pid = b*H_q + h_q. For contiguous output [B, H_q, 1, D] with L_q=1,
1277- # stride_ob == H_q * stride_oh, so pid * stride_oh is correct.
1278- # This relies on `out` being freshly allocated and contiguous.
12791254 o_out_ptrs = O_ptr + pid * stride_oh + offs_d * stride_od
12801255 tl .store (o_out_ptrs , acc .to (tl .bfloat16 ))
12811256
@@ -1297,16 +1272,14 @@ def _launch_decode_splitk(
12971272 stride_mq : int ,
12981273 stride_mk : int ,
12991274 num_groups : int ,
1275+ phi : float ,
13001276) -> None :
13011277 num_splits = min (max (triton .cdiv (L_kv , 256 ), 1 ), 128 )
13021278 chunk_size = triton .cdiv (L_kv , num_splits )
13031279
13041280 O_partial = torch .empty (
13051281 (num_splits , B , H_q , D ), device = query .device , dtype = torch .float32
13061282 )
1307- M_partial = torch .full (
1308- (num_splits , B , H_q ), - float ("inf" ), device = query .device , dtype = torch .float32
1309- )
13101283 L_partial = torch .zeros (
13111284 (num_splits , B , H_q ), device = query .device , dtype = torch .float32
13121285 )
@@ -1316,15 +1289,14 @@ def _launch_decode_splitk(
13161289 stride_vb , stride_vh , stride_vn , stride_vd = value .stride ()
13171290 stride_ob , stride_oh , stride_om , stride_od = out .stride ()
13181291 stride_op_s , stride_op_b , stride_op_h , stride_op_d = O_partial .stride ()
1319- stride_mp_s , stride_mp_b , stride_mp_h = M_partial .stride ()
1292+ stride_lp_s , stride_lp_b , stride_lp_h = L_partial .stride ()
13201293
13211294 grid_split = (num_splits , B * H_kv )
13221295 wrap_triton (_sdpa_decode_splitk_kernel )[grid_split ](
13231296 query ,
13241297 key ,
13251298 value ,
13261299 O_partial ,
1327- M_partial ,
13281300 L_partial ,
13291301 Mask_ptr if HAS_MASK else 0 ,
13301302 B ,
@@ -1346,13 +1318,14 @@ def _launch_decode_splitk(
13461318 stride_op_b ,
13471319 stride_op_h ,
13481320 stride_op_d ,
1349- stride_mp_s ,
1350- stride_mp_b ,
1351- stride_mp_h ,
1321+ stride_lp_s ,
1322+ stride_lp_b ,
1323+ stride_lp_h ,
13521324 stride_mb ,
13531325 stride_mq ,
13541326 stride_mk ,
13551327 sm_scale ,
1328+ phi ,
13561329 chunk_size ,
13571330 HAS_MASK = HAS_MASK ,
13581331 HEAD_DIM = D ,
@@ -1363,17 +1336,16 @@ def _launch_decode_splitk(
13631336 grid_reduce = (B * H_q ,)
13641337 wrap_triton (_sdpa_decode_reduce_kernel )[grid_reduce ](
13651338 O_partial ,
1366- M_partial ,
13671339 L_partial ,
13681340 out ,
13691341 num_splits ,
13701342 stride_op_s ,
13711343 stride_op_b ,
13721344 stride_op_h ,
13731345 stride_op_d ,
1374- stride_mp_s ,
1375- stride_mp_b ,
1376- stride_mp_h ,
1346+ stride_lp_s ,
1347+ stride_lp_b ,
1348+ stride_lp_h ,
13771349 stride_ob ,
13781350 stride_oh ,
13791351 stride_om ,
@@ -1394,9 +1366,13 @@ def sdpa_decode_splitk(
13941366 is_causal : bool = False ,
13951367 scale : float = 0.0 ,
13961368 enable_gqa : bool = False ,
1369+ phi : float = 5.0 ,
13971370) -> torch .Tensor :
13981371 """Split-K flash-decoding SDPA for L_q=1 (decode step).
13991372
1373+ Uses FlashDecoding++ async softmax with unified maximum value (phi)
1374+ to eliminate per-split max tracking and cross-split rescaling.
1375+
14001376 Signature mirrors sdpa() for drop-in use with torch.cond dispatch.
14011377 enable_gqa is accepted but ignored — GQA is handled natively via
14021378 H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1.
@@ -1452,6 +1428,7 @@ def sdpa_decode_splitk(
14521428 stride_mq ,
14531429 stride_mk ,
14541430 num_groups ,
1431+ phi ,
14551432 )
14561433 return out
14571434
@@ -1466,6 +1443,7 @@ def _sdpa_decode_splitk_abstract(
14661443 is_causal : bool = False ,
14671444 scale : float = 0.0 ,
14681445 enable_gqa : bool = False ,
1446+ phi : float = 5.0 ,
14691447) -> torch .Tensor :
14701448 assert query .dtype == key .dtype == value .dtype , "Q, K, V must have the same dtype"
14711449 B , H_q , L_q , D = query .shape
0 commit comments