Skip to content

Commit af889fd

Browse files
committed
fix format
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 9c4d486 commit af889fd

File tree

6 files changed

+50
-38
lines changed

6 files changed

+50
-38
lines changed

include/infinicore/adaptor/aten_adaptor.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
99
#include <c10/cuda/CUDAStream.h>
10+
#include <c10/cuda/CUDAGuard.h>
1011
#endif
1112

1213
#ifdef ENABLE_NVIDIA_API

scripts/install.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import sys
55
from set_env import (
66
set_env,
7-
ensure_metax_hpc_compiler_includes,
8-
xmake_flags_need_metax_aten_torch_includes,
7+
ensure_aten_torch_compiler_includes,
8+
xmake_flags_need_aten_torch_compiler_includes,
99
)
1010

1111
PROJECT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -16,8 +16,8 @@ def run_cmd(cmd):
1616

1717

1818
def install(xmake_config_flags=""):
19-
if xmake_flags_need_metax_aten_torch_includes(xmake_config_flags):
20-
ensure_metax_hpc_compiler_includes()
19+
if xmake_flags_need_aten_torch_compiler_includes(xmake_config_flags):
20+
ensure_aten_torch_compiler_includes()
2121
run_cmd(f"xmake f {xmake_config_flags} -cv")
2222
run_cmd("xmake")
2323
run_cmd("xmake install")

scripts/set_env.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,15 @@
22
import platform
33

44

5-
def _maca_root_from_env():
6-
return (
7-
os.environ.get("MACA_PATH")
8-
or os.environ.get("MACA_HOME")
9-
or os.environ.get("MACA_ROOT")
10-
or ""
11-
).strip()
12-
13-
14-
def metax_hpc_compiler_include_dirs():
15-
"""Directories needed so g++ finds cuda_runtime_api.h (cu-bridge) when compiling against PyTorch c10/cuda headers on MetaX/HPCC."""
16-
maca = _maca_root_from_env()
17-
if not maca:
18-
return []
19-
return [
20-
os.path.join(maca, "tools", "cu-bridge", "include"),
21-
os.path.join(maca, "include", "hcr"),
22-
os.path.join(maca, "include"),
23-
]
5+
def _hpcc_toolkit_root() -> str:
6+
"""HPCC/MACA install root (cu-bridge, headers). Env vars first; else common container path."""
7+
for key in ("MACA_PATH", "MACA_HOME", "MACA_ROOT"):
8+
v = os.environ.get(key, "").strip()
9+
if v:
10+
return v
11+
if os.path.isdir("/opt/hpcc"):
12+
return "/opt/hpcc"
13+
return ""
2414

2515

2616
def _prepend_path_var(name, prefixes):
@@ -32,14 +22,16 @@ def _prepend_path_var(name, prefixes):
3222
os.environ[name] = f"{chunk}:{cur}" if cur else chunk
3323

3424

35-
def ensure_metax_hpc_compiler_includes():
36-
"""
37-
Prepend HPCC/cu-bridge includes to CPATH, CPLUS_INCLUDE_PATH, and C_INCLUDE_PATH.
38-
g++ uses CPLUS_INCLUDE_PATH for .cc files; C_INCLUDE_PATH alone is not enough.
39-
"""
40-
dirs = metax_hpc_compiler_include_dirs()
41-
if not dirs:
25+
def ensure_aten_torch_compiler_includes() -> None:
26+
"""If HPCC root is known, prepend cu-bridge + HPCC headers for g++ compiling ATen .cc (c10/cuda)."""
27+
root = _hpcc_toolkit_root()
28+
if not root:
4229
return
30+
dirs = [
31+
os.path.join(root, "tools", "cu-bridge", "include"),
32+
os.path.join(root, "include", "hcr"),
33+
os.path.join(root, "include"),
34+
]
4335
for var in ("CPATH", "CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH"):
4436
_prepend_path_var(var, dirs)
4537

@@ -69,12 +61,20 @@ def _truthy_flag_value(v: str) -> bool:
6961
return v in ("y", "yes", "true", "1", "on")
7062

7163

72-
def xmake_flags_need_metax_aten_torch_includes(flags: str) -> bool:
73-
"""True when install.py-style args enable MetaX GPU and ATen (PyTorch) together."""
64+
# xmake.lua GPU / accelerator backends (any of these + aten may compile C++ against torch+cuda-style headers).
65+
_XMAKE_GPU_BACKEND_KEYS = frozenset(
66+
{
67+
"metax-gpu",
68+
}
69+
)
70+
71+
72+
def xmake_flags_need_aten_torch_compiler_includes(flags: str) -> bool:
73+
"""True when ATen is enabled with any GPU/accelerator backend (install.py / xmake f ...)."""
7474
d = _parse_xmake_cli_flag_values(flags)
75-
return _truthy_flag_value(d.get("metax-gpu", "n")) and _truthy_flag_value(
76-
d.get("aten", "n")
77-
)
75+
if not _truthy_flag_value(d.get("aten", "n")):
76+
return False
77+
return any(_truthy_flag_value(d.get(k, "n")) for k in _XMAKE_GPU_BACKEND_KEYS)
7878

7979

8080
def set_env():

src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ VarlenFlashPrepared prepare_varlen_flash_tensors(PlannedMeta *p) {
8181
t.max_seqlen_q = p->max_seqlen_q;
8282
t.max_seqlen_k = p->max_seqlen_k;
8383
t.alibi_slopes = p->alibi_slopes
84-
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes).contiguous())
85-
: std::nullopt;
84+
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes).contiguous())
85+
: std::nullopt;
8686
t.scale = p->scale;
8787
return t;
8888
}

src/infiniop/ops/binary_cross_entropy_with_logits/metax/binary_cross_entropy_with_logits_metax.maca

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#include "../../../devices/metax/metax_common.h"
22
#include "../../../devices/metax/metax_handle.h"
33
#include "../../../devices/metax/metax_kernel_common.h"
4+
45
#include "binary_cross_entropy_with_logits_metax.h"
5-
#include <hc_runtime.h>
6+
67
#include <type_traits>
78

89
namespace op::bce_with_logits::metax {

src/infiniop/ops/equal/metax/equal_metax.maca

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,37 @@ infiniStatus_t Descriptor::create(
1313
Descriptor **desc_ptr,
1414
infiniopTensorDescriptor_t out_desc,
1515
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
16+
1617
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
18+
1719
const auto &a_desc = input_desc_vec.at(0);
1820
auto compute_dtype = a_desc->dtype();
1921
auto out_dtype = out_desc->dtype();
22+
2023
const auto &b_desc = input_desc_vec.at(1);
2124
const auto &c_shape = out_desc->shape();
2225
const auto &a_shape = a_desc->shape();
2326
const auto &b_shape = b_desc->shape();
27+
2428
CHECK_DTYPE(compute_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16,
2529
INFINI_DTYPE_I32, INFINI_DTYPE_I64, INFINI_DTYPE_F64);
30+
2631
CHECK_DTYPE(out_dtype, INFINI_DTYPE_BOOL);
32+
2733
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
34+
2835
CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, compute_dtype, out_desc, input_desc_vec)
36+
2937
return INFINI_STATUS_SUCCESS;
3038
}
39+
3140
infiniStatus_t Descriptor::calculate(
3241
void *workspace,
3342
size_t workspace_size,
3443
void *output,
3544
std::vector<const void *> inputs,
3645
void *stream) const {
46+
3747
if (workspace_size < _workspace_size) {
3848
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
3949
}

0 commit comments

Comments
 (0)