Skip to content

Commit 6bbe50b

Browse files
committed
build(jax): split cpu and gpu extras
Add explicit JAX CPU and CUDA extras so CPU workflows avoid pulling GPU runtime dependencies while CUDA jobs request jax[cuda12] through the dedicated GPU extra. Authored by OpenClaw (model: gpt-5.5)
1 parent 5bd0889 commit 6bbe50b

4 files changed

Lines changed: 37 additions & 9 deletions

File tree

.github/workflows/test_cc.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ jobs:
4444
- run: python -m pip install uv
4545
- name: Install Python dependencies
4646
run: |
47-
source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_cpu --group pin_pytorch_cpu --group pin_jax --torch-backend cpu
47+
source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_cpu --group pin_pytorch_cpu --group pin_jax_cpu --torch-backend cpu
4848
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
49-
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py mpich
49+
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax-cpu] mpi4py mpich
5050
- name: Convert models
5151
run: source/tests/infer/convert-models.sh
5252
# https://github.com/actions/runner-images/issues/9491

.github/workflows/test_cuda.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ jobs:
4343
&& sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
4444
if: false # skip as we use nvidia image
4545
- run: python -m pip install -U uv
46-
- run: source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_gpu --group pin_pytorch_gpu --group pin_jax "jax[cuda12]"
46+
- run: source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_gpu --group pin_pytorch_gpu --group pin_jax_gpu
4747
- run: |
4848
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
4949
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
5050
pip install --find-links "https://www.paddlepaddle.org.cn/packages/nightly/cu126/paddlepaddle-gpu/" --index-url https://pypi.org/simple --trusted-host www.paddlepaddle.org.cn --trusted-host paddlepaddle.org.cn "paddlepaddle-gpu==3.4.0.dev20260310"
51-
source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch,jax] mpi4py --reinstall-package deepmd-kit
51+
source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch,jax-gpu] mpi4py --reinstall-package deepmd-kit
5252
# See https://github.com/jax-ml/jax/issues/29042
5353
source/install/uv_with_retry.sh pip install --system -U 'nvidia-cublas-cu12>=12.9.0.13'
5454
env:

.github/workflows/test_python.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
source/install/uv_with_retry.sh pip install --system openmpi --group pin_tensorflow_cpu --group pin_pytorch_cpu --torch-backend cpu
3232
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
3333
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
34-
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py --group pin_jax
34+
source/install/uv_with_retry.sh pip install --system -e .[test,jax-cpu] mpi4py --group pin_jax_cpu
3535
source/install/uv_with_retry.sh pip install --system --find-links "https://www.paddlepaddle.org.cn/packages/nightly/cpu/paddlepaddle/" --index-url https://pypi.org/simple --trusted-host www.paddlepaddle.org.cn --trusted-host paddlepaddle.org.cn paddlepaddle==3.4.0.dev20260310
3636
env:
3737
# Please note that uv has some issues with finding

pyproject.toml

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,35 @@ cu12 = [
132132
"nvidia-cuda-nvcc-cu12",
133133
]
134134
jax = [
135+
# Backward-compatible alias for CPU JAX.
135136
# below is a funny workaround for
136137
# https://github.com/astral-sh/uv/issues/8601
137-
'jax>=0.4.33;python_version>="3.10"',
138-
'jax>=0.4.33;python_version>="3.10"',
138+
'jax[cpu]>=0.4.33;python_version>="3.10"',
139+
'jax[cpu]>=0.4.33;python_version>="3.10"',
140+
'flax>=0.10.0;python_version>="3.10"',
141+
'flax>=0.10.0;python_version>="3.10"',
142+
'orbax-checkpoint;python_version>="3.10"',
143+
'orbax-checkpoint;python_version>="3.10"',
144+
# The pinning of ml_dtypes may conflict with TF
145+
# 'jax-ai-stack;python_version>="3.10"',
146+
]
147+
jax-cpu = [
148+
# below is a funny workaround for
149+
# https://github.com/astral-sh/uv/issues/8601
150+
'jax[cpu]>=0.4.33;python_version>="3.10"',
151+
'jax[cpu]>=0.4.33;python_version>="3.10"',
152+
'flax>=0.10.0;python_version>="3.10"',
153+
'flax>=0.10.0;python_version>="3.10"',
154+
'orbax-checkpoint;python_version>="3.10"',
155+
'orbax-checkpoint;python_version>="3.10"',
156+
# The pinning of ml_dtypes may conflict with TF
157+
# 'jax-ai-stack;python_version>="3.10"',
158+
]
159+
jax-gpu = [
160+
# below is a funny workaround for
161+
# https://github.com/astral-sh/uv/issues/8601
162+
'jax[cuda12]>=0.4.33;python_version>="3.10"',
163+
'jax[cuda12]>=0.4.33;python_version>="3.10"',
139164
'flax>=0.10.0;python_version>="3.10"',
140165
'flax>=0.10.0;python_version>="3.10"',
141166
'orbax-checkpoint;python_version>="3.10"',
@@ -175,8 +200,11 @@ pin_pytorch_cpu = [
175200
pin_pytorch_gpu = [
176201
"torch==2.10.0",
177202
]
178-
pin_jax = [
179-
"jax==0.5.0;python_version>='3.10'",
203+
pin_jax_cpu = [
204+
"jax[cpu]==0.5.0;python_version>='3.10'",
205+
]
206+
pin_jax_gpu = [
207+
"jax[cuda12]==0.5.0;python_version>='3.10'",
180208
]
181209

182210
[tool.setuptools_scm]

0 commit comments

Comments
 (0)