Skip to content
This repository was archived by the owner on Apr 6, 2026. It is now read-only.

Commit 0918eaa

Browse files
shadeMedanieldk
andauthored
Add support for lower- and upper-bounds on Torch versions in build.toml (#320)
Bounds are specified via the `minver` and `maxver` keys in the `[torch]` section. They are applied to all kernels defined in the config file. --------- Co-authored-by: Daniël de Kok <me@danieldk.eu>
1 parent 5f9c8af commit 0918eaa

22 files changed

Lines changed: 299 additions & 10 deletions

File tree

.github/workflows/nix_fmt.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: "Check Nix formatting"
1+
name: "Nix checks"
22
on:
33
push:
44
branches: [main]
@@ -9,7 +9,7 @@ on:
99

1010
jobs:
1111
build:
12-
name: Check Nix formatting
12+
name: Nix checks
1313
runs-on: ubuntu-latest
1414
steps:
1515
- uses: actions/checkout@v4
@@ -18,3 +18,5 @@ jobs:
1818
nix_path: nixpkgs=channel:nixos-unstable
1919
- name: Check formatting
2020
run: nix fmt -- --ci
21+
- name: Nix checks
22+
run: nix build .\#checks.x86_64-linux.default

build2cmake/src/config/v2.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ impl Display for PythonDependency {
9292
#[serde(deny_unknown_fields)]
9393
pub struct Torch {
9494
pub include: Option<Vec<String>>,
95+
pub minver: Option<Version>,
96+
pub maxver: Option<Version>,
9597
pub pyext: Option<Vec<String>>,
9698

9799
#[serde(default)]
@@ -352,6 +354,8 @@ impl From<v1::Torch> for Torch {
352354
fn from(torch: v1::Torch) -> Self {
353355
Self {
354356
include: torch.include,
357+
minver: None,
358+
maxver: None,
355359
pyext: torch.pyext,
356360
src: torch.src,
357361
}

build2cmake/src/templates/cpu/preamble.cmake

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,20 @@ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
2525

2626
find_package(Torch REQUIRED)
2727

28+
run_python(TORCH_VERSION "import torch; print(torch.__version__.split('+')[0])" "Failed to get Torch version")
29+
30+
{% if torch_minver %}
31+
if (TORCH_VERSION VERSION_LESS {{ torch_minver }})
32+
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too old. "
33+
"Minimum required version is {{ torch_minver }}.")
34+
endif()
35+
{% endif %}
36+
37+
{% if torch_maxver %}
38+
if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }})
39+
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too new. "
40+
"Maximum supported version is {{ torch_maxver }}.")
41+
endif()
42+
{% endif %}
43+
2844
add_compile_definitions(CPU_KERNEL)

build2cmake/src/templates/cuda/preamble.cmake

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,22 @@ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
2929

3030
find_package(Torch REQUIRED)
3131

32+
run_python(TORCH_VERSION "import torch; print(torch.__version__.split('+')[0])" "Failed to get Torch version")
33+
34+
{% if torch_minver %}
35+
if (TORCH_VERSION VERSION_LESS {{ torch_minver }})
36+
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too old. "
37+
"Minimum required version is {{ torch_minver }}.")
38+
endif()
39+
{% endif %}
40+
41+
{% if torch_maxver %}
42+
if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }})
43+
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too new. "
44+
"Maximum supported version is {{ torch_maxver }}.")
45+
endif()
46+
{% endif %}
47+
3248
if (NOT TARGET_DEVICE STREQUAL "cuda" AND
3349
NOT TARGET_DEVICE STREQUAL "rocm")
3450
return()

build2cmake/src/templates/metal/preamble.cmake

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,22 @@ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
2525

2626
find_package(Torch REQUIRED)
2727

28+
run_python(TORCH_VERSION "import torch; print(torch.__version__.split('+')[0])" "Failed to get Torch version")
29+
30+
{% if torch_minver %}
31+
if (TORCH_VERSION VERSION_LESS {{ torch_minver }})
32+
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too old. "
33+
"Minimum required version is {{ torch_minver }}.")
34+
endif()
35+
{% endif %}
36+
37+
{% if torch_maxver %}
38+
if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }})
39+
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too new. "
40+
"Maximum supported version is {{ torch_maxver }}.")
41+
endif()
42+
{% endif %}
43+
2844
add_compile_definitions(METAL_KERNEL)
2945

3046
# Initialize list for Metal shader sources

build2cmake/src/templates/xpu/preamble.cmake

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,23 @@ find_package(Torch REQUIRED)
4343

4444
# Intel XPU backend detection and setup
4545
if(NOT TORCH_VERSION)
46-
run_python(TORCH_VERSION "import torch; print(torch.__version__)" "Failed to get Torch version")
46+
run_python(TORCH_VERSION "import torch; print(torch.__version__.split('+')[0])" "Failed to get Torch version")
4747
endif()
4848

49+
{% if torch_minver %}
50+
if (TORCH_VERSION VERSION_LESS {{ torch_minver }})
51+
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too old. "
52+
"Minimum required version is {{ torch_minver }}.")
53+
endif()
54+
{% endif %}
55+
56+
{% if torch_maxver %}
57+
if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }})
58+
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too new. "
59+
"Maximum supported version is {{ torch_maxver }}.")
60+
endif()
61+
{% endif %}
62+
4963
# Check for Intel XPU support in PyTorch
5064
run_python(XPU_AVAILABLE
5165
"import torch; print('true' if hasattr(torch, 'xpu') else 'false')"

build2cmake/src/torch/cpu.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use super::{common::write_pyproject_toml, kernel_ops_identifier};
88
use crate::{
99
config::{Build, Kernel, Torch},
1010
fileset::FileSet,
11+
version::Version,
1112
};
1213

1314
static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
@@ -71,7 +72,13 @@ fn write_cmake(
7172

7273
let cmake_writer = file_set.entry("CMakeLists.txt");
7374

74-
render_preamble(env, name, cmake_writer)?;
75+
render_preamble(
76+
env,
77+
name,
78+
torch.minver.as_ref(),
79+
torch.maxver.as_ref(),
80+
cmake_writer,
81+
)?;
7582

7683
// Add deps once we have any non-CUDA deps.
7784
// render_deps(env, build, cmake_writer)?;
@@ -168,12 +175,20 @@ pub fn render_kernel(
168175
Ok(())
169176
}
170177

171-
fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> {
178+
fn render_preamble(
179+
env: &Environment,
180+
name: &str,
181+
torch_minver: Option<&Version>,
182+
torch_maxver: Option<&Version>,
183+
write: &mut impl Write,
184+
) -> Result<()> {
172185
env.get_template("cpu/preamble.cmake")
173186
.wrap_err("Cannot get CMake prelude template")?
174187
.render_to_write(
175188
context! {
176189
name => name,
190+
torch_minver => torch_minver.map(|v| v.to_string()),
191+
torch_maxver => torch_maxver.map(|v| v.to_string()),
177192
},
178193
&mut *write,
179194
)

build2cmake/src/torch/cuda.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ fn write_cmake(
168168
name,
169169
build.general.cuda_minver.as_ref(),
170170
build.general.cuda_maxver.as_ref(),
171+
torch.minver.as_ref(),
172+
torch.maxver.as_ref(),
171173
cmake_writer,
172174
)?;
173175

@@ -390,6 +392,8 @@ pub fn render_preamble(
390392
name: &str,
391393
cuda_minver: Option<&Version>,
392394
cuda_maxver: Option<&Version>,
395+
torch_minver: Option<&Version>,
396+
torch_maxver: Option<&Version>,
393397
write: &mut impl Write,
394398
) -> Result<()> {
395399
env.get_template("cuda/preamble.cmake")
@@ -399,6 +403,8 @@ pub fn render_preamble(
399403
name => name,
400404
cuda_minver => cuda_minver.map(|v| v.to_string()),
401405
cuda_maxver => cuda_maxver.map(|v| v.to_string()),
406+
torch_minver => torch_minver.map(|v| v.to_string()),
407+
torch_maxver => torch_maxver.map(|v| v.to_string()),
402408
cuda_supported_archs => cuda_supported_archs(),
403409
platform => env::consts::OS
404410
},

build2cmake/src/torch/metal.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use super::{common::write_pyproject_toml, kernel_ops_identifier};
88
use crate::{
99
config::{Build, Kernel, Torch},
1010
fileset::FileSet,
11+
version::Version,
1112
};
1213

1314
static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
@@ -87,7 +88,13 @@ fn write_cmake(
8788

8889
let cmake_writer = file_set.entry("CMakeLists.txt");
8990

90-
render_preamble(env, name, cmake_writer)?;
91+
render_preamble(
92+
env,
93+
name,
94+
torch.minver.as_ref(),
95+
torch.maxver.as_ref(),
96+
cmake_writer,
97+
)?;
9198

9299
// Add deps once we have any non-CUDA deps.
93100
// render_deps(env, build, cmake_writer)?;
@@ -184,12 +191,20 @@ pub fn render_kernel(
184191
Ok(())
185192
}
186193

187-
fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> {
194+
fn render_preamble(
195+
env: &Environment,
196+
name: &str,
197+
torch_minver: Option<&Version>,
198+
torch_maxver: Option<&Version>,
199+
write: &mut impl Write,
200+
) -> Result<()> {
188201
env.get_template("metal/preamble.cmake")
189202
.wrap_err("Cannot get CMake prelude template")?
190203
.render_to_write(
191204
context! {
192205
name => name,
206+
torch_minver => torch_minver.map(|v| v.to_string()),
207+
torch_maxver => torch_maxver.map(|v| v.to_string()),
193208
},
194209
&mut *write,
195210
)

build2cmake/src/torch/xpu.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use minijinja::{context, Environment};
99
use super::common::write_pyproject_toml;
1010
use super::kernel_ops_identifier;
1111
use crate::config::{Build, Dependency, Kernel, Torch};
12+
use crate::version::Version;
1213
use crate::FileSet;
1314

1415
static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
@@ -135,7 +136,13 @@ fn write_cmake(
135136

136137
let cmake_writer = file_set.entry("CMakeLists.txt");
137138

138-
render_preamble(env, name, cmake_writer)?;
139+
render_preamble(
140+
env,
141+
name,
142+
torch.minver.as_ref(),
143+
torch.maxver.as_ref(),
144+
cmake_writer,
145+
)?;
139146

140147
render_deps(env, build, cmake_writer)?;
141148

@@ -263,12 +270,20 @@ pub fn render_extension(
263270
Ok(())
264271
}
265272

266-
pub fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> {
273+
pub fn render_preamble(
274+
env: &Environment,
275+
name: &str,
276+
torch_minver: Option<&Version>,
277+
torch_maxver: Option<&Version>,
278+
write: &mut impl Write,
279+
) -> Result<()> {
267280
env.get_template("xpu/preamble.cmake")
268281
.wrap_err("Cannot get CMake prelude template")?
269282
.render_to_write(
270283
context! {
271284
name => name,
285+
torch_minver => torch_minver.map(|v| v.to_string()),
286+
torch_maxver => torch_maxver.map(|v| v.to_string()),
272287
},
273288
&mut *write,
274289
)

0 commit comments

Comments
 (0)