55
66local CUTLASS_ROOT = os.getenv (" CUTLASS_ROOT" ) or os.getenv (" CUTLASS_HOME" ) or os.getenv (" CUTLASS_PATH" )
77
8- if CUTLASS_ROOT ~= nil then
9- add_includedirs (CUTLASS_ROOT )
10- end
11-
128local FLASH_ATTN_ROOT = get_config (" flash-attn" )
139
1410local INFINI_ROOT = os.getenv (" INFINI_ROOT" ) or (os.getenv (is_host (" windows" ) and " HOMEPATH" or " HOME" ) .. " /.infini" )
@@ -38,6 +34,40 @@ target("infiniop-nvidia")
3834 target :add (" linkdirs" , path .directory (path .directory (nvcc_path )) .. " /lib64/stubs" )
3935 target :add (" links" , " cuda" )
4036 end
37+
38+ -- Auto-detect CUDA arch when no explicit --cuda_arch
39+ local arch_opt = get_config (" cuda_arch" )
40+ if not arch_opt or type (arch_opt ) ~= " string" then
41+ local ok , sm_str = os .iorunv (" nvidia-smi" , {" --query-gpu=compute_cap" , " --format=csv,noheader,nounits" })
42+ if ok and sm_str then
43+ local major , minor = sm_str :match (" (%d+)%.(%d+)" )
44+ if major then
45+ local sm = tonumber (major ) * 10 + tonumber (minor )
46+ local archs = {}
47+ if sm >= 75 then table.insert (archs , " sm_75" ) end
48+ if sm >= 80 then table.insert (archs , " sm_80" ) end
49+ if sm >= 86 then table.insert (archs , " sm_86" ) end
50+ if sm >= 89 then table.insert (archs , " sm_89" ) end
51+ -- H100 (sm_90a): use sm_90a for cutlass 3.x
52+ if sm == 90 then
53+ target :add (" cuflags" , " -gencode=arch=compute_90a,code=sm_90a" )
54+ elseif sm > 90 then
55+ table.insert (archs , " sm_90" )
56+ end
57+ if # archs == 0 then
58+ target :add (" cugencodes" , " native" )
59+ end
60+ for _ , arch in ipairs (archs ) do
61+ local compute = arch :gsub (" sm_" , " compute_" )
62+ target :add (" cuflags" , " -gencode=arch=" .. compute .. " ,code=" .. arch )
63+ end
64+ else
65+ target :add (" cugencodes" , " native" )
66+ end
67+ else
68+ target :add (" cugencodes" , " native" )
69+ end
70+ end
4171 end )
4272
4373 if is_plat (" windows" ) then
@@ -63,15 +93,19 @@ target("infiniop-nvidia")
6393
6494 add_cuflags (" -Xcompiler=-Wno-error=deprecated-declarations" , " -Xcompiler=-Wno-error=unused-function" )
6595
96+ -- Cutlass: enable I8 Gemm when CUTLASS_ROOT is set
97+ if CUTLASS_ROOT ~= nil then
98+ add_defines (" ENABLE_CUTLASS_API" )
99+ add_includedirs (CUTLASS_ROOT , CUTLASS_ROOT .. " /include" , CUTLASS_ROOT .. " /tools/util/include" )
100+ end
101+
66102 local arch_opt = get_config (" cuda_arch" )
67103 if arch_opt and type (arch_opt ) == " string" then
68104 for _ , arch in ipairs (arch_opt :split (" ," )) do
69105 arch = arch :trim ()
70106 local compute = arch :gsub (" sm_" , " compute_" )
71107 add_cuflags (" -gencode=arch=" .. compute .. " ,code=" .. arch )
72108 end
73- else
74- add_cugencodes (" native" )
75109 end
76110
77111 set_languages (" cxx17" )
@@ -151,13 +185,15 @@ target("flash-attn-nvidia")
151185 local PYTHON_INCLUDE = os .iorunv (" python" , {" -c" , " import sysconfig; print(sysconfig.get_paths()['include'])" }):trim ()
152186 local PYTHON_LIB_DIR = os .iorunv (" python" , {" -c" , " import sysconfig; print(sysconfig.get_config_var('LIBDIR'))" }):trim ()
153187 local LIB_PYTHON = os .iorunv (" python" , {" -c" , " import glob,sysconfig,os;print(glob.glob(os.path.join(sysconfig.get_config_var('LIBDIR'),'libpython*.so'))[0])" }):trim ()
154-
188+
155189 -- Include dirs (needed for both device and host)
156190 target :add (" includedirs" , FLASH_ATTN_ROOT .. " /csrc/flash_attn/src" , {public = false })
157191 target :add (" includedirs" , TORCH_DIR .. " /include/torch/csrc/api/include" , {public = false })
158192 target :add (" includedirs" , TORCH_DIR .. " /include" , {public = false })
159193 target :add (" includedirs" , PYTHON_INCLUDE , {public = false })
160- target :add (" includedirs" , CUTLASS_ROOT .. " /include" , {public = false })
194+ if CUTLASS_ROOT ~= nil then
195+ target :add (" includedirs" , CUTLASS_ROOT .. " /include" , {public = false })
196+ end
161197 target :add (" includedirs" , FLASH_ATTN_ROOT .. " /csrc/flash_attn" , {public = false })
162198
163199 -- Link libraries
@@ -167,10 +203,10 @@ target("flash-attn-nvidia")
167203
168204 add_files (FLASH_ATTN_ROOT .. " /csrc/flash_attn/flash_api.cpp" )
169205 add_files (FLASH_ATTN_ROOT .. " /csrc/flash_attn/src/*.cu" )
170-
206+
171207 -- Link options
172208 add_ldflags (" -Wl,--no-undefined" , {force = true })
173-
209+
174210 -- Compile options
175211 add_cxflags (" -fPIC" , {force = true })
176212 add_cuflags (" -Xcompiler=-fPIC" )
0 commit comments