Skip to content

Commit a8f1ace

Browse files
authored
[FIX] Skip metal target tag registration for unsupported LLVM CPUs (#19427)
1 parent d7282a3 commit a8f1ace

3 files changed

Lines changed: 37 additions & 1 deletion

File tree

python/tvm/target/codegen.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,24 @@ def llvm_get_vector_width(target=None):
229229
return _ffi_api.llvm_get_vector_width(target)
230230

231231

232+
def llvm_is_valid_cpu(cpu, triple):
233+
"""Check if a CPU name is valid for the given LLVM triple.
234+
235+
Parameters
236+
----------
237+
cpu : str
238+
The CPU name to check (e.g. "apple-m1").
239+
triple : str
240+
The LLVM target triple (e.g. "arm64-apple-macos").
241+
242+
Returns
243+
-------
244+
is_valid : bool
245+
True if the CPU name is recognized by LLVM for the given triple.
246+
"""
247+
return _ffi_api.llvm_is_valid_cpu(cpu, triple)
248+
249+
232250
def llvm_version_major(allow_none=False):
233251
"""Get the major LLVM version.
234252

python/tvm/target/tag_registry/metal.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,17 @@
1818

1919
from .registry import register_tag
2020

21+
_METAL_HOST_TRIPLE = "arm64-apple-macos"
22+
2123

2224
def _register_metal_tag(name, max_threads, shared_mem, warp_size, mcpu):
25+
try:
26+
from ..codegen import llvm_is_valid_cpu
27+
28+
if not llvm_is_valid_cpu(mcpu, _METAL_HOST_TRIPLE):
29+
return
30+
except Exception: # pylint: disable=broad-except
31+
pass # LLVM not available; register unconditionally
2332
register_tag(
2433
name,
2534
{
@@ -29,7 +38,7 @@ def _register_metal_tag(name, max_threads, shared_mem, warp_size, mcpu):
2938
"thread_warp_size": warp_size,
3039
"host": {
3140
"kind": "llvm",
32-
"mtriple": "arm64-apple-macos",
41+
"mtriple": _METAL_HOST_TRIPLE,
3342
"mcpu": mcpu,
3443
},
3544
},

src/target/llvm/llvm_module.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,15 @@ static void LLVMReflectionRegister() {
743743
LLVMTargetInfo llvm_target(*llvm_instance, use_target);
744744
return llvm_target.TargetHasCPUFeature(feature);
745745
})
746+
.def("target.llvm_is_valid_cpu",
747+
[](ffi::String cpu, ffi::String triple) -> bool {
748+
auto llvm_instance = std::make_unique<LLVMInstance>();
749+
ffi::Map<ffi::String, ffi::Any> target_map;
750+
target_map.Set("kind", ffi::String("llvm"));
751+
target_map.Set("mtriple", triple);
752+
LLVMTargetInfo llvm_backend(*llvm_instance, Target(target_map));
753+
return llvm_backend.IsValidCPU(std::string(cpu));
754+
})
746755
.def("target.llvm_version_major", []() -> int { return TVM_LLVM_VERSION / 10; })
747756
.def("ffi.Module.load_from_file.ll",
748757
[](std::string filename, std::string fmt) -> ffi::Module {

0 commit comments

Comments
 (0)