Skip to content

Commit 06d458d

Browse files
NicoGrandekhatwanimohit
authored andcommitted
change EP axis to expert from attn_dp_expert.
1 parent 8b86c71 commit 06d458d

3 files changed

Lines changed: 147 additions & 5 deletions

File tree

src/maxtext/inference/vllm_decode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def decode_with_vllm(config: Config) -> None:
145145
max_tokens=max_tokens_to_generate,
146146
top_k=config.decode_sampling_top_k,
147147
top_p=config.decode_sampling_nucleus_p,
148-
seed=FLAGS.seed,
149148
)
150149

151150
outputs = llm.generate(prompts, sampling_params)

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
try:
3232
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
33+
from tpu_inference.layers.common.attention_interface import ShardingAxisName
3334
except ImportError:
3435
# Mock for documentation build or environments without tpu_inference
3536
class AttentionMetadata:
@@ -39,7 +40,7 @@ class AttentionMetadata:
3940
from vllm.config import VllmConfig
4041

4142

42-
def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters:
43+
def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.HyperParameters:
4344
"""Generates a MaxText configuration from a vLLM configuration.
4445
4546
This function takes a vLLM configuration object and translates relevant
@@ -50,6 +51,7 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
5051
Args:
5152
vllm_config: The vLLM configuration object containing model and load
5253
parameters.
54+
mesh: The JAX mesh device for model sharding.
5355
5456
Returns:
5557
A `pyconfig.HyperParameters` object configured for MaxText.
@@ -73,6 +75,19 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
7375
base_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
7476
argv_list = ["", str(base_config_path)]
7577

78+
# Pad the number of KV heads if its less than the TP / EP size
79+
if isinstance(ShardingAxisName.ATTN_HEAD, tuple):
80+
tp_sizes = [mesh.shape[axis_name] for axis_name in ShardingAxisName.ATTN_HEAD]
81+
max_tp_size = max(tp_sizes)
82+
else:
83+
max_tp_size = mesh.shape[ShardingAxisName.ATTN_HEAD]
84+
85+
if vllm_config.model_config.get_total_num_kv_heads() < max_tp_size:
86+
max_logging.log(
87+
f"Padding num_kv_heads from {vllm_config.model_config.get_total_num_kv_heads()} to {max_tp_size} to match tp_size."
88+
)
89+
overrides["base_num_kv_heads"] = max_tp_size
90+
7691
maxtext_config = pyconfig.initialize(argv_list, **overrides)
7792
return maxtext_config
7893

@@ -96,7 +111,7 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
96111
"""
97112
self.vllm_config = vllm_config
98113
self.cfg = vllm_config.model_config
99-
self.maxtext_config = generate_maxtext_config(vllm_config)
114+
self.maxtext_config = generate_maxtext_config(vllm_config, mesh)
100115

101116
# Model configuration
102117
self.mesh = mesh

src/maxtext/utils/model_creation_utils.py

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# pylint: disable=bare-except, consider-using-generator
1616
""" Utils that are only interesting for creating a model in MaxText. """
1717

18+
import dataclasses
1819
from collections.abc import Sequence
1920
from functools import partial
2021
from typing import overload
@@ -23,15 +24,120 @@
2324
from flax import nnx
2425
import flax.linen as nn
2526
import jax
27+
import jax.numpy as jnp
2628
from jax.sharding import AxisType, Mesh
2729
from maxtext.configs import pyconfig
2830
from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode
2931
from maxtext.layers import quantizations
3032
from maxtext.models import models
33+
from maxtext.utils import max_logging
3134
from maxtext.utils import max_utils
3235
from maxtext.utils import maxtext_utils
3336
from orbax import checkpoint as ocp
3437

38+
try:
39+
from orbax.checkpoint.metadata import ArrayMetadata as _OrbaxArrayMetadata
40+
41+
def _is_orbax_array_metadata(x):
42+
return isinstance(x, _OrbaxArrayMetadata)
43+
44+
except ImportError:
45+
46+
def _is_orbax_array_metadata(x):
47+
return hasattr(x, "shape") and hasattr(x, "sharding") and hasattr(x, "dtype") and not isinstance(x, jax.Array)
48+
49+
50+
def _expand_checkpoint_to_model_shapes(ckpt_arr, model_arr):
51+
"""Expand ckpt_arr to model_arr's shape and re-shard to model_arr's sharding.
52+
53+
Used to expand checkpoint KV-head (and similar) arrays that were saved with
54+
fewer heads than the padded model shape requires (e.g. due to TP/EP padding
55+
in adapter.py). Each dimension must divide evenly into the corresponding
56+
model dimension.
57+
58+
Uses jnp.repeat so that each original slice is placed adjacent to its copies.
59+
For GQA with TP, device i needs KV head i//ratio from the original checkpoint,
60+
so the correct layout is e.g. [h0, h0, h1, h1, h2, h2, h3, h3] rather than
61+
[h0, h1, h2, h3, h0, h1, h2, h3].
62+
"""
63+
ckpt_shape = ckpt_arr.shape
64+
model_shape = model_arr.shape
65+
if ckpt_shape == model_shape:
66+
return jax.device_put(ckpt_arr, model_arr.sharding)
67+
if len(ckpt_shape) != len(model_shape):
68+
raise ValueError(f"Checkpoint and model arrays have different ranks: {ckpt_shape} vs {model_shape}")
69+
result = ckpt_arr
70+
for axis, (ckpt_dim, model_dim) in enumerate(zip(ckpt_shape, model_shape)):
71+
if model_dim % ckpt_dim != 0:
72+
raise ValueError(
73+
f"Model dimension {model_dim} is not evenly divisible by checkpoint dimension {ckpt_dim}."
74+
f" Full shapes — checkpoint: {ckpt_shape}, model: {model_shape}"
75+
)
76+
if model_dim != ckpt_dim:
77+
result = jnp.repeat(result, model_dim // ckpt_dim, axis=axis)
78+
return jax.device_put(result, model_arr.sharding)
79+
80+
81+
def _fix_restore_args_for_shape_mismatch(restore_args, stored_metadata_tree, mesh):
82+
"""Use replicated sharding for arrays whose checkpoint shape differs from the model shape.
83+
84+
When the model is initialized with padded shapes (e.g. KV heads padded to match
85+
TP size) but the checkpoint was saved with smaller shapes, Orbax will reject the
86+
restore because the provided sharding is incompatible with the stored shape.
87+
For those arrays we switch to a fully-replicated sharding and clear global_shape
88+
so Orbax loads the array as-written. _expand_checkpoint_to_model_shapes then
89+
expands and re-shards the loaded arrays to match the model.
90+
91+
Uses tree_map_with_path so each ArrayRestoreArgs is looked up by path in the
92+
metadata dict — avoids ordering/count mismatches from flattening two trees with
93+
different pytree node types (e.g. nnx.State vs plain dict) independently.
94+
"""
95+
replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
96+
97+
def _key_str(key):
98+
"""Extract string name from a JAX path key (DictKey, GetAttrKey, etc.)."""
99+
if hasattr(key, "key"):
100+
return str(key.key)
101+
if hasattr(key, "attr"):
102+
return str(key.attr)
103+
return str(key)
104+
105+
def _lookup_stored_meta(path):
106+
"""Navigate stored_metadata_tree using path keys from the restore_args tree."""
107+
node = stored_metadata_tree
108+
for key in path:
109+
name = _key_str(key)
110+
if isinstance(node, dict) and name in node:
111+
node = node[name]
112+
else:
113+
return None
114+
return node
115+
116+
mismatched_paths = []
117+
118+
def _fix_one(path, restore_arg):
119+
if not isinstance(restore_arg, ocp.ArrayRestoreArgs):
120+
return restore_arg
121+
stored_meta = _lookup_stored_meta(path)
122+
if stored_meta is not None and _is_orbax_array_metadata(stored_meta):
123+
stored_shape = tuple(stored_meta.shape)
124+
if restore_arg.global_shape is not None and restore_arg.global_shape != stored_shape:
125+
mismatched_paths.append(
126+
f" {'.'.join(_key_str(k) for k in path)}: stored={stored_shape} -> model={restore_arg.global_shape}"
127+
)
128+
return dataclasses.replace(
129+
restore_arg, global_shape=None, shape=None, sharding=replicated, mesh=None, mesh_axes=None
130+
)
131+
return restore_arg
132+
133+
fixed = jax.tree_util.tree_map_with_path(_fix_one, restore_args, is_leaf=lambda x: isinstance(x, ocp.ArrayRestoreArgs))
134+
if mismatched_paths:
135+
max_logging.log(
136+
f"Checkpoint shape mismatches ({len(mismatched_paths)} arrays): loading with replicated "
137+
"sharding and expanding to model shape after restore.\n" + "\n".join(mismatched_paths)
138+
)
139+
return fixed
140+
35141

36142
@overload
37143
def from_config(
@@ -154,6 +260,7 @@ def create_sharded_state():
154260
with nn.logical_axis_rules(config.logical_axis_rules):
155261
sharded_state = create_sharded_state()
156262
model = nnx.merge(graphdef, sharded_state)
263+
157264
# print weights sharding info under debug sharding mode
158265
if config.debug_sharding:
159266
max_utils.print_non_trivial_mesh_axis(model.mesh)
@@ -163,6 +270,7 @@ def create_sharded_state():
163270
mesh=model.mesh,
164271
logical_annotations=specs,
165272
)
273+
166274
if config.load_parameters_path:
167275
try:
168276
ckptr = ocp.Checkpointer(
@@ -196,7 +304,16 @@ def create_sharded_state():
196304
)
197305

198306
item_to_restore = {"params": {"params": target_for_restore}}
199-
restore_args = {"params": {"params": ocp.checkpoint_utils.construct_restore_args(target_for_restore)}}
307+
base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
308+
restore_args = {
309+
"params": {
310+
"params": _fix_restore_args_for_shape_mismatch(
311+
base_restore_args,
312+
metadata.item_metadata.tree["params"]["params"],
313+
mesh,
314+
)
315+
}
316+
}
200317
else:
201318
# structure of nnx checkpoint: {'decoder': {'value': ...}}
202319
target_for_restore = jax.tree.map(
@@ -205,7 +322,12 @@ def create_sharded_state():
205322
is_leaf=lambda n: isinstance(n, nnx.Variable),
206323
)
207324
item_to_restore = target_for_restore
208-
restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
325+
base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
326+
restore_args = _fix_restore_args_for_shape_mismatch(
327+
base_restore_args,
328+
metadata.item_metadata.tree,
329+
mesh,
330+
)
209331

210332
restored = ckptr.restore(
211333
epath.Path(config.load_parameters_path),
@@ -224,6 +346,12 @@ def create_sharded_state():
224346
checkpoint = restored["params"]["params"]
225347

226348
if checkpoint:
349+
model_arrays = jax.tree.map(
350+
lambda v: v.value,
351+
sharded_state,
352+
is_leaf=lambda n: isinstance(n, nnx.Variable),
353+
)
354+
checkpoint = jax.tree.map(_expand_checkpoint_to_model_shapes, checkpoint, model_arrays)
227355
nnx.update(model, checkpoint)
228356

229357
except Exception as e:

0 commit comments

Comments
 (0)