Skip to content

Commit 4acc528

Browse files
authored
Merge pull request #1152 from qinyiqun/fix_cutlass_root
[fix] fix CUTLASS_ROOT error when compiling nvidia_int8_gemm and simplify build setup
2 parents e16f15d + 4c57d79 commit 4acc528

2 files changed

Lines changed: 46 additions & 20 deletions

File tree

xmake.lua

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,6 @@ if has_config("cudnn") then
6767
add_defines("ENABLE_CUDNN_API")
6868
end
6969

70-
option("cutlass")
71-
set_default(false)
72-
set_showmenu(true)
73-
set_description("Whether to compile cutlass for Nvidia GPU")
74-
option_end()
75-
76-
if has_config("cutlass") then
77-
add_defines("ENABLE_CUTLASS_API")
78-
end
79-
8070
option("cuda_arch")
8171
set_showmenu(true)
8272
set_description("Set CUDA GPU architecture (e.g. sm_90)")

xmake/nvidia.lua

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@ end
55

66
local CUTLASS_ROOT = os.getenv("CUTLASS_ROOT") or os.getenv("CUTLASS_HOME") or os.getenv("CUTLASS_PATH")
77

8-
if CUTLASS_ROOT ~= nil then
9-
add_includedirs(CUTLASS_ROOT)
10-
end
11-
128
local FLASH_ATTN_ROOT = get_config("flash-attn")
139

1410
local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")
@@ -38,6 +34,40 @@ target("infiniop-nvidia")
3834
target:add("linkdirs", path.directory(path.directory(nvcc_path)) .. "/lib64/stubs")
3935
target:add("links", "cuda")
4036
end
37+
38+
-- Auto-detect CUDA arch when no explicit --cuda_arch
39+
local arch_opt = get_config("cuda_arch")
40+
if not arch_opt or type(arch_opt) ~= "string" then
41+
local ok, sm_str = os.iorunv("nvidia-smi", {"--query-gpu=compute_cap", "--format=csv,noheader,nounits"})
42+
if ok and sm_str then
43+
local major, minor = sm_str:match("(%d+)%.(%d+)")
44+
if major then
45+
local sm = tonumber(major) * 10 + tonumber(minor)
46+
local archs = {}
47+
if sm >= 75 then table.insert(archs, "sm_75") end
48+
if sm >= 80 then table.insert(archs, "sm_80") end
49+
if sm >= 86 then table.insert(archs, "sm_86") end
50+
if sm >= 89 then table.insert(archs, "sm_89") end
51+
-- H100 (sm_90a): use sm_90a for cutlass 3.x
52+
if sm == 90 then
53+
target:add("cuflags", "-gencode=arch=compute_90a,code=sm_90a")
54+
elseif sm > 90 then
55+
table.insert(archs, "sm_90")
56+
end
57+
if #archs == 0 then
58+
target:add("cugencodes", "native")
59+
end
60+
for _, arch in ipairs(archs) do
61+
local compute = arch:gsub("sm_", "compute_")
62+
target:add("cuflags", "-gencode=arch=" .. compute .. ",code=" .. arch)
63+
end
64+
else
65+
target:add("cugencodes", "native")
66+
end
67+
else
68+
target:add("cugencodes", "native")
69+
end
70+
end
4171
end)
4272

4373
if is_plat("windows") then
@@ -63,15 +93,19 @@ target("infiniop-nvidia")
6393

6494
add_cuflags("-Xcompiler=-Wno-error=deprecated-declarations", "-Xcompiler=-Wno-error=unused-function")
6595

96+
-- Cutlass: enable I8 Gemm when CUTLASS_ROOT is set
97+
if CUTLASS_ROOT ~= nil then
98+
add_defines("ENABLE_CUTLASS_API")
99+
add_includedirs(CUTLASS_ROOT, CUTLASS_ROOT .. "/include", CUTLASS_ROOT .. "/tools/util/include")
100+
end
101+
66102
local arch_opt = get_config("cuda_arch")
67103
if arch_opt and type(arch_opt) == "string" then
68104
for _, arch in ipairs(arch_opt:split(",")) do
69105
arch = arch:trim()
70106
local compute = arch:gsub("sm_", "compute_")
71107
add_cuflags("-gencode=arch=" .. compute .. ",code=" .. arch)
72108
end
73-
else
74-
add_cugencodes("native")
75109
end
76110

77111
set_languages("cxx17")
@@ -151,13 +185,15 @@ target("flash-attn-nvidia")
151185
local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim()
152186
local PYTHON_LIB_DIR = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim()
153187
local LIB_PYTHON = os.iorunv("python", {"-c", "import glob,sysconfig,os;print(glob.glob(os.path.join(sysconfig.get_config_var('LIBDIR'),'libpython*.so'))[0])"}):trim()
154-
188+
155189
-- Include dirs (needed for both device and host)
156190
target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn/src", {public = false})
157191
target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include", {public = false})
158192
target:add("includedirs", TORCH_DIR .. "/include", {public = false})
159193
target:add("includedirs", PYTHON_INCLUDE, {public = false})
160-
target:add("includedirs", CUTLASS_ROOT .. "/include", {public = false})
194+
if CUTLASS_ROOT ~= nil then
195+
target:add("includedirs", CUTLASS_ROOT .. "/include", {public = false})
196+
end
161197
target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn", {public = false})
162198

163199
-- Link libraries
@@ -167,10 +203,10 @@ target("flash-attn-nvidia")
167203

168204
add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/flash_api.cpp")
169205
add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/src/*.cu")
170-
206+
171207
-- Link options
172208
add_ldflags("-Wl,--no-undefined", {force = true})
173-
209+
174210
-- Compile options
175211
add_cxflags("-fPIC", {force = true})
176212
add_cuflags("-Xcompiler=-fPIC")

0 commit comments

Comments
 (0)