Skip to content

Commit 08d9897

Browse files
authored
Remove the last backend-specific writer functions (#255)
* build2cmake: merge `write_cmake` functions in `common` * build2cmake: merge all `write_torch_ext` implementations in `common`
1 parent 24750e9 commit 08d9897

7 files changed

Lines changed: 99 additions & 403 deletions

File tree

build2cmake/src/main.rs

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@ use eyre::{bail, ensure, Context, Result};
99
use minijinja::Environment;
1010

1111
mod torch;
12-
use torch::{
13-
write_torch_ext_cpu, write_torch_ext_cuda, write_torch_ext_metal, write_torch_ext_noarch,
14-
write_torch_ext_xpu,
15-
};
12+
use torch::{write_torch_ext, write_torch_ext_noarch};
1613

1714
mod config;
1815
use config::{v3, Backend, Build, BuildCompat};
@@ -176,14 +173,7 @@ fn generate_torch(
176173
let file_set = if build.is_noarch() {
177174
write_torch_ext_noarch(&env, backend, &build, target_dir.clone(), ops_id)?
178175
} else {
179-
match backend {
180-
Backend::Cpu => write_torch_ext_cpu(&env, &build, target_dir.clone(), ops_id)?,
181-
Backend::Cuda | Backend::Rocm => {
182-
write_torch_ext_cuda(&env, backend, &build, target_dir.clone(), ops_id)?
183-
}
184-
Backend::Metal => write_torch_ext_metal(&env, &build, target_dir.clone(), ops_id)?,
185-
Backend::Xpu => write_torch_ext_xpu(&env, &build, target_dir.clone(), ops_id)?,
186-
}
176+
write_torch_ext(&env, backend, &build, target_dir.clone(), ops_id)?
187177
};
188178
file_set.write(&target_dir, force)?;
189179

@@ -382,20 +372,7 @@ fn get_generated_files(
382372
let set = if build.is_noarch() {
383373
write_torch_ext_noarch(env, *backend, build, target_dir.clone(), ops_id.clone())?
384374
} else {
385-
match backend {
386-
Backend::Cpu => {
387-
write_torch_ext_cpu(env, build, target_dir.clone(), ops_id.clone())?
388-
}
389-
Backend::Cuda | Backend::Rocm => {
390-
write_torch_ext_cuda(env, *backend, build, target_dir.clone(), ops_id.clone())?
391-
}
392-
Backend::Metal => {
393-
write_torch_ext_metal(env, build, target_dir.clone(), ops_id.clone())?
394-
}
395-
Backend::Xpu => {
396-
write_torch_ext_xpu(env, build, target_dir.clone(), ops_id.clone())?
397-
}
398-
}
375+
write_torch_ext(env, *backend, build, target_dir.clone(), ops_id.clone())?
399376
};
400377
all_set.extend(set);
401378
}

build2cmake/src/torch/common.rs

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
use std::io::Write;
22
use std::path::PathBuf;
33

4-
use eyre::{Context, Result};
4+
use eyre::{bail, Context, Result};
55
use itertools::Itertools;
66
use minijinja::{context, Environment};
77

8-
use crate::config::{Backend, General, Torch};
8+
use crate::config::{Backend, Build, General, Torch};
99
use crate::metadata::Metadata;
10+
use crate::torch::deps::render_deps;
11+
use crate::torch::kernel::render_kernel_components;
1012
use crate::version::Version;
1113
use crate::FileSet;
1214

@@ -245,3 +247,94 @@ pub fn render_preamble(
245247

246248
Ok(())
247249
}
250+
251+
pub fn write_cmake(
252+
env: &Environment,
253+
backend: Backend,
254+
build: &Build,
255+
torch: &Torch,
256+
name: &str,
257+
ops_name: &str,
258+
file_set: &mut FileSet,
259+
) -> Result<()> {
260+
write_cmake_helpers(file_set);
261+
262+
let cmake_writer = file_set.entry("CMakeLists.txt");
263+
264+
let (cuda_minver, cuda_maxver) = match backend {
265+
Backend::Cuda => (
266+
build.general.cuda.as_ref().and_then(|c| c.minver.as_ref()),
267+
build.general.cuda.as_ref().and_then(|c| c.maxver.as_ref()),
268+
),
269+
_ => (None, None),
270+
};
271+
272+
render_preamble(
273+
env,
274+
name,
275+
cuda_minver,
276+
cuda_maxver,
277+
torch.minver.as_ref(),
278+
torch.maxver.as_ref(),
279+
cmake_writer,
280+
)?;
281+
282+
render_deps(env, backend, build, cmake_writer)?;
283+
284+
render_binding(env, torch, name, cmake_writer)?;
285+
286+
render_kernel_components(env, build, cmake_writer)?;
287+
288+
render_extension(env, name, ops_name, cmake_writer)?;
289+
290+
Ok(())
291+
}
292+
293+
pub fn write_torch_ext(
294+
env: &Environment,
295+
backend: Backend,
296+
build: &Build,
297+
target_dir: PathBuf,
298+
ops_id: Option<String>,
299+
) -> Result<FileSet> {
300+
let torch_ext = match build.torch.as_ref() {
301+
Some(torch_ext) => torch_ext,
302+
None => bail!("Build configuration does not have `torch` section"),
303+
};
304+
305+
let mut file_set = FileSet::default();
306+
307+
let ops_name = crate::torch::ops_identifier::kernel_ops_identifier(
308+
&target_dir,
309+
&build.general.python_name(),
310+
ops_id,
311+
);
312+
313+
write_cmake(
314+
env,
315+
backend,
316+
build,
317+
torch_ext,
318+
&build.general.name,
319+
&ops_name,
320+
&mut file_set,
321+
)?;
322+
323+
write_setup_py(
324+
env,
325+
torch_ext,
326+
&build.general.name,
327+
&ops_name,
328+
&mut file_set,
329+
)?;
330+
331+
write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?;
332+
333+
write_pyproject_toml(env, backend, &build.general, &mut file_set)?;
334+
335+
write_torch_registration_macros(&mut file_set)?;
336+
337+
write_metadata(backend, &build.general, &mut file_set)?;
338+
339+
Ok(file_set)
340+
}

build2cmake/src/torch/cpu.rs

Lines changed: 0 additions & 90 deletions
This file was deleted.

build2cmake/src/torch/cuda.rs

Lines changed: 0 additions & 93 deletions
This file was deleted.

0 commit comments

Comments
 (0)