Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ logical_axis_rules: [
['decode_length', ['sequence']],
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
['moe_mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
Expand Down Expand Up @@ -1119,6 +1120,8 @@ vllm_hf_config_path: ""
# A JSON string of overrides to apply to the HuggingFace-style config for the vLLM adapter.
# This can be used to override specific settings without modifying the original config file.
vllm_hf_overrides: {}
# Path to yaml file for loading vLLM config
vllm_config_path: ""
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
vllm_additional_config: {}
# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH]
Expand Down
69 changes: 69 additions & 0 deletions src/maxtext/configs/inference/vllm_deepseek.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


base_config: "vllm.yml"

logical_axis_rules: [
['activation_batch', []],
['activation_batch_no_exp', []],
['activation_embed_and_logits_batch', ['expert']],
['activation_embed_and_logits_batch_sequence', ['expert']],
['activation_heads', ['model']],
['activation_kv_heads', ['model']],
['activation_attn_length', ['expert']],
['activation_attn_length_no_exp', []],
['activation_length', ['data', 'expert']],
['activation_length_no_exp', 'data'],
['activation_q_length', ['expert']],
['activation_attn_embed', 'model'],
['activation_embed', ['model', 'attn_dp']],
['activation_mlp', ['model', 'attn_dp', 'expert']],
['activation_kv', ['model']],
['activation_prefill_kv_batch', ['expert']],
['activation_kv_batch', []],
['activation_kv_batch_no_exp', []],
['activation_kv_head_dim', ['model', 'attn_dp', 'expert']],
['activation_vocab', ['model', 'attn_dp']],
['activation_norm_length', []],
['activation_exp', ['expert']],
['decode_batch', ['expert']],
['decode_length', []],
['mlp_no_fsdp', ['model', 'attn_dp', 'expert']],
['vocab', ['model', 'attn_dp', 'expert']],
['heads', ['expert', 'attn_dp', 'model']],
['q_heads', []],
['kv_heads', []],
['kv_head_dim', ['model', 'attn_dp', 'expert']],
['kv', ['model', 'attn_dp', 'expert']],
['kv', []],
['embed', []],
['mlp', ['model', 'attn_dp', 'expert']],
['moe_mlp', []],
['embed_tensor_transpose', ['attn_dp', 'model']],
['embed_no_exp', []],
['q_lora', []],
['kv_lora', []],
['norm', []],
['cache_heads', ['model']],
['exp', ['expert', 'attn_dp', 'model']],
['paged_kv_heads', ['model']],
['cache_batch_prefill', []],
['cache_batch', []],
['cache_sequence', []],
['cache_heads_none', []],
['cache_kv', []],
['kv_lora_up_proj',['expert', 'attn_dp', 'model']],
['q_lora_up_proj',['expert', 'attn_dp', 'model']],
]
2 changes: 2 additions & 0 deletions src/maxtext/configs/post_train/rl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ max_num_seqs: null
async_scheduling: True
# stop generation when any of these strings is generated
stop_strings: null
# path to initialize vllm config
vllm_config_path: 'src/maxtext/configs/inference/vllm.yml'

# ====== Checkpoint Configuration ======
enable_checkpointing: True
Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,9 @@ class VLLM(BaseModel):
description="Overrides for HuggingFace model config for MaxText model.",
)
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
vllm_config_path: str = Field(
"src/maxtext/configs/inference/vllm.yml", description="path to yaml file for loading vLLM config."
)


class RL(BaseModel):
Expand Down
36 changes: 19 additions & 17 deletions src/maxtext/inference/vllm_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
use_chat_template=True
"""

import copy
import os
from typing import Any, Sequence

Expand All @@ -40,7 +41,6 @@

from maxtext.utils import model_creation_utils
from maxtext.utils import max_logging
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
from maxtext.common.common_types import Config
from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter
from tunix.rl.rollout import base_rollout
Expand All @@ -67,6 +67,21 @@ def decode_with_vllm(config: Config) -> None:
config: MaxText config.
"""
# Prepare vLLM Arguments
# Use user-provided vllm_additional_config as base (includes model-specific
# overrides like base_num_decoder_layers, override_model_config, etc.), then
# fill in defaults and runtime-derived values on top.
additional_config = copy.deepcopy(config.vllm_additional_config) if config.vllm_additional_config else {}
additional_config.setdefault("maxtext_config", {})
additional_config["maxtext_config"].setdefault("model_name", config.model_name)
additional_config["maxtext_config"].setdefault("weight_dtype", "bfloat16")
additional_config["maxtext_config"].setdefault("allow_split_physical_axes", True)
additional_config["maxtext_config"]["debug_sharding"] = config.debug_sharding
additional_config.setdefault("sharding", {})
additional_config["sharding"].setdefault("sharding_strategy", {})
additional_config["sharding"]["sharding_strategy"].setdefault("enable_dp_attention", config.enable_dp_attention)
# Pass vllm_config_path so the adapter can use it as the MaxText base config.
additional_config.setdefault("vllm_config_path", str(config.vllm_config_path))

vllm_args = {
"model": config.tokenizer_path,
"max_model_len": config.max_target_length,
Expand All @@ -76,19 +91,7 @@ def decode_with_vllm(config: Config) -> None:
"hf_overrides": config.vllm_hf_overrides,
"gpu_memory_utilization": config.hbm_utilization_vllm,
"async_scheduling": config.async_scheduling,
"additional_config": {
"maxtext_config": {
"model_name": config.model_name,
"weight_dtype": "bfloat16",
"allow_split_physical_axes": True,
"debug_sharding": config.debug_sharding,
},
"sharding": {
"sharding_strategy": {
"enable_dp_attention": config.enable_dp_attention,
},
},
},
"additional_config": additional_config,
}

if config.load_parameters_path:
Expand All @@ -106,8 +109,7 @@ def decode_with_vllm(config: Config) -> None:
f"and EP={config.ici_expert_parallelism if enable_expert_parallel else 1}..."
)

vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
argv_list = ["", str(vllm_config_path), "log_config=False"]
argv_list = ["", str(config.vllm_config_path), "log_config=False"]
vllm_config = pyconfig.initialize(argv_list)

with nn_partitioning.axis_rules(vllm_config.logical_axis_rules):
Expand Down Expand Up @@ -145,7 +147,7 @@ def decode_with_vllm(config: Config) -> None:
max_tokens=max_tokens_to_generate,
top_k=config.decode_sampling_top_k,
top_p=config.decode_sampling_nucleus_p,
seed=FLAGS.seed,
# seed=FLAGS.seed,
)

outputs = llm.generate(prompts, sampling_params)
Expand Down
15 changes: 12 additions & 3 deletions src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,12 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
)
overrides["load_parameters_path"] = None

# Add base config path to positional args
base_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
# Add base config path to positional args — prefer the caller-supplied
# vllm_config_path from additional_config, fall back to vllm.yml default.
base_config_path = vllm_config.additional_config.get(
"vllm_config_path",
os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml"),
)
argv_list = ["", str(base_config_path)]

maxtext_config = pyconfig.initialize(argv_list, **overrides)
Expand All @@ -86,6 +90,11 @@ class MaxTextForCausalLM(nnx.Module):
of the decoding step.
"""

# Signal to tpu-inference model_loader that this class manages its own
# JIT-sharded initialization (via create_nnx_model with out_shardings).
# When True, model_loader skips wrapping __init__ in an outer bare @jax.jit,
_self_manages_sharding: bool = True

def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
"""Initializes the MaxTextForCausalLM model.

Expand Down Expand Up @@ -232,7 +241,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
if self.model is not None:
return

with self.mesh, nn.logical_axis_rules(""):
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
model, _ = model_creation_utils.create_nnx_model(
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
Expand Down
Loading
Loading