Skip to content

Commit 734fbcf

Browse files
authored
[BugFix] Fix Async D2H copy bug & flash mash atten cache V out of bound bug (#7221)
1 parent 3c54a41 commit 734fbcf

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ void GetBlockShapeAndSplitKVBlock(
296296
if (!phi::backends::gpu::IsCUDAGraphCapturing())
297297
#endif
298298
max_len_tensor_cpu.copy_(
299-
max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
299+
max_len_tensor_gpu, max_len_tensor_cpu.place(), true);
300300

301301
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
302302
int max_len_this_time = max_len_cpu_ptr[0];
@@ -378,7 +378,7 @@ void GetBlockShapeAndSplitKVBlock(
378378
if (!phi::backends::gpu::IsCUDAGraphCapturing())
379379
#endif
380380
decoder_num_blocks_cpu.copy_(
381-
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
381+
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), true);
382382
}
383383
}
384384
// mla_backend not need run the following code.
@@ -409,7 +409,7 @@ void GetBlockShapeAndSplitKVBlock(
409409
block_size);
410410

411411
kv_num_blocks_x_cpu.copy_(
412-
kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
412+
kv_num_blocks_x, kv_num_blocks_x_cpu.place(), true);
413413
// Clear buffer
414414
const uint32_t encoder_max_tile_size_per_bs_q =
415415
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
@@ -433,7 +433,7 @@ void GetBlockShapeAndSplitKVBlock(
433433
encoder_block_shape_q,
434434
group_size);
435435
encoder_num_blocks_x_cpu.copy_(
436-
encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
436+
encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), true);
437437
}
438438
}
439439

custom_ops/gpu_ops/append_attn/pre_cache_len_concat.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ std::vector<paddle::Tensor> PreCacheLenConcat(
8787
bsz,
8888
block_size);
8989
paddle::Tensor pre_cache_num_blocks_cpu =
90-
pre_cache_num_blocks.copy_to(paddle::CPUPlace(), false);
90+
pre_cache_num_blocks.copy_to(paddle::CPUPlace(), true);
9191
paddle::Tensor kv_token_num_cpu =
92-
kv_token_num.copy_to(paddle::CPUPlace(), false);
92+
kv_token_num.copy_to(paddle::CPUPlace(), true);
9393

9494
return {
9595
cu_seqlens_k,

custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,23 @@ struct CollectiveMainloopAttn {
490490

491491
softmax.rescale_o(tOrO, scores_scale);
492492
consumer_wait(pipeline_v, smem_pipe_read_v);
493+
if (seq_len_k - n_block * kBlockN < kBlockN) {
494+
int valid_k = seq_len_k - n_block * kBlockN;
495+
auto sVt_this = sVt(_, _, smem_pipe_read_v.index());
496+
constexpr int kHdLo = decltype(get<0, 0>(shape(sVt_this)))::value;
497+
constexpr int kHdHi = decltype(get<0, 1>(shape(sVt_this)))::value;
498+
if (thread_idx >= valid_k && thread_idx < kBlockN) {
499+
#pragma unroll
500+
for (int hd_hi = 0; hd_hi < kHdHi; ++hd_hi) {
501+
#pragma unroll
502+
for (int hd_lo = 0; hd_lo < kHdLo; ++hd_lo) {
503+
sVt_this(make_coord(make_coord(hd_lo, hd_hi), thread_idx)) =
504+
Element(0);
505+
}
506+
}
507+
}
508+
cutlass::arch::fence_view_async_shared();
509+
}
493510
gemm</*zero_init=*/false, /*wg_wait=*/-1>(
494511
tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
495512
warp_scheduler_barrier_arrive();

0 commit comments

Comments
 (0)