Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions src/maxtext/configs/inference/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ weight_dtype: bfloat16
# -------------- Logical Axis Rules --------------
mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
logical_axis_rules: [
['activation_batch', ['expert']],
['activation_batch_moe', ['expert']],
['activation_batch', ['data']],
['activation_batch_moe', []],
['activation_batch_no_exp', []],
['activation_batch_no_exp_moe', []],
['activation_embed_and_logits_batch', ['expert']],
['activation_embed_and_logits_batch_sequence', ['expert']],
['activation_heads', ['model']],
['activation_kv_heads', ['model']],
['activation_embed_and_logits_batch', ['data', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
['activation_heads', ['model', 'expert']],
['activation_kv_heads', ['model', 'expert']],
['activation_attn_length', ['expert']],
['activation_attn_length_no_exp', []],
['activation_length', ['data', 'expert']],
Expand All @@ -50,8 +50,8 @@ logical_axis_rules: [
['activation_mlp', ['model', 'attn_dp']],
['activation_kv', ['model']],
['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
['activation_kv_batch', ['expert', 'attn_dp_expert']],
['activation_kv_batch_no_exp', []],
['activation_kv_batch', ['data', 'expert', 'attn_dp_expert']],
['activation_kv_batch_no_exp', ['data']],
['activation_kv_head_dim', ['model']],
['activation_vocab', ['model', 'attn_dp']],
['activation_norm_length', []],
Expand All @@ -64,12 +64,14 @@ logical_axis_rules: [
['moe_mlp', ['model', 'attn_dp']],
['vocab', ['model', 'attn_dp']],
['heads', ['model']],
['q_heads', ['model']],
['kv_heads', ['model']],
['q_heads', ['model', 'expert']],
['kv_heads', ['model', 'expert']],
['kv_head_dim', []],
['kv', []],
['embed', ['expert', 'attn_dp_expert']],
['embed', ['attn_dp_expert']],
['embed_moe', ['expert', 'attn_dp_expert']],
['embed_moe', ['attn_dp_expert']],
['embed_tensor_transpose', ['attn_dp', 'model']],
['embed_no_exp', []],
['embed_no_exp_moe', []],
Expand Down
1 change: 0 additions & 1 deletion src/maxtext/inference/vllm_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def decode_with_vllm(config: Config) -> None:
max_tokens=max_tokens_to_generate,
top_k=config.decode_sampling_top_k,
top_p=config.decode_sampling_nucleus_p,
seed=FLAGS.seed,
)

outputs = llm.generate(prompts, sampling_params)
Expand Down
22 changes: 20 additions & 2 deletions src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

try:
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.common.attention_interface import ShardingAxisName
except ImportError:
# Mock for documentation build or environments without tpu_inference
class AttentionMetadata:
Expand All @@ -39,7 +40,7 @@ class AttentionMetadata:
from vllm.config import VllmConfig


def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters:
def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.HyperParameters:
"""Generates a MaxText configuration from a vLLM configuration.

This function takes a vLLM configuration object and translates relevant
Expand All @@ -50,6 +51,7 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
Args:
vllm_config: The vLLM configuration object containing model and load
parameters.
mesh: The JAX mesh device for model sharding.

Returns:
A `pyconfig.HyperParameters` object configured for MaxText.
Expand All @@ -73,6 +75,22 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
base_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
argv_list = ["", str(base_config_path)]

# Pad the number of KV heads if its less than the TP / EP size
if isinstance(ShardingAxisName.ATTN_HEAD, tuple):
tp_sizes = [mesh.shape[axis_name] for axis_name in ShardingAxisName.ATTN_HEAD]
max_tp_size = max(tp_sizes)
else:
max_tp_size = mesh.shape[ShardingAxisName.ATTN_HEAD]

if (
max_tp_size % vllm_config.model_config.get_total_num_kv_heads() == 0
and vllm_config.model_config.get_total_num_kv_heads() < max_tp_size
):
max_logging.log(
f"Padding num_kv_heads from {vllm_config.model_config.get_total_num_kv_heads()} to {max_tp_size} to match tp_size."
)
overrides["base_num_kv_heads"] = max_tp_size

maxtext_config = pyconfig.initialize(argv_list, **overrides)
return maxtext_config

Expand All @@ -96,7 +114,7 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
"""
self.vllm_config = vllm_config
self.cfg = vllm_config.model_config
self.maxtext_config = generate_maxtext_config(vllm_config)
self.maxtext_config = generate_maxtext_config(vllm_config, mesh)

# Model configuration
self.mesh = mesh
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,8 @@ def gmm(
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat:
tokamax_group_sizes = group_sizes
elif self.config.attention == "vllm_rpa":
tokamax_group_sizes = group_sizes
else:
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
group_sizes,
Expand Down
13 changes: 9 additions & 4 deletions src/maxtext/trainers/post_train/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,14 @@ def create_rl_components(
optimizer = utils_rl.get_optimizer(trainer_config, max_train_steps)

# Setup checkpointing
checkpointing_options = ocp.CheckpointManagerOptions(
save_interval_steps=trainer_config.checkpoint_period, max_to_keep=trainer_config.max_num_checkpoints_to_keep
)
if trainer_config.enable_checkpointing:
checkpointing_options = ocp.CheckpointManagerOptions(
save_interval_steps=trainer_config.checkpoint_period, max_to_keep=trainer_config.max_num_checkpoints_to_keep
)
checkpoint_dir = trainer_config.checkpoint_dir
else:
checkpointing_options = None
checkpoint_dir = None

# Set up micro batching
micro_batch_size = None if trainer_config.micro_batch_size == -1 else trainer_config.micro_batch_size
Expand Down Expand Up @@ -499,7 +504,7 @@ def create_rl_components(
rollout_micro_batch_size=micro_batch_size,
metrics_logging_options=metrics_logging_options,
profiler_options=profiler_options,
checkpoint_root_directory=trainer_config.checkpoint_dir,
checkpoint_root_directory=checkpoint_dir,
checkpointing_options=checkpointing_options,
),
rollout_config=base_rollout.RolloutConfig(
Expand Down
149 changes: 147 additions & 2 deletions src/maxtext/utils/model_creation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# pylint: disable=bare-except, consider-using-generator
""" Utils that are only interesting for creating a model in MaxText. """

import dataclasses
from collections.abc import Sequence
from functools import partial
from typing import overload
Expand All @@ -23,15 +24,128 @@
from flax import nnx
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax.sharding import AxisType, Mesh
from maxtext.configs import pyconfig
from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode
from maxtext.layers import quantizations
from maxtext.models import models
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from orbax import checkpoint as ocp

try:
from orbax.checkpoint.metadata import ArrayMetadata as _OrbaxArrayMetadata

def _is_orbax_array_metadata(x):
return isinstance(x, _OrbaxArrayMetadata)

except ImportError:

def _is_orbax_array_metadata(x):
return hasattr(x, "shape") and hasattr(x, "sharding") and hasattr(x, "dtype") and not isinstance(x, jax.Array)


def _expand_checkpoint_to_model_shapes(ckpt_arr, model_arr):
"""Expand ckpt_arr to model_arr's shape and re-shard to model_arr's sharding.

Used to expand checkpoint KV-head (and similar) arrays that were saved with
fewer heads than the padded model shape requires (e.g. due to TP/EP padding
in adapter.py). Each dimension must divide evenly into the corresponding
model dimension.

Uses jnp.repeat so that each original slice is placed adjacent to its copies.
For GQA with TP, device i needs KV head i//ratio from the original checkpoint,
so the correct layout is e.g. [h0, h0, h1, h1, h2, h2, h3, h3] rather than
[h0, h1, h2, h3, h0, h1, h2, h3].
"""
ckpt_shape = ckpt_arr.shape
model_shape = model_arr.shape
if ckpt_shape == model_shape:
return jax.device_put(ckpt_arr, model_arr.sharding)
if len(ckpt_shape) != len(model_shape):
raise ValueError(
f"Checkpoint and model arrays have different ranks: {ckpt_shape} vs {model_shape}. "
"If the checkpoint was saved with scan_layers=True (stacked layers), convert it to "
"unscanned format before loading with vLLM (vllm.yml sets scan_layers=False)."
)
result = ckpt_arr
for axis, (ckpt_dim, model_dim) in enumerate(zip(ckpt_shape, model_shape)):
if model_dim % ckpt_dim != 0:
raise ValueError(
f"Model dimension {model_dim} is not evenly divisible by checkpoint dimension {ckpt_dim}."
f" Full shapes — checkpoint: {ckpt_shape}, model: {model_shape}"
)
if model_dim != ckpt_dim:
result = jnp.repeat(result, model_dim // ckpt_dim, axis=axis)
return jax.device_put(result, model_arr.sharding)


def _fix_restore_args_for_shape_mismatch(restore_args, stored_metadata_tree, mesh):
"""Use replicated sharding for arrays whose checkpoint shape differs from the model shape.

When the model is initialized with padded shapes (e.g. KV heads padded to match
TP size) but the checkpoint was saved with smaller shapes, Orbax will reject the
restore because the provided sharding is incompatible with the stored shape.
For those arrays we switch to a fully-replicated sharding and clear global_shape
so Orbax loads the array as-written. _expand_checkpoint_to_model_shapes then
expands and re-shards the loaded arrays to match the model.

Uses tree_map_with_path so each ArrayRestoreArgs is looked up by path in the
metadata dict — avoids ordering/count mismatches from flattening two trees with
different pytree node types (e.g. nnx.State vs plain dict) independently.
"""
replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

def _key_str(key):
"""Extract string name from a JAX path key (DictKey, GetAttrKey, etc.)."""
if hasattr(key, "key"):
return str(key.key)
if hasattr(key, "attr"):
return str(key.attr)
return str(key)

def _lookup_stored_meta(path):
"""Navigate stored_metadata_tree using path keys from the restore_args tree."""
node = stored_metadata_tree
for key in path:
name = _key_str(key)
if isinstance(node, dict) and name in node:
node = node[name]
else:
return None
return node

mismatched_paths = []

def _fix_one(path, restore_arg):
if not isinstance(restore_arg, ocp.ArrayRestoreArgs):
return restore_arg
stored_meta = _lookup_stored_meta(path)
if stored_meta is not None and _is_orbax_array_metadata(stored_meta):
stored_shape = tuple(stored_meta.shape)
if (
restore_arg.global_shape is not None
and restore_arg.global_shape != stored_shape
and len(stored_shape) == len(restore_arg.global_shape)
):
mismatched_paths.append(
f" {'.'.join(_key_str(k) for k in path)}: stored={stored_shape} -> model={restore_arg.global_shape}"
)
return dataclasses.replace(
restore_arg, global_shape=None, shape=None, sharding=replicated, mesh=None, mesh_axes=None
)
return restore_arg

fixed = jax.tree_util.tree_map_with_path(_fix_one, restore_args, is_leaf=lambda x: isinstance(x, ocp.ArrayRestoreArgs))
if mismatched_paths:
max_logging.log(
f"Checkpoint shape mismatches ({len(mismatched_paths)} arrays): loading with replicated "
"sharding and expanding to model shape after restore.\n" + "\n".join(mismatched_paths)
)
return fixed


@overload
def from_config(
Expand Down Expand Up @@ -154,6 +268,7 @@ def create_sharded_state():
with nn.logical_axis_rules(config.logical_axis_rules):
sharded_state = create_sharded_state()
model = nnx.merge(graphdef, sharded_state)

# print weights sharding info under debug sharding mode
if config.debug_sharding:
max_utils.print_non_trivial_mesh_axis(model.mesh)
Expand All @@ -163,6 +278,7 @@ def create_sharded_state():
mesh=model.mesh,
logical_annotations=specs,
)

if config.load_parameters_path:
try:
ckptr = ocp.Checkpointer(
Expand Down Expand Up @@ -196,7 +312,16 @@ def create_sharded_state():
)

item_to_restore = {"params": {"params": target_for_restore}}
restore_args = {"params": {"params": ocp.checkpoint_utils.construct_restore_args(target_for_restore)}}
base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
restore_args = {
"params": {
"params": _fix_restore_args_for_shape_mismatch(
base_restore_args,
metadata.item_metadata.tree["params"]["params"],
mesh,
)
}
}
else:
# structure of nnx checkpoint: {'decoder': {'value': ...}}
target_for_restore = jax.tree.map(
Expand All @@ -205,7 +330,12 @@ def create_sharded_state():
is_leaf=lambda n: isinstance(n, nnx.Variable),
)
item_to_restore = target_for_restore
restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
restore_args = _fix_restore_args_for_shape_mismatch(
base_restore_args,
metadata.item_metadata.tree,
mesh,
)

restored = ckptr.restore(
epath.Path(config.load_parameters_path),
Expand All @@ -223,7 +353,22 @@ def create_sharded_state():
else:
checkpoint = restored["params"]["params"]

loaded_count = len(jax.tree_util.tree_leaves(checkpoint))
expected_count = len(jax.tree_util.tree_leaves(target_for_restore))
if loaded_count < expected_count:
raise ValueError(
f"Checkpoint at '{config.load_parameters_path}' loaded only {loaded_count} of {expected_count} "
"expected parameter arrays. This usually means a scanned (stacked-layers) checkpoint was provided "
"where an unscanned checkpoint is required. Please convert the checkpoint to unscanned format first."
)

if checkpoint:
model_arrays = jax.tree.map(
lambda v: v.value,
sharded_state,
is_leaf=lambda n: isinstance(n, nnx.Variable),
)
checkpoint = jax.tree.map(_expand_checkpoint_to_model_shapes, checkpoint, model_arrays)
nnx.update(model, checkpoint)

except Exception as e:
Expand Down
Loading
Loading