Skip to content

Commit 21bf5e1

Browse files
committed
alias maca version checker
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 72e8bc7 commit 21bf5e1

7 files changed

Lines changed: 120 additions & 86 deletions

File tree

include/infinicore/nn/parameter.hpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ class Parameter : public Tensor {
1212
Size tp_rank = 0,
1313
Size tp_size = 1);
1414

15+
Parameter(const Tensor &tensor,
16+
Size tp_dim,
17+
Size tp_rank,
18+
Size tp_size,
19+
Size num_shards);
20+
1521
Parameter(const Shape &shape,
1622
const DataType &dtype,
1723
const Device &device,
@@ -25,8 +31,9 @@ class Parameter : public Tensor {
2531

2632
protected:
2733
// Tensor parallel configs
28-
Size tp_dim_; // dimension partitioned
29-
Size tp_rank_; // rank of this partition among tp group
30-
Size tp_size_; // total number of partitions
34+
Size tp_dim_; // dimension partitioned
35+
Size tp_rank_; // rank of this partition among tp group
36+
Size tp_size_; // total number of partitions
37+
Size num_shards_ = 0; // logical shards (e.g. KV heads) when tp_size > num_kv_head
3138
};
3239
} // namespace infinicore::nn

scripts/install.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import sys
55
from set_env import (
66
set_env,
7-
ensure_aten_torch_compiler_includes,
8-
xmake_flags_need_aten_torch_compiler_includes,
7+
set_env_by_config,
98
)
109

1110
PROJECT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -16,8 +15,7 @@ def run_cmd(cmd):
1615

1716

1817
def install(xmake_config_flags=""):
19-
if xmake_flags_need_aten_torch_compiler_includes(xmake_config_flags):
20-
ensure_aten_torch_compiler_includes()
18+
set_env_by_config(xmake_config_flags)
2119
run_cmd(f"xmake f {xmake_config_flags} -cv")
2220
run_cmd("xmake")
2321
run_cmd("xmake install")

scripts/metax_env.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
3+
4+
def _first_existing_dir(paths: list[str]) -> str:
5+
for p in paths:
6+
if p and os.path.isdir(p):
7+
return p
8+
return ""
9+
10+
11+
def _metax_toolkit_root(use_mc: bool) -> str:
12+
"""Return toolkit root for MetaX builds (MACA when use-mc; otherwise HPCC)."""
13+
if use_mc:
14+
for key in ("MACA_PATH", "MACA_HOME", "MACA_ROOT"):
15+
v = os.environ.get(key, "").strip()
16+
if v:
17+
return v
18+
return _first_existing_dir(["/opt/maca"])
19+
return _first_existing_dir(["/opt/hpcc"])
20+
21+
22+
def _prepend_path_var(name: str, prefixes: list[str]) -> None:
23+
"""Prepend colon-separated *prefixes* to env var *name* (POSIX)."""
24+
if not prefixes:
25+
return
26+
chunk = ":".join(prefixes)
27+
cur = os.environ.get(name, "")
28+
os.environ[name] = f"{chunk}:{cur}" if cur else chunk
29+
30+
31+
def set_env_for_metax_gpu(
32+
flags: str,
33+
*,
34+
parse_xmake_cli_flag_values,
35+
truthy_flag_value,
36+
) -> None:
37+
"""
38+
Prepend compiler include paths needed when building ATen-enabled C++ against torch headers.
39+
40+
This chooses paths based on xmake backend flags (e.g. --metax-gpu) and toolkit selection
41+
(e.g. MetaX HPCC vs MACA when --use-mc=y).
42+
"""
43+
d = parse_xmake_cli_flag_values(flags)
44+
if not truthy_flag_value(d.get("aten", "n")):
45+
return
46+
47+
if truthy_flag_value(d.get("metax-gpu", "n")):
48+
use_mc = truthy_flag_value(d.get("use-mc", "n"))
49+
root = _metax_toolkit_root(use_mc=use_mc)
50+
if not root:
51+
return
52+
dirs = [
53+
os.path.join(root, "tools", "cu-bridge", "include"),
54+
os.path.join(root, "include", "hcr"),
55+
os.path.join(root, "include"),
56+
]
57+
for var in ("CPATH", "CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH"):
58+
_prepend_path_var(var, dirs)
59+
return
60+

scripts/set_env.py

Lines changed: 11 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,7 @@
11
import os
22
import platform
33

4-
5-
def _hpcc_toolkit_root() -> str:
6-
"""HPCC/MACA install root (cu-bridge, headers). Env vars first; else common container path."""
7-
for key in ("MACA_PATH", "MACA_HOME", "MACA_ROOT"):
8-
v = os.environ.get(key, "").strip()
9-
if v:
10-
return v
11-
if os.path.isdir("/opt/hpcc"):
12-
return "/opt/hpcc"
13-
return ""
14-
15-
16-
def _prepend_path_var(name, prefixes):
17-
"""Prepend colon-separated *prefixes* to env var *name* (POSIX)."""
18-
if not prefixes:
19-
return
20-
chunk = ":".join(prefixes)
21-
cur = os.environ.get(name, "")
22-
os.environ[name] = f"{chunk}:{cur}" if cur else chunk
23-
24-
25-
def ensure_aten_torch_compiler_includes() -> None:
26-
"""If HPCC root is known, prepend cu-bridge + HPCC headers for g++ compiling ATen .cc (c10/cuda)."""
27-
root = _hpcc_toolkit_root()
28-
if not root:
29-
return
30-
dirs = [
31-
os.path.join(root, "tools", "cu-bridge", "include"),
32-
os.path.join(root, "include", "hcr"),
33-
os.path.join(root, "include"),
34-
]
35-
for var in ("CPATH", "CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH"):
36-
_prepend_path_var(var, dirs)
4+
from metax_env import set_env_for_metax_gpu
375

386

397
def _parse_xmake_cli_flag_values(flags: str):
@@ -61,20 +29,17 @@ def _truthy_flag_value(v: str) -> bool:
6129
return v in ("y", "yes", "true", "1", "on")
6230

6331

64-
# xmake.lua GPU / accelerator backends (any of these + aten may compile C++ against torch+cuda-style headers).
65-
_XMAKE_GPU_BACKEND_KEYS = frozenset(
66-
{
67-
"metax-gpu",
68-
}
69-
)
70-
71-
72-
def xmake_flags_need_aten_torch_compiler_includes(flags: str) -> bool:
73-
"""True when ATen is enabled with any GPU/accelerator backend (install.py / xmake f ...)."""
32+
def set_env_by_config(flags: str) -> None:
33+
"""Set environment variables for InfiniCore builds with xmake config flags."""
7434
d = _parse_xmake_cli_flag_values(flags)
75-
if not _truthy_flag_value(d.get("aten", "n")):
76-
return False
77-
return any(_truthy_flag_value(d.get(k, "n")) for k in _XMAKE_GPU_BACKEND_KEYS)
35+
if _truthy_flag_value(d.get("metax-gpu", "n")):
36+
set_env_for_metax_gpu(
37+
flags,
38+
parse_xmake_cli_flag_values=_parse_xmake_cli_flag_values,
39+
truthy_flag_value=_truthy_flag_value,
40+
)
41+
else:
42+
pass
7843

7944

8045
def set_env():

src/infinicore/nn/parameter.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ Parameter::Parameter(const Tensor &tensor, Size tp_dim, Size tp_rank, Size tp_si
3030
}
3131
}
3232

33+
Parameter::Parameter(const Tensor &tensor, Size tp_dim, Size tp_rank, Size tp_size, Size num_shards)
34+
: Parameter(tensor, tp_dim, tp_rank, tp_size) {
35+
num_shards_ = num_shards;
36+
}
37+
3338
Parameter::Parameter(
3439
const Shape &shape,
3540
const DataType &dtype,

xmake.lua

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,13 @@ option_end()
167167

168168
if has_config("metax-gpu") then
169169
add_defines("ENABLE_METAX_API")
170-
-- Container torch build expects this for ATen headers on hpcc.
171-
add_defines("USE_HPCC")
172170
if has_config("use-mc") then
173171
add_defines("ENABLE_METAX_MC_API")
172+
-- MACA torch build expects USE_MACA for ATen headers (e.g. C10_WARP_SIZE).
173+
add_defines("USE_MACA")
174+
else
175+
-- HPCC torch build expects this for ATen headers on hpcc.
176+
add_defines("USE_HPCC")
174177
end
175178
includes("xmake/metax.lua")
176179
end
@@ -481,6 +484,30 @@ target("infinicore_cpp_api")
481484
-- and `xmake/qy.lua`.
482485

483486
before_build(function (target)
487+
-- MetaX + flash-attn: `flash_attn_2_cuda` may use a different `mha_fwd_kvcache` ABI
488+
-- depending on the underlying stack version. When building with MACA (`--use-mc=y`),
489+
-- the version file is typically `/opt/maca/Version.txt` (HPCC uses `/opt/hpcc/Version.txt`).
490+
if has_config("metax-gpu") and get_config("flash-attn") and get_config("flash-attn") ~= "" then
491+
local version_txt = "/opt/hpcc/Version.txt"
492+
if not os.isfile(version_txt) and has_config("use-mc") then
493+
version_txt = "/opt/maca/Version.txt"
494+
end
495+
if os.isfile(version_txt) then
496+
local content = os.iorunv("cat", {version_txt}) or ""
497+
content = content:trim()
498+
local major_str = content:match("Version:(%d+)") or content:match("^(%d+)")
499+
if major_str and major_str ~= "" then
500+
local major = tonumber(major_str)
501+
if major then
502+
local define = "INFINICORE_HPCC_VERSION_MAJOR=" .. tostring(major)
503+
target:add("defines", define)
504+
target:add("cxflags", "-D" .. define)
505+
target:add("cxxflags", "-D" .. define)
506+
end
507+
end
508+
end
509+
end
510+
484511
if has_config("aten") then
485512
local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
486513
local TORCH_DIR = outdata
@@ -531,7 +558,6 @@ target("infinicore_cpp_api")
531558
target_end()
532559

533560
target("_infinicore")
534-
add_packages("boost")
535561
if is_mode("debug") then
536562
add_defines("BOOST_STACKTRACE_USE_BACKTRACE")
537563
add_links("backtrace")

xmake/metax.lua

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ end
1111

1212
-- Resolve MetaX flash-attn .so path (used only from this file: `before_link` sandbox cannot see globals from `xmake.lua`).
1313
local FLASH_ATTN_METAX_CUDA_SO_CONTAINER_DEFAULT =
14-
"/opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-aarch64-linux-gnu.so"
14+
"/opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so"
1515

1616
local function metax_flash_attn_cuda_so_path()
1717
-- Highest priority: override the exact `.so` file to link.
@@ -41,36 +41,9 @@ local function metax_flash_attn_cuda_so_path()
4141
return container_path
4242
end
4343

44-
-- Set numeric HPCC version macro for flash-attn signature/call compatibility.
45-
-- Must be done before compiling `infinicore_cpp_api` sources.
44+
-- MetaX flash-attn link flags for pip `flash_attn_2_cuda`.
45+
-- Version/ABI macros are set in `xmake.lua` for `infinicore_cpp_api` so they apply to all sources.
4646
target("infinicore_cpp_api")
47-
before_build(function (target)
48-
if not has_config("metax-gpu") then
49-
return
50-
end
51-
if not (get_config("flash-attn") and get_config("flash-attn") ~= "") then
52-
return
53-
end
54-
55-
local version_txt = "/opt/hpcc/Version.txt"
56-
if os.isfile(version_txt) then
57-
local content = os.iorunv("cat", {version_txt}) or ""
58-
content = content:trim()
59-
-- Example: `Version:2.32.0.6`
60-
local hpcc_major_str = content:match("Version:(%d+)") or content:match("^(%d+)")
61-
if hpcc_major_str and hpcc_major_str ~= "" then
62-
local hpcc_major = tonumber(hpcc_major_str)
63-
if hpcc_major then
64-
local define = "INFINICORE_HPCC_VERSION_MAJOR=" .. tostring(hpcc_major)
65-
-- `defines` is the logical flag list for the target,
66-
-- but we also pass `-D...` directly to ensure it reaches compilation.
67-
target:add("defines", define)
68-
target:add("cxflags", "-D" .. define)
69-
target:add("cxxflags", "-D" .. define)
70-
end
71-
end
72-
end
73-
end)
7447
if get_config("flash-attn") and get_config("flash-attn") ~= "" then
7548
before_link(function (target)
7649
local flash_so_metax = metax_flash_attn_cuda_so_path()

0 commit comments

Comments
 (0)