44
55#include < stdexcept>
66
7+ #ifdef ENABLE_FLASH_ATTN
8+ #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API)
9+ #include < c10/cuda/CUDAGuard.h>
10+ #endif
11+ #endif
12+
13+ #if defined(ENABLE_METAX_API)
14+ #define INFINICORE_FLASH_OP (name ) ::name
15+ #else
16+ #define INFINICORE_FLASH_OP (name ) flash::name
17+ #endif
18+
719namespace infinicore ::op::mha_kvcache_impl::flashattn {
820
921struct PlannedMeta {
@@ -33,22 +45,24 @@ void *plan(Tensor out,
3345
3446void run (void *planned_meta) {
3547#ifdef ENABLE_FLASH_ATTN
48+ #ifdef ENABLE_NVIDIA_API
3649 c10::cuda::CUDAStreamGuard guard (infinicore::adaptor::get_cuda_stream ());
50+ #elif defined(ENABLE_METAX_API)
51+ c10::cuda::CUDAStreamGuard guard (infinicore::adaptor::get_cuda_stream ());
52+ #endif
3753 auto *p = reinterpret_cast <PlannedMeta *>(planned_meta);
3854
39- auto out_tensor = infinicore::adaptor::to_aten_tensor (p->out );
40- auto q = infinicore::adaptor::to_aten_tensor (p->q );
41- #if defined(ENABLE_NVIDIA_API)
42- auto k_cache = infinicore::adaptor::to_aten_tensor (p->k_cache );
43- auto v_cache = infinicore::adaptor::to_aten_tensor (p->v_cache );
44- #elif defined(ENABLE_QY_API)
55+ // FlashAttention kernels expect standard dense layout (contiguous last dimension).
56+ auto out_at = infinicore::adaptor::to_aten_tensor (p->out );
57+ const bool out_need_copy_back = !out_at.is_contiguous ();
58+ auto out_tensor = out_need_copy_back ? out_at.contiguous () : out_at;
59+ auto q = infinicore::adaptor::to_aten_tensor (p->q ).contiguous ();
4560 auto k_cache = infinicore::adaptor::to_aten_tensor (p->k_cache ).contiguous ();
4661 auto v_cache = infinicore::adaptor::to_aten_tensor (p->v_cache ).contiguous ();
47- #endif
48- auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor (p->seqlens_k ));
49- auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor (p->block_table ));
62+ auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor (p->seqlens_k ).contiguous ());
63+ auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor (p->block_table ).contiguous ());
5064 auto alibi_slopes = p->alibi_slopes
51- ? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor (*p->alibi_slopes ))
65+ ? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor (*p->alibi_slopes ). contiguous () )
5266 : std::nullopt ;
5367
5468 std::optional<const at::Tensor> k_new = std::nullopt ;
@@ -65,7 +79,11 @@ void run(void *planned_meta) {
6579 auto out = use_dynamic_out ? std::optional<at::Tensor>(std::nullopt )
6680 : std::optional<at::Tensor>(out_tensor);
6781
68- auto result = flash::mha_fwd_kvcache (
82+ #if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
83+ std::optional<at::Tensor> flash_attn_mars_ext = std::nullopt ;
84+ #endif
85+
86+ auto result = INFINICORE_FLASH_OP (mha_fwd_kvcache)(
6987 q,
7088 k_cache,
7189 v_cache,
@@ -85,11 +103,19 @@ void run(void *planned_meta) {
85103 -1 ,
86104 0 .0f ,
87105 false ,
88- 0 );
106+ 0
107+ #if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
108+ ,
109+ flash_attn_mars_ext
110+ #endif
111+ );
89112
90113 if (use_dynamic_out) {
91114 out_tensor.copy_ (result[0 ]);
92115 }
116+ if (out_need_copy_back) {
117+ out_at.copy_ (out_tensor);
118+ }
93119#else
94120 throw std::runtime_error (" FlashAttention is not enabled in this build" );
95121#endif
0 commit comments