Skip to content

Commit 0f8c79d

Browse files
authored
Move writing of CMake utility fails and ops wrapper to common (#249)
* build2cmake: always generate kernel components for all backends This change always generates the kernel components for all backends. This is a necessary step for merging the preambles later. * Factor out `write_torch_registration_macros` to `common` mod * Move `write_ops_py` to `common` mod * Factor out CMake helper file generation
1 parent 11a4637 commit 0f8c79d

5 files changed

Lines changed: 84 additions & 257 deletions

File tree

build2cmake/src/torch/common.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::io::Write;
2+
use std::path::PathBuf;
23

34
use eyre::{Context, Result};
45
use itertools::Itertools;
@@ -8,6 +9,14 @@ use crate::config::{Backend, General, Torch};
89
use crate::metadata::Metadata;
910
use crate::FileSet;
1011

12+
static REGISTRATION_H: &str = include_str!("../templates/registration.h");
13+
static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
14+
static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake");
15+
static WINDOWS_UTILS: &str = include_str!("../templates/windows.cmake");
16+
static HIPIFY: &str = include_str!("../templates/cuda/hipify.py");
17+
static COMPILE_METAL_CMAKE: &str = include_str!("../templates/metal/compile-metal.cmake");
18+
static METALLIB_TO_HEADER_PY: &str = include_str!("../templates/metal/metallib_to_header.py");
19+
1120
pub fn write_setup_py(
1221
env: &Environment,
1322
torch: &crate::config::Torch,
@@ -94,6 +103,17 @@ where
94103
.join(";")
95104
}
96105

106+
pub fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> {
107+
let mut path = PathBuf::new();
108+
path.push("torch-ext");
109+
path.push("registration.h");
110+
file_set
111+
.entry(path)
112+
.extend_from_slice(REGISTRATION_H.as_bytes());
113+
114+
Ok(())
115+
}
116+
97117
pub fn render_binding(
98118
env: &Environment,
99119
torch: &Torch,
@@ -117,6 +137,58 @@ pub fn render_binding(
117137
Ok(())
118138
}
119139

140+
pub fn write_ops_py(
141+
env: &Environment,
142+
name: &str,
143+
ops_name: &str,
144+
file_set: &mut FileSet,
145+
) -> Result<()> {
146+
let mut path = PathBuf::new();
147+
path.push("torch-ext");
148+
path.push(name);
149+
path.push("_ops.py");
150+
let writer = file_set.entry(path);
151+
152+
env.get_template("_ops.py")
153+
.wrap_err("Cannot get _ops.py template")?
154+
.render_to_write(
155+
context! {
156+
ops_name => ops_name,
157+
},
158+
writer,
159+
)
160+
.wrap_err("Cannot render kernel template")?;
161+
162+
Ok(())
163+
}
164+
165+
/// Helper function to write a file to the cmake subdirectory
166+
pub fn write_cmake_file(file_set: &mut FileSet, filename: &str, content: &[u8]) {
167+
let mut path = PathBuf::new();
168+
path.push("cmake");
169+
path.push(filename);
170+
file_set.entry(path).extend_from_slice(content);
171+
}
172+
173+
/// Writes all CMake helper files that any backend might need.
174+
/// Each backend will use only the files it references in its CMakeLists.txt.
175+
pub fn write_cmake_helpers(file_set: &mut FileSet) {
176+
write_cmake_file(file_set, "utils.cmake", CMAKE_UTILS.as_bytes());
177+
write_cmake_file(file_set, "kernel.cmake", CMAKE_KERNEL.as_bytes());
178+
write_cmake_file(file_set, "windows.cmake", WINDOWS_UTILS.as_bytes());
179+
write_cmake_file(file_set, "hipify.py", HIPIFY.as_bytes());
180+
write_cmake_file(
181+
file_set,
182+
"compile-metal.cmake",
183+
COMPILE_METAL_CMAKE.as_bytes(),
184+
);
185+
write_cmake_file(
186+
file_set,
187+
"metallib_to_header.py",
188+
METALLIB_TO_HEADER_PY.as_bytes(),
189+
);
190+
}
191+
120192
pub fn render_extension(
121193
env: &Environment,
122194
name: &str,

build2cmake/src/torch/cpu.rs

Lines changed: 3 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,13 @@ use minijinja::{context, Environment};
66
use crate::config::{Backend, Build, Torch};
77
use crate::fileset::FileSet;
88
use crate::torch::common::{
9-
render_binding, render_extension, write_metadata, write_pyproject_toml, write_setup_py,
9+
render_binding, render_extension, write_cmake_helpers, write_metadata, write_ops_py,
10+
write_pyproject_toml, write_setup_py, write_torch_registration_macros,
1011
};
1112
use crate::torch::kernel::render_kernel_components;
1213
use crate::torch::kernel_ops_identifier;
1314
use crate::version::Version;
1415

15-
static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
16-
static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake");
17-
static REGISTRATION_H: &str = include_str!("../templates/registration.h");
18-
1916
pub fn write_torch_ext_cpu(
2017
env: &Environment,
2118
build: &Build,
@@ -67,19 +64,7 @@ fn write_cmake(
6764
ops_name: &str,
6865
file_set: &mut FileSet,
6966
) -> Result<()> {
70-
let mut utils_path = PathBuf::new();
71-
utils_path.push("cmake");
72-
utils_path.push("utils.cmake");
73-
file_set
74-
.entry(utils_path.clone())
75-
.extend_from_slice(CMAKE_UTILS.as_bytes());
76-
77-
let mut kernel_path = PathBuf::new();
78-
kernel_path.push("cmake");
79-
kernel_path.push("kernel.cmake");
80-
file_set
81-
.entry(kernel_path.clone())
82-
.extend_from_slice(CMAKE_KERNEL.as_bytes());
67+
write_cmake_helpers(file_set);
8368

8469
let cmake_writer = file_set.entry("CMakeLists.txt");
8570

@@ -126,39 +111,3 @@ fn render_preamble(
126111

127112
Ok(())
128113
}
129-
130-
fn write_ops_py(
131-
env: &Environment,
132-
name: &str,
133-
ops_name: &str,
134-
file_set: &mut FileSet,
135-
) -> Result<()> {
136-
let mut path = PathBuf::new();
137-
path.push("torch-ext");
138-
path.push(name);
139-
path.push("_ops.py");
140-
let writer = file_set.entry(path);
141-
142-
env.get_template("_ops.py")
143-
.wrap_err("Cannot get _ops.py template")?
144-
.render_to_write(
145-
context! {
146-
ops_name => ops_name,
147-
},
148-
writer,
149-
)
150-
.wrap_err("Cannot render kernel template")?;
151-
152-
Ok(())
153-
}
154-
155-
fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> {
156-
let mut path = PathBuf::new();
157-
path.push("torch-ext");
158-
path.push("registration.h");
159-
file_set
160-
.entry(path)
161-
.extend_from_slice(REGISTRATION_H.as_bytes());
162-
163-
Ok(())
164-
}

build2cmake/src/torch/cuda.rs

Lines changed: 3 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,14 @@ use minijinja::{context, Environment};
88

99
use crate::config::{Backend, Build, Dependency, Torch};
1010
use crate::torch::common::{
11-
render_binding, render_extension, write_metadata, write_pyproject_toml, write_setup_py,
11+
render_binding, render_extension, write_cmake_helpers, write_metadata, write_ops_py,
12+
write_pyproject_toml, write_setup_py, write_torch_registration_macros,
1213
};
1314
use crate::torch::kernel::render_kernel_components;
1415
use crate::torch::kernel_ops_identifier;
1516
use crate::version::Version;
1617
use crate::FileSet;
1718

18-
static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
19-
static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake");
20-
static WINDOWS_UTILS: &str = include_str!("../templates/windows.cmake");
21-
static REGISTRATION_H: &str = include_str!("../templates/registration.h");
22-
static HIPIFY: &str = include_str!("../templates/cuda/hipify.py");
23-
2419
pub fn write_torch_ext_cuda(
2520
env: &Environment,
2621
backend: Backend,
@@ -66,42 +61,6 @@ pub fn write_torch_ext_cuda(
6661
Ok(file_set)
6762
}
6863

69-
fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> {
70-
let mut path = PathBuf::new();
71-
path.push("torch-ext");
72-
path.push("registration.h");
73-
file_set
74-
.entry(path)
75-
.extend_from_slice(REGISTRATION_H.as_bytes());
76-
77-
Ok(())
78-
}
79-
80-
fn write_ops_py(
81-
env: &Environment,
82-
name: &str,
83-
ops_name: &str,
84-
file_set: &mut FileSet,
85-
) -> Result<()> {
86-
let mut path = PathBuf::new();
87-
path.push("torch-ext");
88-
path.push(name);
89-
path.push("_ops.py");
90-
let writer = file_set.entry(path);
91-
92-
env.get_template("_ops.py")
93-
.wrap_err("Cannot get _ops.py template")?
94-
.render_to_write(
95-
context! {
96-
ops_name => ops_name,
97-
},
98-
writer,
99-
)
100-
.wrap_err("Cannot render kernel template")?;
101-
102-
Ok(())
103-
}
104-
10564
fn write_cmake(
10665
env: &Environment,
10766
backend: Backend,
@@ -111,33 +70,7 @@ fn write_cmake(
11170
ops_name: &str,
11271
file_set: &mut FileSet,
11372
) -> Result<()> {
114-
let mut utils_path = PathBuf::new();
115-
utils_path.push("cmake");
116-
utils_path.push("utils.cmake");
117-
file_set
118-
.entry(utils_path.clone())
119-
.extend_from_slice(CMAKE_UTILS.as_bytes());
120-
121-
let mut kernel_path = PathBuf::new();
122-
kernel_path.push("cmake");
123-
kernel_path.push("kernel.cmake");
124-
file_set
125-
.entry(kernel_path.clone())
126-
.extend_from_slice(CMAKE_KERNEL.as_bytes());
127-
128-
let mut windows_utils_path = PathBuf::new();
129-
windows_utils_path.push("cmake");
130-
windows_utils_path.push("windows.cmake");
131-
file_set
132-
.entry(windows_utils_path.clone())
133-
.extend_from_slice(WINDOWS_UTILS.as_bytes());
134-
135-
let mut hipify_path = PathBuf::new();
136-
hipify_path.push("cmake");
137-
hipify_path.push("hipify.py");
138-
file_set
139-
.entry(hipify_path.clone())
140-
.extend_from_slice(HIPIFY.as_bytes());
73+
write_cmake_helpers(file_set);
14174

14275
let cmake_writer = file_set.entry("CMakeLists.txt");
14376

build2cmake/src/torch/metal.rs

Lines changed: 3 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,13 @@ use minijinja::{context, Environment};
66
use crate::config::{Backend, Build, Torch};
77
use crate::fileset::FileSet;
88
use crate::torch::common::{
9-
render_binding, render_extension, write_metadata, write_pyproject_toml, write_setup_py,
9+
render_binding, render_extension, write_cmake_helpers, write_metadata, write_ops_py,
10+
write_pyproject_toml, write_setup_py, write_torch_registration_macros,
1011
};
1112
use crate::torch::kernel::render_kernel_components;
1213
use crate::torch::kernel_ops_identifier;
1314
use crate::version::Version;
1415

15-
static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
16-
static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake");
17-
static REGISTRATION_H: &str = include_str!("../templates/registration.h");
18-
static COMPILE_METAL_CMAKE: &str = include_str!("../templates/metal/compile-metal.cmake");
19-
static METALLIB_TO_HEADER_PY: &str = include_str!("../templates/metal/metallib_to_header.py");
20-
2116
pub fn write_torch_ext_metal(
2217
env: &Environment,
2318
build: &Build,
@@ -69,33 +64,7 @@ fn write_cmake(
6964
ops_name: &str,
7065
file_set: &mut FileSet,
7166
) -> Result<()> {
72-
let mut utils_path = PathBuf::new();
73-
utils_path.push("cmake");
74-
utils_path.push("utils.cmake");
75-
file_set
76-
.entry(utils_path.clone())
77-
.extend_from_slice(CMAKE_UTILS.as_bytes());
78-
79-
let mut kernel_path = PathBuf::new();
80-
kernel_path.push("cmake");
81-
kernel_path.push("kernel.cmake");
82-
file_set
83-
.entry(kernel_path.clone())
84-
.extend_from_slice(CMAKE_KERNEL.as_bytes());
85-
86-
let mut compile_metal_path = PathBuf::new();
87-
compile_metal_path.push("cmake");
88-
compile_metal_path.push("compile-metal.cmake");
89-
file_set
90-
.entry(compile_metal_path)
91-
.extend_from_slice(COMPILE_METAL_CMAKE.as_bytes());
92-
93-
let mut metallib_to_header_path = PathBuf::new();
94-
metallib_to_header_path.push("cmake");
95-
metallib_to_header_path.push("metallib_to_header.py");
96-
file_set
97-
.entry(metallib_to_header_path)
98-
.extend_from_slice(METALLIB_TO_HEADER_PY.as_bytes());
67+
write_cmake_helpers(file_set);
9968

10069
let cmake_writer = file_set.entry("CMakeLists.txt");
10170

@@ -142,39 +111,3 @@ fn render_preamble(
142111

143112
Ok(())
144113
}
145-
146-
fn write_ops_py(
147-
env: &Environment,
148-
name: &str,
149-
ops_name: &str,
150-
file_set: &mut FileSet,
151-
) -> Result<()> {
152-
let mut path = PathBuf::new();
153-
path.push("torch-ext");
154-
path.push(name);
155-
path.push("_ops.py");
156-
let writer = file_set.entry(path);
157-
158-
env.get_template("_ops.py")
159-
.wrap_err("Cannot get _ops.py template")?
160-
.render_to_write(
161-
context! {
162-
ops_name => ops_name,
163-
},
164-
writer,
165-
)
166-
.wrap_err("Cannot render kernel template")?;
167-
168-
Ok(())
169-
}
170-
171-
fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> {
172-
let mut path = PathBuf::new();
173-
path.push("torch-ext");
174-
path.push("registration.h");
175-
file_set
176-
.entry(path)
177-
.extend_from_slice(REGISTRATION_H.as_bytes());
178-
179-
Ok(())
180-
}

0 commit comments

Comments
 (0)