diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index ace0e8f7184f..79d5f3dc0908 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -229,6 +229,24 @@ def llvm_get_vector_width(target=None): return _ffi_api.llvm_get_vector_width(target) +def llvm_is_valid_cpu(cpu, triple): + """Check if a CPU name is valid for the given LLVM triple. + + Parameters + ---------- + cpu : str + The CPU name to check (e.g. "apple-m1"). + triple : str + The LLVM target triple (e.g. "arm64-apple-macos"). + + Returns + ------- + is_valid : bool + True if the CPU name is recognized by LLVM for the given triple. + """ + return _ffi_api.llvm_is_valid_cpu(cpu, triple) + + def llvm_version_major(allow_none=False): """Get the major LLVM version. diff --git a/python/tvm/target/tag_registry/metal.py b/python/tvm/target/tag_registry/metal.py index 7e1d68b1f8c4..6727db7c3046 100644 --- a/python/tvm/target/tag_registry/metal.py +++ b/python/tvm/target/tag_registry/metal.py @@ -18,8 +18,17 @@ from .registry import register_tag +_METAL_HOST_TRIPLE = "arm64-apple-macos" + def _register_metal_tag(name, max_threads, shared_mem, warp_size, mcpu): + try: + from ..codegen import llvm_is_valid_cpu + + if not llvm_is_valid_cpu(mcpu, _METAL_HOST_TRIPLE): + return + except Exception: # pylint: disable=broad-except + pass # LLVM not available; register unconditionally register_tag( name, { @@ -29,7 +38,7 @@ def _register_metal_tag(name, max_threads, shared_mem, warp_size, mcpu): "thread_warp_size": warp_size, "host": { "kind": "llvm", - "mtriple": "arm64-apple-macos", + "mtriple": _METAL_HOST_TRIPLE, "mcpu": mcpu, }, }, diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 36cab28bc6c0..1f6a4e743b49 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -743,6 +743,15 @@ static void LLVMReflectionRegister() { LLVMTargetInfo llvm_target(*llvm_instance, use_target); return llvm_target.TargetHasCPUFeature(feature); }) + .def("target.llvm_is_valid_cpu", + [](ffi::String cpu, ffi::String triple) -> bool { + auto llvm_instance = std::make_unique(); + ffi::Map target_map; + target_map.Set("kind", ffi::String("llvm")); + target_map.Set("mtriple", triple); + LLVMTargetInfo llvm_backend(*llvm_instance, Target(target_map)); + return llvm_backend.IsValidCPU(std::string(cpu)); + }) .def("target.llvm_version_major", []() -> int { return TVM_LLVM_VERSION / 10; }) .def("ffi.Module.load_from_file.ll", [](std::string filename, std::string fmt) -> ffi::Module {