Skip to content

Commit ff8e8e1

Browse files
committed
Update generated requirements for post-training with JAX 0.9.2
1 parent 412902a commit ff8e8e1

8 files changed

Lines changed: 376 additions & 281 deletions

File tree

docs/development/update_dependencies.md

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ To update dependencies, you will follow these general steps:
3636
1. **Modify base requirements**: Update the desired dependencies in
3737
`src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files
3838
(`src/dependencies/requirements/base_requirements/tpu-requirements.txt`,
39-
`src/dependencies/requirements/base_requirements/cuda12-requirements.txt`).
39+
`src/dependencies/requirements/base_requirements/cuda12-requirements.txt`) or the post-training files (`src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt`).
4040
2. **Find the JAX build commit hash**: The dependency generation process is
4141
pinned to a specific nightly build of JAX. You need to find the commit hash
4242
for the desired JAX build.
@@ -66,7 +66,7 @@ if you want to build `seed-env` from source.
6666

6767
## Step 1: Modify base requirements
6868

69-
Update the desired dependencies in `src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files (`src/dependencies/requirements/base_requirements/tpu-requirements.txt`, `src/dependencies/requirements/base_requirements/cuda12-requirements.txt`).
69+
Update the desired dependencies in `src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files (`src/dependencies/requirements/base_requirements/tpu-requirements.txt`, `src/dependencies/requirements/base_requirements/cuda12-requirements.txt`) or the post-training files (`src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt`).
7070

7171
## Step 2: Find the JAX build commit hash
7272

@@ -98,6 +98,24 @@ mv generated_tpu_artifacts/tpu-requirements.txt \
9898
src/dependencies/requirements/generated_requirements/tpu-requirements.txt
9999
```
100100

101+
### TPU Post-Training
102+
103+
If you have made changes to the post-training dependencies in `src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt`, you need to regenerate the pinned post-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
104+
105+
```bash
106+
seed-env \
107+
--local-requirements=src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt \
108+
--host-name=MaxText \
109+
--seed-commit=<jax-build-commit-hash> \
110+
--python-version=3.12 \
111+
--requirements-txt=tpu-post-train-requirements.txt \
112+
--output-dir=generated_tpu_post_train_artifacts
113+
114+
# Copy generated requirements to src/dependencies/requirements/generated_requirements
115+
mv generated_tpu_post_train_artifacts/tpu-post-train-requirements.txt \
116+
src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt
117+
```
118+
101119
### GPU Pre-Training
102120

103121
If you have made changes to the GPU pre-training dependencies in `src/dependencies/requirements/base_requirements/cuda12-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
11
-r post_train_base_deps.txt
2-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
3-
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
42
tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/40876e81f04226f9b7b1e4bbdc9051d6b1364b9d.zip
53
vllm @ git+https://github.com/vllm-project/vllm@595562651a5a4539ffa910d8570c08fb5169bdc9
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
google-metrax>=0.2.3
2+
libtpu>=0.0.39
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
-r requirements.txt
2+
gepa
3+
google-cloud-storage
4+
google-tunix
5+
hypothesis
6+
ijson
7+
mistral_common
8+
prometheus-fastapi-instrumentator
9+
pytest-mock
10+
runai-model-streamer[s3,gcs]
11+
sortedcontainers
12+
torchax
13+
gguf
14+
pyzmq
15+
pybase64
16+
cachetools
17+
openai
18+
openai-harmony
19+
torchvision
20+
py-cpuinfo
21+
llguidance
22+
tpu-info
23+
xgrammar
24+
xprof
25+
yapf
26+
anthropic
27+
openai
28+
ray
29+
gspread
30+
lxml
31+
loguru
32+
astor
33+
python-json-logger
34+
perfetto
35+
cbor2
36+
compressed-tensors

0 commit comments

Comments
 (0)