Skip to content

Commit 1d2eefa

Browse files
NicoGrandekhatwanimohit
authored andcommitted
change EP axis to expert from attn_dp_expert.
1 parent 060fcd4 commit 1d2eefa

5 files changed

Lines changed: 282 additions & 9 deletions

File tree

src/maxtext/configs/inference/vllm.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ logical_axis_rules: [
3333
['activation_batch_no_exp', []],
3434
['activation_embed_and_logits_batch', ['expert']],
3535
['activation_embed_and_logits_batch_sequence', ['expert']],
36-
['activation_heads', ['model']],
37-
['activation_kv_heads', ['model']],
36+
['activation_heads', ['model', 'expert']],
37+
['activation_kv_heads', ['model', 'expert']],
3838
['activation_attn_length', ['expert']],
3939
['activation_attn_length_no_exp', []],
4040
['activation_length', ['data', 'expert']],
@@ -58,11 +58,12 @@ logical_axis_rules: [
5858
['moe_mlp', ['model', 'attn_dp']],
5959
['vocab', ['model', 'attn_dp']],
6060
['heads', ['model']],
61-
['q_heads', ['model']],
62-
['kv_heads', ['model']],
61+
['q_heads', ['model', 'expert']],
62+
['kv_heads', ['model', 'expert']],
6363
['kv_head_dim', []],
6464
['kv', []],
6565
['embed', ['expert', 'attn_dp_expert']],
66+
['embed', ['attn_dp_expert']],
6667
['embed_tensor_transpose', ['attn_dp', 'model']],
6768
['embed_no_exp', []],
6869
['q_lora', ['expert', 'attn_dp_expert']],

src/maxtext/inference/vllm_decode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def decode_with_vllm(config: Config) -> None:
145145
max_tokens=max_tokens_to_generate,
146146
top_k=config.decode_sampling_top_k,
147147
top_p=config.decode_sampling_nucleus_p,
148-
seed=FLAGS.seed,
149148
)
150149

151150
outputs = llm.generate(prompts, sampling_params)

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
try:
3232
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
33+
from tpu_inference.layers.common.attention_interface import ShardingAxisName
3334
except ImportError:
3435
# Mock for documentation build or environments without tpu_inference
3536
class AttentionMetadata:
@@ -39,7 +40,7 @@ class AttentionMetadata:
3940
from vllm.config import VllmConfig
4041

4142

42-
def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters:
43+
def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.HyperParameters:
4344
"""Generates a MaxText configuration from a vLLM configuration.
4445
4546
This function takes a vLLM configuration object and translates relevant
@@ -50,6 +51,7 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
5051
Args:
5152
vllm_config: The vLLM configuration object containing model and load
5253
parameters.
54+
mesh: The JAX mesh device for model sharding.
5355
5456
Returns:
5557
A `pyconfig.HyperParameters` object configured for MaxText.
@@ -73,6 +75,22 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
7375
base_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
7476
argv_list = ["", str(base_config_path)]
7577

78+
# Pad the number of KV heads if its less than the TP / EP size
79+
if isinstance(ShardingAxisName.ATTN_HEAD, tuple):
80+
tp_sizes = [mesh.shape[axis_name] for axis_name in ShardingAxisName.ATTN_HEAD]
81+
max_tp_size = max(tp_sizes)
82+
else:
83+
max_tp_size = mesh.shape[ShardingAxisName.ATTN_HEAD]
84+
85+
if (
86+
max_tp_size % vllm_config.model_config.get_total_num_kv_heads() == 0
87+
and vllm_config.model_config.get_total_num_kv_heads() < max_tp_size
88+
):
89+
max_logging.log(
90+
f"Padding num_kv_heads from {vllm_config.model_config.get_total_num_kv_heads()} to {max_tp_size} to match tp_size."
91+
)
92+
overrides["base_num_kv_heads"] = max_tp_size
93+
7694
maxtext_config = pyconfig.initialize(argv_list, **overrides)
7795
return maxtext_config
7896

@@ -96,7 +114,7 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
96114
"""
97115
self.vllm_config = vllm_config
98116
self.cfg = vllm_config.model_config
99-
self.maxtext_config = generate_maxtext_config(vllm_config)
117+
self.maxtext_config = generate_maxtext_config(vllm_config, mesh)
100118

101119
# Model configuration
102120
self.mesh = mesh

src/maxtext/utils/model_creation_utils.py

Lines changed: 147 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# pylint: disable=bare-except, consider-using-generator
1616
""" Utils that are only interesting for creating a model in MaxText. """
1717

18+
import dataclasses
1819
from collections.abc import Sequence
1920
from functools import partial
2021
from typing import overload
@@ -23,15 +24,128 @@
2324
from flax import nnx
2425
import flax.linen as nn
2526
import jax
27+
import jax.numpy as jnp
2628
from jax.sharding import AxisType, Mesh
2729
from maxtext.configs import pyconfig
2830
from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode
2931
from maxtext.layers import quantizations
3032
from maxtext.models import models
33+
from maxtext.utils import max_logging
3134
from maxtext.utils import max_utils
3235
from maxtext.utils import maxtext_utils
3336
from orbax import checkpoint as ocp
3437

38+
try:
39+
from orbax.checkpoint.metadata import ArrayMetadata as _OrbaxArrayMetadata
40+
41+
def _is_orbax_array_metadata(x):
42+
return isinstance(x, _OrbaxArrayMetadata)
43+
44+
except ImportError:
45+
46+
def _is_orbax_array_metadata(x):
47+
return hasattr(x, "shape") and hasattr(x, "sharding") and hasattr(x, "dtype") and not isinstance(x, jax.Array)
48+
49+
50+
def _expand_checkpoint_to_model_shapes(ckpt_arr, model_arr):
51+
"""Expand ckpt_arr to model_arr's shape and re-shard to model_arr's sharding.
52+
53+
Used to expand checkpoint KV-head (and similar) arrays that were saved with
54+
fewer heads than the padded model shape requires (e.g. due to TP/EP padding
55+
in adapter.py). Each dimension must divide evenly into the corresponding
56+
model dimension.
57+
58+
Uses jnp.repeat so that each original slice is placed adjacent to its copies.
59+
For GQA with TP, device i needs KV head i//ratio from the original checkpoint,
60+
so the correct layout is e.g. [h0, h0, h1, h1, h2, h2, h3, h3] rather than
61+
[h0, h1, h2, h3, h0, h1, h2, h3].
62+
"""
63+
ckpt_shape = ckpt_arr.shape
64+
model_shape = model_arr.shape
65+
if ckpt_shape == model_shape:
66+
return jax.device_put(ckpt_arr, model_arr.sharding)
67+
if len(ckpt_shape) != len(model_shape):
68+
raise ValueError(
69+
f"Checkpoint and model arrays have different ranks: {ckpt_shape} vs {model_shape}. "
70+
"If the checkpoint was saved with scan_layers=True (stacked layers), convert it to "
71+
"unscanned format before loading with vLLM (vllm.yml sets scan_layers=False)."
72+
)
73+
result = ckpt_arr
74+
for axis, (ckpt_dim, model_dim) in enumerate(zip(ckpt_shape, model_shape)):
75+
if model_dim % ckpt_dim != 0:
76+
raise ValueError(
77+
f"Model dimension {model_dim} is not evenly divisible by checkpoint dimension {ckpt_dim}."
78+
f" Full shapes — checkpoint: {ckpt_shape}, model: {model_shape}"
79+
)
80+
if model_dim != ckpt_dim:
81+
result = jnp.repeat(result, model_dim // ckpt_dim, axis=axis)
82+
return jax.device_put(result, model_arr.sharding)
83+
84+
85+
def _fix_restore_args_for_shape_mismatch(restore_args, stored_metadata_tree, mesh):
86+
"""Use replicated sharding for arrays whose checkpoint shape differs from the model shape.
87+
88+
When the model is initialized with padded shapes (e.g. KV heads padded to match
89+
TP size) but the checkpoint was saved with smaller shapes, Orbax will reject the
90+
restore because the provided sharding is incompatible with the stored shape.
91+
For those arrays we switch to a fully-replicated sharding and clear global_shape
92+
so Orbax loads the array as-written. _expand_checkpoint_to_model_shapes then
93+
expands and re-shards the loaded arrays to match the model.
94+
95+
Uses tree_map_with_path so each ArrayRestoreArgs is looked up by path in the
96+
metadata dict — avoids ordering/count mismatches from flattening two trees with
97+
different pytree node types (e.g. nnx.State vs plain dict) independently.
98+
"""
99+
replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
100+
101+
def _key_str(key):
102+
"""Extract string name from a JAX path key (DictKey, GetAttrKey, etc.)."""
103+
if hasattr(key, "key"):
104+
return str(key.key)
105+
if hasattr(key, "attr"):
106+
return str(key.attr)
107+
return str(key)
108+
109+
def _lookup_stored_meta(path):
110+
"""Navigate stored_metadata_tree using path keys from the restore_args tree."""
111+
node = stored_metadata_tree
112+
for key in path:
113+
name = _key_str(key)
114+
if isinstance(node, dict) and name in node:
115+
node = node[name]
116+
else:
117+
return None
118+
return node
119+
120+
mismatched_paths = []
121+
122+
def _fix_one(path, restore_arg):
123+
if not isinstance(restore_arg, ocp.ArrayRestoreArgs):
124+
return restore_arg
125+
stored_meta = _lookup_stored_meta(path)
126+
if stored_meta is not None and _is_orbax_array_metadata(stored_meta):
127+
stored_shape = tuple(stored_meta.shape)
128+
if (
129+
restore_arg.global_shape is not None
130+
and restore_arg.global_shape != stored_shape
131+
and len(stored_shape) == len(restore_arg.global_shape)
132+
):
133+
mismatched_paths.append(
134+
f" {'.'.join(_key_str(k) for k in path)}: stored={stored_shape} -> model={restore_arg.global_shape}"
135+
)
136+
return dataclasses.replace(
137+
restore_arg, global_shape=None, shape=None, sharding=replicated, mesh=None, mesh_axes=None
138+
)
139+
return restore_arg
140+
141+
fixed = jax.tree_util.tree_map_with_path(_fix_one, restore_args, is_leaf=lambda x: isinstance(x, ocp.ArrayRestoreArgs))
142+
if mismatched_paths:
143+
max_logging.log(
144+
f"Checkpoint shape mismatches ({len(mismatched_paths)} arrays): loading with replicated "
145+
"sharding and expanding to model shape after restore.\n" + "\n".join(mismatched_paths)
146+
)
147+
return fixed
148+
35149

36150
@overload
37151
def from_config(
@@ -154,6 +268,7 @@ def create_sharded_state():
154268
with nn.logical_axis_rules(config.logical_axis_rules):
155269
sharded_state = create_sharded_state()
156270
model = nnx.merge(graphdef, sharded_state)
271+
157272
# print weights sharding info under debug sharding mode
158273
if config.debug_sharding:
159274
max_utils.print_non_trivial_mesh_axis(model.mesh)
@@ -163,6 +278,7 @@ def create_sharded_state():
163278
mesh=model.mesh,
164279
logical_annotations=specs,
165280
)
281+
166282
if config.load_parameters_path:
167283
try:
168284
ckptr = ocp.Checkpointer(
@@ -196,7 +312,16 @@ def create_sharded_state():
196312
)
197313

198314
item_to_restore = {"params": {"params": target_for_restore}}
199-
restore_args = {"params": {"params": ocp.checkpoint_utils.construct_restore_args(target_for_restore)}}
315+
base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
316+
restore_args = {
317+
"params": {
318+
"params": _fix_restore_args_for_shape_mismatch(
319+
base_restore_args,
320+
metadata.item_metadata.tree["params"]["params"],
321+
mesh,
322+
)
323+
}
324+
}
200325
else:
201326
# structure of nnx checkpoint: {'decoder': {'value': ...}}
202327
target_for_restore = jax.tree.map(
@@ -205,7 +330,12 @@ def create_sharded_state():
205330
is_leaf=lambda n: isinstance(n, nnx.Variable),
206331
)
207332
item_to_restore = target_for_restore
208-
restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
333+
base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
334+
restore_args = _fix_restore_args_for_shape_mismatch(
335+
base_restore_args,
336+
metadata.item_metadata.tree,
337+
mesh,
338+
)
209339

210340
restored = ckptr.restore(
211341
epath.Path(config.load_parameters_path),
@@ -223,7 +353,22 @@ def create_sharded_state():
223353
else:
224354
checkpoint = restored["params"]["params"]
225355

356+
loaded_count = len(jax.tree_util.tree_leaves(checkpoint))
357+
expected_count = len(jax.tree_util.tree_leaves(target_for_restore))
358+
if loaded_count < expected_count:
359+
raise ValueError(
360+
f"Checkpoint at '{config.load_parameters_path}' loaded only {loaded_count} of {expected_count} "
361+
"expected parameter arrays. This usually means a scanned (stacked-layers) checkpoint was provided "
362+
"where an unscanned checkpoint is required. Please convert the checkpoint to unscanned format first."
363+
)
364+
226365
if checkpoint:
366+
model_arrays = jax.tree.map(
367+
lambda v: v.value,
368+
sharded_state,
369+
is_leaf=lambda n: isinstance(n, nnx.Variable),
370+
)
371+
checkpoint = jax.tree.map(_expand_checkpoint_to_model_shapes, checkpoint, model_arrays)
227372
nnx.update(model, checkpoint)
228373

229374
except Exception as e:

0 commit comments

Comments
 (0)