Skip to content

Commit c1d057b

Browse files
Merge pull request #3684 from AI-Hypercomputer:seed-post-train
PiperOrigin-RevId: 907099588
2 parents 02ebebe + 9481832 commit c1d057b

9 files changed

Lines changed: 463 additions & 341 deletions

File tree

docs/development/update_dependencies.md

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ To update dependencies, you will follow these general steps:
3636
1. **Modify base requirements**: Update the desired dependencies in
3737
`src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files
3838
(`src/dependencies/requirements/base_requirements/tpu-requirements.txt`,
39-
`src/dependencies/requirements/base_requirements/cuda12-requirements.txt`).
39+
`src/dependencies/requirements/base_requirements/cuda12-requirements.txt`) or the post-training files (`src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt`).
4040
2. **Find the JAX build commit hash**: The dependency generation process is
4141
pinned to a specific nightly build of JAX. You need to find the commit hash
4242
for the desired JAX build.
43-
3. **Generate the requirement files**: Run the `seed-env` CLI tool to generate
44-
new, fully-pinned requirements files based on your changes.
43+
3. **Generate the requirement files**: Run `src/dependencies/scripts/generate_requirements.sh`,
44+
which internally invokes `seed-env` to produce fully-pinned requirements files.
4545
4. **Verify the new dependencies**: Test the new dependencies to ensure the
4646
project installs and runs correctly.
4747

@@ -66,17 +66,17 @@ if you want to build `seed-env` from source.
6666

6767
## Step 1: Modify base requirements
6868

69-
Update the desired dependencies in `src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files (`src/dependencies/requirements/base_requirements/tpu-requirements.txt`, `src/dependencies/requirements/base_requirements/cuda12-requirements.txt`).
69+
Update the desired dependencies in `src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files (`src/dependencies/requirements/base_requirements/tpu-requirements.txt`, `src/dependencies/requirements/base_requirements/cuda12-requirements.txt`) or the post-training files (`src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt`).
7070

7171
## Step 2: Find the JAX build commit hash
7272

7373
The dependency generation process is pinned to a specific nightly build of JAX. You need to find the commit hash for the desired JAX build from [JAX `build/` folder](https://github.com/jax-ml/jax/commits/main/build) and copy its full commit hash.
7474

7575
## Step 3: Generate the requirements files
7676

77-
Next, run the `seed-env` CLI to generate the new requirements files. You will
78-
need to do this separately for the TPU and GPU environments. The generated files
79-
will be placed in a directory specified by `--output-dir`.
77+
Next, run `generate_requirements.sh` to generate the new requirements files. This
78+
script wraps the `seed-env` CLI and handles exporting the lock, and applying any
79+
overrides. You will need to do this separately for the TPU and GPU environments.
8080

8181
> **Note:** The current `src/dependencies/requirements/generated_requirements/` in the repository were generated using JAX build commit hash: [e0d2967b50abbefd651d563dbcd7afbcb963d08c](https://github.com/jax-ml/jax/commit/e0d2967b50abbefd651d563dbcd7afbcb963d08c).
8282
@@ -85,35 +85,46 @@ will be placed in a directory specified by `--output-dir`.
8585
If you have made changes to TPU pre-training dependencies in `src/dependencies/requirements/base_requirements/tpu-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
8686

8787
```bash
88-
seed-env \
89-
--local-requirements=src/dependencies/requirements/base_requirements/tpu-base-requirements.txt \
90-
--host-name=MaxText \
91-
--seed-commit=<jax-build-commit-hash> \
92-
--python-version=3.12 \
93-
--requirements-txt=tpu-requirements.txt \
94-
--output-dir=generated_tpu_artifacts
88+
bash src/dependencies/scripts/generate_requirements.sh \
89+
--base-requirements src/dependencies/requirements/base_requirements/tpu-requirements.txt \
90+
--generated-requirements tpu-requirements.txt \
91+
--seed-commit <jax-build-commit-hash>
9592

9693
# Copy generated requirements to src/dependencies/requirements/generated_requirements
97-
mv generated_tpu_artifacts/tpu-requirements.txt \
94+
mv generated_artifacts/python3_12/tpu-requirements.txt \
9895
src/dependencies/requirements/generated_requirements/tpu-requirements.txt
9996
```
10097

98+
### TPU Post-Training
99+
100+
If you have made changes to the post-training dependencies in `src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt`, you need to regenerate the pinned post-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
101+
102+
```bash
103+
104+
bash src/dependencies/scripts/generate_requirements.sh \
105+
--base-requirements src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt \
106+
--generated-requirements tpu-post-train-requirements.txt \
107+
--override-requirements src/dependencies/extra_deps/post_train_overrides.txt \
108+
--seed-commit <jax-build-commit-hash>
109+
110+
# Copy generated requirements to src/dependencies/requirements/generated_requirements
111+
mv generated_artifacts/python3_12/tpu-post-train-requirements.txt \
112+
src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt
113+
```
114+
101115
### GPU Pre-Training
102116

103117
If you have made changes to the GPU pre-training dependencies in `src/dependencies/requirements/base_requirements/cuda12-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
104118

105119
```bash
106-
seed-env \
107-
--local-requirements=src/dependencies/requirements/base_requirements/cuda12-requirements.txt \
108-
--host-name=MaxText \
109-
--seed-commit=<jax-build-commit-hash> \
110-
--python-version=3.12 \
111-
--requirements-txt=cuda12-requirements.txt \
112-
--hardware=cuda12 \
113-
--output-dir=generated_gpu_artifacts
120+
bash src/dependencies/scripts/generate_requirements.sh \
121+
--base-requirements src/dependencies/requirements/base_requirements/cuda12-requirements.txt \
122+
--generated-requirements cuda12-requirements.txt \
123+
--seed-commit <jax-build-commit-hash> \
124+
--hardware cuda12
114125

115126
# Copy generated requirements to src/dependencies/requirements/generated_requirements
116-
mv generated_gpu_artifacts/cuda12-requirements.txt \
127+
mv generated_artifacts/python3_12/cuda12-requirements.txt \
117128
src/dependencies/requirements/generated_requirements/cuda12-requirements.txt
118129
```
119130

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
11
-r post_train_base_deps.txt
2-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
3-
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
42
tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/40876e81f04226f9b7b1e4bbdc9051d6b1364b9d.zip
53
vllm @ git+https://github.com/vllm-project/vllm@595562651a5a4539ffa910d8570c08fb5169bdc9
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
flax==0.12.4
2+
google-metrax>=0.2.3
3+
libtpu>=0.0.39
4+
optax==0.2.6
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
-r requirements.txt
2+
anthropic
3+
astor
4+
cachetools
5+
cbor2
6+
compressed-tensors
7+
gepa
8+
gguf
9+
google-cloud-storage
10+
google-tunix
11+
gspread
12+
hypothesis
13+
ipykernel
14+
ipywidgets
15+
ijson
16+
llguidance
17+
loguru
18+
lxml
19+
mistral_common
20+
openai
21+
openai-harmony
22+
papermill
23+
partial-json-parser
24+
perfetto
25+
prometheus-fastapi-instrumentator
26+
py-cpuinfo
27+
pybase64
28+
pytest-mock
29+
python-json-logger
30+
pyzmq
31+
ray
32+
runai-model-streamer[s3,gcs]
33+
sortedcontainers
34+
torchax
35+
torchvision
36+
tpu-info
37+
xgrammar
38+
xprof
39+
yapf

src/dependencies/requirements/generated_requirements/cuda12-requirements.txt

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
# Generated by seed-env. Do not edit manually.
2-
# If you need to modify dependencies, please do so in the host requirements file and run seed-env again.
1+
# Generated by generate_requirements.sh using seed-env tool. Do not edit manually.
2+
# See https://maxtext.readthedocs.io/en/latest/development/update_dependencies.html for details.
33

44
absl-py>=2.4.0
55
aiofiles>=25.1.0
@@ -30,10 +30,10 @@ cloudpickle>=3.1.2
3030
clu>=0.0.12
3131
colorama>=0.4.6
3232
contourpy>=1.3.3
33-
cryptography>=46.0.7
33+
cryptography>=47.0.0
3434
cycler>=0.12.1
3535
dataclasses-json>=0.6.7
36-
datasets>=4.8.4
36+
datasets>=4.8.5
3737
decorator>=5.2.1
3838
deprecated>=1.3.1
3939
dill>=0.4.1
@@ -58,9 +58,9 @@ gast>=0.7.0
5858
gcsfs>=2026.2.0
5959
google-api-core>=2.30.3
6060
google-api-python-client>=2.194.0
61+
google-auth>=2.49.2
6162
google-auth-httplib2>=0.3.1
6263
google-auth-oauthlib>=1.3.1
63-
google-auth>=2.49.2
6464
google-cloud-aiplatform>=1.148.1
6565
google-cloud-appengine-logging>=1.9.0
6666
google-cloud-audit-log>=0.5.0
@@ -70,25 +70,25 @@ google-cloud-logging>=3.15.0
7070
google-cloud-mldiagnostics>=1.0.2
7171
google-cloud-monitoring>=2.30.0
7272
google-cloud-resource-manager>=1.17.0
73-
google-cloud-storage-control>=1.11.0
7473
google-cloud-storage>=3.10.1
74+
google-cloud-storage-control>=1.11.0
7575
google-crc32c>=1.8.0
7676
google-genai>=1.73.1
7777
google-pasta>=0.2.0
7878
google-resumable-media>=2.8.2
7979
googleapis-common-protos>=1.74.0
8080
grain>=0.2.16
8181
grpc-google-iam-v1>=0.14.4
82-
grpcio-status>=1.78.0
8382
grpcio>=1.78.0
83+
grpcio-status>=1.78.0
8484
gviz-api>=1.10.0
8585
h11>=0.16.0
8686
h5py>=3.14.0
8787
hf-xet>=1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
8888
httpcore>=1.0.9
8989
httplib2>=0.31.2
9090
httpx>=0.28.1
91-
huggingface-hub>=1.11.0
91+
huggingface-hub>=1.12.0
9292
humanize>=4.15.0
9393
hypothesis>=6.142.1
9494
identify>=2.6.19
@@ -98,9 +98,9 @@ importlab>=0.8.1
9898
importlib-metadata>=9.0.0
9999
iniconfig>=2.3.0
100100
isort>=8.0.1
101+
jax>=0.9.2
101102
jax-cuda12-pjrt>=0.9.2 ; sys_platform == 'linux'
102103
jax-cuda12-plugin>=0.9.2 ; sys_platform == 'linux'
103-
jax>=0.9.2
104104
jaxlib>=0.9.2
105105
jaxtyping>=0.3.9
106106
jinja2>=3.1.6
@@ -110,8 +110,8 @@ kiwisolver>=1.5.0
110110
latex2sympy2-extended>=1.11.0
111111
libclang>=18.1.1
112112
libcst>=1.8.6
113-
markdown-it-py>=4.0.0
114113
markdown>=3.10.2
114+
markdown-it-py>=4.0.0
115115
markupsafe>=3.0.3
116116
marshmallow>=3.26.2
117117
math-verify>=0.9.0
@@ -132,11 +132,11 @@ nest-asyncio>=1.6.0 ; sys_platform == 'win32'
132132
networkx>=3.6.1
133133
ninja>=1.13.0
134134
nodeenv>=1.10.0
135-
numpy-typing-compat>=20251206.2.0
136135
numpy>=2.0.2
136+
numpy-typing-compat>=20251206.2.0
137137
nvidia-cublas-cu12>=12.9.1.4 ; sys_platform == 'linux'
138-
nvidia-cuda-cccl-cu12>=12.9.27
139138
nvidia-cuda-cccl>=13.2.27
139+
nvidia-cuda-cccl-cu12>=12.9.27
140140
nvidia-cuda-cupti-cu12>=12.9.79 ; sys_platform == 'linux'
141141
nvidia-cuda-nvcc-cu12>=12.9.86 ; sys_platform == 'linux'
142142
nvidia-cuda-nvrtc-cu12>=12.9.86 ; sys_platform == 'linux'
@@ -160,8 +160,8 @@ orbax-export>=0.0.8
160160
packaging>=26.0
161161
pandas>=3.0.2
162162
parameterized>=0.9.0
163-
pathspec>=1.1.0
164-
pathwaysutils>=0.1.7
163+
pathspec>=1.1.1
164+
pathwaysutils>=0.1.8
165165
pillow>=12.1.1
166166
platformdirs>=4.9.6
167167
pluggy>=1.6.0
@@ -173,12 +173,12 @@ proto-plus>=1.27.2
173173
protobuf>=6.33.6
174174
psutil>=7.2.2
175175
pyarrow>=24.0.0
176-
pyasn1-modules>=0.4.2
177176
pyasn1>=0.6.3
177+
pyasn1-modules>=0.4.2
178178
pycnite>=2024.7.31
179179
pycparser>=3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
180-
pydantic-core>=2.46.3
181180
pydantic>=2.13.3
181+
pydantic-core>=2.46.3
182182
pydot>=4.0.1
183183
pyelftools>=0.32
184184
pyglove>=0.4.5
@@ -187,41 +187,41 @@ pyink>=25.12.0
187187
pylint>=4.0.5
188188
pyparsing>=3.3.2
189189
pyproject-hooks>=1.2.0
190-
pytest-xdist>=3.8.0
191190
pytest>=8.4.2
191+
pytest-xdist>=3.8.0
192192
python-dateutil>=2.9.0.post0
193193
pytokens>=0.4.1
194194
pytype>=2024.10.11
195195
pyyaml>=6.0.3
196196
qwix>=0.1.6
197197
regex>=2026.4.4
198-
requests-oauthlib>=2.0.0
199198
requests>=2.32.5
199+
requests-oauthlib>=2.0.0
200200
rich>=14.3.3
201201
safetensors>=0.7.0
202-
scipy-stubs>=1.17.1.2
203202
scipy>=1.17.1
203+
scipy-stubs>=1.17.1.2
204204
sentencepiece>=0.2.1
205205
seqio>=0.0.20
206206
setuptools>=82.0.1
207207
shellingham>=1.5.4
208208
simple-parsing>=0.1.8
209-
simplejson>=4.1.0
209+
simplejson>=4.1.1
210210
six>=1.17.0
211211
sniffio>=1.3.1
212212
sortedcontainers>=2.4.0
213213
starlette>=1.0.0
214214
sympy>=1.14.0
215215
tabulate>=0.10.0
216216
tenacity>=9.1.4
217+
tensorboard>=2.20.0
217218
tensorboard-data-server>=0.7.2
218219
tensorboard-plugin-profile>=2.13.0
219-
tensorboard>=2.20.0
220220
tensorboardx>=2.6.5
221+
tensorflow>=2.20.0
221222
tensorflow-datasets>=4.9.9
222223
tensorflow-metadata>=1.17.3
223224
tensorflow-text>=2.20.1
224-
tensorflow>=2.20.0
225225
tensorstore>=0.1.82
226226
termcolor>=3.3.0
227227
tiktoken>=0.12.0
@@ -231,17 +231,17 @@ toml>=0.10.2
231231
tomlkit>=0.14.0
232232
toolz>=1.1.0
233233
tqdm>=4.67.3
234-
transformer-engine-cu12>=2.13.0
235-
transformer-engine-jax>=2.13.0
236-
transformer-engine>=2.13.0
237-
transformers>=5.6.1
234+
transformer-engine>=2.14.0
235+
transformer-engine-cu12>=2.14.0
236+
transformer-engine-jax>=2.14.0
237+
transformers>=5.6.2
238238
treescope>=0.1.10
239239
typeguard>=2.13.3
240-
typer>=0.24.2
240+
typer>=0.25.0
241241
typing-extensions>=4.15.0
242242
typing-inspect>=0.9.0
243243
typing-inspection>=0.4.2
244-
tzdata>=2026.1 ; sys_platform == 'emscripten' or sys_platform == 'win32'
244+
tzdata>=2026.2 ; sys_platform == 'emscripten' or sys_platform == 'win32'
245245
uritemplate>=4.2.0
246246
urllib3>=2.6.3
247247
uvicorn>=0.46.0
@@ -252,7 +252,7 @@ websockets>=16.0
252252
werkzeug>=3.1.8
253253
wheel>=0.46.3
254254
wrapt>=2.1.2
255-
xxhash>=3.6.0
255+
xxhash>=3.7.0
256256
yarl>=1.23.0
257257
zipp>=3.23.0
258-
zstandard>=0.25.0
258+
zstandard>=0.25.0

0 commit comments

Comments
 (0)