Skip to content

Commit 72e8bc7

Browse files
committed
resolve comments
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 5338147 commit 72e8bc7

2 files changed

Lines changed: 24 additions & 20 deletions

File tree

src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,24 @@ void *plan(Tensor out,
4545

4646
void run(void *planned_meta) {
4747
#ifdef ENABLE_FLASH_ATTN
48-
#ifdef ENABLE_NVIDIA_API
49-
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
50-
#elif defined(ENABLE_METAX_API)
48+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
5149
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
5250
#endif
5351
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
5452

5553
// Paged KV caches must be contiguous for flash-attn; avoid extra copies for q/metadata when already dense.
56-
auto out_at = infinicore::adaptor::to_aten_tensor(p->out);
57-
const bool out_need_copy_back = !out_at.is_contiguous();
58-
auto out_tensor = out_need_copy_back ? out_at.contiguous() : out_at;
54+
const bool out_need_copy_back = !p->out->is_contiguous();
55+
Tensor out_work = out_need_copy_back ? p->out->contiguous() : Tensor(p->out);
56+
auto out_tensor = infinicore::adaptor::to_aten_tensor(out_work);
5957
auto q = infinicore::adaptor::to_aten_tensor(p->q);
6058
#if defined(ENABLE_NVIDIA_API)
6159
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
6260
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
6361
#elif defined(ENABLE_QY_API) || defined(ENABLE_METAX_API)
64-
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache).contiguous();
65-
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache).contiguous();
62+
Tensor k_cache_work = p->k_cache->contiguous();
63+
Tensor v_cache_work = p->v_cache->contiguous();
64+
auto k_cache = infinicore::adaptor::to_aten_tensor(k_cache_work);
65+
auto v_cache = infinicore::adaptor::to_aten_tensor(v_cache_work);
6666
#endif
6767
auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor(p->seqlens_k));
6868
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
@@ -119,7 +119,7 @@ void run(void *planned_meta) {
119119
out_tensor.copy_(result[0]);
120120
}
121121
if (out_need_copy_back) {
122-
out_at.copy_(out_tensor);
122+
p->out->copy_from(out_work);
123123
}
124124
#else
125125
throw std::runtime_error("FlashAttention is not enabled in this build");

src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,13 @@ namespace {
4949

5050
#ifdef ENABLE_FLASH_ATTN
5151
struct VarlenFlashPrepared {
52+
Tensor k_work;
53+
Tensor v_work;
54+
Tensor out_work_ic;
5255
at::Tensor q;
5356
at::Tensor k;
5457
at::Tensor v;
55-
at::Tensor out_at;
5658
bool out_need_copy_back;
57-
at::Tensor out_work;
5859
std::optional<at::Tensor> out_opt;
5960
at::Tensor cu_seqlens_q;
6061
at::Tensor cu_seqlens_kv;
@@ -69,12 +70,15 @@ VarlenFlashPrepared prepare_varlen_flash_tensors(PlannedMeta *p) {
6970
VarlenFlashPrepared t;
7071
// Varlen flash-attn: keep k/v contiguous for dense/paged layout; avoid extra copies for q/metadata when already dense.
7172
t.q = infinicore::adaptor::to_aten_tensor(p->q);
72-
t.k = infinicore::adaptor::to_aten_tensor(p->k).contiguous();
73-
t.v = infinicore::adaptor::to_aten_tensor(p->v).contiguous();
74-
t.out_at = infinicore::adaptor::to_aten_tensor(p->out);
75-
t.out_need_copy_back = !t.out_at.is_contiguous();
76-
t.out_work = t.out_need_copy_back ? t.out_at.contiguous() : t.out_at;
77-
t.out_opt = std::optional<at::Tensor>(t.out_work);
73+
t.k_work = p->k->contiguous();
74+
t.v_work = p->v->contiguous();
75+
t.k = infinicore::adaptor::to_aten_tensor(t.k_work);
76+
t.v = infinicore::adaptor::to_aten_tensor(t.v_work);
77+
78+
t.out_need_copy_back = !p->out->is_contiguous();
79+
t.out_work_ic = t.out_need_copy_back ? p->out->contiguous() : Tensor(p->out);
80+
auto out_work = infinicore::adaptor::to_aten_tensor(t.out_work_ic);
81+
t.out_opt = std::optional<at::Tensor>(out_work);
7882
t.cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q);
7983
t.cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k);
8084
t.block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
@@ -87,9 +91,9 @@ VarlenFlashPrepared prepare_varlen_flash_tensors(PlannedMeta *p) {
8791
return t;
8892
}
8993

90-
void copy_varlen_flash_output_back(VarlenFlashPrepared &t) {
94+
void copy_varlen_flash_output_back(PlannedMeta *p, VarlenFlashPrepared &t) {
9195
if (t.out_need_copy_back) {
92-
t.out_at.copy_(t.out_work);
96+
p->out->copy_from(t.out_work_ic);
9397
}
9498
}
9599

@@ -135,7 +139,7 @@ void run_flashattn_varlen_metax(PlannedMeta *p) {
135139
flash_attn_mars_ext
136140
#endif
137141
);
138-
copy_varlen_flash_output_back(t);
142+
copy_varlen_flash_output_back(p, t);
139143
}
140144
#endif
141145

0 commit comments

Comments
 (0)