Skip to content
Merged
99 changes: 2 additions & 97 deletions build2cmake/src/torch/cuda.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use std::collections::HashSet;
use std::env;
use std::io::Write;
use std::path::PathBuf;

use eyre::{bail, Context, Result};
use minijinja::{context, Environment};

use crate::config::{Backend, Build, Dependency, Torch};
use crate::config::{Backend, Build, Torch};
use crate::torch::common::{
render_binding, render_extension, write_cmake_helpers, write_metadata, write_ops_py,
write_pyproject_toml, write_setup_py, write_torch_registration_macros,
};
use crate::torch::deps::render_deps;
use crate::torch::kernel::render_kernel_components;
use crate::torch::kernel_ops_identifier;
use crate::version::Version;
Expand Down Expand Up @@ -95,101 +95,6 @@ fn write_cmake(
Ok(())
}

fn render_deps(
env: &Environment,
backend: Backend,
build: &Build,
write: &mut impl Write,
) -> Result<()> {
let mut deps = HashSet::new();

for kernel in build
.kernels
.values()
.filter(|kernel| kernel.backend() == backend)
{
deps.extend(kernel.depends());
}

for dep in deps {
match dep {
Dependency::Cutlass2_10 => {
env.get_template("cuda/dep-cutlass.cmake")
.wrap_err("Cannot get CUTLASS dependency template")?
.render_to_write(
context! {
version => "2.10.0",
},
&mut *write,
)
.wrap_err("Cannot render CUTLASS dependency template")?;
}
Dependency::Cutlass3_5 => {
env.get_template("cuda/dep-cutlass.cmake")
.wrap_err("Cannot get CUTLASS dependency template")?
.render_to_write(
context! {
version => "3.5.1",
},
&mut *write,
)
.wrap_err("Cannot render CUTLASS dependency template")?;
}
Dependency::Cutlass3_6 => {
env.get_template("cuda/dep-cutlass.cmake")
.wrap_err("Cannot get CUTLASS dependency template")?
.render_to_write(
context! {
version => "3.6.0",
},
&mut *write,
)
.wrap_err("Cannot render CUTLASS dependency template")?;
}
Dependency::Cutlass3_8 => {
env.get_template("cuda/dep-cutlass.cmake")
.wrap_err("Cannot get CUTLASS dependency template")?
.render_to_write(
context! {
version => "3.8.0",
},
&mut *write,
)
.wrap_err("Cannot render CUTLASS dependency template")?;
}
Dependency::Cutlass3_9 => {
env.get_template("cuda/dep-cutlass.cmake")
.wrap_err("Cannot get CUTLASS dependency template")?
.render_to_write(
context! {
version => "3.9.2",
},
&mut *write,
)
.wrap_err("Cannot render CUTLASS dependency template")?;
}
Dependency::Cutlass4_0 => {
env.get_template("cuda/dep-cutlass.cmake")
.wrap_err("Cannot get CUTLASS dependency template")?
.render_to_write(
context! {
version => "4.0.0",
},
&mut *write,
)
.wrap_err("Cannot render CUTLASS dependency template")?;
}
Dependency::Torch => (),
_ => {
eprintln!("Warning: CUDA backend doesn't need/support dependency: {dep:?}");
}
};
write.write_all(b"\n")?;
}

Ok(())
}

pub fn render_preamble(
env: &Environment,
name: &str,
Expand Down
106 changes: 106 additions & 0 deletions build2cmake/src/torch/deps.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use std::collections::HashSet;
use std::io::Write;

use eyre::{Context, Result};
use minijinja::{context, Environment};

use crate::config::{Backend, Build, Dependency};

pub fn render_deps(
env: &Environment,
backend: Backend,
build: &Build,
write: &mut impl Write,
) -> Result<()> {
let mut deps = HashSet::new();

for kernel in build
.kernels
.values()
.filter(|kernel| kernel.backend() == backend)
{
deps.extend(kernel.depends());
}

for dep in deps {
match dep {
Dependency::Cutlass2_10 => {
env.get_template("cuda/dep-cutlass.cmake")
.wrap_err("Cannot get CUTLASS dependency template")?
.render_to_write(
context! {
version => "2.10.0",
},
&mut *write,
)
.wrap_err("Cannot render CUTLASS dependency template")?;
}
Dependency::Cutlass3_5 => {
env.get_template("cuda/dep-cutlass.cmake")
.wrap_err("Cannot get CUTLASS dependency template")?
.render_to_write(
context! {
version => "3.5.1",
},
&mut *write,
)
.wrap_err("Cannot render CUTLASS dependency template")?;
}
Dependency::Cutlass3_6 => {
env.get_template("cuda/dep-cutlass.cmake")
.wrap_err("Cannot get CUTLASS dependency template")?
.render_to_write(
context! {
version => "3.6.0",
},
&mut *write,
)
.wrap_err("Cannot render CUTLASS dependency template")?;
}
Dependency::Cutlass3_8 => {
env.get_template("cuda/dep-cutlass.cmake")
.wrap_err("Cannot get CUTLASS dependency template")?
.render_to_write(
context! {
version => "3.8.0",
},
&mut *write,
)
.wrap_err("Cannot render CUTLASS dependency template")?;
}
Dependency::Cutlass3_9 => {
env.get_template("cuda/dep-cutlass.cmake")
.wrap_err("Cannot get CUTLASS dependency template")?
.render_to_write(
context! {
version => "3.9.2",
},
&mut *write,
)
.wrap_err("Cannot render CUTLASS dependency template")?;
}
Dependency::Cutlass4_0 => {
env.get_template("cuda/dep-cutlass.cmake")
.wrap_err("Cannot get CUTLASS dependency template")?
.render_to_write(
context! {
version => "4.0.0",
},
&mut *write,
)
.wrap_err("Cannot render CUTLASS dependency template")?;
}
Dependency::CutlassSycl => {
env.get_template("xpu/dep-cutlass-sycl.cmake")?
.render_to_write(context! {}, &mut *write)?;
}
Dependency::Torch => (),
_ => {
eprintln!("Warning: {backend:?} backend doesn't need/support dependency: {dep:?}");
}
}
write.write_all(b"\n")?;
}

Ok(())
}
2 changes: 2 additions & 0 deletions build2cmake/src/torch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pub use cuda::write_torch_ext_cuda;

pub mod common;

pub(crate) mod deps;

pub mod kernel;

mod metal;
Expand Down
38 changes: 2 additions & 36 deletions build2cmake/src/torch/xpu.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use std::collections::HashSet;
use std::io::Write;
use std::path::PathBuf;

use eyre::{bail, Context, Result};
use minijinja::{context, Environment};

use crate::config::{Backend, Build, Dependency, Torch};
use crate::config::{Backend, Build, Torch};
use crate::torch::common::{
render_binding, render_extension, write_cmake_helpers, write_metadata, write_ops_py,
write_pyproject_toml, write_setup_py, write_torch_registration_macros,
};
use crate::torch::deps::render_deps;
use crate::torch::kernel::render_kernel_components;
use crate::torch::kernel_ops_identifier;
use crate::version::Version;
Expand Down Expand Up @@ -89,40 +89,6 @@ fn write_cmake(
Ok(())
}

fn render_deps(
env: &Environment,
backend: Backend,
build: &Build,
write: &mut impl Write,
) -> Result<()> {
let mut deps = HashSet::new();

for kernel in build
.kernels
.values()
.filter(|kernel| kernel.backend() == backend)
{
deps.extend(kernel.depends());
}

for dep in deps {
match dep {
Dependency::CutlassSycl => {
env.get_template("xpu/dep-cutlass-sycl.cmake")?
.render_to_write(context! {}, &mut *write)?;
}
Dependency::Torch => (),
_ => {
// XPU supports CUTLASS-SYCL instead of CUTLASS
eprintln!("Warning: XPU backend doesn't need/support dependency: {dep:?}");
}
}
write.write_all(b"\n")?;
}

Ok(())
}

pub fn render_preamble(
env: &Environment,
name: &str,
Expand Down
Loading