Skip to content

Commit 6b42ffb

Browse files
authored
Factor out render_deps function (#251)
* build2cmake: always generate kernel components for all backends This change always generates the kernel components for all backends. This is a necessary step for merging the preambles later. * Factor out `write_torch_registration_macros` to `common` mod * Move `write_ops_py` to `common` mod * Factor out CMake helper file generation * Factor out `render_deps` function
1 parent 26355c5 commit 6b42ffb

4 files changed

Lines changed: 112 additions & 133 deletions

File tree

build2cmake/src/torch/cuda.rs

Lines changed: 2 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
use std::collections::HashSet;
21
use std::env;
32
use std::io::Write;
43
use std::path::PathBuf;
54

65
use eyre::{bail, Context, Result};
76
use minijinja::{context, Environment};
87

9-
use crate::config::{Backend, Build, Dependency, Torch};
8+
use crate::config::{Backend, Build, Torch};
109
use crate::torch::common::{
1110
render_binding, render_extension, write_cmake_helpers, write_metadata, write_ops_py,
1211
write_pyproject_toml, write_setup_py, write_torch_registration_macros,
1312
};
13+
use crate::torch::deps::render_deps;
1414
use crate::torch::kernel::render_kernel_components;
1515
use crate::torch::kernel_ops_identifier;
1616
use crate::version::Version;
@@ -95,101 +95,6 @@ fn write_cmake(
9595
Ok(())
9696
}
9797

98-
fn render_deps(
99-
env: &Environment,
100-
backend: Backend,
101-
build: &Build,
102-
write: &mut impl Write,
103-
) -> Result<()> {
104-
let mut deps = HashSet::new();
105-
106-
for kernel in build
107-
.kernels
108-
.values()
109-
.filter(|kernel| kernel.backend() == backend)
110-
{
111-
deps.extend(kernel.depends());
112-
}
113-
114-
for dep in deps {
115-
match dep {
116-
Dependency::Cutlass2_10 => {
117-
env.get_template("cuda/dep-cutlass.cmake")
118-
.wrap_err("Cannot get CUTLASS dependency template")?
119-
.render_to_write(
120-
context! {
121-
version => "2.10.0",
122-
},
123-
&mut *write,
124-
)
125-
.wrap_err("Cannot render CUTLASS dependency template")?;
126-
}
127-
Dependency::Cutlass3_5 => {
128-
env.get_template("cuda/dep-cutlass.cmake")
129-
.wrap_err("Cannot get CUTLASS dependency template")?
130-
.render_to_write(
131-
context! {
132-
version => "3.5.1",
133-
},
134-
&mut *write,
135-
)
136-
.wrap_err("Cannot render CUTLASS dependency template")?;
137-
}
138-
Dependency::Cutlass3_6 => {
139-
env.get_template("cuda/dep-cutlass.cmake")
140-
.wrap_err("Cannot get CUTLASS dependency template")?
141-
.render_to_write(
142-
context! {
143-
version => "3.6.0",
144-
},
145-
&mut *write,
146-
)
147-
.wrap_err("Cannot render CUTLASS dependency template")?;
148-
}
149-
Dependency::Cutlass3_8 => {
150-
env.get_template("cuda/dep-cutlass.cmake")
151-
.wrap_err("Cannot get CUTLASS dependency template")?
152-
.render_to_write(
153-
context! {
154-
version => "3.8.0",
155-
},
156-
&mut *write,
157-
)
158-
.wrap_err("Cannot render CUTLASS dependency template")?;
159-
}
160-
Dependency::Cutlass3_9 => {
161-
env.get_template("cuda/dep-cutlass.cmake")
162-
.wrap_err("Cannot get CUTLASS dependency template")?
163-
.render_to_write(
164-
context! {
165-
version => "3.9.2",
166-
},
167-
&mut *write,
168-
)
169-
.wrap_err("Cannot render CUTLASS dependency template")?;
170-
}
171-
Dependency::Cutlass4_0 => {
172-
env.get_template("cuda/dep-cutlass.cmake")
173-
.wrap_err("Cannot get CUTLASS dependency template")?
174-
.render_to_write(
175-
context! {
176-
version => "4.0.0",
177-
},
178-
&mut *write,
179-
)
180-
.wrap_err("Cannot render CUTLASS dependency template")?;
181-
}
182-
Dependency::Torch => (),
183-
_ => {
184-
eprintln!("Warning: CUDA backend doesn't need/support dependency: {dep:?}");
185-
}
186-
};
187-
write.write_all(b"\n")?;
188-
}
189-
190-
Ok(())
191-
}
192-
19398
pub fn render_preamble(
19499
env: &Environment,
195100
name: &str,

build2cmake/src/torch/deps.rs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
use std::collections::HashSet;
2+
use std::io::Write;
3+
4+
use eyre::{Context, Result};
5+
use minijinja::{context, Environment};
6+
7+
use crate::config::{Backend, Build, Dependency};
8+
9+
pub fn render_deps(
10+
env: &Environment,
11+
backend: Backend,
12+
build: &Build,
13+
write: &mut impl Write,
14+
) -> Result<()> {
15+
let mut deps = HashSet::new();
16+
17+
for kernel in build
18+
.kernels
19+
.values()
20+
.filter(|kernel| kernel.backend() == backend)
21+
{
22+
deps.extend(kernel.depends());
23+
}
24+
25+
for dep in deps {
26+
match dep {
27+
Dependency::Cutlass2_10 => {
28+
env.get_template("cuda/dep-cutlass.cmake")
29+
.wrap_err("Cannot get CUTLASS dependency template")?
30+
.render_to_write(
31+
context! {
32+
version => "2.10.0",
33+
},
34+
&mut *write,
35+
)
36+
.wrap_err("Cannot render CUTLASS dependency template")?;
37+
}
38+
Dependency::Cutlass3_5 => {
39+
env.get_template("cuda/dep-cutlass.cmake")
40+
.wrap_err("Cannot get CUTLASS dependency template")?
41+
.render_to_write(
42+
context! {
43+
version => "3.5.1",
44+
},
45+
&mut *write,
46+
)
47+
.wrap_err("Cannot render CUTLASS dependency template")?;
48+
}
49+
Dependency::Cutlass3_6 => {
50+
env.get_template("cuda/dep-cutlass.cmake")
51+
.wrap_err("Cannot get CUTLASS dependency template")?
52+
.render_to_write(
53+
context! {
54+
version => "3.6.0",
55+
},
56+
&mut *write,
57+
)
58+
.wrap_err("Cannot render CUTLASS dependency template")?;
59+
}
60+
Dependency::Cutlass3_8 => {
61+
env.get_template("cuda/dep-cutlass.cmake")
62+
.wrap_err("Cannot get CUTLASS dependency template")?
63+
.render_to_write(
64+
context! {
65+
version => "3.8.0",
66+
},
67+
&mut *write,
68+
)
69+
.wrap_err("Cannot render CUTLASS dependency template")?;
70+
}
71+
Dependency::Cutlass3_9 => {
72+
env.get_template("cuda/dep-cutlass.cmake")
73+
.wrap_err("Cannot get CUTLASS dependency template")?
74+
.render_to_write(
75+
context! {
76+
version => "3.9.2",
77+
},
78+
&mut *write,
79+
)
80+
.wrap_err("Cannot render CUTLASS dependency template")?;
81+
}
82+
Dependency::Cutlass4_0 => {
83+
env.get_template("cuda/dep-cutlass.cmake")
84+
.wrap_err("Cannot get CUTLASS dependency template")?
85+
.render_to_write(
86+
context! {
87+
version => "4.0.0",
88+
},
89+
&mut *write,
90+
)
91+
.wrap_err("Cannot render CUTLASS dependency template")?;
92+
}
93+
Dependency::CutlassSycl => {
94+
env.get_template("xpu/dep-cutlass-sycl.cmake")?
95+
.render_to_write(context! {}, &mut *write)?;
96+
}
97+
Dependency::Torch => (),
98+
_ => {
99+
eprintln!("Warning: {backend:?} backend doesn't need/support dependency: {dep:?}");
100+
}
101+
}
102+
write.write_all(b"\n")?;
103+
}
104+
105+
Ok(())
106+
}

build2cmake/src/torch/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ pub use cuda::write_torch_ext_cuda;
66

77
pub mod common;
88

9+
pub(crate) mod deps;
10+
911
pub mod kernel;
1012

1113
mod metal;

build2cmake/src/torch/xpu.rs

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
use std::collections::HashSet;
21
use std::io::Write;
32
use std::path::PathBuf;
43

54
use eyre::{bail, Context, Result};
65
use minijinja::{context, Environment};
76

8-
use crate::config::{Backend, Build, Dependency, Torch};
7+
use crate::config::{Backend, Build, Torch};
98
use crate::torch::common::{
109
render_binding, render_extension, write_cmake_helpers, write_metadata, write_ops_py,
1110
write_pyproject_toml, write_setup_py, write_torch_registration_macros,
1211
};
12+
use crate::torch::deps::render_deps;
1313
use crate::torch::kernel::render_kernel_components;
1414
use crate::torch::kernel_ops_identifier;
1515
use crate::version::Version;
@@ -89,40 +89,6 @@ fn write_cmake(
8989
Ok(())
9090
}
9191

92-
fn render_deps(
93-
env: &Environment,
94-
backend: Backend,
95-
build: &Build,
96-
write: &mut impl Write,
97-
) -> Result<()> {
98-
let mut deps = HashSet::new();
99-
100-
for kernel in build
101-
.kernels
102-
.values()
103-
.filter(|kernel| kernel.backend() == backend)
104-
{
105-
deps.extend(kernel.depends());
106-
}
107-
108-
for dep in deps {
109-
match dep {
110-
Dependency::CutlassSycl => {
111-
env.get_template("xpu/dep-cutlass-sycl.cmake")?
112-
.render_to_write(context! {}, &mut *write)?;
113-
}
114-
Dependency::Torch => (),
115-
_ => {
116-
// XPU supports CUTLASS-SYCL instead of CUTLASS
117-
eprintln!("Warning: XPU backend doesn't need/support dependency: {dep:?}");
118-
}
119-
}
120-
write.write_all(b"\n")?;
121-
}
122-
123-
Ok(())
124-
}
125-
12692
pub fn render_preamble(
12793
env: &Environment,
12894
name: &str,

0 commit comments

Comments
 (0)