diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 0ecd3951df..c973087a09 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -551,6 +551,9 @@ logical_axis_rules: [ ['activation_stage', 'stage'], # General Weights ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], + # GDN (linear-attention) projections shard like 'mlp' during training; the + # vLLM serving config overrides this to match tpu-inference's ATTN_HEAD order. + ['gdn_head', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']], diff --git a/src/maxtext/configs/inference/vllm.yml b/src/maxtext/configs/inference/vllm.yml index 3f9e4f5290..157845c5b7 100644 --- a/src/maxtext/configs/inference/vllm.yml +++ b/src/maxtext/configs/inference/vllm.yml @@ -37,7 +37,7 @@ logical_axis_rules: [ # Vocab Activations ['activation_embed_and_logits_batch', ['data', 'attn_dp', 'attn_dp_expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'attn_dp', 'attn_dp_expert']], - ['activation_vocab', ['expert', 'model']], + ['activation_vocab', ['model', 'expert']], # Vocab Weights ['vocab', []], ['embed_vocab', []], @@ -46,16 +46,17 @@ logical_axis_rules: [ # ========================================== # Attention Activations ['activation_batch_attn', ['data', 'attn_dp', 'attn_dp_expert']], - ['activation_heads', ['expert', 'model']], - ['activation_kv_heads', ['expert', 'model']], + ['activation_heads', ['model', 'expert']], + ['activation_kv_heads', ['model', 'expert']], ['activation_embed_attn', []], ['activation_kv', []], ['activation_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']], ['activation_kv_head_dim', []], # Attention Weights - ['heads', ['expert', 'model']], - ['q_heads', ['expert', 'model']], - ['kv_heads', ['expert', 'model']], + ['heads', ['model', 'expert']], + ['gdn_head', ['model', 'expert']], + ['q_heads', ['model', 'expert']], + ['kv_heads', ['model', 'expert']], ['qkv', []], ['kv', []], ['kv_head_dim', []], diff --git a/src/maxtext/inference/vllm_decode.py b/src/maxtext/inference/vllm_decode.py index 0962a33b5f..d54d40c2d9 100644 --- a/src/maxtext/inference/vllm_decode.py +++ b/src/maxtext/inference/vllm_decode.py @@ -42,6 +42,7 @@ from maxtext.utils import max_logging from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR from maxtext.common.common_types import Config +from maxtext.common import profiler from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter from tunix.rl.rollout import base_rollout from tunix.rl.rollout.vllm_rollout import VllmRollout @@ -50,6 +51,13 @@ from maxtext.configs import pyconfig import maxtext.integration.vllm.maxtext_vllm_adapter as adapter +# Force uses_mrope to False to disable 3D multimodal position IDs in text-only runs. +# TODO(b/520142315): Class-level monkey patching is required here because ModelConfig.uses_mrope +# is evaluated during the internal initialization of the LLM engine, making instance-level +# configuration impossible. Check if a cleaner configuration option can be added in upstream vLLM. +from vllm.config import ModelConfig + +ModelConfig.uses_mrope = property(lambda _: False) # --- DEFINE FLAGS GLOBALLY --- FLAGS = flags.FLAGS @@ -110,6 +118,11 @@ def decode_with_vllm(config: Config) -> None: vllm_args["additional_config"]["sharding"]["sharding_strategy"]["expert_parallelism"] = config.ici_expert_parallelism vllm_args["enable_expert_parallel"] = enable_expert_parallel + if config.max_num_batched_tokens is not None: + vllm_args["max_num_batched_tokens"] = config.max_num_batched_tokens + if config.max_num_seqs is not None: + vllm_args["max_num_seqs"] = config.max_num_seqs + max_logging.log( f"Initializing LLM with DP={config.ici_data_parallelism}, TP={config.ici_tensor_parallelism} " f"and EP={config.ici_expert_parallelism if enable_expert_parallel else 1}..." @@ -157,7 +170,10 @@ def decode_with_vllm(config: Config) -> None: top_p=top_p, ) + prof = profiler.Profiler(config) + prof.activate() outputs = llm.generate(prompts, sampling_params) + prof.deactivate() # max_logging.log Outputs for output in outputs: diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py index 216737a7fc..ee14be16b8 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py @@ -30,4 +30,11 @@ def register(): """ logger.info("Registering MaxTextForCausalLM model with tpu_inference and vllm.") register_model("MaxTextForCausalLM", MaxTextForCausalLM) + + # Dynamically apply KVCacheManager patch when registering the adapter + # pylint: disable=import-outside-toplevel + from .adapter import patch_kv_cache_manager + + patch_kv_cache_manager() + logger.info("Successfully registered MaxTextForCausalLM model.") diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index 217806f887..bff36397af 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -40,6 +40,11 @@ class AttentionMetadata: from vllm.config import VllmConfig +# Threshold to determine if the ratio of attention to mamba layers is highly imbalanced. +# If max_count / min_count >= this threshold, we group KV cache allocations by the +# smaller count to prevent excessive memory padding for the minority layer type. +_HYBRID_LAYER_IMBALANCE_THRESHOLD = 1.5 + def next_power_of_two(x: int) -> int: """Finds the smallest power of 2 >= x using bit manipulation. @@ -56,7 +61,7 @@ def next_power_of_two(x: int) -> int: return 1 << (x - 1).bit_length() -def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.HyperParameters: +def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters: """Generates a MaxText configuration from a vLLM configuration. This function takes a vLLM configuration object and translates relevant @@ -67,7 +72,6 @@ def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.Hyp Args: vllm_config: The vLLM configuration object containing model and load parameters. - mesh: The JAX mesh device for model sharding. Returns: A `pyconfig.HyperParameters` object configured for MaxText. @@ -178,7 +182,7 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh): """ self.vllm_config = vllm_config self.cfg = vllm_config.model_config - self.maxtext_config = generate_maxtext_config(vllm_config, mesh) + self.maxtext_config = generate_maxtext_config(vllm_config) # Model configuration self.mesh = mesh @@ -228,6 +232,24 @@ def __call__( if not isinstance(self.model, nnx.Module): raise ValueError("Model must be an instance of type nnx.Module.") + # below, GDN layers don't touch block_tables — they index via + # ``mamba_state_indices`` — and all full-attn layers belong to the same + # kv_cache_group so they share one block_tables. Pick a metadata from a + # full-attn (non-linear_attention) layer when possible; otherwise any + # value works. + if isinstance(attention_metadata, dict): + hf_text_config = getattr(self.cfg, "hf_text_config", getattr(self.cfg, "hf_config", None)) + layer_types = getattr(hf_text_config, "layer_types", None) or [] + attention_metadata_picked = None + for i, lt in enumerate(layer_types): + if lt != "linear_attention": + attention_metadata_picked = attention_metadata.get(f"layer.{i}") + if attention_metadata_picked is not None: + break + if attention_metadata_picked is None: + attention_metadata_picked = next(iter(attention_metadata.values())) + attention_metadata = attention_metadata_picked + # Ensure inputs are at least 2D with a batch dimension input_ids = jnp.expand_dims(input_ids, axis=1) input_positions = jnp.expand_dims(attention_metadata.input_positions, axis=1) @@ -324,3 +346,168 @@ def load_weights(self, rng_key: jax.Array) -> None: self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key ) self.model = nnx.data(model) + + def get_mrope_input_positions( + self, + input_tokens: list[int], + mm_features: list = None, + ) -> tuple[jax.Array, int]: + """Get dummy mrope input positions and delta value for text-only MaxText.""" + seq_len = len(input_tokens) + pos_range = jnp.arange(seq_len, dtype=jnp.int32) + # M-RoPE expects 3D position vectors (3, seq_len) and position_delta (int) + positions = jnp.stack([pos_range, pos_range, pos_range], axis=0) + return positions, 0 + + +# Monkey-patch KVCacheManager.get_kv_cache_spec to support GDN/Mamba layers in Pure JAX path. +def patch_kv_cache_manager(): + """Monkey-patches KVCacheManager to support hybrid Attention + GDN/Mamba models.""" + # pylint: disable=import-outside-toplevel,protected-access + try: + from tpu_inference.runner.kv_cache_manager import KVCacheManager + from vllm.v1.kv_cache_interface import MambaSpec + import torch + import numpy as np + except ImportError as e: + # Gracefully handle missing imports in standard JAX environments (e.g. unit tests on CPU) + max_logging.log(f"Skipping KVCacheManager patch (tpu_inference or dependencies not installed): {e}") + return + + try: + original_get_kv_cache_spec = KVCacheManager.get_kv_cache_spec + except AttributeError as e: + # Raise a clear error if packages exist but patch target is missing (indicating API change or mismatch) + raise RuntimeError( + "Failed to apply KVCacheManager patch: KVCacheManager.get_kv_cache_spec not found. " + "This usually indicates a vLLM / tpu-inference API change or version mismatch." + ) from e + + def patched_get_kv_cache_spec(self): + runner = self.runner + if not hasattr(runner, "model"): + return original_get_kv_cache_spec(self) + + model = runner.model + if not hasattr(model, "maxtext_config"): + return original_get_kv_cache_spec(self) + + cfg = model.maxtext_config + decoder_block = getattr(cfg, "decoder_block", "") + + decoder_block_str = "" + if isinstance(decoder_block, str): + decoder_block_str = decoder_block + elif hasattr(decoder_block, "value"): + decoder_block_str = decoder_block.value + + if decoder_block_str in ("qwen3_next", "qwen3_5"): + interval = cfg.inhomogeneous_layer_cycle_interval + + num_v_heads = cfg.gdn_num_value_heads + num_k_heads = cfg.gdn_num_key_heads + head_k_dim = cfg.gdn_key_head_dim + head_v_dim = cfg.gdn_value_head_dim + conv_kernel_size = cfg.gdn_conv_kernel_dim + + key_dim = head_k_dim * num_k_heads + value_dim = head_v_dim * num_v_heads + conv_dim = key_dim * 2 + value_dim + + conv_state_shape = (conv_kernel_size - 1, conv_dim) + recurrent_state_shape = (num_v_heads, head_k_dim, head_v_dim) + + mamba_shapes = (conv_state_shape, recurrent_state_shape) + + torch_dtype = torch.bfloat16 + if str(cfg.dtype) == "float32": + torch_dtype = torch.float32 + elif str(cfg.dtype) == "float16": + torch_dtype = torch.float16 + mamba_dtypes = (torch_dtype, torch_dtype) + + # Calculate unpadded mamba page size + dtype_size = 4 if torch_dtype == torch.float32 else 2 + unpadded_mamba_page_size = sum(int(np.prod(shape)) * dtype_size for shape in mamba_shapes) + + # Calculate attn_page_size_bytes + from tpu_inference.layers.common.sharding import ShardingAxisName + from tpu_inference import utils as common_utils + + tp_axis_name = ShardingAxisName.ATTN_HEAD + model_cnt = common_utils.get_mesh_shape_product(self.runner.mesh, tp_axis_name) + + model_config = self.runner.model_config + text_config = getattr(model_config, "hf_text_config", getattr(model_config, "hf_config", None)) + base_num_kv_heads = model_config.get_total_num_kv_heads() + base_head_size = model_config.get_head_size() + + num_kv_heads = getattr(text_config, "num_global_key_value_heads", None) or base_num_kv_heads + head_size = getattr(text_config, "global_head_dim", None) or base_head_size + + num_kv_heads = common_utils.get_padded_num_heads(num_kv_heads, model_cnt) + head_size = common_utils.get_padded_head_dim(head_size) + + from tpu_inference.runner.kv_cache import get_attention_page_size_bytes + + block_size = self.runner.cache_config.block_size + + attn_page_size_bytes = get_attention_page_size_bytes( + self.runner.mesh, block_size, num_kv_heads, head_size, self.runner.kv_cache_dtype, False + ) + + # Calculate groups + num_layers = cfg.base_num_decoder_layers + num_attn = num_layers // interval + num_mamba = num_layers - num_attn + + # To allocate memory uniformly for a hybrid model's KV/recurrent cache page table, + # we group layers together. The uniform page size must support both attention and + # mamba layers. + # If the ratio of attention to mamba layers is relatively balanced (less than _HYBRID_LAYER_IMBALANCE_THRESHOLD), + # we use the larger count as the group size to minimize the total number of groups. + # If they are highly imbalanced (>= _HYBRID_LAYER_IMBALANCE_THRESHOLD), we group by the smaller count to prevent + # the page size from being inflated by excessive padding for the minority layer type. + min_count = min(num_attn, num_mamba) + max_count = max(num_attn, num_mamba) + if max_count < min_count * _HYBRID_LAYER_IMBALANCE_THRESHOLD: + group_size = max_count + else: + group_size = min_count + num_attn_groups = (num_attn + group_size - 1) // group_size + num_mamba_groups = (num_mamba + group_size - 1) // group_size + + uniform_page_size_bytes = num_attn_groups * attn_page_size_bytes + num_mamba_groups * unpadded_mamba_page_size + + # Set the padded page size on manager and config + self._hybrid_uniform_page_size_bytes = int(uniform_page_size_bytes) + self.runner.cache_config.mamba_page_size_padded = int(uniform_page_size_bytes) + + self._maybe_set_compact_mamba_num_blocks_override( + attn_page_size_bytes, + int(unpadded_mamba_page_size), + num_attn_groups, + num_mamba_groups, + num_attn, + num_mamba, + group_size, + ) + + kv_cache_spec = original_get_kv_cache_spec(self) + + if decoder_block_str in ("qwen3_next", "qwen3_5"): + for i in range(cfg.base_num_decoder_layers): + if (i + 1) % interval != 0: + layer_name = f"layer.{i}" + if layer_name in kv_cache_spec: + kv_cache_spec[layer_name] = MambaSpec( + block_size=kv_cache_spec[layer_name].block_size, + shapes=mamba_shapes, + dtypes=mamba_dtypes, + page_size_padded=self._hybrid_uniform_page_size_bytes, + ) + + return kv_cache_spec + + KVCacheManager.get_kv_cache_spec = patched_get_kv_cache_spec + max_logging.log("Successfully applied KVCacheManager patch for hybrid GDN models.") diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 46363dbf70..5429169f8a 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -1137,6 +1137,10 @@ def __call__( layer_kwargs = {"layer_idx": lyr} kv_cache = None if kv_caches is not None: + # For all decoder blocks (including QWEN3_NEXT/QWEN3_5 with vLLM flat-list + # kv_caches), pass the per-layer cache directly. For hybrid attention+GDN + # models, kv_caches[lyr] is a regular attention cache for attention layers + # and a (conv_state, recurrent_state) paged-mamba tuple for GDN layers. kv_cache = kv_caches[lyr] if cfg.decoder_block == DecoderBlockType.GPT_OSS: diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index e527b59b24..647ffb841e 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -16,8 +16,10 @@ # pylint: disable=arguments-differ # pylint: disable=no-name-in-module +import functools from typing import Any, cast import math +import os import jax import jax.nn @@ -30,6 +32,8 @@ from flax import nnx from maxtext.common.common_types import AttentionType, Config, DType, Array, BATCH, EMBED, MODEL_MODE_TRAIN, LENGTH, MODEL_MODE_AUTOREGRESSIVE +from maxtext.common.common_types import KV_BATCH, KV_HEAD +from maxtext.utils.sharding import logical_to_mesh_axes from maxtext.layers import attentions from maxtext.layers import initializers as max_initializers from maxtext.layers import moe @@ -449,7 +453,8 @@ class Qwen3NextGatedDeltaNet(nnx.Module): def __init__( self, config: Config, - inputs_shape: tuple, + inputs_shape: tuple | None = None, + mesh=None, dtype: DType = jnp.float32, model_mode: str = MODEL_MODE_TRAIN, *, @@ -458,9 +463,13 @@ def __init__( """ Args: config: MaxText configuration object. + mesh: Optional JAX device mesh (required for vLLM paged-state path). rngs: The random number generators for initialization, passed by the nnx.to_linen wrapper. """ self.config = config + self.mesh = mesh + + self._gdn_replicate_expert = os.environ.get("MAXTEXT_GDN_REPLICATE_EXPERT", "False").lower() == "true" cfg = self.config in_features = cfg.emb_dim @@ -474,7 +483,7 @@ def __init__( conv_kernel_size = cfg.gdn_conv_kernel_dim self.v_heads_per_k_head = self.num_v_heads // self.num_k_heads - if model_mode != MODEL_MODE_TRAIN: + if model_mode != MODEL_MODE_TRAIN and inputs_shape is not None: runtime_batch_size = inputs_shape[0] self.cache = kvcache.KVCache( @@ -495,7 +504,7 @@ def __init__( rngs=rngs, ) else: - self.cache = None + self.cache = None # No cache for train mode or when inputs_shape not provided # Submodule instantiations self.in_proj_qkvz = DenseGeneral( @@ -503,7 +512,7 @@ def __init__( out_features_shape=(self.key_dim * 2 + self.value_dim * 2), dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - kernel_axes=("embed", "mlp"), + kernel_axes=("embed", "gdn_head"), matmul_precision=cfg.matmul_precision, rngs=rngs, ) @@ -512,7 +521,7 @@ def __init__( out_features_shape=(self.num_v_heads * 2), dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - kernel_axes=("embed", "mlp"), + kernel_axes=("embed", "gdn_head"), matmul_precision=cfg.matmul_precision, rngs=rngs, ) @@ -551,7 +560,7 @@ def a_log_init(key, shape, dtype=jnp.float32): out_features_shape=(in_features,), dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - kernel_axes=("mlp", "embed"), + kernel_axes=("gdn_head", "embed"), matmul_precision=cfg.matmul_precision, rngs=rngs, ) @@ -562,6 +571,7 @@ def __call__( model_mode: str = MODEL_MODE_TRAIN, kv_cache=None, decoder_segment_ids: None | Array = None, + attention_metadata=None, **kwargs, ) -> tuple[Array, Any | None]: # hidden_states: (B, S, E) @@ -570,6 +580,17 @@ def __call__( active_cache = kv_cache if kv_cache is not None else self.cache + # When kv_cache is a 2-tuple of paged mamba state arrays from vLLM, use + # run_jax_gdn_attention from tpu_inference for correct sequential token processing. + use_paged_state = ( + kv_cache is not None + and isinstance(kv_cache, tuple) + and len(kv_cache) == 2 + and attention_metadata is not None + and getattr(attention_metadata, "mamba_state_indices", None) is not None + and self.mesh is not None + ) + # ========================================================================= # STEP A: Input Projections # ========================================================================= @@ -578,7 +599,9 @@ def __call__( # ba: (B, S, 2 * H_v) ba = self.in_proj_ba(hidden_states) - # QKVZ Reshaping and Splitting + # ========================================================================= + # QKVZ and BA Reshaping and Splitting (shared by both paths) + # ========================================================================= # Per-K_head group dim: 2 * D_k + 2 * D_v * V_per_K new_shape_qkvz = ( batch, @@ -588,6 +611,11 @@ def __call__( ) # mixed_qkvz: (B, S, H_k, 2*D_k + 2*D_v*V_per_K) mixed_qkvz = qkvz.reshape(new_shape_qkvz) + if self.mesh is not None: + logical_rules = None if self.config.using_pipeline_parallelism else self.config.logical_axis_rules + qkvz_pspec = logical_to_mesh_axes((KV_BATCH, None, KV_HEAD, None), mesh=self.mesh, rules=logical_rules) + qkvz_sharding = jax.sharding.NamedSharding(self.mesh, qkvz_pspec) + mixed_qkvz = jax.lax.with_sharding_constraint(mixed_qkvz, qkvz_sharding) split_indices_qkvz = [ self.head_k_dim, # D_k @@ -625,6 +653,92 @@ def __call__( # a: (B, S, H_v) a = a_raw.reshape(batch, seq_len, self.num_v_heads) + if use_paged_state: + # ========================================================================= + # vLLM PAGED STATE PATH: use tpu_inference fused conv + ragged delta-rule. + # ========================================================================= + try: + # pylint: disable=import-outside-toplevel + # pytype: disable=import-error + from tpu_inference.layers.common.gdn_attention import GdnAttentionConfig, run_jax_gdn_attention # pylint: disable=import-outside-toplevel + from tpu_inference.layers.common.ragged_gated_delta_rule_wrapper import RaggedGatedDeltaRuleImpl # pylint: disable=import-outside-toplevel + from tpu_inference.layers.common.sharding import ShardingAxisName # pylint: disable=import-outside-toplevel + from tpu_inference.layers.common.utils import reorder_concatenated_tensor_for_sharding # pylint: disable=import-outside-toplevel + from tpu_inference.utils import get_mesh_shape_product # pylint: disable=import-outside-toplevel + from jax.sharding import PartitionSpec as P_spec # pylint: disable=import-outside-toplevel + except ImportError as e: + raise ImportError( + "GDN attention kernel require the vllm-tpu package. Please install it with `pip install vllm-tpu`." + ) from e + + attn_data = ShardingAxisName.ATTN_DATA + # Head axis for the GDN kernel + the producer-side reshapes. Default ATTN_HEAD + # (model*expert); the experimental MAXTEXT_GDN_REPLICATE_EXPERT path uses 'model' only + # so GDN replicates over the expert axis (no expert-axis transpose all-to-all). + attn_head = ShardingAxisName.MODEL if self._gdn_replicate_expert else ShardingAxisName.ATTN_HEAD + tp_size = get_mesh_shape_product(self.mesh, attn_head) + num_tokens = batch * seq_len + + # Build mixed_qkv in the kernel's per-shard layout via shard_map concatenation. + # Each TP shard already holds its local q/k/v head slices → concatenate locally + # to get [q_local | k_local | v_local] with no cross-device communication. + q_flat = query.reshape(num_tokens, self.key_dim) # (T, key_dim) sharded on ATTN_HEAD + k_flat = key.reshape(num_tokens, self.key_dim) + v_flat = value_raw.reshape(num_tokens, self.value_dim) # (T, value_dim) sharded on ATTN_HEAD + mixed_qkv = jax.shard_map( + lambda q, k, v: jnp.concatenate([q, k, v], axis=-1), + mesh=self.mesh, + in_specs=(P_spec(attn_data, attn_head),) * 3, + out_specs=P_spec(attn_data, attn_head), + check_vma=False, + )(q_flat, k_flat, v_flat) + + b_flat = b.reshape(num_tokens, self.num_v_heads) + a_flat = a.reshape(num_tokens, self.num_v_heads) + + # Conv weight: transpose from (kernel_size, 1, conv_dim) → (conv_dim, 1, kernel_size), + # then reorder so each TP shard gets its local [q_local | k_local | v_local] channels. + conv_weight = jnp.transpose(self.conv1d.kernel.value, (2, 1, 0)) + conv_weight = reorder_concatenated_tensor_for_sharding( + conv_weight, [self.key_dim, self.key_dim, self.value_dim], tp_size, 0 + ) + + conv_state_paged, recurrent_state_paged = kv_cache + + # Use REF impl (pure JAX) to avoid Mosaic kernel compilation issues. + gdn_config = GdnAttentionConfig(ragged_gated_delta_rule_impl=RaggedGatedDeltaRuleImpl.REF) + + (new_conv_state_paged, new_recurrent_state_paged), gdn_output = run_jax_gdn_attention( + mixed_qkv, + b_flat, + a_flat, + conv_state_paged, + recurrent_state_paged, + conv_weight, + None, # conv_bias: MaxText conv1d uses use_bias=False. + jnp.asarray(self.A_log[...], dtype=cfg.dtype), + jnp.asarray(self.dt_bias[...], dtype=cfg.dtype), + attention_metadata.mamba_state_indices.astype(jnp.int32), + attention_metadata.query_start_loc, + attention_metadata.request_distribution, + attention_metadata.seq_lens, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + cfg.gdn_conv_kernel_dim, + mesh=self.mesh, + config=gdn_config, + ) + + # Reshape GDN output and apply gated norm + out projection. + gdn_output = gdn_output.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) + gated_output = self.norm(gdn_output, z) + gated_output = gated_output.reshape(batch, seq_len, -1) + output = self.out_proj(gated_output) + + return output, (new_conv_state_paged, new_recurrent_state_paged) + # Flatten head dimensions for concatenation before conv # q: (B, S, K_dim) q = query.reshape(batch, seq_len, -1) @@ -732,6 +846,48 @@ def extract_state(c_in, v_len): use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn, compute_dtype=cfg.dtype, ) + elif self.mesh is not None: + logical_rules = self.config.logical_axis_rules + recurrent_state_arg = ( + recurrent_state + if recurrent_state is not None + else jnp.zeros((batch, self.num_v_heads, self.head_k_dim, self.head_v_dim), dtype=cfg.dtype) + ) + qkv_pspec = logical_to_mesh_axes((KV_BATCH, None, KV_HEAD, None), mesh=self.mesh, rules=logical_rules) + g_beta_pspec = logical_to_mesh_axes((KV_BATCH, None, KV_HEAD), mesh=self.mesh, rules=logical_rules) + state_pspec = logical_to_mesh_axes((KV_BATCH, KV_HEAD, None, None), mesh=self.mesh, rules=logical_rules) + + @functools.partial( + jax.shard_map, + mesh=self.mesh, + in_specs=( + qkv_pspec, # query + qkv_pspec, # key + qkv_pspec, # value + g_beta_pspec, # g + g_beta_pspec, # beta + state_pspec, # initial_state + ), + out_specs=( + qkv_pspec, # core_attn_out + state_pspec, # final_state + ), + check_vma=False, + ) + def shard_mapped_delta_rule(q, k, v, g_val, beta_val, init_h): + return jax_chunk_gated_delta_rule( + query=q, + key=k, + value=v, + g=g_val, + beta=beta_val, + chunk_size=cfg.gdn_chunk_size, + initial_state=init_h, + use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn, + compute_dtype=cfg.dtype, + ) + + core_attn_out, next_recurrent_state = shard_mapped_delta_rule(query, key, value, g, beta, recurrent_state_arg) else: core_attn_out, next_recurrent_state = jax_chunk_gated_delta_rule( query, @@ -1125,7 +1281,7 @@ def __init__( batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) self.attention = Qwen3NextGatedDeltaNet( - config=cfg, inputs_shape=dummy_inputs_shape, dtype=cfg.dtype, model_mode=model_mode, rngs=rngs + config=cfg, inputs_shape=dummy_inputs_shape, mesh=self.mesh, dtype=cfg.dtype, model_mode=model_mode, rngs=rngs ) # Second LayerNorm, applied before the MoE block. @@ -1178,6 +1334,7 @@ def __call__( model_mode=model_mode, kv_cache=kv_cache, decoder_segment_ids=decoder_segment_ids, + attention_metadata=attention_metadata, ) # First residual connection after attention diff --git a/src/maxtext/models/qwen3_5.py b/src/maxtext/models/qwen3_5.py index 759fd180cd..ba6b0ab4c3 100644 --- a/src/maxtext/models/qwen3_5.py +++ b/src/maxtext/models/qwen3_5.py @@ -159,7 +159,7 @@ def __init__( batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) self.attention = Qwen3_5GatedDeltaNet( - config=cfg, inputs_shape=dummy_inputs_shape, dtype=cfg.dtype, model_mode=model_mode, rngs=rngs + config=cfg, inputs_shape=dummy_inputs_shape, mesh=self.mesh, dtype=cfg.dtype, model_mode=model_mode, rngs=rngs ) # Second LayerNorm, applied before the MoE block. @@ -212,6 +212,7 @@ def __call__( model_mode=model_mode, kv_cache=kv_cache, decoder_segment_ids=decoder_segment_ids, + attention_metadata=attention_metadata, ) # First residual connection after attention diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index 0e9519c343..bb53b95d6d 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -442,10 +442,7 @@ def _fix_one(path, restore_arg): if rank_mismatched_paths: sample = "\n".join(rank_mismatched_paths[:5]) more = f"\n ... and {len(rank_mismatched_paths) - 5} more" if len(rank_mismatched_paths) > 5 else "" - raise ValueError( - f"Checkpoint rank mismatches detected ({len(rank_mismatched_paths)}" - f" arrays):\n{sample}{more}" - ) + raise ValueError(f"Checkpoint rank mismatches detected ({len(rank_mismatched_paths)}" f" arrays):\n{sample}{more}") # Detect structural mismatch (e.g. scanned checkpoint loaded into unscanned model). # In that case the checkpoint tree has "layers" (all layers stacked) but the model @@ -892,9 +889,7 @@ def from_pretrained( with mesh: if config.load_parameters_path: - with handle_checkpoint_mismatch( - "load parameters", config.load_parameters_path - ): + with handle_checkpoint_mismatch("load parameters", config.load_parameters_path): ckptr = ocp.Checkpointer( ocp.PyTreeCheckpointHandler( restore_concurrent_gb=config.checkpoint_storage_concurrent_gb, @@ -1039,15 +1034,18 @@ def _build_value_target(v): restore_args = {"base": restore_args} if has_base_key else restore_args # Free memory used by initial sharded_state before restore, to make room for the incoming checkpoint arrays. + # Skip nnx.Cache variables — they hold runtime state (e.g. GDN conv/recurrent state) that is + # not present in the checkpoint and must remain valid after the restore. def _free_device_memory(node): - val = node - if isinstance(node, nnx.Variable) and not isinstance(node, nnx.RngState): + if isinstance(node, nnx.Variable) and not isinstance(node, (nnx.RngState, nnx.Cache)): inner = node.get_value() if hasattr(node, "get_value") else node[...] # Same QTensor caveat as `_build_value_target`: AQT serve-mode `qrhs.frozen` # wraps a QTensor whose `__getitem__` fails on `LogicallyPartitioned`. # We only need to free a single jax.Array leaf — for composite values # there's nothing to free at this level, so skip. val = inner if hasattr(inner, "shape") else None + else: + val = node if isinstance(val, jax.Array) and not val.is_deleted(): val.delete() diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 404e089210..3fa8391833 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -1925,6 +1925,8 @@ def setUp(self): [sys.argv[0], get_test_config_path()], **self.config_arguments, ) + devices_array = maxtext_utils.create_device_mesh(self.cfg) + self.mesh = Mesh(devices_array, self.cfg.mesh_axes) self.rng = jax.random.PRNGKey(0) self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)) @@ -1950,6 +1952,7 @@ def test_autoregression(self): gdn = Qwen3NextGatedDeltaNet( config=cfg, inputs_shape=lnx.shape, + mesh=self.mesh, dtype=cfg.dtype, model_mode=MODEL_MODE_PREFILL, rngs=self.nnx_rng, diff --git a/tests/unit/qwen3_next_vs_reference_test.py b/tests/unit/qwen3_next_vs_reference_test.py index f22cf72a17..7df214f57e 100644 --- a/tests/unit/qwen3_next_vs_reference_test.py +++ b/tests/unit/qwen3_next_vs_reference_test.py @@ -904,12 +904,13 @@ def test_gated_delta_net_structure(self): print("Running test_gated_delta_net_structure...") hidden_states_jax = jnp.ones((self.batch_size, self.seq_len, self.hidden_size), dtype=self.cfg.dtype) - jax_model = qwen3.Qwen3NextGatedDeltaNet(config=self.cfg, rngs=self.nnx_rngs) + jax_model = qwen3.Qwen3NextGatedDeltaNet(config=self.cfg, mesh=self.mesh, rngs=self.nnx_rngs) @jax.jit def run_jax(hidden_states): """Runs the JAX GatedDeltaNet model.""" - return jax_model(hidden_states) + output, _ = jax_model(hidden_states) + return output output_jax = run_jax(hidden_states_jax) @@ -932,22 +933,22 @@ def test_qwen3_next_rms_norm(self): expected_output = pt_model(hidden_states_pt) # 2. Set up the JAX implementation. - jax_model = Qwen3NextRMSNorm( - num_features=self.hidden_size, - eps=self.cfg.normalization_layer_epsilon, - dtype=jnp.float32, - weight_dtype=jnp.float32, - rngs=self.nnx_rngs, - ) + class DummyModule(nnx.Module): + + def __init__(self, hidden_size, eps, rngs): + self.norm = Qwen3NextRMSNorm(hidden_size, eps=eps, rngs=rngs) + + jax_model_wrapped = DummyModule(self.hidden_size, self.cfg.normalization_layer_epsilon, self.nnx_rngs) + jax_model = jax_model_wrapped.norm params = {"scale": nnx.Param(jnp.array(weight_pt.numpy()))} - nnx.update(jax_model.value, params) + nnx.update(jax_model, params) hidden_states_jax = jnp.array(hidden_states_pt.numpy()) @jax.jit def run_jax(x): """Runs the JAX Qwen3NextRMSNorm model.""" - return jax_model.value(x) # Call the module inside DataAttr + return jax_model(x) # Call the module actual_output = run_jax(hidden_states_jax) @@ -1047,17 +1048,50 @@ def test_gated_delta_net_full(self): with torch.no_grad(): expected_output = pt_model(hidden_states_pt) + def reorder_pt_qkvz_to_jax(w, num_heads, head_k_dim, head_v_dim): + key_dim = num_heads * head_k_dim + value_dim = num_heads * head_v_dim + q, k, v, z = np.split(w, [key_dim, 2 * key_dim, 2 * key_dim + value_dim], axis=0) + jax_heads = [] + for i in range(num_heads): + head_i = np.concatenate( + [ + q[i * head_k_dim : (i + 1) * head_k_dim], + k[i * head_k_dim : (i + 1) * head_k_dim], + v[i * head_v_dim : (i + 1) * head_v_dim], + z[i * head_v_dim : (i + 1) * head_v_dim], + ], + axis=0, + ) + jax_heads.append(head_i) + return np.concatenate(jax_heads, axis=0) + + def reorder_pt_ba_to_jax(w, num_heads): + b, a = np.split(w, 2, axis=0) + jax_heads = [] + for i in range(num_heads): + head_i = np.concatenate([b[i : i + 1], a[i : i + 1]], axis=0) + jax_heads.append(head_i) + return np.concatenate(jax_heads, axis=0) + # 2. Setup JAX model and map weights - jax_model = qwen3.Qwen3NextGatedDeltaNet(config=self.cfg, rngs=self.nnx_rngs) + jax_model = qwen3.Qwen3NextGatedDeltaNet(config=self.cfg, mesh=self.mesh, rngs=self.nnx_rngs) + assert jax_model.num_k_heads == jax_model.num_v_heads conv1d_weight_pt = pt_model.conv1d.weight.detach().numpy() # Transpose PT (out, in/groups, kw) -> JAX (kw, in/groups, out) # For depthwise, out=in=groups, so PT=(C, 1, kw) -> JAX=(kw, 1, C) conv1d_weight_jax = np.transpose(conv1d_weight_pt, (2, 1, 0)) + w_qkvz_pt = pt_model.in_proj_qkvz.weight.detach().numpy() + w_qkvz_jax = reorder_pt_qkvz_to_jax(w_qkvz_pt, jax_model.num_v_heads, jax_model.head_k_dim, jax_model.head_v_dim) + + w_ba_pt = pt_model.in_proj_ba.weight.detach().numpy() + w_ba_jax = reorder_pt_ba_to_jax(w_ba_pt, jax_model.num_v_heads) + params = { - "in_proj_qkvz": {"kernel": nnx.Param(jnp.array(pt_model.in_proj_qkvz.weight.T.detach().numpy()))}, - "in_proj_ba": {"kernel": nnx.Param(jnp.array(pt_model.in_proj_ba.weight.T.detach().numpy()))}, + "in_proj_qkvz": {"kernel": nnx.Param(jnp.array(w_qkvz_jax.T))}, + "in_proj_ba": {"kernel": nnx.Param(jnp.array(w_ba_jax.T))}, "conv1d": {"kernel": nnx.Param(jnp.array(conv1d_weight_jax))}, "A_log": nnx.Param(jnp.array(pt_model.A_log.detach().numpy())), "dt_bias": nnx.Param(jnp.array(pt_model.dt_bias.detach().numpy())), @@ -1070,7 +1104,8 @@ def test_gated_delta_net_full(self): @jax.jit def run_jax(x): """Runs the JAX GatedDeltaNet model.""" - return jax_model(x) + output, _ = jax_model(x) + return output actual_output = run_jax(hidden_states_jax) @@ -1242,13 +1277,14 @@ def _run_full_attention_jax_vs_pytorch_attention(self, attention_type): # 8. Run JAX Model @jax.jit def run_jax(inputs, segment_ids, positions): - return jax_model( + output, _ = jax_model( inputs, decoder_segment_ids=segment_ids, decoder_positions=positions, deterministic=True, model_mode="train", ) + return output jax_output = run_jax(hidden_states_jax, decoder_segment_ids_jax, decoder_positions_jax)