@@ -49,12 +49,13 @@ namespace {
4949
5050#ifdef ENABLE_FLASH_ATTN
5151struct VarlenFlashPrepared {
52+ Tensor k_work;
53+ Tensor v_work;
54+ Tensor out_work_ic;
5255 at::Tensor q;
5356 at::Tensor k;
5457 at::Tensor v;
55- at::Tensor out_at;
5658 bool out_need_copy_back;
57- at::Tensor out_work;
5859 std::optional<at::Tensor> out_opt;
5960 at::Tensor cu_seqlens_q;
6061 at::Tensor cu_seqlens_kv;
@@ -69,12 +70,15 @@ VarlenFlashPrepared prepare_varlen_flash_tensors(PlannedMeta *p) {
6970 VarlenFlashPrepared t;
7071 // Varlen flash-attn: keep k/v contiguous for dense/paged layout; avoid extra copies for q/metadata when already dense.
7172 t.q = infinicore::adaptor::to_aten_tensor (p->q );
72- t.k = infinicore::adaptor::to_aten_tensor (p->k ).contiguous ();
73- t.v = infinicore::adaptor::to_aten_tensor (p->v ).contiguous ();
74- t.out_at = infinicore::adaptor::to_aten_tensor (p->out );
75- t.out_need_copy_back = !t.out_at .is_contiguous ();
76- t.out_work = t.out_need_copy_back ? t.out_at .contiguous () : t.out_at ;
77- t.out_opt = std::optional<at::Tensor>(t.out_work );
73+ t.k_work = p->k ->contiguous ();
74+ t.v_work = p->v ->contiguous ();
75+ t.k = infinicore::adaptor::to_aten_tensor (t.k_work );
76+ t.v = infinicore::adaptor::to_aten_tensor (t.v_work );
77+
78+ t.out_need_copy_back = !p->out ->is_contiguous ();
79+ t.out_work_ic = t.out_need_copy_back ? p->out ->contiguous () : Tensor (p->out );
80+ auto out_work = infinicore::adaptor::to_aten_tensor (t.out_work_ic );
81+ t.out_opt = std::optional<at::Tensor>(out_work);
7882 t.cu_seqlens_q = infinicore::adaptor::to_aten_tensor (p->cum_seqlens_q );
7983 t.cu_seqlens_kv = infinicore::adaptor::to_aten_tensor (p->cum_seqlens_k );
8084 t.block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor (p->block_table ));
@@ -87,9 +91,9 @@ VarlenFlashPrepared prepare_varlen_flash_tensors(PlannedMeta *p) {
8791 return t;
8892}
8993
90- void copy_varlen_flash_output_back (VarlenFlashPrepared &t) {
94+ void copy_varlen_flash_output_back (PlannedMeta *p, VarlenFlashPrepared &t) {
9195 if (t.out_need_copy_back ) {
92- t. out_at . copy_ (t.out_work );
96+ p-> out -> copy_from (t.out_work_ic );
9397 }
9498}
9599
@@ -135,7 +139,7 @@ void run_flashattn_varlen_metax(PlannedMeta *p) {
135139 flash_attn_mars_ext
136140#endif
137141 );
138- copy_varlen_flash_output_back (t);
142+ copy_varlen_flash_output_back (p, t);
139143}
140144#endif
141145
0 commit comments