@@ -67,21 +67,21 @@ struct VarlenFlashPrepared {
6767
6868VarlenFlashPrepared prepare_varlen_flash_tensors (PlannedMeta *p) {
6969 VarlenFlashPrepared t;
70- // FlashAttention kernels expect standard dense layout (contiguous last dimension) .
71- t.q = infinicore::adaptor::to_aten_tensor (p->q ). contiguous () ;
70+ // Varlen flash-attn: keep k/v contiguous for dense/paged layout; avoid extra copies for q/metadata when already dense .
71+ t.q = infinicore::adaptor::to_aten_tensor (p->q );
7272 t.k = infinicore::adaptor::to_aten_tensor (p->k ).contiguous ();
7373 t.v = infinicore::adaptor::to_aten_tensor (p->v ).contiguous ();
7474 t.out_at = infinicore::adaptor::to_aten_tensor (p->out );
7575 t.out_need_copy_back = !t.out_at .is_contiguous ();
7676 t.out_work = t.out_need_copy_back ? t.out_at .contiguous () : t.out_at ;
7777 t.out_opt = std::optional<at::Tensor>(t.out_work );
78- t.cu_seqlens_q = infinicore::adaptor::to_aten_tensor (p->cum_seqlens_q ). contiguous () ;
79- t.cu_seqlens_kv = infinicore::adaptor::to_aten_tensor (p->cum_seqlens_k ). contiguous () ;
80- t.block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor (p->block_table ). contiguous () );
78+ t.cu_seqlens_q = infinicore::adaptor::to_aten_tensor (p->cum_seqlens_q );
79+ t.cu_seqlens_kv = infinicore::adaptor::to_aten_tensor (p->cum_seqlens_k );
80+ t.block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor (p->block_table ));
8181 t.max_seqlen_q = p->max_seqlen_q ;
8282 t.max_seqlen_k = p->max_seqlen_k ;
8383 t.alibi_slopes = p->alibi_slopes
84- ? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor (*p->alibi_slopes ). contiguous () )
84+ ? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor (*p->alibi_slopes ))
8585 : std::nullopt ;
8686 t.scale = p->scale ;
8787 return t;
@@ -107,6 +107,7 @@ void run_flashattn_varlen_metax(PlannedMeta *p) {
107107 // depending on the HPCC/MetaX stack version.
108108#if defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
109109 std::optional<at::Tensor> flash_attn_mars_ext = std::nullopt ;
110+ #endif
110111 ::mha_varlen_fwd (
111112 t.q,
112113 t.k,
@@ -128,32 +129,12 @@ void run_flashattn_varlen_metax(PlannedMeta *p) {
128129 -1 ,
129130 0.0 ,
130131 false ,
131- std::nullopt ,
132- flash_attn_mars_ext);
133- #else
134- ::mha_varlen_fwd (
135- t.q,
136- t.k,
137- t.v,
138- t.out_opt,
139- t.cu_seqlens_q,
140- t.cu_seqlens_kv,
141- seqused_k,
142- leftpad_k,
143- t.block_table,
144- t.alibi_slopes,
145- t.max_seqlen_q,
146- t.max_seqlen_k,
147- 0.0 ,
148- t.scale,
149- false ,
150- true ,
151- -1 ,
152- -1 ,
153- 0.0 ,
154- false ,
155- std::nullopt );
132+ std::nullopt
133+ #if defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
134+ ,
135+ flash_attn_mars_ext
156136#endif
137+ );
157138 copy_varlen_flash_output_back (t);
158139}
159140#endif
0 commit comments