Skip to content

Commit 7aee9d0

Browse files
committed
Update generated requirements for JAX 0.9.0
1 parent 3817525 commit 7aee9d0

8 files changed

Lines changed: 336 additions & 444 deletions

File tree

docs/install_maxtext.md

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,11 @@ You can find the latest commit hashes in the [JAX `build/` folder](https://githu
124124

125125
Next, run the `seed-env` CLI to generate the new requirements files. You will need to do this separately for the TPU and GPU environments. The generated files will be placed in a directory specified by `--output-dir`.
126126

127-
### For TPU
127+
> **Note:** The current `src/dependencies/requirements/generated_requirements/` in the repository were generated using JAX build commit: [d83b06508d669add43a8875ae7fd9e9fe7abf160](https://github.com/jax-ml/jax/commit/d83b06508d669add43a8875ae7fd9e9fe7abf160).
128128
129-
Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step.
129+
### TPU Pre-Training
130+
131+
If you have made changes to the pre-training dependencies in `src/dependencies/requirements/base_requirements/tpu-base-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
130132

131133
```bash
132134
seed-env \
@@ -138,53 +140,77 @@ seed-env \
138140
--output-dir=generated_tpu_artifacts
139141
```
140142

141-
### For GPU
143+
After generating the new requirements, you need to copy the generated files from `generated_tpu_artifacts/tpu-requirements.txt` to `src/dependencies/requirements/generated_requirements/tpu-requirements.txt`.
144+
145+
#### TPU Post-Training
142146

143-
Similarly, run the command for the GPU requirements.
147+
If you have made changes to the post-training dependencies in `src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt`, you need to regenerate the pinned post-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
144148

145149
```bash
146150
seed-env \
147-
--local-requirements=src/dependencies/requirements/base_requirements/cuda12-base-requirements.txt \
151+
--local-requirements=src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt \
148152
--host-name=MaxText \
149153
--seed-commit=<jax-build-commit-hash> \
150154
--python-version=3.12 \
151-
--requirements-txt=cuda12-requirements.txt \
152-
--hardware=cuda12 \
153-
--output-dir=generated_gpu_artifacts
155+
--requirements-txt=tpu-post-train-requirements.txt \
156+
--output-dir=generated_tpu_post_train_artifacts
154157
```
155158

156-
## Step 4: Update Project Files
159+
After generating the new requirements, you need to copy the generated files from `generated_tpu_post_train_artifacts/tpu-post-train-requirements.txt` to `src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt`.
157160

158-
After generating the new requirements, you need to update the files in the MaxText repository.
161+
### GPU Pre-Training
159162

160-
1. **Copy the generated files:**
163+
If you have made changes to the GPU pre-training dependencies in `src/dependencies/requirements/base_requirements/gpu-base-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
161164

162-
- Move `generated_tpu_artifacts/tpu-requirements.txt` to `generated_requirements/tpu-requirements.txt`.
163-
- Move `generated_gpu_artifacts/cuda12-requirements.txt` to `generated_requirements/cuda12-requirements.txt`.
165+
```bash
166+
seed-env \
167+
--local-requirements=src/dependencies/requirements/base_requirements/gpu-base-requirements.txt \
168+
--host-name=MaxText \
169+
--seed-commit=<jax-build-commit-hash> \
170+
--python-version=3.12 \
171+
--requirements-txt=cuda12-requirements.txt \
172+
--hardware=cuda12 \
173+
--output-dir=generated_gpu_artifacts
174+
```
164175

165-
2. **Update `extra_deps_from_github.txt` (if necessary):**
166-
Currently, MaxText uses a few dependencies, such as `mlperf-logging` and `google-jetstream`, that are installed directly from GitHub source. These are defined in `base_requirements/requirements.txt`, and the `seed-env` tool will carry them over to the generated requirements files.
176+
After generating the new requirements, you need to copy the generated files from `generated_gpu_artifacts/cuda12-requirements.txt` to `src/dependencies/requirements/generated_requirements/cuda12-requirements.txt`.
167177

168-
## Step 5: Verify the New Dependencies
178+
## Step 4: Verify the New Dependencies
169179

170180
Finally, test that the new dependencies install correctly and that MaxText runs as expected.
171181

172182
1. **Create a clean environment:** It's best to start with a fresh Python virtual environment.
173183

174184
```bash
185+
# Ensure uv is installed
186+
pip install uv
187+
188+
# Create and activate the virtual environment
175189
uv venv --python 3.12 --seed maxtext_venv
176190
source maxtext_venv/bin/activate
177191
```
178192

179-
2. **Run the setup script:** Execute `bash setup.sh` to install the new dependencies.
193+
2. **Install MaxText and dependencies**: Install the package in editable mode with the appropriate extras. Choose the command that matches your hardware:
194+
195+
**TPU Pre-Training**:
180196

181197
```bash
182-
pip install uv
183-
# install the tpu package
184198
uv pip install -e .[tpu] --resolution=lowest
185-
# or install the gpu package by running the following line:
186-
# uv pip install -e .[cuda12] --resolution=lowest
187-
install_maxtext_github_deps
199+
install_maxtext_tpu_github_deps
200+
```
201+
202+
**TPU Post-Training**:
203+
204+
```bash
205+
uv pip install -e .[tpu-post-train] --resolution=lowest
206+
install_maxtext_tpu_post_train_extra_deps
207+
```
208+
209+
**GPU Pre-Training**:
210+
211+
```bash
212+
uv pip install -e .[cuda12] --resolution=lowest
213+
install_maxtext_cuda12_github_dep
188214
```
189215

190-
3. **Run tests:** Run MaxText tests to ensure there are no regressions.
216+
3. **Verify the installation**: Run MaxText tests to ensure everything is working as expected with the newly installed dependencies and there are no regressions.

src/dependencies/requirements/base_requirements/requirements.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
absl-py
22
aqtp
33
array-record
4+
chex
45
cloud-accelerator-diagnostics
5-
cloud-tpu-diagnostics
6+
cloud-tpu-diagnostics!=1.1.14
67
datasets
78
drjax
89
flax
@@ -40,9 +41,7 @@ tensorflow-datasets
4041
tensorflow-text
4142
tensorflow
4243
tiktoken
43-
tokamax
44+
tokamax!=0.1.0
4445
transformers
4546
uvloop
4647
qwix
47-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
48-
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
-r requirements.txt
2-
google-tunix
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
-r requirements.txt
2+
ipykernel
3+
papermill

0 commit comments

Comments
 (0)