diff --git a/build2cmake/src/torch/cuda.rs b/build2cmake/src/torch/cuda.rs index da319a6c..54093da1 100644 --- a/build2cmake/src/torch/cuda.rs +++ b/build2cmake/src/torch/cuda.rs @@ -1,4 +1,3 @@ -use std::collections::HashSet; use std::env; use std::io::Write; use std::path::PathBuf; @@ -6,11 +5,12 @@ use std::path::PathBuf; use eyre::{bail, Context, Result}; use minijinja::{context, Environment}; -use crate::config::{Backend, Build, Dependency, Torch}; +use crate::config::{Backend, Build, Torch}; use crate::torch::common::{ render_binding, render_extension, write_cmake_helpers, write_metadata, write_ops_py, write_pyproject_toml, write_setup_py, write_torch_registration_macros, }; +use crate::torch::deps::render_deps; use crate::torch::kernel::render_kernel_components; use crate::torch::kernel_ops_identifier; use crate::version::Version; @@ -95,101 +95,6 @@ fn write_cmake( Ok(()) } -fn render_deps( - env: &Environment, - backend: Backend, - build: &Build, - write: &mut impl Write, -) -> Result<()> { - let mut deps = HashSet::new(); - - for kernel in build - .kernels - .values() - .filter(|kernel| kernel.backend() == backend) - { - deps.extend(kernel.depends()); - } - - for dep in deps { - match dep { - Dependency::Cutlass2_10 => { - env.get_template("cuda/dep-cutlass.cmake") - .wrap_err("Cannot get CUTLASS dependency template")? - .render_to_write( - context! { - version => "2.10.0", - }, - &mut *write, - ) - .wrap_err("Cannot render CUTLASS dependency template")?; - } - Dependency::Cutlass3_5 => { - env.get_template("cuda/dep-cutlass.cmake") - .wrap_err("Cannot get CUTLASS dependency template")? - .render_to_write( - context! { - version => "3.5.1", - }, - &mut *write, - ) - .wrap_err("Cannot render CUTLASS dependency template")?; - } - Dependency::Cutlass3_6 => { - env.get_template("cuda/dep-cutlass.cmake") - .wrap_err("Cannot get CUTLASS dependency template")? - .render_to_write( - context! { - version => "3.6.0", - }, - &mut *write, - ) - .wrap_err("Cannot render CUTLASS dependency template")?; - } - Dependency::Cutlass3_8 => { - env.get_template("cuda/dep-cutlass.cmake") - .wrap_err("Cannot get CUTLASS dependency template")? - .render_to_write( - context! { - version => "3.8.0", - }, - &mut *write, - ) - .wrap_err("Cannot render CUTLASS dependency template")?; - } - Dependency::Cutlass3_9 => { - env.get_template("cuda/dep-cutlass.cmake") - .wrap_err("Cannot get CUTLASS dependency template")? - .render_to_write( - context! { - version => "3.9.2", - }, - &mut *write, - ) - .wrap_err("Cannot render CUTLASS dependency template")?; - } - Dependency::Cutlass4_0 => { - env.get_template("cuda/dep-cutlass.cmake") - .wrap_err("Cannot get CUTLASS dependency template")? - .render_to_write( - context! { - version => "4.0.0", - }, - &mut *write, - ) - .wrap_err("Cannot render CUTLASS dependency template")?; - } - Dependency::Torch => (), - _ => { - eprintln!("Warning: CUDA backend doesn't need/support dependency: {dep:?}"); - } - }; - write.write_all(b"\n")?; - } - - Ok(()) -} - pub fn render_preamble( env: &Environment, name: &str, diff --git a/build2cmake/src/torch/deps.rs b/build2cmake/src/torch/deps.rs new file mode 100644 index 00000000..5a333460 --- /dev/null +++ b/build2cmake/src/torch/deps.rs @@ -0,0 +1,106 @@ +use std::collections::HashSet; +use std::io::Write; + +use eyre::{Context, Result}; +use minijinja::{context, Environment}; + +use crate::config::{Backend, Build, Dependency}; + +pub fn render_deps( + env: &Environment, + backend: Backend, + build: &Build, + write: &mut impl Write, +) -> Result<()> { + let mut deps = HashSet::new(); + + for kernel in build + .kernels + .values() + .filter(|kernel| kernel.backend() == backend) + { + deps.extend(kernel.depends()); + } + + for dep in deps { + match dep { + Dependency::Cutlass2_10 => { + env.get_template("cuda/dep-cutlass.cmake") + .wrap_err("Cannot get CUTLASS dependency template")? + .render_to_write( + context! { + version => "2.10.0", + }, + &mut *write, + ) + .wrap_err("Cannot render CUTLASS dependency template")?; + } + Dependency::Cutlass3_5 => { + env.get_template("cuda/dep-cutlass.cmake") + .wrap_err("Cannot get CUTLASS dependency template")? + .render_to_write( + context! { + version => "3.5.1", + }, + &mut *write, + ) + .wrap_err("Cannot render CUTLASS dependency template")?; + } + Dependency::Cutlass3_6 => { + env.get_template("cuda/dep-cutlass.cmake") + .wrap_err("Cannot get CUTLASS dependency template")? + .render_to_write( + context! { + version => "3.6.0", + }, + &mut *write, + ) + .wrap_err("Cannot render CUTLASS dependency template")?; + } + Dependency::Cutlass3_8 => { + env.get_template("cuda/dep-cutlass.cmake") + .wrap_err("Cannot get CUTLASS dependency template")? + .render_to_write( + context! { + version => "3.8.0", + }, + &mut *write, + ) + .wrap_err("Cannot render CUTLASS dependency template")?; + } + Dependency::Cutlass3_9 => { + env.get_template("cuda/dep-cutlass.cmake") + .wrap_err("Cannot get CUTLASS dependency template")? + .render_to_write( + context! { + version => "3.9.2", + }, + &mut *write, + ) + .wrap_err("Cannot render CUTLASS dependency template")?; + } + Dependency::Cutlass4_0 => { + env.get_template("cuda/dep-cutlass.cmake") + .wrap_err("Cannot get CUTLASS dependency template")? + .render_to_write( + context! { + version => "4.0.0", + }, + &mut *write, + ) + .wrap_err("Cannot render CUTLASS dependency template")?; + } + Dependency::CutlassSycl => { + env.get_template("xpu/dep-cutlass-sycl.cmake")? + .render_to_write(context! {}, &mut *write)?; + } + Dependency::Torch => (), + _ => { + eprintln!("Warning: {backend:?} backend doesn't need/support dependency: {dep:?}"); + } + } + write.write_all(b"\n")?; + } + + Ok(()) +} diff --git a/build2cmake/src/torch/mod.rs b/build2cmake/src/torch/mod.rs index 03a737f9..0a45d55b 100644 --- a/build2cmake/src/torch/mod.rs +++ b/build2cmake/src/torch/mod.rs @@ -6,6 +6,8 @@ pub use cuda::write_torch_ext_cuda; pub mod common; +pub(crate) mod deps; + pub mod kernel; mod metal; diff --git a/build2cmake/src/torch/xpu.rs b/build2cmake/src/torch/xpu.rs index 5404bb4c..8bb6f5fe 100644 --- a/build2cmake/src/torch/xpu.rs +++ b/build2cmake/src/torch/xpu.rs @@ -1,15 +1,15 @@ -use std::collections::HashSet; use std::io::Write; use std::path::PathBuf; use eyre::{bail, Context, Result}; use minijinja::{context, Environment}; -use crate::config::{Backend, Build, Dependency, Torch}; +use crate::config::{Backend, Build, Torch}; use crate::torch::common::{ render_binding, render_extension, write_cmake_helpers, write_metadata, write_ops_py, write_pyproject_toml, write_setup_py, write_torch_registration_macros, }; +use crate::torch::deps::render_deps; use crate::torch::kernel::render_kernel_components; use crate::torch::kernel_ops_identifier; use crate::version::Version; @@ -89,40 +89,6 @@ fn write_cmake( Ok(()) } -fn render_deps( - env: &Environment, - backend: Backend, - build: &Build, - write: &mut impl Write, -) -> Result<()> { - let mut deps = HashSet::new(); - - for kernel in build - .kernels - .values() - .filter(|kernel| kernel.backend() == backend) - { - deps.extend(kernel.depends()); - } - - for dep in deps { - match dep { - Dependency::CutlassSycl => { - env.get_template("xpu/dep-cutlass-sycl.cmake")? - .render_to_write(context! {}, &mut *write)?; - } - Dependency::Torch => (), - _ => { - // XPU supports CUTLASS-SYCL instead of CUTLASS - eprintln!("Warning: XPU backend doesn't need/support dependency: {dep:?}"); - } - } - write.write_all(b"\n")?; - } - - Ok(()) -} - pub fn render_preamble( env: &Environment, name: &str,