From 01bb2c713272dbf06c49710c4259af136c07cdf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 28 Jan 2026 13:19:28 +0100 Subject: [PATCH 1/5] build2cmake: always generate kernel components for all backends This change always generates the kernel components for all backends. This is a necessary step for merging the preambles later. --- build2cmake/src/templates/cpu/kernel.cmake | 12 +- build2cmake/src/templates/cpu/preamble.cmake | 2 + build2cmake/src/templates/metal/kernel.cmake | 12 +- .../src/templates/metal/preamble.cmake | 2 + build2cmake/src/templates/xpu/kernel.cmake | 14 +- build2cmake/src/torch/common.rs | 12 ++ build2cmake/src/torch/cpu.rs | 54 +---- build2cmake/src/torch/cuda.rs | 94 +-------- build2cmake/src/torch/kernel.rs | 191 ++++++++++++++++++ build2cmake/src/torch/metal.rs | 54 +---- build2cmake/src/torch/mod.rs | 2 + build2cmake/src/torch/xpu.rs | 53 +---- 12 files changed, 258 insertions(+), 244 deletions(-) create mode 100644 build2cmake/src/torch/kernel.rs diff --git a/build2cmake/src/templates/cpu/kernel.cmake b/build2cmake/src/templates/cpu/kernel.cmake index 092dd58d..b17bcef7 100644 --- a/build2cmake/src/templates/cpu/kernel.cmake +++ b/build2cmake/src/templates/cpu/kernel.cmake @@ -1,5 +1,7 @@ -cpu_kernel_component(SRC - SOURCES {{ sources }} - {% if includes %}INCLUDES "{{ includes }}"{% endif %} - {% if cxx_flags %}CXX_FLAGS "{{ cxx_flags }}"{% endif %} -) +if(GPU_LANG STREQUAL "CPU") + cpu_kernel_component(SRC + SOURCES {{ sources }} + {% if includes %}INCLUDES "{{ includes }}"{% endif %} + {% if cxx_flags %}CXX_FLAGS "{{ cxx_flags }}"{% endif %} + ) +endif() diff --git a/build2cmake/src/templates/cpu/preamble.cmake b/build2cmake/src/templates/cpu/preamble.cmake index c9efb313..2b118987 100644 --- a/build2cmake/src/templates/cpu/preamble.cmake +++ b/build2cmake/src/templates/cpu/preamble.cmake @@ -42,6 +42,8 @@ if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }}) endif() {% endif %} +set(GPU_LANG "CPU") + add_compile_definitions(CPU_KERNEL) # Initialize SRC list for kernel and binding sources diff --git a/build2cmake/src/templates/metal/kernel.cmake b/build2cmake/src/templates/metal/kernel.cmake index 328e4f9a..6c198311 100644 --- a/build2cmake/src/templates/metal/kernel.cmake +++ b/build2cmake/src/templates/metal/kernel.cmake @@ -1,5 +1,7 @@ -metal_kernel_component(SRC - SOURCES {{ sources }} - {% if includes %}INCLUDES "{{ includes }}"{% endif %} - {% if cxx_flags %}CXX_FLAGS "{{ cxx_flags }}"{% endif %} -) +if(GPU_LANG STREQUAL "METAL") + metal_kernel_component(SRC + SOURCES {{ sources }} + {% if includes %}INCLUDES "{{ includes }}"{% endif %} + {% if cxx_flags %}CXX_FLAGS "{{ cxx_flags }}"{% endif %} + ) +endif() diff --git a/build2cmake/src/templates/metal/preamble.cmake b/build2cmake/src/templates/metal/preamble.cmake index c24ba779..8f67ca0f 100644 --- a/build2cmake/src/templates/metal/preamble.cmake +++ b/build2cmake/src/templates/metal/preamble.cmake @@ -42,6 +42,8 @@ if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }}) endif() {% endif %} +set(GPU_LANG "METAL") + add_compile_definitions(METAL_KERNEL) # Initialize SRC list for kernel and binding sources diff --git a/build2cmake/src/templates/xpu/kernel.cmake b/build2cmake/src/templates/xpu/kernel.cmake index 3886aacf..2538f898 100644 --- a/build2cmake/src/templates/xpu/kernel.cmake +++ b/build2cmake/src/templates/xpu/kernel.cmake @@ -1,6 +1,8 @@ -xpu_kernel_component(SRC - SOURCES {{ sources }} - {% if includes %}INCLUDES "{{ includes }}"{% endif %} - {% if cxx_flags %}CXX_FLAGS "{{ cxx_flags }}"{% endif %} - {% if sycl_flags %}SYCL_FLAGS "{{ sycl_flags }}"{% endif %} -) +if(GPU_LANG STREQUAL "SYCL") + xpu_kernel_component(SRC + SOURCES {{ sources }} + {% if includes %}INCLUDES "{{ includes }}"{% endif %} + {% if cxx_flags %}CXX_FLAGS "{{ cxx_flags }}"{% endif %} + {% if sycl_flags %}SYCL_FLAGS "{{ sycl_flags }}"{% endif %} + ) +endif() diff --git a/build2cmake/src/torch/common.rs b/build2cmake/src/torch/common.rs index 5c4c995d..c5eb162d 100644 --- a/build2cmake/src/torch/common.rs +++ b/build2cmake/src/torch/common.rs @@ -52,3 +52,15 @@ pub fn write_metadata(backend: Backend, general: &General, file_set: &mut FileSe Ok(()) } + +pub fn prefix_and_join_includes(includes: impl AsRef<[S]>) -> String +where + S: AsRef, +{ + includes + .as_ref() + .iter() + .map(|include| format!("${{CMAKE_SOURCE_DIR}}/{}", include.as_ref())) + .collect_vec() + .join(";") +} diff --git a/build2cmake/src/torch/cpu.rs b/build2cmake/src/torch/cpu.rs index da98ab66..b93fa592 100644 --- a/build2cmake/src/torch/cpu.rs +++ b/build2cmake/src/torch/cpu.rs @@ -4,13 +4,13 @@ use eyre::{bail, Context, Result}; use itertools::Itertools; use minijinja::{context, Environment}; -use super::{common::write_pyproject_toml, kernel_ops_identifier}; -use crate::{ - config::{Backend, Build, Kernel, Torch}, - fileset::FileSet, - torch::common::write_metadata, - version::Version, -}; +use crate::config::{Backend, Build, Torch}; +use crate::fileset::FileSet; +use crate::torch::common::write_metadata; +use crate::torch::common::write_pyproject_toml; +use crate::torch::kernel::render_kernel_components; +use crate::torch::kernel_ops_identifier; +use crate::version::Version; static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake"); @@ -96,13 +96,7 @@ fn write_cmake( render_binding(env, torch, name, cmake_writer)?; - for (kernel_name, kernel) in build - .kernels - .iter() - .filter(|(_, kernel)| matches!(kernel, Kernel::Cpu { .. })) - { - render_kernel(env, kernel_name, kernel, cmake_writer)?; - } + render_kernel_components(env, build, cmake_writer)?; render_extension(env, name, ops_name, cmake_writer)?; @@ -154,38 +148,6 @@ pub fn render_extension( Ok(()) } -pub fn render_kernel( - env: &Environment, - kernel_name: &str, - kernel: &Kernel, - write: &mut impl Write, -) -> Result<()> { - // Easier to do in Rust than Jinja. - let sources = kernel - .src() - .iter() - .map(|src| format!("\"{src}\"")) - .collect_vec() - .join("\n"); - - env.get_template("cpu/kernel.cmake") - .wrap_err("Cannot get kernel template")? - .render_to_write( - context! { - cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")), - includes => kernel.include().map(prefix_and_join_includes), - kernel_name => kernel_name, - sources => sources, - }, - &mut *write, - ) - .wrap_err("Cannot render kernel template")?; - - write.write_all(b"\n")?; - - Ok(()) -} - fn render_preamble( env: &Environment, name: &str, diff --git a/build2cmake/src/torch/cuda.rs b/build2cmake/src/torch/cuda.rs index d433feb2..970f5bc3 100644 --- a/build2cmake/src/torch/cuda.rs +++ b/build2cmake/src/torch/cuda.rs @@ -4,13 +4,14 @@ use std::io::Write; use std::path::PathBuf; use eyre::{bail, Context, Result}; -use itertools::Itertools; use minijinja::{context, Environment}; -use super::common::write_pyproject_toml; -use super::kernel_ops_identifier; -use crate::config::{Backend, Build, Dependency, Kernel, Torch}; +use crate::config::{Backend, Build, Dependency, Torch}; +use crate::torch::common::prefix_and_join_includes; use crate::torch::common::write_metadata; +use crate::torch::common::write_pyproject_toml; +use crate::torch::kernel::render_kernel_components; +use crate::torch::kernel_ops_identifier; use crate::version::Version; use crate::FileSet; @@ -181,13 +182,7 @@ fn write_cmake( render_binding(env, torch, name, cmake_writer)?; - for (kernel_name, kernel) in build - .kernels - .iter() - .filter(|(_, kernel)| kernel.backend() == backend) - { - render_kernel(env, kernel_name, kernel, cmake_writer)?; - } + render_kernel_components(env, build, cmake_writer)?; render_extension(env, name, ops_name, cmake_writer)?; @@ -312,71 +307,6 @@ fn render_deps( Ok(()) } -pub fn render_kernel( - env: &Environment, - kernel_name: &str, - kernel: &Kernel, - write: &mut impl Write, -) -> Result<()> { - // Easier to do in Rust than Jinja. - let sources = kernel - .src() - .iter() - .map(|src| format!("\"{src}\"")) - .collect_vec() - .join("\n"); - - let (cuda_capabilities, rocm_archs, cuda_flags, hip_flags, cuda_minver) = match kernel { - Kernel::Cuda { - cuda_capabilities, - cuda_flags, - cuda_minver, - .. - } => ( - cuda_capabilities.as_deref(), - None, - cuda_flags.as_deref(), - None, - cuda_minver.as_ref(), - ), - Kernel::Rocm { - rocm_archs, - hip_flags, - .. - } => ( - None, - rocm_archs.as_deref(), - None, - hip_flags.as_deref(), - None, - ), - _ => unreachable!("Unsupported kernel type for CUDA rendering"), - }; - - env.get_template("cuda/kernel.cmake") - .wrap_err("Cannot get kernel template")? - .render_to_write( - context! { - cuda_capabilities => cuda_capabilities, - cuda_flags => cuda_flags.map(|flags| flags.join(";")), - cuda_minver => cuda_minver.map(ToString::to_string), - cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")), - rocm_archs => rocm_archs, - hip_flags => hip_flags.map(|flags| flags.join(";")), - includes => kernel.include().map(prefix_and_join_includes), - kernel_name => kernel_name, - supports_hipify => matches!(kernel, Kernel::Rocm{ .. }), - sources => sources, - }, - &mut *write, - ) - .wrap_err("Cannot render kernel template")?; - - write.write_all(b"\n")?; - - Ok(()) -} - pub fn render_extension( env: &Environment, name: &str, @@ -428,15 +358,3 @@ pub fn render_preamble( Ok(()) } - -fn prefix_and_join_includes(includes: impl AsRef<[S]>) -> String -where - S: AsRef, -{ - includes - .as_ref() - .iter() - .map(|include| format!("${{CMAKE_SOURCE_DIR}}/{}", include.as_ref())) - .collect_vec() - .join(";") -} diff --git a/build2cmake/src/torch/kernel.rs b/build2cmake/src/torch/kernel.rs new file mode 100644 index 00000000..236fb199 --- /dev/null +++ b/build2cmake/src/torch/kernel.rs @@ -0,0 +1,191 @@ +use std::io::Write; + +use eyre::{Context, Result}; +use itertools::Itertools; +use minijinja::{context, Environment}; + +use crate::config::{Build, Kernel}; +use crate::torch::common::prefix_and_join_includes; + +pub fn render_kernel_components( + env: &Environment, + build: &Build, + write: &mut impl Write, +) -> Result<()> { + for (kernel_name, kernel) in build.kernels.iter() { + render_kernel_component(env, kernel_name, kernel, write)?; + } + + Ok(()) +} + +fn render_kernel_component( + env: &Environment, + kernel_name: &str, + kernel: &Kernel, + write: &mut impl Write, +) -> Result<()> { + // Easier to do in Rust than Jinja. + let sources = kernel + .src() + .iter() + .map(|src| format!("\"{src}\"")) + .collect_vec() + .join("\n"); + + match kernel { + Kernel::Cpu { .. } => { + render_kernel_component_cpu(env, kernel_name, kernel, sources, write)? + } + Kernel::Cuda { .. } | Kernel::Rocm { .. } => { + render_kernel_component_cuda(env, kernel_name, kernel, sources, write)? + } + Kernel::Metal { .. } => { + render_kernel_component_metal(env, kernel_name, kernel, sources, write)? + } + Kernel::Xpu { .. } => { + render_kernel_component_xpu(env, kernel_name, kernel, sources, write)? + } + } + + Ok(()) +} + +pub fn render_kernel_component_cpu( + env: &Environment, + kernel_name: &str, + kernel: &Kernel, + sources: String, + write: &mut impl Write, +) -> Result<()> { + env.get_template("cpu/kernel.cmake") + .wrap_err("Cannot get kernel template")? + .render_to_write( + context! { + cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")), + includes => kernel.include().map(prefix_and_join_includes), + kernel_name => kernel_name, + sources => sources, + }, + &mut *write, + ) + .wrap_err("Cannot render kernel template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +fn render_kernel_component_cuda( + env: &Environment, + kernel_name: &str, + kernel: &Kernel, + sources: String, + write: &mut impl Write, +) -> Result<()> { + let (cuda_capabilities, rocm_archs, cuda_flags, hip_flags, cuda_minver) = match kernel { + Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + .. + } => ( + cuda_capabilities.as_deref(), + None, + cuda_flags.as_deref(), + None, + cuda_minver.as_ref(), + ), + Kernel::Rocm { + rocm_archs, + hip_flags, + .. + } => ( + None, + rocm_archs.as_deref(), + None, + hip_flags.as_deref(), + None, + ), + _ => unreachable!("Unsupported kernel type for CUDA rendering"), + }; + + env.get_template("cuda/kernel.cmake") + .wrap_err("Cannot get kernel template")? + .render_to_write( + context! { + cuda_capabilities => cuda_capabilities, + cuda_flags => cuda_flags.map(|flags| flags.join(";")), + cuda_minver => cuda_minver.map(ToString::to_string), + cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")), + rocm_archs => rocm_archs, + hip_flags => hip_flags.map(|flags| flags.join(";")), + includes => kernel.include().map(prefix_and_join_includes), + kernel_name => kernel_name, + supports_hipify => matches!(kernel, Kernel::Rocm{ .. }), + sources => sources, + }, + &mut *write, + ) + .wrap_err("Cannot render kernel template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +pub fn render_kernel_component_metal( + env: &Environment, + kernel_name: &str, + kernel: &Kernel, + sources: String, + write: &mut impl Write, +) -> Result<()> { + env.get_template("metal/kernel.cmake") + .wrap_err("Cannot get kernel template")? + .render_to_write( + context! { + cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")), + includes => kernel.include().map(prefix_and_join_includes), + kernel_name => kernel_name, + sources => sources, + }, + &mut *write, + ) + .wrap_err("Cannot render kernel template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +pub fn render_kernel_component_xpu( + env: &Environment, + kernel_name: &str, + kernel: &Kernel, + sources: String, + write: &mut impl Write, +) -> Result<()> { + let sycl_flags = match kernel { + Kernel::Xpu { sycl_flags, .. } => sycl_flags.as_deref(), + _ => unreachable!("Unsupported kernel type for XPU rendering"), + }; + + env.get_template("xpu/kernel.cmake") + .wrap_err("Cannot get kernel template")? + .render_to_write( + context! { + cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")), + sycl_flags => sycl_flags.map(|flags| flags.join(";")), + includes => kernel.include().map(prefix_and_join_includes), + kernel_name => kernel_name, + sources => sources, + }, + &mut *write, + ) + .wrap_err("Cannot render kernel template")?; + + write.write_all(b"\n")?; + + Ok(()) +} diff --git a/build2cmake/src/torch/metal.rs b/build2cmake/src/torch/metal.rs index 6891fb1f..93089fe9 100644 --- a/build2cmake/src/torch/metal.rs +++ b/build2cmake/src/torch/metal.rs @@ -4,13 +4,13 @@ use eyre::{bail, Context, Result}; use itertools::Itertools; use minijinja::{context, Environment}; -use super::{common::write_pyproject_toml, kernel_ops_identifier}; -use crate::{ - config::{Backend, Build, Kernel, Torch}, - fileset::FileSet, - torch::common::write_metadata, - version::Version, -}; +use crate::config::{Backend, Build, Torch}; +use crate::fileset::FileSet; +use crate::torch::common::write_metadata; +use crate::torch::common::write_pyproject_toml; +use crate::torch::kernel::render_kernel_components; +use crate::torch::kernel_ops_identifier; +use crate::version::Version; static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake"); @@ -112,13 +112,7 @@ fn write_cmake( render_binding(env, torch, name, cmake_writer)?; - for (kernel_name, kernel) in build - .kernels - .iter() - .filter(|(_, kernel)| matches!(kernel, Kernel::Metal { .. })) - { - render_kernel(env, kernel_name, kernel, cmake_writer)?; - } + render_kernel_components(env, build, cmake_writer)?; render_extension(env, name, ops_name, cmake_writer)?; @@ -170,38 +164,6 @@ pub fn render_extension( Ok(()) } -pub fn render_kernel( - env: &Environment, - kernel_name: &str, - kernel: &Kernel, - write: &mut impl Write, -) -> Result<()> { - // Easier to do in Rust than Jinja. - let sources = kernel - .src() - .iter() - .map(|src| format!("\"{src}\"")) - .collect_vec() - .join("\n"); - - env.get_template("metal/kernel.cmake") - .wrap_err("Cannot get kernel template")? - .render_to_write( - context! { - cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")), - includes => kernel.include().map(prefix_and_join_includes), - kernel_name => kernel_name, - sources => sources, - }, - &mut *write, - ) - .wrap_err("Cannot render kernel template")?; - - write.write_all(b"\n")?; - - Ok(()) -} - fn render_preamble( env: &Environment, name: &str, diff --git a/build2cmake/src/torch/mod.rs b/build2cmake/src/torch/mod.rs index f9cbfeaf..03a737f9 100644 --- a/build2cmake/src/torch/mod.rs +++ b/build2cmake/src/torch/mod.rs @@ -6,6 +6,8 @@ pub use cuda::write_torch_ext_cuda; pub mod common; +pub mod kernel; + mod metal; pub use metal::write_torch_ext_metal; diff --git a/build2cmake/src/torch/xpu.rs b/build2cmake/src/torch/xpu.rs index 80dbc747..03738a68 100644 --- a/build2cmake/src/torch/xpu.rs +++ b/build2cmake/src/torch/xpu.rs @@ -6,10 +6,11 @@ use eyre::{bail, Context, Result}; use itertools::Itertools; use minijinja::{context, Environment}; -use super::common::write_pyproject_toml; -use super::kernel_ops_identifier; -use crate::config::{Backend, Build, Dependency, Kernel, Torch}; +use crate::config::{Backend, Build, Dependency, Torch}; use crate::torch::common::write_metadata; +use crate::torch::common::write_pyproject_toml; +use crate::torch::kernel::render_kernel_components; +use crate::torch::kernel_ops_identifier; use crate::version::Version; use crate::FileSet; @@ -168,13 +169,7 @@ fn write_cmake( render_binding(env, torch, name, cmake_writer)?; - for (kernel_name, kernel) in build - .kernels - .iter() - .filter(|(_, kernel)| matches!(kernel, Kernel::Xpu { .. })) - { - render_kernel(env, kernel_name, kernel, cmake_writer)?; - } + render_kernel_components(env, build, cmake_writer)?; render_extension(env, name, ops_name, cmake_writer)?; @@ -238,44 +233,6 @@ fn render_deps( Ok(()) } -pub fn render_kernel( - env: &Environment, - kernel_name: &str, - kernel: &Kernel, - write: &mut impl Write, -) -> Result<()> { - // Easier to do in Rust than Jinja. - let sources = kernel - .src() - .iter() - .map(|src| format!("\"{src}\"")) - .collect_vec() - .join("\n"); - - let sycl_flags = match kernel { - Kernel::Xpu { sycl_flags, .. } => sycl_flags.as_deref(), - _ => unreachable!("Unsupported kernel type for XPU rendering"), - }; - - env.get_template("xpu/kernel.cmake") - .wrap_err("Cannot get kernel template")? - .render_to_write( - context! { - cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")), - sycl_flags => sycl_flags.map(|flags| flags.join(";")), - includes => kernel.include().map(prefix_and_join_includes), - kernel_name => kernel_name, - sources => sources, - }, - &mut *write, - ) - .wrap_err("Cannot render kernel template")?; - - write.write_all(b"\n")?; - - Ok(()) -} - pub fn render_extension( env: &Environment, name: &str, From 2609a3f608034a43bc57fd11879d6d4fdb16ba16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 28 Jan 2026 13:10:04 +0000 Subject: [PATCH 2/5] Factor out `write_torch_registration_macros` to `common` mod --- build2cmake/src/torch/common.rs | 15 +++++++++++++++ build2cmake/src/torch/cpu.rs | 13 +------------ build2cmake/src/torch/cuda.rs | 13 +------------ build2cmake/src/torch/metal.rs | 13 +------------ build2cmake/src/torch/xpu.rs | 13 +------------ 5 files changed, 19 insertions(+), 48 deletions(-) diff --git a/build2cmake/src/torch/common.rs b/build2cmake/src/torch/common.rs index c5eb162d..5d12b6fe 100644 --- a/build2cmake/src/torch/common.rs +++ b/build2cmake/src/torch/common.rs @@ -1,3 +1,5 @@ +use std::path::PathBuf; + use eyre::{Context, Result}; use itertools::Itertools; use minijinja::{context, Environment}; @@ -6,6 +8,8 @@ use crate::config::{Backend, General}; use crate::metadata::Metadata; use crate::FileSet; +static REGISTRATION_H: &str = include_str!("../templates/registration.h"); + pub fn write_pyproject_toml( env: &Environment, backend: Backend, @@ -64,3 +68,14 @@ where .collect_vec() .join(";") } + +pub fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> { + let mut path = PathBuf::new(); + path.push("torch-ext"); + path.push("registration.h"); + file_set + .entry(path) + .extend_from_slice(REGISTRATION_H.as_bytes()); + + Ok(()) +} diff --git a/build2cmake/src/torch/cpu.rs b/build2cmake/src/torch/cpu.rs index b93fa592..b7499bc0 100644 --- a/build2cmake/src/torch/cpu.rs +++ b/build2cmake/src/torch/cpu.rs @@ -8,13 +8,13 @@ use crate::config::{Backend, Build, Torch}; use crate::fileset::FileSet; use crate::torch::common::write_metadata; use crate::torch::common::write_pyproject_toml; +use crate::torch::common::write_torch_registration_macros; use crate::torch::kernel::render_kernel_components; use crate::torch::kernel_ops_identifier; use crate::version::Version; static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake"); -static REGISTRATION_H: &str = include_str!("../templates/registration.h"); pub fn write_torch_ext_cpu( env: &Environment, @@ -224,17 +224,6 @@ fn write_setup_py( Ok(()) } -fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> { - let mut path = PathBuf::new(); - path.push("torch-ext"); - path.push("registration.h"); - file_set - .entry(path) - .extend_from_slice(REGISTRATION_H.as_bytes()); - - Ok(()) -} - fn prefix_and_join_includes(includes: impl AsRef<[S]>) -> String where S: AsRef, diff --git a/build2cmake/src/torch/cuda.rs b/build2cmake/src/torch/cuda.rs index 970f5bc3..ab849fb5 100644 --- a/build2cmake/src/torch/cuda.rs +++ b/build2cmake/src/torch/cuda.rs @@ -10,6 +10,7 @@ use crate::config::{Backend, Build, Dependency, Torch}; use crate::torch::common::prefix_and_join_includes; use crate::torch::common::write_metadata; use crate::torch::common::write_pyproject_toml; +use crate::torch::common::write_torch_registration_macros; use crate::torch::kernel::render_kernel_components; use crate::torch::kernel_ops_identifier; use crate::version::Version; @@ -18,7 +19,6 @@ use crate::FileSet; static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake"); static WINDOWS_UTILS: &str = include_str!("../templates/windows.cmake"); -static REGISTRATION_H: &str = include_str!("../templates/registration.h"); static HIPIFY: &str = include_str!("../templates/cuda/hipify.py"); pub fn write_torch_ext_cuda( @@ -66,17 +66,6 @@ pub fn write_torch_ext_cuda( Ok(file_set) } -fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> { - let mut path = PathBuf::new(); - path.push("torch-ext"); - path.push("registration.h"); - file_set - .entry(path) - .extend_from_slice(REGISTRATION_H.as_bytes()); - - Ok(()) -} - fn write_setup_py( env: &Environment, torch: &Torch, diff --git a/build2cmake/src/torch/metal.rs b/build2cmake/src/torch/metal.rs index 93089fe9..b430e46d 100644 --- a/build2cmake/src/torch/metal.rs +++ b/build2cmake/src/torch/metal.rs @@ -8,13 +8,13 @@ use crate::config::{Backend, Build, Torch}; use crate::fileset::FileSet; use crate::torch::common::write_metadata; use crate::torch::common::write_pyproject_toml; +use crate::torch::common::write_torch_registration_macros; use crate::torch::kernel::render_kernel_components; use crate::torch::kernel_ops_identifier; use crate::version::Version; static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake"); -static REGISTRATION_H: &str = include_str!("../templates/registration.h"); static COMPILE_METAL_CMAKE: &str = include_str!("../templates/metal/compile-metal.cmake"); static METALLIB_TO_HEADER_PY: &str = include_str!("../templates/metal/metallib_to_header.py"); @@ -240,17 +240,6 @@ fn write_setup_py( Ok(()) } -fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> { - let mut path = PathBuf::new(); - path.push("torch-ext"); - path.push("registration.h"); - file_set - .entry(path) - .extend_from_slice(REGISTRATION_H.as_bytes()); - - Ok(()) -} - fn prefix_and_join_includes(includes: impl AsRef<[S]>) -> String where S: AsRef, diff --git a/build2cmake/src/torch/xpu.rs b/build2cmake/src/torch/xpu.rs index 03738a68..5911edc7 100644 --- a/build2cmake/src/torch/xpu.rs +++ b/build2cmake/src/torch/xpu.rs @@ -9,6 +9,7 @@ use minijinja::{context, Environment}; use crate::config::{Backend, Build, Dependency, Torch}; use crate::torch::common::write_metadata; use crate::torch::common::write_pyproject_toml; +use crate::torch::common::write_torch_registration_macros; use crate::torch::kernel::render_kernel_components; use crate::torch::kernel_ops_identifier; use crate::version::Version; @@ -16,7 +17,6 @@ use crate::FileSet; static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake"); -static REGISTRATION_H: &str = include_str!("../templates/registration.h"); static WINDOWS_UTILS: &str = include_str!("../templates/windows.cmake"); pub fn write_torch_ext_xpu( @@ -62,17 +62,6 @@ pub fn write_torch_ext_xpu( Ok(file_set) } -fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> { - let mut path = PathBuf::new(); - path.push("torch-ext"); - path.push("registration.h"); - file_set - .entry(path) - .extend_from_slice(REGISTRATION_H.as_bytes()); - - Ok(()) -} - fn write_setup_py( env: &Environment, torch: &Torch, From 0c9b25de36f28ffd3a4299e8cb7e4ed6c13ea0a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 28 Jan 2026 13:17:55 +0000 Subject: [PATCH 3/5] Move `write_ops_py` to `common` mod --- build2cmake/src/torch/common.rs | 25 +++++++++++++++++++++++++ build2cmake/src/torch/cpu.rs | 26 +------------------------- build2cmake/src/torch/cuda.rs | 26 +------------------------- build2cmake/src/torch/metal.rs | 26 +------------------------- build2cmake/src/torch/xpu.rs | 26 +------------------------- 5 files changed, 29 insertions(+), 100 deletions(-) diff --git a/build2cmake/src/torch/common.rs b/build2cmake/src/torch/common.rs index 5d12b6fe..c75811ac 100644 --- a/build2cmake/src/torch/common.rs +++ b/build2cmake/src/torch/common.rs @@ -79,3 +79,28 @@ pub fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> { Ok(()) } + +pub fn write_ops_py( + env: &Environment, + name: &str, + ops_name: &str, + file_set: &mut FileSet, +) -> Result<()> { + let mut path = PathBuf::new(); + path.push("torch-ext"); + path.push(name); + path.push("_ops.py"); + let writer = file_set.entry(path); + + env.get_template("_ops.py") + .wrap_err("Cannot get _ops.py template")? + .render_to_write( + context! { + ops_name => ops_name, + }, + writer, + ) + .wrap_err("Cannot render kernel template")?; + + Ok(()) +} diff --git a/build2cmake/src/torch/cpu.rs b/build2cmake/src/torch/cpu.rs index b7499bc0..0d7ea5f4 100644 --- a/build2cmake/src/torch/cpu.rs +++ b/build2cmake/src/torch/cpu.rs @@ -7,6 +7,7 @@ use minijinja::{context, Environment}; use crate::config::{Backend, Build, Torch}; use crate::fileset::FileSet; use crate::torch::common::write_metadata; +use crate::torch::common::write_ops_py; use crate::torch::common::write_pyproject_toml; use crate::torch::common::write_torch_registration_macros; use crate::torch::kernel::render_kernel_components; @@ -172,31 +173,6 @@ fn render_preamble( Ok(()) } -fn write_ops_py( - env: &Environment, - name: &str, - ops_name: &str, - file_set: &mut FileSet, -) -> Result<()> { - let mut path = PathBuf::new(); - path.push("torch-ext"); - path.push(name); - path.push("_ops.py"); - let writer = file_set.entry(path); - - env.get_template("_ops.py") - .wrap_err("Cannot get _ops.py template")? - .render_to_write( - context! { - ops_name => ops_name, - }, - writer, - ) - .wrap_err("Cannot render kernel template")?; - - Ok(()) -} - fn write_setup_py( env: &Environment, torch: &Torch, diff --git a/build2cmake/src/torch/cuda.rs b/build2cmake/src/torch/cuda.rs index ab849fb5..fcb3db30 100644 --- a/build2cmake/src/torch/cuda.rs +++ b/build2cmake/src/torch/cuda.rs @@ -9,6 +9,7 @@ use minijinja::{context, Environment}; use crate::config::{Backend, Build, Dependency, Torch}; use crate::torch::common::prefix_and_join_includes; use crate::torch::common::write_metadata; +use crate::torch::common::write_ops_py; use crate::torch::common::write_pyproject_toml; use crate::torch::common::write_torch_registration_macros; use crate::torch::kernel::render_kernel_components; @@ -93,31 +94,6 @@ fn write_setup_py( Ok(()) } -fn write_ops_py( - env: &Environment, - name: &str, - ops_name: &str, - file_set: &mut FileSet, -) -> Result<()> { - let mut path = PathBuf::new(); - path.push("torch-ext"); - path.push(name); - path.push("_ops.py"); - let writer = file_set.entry(path); - - env.get_template("_ops.py") - .wrap_err("Cannot get _ops.py template")? - .render_to_write( - context! { - ops_name => ops_name, - }, - writer, - ) - .wrap_err("Cannot render kernel template")?; - - Ok(()) -} - fn write_cmake( env: &Environment, backend: Backend, diff --git a/build2cmake/src/torch/metal.rs b/build2cmake/src/torch/metal.rs index b430e46d..f25424d1 100644 --- a/build2cmake/src/torch/metal.rs +++ b/build2cmake/src/torch/metal.rs @@ -7,6 +7,7 @@ use minijinja::{context, Environment}; use crate::config::{Backend, Build, Torch}; use crate::fileset::FileSet; use crate::torch::common::write_metadata; +use crate::torch::common::write_ops_py; use crate::torch::common::write_pyproject_toml; use crate::torch::common::write_torch_registration_macros; use crate::torch::kernel::render_kernel_components; @@ -188,31 +189,6 @@ fn render_preamble( Ok(()) } -fn write_ops_py( - env: &Environment, - name: &str, - ops_name: &str, - file_set: &mut FileSet, -) -> Result<()> { - let mut path = PathBuf::new(); - path.push("torch-ext"); - path.push(name); - path.push("_ops.py"); - let writer = file_set.entry(path); - - env.get_template("_ops.py") - .wrap_err("Cannot get _ops.py template")? - .render_to_write( - context! { - ops_name => ops_name, - }, - writer, - ) - .wrap_err("Cannot render kernel template")?; - - Ok(()) -} - fn write_setup_py( env: &Environment, torch: &Torch, diff --git a/build2cmake/src/torch/xpu.rs b/build2cmake/src/torch/xpu.rs index 5911edc7..998fa0ad 100644 --- a/build2cmake/src/torch/xpu.rs +++ b/build2cmake/src/torch/xpu.rs @@ -8,6 +8,7 @@ use minijinja::{context, Environment}; use crate::config::{Backend, Build, Dependency, Torch}; use crate::torch::common::write_metadata; +use crate::torch::common::write_ops_py; use crate::torch::common::write_pyproject_toml; use crate::torch::common::write_torch_registration_macros; use crate::torch::kernel::render_kernel_components; @@ -89,31 +90,6 @@ fn write_setup_py( Ok(()) } -fn write_ops_py( - env: &Environment, - name: &str, - ops_name: &str, - file_set: &mut FileSet, -) -> Result<()> { - let mut path = PathBuf::new(); - path.push("torch-ext"); - path.push(name); - path.push("_ops.py"); - let writer = file_set.entry(path); - - env.get_template("_ops.py") - .wrap_err("Cannot get _ops.py template")? - .render_to_write( - context! { - ops_name => ops_name, - }, - writer, - ) - .wrap_err("Cannot render _ops.py template")?; - - Ok(()) -} - fn write_cmake( env: &Environment, build: &Build, From 5f4d27b3ecf98c820a779dff1f683d7641d2d7c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 28 Jan 2026 14:04:19 +0000 Subject: [PATCH 4/5] Factor out CMake helper file generation --- build2cmake/src/torch/common.rs | 33 ++++++++++++++++++++++++++++++++ build2cmake/src/torch/cpu.rs | 18 ++--------------- build2cmake/src/torch/cuda.rs | 34 ++------------------------------- build2cmake/src/torch/metal.rs | 34 ++------------------------------- build2cmake/src/torch/xpu.rs | 27 ++------------------------ 5 files changed, 41 insertions(+), 105 deletions(-) diff --git a/build2cmake/src/torch/common.rs b/build2cmake/src/torch/common.rs index c75811ac..53f8f405 100644 --- a/build2cmake/src/torch/common.rs +++ b/build2cmake/src/torch/common.rs @@ -9,6 +9,12 @@ use crate::metadata::Metadata; use crate::FileSet; static REGISTRATION_H: &str = include_str!("../templates/registration.h"); +static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); +static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake"); +static WINDOWS_UTILS: &str = include_str!("../templates/windows.cmake"); +static HIPIFY: &str = include_str!("../templates/cuda/hipify.py"); +static COMPILE_METAL_CMAKE: &str = include_str!("../templates/metal/compile-metal.cmake"); +static METALLIB_TO_HEADER_PY: &str = include_str!("../templates/metal/metallib_to_header.py"); pub fn write_pyproject_toml( env: &Environment, @@ -104,3 +110,30 @@ pub fn write_ops_py( Ok(()) } + +/// Helper function to write a file to the cmake subdirectory +pub fn write_cmake_file(file_set: &mut FileSet, filename: &str, content: &[u8]) { + let mut path = PathBuf::new(); + path.push("cmake"); + path.push(filename); + file_set.entry(path).extend_from_slice(content); +} + +/// Writes all CMake helper files that any backend might need. +/// Each backend will use only the files it references in its CMakeLists.txt. +pub fn write_cmake_helpers(file_set: &mut FileSet) { + write_cmake_file(file_set, "utils.cmake", CMAKE_UTILS.as_bytes()); + write_cmake_file(file_set, "kernel.cmake", CMAKE_KERNEL.as_bytes()); + write_cmake_file(file_set, "windows.cmake", WINDOWS_UTILS.as_bytes()); + write_cmake_file(file_set, "hipify.py", HIPIFY.as_bytes()); + write_cmake_file( + file_set, + "compile-metal.cmake", + COMPILE_METAL_CMAKE.as_bytes(), + ); + write_cmake_file( + file_set, + "metallib_to_header.py", + METALLIB_TO_HEADER_PY.as_bytes(), + ); +} diff --git a/build2cmake/src/torch/cpu.rs b/build2cmake/src/torch/cpu.rs index 0d7ea5f4..8b402969 100644 --- a/build2cmake/src/torch/cpu.rs +++ b/build2cmake/src/torch/cpu.rs @@ -6,6 +6,7 @@ use minijinja::{context, Environment}; use crate::config::{Backend, Build, Torch}; use crate::fileset::FileSet; +use crate::torch::common::write_cmake_helpers; use crate::torch::common::write_metadata; use crate::torch::common::write_ops_py; use crate::torch::common::write_pyproject_toml; @@ -14,9 +15,6 @@ use crate::torch::kernel::render_kernel_components; use crate::torch::kernel_ops_identifier; use crate::version::Version; -static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); -static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake"); - pub fn write_torch_ext_cpu( env: &Environment, build: &Build, @@ -68,19 +66,7 @@ fn write_cmake( ops_name: &str, file_set: &mut FileSet, ) -> Result<()> { - let mut utils_path = PathBuf::new(); - utils_path.push("cmake"); - utils_path.push("utils.cmake"); - file_set - .entry(utils_path.clone()) - .extend_from_slice(CMAKE_UTILS.as_bytes()); - - let mut kernel_path = PathBuf::new(); - kernel_path.push("cmake"); - kernel_path.push("kernel.cmake"); - file_set - .entry(kernel_path.clone()) - .extend_from_slice(CMAKE_KERNEL.as_bytes()); + write_cmake_helpers(file_set); let cmake_writer = file_set.entry("CMakeLists.txt"); diff --git a/build2cmake/src/torch/cuda.rs b/build2cmake/src/torch/cuda.rs index fcb3db30..e94cfcf6 100644 --- a/build2cmake/src/torch/cuda.rs +++ b/build2cmake/src/torch/cuda.rs @@ -8,6 +8,7 @@ use minijinja::{context, Environment}; use crate::config::{Backend, Build, Dependency, Torch}; use crate::torch::common::prefix_and_join_includes; +use crate::torch::common::write_cmake_helpers; use crate::torch::common::write_metadata; use crate::torch::common::write_ops_py; use crate::torch::common::write_pyproject_toml; @@ -17,11 +18,6 @@ use crate::torch::kernel_ops_identifier; use crate::version::Version; use crate::FileSet; -static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); -static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake"); -static WINDOWS_UTILS: &str = include_str!("../templates/windows.cmake"); -static HIPIFY: &str = include_str!("../templates/cuda/hipify.py"); - pub fn write_torch_ext_cuda( env: &Environment, backend: Backend, @@ -103,33 +99,7 @@ fn write_cmake( ops_name: &str, file_set: &mut FileSet, ) -> Result<()> { - let mut utils_path = PathBuf::new(); - utils_path.push("cmake"); - utils_path.push("utils.cmake"); - file_set - .entry(utils_path.clone()) - .extend_from_slice(CMAKE_UTILS.as_bytes()); - - let mut kernel_path = PathBuf::new(); - kernel_path.push("cmake"); - kernel_path.push("kernel.cmake"); - file_set - .entry(kernel_path.clone()) - .extend_from_slice(CMAKE_KERNEL.as_bytes()); - - let mut windows_utils_path = PathBuf::new(); - windows_utils_path.push("cmake"); - windows_utils_path.push("windows.cmake"); - file_set - .entry(windows_utils_path.clone()) - .extend_from_slice(WINDOWS_UTILS.as_bytes()); - - let mut hipify_path = PathBuf::new(); - hipify_path.push("cmake"); - hipify_path.push("hipify.py"); - file_set - .entry(hipify_path.clone()) - .extend_from_slice(HIPIFY.as_bytes()); + write_cmake_helpers(file_set); let cmake_writer = file_set.entry("CMakeLists.txt"); diff --git a/build2cmake/src/torch/metal.rs b/build2cmake/src/torch/metal.rs index f25424d1..bbafc33c 100644 --- a/build2cmake/src/torch/metal.rs +++ b/build2cmake/src/torch/metal.rs @@ -6,6 +6,7 @@ use minijinja::{context, Environment}; use crate::config::{Backend, Build, Torch}; use crate::fileset::FileSet; +use crate::torch::common::write_cmake_helpers; use crate::torch::common::write_metadata; use crate::torch::common::write_ops_py; use crate::torch::common::write_pyproject_toml; @@ -14,11 +15,6 @@ use crate::torch::kernel::render_kernel_components; use crate::torch::kernel_ops_identifier; use crate::version::Version; -static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); -static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake"); -static COMPILE_METAL_CMAKE: &str = include_str!("../templates/metal/compile-metal.cmake"); -static METALLIB_TO_HEADER_PY: &str = include_str!("../templates/metal/metallib_to_header.py"); - pub fn write_torch_ext_metal( env: &Environment, build: &Build, @@ -70,33 +66,7 @@ fn write_cmake( ops_name: &str, file_set: &mut FileSet, ) -> Result<()> { - let mut utils_path = PathBuf::new(); - utils_path.push("cmake"); - utils_path.push("utils.cmake"); - file_set - .entry(utils_path.clone()) - .extend_from_slice(CMAKE_UTILS.as_bytes()); - - let mut kernel_path = PathBuf::new(); - kernel_path.push("cmake"); - kernel_path.push("kernel.cmake"); - file_set - .entry(kernel_path.clone()) - .extend_from_slice(CMAKE_KERNEL.as_bytes()); - - let mut compile_metal_path = PathBuf::new(); - compile_metal_path.push("cmake"); - compile_metal_path.push("compile-metal.cmake"); - file_set - .entry(compile_metal_path) - .extend_from_slice(COMPILE_METAL_CMAKE.as_bytes()); - - let mut metallib_to_header_path = PathBuf::new(); - metallib_to_header_path.push("cmake"); - metallib_to_header_path.push("metallib_to_header.py"); - file_set - .entry(metallib_to_header_path) - .extend_from_slice(METALLIB_TO_HEADER_PY.as_bytes()); + write_cmake_helpers(file_set); let cmake_writer = file_set.entry("CMakeLists.txt"); diff --git a/build2cmake/src/torch/xpu.rs b/build2cmake/src/torch/xpu.rs index 998fa0ad..039f3ec7 100644 --- a/build2cmake/src/torch/xpu.rs +++ b/build2cmake/src/torch/xpu.rs @@ -7,6 +7,7 @@ use itertools::Itertools; use minijinja::{context, Environment}; use crate::config::{Backend, Build, Dependency, Torch}; +use crate::torch::common::write_cmake_helpers; use crate::torch::common::write_metadata; use crate::torch::common::write_ops_py; use crate::torch::common::write_pyproject_toml; @@ -16,10 +17,6 @@ use crate::torch::kernel_ops_identifier; use crate::version::Version; use crate::FileSet; -static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); -static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake"); -static WINDOWS_UTILS: &str = include_str!("../templates/windows.cmake"); - pub fn write_torch_ext_xpu( env: &Environment, build: &Build, @@ -98,27 +95,7 @@ fn write_cmake( ops_name: &str, file_set: &mut FileSet, ) -> Result<()> { - let mut utils_path = PathBuf::new(); - utils_path.push("cmake"); - utils_path.push("utils.cmake"); - file_set - .entry(utils_path.clone()) - .extend_from_slice(CMAKE_UTILS.as_bytes()); - - // Add windows.cmake for Windows-specific install targets - let mut windows_utils_path = PathBuf::new(); - windows_utils_path.push("cmake"); - windows_utils_path.push("windows.cmake"); - file_set - .entry(windows_utils_path.clone()) - .extend_from_slice(WINDOWS_UTILS.as_bytes()); - - let mut kernel_path = PathBuf::new(); - kernel_path.push("cmake"); - kernel_path.push("kernel.cmake"); - file_set - .entry(kernel_path.clone()) - .extend_from_slice(CMAKE_KERNEL.as_bytes()); + write_cmake_helpers(file_set); let cmake_writer = file_set.entry("CMakeLists.txt"); From e4f46f1e9081822b1af9a5e631b3f9a017bcfc0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 29 Jan 2026 16:21:57 +0000 Subject: [PATCH 5/5] Factor out `render_deps` function --- build2cmake/src/torch/cuda.rs | 99 +------------------------------ build2cmake/src/torch/deps.rs | 106 ++++++++++++++++++++++++++++++++++ build2cmake/src/torch/mod.rs | 2 + build2cmake/src/torch/xpu.rs | 38 +----------- 4 files changed, 112 insertions(+), 133 deletions(-) create mode 100644 build2cmake/src/torch/deps.rs diff --git a/build2cmake/src/torch/cuda.rs b/build2cmake/src/torch/cuda.rs index da319a6c..54093da1 100644 --- a/build2cmake/src/torch/cuda.rs +++ b/build2cmake/src/torch/cuda.rs @@ -1,4 +1,3 @@ -use std::collections::HashSet; use std::env; use std::io::Write; use std::path::PathBuf; @@ -6,11 +5,12 @@ use std::path::PathBuf; use eyre::{bail, Context, Result}; use minijinja::{context, Environment}; -use crate::config::{Backend, Build, Dependency, Torch}; +use crate::config::{Backend, Build, Torch}; use crate::torch::common::{ render_binding, render_extension, write_cmake_helpers, write_metadata, write_ops_py, write_pyproject_toml, write_setup_py, write_torch_registration_macros, }; +use crate::torch::deps::render_deps; use crate::torch::kernel::render_kernel_components; use crate::torch::kernel_ops_identifier; use crate::version::Version; @@ -95,101 +95,6 @@ fn write_cmake( Ok(()) } -fn render_deps( - env: &Environment, - backend: Backend, - build: &Build, - write: &mut impl Write, -) -> Result<()> { - let mut deps = HashSet::new(); - - for kernel in build - .kernels - .values() - .filter(|kernel| kernel.backend() == backend) - { - deps.extend(kernel.depends()); - } - - for dep in deps { - match dep { - Dependency::Cutlass2_10 => { - env.get_template("cuda/dep-cutlass.cmake") - .wrap_err("Cannot get CUTLASS dependency template")? - .render_to_write( - context! { - version => "2.10.0", - }, - &mut *write, - ) - .wrap_err("Cannot render CUTLASS dependency template")?; - } - Dependency::Cutlass3_5 => { - env.get_template("cuda/dep-cutlass.cmake") - .wrap_err("Cannot get CUTLASS dependency template")? - .render_to_write( - context! { - version => "3.5.1", - }, - &mut *write, - ) - .wrap_err("Cannot render CUTLASS dependency template")?; - } - Dependency::Cutlass3_6 => { - env.get_template("cuda/dep-cutlass.cmake") - .wrap_err("Cannot get CUTLASS dependency template")? - .render_to_write( - context! { - version => "3.6.0", - }, - &mut *write, - ) - .wrap_err("Cannot render CUTLASS dependency template")?; - } - Dependency::Cutlass3_8 => { - env.get_template("cuda/dep-cutlass.cmake") - .wrap_err("Cannot get CUTLASS dependency template")? - .render_to_write( - context! { - version => "3.8.0", - }, - &mut *write, - ) - .wrap_err("Cannot render CUTLASS dependency template")?; - } - Dependency::Cutlass3_9 => { - env.get_template("cuda/dep-cutlass.cmake") - .wrap_err("Cannot get CUTLASS dependency template")? - .render_to_write( - context! { - version => "3.9.2", - }, - &mut *write, - ) - .wrap_err("Cannot render CUTLASS dependency template")?; - } - Dependency::Cutlass4_0 => { - env.get_template("cuda/dep-cutlass.cmake") - .wrap_err("Cannot get CUTLASS dependency template")? - .render_to_write( - context! { - version => "4.0.0", - }, - &mut *write, - ) - .wrap_err("Cannot render CUTLASS dependency template")?; - } - Dependency::Torch => (), - _ => { - eprintln!("Warning: CUDA backend doesn't need/support dependency: {dep:?}"); - } - }; - write.write_all(b"\n")?; - } - - Ok(()) -} - pub fn render_preamble( env: &Environment, name: &str, diff --git a/build2cmake/src/torch/deps.rs b/build2cmake/src/torch/deps.rs new file mode 100644 index 00000000..5a333460 --- /dev/null +++ b/build2cmake/src/torch/deps.rs @@ -0,0 +1,106 @@ +use std::collections::HashSet; +use std::io::Write; + +use eyre::{Context, Result}; +use minijinja::{context, Environment}; + +use crate::config::{Backend, Build, Dependency}; + +pub fn render_deps( + env: &Environment, + backend: Backend, + build: &Build, + write: &mut impl Write, +) -> Result<()> { + let mut deps = HashSet::new(); + + for kernel in build + .kernels + .values() + .filter(|kernel| kernel.backend() == backend) + { + deps.extend(kernel.depends()); + } + + for dep in deps { + match dep { + Dependency::Cutlass2_10 => { + env.get_template("cuda/dep-cutlass.cmake") + .wrap_err("Cannot get CUTLASS dependency template")? + .render_to_write( + context! { + version => "2.10.0", + }, + &mut *write, + ) + .wrap_err("Cannot render CUTLASS dependency template")?; + } + Dependency::Cutlass3_5 => { + env.get_template("cuda/dep-cutlass.cmake") + .wrap_err("Cannot get CUTLASS dependency template")? + .render_to_write( + context! { + version => "3.5.1", + }, + &mut *write, + ) + .wrap_err("Cannot render CUTLASS dependency template")?; + } + Dependency::Cutlass3_6 => { + env.get_template("cuda/dep-cutlass.cmake") + .wrap_err("Cannot get CUTLASS dependency template")? + .render_to_write( + context! { + version => "3.6.0", + }, + &mut *write, + ) + .wrap_err("Cannot render CUTLASS dependency template")?; + } + Dependency::Cutlass3_8 => { + env.get_template("cuda/dep-cutlass.cmake") + .wrap_err("Cannot get CUTLASS dependency template")? + .render_to_write( + context! { + version => "3.8.0", + }, + &mut *write, + ) + .wrap_err("Cannot render CUTLASS dependency template")?; + } + Dependency::Cutlass3_9 => { + env.get_template("cuda/dep-cutlass.cmake") + .wrap_err("Cannot get CUTLASS dependency template")? + .render_to_write( + context! { + version => "3.9.2", + }, + &mut *write, + ) + .wrap_err("Cannot render CUTLASS dependency template")?; + } + Dependency::Cutlass4_0 => { + env.get_template("cuda/dep-cutlass.cmake") + .wrap_err("Cannot get CUTLASS dependency template")? + .render_to_write( + context! { + version => "4.0.0", + }, + &mut *write, + ) + .wrap_err("Cannot render CUTLASS dependency template")?; + } + Dependency::CutlassSycl => { + env.get_template("xpu/dep-cutlass-sycl.cmake")? + .render_to_write(context! {}, &mut *write)?; + } + Dependency::Torch => (), + _ => { + eprintln!("Warning: {backend:?} backend doesn't need/support dependency: {dep:?}"); + } + } + write.write_all(b"\n")?; + } + + Ok(()) +} diff --git a/build2cmake/src/torch/mod.rs b/build2cmake/src/torch/mod.rs index 03a737f9..0a45d55b 100644 --- a/build2cmake/src/torch/mod.rs +++ b/build2cmake/src/torch/mod.rs @@ -6,6 +6,8 @@ pub use cuda::write_torch_ext_cuda; pub mod common; +pub(crate) mod deps; + pub mod kernel; mod metal; diff --git a/build2cmake/src/torch/xpu.rs b/build2cmake/src/torch/xpu.rs index 5404bb4c..8bb6f5fe 100644 --- a/build2cmake/src/torch/xpu.rs +++ b/build2cmake/src/torch/xpu.rs @@ -1,15 +1,15 @@ -use std::collections::HashSet; use std::io::Write; use std::path::PathBuf; use eyre::{bail, Context, Result}; use minijinja::{context, Environment}; -use crate::config::{Backend, Build, Dependency, Torch}; +use crate::config::{Backend, Build, Torch}; use crate::torch::common::{ render_binding, render_extension, write_cmake_helpers, write_metadata, write_ops_py, write_pyproject_toml, write_setup_py, write_torch_registration_macros, }; +use crate::torch::deps::render_deps; use crate::torch::kernel::render_kernel_components; use crate::torch::kernel_ops_identifier; use crate::version::Version; @@ -89,40 +89,6 @@ fn write_cmake( Ok(()) } -fn render_deps( - env: &Environment, - backend: Backend, - build: &Build, - write: &mut impl Write, -) -> Result<()> { - let mut deps = HashSet::new(); - - for kernel in build - .kernels - .values() - .filter(|kernel| kernel.backend() == backend) - { - deps.extend(kernel.depends()); - } - - for dep in deps { - match dep { - Dependency::CutlassSycl => { - env.get_template("xpu/dep-cutlass-sycl.cmake")? - .render_to_write(context! {}, &mut *write)?; - } - Dependency::Torch => (), - _ => { - // XPU supports CUTLASS-SYCL instead of CUTLASS - eprintln!("Warning: XPU backend doesn't need/support dependency: {dep:?}"); - } - } - write.write_all(b"\n")?; - } - - Ok(()) -} - pub fn render_preamble( env: &Environment, name: &str,