Skip to content

Commit fd5dd4f

Browse files
committed
Support extra data files
This was supported by the Nix builder before, but not the Windows builder.
1 parent 13c4af4 commit fd5dd4f

23 files changed

Lines changed: 535 additions & 36 deletions

File tree

.github/workflows/build_kernel.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,17 @@ jobs:
2828
USER: runner
2929
- name: Nix info
3030
run: nix-shell -p nix-info --run "nix-info -m"
31+
3132
- name: Build relu kernel
3233
run: ( cd builder/examples/relu && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
3334
- name: Copy relu kernel
3435
run: cp -rL builder/examples/relu/result relu-kernel
3536

37+
- name: Build extra-data kernel
38+
run: ( cd builder/examples/extra-data && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
39+
- name: Copy extra-data kernel
40+
run: cp -rL builder/examples/extra-data/result extra-data
41+
3642
- name: Build relu kernel (CPU)
3743
run: ( cd builder/examples/relu && nix build .\#redistributable.torch29-cxx11-cpu-x86_64-linux )
3844
- name: Copy relu kernel (CPU)
@@ -70,6 +76,7 @@ jobs:
7076
path: |
7177
activation-kernel
7278
cutlass-gemm-kernel
79+
extra-data
7380
relu-kernel
7481
relu-kernel-cpu
7582
relu-backprop-compile-kernel

build2cmake/src/config/mod.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,18 @@ pub struct Torch {
122122
}
123123

124124
impl Torch {
125-
pub fn data_globs(&self) -> Option<Vec<String>> {
125+
pub fn data_extensions(&self) -> Option<Vec<String>> {
126126
match self.pyext.as_ref() {
127127
Some(exts) => {
128-
let globs = exts
128+
let extensions = exts
129129
.iter()
130130
.filter(|&ext| ext != "py" && ext != "pyi")
131-
.map(|ext| format!("\"**/*.{ext}\""))
131+
.map(|ext| ext.clone())
132132
.collect_vec();
133-
if globs.is_empty() {
133+
if extensions.is_empty() {
134134
None
135135
} else {
136-
Some(globs)
136+
Some(extensions)
137137
}
138138
}
139139

build2cmake/src/templates/build-variants.cmake

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -120,24 +120,25 @@ endfunction()
120120
#
121121
function(add_kernels_install_target TARGET_NAME PACKAGE_NAME BUILD_VARIANT_NAME)
122122
set(oneValueArgs INSTALL_PREFIX)
123-
cmake_parse_arguments(ARG "" "${oneValueArgs}" "" ${ARGN})
123+
set(multiValueArgs DATA_EXTENSIONS)
124+
cmake_parse_arguments(ARG "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
124125

125126
if(NOT ARG_INSTALL_PREFIX)
126127
set(ARG_INSTALL_PREFIX "${CMAKE_INSTALL_PREFIX}")
127128
endif()
128129

129130
if (${GPU_LANG} STREQUAL "CPU")
130-
set(_BACKEND "cpu")
131+
set(_BACKEND "cpu")
131132
elseif (${GPU_LANG} STREQUAL "CUDA")
132-
set(_BACKEND "cuda")
133+
set(_BACKEND "cuda")
133134
elseif (${GPU_LANG} STREQUAL "HIP")
134-
set(_BACKEND "rocm")
135+
set(_BACKEND "rocm")
135136
elseif (${GPU_LANG} STREQUAL "METAL")
136-
set(_BACKEND "metal")
137+
set(_BACKEND "metal")
137138
elseif (${GPU_LANG} STREQUAL "SYCL")
138-
set(_BACKEND "xpu")
139+
set(_BACKEND "xpu")
139140
else()
140-
message(FATAL_ERROR "Unsupported GPU_LANG: ${GPU_LANG}")
141+
message(FATAL_ERROR "Unsupported GPU_LANG: ${GPU_LANG}")
141142
endif()
142143

143144
# Set the installation directory
@@ -155,8 +156,8 @@ function(add_kernels_install_target TARGET_NAME PACKAGE_NAME BUILD_VARIANT_NAME)
155156
# Glob Python files to install recursively.
156157
file(GLOB_RECURSE PYTHON_FILES RELATIVE "${CMAKE_SOURCE_DIR}/torch-ext/${PACKAGE_NAME}" "${CMAKE_SOURCE_DIR}/torch-ext/${PACKAGE_NAME}/*.py")
157158
foreach(python_file IN LISTS PYTHON_FILES)
158-
get_filename_component(python_file_dir "${python_file}" DIRECTORY)
159-
install(FILES "${CMAKE_SOURCE_DIR}/torch-ext/${PACKAGE_NAME}/${python_file}"
159+
get_filename_component(python_file_dir "${python_file}" DIRECTORY)
160+
install(FILES "${CMAKE_SOURCE_DIR}/torch-ext/${PACKAGE_NAME}/${python_file}"
160161
DESTINATION "${KERNEL_INSTALL_DIR}/${python_file_dir}"
161162
COMPONENT ${TARGET_NAME})
162163
endforeach()
@@ -172,6 +173,17 @@ function(add_kernels_install_target TARGET_NAME PACKAGE_NAME BUILD_VARIANT_NAME)
172173
RENAME "__init__.py"
173174
COMPONENT ${TARGET_NAME})
174175

176+
# Install data files with specified extensions
177+
foreach(ext IN LISTS ARG_DATA_EXTENSIONS)
178+
file(GLOB_RECURSE DATA_FILES RELATIVE "${CMAKE_SOURCE_DIR}/torch-ext/${PACKAGE_NAME}" "${CMAKE_SOURCE_DIR}/torch-ext/${PACKAGE_NAME}/*.${ext}")
179+
foreach(data_file IN LISTS DATA_FILES)
180+
get_filename_component(data_file_dir "${data_file}" DIRECTORY)
181+
install(FILES "${CMAKE_SOURCE_DIR}/torch-ext/${PACKAGE_NAME}/${data_file}"
182+
DESTINATION "${KERNEL_INSTALL_DIR}/${data_file_dir}"
183+
COMPONENT ${TARGET_NAME})
184+
endforeach()
185+
endforeach()
186+
175187
message(STATUS "Added install rules for ${TARGET_NAME} -> ${BUILD_VARIANT_NAME}")
176188
endfunction()
177189

@@ -192,6 +204,9 @@ endfunction()
192204
# BUILD_VARIANT_NAME - Build variant name (e.g., "torch271-cxx11-cu124-x86_64-linux")
193205
#
194206
function(add_local_install_target TARGET_NAME PACKAGE_NAME BUILD_VARIANT_NAME)
207+
set(multiValueArgs DATA_EXTENSIONS)
208+
cmake_parse_arguments(ARG "" "" "${multiValueArgs}" ${ARGN})
209+
195210
# Define your local, folder based, installation directory
196211
set(LOCAL_INSTALL_DIR "${CMAKE_SOURCE_DIR}/build/${BUILD_VARIANT_NAME}")
197212
# Variant directory is where metadata.json should go (for kernels upload discovery)
@@ -206,17 +221,17 @@ function(add_local_install_target TARGET_NAME PACKAGE_NAME BUILD_VARIANT_NAME)
206221
)
207222

208223
if (${GPU_LANG} STREQUAL "CPU")
209-
set(_BACKEND "cpu")
224+
set(_BACKEND "cpu")
210225
elseif (${GPU_LANG} STREQUAL "CUDA")
211-
set(_BACKEND "cuda")
226+
set(_BACKEND "cuda")
212227
elseif (${GPU_LANG} STREQUAL "HIP")
213-
set(_BACKEND "rocm")
228+
set(_BACKEND "rocm")
214229
elseif (${GPU_LANG} STREQUAL "METAL")
215-
set(_BACKEND "metal")
230+
set(_BACKEND "metal")
216231
elseif (${GPU_LANG} STREQUAL "SYCL")
217-
set(_BACKEND "xpu")
232+
set(_BACKEND "xpu")
218233
else()
219-
message(FATAL_ERROR "Unsupported GPU_LANG: ${GPU_LANG}")
234+
message(FATAL_ERROR "Unsupported GPU_LANG: ${GPU_LANG}")
220235
endif()
221236

222237
# Add custom commands to copy files
@@ -242,8 +257,8 @@ function(add_local_install_target TARGET_NAME PACKAGE_NAME BUILD_VARIANT_NAME)
242257

243258
# Copy each Python file preserving directory structure
244259
foreach(python_file IN LISTS PYTHON_FILES)
245-
get_filename_component(python_file_dir "${python_file}" DIRECTORY)
246-
add_custom_command(TARGET local_install POST_BUILD
260+
get_filename_component(python_file_dir "${python_file}" DIRECTORY)
261+
add_custom_command(TARGET local_install POST_BUILD
247262
COMMAND ${CMAKE_COMMAND} -E make_directory
248263
${LOCAL_INSTALL_DIR}/${python_file_dir}
249264
COMMAND ${CMAKE_COMMAND} -E copy_if_different
@@ -253,6 +268,22 @@ function(add_local_install_target TARGET_NAME PACKAGE_NAME BUILD_VARIANT_NAME)
253268
)
254269
endforeach()
255270

271+
# Copy data files with specified extensions
272+
foreach(ext IN LISTS ARG_DATA_EXTENSIONS)
273+
file(GLOB_RECURSE DATA_FILES RELATIVE "${CMAKE_SOURCE_DIR}/torch-ext/${PACKAGE_NAME}" "${CMAKE_SOURCE_DIR}/torch-ext/${PACKAGE_NAME}/*.${ext}")
274+
foreach(data_file IN LISTS DATA_FILES)
275+
get_filename_component(data_file_dir "${data_file}" DIRECTORY)
276+
add_custom_command(TARGET local_install POST_BUILD
277+
COMMAND ${CMAKE_COMMAND} -E make_directory
278+
${LOCAL_INSTALL_DIR}/${data_file_dir}
279+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
280+
${CMAKE_SOURCE_DIR}/torch-ext/${PACKAGE_NAME}/${data_file}
281+
${LOCAL_INSTALL_DIR}/${data_file_dir}/
282+
COMMENT "Copying ${data_file} to ${LOCAL_INSTALL_DIR}/${data_file_dir}"
283+
)
284+
endforeach()
285+
endforeach()
286+
256287
# Create both directories: variant dir for metadata.json, package dir for binaries
257288
file(MAKE_DIRECTORY ${VARIANT_DIR})
258289
file(MAKE_DIRECTORY ${LOCAL_INSTALL_DIR})
Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Include Metal shader compilation utilities if needed
22
if(GPU_LANG STREQUAL "METAL")
3-
include(${CMAKE_CURRENT_LIST_DIR}/cmake/compile-metal.cmake)
3+
include(${CMAKE_CURRENT_LIST_DIR}/cmake/compile-metal.cmake)
44
endif()
55

66
# Define the extension target with unified parameters
@@ -15,23 +15,25 @@ define_gpu_extension_target(
1515
WITH_SOABI)
1616

1717
if(NOT (MSVC OR GPU_LANG STREQUAL "SYCL"))
18-
target_link_options({{ ops_name }} PRIVATE -static-libstdc++)
18+
target_link_options({{ ops_name }} PRIVATE -static-libstdc++)
1919
endif()
2020

2121
if(GPU_LANG STREQUAL "SYCL")
22-
target_link_options({{ ops_name }} PRIVATE ${sycl_link_flags})
23-
target_link_libraries({{ ops_name }} PRIVATE dnnl)
22+
target_link_options({{ ops_name }} PRIVATE ${sycl_link_flags})
23+
target_link_libraries({{ ops_name }} PRIVATE dnnl)
2424
endif()
2525

2626
# Compile Metal shaders if any were found
2727
if(GPU_LANG STREQUAL "METAL")
28-
if(ALL_METAL_SOURCES)
29-
compile_metal_shaders({{ ops_name }} "${ALL_METAL_SOURCES}" "${METAL_INCLUDE_DIRS}")
30-
endif()
28+
if(ALL_METAL_SOURCES)
29+
compile_metal_shaders({{ ops_name }} "${ALL_METAL_SOURCES}" "${METAL_INCLUDE_DIRS}")
30+
endif()
3131
endif()
3232

3333
# Add kernels_install target for huggingface/kernels library layout
34-
add_kernels_install_target({{ ops_name }} "{{ python_name }}" "${BUILD_VARIANT_NAME}")
34+
add_kernels_install_target({{ ops_name }} "{{ python_name }}" "${BUILD_VARIANT_NAME}"
35+
DATA_EXTENSIONS "{{ data_extensions | join(';') }}")
3536

3637
# Add local_install target for local development with get_local_kernel()
37-
add_local_install_target({{ ops_name }} "{{ python_name }}" "${BUILD_VARIANT_NAME}")
38+
add_local_install_target({{ ops_name }} "{{ python_name }}" "${BUILD_VARIANT_NAME}"
39+
DATA_EXTENSIONS "{{ data_extensions | join(';') }}")

build2cmake/src/torch/common.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ pub fn write_setup_py(
3232
) -> Result<()> {
3333
let writer = file_set.entry("setup.py");
3434

35-
let data_globs = torch.data_globs().map(|globs| globs.join(", "));
35+
let data_globs = torch
36+
.data_extensions()
37+
.map(|exts| exts.iter().map(|ext| format!("\"**/*.{ext}\"")).join(", "));
3638

3739
env.get_template("setup.py")
3840
.wrap_err("Cannot get setup.py template")?
@@ -225,6 +227,7 @@ pub fn write_cmake_helpers(file_set: &mut FileSet) {
225227
pub fn render_extension(
226228
env: &Environment,
227229
general: &General,
230+
torch: &Torch,
228231
ops_name: &str,
229232
write: &mut impl Write,
230233
) -> Result<()> {
@@ -234,6 +237,7 @@ pub fn render_extension(
234237
context! {
235238
python_name => general.python_name(),
236239
ops_name => ops_name,
240+
data_extensions => torch.data_extensions(),
237241
},
238242
&mut *write,
239243
)
@@ -300,7 +304,7 @@ pub fn write_cmake(
300304

301305
render_kernel_components(env, build, cmake_writer)?;
302306

303-
render_extension(env, &build.general, ops_name, cmake_writer)?;
307+
render_extension(env, &build.general, torch, ops_name, cmake_writer)?;
304308

305309
Ok(())
306310
}

build2cmake/src/torch/noarch.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ fn write_pyproject_toml(
6565
let writer = file_set.entry("pyproject.toml");
6666

6767
let name = &general.name;
68-
let data_globs = torch.and_then(|torch| torch.data_globs().map(|globs| globs.join(", ")));
68+
let data_globs = torch.and_then(|torch| {
69+
torch
70+
.data_extensions()
71+
.map(|exts| exts.iter().map(|ext| format!("\"**/*.{ext}\"")).join(", "))
72+
});
6973

7074
// Common python dependencies (no backend-specific ones)
7175
let python_dependencies = itertools::process_results(general.python_depends(), |iter| {
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
[general]
2+
name = "extra-data"
3+
backends = [
4+
"cpu",
5+
"cuda",
6+
"metal",
7+
"rocm",
8+
"xpu",
9+
]
10+
11+
[torch]
12+
src = [
13+
"torch-ext/torch_binding.cpp",
14+
"torch-ext/torch_binding.h",
15+
]
16+
pyext = ["json", "py"]
17+
18+
[kernel.relu]
19+
backend = "cuda"
20+
depends = ["torch"]
21+
src = ["relu_cuda/relu.cu"]
22+
23+
[kernel.relu_metal]
24+
backend = "metal"
25+
src = [
26+
"relu_metal/relu.mm",
27+
"relu_metal/relu.metal",
28+
"relu_metal/common.h",
29+
]
30+
depends = [ "torch" ]
31+
32+
[kernel.relu_rocm]
33+
backend = "rocm"
34+
rocm-archs = [
35+
"gfx906",
36+
"gfx908",
37+
"gfx90a",
38+
"gfx940",
39+
"gfx941",
40+
"gfx942",
41+
"gfx1030",
42+
"gfx1100",
43+
"gfx1101",
44+
]
45+
depends = ["torch"]
46+
src = ["relu_cuda/relu.cu"]
47+
48+
[kernel.relu_xpu]
49+
backend = "xpu"
50+
depends = ["torch"]
51+
src = ["relu_xpu/relu.cpp"]
52+
53+
[kernel.relu_cpu]
54+
backend = "cpu"
55+
depends = ["torch"]
56+
src = ["relu_cpu/relu_cpu.cpp"]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
description = "Flake for ReLU kernel";
3+
4+
inputs = {
5+
kernel-builder.url = "path:../../..";
6+
};
7+
8+
outputs =
9+
{
10+
self,
11+
kernel-builder,
12+
}:
13+
kernel-builder.lib.genKernelFlakeOutputs {
14+
inherit self;
15+
path = ./.;
16+
};
17+
}

0 commit comments

Comments
 (0)