Skip to content

Commit ba74653

Browse files
committed
Update generated requirements for pre-training with JAX 0.9.2
1 parent 35e07f9 commit ba74653

16 files changed

Lines changed: 361 additions & 377 deletions

File tree

.github/workflows/build_and_test_maxtext.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,8 @@ jobs:
231231
steps:
232232
- name: Check test results
233233
run: |
234-
# If doc-only, all tests should be skipped
235-
if [ "${NEEDS_DOC_ONLY_CHECK_OUTPUTS_RUN_TESTS}" == "false" ]; then
236-
echo "Documentation-only changes detected, tests were skipped"
234+
if [ "${NEEDS_ANALYZE_CHANGES_OUTPUTS_RUN_TESTS}" == "false" ]; then
235+
echo "Tests were skipped"
237236
exit 0
238237
fi
239238

.github/workflows/run_pathways_tests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ jobs:
112112
- "29001:29001"
113113
- "29002:29002"
114114
options:
115-
--entrypoint=[/usr/pathways/run/cloud_pathways_server_sanitized, --server_port=29001, --node_type=resource_manager, --instance_count=1, --instance_type=tpuv6e:2x2, --gcs_scratch_location=gs://cloud-pathways-staging/tmp]
115+
--entrypoint=[/usr/pathways/run/cloud_pathways_server_sanitized, --server_port=29001, --node_type=resource_manager, --enforce_kernel_ipv6_support=false, --instance_count=1, --instance_type=tpuv6e:2x2, --gcs_scratch_location=gs://cloud-pathways-staging/tmp]
116116
env:
117117
TPU_SKIP_MDS_QUERY: true
118118

@@ -124,7 +124,7 @@ jobs:
124124
- "8471:8471"
125125
- "8080:8080"
126126
options:
127-
--entrypoint=[/usr/pathways/run/cloud_pathways_server_sanitized, --server_port=29005, --resource_manager_address=localhost:29001, --gcs_scratch_location=gs://cloud-pathways-staging/tmp]
127+
--entrypoint=[/usr/pathways/run/cloud_pathways_server_sanitized, --server_port=29005, --resource_manager_address=localhost:29001, --enforce_kernel_ipv6_support=false, --gcs_scratch_location=gs://cloud-pathways-staging/tmp]
128128
--tpu=4
129129

130130
proxy:
@@ -135,4 +135,4 @@ jobs:
135135
IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS: true
136136
XLA_FLAGS: "--xla_dump_to=/tmp/aot_test_dump --xla_dump_hlo_as_text --xla_dump_hlo_module_re=jit_train_step"
137137
options:
138-
--entrypoint=[/usr/pathways/run/cloud_proxy_server_sanitized, --server_port=29000, --resource_manager_address=localhost:29001, --gcs_scratch_location=gs://cloud-pathways-staging/tmp, --xla_tpu_scoped_vmem_limit_kib=65536, --xla_tpu_spmd_rng_bit_generator_unsafe=true]
138+
--entrypoint=[/usr/pathways/run/cloud_proxy_server_sanitized, --server_port=29000, --resource_manager_address=localhost:29001, --enforce_kernel_ipv6_support=false, --gcs_scratch_location=gs://cloud-pathways-staging/tmp, --xla_tpu_scoped_vmem_limit_kib=65536, --xla_tpu_spmd_rng_bit_generator_unsafe=true]

.github/workflows/run_tests_against_package.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ jobs:
152152
# omit this libtpu init args for gpu tests
153153
if [ "${INPUTS_DEVICE_TYPE}" != "cuda12" ]; then
154154
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
155+
else
156+
# For cuda12, explicitly point to the pip-installed CUDA libraries
157+
# to avoid conflicts with system-level installations on the runner.
158+
if [ -d ".venv/lib/python3.12/site-packages/nvidia" ]; then
159+
export LD_LIBRARY_PATH=$(pwd)/.venv/lib/python3.12/site-packages/nvidia/cudnn/lib:${LD_LIBRARY_PATH}
160+
else
161+
echo "Warning: Could not find pinned nvidia libraries in .venv."
162+
fi
155163
fi
156164
if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then
157165
$PYTHON_EXE -m pip install --quiet pytest-split pytest-xdist

docs/development/update_dependencies.md

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,15 @@ to keep dependencies in sync for users installing MaxText from source.
3434
To update dependencies, you will follow these general steps:
3535

3636
1. **Modify base requirements**: Update the desired dependencies in
37-
`src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific files
38-
(`src/dependencies/requirements/base_requirements/tpu-base-requirements.txt`,
39-
`src/dependencies/requirements/base_requirements/gpu-base-requirements.txt`).
37+
`src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files
38+
(`src/dependencies/requirements/base_requirements/tpu-requirements.txt`,
39+
`src/dependencies/requirements/base_requirements/cuda12-requirements.txt`).
4040
2. **Find the JAX build commit hash**: The dependency generation process is
4141
pinned to a specific nightly build of JAX. You need to find the commit hash
4242
for the desired JAX build.
4343
3. **Generate the requirement files**: Run the `seed-env` CLI tool to generate
4444
new, fully-pinned requirements files based on your changes.
45-
4. **Update project files**: Copy the newly generated files into the
46-
`src/dependencies/requirements/generated_requirements/` directory. If
47-
necessary, also update any dependencies that are installed directly from
48-
GitHub from the generated files to `src/dependencies/extra_deps`.
49-
5. **Verify the new dependencies**: Test the new dependencies to ensure the
45+
4. **Verify the new dependencies**: Test the new dependencies to ensure the
5046
project installs and runs correctly.
5147

5248
The following sections provide detailed instructions for each step.
@@ -70,31 +66,23 @@ if you want to build `seed-env` from source.
7066

7167
## Step 1: Modify base requirements
7268

73-
Update the desired dependencies in
74-
`src/dependencies/requirements/base_requirements/requirements.txt` or the
75-
hardware-specific files
76-
(`src/dependencies/requirements/base_requirements/tpu-base-requirements.txt`,
77-
`src/dependencies/requirements/base_requirements/gpu-base-requirements.txt`).
69+
Update the desired dependencies in `src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files (`src/dependencies/requirements/base_requirements/tpu-requirements.txt`, `src/dependencies/requirements/base_requirements/cuda12-requirements.txt`).
7870

7971
## Step 2: Find the JAX build commit hash
8072

81-
The dependency generation process is pinned to a specific nightly build of JAX.
82-
You need to find the commit hash for the desired JAX build.
83-
84-
You can find the latest commit hashes in the
85-
[JAX `build/` folder](https://github.com/jax-ml/jax/commits/main/build). Choose
86-
a recent, successful build and copy its full commit hash.
73+
The dependency generation process is pinned to a specific nightly build of JAX. You need to find the commit hash for the desired JAX build from [JAX `build/` folder](https://github.com/jax-ml/jax/commits/main/build) and copy its full commit hash.
8774

8875
## Step 3: Generate the requirements files
8976

9077
Next, run the `seed-env` CLI to generate the new requirements files. You will
9178
need to do this separately for the TPU and GPU environments. The generated files
9279
will be placed in a directory specified by `--output-dir`.
9380

94-
### For TPU
81+
> **Note:** The current `src/dependencies/requirements/generated_requirements/` in the repository were generated using JAX build commit hash: [e0d2967b50abbefd651d563dbcd7afbcb963d08c](https://github.com/jax-ml/jax/commit/e0d2967b50abbefd651d563dbcd7afbcb963d08c).
82+
83+
### TPU Pre-Training
9584

96-
Run the following command, replacing `<jax-build-commit-hash>` with the hash you
97-
copied in the previous step.
85+
If you have made changes to TPU pre-training dependencies in `src/dependencies/requirements/base_requirements/tpu-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
9886

9987
```bash
10088
seed-env \
@@ -104,45 +92,34 @@ seed-env \
10492
--python-version=3.12 \
10593
--requirements-txt=tpu-requirements.txt \
10694
--output-dir=generated_tpu_artifacts
95+
96+
# Copy generated requirements to src/dependencies/requirements/generated_requirements
97+
mv generated_tpu_artifacts/tpu-requirements.txt src/dependencies/requirements/generated_requirements/tpu-requirements.txt
10798
```
10899

109-
### For GPU
100+
### GPU Pre-Training
110101

111-
Similarly, run the command for the GPU requirements.
102+
If you have made changes to the GPU pre-training dependencies in `src/dependencies/requirements/base_requirements/cuda12-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
112103

113104
```bash
114105
seed-env \
115-
--local-requirements=src/dependencies/requirements/base_requirements/gpu-base-requirements.txt \
106+
--local-requirements=src/dependencies/requirements/base_requirements/cuda12-requirements.txt \
116107
--host-name=MaxText \
117108
--seed-commit=<jax-build-commit-hash> \
118109
--python-version=3.12 \
119110
--requirements-txt=cuda12-requirements.txt \
120111
--hardware=cuda12 \
121112
--output-dir=generated_gpu_artifacts
122-
```
123-
124-
## Step 4: Update project files
125113

126-
After generating the new requirements, you need to update the files in the
127-
MaxText repository.
128-
129-
1. **Copy the generated files:**
130-
131-
- Move `generated_tpu_artifacts/tpu-requirements.txt` to `generated_requirements/tpu-requirements.txt`.
132-
- Move `generated_gpu_artifacts/cuda12-requirements.txt` to `generated_requirements/cuda12-requirements.txt`.
133-
134-
2. **Update `src/dependencies/extra_deps` (if necessary):**
135-
Currently, MaxText uses a few dependencies, such as `mlperf-logging` and
136-
`google-jetstream`, that are installed directly from GitHub source. These are
137-
defined in `base_requirements/requirements.txt`, and the `seed-env` tool will
138-
carry them over to the generated requirements files.
114+
# Copy generated requirements to src/dependencies/requirements/generated_requirements
115+
mv generated_gpu_artifacts/cuda12-requirements.txt src/dependencies/requirements/generated_requirements/cuda12-requirements.txt
116+
```
139117

140-
## Step 5: Verify the new dependencies
118+
## Step 4: Verify the new dependencies
141119

142120
Finally, test that the new dependencies install correctly and that MaxText runs
143121
as expected.
144122

145-
1. **Install MaxText:** Follow the instructions to
146-
[install MaxText from source](install-from-source).
123+
1. **Install MaxText and dependencies**: For instructions on installing MaxText on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-source).
147124

148125
2. **Run tests:** Run MaxText tests to ensure there are no regressions.

src/dependencies/requirements/base_requirements/gpu-base-requirements.txt renamed to src/dependencies/requirements/base_requirements/cuda12-requirements.txt

File renamed without changes.

src/dependencies/requirements/base_requirements/requirements.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
absl-py
22
aqtp
33
array-record
4+
chex
45
cloud-accelerator-diagnostics
5-
cloud-tpu-diagnostics
6+
cloud-tpu-diagnostics!=1.1.14
67
datasets
78
drjax
89
flax
@@ -24,6 +25,7 @@ numpy
2425
omegaconf
2526
optax
2627
orbax-checkpoint
28+
parameterized
2729
pathwaysutils
2830
pillow
2931
pre-commit
@@ -34,15 +36,14 @@ pylint
3436
pytest
3537
pytype
3638
sentencepiece
39+
seqio
3740
tensorboard-plugin-profile
3841
tensorboardx
3942
tensorflow-datasets
4043
tensorflow-text
4144
tensorflow
4245
tiktoken
43-
tokamax
46+
tokamax!=0.1.0
4447
transformers
4548
uvloop
4649
qwix
47-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
48-
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip

src/dependencies/requirements/base_requirements/tpu-base-requirements.txt renamed to src/dependencies/requirements/base_requirements/tpu-requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
-r requirements.txt
2-
google-tunix

0 commit comments

Comments
 (0)