Skip to content

Commit 7870e55

Browse files
committed
metax support flash-attn
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 6e88052 commit 7870e55

File tree

13 files changed

+438
-45
lines changed

13 files changed

+438
-45
lines changed

include/infinicore/adaptor/aten_adaptor.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
#include <ATen/ATen.h>
77

8-
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
8+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API)
99
#include <c10/cuda/CUDAStream.h>
10-
#include <c10/cuda/CUDAGuard.h>
10+
#endif
11+
12+
#ifdef ENABLE_NVIDIA_API
1113
#include <ATen/cuda/CUDAContext.h>
1214
#endif
1315

@@ -30,20 +32,18 @@ inline at::ScalarType to_at_dtype(DataType dtype) {
3032
}
3133

3234
inline at::Device to_at_device(const Device &device) {
33-
if (device.getType() == Device::Type::NVIDIA) {
35+
if (device.getType() == Device::Type::NVIDIA || device.getType() == Device::Type::METAX) {
3436
return at::Device(at::kCUDA, device.getIndex());
3537
} else if (device.getType() == Device::Type::CPU) {
3638
return at::Device(at::kCPU);
37-
} else if (device.getType() == Device::Type::QY) {
38-
return at::Device(at::kCUDA, device.getIndex());
3939
} else {
4040
throw std::runtime_error("Unsupported device type for ATen");
4141
}
4242
}
4343

4444
at::Tensor to_aten_tensor(const infinicore::Tensor &t);
4545

46-
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
46+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API)
4747
c10::cuda::CUDAStream get_cuda_stream();
4848
#endif
4949
} // namespace infinicore::adaptor

include/infinicore/adaptor/flash_attention_adaptor.hpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
#pragma once
33
#include "aten_adaptor.hpp"
44

5+
// NVIDIA flash-attn-nvidia.so uses namespace flash. The pip/MetaX flash_attn_2_cuda extension
6+
// exports the same entry points at global scope (no namespace), matching FLASH_NAMESPACE builds
7+
// where the namespace is empty.
8+
#if !defined(ENABLE_METAX_API)
59
namespace flash {
10+
#endif
611
std::vector<at::Tensor>
712
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
813
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
@@ -39,7 +44,13 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_hea
3944
int window_size_right,
4045
const float softcap,
4146
const bool return_softmax,
42-
std::optional<at::Generator> gen_);
47+
std::optional<at::Generator> gen_
48+
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
49+
// MetaX/Mars `flash_attn_2_cuda` (e.g. 2.6.x+mars) appends this argument vs upstream Dao-AILab flash-attn.
50+
,
51+
std::optional<at::Tensor> &flash_attn_mars_ext_
52+
#endif
53+
);
4354

4455
std::vector<at::Tensor>
4556
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
@@ -108,7 +119,15 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size
108119
int window_size_right,
109120
const float softcap,
110121
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
111-
int num_splits);
122+
int num_splits
123+
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
124+
// MetaX/Mars `flash_attn_2_cuda` (e.g. 2.6.x+mars) appends this argument vs upstream Dao-AILab flash-attn.
125+
,
126+
std::optional<at::Tensor> &flash_attn_mars_ext_
127+
#endif
128+
);
112129

130+
#if !defined(ENABLE_METAX_API)
113131
} // namespace flash
132+
#endif
114133
#endif // ENABLE_FLASH_ATTN

scripts/install.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import subprocess
33
import platform
44
import sys
5-
from set_env import set_env
5+
from set_env import (
6+
set_env,
7+
ensure_metax_hpc_compiler_includes,
8+
xmake_flags_need_metax_aten_torch_includes,
9+
)
610

711
PROJECT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
812
os.chdir(PROJECT_DIR)
@@ -12,6 +16,8 @@ def run_cmd(cmd):
1216

1317

1418
def install(xmake_config_flags=""):
19+
if xmake_flags_need_metax_aten_torch_includes(xmake_config_flags):
20+
ensure_metax_hpc_compiler_includes()
1521
run_cmd(f"xmake f {xmake_config_flags} -cv")
1622
run_cmd("xmake")
1723
run_cmd("xmake install")

scripts/set_env.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,81 @@
22
import platform
33

44

5+
def _maca_root_from_env():
6+
return (
7+
os.environ.get("MACA_PATH")
8+
or os.environ.get("MACA_HOME")
9+
or os.environ.get("MACA_ROOT")
10+
or ""
11+
).strip()
12+
13+
14+
def metax_hpc_compiler_include_dirs():
15+
"""Directories needed so g++ finds cuda_runtime_api.h (cu-bridge) when compiling against PyTorch c10/cuda headers on MetaX/HPCC."""
16+
maca = _maca_root_from_env()
17+
if not maca:
18+
return []
19+
return [
20+
os.path.join(maca, "tools", "cu-bridge", "include"),
21+
os.path.join(maca, "include", "hcr"),
22+
os.path.join(maca, "include"),
23+
]
24+
25+
26+
def _prepend_path_var(name, prefixes):
27+
"""Prepend colon-separated *prefixes* to env var *name* (POSIX)."""
28+
if not prefixes:
29+
return
30+
chunk = ":".join(prefixes)
31+
cur = os.environ.get(name, "")
32+
os.environ[name] = f"{chunk}:{cur}" if cur else chunk
33+
34+
35+
def ensure_metax_hpc_compiler_includes():
36+
"""
37+
Prepend HPCC/cu-bridge includes to CPATH, CPLUS_INCLUDE_PATH, and C_INCLUDE_PATH.
38+
g++ uses CPLUS_INCLUDE_PATH for .cc files; C_INCLUDE_PATH alone is not enough.
39+
"""
40+
dirs = metax_hpc_compiler_include_dirs()
41+
if not dirs:
42+
return
43+
for var in ("CPATH", "CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH"):
44+
_prepend_path_var(var, dirs)
45+
46+
47+
def _parse_xmake_cli_flag_values(flags: str):
48+
"""Parse a string like '--metax-gpu=y --aten=y' into {key: value}."""
49+
parts = flags.replace("=", " ").split()
50+
d = {}
51+
i = 0
52+
n = len(parts)
53+
while i < n:
54+
p = parts[i]
55+
if p.startswith("--") and len(p) > 2:
56+
key = p[2:].lower()
57+
i += 1
58+
if i < n and not parts[i].startswith("--"):
59+
d[key] = parts[i].lower()
60+
i += 1
61+
else:
62+
d[key] = "y"
63+
else:
64+
i += 1
65+
return d
66+
67+
68+
def _truthy_flag_value(v: str) -> bool:
69+
return v in ("y", "yes", "true", "1", "on")
70+
71+
72+
def xmake_flags_need_metax_aten_torch_includes(flags: str) -> bool:
73+
"""True when install.py-style args enable MetaX GPU and ATen (PyTorch) together."""
74+
d = _parse_xmake_cli_flag_values(flags)
75+
return _truthy_flag_value(d.get("metax-gpu", "n")) and _truthy_flag_value(
76+
d.get("aten", "n")
77+
)
78+
79+
580
def set_env():
681
if os.environ.get("INFINI_ROOT") == None:
782
os.environ["INFINI_ROOT"] = os.path.expanduser("~/.infini")

src/infinicore/adaptor/aten_adaptor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
3232
options);
3333
}
3434

35-
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
35+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API)
3636
c10::cuda::CUDAStream get_cuda_stream() {
3737
return c10::cuda::getStreamFromExternal(
3838
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());

src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@
44

55
#include <stdexcept>
66

7+
#ifdef ENABLE_FLASH_ATTN
8+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API)
9+
#include <c10/cuda/CUDAGuard.h>
10+
#endif
11+
#endif
12+
13+
#if defined(ENABLE_METAX_API)
14+
#define INFINICORE_FLASH_OP(name) ::name
15+
#else
16+
#define INFINICORE_FLASH_OP(name) flash::name
17+
#endif
18+
719
namespace infinicore::op::mha_kvcache_impl::flashattn {
820

921
struct PlannedMeta {
@@ -33,22 +45,24 @@ void *plan(Tensor out,
3345

3446
void run(void *planned_meta) {
3547
#ifdef ENABLE_FLASH_ATTN
48+
#ifdef ENABLE_NVIDIA_API
3649
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
50+
#elif defined(ENABLE_METAX_API)
51+
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
52+
#endif
3753
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
3854

39-
auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out);
40-
auto q = infinicore::adaptor::to_aten_tensor(p->q);
41-
#if defined(ENABLE_NVIDIA_API)
42-
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
43-
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
44-
#elif defined(ENABLE_QY_API)
55+
// FlashAttention kernels expect standard dense layout (contiguous last dimension).
56+
auto out_at = infinicore::adaptor::to_aten_tensor(p->out);
57+
const bool out_need_copy_back = !out_at.is_contiguous();
58+
auto out_tensor = out_need_copy_back ? out_at.contiguous() : out_at;
59+
auto q = infinicore::adaptor::to_aten_tensor(p->q).contiguous();
4560
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache).contiguous();
4661
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache).contiguous();
47-
#endif
48-
auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor(p->seqlens_k));
49-
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
62+
auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor(p->seqlens_k).contiguous());
63+
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table).contiguous());
5064
auto alibi_slopes = p->alibi_slopes
51-
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes))
65+
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes).contiguous())
5266
: std::nullopt;
5367

5468
std::optional<const at::Tensor> k_new = std::nullopt;
@@ -65,7 +79,11 @@ void run(void *planned_meta) {
6579
auto out = use_dynamic_out ? std::optional<at::Tensor>(std::nullopt)
6680
: std::optional<at::Tensor>(out_tensor);
6781

68-
auto result = flash::mha_fwd_kvcache(
82+
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
83+
std::optional<at::Tensor> flash_attn_mars_ext = std::nullopt;
84+
#endif
85+
86+
auto result = INFINICORE_FLASH_OP(mha_fwd_kvcache)(
6987
q,
7088
k_cache,
7189
v_cache,
@@ -85,11 +103,19 @@ void run(void *planned_meta) {
85103
-1,
86104
0.0f,
87105
false,
88-
0);
106+
0
107+
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
108+
,
109+
flash_attn_mars_ext
110+
#endif
111+
);
89112

90113
if (use_dynamic_out) {
91114
out_tensor.copy_(result[0]);
92115
}
116+
if (out_need_copy_back) {
117+
out_at.copy_(out_tensor);
118+
}
93119
#else
94120
throw std::runtime_error("FlashAttention is not enabled in this build");
95121
#endif

0 commit comments

Comments
 (0)