Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/model-coverage/llm.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The table below lists the main architectures we test against (FSDP2 combined wit

| Architecture | Models | Example HF Models |
|---------------------------------------|---------------------------------------|-----------------------------------------------------------------------------------|
| `AfmoeForCausalLM` | Afmoe (Arcee Fusion MoE) | `arcee-ai/Trinity-Large-Thinking` — example recipe: [trinity_large_thinking_sft.yaml](../../examples/llm_finetune/afmoe/trinity_large_thinking_sft.yaml) |
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. |
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. — example recipes: [baichuan_2_7b_squad.yaml](../../examples/llm_finetune/baichuan/baichuan_2_7b_squad.yaml), [baichuan_2_7b_squad_peft.yaml](../../examples/llm_finetune/baichuan/baichuan_2_7b_squad_peft.yaml) |
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B` |
Expand Down
88 changes: 88 additions & 0 deletions examples/llm_finetune/afmoe/trinity_large_thinking_sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Afmoe (Arcee Trinity-Large-Thinking) SFT example
# 256 experts, 4 active per token, 60 layers, ~large model
#
# To run this recipe:
# automodel examples/llm_finetune/afmoe/trinity_large_thinking_sft.yaml --nproc-per-node 8
# Adjust --nproc-per-node to the number of GPUs available on your machine.

recipe: TrainFinetuneRecipeForNextTokenPrediction

step_scheduler:
global_batch_size: 32
local_batch_size: 1
ckpt_every_steps: 200
val_every_steps: 100
num_epochs: 1

dist_env:
backend: nccl
timeout_minutes: 10

rng:
_target_: nemo_automodel.components.training.rng.StatefulRNG
seed: 1111
ranked: true

model:
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
pretrained_model_name_or_path: arcee-ai/Trinity-Large-Thinking

checkpoint:
enabled: false
checkpoint_dir: checkpoints/
model_save_format: safetensors
save_consolidated: false

distributed:
strategy: fsdp2
dp_size: none
tp_size: 1
cp_size: 1
sequence_parallel: false

loss_fn:
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy

dataset:
_target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag
path_or_dataset: rowan/hellaswag
split: train

packed_sequence:
packed_sequence_size: 0

dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
collate_fn: nemo_automodel.components.datasets.utils.default_collater
shuffle: false

validation_dataset:
_target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag
path_or_dataset: rowan/hellaswag
split: validation
num_samples_limit: 64

validation_dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
collate_fn: nemo_automodel.components.datasets.utils.default_collater

optimizer:
_target_: torch.optim.Adam
betas: [0.9, 0.999]
eps: 1e-8
lr: 1.0e-5
weight_decay: 0
5 changes: 5 additions & 0 deletions nemo_automodel/_transformers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
# downstream code to classify model archs without importing them.
MODEL_ARCH_MAPPING = OrderedDict(
[
(
"AfmoeForCausalLM",
("nemo_automodel.components.models.afmoe.model", "AfmoeForCausalLM"),
),
(
"BaichuanForCausalLM",
("nemo_automodel.components.models.baichuan.model", "BaichuanForCausalLM"),
Expand Down Expand Up @@ -154,6 +158,7 @@
# checkpoint config.json. Registered eagerly with AutoConfig so that
# AutoConfig.from_pretrained can resolve them without trust_remote_code.
_CUSTOM_CONFIG_REGISTRATIONS: Dict[str, Tuple[str, str]] = {
"afmoe": ("nemo_automodel.components.models.afmoe.config", "AfmoeConfig"),
"baichuan": ("nemo_automodel.components.models.baichuan.configuration", "BaichuanConfig"),
"kimi_k25": ("nemo_automodel.components.models.kimi_k25_vl.model", "KimiK25VLConfig"),
"kimi_vl": ("nemo_automodel.components.models.kimivl.model", "KimiVLConfig"),
Expand Down
17 changes: 17 additions & 0 deletions nemo_automodel/components/models/afmoe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_automodel.components.models.afmoe.model import AfmoeForCausalLM

__all__ = ["AfmoeForCausalLM"]
123 changes: 123 additions & 0 deletions nemo_automodel/components/models/afmoe/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation


class AfmoeConfig(PretrainedConfig):
"""Configuration for the Afmoe (Arcee Fusion MoE) model.

This is a Mixture-of-Experts model with hybrid sliding-window / full attention,
gated attention output, QK normalization, and dual pre/post normalization.
"""

model_type = "afmoe"

def __init__(
self,
num_hidden_layers: int = 32,
vocab_size: int = 200192,
hidden_size: int = 2048,
intermediate_size: int = 6144,
moe_intermediate_size=1408,
num_dense_layers=1,
num_attention_heads=16,
num_key_value_heads=None,
head_dim=128,
hidden_act="silu",
max_position_embeddings=16384,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
num_experts=64,
num_experts_per_tok=6,
num_shared_experts=2,
num_expert_groups=1,
num_limited_groups=1,
score_func="sigmoid",
route_norm=True,
route_scale=1.0,
global_attn_every_n_layers=4,
sliding_window=1024,
mup_enabled=False,
layer_types=None,
attention_dropout: float = 0.0,
n_group: int = 1,
topk_group: int = 1,
load_balance_coeff: float = 0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_dense_layers = num_dense_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling

# MoE specific
self.moe_intermediate_size = moe_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.n_group = n_group
self.topk_group = topk_group
self.num_experts = num_experts
self.num_shared_experts = num_shared_experts
self.num_expert_groups = num_expert_groups
self.num_limited_groups = num_limited_groups
self.score_func = score_func
self.route_norm = route_norm
self.route_scale = route_scale
self.load_balance_coeff = load_balance_coeff

# Attention specific
self.attention_dropout = attention_dropout
self.global_attn_every_n_layers = global_attn_every_n_layers
self.sliding_window = sliding_window
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if bool((i + 1) % global_attn_every_n_layers) else "full_attention"
for i in range(self.num_hidden_layers)
]

# muP specific
self.mup_enabled = mup_enabled

if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads

# Validate rope configs
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)


__all__ = ["AfmoeConfig"]
Loading
Loading