Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 228 additions & 7 deletions mlx_lm/models/nemotron_h.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright © 2025 Apple Inc.

from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import partial
from typing import Any, List, Optional, Tuple

Expand Down Expand Up @@ -56,6 +56,9 @@ class ModelArgs(BaseModelArgs):
time_step_limit: Optional[Tuple[float, float]] = None
time_step_min: Optional[float] = None
time_step_max: Optional[float] = None
num_nextn_predict_layers: int = 0
mtp_hybrid_override_pattern: Optional[str] = None
mtp_layers_block_type: Optional[List[str]] = None

# Map from layers_block_type names to single-char pattern codes
_block_type_to_char = {"mamba": "M", "attention": "*", "moe": "E", "mlp": "-"}
Expand All @@ -72,6 +75,19 @@ def __post_init__(self):
if self.hybrid_override_pattern is not None:
self.num_hidden_layers = len(self.hybrid_override_pattern)

# Normalize MTP pattern
if self.mtp_hybrid_override_pattern is not None:
if isinstance(self.mtp_hybrid_override_pattern, str):
self._mtp_pattern = list(self.mtp_hybrid_override_pattern)
else:
self._mtp_pattern = list(self.mtp_hybrid_override_pattern)
elif self.mtp_layers_block_type is not None:
self._mtp_pattern = [
self._block_type_to_char[t] for t in self.mtp_layers_block_type
]
else:
self._mtp_pattern = []


class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size: int, eps: float, group_size: int):
Expand Down Expand Up @@ -455,6 +471,105 @@ def __call__(
return x + hidden_states


class NemotronHMTPBlock(nn.Module):
"""A single block in the MTP head. Follows the same pattern as
NemotronHBlock but only supports attention ('*') and MoE ('E') types,
matching the ``mtp_hybrid_override_pattern``."""

def __init__(self, args: ModelArgs, block_type: str):
super().__init__()
self.block_type = block_type
self.norm = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)

if block_type == "*":
self.mixer = NemotronHAttention(args)
elif block_type == "E":
self.mixer = NemotronHMoE(args)
elif block_type == "-":
self.mixer = NemotronHMLP(args)
else:
raise ValueError(f"Unsupported MTP block type: {block_type}")

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.norm(x)
if self.block_type == "*":
h = self.mixer(h, mask=mask, cache=cache)
else:
h = self.mixer(h)
return x + h


class NemotronHMTPModule(nn.Module):
"""Multi-Token Prediction head for Nemotron-H.

Predicts token t+2 from the backbone's pre-norm hidden state h_t and the
sampled token t+1, using a shared ``lm_head`` with the backbone.

Architecture (for ``mtp_hybrid_override_pattern = "*E"``):
1. Embed next_token via shared embedding
2. Dual-norm fusion: ``eh_proj(cat(enorm(embed), hnorm(hidden)))``
3. Attention block (``*``) with its own KVCache
4. MoE block (``E``) — same structure as backbone MoE
5. Final layernorm

Weight mapping from HF checkpoints:
``mtp.layers.0.hnorm`` → ``hnorm``
``mtp.layers.0.enorm`` → ``enorm``
``mtp.layers.0.eh_proj`` → ``eh_proj``
``mtp.layers.0.norm`` → ``layers.0.norm``
``mtp.layers.0.mixer.*`` → ``layers.0.mixer.*``
``mtp.layers.1.norm`` → ``layers.1.norm``
``mtp.layers.1.mixer.*`` → ``layers.1.mixer.*``
``mtp.layers.1.final_layernorm`` → ``final_layernorm``
"""

def __init__(self, args: ModelArgs):
super().__init__()
self.hnorm = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.enorm = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.eh_proj = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False)
self.layers = [NemotronHMTPBlock(args, bt) for bt in args._mtp_pattern]
self.final_layernorm = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)

def __call__(
self,
hidden_states: mx.array,
next_token_ids: mx.array,
embed_tokens: nn.Embedding,
cache: Optional[Any] = None,
) -> mx.array:
embeds = embed_tokens(next_token_ids)
e = self.enorm(embeds)
h = self.hnorm(hidden_states)
fused = self.eh_proj(mx.concatenate([e, h], axis=-1))

if cache is None:
cache = [None] * len(self.layers)

# Build attention mask from the first attention layer's cache
attn_cache_idx = 0
for i, layer in enumerate(self.layers):
if layer.block_type == "*":
attn_cache_idx = i
break
mask = create_attention_mask(fused, cache[attn_cache_idx])

cache_idx = 0
for layer in self.layers:
if layer.block_type == "*":
fused = layer(fused, mask=mask, cache=cache[cache_idx])
cache_idx += 1
else:
fused = layer(fused)

return self.final_layernorm(fused)


class NemotronHModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
Expand Down Expand Up @@ -503,7 +618,7 @@ def __call__(
mask = ssm_mask
hidden_states = layer(hidden_states, mask=mask, cache=c)

return self.norm_f(hidden_states)
return hidden_states


class Model(nn.Module):
Expand All @@ -513,14 +628,45 @@ def __init__(self, args: ModelArgs):
self.backbone = NemotronHModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.model_type = args.model_type
if args.num_nextn_predict_layers > 0 and len(args._mtp_pattern) > 0:
self.mtp = NemotronHMTPModule(args)

def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
return_hidden: bool = False,
n_confirmed: int = 0,
):
out = self.backbone(inputs, cache=cache)
return self.lm_head(out)
hidden = self.backbone(inputs, cache=cache)
out = self.lm_head(self.backbone.norm_f(hidden))
if return_hidden:
return out, hidden
return out

def mtp_forward(
self,
hidden_states: mx.array,
next_token_ids: mx.array,
mtp_cache: Any,
) -> mx.array:
"""Run the MTP head and apply the shared lm_head.

Args:
hidden_states: (B, 1, H) — backbone pre-norm hidden at last position.
next_token_ids: (B, 1) — sampled main token.
mtp_cache: list of KVCache entries for MTP attention layers.

Returns:
logits: (B, 1, vocab_size)
"""
mtp_out = self.mtp(
hidden_states,
next_token_ids,
self.backbone.embeddings,
mtp_cache,
)
return self.lm_head(mtp_out)

@property
def layers(self):
Expand All @@ -535,13 +681,22 @@ def make_cache(self):
caches.append(KVCache())
return caches

def make_mtp_cache(self):
"""Return a fresh list of KVCache entries for MTP attention layers."""
if hasattr(self, "mtp"):
return [KVCache() for layer in self.mtp.layers if layer.block_type == "*"]
return []

def sanitize(self, weights):
weights = {k: v for (k, v) in weights.items() if not k.startswith("mtp.")}
for k, v in weights.items():
has_mtp = self.args.num_nextn_predict_layers > 0
if not has_mtp:
weights = {k: v for (k, v) in weights.items() if not k.startswith("mtp.")}

for k, v in list(weights.items()):
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)

# Stack experts
# Stack backbone experts
for l in range(self.args.num_hidden_layers):
prefix = f"backbone.layers.{l}.mixer"
for m, n in [("down_proj", "fc2"), ("up_proj", "fc1")]:
Expand All @@ -552,6 +707,72 @@ def sanitize(self, weights):
]
weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join)

if has_mtp:
# Remap MTP weights from HF naming to our module structure.
#
# HF layout (num_nextn_predict_layers=1, pattern "*E"):
# mtp.layers.0.hnorm → fusion norms
# mtp.layers.0.enorm
# mtp.layers.0.eh_proj → fusion projection
# mtp.layers.0.norm → attention block norm
# mtp.layers.0.mixer.* → attention block mixer
# mtp.layers.1.norm → MoE block norm
# mtp.layers.1.mixer.* → MoE block mixer
# mtp.layers.1.final_layernorm → final output norm
#
# Our layout:
# mtp.hnorm, mtp.enorm, mtp.eh_proj
# mtp.layers.0.norm, mtp.layers.0.mixer.* (attention)
# mtp.layers.1.norm, mtp.layers.1.mixer.* (MoE)
# mtp.final_layernorm

remap = {}
mtp_keys = [k for k in weights if k.startswith("mtp.")]
for k in mtp_keys:
v = weights.pop(k)
rest = k[len("mtp.") :]

# Fusion components live on HF layer 0
if rest.startswith("layers.0.hnorm."):
new_k = "mtp." + rest.replace("layers.0.hnorm.", "hnorm.")
elif rest.startswith("layers.0.enorm."):
new_k = "mtp." + rest.replace("layers.0.enorm.", "enorm.")
elif rest.startswith("layers.0.eh_proj."):
new_k = "mtp." + rest.replace("layers.0.eh_proj.", "eh_proj.")
# Attention block: HF layer 0 norm/mixer → our layers.0
elif rest.startswith("layers.0.norm."):
new_k = "mtp.layers.0.norm." + rest[len("layers.0.norm.") :]
elif rest.startswith("layers.0.mixer."):
new_k = "mtp.layers.0.mixer." + rest[len("layers.0.mixer.") :]
# MoE block: HF layer 1 → our layers.1
elif rest.startswith("layers.1.norm."):
new_k = "mtp.layers.1.norm." + rest[len("layers.1.norm.") :]
elif rest.startswith("layers.1.final_layernorm."):
new_k = (
"mtp.final_layernorm."
+ rest[len("layers.1.final_layernorm.") :]
)
elif rest.startswith("layers.1.mixer."):
new_k = "mtp.layers.1.mixer." + rest[len("layers.1.mixer.") :]
else:
new_k = "mtp." + rest

remap[new_k] = v

# Stack MTP MoE experts (same pattern as backbone)
for m, n in [("down_proj", "fc2"), ("up_proj", "fc1")]:
expert_key = f"mtp.layers.1.mixer.experts.0.{m}.weight"
if expert_key in remap:
to_join = [
remap.pop(f"mtp.layers.1.mixer.experts.{e}.{m}.weight")
for e in range(self.args.n_routed_experts)
]
remap[f"mtp.layers.1.mixer.switch_mlp.{n}.weight"] = mx.stack(
to_join
)

weights.update(remap)

return weights

@property
Expand Down