Skip to content

Commit 3ce9687

Browse files
danieldkdrbh
andauthored
build2cmake: always generate kernel components for all backends (#245)
* build2cmake: always generate kernel components for all backends This change always generates the kernel components for all backends. This is a necessary step for merging the preambles later. * Fix kernel render visibility Co-authored-by: drbh <david.richard.holtz@gmail.com> --------- Co-authored-by: drbh <david.richard.holtz@gmail.com>
1 parent b7275c7 commit 3ce9687

12 files changed

Lines changed: 258 additions & 244 deletions

File tree

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
cpu_kernel_component(SRC
2-
SOURCES {{ sources }}
3-
{% if includes %}INCLUDES "{{ includes }}"{% endif %}
4-
{% if cxx_flags %}CXX_FLAGS "{{ cxx_flags }}"{% endif %}
5-
)
1+
if(GPU_LANG STREQUAL "CPU")
2+
cpu_kernel_component(SRC
3+
SOURCES {{ sources }}
4+
{% if includes %}INCLUDES "{{ includes }}"{% endif %}
5+
{% if cxx_flags %}CXX_FLAGS "{{ cxx_flags }}"{% endif %}
6+
)
7+
endif()

build2cmake/src/templates/cpu/preamble.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }})
4242
endif()
4343
{% endif %}
4444

45+
set(GPU_LANG "CPU")
46+
4547
add_compile_definitions(CPU_KERNEL)
4648

4749
# Initialize SRC list for kernel and binding sources
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
metal_kernel_component(SRC
2-
SOURCES {{ sources }}
3-
{% if includes %}INCLUDES "{{ includes }}"{% endif %}
4-
{% if cxx_flags %}CXX_FLAGS "{{ cxx_flags }}"{% endif %}
5-
)
1+
if(GPU_LANG STREQUAL "METAL")
2+
metal_kernel_component(SRC
3+
SOURCES {{ sources }}
4+
{% if includes %}INCLUDES "{{ includes }}"{% endif %}
5+
{% if cxx_flags %}CXX_FLAGS "{{ cxx_flags }}"{% endif %}
6+
)
7+
endif()

build2cmake/src/templates/metal/preamble.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }})
4242
endif()
4343
{% endif %}
4444

45+
set(GPU_LANG "METAL")
46+
4547
add_compile_definitions(METAL_KERNEL)
4648

4749
# Initialize SRC list for kernel and binding sources
Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
xpu_kernel_component(SRC
2-
SOURCES {{ sources }}
3-
{% if includes %}INCLUDES "{{ includes }}"{% endif %}
4-
{% if cxx_flags %}CXX_FLAGS "{{ cxx_flags }}"{% endif %}
5-
{% if sycl_flags %}SYCL_FLAGS "{{ sycl_flags }}"{% endif %}
6-
)
1+
if(GPU_LANG STREQUAL "SYCL")
2+
xpu_kernel_component(SRC
3+
SOURCES {{ sources }}
4+
{% if includes %}INCLUDES "{{ includes }}"{% endif %}
5+
{% if cxx_flags %}CXX_FLAGS "{{ cxx_flags }}"{% endif %}
6+
{% if sycl_flags %}SYCL_FLAGS "{{ sycl_flags }}"{% endif %}
7+
)
8+
endif()

build2cmake/src/torch/common.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,15 @@ pub fn write_metadata(backend: Backend, general: &General, file_set: &mut FileSe
5252

5353
Ok(())
5454
}
55+
56+
pub fn prefix_and_join_includes<S>(includes: impl AsRef<[S]>) -> String
57+
where
58+
S: AsRef<str>,
59+
{
60+
includes
61+
.as_ref()
62+
.iter()
63+
.map(|include| format!("${{CMAKE_SOURCE_DIR}}/{}", include.as_ref()))
64+
.collect_vec()
65+
.join(";")
66+
}

build2cmake/src/torch/cpu.rs

Lines changed: 8 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ use eyre::{bail, Context, Result};
44
use itertools::Itertools;
55
use minijinja::{context, Environment};
66

7-
use super::{common::write_pyproject_toml, kernel_ops_identifier};
8-
use crate::{
9-
config::{Backend, Build, Kernel, Torch},
10-
fileset::FileSet,
11-
torch::common::write_metadata,
12-
version::Version,
13-
};
7+
use crate::config::{Backend, Build, Torch};
8+
use crate::fileset::FileSet;
9+
use crate::torch::common::write_metadata;
10+
use crate::torch::common::write_pyproject_toml;
11+
use crate::torch::kernel::render_kernel_components;
12+
use crate::torch::kernel_ops_identifier;
13+
use crate::version::Version;
1414

1515
static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
1616
static CMAKE_KERNEL: &str = include_str!("../templates/kernel.cmake");
@@ -96,13 +96,7 @@ fn write_cmake(
9696

9797
render_binding(env, torch, name, cmake_writer)?;
9898

99-
for (kernel_name, kernel) in build
100-
.kernels
101-
.iter()
102-
.filter(|(_, kernel)| matches!(kernel, Kernel::Cpu { .. }))
103-
{
104-
render_kernel(env, kernel_name, kernel, cmake_writer)?;
105-
}
99+
render_kernel_components(env, build, cmake_writer)?;
106100

107101
render_extension(env, name, ops_name, cmake_writer)?;
108102

@@ -154,38 +148,6 @@ pub fn render_extension(
154148
Ok(())
155149
}
156150

157-
pub fn render_kernel(
158-
env: &Environment,
159-
kernel_name: &str,
160-
kernel: &Kernel,
161-
write: &mut impl Write,
162-
) -> Result<()> {
163-
// Easier to do in Rust than Jinja.
164-
let sources = kernel
165-
.src()
166-
.iter()
167-
.map(|src| format!("\"{src}\""))
168-
.collect_vec()
169-
.join("\n");
170-
171-
env.get_template("cpu/kernel.cmake")
172-
.wrap_err("Cannot get kernel template")?
173-
.render_to_write(
174-
context! {
175-
cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")),
176-
includes => kernel.include().map(prefix_and_join_includes),
177-
kernel_name => kernel_name,
178-
sources => sources,
179-
},
180-
&mut *write,
181-
)
182-
.wrap_err("Cannot render kernel template")?;
183-
184-
write.write_all(b"\n")?;
185-
186-
Ok(())
187-
}
188-
189151
fn render_preamble(
190152
env: &Environment,
191153
name: &str,

build2cmake/src/torch/cuda.rs

Lines changed: 6 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ use std::io::Write;
44
use std::path::PathBuf;
55

66
use eyre::{bail, Context, Result};
7-
use itertools::Itertools;
87
use minijinja::{context, Environment};
98

10-
use super::common::write_pyproject_toml;
11-
use super::kernel_ops_identifier;
12-
use crate::config::{Backend, Build, Dependency, Kernel, Torch};
9+
use crate::config::{Backend, Build, Dependency, Torch};
10+
use crate::torch::common::prefix_and_join_includes;
1311
use crate::torch::common::write_metadata;
12+
use crate::torch::common::write_pyproject_toml;
13+
use crate::torch::kernel::render_kernel_components;
14+
use crate::torch::kernel_ops_identifier;
1415
use crate::version::Version;
1516
use crate::FileSet;
1617

@@ -181,13 +182,7 @@ fn write_cmake(
181182

182183
render_binding(env, torch, name, cmake_writer)?;
183184

184-
for (kernel_name, kernel) in build
185-
.kernels
186-
.iter()
187-
.filter(|(_, kernel)| kernel.backend() == backend)
188-
{
189-
render_kernel(env, kernel_name, kernel, cmake_writer)?;
190-
}
185+
render_kernel_components(env, build, cmake_writer)?;
191186

192187
render_extension(env, name, ops_name, cmake_writer)?;
193188

@@ -312,71 +307,6 @@ fn render_deps(
312307
Ok(())
313308
}
314309

315-
pub fn render_kernel(
316-
env: &Environment,
317-
kernel_name: &str,
318-
kernel: &Kernel,
319-
write: &mut impl Write,
320-
) -> Result<()> {
321-
// Easier to do in Rust than Jinja.
322-
let sources = kernel
323-
.src()
324-
.iter()
325-
.map(|src| format!("\"{src}\""))
326-
.collect_vec()
327-
.join("\n");
328-
329-
let (cuda_capabilities, rocm_archs, cuda_flags, hip_flags, cuda_minver) = match kernel {
330-
Kernel::Cuda {
331-
cuda_capabilities,
332-
cuda_flags,
333-
cuda_minver,
334-
..
335-
} => (
336-
cuda_capabilities.as_deref(),
337-
None,
338-
cuda_flags.as_deref(),
339-
None,
340-
cuda_minver.as_ref(),
341-
),
342-
Kernel::Rocm {
343-
rocm_archs,
344-
hip_flags,
345-
..
346-
} => (
347-
None,
348-
rocm_archs.as_deref(),
349-
None,
350-
hip_flags.as_deref(),
351-
None,
352-
),
353-
_ => unreachable!("Unsupported kernel type for CUDA rendering"),
354-
};
355-
356-
env.get_template("cuda/kernel.cmake")
357-
.wrap_err("Cannot get kernel template")?
358-
.render_to_write(
359-
context! {
360-
cuda_capabilities => cuda_capabilities,
361-
cuda_flags => cuda_flags.map(|flags| flags.join(";")),
362-
cuda_minver => cuda_minver.map(ToString::to_string),
363-
cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")),
364-
rocm_archs => rocm_archs,
365-
hip_flags => hip_flags.map(|flags| flags.join(";")),
366-
includes => kernel.include().map(prefix_and_join_includes),
367-
kernel_name => kernel_name,
368-
supports_hipify => matches!(kernel, Kernel::Rocm{ .. }),
369-
sources => sources,
370-
},
371-
&mut *write,
372-
)
373-
.wrap_err("Cannot render kernel template")?;
374-
375-
write.write_all(b"\n")?;
376-
377-
Ok(())
378-
}
379-
380310
pub fn render_extension(
381311
env: &Environment,
382312
name: &str,
@@ -428,15 +358,3 @@ pub fn render_preamble(
428358

429359
Ok(())
430360
}
431-
432-
fn prefix_and_join_includes<S>(includes: impl AsRef<[S]>) -> String
433-
where
434-
S: AsRef<str>,
435-
{
436-
includes
437-
.as_ref()
438-
.iter()
439-
.map(|include| format!("${{CMAKE_SOURCE_DIR}}/{}", include.as_ref()))
440-
.collect_vec()
441-
.join(";")
442-
}

0 commit comments

Comments
 (0)