diff --git a/src/maxtext/configs/inference/vllm.yml b/src/maxtext/configs/inference/vllm.yml index 98c56bb61a..cc8f40c7e1 100644 --- a/src/maxtext/configs/inference/vllm.yml +++ b/src/maxtext/configs/inference/vllm.yml @@ -29,14 +29,14 @@ weight_dtype: bfloat16 # -------------- Logical Axis Rules -------------- mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert'] logical_axis_rules: [ - ['activation_batch', ['expert']], - ['activation_batch_moe', ['expert']], + ['activation_batch', ['data']], + ['activation_batch_moe', []], ['activation_batch_no_exp', []], ['activation_batch_no_exp_moe', []], - ['activation_embed_and_logits_batch', ['expert']], - ['activation_embed_and_logits_batch_sequence', ['expert']], - ['activation_heads', ['model']], - ['activation_kv_heads', ['model']], + ['activation_embed_and_logits_batch', ['data', 'expert']], + ['activation_embed_and_logits_batch_sequence', ['data', 'expert']], + ['activation_heads', ['model', 'expert']], + ['activation_kv_heads', ['model', 'expert']], ['activation_attn_length', ['expert']], ['activation_attn_length_no_exp', []], ['activation_length', ['data', 'expert']], @@ -50,8 +50,8 @@ logical_axis_rules: [ ['activation_mlp', ['model', 'attn_dp']], ['activation_kv', ['model']], ['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']], - ['activation_kv_batch', ['expert', 'attn_dp_expert']], - ['activation_kv_batch_no_exp', []], + ['activation_kv_batch', ['data', 'expert', 'attn_dp_expert']], + ['activation_kv_batch_no_exp', ['data']], ['activation_kv_head_dim', ['model']], ['activation_vocab', ['model', 'attn_dp']], ['activation_norm_length', []], @@ -64,12 +64,14 @@ logical_axis_rules: [ ['moe_mlp', ['model', 'attn_dp']], ['vocab', ['model', 'attn_dp']], ['heads', ['model']], - ['q_heads', ['model']], - ['kv_heads', ['model']], + ['q_heads', ['model', 'expert']], + ['kv_heads', ['model', 'expert']], ['kv_head_dim', []], ['kv', []], ['embed', ['expert', 'attn_dp_expert']], + ['embed', ['attn_dp_expert']], ['embed_moe', ['expert', 'attn_dp_expert']], + ['embed_moe', ['attn_dp_expert']], ['embed_tensor_transpose', ['attn_dp', 'model']], ['embed_no_exp', []], ['embed_no_exp_moe', []], diff --git a/src/maxtext/inference/vllm_decode.py b/src/maxtext/inference/vllm_decode.py index f7df999547..66af92e209 100644 --- a/src/maxtext/inference/vllm_decode.py +++ b/src/maxtext/inference/vllm_decode.py @@ -145,7 +145,6 @@ def decode_with_vllm(config: Config) -> None: max_tokens=max_tokens_to_generate, top_k=config.decode_sampling_top_k, top_p=config.decode_sampling_nucleus_p, - seed=FLAGS.seed, ) outputs = llm.generate(prompts, sampling_params) diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index a0f3afba76..a4cd924672 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -30,6 +30,7 @@ try: from tpu_inference.layers.common.attention_metadata import AttentionMetadata + from tpu_inference.layers.common.attention_interface import ShardingAxisName except ImportError: # Mock for documentation build or environments without tpu_inference class AttentionMetadata: @@ -39,7 +40,7 @@ class AttentionMetadata: from vllm.config import VllmConfig -def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters: +def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.HyperParameters: """Generates a MaxText configuration from a vLLM configuration. This function takes a vLLM configuration object and translates relevant @@ -50,6 +51,7 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters Args: vllm_config: The vLLM configuration object containing model and load parameters. + mesh: The JAX mesh device for model sharding. Returns: A `pyconfig.HyperParameters` object configured for MaxText. @@ -73,6 +75,22 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters base_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml") argv_list = ["", str(base_config_path)] + # Pad the number of KV heads if its less than the TP / EP size + if isinstance(ShardingAxisName.ATTN_HEAD, tuple): + tp_sizes = [mesh.shape[axis_name] for axis_name in ShardingAxisName.ATTN_HEAD] + max_tp_size = max(tp_sizes) + else: + max_tp_size = mesh.shape[ShardingAxisName.ATTN_HEAD] + + if ( + max_tp_size % vllm_config.model_config.get_total_num_kv_heads() == 0 + and vllm_config.model_config.get_total_num_kv_heads() < max_tp_size + ): + max_logging.log( + f"Padding num_kv_heads from {vllm_config.model_config.get_total_num_kv_heads()} to {max_tp_size} to match tp_size." + ) + overrides["base_num_kv_heads"] = max_tp_size + maxtext_config = pyconfig.initialize(argv_list, **overrides) return maxtext_config @@ -96,7 +114,7 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh): """ self.vllm_config = vllm_config self.cfg = vllm_config.model_config - self.maxtext_config = generate_maxtext_config(vllm_config) + self.maxtext_config = generate_maxtext_config(vllm_config, mesh) # Model configuration self.mesh = mesh diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 779c43fb67..5f6570bc88 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -911,6 +911,8 @@ def gmm( # TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat: tokamax_group_sizes = group_sizes + elif self.config.attention == "vllm_rpa": + tokamax_group_sizes = group_sizes else: tokamax_group_sizes = tokamax.RaggedDotGroupSizes( group_sizes, diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 5083a2194f..c17f74b5b5 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -440,9 +440,14 @@ def create_rl_components( optimizer = utils_rl.get_optimizer(trainer_config, max_train_steps) # Setup checkpointing - checkpointing_options = ocp.CheckpointManagerOptions( - save_interval_steps=trainer_config.checkpoint_period, max_to_keep=trainer_config.max_num_checkpoints_to_keep - ) + if trainer_config.enable_checkpointing: + checkpointing_options = ocp.CheckpointManagerOptions( + save_interval_steps=trainer_config.checkpoint_period, max_to_keep=trainer_config.max_num_checkpoints_to_keep + ) + checkpoint_dir = trainer_config.checkpoint_dir + else: + checkpointing_options = None + checkpoint_dir = None # Set up micro batching micro_batch_size = None if trainer_config.micro_batch_size == -1 else trainer_config.micro_batch_size @@ -499,7 +504,7 @@ def create_rl_components( rollout_micro_batch_size=micro_batch_size, metrics_logging_options=metrics_logging_options, profiler_options=profiler_options, - checkpoint_root_directory=trainer_config.checkpoint_dir, + checkpoint_root_directory=checkpoint_dir, checkpointing_options=checkpointing_options, ), rollout_config=base_rollout.RolloutConfig( diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index b3057d0518..f492744b24 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -15,6 +15,7 @@ # pylint: disable=bare-except, consider-using-generator """ Utils that are only interesting for creating a model in MaxText. """ +import dataclasses from collections.abc import Sequence from functools import partial from typing import overload @@ -23,15 +24,128 @@ from flax import nnx import flax.linen as nn import jax +import jax.numpy as jnp from jax.sharding import AxisType, Mesh from maxtext.configs import pyconfig from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode from maxtext.layers import quantizations from maxtext.models import models +from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils from orbax import checkpoint as ocp +try: + from orbax.checkpoint.metadata import ArrayMetadata as _OrbaxArrayMetadata + + def _is_orbax_array_metadata(x): + return isinstance(x, _OrbaxArrayMetadata) + +except ImportError: + + def _is_orbax_array_metadata(x): + return hasattr(x, "shape") and hasattr(x, "sharding") and hasattr(x, "dtype") and not isinstance(x, jax.Array) + + +def _expand_checkpoint_to_model_shapes(ckpt_arr, model_arr): + """Expand ckpt_arr to model_arr's shape and re-shard to model_arr's sharding. + + Used to expand checkpoint KV-head (and similar) arrays that were saved with + fewer heads than the padded model shape requires (e.g. due to TP/EP padding + in adapter.py). Each dimension must divide evenly into the corresponding + model dimension. + + Uses jnp.repeat so that each original slice is placed adjacent to its copies. + For GQA with TP, device i needs KV head i//ratio from the original checkpoint, + so the correct layout is e.g. [h0, h0, h1, h1, h2, h2, h3, h3] rather than + [h0, h1, h2, h3, h0, h1, h2, h3]. + """ + ckpt_shape = ckpt_arr.shape + model_shape = model_arr.shape + if ckpt_shape == model_shape: + return jax.device_put(ckpt_arr, model_arr.sharding) + if len(ckpt_shape) != len(model_shape): + raise ValueError( + f"Checkpoint and model arrays have different ranks: {ckpt_shape} vs {model_shape}. " + "If the checkpoint was saved with scan_layers=True (stacked layers), convert it to " + "unscanned format before loading with vLLM (vllm.yml sets scan_layers=False)." + ) + result = ckpt_arr + for axis, (ckpt_dim, model_dim) in enumerate(zip(ckpt_shape, model_shape)): + if model_dim % ckpt_dim != 0: + raise ValueError( + f"Model dimension {model_dim} is not evenly divisible by checkpoint dimension {ckpt_dim}." + f" Full shapes — checkpoint: {ckpt_shape}, model: {model_shape}" + ) + if model_dim != ckpt_dim: + result = jnp.repeat(result, model_dim // ckpt_dim, axis=axis) + return jax.device_put(result, model_arr.sharding) + + +def _fix_restore_args_for_shape_mismatch(restore_args, stored_metadata_tree, mesh): + """Use replicated sharding for arrays whose checkpoint shape differs from the model shape. + + When the model is initialized with padded shapes (e.g. KV heads padded to match + TP size) but the checkpoint was saved with smaller shapes, Orbax will reject the + restore because the provided sharding is incompatible with the stored shape. + For those arrays we switch to a fully-replicated sharding and clear global_shape + so Orbax loads the array as-written. _expand_checkpoint_to_model_shapes then + expands and re-shards the loaded arrays to match the model. + + Uses tree_map_with_path so each ArrayRestoreArgs is looked up by path in the + metadata dict — avoids ordering/count mismatches from flattening two trees with + different pytree node types (e.g. nnx.State vs plain dict) independently. + """ + replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + + def _key_str(key): + """Extract string name from a JAX path key (DictKey, GetAttrKey, etc.).""" + if hasattr(key, "key"): + return str(key.key) + if hasattr(key, "attr"): + return str(key.attr) + return str(key) + + def _lookup_stored_meta(path): + """Navigate stored_metadata_tree using path keys from the restore_args tree.""" + node = stored_metadata_tree + for key in path: + name = _key_str(key) + if isinstance(node, dict) and name in node: + node = node[name] + else: + return None + return node + + mismatched_paths = [] + + def _fix_one(path, restore_arg): + if not isinstance(restore_arg, ocp.ArrayRestoreArgs): + return restore_arg + stored_meta = _lookup_stored_meta(path) + if stored_meta is not None and _is_orbax_array_metadata(stored_meta): + stored_shape = tuple(stored_meta.shape) + if ( + restore_arg.global_shape is not None + and restore_arg.global_shape != stored_shape + and len(stored_shape) == len(restore_arg.global_shape) + ): + mismatched_paths.append( + f" {'.'.join(_key_str(k) for k in path)}: stored={stored_shape} -> model={restore_arg.global_shape}" + ) + return dataclasses.replace( + restore_arg, global_shape=None, shape=None, sharding=replicated, mesh=None, mesh_axes=None + ) + return restore_arg + + fixed = jax.tree_util.tree_map_with_path(_fix_one, restore_args, is_leaf=lambda x: isinstance(x, ocp.ArrayRestoreArgs)) + if mismatched_paths: + max_logging.log( + f"Checkpoint shape mismatches ({len(mismatched_paths)} arrays): loading with replicated " + "sharding and expanding to model shape after restore.\n" + "\n".join(mismatched_paths) + ) + return fixed + @overload def from_config( @@ -154,6 +268,7 @@ def create_sharded_state(): with nn.logical_axis_rules(config.logical_axis_rules): sharded_state = create_sharded_state() model = nnx.merge(graphdef, sharded_state) + # print weights sharding info under debug sharding mode if config.debug_sharding: max_utils.print_non_trivial_mesh_axis(model.mesh) @@ -163,6 +278,7 @@ def create_sharded_state(): mesh=model.mesh, logical_annotations=specs, ) + if config.load_parameters_path: try: ckptr = ocp.Checkpointer( @@ -196,7 +312,16 @@ def create_sharded_state(): ) item_to_restore = {"params": {"params": target_for_restore}} - restore_args = {"params": {"params": ocp.checkpoint_utils.construct_restore_args(target_for_restore)}} + base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore) + restore_args = { + "params": { + "params": _fix_restore_args_for_shape_mismatch( + base_restore_args, + metadata.item_metadata.tree["params"]["params"], + mesh, + ) + } + } else: # structure of nnx checkpoint: {'decoder': {'value': ...}} target_for_restore = jax.tree.map( @@ -205,7 +330,12 @@ def create_sharded_state(): is_leaf=lambda n: isinstance(n, nnx.Variable), ) item_to_restore = target_for_restore - restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore) + base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore) + restore_args = _fix_restore_args_for_shape_mismatch( + base_restore_args, + metadata.item_metadata.tree, + mesh, + ) restored = ckptr.restore( epath.Path(config.load_parameters_path), @@ -223,7 +353,22 @@ def create_sharded_state(): else: checkpoint = restored["params"]["params"] + loaded_count = len(jax.tree_util.tree_leaves(checkpoint)) + expected_count = len(jax.tree_util.tree_leaves(target_for_restore)) + if loaded_count < expected_count: + raise ValueError( + f"Checkpoint at '{config.load_parameters_path}' loaded only {loaded_count} of {expected_count} " + "expected parameter arrays. This usually means a scanned (stacked-layers) checkpoint was provided " + "where an unscanned checkpoint is required. Please convert the checkpoint to unscanned format first." + ) + if checkpoint: + model_arrays = jax.tree.map( + lambda v: v.value, + sharded_state, + is_leaf=lambda n: isinstance(n, nnx.Variable), + ) + checkpoint = jax.tree.map(_expand_checkpoint_to_model_shapes, checkpoint, model_arrays) nnx.update(model, checkpoint) except Exception as e: diff --git a/tests/unit/model_creation_utils_test.py b/tests/unit/model_creation_utils_test.py new file mode 100644 index 0000000000..7f8c784176 --- /dev/null +++ b/tests/unit/model_creation_utils_test.py @@ -0,0 +1,110 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for model_creation_utils.""" + +import dataclasses +import unittest + +import jax +import jax.numpy as jnp + +from orbax import checkpoint as ocp + +# Import the private helpers under test. +from maxtext.utils.model_creation_utils import _fix_restore_args_for_shape_mismatch + + +# --------------------------------------------------------------------------- +# Minimal stub for ArrayMetadata (avoids a real Orbax checkpoint on disk). +# --------------------------------------------------------------------------- +@dataclasses.dataclass +class _FakeArrayMetadata: + shape: tuple + dtype: object = jnp.float32 + sharding: object = None + + +def _is_fake_meta(x): + return isinstance(x, _FakeArrayMetadata) + + +# Monkey-patch the module-level helper so our fake metadata is recognised. +import maxtext.utils.model_creation_utils as _mcu + +_orig_is_orbax = _mcu._is_orbax_array_metadata # pylint: disable=protected-access +_mcu._is_orbax_array_metadata = _is_fake_meta # pylint: disable=protected-access + + +def _make_restore_arg(global_shape): + """Return an ArrayRestoreArgs with a trivial NamedSharding.""" + mesh = jax.sharding.Mesh(jax.local_devices()[:1], ("x",)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + return ocp.ArrayRestoreArgs( + global_shape=global_shape, + shape=global_shape, + sharding=sharding, + dtype=jnp.float32, + ) + + +class FixRestoreArgsRankGuardTest(unittest.TestCase): + """_fix_restore_args_for_shape_mismatch must not touch args when stored rank != model rank.""" + + def setUp(self): + self.mesh = jax.sharding.Mesh(jax.local_devices()[:1], ("x",)) + + def _run_fix(self, stored_shape, model_shape): + restore_args = {"kernel": _make_restore_arg(model_shape)} + metadata_tree = {"kernel": _FakeArrayMetadata(shape=stored_shape)} + return _fix_restore_args_for_shape_mismatch(restore_args, metadata_tree, self.mesh) + + def test_scanned_ckpt_unscanned_model_not_modified(self): + """Rank mismatch (scanned ckpt rank 4 vs unscanned model rank 3): arg must be unchanged.""" + # Simulates: scanned checkpoint key kernel (94, 4096, 4, 128) vs vLLM model (4096, 64, 128). + stored_shape = (94, 4096, 4, 128) + model_shape = (4096, 64, 128) + fixed = self._run_fix(stored_shape, model_shape) + arg = fixed["kernel"] + # The restore arg should be unchanged — global_shape still points to model_shape. + self.assertEqual(arg.global_shape, model_shape) + + def test_same_rank_shape_mismatch_is_modified(self): + """Same rank, shape mismatch (KV padding): arg should be switched to replicated.""" + # Simulates: unscanned checkpoint (4096, 4, 128) vs padded model (4096, 64, 128). + stored_shape = (4096, 4, 128) + model_shape = (4096, 64, 128) + fixed = self._run_fix(stored_shape, model_shape) + arg = fixed["kernel"] + # global_shape must be cleared (set to None) so Orbax loads the stored shape as-is. + self.assertIsNone(arg.global_shape) + + def test_same_shape_no_modification(self): + """Identical shapes: arg must be unchanged.""" + shape = (4096, 4, 128) + fixed = self._run_fix(shape, shape) + arg = fixed["kernel"] + self.assertEqual(arg.global_shape, shape) + + def test_scanned_both_same_rank_shape_mismatch_is_modified(self): + """Scanned ckpt + scanned model + KV padding (rank 4 both): arg must be modified.""" + stored_shape = (94, 4096, 4, 128) + model_shape = (94, 4096, 64, 128) + fixed = self._run_fix(stored_shape, model_shape) + arg = fixed["kernel"] + self.assertIsNone(arg.global_shape) + + +if __name__ == "__main__": + unittest.main()