diff --git a/.github/workflows/build_and_push_docker_image.yml b/.github/workflows/build_and_push_docker_image.yml index e92578217d..d198407a38 100644 --- a/.github/workflows/build_and_push_docker_image.yml +++ b/.github/workflows/build_and_push_docker_image.yml @@ -117,6 +117,7 @@ jobs: MODE=${{ inputs.build_mode }} WORKFLOW=${{ inputs.workflow }} PACKAGE_DIR=./src + TESTS_DIR=./tests JAX_VERSION=NONE LIBTPU_VERSION=NONE INCLUDE_TEST_ASSETS=true diff --git a/PREFLIGHT.md b/PREFLIGHT.md index 71f2d9e379..d777ae7b40 100644 --- a/PREFLIGHT.md +++ b/PREFLIGHT.md @@ -1,35 +1,39 @@ # Optimization 1: Multihost recommended network settings -We included all the recommended network settings in [rto_setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/rto_setup.sh). + +We included all the recommended network settings in [rto_setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/rto_setup.sh). [preflight.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/preflight.sh) will help you apply them based on GCE or GKE platform. Before you run ML workload on Multihost with GCE or GKE, simply apply `bash preflight.sh PLATFORM=[GCE or GKE]` to leverage the best DCN network performance. Here is an example for GCE: + ``` -bash src/dependencies/scripts/preflight.sh PLATFORM=GCE && python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?} +bash preflight.sh PLATFORM=GCE && python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?} ``` Here is an example for GKE: + ``` -bash src/dependencies/scripts/preflight.sh PLATFORM=GKE && python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?} +bash preflight.sh PLATFORM=GKE && python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?} ``` # Optimization 2: Numa binding (You can only apply this to v4 and v5p) + NUMA binding is recommended for enhanced performance, as it reduces memory latency and maximizes data throughput, ensuring that your high-performance applications operate more efficiently and effectively. -For GCE, +For GCE, [preflight.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/preflight.sh) will help you install `numactl` dependency, so you can use it directly, here is an example: ``` -bash src/dependencies/scripts/preflight.sh PLATFORM=GCE && numactl --membind 0 --cpunodebind=0 python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?} +bash preflight.sh PLATFORM=GCE && numactl --membind 0 --cpunodebind=0 python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?} ``` For GKE, `numactl` should be built into your docker image from [maxtext_tpu_dependencies.Dockerfile](https://github.com/google/maxtext/blob/main/src/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile), so you can use it directly if you built the maxtext docker image. Here is an example ``` -bash src/dependencies/scripts/preflight.sh PLATFORM=GKE && numactl --membind 0 --cpunodebind=0 python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?} +bash preflight.sh PLATFORM=GKE && numactl --membind 0 --cpunodebind=0 python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?} ``` 1. `numactl`: This is the command-line tool used for controlling NUMA policy for processes or shared memory. It's particularly useful on multi-socket systems where memory locality can impact performance. diff --git a/docs/tutorials/first_run.md b/docs/tutorials/first_run.md index 3b7468129b..55ddb762d5 100644 --- a/docs/tutorials/first_run.md +++ b/docs/tutorials/first_run.md @@ -39,6 +39,15 @@ multiple hosts but is a good way to learn about MaxText. 2. For instructions on installing MaxText on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html). 3. After installation completes, run training on synthetic data with the following command: +```sh +python3 -m venv ~/venv-maxtext +source ~/venv-maxtext/bin/activate +bash tools/setup/setup.sh +pre-commit install +``` + +4. After installation completes, run training on synthetic data with the following command: + ```sh python3 -m maxtext.trainers.pre_train.train \ run_name=${YOUR_JOB_NAME?} \ diff --git a/docs/tutorials/posttraining/sft_on_multi_host.md b/docs/tutorials/posttraining/sft_on_multi_host.md index b54819cee0..f7bd0c514e 100644 --- a/docs/tutorials/posttraining/sft_on_multi_host.md +++ b/docs/tutorials/posttraining/sft_on_multi_host.md @@ -113,7 +113,7 @@ xpk workload create \ --workload=${WORKLOAD_NAME?} \ --tpu-type=${TPU_TYPE?} \ --num-slices=${TPU_SLICE?} \ ---command "python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL?} load_parameters_path=${MAXTEXT_CKPT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}" +--command "python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}" ``` Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`. @@ -131,7 +131,7 @@ xpk workload create-pathways \ --workload=${WORKLOAD_NAME?} \ --tpu-type=${TPU_TYPE?} \ --num-slices=${TPU_SLICE?} \ ---command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL?} load_parameters_path=${MAXTEXT_CKPT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) enable_single_controller=True" +--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True" ``` Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`. diff --git a/pyproject.toml b/pyproject.toml index a32239f9d0..840cbca67f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ Repository = "https://github.com/AI-Hypercomputer/maxtext.git" allow-direct-references = true [tool.hatch.build.targets.wheel] -packages = ["src/MaxText", "src/maxtext", "src/dependencies"] +packages = ["src/MaxText", "src/maxtext", "src/dependencies/github_deps", "src/dependencies"] # TODO: Add this hook back when it handles device-type parsing # [tool.hatch.build.targets.wheel.hooks.custom] diff --git a/src/dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile b/src/dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile index b3caf4faa0..4d6a81a881 100644 --- a/src/dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile +++ b/src/dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile @@ -41,6 +41,9 @@ ENV ENV_DEVICE=$DEVICE ARG PACKAGE_DIR ENV PACKAGE_DIR=$PACKAGE_DIR +ARG TESTS_DIR +ENV TESTS_DIR=$TESTS_DIR + ENV MAXTEXT_ASSETS_ROOT=/deps/src/maxtext/assets ENV MAXTEXT_TEST_ASSETS_ROOT=/deps/tests/assets ENV MAXTEXT_PKG_DIR=/deps/src/maxtext @@ -63,9 +66,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Now copy the remaining code (source files that may change frequently) COPY ${PACKAGE_DIR}/maxtext/ src/maxtext/ -COPY ${PACKAGE_DIR}/MaxText/ src/MaxText/ -COPY tests*/ tests/ -COPY benchmarks*/ benchmarks/ +COPY ${TESTS_DIR}*/ tests/ # Download test assets from GCS if building image with test assets ARG INCLUDE_TEST_ASSETS=false diff --git a/src/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile b/src/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile index 131fdabaf1..427640fefe 100644 --- a/src/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile +++ b/src/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile @@ -38,6 +38,9 @@ ENV ENV_DEVICE=$DEVICE ARG PACKAGE_DIR ENV PACKAGE_DIR=$PACKAGE_DIR +ARG TESTS_DIR +ENV TESTS_DIR=$TESTS_DIR + ENV MAXTEXT_ASSETS_ROOT=/deps/src/maxtext/assets ENV MAXTEXT_TEST_ASSETS_ROOT=/deps/tests/assets ENV MAXTEXT_PKG_DIR=/deps/src/maxtext @@ -63,9 +66,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Now copy the remaining code (source files that may change frequently) COPY ${PACKAGE_DIR}/maxtext/ src/maxtext/ -COPY ${PACKAGE_DIR}/MaxText/ src/MaxText/ -COPY tests*/ tests/ -COPY benchmarks*/ benchmarks/ +COPY ${TESTS_DIR}*/ tests/ # Download test assets from GCS if building image with test assets ARG INCLUDE_TEST_ASSETS=false diff --git a/src/dependencies/github_deps/install_post_train_deps.py b/src/dependencies/github_deps/install_post_train_deps.py index fd09cd2109..45f7f5b769 100644 --- a/src/dependencies/github_deps/install_post_train_deps.py +++ b/src/dependencies/github_deps/install_post_train_deps.py @@ -34,12 +34,14 @@ def main(): """ os.environ["VLLM_TARGET_DEVICE"] = "tpu" - current_dir = os.path.dirname(os.path.abspath(__file__)) - repo_root = os.path.abspath(os.path.join(current_dir, "..", "..")) - extra_deps_path = os.path.join(current_dir, "post_train_deps.txt") - if not os.path.exists(extra_deps_path): - raise FileNotFoundError(f"Dependencies file not found at {extra_deps_path}") + # Adjust this path if your post_train_deps.txt is in a different location, + # e.g., script_dir / "data" / "post_train_deps.txt" + extra_deps_file = script_dir / "post_train_deps.txt" + if not extra_deps_file.exists(): + print(f"Error: '{extra_deps_file}' not found.") + print("Please ensure 'post_train_deps.txt' is in the correct location relative to the script.") + sys.exit(1) # Check if 'uv' is available in the environment try: subprocess.run([sys.executable, "-m", "pip", "install", "uv"], check=True, capture_output=True) diff --git a/src/dependencies/github_deps/install_pre_train_deps.py b/src/dependencies/github_deps/install_pre_train_deps.py index d2cbe15ccb..9844b2e3d7 100644 --- a/src/dependencies/github_deps/install_pre_train_deps.py +++ b/src/dependencies/github_deps/install_pre_train_deps.py @@ -37,6 +37,14 @@ def main(): if not os.path.exists(extra_deps_path): raise FileNotFoundError(f"Dependencies file not found at {extra_deps_path}") + # Adjust this path if your pre_train_deps.txt is in a different location, + # e.g., script_dir / "data" / "pre_train_deps.txt" + extra_deps_file = script_dir / "pre_train_deps.txt" + + if not extra_deps_file.exists(): + print(f"Error: '{extra_deps_file}' not found.") + print("Please ensure 'pre_train_deps.txt' is in the correct location relative to the script.") + sys.exit(1) # Check if 'uv' is available in the environment try: subprocess.run([sys.executable, "-m", "pip", "install", "uv"], check=True, capture_output=True) diff --git a/src/dependencies/scripts/docker_build_dependency_image.sh b/src/dependencies/scripts/docker_build_dependency_image.sh index 3705334014..a452d55798 100644 --- a/src/dependencies/scripts/docker_build_dependency_image.sh +++ b/src/dependencies/scripts/docker_build_dependency_image.sh @@ -20,8 +20,43 @@ # For instructions on building the MaxText Docker image, please refer to the https://maxtext.readthedocs.io/en/latest/build_maxtext.html. -PACKAGE_DIR="${PACKAGE_DIR:-src}" -echo "PACKAGE_DIR: $PACKAGE_DIR" +# Build docker image with stable dependencies +## bash src/dependencies/scripts/docker_build_dependency_image.sh DEVICE={{gpu|tpu}} MODE=stable + +# Build docker image with nightly dependencies +## bash src/dependencies/scripts/docker_build_dependency_image.sh DEVICE={{gpu|tpu}} MODE=nightly + +# Build docker image with stable dependencies and, a pinned JAX_VERSION for TPUs +## bash src/dependencies/scripts/docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13 + +# Build docker image with a pinned JAX_VERSION and, a pinned LIBTPU_VERSION for TPUs +## bash src/dependencies/scripts/docker_build_dependency_image.sh MODE={{stable|nightly}} JAX_VERSION=0.8.1 LIBTPU_VERSION=0.0.31.dev20251119+nightly + +# Build docker image with a custom libtpu.so for TPUs +# Note: libtpu.so file must be present in the root directory of the MaxText repository +## bash src/dependencies/scripts/docker_build_dependency_image.sh MODE={{stable|nightly}} + +# Build docker image with nightly dependencies and, a pinned JAX_VERSION for GPUs +# Available versions listed at https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax +## bash src/dependencies/scripts/docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 + +# ================================== +# POST-TRAINING BUILD EXAMPLES +# ================================== + +# Build docker image with post-training dependencies +## bash src/dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training + +if [ "${BASH_SOURCE-}" ]; then + this_file="${BASH_SOURCE[0]}" +elif [ "${ZSH_VERSION-}" ]; then + # shellcheck disable=SC2296 + this_file="${(%):-%x}" +else + this_file="${0}" +fi + +MAXTEXT_REPO_ROOT="${MAXTEXT_REPO_ROOT:-$(CDPATH='' cd -- "$(dirname -- "${this_file}")"'/../../..' && pwd)}" # Enable "exit immediately if any command fails" option set -e @@ -71,6 +106,7 @@ docker_build_args=( "MODE=${MODE}" "JAX_VERSION=${JAX_VERSION}" "PACKAGE_DIR=${PACKAGE_DIR}" + "TESTS_DIR=${TESTS_DIR}" ) run_docker_build() { @@ -104,7 +140,7 @@ build_tpu_image() { fi echo "Building docker image with arguments: ${docker_build_args[*]}" - run_docker_build "$PACKAGE_DIR/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile" "${docker_build_args[@]}" + run_docker_build "$MAXTEXT_REPO_ROOT/src/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile" "${docker_build_args[@]}" } if [[ ${DEVICE} == "gpu" ]]; then diff --git a/src/maxtext/checkpoint_conversion/utils/param_mapping.py b/src/maxtext/checkpoint_conversion/utils/param_mapping.py index 7e318d7fe5..d7d27e7420 100644 --- a/src/maxtext/checkpoint_conversion/utils/param_mapping.py +++ b/src/maxtext/checkpoint_conversion/utils/param_mapping.py @@ -659,29 +659,22 @@ def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): # This follows the (experts, layers, ...) tensor layout. mapping.update( { - "params-decoder-layers-moe_block-gate-kernel": [ - f"model.layers.{i}.mlp.gate.weight" for i in range(n_layers) - ], + "params-decoder-layers-moe_block-gate-kernel": [f"model.layers.{i}.mlp.gate.weight" for i in range(n_layers)], "params-decoder-layers-moe_block-wi_0": [ - [f"model.layers.{l}.mlp.experts.{e}.gate_proj.weight" for l in range(n_layers)] - for e in range(num_experts) + [f"model.layers.{l}.mlp.experts.{e}.gate_proj.weight" for l in range(n_layers)] for e in range(num_experts) ], "params-decoder-layers-moe_block-wi_1": [ - [f"model.layers.{l}.mlp.experts.{e}.up_proj.weight" for l in range(n_layers)] - for e in range(num_experts) + [f"model.layers.{l}.mlp.experts.{e}.up_proj.weight" for l in range(n_layers)] for e in range(num_experts) ], "params-decoder-layers-moe_block-wo": [ - [f"model.layers.{l}.mlp.experts.{e}.down_proj.weight" for l in range(n_layers)] - for e in range(num_experts) + [f"model.layers.{l}.mlp.experts.{e}.down_proj.weight" for l in range(n_layers)] for e in range(num_experts) ], } ) else: # Dense MLP mapping.update( { - "params-decoder-layers-mlp-wi_0-kernel": [ - f"model.layers.{i}.mlp.gate_proj.weight" for i in range(n_layers) - ], + "params-decoder-layers-mlp-wi_0-kernel": [f"model.layers.{i}.mlp.gate_proj.weight" for i in range(n_layers)], "params-decoder-layers-mlp-wi_1-kernel": [f"model.layers.{i}.mlp.up_proj.weight" for i in range(n_layers)], "params-decoder-layers-mlp-wo-kernel": [f"model.layers.{i}.mlp.down_proj.weight" for i in range(n_layers)], } @@ -780,9 +773,12 @@ def reshape_kernel(input_tensor, target_shape): def reshape_bias(input_tensor, target_shape=None): """Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden).""" - # saving_to_hf: MaxText [heads, head_dim] -> HF [hidden_dim] (flatten) - # loading_to_maxtext: HF [hidden_dim] -> MaxText [heads, head_dim] - return input_tensor.reshape(target_shape) + if saving_to_hf: + # MaxText [heads, head_dim] -> HF [hidden_dim] (flatten) + return input_tensor.reshape(target_shape) + else: + # HF [hidden_dim] -> MaxText [heads, head_dim] + return input_tensor.reshape(target_shape) mapping = { "params-token_embedder-embedding": pad_embedding_layer, @@ -900,12 +896,8 @@ def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=F f"{prefix}-attention-conv1d-kernel": [f"model.layers.{i}.linear_attn.conv1d.weight" for i in hf_indices], f"{prefix}-attention-A_log": [f"model.layers.{i}.linear_attn.A_log" for i in hf_indices], f"{prefix}-attention-dt_bias": [f"model.layers.{i}.linear_attn.dt_bias" for i in hf_indices], - f"{prefix}-attention-norm-rms_norm-scale": [ - f"model.layers.{i}.linear_attn.norm.weight" for i in hf_indices - ], - f"{prefix}-attention-out_proj-kernel": [ - f"model.layers.{i}.linear_attn.out_proj.weight" for i in hf_indices - ], + f"{prefix}-attention-norm-rms_norm-scale": [f"model.layers.{i}.linear_attn.norm.weight" for i in hf_indices], + f"{prefix}-attention-out_proj-kernel": [f"model.layers.{i}.linear_attn.out_proj.weight" for i in hf_indices], } ) @@ -1237,9 +1229,7 @@ def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fals prefix = f"params-decoder-layers-layers_{block_idx}" block_mapping = { # Layer Norms - f"{prefix}-pre_self_attention_layer_norm-scale": [ - f"model.layers.{i}.input_layernorm.weight" for i in hf_indices - ], + f"{prefix}-pre_self_attention_layer_norm-scale": [f"model.layers.{i}.input_layernorm.weight" for i in hf_indices], f"{prefix}-post_self_attention_layer_norm-scale": [ f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices ], @@ -1429,9 +1419,7 @@ def add_prefix_recursive(value): mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-kernel"] = ( "thinker.visual.patch_embed.proj.weight" ) - mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-bias"] = ( - "thinker.visual.patch_embed.proj.bias" - ) + mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-bias"] = "thinker.visual.patch_embed.proj.bias" # Vision positional embedding mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-pos_embed_interpolate-pos_embed"] = ( @@ -1482,13 +1470,9 @@ def add_prefix_recursive(value): # Vision projector (final merger) mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-ln_q-scale"] = "thinker.visual.merger.ln_q.weight" mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-ln_q-bias"] = "thinker.visual.merger.ln_q.bias" - mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-kernel"] = ( - "thinker.visual.merger.mlp.0.weight" - ) + mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-kernel"] = "thinker.visual.merger.mlp.0.weight" mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-bias"] = "thinker.visual.merger.mlp.0.bias" - mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-kernel"] = ( - "thinker.visual.merger.mlp.2.weight" - ) + mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-kernel"] = "thinker.visual.merger.mlp.2.weight" mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-bias"] = "thinker.visual.merger.mlp.2.bias" # Audio mapping @@ -2359,26 +2343,24 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-1.7b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-1.7b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-4b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-8b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-1.7b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-1.7b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-4b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-8b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-14b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-30b-a3b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-30b-a3b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, "deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING, "gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, @@ -2399,26 +2381,24 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-1.7b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-1.7b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-4b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-8b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-1.7b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-1.7b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-4b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-8b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-14b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-30b-a3b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-30b-a3b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN, diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 814eefe0b5..04bb9d9da5 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -42,7 +42,7 @@ normalization_layer_epsilon: 1.e-05 # epsilon value for rmsnorm, layernorm. load_parameters_path: "" # LoRA adapter support configs -lora_input_adapters_path: "" # Input GCS path for a parent directory which has all the LoRA adapters (lora_id as subdir) +lora_input_adapters_path: "" # Input GCS path for a parent directory which has all the LoRA adapters (lora_id as subdir) # Loads a full checkpoint including optimizer state and step count from a specific directory # e.g. gs://my-base-output-directory/my-previous-run-name/checkpoints/items/NUMBER or NUMBER/items @@ -81,14 +81,12 @@ checkpoint_conversion_fn: none # optional checkpoint context to use for loading. options: "orbax", "safetensors" source_checkpoint_layout: "orbax" -# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing +# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing colocated_python_checkpointing: False ############################### end checkpointing ################################## - reuse_example_batch: 0 # for testing tpu performance, this options repeated uses the same batch. - metrics_file: "" # for testing, local file that stores scalar metrics. if empty, no metrics are written. # if true save metrics such as loss and tflops to gcs in {base_output_directory}/{run_name}/metrics/ gcs_metrics: false @@ -159,8 +157,8 @@ mlp_activations: ["silu", "linear"] mlp_activations_limit: -1.0 dropout_rate: 0.0 logits_via_embedding: false -normalize_embedding_logits: true # whether to normalize pre-softmax logits if logits_via_embedding is true -logits_dot_in_fp32: false # whether to use fp32 in logits_dense or shared_embedding dot product for stability +normalize_embedding_logits: true # whether to normalize pre-softmax logits if logits_via_embedding is true +logits_dot_in_fp32: false # whether to use fp32 in logits_dense or shared_embedding dot product for stability cast_logits_to_fp32: true # whether to cast the logits to fp32. the higher precision is generally beneficial, but it can vary slightly. float32_qk_product: false # in dot_product attention, whether to cast to fp32 the inputs to qk product float32_logits: false # in dot_product attention, whether to cast to fp32 the inputs to softmax @@ -304,33 +302,27 @@ scan_layers_per_stage: False set_remat_policy_on_pipeline_iterations: True set_remat_policy_on_layers_per_stage: False - # Choose 'remat_policy' between 'minimal_with_context', 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp', # 'save_qkv_proj', 'qkv_proj_offloaded', 'custom', 'minimal_offloaded', 'save_out_proj' and 'full'. # These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest) -remat_policy: 'full' +remat_policy: "full" # If "custom" remat_policy is chosen, you can select tensors from the following list to offload on host memory, rematerialize or save on device memory. # Pick one of these options for following tensors: ['remat','device','offload'] -decoder_layer_input: 'device' # this tensor cannot be rematerialized - it serves as periodic checkpoints that act as the remat start points -context: 'remat' # From https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/attention.py#L581-L583 -mlpwi: 'remat' -mlpwi_0: 'remat' -mlpwi_1: 'remat' -mlpwo: 'remat' -moe_mlpwi_0: 'remat' -moe_mlpwi_1: 'remat' -moe_mlpwo: 'remat' -query_proj: 'remat' -key_proj: 'remat' -value_proj: 'remat' -qkv_proj: 'remat' -out_proj: 'remat' -query_wa_proj: 'remat' -kv_wa_proj: 'remat' -mla_q: 'remat' -mla_kv: 'remat' -attention_out: 'remat' -engram: 'remat' +decoder_layer_input: "device" # this tensor cannot be rematerialized - it serves as periodic checkpoints that act as the remat start points +context: "remat" # From https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/attention.py#L581-L583 +mlpwi: "remat" +mlpwi_0: "remat" +mlpwi_1: "remat" +mlpwo: "remat" +query_proj: "remat" +key_proj: "remat" +value_proj: "remat" +qkv_proj: "remat" +out_proj: "remat" +mla_q: "remat" +mla_kv: "remat" +attention_out: "remat" +engram: "remat" optimizer_memory_host_offload: False parameter_memory_host_offload: False @@ -339,8 +331,8 @@ param_scan_axis: 1 # The attention parameter dictates the specific algorithm/methodology used to compute the attention scores # The attention_type parameter determines the variants of attention, e.g. global or local_sliding -attention: 'autoselected' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te -attention_type: 'global' # Supported attention_type: global, local_sliding, chunk, mla +attention: "autoselected" # Supported attention: autoselected, dot_product, flash, cudnn_flash_te +attention_type: "global" # Supported attention_type: global, local_sliding, chunk, mla share_kv_projections: False # Note: Not compatible with attention_type='mla' attention_bias: False # If True, adds a learnable bias to the query, key, and value projections attention_sink: False @@ -353,7 +345,6 @@ use_post_attn_norm: False use_post_ffw_norm: False mla_naive_kvcache: True - # Adding Mixture of Block Attention Support (MoBA): https://github.com/MoonshotAI/MoBA/blob/master/MoBA_Tech_Report.pdf moba: False moba_chunk_size: 1024 @@ -361,13 +352,13 @@ moba_topk: 8 # DeepSeek Sparse Attention (DSA) # deepseek3.2 introduces indexer in MLA -use_indexer: False -indexer_head_dim: 128 -indexer_n_heads: 64 -indexer_topk: 2048 -# Determines the training strategy for the indexer: -# - False (Dense Warm-up): Computes indexer loss over all tokens. Used with `trainable_parameters_mask` to freeze other model parameters. -# - True (Sparse Training): Computes indexer loss over top-k tokens only and detaches the indexer input for independent optimization. +use_sparse_indexer: False +index_head_dim: 128 +index_n_heads: 64 +index_topk: 2048 +# Determines the token selection strategy for indexer loss: +# - False: Uses all tokens (Dense Warm-up). +# - True: Uses only top-k tokens (Sparse Training). # Note: This is only active when `indexer_loss_scaling_factor` > 0. indexer_sparse_training: False # Multiplier for the indexer KL divergence loss @@ -381,8 +372,8 @@ qk_rope_head_dim: 64 v_head_dim: 128 # QK-Clip (Muon Clip) Configuration -use_qk_clip: False # Enable QK-Clip (supported in MLA with DotProduct or Tokamax Splash) -qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper) +use_qk_clip: False # Enable QK-Clip (supported in MLA with DotProduct or Tokamax Splash) +qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper) # Combine matmuls for QKV and MLP fused_qkv: False @@ -411,7 +402,6 @@ multi_tier_checkpointing_backup_interval_minutes: 0 # It should be a positive number when enabling multi-tier checkpointing. If set to 0, it will be set to num of slices. mtc_data_parallelism: 0 - # Whether to enable emergency checkpoint. If True, `local_checkpoint_directory` and a non-zero `local_checkpoint_period` must also be specified. # Emergency checkpoint is an experimental Orbax feature that: periodically saves to persistent storage and, with a larger invertal, saves to a local directory. # During restore, if a local copy is available in any slice, it will be broadcast to other slices without having to fetch from persistent storage. @@ -428,7 +418,7 @@ local_checkpoint_period: 0 jax_cache_dir: "~/jax_cache" # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' +hardware: "tpu" # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' # internal_compile allows bypassing open-source topology name mappings when using internal topologies directly via get_topology_desc. internal_compile: False @@ -436,111 +426,212 @@ internal_compile_num_devices: -1 # You must specify the number of devices when u # Parallelism shard_mode: "auto" # can be either auto or explicit -custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/. -mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] -logical_axis_rules: [ - ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], - ['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']], - ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], - ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']], - ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], - ['activation_length', ['sequence', 'context', 'expert']], - ['activation_length', ['context', 'expert']], - ['activation_attn_length', ['sequence', 'context', 'expert']], - ['activation_attn_length', ['context', 'expert']], - ['activation_attn_length_no_exp', ['sequence', 'context']], - ['activation_attn_length_no_exp', ['context']], - ['activation_length_no_exp', ['sequence', 'context']], - ['activation_length_no_exp', ['context']], - ['activation_length_no_exp_moe', ['sequence', 'context']], - ['activation_length_no_exp_moe', ['context']], - ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']], - ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']], - ['activation_q_length', ['context', 'expert']], - ['activation_q_length_no_exp', ['context']], - ['prefill_activation_length', ['sequence', 'context']], - ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']], - ['activation_kv_length', []], - ['activation_attn_embed', ['tensor', 'tensor_transpose']], - ['activation_embed', ['tensor', 'tensor_transpose']], - ['activation_embed_moe', ['tensor', 'tensor_transpose']], - ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_kv_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], - ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_vocab', ['tensor', 'tensor_transpose']], - ['activation_vocab', 'tensor_sequence'], - ['activation_vocab', ['sequence','context']], - ['activation_stage', 'stage'], - ['activation_exp', ['expert']], - ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['decode_length', ['sequence']], - ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], - ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], - ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], - ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], - ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], - ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']], - ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], - ['embed', ['fsdp', 'sequence', 'context', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], - ['embed_no_exp', ['fsdp', 'sequence', 'context']], - ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']], - ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']], - ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], - ['embed_moe', ['fsdp', 'sequence', 'context', 'expert']], - ['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], - ['embed_no_exp_moe', ['fsdp', 'sequence', 'context']], - ['embed_tensor_transpose', ['tensor_transpose']], - ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], - ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], - ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], - ['q_lora', ['fsdp', 'sequence', 'context', 'expert']], - ["q_lora_up_proj",[]], - ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], - ['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], - ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], - ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']], - ["kv_lora_up_proj",[]], - ['norm', ['tensor', 'tensor_transpose']], - ['layers', 'stage'], - ['qkv', []], - ['kv', []], - ['kv_head_dim', []], - ['cache_batch_prefill', []], - ['cache_batch', []], - ['cache_heads_none', []], - ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], - ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], - ['cache_kv', []], - ['cache_sequence', []], - ['exp', 'expert'], - ['exp_with_fsdp', 'fsdp'], - ['paged_kv_heads', ['tensor']], - ['num_pages', []], - ['tokens_per_page', []], - ['paged_kv_head_dim_size', []], - ['dense_layers', []], - ['moe_layers', []], - ['engram_dim', ['tensor']], - ['mhc', []], - ['diloco', 'diloco'], - ] +mesh_axes: + [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive", + ] +logical_axis_rules: + [ + ["activation_batch", ["data", "fsdp", "fsdp_transpose", "expert"]], + ["activation_batch_no_exp", ["data", "fsdp", "fsdp_transpose"]], + [ + "activation_embed_and_logits_batch", + ["data", "stage", "fsdp", "fsdp_transpose", "expert"], + ], + [ + "activation_embed_and_logits_batch_sequence", + [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "expert", + ], + ], + [ + "activation_heads", + [ + "tensor", + "tensor_transpose", + "sequence", + "tensor_sequence", + "autoregressive", + ], + ], + [ + "activation_kv_heads", + ["tensor", "tensor_transpose", "sequence", "tensor_sequence"], + ], + ["activation_length", ["sequence", "context", "expert"]], + ["activation_length", ["context", "expert"]], + ["activation_attn_length", ["sequence", "context", "expert"]], + ["activation_attn_length", ["context", "expert"]], + ["activation_attn_length_no_exp", ["sequence", "context"]], + ["activation_attn_length_no_exp", ["context"]], + ["activation_length_no_exp", ["sequence", "context"]], + ["activation_length_no_exp", ["context"]], + ["activation_norm_length", ["tensor_sequence", "context", "sequence"]], + ["activation_q_length", ["context", "expert"]], + ["activation_q_length_no_exp", ["context"]], + ["prefill_activation_length", ["sequence", "context"]], + [ + "prefill_activation_norm_length", + ["tensor_sequence", "context", "sequence"], + ], + ["activation_kv_length", []], + ["activation_attn_embed", ["tensor", "tensor_transpose"]], + ["activation_embed", ["tensor", "tensor_transpose"]], + ["activation_mlp", ["tensor", "tensor_transpose", "tensor_sequence"]], + ["activation_kv", ["tensor", "tensor_transpose", "tensor_sequence"]], + [ + "activation_prefill_kv_batch", + ["data", "fsdp", "fsdp_transpose", "expert"], + ], + ["activation_kv_batch", ["data", "fsdp", "fsdp_transpose", "expert"]], + ["activation_kv_batch_no_exp", ["data", "fsdp", "fsdp_transpose"]], + [ + "activation_kv_head_dim", + ["tensor", "tensor_transpose", "tensor_sequence"], + ], + ["activation_vocab", ["tensor", "tensor_transpose", "tensor_sequence"]], + ["activation_vocab", ["tensor", "tensor_transpose"]], + ["activation_vocab", "tensor_sequence"], + ["activation_vocab", ["sequence", "context"]], + ["activation_stage", "stage"], + ["activation_exp", ["expert"]], + ["decode_batch", ["data", "fsdp", "fsdp_transpose", "expert"]], + ["decode_length", ["sequence"]], + ["mlp", ["fsdp_transpose", "tensor", "tensor_sequence", "autoregressive"]], + ["mlp_no_fsdp", ["tensor", "tensor_sequence", "autoregressive"]], + [ + "vocab", + ["tensor", "tensor_transpose", "tensor_sequence", "autoregressive"], + ], + [ + "heads", + ["tensor", "tensor_transpose", "tensor_sequence", "autoregressive"], + ], + [ + "q_heads", + ["tensor", "tensor_transpose", "tensor_sequence", "autoregressive"], + ], + [ + "kv_heads", + ["tensor", "tensor_transpose", "tensor_sequence", "autoregressive"], + ], + [ + "embed", + [ + "fsdp", + "fsdp_transpose", + "sequence", + "tensor_transpose", + "context", + "expert", + ], + ], + ["embed", ["fsdp", "sequence", "tensor_transpose", "context", "expert"]], + ["embed", ["fsdp", "fsdp_transpose", "sequence", "context", "expert"]], + ["embed", ["fsdp", "sequence", "context", "expert"]], + [ + "embed_no_exp", + ["fsdp", "fsdp_transpose", "sequence", "tensor_transpose", "context"], + ], + ["embed_no_exp", ["fsdp", "sequence", "tensor_transpose", "context"]], + ["embed_no_exp", ["fsdp", "fsdp_transpose", "sequence", "context"]], + ["embed_no_exp", ["fsdp", "sequence", "context"]], + ["embed_tensor_transpose", ["tensor_transpose"]], + [ + "q_lora", + [ + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "tensor_transpose", + "expert", + ], + ], + ["q_lora", ["fsdp", "sequence", "context", "tensor_transpose", "expert"]], + ["q_lora", ["fsdp", "fsdp_transpose", "sequence", "context", "expert"]], + ["q_lora", ["fsdp", "sequence", "context", "expert"]], + ["q_lora_up_proj", []], + [ + "kv_lora", + [ + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "tensor_transpose", + "expert", + ], + ], + ["kv_lora", ["fsdp", "sequence", "context", "tensor_transpose", "expert"]], + ["kv_lora", ["fsdp", "fsdp_transpose", "sequence", "context", "expert"]], + ["kv_lora", ["fsdp", "sequence", "context", "expert"]], + ["kv_lora_up_proj", []], + ["norm", ["tensor", "tensor_transpose"]], + ["layers", "stage"], + ["qkv", []], + ["kv", []], + ["kv_head_dim", []], + ["cache_batch_prefill", []], + ["cache_batch", []], + ["cache_heads_none", []], + [ + "cache_heads", + ["autoregressive", "tensor", "tensor_transpose", "tensor_sequence"], + ], + ["cache_heads", ["autoregressive", "tensor", "tensor_sequence"]], + ["cache_kv", []], + ["cache_sequence", []], + ["exp", "expert"], + ["exp_with_fsdp", "fsdp"], + ["paged_kv_heads", ["tensor"]], + ["num_pages", []], + ["tokens_per_page", []], + ["paged_kv_head_dim_size", []], + ["dense_layers", []], + ["moe_layers", []], + ["engram_dim", ["tensor"]], + ["mhc", []], + ["diloco", "diloco"], + ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details -data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] -input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length'] +data_sharding: + [ + [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive", + ], + ] +input_data_sharding_logical_axes: + ["activation_embed_and_logits_batch", "activation_norm_length"] # sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters. sharding_tolerance: 0.02 @@ -550,10 +641,10 @@ sharding_tolerance: 0.02 # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. dcn_diloco_parallelism: 1 -dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded +dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 dcn_fsdp_transpose_parallelism: 1 -dcn_sequence_parallelism: 1 # never recommended +dcn_sequence_parallelism: 1 # never recommended dcn_context_parallelism: 1 dcn_context_autoregressive_parallelism: 1 dcn_tensor_parallelism: 1 # never recommended @@ -600,8 +691,8 @@ tokenizer_path: "" tokenizer_type: "sentencepiece" # Currently supporting: "tiktoken", "sentencepiece", "huggingface" use_chat_template: False chat_template_path: "" # path to chat template json file -tokenize_train_data: True # False if the dataset is pre-tokenized -tokenize_eval_data: True # False if the dataset is pre-tokenized +tokenize_train_data: True # False if the dataset is pre-tokenized +tokenize_eval_data: True # False if the dataset is pre-tokenized add_bos: True add_eos: True # If False, use chunking for long sequences instead of truncation. @@ -619,10 +710,10 @@ per_device_batch_size: 12.0 expansion_factor_real_data: -1.0 eval_per_device_batch_size: 0.0 max_corpus_chars: 10_000_000 -train_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected" -train_image_column: 'image' -eval_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected" -eval_image_column: 'image' +train_data_columns: ["text"] # for DPO dataset containing "chosen" and "rejected" +train_image_column: "image" +eval_data_columns: ["text"] # for DPO dataset containing "chosen" and "rejected" +eval_image_column: "image" packing: True num_epoch: 1 generate_padding_batch_train: False @@ -661,30 +752,30 @@ sft_train_on_completion_only: False dataset_type: tfds # for TFDS input pipeline (dataset_type=tfds) dataset_path: "" # your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/" -dataset_name: 'c4/en:3.0.1' -eval_dataset_name: 'c4/en:3.0.1' -train_split: 'train' -eval_split: 'validation' +dataset_name: "c4/en:3.0.1" +eval_dataset_name: "c4/en:3.0.1" +train_split: "train" +eval_split: "validation" # for HuggingFace input pipeline (dataset_type=hf) # Check definition at https://github.com/huggingface/datasets/blob/0feb65dd8733191dd2d1e74215b422fc5939a56a/src/datasets/load.py#L1338-L1408 -hf_path: '' -hf_name: '' -hf_data_dir: '' -hf_train_files: '' -hf_eval_split: '' -hf_eval_files: '' -hf_access_token: '' +hf_path: "" +hf_name: "" +hf_data_dir: "" +hf_train_files: "" +hf_eval_split: "" +hf_eval_files: "" +hf_access_token: "" # for Grain input pipeline (dataset_type=grain) # Path to grain data files. Can be a single pattern or multiple patterns with weights. # For multiple patterns, use semicolon (;) to separate and comma (,) to specify weights. # Example: "path/to/data1.array_record*,0.3;path/to/data2.array_record*,0.7" # Note: When using multiple files (separated by ';'), only ArrayRecord format is supported. # For more details, see https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline/data_input_grain.md -grain_train_files: '' -grain_eval_files: '' -grain_train_mixture_config_path: '' # Path to a JSON file specifying the mixture weights for Grain training data. -grain_file_type: 'arrayrecord' # arrayrecord or parquet -grain_packing_type: 'first_fit' # 'first_fit', 'best_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html +grain_train_files: "" +grain_eval_files: "" +grain_train_mixture_config_path: "" # Path to a JSON file specifying the mixture weights for Grain training data. +grain_file_type: "arrayrecord" # arrayrecord or parquet +grain_packing_type: "first_fit" # 'first_fit', 'best_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html grain_worker_count: 1 # Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html grain_per_worker_buffer_size: 1 # num_threads and prefetch_buffer_size are per-worker per-dataset. @@ -698,10 +789,9 @@ grain_per_worker_buffer_size_eval: 1 grain_ram_budget_mb: 1024 # RAM budget (MB) for auto-tuning worker count. Only used when grain_worker_count is -1. grain_num_threads_eval: 16 grain_prefetch_buffer_size_eval: 500 -grain_data_source_max_workers: 16 # Max workers for ThreadPoolExecutor when mixing multiple Grain data sources. -grain_shuffle_buffer_size: 100 # shuffle buffer when using sequential access formats such as Parquet, TFRecord. +grain_data_source_max_workers: 16 # Max workers for ThreadPoolExecutor when mixing multiple Grain data sources. # for using pathways -colocated_python_data_input: False # experimental feature, under testing +colocated_python_data_input: False # experimental feature, under testing # Training loop steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps @@ -733,11 +823,11 @@ skip_jax_distributed_system: False # If True we will not initialize the jax dist # # The zero learning rate section can be used to more accurately measure the fully trained model's performance. learning_rate: 3.e-5 -lr_schedule_type: 'cosine' # Options: 'cosine' or 'wsd' -learning_rate_final_fraction: 0.1 # Final LR as fraction of peak LR (applies to both cosine and WSD schedules) -wsd_decay_steps_fraction: 0.1 # Fraction of learning_rate_schedule_steps used for decay phase in WSD (e.g., 0.1 = 10%) -wsd_decay_style: 'linear' # Decay style for WSD schedule: 'linear' or 'cosine' -warmup_steps_fraction: 0.1 # Fraction of learning_rate_schedule_steps used for warmup phase (applies to both schedules) +lr_schedule_type: "cosine" # Options: 'cosine' or 'wsd' +learning_rate_final_fraction: 0.1 # Final LR as fraction of peak LR (applies to both cosine and WSD schedules) +wsd_decay_steps_fraction: 0.1 # Fraction of learning_rate_schedule_steps used for decay phase in WSD (e.g., 0.1 = 10%) +wsd_decay_style: "linear" # Decay style for WSD schedule: 'linear' or 'cosine' +warmup_steps_fraction: 0.1 # Fraction of learning_rate_schedule_steps used for warmup phase (applies to both schedules) learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. # However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before # dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. @@ -769,7 +859,7 @@ profile_periodically_period: -1 # If set to a positive integer, profile every pr # - upload xplane profiling, if it is enabled. # - upload training metrics, at the defined log_period interval. managed_mldiagnostics: False # Whether to enable the managed diagnostics -managed_mldiagnostics_run_group: "" # Optional. Used to group multiple runs. +managed_mldiagnostics_run_group: "" # Optional. Used to group multiple runs. # Dump HLO and jaxpr options dump_hlo: False @@ -808,11 +898,7 @@ gradient_clipping_threshold: 1.0 # batch by accumulating the gradient over a set of steps. gradient_accumulation_steps: 1 -opt_type: "adamw" # one of "adamw", "adam_pax", "sgd", or "muon" -# List of parameter names/patterns to train. -# If non-empty, all other parameters will be frozen. Example: ['.*indexer.*']. -# If empty (default), all parameters are trained. -trainable_parameters_mask: [] +opt_type: "adamw" # one of "adamw", "adam_pax", "sgd", or "muon" # AdamW optimizer parameters # We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 @@ -836,14 +922,14 @@ muon_consistent_rms: None # If None, apply width scaling to updates. If float, a # Stack trace parameters collect_stack_trace: False -stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False. -stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds. +stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False. +stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds. # Use iota operator in Embed use_iota_embed: False # use positional embedding use_untrainable_positional_embedding: False -trainable_position_size: -1 # enable gpt3 position embedding with a positive trainable_position_size +trainable_position_size: -1 # enable gpt3 position embedding with a positive trainable_position_size # RoPE parameters rope_type: "default" # one of "default", "llama3.1" or "yarn" rope_linear_scaling_factor: 1.0 # linear scaling factor for "default" RoPE (see class `RotaryEmbedding` for more) @@ -866,7 +952,7 @@ rope_attention_scaling: False # Scale the rotary embedding output # Ahead of time Compilation (aka AOT) # Only set these arguments if you are running train_compile or loading a compiled train step. compiled_trainstep_file: "" # Name of saved serialized compiled train_step, e.g. compiled_train_v5e-256.pickle -compile_topology: '' # Target hardware version, e.g. 'v5e-256' +compile_topology: "" # Target hardware version, e.g. 'v5e-256' compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. decode_sampling_strategy: "greedy" # decode_sampling_strategy should be one of greedy, weighted, nucleus, topk, or composite(top_k -> top_p -> weighted temperature) @@ -874,11 +960,9 @@ decode_sampling_nucleus_p: -1 # set if you're doing nucleus / top-p decode_sampling_top_k: 0 # set if you're doing top-k decode_sampling_temperature: 1. -eval_interval: -1 # the specific number of train step between eval_step -eval_steps: -1 # run this number of steps for eval, recommend setting this to prevent error due to running out of evel data -target_eval_loss: 0. # early stop once reaching target eval_loss -abort_on_nan_loss: True # Check for NaN and abort if found in training loss -abort_on_inf_loss: True # Check for Inf and abort if found in training loss +eval_interval: -1 # the specific number of train step between eval_step +eval_steps: -1 # run this number of steps for eval, recommend setting this to prevent error due to running out of evel data +target_eval_loss: 0. # early stop once reaching target eval_loss # Goodput parameters enable_goodput_recording: False @@ -917,12 +1001,12 @@ inference_microbenchmark_loop_iters: 10 inference_microbenchmark_log_file_path: "" inference_microbenchmark_num_samples: [1, 2, 3, 4, 5] inference_metadata_file: "" # path to a json file -inference_server: "MaxtextInterleavedServer" # inference server to start +inference_server: "MaxtextInterleavedServer" # inference server to start prefill_slice: "v5e-16" # slice to use for prefill in disaggregation mode generate_slice: "v5e-16" # slice to use for generatation in disaggregation mode inference_benchmark_test: False enable_model_warmup: False -enable_llm_inference_pool: False # Bool to launch inference server for llm_inference_gateway with their specified APIs +enable_llm_inference_pool: False # Bool to launch inference server for llm_inference_gateway with their specified APIs multi_sampling: False return_log_prob: False @@ -1005,15 +1089,14 @@ context_parallel_strategy: "all_gather" # "all_gather" or "ring" # These settings take effect only when `attention=paged`. # They should be adjusted based on the available HBM and model config. # Note: one page group corresponds to one request/slot -pagedattn_num_pages: 64 # total number of pages to allocate -pagedattn_tokens_per_page: 32 # number of tokens each page can hold -pagedattn_pages_per_compute_block: 4 # number of pages processed together in pallas kernels -pagedattn_max_pages_per_group: -1 # defaults to number of pages needed to reach max_target_length +pagedattn_num_pages: 64 # total number of pages to allocate +pagedattn_tokens_per_page: 32 # number of tokens each page can hold +pagedattn_pages_per_compute_block: 4 # number of pages processed together in pallas kernels +pagedattn_max_pages_per_group: -1 # defaults to number of pages needed to reach max_target_length # Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On # TPUs, the head_dim is padded to the nearest multiple of 128. pagedattn_head_dim_alignment: 128 - # Chunked Prefill Parameters prefill_chunk_size: 256 use_chunked_prefill: False @@ -1045,8 +1128,8 @@ use_multimodal: False use_audio: False freeze_vision_encoder_params: True freeze_audio_encoder_params: True -dtype_mm: "float32" # Data type for multimodal model's vision encoder -remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision encoder. Check `remat_policy` for options. +dtype_mm: "float32" # Data type for multimodal model's vision encoder +remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision encoder. Check `remat_policy` for options. image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config image_path: "" # Local image path used for decoding, can be multiple paths separated by comma, exp "/path/image1.jpg,/path/image2.jpg" video_path: "" # Local video path used for decoding, can be multiple paths separated by comma, exp "/path/video1.mp4,/path/video2.mp4" @@ -1144,7 +1227,7 @@ use_jax_splash: false # Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter) vllm_hf_config_path: "" # A JSON string of overrides to apply to the HuggingFace-style config for the vLLM adapter. -# This can be used to override specific settings without modifying the original config file. +# This can be used to override specific settings without modifying the original config file. vllm_hf_overrides: {} # JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}') vllm_additional_config: {} @@ -1159,7 +1242,7 @@ sinkhorn_iterations: 20 ################################## DeepSeek Engram ################################## # Indices of transformer layers where Engram are integrated; leave empty [] to disable. -# Example: [1, 4] attaches to the 2nd and 5th layer. +# Example: [1, 4] attaches to the 2nd and 5th layer. engram_layers: [] # The max 'n' in N-gram. Example: n=3 means it covers both 2-grams and 3-grams. engram_max_ngram_size: 3 diff --git a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml index e55a173f1f..589e9d80b2 100644 --- a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml +++ b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml @@ -14,7 +14,7 @@ # model config for DeepSeek V3 - 671B that uses fsdp on two logical axes -# For DeepSeek default device-limited routing, +# For DeepSeek default device-limited routing, # please set n_routing_groups=8 and topk_routing_group=4 in your command-line arguments. base_emb_dim: 7168 @@ -24,7 +24,7 @@ base_mlp_dim: 18432 base_moe_mlp_dim: 2048 base_num_decoder_layers: 61 first_num_dense_layers: 3 -mlp_activations: ["silu","linear"] +mlp_activations: ["silu", "linear"] vocab_size: 129280 enable_dropout: False logits_via_embedding: False @@ -56,31 +56,44 @@ rope_truncate: True rope_attention_scaling: False override_logical_axis_rules: True -mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context'] -data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']] -logical_axis_rules: [ - ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], - ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], - ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']], - ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], - ['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_norm_length', ['context']], - ['activation_norm_length_moe', ['context']], - ['activation_heads', []], - ['activation_stage', 'stage'], - ['embed', ['fsdp']], - ['embed_moe', ['fsdp']], - ['embed_no_exp', ['fsdp']], - ['embed_no_exp_moe', ['fsdp']], - ['q_lora', ['fsdp']], - ['kv_lora', ['fsdp']], - ['layers', 'stage'], - ['q_lora_up_proj', ['fsdp_transpose', 'expert']], - ['kv_lora_up_proj', ['fsdp_transpose', 'expert']], - ['q_heads', ['fsdp_transpose', 'expert']], - ['kv_heads', ['fsdp_transpose', 'expert']], - ['heads', ['fsdp_transpose', 'expert']], - ['mlp', ['fsdp_transpose', 'expert']], - ['mlp_only_fsdp_transpose', ['fsdp_transpose']], - ['mlp_only_tensor', ['expert']], -] +mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "expert", "context"] +data_sharding: + [["data", "stage", "fsdp", "fsdp_transpose", "expert", "context"]] +logical_axis_rules: + [ + [ + "activation_batch", + ["data", "fsdp", "fsdp_transpose", "expert", "context"], + ], + [ + "activation_embed_and_logits_batch", + ["data", "stage", "fsdp", "fsdp_transpose", "expert", "context"], + ], + [ + "activation_kv_batch", + ["data", "fsdp", "fsdp_transpose", "expert", "context"], + ], + [ + "activation_embed_and_logits_batch", + ["data", "fsdp", "fsdp_transpose", "expert"], + ], + ["activation_norm_length", ["context"]], + ["activation_norm_length_moe", ["context"]], + ["activation_heads", []], + ["activation_stage", "stage"], + ["embed", ["fsdp"]], + ["embed_moe", ["fsdp"]], + ["embed_no_exp", ["fsdp"]], + ["embed_no_exp_moe", ["fsdp"]], + ["q_lora", ["fsdp"]], + ["kv_lora", ["fsdp"]], + ["layers", "stage"], + ["q_lora_up_proj", ["fsdp_transpose", "expert"]], + ["kv_lora_up_proj", ["fsdp_transpose", "expert"]], + ["q_heads", ["fsdp_transpose", "expert"]], + ["kv_heads", ["fsdp_transpose", "expert"]], + ["heads", ["fsdp_transpose", "expert"]], + ["mlp", ["fsdp_transpose", "expert"]], + ["mlp_only_fsdp_transpose", ["fsdp_transpose"]], + ["mlp_only_tensor", ["expert"]], + ] diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index 78d783270f..d276349999 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -29,7 +29,7 @@ import omegaconf from maxtext.configs import pyconfig_deprecated -from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_ASSETS_ROOT, HF_IDS, MAXTEXT_PKG_DIR +from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_ASSETS_ROOT, HF_IDS from maxtext.common.common_types import DecoderBlockType, ShardMode from maxtext.configs import types from maxtext.configs.types import MaxTextConfig @@ -83,9 +83,7 @@ def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]: return resolve_config_path(argv[1]), argv[2:] module = _module_from_path(argv[0]) if module not in _CONFIG_FILE_MAPPING: - raise ValueError( - f"No config file provided and no default config found for module '{module}'" - ) + raise ValueError(f"No config file provided and no default config found for module '{module}'") config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module]) logger.warning("No config file provided, using default config mapping: %s", config_path) return config_path, argv[1:] @@ -242,9 +240,7 @@ def __init__(self, pydantic_config: types.MaxTextConfig): final_dict["dtype"] = jnp.dtype(final_dict["dtype"]) final_dict["grad_dtype"] = jnp.dtype(final_dict["grad_dtype"]) final_dict["weight_dtype"] = jnp.dtype(final_dict["weight_dtype"]) - final_dict["mu_dtype"] = ( - final_dict["weight_dtype"] if not final_dict["mu_dtype"] else jnp.dtype(final_dict["mu_dtype"]) - ) + final_dict["mu_dtype"] = final_dict["weight_dtype"] if not final_dict["mu_dtype"] else jnp.dtype(final_dict["mu_dtype"]) final_dict["logical_axis_rules"] = _lists_to_tuples(final_dict["logical_axis_rules"]) final_dict["data_sharding"] = _lists_to_tuples(final_dict["data_sharding"]) diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 0dc581d1d0..15bf9d05b1 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -306,9 +306,7 @@ class Checkpointing(BaseModel): async_checkpointing: bool = Field(True, description="If True, uses an asynchronous checkpointer for performance.") checkpoint_period: int = Field(10_000, description="The frequency (in steps) at which to save checkpoints.") max_num_checkpoints_to_keep: int | None = Field(None, description="Maximum number of checkpoints to keep.") - enable_single_replica_ckpt_restoring: bool = Field( - False, description="One replica reads and broadcasts the checkpoint." - ) + enable_single_replica_ckpt_restoring: bool = Field(False, description="One replica reads and broadcasts the checkpoint.") force_unroll: bool = Field( False, description="During param-only checkpoint generation, whether to unroll the loop.", @@ -340,9 +338,7 @@ class OrbaxStorage(BaseModel): 2147483648, description="Target file size for chunking large arrays in Orbax." ) checkpoint_storage_use_ocdbt: bool = Field(True, description="Whether to use the OCDbT storage format for checkpoints.") - checkpoint_storage_use_zarr3: bool = Field( - True, description="Whether to use Zarr3 with OCDbT. Requires use_ocdbt=True." - ) + checkpoint_storage_use_zarr3: bool = Field(True, description="Whether to use Zarr3 with OCDbT. Requires use_ocdbt=True.") checkpoint_storage_concurrent_gb: int = Field(96, description="Concurrent GB for I/O operations during checkpointing.") @@ -831,9 +827,7 @@ class DcnParallelism(BaseModel): dcn_context_autoregressive_parallelism: int = Field(1, description="DCN axis for context autoregressive parallelism.") dcn_tensor_parallelism: int = Field(1, description="DCN axis for tensor parallelism (not recommended).") dcn_tensor_transpose_parallelism: int = Field(1, description="DCN axis for tensor transpose parallelism.") - dcn_tensor_sequence_parallelism: int = Field( - 1, description="DCN axis for tensor sequence parallelism (not recommended)." - ) + dcn_tensor_sequence_parallelism: int = Field(1, description="DCN axis for tensor sequence parallelism (not recommended).") dcn_pipeline_parallelism: int = Field(1, description="DCN axis for pipeline parallelism.") dcn_expert_parallelism: int = Field(1, description="DCN axis for expert parallelism.") dcn_autoregressive_parallelism: int = Field(1, description="DCN axis for autoregressive parallelism (not recommended).") @@ -892,9 +886,7 @@ class RematAndOffload(BaseModel): description="The rematerialization policy, trading off speed and memory.", ) remat_policy_for_vit: str = Field("minimal", description="Remat policy for multimodal model's vision encoder.") - decoder_layer_input: RematLocation = Field( - RematLocation.DEVICE, description="Remat policy for the decoder layer's input." - ) + decoder_layer_input: RematLocation = Field(RematLocation.DEVICE, description="Remat policy for the decoder layer's input.") context: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the attention context.") mlpwi: RematLocation = Field( RematLocation.REMAT, @@ -1089,9 +1081,7 @@ class FineTuning(BaseModel): dpo_label_smoothing: float = Field(0.0, ge=0.0, le=1.0, description="Label smoothing for DPO.") dpo_beta: float = Field(0.1, description="Beta parameter for DPO.") use_sft: bool = Field(False, description="If True, enables Supervised Fine-Tuning.") - sft_train_on_completion_only: bool = Field( - False, description="If True, trains only on the completion part of the text." - ) + sft_train_on_completion_only: bool = Field(False, description="If True, trains only on the completion part of the text.") use_grpo: None | bool = Field(None, description="If True, enables Group Relative Policy Optimization.") @@ -1172,9 +1162,7 @@ class Optimizer(BaseModel): """Configuration for the optimizer and learning rate schedule.""" opt_type: OptimizerType = Field(OptimizerType.ADAMW, description="The type of optimizer to use.") - gradient_accumulation_steps: PositiveInt = Field( - 1, description="Number of steps to accumulate gradients before updating." - ) + gradient_accumulation_steps: PositiveInt = Field(1, description="Number of steps to accumulate gradients before updating.") use_tunix_gradient_accumulation: bool = Field( False, description="Whether to use the Tunix implementation for gradient accumulation.", @@ -1269,9 +1257,7 @@ class PositionalEmbedding(BaseModel): False, description="Use iota operator in Embed, an efficient way to represent positions.", ) - use_untrainable_positional_embedding: bool = Field( - False, description="Use untrainable sinusoidal positional embeddings." - ) + use_untrainable_positional_embedding: bool = Field(False, description="Use untrainable sinusoidal positional embeddings.") trainable_position_size: int = Field( -1, description="Enables GPT-3 style trainable positional embeddings if positive.", @@ -1358,9 +1344,7 @@ class InferenceServer(BaseModel): class InferenceBenchmark(BaseModel): """Configuration for running inference microbenchmarks.""" - inference_microbenchmark_prefill_lengths: str = Field( - "64,128,256,512,1024", description="Prefill lengths to benchmark." - ) + inference_microbenchmark_prefill_lengths: str = Field("64,128,256,512,1024", description="Prefill lengths to benchmark.") inference_microbenchmark_stages: str = Field("prefill,generate", description="Stages to benchmark.") inference_microbenchmark_loop_iters: int = Field(10, description="Number of iterations for the benchmark loop.") inference_microbenchmark_log_file_path: PathStr = Field("", description="Path to log benchmark results.") @@ -1518,9 +1502,7 @@ class Goodput(BaseModel): class GcpMonitoring(BaseModel): """Configuration for GCP-specific workload monitoring.""" - report_heartbeat_metric_for_gcp_monitoring: bool = Field( - False, description="Report heartbeat metric for GCP monitoring." - ) + report_heartbeat_metric_for_gcp_monitoring: bool = Field(False, description="Report heartbeat metric for GCP monitoring.") heartbeat_reporting_interval_in_seconds: int = Field(5, description="Interval for heartbeat metric.") report_performance_metric_for_gcp_monitoring: bool = Field( False, description="Report performance metric for GCP monitoring." @@ -1850,16 +1832,12 @@ class DerivedValues(BaseModel): ) rampup_end_step: None | int = Field(None, description="The step at which the batch size ramp-up phase concludes.") - tensors_on_device: None | list[str] = Field( - None, description="List of tensors to keep on device memory for custom remat." - ) + tensors_on_device: None | list[str] = Field(None, description="List of tensors to keep on device memory for custom remat.") tensors_to_offload: None | list[str] = Field( None, description="List of tensors to offload to host memory for custom remat." ) global_batch_size_to_load_start: None | int = Field(None, description="Starting global batch size for rampup.") - global_batch_size_to_load_increment: None | int = Field( - None, description="Increment for global batch size during rampup." - ) + global_batch_size_to_load_increment: None | int = Field(None, description="Increment for global batch size during rampup.") rampup_samples_per_increment_to_load: None | float = Field(None, description="Samples per increment for rampup.") @@ -2122,9 +2100,7 @@ def validate_and_set_hlo_dump_defaults(): if not os.environ.get("XLA_FLAGS") and not self.dump_hlo_xla_flags: self.dump_hlo_xla_flags = f"--xla_dump_to={self.dump_hlo_local_dir} --xla_dump_large_constants" if self.dump_hlo_local_module_name: - self.dump_hlo_xla_flags = ( - f"{self.dump_hlo_xla_flags} --xla_dump_hlo_module_re={self.dump_hlo_local_module_name}" - ) + self.dump_hlo_xla_flags = f"{self.dump_hlo_xla_flags} --xla_dump_hlo_module_re={self.dump_hlo_local_module_name}" if not self.dump_hlo_gcs_dir: self.dump_hlo_gcs_dir = os.path.join(self.base_output_directory, self.run_name, "xla_dump") else: @@ -2328,9 +2304,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de assert ( self.num_layers_per_pipeline_stage == 1 ), "Pipeline weight prefetching currently only supports one layer per pipeline stage." - assert ( - not self.pipeline_delay_activation_forwarding - ), "Pipeline weight prefetching does not support pipeline delay." + assert not self.pipeline_delay_activation_forwarding, "Pipeline weight prefetching does not support pipeline delay." assert not self.quantization, "Quantization is currently not supported for pipeline prefetching." assert not self.scan_layers_per_stage, "Pipeline weight prefetching currently does not support scan." @@ -2497,18 +2471,13 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ) if self.quantization: raise ValueError("Quantization is not supported with 'explicit' sharding.") - if ( - self.per_device_batch_size > 0 - and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0 - ): + if self.per_device_batch_size > 0 and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0: raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") if self.num_vocab_tiling > 1 and self.enable_nnx: raise ValueError("We currently don't support vocab tiling on NNX module.") if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring": if "gpu" not in self.hardware: - raise ValueError( - "Ring context parallelism strategy (context_parallel_strategy='ring') is only supported on GPUs." - ) + raise ValueError("Ring context parallelism strategy (context_parallel_strategy='ring') is only supported on GPUs.") if self.hardware == "gpu" and self.packing and self.attention == "cudnn_flash_te" and self.max_segments_per_seq <= 0: raise ValueError("max_segments_per_seq must be set when using TransformerEngine attention and packing") dcn_product = ( @@ -2597,9 +2566,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de if not (self.decoder_block == DecoderBlockType.DEEPSEEK and self.sparse_matmul and self.use_tokamax_gmm): raise ValueError("Batch split only supports deepseek, with `sparse_matmul=True` and `use_tokamax_gmm=True`") if self.quantization and not (self.use_qwix_quantization and self.quantization == "fp8_full"): - raise ValueError( - "Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`" - ) + raise ValueError("Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`") if self.opt_type == "muon" and self.decoder_block not in [ DecoderBlockType.DEEPSEEK, @@ -2632,6 +2599,39 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de raise ValueError("`share_kv_projections` is not compatible with `attention_type='mla'`.") # I. FINAL TYPE CONVERSIONS AND DERIVED LISTS + # Create the ici_parallelism and dcn_parallelism lists for legacy compatibility. + # if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage": + # self.ici_parallelism = [ + # self.ici_diloco_parallelism, + # self.ici_pipeline_parallelism, + # self.ici_data_parallelism, + # self.ici_fsdp_parallelism, + # self.ici_fsdp_transpose_parallelism, + # self.ici_sequence_parallelism, + # self.ici_context_parallelism, + # self.ici_context_autoregressive_parallelism, + # self.ici_tensor_parallelism, + # self.ici_tensor_transpose_parallelism, + # self.ici_tensor_sequence_parallelism, + # self.ici_expert_parallelism, + # self.ici_autoregressive_parallelism, + # ] + # self.dcn_parallelism = [ + # self.dcn_diloco_parallelism, + # self.dcn_pipeline_parallelism, + # self.dcn_data_parallelism, + # self.dcn_fsdp_parallelism, + # self.dcn_fsdp_transpose_parallelism, + # self.dcn_sequence_parallelism, + # self.dcn_context_parallelism, + # self.dcn_context_autoregressive_parallelism, + # self.dcn_tensor_parallelism, + # self.dcn_tensor_transpose_parallelism, + # self.dcn_tensor_sequence_parallelism, + # self.dcn_expert_parallelism, + # self.dcn_autoregressive_parallelism, + # ] + # else: ici_map = { "diloco": self.ici_diloco_parallelism, "data": self.ici_data_parallelism, diff --git a/src/maxtext/kernels/megablox/backend.py b/src/maxtext/kernels/megablox/backend.py index 0b8804d610..71464a1abf 100644 --- a/src/maxtext/kernels/megablox/backend.py +++ b/src/maxtext/kernels/megablox/backend.py @@ -298,6 +298,7 @@ def _calculate_bytes(x: jax.Array | qpl.QArray) -> int: "tiling", "transpose_rhs", "interpret", + "vma_axes", ], ) def gmm( @@ -310,6 +311,7 @@ def gmm( existing_out: jnp.ndarray | None = None, transpose_rhs: bool = False, interpret: bool = False, + vma_axes: tuple = tuple(), ) -> jnp.ndarray: """Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'. @@ -522,7 +524,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): } call_gmm = qpl.pallas_call( kernel, - out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type), + out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type, vma=set(vma_axes)), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=2, in_specs=[ @@ -558,6 +560,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): return out +# calculates drhs - expert weight gradient @functools.partial( jax.jit, static_argnames=[ @@ -565,6 +568,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): "tiling", "num_actual_groups", "interpret", + "vma_axes", ], ) def tgmm( @@ -577,6 +581,7 @@ def tgmm( num_actual_groups: int | None = None, existing_out: jnp.ndarray | None = None, interpret: bool = False, + vma_axes: tuple = tuple(), ) -> jnp.ndarray: """Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :]. @@ -773,9 +778,10 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset): "prefer_element_type": jnp.dtype(preferred_element_type).name, "num_actual_groups": num_actual_groups, } + # computes call_gmm = qpl.pallas_call( kernel, - out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type), + out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type, vma=set(vma_axes)), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=2, in_specs=[ diff --git a/src/maxtext/kernels/megablox/ops.py b/src/maxtext/kernels/megablox/ops.py index 4edacae2b3..9df9117767 100644 --- a/src/maxtext/kernels/megablox/ops.py +++ b/src/maxtext/kernels/megablox/ops.py @@ -44,8 +44,8 @@ def gmm( weight_gather_axes: List[Tuple[str, int]] | None = None, input_buffer_count: tuple[int, int, int] = (2, 2, 2), combine_scopes: bool = False, - # TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature - qwix_rule: qwix.QtRule | None = None, + lhs_vma_axes: tuple = tuple(), + rhs_vma_axes: tuple = tuple(), ): """Grouped matrix multiplication operation.""" quantization_rule = None @@ -64,9 +64,13 @@ def gmm( act_calibration_method="absmax", ) - gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001 + _gmm_fwd_vma = functools.partial(_gmm_fwd, lhs_vma_axes=tuple()) + _gmm_bwd_vma = functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype, lhs_vma_axes=tuple(), rhs_vma_axes=("expert",)) + gmm_fwd_bwd = lambda *args: _gmm_fwd_vma(*args)[0] # pylint: disable=C3001 + # defined custom backward propagation to be more efficient + # computes: dlhs: gradients of activations (for previous layers); drhs: gradients of weights gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 5, 6, 9, 10, 11, 12, 13)) - gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype)) + gmm_fwd_bwd.defvjp(_gmm_fwd_vma, _gmm_bwd_vma) return gmm_fwd_bwd( lhs, rhs, @@ -85,6 +89,7 @@ def gmm( ) +# wraps backend kernel def _gmm_fwd( lhs: jnp.ndarray, rhs: jnp.ndarray, @@ -100,6 +105,7 @@ def _gmm_fwd( quantization_rule: qwix.QtRule | None = None, use_tokamax_backend: bool = False, weight_gather_axes: List[Tuple[str, int]] | None = None, + lhs_vma_axes: tuple = tuple(), ) -> tuple[ jnp.ndarray, tuple[ @@ -129,7 +135,7 @@ def _gmm_fwd( calibration_method=quantization_rule.weight_calibration_method, ) # QAG is only supported for following conditions - if use_tokamax_backend: + if use_tokamax_backend: # false if quantization_rule and quantization_rule.bwd_qtype: if quantization_rule.weight_calibration_method.startswith("fixed") and isinstance(rhs, qpl.QArray): if weight_gather_axes: @@ -159,10 +165,12 @@ def _gmm_fwd( existing_out, transpose_rhs=transpose_rhs, interpret=interpret, + vma_axes=tuple(), ) return out, (lhs, rhs, group_sizes, group_offset) +# custom backward function def _gmm_bwd( lhs_dtype: jax.typing.DTypeLike, rhs_dtype: jax.typing.DTypeLike, @@ -182,6 +190,8 @@ def _gmm_bwd( jnp.ndarray | None, ], grad: jnp.ndarray, + lhs_vma_axes: tuple = tuple(), # axes for SiLU output - fsdp + rhs_vma_axes: tuple = tuple(), # axes for W_out - expert ) -> tuple[jnp.ndarray, jnp.ndarray, None, None, jnp.ndarray]: """Backward function for throughput GMM VJP.""" del preferred_element_type @@ -223,7 +233,7 @@ def _gmm_bwd( channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [1], calibration_method=quantization_rule.bwd_calibration_method, ) - if use_tokamax_backend: + if use_tokamax_backend: # false dlhs = tokamax_backend.gmm( lhs=dlhs_dout, rhs=rhs, @@ -263,6 +273,7 @@ def _gmm_bwd( group_offset, transpose_rhs=not transpose_rhs, interpret=interpret, + vma_axes=lhs_vma_axes, ) drhs = backend.tgmm( lhs.swapaxes(0, 1), @@ -273,6 +284,7 @@ def _gmm_bwd( group_offset, num_actual_groups, interpret=interpret, + vma_axes=("expert",), ) # NOTE: If the rhs transposition is fused into the forward pass we need to diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index c0033b4bae..f70e6f9e75 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -580,7 +580,9 @@ def maybe_create_nnx(einsum, *args): def _logical_to_mesh_axes(self, logical_name): logical_rules = None if self.config.using_pipeline_parallelism else self.config.logical_axis_rules - return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=logical_rules) + return logical_to_mesh_axes( + logical_name, mesh=self.mesh, rules=logical_rules + ) def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Array | KVTensor) -> None: """Check attention inputs.""" diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 9d53d1149d..f12f1040bc 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -640,9 +640,7 @@ def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True sorted_selected_experts = jnp.argsort(flatten_selected_experts) # sort inputs for number of selected experts replicated_inputs_2d = jnp.repeat(inputs_2d, self.num_experts_per_tok, axis=0) - sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp).astype( - self.dtype - ) + sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp).astype(self.dtype) group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts) # Return the experts for each sorted input. expert_indices = jnp.arange(self.num_experts) @@ -893,9 +891,10 @@ def sparse_matmul( ): """Perform sparse matrix multiplication of inputs and Experts.""" - def gmm( - inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes - ): + vma_axes = tuple(axis for axis in self.config.mesh_axes if self.mesh.shape[axis] > 1) + use_vma = not self.config.use_tokamax_gmm + + def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes): # TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat: tokamax_group_sizes = group_sizes @@ -904,6 +903,7 @@ def gmm( group_sizes, max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), ) + pad_length = self.config.wi_tile_fwd_batch_seq hs_shape = inputs.shape # pad length is the 1st dimension of tiling size in gmm call @@ -1033,14 +1033,13 @@ def gmm( batch_logical_axis = "activation_batch_no_exp_moe" if self.get_tensor_transpose_parallelism_size() > 1: - input_partition_pspec = self._logical_to_mesh_axes( - (batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe") - ) + input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")) w0_bias_pspec = self._logical_to_mesh_axes(("exp", None)) w1_bias_pspec = self._logical_to_mesh_axes(("exp", None)) wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe")) else: - input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) + input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + # expert weights are sharded by exp w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe")) @@ -1103,7 +1102,7 @@ def gmm( P(), # Handle None or replicate the output P(), # Handle None or replicate the output ), - check_vma=False, + check_vma=use_vma, ) def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs): batch_size, sequence_length, _ = x.shape @@ -1119,8 +1118,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r # Duplicate inputs to all expert shards. x, logits, pre_bias_logits = tuple( - jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True) - for z in (x, logits, pre_bias_logits) + jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True) for z in (x, logits, pre_bias_logits) ) # "Route" tokens within each shard. @@ -1262,6 +1260,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): wi_combine_scopes = self.config.wi_combine_scopes wo_combine_scopes = self.config.wo_combine_scopes + # x * W_gate layer_w0 = gmm_fn( x, w0, @@ -1274,8 +1273,8 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose") if self.config.mlp_bias: layer_w0 = layer_w0 + w0_bias - layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0") - + layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0") + # x * W_up layer_w1 = gmm_fn( x, w1, @@ -1288,9 +1287,10 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose") if self.config.mlp_bias: layer_w1 = layer_w1 + w1_bias - layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1") + layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1") + # multiplied result from W_gate and W_up before downward projection intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1) - + # output of FFN intermediate_output = gmm_fn( intermediate_layer, wo, @@ -1435,10 +1435,8 @@ def reshape_and_update_weights(self, weights, indices): # output of updated weights: (batch_size, seq_len, num_experts) update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype) index_update = ( - self._maybe_shard_with_logical( - jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp_moe", None, None) - ), - self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp_moe", None)), + self._maybe_shard_with_logical(jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp", None, None)), + self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp", None)), indices, ) weight_sharding = ( @@ -1664,9 +1662,7 @@ def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument einsum_op = jnp.einsum return einsum_op - def maybe_all_gather_kernel_weight_in_expert_parallelism( - self, kernel: jax.Array, kernel_axes: Tuple[Optional[str], ...] - ): + def maybe_all_gather_kernel_weight_in_expert_parallelism(self, kernel: jax.Array, kernel_axes: Tuple[Optional[str], ...]): """All-gather kernel weight in expert parallelism if needed.""" if self.get_expert_parallelism_size() > 1: # This will trigger all-gather using weight_dtype @@ -1691,14 +1687,10 @@ def dense_matmul( ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: """Dense matrix multiplication.""" # gate_logits: batch, length, expert - gate_logits = self._maybe_shard_with_logical( - gate_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None) - ) + gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None)) if self.config.model_name.startswith("deepseek3"): # pre_bias_logits is None for non-DeepSeek v3 models - pre_bias_logits = self._maybe_shard_with_logical( - pre_bias_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None) - ) + pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, ("activation_batch", "activation_norm_length", None)) top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs) is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4 if is_llama4_decoder_layer: @@ -1711,9 +1703,7 @@ def dense_matmul( # Calculate load balance loss if self.config.model_call_mode != "inference": softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype) - lb_loss = ( - self.load_balance_loss(top_k_indices, softmax_probs) if self.config.load_balance_loss_weight > 0.0 else None - ) + lb_loss = self.load_balance_loss(top_k_indices, softmax_probs) if self.config.load_balance_loss_weight > 0.0 else None else: lb_loss = None @@ -1990,9 +1980,7 @@ def retrieve_quantized_weight( # This is called only during tracing. This is to invoke creation of # quantized tensor inside AqtEinsum. After jit, this will become no-op and # will not affect performance. - _ = self.dense_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias - ) + _ = self.dense_matmul(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias) w0_kernel = self.variables["aqt"]["AqtEinsum_0"]["AqtDotGeneral_0"]["qrhs"]["frozen"] w1_kernel = self.variables["aqt"]["AqtEinsum_1"]["AqtDotGeneral_0"]["qrhs"]["frozen"] diff --git a/src/maxtext/layers/pipeline.py b/src/maxtext/layers/pipeline.py index 60cb7d2ac2..5097763f54 100644 --- a/src/maxtext/layers/pipeline.py +++ b/src/maxtext/layers/pipeline.py @@ -14,10 +14,10 @@ """Pipeline layer wrapping a decoder layer(s). Supports circular pipelining.""" -import functools from typing import Any +import functools -import numpy as np +from maxtext.utils import pipeline_utils from jax import numpy as jnp from jax.sharding import Mesh, NamedSharding, PartitionSpec as P @@ -26,7 +26,6 @@ from flax.core import meta from flax import linen as nn -from flax.linen.spmd import LogicallyPartitioned from maxtext.common.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode from maxtext.utils.sharding import ( @@ -247,9 +246,7 @@ def get_main_vmap_func_for_iterations(self): else a set of layers if body_instance is a set of layers. """ - def func_to_vmap( - body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode - ): + def func_to_vmap(body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode): weights = meta.remove_axis( weights, 0, @@ -304,15 +301,11 @@ def _run_weight_initialization( else None ) example_position = ( - jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats]) - if example_position is not None - else None + jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats]) if example_position is not None else None ) example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None)) - stage_outputs = vmap_func( - self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode - ) + stage_outputs = vmap_func(self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode) if self.config.scan_layers: stage_outputs = stage_outputs[0] if self.config.num_pipeline_repeats > 1: @@ -404,11 +397,30 @@ def init_states(self, inputs): # of delay circ_storage_mover shape is same as shift: [num_stages, micro_size, sequence, embed] if self.use_circ_storage: circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype, out_sharding=self.state_io_sharding) - circ_storage_mover = shift else: circ_storage = None + + # circ_storage_mover is used to push the microbatches from the pipeline into circ_storage with one buffer iteration + # of delay circ_storage_mover shape is same as shift: [num_stages, micro_size, sequence, embed] + if self.use_circ_storage: + circ_storage_mover = shift + else: circ_storage_mover = None + def _init_bsw_from_weights(variables): + """Buffer space for two copies of weights.""" + # take idx 0 slice assuming num_layers_per_pipeline_stage=1 + return ( + jax.tree.map(lambda x: jnp.zeros_like(x[0]), variables), + jax.tree.map(lambda x: jnp.zeros_like(x[0]), variables), + ) + + if self.is_initializing(): + bsw = None + else: + variables = pipeline_utils.remove_logically_partition(self.layers.variables) + bsw = _init_bsw_from_weights(variables) + init_loop_state = { "state_io": state_io, "shift": shift, @@ -416,6 +428,8 @@ def init_states(self, inputs): "circ_storage_mover": circ_storage_mover, "loop_iteration": 0, "prev_outputs": prev_outputs, + "bsw": bsw, + "weights": self.layers.variables, } return init_loop_state @@ -443,9 +457,50 @@ def shard_dim_by_stages(self, x, dim: int, physical_partition_spec: P | None, is sharding = jax.sharding.NamedSharding(self.mesh, pspec) return self._maybe_shard_with_name(x, sharding) - def vmap_parallel_gather( - self, weights, physical_partition_spec, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights - ): + if self.use_circ_storage: + # Setup potential input from circ_storage, which also has a rotating index for microbatch, + # size of num_microbatches + circ_storage_batch_idx = loop_iteration % self.config.num_pipeline_microbatches + circular_stage_in = circ_storage[:, circ_storage_batch_idx] + else: + # The last stage immediately flows into the first stage, use this rotated shift instead of circular storage + circular_stage_in = shift + + # For early loop iterations we grab a new input for stage 0 from the state_io. Once each microbatch has left + # state_io we instead grab from the last stage's output (possibly buffered when num_microbatches > num_stages, e.g. + # from circ_storage). + first_stage_in = jnp.where(loop_iteration < self.config.num_pipeline_microbatches, state_io_slice, circular_stage_in) + first_stage_in = self._maybe_shard_with_logical(first_stage_in, self.stages_in_logical) + + # Note that first_stage_in may correspond to bubble computation during the last few iterations. + # However, these bubble computation results remain in the shift buffer (do not make it back to state_io) and are + # thus discarded / not returned. + # The final returned output is stored in the state_io, which has the appropriate total size of num_microbatches. The + # state_io will not contain bubble results at the end of the last iteration. + + def select_state_or_input(first_stage_in, shift): + # Selects input for stage 0, shift for other stages + return jnp.where( + jax.lax.broadcasted_iota("int32", shift.shape, 0, out_sharding=self.stages_in_sharding) == 0, + first_stage_in, + shift, + ) + + # Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output) + stages_in = select_state_or_input(first_stage_in, shift) + stages_in = self._maybe_shard_with_logical(stages_in, self.stages_in_logical) + return stages_in + + def get_microbatch_and_repeat_ids(self, loop_iteration): + """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and + non-circular""" + # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages + microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) + microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches + repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches + return microbatch_ids, repeat_ids + + def vmap_parallel_gather(self, weights, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights): """Use vmap to implement a sharded parallel gather. Parallel gather means each stage has its own weights, and gets one slice from it. Args: @@ -624,8 +679,8 @@ def permute_output_micro_per_stage_dim(self, output): # The first real output (microbatch 0) takes a certain amount of loop iterations to finish and be pushed to # state_io - it will land on a different index of state_io depending on the number of iterations. microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage - permutation = (np.arange(self.microbatches_per_stage) + microbatch_0_idx) % self.microbatches_per_stage - output = output[:, permutation] + output = jnp.roll(output, shift=-microbatch_0_idx, axis=1) + output = self._maybe_shard_with_logical(output, self.state_io_logical) return output def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_partition_spec=None): @@ -635,10 +690,63 @@ def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_p For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However for circular pipelines each stage grabs only the weights corresponding to the current repeat. """ + pipeline_weights = pipeline_utils.remove_logically_partition(pipeline_weights) if self.config.num_pipeline_repeats > 1: return self.get_current_repeat_from_stages( pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec ) + return pipeline_weights + + def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec, is_initializing=None): + """Collect and gather weights from given bsw (buffer sliding window)""" + bsw_pps = jax.tree.map(pipeline_utils.remove_fsdp_from_physical_partition_spec, physical_partition_spec) + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + target_repeat_id = repeat_ids[0] + + @jax.shard_map( + mesh=self.mesh, + in_specs=((bsw_pps, bsw_pps), P("stage")), + out_specs=(bsw_pps), + check_vma=True, + ) + def select_weights_from_bsw(bsw, repeat_id): + weights = jax.tree.map( + lambda x, y: jax.lax.select(repeat_id[0] == target_repeat_id, y, x), + bsw[0], + bsw[1], + ) + + return weights + + weights = select_weights_from_bsw(bsw, repeat_ids) + + if is_initializing is None: + is_initializing = self.is_initializing() + + circular_metadata_params = { + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": is_initializing, + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + } + weights = meta.remove_axis( + weights, 0, circular_metadata_params + ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular + # entry per stage. + + return weights + + def from_all_variables_to_repeat_weights(self, weights, loop_iteration, physical_partition_spec): + """Generate one single repeat weight from all variables.""" + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def gather_weights_for_stages_in(w, spec): + return self.vmap_parallel_gather(w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1) + + weights = pipeline_utils.remove_logically_partition(weights) + if physical_partition_spec is None: + weights = jax.tree.map(gather_weights_for_stages_in, weights) else: return pipeline_weights @@ -657,9 +765,152 @@ def get_current_repeat_from_stages(self, weights, loop_iteration, physical_parti weights = meta.remove_axis(weights, 0, circular_metadata_params) weights = self._remove_logically_partition(weights) - def gather_weights_for_stages_in(w, spec=None): - return self.vmap_parallel_gather( - w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec + def from_all_variables_to_bsw(self, repeat_weights, physical_partition_spec): + """All gather one branch of bsw using shardmap.""" + + bsw_pps = pipeline_utils.generate_bsw_pps_from_pps(physical_partition_spec) + repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec) + fsdp_idx = pipeline_utils.get_fsdp_index_pytree(physical_partition_spec, "fsdp") + fsdpt_idx = pipeline_utils.get_fsdp_index_pytree(physical_partition_spec, "fsdp_transpose") + expert_idx = pipeline_utils.get_fsdp_index_pytree(physical_partition_spec, "expert") + + @jax.shard_map( + mesh=self.mesh, + in_specs=(repeat_weights_pps, None, None, None), + out_specs=bsw_pps, + check_vma=True, + ) + def _all_gather_inner(sharded_weights, fsdp_idx, fsdpt_idx, expert_idx): + def _all_gather_with_path(path, x, i, j, k): + if i >= 0: + x = all_gather_invariant(x, axis_name="fsdp", axis=i - 1, tiled=True) + if j >= 0: + x = all_gather_invariant(x, axis_name="fsdp_transpose", axis=j - 1, tiled=True) + # path_keys = [getattr(p, "key", str(p)) for p in path] + is_moe_block = True # "MoeBlock_0" in path_keys TODO: Enable it + if k >= 0 and not is_moe_block: + x = all_gather_invariant(x, axis_name="expert", axis=k - 1, tiled=True) + return x + + return jax.tree_util.tree_map_with_path(_all_gather_with_path, sharded_weights, fsdp_idx, fsdpt_idx, expert_idx) + + return _all_gather_inner(repeat_weights, fsdp_idx, fsdpt_idx, expert_idx) + + def bsw_all_gather_over_fsdp(self, weights, physical_partition_spec, loop_iteration): + """All gather all bsw over fsdp mesh axis using shardmap.""" + cur_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration, physical_partition_spec) + nxt_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1, physical_partition_spec) + bsw_0 = self.from_all_variables_to_bsw(cur_repeat_weights, physical_partition_spec) + bsw_1 = self.from_all_variables_to_bsw(nxt_repeat_weights, physical_partition_spec) + return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw") + + def _run_initialization( + self, + example_inputs, + example_segmentation, + example_position, + segment_idx, + position_idx, + deterministic, + model_mode, + ): + """Runs the initialization sequence mapping layers appropriately based on pipeline settings.""" + vmap_func = self.get_vmap_func_for_init() + + if self.config.num_pipeline_repeats > 1: + # To shard the weights on initialization for the circular pipeline we create weights of + # shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis. + # We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization. + vmap_func = nn.vmap( + vmap_func, + in_axes=(0, segment_idx, position_idx, None, None), + variable_axes={ + "params": 0, + "_overwrite_with_gradient": 0, + "non_trainable": 0, + "hyper_params": 0, + }, + split_rngs={"params": True, "dropout": self.config.enable_dropout}, + metadata_params={ + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": True, + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + }, + ) + + example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats]) + example_segmentation = ( + jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats]) + if example_segmentation is not None + else None + ) + example_position = ( + jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats]) if example_position is not None else None + ) + + # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for + # the full total_iterations. + example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None)) + stage_outputs = vmap_func(self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode) + if self.config.scan_layers: + stage_outputs = stage_outputs[0] + + # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output + # which has shape [pipeline_microbatch_size, sequence, embed] + if self.config.num_pipeline_repeats > 1: + stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap + broadcasted_stage_outpus = jax.lax.broadcast( + stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size] + ) + + return jnp.reshape( + broadcasted_stage_outpus, + [self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim], + out_sharding=self.output_sharding, + ) + + def get_vmap_func_for_init(self): + """This vmap func is used to initialize the weights only on init.""" + + def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode): + """nn.vmap requires either a nn.module class or a function whose first argument is a nn.module instance.""" + return body_instance(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) + + vmap_func = nn.vmap( + func_to_vmap, + in_axes=(0, 0, 0, None, None), + spmd_axis_name=self.spmd_axis_name, # TODO(b/470167805): replace self.spmd_axis_name with "stage" when JAX >= 0.8.2. + variable_axes={"params": 0, "_overwrite_with_gradient": 0}, + split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, + metadata_params={ + nn.PARTITION_NAME: "layers", + "sub_weight_split_dims_mapping": (None), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages, + }, + ) + return vmap_func + + def get_main_vmap_func_for_iterations(self): + """ + Returns main stage function vmapped by number of stages. + This becomes a vmap over a single layer instance if body_instance is a single layer, + else a set of layers if body_instance is a set of layers. + """ + + def func_to_vmap(body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode): + """nn.vmap requires either a nn.module class or a function whose first argument is a nn.module instance.""" + weights = meta.remove_axis( + weights, + 0, + { + nn.PARTITION_NAME: "layers", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages, + }, ) if physical_partition_spec is None: @@ -676,8 +927,7 @@ def run_one_iteration( segment_ids, deterministic, model_mode, - decoder_layer_instance, - logical_partition_spec=None, + logical_partition_spec, ): """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, and update the loop state. @@ -705,46 +955,15 @@ def run_one_iteration( vmap_func = self.get_main_vmap_func_for_iterations() - if self.config.num_pipeline_repeats > 1: - _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) - - def prepare_vars_for_main_vmap(weights, physical_partition_spec=None): - circular_metadata_params = { - nn.PARTITION_NAME: "circular_repeats", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.config.num_pipeline_repeats, - "optimizer_dims_mapping": None, - } - weights = meta.remove_axis(weights, 0, circular_metadata_params) - weights = self._remove_logically_partition(weights) - - def gather_weights_for_stages_in(w, spec=None): - return self.vmap_parallel_gather( - w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec - ) - - if physical_partition_spec is None: - weights = jax.tree.map(gather_weights_for_stages_in, weights) - else: - weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) - return weights - - prepare_vars_for_main_vmap_partial = functools.partial( - prepare_vars_for_main_vmap, physical_partition_spec=physical_partition_spec - ) - vmap_func = nn.map_variables( - vmap_func, - mapped_collections=["params", "_overwrite_with_gradient", "non_trainable", "summaries", "intermediates"], - mutable=True, - trans_in_fn=prepare_vars_for_main_vmap_partial, - ) - stage_weights = self.get_current_stage_weights( - pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec + pipeline_weights, + loop_state["bsw"], + loop_iteration, + physical_partition_spec=physical_partition_spec, + is_initializing=self.is_initializing(), ) stages_output = vmap_func( - decoder_layer_instance, + self.layers, stage_weights, stages_inputs, stages_segment_ids, @@ -758,34 +977,41 @@ def gather_weights_for_stages_in(w, spec=None): new_state = self.get_new_loop_state(stages_output, loop_state) return new_state - @staticmethod - def get_logical_spec_repeats_removed(full_logical): - """Returns a new logical spec with 'circular_repeats' removed.""" - if full_logical is None: - return None + def get_pipeline_remat_policy(self): + """Returns the pipeline remat policy for this pipeline.""" + # We ensure that the decoder layer inputs are saved, although we leave it to a custom + # policy if they should be saved to device or offloaded. + if self.config.remat_policy == "custom": + return self.remat_policy - def _remove_from_spec(spec): - return jax.sharding.PartitionSpec(*[dim for dim in spec if dim != "circular_repeats"]) + save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input") + if self.remat_policy is not None: + remat_policy = jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy) + else: + remat_policy = save_input_policy + return remat_policy - return jax.tree.map(_remove_from_spec, full_logical) + def get_weight_sharding(self, *init_args): + """get weight sharding function for this pipeline.""" + # Returns a partition spec of all weights. Requires passing in arguments to init. + key = jax.random.PRNGKey(0) + keys = {"params": key, "dropout": key, "aqt": key} + weights = self.init(keys, *init_args) - @staticmethod - def _remove_logically_partition(weights): - """Removes LogicallyPartitioned wrappers from the variables.""" + def get_partition_spec(pytree): + def _is_leaf(x): + return isinstance(x, nn.spmd.LogicallyPartitioned) - def _remove_logically_partition_leaf(v): - return getattr(v, "value") if isinstance(v, LogicallyPartitioned) else v + def get_partition_spec_leaf(leaf): + return leaf.get_partition_spec() - return jax.tree.map(_remove_logically_partition_leaf, weights, is_leaf=lambda v: isinstance(v, LogicallyPartitioned)) + partition_spec_tree = jax.tree.map(get_partition_spec_leaf, pytree, is_leaf=_is_leaf) + return partition_spec_tree def all_gather_over_fsdp(self, variables, logical_partition_spec): """Gathers FSDP partitioned variables to reconstruct them fully.""" - physical_partition_spec = logical_to_mesh( - logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules - ) - physical_partition_spec_no_fsdp = jax.tree.map( - self._remove_fsdp_from_physical_partition_spec, physical_partition_spec - ) + physical_partition_spec = logical_to_mesh(logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules) + physical_partition_spec_no_fsdp = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec) return jax.tree.map( lambda w, p: self._maybe_shard_with_name(w, NamedSharding(self.mesh, p)), variables, @@ -852,21 +1078,13 @@ def __call__( # Thus the total iterations is num_micro * repeat + num_stages - 1, & we may consider the num_stages - 1 as bubble. # The bubble doubles when we use forwarding delay. bubble_iterations = self.forwarding_delay * (self.num_stages - 1) - real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats - total_iterations = real_iterations + bubble_iterations if self.is_initializing(): return self._run_weight_initialization( example_inputs, example_segmentation, example_position, segment_idx, position_idx, deterministic, model_mode ) - if self.config.pipeline_fsdp_ag_once: - variables = self._remove_logically_partition(self.layers.variables) - all_pipeline_weights = self.all_gather_over_fsdp(variables, logical_partition_spec) - else: - all_pipeline_weights = self.layers.variables - - logical_partition_spec = self.get_logical_spec_repeats_removed(logical_partition_spec) + logical_partition_spec = pipeline_utils.get_logical_spec_repeats_removed(logical_partition_spec) def run_iteration_scannable(model, loop_state, xs): # flax transforms like nn.scan and nn.remat can only be applied to nn.module classes or nn.module instances, so we @@ -879,7 +1097,6 @@ def run_iteration_scannable(model, loop_state, xs): segment_ids, deterministic, model_mode, - model.layers, logical_partition_spec=logical_partition_spec, ), None, @@ -996,9 +1213,9 @@ def _gather_single_repeat(x, repeat_id): return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) gathered_weights_stage_dim = 0 - stage_weights = jax.vmap( - _gather_single_repeat, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim - )(weights, repeat_ids) + stage_weights = jax.vmap(_gather_single_repeat, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)( + weights, repeat_ids + ) return stage_weights def gather_microbatch_inputs_vmap(self, xs, ids, ids_dim): @@ -1340,9 +1557,7 @@ def __call__( loop_state, bsw = self.init_states(inputs) weights = self.layers.variables - physical_partition_spec = logical_to_mesh( - logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules - ) + physical_partition_spec = logical_to_mesh(logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules) bubble_iterations = self.forwarding_delay * (self.num_stages - 1) diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 6d502d92c4..7c4ac2e6b1 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -180,7 +180,7 @@ def mlp_op(self, x, deterministic, *args, **kwargs): def with_logical_constraint(self, x): return maybe_shard_with_logical( x, - logical_axes=self.logical_axis_names, + logical_axes=tuple(self.logical_axis_names), mesh=self.mesh, shard_mode=self.config.shard_mode, debug_sharding=self.config.debug_sharding, diff --git a/src/maxtext/models/deepseek_batchsplit.py b/src/maxtext/models/deepseek_batchsplit.py index 24ccf1c7b5..92b016ff8b 100644 --- a/src/maxtext/models/deepseek_batchsplit.py +++ b/src/maxtext/models/deepseek_batchsplit.py @@ -790,12 +790,8 @@ def yarn( # (Note: We use jnp.arange with float32 for precision.) freqs = 1.0 / (rope_theta ** (2.0 * jnp.arange(0, half_dim, dtype=jnp.float32) / embedding_dims)) - low = ( - embedding_dims * math.log(original_max_position_embeddings / (beta_fast * 2 * math.pi)) / (2 * math.log(rope_theta)) - ) - high = ( - embedding_dims * math.log(original_max_position_embeddings / (beta_slow * 2 * math.pi)) / (2 * math.log(rope_theta)) - ) + low = embedding_dims * math.log(original_max_position_embeddings / (beta_fast * 2 * math.pi)) / (2 * math.log(rope_theta)) + high = embedding_dims * math.log(original_max_position_embeddings / (beta_slow * 2 * math.pi)) / (2 * math.log(rope_theta)) low = max(math.floor(low), 0) high = min(math.ceil(high), embedding_dims - 1) diff = high - low if high > low else 0.001 @@ -952,6 +948,11 @@ def gmm( input_buffer_count, combine_scopes, ): + + tokamax_group_sizes = tokamax.RaggedDotGroupSizes( + group_sizes, + max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), + ) if config.use_qwix_quantization: output = megablox.gmm( lhs=inputs, @@ -1106,9 +1107,7 @@ def route_compute_unroute( def route_fn(inputs): # Shared expert. - y = dot( - jax.nn.silu(dot(inputs, shared_w0, quant=quant)) * dot(inputs, shared_w1, quant=quant), shared_wo, quant=quant - ) + y = dot(jax.nn.silu(dot(inputs, shared_w0, quant=quant)) * dot(inputs, shared_w1, quant=quant), shared_wo, quant=quant) inputs = jnp.reshape(inputs, (-1, inputs.shape[-1])) selected_experts, weights, group_sizes = expert_selection( diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 0d1fcab700..b34f14953a 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -335,8 +335,8 @@ def __init__( if cfg.pure_nnx_decoder: self.decoder = NNXDecoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs) else: - decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) - self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs) + self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) + self.decoder = nnx_wrappers.ToNNX(self.decoder, rngs=rngs) self.hidden_states = None batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode) diff --git a/src/maxtext/utils/globals.py b/src/maxtext/utils/globals.py index 203d7a6165..f379a85f1e 100644 --- a/src/maxtext/utils/globals.py +++ b/src/maxtext/utils/globals.py @@ -24,9 +24,7 @@ MAXTEXT_REPO_ROOT = os.environ.get( "MAXTEXT_REPO_ROOT", r - if os.path.isdir( - os.path.join(r := os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), ".git") - ) + if os.path.isdir(os.path.join(r := os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), ".git")) else MAXTEXT_PKG_DIR, ) diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index e4cad14906..4e68ca8c39 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -68,7 +68,7 @@ def _maybe_shard_with_name(inputs, sharding_names): return maybe_shard_with_name(inputs, sharding_names, config.shard_mode, debug_sharding=config.debug_sharding) # For more efficient DP/ZeRO-1 + GA - if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1: + if config.shard_mode == ShardMode.EXPLICIT and model.mesh.shape.get("data", 1) > 1: ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings) grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) else: diff --git a/tests/unit/pipeline_parallelism_test.py b/tests/unit/pipeline_parallelism_test.py index b2582d822c..5441023102 100644 --- a/tests/unit/pipeline_parallelism_test.py +++ b/tests/unit/pipeline_parallelism_test.py @@ -118,9 +118,7 @@ def get_inputs(batch_size, sequence, features): single_pipeline_stage = simple_layer.SimpleDecoderLayerToLinen( config=config, mesh=mesh, model_mode=model_mode, rngs=rngs ) - my_pipeline = pipeline.create_pipeline( - config=config, layers=single_pipeline_stage, mesh=mesh - ) + my_pipeline = pipeline.create_pipeline(config=config, layers=single_pipeline_stage, mesh=mesh) init_pipeline_params = my_pipeline.init( jax.random.PRNGKey(0), inputs, inputs_position, inputs_segmentation, deterministic, model_mode ) @@ -351,35 +349,68 @@ def test_full_train_circular(self): def test_full_train_circular_pipeline_ag_per_repeat(self): # Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats), # 8 microbatches and using pipeline ag per repeat - train_main([ - None, - get_test_config_path(), - f"base_output_directory={self.base_output_directory}", - "run_name=runner_pipeline_parallelism_test", - f"dataset_path={self.dataset_path}", - "base_emb_dim=28", - "base_num_query_heads=4", - "base_num_kv_heads=4", - "base_mlp_dim=32", - "base_num_decoder_layers=32", - "head_dim=128", - "per_device_batch_size=2", - "max_target_length=1024", - "vocab_size=32", - "dataset_type=synthetic", - "steps=3", - "enable_checkpointing=False", - "enable_goodput_recording=False", - "ici_pipeline_parallelism=2", - "num_layers_per_pipeline_stage=1", - "num_pipeline_microbatches=4", - "pipeline_fsdp_ag_per_repeat=True", - ( - rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}" - ), - ]) + train_main( + [ + None, + get_test_config_path(), + f"base_output_directory={self.base_output_directory}", + "run_name=runner_pipeline_parallelism_test", + f"dataset_path={self.dataset_path}", + "base_emb_dim=28", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=32", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "vocab_size=32", + "dataset_type=synthetic", + "steps=3", + "enable_checkpointing=False", + "enable_goodput_recording=False", + "ici_pipeline_parallelism=2", + "num_layers_per_pipeline_stage=1", + "num_pipeline_microbatches=4", + "pipeline_fsdp_ag_per_repeat=True", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", + ] + ) + + @pytest.mark.tpu_only + def test_full_train_circular_pipeline_ag_per_repeat(self): + # Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats), + # 8 microbatches and using pipeline ag per repeat + train_main( + [ + None, + get_test_config_path(), + f"base_output_directory={self.base_output_directory}", + "run_name=runner_pipeline_parallelism_test", + f"dataset_path={self.dataset_path}", + "base_emb_dim=28", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=32", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "vocab_size=32", + "dataset_type=synthetic", + "steps=3", + "enable_checkpointing=False", + "enable_goodput_recording=False", + "ici_pipeline_parallelism=2", + "num_layers_per_pipeline_stage=1", + "num_pipeline_microbatches=4", + "pipeline_fsdp_ag_per_repeat=True", + (rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}"), + ] + ) @pytest.mark.tpu_only + @pytest.mark.skip(reason="Circular pipeline does not support pipeline delay.") def test_delay_activation_forwarding_same_output_and_grad(self): # 4 stages, delayed activation forwarding, 8 layers (2 repeats, 1 layer per stage), 8 microbatches config = pyconfig.initialize( @@ -398,6 +429,7 @@ def test_delay_activation_forwarding_same_output_and_grad(self): self.assert_pipeline_same_output_and_grad(config) @pytest.mark.integration_test + @pytest.mark.skip(reason="Non-circular pipeline is not supported.") @pytest.mark.tpu_only def test_full_train_non_circular(self): # Run a full train.py call with 4 stages, 32 layers (8 layers per stage), 8 microbatches @@ -463,7 +495,8 @@ def test_subset_layers(self): ] ) - @pytest.mark.skipif(is_decoupled(), reason="Pipeline parallelism not supported in decoupled mode") + # @pytest.mark.skipif(is_decoupled(), reason="Pipeline parallelism not supported in decoupled mode") + @pytest.mark.skip(reason="Circular pipeline does not support fp8.") @pytest.mark.integration_test def test_full_train_fp8(self): # Run a full train.py call with fp8 quantization, which adds extra @@ -496,7 +529,8 @@ def test_full_train_fp8(self): _adapt_parallelism(args, pipeline_stages=4) train_main(args) - @pytest.mark.skipif(is_decoupled(), reason="Pipeline parallelism not supported in decoupled mode") + # @pytest.mark.skipif(is_decoupled(), reason="Pipeline parallelism not supported in decoupled mode") + @pytest.mark.skip(reason="Circular pipeline does not support fp8.") @pytest.mark.integration_test def test_full_train_nanoo_fp8(self): # Run a full train.py call with NANOO fp8 quantization, which adds extra