Skip to content

bug: MoE expert LoRA checkpoints incompatible with HF PEFT and vLLM (Nemotron Super 120B) #1814

@pst2154

Description

@pst2154

Bug Description

When you train LoRA on Nemotron Super 120B with AutoModel and try to merge or serve the adapter, it fails silently for MoE layers. The adapter loads fine for attention layers but MoE expert LoRA weights are saved in a format that nothing else can read.

The Problem (simple version)

AutoModel stores MoE experts as one big grouped tensor. HF and vLLM store them as separate per-expert modules. When LoRA is applied and saved, AutoModel produces grouped LoRA keys that don't exist in HF or vLLM's model:

AutoModel saves:
  experts.lora_gate_and_up_A        shape: [512, 4096, 8]   ← one 3D tensor for ALL experts

HF PEFT expects:
  experts.0.up_proj.lora_A.weight   shape: [8, 4096]        ← one 2D tensor PER expert
  experts.1.up_proj.lora_A.weight   shape: [8, 4096]
  experts.2.up_proj.lora_A.weight   shape: [8, 4096]
  ...× 512 experts

4 grouped tensors vs 2048 per-expert tensors. Completely different names, shapes, and count.

What works and what doesn't

Layer type LoRA keys match HF PEFT? merge_and_unload() works?
Attention (q/k/v/o_proj) ✅ Yes ✅ Yes
MLP (up/down/gate_proj) ✅ Yes ✅ Yes
MoE experts ❌ No ❌ No

Why it happens

The PEFT save path in stateful_wrappers.py (line ~288) just collects raw parameter FQNs and adds the base_model.model. prefix. It never calls state_dict_adapter.to_hf(), which already knows how to split grouped expert tensors into per-expert format.

For base model weights, the adapter handles this conversion correctly. For LoRA weights, it's skipped.

Reproduce

# After training LoRA on Nemotron Super 120B with AutoModel:
from safetensors.torch import load_file
sd = load_file("checkpoints/adapter_model.safetensors")

# MoE expert LoRA keys look like this (WRONG for HF PEFT):
for k in sorted(sd):
    if "expert" in k:
        print(k, sd[k].shape)
# base_model.model.model.layers.79.mixer.experts.lora_gate_and_up_A  [512, 4096, 8]
# base_model.model.model.layers.79.mixer.experts.lora_gate_and_up_B  [512, 8, 2688]
# ...

# HF PEFT expects these (which don't exist in the file):
# base_model.model.model.layers.79.mixer.experts.0.up_proj.lora_A.weight  [8, 4096]
# base_model.model.model.layers.79.mixer.experts.0.up_proj.lora_B.weight  [2688, 8]

Fix

In nemo_automodel/components/checkpoint/stateful_wrappers.py, apply state_dict_adapter to LoRA weights the same way it's already applied to base weights:

Save path (~line 288):

        if self.is_peft and not _has_quantized_params(self.model[0]):
            # NEW: convert grouped expert LoRA tensors to per-expert HF format
            adapter = getattr(self.model[0], "state_dict_adapter", None)
            if adapter:
                model_state_dict = adapter.to_hf(model_state_dict)
            _add_outer_prefix(model_state_dict, "base_model.model.")
            _rename_dora_keys_to_hf(model_state_dict)

Load path (~line 311):

        if self.is_peft:
            _drop_outer_prefix(state_dict, "base_model.model.")
            # NEW: convert per-expert HF format back to grouped
            adapter = getattr(self.model[0], "state_dict_adapter", None)
            if adapter:
                state_dict = adapter.from_hf(state_dict)
            _rename_dora_keys_from_hf(state_dict)

The state_dict_adapter is a no-op for models without MoE expert splitting, so this is safe for all models.

Note: The state_dict_adapter.to_hf() currently only handles base weight tensors (gate_and_up_projs → per-expert up_proj.weight). It would also need to handle LoRA tensors (lora_gate_and_up_A/B → per-expert up_proj.lora_A/B.weight), likely via an extension to convert_single_tensor_to_hf() in nemotron_v3/state_dict_adapter.py.

Correction from original issue

The original version of this issue reported a model.* vs backbone.* prefix mismatch. That is not a bug — HuggingFace updated their NemotronH code to use self.model (matching AutoModel), with _get_key_renaming_mapping({"^backbone": "model"}) for backward compat. Non-MoE LoRA keys are fully compatible as-is.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions