Skip to content

Commit e1aa0bb

Browse files
committed
support of gdn kernel from tpu-inference
1 parent 57a6b30 commit e1aa0bb

10 files changed

Lines changed: 454 additions & 66 deletions

File tree

src/maxtext/configs/inference/vllm.yml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ logical_axis_rules: [
3737
# Vocab Activations
3838
['activation_embed_and_logits_batch', ['data', 'attn_dp', 'attn_dp_expert']],
3939
['activation_embed_and_logits_batch_sequence', ['data', 'attn_dp', 'attn_dp_expert']],
40-
['activation_vocab', ['expert', 'model']],
40+
['activation_vocab', ['model', 'expert']],
4141
# Vocab Weights
4242
['vocab', []],
4343
['embed_vocab', []],
@@ -46,16 +46,17 @@ logical_axis_rules: [
4646
# ==========================================
4747
# Attention Activations
4848
['activation_batch_attn', ['data', 'attn_dp', 'attn_dp_expert']],
49-
['activation_heads', ['expert', 'model']],
50-
['activation_kv_heads', ['expert', 'model']],
49+
['activation_heads', ['model', 'expert']],
50+
['activation_kv_heads', ['model', 'expert']],
5151
['activation_embed_attn', []],
5252
['activation_kv', []],
5353
['activation_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
5454
['activation_kv_head_dim', []],
5555
# Attention Weights
56-
['heads', ['expert', 'model']],
57-
['q_heads', ['expert', 'model']],
58-
['kv_heads', ['expert', 'model']],
56+
['heads', ['model', 'expert']],
57+
['gdn_head', ['model', 'expert']],
58+
['q_heads', ['model', 'expert']],
59+
['kv_heads', ['model', 'expert']],
5960
['qkv', []],
6061
['kv', []],
6162
['kv_head_dim', []],

src/maxtext/inference/vllm_decode.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from maxtext.utils import max_logging
4343
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
4444
from maxtext.common.common_types import Config
45+
from maxtext.common import profiler
4546
from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter
4647
from tunix.rl.rollout import base_rollout
4748
from tunix.rl.rollout.vllm_rollout import VllmRollout
@@ -52,6 +53,11 @@
5253

5354
adapter.register()
5455

56+
# Force uses_mrope to False to disable 3D multimodal position IDs in text-only runs.
57+
from vllm.config import ModelConfig
58+
59+
ModelConfig.uses_mrope = property(lambda _: False)
60+
5561
os.environ["SKIP_JAX_PRECOMPILE"] = "1"
5662
os.environ["NEW_MODEL_DESIGN"] = "1"
5763

@@ -106,6 +112,11 @@ def decode_with_vllm(config: Config) -> None:
106112
vllm_args["additional_config"]["sharding"]["sharding_strategy"]["expert_parallelism"] = config.ici_expert_parallelism
107113
vllm_args["enable_expert_parallel"] = enable_expert_parallel
108114

115+
if config.max_num_batched_tokens is not None:
116+
vllm_args["max_num_batched_tokens"] = config.max_num_batched_tokens
117+
if config.max_num_seqs is not None:
118+
vllm_args["max_num_seqs"] = config.max_num_seqs
119+
109120
max_logging.log(
110121
f"Initializing LLM with DP={config.ici_data_parallelism}, TP={config.ici_tensor_parallelism} "
111122
f"and EP={config.ici_expert_parallelism if enable_expert_parallel else 1}..."
@@ -156,7 +167,10 @@ def decode_with_vllm(config: Config) -> None:
156167
top_p=top_p,
157168
)
158169

170+
prof = profiler.Profiler(config)
171+
prof.activate()
159172
outputs = llm.generate(prompts, sampling_params)
173+
prof.deactivate()
160174

161175
# max_logging.log Outputs
162176
for output in outputs:

src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,11 @@ def register():
3030
"""
3131
logger.info("Registering MaxTextForCausalLM model with tpu_inference and vllm.")
3232
register_model("MaxTextForCausalLM", MaxTextForCausalLM)
33+
34+
# Dynamically apply KVCacheManager patch when registering the adapter
35+
# pylint: disable=import-outside-toplevel
36+
from .adapter import patch_kv_cache_manager
37+
38+
patch_kv_cache_manager()
39+
3340
logger.info("Successfully registered MaxTextForCausalLM model.")

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 176 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def next_power_of_two(x: int) -> int:
5656
return 1 << (x - 1).bit_length()
5757

5858

59-
def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.HyperParameters:
59+
def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters:
6060
"""Generates a MaxText configuration from a vLLM configuration.
6161
6262
This function takes a vLLM configuration object and translates relevant
@@ -67,7 +67,6 @@ def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.Hyp
6767
Args:
6868
vllm_config: The vLLM configuration object containing model and load
6969
parameters.
70-
mesh: The JAX mesh device for model sharding.
7170
7271
Returns:
7372
A `pyconfig.HyperParameters` object configured for MaxText.
@@ -178,7 +177,7 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
178177
"""
179178
self.vllm_config = vllm_config
180179
self.cfg = vllm_config.model_config
181-
self.maxtext_config = generate_maxtext_config(vllm_config, mesh)
180+
self.maxtext_config = generate_maxtext_config(vllm_config)
182181

183182
# Model configuration
184183
self.mesh = mesh
@@ -228,6 +227,24 @@ def __call__(
228227
if not isinstance(self.model, nnx.Module):
229228
raise ValueError("Model must be an instance of type nnx.Module.")
230229

230+
# below, GDN layers don't touch block_tables — they index via
231+
# ``mamba_state_indices`` — and all full-attn layers belong to the same
232+
# kv_cache_group so they share one block_tables. Pick a metadata from a
233+
# full-attn (non-linear_attention) layer when possible; otherwise any
234+
# value works.
235+
if isinstance(attention_metadata, dict):
236+
hf_text_config = getattr(self.cfg, "hf_text_config", getattr(self.cfg, "hf_config", None))
237+
layer_types = getattr(hf_text_config, "layer_types", None) or []
238+
attention_metadata_picked = None
239+
for i, lt in enumerate(layer_types):
240+
if lt != "linear_attention":
241+
attention_metadata_picked = attention_metadata.get(f"layer.{i}")
242+
if attention_metadata_picked is not None:
243+
break
244+
if attention_metadata_picked is None:
245+
attention_metadata_picked = next(iter(attention_metadata.values()))
246+
attention_metadata = attention_metadata_picked
247+
231248
# Ensure inputs are at least 2D with a batch dimension
232249
input_ids = jnp.expand_dims(input_ids, axis=1)
233250
input_positions = jnp.expand_dims(attention_metadata.input_positions, axis=1)
@@ -324,3 +341,159 @@ def load_weights(self, rng_key: jax.Array) -> None:
324341
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
325342
)
326343
self.model = nnx.data(model)
344+
345+
def get_mrope_input_positions(
346+
self,
347+
input_tokens: list[int],
348+
mm_features: list = None,
349+
) -> tuple[jax.Array, int]:
350+
"""Get dummy mrope input positions and delta value for text-only MaxText."""
351+
seq_len = len(input_tokens)
352+
pos_range = jnp.arange(seq_len, dtype=jnp.int32)
353+
# M-RoPE expects 3D position vectors (3, seq_len) and position_delta (int)
354+
positions = jnp.stack([pos_range, pos_range, pos_range], axis=0)
355+
return positions, 0
356+
357+
358+
# Monkey-patch KVCacheManager.get_kv_cache_spec to support GDN/Mamba layers in Pure JAX path.
359+
def patch_kv_cache_manager():
360+
"""Monkey-patches KVCacheManager to support hybrid Attention + GDN/Mamba models."""
361+
# pylint: disable=import-outside-toplevel,protected-access
362+
try:
363+
from tpu_inference.runner.kv_cache_manager import KVCacheManager
364+
from vllm.v1.kv_cache_interface import MambaSpec
365+
import torch
366+
import numpy as np
367+
except ImportError as e:
368+
# Gracefully handle missing imports in standard JAX environments (e.g. unit tests on CPU)
369+
max_logging.log(f"Skipping KVCacheManager patch (tpu_inference or dependencies not installed): {e}")
370+
return
371+
372+
try:
373+
original_get_kv_cache_spec = KVCacheManager.get_kv_cache_spec
374+
except AttributeError as e:
375+
# Raise a clear error if packages exist but patch target is missing (indicating API change or mismatch)
376+
raise RuntimeError(
377+
"Failed to apply KVCacheManager patch: KVCacheManager.get_kv_cache_spec not found. "
378+
"This usually indicates a vLLM / tpu-inference API change or version mismatch."
379+
) from e
380+
381+
def patched_get_kv_cache_spec(self):
382+
runner = self.runner
383+
if not hasattr(runner, "model"):
384+
return original_get_kv_cache_spec(self)
385+
386+
model = runner.model
387+
if not hasattr(model, "maxtext_config"):
388+
return original_get_kv_cache_spec(self)
389+
390+
cfg = model.maxtext_config
391+
decoder_block = getattr(cfg, "decoder_block", "")
392+
393+
decoder_block_str = ""
394+
if isinstance(decoder_block, str):
395+
decoder_block_str = decoder_block
396+
elif hasattr(decoder_block, "value"):
397+
decoder_block_str = decoder_block.value
398+
399+
if decoder_block_str in ("qwen3_next", "qwen3_5"):
400+
interval = cfg.inhomogeneous_layer_cycle_interval
401+
402+
num_v_heads = cfg.gdn_num_value_heads
403+
num_k_heads = cfg.gdn_num_key_heads
404+
head_k_dim = cfg.gdn_key_head_dim
405+
head_v_dim = cfg.gdn_value_head_dim
406+
conv_kernel_size = cfg.gdn_conv_kernel_dim
407+
408+
key_dim = head_k_dim * num_k_heads
409+
value_dim = head_v_dim * num_v_heads
410+
conv_dim = key_dim * 2 + value_dim
411+
412+
conv_state_shape = (conv_kernel_size - 1, conv_dim)
413+
recurrent_state_shape = (num_v_heads, head_k_dim, head_v_dim)
414+
415+
mamba_shapes = (conv_state_shape, recurrent_state_shape)
416+
417+
torch_dtype = torch.bfloat16
418+
if str(cfg.dtype) == "float32":
419+
torch_dtype = torch.float32
420+
mamba_dtypes = (torch_dtype, torch_dtype)
421+
422+
# Calculate unpadded mamba page size
423+
dtype_size = 2 if torch_dtype == torch.bfloat16 else 4
424+
unpadded_mamba_page_size = sum(int(np.prod(shape)) * dtype_size for shape in mamba_shapes)
425+
426+
# Calculate attn_page_size_bytes
427+
from tpu_inference.layers.common.sharding import ShardingAxisName
428+
from tpu_inference import utils as common_utils
429+
430+
tp_axis_name = ShardingAxisName.ATTN_HEAD
431+
model_cnt = common_utils.get_mesh_shape_product(self.runner.mesh, tp_axis_name)
432+
433+
model_config = self.runner.model_config
434+
text_config = getattr(model_config, "hf_text_config", getattr(model_config, "hf_config", None))
435+
base_num_kv_heads = model_config.get_total_num_kv_heads()
436+
base_head_size = model_config.get_head_size()
437+
438+
num_kv_heads = getattr(text_config, "num_global_key_value_heads", None) or base_num_kv_heads
439+
head_size = getattr(text_config, "global_head_dim", None) or base_head_size
440+
441+
num_kv_heads = common_utils.get_padded_num_heads(num_kv_heads, model_cnt)
442+
head_size = common_utils.get_padded_head_dim(head_size)
443+
444+
from tpu_inference.runner.kv_cache import get_attention_page_size_bytes
445+
446+
block_size = self.runner.cache_config.block_size
447+
448+
attn_page_size_bytes = get_attention_page_size_bytes(
449+
self.runner.mesh, block_size, num_kv_heads, head_size, self.runner.kv_cache_dtype, False
450+
)
451+
452+
# Calculate groups
453+
num_layers = cfg.base_num_decoder_layers
454+
num_attn = num_layers // interval
455+
num_mamba = num_layers - num_attn
456+
457+
min_count = min(num_attn, num_mamba)
458+
max_count = max(num_attn, num_mamba)
459+
if max_count < min_count * 1.5:
460+
group_size = max_count
461+
else:
462+
group_size = min_count
463+
num_attn_groups = (num_attn + group_size - 1) // group_size
464+
num_mamba_groups = (num_mamba + group_size - 1) // group_size
465+
466+
uniform_page_size_bytes = num_attn_groups * attn_page_size_bytes + num_mamba_groups * unpadded_mamba_page_size
467+
468+
# Set the padded page size on manager and config
469+
self._hybrid_uniform_page_size_bytes = int(uniform_page_size_bytes)
470+
self.runner.cache_config.mamba_page_size_padded = int(uniform_page_size_bytes)
471+
472+
self._maybe_set_compact_mamba_num_blocks_override(
473+
attn_page_size_bytes,
474+
int(unpadded_mamba_page_size),
475+
num_attn_groups,
476+
num_mamba_groups,
477+
num_attn,
478+
num_mamba,
479+
group_size,
480+
)
481+
482+
kv_cache_spec = original_get_kv_cache_spec(self)
483+
484+
if decoder_block_str in ("qwen3_next", "qwen3_5"):
485+
for i in range(cfg.base_num_decoder_layers):
486+
if (i + 1) % interval != 0:
487+
layer_name = f"layer.{i}"
488+
if layer_name in kv_cache_spec:
489+
kv_cache_spec[layer_name] = MambaSpec(
490+
block_size=kv_cache_spec[layer_name].block_size,
491+
shapes=mamba_shapes,
492+
dtypes=mamba_dtypes,
493+
page_size_padded=self._hybrid_uniform_page_size_bytes,
494+
)
495+
496+
return kv_cache_spec
497+
498+
KVCacheManager.get_kv_cache_spec = patched_get_kv_cache_spec
499+
max_logging.log("Successfully applied KVCacheManager patch for hybrid GDN models.")

src/maxtext/layers/decoders.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,12 +1135,12 @@ def __call__(
11351135
if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
11361136
layer_kwargs = {"layer_idx": lyr}
11371137
kv_cache = None
1138-
if kv_caches is not None and cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
1138+
if kv_caches is not None:
1139+
# For all decoder blocks (including QWEN3_NEXT/QWEN3_5 with vLLM flat-list
1140+
# kv_caches), pass the per-layer cache directly. For hybrid attention+GDN
1141+
# models, kv_caches[lyr] is a regular attention cache for attention layers
1142+
# and a (conv_state, recurrent_state) paged-mamba tuple for GDN layers.
11391143
kv_cache = kv_caches[lyr]
1140-
elif kv_caches is not None and cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
1141-
# For Qwen3Next & Qwen3.5, kv_caches is a dictionary of lists of caches.
1142-
if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
1143-
kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr])
11441144

11451145
if cfg.decoder_block == DecoderBlockType.GPT_OSS:
11461146
layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)}
@@ -1162,11 +1162,7 @@ def __call__(
11621162
**layer_call_kwargs,
11631163
)
11641164
if kv_caches is not None and returned_cache is not None:
1165-
if cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
1166-
kv_caches[lyr] = returned_cache
1167-
elif (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
1168-
kv_caches["key_cache"][lyr] = returned_cache[0]
1169-
kv_caches["value_cache"][lyr] = returned_cache[1]
1165+
kv_caches[lyr] = returned_cache
11701166

11711167
if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds):
11721168
visual_embeds = deepstack_visual_embeds[lyr]

0 commit comments

Comments
 (0)