diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index ec1b6b4fe2..65e6d28bc1 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -102,6 +102,7 @@ class DecoderBlockType(enum.Enum): SIMPLE_MLP = "simple_mlp" LLAMA4 = "llama4" OLMO3 = "olmo3" + DEEPSEEK_CUSTOM = "deepseek_custom" class AttentionType(enum.Enum): diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 77751479ce..2876a01cf9 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -155,6 +155,12 @@ base_num_kv_heads: 16 base_mlp_dim: 7168 base_num_decoder_layers: 16 head_dim: 128 +attention_output_dim: -1 +local_num_query_heads: -1 +local_num_kv_heads: -1 +global_num_query_heads: -1 +global_num_kv_heads: -1 +attention_layer_hybrid_ratio: -1 mlp_activations: ["silu", "linear"] mlp_activations_limit: -1.0 dropout_rate: 0.0 @@ -184,6 +190,11 @@ num_experts_per_tok: 1 megablox: true sparse_matmul: true capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default +ragged_buffer_factor: -1.0 # a factor to determine the size of the ragged buffer for routed MoE activations. +# By default (-1), this buffer will be worst case size to ensure no dropping. +# When set to 1.0 this buffer if set to the size assuming perfectly balanced. If the routing dictates +# a size larger than this then tokens will be dropped. +# In general if ragged_buffer_factor>0, the ragged_buffer_size is is balanced_size * ragged_buffer_factor. load_balance_loss_weight: 0.0 # weight for the load balance loss use_random_routing: false # whether to use random routing for debug/test purpose use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul @@ -240,6 +251,8 @@ use_2d_fsdp_sharding: False # deepseek moe base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim. +moe_model_dim: -1 # dimension of token entering moe layer. +shared_expert_mlp_dim: -1 # intermediate dimension of the shared expert. first_num_dense_layers: 0 # number of initial dense layers in the model shared_experts: 1 routed_scaling_factor: 1.0 # scaling factor for routing scores @@ -485,6 +498,7 @@ logical_axis_rules: [ ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], ['embed_no_exp', ['fsdp', 'sequence', 'context']], ['embed_tensor_transpose', ['tensor_transpose']], + ['attention_out_proj', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], diff --git a/src/maxtext/configs/models/deepseek3-custom-large.yml b/src/maxtext/configs/models/deepseek3-custom-large.yml new file mode 100644 index 0000000000..d78e9f17a4 --- /dev/null +++ b/src/maxtext/configs/models/deepseek3-custom-large.yml @@ -0,0 +1,62 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for DeepSeek Custom + + +base_emb_dim: 16384 +moe_model_dim: 8192 +base_moe_mlp_dim: 16384 # (2 * 8192) +shared_expert_mlp_dim: 32768 # (4 * 8192) + +base_num_decoder_layers: 61 +first_num_dense_layers: 3 +mlp_activations: ["silu","linear"] +vocab_size: 129280 +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-6 + +num_experts: 256 +num_experts_per_tok: 4 # (1 shared + 4 routed) +shared_experts: 1 +routed_scaling_factor: 2.5 +routed_score_func: "sigmoid" +routed_bias: True +decoder_block: "deepseek_custom" + +# Hybrid GQA Attention +attention_output_dim: 8192 # same as moe_model_dim +attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio +inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio +head_dim: 256 + +local_num_query_heads: 64 +local_num_kv_heads: 8 +sliding_window_size: 1024 + +global_num_query_heads: 64 +global_num_kv_heads: 4 + +mscale: 1.0 +# RoPE +rope_type: "yarn" +rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000 +max_position_embeddings: 163840 +original_max_position_embeddings: 4096 +rope_factor: 40 +beta_fast: 32 +rope_interleave: True +rope_truncate: True +rope_attention_scaling: False diff --git a/src/maxtext/configs/models/deepseek3-custom-small.yml b/src/maxtext/configs/models/deepseek3-custom-small.yml new file mode 100644 index 0000000000..142c4db397 --- /dev/null +++ b/src/maxtext/configs/models/deepseek3-custom-small.yml @@ -0,0 +1,49 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for DeepSeek Custom + +base_emb_dim: 1024 +moe_model_dim: 512 +base_moe_mlp_dim: 1024 +shared_expert_mlp_dim: 4096 +num_experts: 16 +num_experts_per_tok: 2 +shared_experts: 1 +base_num_decoder_layers: 4 +first_num_dense_layers: 1 +mlp_activations: ["silu", "linear"] +vocab_size: 129280 +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-6 +routed_scaling_factor: 2.5 +routed_score_func: "sigmoid" +routed_bias: True +decoder_block: "deepseek_custom" + + +# Hybrid GQA Attention + +attention_output_dim: 512 # same as moe_model_dim +attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio +inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio +head_dim: 256 + +local_num_query_heads: 4 +local_num_kv_heads: 2 +sliding_window_size: 128 + +global_num_query_heads: 4 +global_num_kv_heads: 1 \ No newline at end of file diff --git a/src/maxtext/configs/models/deepseek3-custom.yml b/src/maxtext/configs/models/deepseek3-custom.yml new file mode 100644 index 0000000000..c3ffdb4033 --- /dev/null +++ b/src/maxtext/configs/models/deepseek3-custom.yml @@ -0,0 +1,62 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for DeepSeek Custom + + +base_emb_dim: 7168 +moe_model_dim: 3072 +base_moe_mlp_dim: 6144 # (2 * 3072) +shared_expert_mlp_dim: 15360 # (5 * 3072) + +base_num_decoder_layers: 61 +first_num_dense_layers: 3 +mlp_activations: ["silu","linear"] +vocab_size: 129280 +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-6 + +num_experts: 256 +num_experts_per_tok: 4 # (1 shared + 4 routed) +shared_experts: 1 +routed_scaling_factor: 2.5 +routed_score_func: "sigmoid" +routed_bias: True +decoder_block: "deepseek_custom" + +# Hybrid GQA Attention +attention_output_dim: 3072 # same as moe_model_dim +attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio +inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio +head_dim: 256 + +local_num_query_heads: 64 +local_num_kv_heads: 8 +sliding_window_size: 1024 + +global_num_query_heads: 64 +global_num_kv_heads: 4 + +mscale: 1.0 +# RoPE +rope_type: "yarn" +rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000 +max_position_embeddings: 163840 +original_max_position_embeddings: 4096 +rope_factor: 40 +beta_fast: 32 +rope_interleave: True +rope_truncate: True +rope_attention_scaling: False diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 21296e965d..135622ce42 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -224,6 +224,9 @@ class ProfilerType(str, Enum): "deepseek3-tiny", "deepseek3.2-671b", "deepseek-custom", + "deepseek3-custom-small", + "deepseek3-custom", + "deepseek3-custom-large", "kimi-k2-1t", "gemma-7b", "gemma-2b", @@ -437,6 +440,14 @@ class ModelArchitecture(BaseModel): base_mlp_dim: int = Field(7168, description="Base dimension of the MLP layer.") base_num_decoder_layers: int = Field(16, description="Base number of decoder layers.") head_dim: int = Field(128, description="Dimension of each attention head.") + attention_output_dim: int = Field(-1, description="Override output dimension for attention block") + local_num_query_heads: int = Field(-1, description="Number of query heads in local context layers.") + local_num_kv_heads: int = Field(-1, description="Number of KV heads in local context layers.") + global_num_query_heads: int = Field(-1, description="Number of query heads in global context layers.") + global_num_kv_heads: int = Field(-1, description="Number of KV heads in global context layers.") + attention_layer_hybrid_ratio: int = Field( + -1, description="Ratio of layer context styles (e.g. 5 means 4 local followed by 1 global)." + ) mlp_activations: list[str] = Field(["silu", "linear"], description="Activation functions in the MLP layer.") mlp_activations_limit: float = Field( -1.0, @@ -617,6 +628,7 @@ class MoEGeneral(BaseModel): num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.") num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.") capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.") + ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.") load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.") use_custom_sort_vjp: bool = Field( True, @@ -714,6 +726,8 @@ class DeepSeekMoE(BaseModel): """Configuration specific to DeepSeek-style MoE layers.""" base_moe_mlp_dim: int = Field(7168, description="Intermediate dimension at MoE layer (DeepSeek style).") + moe_model_dim: int = Field(-1, description="Dimension of tokens entering the MoE layer.") + shared_expert_mlp_dim: int = Field(-1, description="Intermediate dimension for the shared expert.") first_num_dense_layers: NonNegativeInt = Field(0, description="Number of initial dense layers in the model.") shared_experts: PositiveInt = Field(1, description="Number of shared experts.") routed_scaling_factor: float = Field(1.0, description="Scaling factor for routing scores.") diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index ce69b7a396..3d455010a7 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -42,6 +42,7 @@ from maxtext.models import ( deepseek, deepseek_batchsplit, + deepseek_custom, gemma, gemma2, gemma3, @@ -52,7 +53,6 @@ mistral, mixtral, olmo3, - qwen2, qwen3, simple_layer, ) @@ -458,6 +458,14 @@ def get_decoder_layers(self): deepseek.DeepSeekDenseLayerToLinen, deepseek.DeepSeekMoELayerToLinen, ] + case DecoderBlockType.DEEPSEEK_CUSTOM: + deepseek_custom_moe_layer = deepseek_custom.DeepSeekMoELayerToLinen + if self.config.scan_layers and self.config.attention_layer_hybrid_ratio > 1: + deepseek_custom_moe_layer = deepseek_custom.DeepSeekMoEScannableBlockToLinen + return [ + deepseek_custom.DeepSeekDenseLayerToLinen, + deepseek_custom_moe_layer, + ] case DecoderBlockType.GEMMA: return [gemma.GemmaDecoderLayerToLinen] case DecoderBlockType.GEMMA2: @@ -468,8 +476,6 @@ def get_decoder_layers(self): return [gpt3.Gpt3DecoderLayerToLinen] case DecoderBlockType.GPT_OSS: return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen] - case DecoderBlockType.QWEN2: - return [qwen2.Qwen2DecoderLayerToLinen] case DecoderBlockType.QWEN3: return [qwen3.Qwen3DecoderLayerToLinen] case DecoderBlockType.QWEN3_MOE: @@ -525,10 +531,10 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.MISTRAL, DecoderBlockType.MIXTRAL, DecoderBlockType.DEEPSEEK, + DecoderBlockType.DEEPSEEK_CUSTOM, DecoderBlockType.GEMMA, DecoderBlockType.GEMMA2, DecoderBlockType.GEMMA3, - DecoderBlockType.QWEN2, DecoderBlockType.QWEN3, DecoderBlockType.QWEN3_MOE, DecoderBlockType.GPT_OSS, @@ -577,7 +583,7 @@ def get_pipeline_stage_module(self, decoder_blocks): """get pipeline stage module""" def get_layer_to_pipeline(blocks, cfg): - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): return blocks[1] # return the sparse block else: return blocks[0] @@ -803,7 +809,7 @@ def __call__( if cfg.pipeline_fsdp_ag_once or cfg.pipeline_fsdp_ag_per_repeat else None ) - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." dense_layer = RemattedBlockLayers[0] moe_layer = RemattedBlockLayers[1] @@ -849,7 +855,7 @@ def __call__( )(y, *broadcast_args) else: if cfg.scan_layers: - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." layer_call_kwargs = { "page_state": page_state, @@ -927,10 +933,31 @@ def __call__( policy=policy, ) else: + scan_length = num_moe_layers + if cfg.decoder_block == DecoderBlockType.DEEPSEEK_CUSTOM and cfg.scan_layers: + if num_moe_layers % cfg.inhomogeneous_layer_cycle_interval != 0: + raise ValueError( + f"num_moe_layers ({num_moe_layers}) must be divisible by " + f"inhomogeneous_layer_cycle_interval ({cfg.inhomogeneous_layer_cycle_interval}) " + "when using DeepSeek Custom and scan_layers is True." + ) + if cfg.attention_layer_hybrid_ratio != cfg.inhomogeneous_layer_cycle_interval: + raise ValueError( + f"attention_layer_hybrid_ratio ({cfg.attention_layer_hybrid_ratio}) and " + f"inhomogeneous_layer_cycle_interval ({cfg.inhomogeneous_layer_cycle_interval}) " + "must be the same." + ) + scan_length = num_moe_layers // cfg.inhomogeneous_layer_cycle_interval + max_logging.log( + f"scan_length: {scan_length}, " + f"num_moe_layers // cfg.inhomogeneous_layer_cycle_interval: " + f"{num_moe_layers // cfg.inhomogeneous_layer_cycle_interval}" + ) + y, _ = self.scan_decoder_layers( cfg, moe_layer, - num_moe_layers, + scan_length, "moe_layers", mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), @@ -968,7 +995,7 @@ def __call__( **layer_kwargs, )(y, *broadcast_args) else: - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." dense_layer = RemattedBlockLayers[0] moe_layer = RemattedBlockLayers[1] @@ -1058,11 +1085,14 @@ def __call__( kv_caches["key_cache"][lyr] = returned_cache[0] kv_caches["value_cache"][lyr] = returned_cache[1] - if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): - visual_embeds = deepstack_visual_embeds[lyr] + if ( + deepstack_visual_embeds is not None + and lyr < len(deepstack_visual_embeds) + and bidirectional_mask is not None + and deepstack_visual_embeds[lyr] is not None + ): # Use bidirectional_mask to identify visual token positions - if bidirectional_mask is not None and visual_embeds is not None: - y = deepstack_process(y, bidirectional_mask, visual_embeds) + y = deepstack_process(y, bidirectional_mask, deepstack_visual_embeds[lyr]) assert isinstance(y, jax.Array) diff --git a/src/maxtext/layers/linears.py b/src/maxtext/layers/linears.py index 4af9c5c530..9c25935ced 100644 --- a/src/maxtext/layers/linears.py +++ b/src/maxtext/layers/linears.py @@ -474,6 +474,7 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.GEMMA3, DecoderBlockType.QWEN3, DecoderBlockType.DEEPSEEK, + DecoderBlockType.DEEPSEEK_CUSTOM, DecoderBlockType.LLAMA4, ): return functools.partial(normalizations.RMSNorm, num_features=num_features) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index e7f548c847..8a3dd161c1 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -373,8 +373,15 @@ def __init__( else: self._expert_parallelism_name = "expert" + # Some architectures may have a dedicated model dimension for the MoE layers, + # different from the primary embedding dimension. + self.in_features = self.config.moe_model_dim if self.config.moe_model_dim > 0 else self.config.emb_dim + max_logging.log( + f" [RoutedMoE Block] In Features: {self.in_features}, Intermediate Dim: {intermediate_dim}, " + f"Total Experts: {num_experts}, Active Experts/Token: {num_experts_per_tok}" + ) self.gate = GateLogit( - in_features_shape=self.config.emb_dim, + in_features_shape=self.in_features, out_features_shape=self.num_experts, mesh=self.mesh, model_name=self.config.model_name, @@ -392,22 +399,20 @@ def __init__( # pylint: disable=protected-access self.activation_fn = linears._convert_to_activation_function(self.config.mlp_activations[0]) - - kernel_in_axis = np.arange(1) + kernel_in_axis = np.arange(1) # the first dimension is the number of experts kernel_out_axis = np.arange(1, 2) if quantizations.in_serve_mode(self.quant): # During aqt convert state we delete kernel weight from params to save # memory. Instead they are retrieved from the tensors stored in the 'aqt' - # collection. - self.wi_0 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim)) - self.wi_1 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim)) - self.wo = jnp.zeros((num_experts, intermediate_dim, self.config.emb_dim)) + self.wi_0 = jnp.zeros((num_experts, self.in_features, intermediate_dim)) + self.wi_1 = jnp.zeros((num_experts, self.in_features, intermediate_dim)) + self.wo = jnp.zeros((num_experts, intermediate_dim, self.in_features)) else: self.wi_0 = nnx.Param( self.kernel_init( self.rngs.params(), - (num_experts, self.config.emb_dim, intermediate_dim), + (num_experts, self.in_features, intermediate_dim), weight_dtype, kernel_in_axis, kernel_out_axis, @@ -417,7 +422,7 @@ def __init__( self.wi_1 = nnx.Param( self.kernel_init( self.rngs.params(), - (num_experts, self.config.emb_dim, intermediate_dim), + (num_experts, self.in_features, intermediate_dim), weight_dtype, kernel_in_axis, kernel_out_axis, @@ -427,7 +432,7 @@ def __init__( self.wo = nnx.Param( self.kernel_init( self.rngs.params(), - (self.num_experts, self.intermediate_dim, self.config.emb_dim), + (self.num_experts, self.intermediate_dim, self.in_features), self.weight_dtype, kernel_in_axis, kernel_out_axis, @@ -439,7 +444,7 @@ def __init__( wi_bias_axes = ("exp", "activation_mlp") wo_bias_axes = ("exp", "activation_embed") wi_bias_shape = (self.num_experts, self.intermediate_dim) - wo_bias_shape = (self.num_experts, self.config.emb_dim) + wo_bias_shape = (self.num_experts, self.in_features) self.wi_0_bias = nnx.Param( default_bias_init(self.rngs.params(), wi_bias_shape, self.weight_dtype), sharding=wi_bias_axes, @@ -875,6 +880,14 @@ def transform_array(input_array, shard_id, strategy, is_batch_sharded): ) return input_offsets, send_sizes, output_offsets, recv_sizes + def get_ragged_buffer_size(self, local_expert_size, local_batch): + if self.config.ragged_buffer_factor > 0.0: + balanced_size = local_batch + return int(balanced_size * self.config.ragged_buffer_factor) + else: + max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) + return int(local_batch * max_local_experts_per_tok) + def transform_bias(self, experts_index, *biases): """Selects bias values for a variable number of bias tensors based on chosen experts.""" return tuple(bias[experts_index] for bias in biases) @@ -902,7 +915,7 @@ def gmm( else: tokamax_group_sizes = tokamax.RaggedDotGroupSizes( group_sizes, - max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), + max_utils.generate_representative_group_sizes(inputs.shape[0], group_sizes.shape[0]), ) pad_length = self.config.wi_tile_fwd_batch_seq hs_shape = inputs.shape @@ -1166,9 +1179,8 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r # In the worst case, all of the global input data is assigned to each expert in the current shard. # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs. - max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) - buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok) - output_shape = jnp.zeros((buffer_size, self.config.emb_dim), dtype=x.dtype) + buffer_size = self.get_ragged_buffer_size(local_expert_size, jnp.shape(x)[0]) + output_shape = jnp.zeros((buffer_size, self.in_features), dtype=x.dtype) x = jax.lax.ragged_all_to_all( x, @@ -1323,7 +1335,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): ) # Sum up the partial outputs across the expert shards. - output = jnp.reshape(output, (-1, sequence_length, self.config.emb_dim // self.get_tensor_parallelism_size())) + output = jnp.reshape(output, (-1, sequence_length, self.in_features // self.get_tensor_parallelism_size())) output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True) else: @@ -1334,7 +1346,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): output_shape = jnp.zeros( ( original_inputs_first_dim, - self.config.emb_dim // self.get_tensor_parallelism_size(), + self.in_features // self.get_tensor_parallelism_size(), ), dtype=intermediate_output.dtype, ) @@ -2076,23 +2088,33 @@ def __init__( self.rngs = rngs # NOTE: the name MoeBlock_0 is to ensure reverse compatibility with # existing checkpoints for routed experts. + in_features = self.config.moe_model_dim if self.config.moe_model_dim > 0 else self.config.emb_dim self.MoeBlock_0 = RoutedMoE( config=self.config, num_experts=self.config.num_experts, num_experts_per_tok=self.config.num_experts_per_tok, mesh=self.mesh, kernel_init=nd_dense_init(1.0, "fan_in", "truncated_normal"), - kernel_axes=("embed", None), - intermediate_dim=self.config.moe_mlp_dim, + kernel_axes=kernel_axes, + intermediate_dim=self.config.base_moe_mlp_dim, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, quant=self.quant, rngs=self.rngs, ) + # If shared_expert_mlp_dim is not set, use the base_moe_mlp_dim. + shared_expert_dim = self.config.shared_expert_mlp_dim if self.config.shared_expert_mlp_dim > 0 else self.config.base_moe_mlp_dim + max_logging.log( + " [RoutedAndSharedMoE] Shared Experts: %s, " + "Shared Exp MLP Dim: %s" % ( + self.config.shared_experts, + self.config.shared_experts * shared_expert_dim, + ) + ) self.shared_experts = linears.MlpBlock( mesh=self.mesh, - in_features=self.config.emb_dim, - intermediate_dim=self.config.shared_experts * self.config.moe_mlp_dim, + in_features=in_features, + intermediate_dim=self.config.shared_experts * shared_expert_dim, activations=self.config.mlp_activations, intermediate_dropout_rate=self.config.dropout_rate, dtype=self.config.dtype, diff --git a/src/maxtext/models/deepseek_custom.py b/src/maxtext/models/deepseek_custom.py new file mode 100644 index 0000000000..8c3335dac7 --- /dev/null +++ b/src/maxtext/models/deepseek_custom.py @@ -0,0 +1,604 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer model definition.""" +# pylint: disable=arguments-differ +# pylint: disable=no-name-in-module + +from typing import Optional + +from flax import nnx +from jax.ad_checkpoint import checkpoint_name +import jax.numpy as jnp +from jax.sharding import Mesh +from maxtext.common.common_types import AttentionType +from maxtext.common.common_types import Config +from maxtext.common.common_types import MODEL_MODE_PREFILL +from maxtext.inference import page_manager +from maxtext.layers import attentions +from maxtext.layers import initializers +from maxtext.layers import linears +from maxtext.layers import moe +from maxtext.layers import nnx_wrappers +from maxtext.layers import quantizations +from maxtext.layers.linears import Dropout +from maxtext.layers.normalizations import RMSNorm +from maxtext.utils import max_logging +from maxtext.utils import max_utils +from maxtext.utils.sharding import create_sharding +from maxtext.utils.sharding import maybe_shard_with_logical + + +class CustomAttention(attentions.Attention): + """Custom GQA attention that supports sub-dimensional output.""" + + def init_out_w(self, output_dim: int) -> nnx.Module: + """Initializes the output projection.""" + if not self.config.attention_output_dim > 0: + raise ValueError( + "attention_output_dim must be set to a positive integer for CustomAttention." + ) + + in_features = (self.num_query_heads, self.head_dim) + out_kernel_axis = ( + (None, None, None) + if self.config.ici_context_autoregressive_parallelism > 1 + else ("heads", "kv", "embed") + ) + axis = (-2, -1) + + return linears.DenseGeneral( + in_features_shape=in_features, + out_features_shape=self.config.attention_output_dim, + axis=axis, + kernel_init=self.kernel_init, + kernel_axes=out_kernel_axis, # trade speed with memory + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + use_bias=self.use_bias_in_projections, + rngs=self.rngs, + ) + + +class DeepSeekGenericLayer(nnx.Module): + """Generic DeepSeek layer with Multi-Head Latent Attention. + + This is to be used as a base class for DeepSeek layers with dense/sparse MLPs. + This class follows a pattern of separating module creation from execution. + """ + + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = -1, + ) -> None: + self.config = config + self.model_mode = model_mode + self.mesh = mesh + self.quant = quant + self.rngs = rngs + self.is_mhc_enabled = config.mhc_expansion_rate > 1 + self.layer_idx = layer_idx + + # GQA Hybrid routing calculation + attention_layer_hybrid_ratio = self.config.attention_layer_hybrid_ratio + + # All dense layers use local attention. + is_global_attention = False + attention_cls = CustomAttention + if isinstance(self, DeepSeekDenseLayer): + attention_cls = attentions.Attention + elif attention_layer_hybrid_ratio > 0: + is_global_attention = (self.layer_idx + 1) % attention_layer_hybrid_ratio == 0 + + self.attention_type = AttentionType.GLOBAL if is_global_attention else AttentionType.LOCAL_SLIDING + + if is_global_attention and self.config.global_num_query_heads > 0: + self.num_query_heads = self.config.global_num_query_heads + self.num_kv_heads = self.config.global_num_kv_heads + self.sliding_window_size = None + elif not is_global_attention and self.config.local_num_query_heads > 0: + self.num_query_heads = self.config.local_num_query_heads + self.num_kv_heads = self.config.local_num_kv_heads + self.sliding_window_size = self.config.sliding_window_size + else: + self.num_query_heads = self.config.base_num_query_heads + self.num_kv_heads = self.config.base_num_kv_heads + self.sliding_window_size = None + + max_logging.log( + f"Initializing {self.__class__.__name__} - Layer: {layer_idx}, " + f"Context: {'Global' if is_global_attention else 'Local'} " + f"(Q_Heads: {self.num_query_heads}, KV_Heads: {self.num_kv_heads})" + ) + + batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode( + self.config, self.model_mode + ) + self.dummy_inputs_shape = (batch_size, sequence_length, self.config.emb_dim) + + self.out_sharding = create_sharding(self.mesh, self.logical_axis_names) + self.mlp_intermediate_sharding = create_sharding( + self.mesh, self.mlp_logical_axis_names + ) + + self.pre_self_attention_layer_norm = RMSNorm( + num_features=self.dummy_inputs_shape[-1], + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + kernel_axes=("norm",), + epsilon=self.config.normalization_layer_epsilon, + rngs=rngs, + ) + + self.post_self_attention_layer_norm = RMSNorm( + num_features=self.dummy_inputs_shape[-1], + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + kernel_axes=("norm",), + epsilon=self.config.normalization_layer_epsilon, + rngs=rngs, + ) + + self.self_attention = attention_cls( + config=self.config, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.config.head_dim, + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + attention_kernel=self.config.attention, + attention_type=self.attention_type, + sliding_window_size=self.sliding_window_size, + inputs_q_shape=self.dummy_inputs_shape, + inputs_kv_shape=self.dummy_inputs_shape, + mesh=mesh, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + dropout_rate=self.config.dropout_rate, + name="self_attention", + quant=quant, + kv_quant=quantizations.configure_kv_quant(config), + model_mode=model_mode, + rngs=rngs, + attn_logits_soft_cap=self.config.attn_logits_soft_cap, + ) + + self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) + + # Optional projection up from intermediate state back to emb dim. + # This corresponds to the transition from the latent/compressed space back to the main model dimension. + skip_projection = isinstance(self, DeepSeekDenseLayer) + + if ( + self.config.attention_output_dim > 0 + and self.config.attention_output_dim != self.config.emb_dim + and not skip_projection + ): + out_kernel_axis = ( + (None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("mlp", "embed") + ) + self.layer_up_projection = linears.DenseGeneral( + in_features_shape=self.config.attention_output_dim, + out_features_shape=self.config.emb_dim, + axis=-1, + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=out_kernel_axis, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + quant=quant, + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + use_bias=False, + rngs=self.rngs, + ) + else: + self.layer_up_projection = None + + def mlp_op(self, x, deterministic, *args, **kwargs): + """Executes the MLP operation. To be implemented by subclasses.""" + raise NotImplementedError() + + def with_logical_constraint(self, x): + return maybe_shard_with_logical( + x, + logical_axes=self.logical_axis_names, + mesh=self.mesh, + shard_mode=self.config.shard_mode, + debug_sharding=self.config.debug_sharding, + extra_stack_level=1, + ) + + def dropout_op(self, x, deterministic): + dropout = self.dropout(x, deterministic=deterministic) + return self.with_logical_constraint(dropout) + + def pre_attention_norm_op(self, x): + pre_attention_norm = self.pre_self_attention_layer_norm(x) + return self.with_logical_constraint(pre_attention_norm) + + def post_attention_norm_op(self, x): + post_attention_norm = self.post_self_attention_layer_norm(x) + return self.with_logical_constraint(post_attention_norm) + + def attention_op( + self, + x, + decoder_segment_ids, + decoder_positions, + deterministic, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot: None | int = None, + ): + """Executes the attention layer.""" + attention_result, _ = self.self_attention( + x, + x, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=self.model_mode, + out_sharding=self.out_sharding, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + ) + return self.with_logical_constraint(attention_result) + + @property + def logical_axis_names(self): + """Generate logical names for activations generally.""" + length_name = "prefill_activation_norm_length" if self.model_mode == MODEL_MODE_PREFILL else "activation_norm_length" + axis_names = ["activation_batch", length_name, "activation_embed"] + return axis_names + + @property + def mlp_logical_axis_names(self): + """Generate logical names for activations in MLP.""" + length_name = "prefill_activation_norm_length" if self.model_mode == MODEL_MODE_PREFILL else "activation_norm_length" + axis_names = ["activation_batch", length_name, "activation_mlp"] + return axis_names + + def post_process(self, layer_output, load_balance_loss, moe_bias_updates, kv_cache=None): + """postprocessing.""" + + if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: + self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss) + + if self.config.routed_bias and self.config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + self.sow(nnx.Intermediate, "moe_bias_updates", moe_bias_updates) + + if self.config.record_internal_nn_metrics: + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) + self.sow( + nnx.Intermediate, + "activation_fraction_zero", + jnp.sum(layer_output == 0) / jnp.size(layer_output), + ) + + if self.config.scan_layers: + return layer_output, None + return layer_output, kv_cache + + +class DeepSeekDenseLayer(DeepSeekGenericLayer): + """DeepSeek-style dense layer with Multi-Head Latent Attention.""" + + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = -1, + ) -> None: + super().__init__(config, model_mode, mesh, rngs, quant, layer_idx) + + # Dense MLP Block uses emb_dim as input and output does not go through the + # bottleneck. + # Input Shape: [Batch, SeqLen, emb_dim] (e.g., [B, S, 7168]) + # Output Shape: [Batch, SeqLen, emb_dim] (e.g., [B, S, 7168]) + mlp_in_features = self.dummy_inputs_shape[-1] + + max_logging.log(f" [Layer {layer_idx} - Dense] Feature Sizes -> " + f"MLP In: {mlp_in_features}, " + f"MLP Dim: {self.config.mlp_dim}, ") + + self.mlp = linears.MlpBlock( + in_features=mlp_in_features, + intermediate_dim=self.config.mlp_dim, + activations=self.config.mlp_activations, + intermediate_dropout_rate=self.config.dropout_rate, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + config=self.config, + quant=quant, + model_mode=model_mode, + mesh=mesh, + rngs=self.rngs, + ) + + def mlp_op(self, x, deterministic): + mlp = self.mlp(x, deterministic, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding) + return self.with_logical_constraint(mlp) + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot: None | int = None, + kv_cache=None, + attention_metadata=None, + decoder_input_tokens=None, + ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] + + x = self.with_logical_constraint(inputs) + x = checkpoint_name(x, "decoder_layer_input") + + # 1. Attention: Takes [B, S, base_emb_dim (e.g. 7168)] tokens and outputs + # [B, S, attention_output_dim (e.g. 3072)] tokens directly. + attn_out = self.attention_op( + self.pre_attention_norm_op(x), + decoder_segment_ids, + decoder_positions, + deterministic, + previous_chunk, + page_state, + slot, + ) + + # 2. MLP processing: Takes [B, S, attention_output_dim (e.g. 3072)] tokens directly + # skipping any intermediate residual and feeds into the MLP block. + # Outputs [B, S, attention_output_dim] tokens. + mlp_lnx = self.mlp_op(attn_out, deterministic) + layer_output = mlp_lnx + + # 3. Final Projection: Maps the [B, S, attention_output_dim] + # combined output back to [B, S, base_emb_dim (e.g. 7168)]. + # This corresponds to the transition from the latent/compressed space back to the main model dimension. + if self.layer_up_projection is not None: + layer_output = self.layer_up_projection(layer_output) + layer_output = self.with_logical_constraint(layer_output) + + layer_output = self.dropout_op(layer_output, deterministic=deterministic) + + # 4. Single residual connection for the whole layer + x = inputs + layer_output + hidden_states = self.post_attention_norm_op(x) + + return self.post_process(hidden_states, None, None, kv_cache) + + +DeepSeekDenseLayerToLinen = nnx_wrappers.to_linen_class( + DeepSeekDenseLayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) + + +class DeepSeekMoELayer(DeepSeekGenericLayer): + """DeepSeek-style MoE layer with Multi-Head Latent Attention. + + Supports dropless and dropping base on configs. Uses a bias in routing instead + of load balancing loss. + """ + + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = -1, + ) -> None: + super().__init__(config, model_mode, mesh, rngs, quant, layer_idx) + if config.attention_output_dim <= 0 or config.attention_output_dim != config.moe_model_dim: + raise ValueError("attention_output_dim must be positive and equal to moe_model_dim for DeepSeekMoELayer.") + + max_logging.log(f" [Layer {layer_idx} - MoE] Feature Sizes -> " + f"Emb Dim: {self.dummy_inputs_shape[-1]}, " + f"Attn Out: {config.attention_output_dim}, " + f"MoE In: {config.moe_model_dim}") + + self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE( + config=self.config, + mesh=mesh, + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + quant=quant, + rngs=self.rngs, + ) + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot: None | int = None, + kv_cache=None, + attention_metadata=None, + decoder_input_tokens=None, + ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] + + x = self.with_logical_constraint(inputs) + x = checkpoint_name(x, "decoder_layer_input") + + # ========================================================================= + # 1. ATTENTION (Down-Projection) + # Input Shape: [Batch, SeqLen, emb_dim] (e.g., [B, S, 7168]) + # Output Shape: [Batch, SeqLen, moe_model_dim] (e.g., [B, S, 3072]) + # The attention block implicitly acts as our down-projection bottleneck. + # ========================================================================= + attn_out = self.attention_op( + self.pre_attention_norm_op(x), + decoder_segment_ids, + decoder_positions, + deterministic, + previous_chunk, + page_state, + slot, + ) + + # ========================================================================= + # 2. MIXTURE OF EXPERTS (A2A Communication & Computation) + # Input Shape: [Batch, SeqLen, moe_model_dim] (e.g., [B, S, 3072]) + # + # Routing Flow inside `self.mlp_op`: + # a. Token Routing (A2A Dispatch): Tokens are routed to experts across devices. + # Because we pass `attn_out` directly without restoring to `emb_dim`, + # the A2A payload is only moe_model_dim per token, saving massive bandwidth. + # b. Expert Compute: [B, S, moe_model_dim] -> [B, S, expert_hidden] -> [B, S, moe_model_dim] + # c. Token Return (A2A Combine): Processed tokens are returned to their + # original devices. The payload over the network remains moe_model_dim. + # + # Output Shape: [Batch, SeqLen, moe_model_dim] (e.g., [B, S, 3072]) + # ========================================================================= + + mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op(attn_out, deterministic) + layer_output = mlp_lnx + + # ========================================================================= + # 3. FINAL UP-PROJECTION + # Input Shape: [Batch, SeqLen, moe_model_dim] (e.g., [B, S, 3072]) + # Output Shape: [Batch, SeqLen, emb_dim] (e.g., [B, S, 7168]) + # Restores the token dimension before the outer residual connection. + # ========================================================================= + if self.layer_up_projection is not None: + layer_output = self.layer_up_projection(layer_output) + layer_output = self.with_logical_constraint(layer_output) + + layer_output = self.dropout_op(layer_output, deterministic=deterministic) + + # ========================================================================= + # 4. RESIDUAL CONNECTION + # [Batch, SeqLen, 7168] + [Batch, SeqLen, 7168] + # ========================================================================= + x = inputs + layer_output + hidden_states = self.post_attention_norm_op(x) + + return self.post_process(hidden_states, load_balance_loss, moe_bias_updates, kv_cache) + + def mlp_op(self, x, deterministic, *args, **kwargs): + mlp_lnx, load_balance_loss, moe_bias_updates = self.DeepSeekMoeBlock_0( + x, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding + ) + return self.with_logical_constraint(mlp_lnx), load_balance_loss, moe_bias_updates + + +DeepSeekMoELayerToLinen = nnx_wrappers.to_linen_class( + DeepSeekMoELayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) + + +class DeepSeekMoEScannableBlock(nnx.Module): + """A repeatable block of DeepSeek Custom MoE layers.""" + + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + + for i in range(self.config.inhomogeneous_layer_cycle_interval): + layer_idx = self.config.first_num_dense_layers + i + layer_name = f"layers_{i}" + layer = DeepSeekMoELayer( + config=config, + model_mode=model_mode, + mesh=mesh, + rngs=rngs, + quant=quant, + layer_idx=layer_idx, + ) + setattr(self, layer_name, layer) + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot: None | int = None, + kv_cache=None, + attention_metadata=None, + decoder_input_tokens=None, + ): + cfg = self.config + y = inputs + for i in range(cfg.inhomogeneous_layer_cycle_interval): + layer = getattr(self, f"layers_{i}") + y = layer( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + decoder_input_tokens=decoder_input_tokens, + ) + if cfg.scan_layers: + y = y[0] + if cfg.scan_layers: + return y, None + else: + return y + + +DeepSeekMoEScannableBlockToLinen = nnx_wrappers.to_linen_class( + DeepSeekMoEScannableBlock, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index dab8103a4f..e27dc2ed87 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -214,19 +214,14 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo Calculate training TFLOP for Gemma2 as in Gemma2 we combine [local_attention, global_attention] into one decoder layer and we use sliding window attention in local_attention """ - noncausal_attention_flops = ( - # global attention - 4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim - + - # local attention - 4 - * config.per_device_batch_size - * config.max_target_length - * min(config.sliding_window_size, config.max_target_length) - * config.num_query_heads - * config.head_dim - ) - causal_attention_flops = noncausal_attention_flops / 2 + S = config.max_target_length + W = min(config.sliding_window_size, S) + + global_attention_flops = 4 * config.per_device_batch_size * (S**2) * config.num_query_heads * config.head_dim / 2 + + local_attention_flops = 4 * config.per_device_batch_size * S * W * config.num_query_heads * config.head_dim + + causal_attention_flops = global_attention_flops + local_attention_flops attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12 # multiply num_decoder_layers by 2 because we combine [local_attention, global_attention] into one decoder layer @@ -253,7 +248,7 @@ def calculate_mixed_attention_model_tflops_training_per_device( # Formula: 4 * batch_size * seq_len^2 * num_heads * head_dim global_attention_flops_per_layer = ( 4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim - ) + ) / 2 # FLOPs for a single local attention layer (sliding window) # Formula: 4 * batch_size * seq_len * window_size * num_heads * head_dim @@ -266,11 +261,10 @@ def calculate_mixed_attention_model_tflops_training_per_device( * config.head_dim ) - # Total attention FLOPs = (num_global_layers * FLOPs_per_global) + (num_local_layers * FLOPs_per_local) - noncausal_attention_flops = ( + # Total causal attention FLOPs = (num_global_layers * FLOPs_per_global) + (num_local_layers * FLOPs_per_local) + causal_attention_flops = ( num_global_layers * global_attention_flops_per_layer + num_local_layers * local_attention_flops_per_layer ) - causal_attention_flops = noncausal_attention_flops / 2 # Convert to TFLOPs and multiply by 3 for fwd/bwd pass attention_tflops = causal_attention_flops * 3 / 10**12 @@ -456,7 +450,7 @@ def calculate_mla_tflops_per_device(config): return qkv_flops, attention_flops, projection_flops -def calculate_ffn_mamtul_tflops_per_device(config, mlp_dim): +def calculate_ffn_mamtul_tflops_per_device(config, mlp_dim, in_features=None): """Helper function to calculate matmul TFLOP in ffn based on MLP dimension. Applies to: @@ -464,21 +458,32 @@ def calculate_ffn_mamtul_tflops_per_device(config, mlp_dim): - MoE FFN layers (mlp_dim = config.moe_mlp_dim), need to scale by shared_experts or num_experts_per_tok. """ + if in_features is None: + in_features = config.emb_dim ffn1_flops = ( - 2 * config.per_device_batch_size * config.max_target_length * mlp_dim * config.emb_dim * len(config.mlp_activations) + 2 * config.per_device_batch_size * config.max_target_length * mlp_dim * in_features * len(config.mlp_activations) ) - ffn2_flops = 2 * config.per_device_batch_size * config.max_target_length * mlp_dim * config.emb_dim + ffn2_flops = 2 * config.per_device_batch_size * config.max_target_length * mlp_dim * in_features return ffn1_flops + ffn2_flops def calculate_routed_and_shared_ffn_tflops_per_device(config): """Helper function to calculate DeepSeek-style ffn TFLOP""" - gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts + if config.decoder_block == DecoderBlockType.DEEPSEEK_CUSTOM: + in_features = config.moe_model_dim if config.moe_model_dim > 0 else config.emb_dim + dense_in_features = config.emb_dim + shared_expert_mlp_dim = config.shared_expert_mlp_dim if config.shared_expert_mlp_dim > 0 else config.moe_mlp_dim * config.shared_experts + else: + in_features = config.emb_dim + dense_in_features = config.emb_dim + shared_expert_mlp_dim = config.moe_mlp_dim * config.shared_experts + + gate_flops = 2 * config.per_device_batch_size * config.max_target_length * in_features * config.num_experts # Due to the mixed decoder layers, the flops is multiplied by num of layers for both dense and moe num_dense_layers, num_moe_layers = get_dense_moe_layers(config) - dense_ffn_flops = calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) * num_dense_layers - shared_experts_flops = calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.shared_experts - routed_experts_flops = calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok + dense_ffn_flops = calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim, dense_in_features) * num_dense_layers + shared_experts_flops = calculate_ffn_mamtul_tflops_per_device(config, shared_expert_mlp_dim, in_features) + routed_experts_flops = calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim, in_features) * config.num_experts_per_tok moe_ffn_flops = (gate_flops + shared_experts_flops + routed_experts_flops) * num_moe_layers total_ffn_flops = dense_ffn_flops + moe_ffn_flops return total_ffn_flops @@ -486,7 +491,7 @@ def calculate_routed_and_shared_ffn_tflops_per_device(config): def get_dense_moe_layers(config): """Helper function to calculate number of dense and moe layers""" - if config.decoder_block == DecoderBlockType.DEEPSEEK: + if config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): num_dense_layers = config.first_num_dense_layers num_moe_layers = config.num_decoder_layers - config.first_num_dense_layers return num_dense_layers, num_moe_layers @@ -502,6 +507,66 @@ def get_dense_moe_layers(config): return num_dense_layers, num_moe_layers +def calculate_deepseek_custom_attention_and_proj_tflops(config): + """Calculates attention and projection FLOPs for DeepSeek Custom model layer by layer.""" + total_qkv_proj_flops = 0 + total_causal_attention_flops = 0 + + B = config.per_device_batch_size + S = config.max_target_length + E = config.emb_dim + D_head = config.head_dim + + for layer_idx in range(config.num_decoder_layers): + is_dense_layer = layer_idx < config.first_num_dense_layers + is_global_attention = False + + if not is_dense_layer and config.attention_layer_hybrid_ratio > 0: + is_global_attention = (layer_idx + 1) % config.attention_layer_hybrid_ratio == 0 + + if is_global_attention and config.global_num_query_heads > 0: + H_q = config.global_num_query_heads + H_kv = config.global_num_kv_heads + W = None + elif not is_global_attention and config.local_num_query_heads > 0: + H_q = config.local_num_query_heads + H_kv = config.local_num_kv_heads + W = config.sliding_window_size + else: + H_q = config.base_num_query_heads + H_kv = config.base_num_kv_heads + W = None + + attention_output_dim = config.attention_output_dim if (not is_dense_layer and config.attention_output_dim > 0) else E + + # QKV projection + # Q, K, V + qkv_flops = 2 * B * S * E * (H_q + 2 * H_kv) * D_head + + # Attention + if W is not None: + # Local Sliding window attention + # Formula: 4 * B * S * W * H_q * D_head + causal_attention_flops = 4 * B * S * W * H_q * D_head + else: + # Global attention + noncausal_attention_flops = 4 * B * S * (S ** 2) * H_q * D_head / S + causal_attention_flops = noncausal_attention_flops / 2 + + # Out projection + projection_flops = 2 * B * S * H_q * D_head * attention_output_dim + + # Up projection (if MoE layer) + up_projection_flops = 0 + if not is_dense_layer and config.attention_output_dim > 0 and config.attention_output_dim != E: + up_projection_flops = 2 * B * S * config.attention_output_dim * E + + total_qkv_proj_flops += (qkv_flops + projection_flops + up_projection_flops) + total_causal_attention_flops += causal_attention_flops + + return total_qkv_proj_flops, total_causal_attention_flops + + def calculate_gated_delta_net_flops_per_device(config): """ - Calculates the FLOPs for a single Gated Delta Net (Linear Attention) layer. @@ -725,7 +790,7 @@ def calculate_tflops_training_per_device(config, log=True): # MLP flops if config.num_experts > 1: # calculation based on dropless implementation - if config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4, DecoderBlockType.QWEN3_NEXT): + if config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4, DecoderBlockType.QWEN3_NEXT, DecoderBlockType.DEEPSEEK_CUSTOM): total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device(config) else: gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts @@ -793,6 +858,10 @@ def calculate_tflops_training_per_device(config, log=True): (total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12 ) attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12 + elif config.decoder_block == DecoderBlockType.DEEPSEEK_CUSTOM: + total_qkv_proj_flops, total_causal_attention_flops = calculate_deepseek_custom_attention_and_proj_tflops(config) + learnable_weight_tflops = (total_ffn_flops + total_qkv_proj_flops + embedding_flops) * 3 / 10**12 + attention_tflops = total_causal_attention_flops * 3 / 10**12 elif config.decoder_block == DecoderBlockType.QWEN3_NEXT: gdn_weight_flops_per_layer, gdn_attn_flops_per_layer = calculate_gated_delta_net_flops_per_device(config) cycle_interval = config.inhomogeneous_layer_cycle_interval diff --git a/tests/integration/smoke/train_smoke_test.py b/tests/integration/smoke/train_smoke_test.py index 3ed0b40c14..f317a928e2 100644 --- a/tests/integration/smoke/train_smoke_test.py +++ b/tests/integration/smoke/train_smoke_test.py @@ -42,26 +42,69 @@ def test_tiny_config(self): train_main( [ None, - get_test_config_path(), + get_test_config_path("base.yml"), + "model_name=deepseek3-custom-small", # pylint: disable=f-string-without-interpolation f"base_output_directory={self.base_output_directory}", "run_name=runner_test", r"dataset_path={self.dataset_path}", - "base_emb_dim=8", + "base_emb_dim=128", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=9", + "head_dim=128", + "attention_output_dim=64", + "moe_model_dim=64", + "per_device_batch_size=2", + "max_target_length=1024", + "dataset_type=synthetic", + "steps=10", + "skip_jax_distributed_system=True", + "attention=dot_product", + "enable_checkpointing=False", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", + "enable_goodput_recording=False", + "enable_checkpoint_cloud_logger=False", + "monitor_goodput=False", + ] + ) + + def test_tiny_config_deepseek_custom_scan(self): + test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable + train_main( + [ + None, + get_test_config_path("base.yml"), + "model_name=deepseek3-custom-small", + "decoder_block=deepseek_custom", + # pylint: disable=f-string-without-interpolation + f"base_output_directory={self.base_output_directory}", + "run_name=runner_test", + r"dataset_path={self.dataset_path}", + "base_emb_dim=128", "base_num_query_heads=4", "base_num_kv_heads=4", "base_mlp_dim=32", - "base_num_decoder_layers=8", + "base_num_decoder_layers=9", + "first_num_dense_layers=1", "head_dim=128", + "attention_output_dim=64", + "moe_model_dim=64", "per_device_batch_size=2", "max_target_length=1024", "dataset_type=synthetic", "steps=10", + "skip_jax_distributed_system=True", + "attention=dot_product", "enable_checkpointing=False", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", "enable_goodput_recording=False", "enable_checkpoint_cloud_logger=False", "monitor_goodput=False", + "scan_layers=True", + "inhomogeneous_layer_cycle_interval=2", + "attention_layer_hybrid_ratio=2", ] ) @@ -71,18 +114,23 @@ def test_tiny_config_no_scan(self): [ None, get_test_config_path(), + "model_name=deepseek3-custom-small", # pylint: disable=f-string-without-interpolation f"base_output_directory={self.base_output_directory}", "run_name=runner_test", r"dataset_path={self.dataset_path}", - "base_emb_dim=8", + "base_emb_dim=128", "base_num_query_heads=4", "base_num_kv_heads=4", "base_mlp_dim=32", - "base_num_decoder_layers=8", + "base_num_decoder_layers=9", "head_dim=128", + "attention_output_dim=64", + "moe_model_dim=64", "per_device_batch_size=2", "max_target_length=1024", + "inhomogeneous_layer_cycle_interval=2", + "attention_layer_hybrid_ratio=2", "dataset_type=synthetic", "steps=10", "enable_checkpointing=False", @@ -108,8 +156,10 @@ def test_tiny_config_explicit_shardmode(self): "base_num_query_heads=4", "base_num_kv_heads=4", "base_mlp_dim=32", - "base_num_decoder_layers=8", + "base_num_decoder_layers=9", "head_dim=128", + "inhomogeneous_layer_cycle_interval=2", + "attention_layer_hybrid_ratio=2", "per_device_batch_size=2", "max_target_length=1024", "dataset_type=synthetic", @@ -123,6 +173,115 @@ def test_tiny_config_explicit_shardmode(self): ] ) + def test_tiny_config_moe_megablox(self): + test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable + train_main( + [ + None, + get_test_config_path(), + "model_name=deepseek3-custom-small", + # pylint: disable=f-string-without-interpolation + f"base_output_directory={self.base_output_directory}", + "run_name=runner_test", + r"dataset_path={self.dataset_path}", + "base_emb_dim=128", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=9", + "head_dim=128", + "attention_output_dim=128", + "moe_model_dim=128", + "base_moe_mlp_dim=128", + "shared_expert_mlp_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "dataset_type=synthetic", + "steps=10", + "enable_checkpointing=False", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", + "enable_goodput_recording=False", + "enable_checkpoint_cloud_logger=False", + "monitor_goodput=False", + "sparse_matmul=True", + "megablox=True", + "use_tokamax_gmm=False", + ] + ) + + def test_tiny_config_moe_tokamax(self): + test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable + train_main( + [ + None, + get_test_config_path(), + "model_name=deepseek3-custom-small", + # pylint: disable=f-string-without-interpolation + f"base_output_directory={self.base_output_directory}", + "run_name=runner_test", + r"dataset_path={self.dataset_path}", + "base_emb_dim=128", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=9", + "head_dim=128", + "attention_output_dim=128", + "moe_model_dim=128", + "base_moe_mlp_dim=128", + "shared_expert_mlp_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "inhomogeneous_layer_cycle_interval=1", + "attention_layer_hybrid_ratio=1", + "dataset_type=synthetic", + "steps=10", + "enable_checkpointing=False", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", + "enable_goodput_recording=False", + "enable_checkpoint_cloud_logger=False", + "monitor_goodput=False", + "sparse_matmul=True", + "megablox=False", + "use_tokamax_gmm=True", + ] + ) + + def test_tiny_config_moe_vanilla(self): + test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable + train_main( + [ + None, + get_test_config_path(), + "model_name=deepseek3-custom-small", + # pylint: disable=f-string-without-interpolation + f"base_output_directory={self.base_output_directory}", + "run_name=runner_test", + r"dataset_path={self.dataset_path}", + "base_emb_dim=128", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=9", + "head_dim=128", + "attention_output_dim=128", + "moe_model_dim=128", + "base_moe_mlp_dim=128", + "shared_expert_mlp_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "dataset_type=synthetic", + "steps=10", + "enable_checkpointing=False", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", + "enable_goodput_recording=False", + "enable_checkpoint_cloud_logger=False", + "monitor_goodput=False", + "sparse_matmul=True", + "megablox=False", + "use_tokamax_gmm=False", + ] + ) if __name__ == "__main__": absltest.main() diff --git a/tests/unit/flop_calculation_test.py b/tests/unit/flop_calculation_test.py index db839e066d..fc8268e893 100644 --- a/tests/unit/flop_calculation_test.py +++ b/tests/unit/flop_calculation_test.py @@ -97,6 +97,80 @@ def compute_gpt_attention_flops_per_device(self, kwargs: dict) -> float: return attention_flops / 1e12 # return tflops + def compute_deepseek_custom_flops_training_per_device(self, kwargs: dict) -> float: + """Computes the total training TFLOPs per device for DeepSeek Custom model.""" + B = kwargs["per_device_batch_size"] + S = kwargs["max_target_length"] + E = kwargs["base_emb_dim"] + + total_ffn_flops = 0 + total_qkv_proj_flops = 0 + total_causal_attention_flops = 0 + + num_layers = kwargs["base_num_decoder_layers"] + for layer_idx in range(num_layers): + is_dense = layer_idx < kwargs["first_num_dense_layers"] + + is_global = False + if not is_dense and kwargs["attention_layer_hybrid_ratio"] > 0: + is_global = (layer_idx + 1) % kwargs["attention_layer_hybrid_ratio"] == 0 + + # FFN + if is_dense: + # Dense layers use full emb_dim for both in and out. + # FFN1 + FFN2 ops: 2 (matmul) * 3 (due to silu + linear activations) + total_ffn_flops += 2 * B * S * kwargs["base_mlp_dim"] * E * 3 + else: + # MoE layers operate over moe_model_dim rather than emb_dim + in_features = kwargs["moe_model_dim"] + gate = 2 * B * S * in_features * kwargs["num_experts"] + shared = 2 * B * S * kwargs["shared_expert_mlp_dim"] * in_features * 3 + routed = 2 * B * S * kwargs["base_moe_mlp_dim"] * in_features * 3 * kwargs["num_experts_per_tok"] + total_ffn_flops += gate + shared + routed + + # Attention Configuration Based on Hybrid Layer type + if is_global and kwargs["global_num_query_heads"] > 0: + H_q = kwargs["global_num_query_heads"] + H_kv = kwargs["global_num_kv_heads"] + W = None + elif not is_global and kwargs["local_num_query_heads"] > 0: + H_q = kwargs["local_num_query_heads"] + H_kv = kwargs["local_num_kv_heads"] + W = kwargs["sliding_window_size"] + else: + H_q = kwargs["base_num_query_heads"] + H_kv = kwargs["base_num_kv_heads"] + W = None + + attention_output_dim = kwargs["attention_output_dim"] if (not is_dense and kwargs["attention_output_dim"] > 0) else E + + # Attention Math + # QKV proj + qkv_flops = 2 * B * S * E * (H_q + 2 * H_kv) * kwargs["head_dim"] + + # Core Attention (QK, SV) + if W is not None: + caus_attn = 4 * B * S * W * H_q * kwargs["head_dim"] + else: + caus_attn = (4 * B * S * S * H_q * kwargs["head_dim"]) / 2 + + # Out Proj + projection_flops = 2 * B * S * H_q * kwargs["head_dim"] * attention_output_dim + + # Up proj (only for MoE layer) + up_proj = 0 + if not is_dense and kwargs["attention_output_dim"] > 0 and kwargs["attention_output_dim"] != E: + up_proj = 2 * B * S * kwargs["attention_output_dim"] * E + + total_qkv_proj_flops += qkv_flops + projection_flops + up_proj + total_causal_attention_flops += caus_attn + + embedding_flops = 2 * B * S * E * kwargs["vocab_size"] + + learnable_tflops = (total_ffn_flops + total_qkv_proj_flops + embedding_flops) * 3 / 1e12 + attn_tflops = total_causal_attention_flops * 3 / 1e12 + return learnable_tflops + attn_tflops + def compute_qwen3_next_attention_flops_per_device(self, kwargs: dict) -> float: """ Computes the total training TFLOPs per device for a Qwen3-Next model. @@ -560,3 +634,51 @@ def test_custom_engram_flops(self): ) calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg) self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops) + + @pytest.mark.cpu_only + def test_deepseek_custom_flops(self): + """Test DeepSeek-Custom FLops calculation""" + kwargs = { + # Model bases + "model_name": "deepseek-custom", + "override_model_config": True, + # Core workload parameters + "per_device_batch_size": 4, + "max_target_length": 8192, + "num_experts": 256, + "num_experts_per_tok": 4, + "shared_experts": 1, + # Model dimensions + "base_emb_dim": 7168, + "moe_model_dim": 3072, + "base_moe_mlp_dim": 6144, + "shared_expert_mlp_dim": 15360, + "attention_output_dim": 3072, + "attention_layer_hybrid_ratio": 2, + "inhomogeneous_layer_cycle_interval": 2, + "head_dim": 256, + "local_num_query_heads": 64, + "local_num_kv_heads": 8, + "sliding_window_size": 1024, + "global_num_query_heads": 64, + "global_num_kv_heads": 4, + "base_num_decoder_layers": 61, + "first_num_dense_layers": 3, + "base_num_query_heads": 64, + "base_num_kv_heads": 8, + "base_mlp_dim": 6144, + "mlp_activations": ["silu", "linear"], + "vocab_size": 102400, + "skip_jax_distributed_system": True, + "decoder_block": "deepseek_custom", + } + + golden_tflops = self.compute_deepseek_custom_flops_training_per_device(kwargs) + + cfg = pyconfig.initialize( + [None, get_test_config_path()], + **kwargs, + ) + calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg) + self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops) +