@@ -97,26 +97,6 @@ def round_up(a, b):
9797 return nopack_eff < 0.9 * pack_eff
9898
9999
100- def _compute_num_splits (L_kv : int , B : int , H_kv : int , device : torch .device ) -> int :
101- """Compute optimal KV-split count for flash-decoding on A100 / RTX 4090.
102-
103- Balances GPU occupancy against per-split work:
104- * Targets >= 2 full SM waves (2 x SM-count CTAs) so the GPU stays
105- saturated even with tail effects.
106- * Enforces a minimum of 64 KV tokens per split to amortise
107- kernel-launch and reduce overhead.
108- * Caps at 128 splits to bound reduce-kernel cost.
109-
110- A100 -> 108 SMs, RTX 4090 -> 128 SMs. The heuristic adapts to
111- whatever GPU is present via ``torch.cuda.get_device_properties``.
112- """
113- sm_count = torch .cuda .get_device_properties (device ).multi_processor_count
114- ctas_per_split = max (B * H_kv , 1 )
115- target = max (triton .cdiv (sm_count * 2 , ctas_per_split ), 1 )
116- max_by_work = max (L_kv // 64 , 1 )
117- return min (target , max_by_work , 128 )
118-
119-
120100def _validate_qkv_shapes (
121101 query : torch .Tensor ,
122102 key : torch .Tensor ,
@@ -1091,8 +1071,6 @@ def _sdpa_abstract(
10911071 triton .Config ({"BLOCK_N" : 128 }, num_warps = 8 , num_stages = 2 ),
10921072 triton .Config ({"BLOCK_N" : 256 }, num_warps = 4 , num_stages = 2 ),
10931073 triton .Config ({"BLOCK_N" : 256 }, num_warps = 8 , num_stages = 2 ),
1094- triton .Config ({"BLOCK_N" : 256 }, num_warps = 8 , num_stages = 3 ),
1095- triton .Config ({"BLOCK_N" : 128 }, num_warps = 8 , num_stages = 3 ),
10961074 ],
10971075 key = ["Lk" , "HEAD_DIM" , "NUM_GROUPS" , "HAS_MASK" ],
10981076)
@@ -1129,24 +1107,19 @@ def _sdpa_decode_splitk_kernel(
11291107 stride_mb ,
11301108 stride_mq ,
11311109 stride_mk ,
1132- sm_scale_log2 : tl .float32 ,
1133- phi_log2 : tl .float32 ,
1110+ sm_scale : tl .float32 ,
1111+ phi : tl .float32 ,
11341112 chunk_size ,
11351113 HAS_MASK : tl .constexpr ,
11361114 BLOCK_N : tl .constexpr ,
11371115 HEAD_DIM : tl .constexpr ,
11381116 NUM_GROUPS : tl .constexpr ,
11391117 BLOCK_G : tl .constexpr ,
1140- BATCH_ONE : tl .constexpr ,
11411118):
11421119 split_id = tl .program_id (axis = 0 )
11431120 pid_bh = tl .program_id (axis = 1 )
1144- if BATCH_ONE :
1145- b = 0
1146- h_kv = pid_bh
1147- else :
1148- b = pid_bh // H_kv
1149- h_kv = pid_bh % H_kv
1121+ b = pid_bh // H_kv
1122+ h_kv = pid_bh % H_kv
11501123
11511124 start_n = split_id * chunk_size
11521125 end_n = tl .minimum (start_n + chunk_size , Lk )
@@ -1163,11 +1136,9 @@ def _sdpa_decode_splitk_kernel(
11631136 + 0 * stride_qm
11641137 + offs_d [None , :] * stride_qd
11651138 )
1166- q = tl .load (q_ptrs , mask = g_valid [:, None ], other = 0.0 )
1167- # Pre-scale Q so the inner loop avoids a per-element multiply on [G,N] QK
1168- q = (q .to (tl .float32 ) * sm_scale_log2 ).to (tl .bfloat16 )
1139+ q = tl .load (q_ptrs , mask = g_valid [:, None ], other = 0.0 ).to (tl .bfloat16 )
11691140
1170- # FlashDecoding++ async softmax with exp2: all scores in log2 space
1141+ # FlashDecoding++ async softmax: use unified max phi instead of tracking m_i
11711142 l_i = tl .zeros ([BLOCK_G ], dtype = tl .float32 )
11721143 acc = tl .zeros ([BLOCK_G , HEAD_DIM ], dtype = tl .float32 )
11731144
@@ -1185,8 +1156,8 @@ def _sdpa_decode_splitk_kernel(
11851156 )
11861157 k = tl .load (k_ptrs , mask = n_valid [:, None ], other = 0.0 ).to (tl .bfloat16 )
11871158
1188- # QK: [BLOCK_G, BLOCK_N] — Q already scaled, result in log2 space
1189- qk = tl .dot (q , tl .trans (k )).to (tl .float32 )
1159+ # QK: [BLOCK_G, BLOCK_N]
1160+ qk = ( tl .dot (q , tl .trans (k )). to ( tl . float32 ) * sm_scale ).to (tl .float32 )
11901161
11911162 # Mask out-of-bounds KV positions
11921163 qk = tl .where (
@@ -1204,9 +1175,9 @@ def _sdpa_decode_splitk_kernel(
12041175 mask_block , qk , tl .full (qk .shape , - float ("inf" ), dtype = tl .float32 )
12051176 )
12061177
1207- # FlashDecoding++ async softmax: exp2 maps to single PTX ex2 instruction
1208- safe_diff = tl .where (qk > - float ("inf" ), qk - phi_log2 , - float ("inf" ))
1209- p_f32 = tl .math . exp2 (safe_diff ).to (tl .float32 )
1178+ # FlashDecoding++ async softmax: subtract unified phi instead of local max
1179+ safe_diff = tl .where (qk > - float ("inf" ), qk - phi , - float ("inf" ))
1180+ p_f32 = tl .exp (safe_diff ).to (tl .float32 )
12101181 l_ij = tl .sum (p_f32 , axis = 1 ).to (tl .float32 )
12111182
12121183 v_ptrs = V_ptr + (
@@ -1263,7 +1234,7 @@ def _sdpa_decode_reduce_kernel(
12631234 acc = tl .zeros ([HEAD_DIM ], dtype = tl .float32 )
12641235 l_global = tl .zeros ([1 ], dtype = tl .float32 )
12651236
1266- for s in tl .range (0 , num_splits , num_stages = 2 ):
1237+ for s in tl .range (0 , num_splits ):
12671238 l_ptr = L_partial_ptr + s * stride_lp_s + pid * stride_lp_h
12681239 o_ptrs = O_partial_ptr + (
12691240 s * stride_op_s + pid * stride_op_h + offs_d * stride_op_d
@@ -1282,6 +1253,9 @@ def _sdpa_decode_reduce_kernel(
12821253 tl .store (o_out_ptrs , acc .to (tl .bfloat16 ))
12831254
12841255
1256+ _splitk_buf_cache : dict = {}
1257+
1258+
12851259def _launch_decode_splitk (
12861260 query : torch .Tensor ,
12871261 key : torch .Tensor ,
@@ -1301,19 +1275,23 @@ def _launch_decode_splitk(
13011275 num_groups : int ,
13021276 phi : float ,
13031277) -> None :
1304- num_splits = _compute_num_splits ( L_kv , B , H_kv , query . device )
1278+ num_splits = min ( max ( triton . cdiv ( L_kv , 256 ), 1 ), 128 )
13051279 chunk_size = triton .cdiv (L_kv , num_splits )
13061280
1307- _LOG2E = 1.4426950408889634
1308- sm_scale_log2 = sm_scale * _LOG2E
1309- phi_log2 = phi * _LOG2E
1310-
1311- O_partial = torch .empty (
1312- (num_splits , B , H_q , D ), device = query .device , dtype = torch .float32
1313- )
1314- L_partial = torch .zeros (
1315- (num_splits , B , H_q ), device = query .device , dtype = torch .float32
1316- )
1281+ # Cache partial buffers to avoid CUDA allocator overhead per call.
1282+ # The split kernel fully writes every entry before the reduce kernel
1283+ # reads, so stale data from a previous call is harmless.
1284+ buf_key = (num_splits , B , H_q , D , query .device .index )
1285+ bufs = _splitk_buf_cache .get (buf_key )
1286+ if bufs is None :
1287+ bufs = (
1288+ torch .empty (
1289+ (num_splits , B , H_q , D ), device = query .device , dtype = torch .float32
1290+ ),
1291+ torch .empty ((num_splits , B , H_q ), device = query .device , dtype = torch .float32 ),
1292+ )
1293+ _splitk_buf_cache [buf_key ] = bufs
1294+ O_partial , L_partial = bufs
13171295
13181296 stride_qb , stride_qh , stride_qm , stride_qd = query .stride ()
13191297 stride_kb , stride_kh , stride_kn , stride_kd = key .stride ()
@@ -1355,14 +1333,13 @@ def _launch_decode_splitk(
13551333 stride_mb ,
13561334 stride_mq ,
13571335 stride_mk ,
1358- sm_scale_log2 ,
1359- phi_log2 ,
1336+ sm_scale ,
1337+ phi ,
13601338 chunk_size ,
13611339 HAS_MASK = HAS_MASK ,
13621340 HEAD_DIM = D ,
13631341 NUM_GROUPS = num_groups ,
13641342 BLOCK_G = _next_power_of_2_unclamped (num_groups ),
1365- BATCH_ONE = B == 1 ,
13661343 )
13671344
13681345 grid_reduce = (B * H_q ,)
0 commit comments