Skip to content

Commit 12244c2

Browse files
committed
metax fla-attn
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 73fb6a8 commit 12244c2

11 files changed

Lines changed: 424 additions & 34 deletions

File tree

include/infinicore/adaptor/aten_adaptor.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
#include <ATen/ATen.h>
77

8-
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
8+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
99
#include <ATen/cuda/CUDAContext.h>
1010
#include <c10/cuda/CUDAGuard.h>
1111
#include <c10/cuda/CUDAStream.h>
@@ -30,20 +30,20 @@ inline at::ScalarType to_at_dtype(DataType dtype) {
3030
}
3131

3232
inline at::Device to_at_device(const Device &device) {
33-
if (device.getType() == Device::Type::NVIDIA) {
33+
// PyTorch ATen only exposes standard device types (e.g. kCPU/kCUDA).
34+
// Treat MetaX/QY devices as CUDA devices for ATen tensor interoperability.
35+
if (device.getType() == Device::Type::NVIDIA || device.getType() == Device::Type::METAX || device.getType() == Device::Type::QY) {
3436
return at::Device(at::kCUDA, device.getIndex());
3537
} else if (device.getType() == Device::Type::CPU) {
3638
return at::Device(at::kCPU);
37-
} else if (device.getType() == Device::Type::QY) {
38-
return at::Device(at::kCUDA, device.getIndex());
3939
} else {
4040
throw std::runtime_error("Unsupported device type for ATen");
4141
}
4242
}
4343

4444
at::Tensor to_aten_tensor(const infinicore::Tensor &t);
4545

46-
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
46+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
4747
c10::cuda::CUDAStream get_cuda_stream();
4848
#endif
4949
} // namespace infinicore::adaptor

include/infinicore/adaptor/flash_attention_adaptor.hpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
#pragma once
33
#include "aten_adaptor.hpp"
44

5+
// NVIDIA flash-attn-nvidia.so uses namespace flash. The pip/MetaX flash_attn_2_cuda extension
6+
// exports the same entry points at global scope (no namespace), matching FLASH_NAMESPACE builds
7+
// where the namespace is empty.
8+
#if !defined(ENABLE_METAX_API)
59
namespace flash {
10+
#endif
611
std::vector<at::Tensor>
712
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
813
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
@@ -39,7 +44,13 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_hea
3944
int window_size_right,
4045
const float softcap,
4146
const bool return_softmax,
42-
std::optional<at::Generator> gen_);
47+
std::optional<at::Generator> gen_
48+
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
49+
// MetaX/Mars `flash_attn_2_cuda` (e.g. 2.6.x+mars) appends this argument vs upstream Dao-AILab flash-attn.
50+
,
51+
std::optional<at::Tensor> &flash_attn_mars_ext_
52+
#endif
53+
);
4354

4455
std::vector<at::Tensor>
4556
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
@@ -108,7 +119,15 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size
108119
int window_size_right,
109120
const float softcap,
110121
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
111-
int num_splits);
122+
int num_splits
123+
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
124+
// MetaX/Mars `flash_attn_2_cuda` (e.g. 2.6.x+mars) appends this argument vs upstream Dao-AILab flash-attn.
125+
,
126+
std::optional<at::Tensor> &flash_attn_mars_ext_
127+
#endif
128+
);
112129

130+
#if !defined(ENABLE_METAX_API)
113131
} // namespace flash
132+
#endif
114133
#endif // ENABLE_FLASH_ATTN

scripts/install.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
import subprocess
33
import platform
44
import sys
5-
from set_env import set_env
5+
from set_env import (
6+
set_env,
7+
set_env_by_config,
8+
)
69

710
PROJECT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
811
os.chdir(PROJECT_DIR)
@@ -12,6 +15,7 @@ def run_cmd(cmd):
1215

1316

1417
def install(xmake_config_flags=""):
18+
set_env_by_config(xmake_config_flags)
1519
run_cmd(f"xmake f {xmake_config_flags} -cv")
1620
run_cmd("xmake")
1721
run_cmd("xmake install")

scripts/metax_env.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
3+
4+
def _first_existing_dir(paths: list[str]) -> str:
5+
for p in paths:
6+
if p and os.path.isdir(p):
7+
return p
8+
return ""
9+
10+
11+
def _metax_toolkit_root(use_mc: bool) -> str:
12+
"""Return toolkit root for MetaX builds (MACA when use-mc; otherwise HPCC)."""
13+
if use_mc:
14+
for key in ("MACA_PATH", "MACA_HOME", "MACA_ROOT"):
15+
v = os.environ.get(key, "").strip()
16+
if v:
17+
return v
18+
return _first_existing_dir(["/opt/maca"])
19+
return _first_existing_dir(["/opt/hpcc"])
20+
21+
22+
def _prepend_path_var(name: str, prefixes: list[str]) -> None:
23+
"""Prepend colon-separated *prefixes* to env var *name* (POSIX)."""
24+
if not prefixes:
25+
return
26+
chunk = ":".join(prefixes)
27+
cur = os.environ.get(name, "")
28+
os.environ[name] = f"{chunk}:{cur}" if cur else chunk
29+
30+
31+
def set_env_for_metax_gpu(
32+
flags: str,
33+
*,
34+
parse_xmake_cli_flag_values,
35+
truthy_flag_value,
36+
) -> None:
37+
"""
38+
Prepend compiler include paths needed when building ATen-enabled C++ against torch headers.
39+
40+
This chooses paths based on xmake backend flags (e.g. --metax-gpu) and toolkit selection
41+
(e.g. MetaX HPCC vs MACA when --use-mc=y).
42+
"""
43+
d = parse_xmake_cli_flag_values(flags)
44+
if not truthy_flag_value(d.get("aten", "n")):
45+
return
46+
47+
if truthy_flag_value(d.get("metax-gpu", "n")):
48+
use_mc = truthy_flag_value(d.get("use-mc", "n"))
49+
root = _metax_toolkit_root(use_mc=use_mc)
50+
if not root:
51+
return
52+
dirs = [
53+
os.path.join(root, "tools", "cu-bridge", "include"),
54+
os.path.join(root, "include", "hcr"),
55+
os.path.join(root, "include"),
56+
]
57+
for var in ("CPATH", "CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH"):
58+
_prepend_path_var(var, dirs)
59+
return
60+

scripts/set_env.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,46 @@
11
import os
22
import platform
33

4+
from metax_env import set_env_for_metax_gpu
5+
6+
7+
def _parse_xmake_cli_flag_values(flags: str):
8+
"""Parse a string like '--metax-gpu=y --aten=y' into {key: value}."""
9+
parts = flags.replace("=", " ").split()
10+
d = {}
11+
i = 0
12+
n = len(parts)
13+
while i < n:
14+
p = parts[i]
15+
if p.startswith("--") and len(p) > 2:
16+
key = p[2:].lower()
17+
i += 1
18+
if i < n and not parts[i].startswith("--"):
19+
d[key] = parts[i].lower()
20+
i += 1
21+
else:
22+
d[key] = "y"
23+
else:
24+
i += 1
25+
return d
26+
27+
28+
def _truthy_flag_value(v: str) -> bool:
29+
return v in ("y", "yes", "true", "1", "on")
30+
31+
32+
def set_env_by_config(flags: str) -> None:
33+
"""Set environment variables for InfiniCore builds with xmake config flags."""
34+
d = _parse_xmake_cli_flag_values(flags)
35+
if _truthy_flag_value(d.get("metax-gpu", "n")):
36+
set_env_for_metax_gpu(
37+
flags,
38+
parse_xmake_cli_flag_values=_parse_xmake_cli_flag_values,
39+
truthy_flag_value=_truthy_flag_value,
40+
)
41+
else:
42+
pass
43+
444

545
def set_env():
646
if os.environ.get("INFINI_ROOT") == None:

src/infinicore/adaptor/aten_adaptor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
3232
options);
3333
}
3434

35-
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
35+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
3636
c10::cuda::CUDAStream get_cuda_stream() {
3737
return c10::cuda::getStreamFromExternal(
3838
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());

src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@
44

55
#include <stdexcept>
66

7+
#ifdef ENABLE_FLASH_ATTN
8+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
9+
#include <c10/cuda/CUDAGuard.h>
10+
#endif
11+
#endif
12+
13+
#if defined(ENABLE_METAX_API)
14+
#define INFINICORE_FLASH_OP(name) ::name
15+
#else
16+
#define INFINICORE_FLASH_OP(name) flash::name
17+
#endif
18+
719
namespace infinicore::op::mha_kvcache_impl::flashattn {
820

921
struct PlannedMeta {
@@ -33,17 +45,24 @@ void *plan(Tensor out,
3345

3446
void run(void *planned_meta) {
3547
#ifdef ENABLE_FLASH_ATTN
48+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
3649
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
50+
#endif
3751
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
3852

39-
auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out);
53+
// Paged KV caches must be contiguous for flash-attn; avoid extra copies for q/metadata when already dense.
54+
const bool out_need_copy_back = !p->out->is_contiguous();
55+
Tensor out_work = out_need_copy_back ? p->out->contiguous() : Tensor(p->out);
56+
auto out_tensor = infinicore::adaptor::to_aten_tensor(out_work);
4057
auto q = infinicore::adaptor::to_aten_tensor(p->q);
4158
#if defined(ENABLE_NVIDIA_API)
4259
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
4360
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
44-
#elif defined(ENABLE_QY_API)
45-
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache).contiguous();
46-
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache).contiguous();
61+
#elif defined(ENABLE_QY_API) || defined(ENABLE_METAX_API)
62+
Tensor k_cache_work = p->k_cache->contiguous();
63+
Tensor v_cache_work = p->v_cache->contiguous();
64+
auto k_cache = infinicore::adaptor::to_aten_tensor(k_cache_work);
65+
auto v_cache = infinicore::adaptor::to_aten_tensor(v_cache_work);
4766
#endif
4867
auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor(p->seqlens_k));
4968
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
@@ -65,7 +84,11 @@ void run(void *planned_meta) {
6584
auto out = use_dynamic_out ? std::optional<at::Tensor>(std::nullopt)
6685
: std::optional<at::Tensor>(out_tensor);
6786

68-
auto result = flash::mha_fwd_kvcache(
87+
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
88+
std::optional<at::Tensor> flash_attn_mars_ext = std::nullopt;
89+
#endif
90+
91+
auto result = INFINICORE_FLASH_OP(mha_fwd_kvcache)(
6992
q,
7093
k_cache,
7194
v_cache,
@@ -85,11 +108,19 @@ void run(void *planned_meta) {
85108
-1,
86109
0.0f,
87110
false,
88-
0);
111+
0
112+
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
113+
,
114+
flash_attn_mars_ext
115+
#endif
116+
);
89117

90118
if (use_dynamic_out) {
91119
out_tensor.copy_(result[0]);
92120
}
121+
if (out_need_copy_back) {
122+
p->out->copy_from(out_work);
123+
}
93124
#else
94125
throw std::runtime_error("FlashAttention is not enabled in this build");
95126
#endif

0 commit comments

Comments
 (0)