Skip to content
This repository was archived by the owner on Apr 6, 2026. It is now read-only.

Commit da3340e

Browse files
authored
Add basic support for building Metal 🀘 kernels (#146)
This change adds support for building Metal kernels (`backend = metal`) for Apple Silicon Macs.
1 parent 31c6e9b commit da3340e

28 files changed

Lines changed: 699 additions & 73 deletions

β€Žbuild2cmake/src/config/v1.rsβ€Ž

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ pub enum Language {
5151
#[default]
5252
Cuda,
5353
CudaHipify,
54+
Metal,
5455
}
5556

5657
impl Display for Language {
5758
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5859
match self {
5960
Language::Cuda => f.write_str("cuda"),
6061
Language::CudaHipify => f.write_str("cuda-hipify"),
62+
Language::Metal => f.write_str("metal"),
6163
}
6264
}
6365
}

β€Žbuild2cmake/src/config/v2.rsβ€Ž

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,15 @@ pub struct Kernel {
8787
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
8888
pub enum Backend {
8989
Cuda,
90+
Metal,
9091
Rocm,
9192
}
9293

9394
impl Display for Backend {
9495
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
9596
match self {
9697
Backend::Cuda => write!(f, "cuda"),
98+
Backend::Metal => write!(f, "metal"),
9799
Backend::Rocm => write!(f, "rocm"),
98100
}
99101
}
@@ -105,6 +107,7 @@ impl FromStr for Backend {
105107
fn from_str(s: &str) -> Result<Self, Self::Err> {
106108
match s.to_lowercase().as_str() {
107109
"cuda" => Ok(Backend::Cuda),
110+
"metal" => Ok(Backend::Metal),
108111
"rocm" => Ok(Backend::Rocm),
109112
_ => Err(format!("Unknown backend: {}", s)),
110113
}

β€Žbuild2cmake/src/main.rsβ€Ž

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,13 @@ use eyre::{bail, ensure, Context, Result};
99
use minijinja::Environment;
1010

1111
mod torch;
12-
use torch::write_torch_ext;
13-
14-
mod torch_universal;
12+
use torch::{write_torch_ext, write_torch_ext_metal, write_torch_universal_ext};
1513

1614
mod config;
1715
use config::{Backend, Build, BuildCompat};
1816

1917
mod fileset;
2018
use fileset::FileSet;
21-
use torch_universal::write_torch_universal_ext;
2219

2320
#[derive(Parser, Debug)]
2421
#[command(version, about, long_about = None)]
@@ -108,16 +105,15 @@ fn generate_torch(
108105
env.set_trim_blocks(true);
109106
minijinja_embed::load_templates!(&mut env);
110107

111-
match (backend, build.general.universal) {
112-
(None, true) => write_torch_universal_ext(&env, &build, target_dir, force, ops_id)?,
108+
let backend = match (backend, build.general.universal) {
109+
(None, true) => return write_torch_universal_ext(&env, &build, target_dir, force, ops_id),
113110
(Some(backend), true) => bail!("Universal kernel, cannot generate for backend {}", backend),
114-
// TODO: add check if that type of backend has at least one kernel.
115111
(Some(backend), false) => {
116112
if !build.has_kernel_with_backend(&backend) {
117113
bail!("No kernels found for backend {}", backend);
118114
}
119115

120-
write_torch_ext(&env, &build, target_dir, force, ops_id)?
116+
backend
121117
}
122118
(None, false) => {
123119
let mut kernel_backends = build.backends();
@@ -139,15 +135,16 @@ fn generate_torch(
139135
);
140136
}
141137

142-
match backend {
143-
Backend::Cuda | Backend::Rocm => {
144-
write_torch_ext(&env, &build, target_dir, force, ops_id)?
145-
}
146-
}
138+
backend
147139
}
148-
}
140+
};
149141

150-
Ok(())
142+
match backend {
143+
Backend::Cuda | Backend::Rocm => {
144+
write_torch_ext(&env, backend, &build, target_dir, force, ops_id)
145+
}
146+
Backend::Metal => write_torch_ext_metal(&env, &build, target_dir, force, ops_id),
147+
}
151148
}
152149

153150
fn update_build(build_toml: PathBuf) -> Result<()> {
File renamed without changes.
File renamed without changes.

build2cmake/src/templates/preamble.cmake renamed to build2cmake/src/templates/cuda/preamble.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,16 @@ if(GPU_LANG STREQUAL "CUDA")
6060
if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA")
6161
list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}")
6262
endif()
63+
64+
add_compile_definitions(CUDA_KERNEL)
6365
elseif(GPU_LANG STREQUAL "HIP")
6466
set(ROCM_ARCHS "${HIP_SUPPORTED_ARCHS}")
6567
# TODO: remove this once we can set specific archs per source file set.
6668
override_gpu_arches(GPU_ARCHES
6769
${GPU_LANG}
6870
"${${GPU_LANG}_SUPPORTED_ARCHS}")
71+
72+
add_compile_definitions(ROCM_KERNEL)
6973
else()
7074
override_gpu_arches(GPU_ARCHES
7175
${GPU_LANG}
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
Β (0)