Skip to content

Commit f8fab93

Browse files
committed
simplify mha_varlen_flashattn.cc
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent af889fd commit f8fab93

File tree

1 file changed

+12
-31
lines changed

1 file changed

+12
-31
lines changed

src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -67,21 +67,21 @@ struct VarlenFlashPrepared {
6767

6868
VarlenFlashPrepared prepare_varlen_flash_tensors(PlannedMeta *p) {
6969
VarlenFlashPrepared t;
70-
// FlashAttention kernels expect standard dense layout (contiguous last dimension).
71-
t.q = infinicore::adaptor::to_aten_tensor(p->q).contiguous();
70+
// Varlen flash-attn: keep k/v contiguous for dense/paged layout; avoid extra copies for q/metadata when already dense.
71+
t.q = infinicore::adaptor::to_aten_tensor(p->q);
7272
t.k = infinicore::adaptor::to_aten_tensor(p->k).contiguous();
7373
t.v = infinicore::adaptor::to_aten_tensor(p->v).contiguous();
7474
t.out_at = infinicore::adaptor::to_aten_tensor(p->out);
7575
t.out_need_copy_back = !t.out_at.is_contiguous();
7676
t.out_work = t.out_need_copy_back ? t.out_at.contiguous() : t.out_at;
7777
t.out_opt = std::optional<at::Tensor>(t.out_work);
78-
t.cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q).contiguous();
79-
t.cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k).contiguous();
80-
t.block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table).contiguous());
78+
t.cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q);
79+
t.cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k);
80+
t.block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
8181
t.max_seqlen_q = p->max_seqlen_q;
8282
t.max_seqlen_k = p->max_seqlen_k;
8383
t.alibi_slopes = p->alibi_slopes
84-
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes).contiguous())
84+
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes))
8585
: std::nullopt;
8686
t.scale = p->scale;
8787
return t;
@@ -107,6 +107,7 @@ void run_flashattn_varlen_metax(PlannedMeta *p) {
107107
// depending on the HPCC/MetaX stack version.
108108
#if defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
109109
std::optional<at::Tensor> flash_attn_mars_ext = std::nullopt;
110+
#endif
110111
::mha_varlen_fwd(
111112
t.q,
112113
t.k,
@@ -128,32 +129,12 @@ void run_flashattn_varlen_metax(PlannedMeta *p) {
128129
-1,
129130
0.0,
130131
false,
131-
std::nullopt,
132-
flash_attn_mars_ext);
133-
#else
134-
::mha_varlen_fwd(
135-
t.q,
136-
t.k,
137-
t.v,
138-
t.out_opt,
139-
t.cu_seqlens_q,
140-
t.cu_seqlens_kv,
141-
seqused_k,
142-
leftpad_k,
143-
t.block_table,
144-
t.alibi_slopes,
145-
t.max_seqlen_q,
146-
t.max_seqlen_k,
147-
0.0,
148-
t.scale,
149-
false,
150-
true,
151-
-1,
152-
-1,
153-
0.0,
154-
false,
155-
std::nullopt);
132+
std::nullopt
133+
#if defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
134+
,
135+
flash_attn_mars_ext
156136
#endif
137+
);
157138
copy_varlen_flash_output_back(t);
158139
}
159140
#endif

0 commit comments

Comments
 (0)