44
55#include < stdexcept>
66
7+ #ifdef ENABLE_FLASH_ATTN
8+ #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
9+ #include < c10/cuda/CUDAGuard.h>
10+ #endif
11+ #endif
12+
13+ #if defined(ENABLE_METAX_API)
14+ #define INFINICORE_FLASH_OP (name ) ::name
15+ #else
16+ #define INFINICORE_FLASH_OP (name ) flash::name
17+ #endif
18+
719namespace infinicore ::op::mha_kvcache_impl::flashattn {
820
921struct PlannedMeta {
@@ -33,17 +45,24 @@ void *plan(Tensor out,
3345
3446void run (void *planned_meta) {
3547#ifdef ENABLE_FLASH_ATTN
48+ #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
3649 c10::cuda::CUDAStreamGuard guard (infinicore::adaptor::get_cuda_stream ());
50+ #endif
3751 auto *p = reinterpret_cast <PlannedMeta *>(planned_meta);
3852
39- auto out_tensor = infinicore::adaptor::to_aten_tensor (p->out );
53+ // Paged KV caches must be contiguous for flash-attn; avoid extra copies for q/metadata when already dense.
54+ const bool out_need_copy_back = !p->out ->is_contiguous ();
55+ Tensor out_work = out_need_copy_back ? p->out ->contiguous () : Tensor (p->out );
56+ auto out_tensor = infinicore::adaptor::to_aten_tensor (out_work);
4057 auto q = infinicore::adaptor::to_aten_tensor (p->q );
4158#if defined(ENABLE_NVIDIA_API)
4259 auto k_cache = infinicore::adaptor::to_aten_tensor (p->k_cache );
4360 auto v_cache = infinicore::adaptor::to_aten_tensor (p->v_cache );
44- #elif defined(ENABLE_QY_API)
45- auto k_cache = infinicore::adaptor::to_aten_tensor (p->k_cache ).contiguous ();
46- auto v_cache = infinicore::adaptor::to_aten_tensor (p->v_cache ).contiguous ();
61+ #elif defined(ENABLE_QY_API) || defined(ENABLE_METAX_API)
62+ Tensor k_cache_work = p->k_cache ->contiguous ();
63+ Tensor v_cache_work = p->v_cache ->contiguous ();
64+ auto k_cache = infinicore::adaptor::to_aten_tensor (k_cache_work);
65+ auto v_cache = infinicore::adaptor::to_aten_tensor (v_cache_work);
4766#endif
4867 auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor (p->seqlens_k ));
4968 auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor (p->block_table ));
@@ -65,7 +84,11 @@ void run(void *planned_meta) {
6584 auto out = use_dynamic_out ? std::optional<at::Tensor>(std::nullopt )
6685 : std::optional<at::Tensor>(out_tensor);
6786
68- auto result = flash::mha_fwd_kvcache (
87+ #if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
88+ std::optional<at::Tensor> flash_attn_mars_ext = std::nullopt ;
89+ #endif
90+
91+ auto result = INFINICORE_FLASH_OP (mha_fwd_kvcache)(
6992 q,
7093 k_cache,
7194 v_cache,
@@ -85,11 +108,19 @@ void run(void *planned_meta) {
85108 -1 ,
86109 0 .0f ,
87110 false ,
88- 0 );
111+ 0
112+ #if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
113+ ,
114+ flash_attn_mars_ext
115+ #endif
116+ );
89117
90118 if (use_dynamic_out) {
91119 out_tensor.copy_ (result[0 ]);
92120 }
121+ if (out_need_copy_back) {
122+ p->out ->copy_from (out_work);
123+ }
93124#else
94125 throw std::runtime_error (" FlashAttention is not enabled in this build" );
95126#endif
0 commit comments