Skip to content

Commit 92cc4b5

Browse files
committed
update vllm logical rule
1 parent 7262655 commit 92cc4b5

2 files changed

Lines changed: 70 additions & 49 deletions

File tree

src/maxtext/configs/inference/vllm.yml

Lines changed: 69 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -29,56 +29,77 @@ weight_dtype: bfloat16
2929
# -------------- Logical Axis Rules --------------
3030
mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
3131
logical_axis_rules: [
32-
['activation_batch', ['data']],
33-
['activation_batch_moe', ['data']],
34-
['activation_batch_attn', ['data']],
35-
['activation_embed_and_logits_batch', ['data', 'expert']],
36-
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
37-
['activation_heads', ['model', 'expert']],
38-
['activation_kv_heads', ['model', 'expert']],
39-
['activation_length_attn', []],
40-
['activation_length', []],
41-
['activation_length_moe', []],
42-
['activation_q_length', ['expert', 'attn_dp_expert']],
43-
['activation_embed_attn', 'model'],
44-
# Expert is missing explicitly from activation_embed despite using TP.
45-
# We are going for a replicate-AR style of TP as opposed to our typical AG-RS style of TP
46-
# due to the output sharding of the fused_moe_gmm kernel in tpu-inference.
47-
['activation_embed', ['model', 'attn_dp']],
48-
['activation_embed_moe', ['model', 'attn_dp']],
49-
['activation_mlp', ['model', 'attn_dp']],
50-
['activation_mlp_moe', ['model', 'attn_dp']],
51-
['activation_kv', ['model']],
52-
['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
53-
['activation_kv_batch', ['data']],
54-
['activation_kv_head_dim', ['model']],
55-
['activation_vocab', ['model', 'attn_dp']],
56-
['activation_norm_length', []],
57-
['activation_norm_length_moe', []],
58-
['activation_exp', ['expert', 'attn_dp_expert']],
59-
['decode_batch', ['data']],
60-
['decode_batch_moe', ['data']],
61-
['decode_length', []],
62-
['mlp', ['model', 'attn_dp']],
63-
['mlp_moe', ['model', 'attn_dp']],
64-
['mlp_no_fsdp', ['model', 'attn_dp']],
65-
['vocab', ['model', 'attn_dp']],
66-
# Expert is intended to act like TP for attention.
67-
# We target two all-reduces, one at the end of attention out projection and one at the end of the feedforward.
68-
['heads', ['model', 'expert']],
69-
['q_heads', ['model', 'expert']],
70-
['kv_heads', ['model', 'expert']],
71-
['kv_head_dim', []],
32+
# ==========================================
33+
# Vocabulary Embedding
34+
# ==========================================
35+
# Vocab Activations
36+
['activation_embed_and_logits_batch', ['data', 'attn_dp', 'attn_dp_expert']],
37+
['activation_embed_and_logits_batch_sequence', ['data', 'attn_dp', 'attn_dp_expert']],
38+
['activation_vocab', ['expert', 'model']],
39+
# Vocab Weights
40+
['vocab', []],
41+
['embed_vocab', []],
42+
# ==========================================
43+
# Attention
44+
# ==========================================
45+
# Attention Activations
46+
['activation_batch_attn', ['data', 'attn_dp', 'attn_dp_expert']],
47+
['activation_heads', ['expert', 'model']],
48+
['activation_kv_heads', ['expert', 'model']],
49+
['activation_embed_attn', []],
50+
['activation_kv', []],
51+
['activation_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
52+
['activation_kv_head_dim', []],
53+
# Attention Weights
54+
['heads', ['expert', 'model']],
55+
['q_heads', ['expert', 'model']],
56+
['kv_heads', ['expert', 'model']],
57+
['qkv', []],
7258
['kv', []],
73-
['embed', []],
59+
['kv_head_dim', []],
60+
['q_lora', []],
61+
["q_lora_up_proj", []],
62+
['kv_lora', []],
63+
["kv_lora_up_proj", []],
64+
# ==========================================
65+
# Mixture of Experts (MoE)
66+
# ==========================================
67+
# MoE Activations
68+
['activation_batch_moe', ['data', 'attn_dp', 'attn_dp_expert']],
69+
['activation_embed_moe', ['attn_dp', 'model']],
70+
['activation_mlp_moe', ['attn_dp', 'model']],
71+
['activation_exp', ['attn_dp_expert', 'expert']],
72+
# MoE Weights
73+
['exp', ['attn_dp_expert', 'expert']],
74+
['mlp_moe', ['attn_dp', 'model']],
7475
['embed_moe', []],
75-
['embed_tensor_transpose', ['attn_dp', 'model']],
76-
['q_lora', ['expert', 'attn_dp_expert']],
77-
['kv_lora', ['expert', 'attn_dp_expert']],
76+
# ==========================================
77+
# Standard MLP / Dense Layers / Model Structure
78+
# ==========================================
79+
# Dense Activations
80+
['activation_mlp', ['attn_dp', 'model']],
81+
# Note activation batch and length also get used in attention and vocab
82+
['activation_batch', ['data', 'attn_dp', 'attn_dp_expert']],
83+
['activation_embed', []],
84+
# General Weights
85+
['mlp', ['attn_dp', 'model']],
86+
['embed', []],
7887
['norm', []],
79-
['cache_heads', ['model']],
80-
['exp', ['expert', 'attn_dp_expert']],
81-
['paged_kv_heads', ['model']],
82-
]
88+
# ==========================================
89+
# Inference(Prefill, Decode, Cache)
90+
# ==========================================
91+
['activation_prefill_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
92+
['decode_batch', ['data', 'attn_dp', 'attn_dp_expert']],
93+
['cache_heads', ['expert', 'model']],
94+
['paged_kv_heads', ['expert', 'model']],
95+
['cache_batch_prefill', []],
96+
['cache_batch', []],
97+
['cache_heads_none', []],
98+
['cache_kv', []],
99+
['cache_sequence', []],
100+
['num_pages', []],
101+
['tokens_per_page', []],
102+
['paged_kv_head_dim_size', []],
103+
]
83104
data_sharding: [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']]
84105
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
254254
return
255255

256256
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
257-
model, _ = model_creation_utils.create_nnx_model(
257+
model = model_creation_utils.from_pretrained(
258258
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
259259
)
260260
self.model = nnx.data(model)

0 commit comments

Comments
 (0)