diff --git a/mlx_lm/models/nemotron_h.py b/mlx_lm/models/nemotron_h.py index 353de36c9..e3cddf2d5 100644 --- a/mlx_lm/models/nemotron_h.py +++ b/mlx_lm/models/nemotron_h.py @@ -1,6 +1,6 @@ # Copyright © 2025 Apple Inc. -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial from typing import Any, List, Optional, Tuple @@ -56,6 +56,9 @@ class ModelArgs(BaseModelArgs): time_step_limit: Optional[Tuple[float, float]] = None time_step_min: Optional[float] = None time_step_max: Optional[float] = None + num_nextn_predict_layers: int = 0 + mtp_hybrid_override_pattern: Optional[str] = None + mtp_layers_block_type: Optional[List[str]] = None # Map from layers_block_type names to single-char pattern codes _block_type_to_char = {"mamba": "M", "attention": "*", "moe": "E", "mlp": "-"} @@ -72,6 +75,19 @@ def __post_init__(self): if self.hybrid_override_pattern is not None: self.num_hidden_layers = len(self.hybrid_override_pattern) + # Normalize MTP pattern + if self.mtp_hybrid_override_pattern is not None: + if isinstance(self.mtp_hybrid_override_pattern, str): + self._mtp_pattern = list(self.mtp_hybrid_override_pattern) + else: + self._mtp_pattern = list(self.mtp_hybrid_override_pattern) + elif self.mtp_layers_block_type is not None: + self._mtp_pattern = [ + self._block_type_to_char[t] for t in self.mtp_layers_block_type + ] + else: + self._mtp_pattern = [] + class MambaRMSNormGated(nn.Module): def __init__(self, hidden_size: int, eps: float, group_size: int): @@ -455,6 +471,105 @@ def __call__( return x + hidden_states +class NemotronHMTPBlock(nn.Module): + """A single block in the MTP head. Follows the same pattern as + NemotronHBlock but only supports attention ('*') and MoE ('E') types, + matching the ``mtp_hybrid_override_pattern``.""" + + def __init__(self, args: ModelArgs, block_type: str): + super().__init__() + self.block_type = block_type + self.norm = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + if block_type == "*": + self.mixer = NemotronHAttention(args) + elif block_type == "E": + self.mixer = NemotronHMoE(args) + elif block_type == "-": + self.mixer = NemotronHMLP(args) + else: + raise ValueError(f"Unsupported MTP block type: {block_type}") + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + h = self.norm(x) + if self.block_type == "*": + h = self.mixer(h, mask=mask, cache=cache) + else: + h = self.mixer(h) + return x + h + + +class NemotronHMTPModule(nn.Module): + """Multi-Token Prediction head for Nemotron-H. + + Predicts token t+2 from the backbone's pre-norm hidden state h_t and the + sampled token t+1, using a shared ``lm_head`` with the backbone. + + Architecture (for ``mtp_hybrid_override_pattern = "*E"``): + 1. Embed next_token via shared embedding + 2. Dual-norm fusion: ``eh_proj(cat(enorm(embed), hnorm(hidden)))`` + 3. Attention block (``*``) with its own KVCache + 4. MoE block (``E``) — same structure as backbone MoE + 5. Final layernorm + + Weight mapping from HF checkpoints: + ``mtp.layers.0.hnorm`` → ``hnorm`` + ``mtp.layers.0.enorm`` → ``enorm`` + ``mtp.layers.0.eh_proj`` → ``eh_proj`` + ``mtp.layers.0.norm`` → ``layers.0.norm`` + ``mtp.layers.0.mixer.*`` → ``layers.0.mixer.*`` + ``mtp.layers.1.norm`` → ``layers.1.norm`` + ``mtp.layers.1.mixer.*`` → ``layers.1.mixer.*`` + ``mtp.layers.1.final_layernorm`` → ``final_layernorm`` + """ + + def __init__(self, args: ModelArgs): + super().__init__() + self.hnorm = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.enorm = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.eh_proj = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False) + self.layers = [NemotronHMTPBlock(args, bt) for bt in args._mtp_pattern] + self.final_layernorm = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__( + self, + hidden_states: mx.array, + next_token_ids: mx.array, + embed_tokens: nn.Embedding, + cache: Optional[Any] = None, + ) -> mx.array: + embeds = embed_tokens(next_token_ids) + e = self.enorm(embeds) + h = self.hnorm(hidden_states) + fused = self.eh_proj(mx.concatenate([e, h], axis=-1)) + + if cache is None: + cache = [None] * len(self.layers) + + # Build attention mask from the first attention layer's cache + attn_cache_idx = 0 + for i, layer in enumerate(self.layers): + if layer.block_type == "*": + attn_cache_idx = i + break + mask = create_attention_mask(fused, cache[attn_cache_idx]) + + cache_idx = 0 + for layer in self.layers: + if layer.block_type == "*": + fused = layer(fused, mask=mask, cache=cache[cache_idx]) + cache_idx += 1 + else: + fused = layer(fused) + + return self.final_layernorm(fused) + + class NemotronHModel(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -503,7 +618,7 @@ def __call__( mask = ssm_mask hidden_states = layer(hidden_states, mask=mask, cache=c) - return self.norm_f(hidden_states) + return hidden_states class Model(nn.Module): @@ -513,14 +628,45 @@ def __init__(self, args: ModelArgs): self.backbone = NemotronHModel(args) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.model_type = args.model_type + if args.num_nextn_predict_layers > 0 and len(args._mtp_pattern) > 0: + self.mtp = NemotronHMTPModule(args) def __call__( self, inputs: mx.array, cache: Optional[Any] = None, + return_hidden: bool = False, + n_confirmed: int = 0, ): - out = self.backbone(inputs, cache=cache) - return self.lm_head(out) + hidden = self.backbone(inputs, cache=cache) + out = self.lm_head(self.backbone.norm_f(hidden)) + if return_hidden: + return out, hidden + return out + + def mtp_forward( + self, + hidden_states: mx.array, + next_token_ids: mx.array, + mtp_cache: Any, + ) -> mx.array: + """Run the MTP head and apply the shared lm_head. + + Args: + hidden_states: (B, 1, H) — backbone pre-norm hidden at last position. + next_token_ids: (B, 1) — sampled main token. + mtp_cache: list of KVCache entries for MTP attention layers. + + Returns: + logits: (B, 1, vocab_size) + """ + mtp_out = self.mtp( + hidden_states, + next_token_ids, + self.backbone.embeddings, + mtp_cache, + ) + return self.lm_head(mtp_out) @property def layers(self): @@ -535,13 +681,22 @@ def make_cache(self): caches.append(KVCache()) return caches + def make_mtp_cache(self): + """Return a fresh list of KVCache entries for MTP attention layers.""" + if hasattr(self, "mtp"): + return [KVCache() for layer in self.mtp.layers if layer.block_type == "*"] + return [] + def sanitize(self, weights): - weights = {k: v for (k, v) in weights.items() if not k.startswith("mtp.")} - for k, v in weights.items(): + has_mtp = self.args.num_nextn_predict_layers > 0 + if not has_mtp: + weights = {k: v for (k, v) in weights.items() if not k.startswith("mtp.")} + + for k, v in list(weights.items()): if "conv1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) - # Stack experts + # Stack backbone experts for l in range(self.args.num_hidden_layers): prefix = f"backbone.layers.{l}.mixer" for m, n in [("down_proj", "fc2"), ("up_proj", "fc1")]: @@ -552,6 +707,72 @@ def sanitize(self, weights): ] weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join) + if has_mtp: + # Remap MTP weights from HF naming to our module structure. + # + # HF layout (num_nextn_predict_layers=1, pattern "*E"): + # mtp.layers.0.hnorm → fusion norms + # mtp.layers.0.enorm + # mtp.layers.0.eh_proj → fusion projection + # mtp.layers.0.norm → attention block norm + # mtp.layers.0.mixer.* → attention block mixer + # mtp.layers.1.norm → MoE block norm + # mtp.layers.1.mixer.* → MoE block mixer + # mtp.layers.1.final_layernorm → final output norm + # + # Our layout: + # mtp.hnorm, mtp.enorm, mtp.eh_proj + # mtp.layers.0.norm, mtp.layers.0.mixer.* (attention) + # mtp.layers.1.norm, mtp.layers.1.mixer.* (MoE) + # mtp.final_layernorm + + remap = {} + mtp_keys = [k for k in weights if k.startswith("mtp.")] + for k in mtp_keys: + v = weights.pop(k) + rest = k[len("mtp.") :] + + # Fusion components live on HF layer 0 + if rest.startswith("layers.0.hnorm."): + new_k = "mtp." + rest.replace("layers.0.hnorm.", "hnorm.") + elif rest.startswith("layers.0.enorm."): + new_k = "mtp." + rest.replace("layers.0.enorm.", "enorm.") + elif rest.startswith("layers.0.eh_proj."): + new_k = "mtp." + rest.replace("layers.0.eh_proj.", "eh_proj.") + # Attention block: HF layer 0 norm/mixer → our layers.0 + elif rest.startswith("layers.0.norm."): + new_k = "mtp.layers.0.norm." + rest[len("layers.0.norm.") :] + elif rest.startswith("layers.0.mixer."): + new_k = "mtp.layers.0.mixer." + rest[len("layers.0.mixer.") :] + # MoE block: HF layer 1 → our layers.1 + elif rest.startswith("layers.1.norm."): + new_k = "mtp.layers.1.norm." + rest[len("layers.1.norm.") :] + elif rest.startswith("layers.1.final_layernorm."): + new_k = ( + "mtp.final_layernorm." + + rest[len("layers.1.final_layernorm.") :] + ) + elif rest.startswith("layers.1.mixer."): + new_k = "mtp.layers.1.mixer." + rest[len("layers.1.mixer.") :] + else: + new_k = "mtp." + rest + + remap[new_k] = v + + # Stack MTP MoE experts (same pattern as backbone) + for m, n in [("down_proj", "fc2"), ("up_proj", "fc1")]: + expert_key = f"mtp.layers.1.mixer.experts.0.{m}.weight" + if expert_key in remap: + to_join = [ + remap.pop(f"mtp.layers.1.mixer.experts.{e}.{m}.weight") + for e in range(self.args.n_routed_experts) + ] + remap[f"mtp.layers.1.mixer.switch_mlp.{n}.weight"] = mx.stack( + to_join + ) + + weights.update(remap) + return weights @property