diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index 13efbcb3f8..d54f98fbec 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -231,9 +231,8 @@ jobs: steps: - name: Check test results run: | - # If doc-only, all tests should be skipped - if [ "${NEEDS_DOC_ONLY_CHECK_OUTPUTS_RUN_TESTS}" == "false" ]; then - echo "Documentation-only changes detected, tests were skipped" + if [ "${NEEDS_ANALYZE_CHANGES_OUTPUTS_RUN_TESTS}" == "false" ]; then + echo "Tests were skipped" exit 0 fi diff --git a/.github/workflows/run_pathways_tests.yml b/.github/workflows/run_pathways_tests.yml index b6776d7ddd..60b2dd3d93 100644 --- a/.github/workflows/run_pathways_tests.yml +++ b/.github/workflows/run_pathways_tests.yml @@ -107,32 +107,32 @@ jobs: PYTHONPATH: "${{ github.workspace }}/src" services: resource_manager: - image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:20260413-jax_0.9.2 ports: - "29001:29001" - "29002:29002" options: - --entrypoint=[/usr/pathways/run/cloud_pathways_server_sanitized, --server_port=29001, --node_type=resource_manager, --instance_count=1, --instance_type=tpuv6e:2x2, --gcs_scratch_location=gs://cloud-pathways-staging/tmp] + --entrypoint=[/usr/pathways/run/cloud_pathways_server_sanitized, --server_port=29001, --node_type=resource_manager, --enforce_kernel_ipv6_support=false, --instance_count=1, --instance_type=tpuv6e:2x2, --gcs_scratch_location=gs://cloud-pathways-staging/tmp] env: TPU_SKIP_MDS_QUERY: true worker: - image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:20260413-jax_0.9.2 ports: - "29005:29005" - "29006:29006" - "8471:8471" - "8080:8080" options: - --entrypoint=[/usr/pathways/run/cloud_pathways_server_sanitized, --server_port=29005, --resource_manager_address=localhost:29001, --gcs_scratch_location=gs://cloud-pathways-staging/tmp] + --entrypoint=[/usr/pathways/run/cloud_pathways_server_sanitized, --server_port=29005, --resource_manager_address=localhost:29001, --enforce_kernel_ipv6_support=false, --gcs_scratch_location=gs://cloud-pathways-staging/tmp] --tpu=4 proxy: - image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:20260413-jax_0.9.2 ports: - "29000:29000" env: IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS: true XLA_FLAGS: "--xla_dump_to=/tmp/aot_test_dump --xla_dump_hlo_as_text --xla_dump_hlo_module_re=jit_train_step" options: - --entrypoint=[/usr/pathways/run/cloud_proxy_server_sanitized, --server_port=29000, --resource_manager_address=localhost:29001, --gcs_scratch_location=gs://cloud-pathways-staging/tmp, --xla_tpu_scoped_vmem_limit_kib=65536, --xla_tpu_spmd_rng_bit_generator_unsafe=true] + --entrypoint=[/usr/pathways/run/cloud_proxy_server_sanitized, --server_port=29000, --resource_manager_address=localhost:29001, --enforce_kernel_ipv6_support=false, --gcs_scratch_location=gs://cloud-pathways-staging/tmp, --xla_tpu_scoped_vmem_limit_kib=65536, --xla_tpu_spmd_rng_bit_generator_unsafe=true] diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index 4520b7580c..7817e2d678 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -152,6 +152,14 @@ jobs: # omit this libtpu init args for gpu tests if [ "${INPUTS_DEVICE_TYPE}" != "cuda12" ]; then export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536' + else + # For cuda12, explicitly point to the pip-installed CUDA libraries + # to avoid conflicts with system-level installations on the runner. + if [ -d ".venv/lib/python3.12/site-packages/nvidia" ]; then + export LD_LIBRARY_PATH=$(pwd)/.venv/lib/python3.12/site-packages/nvidia/cudnn/lib:${LD_LIBRARY_PATH} + else + echo "Warning: Could not find pinned nvidia libraries in .venv." + fi fi if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then $PYTHON_EXE -m pip install --quiet pytest-split pytest-xdist diff --git a/docs/development/update_dependencies.md b/docs/development/update_dependencies.md index 4872ef788f..b9244bef69 100644 --- a/docs/development/update_dependencies.md +++ b/docs/development/update_dependencies.md @@ -34,19 +34,15 @@ to keep dependencies in sync for users installing MaxText from source. To update dependencies, you will follow these general steps: 1. **Modify base requirements**: Update the desired dependencies in - `src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific files - (`src/dependencies/requirements/base_requirements/tpu-base-requirements.txt`, - `src/dependencies/requirements/base_requirements/gpu-base-requirements.txt`). + `src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files + (`src/dependencies/requirements/base_requirements/tpu-requirements.txt`, + `src/dependencies/requirements/base_requirements/cuda12-requirements.txt`). 2. **Find the JAX build commit hash**: The dependency generation process is pinned to a specific nightly build of JAX. You need to find the commit hash for the desired JAX build. 3. **Generate the requirement files**: Run the `seed-env` CLI tool to generate new, fully-pinned requirements files based on your changes. -4. **Update project files**: Copy the newly generated files into the - `src/dependencies/requirements/generated_requirements/` directory. If - necessary, also update any dependencies that are installed directly from - GitHub from the generated files to `src/dependencies/extra_deps`. -5. **Verify the new dependencies**: Test the new dependencies to ensure the +4. **Verify the new dependencies**: Test the new dependencies to ensure the project installs and runs correctly. The following sections provide detailed instructions for each step. @@ -70,20 +66,11 @@ if you want to build `seed-env` from source. ## Step 1: Modify base requirements -Update the desired dependencies in -`src/dependencies/requirements/base_requirements/requirements.txt` or the -hardware-specific files -(`src/dependencies/requirements/base_requirements/tpu-base-requirements.txt`, -`src/dependencies/requirements/base_requirements/gpu-base-requirements.txt`). +Update the desired dependencies in `src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files (`src/dependencies/requirements/base_requirements/tpu-requirements.txt`, `src/dependencies/requirements/base_requirements/cuda12-requirements.txt`). ## Step 2: Find the JAX build commit hash -The dependency generation process is pinned to a specific nightly build of JAX. -You need to find the commit hash for the desired JAX build. - -You can find the latest commit hashes in the -[JAX `build/` folder](https://github.com/jax-ml/jax/commits/main/build). Choose -a recent, successful build and copy its full commit hash. +The dependency generation process is pinned to a specific nightly build of JAX. You need to find the commit hash for the desired JAX build from [JAX `build/` folder](https://github.com/jax-ml/jax/commits/main/build) and copy its full commit hash. ## Step 3: Generate the requirements files @@ -91,10 +78,11 @@ Next, run the `seed-env` CLI to generate the new requirements files. You will need to do this separately for the TPU and GPU environments. The generated files will be placed in a directory specified by `--output-dir`. -### For TPU +> **Note:** The current `src/dependencies/requirements/generated_requirements/` in the repository were generated using JAX build commit hash: [e0d2967b50abbefd651d563dbcd7afbcb963d08c](https://github.com/jax-ml/jax/commit/e0d2967b50abbefd651d563dbcd7afbcb963d08c). + +### TPU Pre-Training -Run the following command, replacing `` with the hash you -copied in the previous step. +If you have made changes to TPU pre-training dependencies in `src/dependencies/requirements/base_requirements/tpu-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `` with the hash you copied in the previous step: ```bash seed-env \ @@ -104,45 +92,36 @@ seed-env \ --python-version=3.12 \ --requirements-txt=tpu-requirements.txt \ --output-dir=generated_tpu_artifacts + +# Copy generated requirements to src/dependencies/requirements/generated_requirements +mv generated_tpu_artifacts/tpu-requirements.txt \ + src/dependencies/requirements/generated_requirements/tpu-requirements.txt ``` -### For GPU +### GPU Pre-Training -Similarly, run the command for the GPU requirements. +If you have made changes to the GPU pre-training dependencies in `src/dependencies/requirements/base_requirements/cuda12-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `` with the hash you copied in the previous step: ```bash seed-env \ - --local-requirements=src/dependencies/requirements/base_requirements/gpu-base-requirements.txt \ + --local-requirements=src/dependencies/requirements/base_requirements/cuda12-requirements.txt \ --host-name=MaxText \ --seed-commit= \ --python-version=3.12 \ --requirements-txt=cuda12-requirements.txt \ --hardware=cuda12 \ --output-dir=generated_gpu_artifacts -``` - -## Step 4: Update project files -After generating the new requirements, you need to update the files in the -MaxText repository. - -1. **Copy the generated files:** - - - Move `generated_tpu_artifacts/tpu-requirements.txt` to `generated_requirements/tpu-requirements.txt`. - - Move `generated_gpu_artifacts/cuda12-requirements.txt` to `generated_requirements/cuda12-requirements.txt`. - -2. **Update `src/dependencies/extra_deps` (if necessary):** - Currently, MaxText uses a few dependencies, such as `mlperf-logging` and - `google-jetstream`, that are installed directly from GitHub source. These are - defined in `base_requirements/requirements.txt`, and the `seed-env` tool will - carry them over to the generated requirements files. +# Copy generated requirements to src/dependencies/requirements/generated_requirements +mv generated_gpu_artifacts/cuda12-requirements.txt \ + src/dependencies/requirements/generated_requirements/cuda12-requirements.txt +``` -## Step 5: Verify the new dependencies +## Step 4: Verify the new dependencies Finally, test that the new dependencies install correctly and that MaxText runs as expected. -1. **Install MaxText:** Follow the instructions to - [install MaxText from source](install-from-source). +1. **Install MaxText and dependencies**: For instructions on installing MaxText on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-source). 2. **Run tests:** Run MaxText tests to ensure there are no regressions. diff --git a/src/dependencies/requirements/base_requirements/gpu-base-requirements.txt b/src/dependencies/requirements/base_requirements/cuda12-requirements.txt similarity index 100% rename from src/dependencies/requirements/base_requirements/gpu-base-requirements.txt rename to src/dependencies/requirements/base_requirements/cuda12-requirements.txt diff --git a/src/dependencies/requirements/base_requirements/requirements.txt b/src/dependencies/requirements/base_requirements/requirements.txt index 61dc53e61d..52b68dd289 100644 --- a/src/dependencies/requirements/base_requirements/requirements.txt +++ b/src/dependencies/requirements/base_requirements/requirements.txt @@ -1,8 +1,9 @@ absl-py aqtp array-record +chex cloud-accelerator-diagnostics -cloud-tpu-diagnostics +cloud-tpu-diagnostics!=1.1.14 datasets drjax flax @@ -24,6 +25,7 @@ numpy omegaconf optax orbax-checkpoint +parameterized pathwaysutils pillow pre-commit @@ -34,15 +36,14 @@ pylint pytest pytype sentencepiece +seqio tensorboard-plugin-profile tensorboardx tensorflow-datasets tensorflow-text tensorflow tiktoken -tokamax +tokamax!=0.1.0 transformers uvloop qwix -google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip -mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip diff --git a/src/dependencies/requirements/base_requirements/tpu-base-requirements.txt b/src/dependencies/requirements/base_requirements/tpu-requirements.txt similarity index 60% rename from src/dependencies/requirements/base_requirements/tpu-base-requirements.txt rename to src/dependencies/requirements/base_requirements/tpu-requirements.txt index 3771eb9143..bc04b4960a 100644 --- a/src/dependencies/requirements/base_requirements/tpu-base-requirements.txt +++ b/src/dependencies/requirements/base_requirements/tpu-requirements.txt @@ -1,2 +1 @@ -r requirements.txt -google-tunix diff --git a/src/dependencies/requirements/generated_requirements/cuda12-requirements.txt b/src/dependencies/requirements/generated_requirements/cuda12-requirements.txt index 364465d935..01eb07b9a2 100644 --- a/src/dependencies/requirements/generated_requirements/cuda12-requirements.txt +++ b/src/dependencies/requirements/generated_requirements/cuda12-requirements.txt @@ -1,259 +1,258 @@ # Generated by seed-env. Do not edit manually. # If you need to modify dependencies, please do so in the host requirements file and run seed-env again. -absl-py>=2.3.1 +absl-py>=2.4.0 aiofiles>=25.1.0 aiohappyeyeballs>=2.6.1 -aiohttp>=3.13.2 +aiohttp>=3.13.5 aiosignal>=1.4.0 annotated-doc>=0.0.4 annotated-types>=0.7.0 antlr4-python3-runtime>=4.9.3 -anyio>=4.11.0 +anyio>=4.13.0 aqtp>=0.9.0 array-record>=0.8.3 -astroid>=4.0.2 +astroid>=4.0.4 astunparse>=1.6.3 attrs>=25.4.0 -auditwheel>=6.5.0 -black>=24.10.0 -blobfile>=3.1.0 -build>=1.3.0 -cachetools>=6.2.2 -certifi>=2025.11.12 +auditwheel>=6.6.0 +black>=25.12.0 +build>=1.4.0 +certifi>=2026.2.25 +cffi>=2.0.0 ; platform_python_implementation != 'PyPy' cfgv>=3.5.0 -charset-normalizer>=3.4.4 -cheroot>=11.1.2 +charset-normalizer>=3.4.6 chex>=0.1.91 -click>=8.3.1 +click>=8.3.2 cloud-accelerator-diagnostics>=0.1.1 cloud-tpu-diagnostics>=0.1.5 cloudpickle>=3.1.2 clu>=0.0.12 colorama>=0.4.6 contourpy>=1.3.3 -coverage>=7.12.0 +cryptography>=46.0.7 cycler>=0.12.1 -datasets>=4.4.1 +dataclasses-json>=0.6.7 +datasets>=4.8.4 decorator>=5.2.1 -dill>=0.4.0 +deprecated>=1.3.1 +dill>=0.4.1 distlib>=0.4.0 -dm-tree>=0.1.9 +distro>=1.9.0 +dm-tree>=0.1.10 docstring-parser>=0.17.0 drjax>=0.1.4 editdistance>=0.8.1 -einops>=0.8.1 +einops>=0.8.2 einshape>=1.0 -etils>=1.13.0 -evaluate>=0.4.6 +etils>=1.14.0 execnet>=2.1.2 -fastapi>=0.122.0 -filelock>=3.20.0 -flatbuffers>=25.9.23 -flax>=0.12.1 -fonttools>=4.60.1 +fastapi>=0.135.3 +filelock>=3.20.3 +flatbuffers>=25.12.19 +flax>=0.12.6 +fonttools>=4.62.1 frozenlist>=1.8.0 -fsspec>=2025.10.0 -gast>=0.6.0 -gcsfs>=2025.10.0 -google-api-core>=2.28.1 -google-api-python-client>=2.187.0 -google-auth-httplib2>=0.2.1 -google-auth-oauthlib>=1.2.2 -google-auth>=2.43.0 -google-cloud-aiplatform>=1.128.0 -google-cloud-appengine-logging>=1.7.0 -google-cloud-audit-log>=0.4.0 -google-cloud-bigquery>=3.38.0 -google-cloud-core>=2.5.0 -google-cloud-logging>=3.12.1 -google-cloud-mldiagnostics>=0.5.10 -google-cloud-monitoring>=2.28.0 -google-cloud-resource-manager>=1.15.0 -google-cloud-storage>=3.6.0 -google-crc32c>=1.7.1 -google-genai>=1.52.0 +fsspec>=2026.2.0 +gast>=0.7.0 +gcsfs>=2026.2.0 +google-api-core>=2.30.3 +google-api-python-client>=2.194.0 +google-auth-httplib2>=0.3.1 +google-auth-oauthlib>=1.3.1 +google-auth>=2.49.2 +google-cloud-aiplatform>=1.147.0 +google-cloud-appengine-logging>=1.9.0 +google-cloud-audit-log>=0.5.0 +google-cloud-bigquery>=3.41.0 +google-cloud-core>=2.5.1 +google-cloud-logging>=3.15.0 +google-cloud-mldiagnostics>=1.0.2 +google-cloud-monitoring>=2.30.0 +google-cloud-resource-manager>=1.17.0 +google-cloud-storage-control>=1.11.0 +google-cloud-storage>=3.10.1 +google-crc32c>=1.8.0 +google-genai>=1.72.0 google-pasta>=0.2.0 -google-resumable-media>=2.8.0 -googleapis-common-protos>=1.72.0 -grain>=0.2.15 -grpc-google-iam-v1>=0.14.3 -grpcio-status>=1.71.2 -grpcio>=1.76.0 +google-resumable-media>=2.8.2 +googleapis-common-protos>=1.74.0 +grain>=0.2.16 +grpc-google-iam-v1>=0.14.4 +grpcio-status>=1.78.0 +grpcio>=1.78.0 gviz-api>=1.10.0 h11>=0.16.0 -h5py>=3.15.1 -hf-xet>=1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' +h5py>=3.14.0 +hf-xet>=1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' httpcore>=1.0.9 -httplib2>=0.31.0 +httplib2>=0.31.2 httpx>=0.28.1 -huggingface-hub>=0.36.0 -humanize>=4.14.0 +huggingface-hub>=1.10.1 +humanize>=4.15.0 hypothesis>=6.142.1 -identify>=2.6.15 +identify>=2.6.18 idna>=3.11 -immutabledict>=4.2.2 +immutabledict>=4.3.1 importlab>=0.8.1 -importlib-metadata>=8.7.0 -importlib-resources>=6.5.2 +importlib-metadata>=9.0.0 iniconfig>=2.3.0 -isort>=7.0.0 -jaraco-functools>=4.3.0 -jax-cuda12-pjrt>=0.8.1 ; sys_platform == 'linux' -jax-cuda12-plugin>=0.8.1 ; sys_platform == 'linux' -jax>=0.8.1 -jaxlib>=0.8.1 -jaxtyping>=0.3.3 +isort>=8.0.1 +jax-cuda12-pjrt>=0.9.2 ; sys_platform == 'linux' +jax-cuda12-plugin>=0.9.2 ; sys_platform == 'linux' +jax>=0.9.2 +jaxlib>=0.9.2 +jaxtyping>=0.3.9 jinja2>=3.1.6 -joblib>=1.5.2 jsonlines>=4.0.0 -keras>=3.12.0 -kiwisolver>=1.4.9 +keras>=3.13.2 +kiwisolver>=1.5.0 +latex2sympy2-extended>=1.11.0 libclang>=18.1.1 libcst>=1.8.6 -lxml>=6.0.2 markdown-it-py>=4.0.0 -markdown>=3.10 +markdown>=3.10.2 markupsafe>=3.0.3 -matplotlib>=3.10.7 +marshmallow>=3.26.2 +math-verify>=0.9.0 +matplotlib>=3.10.8 mccabe>=0.7.0 mdurl>=0.1.2 ml-collections>=1.1.0 ml-dtypes>=0.5.4 -ml-goodput-measurement>=0.0.15 -more-itertools>=10.8.0 +ml-goodput-measurement>=0.0.16 mpmath>=1.3.0 msgpack>=1.1.2 -msgspec>=0.20.0 -multidict>=6.7.0 -multiprocess>=0.70.18 +msgspec>=0.21.1 +multidict>=6.7.1 +multiprocess>=0.70.19 mypy-extensions>=1.1.0 namex>=0.1.0 -nest-asyncio>=1.6.0 -networkx>=3.6 +nest-asyncio>=1.6.0 ; sys_platform == 'win32' +networkx>=3.6.1 ninja>=1.13.0 -nltk>=3.9.2 -nodeenv>=1.9.1 -numpy-typing-compat>=20250818.2.0 +nodeenv>=1.10.0 +numpy-typing-compat>=20251206.2.0 numpy>=2.0.2 nvidia-cublas-cu12>=12.9.1.4 ; sys_platform == 'linux' +nvidia-cuda-cccl-cu12>=12.9.27 +nvidia-cuda-cccl>=13.2.27 nvidia-cuda-cupti-cu12>=12.9.79 ; sys_platform == 'linux' nvidia-cuda-nvcc-cu12>=12.9.86 ; sys_platform == 'linux' nvidia-cuda-nvrtc-cu12>=12.9.86 ; sys_platform == 'linux' nvidia-cuda-runtime-cu12>=12.9.79 ; sys_platform == 'linux' -nvidia-cudnn-cu12>=9.16.0.29 ; sys_platform == 'linux' +nvidia-cudnn-cu12>=9.20.0.48 ; sys_platform == 'linux' nvidia-cufft-cu12>=11.4.1.4 ; sys_platform == 'linux' nvidia-cusolver-cu12>=11.7.5.82 ; sys_platform == 'linux' nvidia-cusparse-cu12>=12.5.10.65 ; sys_platform == 'linux' -nvidia-nccl-cu12>=2.28.9 ; sys_platform == 'linux' +nvidia-nccl-cu12>=2.29.7 ; sys_platform == 'linux' nvidia-nvjitlink-cu12>=12.9.86 ; sys_platform == 'linux' -nvidia-nvshmem-cu12>=3.4.5 ; sys_platform == 'linux' +nvidia-nvshmem-cu12>=3.5.21 ; sys_platform == 'linux' oauthlib>=3.3.1 omegaconf>=2.3.0 -opentelemetry-api>=1.38.0 +opentelemetry-api>=1.16.0 opt-einsum>=3.4.0 -optax>=0.2.6 -optree>=0.18.0 -optype>=0.14.0 -orbax-checkpoint>=0.11.33 -packaging>=25.0 -pandas>=2.3.3 +optax>=0.2.8 +optree>=0.19.0 +optype>=0.17.0 +orbax-checkpoint>=0.11.34 +orbax-export>=0.0.8 +packaging>=26.0 +pandas>=3.0.2 parameterized>=0.9.0 -pathspec>=0.12.1 -pathwaysutils>=0.1.3 -pillow>=12.0.0 -platformdirs>=4.5.0 +pathspec>=1.0.4 +pathwaysutils>=0.1.7 +pillow>=12.1.1 +platformdirs>=4.9.6 pluggy>=1.6.0 portpicker>=1.6.0 -pre-commit>=4.5.0 -prometheus-client>=0.23.1 +pre-commit>=4.5.1 promise>=2.3 propcache>=0.4.1 -proto-plus>=1.26.1 -protobuf>=5.29.5 -psutil>=7.1.3 -pyarrow>=22.0.0 +proto-plus>=1.27.2 +protobuf>=6.33.6 +psutil>=7.2.2 +pyarrow>=23.0.1 pyasn1-modules>=0.4.2 -pyasn1>=0.6.1 +pyasn1>=0.6.3 pycnite>=2024.7.31 -pycryptodomex>=3.23.0 -pydantic-core>=2.41.5 -pydantic>=2.12.5 +pycparser>=3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy' +pydantic-core>=2.46.0 +pydantic>=2.13.0 pydot>=4.0.1 pyelftools>=0.32 pyglove>=0.4.5 pygments>=2.19.2 -pyink>=24.10.1 -pylint>=4.0.3 -pyparsing>=3.2.5 +pyink>=25.12.0 +pylint>=4.0.5 +pyparsing>=3.3.2 pyproject-hooks>=1.2.0 pytest-xdist>=3.8.0 pytest>=8.4.2 python-dateutil>=2.9.0.post0 +pytokens>=0.4.1 pytype>=2024.10.11 -pytz>=2025.2 pyyaml>=6.0.3 -qwix>=0.1.4 -regex>=2025.11.3 +qwix>=0.1.5 +regex>=2026.4.4 requests-oauthlib>=2.0.0 requests>=2.32.5 -rich>=14.2.0 -rsa>=4.9.1 +rich>=14.3.3 safetensors>=0.7.0 -scipy-stubs>=1.16.3.0 -scipy>=1.16.3 +scipy-stubs>=1.17.1.2 +scipy>=1.17.1 sentencepiece>=0.2.1 seqio>=0.0.20 -setuptools>=80.9.0 -shapely>=2.1.2 -shortuuid>=1.0.13 -simple-parsing>=0.1.7 +setuptools>=82.0.1 +shellingham>=1.5.4 +simple-parsing>=0.1.8 simplejson>=3.20.2 six>=1.17.0 sniffio>=1.3.1 sortedcontainers>=2.4.0 -starlette>=0.50.0 +starlette>=1.0.0 sympy>=1.14.0 -tabulate>=0.9.0 -tenacity>=9.1.2 +tabulate>=0.10.0 +tenacity>=9.1.4 tensorboard-data-server>=0.7.2 tensorboard-plugin-profile>=2.13.0 -tensorboard>=2.19.0 -tensorboardx>=2.6.4 +tensorboard>=2.20.0 +tensorboardx>=2.6.5 tensorflow-datasets>=4.9.9 -tensorflow-metadata>=1.17.2 -tensorflow-text>=2.19.0 -tensorflow>=2.19.1 -tensorstore>=0.1.79 -termcolor>=3.2.0 +tensorflow-metadata>=1.17.3 +tensorflow-text>=2.20.1 +tensorflow>=2.20.0 +tensorstore>=0.1.82 +termcolor>=3.3.0 tiktoken>=0.12.0 -tokamax>=0.0.8 -tokenizers>=0.22.1 +tokamax>=0.0.12 +tokenizers>=0.22.2 toml>=0.10.2 -tomlkit>=0.13.3 +tomlkit>=0.14.0 toolz>=1.1.0 -tqdm>=4.67.1 -transformer-engine-cu12>=2.9.0 -transformer-engine-jax>=2.9.0 -transformer-engine>=2.9.0 -transformers>=4.57.3 +tqdm>=4.67.3 +transformer-engine-cu12>=2.13.0 +transformer-engine-jax>=2.13.0 +transformer-engine>=2.13.0 +transformers>=5.5.4 treescope>=0.1.10 typeguard>=2.13.3 +typer>=0.24.1 typing-extensions>=4.15.0 +typing-inspect>=0.9.0 typing-inspection>=0.4.2 -tzdata>=2025.2 +tzdata>=2026.1 ; sys_platform == 'emscripten' or sys_platform == 'win32' uritemplate>=4.2.0 -urllib3>=2.5.0 -uvicorn>=0.38.0 -uvloop>=0.19.0 -virtualenv>=20.35.4 +urllib3>=2.6.3 +uvicorn>=0.44.0 +uvloop>=0.22.1 +virtualenv>=20.36.1 wadler-lindig>=0.1.7 -websockets>=15.0.1 -werkzeug>=3.1.3 -wheel>=0.45.1 -wrapt>=2.0.1 -xprof>=2.21.1 +websockets>=16.0 +werkzeug>=3.1.8 +wheel>=0.46.3 +wrapt>=2.1.2 xxhash>=3.6.0 -yarl>=1.22.0 +yarl>=1.23.0 zipp>=3.23.0 -zstandard>=0.25.0 +zstandard>=0.25.0 \ No newline at end of file diff --git a/src/dependencies/requirements/generated_requirements/tpu-requirements.txt b/src/dependencies/requirements/generated_requirements/tpu-requirements.txt index 08da4a3ab7..d2352b51a8 100644 --- a/src/dependencies/requirements/generated_requirements/tpu-requirements.txt +++ b/src/dependencies/requirements/generated_requirements/tpu-requirements.txt @@ -1,252 +1,240 @@ # Generated by seed-env. Do not edit manually. # If you need to modify dependencies, please do so in the host requirements file and run seed-env again. -absl-py>=2.3.1 +absl-py>=2.4.0 aiofiles>=25.1.0 aiohappyeyeballs>=2.6.1 -aiohttp>=3.13.2 +aiohttp>=3.13.5 aiosignal>=1.4.0 annotated-doc>=0.0.4 annotated-types>=0.7.0 antlr4-python3-runtime>=4.9.3 -anyio>=4.11.0 +anyio>=4.13.0 aqtp>=0.9.0 array-record>=0.8.3 -astroid>=4.0.2 +astroid>=4.0.4 astunparse>=1.6.3 attrs>=25.4.0 -auditwheel>=6.5.0 -black>=24.10.0 -blobfile>=3.1.0 -build>=1.3.0 -cachetools>=6.2.2 -certifi>=2025.11.12 +auditwheel>=6.6.0 +black>=25.12.0 +build>=1.4.0 +certifi>=2026.2.25 +cffi>=2.0.0 ; platform_python_implementation != 'PyPy' cfgv>=3.5.0 -charset-normalizer>=3.4.4 -cheroot>=11.1.2 +charset-normalizer>=3.4.6 chex>=0.1.91 -click>=8.3.1 +click>=8.3.2 cloud-accelerator-diagnostics>=0.1.1 cloud-tpu-diagnostics>=0.1.5 cloudpickle>=3.1.2 clu>=0.0.12 colorama>=0.4.6 contourpy>=1.3.3 -coverage>=7.12.0 +cryptography>=46.0.7 cycler>=0.12.1 -dacite>=1.9.2 -datasets>=4.4.1 +dataclasses-json>=0.6.7 +datasets>=4.8.4 decorator>=5.2.1 -dill>=0.4.0 +dill>=0.4.1 distlib>=0.4.0 -dm-tree>=0.1.9 +distro>=1.9.0 +dm-tree>=0.1.10 docstring-parser>=0.17.0 drjax>=0.1.4 editdistance>=0.8.1 -einops>=0.8.1 +einops>=0.8.2 einshape>=1.0 -etils>=1.13.0 -evaluate>=0.4.6 +etils>=1.14.0 execnet>=2.1.2 -fastapi>=0.122.0 -filelock>=3.20.0 -flatbuffers>=25.9.23 +fastapi>=0.135.3 +filelock>=3.20.3 +flatbuffers>=25.12.19 flax>=0.12.6 -fonttools>=4.60.1 +fonttools>=4.62.1 frozenlist>=1.8.0 -fsspec>=2025.10.0 -gast>=0.6.0 -gcsfs>=2025.10.0 -google-api-core>=2.28.1 -google-api-python-client>=2.187.0 -google-auth-httplib2>=0.2.1 -google-auth-oauthlib>=1.2.2 -google-auth>=2.43.0 -google-cloud-aiplatform>=1.128.0 -google-cloud-appengine-logging>=1.7.0 -google-cloud-audit-log>=0.4.0 -google-cloud-bigquery>=3.38.0 -google-cloud-core>=2.5.0 -google-cloud-logging>=3.12.1 -google-cloud-mldiagnostics>=0.5.10 -google-cloud-monitoring>=2.28.0 -google-cloud-resource-manager>=1.15.0 -google-cloud-storage>=3.6.0 -google-crc32c>=1.7.1 -google-genai>=1.52.0 +fsspec>=2026.2.0 +gast>=0.7.0 +gcsfs>=2026.2.0 +google-api-core>=2.30.3 +google-api-python-client>=2.194.0 +google-auth-httplib2>=0.3.1 +google-auth-oauthlib>=1.3.1 +google-auth>=2.49.2 +google-cloud-aiplatform>=1.147.0 +google-cloud-appengine-logging>=1.9.0 +google-cloud-audit-log>=0.5.0 +google-cloud-bigquery>=3.41.0 +google-cloud-core>=2.5.1 +google-cloud-logging>=3.15.0 +google-cloud-mldiagnostics>=1.0.2 +google-cloud-monitoring>=2.30.0 +google-cloud-resource-manager>=1.17.0 +google-cloud-storage-control>=1.11.0 +google-cloud-storage>=3.10.1 +google-crc32c>=1.8.0 +google-genai>=1.72.0 google-pasta>=0.2.0 -google-resumable-media>=2.8.0 -google-tunix>=0.1.3 -googleapis-common-protos>=1.72.0 -grain>=0.2.15 -grpc-google-iam-v1>=0.14.3 -grpcio-status>=1.71.2 -grpcio>=1.76.0 -gspread>=6.2.1 +google-resumable-media>=2.8.2 +googleapis-common-protos>=1.74.0 +grain>=0.2.16 +grpc-google-iam-v1>=0.14.4 +grpcio-status>=1.78.0 +grpcio>=1.78.0 gviz-api>=1.10.0 h11>=0.16.0 -h5py>=3.15.1 -hf-transfer>=0.1.9 -hf-xet>=1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' +h5py>=3.14.0 +hf-xet>=1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' httpcore>=1.0.9 -httplib2>=0.31.0 +httplib2>=0.31.2 httpx>=0.28.1 -huggingface-hub>=0.36.0 -humanize>=4.14.0 +huggingface-hub>=1.10.1 +humanize>=4.15.0 hypothesis>=6.142.1 -identify>=2.6.15 +identify>=2.6.18 idna>=3.11 -immutabledict>=4.2.2 +immutabledict>=4.3.1 importlab>=0.8.1 -importlib-metadata>=8.7.0 -importlib-resources>=6.5.2 +importlib-metadata>=8.7.1 iniconfig>=2.3.0 -isort>=7.0.0 -jaraco-functools>=4.3.0 -jax>=0.8.1 -jaxlib>=0.8.1 -jaxtyping>=0.3.3 +isort>=8.0.1 +jax>=0.9.2 +jaxlib>=0.9.2 +jaxtyping>=0.3.9 jinja2>=3.1.6 -joblib>=1.5.2 jsonlines>=4.0.0 -kagglehub>=0.3.13 -keras>=3.12.0 -kiwisolver>=1.4.9 +keras>=3.13.2 +kiwisolver>=1.5.0 +latex2sympy2-extended>=1.11.0 libclang>=18.1.1 libcst>=1.8.6 -libtpu>=0.0.30 ; platform_machine == 'x86_64' and sys_platform == 'linux' -llvmlite>=0.45.1 -lxml>=6.0.2 +libtpu>=0.0.37 ; platform_machine == 'x86_64' and sys_platform == 'linux' markdown-it-py>=4.0.0 -markdown>=3.10 +markdown>=3.10.2 markupsafe>=3.0.3 +marshmallow>=3.26.2 math-verify>=0.9.0 -matplotlib>=3.10.7 +matplotlib>=3.10.8 mccabe>=0.7.0 mdurl>=0.1.2 ml-collections>=1.1.0 ml-dtypes>=0.5.4 -ml-goodput-measurement>=0.0.15 -more-itertools>=10.8.0 +ml-goodput-measurement>=0.0.16 mpmath>=1.3.0 msgpack>=1.1.2 -msgspec>=0.20.0 -multidict>=6.7.0 -multiprocess>=0.70.18 +msgspec>=0.21.1 +multidict>=6.7.1 +multiprocess>=0.70.19 mypy-extensions>=1.1.0 namex>=0.1.0 -nest-asyncio>=1.6.0 -networkx>=3.6 +nest-asyncio>=1.6.0 ; sys_platform == 'win32' +networkx>=3.6.1 ninja>=1.13.0 -nltk>=3.9.2 -nodeenv>=1.9.1 -numba>=0.62.1 -numpy-typing-compat>=20250818.2.0 +nodeenv>=1.10.0 +numpy-typing-compat>=20251206.2.0 numpy>=2.0.2 +nvidia-cuda-cccl>=13.2.27 oauthlib>=3.3.1 omegaconf>=2.3.0 -opentelemetry-api>=1.38.0 +opentelemetry-api>=1.41.0 opt-einsum>=3.4.0 -optax>=0.2.6 -optree>=0.18.0 -optype>=0.14.0 -orbax-checkpoint>=0.11.33 -packaging>=25.0 -pandas>=2.3.3 +optax>=0.2.8 +optree>=0.19.0 +optype>=0.17.0 +orbax-checkpoint>=0.11.34 +orbax-export>=0.0.8 +packaging>=26.0 +pandas>=3.0.2 parameterized>=0.9.0 -pathspec>=0.12.1 -pathwaysutils>=0.1.4 -pillow>=12.0.0 -platformdirs>=4.5.0 +pathspec>=1.0.4 +pathwaysutils>=0.1.7 +pillow>=12.1.1 +platformdirs>=4.9.6 pluggy>=1.6.0 portpicker>=1.6.0 -pre-commit>=4.5.0 -prometheus-client>=0.23.1 +pre-commit>=4.5.1 promise>=2.3 propcache>=0.4.1 -proto-plus>=1.26.1 -protobuf>=5.29.5 -psutil>=7.1.3 -pyarrow>=22.0.0 +proto-plus>=1.27.2 +protobuf>=6.33.6 +psutil>=7.2.2 +pyarrow>=23.0.1 pyasn1-modules>=0.4.2 -pyasn1>=0.6.1 +pyasn1>=0.6.3 pycnite>=2024.7.31 -pycryptodomex>=3.23.0 -pydantic-core>=2.41.5 -pydantic>=2.12.5 +pycparser>=3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy' +pydantic-core>=2.46.0 +pydantic>=2.13.0 pydot>=4.0.1 pyelftools>=0.32 pyglove>=0.4.5 pygments>=2.19.2 -pyink>=24.10.1 -pylint>=4.0.3 -pyparsing>=3.2.5 +pyink>=25.12.0 +pylint>=4.0.5 +pyparsing>=3.3.2 pyproject-hooks>=1.2.0 pytest-xdist>=3.8.0 pytest>=8.4.2 python-dateutil>=2.9.0.post0 -python-dotenv>=1.2.1 +pytokens>=0.4.1 pytype>=2024.10.11 -pytz>=2025.2 pyyaml>=6.0.3 -qwix>=0.1.4 -regex>=2025.11.3 +qwix>=0.1.5 +regex>=2026.4.4 requests-oauthlib>=2.0.0 requests>=2.32.5 -rich>=14.2.0 -rsa>=4.9.1 +rich>=14.3.3 safetensors>=0.7.0 -scipy-stubs>=1.16.3.0 -scipy>=1.16.3 +scipy-stubs>=1.17.1.2 +scipy>=1.17.1 sentencepiece>=0.2.1 seqio>=0.0.20 -setuptools>=80.9.0 -shapely>=2.1.2 -shortuuid>=1.0.13 -simple-parsing>=0.1.7 +setuptools>=82.0.1 +shellingham>=1.5.4 +simple-parsing>=0.1.8 simplejson>=3.20.2 six>=1.17.0 sniffio>=1.3.1 sortedcontainers>=2.4.0 -starlette>=0.50.0 +starlette>=1.0.0 sympy>=1.14.0 -tabulate>=0.9.0 -tenacity>=9.1.2 +tabulate>=0.10.0 +tenacity>=9.1.4 tensorboard-data-server>=0.7.2 tensorboard-plugin-profile>=2.13.0 -tensorboard>=2.19.0 -tensorboardx>=2.6.4 +tensorboard>=2.20.0 +tensorboardx>=2.6.5 tensorflow-datasets>=4.9.9 -tensorflow-metadata>=1.17.2 -tensorflow-text>=2.19.0 -tensorflow>=2.19.1 -tensorstore>=0.1.79 -termcolor>=3.2.0 +tensorflow-metadata>=1.17.3 +tensorflow-text>=2.20.1 +tensorflow>=2.20.0 +tensorstore>=0.1.82 +termcolor>=3.3.0 tiktoken>=0.12.0 -tokamax>=0.0.8 -tokenizers>=0.22.1 +tokamax>=0.0.12 +tokenizers>=0.22.2 toml>=0.10.2 -tomlkit>=0.13.3 +tomlkit>=0.14.0 toolz>=1.1.0 -tqdm>=4.67.1 -transformers>=4.57.3 +tqdm>=4.67.3 +transformers>=5.5.4 treescope>=0.1.10 typeguard>=2.13.3 +typer>=0.24.1 typing-extensions>=4.15.0 +typing-inspect>=0.9.0 typing-inspection>=0.4.2 -tzdata>=2025.2 +tzdata>=2026.1 ; sys_platform == 'emscripten' or sys_platform == 'win32' uritemplate>=4.2.0 -urllib3>=2.5.0 -uvicorn>=0.38.0 -uvloop>=0.19.0 -virtualenv>=20.35.4 +urllib3>=2.6.3 +uvicorn>=0.44.0 +uvloop>=0.22.1 +virtualenv>=20.36.1 wadler-lindig>=0.1.7 -websockets>=15.0.1 -werkzeug>=3.1.3 -wheel>=0.45.1 -wrapt>=2.0.1 -xprof>=2.21.1 +websockets>=16.0 +werkzeug>=3.1.8 +wheel>=0.46.3 +wrapt>=2.1.2 xxhash>=3.6.0 -yarl>=1.22.0 +yarl>=1.23.0 zipp>=3.23.0 -zstandard>=0.25.0 +zstandard>=0.25.0 \ No newline at end of file diff --git a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py index 0d187fca46..0ad4047ca4 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py @@ -855,6 +855,7 @@ # TODO(shuningjin): replace with DeepseekV32Config when available in transformers library class DeepseekV32Config(PTConfig): + model_type = "deepseek_v32" def __init__(self, **kwargs): self.max_position_embeddings = kwargs.get("max_position_embeddings", 163840) diff --git a/src/maxtext/examples/sft_llama3_demo_tpu.ipynb b/src/maxtext/examples/sft_llama3_demo_tpu.ipynb index 713c0dfb52..8a1a3cd664 100644 --- a/src/maxtext/examples/sft_llama3_demo_tpu.ipynb +++ b/src/maxtext/examples/sft_llama3_demo_tpu.ipynb @@ -154,11 +154,11 @@ "outputs": [], "source": [ "import datetime\n", + "import jax\n", "import os\n", "from maxtext.configs import pyconfig\n", "from maxtext.utils.globals import MAXTEXT_PKG_DIR\n", "from maxtext.trainers.post_train.sft import train_sft\n", - "import jax\n", "from huggingface_hub import login\n", "\n", "\n", @@ -168,13 +168,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "goXDCIB4Kgv5" - }, + "metadata": {}, "outputs": [], "source": [ "if not jax.distributed.is_initialized():\n", - " jax.distributed.initialize()\n", + " jax.distributed.initialize()\n", "print(f\"JAX version: {jax.__version__}\")\n", "print(f\"JAX devices: {jax.devices()}\")" ] @@ -223,8 +221,7 @@ }, "outputs": [], "source": [ - "MODEL_NAME = \"llama3.1-8b\"\n", - "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n", + "MODEL_NAME = \"llama3.1-8b-Instruct\"\n", "\n", "# set the path to the model checkpoint (excluding `/0/items`) or leave empty to download from HuggingFace\n", "MODEL_CHECKPOINT_PATH = \"\"\n", @@ -326,7 +323,6 @@ " f\"hf_access_token={HF_TOKEN}\",\n", " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", " f\"run_name={RUN_NAME}\",\n", - " f\"tokenizer_path={TOKENIZER_PATH}\",\n", " \"profiler=xplane\",\n", "]\n", "\n", diff --git a/src/maxtext/input_pipeline/distillation_data_processing.py b/src/maxtext/input_pipeline/distillation_data_processing.py index 44495a39e8..4aebb9169f 100644 --- a/src/maxtext/input_pipeline/distillation_data_processing.py +++ b/src/maxtext/input_pipeline/distillation_data_processing.py @@ -121,7 +121,9 @@ def filter_dataset(config, dataset, tokenizer): max_output_tokens = min(max_output_length, len(tokenizer.encode(actual_completion))) if config.use_chat_template: message = [{"role": "user", "content": prompt}] - prompt_token_ids = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=True) + prompt_token_ids = input_pipeline_utils.extract_token_ids( + tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=True) + ) else: prompt_token_ids = tokenizer.encode(prompt) diff --git a/src/maxtext/input_pipeline/input_pipeline_utils.py b/src/maxtext/input_pipeline/input_pipeline_utils.py index d8c93d141a..471f25e5ca 100644 --- a/src/maxtext/input_pipeline/input_pipeline_utils.py +++ b/src/maxtext/input_pipeline/input_pipeline_utils.py @@ -180,7 +180,7 @@ def is_conversational(features, data_columns): return False -def _extract_token_ids(tokens): +def extract_token_ids(tokens): """Extracts token IDs from various tokenizer output formats. This helper function standardizes the extraction of tokenized integer IDs @@ -248,21 +248,21 @@ def verify_chat_template_generation_prompt_logic(tokenizer_model): ) dummy_msgs.pop(0) prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True) - prompt_wo_gen_ids = _extract_token_ids(prompt_wo_gen_tokens) + prompt_wo_gen_ids = extract_token_ids(prompt_wo_gen_tokens) prompt_w_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=True, tokenize=True) - prompt_w_gen_ids = _extract_token_ids(prompt_w_gen_tokens) + prompt_w_gen_ids = extract_token_ids(prompt_w_gen_tokens) if prompt_w_gen_ids[: len(prompt_wo_gen_ids)] != prompt_wo_gen_ids: raise ValueError("Unable to extract generation prompt tokens.") # Extract the tokenized generation prompt (the expected assistant prefix) assistant_prefix = prompt_w_gen_ids[len(prompt_wo_gen_ids) :] - full_turn_tokens = _extract_token_ids( + full_turn_tokens = extract_token_ids( tokenizer_model.apply_chat_template( dummy_msgs + [{"role": "assistant", "content": "Dummy response"}], add_generation_prompt=False, tokenize=True ) ) - full_turn_ids = _extract_token_ids(full_turn_tokens) + full_turn_ids = extract_token_ids(full_turn_tokens) # Extract the actual tokens that appear right after the user message in the full turn actual_prefix_in_full_turn = full_turn_ids[len(prompt_wo_gen_ids) : len(prompt_wo_gen_ids) + len(assistant_prefix)] @@ -295,8 +295,8 @@ def _get_completion_in_chat_template(tokenizer_model, round_msgs): # include generation_prompt as part of the prompt tokens prompt_tokens = tokenizer_model.apply_chat_template(round_msgs[:-1], add_generation_prompt=True, tokenize=True) - prompt_completion_ids = _extract_token_ids(prompt_completion_tokens) - prompt_ids = _extract_token_ids(prompt_tokens) + prompt_completion_ids = extract_token_ids(prompt_completion_tokens) + prompt_ids = extract_token_ids(prompt_tokens) completion_tokens = prompt_completion_ids[len(prompt_ids) :] completion_in_chat_template = tokenizer_model.decode(completion_tokens, skip_special_tokens=False) diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 4f4a5fed20..9d7fccc457 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -688,7 +688,6 @@ def test_share_kv_projections(self): "shard_mode": "explicit", }, ) - # TODO (b/454764135.) : This tests fails with new tokamax kernel @pytest.mark.tpu_only def test_tpu_flash_attention_context_parallel( self, @@ -698,6 +697,10 @@ def test_tpu_flash_attention_context_parallel( shard_mode, ): """Test equivalence between dot_product and flash attention + context/expert parallelism""" + # TODO: Enable these tests after b/454764135 is fixed + if shard_mode == "explicit": + self.skipTest("Skipping explicit shard_mode tests.") + num_kv_heads = self.num_kv_heads lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) # Dot product @@ -1470,7 +1473,6 @@ def test_projection_initialization(self): "shard_mode": "explicit", }, ) - # TODO (b/454764135.) : This tests fails with new tokamax kernel @pytest.mark.tpu_only def test_tpu_flash_attention_context_parallel( self, @@ -1480,6 +1482,9 @@ def test_tpu_flash_attention_context_parallel( shard_mode, ): """Test equivalence between dot_product and flash attention + context/expert parallelism""" + # TODO: Enable these tests after b/454764135 is fixed + if shard_mode == "explicit": + self.skipTest("Skipping explicit shard_mode tests.") config_arguments = { "per_device_batch_size": 1.0, diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index cb291e13bd..a583c40cdd 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -24,7 +24,9 @@ from tempfile import gettempdir import pytest +import transformers +from maxtext.checkpoint_conversion.utils.hf_model_configs import DeepseekV32Config from maxtext.trainers.pre_train.train_compile import main as train_compile_main from tests.utils.test_helpers import get_test_config_path @@ -893,6 +895,7 @@ def test_mhc_integration(self): def test_engram_integration(self): """AOT test for Engram implementation""" compiled_trainstep_file = "/tmp/test_engram_integration" + transformers.AutoConfig.register("deepseek_v32", DeepseekV32Config) train_compile_main( ( "",