From 42f5ed5ed36d39c81b975b3ceb8a5a804fadcb44 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 2 Apr 2026 10:52:54 -0700 Subject: [PATCH 1/3] add support for arcee-ai/Trinity-Large-Thinking Signed-off-by: Alexandros Koumparoulis --- docs/model-coverage/llm.md | 1 + .../afmoe/trinity_large_thinking_sft.yaml | 88 +++++ nemo_automodel/_transformers/registry.py | 5 + .../components/models/afmoe/__init__.py | 17 + .../components/models/afmoe/config.py | 123 +++++++ .../components/models/afmoe/layers.py | 153 ++++++++ .../components/models/afmoe/model.py | 334 ++++++++++++++++++ .../models/afmoe/state_dict_adapter.py | 131 +++++++ tests/unit_tests/models/afmoe/__init__.py | 0 .../models/afmoe/test_afmoe_layers.py | 100 ++++++ .../models/afmoe/test_afmoe_model.py | 197 +++++++++++ .../afmoe/test_afmoe_state_dict_adapter.py | 189 ++++++++++ 12 files changed, 1338 insertions(+) create mode 100644 examples/llm_finetune/afmoe/trinity_large_thinking_sft.yaml create mode 100644 nemo_automodel/components/models/afmoe/__init__.py create mode 100644 nemo_automodel/components/models/afmoe/config.py create mode 100644 nemo_automodel/components/models/afmoe/layers.py create mode 100644 nemo_automodel/components/models/afmoe/model.py create mode 100644 nemo_automodel/components/models/afmoe/state_dict_adapter.py create mode 100644 tests/unit_tests/models/afmoe/__init__.py create mode 100644 tests/unit_tests/models/afmoe/test_afmoe_layers.py create mode 100644 tests/unit_tests/models/afmoe/test_afmoe_model.py create mode 100644 tests/unit_tests/models/afmoe/test_afmoe_state_dict_adapter.py diff --git a/docs/model-coverage/llm.md b/docs/model-coverage/llm.md index 005acf27d7..d736b9831e 100644 --- a/docs/model-coverage/llm.md +++ b/docs/model-coverage/llm.md @@ -22,6 +22,7 @@ The table below lists the main architectures we test against (FSDP2 combined wit | Architecture | Models | Example HF Models | |---------------------------------------|---------------------------------------|-----------------------------------------------------------------------------------| +| `AfmoeForCausalLM` | Afmoe (Arcee Fusion MoE) | `arcee-ai/Trinity-Large-Thinking` — example recipe: [trinity_large_thinking_sft.yaml](../../examples/llm_finetune/afmoe/trinity_large_thinking_sft.yaml) | | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. — example recipes: [baichuan_2_7b_squad.yaml](../../examples/llm_finetune/baichuan/baichuan_2_7b_squad.yaml), [baichuan_2_7b_squad_peft.yaml](../../examples/llm_finetune/baichuan/baichuan_2_7b_squad_peft.yaml) | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B` | diff --git a/examples/llm_finetune/afmoe/trinity_large_thinking_sft.yaml b/examples/llm_finetune/afmoe/trinity_large_thinking_sft.yaml new file mode 100644 index 0000000000..08e0fc30c5 --- /dev/null +++ b/examples/llm_finetune/afmoe/trinity_large_thinking_sft.yaml @@ -0,0 +1,88 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Afmoe (Arcee Trinity-Large-Thinking) SFT example +# 256 experts, 4 active per token, 60 layers, ~large model +# +# To run this recipe: +# automodel examples/llm_finetune/afmoe/trinity_large_thinking_sft.yaml --nproc-per-node 8 +# Adjust --nproc-per-node to the number of GPUs available on your machine. + +recipe: TrainFinetuneRecipeForNextTokenPrediction + +step_scheduler: + global_batch_size: 32 + local_batch_size: 1 + ckpt_every_steps: 200 + val_every_steps: 100 + num_epochs: 1 + +dist_env: + backend: nccl + timeout_minutes: 10 + +rng: + _target_: nemo_automodel.components.training.rng.StatefulRNG + seed: 1111 + ranked: true + +model: + _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained + pretrained_model_name_or_path: arcee-ai/Trinity-Large-Thinking + +checkpoint: + enabled: false + checkpoint_dir: checkpoints/ + model_save_format: safetensors + save_consolidated: false + +distributed: + strategy: fsdp2 + dp_size: none + tp_size: 1 + cp_size: 1 + sequence_parallel: false + +loss_fn: + _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy + +dataset: + _target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag + path_or_dataset: rowan/hellaswag + split: train + +packed_sequence: + packed_sequence_size: 0 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + collate_fn: nemo_automodel.components.datasets.utils.default_collater + shuffle: false + +validation_dataset: + _target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag + path_or_dataset: rowan/hellaswag + split: validation + num_samples_limit: 64 + +validation_dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + collate_fn: nemo_automodel.components.datasets.utils.default_collater + +optimizer: + _target_: torch.optim.Adam + betas: [0.9, 0.999] + eps: 1e-8 + lr: 1.0e-5 + weight_decay: 0 diff --git a/nemo_automodel/_transformers/registry.py b/nemo_automodel/_transformers/registry.py index 50a3377ef1..6319970a0b 100644 --- a/nemo_automodel/_transformers/registry.py +++ b/nemo_automodel/_transformers/registry.py @@ -31,6 +31,10 @@ # downstream code to classify model archs without importing them. MODEL_ARCH_MAPPING = OrderedDict( [ + ( + "AfmoeForCausalLM", + ("nemo_automodel.components.models.afmoe.model", "AfmoeForCausalLM"), + ), ( "BaichuanForCausalLM", ("nemo_automodel.components.models.baichuan.model", "BaichuanForCausalLM"), @@ -150,6 +154,7 @@ # checkpoint config.json. Registered eagerly with AutoConfig so that # AutoConfig.from_pretrained can resolve them without trust_remote_code. _CUSTOM_CONFIG_REGISTRATIONS: Dict[str, Tuple[str, str]] = { + "afmoe": ("nemo_automodel.components.models.afmoe.config", "AfmoeConfig"), "baichuan": ("nemo_automodel.components.models.baichuan.configuration", "BaichuanConfig"), "kimi_k25": ("nemo_automodel.components.models.kimi_k25_vl.model", "KimiK25VLConfig"), "kimi_vl": ("nemo_automodel.components.models.kimivl.model", "KimiVLConfig"), diff --git a/nemo_automodel/components/models/afmoe/__init__.py b/nemo_automodel/components/models/afmoe/__init__.py new file mode 100644 index 0000000000..aba2500962 --- /dev/null +++ b/nemo_automodel/components/models/afmoe/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_automodel.components.models.afmoe.model import AfmoeForCausalLM + +__all__ = ["AfmoeForCausalLM"] diff --git a/nemo_automodel/components/models/afmoe/config.py b/nemo_automodel/components/models/afmoe/config.py new file mode 100644 index 0000000000..072e06eaf9 --- /dev/null +++ b/nemo_automodel/components/models/afmoe/config.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class AfmoeConfig(PretrainedConfig): + """Configuration for the Afmoe (Arcee Fusion MoE) model. + + This is a Mixture-of-Experts model with hybrid sliding-window / full attention, + gated attention output, QK normalization, and dual pre/post normalization. + """ + + model_type = "afmoe" + + def __init__( + self, + num_hidden_layers: int = 32, + vocab_size: int = 200192, + hidden_size: int = 2048, + intermediate_size: int = 6144, + moe_intermediate_size=1408, + num_dense_layers=1, + num_attention_heads=16, + num_key_value_heads=None, + head_dim=128, + hidden_act="silu", + max_position_embeddings=16384, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + num_experts=64, + num_experts_per_tok=6, + num_shared_experts=2, + num_expert_groups=1, + num_limited_groups=1, + score_func="sigmoid", + route_norm=True, + route_scale=1.0, + global_attn_every_n_layers=4, + sliding_window=1024, + mup_enabled=False, + layer_types=None, + attention_dropout: float = 0.0, + n_group: int = 1, + topk_group: int = 1, + load_balance_coeff: float = 0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_dense_layers = num_dense_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + # MoE specific + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.n_group = n_group + self.topk_group = topk_group + self.num_experts = num_experts + self.num_shared_experts = num_shared_experts + self.num_expert_groups = num_expert_groups + self.num_limited_groups = num_limited_groups + self.score_func = score_func + self.route_norm = route_norm + self.route_scale = route_scale + self.load_balance_coeff = load_balance_coeff + + # Attention specific + self.attention_dropout = attention_dropout + self.global_attn_every_n_layers = global_attn_every_n_layers + self.sliding_window = sliding_window + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if bool((i + 1) % global_attn_every_n_layers) else "full_attention" + for i in range(self.num_hidden_layers) + ] + + # muP specific + self.mup_enabled = mup_enabled + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + # Validate rope configs + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["AfmoeConfig"] diff --git a/nemo_automodel/components/models/afmoe/layers.py b/nemo_automodel/components/models/afmoe/layers.py new file mode 100644 index 0000000000..4cad789960 --- /dev/null +++ b/nemo_automodel/components/models/afmoe/layers.py @@ -0,0 +1,153 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Afmoe attention layer with gated output, QK normalization, and conditional RoPE.""" + +from typing import Any + +import torch +import torch.nn.functional as F +from torch import nn + +from nemo_automodel.components.attention.utils import ( + initialize_attn_module_and_func, + postprocess_output_for_attn, + preprocess_args_and_kwargs_for_attn, +) +from nemo_automodel.components.models.common import ( + BackendConfig, + initialize_linear_module, + initialize_rms_norm_module, +) +from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk + + +class AfmoeAttention(nn.Module): + """Afmoe attention with gated output, per-head QK RMSNorm, and conditional RoPE. + + Key differences from standard attention: + - RoPE is applied only to sliding-window (local) attention layers. + - Attention output is gated: ``output = output * sigmoid(gate_proj(x))``. + - Per-head RMSNorm on Q and K before attention. + """ + + def __init__(self, config, layer_idx: int, backend: BackendConfig): + super().__init__() + self.backend = backend + self.layer_idx = layer_idx + + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads) + self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention" + self.sliding_window = config.sliding_window if self.is_local_attention else None + + self.q_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, bias=False + ) + self.v_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, bias=False + ) + self.o_proj = initialize_linear_module( + backend.linear, self.num_heads * self.head_dim, config.hidden_size, bias=False + ) + self.gate_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_heads * self.head_dim, bias=False + ) + + # Per-head RMSNorm on Q and K + self.q_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) + self.k_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) + + # Attention implementation + softmax_scale = self.head_dim**-0.5 + self.attn_module, self.attn_func = initialize_attn_module_and_func( + attn_impl=backend.attn, + num_attention_heads=self.num_heads, + num_qk_channels=self.head_dim, + num_v_channels=self.head_dim, + softmax_scale=softmax_scale, + num_gqa_groups=self.num_kv_heads, + ) + + def forward( + self, + x: torch.Tensor, + *, + freqs_cis: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if len(x.shape) == 2: + qkv_format = "thd" + num_tokens = x.shape[0] + else: + qkv_format = "bshd" + bsz, seqlen, _ = x.size() + + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + gate = self.gate_proj(x) + + if qkv_format == "thd": + q = q.view(num_tokens, self.num_heads, self.head_dim) + k = k.view(num_tokens, self.num_kv_heads, self.head_dim) + v = v.view(num_tokens, self.num_kv_heads, self.head_dim) + else: + q = q.view(bsz, seqlen, self.num_heads, self.head_dim) + k = k.view(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.num_kv_heads, self.head_dim) + + # Per-head RMSNorm + q = self.q_norm(q) + k = self.k_norm(k) + + # RoPE only on local (sliding-window) attention layers + if self.is_local_attention: + q, k = apply_rotary_emb_qk( + q, + k, + freqs_cis, + format=qkv_format, + rope_fusion=self.backend.rope_fusion, + cu_seqlens=attn_kwargs.get("cu_seqlens", None), + cp_size=attn_kwargs.get("cp_size", 1), + cp_rank=attn_kwargs.get("cp_rank", 0), + ) + + # Backend-specific attention + window_size = (self.sliding_window, 0) if self.is_local_attention else (-1, 0) + q, k, v, _attn_kwargs = preprocess_args_and_kwargs_for_attn( + q, k, v, attention_mask, self.backend.attn, window_size=window_size, **attn_kwargs + ) + out = self.attn_func(q, k, v, **_attn_kwargs) + out = postprocess_output_for_attn(out, self.backend.attn) + + flatten_dim = 2 if qkv_format == "bshd" else 1 + out = out.flatten(flatten_dim) + + # Gated attention output + out = out * F.sigmoid(gate) + out = self.o_proj(out) + return out + + def init_weights(self, buffer_device: torch.device, init_std: float = 0.02): + for linear in (self.q_proj, self.k_proj, self.v_proj, self.o_proj, self.gate_proj): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + for norm in (self.q_norm, self.k_norm): + norm.reset_parameters() diff --git a/nemo_automodel/components/models/afmoe/model.py b/nemo_automodel/components/models/afmoe/model.py new file mode 100644 index 0000000000..6434b8fcf6 --- /dev/null +++ b/nemo_automodel/components/models/afmoe/model.py @@ -0,0 +1,334 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Afmoe (Arcee Fusion MoE) model implementation for NeMo AutoModel. + +Key architectural features: +- Mixture-of-Experts with sigmoid routing, shared experts, and expert bias correction +- Hybrid attention: sliding-window (local) + full (global) every N layers +- Gated attention output with per-head QK RMSNorm +- Dual pre/post normalization around both attention and MLP +- RoPE only on local attention layers +- Optional muP input scaling +""" + +from typing import Any + +import torch +import torch.nn as nn + +from nemo_automodel.components.models.afmoe.config import AfmoeConfig +from nemo_automodel.components.models.afmoe.layers import AfmoeAttention +from nemo_automodel.components.models.afmoe.state_dict_adapter import AfmoeStateDictAdapter +from nemo_automodel.components.models.common import ( + BackendConfig, + get_rope_config, + initialize_linear_module, + initialize_rms_norm_module, +) +from nemo_automodel.components.models.common.hf_checkpointing_mixin import HFCheckpointingMixin +from nemo_automodel.components.models.common.utils import cast_model_to_dtype +from nemo_automodel.components.models.gpt_oss.rope_utils import RotaryEmbedding, position_ids_to_freqs_cis +from nemo_automodel.components.moe.config import MoEConfig +from nemo_automodel.components.moe.fsdp_mixin import MoEFSDPSyncMixin +from nemo_automodel.components.moe.layers import MLP, MoE +from nemo_automodel.components.utils.model_utils import squeeze_input_for_thd +from nemo_automodel.shared.utils import dtype_from_str as get_dtype + + +def _build_moe_config(config: AfmoeConfig) -> MoEConfig: + """Build MoEConfig from the HF AfmoeConfig.""" + return MoEConfig( + dim=config.hidden_size, + inter_dim=config.intermediate_size, + moe_inter_dim=config.moe_intermediate_size, + n_routed_experts=config.num_experts, + n_shared_experts=config.num_shared_experts, + n_activated_experts=config.num_experts_per_tok, + n_expert_groups=config.n_group, + n_limited_groups=config.topk_group, + train_gate=True, + gate_bias_update_factor=0.001, + score_func=config.score_func, + route_scale=config.route_scale, + aux_loss_coeff=getattr(config, "load_balance_coeff", 0.0), + norm_topk_prob=config.route_norm, + expert_bias=False, + router_bias=False, + expert_activation="swiglu", + force_e_score_correction_bias=True, + shared_expert_inter_dim=config.moe_intermediate_size, + ) + + +class Block(nn.Module): + """Afmoe decoder block with dual normalization and conditional MoE/dense MLP.""" + + def __init__(self, layer_idx: int, config: AfmoeConfig, moe_config: MoEConfig, backend: BackendConfig): + super().__init__() + self.self_attn = AfmoeAttention(config, layer_idx, backend) + + # Dense MLP for first num_dense_layers, MoE for the rest + self.moe_enabled = layer_idx >= config.num_dense_layers + if self.moe_enabled: + self.mlp = MoE(moe_config, backend) + else: + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + + # Dual normalization: pre/post around both attention and MLP + self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_mlp_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + ) + self.post_mlp_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + ) + self.layer_idx = layer_idx + + def forward( + self, + x: torch.Tensor, + *, + freqs_cis: torch.Tensor, + attention_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if attention_mask is not None and padding_mask is None: + padding_mask = attention_mask.bool().logical_not() + + # Attention with dual normalization + residual = x + x = self.input_layernorm(x) + x = self.self_attn( + x=x, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + **attn_kwargs, + ) + x = self.post_attention_layernorm(x) + x = residual + x + + # MLP with dual normalization + residual = x + x = self.pre_mlp_layernorm(x) + x = self._mlp(x=x, padding_mask=padding_mask) + x = self.post_mlp_layernorm(x) + x = residual + x + return x + + def _mlp(self, x: torch.Tensor, padding_mask: torch.Tensor | None) -> torch.Tensor: + if isinstance(self.mlp, MLP): + return self.mlp(x) + else: + assert isinstance(self.mlp, MoE) + return self.mlp(x, padding_mask) + + def init_weights(self, buffer_device: torch.device): + for norm in ( + self.input_layernorm, + self.post_attention_layernorm, + self.pre_mlp_layernorm, + self.post_mlp_layernorm, + ): + norm.reset_parameters() + self.self_attn.init_weights(buffer_device) + self.mlp.init_weights(buffer_device) + + +class AfmoeModel(nn.Module): + """Afmoe transformer backbone.""" + + def __init__(self, config: AfmoeConfig, backend: BackendConfig, *, moe_config: MoEConfig | None = None): + super().__init__() + self.backend = backend + self.config = config + self.moe_config = moe_config or _build_moe_config(config) + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + ) + self.layers = torch.nn.ModuleDict() + for layer_id in range(config.num_hidden_layers): + self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) + self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + + # Rotary embedding + self.max_seq_len = config.max_position_embeddings + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + + base, rope_scaling, _ = get_rope_config(config) + + self.rotary_emb = RotaryEmbedding( + head_dim=self.head_dim, + base=base, + dtype=torch.float32, + initial_context_length=rope_scaling.get("original_max_position_embeddings", 4096), + scaling_factor=rope_scaling.get("factor", 1.0), + ntk_alpha=rope_scaling.get("beta_slow", 1.0), + ntk_beta=rope_scaling.get("beta_fast", 32.0), + device=torch.device(f"cuda:{torch.cuda.current_device()}"), + ) + + # muP: scale embeddings by sqrt(hidden_size) + self.mup_enabled = getattr(config, "mup_enabled", False) + + def forward( + self, + input_ids: torch.Tensor, + *, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if position_ids is None: + position_ids = ( + torch.arange(0, input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand(input_ids.shape[0], -1) + ) + + freqs_cis = position_ids_to_freqs_cis( + self.rotary_emb, + position_ids, + qkv_format=attn_kwargs.get("qkv_format", "bshd"), + for_fused_rope=self.backend.rope_fusion, + cp_size=attn_kwargs.get("cp_size", 1), + ) + + h = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids + + if self.mup_enabled: + h = h * (self.config.hidden_size**0.5) + + for layer in self.layers.values(): + h = layer( + x=h, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + padding_mask=padding_mask, + **attn_kwargs, + ) + + h = self.norm(h) if self.norm else h + return h + + @torch.no_grad() + def init_weights(self, buffer_device: torch.device | None = None) -> None: + buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}") + with buffer_device: + if self.embed_tokens is not None: + nn.init.normal_(self.embed_tokens.weight) + if self.norm is not None: + self.norm.reset_parameters() + self.rotary_emb.device = buffer_device + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device=buffer_device) + + +class AfmoeForCausalLM(HFCheckpointingMixin, nn.Module, MoEFSDPSyncMixin): + """Afmoe MoE causal language model.""" + + _keep_in_fp32_modules_strict = ["e_score_correction_bias"] + + @classmethod + def from_config(cls, config: AfmoeConfig, moe_config=None, backend=None, **kwargs): + return cls(config, moe_config, backend, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): + config = AfmoeConfig.from_pretrained(pretrained_model_name_or_path) + return cls.from_config(config, *model_args, **kwargs) + + def __init__(self, config: AfmoeConfig, moe_config=None, backend=None, **kwargs): + super().__init__() + self.config = config + self.backend = backend or BackendConfig() + self.model = AfmoeModel(config, backend=self.backend, moe_config=moe_config) + self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + if self.backend.enable_hf_state_dict_adapter: + self.state_dict_adapter = AfmoeStateDictAdapter( + self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + ) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def forward( + self, + input_ids: torch.Tensor, + *, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if "qkv_format" in attn_kwargs and attn_kwargs["qkv_format"] == "thd": + input_ids, position_ids, padding_mask, attn_kwargs = squeeze_input_for_thd( + input_ids, position_ids, padding_mask, attn_kwargs + ) + attention_mask = None + + hidden = self.model( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + padding_mask=padding_mask, + **attn_kwargs, + ) + logits = self.lm_head(hidden) if self.lm_head else hidden + if "qkv_format" in attn_kwargs and attn_kwargs["qkv_format"] == "thd": + logits = logits.unsqueeze(0) + return logits + + def update_moe_gate_bias(self) -> None: + with torch.no_grad(): + for _, block in self.model.layers.named_children(): + if isinstance(block.mlp, MoE): + block.mlp.gate.update_bias() + + @torch.no_grad() + def initialize_weights(self, buffer_device=None, dtype=torch.bfloat16): + buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}") + with buffer_device: + self.model.init_weights(buffer_device=buffer_device) + final_out_std = self.config.hidden_size**-0.5 + cutoff_factor = 3 + if self.lm_head is not None: + nn.init.trunc_normal_( + self.lm_head.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + cast_model_to_dtype(self, dtype) + with buffer_device: + self.model.rotary_emb.device = buffer_device + + +ModelClass = AfmoeForCausalLM diff --git a/nemo_automodel/components/models/afmoe/state_dict_adapter.py b/nemo_automodel/components/models/afmoe/state_dict_adapter.py new file mode 100644 index 0000000000..98f50c329c --- /dev/null +++ b/nemo_automodel/components/models/afmoe/state_dict_adapter.py @@ -0,0 +1,131 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State dict adapter for Afmoe HF checkpoints. + +Handles conversion between HF per-expert format and NeMo grouped-expert format, +plus key renaming for the router gate and expert bias. + +HF key format: + model.layers.{L}.mlp.router.gate.weight -> model.layers.{L}.mlp.gate.weight + model.layers.{L}.mlp.expert_bias -> model.layers.{L}.mlp.gate.e_score_correction_bias + model.layers.{L}.mlp.experts.{E}.gate_proj.weight -> (stacked into gate_and_up_projs) + model.layers.{L}.mlp.experts.{E}.up_proj.weight -> (stacked into gate_and_up_projs) + model.layers.{L}.mlp.experts.{E}.down_proj.weight -> (stacked into down_projs) + +Other keys (attention projections, norms, shared experts, dense MLP) pass through unchanged. +""" + +import logging +import re +from typing import Any, Optional + +import torch +from torch.distributed.device_mesh import DeviceMesh + +from nemo_automodel.components.checkpoint.state_dict_adapter import StateDictAdapter +from nemo_automodel.components.models.common import BackendConfig +from nemo_automodel.components.moe.config import MoEConfig +from nemo_automodel.components.moe.state_dict_mixin import MoESplitExpertsStateDictMixin + +logger = logging.getLogger(__name__) + +# Bidirectional key renaming rules: (hf_pattern, nemo_pattern) +_KEY_RENAMES_HF_TO_NEMO = [ + (".mlp.router.gate.weight", ".mlp.gate.weight"), + (".mlp.expert_bias", ".mlp.gate.e_score_correction_bias"), +] + + +class AfmoeStateDictAdapter(MoESplitExpertsStateDictMixin, StateDictAdapter): + """Converts between HF Afmoe checkpoints and NeMo grouped-experts native format.""" + + def __init__( + self, + config: Any, + moe_config: MoEConfig, + backend: BackendConfig, + dtype: torch.dtype = torch.float32, + ): + self.config = config + self.moe_config = moe_config + self.backend = backend + self.dtype = dtype + self._uses_model_prefix = True + + def from_hf( + self, + hf_state_dict: dict[str, Any], + device_mesh: Optional[DeviceMesh] = None, + **kwargs, + ) -> dict[str, Any]: + # Detect whether HF checkpoints use the "model." prefix + for key in hf_state_dict.keys(): + if ".mlp.experts." in key and key.endswith(".weight"): + self._uses_model_prefix = key.startswith("model.") + break + + # Rename HF keys to NeMo keys before expert merging + renamed = {} + for key, value in list(hf_state_dict.items()): + new_key = key + for hf_pat, nemo_pat in _KEY_RENAMES_HF_TO_NEMO: + if hf_pat in new_key: + new_key = new_key.replace(hf_pat, nemo_pat) + break + renamed[new_key] = value + hf_state_dict = renamed + + return self._from_hf_w_merged_experts(hf_state_dict, device_mesh) + + def to_hf( + self, + state_dict: dict[str, Any], + exclude_key_regex: Optional[str] = None, + quantization: bool = False, + **kwargs, + ) -> dict[str, Any]: + hf_state_dict = {} + for fqn, tensor in state_dict.items(): + converted_tensors = self.convert_single_tensor_to_hf( + fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs + ) + for key, value in converted_tensors: + hf_state_dict[key] = value + + # Rename NeMo keys back to HF keys + renamed = {} + for key, value in hf_state_dict.items(): + new_key = key + for hf_pat, nemo_pat in _KEY_RENAMES_HF_TO_NEMO: + if nemo_pat in new_key: + new_key = new_key.replace(nemo_pat, hf_pat) + break + renamed[new_key] = value + + return renamed + + def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[tuple[str, Any]]: + exclude_key_regex = kwargs.get("exclude_key_regex", None) + + expert_result = self._convert_single_merged_expert_to_hf_split_experts(fqn, tensor, **kwargs) + if expert_result is not None: + result = expert_result + else: + result = [(fqn, tensor)] + + if exclude_key_regex: + result = [(k, v) for k, v in result if not re.match(exclude_key_regex, k)] + + return result diff --git a/tests/unit_tests/models/afmoe/__init__.py b/tests/unit_tests/models/afmoe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/models/afmoe/test_afmoe_layers.py b/tests/unit_tests/models/afmoe/test_afmoe_layers.py new file mode 100644 index 0000000000..74740ace5e --- /dev/null +++ b/tests/unit_tests/models/afmoe/test_afmoe_layers.py @@ -0,0 +1,100 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nemo_automodel.components.models.afmoe.config import AfmoeConfig +from nemo_automodel.components.models.afmoe.layers import AfmoeAttention +from nemo_automodel.components.models.common import BackendConfig + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +@pytest.fixture +def device(): + return torch.device(f"cuda:{torch.cuda.current_device()}") + + +@pytest.fixture +def tiny_config(): + return AfmoeConfig( + vocab_size=256, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + num_hidden_layers=4, + intermediate_size=128, + moe_intermediate_size=32, + num_experts=4, + num_experts_per_tok=2, + num_shared_experts=1, + num_dense_layers=1, + max_position_embeddings=128, + rms_norm_eps=1e-5, + global_attn_every_n_layers=2, + sliding_window=64, + ) + + +@pytest.fixture +def backend_config(): + return BackendConfig( + linear="torch", + attn="sdpa", + rms_norm="torch", + rope_fusion=False, + ) + + +class TestAfmoeAttention: + def test_local_attention_has_sliding_window(self, tiny_config, backend_config): + attn = AfmoeAttention(tiny_config, layer_idx=0, backend=backend_config) + assert attn.is_local_attention is True + assert attn.sliding_window == tiny_config.sliding_window + + def test_global_attention_no_sliding_window(self, tiny_config, backend_config): + attn = AfmoeAttention(tiny_config, layer_idx=1, backend=backend_config) + assert attn.is_local_attention is False + assert attn.sliding_window is None + + def test_has_gate_proj(self, tiny_config, backend_config): + attn = AfmoeAttention(tiny_config, layer_idx=0, backend=backend_config) + assert hasattr(attn, "gate_proj") + + def test_has_qk_norm(self, tiny_config, backend_config): + attn = AfmoeAttention(tiny_config, layer_idx=0, backend=backend_config) + assert hasattr(attn, "q_norm") + assert hasattr(attn, "k_norm") + + def test_forward_shape(self, tiny_config, backend_config, device): + attn = AfmoeAttention(tiny_config, layer_idx=0, backend=backend_config).to(device).to(torch.float32) + + batch, seq_len = 2, 8 + x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device) + freqs_cis = torch.randn(batch, seq_len, tiny_config.head_dim, device=device) + + out = attn(x, freqs_cis=freqs_cis) + assert out.shape == (batch, seq_len, tiny_config.hidden_size) + + def test_global_attention_forward_shape(self, tiny_config, backend_config, device): + attn = AfmoeAttention(tiny_config, layer_idx=1, backend=backend_config).to(device).to(torch.float32) + + batch, seq_len = 2, 8 + x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device) + freqs_cis = torch.randn(batch, seq_len, tiny_config.head_dim, device=device) + + out = attn(x, freqs_cis=freqs_cis) + assert out.shape == (batch, seq_len, tiny_config.hidden_size) diff --git a/tests/unit_tests/models/afmoe/test_afmoe_model.py b/tests/unit_tests/models/afmoe/test_afmoe_model.py new file mode 100644 index 0000000000..767c348754 --- /dev/null +++ b/tests/unit_tests/models/afmoe/test_afmoe_model.py @@ -0,0 +1,197 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest +import torch + +from nemo_automodel.components.models.afmoe.config import AfmoeConfig +from nemo_automodel.components.models.afmoe.model import AfmoeForCausalLM, AfmoeModel, Block, _build_moe_config +from nemo_automodel.components.models.common import BackendConfig +from nemo_automodel.components.moe.layers import MLP, MoE + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +@pytest.fixture +def device(): + if torch.cuda.is_available(): + return torch.device(f"cuda:{torch.cuda.current_device()}") + return torch.device("cpu") + + +@pytest.fixture +def tiny_config(): + return AfmoeConfig( + vocab_size=256, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + num_hidden_layers=4, + intermediate_size=128, + moe_intermediate_size=32, + num_experts=4, + num_experts_per_tok=2, + num_shared_experts=1, + num_dense_layers=1, + max_position_embeddings=128, + rms_norm_eps=1e-5, + rope_theta=10000.0, + score_func="sigmoid", + route_norm=True, + route_scale=2.0, + global_attn_every_n_layers=2, + sliding_window=64, + mup_enabled=False, + n_group=1, + topk_group=1, + ) + + +@pytest.fixture +def backend_config(): + return BackendConfig( + linear="torch", + attn="sdpa", + rms_norm="torch", + experts="torch", + dispatcher="torch", + fake_balanced_gate=False, + enable_hf_state_dict_adapter=False, + ) + + +class TestBlock: + def test_dense_layer_uses_mlp(self, tiny_config, backend_config): + moe_config = _build_moe_config(tiny_config) + block = Block(layer_idx=0, config=tiny_config, moe_config=moe_config, backend=backend_config) + assert isinstance(block.mlp, MLP) + + def test_moe_layer_uses_moe(self, tiny_config, backend_config): + moe_config = _build_moe_config(tiny_config) + block = Block(layer_idx=1, config=tiny_config, moe_config=moe_config, backend=backend_config) + assert isinstance(block.mlp, MoE) + + def test_has_four_norms(self, tiny_config, backend_config): + moe_config = _build_moe_config(tiny_config) + block = Block(layer_idx=0, config=tiny_config, moe_config=moe_config, backend=backend_config) + assert hasattr(block, "input_layernorm") + assert hasattr(block, "post_attention_layernorm") + assert hasattr(block, "pre_mlp_layernorm") + assert hasattr(block, "post_mlp_layernorm") + + def test_forward_shape(self, tiny_config, backend_config, device): + moe_config = _build_moe_config(tiny_config) + block = Block(layer_idx=0, config=tiny_config, moe_config=moe_config, backend=backend_config).to(device) + + batch, seq_len = 2, 8 + x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device) + freqs_cis = torch.randn(batch, seq_len, tiny_config.head_dim, device=device) + + with ( + patch.object(block.self_attn, "forward", return_value=torch.zeros_like(x)) as mock_attn, + patch.object(block, "_mlp", return_value=torch.zeros_like(x)) as mock_mlp, + ): + out = block(x, freqs_cis=freqs_cis) + + assert out.shape == x.shape + mock_attn.assert_called_once() + mock_mlp.assert_called_once() + + +class TestAfmoeModel: + def test_initialization(self, tiny_config, backend_config): + model = AfmoeModel(tiny_config, backend=backend_config) + + assert model.config == tiny_config + assert len(model.layers) == tiny_config.num_hidden_layers + assert model.embed_tokens.num_embeddings == tiny_config.vocab_size + + def test_layer_types_correct(self, tiny_config, backend_config): + model = AfmoeModel(tiny_config, backend=backend_config) + + # layer 0: sliding, layer 1: full, layer 2: sliding, layer 3: full + assert model.layers["0"].self_attn.is_local_attention is True + assert model.layers["1"].self_attn.is_local_attention is False + assert model.layers["2"].self_attn.is_local_attention is True + assert model.layers["3"].self_attn.is_local_attention is False + + def test_dense_vs_moe_layers(self, tiny_config, backend_config): + model = AfmoeModel(tiny_config, backend=backend_config) + + # layer 0 is dense (num_dense_layers=1), layers 1-3 are MoE + assert isinstance(model.layers["0"].mlp, MLP) + assert isinstance(model.layers["1"].mlp, MoE) + assert isinstance(model.layers["2"].mlp, MoE) + assert isinstance(model.layers["3"].mlp, MoE) + + +class TestAfmoeForCausalLM: + def test_forward_returns_logits(self, tiny_config, backend_config, device): + model = AfmoeForCausalLM(tiny_config, backend=backend_config).to(device) + + batch, seq_len = 2, 8 + input_ids = torch.randint(0, tiny_config.vocab_size, (batch, seq_len), device=device) + + with patch.object( + model.model, + "forward", + return_value=torch.randn(batch, seq_len, tiny_config.hidden_size, device=device, dtype=torch.bfloat16), + ): + logits = model(input_ids) + + assert logits.shape == (batch, seq_len, tiny_config.vocab_size) + + def test_state_dict_adapter_created(self, tiny_config): + backend = BackendConfig( + linear="torch", + attn="sdpa", + rms_norm="torch", + experts="torch", + dispatcher="torch", + enable_hf_state_dict_adapter=True, + ) + model = AfmoeForCausalLM(tiny_config, backend=backend) + assert hasattr(model, "state_dict_adapter") + + def test_modelclass_export(self): + from nemo_automodel.components.models.afmoe import model as afmoe_mod + + assert hasattr(afmoe_mod, "ModelClass") + assert afmoe_mod.ModelClass is AfmoeForCausalLM + + def test_from_pretrained_classmethod(self, tiny_config): + with patch.object(AfmoeConfig, "from_pretrained", return_value=tiny_config): + with patch.object(AfmoeForCausalLM, "from_config", wraps=AfmoeForCausalLM.from_config) as mock: + model = AfmoeForCausalLM.from_pretrained("arcee-ai/Trinity-Large-Thinking") + assert isinstance(model, AfmoeForCausalLM) + mock.assert_called_once() + + +class TestBuildMoeConfig: + def test_fields_mapped_correctly(self, tiny_config): + moe_cfg = _build_moe_config(tiny_config) + + assert moe_cfg.dim == tiny_config.hidden_size + assert moe_cfg.inter_dim == tiny_config.intermediate_size + assert moe_cfg.moe_inter_dim == tiny_config.moe_intermediate_size + assert moe_cfg.n_routed_experts == tiny_config.num_experts + assert moe_cfg.n_shared_experts == tiny_config.num_shared_experts + assert moe_cfg.n_activated_experts == tiny_config.num_experts_per_tok + assert moe_cfg.score_func == "sigmoid" + assert moe_cfg.route_scale == tiny_config.route_scale + assert moe_cfg.norm_topk_prob is True + assert moe_cfg.force_e_score_correction_bias is True diff --git a/tests/unit_tests/models/afmoe/test_afmoe_state_dict_adapter.py b/tests/unit_tests/models/afmoe/test_afmoe_state_dict_adapter.py new file mode 100644 index 0000000000..337012cc4a --- /dev/null +++ b/tests/unit_tests/models/afmoe/test_afmoe_state_dict_adapter.py @@ -0,0 +1,189 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +import pytest +import torch + +from nemo_automodel.components.models.afmoe.config import AfmoeConfig +from nemo_automodel.components.models.afmoe.state_dict_adapter import AfmoeStateDictAdapter +from nemo_automodel.components.models.common import BackendConfig +from nemo_automodel.components.moe.config import MoEConfig + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +@pytest.fixture +def config(): + cfg = Mock(spec=AfmoeConfig) + cfg.num_hidden_layers = 2 + cfg.hidden_size = 64 + cfg.intermediate_size = 128 + cfg.moe_intermediate_size = 32 + cfg.num_attention_heads = 4 + cfg.num_key_value_heads = 2 + cfg.num_experts = 4 + cfg.num_experts_per_tok = 2 + cfg.num_shared_experts = 1 + cfg.num_dense_layers = 1 + return cfg + + +@pytest.fixture +def moe_config(): + return MoEConfig( + dim=64, + inter_dim=128, + moe_inter_dim=32, + n_routed_experts=4, + n_shared_experts=1, + n_activated_experts=2, + n_expert_groups=1, + n_limited_groups=1, + train_gate=True, + gate_bias_update_factor=0.001, + score_func="sigmoid", + route_scale=2.0, + aux_loss_coeff=0.0, + norm_topk_prob=True, + expert_bias=False, + router_bias=False, + expert_activation="swiglu", + force_e_score_correction_bias=True, + shared_expert_inter_dim=32, + ) + + +@pytest.fixture +def backend(): + return BackendConfig( + linear="torch", + attn="sdpa", + rms_norm="torch", + experts="torch", + dispatcher="torch", + ) + + +@pytest.fixture +def adapter(config, moe_config, backend): + return AfmoeStateDictAdapter(config, moe_config, backend, dtype=torch.bfloat16) + + +def _make_hf_expert_state_dict(n_layers=2, n_experts=4, hidden=64, moe_inter=32, num_dense=1): + """Create a minimal HF-format state dict with router, experts, and expert_bias.""" + sd = {} + for layer_idx in range(n_layers): + prefix = f"model.layers.{layer_idx}" + if layer_idx >= num_dense: + # Router gate + sd[f"{prefix}.mlp.router.gate.weight"] = torch.randn(n_experts, hidden) + # Expert bias + sd[f"{prefix}.mlp.expert_bias"] = torch.zeros(n_experts) + # Per-expert weights + for e in range(n_experts): + sd[f"{prefix}.mlp.experts.{e}.gate_proj.weight"] = torch.randn(moe_inter, hidden) + sd[f"{prefix}.mlp.experts.{e}.up_proj.weight"] = torch.randn(moe_inter, hidden) + sd[f"{prefix}.mlp.experts.{e}.down_proj.weight"] = torch.randn(hidden, moe_inter) + # Shared expert + sd[f"{prefix}.mlp.shared_experts.gate_proj.weight"] = torch.randn(moe_inter, hidden) + sd[f"{prefix}.mlp.shared_experts.up_proj.weight"] = torch.randn(moe_inter, hidden) + sd[f"{prefix}.mlp.shared_experts.down_proj.weight"] = torch.randn(hidden, moe_inter) + else: + # Dense MLP + sd[f"{prefix}.mlp.gate_proj.weight"] = torch.randn(128, hidden) + sd[f"{prefix}.mlp.up_proj.weight"] = torch.randn(128, hidden) + sd[f"{prefix}.mlp.down_proj.weight"] = torch.randn(hidden, 128) + return sd + + +class TestAfmoeStateDictAdapter: + def test_router_key_renamed_from_hf(self, adapter): + hf_sd = _make_hf_expert_state_dict() + nemo_sd = adapter.from_hf(hf_sd) + + # Router gate should be renamed + assert "model.layers.1.mlp.gate.weight" in nemo_sd + assert "model.layers.1.mlp.router.gate.weight" not in nemo_sd + + def test_expert_bias_renamed_from_hf(self, adapter): + hf_sd = _make_hf_expert_state_dict() + nemo_sd = adapter.from_hf(hf_sd) + + assert "model.layers.1.mlp.gate.e_score_correction_bias" in nemo_sd + assert "model.layers.1.mlp.expert_bias" not in nemo_sd + + def test_experts_merged_from_hf(self, adapter): + hf_sd = _make_hf_expert_state_dict() + nemo_sd = adapter.from_hf(hf_sd) + + # Per-expert keys should be merged into grouped format + assert "model.layers.1.mlp.experts.gate_and_up_projs" in nemo_sd + assert "model.layers.1.mlp.experts.down_projs" in nemo_sd + # Individual expert keys should be gone + assert "model.layers.1.mlp.experts.0.gate_proj.weight" not in nemo_sd + + def test_experts_merged_shape(self, adapter): + hf_sd = _make_hf_expert_state_dict() + nemo_sd = adapter.from_hf(hf_sd) + + gate_up = nemo_sd["model.layers.1.mlp.experts.gate_and_up_projs"] + down = nemo_sd["model.layers.1.mlp.experts.down_projs"] + # gate_and_up: [n_experts, dim, 2*moe_inter] + assert gate_up.shape == (4, 64, 64) # 4 experts, dim=64, 2*32=64 + # down: [n_experts, moe_inter, dim] + assert down.shape == (4, 32, 64) + + def test_shared_experts_pass_through(self, adapter): + hf_sd = _make_hf_expert_state_dict() + nemo_sd = adapter.from_hf(hf_sd) + + # Shared experts should pass through unchanged + assert "model.layers.1.mlp.shared_experts.gate_proj.weight" in nemo_sd + assert "model.layers.1.mlp.shared_experts.up_proj.weight" in nemo_sd + assert "model.layers.1.mlp.shared_experts.down_proj.weight" in nemo_sd + + def test_dense_mlp_pass_through(self, adapter): + hf_sd = _make_hf_expert_state_dict() + nemo_sd = adapter.from_hf(hf_sd) + + # Dense layer (layer 0) MLP keys pass through + assert "model.layers.0.mlp.gate_proj.weight" in nemo_sd + assert "model.layers.0.mlp.up_proj.weight" in nemo_sd + assert "model.layers.0.mlp.down_proj.weight" in nemo_sd + + def test_to_hf_reverses_router_rename(self, adapter): + nemo_sd = { + "model.layers.1.mlp.gate.weight": torch.randn(4, 64), + "model.layers.1.mlp.gate.e_score_correction_bias": torch.zeros(4), + "model.layers.0.mlp.gate_proj.weight": torch.randn(128, 64), + } + hf_sd = adapter.to_hf(nemo_sd) + + assert "model.layers.1.mlp.router.gate.weight" in hf_sd + assert "model.layers.1.mlp.expert_bias" in hf_sd + assert "model.layers.0.mlp.gate_proj.weight" in hf_sd + + def test_to_hf_splits_experts(self, adapter): + nemo_sd = { + "model.layers.1.mlp.experts.gate_and_up_projs": torch.randn(4, 64, 64), + "model.layers.1.mlp.experts.down_projs": torch.randn(4, 32, 64), + } + hf_sd = adapter.to_hf(nemo_sd) + + for e in range(4): + assert f"model.layers.1.mlp.experts.{e}.gate_proj.weight" in hf_sd + assert f"model.layers.1.mlp.experts.{e}.up_proj.weight" in hf_sd + assert f"model.layers.1.mlp.experts.{e}.down_proj.weight" in hf_sd From 244e199180d8a994b6e2116027d3f5b27c2fc281 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 2 Apr 2026 11:18:16 -0700 Subject: [PATCH 2/3] update tests Signed-off-by: Alexandros Koumparoulis --- .../models/afmoe/test_afmoe_layers.py | 45 +++++++++- .../models/afmoe/test_afmoe_model.py | 85 ++++++++++++++++++- .../afmoe/test_afmoe_state_dict_adapter.py | 39 ++++++--- 3 files changed, 153 insertions(+), 16 deletions(-) diff --git a/tests/unit_tests/models/afmoe/test_afmoe_layers.py b/tests/unit_tests/models/afmoe/test_afmoe_layers.py index 74740ace5e..6133f91d8c 100644 --- a/tests/unit_tests/models/afmoe/test_afmoe_layers.py +++ b/tests/unit_tests/models/afmoe/test_afmoe_layers.py @@ -80,21 +80,58 @@ def test_has_qk_norm(self, tiny_config, backend_config): assert hasattr(attn, "k_norm") def test_forward_shape(self, tiny_config, backend_config, device): - attn = AfmoeAttention(tiny_config, layer_idx=0, backend=backend_config).to(device).to(torch.float32) + attn = AfmoeAttention(tiny_config, layer_idx=0, backend=backend_config).to(device) batch, seq_len = 2, 8 - x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device) + x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device, dtype=torch.bfloat16) freqs_cis = torch.randn(batch, seq_len, tiny_config.head_dim, device=device) out = attn(x, freqs_cis=freqs_cis) assert out.shape == (batch, seq_len, tiny_config.hidden_size) def test_global_attention_forward_shape(self, tiny_config, backend_config, device): - attn = AfmoeAttention(tiny_config, layer_idx=1, backend=backend_config).to(device).to(torch.float32) + attn = AfmoeAttention(tiny_config, layer_idx=1, backend=backend_config).to(device) batch, seq_len = 2, 8 - x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device) + x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device, dtype=torch.bfloat16) freqs_cis = torch.randn(batch, seq_len, tiny_config.head_dim, device=device) out = attn(x, freqs_cis=freqs_cis) assert out.shape == (batch, seq_len, tiny_config.hidden_size) + + +class TestAfmoeAttentionParity: + def test_rope_conditional_local_vs_global(self, tiny_config, backend_config, device): + """Local attention (with RoPE) and global attention (without) must diverge given shared weights.""" + torch.manual_seed(42) + local_attn = AfmoeAttention(tiny_config, layer_idx=0, backend=backend_config).to(device) + global_attn = AfmoeAttention(tiny_config, layer_idx=1, backend=backend_config).to(device) + global_attn.load_state_dict(local_attn.state_dict()) + + batch, seq_len = 2, 8 + x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device, dtype=torch.bfloat16) + freqs_cis = torch.randn(batch, seq_len, tiny_config.head_dim, device=device) + + with torch.no_grad(): + local_out = local_attn(x, freqs_cis=freqs_cis) + global_out = global_attn(x, freqs_cis=freqs_cis) + + max_diff = (local_out - global_out).abs().max().item() + assert max_diff > 0.01, f"RoPE should cause divergence, but max_diff={max_diff}" + + def test_qk_norm_reduces_head_variance(self, tiny_config, backend_config, device): + """Per-head QK RMSNorm should equalize magnitudes across heads.""" + attn = AfmoeAttention(tiny_config, layer_idx=0, backend=backend_config).to(device) + + batch, seq_len = 1, 4 + q = torch.randn( + batch, seq_len, tiny_config.num_attention_heads, tiny_config.head_dim, device=device, dtype=torch.bfloat16 + ) + q[:, :, 0, :] *= 10.0 # Make first head 10x larger + + with torch.no_grad(): + q_normed = attn.q_norm(q) + + pre_var = q.norm(dim=-1).var(dim=-1).mean().item() + post_var = q_normed.norm(dim=-1).var(dim=-1).mean().item() + assert post_var < pre_var, "QK norm should reduce variance across heads" diff --git a/tests/unit_tests/models/afmoe/test_afmoe_model.py b/tests/unit_tests/models/afmoe/test_afmoe_model.py index 767c348754..2bf8c1164b 100644 --- a/tests/unit_tests/models/afmoe/test_afmoe_model.py +++ b/tests/unit_tests/models/afmoe/test_afmoe_model.py @@ -20,7 +20,8 @@ from nemo_automodel.components.models.afmoe.config import AfmoeConfig from nemo_automodel.components.models.afmoe.model import AfmoeForCausalLM, AfmoeModel, Block, _build_moe_config from nemo_automodel.components.models.common import BackendConfig -from nemo_automodel.components.moe.layers import MLP, MoE +from nemo_automodel.components.moe.config import MoEConfig +from nemo_automodel.components.moe.layers import MLP, Gate, MoE pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -195,3 +196,85 @@ def test_fields_mapped_correctly(self, tiny_config): assert moe_cfg.route_scale == tiny_config.route_scale assert moe_cfg.norm_topk_prob is True assert moe_cfg.force_e_score_correction_bias is True + + +class TestDualNormParity: + def test_manual_trace_matches_forward(self, tiny_config, backend_config, device): + """Manual 4-norm residual trace must be bit-identical to Block.forward().""" + torch.manual_seed(42) + moe_config = _build_moe_config(tiny_config) + block = Block(layer_idx=0, config=tiny_config, moe_config=moe_config, backend=backend_config).to(device) + block.eval() + + batch, seq_len = 1, 4 + x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device, dtype=torch.bfloat16) + freqs_cis = torch.randn(batch, seq_len, tiny_config.head_dim, device=device) + + with torch.no_grad(): + # Manual trace: attention sublayer + residual = x + h = block.input_layernorm(x) + h = block.self_attn(h, freqs_cis=freqs_cis) + h = block.post_attention_layernorm(h) + after_attn = residual + h + + # Manual trace: MLP sublayer + residual = after_attn + h = block.pre_mlp_layernorm(after_attn) + h = block._mlp(h, padding_mask=None) + h = block.post_mlp_layernorm(h) + expected = residual + h + + # Block forward + actual = block(x, freqs_cis=freqs_cis) + + torch.testing.assert_close(actual, expected, rtol=0, atol=0) + + +class TestMoeRoutingParity: + def test_sigmoid_norm_scale(self, device): + """Manual sigmoid -> topk -> normalize -> scale must match Gate.forward().""" + torch.manual_seed(42) + + moe_config = MoEConfig( + dim=64, + inter_dim=128, + moe_inter_dim=32, + n_routed_experts=4, + n_shared_experts=1, + n_activated_experts=2, + n_expert_groups=1, + n_limited_groups=1, + train_gate=False, + gate_bias_update_factor=0.0, + score_func="sigmoid", + route_scale=2.0, + aux_loss_coeff=0.0, + norm_topk_prob=True, + force_e_score_correction_bias=True, + dtype=torch.bfloat16, + ) + + gate = Gate(moe_config).to(device) + torch.manual_seed(123) + gate.weight.data = torch.randn(4, 64, device=device, dtype=torch.bfloat16) + + x = torch.randn(8, 64, device=device, dtype=torch.bfloat16) # 8 tokens + token_mask = torch.ones(8, dtype=torch.bool, device=device) + + with torch.no_grad(): + weights, indices, aux_loss = gate(x, token_mask, cp_mesh=None) + + # Manual reference: sigmoid -> bias -> topk -> gather original -> normalize -> scale + with torch.no_grad(): + scores = torch.sigmoid(x @ gate.weight.data.T) # [8, 4] + original_scores = scores.clone() + biased = scores + gate.e_score_correction_bias # zeros, no-op + manual_idx = torch.topk(biased, 2, dim=-1)[1] + manual_w = original_scores.gather(1, manual_idx) + manual_w = manual_w / (manual_w.sum(dim=-1, keepdim=True) + 1e-20) + manual_w = manual_w * 2.0 + + assert torch.equal(indices, manual_idx), "Expert indices mismatch" + torch.testing.assert_close(weights, manual_w, rtol=1e-3, atol=1e-3) + assert aux_loss is None diff --git a/tests/unit_tests/models/afmoe/test_afmoe_state_dict_adapter.py b/tests/unit_tests/models/afmoe/test_afmoe_state_dict_adapter.py index 337012cc4a..d7e5060915 100644 --- a/tests/unit_tests/models/afmoe/test_afmoe_state_dict_adapter.py +++ b/tests/unit_tests/models/afmoe/test_afmoe_state_dict_adapter.py @@ -82,30 +82,30 @@ def adapter(config, moe_config, backend): return AfmoeStateDictAdapter(config, moe_config, backend, dtype=torch.bfloat16) -def _make_hf_expert_state_dict(n_layers=2, n_experts=4, hidden=64, moe_inter=32, num_dense=1): +def _make_hf_expert_state_dict(n_layers=2, n_experts=4, hidden=64, moe_inter=32, num_dense=1, dtype=torch.bfloat16): """Create a minimal HF-format state dict with router, experts, and expert_bias.""" sd = {} for layer_idx in range(n_layers): prefix = f"model.layers.{layer_idx}" if layer_idx >= num_dense: # Router gate - sd[f"{prefix}.mlp.router.gate.weight"] = torch.randn(n_experts, hidden) + sd[f"{prefix}.mlp.router.gate.weight"] = torch.randn(n_experts, hidden, dtype=dtype) # Expert bias sd[f"{prefix}.mlp.expert_bias"] = torch.zeros(n_experts) # Per-expert weights for e in range(n_experts): - sd[f"{prefix}.mlp.experts.{e}.gate_proj.weight"] = torch.randn(moe_inter, hidden) - sd[f"{prefix}.mlp.experts.{e}.up_proj.weight"] = torch.randn(moe_inter, hidden) - sd[f"{prefix}.mlp.experts.{e}.down_proj.weight"] = torch.randn(hidden, moe_inter) + sd[f"{prefix}.mlp.experts.{e}.gate_proj.weight"] = torch.randn(moe_inter, hidden, dtype=dtype) + sd[f"{prefix}.mlp.experts.{e}.up_proj.weight"] = torch.randn(moe_inter, hidden, dtype=dtype) + sd[f"{prefix}.mlp.experts.{e}.down_proj.weight"] = torch.randn(hidden, moe_inter, dtype=dtype) # Shared expert - sd[f"{prefix}.mlp.shared_experts.gate_proj.weight"] = torch.randn(moe_inter, hidden) - sd[f"{prefix}.mlp.shared_experts.up_proj.weight"] = torch.randn(moe_inter, hidden) - sd[f"{prefix}.mlp.shared_experts.down_proj.weight"] = torch.randn(hidden, moe_inter) + sd[f"{prefix}.mlp.shared_experts.gate_proj.weight"] = torch.randn(moe_inter, hidden, dtype=dtype) + sd[f"{prefix}.mlp.shared_experts.up_proj.weight"] = torch.randn(moe_inter, hidden, dtype=dtype) + sd[f"{prefix}.mlp.shared_experts.down_proj.weight"] = torch.randn(hidden, moe_inter, dtype=dtype) else: # Dense MLP - sd[f"{prefix}.mlp.gate_proj.weight"] = torch.randn(128, hidden) - sd[f"{prefix}.mlp.up_proj.weight"] = torch.randn(128, hidden) - sd[f"{prefix}.mlp.down_proj.weight"] = torch.randn(hidden, 128) + sd[f"{prefix}.mlp.gate_proj.weight"] = torch.randn(128, hidden, dtype=dtype) + sd[f"{prefix}.mlp.up_proj.weight"] = torch.randn(128, hidden, dtype=dtype) + sd[f"{prefix}.mlp.down_proj.weight"] = torch.randn(hidden, 128, dtype=dtype) return sd @@ -187,3 +187,20 @@ def test_to_hf_splits_experts(self, adapter): assert f"model.layers.1.mlp.experts.{e}.gate_proj.weight" in hf_sd assert f"model.layers.1.mlp.experts.{e}.up_proj.weight" in hf_sd assert f"model.layers.1.mlp.experts.{e}.down_proj.weight" in hf_sd + + def test_roundtrip_preserves_all_values(self, adapter): + """HF -> NeMo -> HF round-trip must preserve exact tensor values.""" + torch.manual_seed(42) + hf_sd = _make_hf_expert_state_dict() + originals = {k: v.clone() for k, v in hf_sd.items()} + + nemo_sd = adapter.from_hf(hf_sd) + roundtrip_sd = adapter.to_hf(nemo_sd) + + assert set(roundtrip_sd.keys()) == set(originals.keys()), ( + f"Missing: {set(originals.keys()) - set(roundtrip_sd.keys())}, " + f"Extra: {set(roundtrip_sd.keys()) - set(originals.keys())}" + ) + for key in originals: + max_diff = (originals[key].float() - roundtrip_sd[key].float()).abs().max().item() + assert max_diff == 0.0, f"Round-trip mismatch for {key}: max_diff={max_diff}" From 3741031b86f4b2279b9868cb5ccc583dda914928 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 2 Apr 2026 12:16:04 -0700 Subject: [PATCH 3/3] fix Signed-off-by: Alexandros Koumparoulis --- tests/unit_tests/models/afmoe/test_afmoe_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/models/afmoe/test_afmoe_model.py b/tests/unit_tests/models/afmoe/test_afmoe_model.py index 2bf8c1164b..abbbf4d130 100644 --- a/tests/unit_tests/models/afmoe/test_afmoe_model.py +++ b/tests/unit_tests/models/afmoe/test_afmoe_model.py @@ -208,7 +208,10 @@ def test_manual_trace_matches_forward(self, tiny_config, backend_config, device) batch, seq_len = 1, 4 x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device, dtype=torch.bfloat16) - freqs_cis = torch.randn(batch, seq_len, tiny_config.head_dim, device=device) + if backend_config.rope_fusion: + freqs_cis = torch.randn(seq_len, 1, 1, tiny_config.head_dim, device=device) + else: + freqs_cis = torch.randn(batch, seq_len, tiny_config.head_dim, device=device) with torch.no_grad(): # Manual trace: attention sublayer