Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/maxtext/configs/inference/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ logical_axis_rules: [
# Vocab Activations
['activation_embed_and_logits_batch', ['data', 'attn_dp', 'attn_dp_expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'attn_dp', 'attn_dp_expert']],
['activation_vocab', ['expert', 'model']],
['activation_vocab', ['model', 'expert']],
# Vocab Weights
['vocab', []],
['embed_vocab', []],
Expand All @@ -46,16 +46,17 @@ logical_axis_rules: [
# ==========================================
# Attention Activations
['activation_batch_attn', ['data', 'attn_dp', 'attn_dp_expert']],
['activation_heads', ['expert', 'model']],
['activation_kv_heads', ['expert', 'model']],
['activation_heads', ['model', 'expert']],
['activation_kv_heads', ['model', 'expert']],
['activation_embed_attn', []],
['activation_kv', []],
['activation_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
['activation_kv_head_dim', []],
# Attention Weights
['heads', ['expert', 'model']],
['q_heads', ['expert', 'model']],
['kv_heads', ['expert', 'model']],
['heads', ['model', 'expert']],
['gdn_head', ['model', 'expert']],
['q_heads', ['model', 'expert']],
['kv_heads', ['model', 'expert']],
['qkv', []],
['kv', []],
['kv_head_dim', []],
Expand Down
14 changes: 14 additions & 0 deletions src/maxtext/inference/vllm_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from maxtext.utils import max_logging
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
from maxtext.common.common_types import Config
from maxtext.common import profiler
from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter
from tunix.rl.rollout import base_rollout
from tunix.rl.rollout.vllm_rollout import VllmRollout
Expand All @@ -52,6 +53,11 @@

adapter.register()

# Force uses_mrope to False to disable 3D multimodal position IDs in text-only runs.
from vllm.config import ModelConfig

ModelConfig.uses_mrope = property(lambda _: False)

os.environ["SKIP_JAX_PRECOMPILE"] = "1"
os.environ["NEW_MODEL_DESIGN"] = "1"

Expand Down Expand Up @@ -106,6 +112,11 @@ def decode_with_vllm(config: Config) -> None:
vllm_args["additional_config"]["sharding"]["sharding_strategy"]["expert_parallelism"] = config.ici_expert_parallelism
vllm_args["enable_expert_parallel"] = enable_expert_parallel

if config.max_num_batched_tokens is not None:
vllm_args["max_num_batched_tokens"] = config.max_num_batched_tokens
if config.max_num_seqs is not None:
vllm_args["max_num_seqs"] = config.max_num_seqs

max_logging.log(
f"Initializing LLM with DP={config.ici_data_parallelism}, TP={config.ici_tensor_parallelism} "
f"and EP={config.ici_expert_parallelism if enable_expert_parallel else 1}..."
Expand Down Expand Up @@ -156,7 +167,10 @@ def decode_with_vllm(config: Config) -> None:
top_p=top_p,
)

prof = profiler.Profiler(config)
prof.activate()
outputs = llm.generate(prompts, sampling_params)
prof.deactivate()

# max_logging.log Outputs
for output in outputs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,11 @@ def register():
"""
logger.info("Registering MaxTextForCausalLM model with tpu_inference and vllm.")
register_model("MaxTextForCausalLM", MaxTextForCausalLM)

# Dynamically apply KVCacheManager patch when registering the adapter
# pylint: disable=import-outside-toplevel
from .adapter import patch_kv_cache_manager

patch_kv_cache_manager()

logger.info("Successfully registered MaxTextForCausalLM model.")
179 changes: 176 additions & 3 deletions src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def next_power_of_two(x: int) -> int:
return 1 << (x - 1).bit_length()


def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.HyperParameters:
def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters:
"""Generates a MaxText configuration from a vLLM configuration.
This function takes a vLLM configuration object and translates relevant
Expand All @@ -67,7 +67,6 @@ def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.Hyp
Args:
vllm_config: The vLLM configuration object containing model and load
parameters.
mesh: The JAX mesh device for model sharding.
Returns:
A `pyconfig.HyperParameters` object configured for MaxText.
Expand Down Expand Up @@ -178,7 +177,7 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
"""
self.vllm_config = vllm_config
self.cfg = vllm_config.model_config
self.maxtext_config = generate_maxtext_config(vllm_config, mesh)
self.maxtext_config = generate_maxtext_config(vllm_config)

# Model configuration
self.mesh = mesh
Expand Down Expand Up @@ -228,6 +227,24 @@ def __call__(
if not isinstance(self.model, nnx.Module):
raise ValueError("Model must be an instance of type nnx.Module.")

# below, GDN layers don't touch block_tables — they index via
# ``mamba_state_indices`` — and all full-attn layers belong to the same
# kv_cache_group so they share one block_tables. Pick a metadata from a
# full-attn (non-linear_attention) layer when possible; otherwise any
# value works.
if isinstance(attention_metadata, dict):
hf_text_config = getattr(self.cfg, "hf_text_config", getattr(self.cfg, "hf_config", None))
layer_types = getattr(hf_text_config, "layer_types", None) or []
attention_metadata_picked = None
for i, lt in enumerate(layer_types):
if lt != "linear_attention":
attention_metadata_picked = attention_metadata.get(f"layer.{i}")
if attention_metadata_picked is not None:
break
if attention_metadata_picked is None:
attention_metadata_picked = next(iter(attention_metadata.values()))
attention_metadata = attention_metadata_picked

# Ensure inputs are at least 2D with a batch dimension
input_ids = jnp.expand_dims(input_ids, axis=1)
input_positions = jnp.expand_dims(attention_metadata.input_positions, axis=1)
Expand Down Expand Up @@ -324,3 +341,159 @@ def load_weights(self, rng_key: jax.Array) -> None:
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
self.model = nnx.data(model)

def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list = None,
) -> tuple[jax.Array, int]:
"""Get dummy mrope input positions and delta value for text-only MaxText."""
seq_len = len(input_tokens)
pos_range = jnp.arange(seq_len, dtype=jnp.int32)
# M-RoPE expects 3D position vectors (3, seq_len) and position_delta (int)
positions = jnp.stack([pos_range, pos_range, pos_range], axis=0)
return positions, 0


# Monkey-patch KVCacheManager.get_kv_cache_spec to support GDN/Mamba layers in Pure JAX path.
def patch_kv_cache_manager():
"""Monkey-patches KVCacheManager to support hybrid Attention + GDN/Mamba models."""
# pylint: disable=import-outside-toplevel,protected-access
try:
from tpu_inference.runner.kv_cache_manager import KVCacheManager
from vllm.v1.kv_cache_interface import MambaSpec
import torch
import numpy as np
except ImportError as e:
# Gracefully handle missing imports in standard JAX environments (e.g. unit tests on CPU)
max_logging.log(f"Skipping KVCacheManager patch (tpu_inference or dependencies not installed): {e}")
return

try:
original_get_kv_cache_spec = KVCacheManager.get_kv_cache_spec
except AttributeError as e:
# Raise a clear error if packages exist but patch target is missing (indicating API change or mismatch)
raise RuntimeError(
"Failed to apply KVCacheManager patch: KVCacheManager.get_kv_cache_spec not found. "
"This usually indicates a vLLM / tpu-inference API change or version mismatch."
) from e

def patched_get_kv_cache_spec(self):
runner = self.runner
if not hasattr(runner, "model"):
return original_get_kv_cache_spec(self)

model = runner.model
if not hasattr(model, "maxtext_config"):
return original_get_kv_cache_spec(self)

cfg = model.maxtext_config
decoder_block = getattr(cfg, "decoder_block", "")

decoder_block_str = ""
if isinstance(decoder_block, str):
decoder_block_str = decoder_block
elif hasattr(decoder_block, "value"):
decoder_block_str = decoder_block.value

if decoder_block_str in ("qwen3_next", "qwen3_5"):
interval = cfg.inhomogeneous_layer_cycle_interval

num_v_heads = cfg.gdn_num_value_heads
num_k_heads = cfg.gdn_num_key_heads
head_k_dim = cfg.gdn_key_head_dim
head_v_dim = cfg.gdn_value_head_dim
conv_kernel_size = cfg.gdn_conv_kernel_dim

key_dim = head_k_dim * num_k_heads
value_dim = head_v_dim * num_v_heads
conv_dim = key_dim * 2 + value_dim

conv_state_shape = (conv_kernel_size - 1, conv_dim)
recurrent_state_shape = (num_v_heads, head_k_dim, head_v_dim)

mamba_shapes = (conv_state_shape, recurrent_state_shape)

torch_dtype = torch.bfloat16
if str(cfg.dtype) == "float32":
torch_dtype = torch.float32
mamba_dtypes = (torch_dtype, torch_dtype)

# Calculate unpadded mamba page size
dtype_size = 2 if torch_dtype == torch.bfloat16 else 4
unpadded_mamba_page_size = sum(int(np.prod(shape)) * dtype_size for shape in mamba_shapes)

# Calculate attn_page_size_bytes
from tpu_inference.layers.common.sharding import ShardingAxisName
from tpu_inference import utils as common_utils

tp_axis_name = ShardingAxisName.ATTN_HEAD
model_cnt = common_utils.get_mesh_shape_product(self.runner.mesh, tp_axis_name)

model_config = self.runner.model_config
text_config = getattr(model_config, "hf_text_config", getattr(model_config, "hf_config", None))
base_num_kv_heads = model_config.get_total_num_kv_heads()
base_head_size = model_config.get_head_size()

num_kv_heads = getattr(text_config, "num_global_key_value_heads", None) or base_num_kv_heads
head_size = getattr(text_config, "global_head_dim", None) or base_head_size

num_kv_heads = common_utils.get_padded_num_heads(num_kv_heads, model_cnt)
head_size = common_utils.get_padded_head_dim(head_size)

from tpu_inference.runner.kv_cache import get_attention_page_size_bytes

block_size = self.runner.cache_config.block_size

attn_page_size_bytes = get_attention_page_size_bytes(
self.runner.mesh, block_size, num_kv_heads, head_size, self.runner.kv_cache_dtype, False
)

# Calculate groups
num_layers = cfg.base_num_decoder_layers
num_attn = num_layers // interval
num_mamba = num_layers - num_attn

min_count = min(num_attn, num_mamba)
max_count = max(num_attn, num_mamba)
if max_count < min_count * 1.5:
group_size = max_count
else:
group_size = min_count
num_attn_groups = (num_attn + group_size - 1) // group_size
num_mamba_groups = (num_mamba + group_size - 1) // group_size

uniform_page_size_bytes = num_attn_groups * attn_page_size_bytes + num_mamba_groups * unpadded_mamba_page_size

# Set the padded page size on manager and config
self._hybrid_uniform_page_size_bytes = int(uniform_page_size_bytes)
self.runner.cache_config.mamba_page_size_padded = int(uniform_page_size_bytes)

self._maybe_set_compact_mamba_num_blocks_override(
attn_page_size_bytes,
int(unpadded_mamba_page_size),
num_attn_groups,
num_mamba_groups,
num_attn,
num_mamba,
group_size,
)

kv_cache_spec = original_get_kv_cache_spec(self)

if decoder_block_str in ("qwen3_next", "qwen3_5"):
for i in range(cfg.base_num_decoder_layers):
if (i + 1) % interval != 0:
layer_name = f"layer.{i}"
if layer_name in kv_cache_spec:
kv_cache_spec[layer_name] = MambaSpec(
block_size=kv_cache_spec[layer_name].block_size,
shapes=mamba_shapes,
dtypes=mamba_dtypes,
page_size_padded=self._hybrid_uniform_page_size_bytes,
)

return kv_cache_spec

KVCacheManager.get_kv_cache_spec = patched_get_kv_cache_spec
max_logging.log("Successfully applied KVCacheManager patch for hybrid GDN models.")
16 changes: 6 additions & 10 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,12 +1135,12 @@ def __call__(
if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
layer_kwargs = {"layer_idx": lyr}
kv_cache = None
if kv_caches is not None and cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
if kv_caches is not None:
# For all decoder blocks (including QWEN3_NEXT/QWEN3_5 with vLLM flat-list
# kv_caches), pass the per-layer cache directly. For hybrid attention+GDN
# models, kv_caches[lyr] is a regular attention cache for attention layers
# and a (conv_state, recurrent_state) paged-mamba tuple for GDN layers.
kv_cache = kv_caches[lyr]
elif kv_caches is not None and cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
# For Qwen3Next & Qwen3.5, kv_caches is a dictionary of lists of caches.
if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr])

if cfg.decoder_block == DecoderBlockType.GPT_OSS:
layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)}
Expand All @@ -1162,11 +1162,7 @@ def __call__(
**layer_call_kwargs,
)
if kv_caches is not None and returned_cache is not None:
if cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
kv_caches[lyr] = returned_cache
elif (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
kv_caches["key_cache"][lyr] = returned_cache[0]
kv_caches["value_cache"][lyr] = returned_cache[1]
kv_caches[lyr] = returned_cache

if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds):
visual_embeds = deepstack_visual_embeds[lyr]
Expand Down
Loading
Loading