@@ -56,7 +56,7 @@ def next_power_of_two(x: int) -> int:
5656 return 1 << (x - 1 ).bit_length ()
5757
5858
59- def generate_maxtext_config (vllm_config : VllmConfig , mesh : Mesh ) -> pyconfig .HyperParameters :
59+ def generate_maxtext_config (vllm_config : VllmConfig ) -> pyconfig .HyperParameters :
6060 """Generates a MaxText configuration from a vLLM configuration.
6161
6262 This function takes a vLLM configuration object and translates relevant
@@ -67,7 +67,6 @@ def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.Hyp
6767 Args:
6868 vllm_config: The vLLM configuration object containing model and load
6969 parameters.
70- mesh: The JAX mesh device for model sharding.
7170
7271 Returns:
7372 A `pyconfig.HyperParameters` object configured for MaxText.
@@ -178,7 +177,7 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
178177 """
179178 self .vllm_config = vllm_config
180179 self .cfg = vllm_config .model_config
181- self .maxtext_config = generate_maxtext_config (vllm_config , mesh )
180+ self .maxtext_config = generate_maxtext_config (vllm_config )
182181
183182 # Model configuration
184183 self .mesh = mesh
@@ -228,6 +227,24 @@ def __call__(
228227 if not isinstance (self .model , nnx .Module ):
229228 raise ValueError ("Model must be an instance of type nnx.Module." )
230229
230+ # below, GDN layers don't touch block_tables — they index via
231+ # ``mamba_state_indices`` — and all full-attn layers belong to the same
232+ # kv_cache_group so they share one block_tables. Pick a metadata from a
233+ # full-attn (non-linear_attention) layer when possible; otherwise any
234+ # value works.
235+ if isinstance (attention_metadata , dict ):
236+ hf_text_config = getattr (self .cfg , "hf_text_config" , getattr (self .cfg , "hf_config" , None ))
237+ layer_types = getattr (hf_text_config , "layer_types" , None ) or []
238+ attention_metadata_picked = None
239+ for i , lt in enumerate (layer_types ):
240+ if lt != "linear_attention" :
241+ attention_metadata_picked = attention_metadata .get (f"layer.{ i } " )
242+ if attention_metadata_picked is not None :
243+ break
244+ if attention_metadata_picked is None :
245+ attention_metadata_picked = next (iter (attention_metadata .values ()))
246+ attention_metadata = attention_metadata_picked
247+
231248 # Ensure inputs are at least 2D with a batch dimension
232249 input_ids = jnp .expand_dims (input_ids , axis = 1 )
233250 input_positions = jnp .expand_dims (attention_metadata .input_positions , axis = 1 )
@@ -324,3 +341,159 @@ def load_weights(self, rng_key: jax.Array) -> None:
324341 self .maxtext_config , mesh = self .mesh , model_mode = self .model_mode , rng_key = rng_key
325342 )
326343 self .model = nnx .data (model )
344+
345+ def get_mrope_input_positions (
346+ self ,
347+ input_tokens : list [int ],
348+ mm_features : list = None ,
349+ ) -> tuple [jax .Array , int ]:
350+ """Get dummy mrope input positions and delta value for text-only MaxText."""
351+ seq_len = len (input_tokens )
352+ pos_range = jnp .arange (seq_len , dtype = jnp .int32 )
353+ # M-RoPE expects 3D position vectors (3, seq_len) and position_delta (int)
354+ positions = jnp .stack ([pos_range , pos_range , pos_range ], axis = 0 )
355+ return positions , 0
356+
357+
358+ # Monkey-patch KVCacheManager.get_kv_cache_spec to support GDN/Mamba layers in Pure JAX path.
359+ def patch_kv_cache_manager ():
360+ """Monkey-patches KVCacheManager to support hybrid Attention + GDN/Mamba models."""
361+ # pylint: disable=import-outside-toplevel,protected-access
362+ try :
363+ from tpu_inference .runner .kv_cache_manager import KVCacheManager
364+ from vllm .v1 .kv_cache_interface import MambaSpec
365+ import torch
366+ import numpy as np
367+ except ImportError as e :
368+ # Gracefully handle missing imports in standard JAX environments (e.g. unit tests on CPU)
369+ max_logging .log (f"Skipping KVCacheManager patch (tpu_inference or dependencies not installed): { e } " )
370+ return
371+
372+ try :
373+ original_get_kv_cache_spec = KVCacheManager .get_kv_cache_spec
374+ except AttributeError as e :
375+ # Raise a clear error if packages exist but patch target is missing (indicating API change or mismatch)
376+ raise RuntimeError (
377+ "Failed to apply KVCacheManager patch: KVCacheManager.get_kv_cache_spec not found. "
378+ "This usually indicates a vLLM / tpu-inference API change or version mismatch."
379+ ) from e
380+
381+ def patched_get_kv_cache_spec (self ):
382+ runner = self .runner
383+ if not hasattr (runner , "model" ):
384+ return original_get_kv_cache_spec (self )
385+
386+ model = runner .model
387+ if not hasattr (model , "maxtext_config" ):
388+ return original_get_kv_cache_spec (self )
389+
390+ cfg = model .maxtext_config
391+ decoder_block = getattr (cfg , "decoder_block" , "" )
392+
393+ decoder_block_str = ""
394+ if isinstance (decoder_block , str ):
395+ decoder_block_str = decoder_block
396+ elif hasattr (decoder_block , "value" ):
397+ decoder_block_str = decoder_block .value
398+
399+ if decoder_block_str in ("qwen3_next" , "qwen3_5" ):
400+ interval = cfg .inhomogeneous_layer_cycle_interval
401+
402+ num_v_heads = cfg .gdn_num_value_heads
403+ num_k_heads = cfg .gdn_num_key_heads
404+ head_k_dim = cfg .gdn_key_head_dim
405+ head_v_dim = cfg .gdn_value_head_dim
406+ conv_kernel_size = cfg .gdn_conv_kernel_dim
407+
408+ key_dim = head_k_dim * num_k_heads
409+ value_dim = head_v_dim * num_v_heads
410+ conv_dim = key_dim * 2 + value_dim
411+
412+ conv_state_shape = (conv_kernel_size - 1 , conv_dim )
413+ recurrent_state_shape = (num_v_heads , head_k_dim , head_v_dim )
414+
415+ mamba_shapes = (conv_state_shape , recurrent_state_shape )
416+
417+ torch_dtype = torch .bfloat16
418+ if str (cfg .dtype ) == "float32" :
419+ torch_dtype = torch .float32
420+ mamba_dtypes = (torch_dtype , torch_dtype )
421+
422+ # Calculate unpadded mamba page size
423+ dtype_size = 2 if torch_dtype == torch .bfloat16 else 4
424+ unpadded_mamba_page_size = sum (int (np .prod (shape )) * dtype_size for shape in mamba_shapes )
425+
426+ # Calculate attn_page_size_bytes
427+ from tpu_inference .layers .common .sharding import ShardingAxisName
428+ from tpu_inference import utils as common_utils
429+
430+ tp_axis_name = ShardingAxisName .ATTN_HEAD
431+ model_cnt = common_utils .get_mesh_shape_product (self .runner .mesh , tp_axis_name )
432+
433+ model_config = self .runner .model_config
434+ text_config = getattr (model_config , "hf_text_config" , getattr (model_config , "hf_config" , None ))
435+ base_num_kv_heads = model_config .get_total_num_kv_heads ()
436+ base_head_size = model_config .get_head_size ()
437+
438+ num_kv_heads = getattr (text_config , "num_global_key_value_heads" , None ) or base_num_kv_heads
439+ head_size = getattr (text_config , "global_head_dim" , None ) or base_head_size
440+
441+ num_kv_heads = common_utils .get_padded_num_heads (num_kv_heads , model_cnt )
442+ head_size = common_utils .get_padded_head_dim (head_size )
443+
444+ from tpu_inference .runner .kv_cache import get_attention_page_size_bytes
445+
446+ block_size = self .runner .cache_config .block_size
447+
448+ attn_page_size_bytes = get_attention_page_size_bytes (
449+ self .runner .mesh , block_size , num_kv_heads , head_size , self .runner .kv_cache_dtype , False
450+ )
451+
452+ # Calculate groups
453+ num_layers = cfg .base_num_decoder_layers
454+ num_attn = num_layers // interval
455+ num_mamba = num_layers - num_attn
456+
457+ min_count = min (num_attn , num_mamba )
458+ max_count = max (num_attn , num_mamba )
459+ if max_count < min_count * 1.5 :
460+ group_size = max_count
461+ else :
462+ group_size = min_count
463+ num_attn_groups = (num_attn + group_size - 1 ) // group_size
464+ num_mamba_groups = (num_mamba + group_size - 1 ) // group_size
465+
466+ uniform_page_size_bytes = num_attn_groups * attn_page_size_bytes + num_mamba_groups * unpadded_mamba_page_size
467+
468+ # Set the padded page size on manager and config
469+ self ._hybrid_uniform_page_size_bytes = int (uniform_page_size_bytes )
470+ self .runner .cache_config .mamba_page_size_padded = int (uniform_page_size_bytes )
471+
472+ self ._maybe_set_compact_mamba_num_blocks_override (
473+ attn_page_size_bytes ,
474+ int (unpadded_mamba_page_size ),
475+ num_attn_groups ,
476+ num_mamba_groups ,
477+ num_attn ,
478+ num_mamba ,
479+ group_size ,
480+ )
481+
482+ kv_cache_spec = original_get_kv_cache_spec (self )
483+
484+ if decoder_block_str in ("qwen3_next" , "qwen3_5" ):
485+ for i in range (cfg .base_num_decoder_layers ):
486+ if (i + 1 ) % interval != 0 :
487+ layer_name = f"layer.{ i } "
488+ if layer_name in kv_cache_spec :
489+ kv_cache_spec [layer_name ] = MambaSpec (
490+ block_size = kv_cache_spec [layer_name ].block_size ,
491+ shapes = mamba_shapes ,
492+ dtypes = mamba_dtypes ,
493+ page_size_padded = self ._hybrid_uniform_page_size_bytes ,
494+ )
495+
496+ return kv_cache_spec
497+
498+ KVCacheManager .get_kv_cache_spec = patched_get_kv_cache_spec
499+ max_logging .log ("Successfully applied KVCacheManager patch for hybrid GDN models." )
0 commit comments