Skip to content

Commit 5a71eb9

Browse files
committed
deepseek sharding and mla attention plumbing
1 parent cf051eb commit 5a71eb9

10 files changed

Lines changed: 271 additions & 44 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ logical_axis_rules: [
471471
['decode_length', ['sequence']],
472472
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
473473
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
474+
['moe_mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
474475
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
475476
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
476477
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
@@ -1119,6 +1120,8 @@ vllm_hf_config_path: ""
11191120
# A JSON string of overrides to apply to the HuggingFace-style config for the vLLM adapter.
11201121
# This can be used to override specific settings without modifying the original config file.
11211122
vllm_hf_overrides: {}
1123+
# Path to yaml file for loading vLLM config
1124+
vllm_config_path: ""
11221125
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
11231126
vllm_additional_config: {}
11241127
# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH]
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
base_config: "vllm.yml"
17+
18+
logical_axis_rules: [
19+
['activation_batch', ['']],
20+
['activation_batch_no_exp', []],
21+
['activation_embed_and_logits_batch', ['expert']],
22+
['activation_embed_and_logits_batch_sequence', ['expert']],
23+
['activation_heads', ['model']],
24+
['activation_kv_heads', ['model']],
25+
['activation_attn_length', ['expert']],
26+
['activation_attn_length_no_exp', []],
27+
['activation_length', ['data', 'expert']],
28+
['activation_length_no_exp', 'data'],
29+
['activation_q_length', ['expert']],
30+
['activation_attn_embed', 'model'],
31+
['activation_embed', ['model', 'attn_dp']],
32+
['activation_mlp', ['model', 'attn_dp', 'expert']],
33+
['activation_kv', ['model']],
34+
['activation_prefill_kv_batch', ['expert']],
35+
['activation_kv_batch', ['']],
36+
['activation_kv_batch_no_exp', []],
37+
['activation_kv_head_dim', ['model', 'attn_dp', 'expert']],
38+
['activation_vocab', ['model', 'attn_dp']],
39+
['activation_norm_length', []],
40+
['activation_exp', ['expert']],
41+
['decode_batch', ['expert']],
42+
['decode_length', []],
43+
['mlp_no_fsdp', ['model', 'attn_dp', 'expert']],
44+
['vocab', ['model', 'attn_dp', 'expert']],
45+
['heads', ['expert', 'attn_dp', 'model']],
46+
['q_heads', []],
47+
['kv_heads', []],
48+
['kv_head_dim', ['model', 'attn_dp', 'expert']],
49+
['kv', ['model', 'attn_dp', 'expert']],
50+
['kv', []],
51+
['embed', []],
52+
['mlp', ['model', 'attn_dp', 'expert']],
53+
['moe_mlp', []],
54+
['embed_tensor_transpose', ['attn_dp', 'model']],
55+
['embed_no_exp', []],
56+
['q_lora', []],
57+
['kv_lora', []],
58+
['norm', []],
59+
['cache_heads', ['model']],
60+
['exp', ['expert', 'attn_dp', 'model']],
61+
['paged_kv_heads', ['model']],
62+
['cache_batch_prefill', []],
63+
['cache_batch', []],
64+
['cache_sequence', []],
65+
['cache_heads_none', []],
66+
['cache_kv', []],
67+
['kv_lora_up_proj',['expert', 'attn_dp', 'model']],
68+
['q_lora_up_proj',['expert', 'attn_dp', 'model']],
69+
]

src/maxtext/configs/post_train/rl.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ max_num_seqs: null
155155
async_scheduling: True
156156
# stop generation when any of these strings is generated
157157
stop_strings: null
158+
# path to initialize vllm config
159+
vllm_config_path: 'src/maxtext/configs/inference/vllm.yml'
158160

159161
# ====== Checkpoint Configuration ======
160162
enable_checkpointing: True

src/maxtext/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,6 +1621,9 @@ class VLLM(BaseModel):
16211621
description="Overrides for HuggingFace model config for MaxText model.",
16221622
)
16231623
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
1624+
vllm_config_path: str = Field(
1625+
"src/maxtext/configs/inference/vllm.yml", description="path to yaml file for loading vLLM config."
1626+
)
16241627

16251628

16261629
class RL(BaseModel):

src/maxtext/inference/vllm_decode.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
use_chat_template=True
3030
"""
3131

32+
import copy
3233
import os
3334
from typing import Any, Sequence
3435

@@ -40,7 +41,6 @@
4041

4142
from maxtext.utils import model_creation_utils
4243
from maxtext.utils import max_logging
43-
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
4444
from maxtext.common.common_types import Config
4545
from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter
4646
from tunix.rl.rollout import base_rollout
@@ -67,6 +67,21 @@ def decode_with_vllm(config: Config) -> None:
6767
config: MaxText config.
6868
"""
6969
# Prepare vLLM Arguments
70+
# Use user-provided vllm_additional_config as base (includes model-specific
71+
# overrides like base_num_decoder_layers, override_model_config, etc.), then
72+
# fill in defaults and runtime-derived values on top.
73+
additional_config = copy.deepcopy(config.vllm_additional_config) if config.vllm_additional_config else {}
74+
additional_config.setdefault("maxtext_config", {})
75+
additional_config["maxtext_config"].setdefault("model_name", config.model_name)
76+
additional_config["maxtext_config"].setdefault("weight_dtype", "bfloat16")
77+
additional_config["maxtext_config"].setdefault("allow_split_physical_axes", True)
78+
additional_config["maxtext_config"]["debug_sharding"] = config.debug_sharding
79+
additional_config.setdefault("sharding", {})
80+
additional_config["sharding"].setdefault("sharding_strategy", {})
81+
additional_config["sharding"]["sharding_strategy"].setdefault("enable_dp_attention", config.enable_dp_attention)
82+
# Pass vllm_config_path so the adapter can use it as the MaxText base config.
83+
additional_config.setdefault("vllm_config_path", str(config.vllm_config_path))
84+
7085
vllm_args = {
7186
"model": config.tokenizer_path,
7287
"max_model_len": config.max_target_length,
@@ -76,19 +91,7 @@ def decode_with_vllm(config: Config) -> None:
7691
"hf_overrides": config.vllm_hf_overrides,
7792
"gpu_memory_utilization": config.hbm_utilization_vllm,
7893
"async_scheduling": config.async_scheduling,
79-
"additional_config": {
80-
"maxtext_config": {
81-
"model_name": config.model_name,
82-
"weight_dtype": "bfloat16",
83-
"allow_split_physical_axes": True,
84-
"debug_sharding": config.debug_sharding,
85-
},
86-
"sharding": {
87-
"sharding_strategy": {
88-
"enable_dp_attention": config.enable_dp_attention,
89-
},
90-
},
91-
},
94+
"additional_config": additional_config,
9295
}
9396

9497
if config.load_parameters_path:
@@ -106,8 +109,7 @@ def decode_with_vllm(config: Config) -> None:
106109
f"and EP={config.ici_expert_parallelism if enable_expert_parallel else 1}..."
107110
)
108111

109-
vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
110-
argv_list = ["", str(vllm_config_path), "log_config=False"]
112+
argv_list = ["", str(config.vllm_config_path), "log_config=False"]
111113
vllm_config = pyconfig.initialize(argv_list)
112114

113115
with nn_partitioning.axis_rules(vllm_config.logical_axis_rules):
@@ -145,7 +147,7 @@ def decode_with_vllm(config: Config) -> None:
145147
max_tokens=max_tokens_to_generate,
146148
top_k=config.decode_sampling_top_k,
147149
top_p=config.decode_sampling_nucleus_p,
148-
seed=FLAGS.seed,
150+
# seed=FLAGS.seed,
149151
)
150152

151153
outputs = llm.generate(prompts, sampling_params)

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,12 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
6969
)
7070
overrides["load_parameters_path"] = None
7171

72-
# Add base config path to positional args
73-
base_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
72+
# Add base config path to positional args — prefer the caller-supplied
73+
# vllm_config_path from additional_config, fall back to vllm.yml default.
74+
base_config_path = vllm_config.additional_config.get(
75+
"vllm_config_path",
76+
os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml"),
77+
)
7478
argv_list = ["", str(base_config_path)]
7579

7680
maxtext_config = pyconfig.initialize(argv_list, **overrides)
@@ -86,6 +90,11 @@ class MaxTextForCausalLM(nnx.Module):
8690
of the decoding step.
8791
"""
8892

93+
# Signal to tpu-inference model_loader that this class manages its own
94+
# JIT-sharded initialization (via create_nnx_model with out_shardings).
95+
# When True, model_loader skips wrapping __init__ in an outer bare @jax.jit,
96+
_self_manages_sharding: bool = True
97+
8998
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
9099
"""Initializes the MaxTextForCausalLM model.
91100
@@ -232,7 +241,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
232241
if self.model is not None:
233242
return
234243

235-
with self.mesh, nn.logical_axis_rules(""):
244+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
236245
model, _ = model_creation_utils.create_nnx_model(
237246
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
238247
)

0 commit comments

Comments
 (0)