From 50f9662f347e91891d55890d987c96b8806fe101 Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Fri, 27 Feb 2026 20:34:10 +0000 Subject: [PATCH] deepseek sharding and mla attention plumbing --- src/maxtext/configs/base.yml | 3 + .../configs/inference/vllm_deepseek.yml | 69 +++++++++ src/maxtext/configs/post_train/rl.yml | 2 + src/maxtext/configs/types.py | 3 + src/maxtext/inference/vllm_decode.py | 36 ++--- .../vllm/maxtext_vllm_adapter/adapter.py | 15 +- src/maxtext/layers/attention_mla.py | 137 +++++++++++++++++- src/maxtext/layers/moe.py | 20 +-- src/maxtext/models/deepseek.py | 26 +++- .../trainers/post_train/rl/train_rl.py | 4 +- 10 files changed, 271 insertions(+), 44 deletions(-) create mode 100644 src/maxtext/configs/inference/vllm_deepseek.yml diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 398df849fe..83b20eeda6 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -471,6 +471,7 @@ logical_axis_rules: [ ['decode_length', ['sequence']], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], + ['moe_mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], @@ -1119,6 +1120,8 @@ vllm_hf_config_path: "" # A JSON string of overrides to apply to the HuggingFace-style config for the vLLM adapter. # This can be used to override specific settings without modifying the original config file. vllm_hf_overrides: {} +# Path to yaml file for loading vLLM config +vllm_config_path: "" # JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}') vllm_additional_config: {} # When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH] diff --git a/src/maxtext/configs/inference/vllm_deepseek.yml b/src/maxtext/configs/inference/vllm_deepseek.yml new file mode 100644 index 0000000000..a584d9469a --- /dev/null +++ b/src/maxtext/configs/inference/vllm_deepseek.yml @@ -0,0 +1,69 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +base_config: "vllm.yml" + +logical_axis_rules: [ + ['activation_batch', []], + ['activation_batch_no_exp', []], + ['activation_embed_and_logits_batch', ['expert']], + ['activation_embed_and_logits_batch_sequence', ['expert']], + ['activation_heads', ['model']], + ['activation_kv_heads', ['model']], + ['activation_attn_length', ['expert']], + ['activation_attn_length_no_exp', []], + ['activation_length', ['data', 'expert']], + ['activation_length_no_exp', 'data'], + ['activation_q_length', ['expert']], + ['activation_attn_embed', 'model'], + ['activation_embed', ['model', 'attn_dp']], + ['activation_mlp', ['model', 'attn_dp', 'expert']], + ['activation_kv', ['model']], + ['activation_prefill_kv_batch', ['expert']], + ['activation_kv_batch', []], + ['activation_kv_batch_no_exp', []], + ['activation_kv_head_dim', ['model', 'attn_dp', 'expert']], + ['activation_vocab', ['model', 'attn_dp']], + ['activation_norm_length', []], + ['activation_exp', ['expert']], + ['decode_batch', ['expert']], + ['decode_length', []], + ['mlp_no_fsdp', ['model', 'attn_dp', 'expert']], + ['vocab', ['model', 'attn_dp', 'expert']], + ['heads', ['expert', 'attn_dp', 'model']], + ['q_heads', []], + ['kv_heads', []], + ['kv_head_dim', ['model', 'attn_dp', 'expert']], + ['kv', ['model', 'attn_dp', 'expert']], + ['kv', []], + ['embed', []], + ['mlp', ['model', 'attn_dp', 'expert']], + ['moe_mlp', []], + ['embed_tensor_transpose', ['attn_dp', 'model']], + ['embed_no_exp', []], + ['q_lora', []], + ['kv_lora', []], + ['norm', []], + ['cache_heads', ['model']], + ['exp', ['expert', 'attn_dp', 'model']], + ['paged_kv_heads', ['model']], + ['cache_batch_prefill', []], + ['cache_batch', []], + ['cache_sequence', []], + ['cache_heads_none', []], + ['cache_kv', []], + ['kv_lora_up_proj',['expert', 'attn_dp', 'model']], + ['q_lora_up_proj',['expert', 'attn_dp', 'model']], + ] \ No newline at end of file diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index da455a13e2..5e2153d2c3 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -155,6 +155,8 @@ max_num_seqs: null async_scheduling: True # stop generation when any of these strings is generated stop_strings: null +# path to initialize vllm config +vllm_config_path: 'src/maxtext/configs/inference/vllm.yml' # ====== Checkpoint Configuration ====== enable_checkpointing: True diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 3df51ac106..de5d30dd5d 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1621,6 +1621,9 @@ class VLLM(BaseModel): description="Overrides for HuggingFace model config for MaxText model.", ) vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.") + vllm_config_path: str = Field( + "src/maxtext/configs/inference/vllm.yml", description="path to yaml file for loading vLLM config." + ) class RL(BaseModel): diff --git a/src/maxtext/inference/vllm_decode.py b/src/maxtext/inference/vllm_decode.py index f7df999547..620254b173 100644 --- a/src/maxtext/inference/vllm_decode.py +++ b/src/maxtext/inference/vllm_decode.py @@ -29,6 +29,7 @@ use_chat_template=True """ +import copy import os from typing import Any, Sequence @@ -40,7 +41,6 @@ from maxtext.utils import model_creation_utils from maxtext.utils import max_logging -from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR from maxtext.common.common_types import Config from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter from tunix.rl.rollout import base_rollout @@ -67,6 +67,21 @@ def decode_with_vllm(config: Config) -> None: config: MaxText config. """ # Prepare vLLM Arguments + # Use user-provided vllm_additional_config as base (includes model-specific + # overrides like base_num_decoder_layers, override_model_config, etc.), then + # fill in defaults and runtime-derived values on top. + additional_config = copy.deepcopy(config.vllm_additional_config) if config.vllm_additional_config else {} + additional_config.setdefault("maxtext_config", {}) + additional_config["maxtext_config"].setdefault("model_name", config.model_name) + additional_config["maxtext_config"].setdefault("weight_dtype", "bfloat16") + additional_config["maxtext_config"].setdefault("allow_split_physical_axes", True) + additional_config["maxtext_config"]["debug_sharding"] = config.debug_sharding + additional_config.setdefault("sharding", {}) + additional_config["sharding"].setdefault("sharding_strategy", {}) + additional_config["sharding"]["sharding_strategy"].setdefault("enable_dp_attention", config.enable_dp_attention) + # Pass vllm_config_path so the adapter can use it as the MaxText base config. + additional_config.setdefault("vllm_config_path", str(config.vllm_config_path)) + vllm_args = { "model": config.tokenizer_path, "max_model_len": config.max_target_length, @@ -76,19 +91,7 @@ def decode_with_vllm(config: Config) -> None: "hf_overrides": config.vllm_hf_overrides, "gpu_memory_utilization": config.hbm_utilization_vllm, "async_scheduling": config.async_scheduling, - "additional_config": { - "maxtext_config": { - "model_name": config.model_name, - "weight_dtype": "bfloat16", - "allow_split_physical_axes": True, - "debug_sharding": config.debug_sharding, - }, - "sharding": { - "sharding_strategy": { - "enable_dp_attention": config.enable_dp_attention, - }, - }, - }, + "additional_config": additional_config, } if config.load_parameters_path: @@ -106,8 +109,7 @@ def decode_with_vllm(config: Config) -> None: f"and EP={config.ici_expert_parallelism if enable_expert_parallel else 1}..." ) - vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml") - argv_list = ["", str(vllm_config_path), "log_config=False"] + argv_list = ["", str(config.vllm_config_path), "log_config=False"] vllm_config = pyconfig.initialize(argv_list) with nn_partitioning.axis_rules(vllm_config.logical_axis_rules): @@ -145,7 +147,7 @@ def decode_with_vllm(config: Config) -> None: max_tokens=max_tokens_to_generate, top_k=config.decode_sampling_top_k, top_p=config.decode_sampling_nucleus_p, - seed=FLAGS.seed, + # seed=FLAGS.seed, ) outputs = llm.generate(prompts, sampling_params) diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index a0f3afba76..47ad2228aa 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -69,8 +69,12 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters ) overrides["load_parameters_path"] = None - # Add base config path to positional args - base_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml") + # Add base config path to positional args — prefer the caller-supplied + # vllm_config_path from additional_config, fall back to vllm.yml default. + base_config_path = vllm_config.additional_config.get( + "vllm_config_path", + os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml"), + ) argv_list = ["", str(base_config_path)] maxtext_config = pyconfig.initialize(argv_list, **overrides) @@ -86,6 +90,11 @@ class MaxTextForCausalLM(nnx.Module): of the decoding step. """ + # Signal to tpu-inference model_loader that this class manages its own + # JIT-sharded initialization (via create_nnx_model with out_shardings). + # When True, model_loader skips wrapping __init__ in an outer bare @jax.jit, + _self_manages_sharding: bool = True + def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh): """Initializes the MaxTextForCausalLM model. @@ -232,7 +241,7 @@ def load_weights(self, rng_key: jax.Array) -> None: if self.model is not None: return - with self.mesh, nn.logical_axis_rules(""): + with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): model, _ = model_creation_utils.create_nnx_model( self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key ) diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index e0d6e4e9f1..9f42350c23 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -21,6 +21,8 @@ import jax from jax.ad_checkpoint import checkpoint_name from jax.experimental import layout +from jax.sharding import PartitionSpec as P +from jax.experimental import shard_map import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding @@ -624,7 +626,11 @@ def __init__( ) # Module attribute names must match names previously passed to Linen for checkpointing - self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None + self.MlaKVCache_0 = ( + self.init_mla_kv_caches(inputs_kv_shape) + if model_mode != MODEL_MODE_TRAIN and config.attention != "vllm_rpa" + else None + ) def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None: """Initializes the MLA-specific projections.""" @@ -942,7 +948,7 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode) cached_values = [None, None] - if self.config.attention != "paged" and model_mode != MODEL_MODE_TRAIN: + if self.config.attention != "paged" and self.config.attention != "vllm_rpa" and model_mode != MODEL_MODE_TRAIN: if self.config.mla_naive_kvcache: cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk) else: @@ -950,7 +956,115 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk ) - return key, value, cached_values + return key, value, cached_values, low_rank_main, key_rope + + def mla_rpa_vllm(self, q_nope, q_rope, k_latent, k_rope, mla_kv_cache, mla_metadata): + """Forward function for vLLM serving with MLA attention. + + Args: + q_nope: Query nope part [T, N, qk_nope_head_dim] + q_rope: Query rope part [T, N, qk_rope_head_dim] + k_latent: Latent KV representation [S, kv_lora_rank] (NOT expanded k_nope) + k_rope: Key rope part [S, qk_rope_head_dim] (NO head dimension) + mla_kv_cache: The KV cache + mla_metadata: Attention metadata + """ + md = mla_metadata + try: + # pylint: disable=import-outside-toplevel + # pytype: disable=import-error + from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention + from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import get_tuned_block_sizes + except ImportError as e: + raise ImportError( + "vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`." + ) from e + + if mla_kv_cache is None or mla_metadata is None: + raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.") + + wkv_b_kernel = self.wkv_b.kernel.value + wk_b_kernel = wkv_b_kernel[..., : self.qk_nope_head_dim] + wv_b_kernel = wkv_b_kernel[..., self.qk_nope_head_dim :] + q_absorbed = jnp.einsum("TNH,ANH->TNA", q_nope, wk_b_kernel) + + def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args): + seq_lens_local, block_tables_local = args[0], args[1] + + def _initialize_block_sizes(): + # Use local (per-shard) shapes inside shard_map to get correct block sizes. + max_num_tokens = q.shape[0] + max_num_seqs = seq_lens_local.shape[0] + num_page_indices = block_tables_local.shape[0] + assert num_page_indices % max_num_seqs == 0 + pages_per_seq = num_page_indices // max_num_seqs + # num_kv_pages_per_block = min(pages_per_seq, 16) + bkv_p, bq_sz = get_tuned_block_sizes( + q_nope.dtype, + q_nope.dtype, # changed to q_nope dtype from mla_kv_cache.dtype + self.num_query_heads, + 1, # num_kv_heads for MLA kernel + self.qk_nope_head_dim, + q_nope.shape[1], # page size ?? kv_cache.shape[1] + max_num_tokens, + pages_per_seq, + ) + num_kv_pages_per_block = min(pages_per_seq, bkv_p, 4) + num_queries_per_block = min(max_num_tokens, bq_sz, 4) # OOMS at 8 + return num_kv_pages_per_block, num_queries_per_block + + num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes() + output, kv_cache = mla_ragged_paged_attention( + q, + q_rope, + k, + k_rope, + kv_cache, + *args, + sm_scale=1.0, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + ) + return kv_cache, output + + in_specs = ( + P(("attn_dp", "model", "expert", "attn_dp_expert"), None, None), # q + P(("attn_dp", "model", "expert", "attn_dp_expert"), None, None), # q_rope + P(("attn_dp", "model", "expert", "attn_dp_expert"), None), # k + P(("attn_dp", "model", "expert", "attn_dp_expert"), None), # k_rope + P(("attn_dp", "model", "expert", "attn_dp_expert")), # kv_cache + P(("data", "attn_dp", "attn_dp_expert")), # md.seq_lens: Replicated + P(("data", "attn_dp", "attn_dp_expert")), # page_indices_flat: Replicated + P(("data", "attn_dp", "attn_dp_expert")), # query_start_loc: Replicated + P(("data", "attn_dp", "attn_dp_expert")), # distribution: Replicated + ) + + out_specs = ( + P(("attn_dp", "model", "expert", "attn_dp_expert"), None, None), + P(("attn_dp", "model", "expert", "attn_dp_expert")), + ) + + kv_cache, output = jax.jit( + shard_map.shard_map( + _mla_ragged_paged_attention, + mesh=self.mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + ), + )( + q_absorbed, + q_rope, + k_latent, + k_rope, + mla_kv_cache, + md.seq_lens, + md.block_tables, + md.query_start_loc, + md.request_distribution, + ) + output = jnp.einsum("TNA,ANH->TNH", output, wv_b_kernel) + return kv_cache, output def calculate_indexer_loss( self, @@ -1071,7 +1185,7 @@ def __call__( query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode) if self.config.force_q_layout: query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1))) - key, value, cached_values = self.mla_kv_projection( + key, value, cached_values, low_rank_main, key_rope = self.mla_kv_projection( inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk ) query = checkpoint_name(query, "query_proj") @@ -1119,7 +1233,22 @@ def __call__( ) unnormalized_out = unnormalized_out[..., : self.v_head_dim] out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out + elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN and kv_cache is not None: + batch, seq_len, num_heads, _ = query.shape + query = query.reshape(-1, query.shape[2], query.shape[3]) + q_nope, q_rope = jnp.split(query, [self.qk_nope_head_dim], axis=-1) + + k_latent = low_rank_main.reshape(-1, self.kv_lora_rank) + k_rope_squeezed = key_rope.reshape(-1, self.qk_rope_head_dim) + + updated_kv, attn_out = self.mla_rpa_vllm( + q_nope, q_rope, k_latent, k_rope_squeezed, mla_kv_cache=kv_cache, mla_metadata=attention_metadata + ) + out = attn_out.reshape(batch, seq_len, num_heads, self.v_head_dim) + kv_cache = updated_kv else: + if self.config.attention == "vllm_rpa" and kv_cache is None and model_mode != MODEL_MODE_TRAIN: + model_mode = MODEL_MODE_TRAIN out = self.attention_op( query, key, diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index e7f548c847..ba90f66eea 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -351,16 +351,16 @@ def __init__( if self.config.shard_exp_on_fsdp: # special sharding for dsv3 - self.wi_kernel_axes = ("embed_no_exp", None, "mlp") - self.wo_kernel_axes = ("embed_no_exp", "mlp", None) + self.wi_kernel_axes = ("embed_no_exp", None, "moe_mlp") + self.wo_kernel_axes = ("embed_no_exp", "moe_mlp", None) elif self.config.use_2d_fsdp_sharding: self.wi_kernel_axes = ("embed_no_exp", "mlp", None) self.wo_kernel_axes = ("embed_no_exp", "mlp", None) elif self.config.use_batch_split_schedule: self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes() else: - self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp") - self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp") + self.wi_kernel_axes = ("exp", "embed_no_exp", "moe_mlp") + self.wo_kernel_axes = ("exp", "moe_mlp", "embed_no_exp") if self.config.attention == "vllm_rpa": # vLLM uses 'model' as the tensor parallelism axis name @@ -1393,11 +1393,11 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): if self.config.moe_fsdp_use_two_stage_all_gather: # Unshard on fsdp axis - w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp")) - w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp")) + w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "moe_mlp")) + w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "moe_mlp")) # Unshard on fsdp_transpose axis - wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp", "embed_tensor_transpose")) + wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "moe_mlp", "embed_tensor_transpose")) # Make sure XLA does not optimize by combining above All-Gather to unshard # on FSDP axis and the subsequent unshard on fsdp_transpose axis @@ -1845,7 +1845,7 @@ def dense_matmul( dispatch_axis, ) with jax.named_scope("wi_0"): - w0_kernel_axes = ("exp", None, "mlp") + w0_kernel_axes = ("exp", None, "moe_mlp") w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes) layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( mlp_up_einsum, dispatch, w0_kernel, precision=matmul_precision @@ -1862,7 +1862,7 @@ def dense_matmul( ) layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0") with jax.named_scope("wi_1"): - w1_kernel_axes = ("exp", None, "mlp") + w1_kernel_axes = ("exp", None, "moe_mlp") w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes) layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( mlp_up_einsum, dispatch, w1_kernel, precision=matmul_precision @@ -1879,7 +1879,7 @@ def dense_matmul( layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1") layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1) with jax.named_scope("wo"): - wo_kernel_axes = ("exp", "mlp", None) + wo_kernel_axes = ("exp", "moe_mlp", None) wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes) intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)( mlp_down_einsum, diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 6d502d92c4..e9b5b9f6af 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -17,7 +17,7 @@ # pylint: disable=no-name-in-module import functools -from typing import Optional +from typing import Optional, Any from flax import nnx import jax @@ -208,9 +208,11 @@ def attention_op( previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, ): """Executes the attention layer.""" - attention_result, _ = self.self_attention( + attention_result, kv_cache = self.self_attention( x, x, decoder_positions, @@ -221,8 +223,10 @@ def attention_op( previous_chunk=previous_chunk, page_state=page_state, slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) - return self.with_logical_constraint(attention_result) + return self.with_logical_constraint(attention_result), kv_cache @property def logical_axis_names(self): @@ -269,6 +273,8 @@ def self_attention_with_norm_op( previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, ): """self-attention with normalization""" if self.is_mhc_enabled: @@ -288,7 +294,7 @@ def self_attention_with_norm_op( ) else: lnx = self.pre_attention_norm_op(inputs) - attention_lnx = self.attention_op( + attention_lnx, kv_cache = self.attention_op( lnx, decoder_segment_ids, decoder_positions, @@ -296,11 +302,13 @@ def self_attention_with_norm_op( previous_chunk, page_state, slot, + kv_cache, + attention_metadata, ) intermediate_inputs = inputs + attention_lnx # Normalization hidden_states = self.post_attention_norm_op(intermediate_inputs) - return hidden_states, intermediate_inputs + return hidden_states, intermediate_inputs, kv_cache def engram_op(self, x, decoder_input_tokens): normed_x = self.engram_layer_norm(x) @@ -363,7 +371,7 @@ def __call__( engram_output = self.engram_op(x, decoder_input_tokens) x = x + engram_output - hidden_states, intermediate_inputs = self.self_attention_with_norm_op( + hidden_states, intermediate_inputs, kv_cache = self.self_attention_with_norm_op( x, decoder_segment_ids, decoder_positions, @@ -371,6 +379,8 @@ def __call__( previous_chunk, page_state, slot, + kv_cache, + attention_metadata, ) if self.is_mhc_enabled: @@ -492,7 +502,7 @@ def __call__( engram_output = self.engram_op(x, decoder_input_tokens) x = x + engram_output - hidden_states, intermediate_inputs = self.self_attention_with_norm_op( + hidden_states, intermediate_inputs, kv_cache = self.self_attention_with_norm_op( x, decoder_segment_ids, decoder_positions, @@ -500,6 +510,8 @@ def __call__( previous_chunk, page_state, slot, + kv_cache, + attention_metadata, ) if self.is_mhc_enabled: diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 7f5a33eed6..f096349800 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -72,7 +72,6 @@ os.environ["SKIP_JAX_PRECOMPILE"] = "1" from maxtext.configs import pyconfig -from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter from maxtext.trainers.post_train.rl.evaluate_rl import evaluate from maxtext.trainers.post_train.rl import utils_rl @@ -472,8 +471,7 @@ def create_rl_components( raise ValueError(f"Failed to parse additional_config JSON: {e}") from e # We need to parse vLLM config to get the logical axis rules for the sampler config. - vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml") - argv_list = ["", str(vllm_config_path), "log_config=False"] + argv_list = ["", str(trainer_config.vllm_config_path), "log_config=False"] vllm_config = pyconfig.initialize(argv_list) cluster_config = rl_cluster_lib.ClusterConfig(