Skip to content

Commit 2d024de

Browse files
committed
build(jax): split pinned cpu and gpu groups
Keep the public jax extra unchanged and only split the pinned dependency groups used by CPU and CUDA CI. CPU uses plain jax, while GPU uses jax[cuda12]. Authored by OpenClaw (model: gpt-5.5)
1 parent 5bd0889 commit 2d024de

4 files changed

Lines changed: 7 additions & 4 deletions

File tree

.github/workflows/test_cc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
- run: python -m pip install uv
4545
- name: Install Python dependencies
4646
run: |
47-
source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_cpu --group pin_pytorch_cpu --group pin_jax --torch-backend cpu
47+
source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_cpu --group pin_pytorch_cpu --group pin_jax_cpu --torch-backend cpu
4848
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
4949
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py mpich
5050
- name: Convert models

.github/workflows/test_cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
&& sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
4444
if: false # skip as we use nvidia image
4545
- run: python -m pip install -U uv
46-
- run: source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_gpu --group pin_pytorch_gpu --group pin_jax "jax[cuda12]"
46+
- run: source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_gpu --group pin_pytorch_gpu --group pin_jax_gpu
4747
- run: |
4848
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
4949
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')

.github/workflows/test_python.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
source/install/uv_with_retry.sh pip install --system openmpi --group pin_tensorflow_cpu --group pin_pytorch_cpu --torch-backend cpu
3232
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
3333
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
34-
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py --group pin_jax
34+
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py --group pin_jax_cpu
3535
source/install/uv_with_retry.sh pip install --system --find-links "https://www.paddlepaddle.org.cn/packages/nightly/cpu/paddlepaddle/" --index-url https://pypi.org/simple --trusted-host www.paddlepaddle.org.cn --trusted-host paddlepaddle.org.cn paddlepaddle==3.4.0.dev20260310
3636
env:
3737
# Please note that uv has some issues with finding

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,12 @@ pin_pytorch_cpu = [
175175
pin_pytorch_gpu = [
176176
"torch==2.10.0",
177177
]
178-
pin_jax = [
178+
pin_jax_cpu = [
179179
"jax==0.5.0;python_version>='3.10'",
180180
]
181+
pin_jax_gpu = [
182+
"jax[cuda12]==0.5.0;python_version>='3.10'",
183+
]
181184

182185
[tool.setuptools_scm]
183186

0 commit comments

Comments
 (0)