diff --git a/examples/experiments/paddlefleet/deepseek_v3_2_provider.py b/examples/experiments/paddlefleet/deepseek_v3_2_provider.py new file mode 100644 index 00000000000..dc73aee3623 --- /dev/null +++ b/examples/experiments/paddlefleet/deepseek_v3_2_provider.py @@ -0,0 +1,286 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DeepSeek V3.2 Model Providers for PaddleFleet-based pretraining. + +Architecture: MLA (Multi-Latent Attention) + DSA Indexer (DeepSeek Sparse Attention) + + MoE (Mixture of Experts) + MTP (Multi-Token Prediction) + +Reference: DeepSeek-V3.2-Exp/inference/model.py +Config: DeepSeek-V3.2-Exp/inference/config_671B_v3.2.json + +Usage: + provider = DeepSeekV3_2_671BProvider() + model = provider.provide(loss_fn=loss_fn) + +Pattern follows glm45_provider.py exactly. +""" + +import logging +from dataclasses import dataclass, field +from typing import Callable, List, Optional, Union + +import paddle +import paddle.nn.functional as F + +from paddleformers.transformers.gpt_provider import GPTModelProvider + +logger = logging.getLogger(__name__) + + +@dataclass +class DeepSeekV3_2BaseProvider(GPTModelProvider): + """ + Base provider for DeepSeek V3.2 architecture. + + Key components: + - MLA: Multi-Latent Attention with low-rank KV compression + - DSA: DeepSeek Sparse Attention (Indexer selects top-2048 tokens per query) + - MoE: Mixture of Experts with group-limited routing + - MTP: Multi-Token Prediction auxiliary loss + """ + + # ---- Normalization and activation ---- + normalization: str = "RMSNorm" + hidden_act: Callable = F.silu + gated_linear_unit: bool = True + use_bias: bool = False + attention_bias: bool = False + rms_norm_eps: float = 1e-6 + + # ---- Precision ---- + autocast_dtype: paddle.dtype = paddle.bfloat16 + params_dtype: paddle.dtype = paddle.bfloat16 + bf16: bool = True + + # ---- Embedding ---- + tie_word_embeddings: bool = False + + # ---- Sequence ---- + seq_length: int = 4096 + max_sequence_length: int = 4096 + hidden_dropout_prob: float = 0.0 + attention_dropout: float = 0.0 + init_method_std: float = 0.006 # ~1/sqrt(7168) + + # ---- MLA: Multi-Latent Attention ---- + # MLA de-interleave in rope_utils is NOT needed when rotary_interleaved=True, + # because _rotate_half(interleaved=True) already pairs adjacent dims correctly + # (matching DeepSeek-V3.2 reference apply_rotary_emb(interleaved=True)). + multi_latent_attention: bool = False + num_attention_heads: int = 128 + # head_dim matches v_head_dim=128 so o_proj sizing in Attention base is correct + head_dim: int = 128 + # num_key_value_heads must be set for Attention base class; + # in MLA, KV is latent-compressed but we set this equal to num_attention_heads + # so TP sharding logic in Attention.__init__ works correctly + num_key_value_heads: int = 128 + + # MLA low-rank projection dimensions (matches DeepSeek V3.2 671B config) + q_lora_rank: int = 1536 # wq_a: hidden -> q_lora_rank + kv_lora_rank: int = 512 # wkv_a: hidden -> kv_lora_rank + qk_rope_head_dim + qk_nope_head_dim: int = 128 # per-head non-RoPE Q/K dim + qk_rope_head_dim: int = 64 # per-head RoPE Q/K dim + v_head_dim: int = 128 # per-head V dim (= head_dim, so o_proj ok) + + # ---- DSA: DeepSeek Sparse Attention Indexer ---- + # Non-None activates the DeepSeek V3.2 path in gpt_builders.py + # Field names mirror HuggingFace config.json keys for zero-copy from_config(). + index_n_heads: int = 64 # Indexer scoring heads + index_head_dim: int = 128 # Indexer Q/K head dim + index_topk: int = 2048 # Tokens selected per query + # KL loss trains wq_b/wk/weights_proj via KL(true_attn_dist || indexer_dist) + # Coefficient ~0.01 matches Megatron-Core default; set to None to disable + indexer_loss_coeff: float = 0.01 + indexer_use_sparse_loss: bool = False # use full-sequence KL (denser gradients) + + # ---- RoPE ---- + position_embedding_type: str = "rope" + # DeepSeek V3.2 uses YaRN-style RoPE with base 10000 + rotary_base: float = 10000.0 + # MLA uses interleaved RoPE; Indexer uses non-interleaved (handled internally) + # Setting rotary_interleaved=True here enables the interleaved path for MLA Q/K + rotary_interleaved: bool = True + # Disable fused RoPE kernel: MLA applies RoPE only to qk_rope_head_dim subspace, + # which is incompatible with the fused kernel that expects full head_dim + apply_rope_fusion: bool = False + # Use fp32 RoPE for numerical stability (matches reference implementation) + high_precision_rope: bool = True + + # ---- MoE routing ---- + scoring_func: str = "sigmoid" # Score experts with sigmoid + num_experts_per_tok: int = 8 # n_activated_experts + n_group: int = 8 # n_expert_groups: 256 experts / 8 groups = 32 per group + topk_group: int = 4 # n_limited_groups: select top-4 groups + routed_scaling_factor: float = 2.5 # route_scale: scale selected expert weights + topk_method: str = "group_limited_greedy" # group-limited top-k routing + norm_topk_prob: bool = True # normalize expert weights to sum to 1 + moe_token_dispatcher_type: str = "deepep" + moe_router_load_balancing_type: str = "seq_aux_loss" + moe_router_pre_softmax: bool = False + moe_expert_fusion: bool = False + moe_shared_expert_overlap: bool = True + moe_router_dtype: str = "fp32" + moe_router_enable_expert_bias: bool = True + moe_router_bias_update_rate: float = 0.0 + + # ---- MTP: Multi-Token Prediction ---- + # 1 MTP layer for auxiliary next-token prediction loss + num_nextn_predict_layers: Optional[int] = 1 + mtp_loss_scaling_factor: float = 0.1 # MTP loss weight + + # ---- Optimization ---- + persist_layer_norm: bool = True + bias_activation_fusion: bool = True + bias_dropout_fusion: bool = True + + +@dataclass +class DeepSeekV3_2_671BProvider(DeepSeekV3_2BaseProvider): + """ + Provider for DeepSeek V3.2 671B model (full production config). + + Architecture: + - 61 transformer layers: first 3 dense MLP + 58 MoE + - All layers use MLA + DSA Indexer attention + - 256 routed experts + 1 shared expert per MoE layer + + Config reference: DeepSeek-V3.2-Exp/inference/config_671B_v3.2.json + """ + + # ---- Model dimensions ---- + hidden_size: int = 7168 # dim + num_hidden_layers: int = 61 # n_layers + vocab_size: int = 129280 + + # ---- FFN dimensions ---- + intermediate_size: int = 18432 # inter_dim: dense MLP hidden size + moe_intermediate_size: int = 2048 # moe_inter_dim: per-expert MLP hidden size + + # ---- MoE architecture ---- + n_routed_experts: int = 256 + n_shared_experts: int = 1 + # Layer pattern: first 3 layers dense (0), then 58 MoE (1) + moe_layer_freq: Union[int, List[int]] = field(default_factory=lambda: [0] * 3 + [1] * 58) + + +@dataclass +class DeepSeekV3_2_671BDebugProvider(DeepSeekV3_2_671BProvider): + """ + Small debug variant of DeepSeek V3.2 for single-card validation. + + Reduces all dimensions to fit on a single GPU for smoke testing. + Pattern: 1 dense layer + 3 MoE layers. + """ + + # ---- Reduced model dimensions ---- + num_hidden_layers: int = 4 + hidden_size: int = 1024 + vocab_size: int = 129280 + + # ---- Reduced attention dimensions ---- + num_attention_heads: int = 16 + num_key_value_heads: int = 16 + head_dim: int = 64 + q_lora_rank: int = 256 + kv_lora_rank: int = 128 + qk_nope_head_dim: int = 64 + qk_rope_head_dim: int = 32 + v_head_dim: int = 64 + + # ---- Reduced Indexer dimensions ---- + index_n_heads: int = 8 + index_head_dim: int = 64 + index_topk: int = 128 + indexer_loss_coeff: float = 0.01 + indexer_use_sparse_loss: bool = False + + # ---- Reduced FFN dimensions ---- + intermediate_size: int = 2048 + moe_intermediate_size: int = 512 + + # ---- Reduced MoE ---- + n_routed_experts: int = 8 + n_shared_experts: int = 1 + moe_layer_freq: Union[int, List[int]] = field(default_factory=lambda: [0] * 1 + [1] * 3) + + # ---- Disable MTP for simplicity ---- + num_nextn_predict_layers: Optional[int] = 0 + + # ---- Short sequence for debug ---- + seq_length: int = 512 + max_sequence_length: int = 512 + + # ---- Single card: no model parallel ---- + sequence_parallel: bool = False + expert_model_parallel_size: int = 1 + tensor_model_parallel_size: int = 1 + moe_router_force_load_balancing: bool = True + + +@dataclass +class DeepSeekV3_2_8GPUDebugProvider(DeepSeekV3_2BaseProvider): + """ + Debug provider for DeepSeek V3.2 on a single node with 8 GPUs. + + Scales up from the single-card DebugProvider to exercise multi-card + communication paths (all-reduce, all-gather, DeepEP routing) without + the memory footprint of the full 671B model. + + Key dimension constraints for parallelism: + num_attention_heads (32) and index_n_heads (16) must be + divisible by whatever tensor_model_parallel_size is used. + n_routed_experts (16) must be divisible by expert_model_parallel_size. + + Pattern: 2 dense layers + 6 MoE layers (8 total). + """ + + # ---- Reduced model dimensions ---- + num_hidden_layers: int = 8 + hidden_size: int = 2048 + vocab_size: int = 129280 + + # ---- Reduced attention dimensions ---- + num_attention_heads: int = 32 # divisible by TP=1/2/4/8 + num_key_value_heads: int = 32 + head_dim: int = 64 + q_lora_rank: int = 512 + kv_lora_rank: int = 128 + qk_nope_head_dim: int = 64 + qk_rope_head_dim: int = 32 + v_head_dim: int = 64 + + # ---- Reduced Indexer dimensions ---- + index_n_heads: int = 16 # divisible by TP=1/2/4/8 + index_head_dim: int = 64 + index_topk: int = 256 + indexer_loss_coeff: float = 0.01 + indexer_use_sparse_loss: bool = False + + # ---- Reduced FFN dimensions ---- + intermediate_size: int = 4096 + moe_intermediate_size: int = 1024 + + # ---- Reduced MoE ---- + n_routed_experts: int = 16 # divisible by EP=1/2/4/8 + n_shared_experts: int = 1 + moe_layer_freq: Union[int, List[int]] = field(default_factory=lambda: [0] * 2 + [1] * 6) + + # ---- Disable MTP for simplicity ---- + num_nextn_predict_layers: Optional[int] = 0 + + # ---- Moderate sequence length ---- + seq_length: int = 1024 + max_sequence_length: int = 1024 diff --git a/paddleformers/datasets/template/template.py b/paddleformers/datasets/template/template.py index 1c3757e34ad..6d0d42cd39c 100644 --- a/paddleformers/datasets/template/template.py +++ b/paddleformers/datasets/template/template.py @@ -977,6 +977,16 @@ def _get_gpt_oss_prefix(): chat_sep="<|im_end|>", ) +# copied from deepseekv3 template +register_template( + name="deepseek_v32", + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_user=StringFormatter(slots=["<|User|>{{content}}\n\n<|Assistant|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + format_assistant=StringFormatter(slots=["{{content}}"]), + chat_sep="<|end▁of▁sentence|>", +) + register_template( name="glm_ocr", format_user=StringFormatter(slots=["<|user|>\n{{content}}\n"]), diff --git a/paddleformers/trainer/training_args.py b/paddleformers/trainer/training_args.py index a055326140f..7116c0e1aa2 100644 --- a/paddleformers/trainer/training_args.py +++ b/paddleformers/trainer/training_args.py @@ -1692,6 +1692,11 @@ class TrainingArguments: }, ) + dsa_indexer_loss_coeff: bool = field( + default=0.01, + metadata={"help": "Loss coefficient for the DSA indexer; controls the weight of the indexer loss term."}, + ) + online_merge_ema: bool = field( default=True, metadata={"help": "Whether to perform online merge of the EMA parameters during training. "} ) diff --git a/paddleformers/transformers/__init__.py b/paddleformers/transformers/__init__.py index feeff603775..188f854d162 100644 --- a/paddleformers/transformers/__init__.py +++ b/paddleformers/transformers/__init__.py @@ -96,6 +96,11 @@ "auto.video_processing": ["AutoVideoProcessor", "VIDEO_PROCESSOR_MAPPING"], "auto.feature_extraction": ["AutoFeatureExtractor"], "deepseek_v3.configuration": ["DeepseekV3Config"], + "deepseek_v32.configuration": ["DeepseekV32Config"], + "deepseek_v32.modeling": [ + "DeepseekV32ForCausalLM", + "DeepseekV32ForCausalLMPipe", + ], "deepseek_v3.modeling": [ "masked_fill", "DeepseekV3Attention", diff --git a/paddleformers/transformers/aoa_config_base.py b/paddleformers/transformers/aoa_config_base.py index 2ff6870da09..8414d11fe26 100644 --- a/paddleformers/transformers/aoa_config_base.py +++ b/paddleformers/transformers/aoa_config_base.py @@ -667,7 +667,7 @@ def _get_inv_moe_layer_statements(cls, params: MoEAOAConfigParams) -> List[str]: if layer_idx >= params.num_hidden_layers: prefix_offset += ".transformer_layer" - statements.extend( + statements.extend( [ f"{prefix_offset}.input_layernorm.weight -> {prefix}.input_layernorm.weight", f"{prefix_offset}.post_attention_layernorm.weight -> {prefix}.post_attention_layernorm.weight", diff --git a/paddleformers/transformers/auto/configuration.py b/paddleformers/transformers/auto/configuration.py index fc8c594f4cb..6a47398bf1f 100644 --- a/paddleformers/transformers/auto/configuration.py +++ b/paddleformers/transformers/auto/configuration.py @@ -34,6 +34,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( [ ("deepseek_v3", "DeepseekV3Config"), + ("deepseek_v32", "DeepseekV32Config"), ("ernie4_5", "Ernie4_5Config"), ("ernie4_5_moe", "Ernie4_5_MoeConfig"), ("ernie4_5_moe_vl", "Ernie4_5_VLConfig"), diff --git a/paddleformers/transformers/auto/modeling.py b/paddleformers/transformers/auto/modeling.py index f450dc95656..bc9e105600a 100644 --- a/paddleformers/transformers/auto/modeling.py +++ b/paddleformers/transformers/auto/modeling.py @@ -54,6 +54,7 @@ MAPPING_NAMES = OrderedDict( [ ("DeepseekV3", "deepseek_v3"), + ("DeepseekV32", "deepseek_v32"), ("Ernie4_5", "ernie4_5"), ("Ernie4_5_Moe", "ernie4_5_moe"), ("Ernie4_5_VLMoe", "ernie4_5_moe_vl"), diff --git a/paddleformers/transformers/configuration_utils.py b/paddleformers/transformers/configuration_utils.py index f232707964d..f0fd73678bb 100644 --- a/paddleformers/transformers/configuration_utils.py +++ b/paddleformers/transformers/configuration_utils.py @@ -410,6 +410,12 @@ class LlmMetaConfig: False, "Whether to use SonicMoE as the computation backend for the moelayer.", ), + ( + "dsa_indexer_loss_coeff", + float, + 0.01, + "Loss coefficient for the DSA indexer; controls the weight of the indexer loss term.", + ), ] mtp_attributes = [ diff --git a/paddleformers/transformers/deepseek_v32/__init__.py b/paddleformers/transformers/deepseek_v32/__init__.py new file mode 100644 index 00000000000..1889bf542a1 --- /dev/null +++ b/paddleformers/transformers/deepseek_v32/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import TYPE_CHECKING + +from ...utils.lazy_import import _LazyModule + +import_structure = { + "configuration": ["DeepseekV32Config"], + "modeling": [ + "DeepseekV32ForCausalLM", + "DeepseekV32ForCausalLMPipe", + ], +} + +if TYPE_CHECKING: + from .configuration import * + from .modeling import * +else: + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + import_structure, + module_spec=__spec__, + ) diff --git a/paddleformers/transformers/deepseek_v32/configuration.py b/paddleformers/transformers/deepseek_v32/configuration.py new file mode 100644 index 00000000000..121e75d07d2 --- /dev/null +++ b/paddleformers/transformers/deepseek_v32/configuration.py @@ -0,0 +1,153 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..configuration_utils import PretrainedConfig + + +class DeepseekV32Config(PretrainedConfig): + r""" + Configuration for DeepSeek V3.2 model. + + Architecture: MLA (Multi-Latent Attention) + DSA Indexer (DeepSeek Sparse Attention) + + MoE (Mixture of Experts) + MTP (Multi-Token Prediction) + + Field names are kept consistent with the HuggingFace config.json so that + ``TransformerConfig.from_config()`` can map them to the PaddleFleet provider + dataclass fields without any manual renaming. + """ + + model_type = "deepseek_v32" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=129280, + hidden_size=7168, + intermediate_size=18432, + moe_intermediate_size=2048, + num_hidden_layers=61, + num_attention_heads=128, + num_key_value_heads=128, + max_position_embeddings=163840, + rms_norm_eps=1e-6, + hidden_act="silu", + initializer_range=0.02, + use_cache=True, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + tie_word_embeddings=False, + # MLA parameters + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + head_dim=None, + multi_latent_attention=True, + use_qk_norm=True, + # DSA Indexer parameters (field names match HF config.json) + index_n_heads=64, + index_head_dim=128, + index_topk=2048, + indexer_loss_coeff=0.0, + indexer_use_sparse_loss=False, + # RoPE format control for DSA Indexer + # False = non-interleaved (default, compatible with MLA's interleaved YaRN) + # True = interleaved (paired frequency format) + indexer_rotary_interleaved=False, + # MoE parameters + n_routed_experts=256, + n_shared_experts=1, + num_experts_per_tok=8, + n_group=8, + topk_group=4, + routed_scaling_factor=2.5, + scoring_func="sigmoid", + norm_topk_prob=True, + topk_method="noaux_tc", + first_k_dense_replace=3, + moe_layer_freq=1, + # MTP parameters + num_nextn_predict_layers=1, + # Pipeline parallel segmentation + pp_seg_method="layer:TransformerLayer|EmptyLayer", + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.max_position_embeddings = max_position_embeddings + self.rms_norm_eps = rms_norm_eps + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + # MLA + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + # head_dim must equal v_head_dim for MLA: o_proj input size = num_heads * head_dim, + # and the attention output per head = v_head_dim. + self.head_dim = head_dim if head_dim is not None else v_head_dim + self.multi_latent_attention = multi_latent_attention + self.use_qk_norm = use_qk_norm + + # DSA Indexer + self.index_n_heads = index_n_heads + self.index_head_dim = index_head_dim + self.index_topk = index_topk + self.indexer_loss_coeff = indexer_loss_coeff + self.indexer_use_sparse_loss = indexer_use_sparse_loss + self.indexer_rotary_interleaved = indexer_rotary_interleaved + + # MoE + self.n_routed_experts = n_routed_experts + self.n_shared_experts = n_shared_experts + self.num_experts_per_tok = num_experts_per_tok + self.n_group = n_group + self.topk_group = topk_group + self.routed_scaling_factor = routed_scaling_factor + self.scoring_func = scoring_func + self.norm_topk_prob = norm_topk_prob + self.topk_method = topk_method + self.first_k_dense_replace = first_k_dense_replace + self.moe_layer_freq = moe_layer_freq + + # MTP + self.num_nextn_predict_layers = num_nextn_predict_layers + + # PP + self.pp_seg_method = pp_seg_method + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + # Re-set after super().__init__ because LlmMetaConfig defaults override these + self.multi_latent_attention = multi_latent_attention + self.use_qk_norm = use_qk_norm + self.num_nextn_predict_layers = num_nextn_predict_layers + + +__all__ = ["DeepseekV32Config"] diff --git a/paddleformers/transformers/deepseek_v32/modeling.py b/paddleformers/transformers/deepseek_v32/modeling.py new file mode 100644 index 00000000000..c14cda3aadb --- /dev/null +++ b/paddleformers/transformers/deepseek_v32/modeling.py @@ -0,0 +1,139 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DeepSeek V3.2 PaddleFleet model bridge. + +This module bridges the HuggingFace-style PretrainedConfig/PretrainedModel +interface used by PaddleFormers with the PaddleFleet provider system. + +Pattern follows glm_moe_dsa/modeling.py (GLM5 PR #3940) exactly: + - DeepseekV32ForCausalLM.__new__() calls DeepSeekV3_2BaseProvider.from_config(config) + - provider.provide() calls paddlefleet.gpt_builders.gpt_builder() + - Returns a PaddleFleet GPT model (Megatron-style) +""" + +import os +import sys + +from ..aoa_config_base import MoEAOAConfigGenerator +from ..model_utils import PretrainedModel +from .configuration import DeepseekV32Config + + +class DeepseekV32PreTrainedModel(PretrainedModel): + config_class = DeepseekV32Config + base_model_prefix = "model" + + # Layernorm weight names that need dtype cast (fleet model skips generic dtype mapping) + _NORM_WEIGHT_KEYS = ("input_layernorm.weight", "post_attention_layernorm.weight", + "q_a_layernorm.weight", "kv_a_layernorm.weight", + "k_norm.weight", "k_norm.bias", "norm.weight") + + @classmethod + def _gen_aoa_config(cls, config: DeepseekV32Config): + aoa_config = MoEAOAConfigGenerator.gen_aoa_config(config) + cls._inject_norm_dtype(aoa_config["aoa_statements"], "bfloat16") + return aoa_config + + @classmethod + def _gen_inv_aoa_config(cls, config: DeepseekV32Config): + inv_aoa_config = MoEAOAConfigGenerator.gen_inv_aoa_config(config) + cls._inject_norm_dtype(inv_aoa_config["aoa_statements"], "float32") + return inv_aoa_config + + @classmethod + def _inject_norm_dtype(cls, aoa_statements, target_dtype): + """Inject dtype into existing layernorm statements generated by base class.""" + for i, stmt in enumerate(aoa_statements): + if any(k in stmt for k in cls._NORM_WEIGHT_KEYS) and "dtype=" not in stmt: + aoa_statements[i] = f"{stmt}, dtype='{target_dtype}'" + + +def _build_model(config): + """ + Common __new__ logic shared by ForCausalLM and ForCausalLMPipe. + + Steps: + 1. Normalise parallel config attributes (same as GLM5). + 2. Call DeepSeekV3_2BaseProvider.from_config(config) to populate provider fields. + 3. Call provider.provide() which runs gpt_builder() and returns the PaddleFleet model. + (moe_layer_freq + first_k_dense_replace conversion is handled by + TransformerConfig.__post_init__ automatically.) + """ + # 1. Normalise parallel config (guard against missing attrs from old configs) + config.tensor_model_parallel_size = max(getattr(config, "tensor_model_parallel_size", 1), 1) + config.pipeline_model_parallel_size = max(getattr(config, "pipeline_model_parallel_size", 1), 1) + config.context_parallel_size = max(getattr(config, "context_parallel_size", 1), 1) + config.virtual_pipeline_model_parallel_size = max(getattr(config, "virtual_pipeline_model_parallel_size", 1), 1) + config.expert_model_parallel_size = max(getattr(config, "expert_model_parallel_size", 1), 1) + + # 2. Resolve provider module path. + # The provider lives under examples/experiments/paddlefleet/ which is + # not a proper package. We add it to sys.path if needed. + _provider_dir = os.path.join( + os.path.dirname(__file__), # .../paddleformers/transformers/deepseek_v32/ + "..", + "..", + "..", + "examples", + "experiments", + "paddlefleet", + ) + _provider_dir = os.path.normpath(_provider_dir) + if _provider_dir not in sys.path: + sys.path.insert(0, _provider_dir) + + from deepseek_v3_2_provider import DeepSeekV3_2BaseProvider + + # 3. Build model via provider + model_provider = DeepSeekV3_2BaseProvider.from_config(config) + gpt_model = model_provider.provide() + gpt_model.config_to_save = config + return gpt_model + + +class DeepseekV32ForCausalLM(DeepseekV32PreTrainedModel): + """DeepSeek V3.2 model for pipeline_model_parallel_size == 1.""" + + is_fleet = True + + def __new__(cls, config): + gpt_model = _build_model(config) + gpt_model.is_fleet = cls.is_fleet + gpt_model._gen_aoa_config = cls._gen_aoa_config + gpt_model._gen_inv_aoa_config = cls._gen_inv_aoa_config + return gpt_model + + +class DeepseekV32ForCausalLMPipe(DeepseekV32PreTrainedModel): + """DeepSeek V3.2 model for pipeline_model_parallel_size > 1.""" + + is_fleet = True + + def __new__(cls, config): + if not hasattr(config, "architectures"): + config.architectures = ["DeepseekV32ForCausalLM"] + gpt_model = _build_model(config) + gpt_model.is_fleet = cls.is_fleet + gpt_model._gen_aoa_config = cls._gen_aoa_config + gpt_model._gen_inv_aoa_config = cls._gen_inv_aoa_config + return gpt_model + + +__all__ = [ + "DeepseekV32PreTrainedModel", + "DeepseekV32ForCausalLM", + "DeepseekV32ForCausalLMPipe", +] diff --git a/paddleformers/transformers/gpt_provider.py b/paddleformers/transformers/gpt_provider.py index d477e1ee8c3..ba02bc5b342 100644 --- a/paddleformers/transformers/gpt_provider.py +++ b/paddleformers/transformers/gpt_provider.py @@ -203,6 +203,9 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None, loss_fn=No self.rope_type = self.rope_parameters["rope_type"] if "rope_theta" in self.rope_parameters: self.rope_theta = self.rope_parameters["rope_theta"] + if hasattr(self, "rope_scaling") and self.rope_scaling is not None: + if "mscale_all_dim" in self.rope_scaling: + self.mscale_all_dim = self.rope_scaling["mscale_all_dim"] # Check if mtp_block_spec parameter is supported kwargs = {}