@@ -83,11 +83,14 @@ __global__ void GetMaxLenKernel(const int* seq_lens_decoder,
8383 }
8484}
8585
86- template <int min_chunk_size, uint32_t block_size>
86+ template <int min_chunk_size,
87+ int chunk_step,
88+ uint32_t block_size,
89+ int max_chunk_size>
8790__global__ void config_decode_attn (const int * __restrict__ seq_lens_this_time,
8891 const int * __restrict__ seq_lens_encoder,
8992 const int * __restrict__ seq_lens_decoder,
90- int * __restrict__ block_indices,
93+ int4 * __restrict__ block_indices,
9194 int * __restrict__ num_blocks,
9295 int * __restrict__ chunk_size,
9396 const int bsz,
@@ -96,166 +99,168 @@ __global__ void config_decode_attn(const int* __restrict__ seq_lens_this_time,
9699 const int q_tile_size,
97100 const int max_tokens_per_batch,
98101 const int config_gridx) {
99- // one block one warp
100102 const int tid = threadIdx .x , wid = threadIdx .y ;
101103 const uint32_t warp_size = blockDim .x ;
102104 __shared__ int num_block_all_shared[block_size];
103105 __shared__ int chunk_size_res[1 ];
104- __shared__ int use_scheme_e_res[1 ];
105106
106107 const int lane_id = tid + wid * warp_size;
107108
108- // Step 1: compute num_block_all WITHOUT chunk splitting (Scheme E)
109+ // Merged Step 1+2: single bsz loop computing both Scheme E metrics and
110+ // split-KV block counts per lane. Avoids redundant seq_lens reads and
111+ // shared intermediate values (token_num, kv_len, q_tile_num).
112+ const int target_blocks = config_gridx / 3 ; // sm_count * 3
113+ // Search chunk_size from 512 with step 128: {512, 640, 768, ...}
114+
115+ const int cur_chunk_size =
116+ min (min_chunk_size + lane_id * chunk_step, max_chunk_size);
109117 int num_block_no_chunk = 0 ;
118+ int max_kv_len_no_chunk = 0 ;
119+ int num_block_all = 0 ;
110120 for (int bid = 0 ; bid < bsz; bid++) {
111121 if (seq_lens_this_time[bid] <= 0 || seq_lens_encoder[bid] > 0 ) {
112122 continue ;
113123 }
114- int token_num_cur_batch = seq_lens_this_time[bid];
115- int q_tile_num = div_up (token_num_cur_batch * group_size, q_tile_size);
124+ const int token_num_cur_batch = seq_lens_this_time[bid];
125+ const int kv_len_cur_batch = seq_lens_decoder[bid] + token_num_cur_batch;
126+ const int q_tile_num =
127+ div_up (token_num_cur_batch * group_size, q_tile_size);
116128 num_block_no_chunk += q_tile_num * kv_num_heads;
129+ max_kv_len_no_chunk = max (max_kv_len_no_chunk, kv_len_cur_batch);
130+ const int kv_chunk_num = div_up (kv_len_cur_batch, cur_chunk_size);
131+ num_block_all += q_tile_num * kv_chunk_num * kv_num_heads;
117132 }
133+ num_block_all_shared[lane_id] = num_block_all;
134+ __syncthreads ();
118135
119- // Step 2: decide mode — Scheme E if enough blocks, else split-kv
120- // Adaptive strategy: prefer Scheme E (zero merge overhead) when blocks
121- // already fill all SMs. When splitting is needed, use the LARGEST
122- // chunk_size that still creates enough blocks to fill SMs, minimizing
123- // merge count while ensuring SM utilization.
124- // Target: at least sm_count*3 blocks to ensure 3+ waves for GPU utilization.
125- // Too few waves (e.g. 1 waves with target=sm_count*3) leaves SMs idle between
126- // waves; 3 waves is a balanced tradeoff between utilization and merge
127- // overhead.
128- const int target_blocks = config_gridx / 2 ; // sm_count * 3
129- const bool use_scheme_e = (num_block_no_chunk >= target_blocks);
130-
131- if (use_scheme_e) {
132- // Scheme E: no chunk splitting, chunk_size = INT_MAX
133- if (tid == 0 && wid == 0 ) {
134- num_blocks[0 ] = num_block_no_chunk;
135- chunk_size[0 ] = INT_MAX;
136- chunk_size_res[0 ] = INT_MAX;
137- use_scheme_e_res[0 ] = 1 ;
138- }
139- } else {
140- // Split-kv: find the LARGEST chunk_size whose total blocks >= target_blocks
141- // This minimizes merge count while ensuring SM utilization.
142- int cur_chunk_size = min_chunk_size * (lane_id + 1 );
143- int num_block_all = 0 ;
144- for (int bid = 0 ; bid < bsz; bid++) {
145- if (seq_lens_this_time[bid] <= 0 || seq_lens_encoder[bid] > 0 ) {
146- continue ;
136+ // Step 3: find best chunk_size, then decide Scheme E vs split-KV
137+ if (tid == 0 && wid == 0 ) {
138+ // Strategy:
139+ // 1. Must fill target_blocks (2*sm_count) to maintain SM concurrency
140+ // 2. Among valid choices, prefer minimum per-SM max KV traffic
141+ // (= waves * chunk_size, since kernel time = slowest SM)
142+ // 3. Within 5% of minimum KV traffic, prefer larger chunk_size
143+ int chunk_size_best = min_chunk_size;
144+ int num_block_all_best = num_block_all_shared[0 ];
145+ // Step 1: find minimum kv_traffic among chunk_sizes that fill SMs
146+ int64_t kv_traffic_min = INT64_MAX;
147+ for (int i = 0 ; i < static_cast <int >(block_size); i++) {
148+ const int nb = num_block_all_shared[i];
149+ if (nb < target_blocks) continue ;
150+ const int cs = min (min_chunk_size + i * chunk_step, max_chunk_size);
151+ const int w = div_up (nb, target_blocks);
152+ const int64_t kv_traffic = static_cast <int64_t >(w) * cs;
153+ if (kv_traffic < kv_traffic_min) {
154+ kv_traffic_min = kv_traffic;
147155 }
148- int token_num_cur_batch = seq_lens_this_time[bid];
149- int kv_len_cur_batch = seq_lens_decoder[bid] + token_num_cur_batch;
150- int q_tile_num = div_up (token_num_cur_batch * group_size, q_tile_size);
151- int kv_chunk_num = div_up (kv_len_cur_batch, cur_chunk_size);
152- num_block_all += q_tile_num * kv_chunk_num * kv_num_heads;
153156 }
154- num_block_all_shared[lane_id] = num_block_all;
155- __syncthreads ();
156-
157- int chunk_size_best;
158- int num_block_all_best;
159- if (tid == 0 && wid == 0 ) {
160- // Search from largest chunk_size to smallest:
161- // pick the first (largest) chunk_size with enough blocks
162- chunk_size_best = min_chunk_size; // fallback: smallest chunk
157+ // Step 2: if no chunk_size fills SMs, fall back to smallest
158+ if (kv_traffic_min == INT64_MAX) {
159+ chunk_size_best = min_chunk_size;
163160 num_block_all_best = num_block_all_shared[0 ];
161+ } else {
162+ // Step 3: scan from largest chunk_size downward; accept the first
163+ // one that fills SMs AND has kv_traffic within 20% of minimum
164164 for (int i = block_size - 1 ; i >= 0 ; i--) {
165- if (num_block_all_shared[i] >= target_blocks) {
166- chunk_size_best = min_chunk_size * (i + 1 );
167- num_block_all_best = num_block_all_shared[i];
165+ const int nb = num_block_all_shared[i];
166+ if (nb < target_blocks) continue ;
167+ const int cs = min (min_chunk_size + i * chunk_step, max_chunk_size);
168+ const int w = div_up (nb, target_blocks);
169+ const int64_t kv_traffic = static_cast <int64_t >(w) * cs;
170+ if (kv_traffic <= kv_traffic_min + kv_traffic_min / 4 ) {
171+ chunk_size_best = cs;
172+ num_block_all_best = nb;
168173 break ;
169174 }
170175 }
171- // If even the smallest chunk doesn't reach target_blocks,
172- // use the smallest chunk to maximize parallelism
173- if (num_block_all_best < target_blocks) {
174- chunk_size_best = min_chunk_size;
175- num_block_all_best = num_block_all_shared[0 ];
176+ }
177+
178+ // Decide Scheme E: prefer when blocks fill SMs AND estimated latency
179+ // is no worse than split-KV.
180+ // Scheme E: waves_E * max_kv_len (few heavy blocks)
181+ // Split-KV: waves_split * chunk_size_best (many light blocks)
182+ // When no splitting is needed (num_block_all_best == num_block_no_chunk),
183+ // Scheme E is strictly better (saves merge overhead).
184+ bool use_scheme_e = false ;
185+ if (num_block_no_chunk >= target_blocks) {
186+ if (num_block_all_best == num_block_no_chunk) {
187+ use_scheme_e = true ;
188+ } else {
189+ // target_blocks = sm_count * 3 ≈ CTAs per wave (sm_count × occupancy).
190+ // Using target_blocks as denominator correctly accounts for occupancy
191+ // in wave count estimation.
192+ const int waves_e = div_up (num_block_no_chunk, target_blocks);
193+ const int waves_split = div_up (num_block_all_best, target_blocks);
194+ use_scheme_e = (static_cast <int64_t >(waves_e) * max_kv_len_no_chunk <=
195+ static_cast <int64_t >(waves_split) * chunk_size_best);
176196 }
197+ }
198+
199+ if (use_scheme_e) {
200+ num_blocks[0 ] = num_block_no_chunk;
201+ chunk_size[0 ] = INT_MAX;
202+ chunk_size_res[0 ] = INT_MAX;
203+ } else {
177204 num_blocks[0 ] = num_block_all_best;
178205 chunk_size[0 ] = chunk_size_best;
179206 chunk_size_res[0 ] = chunk_size_best;
180- use_scheme_e_res[0 ] = 0 ;
181207 }
182208 }
183209
184210 __syncthreads ();
185211 if (wid == 0 ) {
186- const bool use_scheme_e_local = use_scheme_e_res[0 ];
187- const int chunk_size_best = chunk_size_res[0 ];
212+ const int chunk_size_final = chunk_size_res[0 ];
188213
189- // one block one warp
190214 int prev_offset = 0 ;
191- // loop on warp tile:[base, base+32)
192215 for (int base = 0 ; base < bsz; base += warp_size) {
193216 const int bid = base + tid;
217+ int num_block_cur = 0 ;
194218 int q_tile_num = 0 ;
195219 int kv_chunk_num = 0 ;
196220
197- // calculate loop_times for bid
198- int num_block_all = 0 ;
199221 if (bid < bsz) {
200222 int token_num_cur_batch = seq_lens_this_time[bid];
201223 if (seq_lens_encoder && seq_lens_encoder[bid] > 0 ) {
202224 token_num_cur_batch = 0 ;
203225 }
204226 q_tile_num = div_up (token_num_cur_batch * group_size, q_tile_size);
205- if (use_scheme_e_local) {
206- num_block_all += q_tile_num * kv_num_heads;
207- } else {
208- int kv_len_cur_batch = seq_lens_decoder[bid] + token_num_cur_batch;
209- kv_chunk_num = div_up (kv_len_cur_batch, chunk_size_best);
210- num_block_all += q_tile_num * kv_chunk_num * kv_num_heads;
211- }
227+ const int kv_len_cur_batch =
228+ seq_lens_decoder[bid] + token_num_cur_batch;
229+ kv_chunk_num = div_up (kv_len_cur_batch, chunk_size_final);
230+ num_block_cur = q_tile_num * kv_chunk_num * kv_num_heads;
212231 }
213232
214- // prefix sum for each lane, get the start offset in this tile
215- // inclusive scan
216- int x = num_block_all;
233+ // inclusive prefix sum
234+ int x = num_block_cur;
217235 for (int offset = 1 ; offset < warp_size; offset <<= 1 ) {
218236 int y = __shfl_up_sync (0xffffffff , x, offset);
219237 if (tid >= offset) x += y;
220238 }
221- // exclusive prefix sum
222- int bid_offset = x - num_block_all;
239+ int bid_offset = x - num_block_cur;
223240 int tile_sum = __shfl_sync (0xffffffff , x, warp_size - 1 );
224241
225- // write batch_ids and tile_ids_per_batch
226- if (bid < bsz && num_block_all > 0 ) {
227- int write_base = prev_offset + bid_offset;
228- if (use_scheme_e_local) {
229- for (int kv_head_id = 0 ; kv_head_id < kv_num_heads; kv_head_id++) {
242+ // Write block_indices using int4 vectorized stores.
243+ // Each entry is exactly 4 ints (bid, kv_head_id, kv_chunk_id, q_tile_id),
244+ // matching int4 layout. This reduces 4 scalar stores to 1 vector store.
245+ if (bid < bsz && num_block_cur > 0 ) {
246+ int4 * write_ptr = block_indices + prev_offset + bid_offset;
247+ int flat_idx = 0 ;
248+ const int kv_chunk_num_x_q_tile_num = kv_chunk_num * q_tile_num;
249+ #pragma unroll 2
250+ for (int kv_head_id = 0 ; kv_head_id < kv_num_heads; kv_head_id++) {
251+ const int head_base = kv_head_id * kv_chunk_num_x_q_tile_num;
252+ #pragma unroll 2
253+ for (int kv_chunk_id = 0 ; kv_chunk_id < kv_chunk_num; kv_chunk_id++) {
254+ const int chunk_base = head_base + kv_chunk_id * q_tile_num;
255+ #pragma unroll
230256 for (int q_tile_id = 0 ; q_tile_id < q_tile_num; q_tile_id++) {
231- int idx =
232- write_base * 4 + (kv_head_id * q_tile_num + q_tile_id) * 4 ;
233- block_indices[idx] = bid;
234- block_indices[idx + 1 ] = kv_head_id;
235- block_indices[idx + 2 ] = 0 ;
236- block_indices[idx + 3 ] = q_tile_id;
237- }
238- }
239- } else {
240- for (int kv_head_id = 0 ; kv_head_id < kv_num_heads; kv_head_id++) {
241- for (int kv_chunk_id = 0 ; kv_chunk_id < kv_chunk_num;
242- kv_chunk_id++) {
243- for (int q_tile_id = 0 ; q_tile_id < q_tile_num; q_tile_id++) {
244- int idx =
245- write_base * 4 +
246- ((kv_head_id * kv_chunk_num + kv_chunk_id) * q_tile_num +
247- q_tile_id) *
248- 4 ;
249- block_indices[idx] = bid;
250- block_indices[idx + 1 ] = kv_head_id;
251- block_indices[idx + 2 ] = kv_chunk_id;
252- block_indices[idx + 3 ] = q_tile_id;
253- }
257+ write_ptr[flat_idx] =
258+ make_int4 (bid, kv_head_id, kv_chunk_id, q_tile_id);
259+ flat_idx++;
254260 }
255261 }
256262 }
257263 }
258- // for next warp tile
259264 prev_offset += tile_sum;
260265 }
261266 }
@@ -318,12 +323,16 @@ void ConfigForAttention(
318323
319324 const int q_tile_size = 16 ;
320325 dim3 blocks (32 , 4 );
326+ // Cast block_indices to int4* for vectorized stores.
327+ // Each block_indices entry is 4 ints = 16 bytes = sizeof(int4),
328+ // and block_num * 4 ints = block_num int4s, so the reinterpret is valid.
329+ int4 * block_indices_i4 = reinterpret_cast <int4 *>(block_indices.data <int >());
321330 if (cache_quant_type == " cache_int4_zp" ) {
322- config_decode_attn<256 , 128 >
331+ config_decode_attn<512 , 256 , 128 , 32768 >
323332 <<<1 , blocks, 0 , stream>>> (seq_lens_this_time.data <int >(),
324333 seq_lens_encoder.data <int >(),
325334 seq_lens_decoder.data <int >(),
326- block_indices. data < int >() ,
335+ block_indices_i4 ,
327336 num_blocks.data <int >(),
328337 chunk_size.data <int >(),
329338 bsz,
@@ -333,11 +342,11 @@ void ConfigForAttention(
333342 max_tokens_per_batch,
334343 config_gridx);
335344 } else {
336- config_decode_attn<128 , 128 >
345+ config_decode_attn<512 , 128 , 128 , 16384 >
337346 <<<1 , blocks, 0 , stream>>> (seq_lens_this_time.data <int >(),
338347 seq_lens_encoder.data <int >(),
339348 seq_lens_decoder.data <int >(),
340- block_indices. data < int >() ,
349+ block_indices_i4 ,
341350 num_blocks.data <int >(),
342351 chunk_size.data <int >(),
343352 bsz,
0 commit comments