Skip to content

Commit 9e3afc1

Browse files
committed
update vllm logical rule
1 parent f67d8b1 commit 9e3afc1

5 files changed

Lines changed: 77 additions & 56 deletions

File tree

src/maxtext/configs/inference/vllm.yml

Lines changed: 69 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -29,56 +29,77 @@ weight_dtype: bfloat16
2929
# -------------- Logical Axis Rules --------------
3030
mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
3131
logical_axis_rules: [
32-
['activation_batch', ['data']],
33-
['activation_batch_moe', ['data']],
34-
['activation_batch_attn', ['data']],
35-
['activation_embed_and_logits_batch', ['data', 'expert']],
36-
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
37-
['activation_heads', ['model', 'expert']],
38-
['activation_kv_heads', ['model', 'expert']],
39-
['activation_length_attn', []],
40-
['activation_length', []],
41-
['activation_length_moe', []],
42-
['activation_q_length', ['expert', 'attn_dp_expert']],
43-
['activation_embed_attn', 'model'],
44-
# Expert is missing explicitly from activation_embed despite using TP.
45-
# We are going for a replicate-AR style of TP as opposed to our typical AG-RS style of TP
46-
# due to the output sharding of the fused_moe_gmm kernel in tpu-inference.
47-
['activation_embed', ['model', 'attn_dp']],
48-
['activation_embed_moe', ['model', 'attn_dp']],
49-
['activation_mlp', ['model', 'attn_dp']],
50-
['activation_mlp_moe', ['model', 'attn_dp']],
51-
['activation_kv', ['model']],
52-
['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
53-
['activation_kv_batch', ['data']],
54-
['activation_kv_head_dim', ['model']],
55-
['activation_vocab', ['model', 'attn_dp']],
56-
['activation_norm_length', []],
57-
['activation_norm_length_moe', []],
58-
['activation_exp', ['expert', 'attn_dp_expert']],
59-
['decode_batch', ['data']],
60-
['decode_batch_moe', ['data']],
61-
['decode_length', []],
62-
['mlp', ['model', 'attn_dp']],
63-
['mlp_moe', ['model', 'attn_dp']],
64-
['mlp_no_fsdp', ['model', 'attn_dp']],
65-
['vocab', ['model', 'attn_dp']],
66-
# Expert is intended to act like TP for attention.
67-
# We target two all-reduces, one at the end of attention out projection and one at the end of the feedforward.
68-
['heads', ['model', 'expert']],
69-
['q_heads', ['model', 'expert']],
70-
['kv_heads', ['model', 'expert']],
71-
['kv_head_dim', []],
32+
# ==========================================
33+
# Vocabulary Embedding
34+
# ==========================================
35+
# Vocab Activations
36+
['activation_embed_and_logits_batch', ['data', 'attn_dp', 'attn_dp_expert']],
37+
['activation_embed_and_logits_batch_sequence', ['data', 'attn_dp', 'attn_dp_expert']],
38+
['activation_vocab', ['expert', 'model']],
39+
# Vocab Weights
40+
['vocab', []],
41+
['embed_vocab', []],
42+
# ==========================================
43+
# Attention
44+
# ==========================================
45+
# Attention Activations
46+
['activation_batch_attn', ['data', 'attn_dp', 'attn_dp_expert']],
47+
['activation_heads', ['expert', 'model']],
48+
['activation_kv_heads', ['expert', 'model']],
49+
['activation_embed_attn', []],
50+
['activation_kv', []],
51+
['activation_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
52+
['activation_kv_head_dim', []],
53+
# Attention Weights
54+
['heads', ['expert', 'model']],
55+
['q_heads', ['expert', 'model']],
56+
['kv_heads', ['expert', 'model']],
57+
['qkv', []],
7258
['kv', []],
73-
['embed', []],
59+
['kv_head_dim', []],
60+
['q_lora', []],
61+
["q_lora_up_proj", []],
62+
['kv_lora', []],
63+
["kv_lora_up_proj", []],
64+
# ==========================================
65+
# Mixture of Experts (MoE)
66+
# ==========================================
67+
# MoE Activations
68+
['activation_batch_moe', ['data', 'attn_dp', 'attn_dp_expert']],
69+
['activation_embed_moe', ['attn_dp', 'model']],
70+
['activation_mlp_moe', ['attn_dp', 'model']],
71+
['activation_exp', ['attn_dp_expert', 'expert']],
72+
# MoE Weights
73+
['exp', ['attn_dp_expert', 'expert']],
74+
['mlp_moe', ['attn_dp', 'model']],
7475
['embed_moe', []],
75-
['embed_tensor_transpose', ['attn_dp', 'model']],
76-
['q_lora', ['expert', 'attn_dp_expert']],
77-
['kv_lora', ['expert', 'attn_dp_expert']],
76+
# ==========================================
77+
# Standard MLP / Dense Layers / Model Structure
78+
# ==========================================
79+
# Dense Activations
80+
['activation_mlp', ['attn_dp', 'model']],
81+
# Note activation batch and length also get used in attention and vocab
82+
['activation_batch', ['data', 'attn_dp', 'attn_dp_expert']],
83+
['activation_embed', []],
84+
# General Weights
85+
['mlp', ['attn_dp', 'model']],
86+
['embed', []],
7887
['norm', []],
79-
['cache_heads', ['model']],
80-
['exp', ['expert', 'attn_dp_expert']],
81-
['paged_kv_heads', ['model']],
82-
]
88+
# ==========================================
89+
# Inference(Prefill, Decode, Cache)
90+
# ==========================================
91+
['activation_prefill_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
92+
['decode_batch', ['data', 'attn_dp', 'attn_dp_expert']],
93+
['cache_heads', ['expert', 'model']],
94+
['paged_kv_heads', ['expert', 'model']],
95+
['cache_batch_prefill', []],
96+
['cache_batch', []],
97+
['cache_heads_none', []],
98+
['cache_kv', []],
99+
['cache_sequence', []],
100+
['num_pages', []],
101+
['tokens_per_page', []],
102+
['paged_kv_head_dim_size', []],
103+
]
83104
data_sharding: [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']]
84105
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
254254
return
255255

256256
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
257-
model, _ = model_creation_utils.create_nnx_model(
257+
model = model_creation_utils.from_pretrained(
258258
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
259259
)
260260
self.model = nnx.data(model)

src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def generate_and_save_data(config, local_args):
9696

9797
# Loading teacher model and dataset iterator
9898
max_logging.log(f"Loading Teacher Model from {config.load_parameters_path}...")
99-
teacher_model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
99+
teacher_model = model_creation_utils.from_pretrained(config, mesh=mesh)
100100
train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh)
101101

102102
# Determine start_step for resuming

src/maxtext/utils/model_creation_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# pylint: disable=bare-except, consider-using-generator
16-
""" Utils that are only interesting for creating a model in MaxText. """
16+
"""Utils that are only interesting for creating a model in MaxText."""
1717

1818
import dataclasses
1919
import collections
@@ -226,9 +226,9 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng
226226
def create_nnx_abstract_model(config, mesh, model_mode=MODEL_MODE_TRAIN, rng_key=None):
227227
"""Returns (_create_model_partial, abstract_model) for AOT compilation.
228228
229-
Unlike create_nnx_model, this does not shard parameters or load checkpoints.
230-
It only builds the abstract shape/dtype structure needed by get_abstract_state
231-
and optimizer construction (e.g. Muon).
229+
This does not shard parameters or load checkpoints. It only builds the
230+
abstract shape/dtype structure needed by get_abstract_state and optimizer
231+
construction (e.g. Muon).
232232
233233
Args:
234234
config: the configuration

tests/unit/model_creation_utils_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def _make_nnx_metadata_mock(self):
345345
@patch("maxtext.utils.model_creation_utils.ocp")
346346
def test_load_nnx_checkpoint(self, mock_ocp):
347347
"""NNX-format checkpoint: restored values are wrapped under a 'value' key."""
348-
# Echo back the `item` argument passed by create_nnx_model to ckptr.restore.
348+
# Echo back the `item` argument passed by from_pretrained to ckptr.restore.
349349
# For NNX checkpoints, item IS already {leaf: {"value": array}, ...}, so
350350
# returning it directly gives a correctly-structured restored dict that
351351
# matches the model's own state — regardless of the exact leaf count.
@@ -364,7 +364,7 @@ def test_load_nnx_checkpoint(self, mock_ocp):
364364
@patch("maxtext.utils.model_creation_utils.ocp")
365365
def test_load_linen_checkpoint(self, mock_ocp):
366366
"""Linen-format checkpoint: restored values are nested under 'params'/'params'."""
367-
# Echo back the `item` argument passed by create_nnx_model to ckptr.restore.
367+
# Echo back the `item` argument passed by from_pretrained to ckptr.restore.
368368
# For Linen checkpoints, item IS already {"params": {"params": arrays}}, so
369369
# returning it directly gives a correctly-structured restored dict that
370370
# matches the model's own state — regardless of the exact leaf count.

0 commit comments

Comments
 (0)