Skip to content

Commit 2e985cb

Browse files
committed
Update generated requirements for JAX 0.9.2
1 parent 7f479f4 commit 2e985cb

10 files changed

Lines changed: 463 additions & 619 deletions

File tree

docs/install_maxtext.md

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,10 @@ Please keep dependencies updated throughout development. This will allow each co
110110

111111
To update dependencies, you will follow these general steps:
112112

113-
1. **Modify Base Requirements**: Update the desired dependencies in `base_requirements/requirements.txt` or the hardware-specific files (`base_requirements/tpu-base-requirements.txt`, `base_requirements/gpu-base-requirements.txt`).
113+
1. **Modify Base Requirements**: Update the desired dependencies in `src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files (`base_requirements/tpu-requirements.txt`, `base_requirements/gpu-requirements.txt`) or post-training requirements.
114114
2. **Generate New Files**: Run the `seed-env` CLI tool to generate new, fully-pinned requirements files based on your changes.
115115
3. **Update Project Files**: Copy the newly generated files into the `generated_requirements/` directory.
116-
4. **Handle GitHub Dependencies**: Move any dependencies that are installed directly from GitHub from the generated files to `src/dependencies/github_deps/pre_train_deps.txt`.
117-
5. **Verify**: Test the new dependencies to ensure the project installs and runs correctly.
116+
4. **Verify**: Test the new dependencies to ensure the project installs and runs correctly.
118117

119118
The following sections provide detailed instructions for each step.
120119

@@ -125,59 +124,70 @@ First, you need to install the `seed-env` command-line tool by running `pip inst
125124

126125
## Step 2: Find the JAX Build Commit Hash
127126

128-
The dependency generation process is pinned to a specific nightly build of JAX. You need to find the commit hash for the desired JAX build.
129-
130-
You can find the latest commit hashes in the [JAX `build/` folder](https://github.com/jax-ml/jax/commits/main/build). Choose a recent, successful build and copy its full commit hash.
127+
The dependency generation process is pinned to a specific nightly build of JAX. You need to find the commit hash for the desired JAX build from [JAX `build/` folder](https://github.com/jax-ml/jax/commits/main/build).
131128

132129
## Step 3: Generate the Requirements Files
133130

134131
Next, run the `seed-env` CLI to generate the new requirements files. You will need to do this separately for the TPU and GPU environments. The generated files will be placed in a directory specified by `--output-dir`.
135132

136-
### For TPU
133+
> **Note:** The current `src/dependencies/requirements/generated_requirements/` in the repository were generated using JAX build commit hash: [e0d2967b50abbefd651d563dbcd7afbcb963d08c](https://github.com/jax-ml/jax/commit/e0d2967b50abbefd651d563dbcd7afbcb963d08c).
134+
135+
### TPU Pre-Training
137136

138-
Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step.
137+
If you have made changes to TPU pre-training dependencies in `src/dependencies/requirements/base_requirements/tpu-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
139138

140139
```bash
141140
seed-env \
142-
--local-requirements=src/dependencies/requirements/base_requirements/tpu-base-requirements.txt \
141+
--local-requirements=src/dependencies/requirements/base_requirements/tpu-requirements.txt \
143142
--host-name=MaxText \
144143
--seed-commit=<jax-build-commit-hash> \
145144
--python-version=3.12 \
146145
--requirements-txt=tpu-requirements.txt \
147146
--output-dir=generated_tpu_artifacts
147+
148+
# Copy generated requirements to src/dependencies/requirements/generated_requirements
149+
mv generated_tpu_artifacts/tpu-requirements.txt src/dependencies/requirements/generated_requirements/tpu-requirements.txt
148150
```
149151

150-
### For GPU
152+
#### TPU Post-Training
151153

152-
Similarly, run the command for the GPU requirements.
154+
If you have made changes to the post-training dependencies in `src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt`, you need to regenerate the pinned post-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
153155

154156
```bash
155157
seed-env \
156-
--local-requirements=src/dependencies/requirements/base_requirements/cuda12-base-requirements.txt \
158+
--local-requirements=src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt \
157159
--host-name=MaxText \
158160
--seed-commit=<jax-build-commit-hash> \
159161
--python-version=3.12 \
160-
--requirements-txt=cuda12-requirements.txt \
161-
--hardware=cuda12 \
162-
--output-dir=generated_gpu_artifacts
163-
```
162+
--requirements-txt=tpu-post-train-requirements.txt \
163+
--output-dir=generated_tpu_post_train_artifacts
164164

165-
## Step 4: Update Project Files
165+
# Copy generated requirements to src/dependencies/requirements/generated_requirements
166+
mv generated_tpu_post_train_artifacts/tpu-post-train-requirements.txt src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt
167+
```
166168

167-
After generating the new requirements, you need to update the files in the MaxText repository.
169+
### GPU Pre-Training
168170

169-
1. **Copy the generated files:**
171+
If you have made changes to the GPU pre-training dependencies in `src/dependencies/requirements/base_requirements/gpu-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
170172

171-
- Move `generated_tpu_artifacts/tpu-requirements.txt` to `generated_requirements/tpu-requirements.txt`.
172-
- Move `generated_gpu_artifacts/cuda12-requirements.txt` to `generated_requirements/cuda12-requirements.txt`.
173+
```bash
174+
seed-env \
175+
--local-requirements=src/dependencies/requirements/base_requirements/cuda12-requirements.txt \
176+
--host-name=MaxText \
177+
--seed-commit=<jax-build-commit-hash> \
178+
--python-version=3.12 \
179+
--requirements-txt=cuda12-requirements.txt \
180+
--hardware=cuda12 \
181+
--output-dir=generated_gpu_artifacts
173182

174-
2. **Update `pre_train_deps.txt` (if necessary):**
175-
Currently, MaxText uses a few dependencies, such as `mlperf-logging` and `google-jetstream`, that are installed directly from GitHub source. These are defined in `base_requirements/requirements.txt`, and the `seed-env` tool will carry them over to the generated requirements files.
183+
# Copy generated requirements to src/dependencies/requirements/generated_requirements
184+
mv generated_gpu_artifacts/cuda12-requirements.txt.txt src/dependencies/requirements/generated_requirements/cuda12-requirements.txt.txt
185+
```
176186

177-
## Step 5: Verify the New Dependencies
187+
## Step 4: Verify the New Dependencies
178188

179189
Finally, test that the new dependencies install correctly and that MaxText runs as expected.
180190

181-
1. **Install MaxText and dependencies**: For instructions on installing MaxText on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/maxtext-v0.2.0/install_maxtext.html#from-source).
191+
1. **Install MaxText and dependencies**: For instructions on installing MaxText on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-source).
182192

183193
2. **Verify the installation**: Run MaxText tests to ensure everything is working as expected with the newly installed dependencies and there are no regressions.

src/dependencies/requirements/base_requirements/gpu-base-requirements.txt renamed to src/dependencies/requirements/base_requirements/cuda12-requirements.txt

File renamed without changes.

src/dependencies/requirements/base_requirements/requirements.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
absl-py
22
aqtp
33
array-record
4+
chex
45
cloud-accelerator-diagnostics
5-
cloud-tpu-diagnostics
6+
cloud-tpu-diagnostics!=1.1.14
67
datasets
78
drjax
89
flax
@@ -40,9 +41,7 @@ tensorflow-datasets
4041
tensorflow-text
4142
tensorflow
4243
tiktoken
43-
tokamax
44+
tokamax!=0.1.0
4445
transformers
4546
uvloop
4647
qwix
47-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
48-
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
-r requirements.txt
2+
clu
3+
google-metrax
4+
ipykernel
5+
papermill

src/dependencies/requirements/base_requirements/tpu-base-requirements.txt renamed to src/dependencies/requirements/base_requirements/tpu-requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
-r requirements.txt
2-
google-tunix

0 commit comments

Comments
 (0)