Skip to content

Commit eab3419

Browse files
authored
build(jax): split pinned cpu and gpu groups (#5436)
Problem - CPU and CUDA CI need different pinned JAX dependency groups. - The public `jax` extra should stay unchanged. Change - Keep the existing `jax` optional dependency as plain `jax`. - Split the pinned groups into `pin_jax_cpu` and `pin_jax_gpu`. - Use `pin_jax_cpu` in CPU jobs and `pin_jax_gpu` in the CUDA job. Validation - `git diff --check` - `uv pip compile pyproject.toml --group pin_jax_cpu --python-version 3.10` - `uv pip compile pyproject.toml --group pin_jax_gpu --python-version 3.10` Authored by OpenClaw (model: gpt-5.5) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Split JAX dependency pins into separate CPU and GPU variants. * Updated CI workflows to install the appropriate JAX package group for CPU or GPU runs. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 78bffde commit eab3419

4 files changed

Lines changed: 7 additions & 4 deletions

File tree

.github/workflows/test_cc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
- run: python -m pip install uv
4545
- name: Install Python dependencies
4646
run: |
47-
source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_cpu --group pin_pytorch_cpu --group pin_jax --torch-backend cpu
47+
source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_cpu --group pin_pytorch_cpu --group pin_jax_cpu --torch-backend cpu
4848
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
4949
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py mpich
5050
- name: Convert models

.github/workflows/test_cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
&& sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
4444
if: false # skip as we use nvidia image
4545
- run: python -m pip install -U uv
46-
- run: source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_gpu --group pin_pytorch_gpu --group pin_jax "jax[cuda12]"
46+
- run: source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_gpu --group pin_pytorch_gpu --group pin_jax_gpu
4747
- run: |
4848
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
4949
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')

.github/workflows/test_python.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
source/install/uv_with_retry.sh pip install --system openmpi --group pin_tensorflow_cpu --group pin_pytorch_cpu --torch-backend cpu
3232
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
3333
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
34-
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py --group pin_jax
34+
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py --group pin_jax_cpu
3535
source/install/uv_with_retry.sh pip install --system --find-links "https://www.paddlepaddle.org.cn/packages/nightly/cpu/paddlepaddle/" --index-url https://pypi.org/simple --trusted-host www.paddlepaddle.org.cn --trusted-host paddlepaddle.org.cn paddlepaddle==3.4.0.dev20260310
3636
env:
3737
# Please note that uv has some issues with finding

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,12 @@ pin_pytorch_cpu = [
175175
pin_pytorch_gpu = [
176176
"torch==2.10.0",
177177
]
178-
pin_jax = [
178+
pin_jax_cpu = [
179179
"jax==0.5.0;python_version>='3.10'",
180180
]
181+
pin_jax_gpu = [
182+
"jax[cuda12]==0.5.0;python_version>='3.10'",
183+
]
181184

182185
[tool.setuptools_scm]
183186

0 commit comments

Comments
 (0)