Skip to content

Commit a38f483

Browse files
njzjzpre-commit-ci[bot]Copilot
authored
CI: pin cibuildwheel TF/PT deps to global pinnings (#5071)
This will make the CI more robust, i.e., new TF/PT versions will not break the CI until we bump the version. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added CPU-optimized build support for PyTorch and TensorFlow. * Platform-specific dependency resolution for Linux x86_64, macOS, and other architectures. * **Chores** * Updated build configuration to use configurable dependency groups and conditional pins for CPU vs GPU builds. * Adjusted CI/build selection logic to better handle CPU-targeted packaging across environments. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <njzjz@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent f95cb74 commit a38f483

4 files changed

Lines changed: 63 additions & 12 deletions

File tree

backend/find_pytorch.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import importlib
33
import os
4-
import platform
54
import site
65
from functools import (
76
lru_cache,
@@ -30,6 +29,10 @@
3029
Version,
3130
)
3231

32+
from .utils import (
33+
read_dependencies_from_dependency_group,
34+
)
35+
3336

3437
@lru_cache
3538
def find_pytorch() -> tuple[Optional[str], list[str]]:
@@ -108,15 +111,15 @@ def get_pt_requirement(pt_version: str = "") -> dict:
108111
"""
109112
if pt_version is None:
110113
return {"torch": []}
111-
if (
112-
os.environ.get("CIBUILDWHEEL", "0") == "1"
113-
and platform.system() == "Linux"
114-
and platform.machine() == "x86_64"
115-
):
114+
cibw_requirement = []
115+
if os.environ.get("CIBUILDWHEEL", "0") == "1":
116116
cuda_version = os.environ.get("CUDA_VERSION", "12.2")
117117
if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"):
118118
# CUDA 12.2, cudnn 9
119-
pt_version = "2.8.0"
119+
# or CPU builds
120+
cibw_requirement = read_dependencies_from_dependency_group(
121+
"pin_pytorch_cpu"
122+
)
120123
elif cuda_version in SpecifierSet(">=11,<12"):
121124
# CUDA 11.8, cudnn 8
122125
pt_version = "2.3.1"
@@ -141,6 +144,7 @@ def get_pt_requirement(pt_version: str = "") -> dict:
141144
# https://github.com/pytorch/pytorch/commit/7e0c26d4d80d6602aed95cb680dfc09c9ce533bc
142145
else "torch>=2.1.0",
143146
*mpi_requirement,
147+
*cibw_requirement,
144148
],
145149
}
146150

backend/find_tensorflow.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
SpecifierSet,
2727
)
2828

29+
from .utils import (
30+
read_dependencies_from_dependency_group,
31+
)
32+
2933

3034
@lru_cache
3135
def find_tensorflow() -> tuple[Optional[str], list[str]]:
@@ -91,10 +95,9 @@ def find_tensorflow() -> tuple[Optional[str], list[str]]:
9195
cuda_version = os.environ.get("CUDA_VERSION", "12.2")
9296
if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"):
9397
# CUDA 12.2, cudnn 9
98+
# or CPU builds
9499
requires.extend(
95-
[
96-
"tensorflow-cpu>=2.18.0; platform_machine=='x86_64' and platform_system == 'Linux'",
97-
]
100+
read_dependencies_from_dependency_group("pin_tensorflow_cpu")
98101
)
99102
elif cuda_version in SpecifierSet(">=11,<12"):
100103
# CUDA 11.8, cudnn 8

backend/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import sys
3+
from pathlib import (
4+
Path,
5+
)
6+
7+
from dependency_groups import (
8+
resolve,
9+
)
10+
11+
if sys.version_info >= (3, 11):
12+
import tomllib
13+
else:
14+
import tomli as tomllib
15+
16+
17+
def read_dependencies_from_dependency_group(group: str) -> tuple[str, ...]:
18+
"""
19+
Reads dependencies from a dependency group.
20+
21+
Parameters
22+
----------
23+
group : str
24+
The name of the dependency group.
25+
26+
Returns
27+
-------
28+
tuple[str, ...]
29+
A tuple of dependencies in the specified group.
30+
"""
31+
with Path("pyproject.toml").open("rb") as f:
32+
pyproject = tomllib.load(f)
33+
34+
groups = pyproject["dependency-groups"]
35+
36+
return resolve(groups, group)

pyproject.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ requires = [
55
"scikit-build-core>=0.5,<0.11,!=0.6.0",
66
"packaging",
77
'tomli >= 1.1.0 ; python_version < "3.11"',
8+
"dependency_groups",
89
]
910
build-backend = "backend.dp_backend"
1011
backend-path = ["."]
@@ -159,13 +160,20 @@ dev = [
159160
"mpich",
160161
]
161162
pin_tensorflow_cpu = [
162-
"tensorflow-cpu~=2.18.0",
163+
# https://github.com/tensorflow/tensorflow/issues/75279
164+
# macos x86 has been deprecated
165+
"tensorflow-cpu~=2.18.0; platform_machine=='x86_64' and platform_system == 'Linux'",
166+
"tensorflow~=2.18.0; (platform_machine!='x86_64' or platform_system != 'Linux') and (platform_machine!='x86_64' or platform_system != 'Darwin')",
167+
"tensorflow; platform_machine=='x86_64' and platform_system == 'Darwin'",
163168
]
164169
pin_tensorflow_gpu = [
165170
"tensorflow~=2.18.0",
166171
]
167172
pin_pytorch_cpu = [
168-
"torch~=2.8.0",
173+
# https://github.com/pytorch/pytorch/issues/114602
174+
# macos x86 has been deprecated
175+
"torch~=2.8.0; platform_machine!='x86_64' or platform_system != 'Darwin'",
176+
"torch; platform_machine=='x86_64' and platform_system == 'Darwin'",
169177
]
170178
pin_pytorch_gpu = [
171179
"torch~=2.7.0",

0 commit comments

Comments
 (0)