Skip to content

Commit 8c072f5

Browse files
committed
Update JAX to 0.10.0 for pre-training
1 parent 348355c commit 8c072f5

5 files changed

Lines changed: 108 additions & 129 deletions

File tree

docs/development/update_dependencies.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ Next, run `generate_requirements.sh` to generate the new requirements files. Thi
7878
script wraps the `seed-env` CLI and handles exporting the lock, and applying any
7979
overrides. You will need to do this separately for the TPU and GPU environments.
8080

81-
> **Note:** The current `src/dependencies/requirements/generated_requirements/` in the repository were generated using JAX build commit hash: [e0d2967b50abbefd651d563dbcd7afbcb963d08c](https://github.com/jax-ml/jax/commit/e0d2967b50abbefd651d563dbcd7afbcb963d08c).
82-
8381
### TPU Pre-Training
8482

83+
> **Note:** The current `src/dependencies/requirements/generated_requirements/tpu-requirements.txt` in the repository was generated using JAX build commit hash: efd6cf797ee9c4f29c6c6d5e91ae4432209063be. When regenerating the requirements, either use the same commit hash or update this hash if you use a different one.
84+
8585
If you have made changes to TPU pre-training dependencies in `src/dependencies/requirements/base_requirements/tpu-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
8686

8787
```bash
@@ -98,6 +98,8 @@ mv generated_artifacts/python3_12/tpu-requirements.txt \
9898

9999
### TPU Post-Training
100100

101+
> **Note:** The current `src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt` in the repository was generated using JAX build commit hash: e0d2967b50abbefd651d563dbcd7afbcb963d08c. When regenerating the requirements, either use the same commit hash or update this hash if you use a different one.
102+
101103
If you have made changes to the post-training dependencies in `src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt`, you need to regenerate the pinned post-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
102104

103105
```bash
@@ -115,6 +117,8 @@ mv generated_artifacts/python3_12/tpu-post-train-requirements.txt \
115117

116118
### GPU Pre-Training
117119

120+
> **Note:** The current `src/dependencies/requirements/generated_requirements/cuda12-requirements.txt` in the repository was generated using JAX build commit hash: efd6cf797ee9c4f29c6c6d5e91ae4432209063be. When regenerating the requirements, either use the same commit hash or update this hash if you use a different one.
121+
118122
If you have made changes to the GPU pre-training dependencies in `src/dependencies/requirements/base_requirements/cuda12-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
119123

120124
```bash
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
libtpu>=0.0.38
1+
datasets>=4.8.5
2+
fsspec>=2023.1.0,<=2026.2.0

src/dependencies/requirements/generated_requirements/cuda12-requirements.txt

Lines changed: 53 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ aqtp>=0.9.0
1414
array-record>=0.8.3
1515
astroid>=4.0.4
1616
astunparse>=1.6.3
17-
attrs>=25.4.0
17+
attrs>=26.1.0
1818
auditwheel>=6.6.0
1919
black>=25.12.0
20-
build>=1.4.0
20+
build>=1.4.3
2121
certifi>=2026.2.25
2222
cffi>=2.0.0 ; platform_python_implementation != 'PyPy'
2323
cfgv>=3.5.0
24-
charset-normalizer>=3.4.6
24+
charset-normalizer>=3.4.7
2525
chex>=0.1.91
2626
click>=8.3.3
2727
cloud-accelerator-diagnostics>=0.1.1
@@ -30,13 +30,11 @@ cloudpickle>=3.1.2
3030
clu>=0.0.12
3131
colorama>=0.4.6
3232
contourpy>=1.3.3
33-
cryptography>=47.0.0
33+
cryptography>=48.0.0
3434
cycler>=0.12.1
35-
dataclasses-json>=0.6.7
36-
datasets>=4.8.5
35+
datasets>=2.14.4
3736
decorator>=5.2.1
38-
deprecated>=1.3.1
39-
dill>=0.4.1
37+
dill>=0.3.7
4038
distlib>=0.4.0
4139
distro>=1.9.0
4240
dm-tree>=0.1.10
@@ -48,72 +46,71 @@ einshape>=1.0
4846
etils>=1.14.0
4947
execnet>=2.1.2
5048
fastapi>=0.136.1
51-
filelock>=3.20.3
49+
filelock>=3.28.0
5250
flatbuffers>=25.12.19
53-
flax>=0.12.6
51+
flax>=0.12.7
5452
fonttools>=4.62.1
5553
frozenlist>=1.8.0
56-
fsspec>=2026.2.0
54+
fsspec>=2026.3.0
5755
gast>=0.7.0
58-
gcsfs>=2026.2.0
56+
gcsfs>=2026.5.0
5957
google-api-core>=2.30.3
60-
google-api-python-client>=2.194.0
61-
google-auth>=2.49.2
62-
google-auth-httplib2>=0.3.1
63-
google-auth-oauthlib>=1.3.1
64-
google-cloud-aiplatform>=1.148.1
58+
google-api-python-client>=2.196.0
59+
google-auth>=2.52.0
60+
google-auth-httplib2>=0.4.0
61+
google-auth-oauthlib>=1.4.0
62+
google-cloud-aiplatform>=1.151.0
6563
google-cloud-appengine-logging>=1.9.0
6664
google-cloud-audit-log>=0.5.0
6765
google-cloud-bigquery>=3.41.0
68-
google-cloud-core>=2.5.1
66+
google-cloud-core>=2.6.0
6967
google-cloud-logging>=3.15.0
7068
google-cloud-mldiagnostics>=1.0.2
7169
google-cloud-monitoring>=2.30.0
7270
google-cloud-resource-manager>=1.17.0
7371
google-cloud-storage>=3.10.1
7472
google-cloud-storage-control>=1.11.0
7573
google-crc32c>=1.8.0
76-
google-genai>=1.73.1
74+
google-genai>=1.75.0
7775
google-pasta>=0.2.0
78-
google-resumable-media>=2.8.2
79-
googleapis-common-protos>=1.74.0
76+
google-resumable-media>=2.9.0
77+
googleapis-common-protos>=1.75.0
8078
grain>=0.2.16
8179
grpc-google-iam-v1>=0.14.4
82-
grpcio>=1.78.0
83-
grpcio-status>=1.78.0
80+
grpcio>=1.80.0
81+
grpcio-status>=1.80.0
8482
gviz-api>=1.10.0
8583
h11>=0.16.0
8684
h5py>=3.14.0
87-
hf-xet>=1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
85+
hf-xet>=1.5.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
8886
httpcore>=1.0.9
8987
httplib2>=0.31.2
9088
httpx>=0.28.1
91-
huggingface-hub>=1.12.0
89+
huggingface-hub>=0.36.2
9290
humanize>=4.15.0
9391
hypothesis>=6.142.1
9492
identify>=2.6.19
9593
idna>=3.11
9694
immutabledict>=4.3.1
9795
importlab>=0.8.1
98-
importlib-metadata>=9.0.0
96+
importlib-metadata>=8.7.1
9997
iniconfig>=2.3.0
10098
isort>=8.0.1
101-
jax>=0.9.2
102-
jax-cuda12-pjrt>=0.9.2 ; sys_platform == 'linux'
103-
jax-cuda12-plugin>=0.9.2 ; sys_platform == 'linux'
104-
jaxlib>=0.9.2
99+
jax>=0.10.0
100+
jax-cuda12-pjrt>=0.10.0 ; sys_platform == 'linux'
101+
jax-cuda12-plugin>=0.10.0 ; sys_platform == 'linux'
102+
jaxlib>=0.10.0
105103
jaxtyping>=0.3.9
106104
jinja2>=3.1.6
107105
jsonlines>=4.0.0
108-
keras>=3.13.2
106+
keras>=3.14.0
109107
kiwisolver>=1.5.0
110108
latex2sympy2-extended>=1.11.0
111109
libclang>=18.1.1
112110
libcst>=1.8.6
113111
markdown>=3.10.2
114112
markdown-it-py>=4.0.0
115113
markupsafe>=3.0.3
116-
marshmallow>=3.26.2
117114
math-verify>=0.9.0
118115
matplotlib>=3.10.8
119116
mccabe>=0.7.0
@@ -125,7 +122,7 @@ mpmath>=1.3.0
125122
msgpack>=1.1.2
126123
msgspec>=0.21.1
127124
multidict>=6.7.1
128-
multiprocess>=0.70.19
125+
multiprocess>=0.70.15
129126
mypy-extensions>=1.1.0
130127
namex>=0.1.0
131128
nest-asyncio>=1.6.0 ; sys_platform == 'win32'
@@ -134,77 +131,76 @@ ninja>=1.13.0
134131
nodeenv>=1.10.0
135132
numpy>=2.0.2
136133
numpy-typing-compat>=20251206.2.0
137-
nvidia-cublas-cu12>=12.9.1.4 ; sys_platform == 'linux'
138-
nvidia-cuda-cccl>=13.2.27
134+
nvidia-cublas-cu12>=12.9.2.10 ; sys_platform == 'linux'
135+
nvidia-cuda-cccl>=13.2.75
139136
nvidia-cuda-cccl-cu12>=12.9.27
140137
nvidia-cuda-cupti-cu12>=12.9.79 ; sys_platform == 'linux'
141138
nvidia-cuda-nvcc-cu12>=12.9.86 ; sys_platform == 'linux'
142139
nvidia-cuda-nvrtc-cu12>=12.9.86 ; sys_platform == 'linux'
143140
nvidia-cuda-runtime-cu12>=12.9.79 ; sys_platform == 'linux'
144-
nvidia-cudnn-cu12>=9.20.0.48 ; sys_platform == 'linux'
141+
nvidia-cudnn-cu12>=9.21.0.82 ; sys_platform == 'linux'
145142
nvidia-cufft-cu12>=11.4.1.4 ; sys_platform == 'linux'
146143
nvidia-cusolver-cu12>=11.7.5.82 ; sys_platform == 'linux'
147144
nvidia-cusparse-cu12>=12.5.10.65 ; sys_platform == 'linux'
148145
nvidia-nccl-cu12>=2.29.7 ; sys_platform == 'linux'
149146
nvidia-nvjitlink-cu12>=12.9.86 ; sys_platform == 'linux'
150-
nvidia-nvshmem-cu12>=3.5.21 ; sys_platform == 'linux'
147+
nvidia-nvshmem-cu12>=3.6.5 ; sys_platform == 'linux'
151148
oauthlib>=3.3.1
152149
omegaconf>=2.3.0
153-
opentelemetry-api>=1.16.0
150+
opentelemetry-api>=1.41.1
154151
opt-einsum>=3.4.0
155152
optax>=0.2.8
156153
optree>=0.19.0
157154
optype>=0.17.0
158-
orbax-checkpoint>=0.11.36
159-
orbax-export>=0.0.8
160-
packaging>=26.0
155+
orbax-checkpoint>=0.11.39
156+
packaging>=26.1
161157
pandas>=3.0.2
162158
parameterized>=0.9.0
163159
pathspec>=1.1.1
164160
pathwaysutils>=0.1.8
165-
pillow>=12.1.1
161+
pillow>=12.2.0
166162
platformdirs>=4.9.6
167163
pluggy>=1.6.0
168164
portpicker>=1.6.0
169165
pre-commit>=4.6.0
170166
promise>=2.3
171167
propcache>=0.4.1
172-
proto-plus>=1.27.2
168+
proto-plus>=1.28.0
173169
protobuf>=6.33.6
174170
psutil>=7.2.2
175171
pyarrow>=24.0.0
176172
pyasn1>=0.6.3
177173
pyasn1-modules>=0.4.2
178174
pycnite>=2024.7.31
179175
pycparser>=3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
180-
pydantic>=2.13.3
181-
pydantic-core>=2.46.3
176+
pydantic>=2.13.4
177+
pydantic-core>=2.46.4
182178
pydot>=4.0.1
183179
pyelftools>=0.32
184180
pyglove>=0.4.5
185-
pygments>=2.19.2
181+
pygments>=2.20.0
186182
pyink>=25.12.0
187183
pylint>=4.0.5
188184
pyparsing>=3.3.2
189185
pyproject-hooks>=1.2.0
190186
pytest>=8.4.2
191187
pytest-xdist>=3.8.0
192188
python-dateutil>=2.9.0.post0
189+
python-discovery>=1.3.0
193190
pytokens>=0.4.1
194191
pytype>=2024.10.11
195192
pyyaml>=6.0.3
196193
qwix>=0.1.6
197194
regex>=2026.4.4
198-
requests>=2.32.5
195+
requests>=2.33.1
199196
requests-oauthlib>=2.0.0
200-
rich>=14.3.3
197+
rich>=15.0.0
201198
safetensors>=0.7.0
202199
scipy>=1.17.1
203-
scipy-stubs>=1.17.1.2
200+
scipy-stubs>=1.17.1.4
204201
sentencepiece>=0.2.1
205202
seqio>=0.0.20
206203
setuptools>=82.0.1
207-
shellingham>=1.5.4
208204
simple-parsing>=0.1.8
209205
simplejson>=4.1.1
210206
six>=1.17.0
@@ -219,7 +215,7 @@ tensorboard-data-server>=0.7.2
219215
tensorboard-plugin-profile>=2.13.0
220216
tensorboardx>=2.6.5
221217
tensorflow>=2.20.0
222-
tensorflow-datasets>=4.9.9
218+
tensorflow-datasets>=4.9.10
223219
tensorflow-metadata>=1.17.3
224220
tensorflow-text>=2.20.1
225221
tensorstore>=0.1.82
@@ -231,28 +227,26 @@ toml>=0.10.2
231227
tomlkit>=0.14.0
232228
toolz>=1.1.0
233229
tqdm>=4.67.3
234-
transformer-engine>=2.14.0
235-
transformer-engine-cu12>=2.14.0
236-
transformer-engine-jax>=2.14.0
237-
transformers>=5.6.2
230+
transformer-engine>=2.14.1
231+
transformer-engine-cu12>=2.14.1
232+
transformer-engine-jax>=2.14.1
233+
transformers>=4.57.6
238234
treescope>=0.1.10
239235
typeguard>=2.13.3
240-
typer>=0.25.0
241236
typing-extensions>=4.15.0
242-
typing-inspect>=0.9.0
243237
typing-inspection>=0.4.2
244238
tzdata>=2026.2 ; sys_platform == 'emscripten' or sys_platform == 'win32'
245239
uritemplate>=4.2.0
246240
urllib3>=2.6.3
247241
uvicorn>=0.46.0
248242
uvloop>=0.22.1
249-
virtualenv>=20.36.1
243+
virtualenv>=21.3.1
250244
wadler-lindig>=0.1.7
251245
websockets>=16.0
252246
werkzeug>=3.1.8
253247
wheel>=0.46.3
254248
wrapt>=2.1.2
255249
xxhash>=3.7.0
256250
yarl>=1.23.0
257-
zipp>=3.23.0
251+
zipp>=3.23.1
258252
zstandard>=0.25.0

0 commit comments

Comments
 (0)