Skip to content

Commit b621b28

Browse files
committed
opt memory
1 parent 7cd054c commit b621b28

9 files changed

Lines changed: 250 additions & 181 deletions

File tree

custom_ops/gpu_ops/decode_unified_attention/config_for_attention.cu

Lines changed: 118 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,14 @@ __global__ void GetMaxLenKernel(const int* seq_lens_decoder,
8383
}
8484
}
8585

86-
template <int min_chunk_size, uint32_t block_size>
86+
template <int min_chunk_size,
87+
int chunk_step,
88+
uint32_t block_size,
89+
int max_chunk_size>
8790
__global__ void config_decode_attn(const int* __restrict__ seq_lens_this_time,
8891
const int* __restrict__ seq_lens_encoder,
8992
const int* __restrict__ seq_lens_decoder,
90-
int* __restrict__ block_indices,
93+
int4* __restrict__ block_indices,
9194
int* __restrict__ num_blocks,
9295
int* __restrict__ chunk_size,
9396
const int bsz,
@@ -96,166 +99,168 @@ __global__ void config_decode_attn(const int* __restrict__ seq_lens_this_time,
9699
const int q_tile_size,
97100
const int max_tokens_per_batch,
98101
const int config_gridx) {
99-
// one block one warp
100102
const int tid = threadIdx.x, wid = threadIdx.y;
101103
const uint32_t warp_size = blockDim.x;
102104
__shared__ int num_block_all_shared[block_size];
103105
__shared__ int chunk_size_res[1];
104-
__shared__ int use_scheme_e_res[1];
105106

106107
const int lane_id = tid + wid * warp_size;
107108

108-
// Step 1: compute num_block_all WITHOUT chunk splitting (Scheme E)
109+
// Merged Step 1+2: single bsz loop computing both Scheme E metrics and
110+
// split-KV block counts per lane. Avoids redundant seq_lens reads and
111+
// shared intermediate values (token_num, kv_len, q_tile_num).
112+
const int target_blocks = config_gridx / 3; // sm_count * 3
113+
// Search chunk_size from 512 with step 128: {512, 640, 768, ...}
114+
115+
const int cur_chunk_size =
116+
min(min_chunk_size + lane_id * chunk_step, max_chunk_size);
109117
int num_block_no_chunk = 0;
118+
int max_kv_len_no_chunk = 0;
119+
int num_block_all = 0;
110120
for (int bid = 0; bid < bsz; bid++) {
111121
if (seq_lens_this_time[bid] <= 0 || seq_lens_encoder[bid] > 0) {
112122
continue;
113123
}
114-
int token_num_cur_batch = seq_lens_this_time[bid];
115-
int q_tile_num = div_up(token_num_cur_batch * group_size, q_tile_size);
124+
const int token_num_cur_batch = seq_lens_this_time[bid];
125+
const int kv_len_cur_batch = seq_lens_decoder[bid] + token_num_cur_batch;
126+
const int q_tile_num =
127+
div_up(token_num_cur_batch * group_size, q_tile_size);
116128
num_block_no_chunk += q_tile_num * kv_num_heads;
129+
max_kv_len_no_chunk = max(max_kv_len_no_chunk, kv_len_cur_batch);
130+
const int kv_chunk_num = div_up(kv_len_cur_batch, cur_chunk_size);
131+
num_block_all += q_tile_num * kv_chunk_num * kv_num_heads;
117132
}
133+
num_block_all_shared[lane_id] = num_block_all;
134+
__syncthreads();
118135

119-
// Step 2: decide mode — Scheme E if enough blocks, else split-kv
120-
// Adaptive strategy: prefer Scheme E (zero merge overhead) when blocks
121-
// already fill all SMs. When splitting is needed, use the LARGEST
122-
// chunk_size that still creates enough blocks to fill SMs, minimizing
123-
// merge count while ensuring SM utilization.
124-
// Target: at least sm_count*3 blocks to ensure 3+ waves for GPU utilization.
125-
// Too few waves (e.g. 1 waves with target=sm_count*3) leaves SMs idle between
126-
// waves; 3 waves is a balanced tradeoff between utilization and merge
127-
// overhead.
128-
const int target_blocks = config_gridx / 2; // sm_count * 3
129-
const bool use_scheme_e = (num_block_no_chunk >= target_blocks);
130-
131-
if (use_scheme_e) {
132-
// Scheme E: no chunk splitting, chunk_size = INT_MAX
133-
if (tid == 0 && wid == 0) {
134-
num_blocks[0] = num_block_no_chunk;
135-
chunk_size[0] = INT_MAX;
136-
chunk_size_res[0] = INT_MAX;
137-
use_scheme_e_res[0] = 1;
138-
}
139-
} else {
140-
// Split-kv: find the LARGEST chunk_size whose total blocks >= target_blocks
141-
// This minimizes merge count while ensuring SM utilization.
142-
int cur_chunk_size = min_chunk_size * (lane_id + 1);
143-
int num_block_all = 0;
144-
for (int bid = 0; bid < bsz; bid++) {
145-
if (seq_lens_this_time[bid] <= 0 || seq_lens_encoder[bid] > 0) {
146-
continue;
136+
// Step 3: find best chunk_size, then decide Scheme E vs split-KV
137+
if (tid == 0 && wid == 0) {
138+
// Strategy:
139+
// 1. Must fill target_blocks (2*sm_count) to maintain SM concurrency
140+
// 2. Among valid choices, prefer minimum per-SM max KV traffic
141+
// (= waves * chunk_size, since kernel time = slowest SM)
142+
// 3. Within 5% of minimum KV traffic, prefer larger chunk_size
143+
int chunk_size_best = min_chunk_size;
144+
int num_block_all_best = num_block_all_shared[0];
145+
// Step 1: find minimum kv_traffic among chunk_sizes that fill SMs
146+
int64_t kv_traffic_min = INT64_MAX;
147+
for (int i = 0; i < static_cast<int>(block_size); i++) {
148+
const int nb = num_block_all_shared[i];
149+
if (nb < target_blocks) continue;
150+
const int cs = min(min_chunk_size + i * chunk_step, max_chunk_size);
151+
const int w = div_up(nb, target_blocks);
152+
const int64_t kv_traffic = static_cast<int64_t>(w) * cs;
153+
if (kv_traffic < kv_traffic_min) {
154+
kv_traffic_min = kv_traffic;
147155
}
148-
int token_num_cur_batch = seq_lens_this_time[bid];
149-
int kv_len_cur_batch = seq_lens_decoder[bid] + token_num_cur_batch;
150-
int q_tile_num = div_up(token_num_cur_batch * group_size, q_tile_size);
151-
int kv_chunk_num = div_up(kv_len_cur_batch, cur_chunk_size);
152-
num_block_all += q_tile_num * kv_chunk_num * kv_num_heads;
153156
}
154-
num_block_all_shared[lane_id] = num_block_all;
155-
__syncthreads();
156-
157-
int chunk_size_best;
158-
int num_block_all_best;
159-
if (tid == 0 && wid == 0) {
160-
// Search from largest chunk_size to smallest:
161-
// pick the first (largest) chunk_size with enough blocks
162-
chunk_size_best = min_chunk_size; // fallback: smallest chunk
157+
// Step 2: if no chunk_size fills SMs, fall back to smallest
158+
if (kv_traffic_min == INT64_MAX) {
159+
chunk_size_best = min_chunk_size;
163160
num_block_all_best = num_block_all_shared[0];
161+
} else {
162+
// Step 3: scan from largest chunk_size downward; accept the first
163+
// one that fills SMs AND has kv_traffic within 20% of minimum
164164
for (int i = block_size - 1; i >= 0; i--) {
165-
if (num_block_all_shared[i] >= target_blocks) {
166-
chunk_size_best = min_chunk_size * (i + 1);
167-
num_block_all_best = num_block_all_shared[i];
165+
const int nb = num_block_all_shared[i];
166+
if (nb < target_blocks) continue;
167+
const int cs = min(min_chunk_size + i * chunk_step, max_chunk_size);
168+
const int w = div_up(nb, target_blocks);
169+
const int64_t kv_traffic = static_cast<int64_t>(w) * cs;
170+
if (kv_traffic <= kv_traffic_min + kv_traffic_min / 4) {
171+
chunk_size_best = cs;
172+
num_block_all_best = nb;
168173
break;
169174
}
170175
}
171-
// If even the smallest chunk doesn't reach target_blocks,
172-
// use the smallest chunk to maximize parallelism
173-
if (num_block_all_best < target_blocks) {
174-
chunk_size_best = min_chunk_size;
175-
num_block_all_best = num_block_all_shared[0];
176+
}
177+
178+
// Decide Scheme E: prefer when blocks fill SMs AND estimated latency
179+
// is no worse than split-KV.
180+
// Scheme E: waves_E * max_kv_len (few heavy blocks)
181+
// Split-KV: waves_split * chunk_size_best (many light blocks)
182+
// When no splitting is needed (num_block_all_best == num_block_no_chunk),
183+
// Scheme E is strictly better (saves merge overhead).
184+
bool use_scheme_e = false;
185+
if (num_block_no_chunk >= target_blocks) {
186+
if (num_block_all_best == num_block_no_chunk) {
187+
use_scheme_e = true;
188+
} else {
189+
// target_blocks = sm_count * 3 ≈ CTAs per wave (sm_count × occupancy).
190+
// Using target_blocks as denominator correctly accounts for occupancy
191+
// in wave count estimation.
192+
const int waves_e = div_up(num_block_no_chunk, target_blocks);
193+
const int waves_split = div_up(num_block_all_best, target_blocks);
194+
use_scheme_e = (static_cast<int64_t>(waves_e) * max_kv_len_no_chunk <=
195+
static_cast<int64_t>(waves_split) * chunk_size_best);
176196
}
197+
}
198+
199+
if (use_scheme_e) {
200+
num_blocks[0] = num_block_no_chunk;
201+
chunk_size[0] = INT_MAX;
202+
chunk_size_res[0] = INT_MAX;
203+
} else {
177204
num_blocks[0] = num_block_all_best;
178205
chunk_size[0] = chunk_size_best;
179206
chunk_size_res[0] = chunk_size_best;
180-
use_scheme_e_res[0] = 0;
181207
}
182208
}
183209

184210
__syncthreads();
185211
if (wid == 0) {
186-
const bool use_scheme_e_local = use_scheme_e_res[0];
187-
const int chunk_size_best = chunk_size_res[0];
212+
const int chunk_size_final = chunk_size_res[0];
188213

189-
// one block one warp
190214
int prev_offset = 0;
191-
// loop on warp tile:[base, base+32)
192215
for (int base = 0; base < bsz; base += warp_size) {
193216
const int bid = base + tid;
217+
int num_block_cur = 0;
194218
int q_tile_num = 0;
195219
int kv_chunk_num = 0;
196220

197-
// calculate loop_times for bid
198-
int num_block_all = 0;
199221
if (bid < bsz) {
200222
int token_num_cur_batch = seq_lens_this_time[bid];
201223
if (seq_lens_encoder && seq_lens_encoder[bid] > 0) {
202224
token_num_cur_batch = 0;
203225
}
204226
q_tile_num = div_up(token_num_cur_batch * group_size, q_tile_size);
205-
if (use_scheme_e_local) {
206-
num_block_all += q_tile_num * kv_num_heads;
207-
} else {
208-
int kv_len_cur_batch = seq_lens_decoder[bid] + token_num_cur_batch;
209-
kv_chunk_num = div_up(kv_len_cur_batch, chunk_size_best);
210-
num_block_all += q_tile_num * kv_chunk_num * kv_num_heads;
211-
}
227+
const int kv_len_cur_batch =
228+
seq_lens_decoder[bid] + token_num_cur_batch;
229+
kv_chunk_num = div_up(kv_len_cur_batch, chunk_size_final);
230+
num_block_cur = q_tile_num * kv_chunk_num * kv_num_heads;
212231
}
213232

214-
// prefix sum for each lane, get the start offset in this tile
215-
// inclusive scan
216-
int x = num_block_all;
233+
// inclusive prefix sum
234+
int x = num_block_cur;
217235
for (int offset = 1; offset < warp_size; offset <<= 1) {
218236
int y = __shfl_up_sync(0xffffffff, x, offset);
219237
if (tid >= offset) x += y;
220238
}
221-
// exclusive prefix sum
222-
int bid_offset = x - num_block_all;
239+
int bid_offset = x - num_block_cur;
223240
int tile_sum = __shfl_sync(0xffffffff, x, warp_size - 1);
224241

225-
// write batch_ids and tile_ids_per_batch
226-
if (bid < bsz && num_block_all > 0) {
227-
int write_base = prev_offset + bid_offset;
228-
if (use_scheme_e_local) {
229-
for (int kv_head_id = 0; kv_head_id < kv_num_heads; kv_head_id++) {
242+
// Write block_indices using int4 vectorized stores.
243+
// Each entry is exactly 4 ints (bid, kv_head_id, kv_chunk_id, q_tile_id),
244+
// matching int4 layout. This reduces 4 scalar stores to 1 vector store.
245+
if (bid < bsz && num_block_cur > 0) {
246+
int4* write_ptr = block_indices + prev_offset + bid_offset;
247+
int flat_idx = 0;
248+
const int kv_chunk_num_x_q_tile_num = kv_chunk_num * q_tile_num;
249+
#pragma unroll 2
250+
for (int kv_head_id = 0; kv_head_id < kv_num_heads; kv_head_id++) {
251+
const int head_base = kv_head_id * kv_chunk_num_x_q_tile_num;
252+
#pragma unroll 2
253+
for (int kv_chunk_id = 0; kv_chunk_id < kv_chunk_num; kv_chunk_id++) {
254+
const int chunk_base = head_base + kv_chunk_id * q_tile_num;
255+
#pragma unroll
230256
for (int q_tile_id = 0; q_tile_id < q_tile_num; q_tile_id++) {
231-
int idx =
232-
write_base * 4 + (kv_head_id * q_tile_num + q_tile_id) * 4;
233-
block_indices[idx] = bid;
234-
block_indices[idx + 1] = kv_head_id;
235-
block_indices[idx + 2] = 0;
236-
block_indices[idx + 3] = q_tile_id;
237-
}
238-
}
239-
} else {
240-
for (int kv_head_id = 0; kv_head_id < kv_num_heads; kv_head_id++) {
241-
for (int kv_chunk_id = 0; kv_chunk_id < kv_chunk_num;
242-
kv_chunk_id++) {
243-
for (int q_tile_id = 0; q_tile_id < q_tile_num; q_tile_id++) {
244-
int idx =
245-
write_base * 4 +
246-
((kv_head_id * kv_chunk_num + kv_chunk_id) * q_tile_num +
247-
q_tile_id) *
248-
4;
249-
block_indices[idx] = bid;
250-
block_indices[idx + 1] = kv_head_id;
251-
block_indices[idx + 2] = kv_chunk_id;
252-
block_indices[idx + 3] = q_tile_id;
253-
}
257+
write_ptr[flat_idx] =
258+
make_int4(bid, kv_head_id, kv_chunk_id, q_tile_id);
259+
flat_idx++;
254260
}
255261
}
256262
}
257263
}
258-
// for next warp tile
259264
prev_offset += tile_sum;
260265
}
261266
}
@@ -318,12 +323,16 @@ void ConfigForAttention(
318323

319324
const int q_tile_size = 16;
320325
dim3 blocks(32, 4);
326+
// Cast block_indices to int4* for vectorized stores.
327+
// Each block_indices entry is 4 ints = 16 bytes = sizeof(int4),
328+
// and block_num * 4 ints = block_num int4s, so the reinterpret is valid.
329+
int4* block_indices_i4 = reinterpret_cast<int4*>(block_indices.data<int>());
321330
if (cache_quant_type == "cache_int4_zp") {
322-
config_decode_attn<256, 128>
331+
config_decode_attn<512, 256, 128, 32768>
323332
<<<1, blocks, 0, stream>>>(seq_lens_this_time.data<int>(),
324333
seq_lens_encoder.data<int>(),
325334
seq_lens_decoder.data<int>(),
326-
block_indices.data<int>(),
335+
block_indices_i4,
327336
num_blocks.data<int>(),
328337
chunk_size.data<int>(),
329338
bsz,
@@ -333,11 +342,11 @@ void ConfigForAttention(
333342
max_tokens_per_batch,
334343
config_gridx);
335344
} else {
336-
config_decode_attn<128, 128>
345+
config_decode_attn<512, 128, 128, 16384>
337346
<<<1, blocks, 0, stream>>>(seq_lens_this_time.data<int>(),
338347
seq_lens_encoder.data<int>(),
339348
seq_lens_decoder.data<int>(),
340-
block_indices.data<int>(),
349+
block_indices_i4,
341350
num_blocks.data<int>(),
342351
chunk_size.data<int>(),
343352
bsz,

custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c16_impl.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ void DecodeUnifiedC16Attention(
400400
int sm_count;
401401
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
402402

403-
const int max_num_chunks = div_up(max_seq_len, 128);
403+
const int max_num_chunks = div_up(max_seq_len, 512);
404404
uint32_t attn_mask_len;
405405
if (attn_mask) {
406406
attn_mask_len = attn_mask.get().shape()[1];

custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c8_impl.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ void DecodeUnifiedC8Attention(const AppendAttnMetaData& meta_data,
610610
int sm_count;
611611
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
612612

613-
const int max_num_chunks = div_up(max_seq_len, 128);
613+
const int max_num_chunks = div_up(max_seq_len, 512);
614614
uint32_t attn_mask_len;
615615
if (attn_mask) {
616616
attn_mask_len = attn_mask.get().shape()[1];

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def allocate_launch_related_buffer(
112112

113113
# Decode unified attention split ops buffers
114114
if envs.USE_DECODE_UNIFIED_ATTENTION:
115-
min_chunk_size = 128
115+
min_chunk_size = 512
116116
max_num_chunk = (max_model_len + min_chunk_size - 1) // min_chunk_size
117117
q_tile_size = 16
118118
q_tile_num = (decoder_step_token_num * group_size + q_tile_size - 1) // q_tile_size

0 commit comments

Comments
 (0)