Skip to content

Commit 4a19791

Browse files
committed
Update generated requirements for JAX 0.9.0
1 parent 7ac9fb4 commit 4a19791

9 files changed

Lines changed: 315 additions & 439 deletions

File tree

docs/install_maxtext.md

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ Please keep dependencies updated throughout development. This will allow each co
108108

109109
To update dependencies, you will follow these general steps:
110110

111-
1. **Modify Base Requirements**: Update the desired dependencies in `base_requirements/requirements.txt` or the hardware-specific files (`base_requirements/tpu-base-requirements.txt`, `base_requirements/gpu-base-requirements.txt`).
111+
1. **Modify Base Requirements**: Update the desired dependencies in `base_requirements/requirements.txt` or the hardware-specific files (`base_requirements/tpu-requirements.txt`, `base_requirements/gpu-requirements.txt`).
112112
2. **Generate New Files**: Run the `seed-env` CLI tool to generate new, fully-pinned requirements files based on your changes.
113113
3. **Update Project Files**: Copy the newly generated files into the `generated_requirements/` directory.
114114
4. **Handle GitHub Dependencies**: Move any dependencies that are installed directly from GitHub from the generated files to `src/dependencies/github_deps/pre_train_deps.txt`.
@@ -131,48 +131,58 @@ You can find the latest commit hashes in the [JAX `build/` folder](https://githu
131131

132132
Next, run the `seed-env` CLI to generate the new requirements files. You will need to do this separately for the TPU and GPU environments. The generated files will be placed in a directory specified by `--output-dir`.
133133

134-
### For TPU
134+
> **Note:** The current `src/dependencies/requirements/generated_requirements/` in the repository were generated using JAX build commit: [d83b06508d669add43a8875ae7fd9e9fe7abf160](https://github.com/jax-ml/jax/commit/d83b06508d669add43a8875ae7fd9e9fe7abf160).
135135
136-
Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step.
136+
### TPU Pre-Training
137+
138+
If you have made changes to the pre-training dependencies in `src/dependencies/requirements/base_requirements/tpu-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
137139

138140
```bash
139141
seed-env \
140-
--local-requirements=src/dependencies/requirements/base_requirements/tpu-base-requirements.txt \
142+
--local-requirements=src/dependencies/requirements/base_requirements/tpu-requirements.txt \
141143
--host-name=MaxText \
142144
--seed-commit=<jax-build-commit-hash> \
143145
--python-version=3.12 \
144146
--requirements-txt=tpu-requirements.txt \
145147
--output-dir=generated_tpu_artifacts
146148
```
147149

148-
### For GPU
150+
After generating the new requirements, you need to copy the generated files from `generated_tpu_artifacts/tpu-requirements.txt` to `src/dependencies/requirements/generated_requirements/tpu-requirements.txt`.
151+
152+
#### TPU Post-Training
149153

150-
Similarly, run the command for the GPU requirements.
154+
If you have made changes to the post-training dependencies in `src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt`, you need to regenerate the pinned post-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
151155

152156
```bash
153157
seed-env \
154-
--local-requirements=src/dependencies/requirements/base_requirements/cuda12-base-requirements.txt \
158+
--local-requirements=src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt \
155159
--host-name=MaxText \
156160
--seed-commit=<jax-build-commit-hash> \
157161
--python-version=3.12 \
158-
--requirements-txt=cuda12-requirements.txt \
159-
--hardware=cuda12 \
160-
--output-dir=generated_gpu_artifacts
162+
--requirements-txt=tpu-post-train-requirements.txt \
163+
--output-dir=generated_tpu_post_train_artifacts
161164
```
162165

163-
## Step 4: Update Project Files
166+
After generating the new requirements, you need to copy the generated files from `generated_tpu_post_train_artifacts/tpu-post-train-requirements.txt` to `src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt`.
164167

165-
After generating the new requirements, you need to update the files in the MaxText repository.
168+
### GPU Pre-Training
166169

167-
1. **Copy the generated files:**
170+
If you have made changes to the GPU pre-training dependencies in `src/dependencies/requirements/base_requirements/gpu-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
168171

169-
- Move `generated_tpu_artifacts/tpu-requirements.txt` to `generated_requirements/tpu-requirements.txt`.
170-
- Move `generated_gpu_artifacts/cuda12-requirements.txt` to `generated_requirements/cuda12-requirements.txt`.
172+
```bash
173+
seed-env \
174+
--local-requirements=src/dependencies/requirements/base_requirements/gpu-requirements.txt \
175+
--host-name=MaxText \
176+
--seed-commit=<jax-build-commit-hash> \
177+
--python-version=3.12 \
178+
--requirements-txt=cuda12-requirements.txt \
179+
--hardware=cuda12 \
180+
--output-dir=generated_gpu_artifacts
181+
```
171182

172-
2. **Update `pre_train_deps.txt` (if necessary):**
173-
Currently, MaxText uses a few dependencies, such as `mlperf-logging` and `google-jetstream`, that are installed directly from GitHub source. These are defined in `base_requirements/requirements.txt`, and the `seed-env` tool will carry them over to the generated requirements files.
183+
After generating the new requirements, you need to copy the generated files from `generated_gpu_artifacts/cuda12-requirements.txt` to `src/dependencies/requirements/generated_requirements/cuda12-requirements.txt`.
174184

175-
## Step 5: Verify the New Dependencies
185+
## Step 4: Verify the New Dependencies
176186

177187
Finally, test that the new dependencies install correctly and that MaxText runs as expected.
178188

src/dependencies/requirements/base_requirements/gpu-base-requirements.txt renamed to src/dependencies/requirements/base_requirements/gpu-requirements.txt

File renamed without changes.

src/dependencies/requirements/base_requirements/requirements.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
absl-py
22
aqtp
33
array-record
4+
chex
45
cloud-accelerator-diagnostics
5-
cloud-tpu-diagnostics
6+
cloud-tpu-diagnostics!=1.1.14
67
datasets
78
drjax
89
flax
@@ -40,9 +41,7 @@ tensorflow-datasets
4041
tensorflow-text
4142
tensorflow
4243
tiktoken
43-
tokamax
44+
tokamax!=0.1.0
4445
transformers
4546
uvloop
4647
qwix
47-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
48-
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
-r requirements.txt
2+
ipykernel
3+
papermill

src/dependencies/requirements/base_requirements/tpu-base-requirements.txt renamed to src/dependencies/requirements/base_requirements/tpu-requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
-r requirements.txt
2-
google-tunix

0 commit comments

Comments
 (0)