diff --git a/README.md b/README.md index dbe2c363a563..ae37b13e1236 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) - [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7) - [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86) +- [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum) #### Multimodal diff --git a/conversion/__init__.py b/conversion/__init__.py index 3ceb2d38536a..8415c65f9432 100644 --- a/conversion/__init__.py +++ b/conversion/__init__.py @@ -135,6 +135,7 @@ "Mamba2ForCausalLM": "mamba", "MambaForCausalLM": "mamba", "MambaLMHeadModel": "mamba", + "MellumForCausalLM": "mellum", "MiMoV2FlashForCausalLM": "mimo", "MiMoV2ForCausalLM": "mimo", "MiniCPM3ForCausalLM": "minicpm", diff --git a/conversion/base.py b/conversion/base.py index 729ddbca4ae1..408e209aa884 100644 --- a/conversion/base.py +++ b/conversion/base.py @@ -1657,6 +1657,15 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "36f3066e97b7f3994b379aaacde306c1444c6ae84e81a5ae3cd2b7ed3b8c42d4": # ref: https://huggingface.co/openbmb/MiniCPM5-1B res = "minicpm5" + if chkhsh == "f241072145675bf8322086f115aebad05e9f869557a238bf2150a2a417d1bf60": + # ref: https://huggingface.co/ibm-granite/granite-embedding-97m-multilingual-r2 + res = "granite-embed-multi-97m" + if chkhsh == "789696f5946cc0fc59371f39f6097cafed196b3acded6140432f26bbb1ae1669": + # ref: https://huggingface.co/ibm-granite/granite-embedding-311m-multilingual-r2 + res = "granite-embed-multi-311m" + if chkhsh == "9dcf830ee9990cdbf78cc523a5f7bd9ad8f3f9890c2d3581d2785ad10f07049d": + # ref: https://huggingface.co/JetBrains/Mellum2-12B-A2.5B-Base + res = "mellum2" if res is None: logger.warning("\n") diff --git a/conversion/bert.py b/conversion/bert.py index 9eb320e58aad..49a6948f6ce5 100644 --- a/conversion/bert.py +++ b/conversion/bert.py @@ -603,6 +603,12 @@ def set_gguf_parameters(self): self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + # FFN activation: ModernBert uses a GLU pair (ffn_up output is 2*n_ff). The + # original ModernBERT uses GELU (-> GeGLU); some derivatives such as IBM + # Granite Embedding 97m R2 use SiLU (-> SwiGLU). Persist this so the + # llama.cpp graph can pick the matching activation. + if hidden_act := self.hparams.get("hidden_activation"): + self.gguf_writer.add_hidden_act(hidden_act) @classmethod def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None: diff --git a/conversion/mellum.py b/conversion/mellum.py new file mode 100644 index 000000000000..79bc6755ccca --- /dev/null +++ b/conversion/mellum.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import Iterable, TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from torch import Tensor + +from .base import ModelBase, TextModel, gguf, logger + + +@ModelBase.register("MellumForCausalLM") +class MellumModel(TextModel): + model_arch = gguf.MODEL_ARCH.MELLUM + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") + + use_sliding_window = self.hparams.get("use_sliding_window") + sliding_window = self.hparams.get("sliding_window") + if (use_sliding_window is True or use_sliding_window is None) and sliding_window is not None: + self.gguf_writer.add_sliding_window(sliding_window) + logger.info(f"gguf: sliding window = {sliding_window}") + self.gguf_writer.add_sliding_window_pattern([t == "sliding_attention" for t in self.hparams["layer_types"]]) + logger.info(f"gguf: sliding window pattern length = {len(self.hparams['layer_types'])}") + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.find("experts") != -1: + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + yield from super().modify_tensors(data_torch, merged_name, bid) + return + else: + return + + yield from super().modify_tensors(data_torch, name, bid) diff --git a/conversion/step3.py b/conversion/step3.py index eeba66c7a8a9..8c45b61c954a 100644 --- a/conversion/step3.py +++ b/conversion/step3.py @@ -99,6 +99,34 @@ class Step3VLTextModel(Qwen3Model): class Step35Model(TextModel): model_arch = gguf.MODEL_ARCH.STEP35 + # The --mtp / --no-mtp toggles are ModelBase.mtp_only / no_mtp (set in + # convert_hf_to_gguf.py main()). Unlike Qwen3.5, which stores MTP under a + # `mtp.*` namespace, Step3.5 appends MTP layers at + # `model.layers.{num_hidden_layers + i}`, so we filter them by layer index. + # The trunk layer count is captured before indexing so the classmethod + # filter_tensors can tell the appended MTP block(s) apart from the trunk. + _n_main_layers: int | None = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # NextN/MTP layers are appended past num_hidden_layers; extend the + # tensor map to cover them so the MTP block's tensors get correctly + # indexed names. When --no-mtp drops the MTP blocks, fall back to the + # base num_hidden_layers so we don't reserve unused slots. + n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0)) + if n_nextn > 0 and not self.no_mtp: + self.block_count += n_nextn + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def index_tensors(self, remote_hf_model_id: str | None = None): + # filter_tensors is a classmethod and can't reach self.hparams; stash + # the trunk layer count here (before indexing runs) so it can detect + # the appended MTP layers by index. + hparams = {**self.hparams, **self.hparams.get("text_config", {})} + key = next((k for k in ["n_layers", "num_hidden_layers", "n_layer", "num_layers"] if k in hparams), None) + type(self)._n_main_layers = hparams.get(key) + return super().index_tensors(remote_hf_model_id=remote_hf_model_id) + def set_gguf_parameters(self): rope_theta = self.hparams.get("rope_theta") if isinstance(rope_theta, list): @@ -119,8 +147,25 @@ def set_gguf_parameters(self): n_head_swa = attn_other.get("num_attention_heads", n_head_base) n_kv_swa = attn_other.get("num_attention_groups", n_kv_base) - layer_types = layer_types[: self.block_count] - partial_rotary_factors = partial_rotary_factors[: self.block_count] + n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0)) + + # The Step3p5 HF checkpoint stores layer_types/partial_rotary_factors + # entries for the MTP blocks past num_hidden_layers; preserve them so + # the MTP layer's attention shape, SWA flag, and partial RoPE dim are + # set correctly. Pad with full-attention defaults if the checkpoint + # truncated them. + def _pad(arr, n, default): + arr = list(arr) + if len(arr) < n: + arr = arr + [default] * (n - len(arr)) + return arr[:n] + + layer_types = _pad(layer_types, self.block_count, "full_attention") + partial_rotary_factors = _pad( + partial_rotary_factors, + self.block_count, + 0.5, # full_attention default for Step3p5 + ) assert [1.0 if lt == "sliding_attention" else 0.5 for lt in layer_types] == partial_rotary_factors head_arr = [n_head_swa if lt == "sliding_attention" else n_head_base for lt in layer_types] kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types] @@ -157,31 +202,61 @@ def set_gguf_parameters(self): self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5)) - # Optional per-layer SwiGLU clamps. + # Optional per-layer SwiGLU clamps. MTP layers default to no clamping (0.0). if (limits := self.hparams.get("swiglu_limits")) is not None: - limits_f = [0.0 if v is None else float(v) for v in limits[: self.block_count]] + limits_f = _pad( + [0.0 if v is None else float(v) for v in limits], + self.block_count, + 0.0, + ) self.gguf_writer.add_swiglu_clamp_exp(limits_f) if (limits_shared := self.hparams.get("swiglu_limits_shared")) is not None: - limits_shared_f = [0.0 if v is None else float(v) for v in limits_shared[: self.block_count]] + limits_shared_f = _pad( + [0.0 if v is None else float(v) for v in limits_shared], + self.block_count, + 0.0, + ) self.gguf_writer.add_swiglu_clamp_shexp(limits_shared_f) + if n_nextn > 0 and not self.no_mtp: + self.gguf_writer.add_nextn_predict_layers(n_nextn) + @classmethod def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None: - name, gen = item + if (titem := super().filter_tensors(item)) is None: + return None + name, gen = titem # Map router bias (expert selection bias) to a GGUF bias tensor if name.endswith(".moe.router_bias"): name += ".bias" - return super().filter_tensors((name, gen)) + # Step3.5 appends the MTP block(s) past num_hidden_layers. + assert cls._n_main_layers is not None + is_mtp = (m := re.match(r"model\.layers\.(\d+)\.", name)) is not None and int(m.group(1)) >= cls._n_main_layers + + # --no-mtp: drop the appended MTP block(s) entirely. + if is_mtp and cls.no_mtp: + return None + # --mtp: keep ONLY MTP-block tensors plus the shared embeddings/norm/ + # lm_head (so the resulting GGUF carries just the draft head). + if cls.mtp_only and not is_mtp and name not in ( + "model.embed_tokens.weight", "model.norm.weight", "lm_head.weight", + ): + return None + + # The checkpoint nests the per-MTP-layer shared head under + # `model.layers.{N+i}.transformer.shared_head.{norm,output}.weight`; + # strip the `transformer.` infix and rename `output` → `head` so the + # existing NEXTN_SHARED_HEAD_{NORM,HEAD} tensor mapping picks them up. + # Mirrors vllm's `_rewrite_spec_layer_name` (step3p5_mtp.py). + if is_mtp: + name = name.replace(".transformer.", ".") + name = name.replace("shared_head.output", "shared_head.head") + + return name, gen def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): - # remove mtp layers - if (m := re.match(r"model\.layers\.(\d+)\.", name)) is not None: - il = int(m.group(1)) - n_main = int(self.hparams.get("num_hidden_layers", self.block_count)) - if il >= n_main: - return if name.endswith("norm.weight"): data_torch += 1.0 @@ -190,6 +265,21 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): yield from super().modify_tensors(data_torch, name, bid) + def prepare_metadata(self, vocab_only: bool): + from_dir = self.fname_out.is_dir() + super().prepare_metadata(vocab_only=vocab_only) + + # Mirror Qwen3.5's behavior: when emitting a draft-only file into a + # directory, prefix with "mtp-" so it doesn't collide with the trunk. + if not self.mtp_only or not from_dir: + return + + output_type: str = self.ftype.name.partition("_")[2] + fname_default: str = gguf.naming_convention( + self.metadata.name, self.metadata.basename, self.metadata.finetune, + self.metadata.version, size_label=None, output_type=output_type, model_type=None) + self.fname_out = self.fname_out.parent / f"mtp-{fname_default}.gguf" + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: # Step35 can optionally use Llama-3 style RoPE scaling (HF: rope_scaling.rope_type == "llama3"). # llama.cpp represents this via a single extra tensor: "rope_freqs.weight" (aka MODEL_TENSOR.ROPE_FREQS). diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 85527553563d..cd19eebdfa34 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -251,8 +251,9 @@ def main() -> None: if args.mtp or args.no_mtp: from conversion.qwen import _Qwen35MtpMixin - if not issubclass(model_class, _Qwen35MtpMixin): - logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 text variants today") + from conversion.step3 import Step35Model + if not (issubclass(model_class, _Qwen35MtpMixin) or issubclass(model_class, Step35Model)): + logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 and Step3.5 text variants today") sys.exit(1) if args.no_mtp: model_class.no_mtp = True diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 827af277b929..b4c8a7cf00a3 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -158,6 +158,9 @@ class TOKENIZER_TYPE(IntEnum): {"name": "sarvam-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sarvamai/sarvam-30b", }, {"name": "talkie", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/lewtun/talkie-1930-13b-it-hf", }, {"name": "minicpm5", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openbmb/MiniCPM5-1B"}, + {"name": "granite-embed-multi-97m", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-embedding-97m-multilingual-r2", }, + {"name": "granite-embed-multi-311m", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-embedding-311m-multilingual-r2", }, + {"name": "mellum2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum2-12B-A2.5B-Base"}, ] # some models are known to be broken upstream, so we will skip them as exceptions diff --git a/ggml/src/ggml-hexagon/htp-opnode.h b/ggml/src/ggml-hexagon/htp-opnode.h index 8a1228ccdc0b..52c727c6206b 100644 --- a/ggml/src/ggml-hexagon/htp-opnode.h +++ b/ggml/src/ggml-hexagon/htp-opnode.h @@ -56,6 +56,20 @@ struct htp_opnode { } std::vector get_inputs() const { + if (fused.empty()) { + int last_non_null = -1; + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node->src[i]) { + last_non_null = i; + } + } + std::vector inputs(last_non_null + 1, nullptr); + for (int i = 0; i <= last_non_null; i++) { + inputs[i] = node->src[i]; + } + return inputs; + } + std::vector inputs(GGML_MAX_SRC, nullptr); std::vector outputs; outputs.push_back(node); @@ -82,12 +96,8 @@ struct htp_opnode { }; for (int i = 0; i < GGML_MAX_SRC; i++) { - if (fused.empty()) { - inputs[i] = node->src[i]; - } else { - if (node->src[i]) { - add_input(node->src[i]); - } + if (node->src[i]) { + add_input(node->src[i]); } } for (const auto * f : fused) { @@ -98,10 +108,7 @@ struct htp_opnode { } } - if (!fused.empty()) { - inputs.resize(count); - } - + inputs.resize(count); return inputs; } diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index b67ea46bce8e..c411e4aeaec4 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4950,6 +4950,21 @@ inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backen return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 } +static inline bool use_flat_gemv_for_large_m_q4_K(const ggml_tensor *tensor) { + // gemv_noshuffle variant perf drops for large M, use flat variant for large M. + // threshold is well above typical hidden/FFN dims, but below typical vocab sizes. + // note that this forces large M weights to use LM GEMM. + return tensor->ne[1] >= 32768 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool use_flat_gemv_for_large_m_q6_K(const ggml_tensor *tensor) { + // gemv_noshuffle variant perf drops for large M, use flat variant for large M. + // threshold is well above typical hidden/FFN dims, but below typical vocab sizes. + // q6_K flat gemv is worse for smaller K; 2048 seems to be a reasonable threshold. + // note that this forces large M weights to use LM GEMM. + return tensor->ne[1] >= 32768 && tensor->ne[0] >= 2048 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context; ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; @@ -6595,7 +6610,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle; } #else @@ -6623,7 +6638,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -6923,7 +6938,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_kernel kernel; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS kernel = backend_ctx->kernel_convert_block_q6_K; - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { kernel = backend_ctx->kernel_convert_block_q6_K_noshuffle; } #else @@ -6956,7 +6971,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { cl_int M = tensor->ne[1]; // ne01 cl_int K = tensor->ne[0]; // ne00 @@ -7599,7 +7614,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -7820,7 +7835,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { static ggml_cl_buffer buf_trans_ql; static ggml_cl_buffer buf_trans_qh; static ggml_cl_buffer buf_trans_s; @@ -13213,13 +13228,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } // q4_k x fp32 - if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32) { + if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32 && !use_flat_gemv_for_large_m_q4_K(src0)) { ggml_cl_mul_mat_q4_k_f32_adreno(backend, src0, src1, dst); return; } // q6_K x fp32 - if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) { + if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32 && !use_flat_gemv_for_large_m_q6_K(src0)) { ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst); return; } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8aed0d76671a..207cc2a1933f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -150,6 +150,7 @@ class LLM: EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input" SWIGLU_CLAMP_EXP = "{arch}.swiglu_clamp_exp" SWIGLU_CLAMP_SHEXP = "{arch}.swiglu_clamp_shexp" + HIDDEN_ACT = "{arch}.hidden_activation" DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in" DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out" @@ -509,6 +510,7 @@ class MODEL_ARCH(IntEnum): MAINCODER = auto() KIMI_LINEAR = auto() TALKIE = auto() + MELLUM = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -1029,6 +1031,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.MAINCODER: "maincoder", MODEL_ARCH.KIMI_LINEAR: "kimi-linear", MODEL_ARCH.TALKIE: "talkie", + MODEL_ARCH.MELLUM: "mellum", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -3994,6 +3997,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_EXP_PROBS_B, + # NextN/MTP tensors (Step3p5 draft head) + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.LLAMA_EMBED: [ MODEL_TENSOR.TOKEN_EMBD, @@ -4085,6 +4095,23 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP, MODEL_TENSOR.LAYER_OUT_SCALE, ], + MODEL_ARCH.MELLUM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index e94b47badb41..63cf6debcc91 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -853,6 +853,9 @@ def add_swiglu_clamp_exp(self, values: Sequence[float]) -> None: def add_swiglu_clamp_shexp(self, values: Sequence[float]) -> None: self.add_array(Keys.LLM.SWIGLU_CLAMP_SHEXP.format(arch=self.arch), values) + def add_hidden_act(self, value: str) -> None: + self.add_string(Keys.LLM.HIDDEN_ACT.format(arch=self.arch), value) + def add_expert_group_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_GROUP_SCALE.format(arch=self.arch), value) diff --git a/pyproject.toml b/pyproject.toml index e4f8c86b9516..46cf68ca1a39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ requires-python = '>=3.10,<3.15' dependencies = [ 'numpy (>=1.26.4,<3.0.0)', 'sentencepiece (>=0.1.98,<0.3.0)', - 'transformers (==5.5.1)', + 'transformers (==4.57.6)', 'protobuf (>=4.21.0,<5.0.0)', 'torch (>=2.6.0,<3.0.0)', 'gguf @ ./gguf-py', diff --git a/requirements/requirements-convert_legacy_llama.txt b/requirements/requirements-convert_legacy_llama.txt index 18d39801066c..28221fad0ce9 100644 --- a/requirements/requirements-convert_legacy_llama.txt +++ b/requirements/requirements-convert_legacy_llama.txt @@ -1,7 +1,7 @@ numpy~=1.26.4 sentencepiece>=0.1.98,<0.3.0 -transformers==5.5.1 +transformers==4.57.6 gguf>=0.1.0 protobuf>=4.21.0,<5.0.0 diff --git a/requirements/requirements-tool_bench.txt b/requirements/requirements-tool_bench.txt index 17d6b866c6b8..3e6f824165c4 100644 --- a/requirements/requirements-tool_bench.txt +++ b/requirements/requirements-tool_bench.txt @@ -1,6 +1,5 @@ aiohttp~=3.9.3 pytest~=8.3.3 -huggingface_hub>=1.5.0,<2.0 matplotlib~=3.10.0 numpy~=1.26.4 openai~=2.14.0 diff --git a/scripts/snapdragon/ggml-hexagon-profile.py b/scripts/snapdragon/ggml-hexagon-profile.py index aa1f20dcc23e..fe94eb6c1903 100755 --- a/scripts/snapdragon/ggml-hexagon-profile.py +++ b/scripts/snapdragon/ggml-hexagon-profile.py @@ -11,6 +11,7 @@ # Mapping of cli-friendly names to (internal_data_key, Display Header, numeric_sort_key) COL_MAP = { + "tot-usec": ("tot_usec", "Tot usec", "_sort_tot_usec"), "op": ("op", "Op", "op"), "dims": ("dims", "Dims", "dims"), "dtypes": ("dtypes", "DTypes", "dtypes"), @@ -24,7 +25,7 @@ } op_pattern = re.compile( - r"profile-op\s+(?P[A-Z_0-9+]+):\s+.*?\s+:\s+(?P[\d:x\s\->!]+)\s+:\s+(?P[a-z\d_\s\->x]+)\s+:\s+.*?\s+usec\s+(?P\d+)\s+cycles\s+(?P\d+)(?:\s+pmu\s+\[(?P[\d,\s]+)\])?" + r"profile-op\s+(?P[A-Z_0-9+]+):\s+.*?\s+:\s+(?P[\d:x\s\->!]+)\s+:\s+(?P[a-z\d_\s\->x]+)\s+:\s+.*?\s+(?:op-)?usec\s+(?P\d+)\s+(?:op-)?cycles\s+(?P\d+)(?:\s+pmu\s+\[(?P[\d,\s]+)\])?" ) logger = logging.getLogger("ggml-hexagon-profile") @@ -85,21 +86,27 @@ def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None): cycles = [o['cycles'] for o in group_ops] pmu_vals = [o['pmu_val'] for o in group_ops if o['pmu_val'] is not None] + avg_usec_val = statistics.mean(usecs) + count_val = len(group_ops) + tot_usec_val = avg_usec_val * count_val + group_stats.append({ 'op': name, 'dims': dims, 'dtypes': types, - 'count': str(len(group_ops)), + 'count': str(count_val), 'max_usec': str(max(usecs)), - 'avg_usec': f"{statistics.mean(usecs):.2f}", + 'avg_usec': f"{avg_usec_val:.2f}", + 'tot_usec': f"{tot_usec_val:.2f}", 'max_cycles': str(max(cycles)), 'avg_cycles': f"{statistics.mean(cycles):.2f}", 'max_pmu': str(max(pmu_vals)) if pmu_vals else "0", 'avg_pmu': f"{statistics.mean(pmu_vals):.2f}" if pmu_vals else "0.00", # Numeric values for accurate sorting - '_sort_count': len(group_ops), + '_sort_count': count_val, '_sort_max_usec': max(usecs), - '_sort_avg_usec': statistics.mean(usecs), + '_sort_avg_usec': avg_usec_val, + '_sort_tot_usec': tot_usec_val, '_sort_max_cycles': max(cycles), '_sort_avg_cycles': statistics.mean(cycles), '_sort_max_pmu': max(pmu_vals) if pmu_vals else 0, @@ -116,7 +123,7 @@ def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None): active_cols = ["op", "dims", "dtypes"] if pmu_name: active_cols += ["max-pmu", "avg-pmu"] - active_cols += ["max-usec", "avg-usec", "max-cycles", "avg-cycles", "count"] + active_cols += ["tot-usec", "avg-usec", "avg-cycles", "max-usec", "max-cycles", "count"] final_headers, final_keys, final_widths = [], [], [] @@ -156,7 +163,7 @@ def main(): parser = argparse.ArgumentParser(description="Post-process Op profile info.") parser.add_argument("logfile") parser.add_argument("-n", "--top", type=int, default=100) - parser.add_argument("--sort", type=str, default="max-usec", choices=list(COL_MAP.keys())) + parser.add_argument("--sort", type=str, default="tot-usec", choices=list(COL_MAP.keys())) parser.add_argument("--pmu-index", type=int) parser.add_argument("--pmu-name", type=str) parser.add_argument("--width", action='append', default=['dims:40'], help="Override column width, e.g. --width dims:50") diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index be8f73cc1edd..8f462396f4a7 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -135,6 +135,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MAINCODER, "maincoder" }, { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, { LLM_ARCH_TALKIE, "talkie" }, + { LLM_ARCH_MELLUM, "mellum" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -195,6 +196,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_MOE_LATENT_SIZE, "%s.moe_latent_size" }, { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, { LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" }, + { LLM_KV_HIDDEN_ACT, "%s.hidden_activation" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 2c71bbe81562..b47c05d90d5b 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -139,6 +139,7 @@ enum llm_arch { LLM_ARCH_MAINCODER, LLM_ARCH_KIMI_LINEAR, LLM_ARCH_TALKIE, + LLM_ARCH_MELLUM, LLM_ARCH_UNKNOWN, }; @@ -199,6 +200,7 @@ enum llm_kv { LLM_KV_MOE_LATENT_SIZE, LLM_KV_NEXTN_PREDICT_LAYERS, LLM_KV_NUM_DEEPSTACK_LAYERS, + LLM_KV_HIDDEN_ACT, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, diff --git a/src/llama-graph.h b/src/llama-graph.h index eab82bd0d706..f2b952b2c3f8 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -36,7 +36,8 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DECODER_MTP, }; -enum llm_ffn_op_type { +enum llm_ffn_op_type : int { + LLM_FFN_NONE = 0, // sentinel: unset; archs must assign before use LLM_FFN_SILU, LLM_FFN_GELU, LLM_FFN_RELU, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index e2d051edc6cd..e4601d30f510 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -23,6 +23,9 @@ enum llama_swa_type { LLAMA_SWA_TYPE_SYMMETRIC = 3, }; +// forward declaration; full definition in llama-graph.h +enum llm_ffn_op_type : int; + struct llama_hparams_posnet { uint32_t n_embd; uint32_t n_layer; @@ -227,6 +230,14 @@ struct llama_hparams { enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + // Resolved FFN gated activation flavor for archs that read + // `.hidden_activation` from the GGUF (e.g. ModernBert derivatives). + // Defaults to LLM_FFN_NONE (sentinel = 0); the mapping from the GGUF + // string to a real op is done at hparam-load time via + // llm_ffn_op_type_from_string() in llama-model.cpp, mirroring how + // rope_scaling_type_train is handled. + enum llm_ffn_op_type llm_ffn_op; + // Step35: optional per-layer clamps for (Swi)GLU std::array swiglu_clamp_exp; // clamping for expert FFN std::array swiglu_clamp_shexp; // shared expert diff --git a/src/llama-model-saver.cpp b/src/llama-model-saver.cpp index 528e4c9c069f..539d17eebc60 100644 --- a/src/llama-model-saver.cpp +++ b/src/llama-model-saver.cpp @@ -29,6 +29,7 @@ bool llama_model_saver_supports_arch(llm_arch arch) { case LLM_ARCH_APERTUS: case LLM_ARCH_MIMO2: case LLM_ARCH_STEP35: + case LLM_ARCH_MELLUM: return false; default: return true; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3e236f8c17d2..bd5635ed4561 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -81,6 +81,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_mpt(params); case LLM_ARCH_STABLELM: return new llama_model_stablelm(params); + case LLM_ARCH_MELLUM: + return new llama_model_mellum(params); case LLM_ARCH_QWEN: return new llama_model_qwen(params); case LLM_ARCH_QWEN2: @@ -764,6 +766,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_A13B: return "A13B"; case LLM_TYPE_7B_A1B: return "7B.A1B"; case LLM_TYPE_8B_A1B: return "8B.A1B"; + case LLM_TYPE_12B_A2_5B: return "12B.A2.5B"; case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_24B_A2B: return "24B.A2B"; @@ -822,6 +825,28 @@ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::st return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; } +// Maps the GGUF `.hidden_activation` string to the FFN op type used by the +// graph builders. Only gated activations that map cleanly to llm_ffn_op_type are +// listed; unrecognized values fall back to GeGLU, which matches the historical +// default for ModernBert-style architectures. +static const std::map LLM_FFN_OP_TYPES_FROM_STRING = { + { "gelu", LLM_FFN_GEGLU }, + { "geglu", LLM_FFN_GEGLU }, + { "silu", LLM_FFN_SWIGLU }, + { "swish", LLM_FFN_SWIGLU }, + { "swiglu", LLM_FFN_SWIGLU }, + { "relu", LLM_FFN_RELU }, + { "reglu", LLM_FFN_REGLU }, +}; + +llm_ffn_op_type llm_ffn_op_type_from_string(const std::string & name, llm_ffn_op_type fallback) { + const auto it = LLM_FFN_OP_TYPES_FROM_STRING.find(name); + if (it != LLM_FFN_OP_TYPES_FROM_STRING.end()) { + return it->second; + } + return fallback; +} + // CPU: ACCEL -> GPU host -> CPU extra -> CPU static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; @@ -1794,7 +1819,11 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { + if (arch == LLM_ARCH_MELLUM || + arch == LLM_ARCH_QWEN3MOE || + arch == LLM_ARCH_OPENAI_MOE || + arch == LLM_ARCH_QWEN3VLMOE || + arch == LLM_ARCH_RND1) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } @@ -2382,6 +2411,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MIMO2: case LLM_ARCH_STEP35: case LLM_ARCH_TALKIE: + case LLM_ARCH_MELLUM: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index 743feb970d99..a561374ed956 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -116,6 +116,7 @@ enum llm_type { LLM_TYPE_A13B, LLM_TYPE_7B_A1B, LLM_TYPE_8B_A1B, // lfm2moe + LLM_TYPE_12B_A2_5B, LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_24B_A2B, // lfm2moe @@ -145,6 +146,10 @@ enum llm_type { std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type); +// Map a GGUF activation-name string to llm_ffn_op_type. Returns `fallback` if +// the string is empty or not recognized. +llm_ffn_op_type llm_ffn_op_type_from_string(const std::string & name, llm_ffn_op_type fallback); + struct llama_layer_posnet { // resnet struct ggml_tensor * norm1 = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 04183efc4d0c..520502398162 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -353,6 +353,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_CODESHELL: case LLAMA_VOCAB_PRE_TYPE_EXAONE: case LLAMA_VOCAB_PRE_TYPE_MINERVA: + case LLAMA_VOCAB_PRE_TYPE_MELLUM2: regex_exprs = { "\\p{N}", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", @@ -432,6 +433,15 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI: + // Same lookaheads as GPT4O but with \p{M} added so combining marks + // (diacritics) attach to their base letters. Avoids excessive + // backtracking on scripts that use them heavily (Bengali, Hindi, + // Telugu, Thai, ...). See PR #22716 for benchmarks. + regex_exprs = { + "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}\\p{M}])([^a-z]))*((?=[\\p{L}\\p{M}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}\\p{M}])([^a-z]))+((?=[\\p{L}\\p{M}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_TINY_AYA: regex_exprs = { // original regex from tokenizer.json: "\\d{1,3}(?=(?:\\d{3})*\\b)" @@ -2142,7 +2152,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jais-2") { pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2; } else if ( - tokenizer_pre == "gemma4") { + tokenizer_pre == "gemma4" || + tokenizer_pre == "granite-embed-multi-311m") { pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4; escape_whitespaces = true; } else if ( @@ -2252,6 +2263,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "talkie") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; + } else if ( + tokenizer_pre == "granite-embed-multi-97m") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI; + clean_spaces = false; + ignore_merges = true; } else if ( tokenizer_pre == "tiny_aya") { pre_type = LLAMA_VOCAB_PRE_TYPE_TINY_AYA; @@ -2310,6 +2326,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "solar-open") { pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN; clean_spaces = false; + } else if ( + tokenizer_pre == "mellum2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MELLUM2; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 093e5d02cdaf..b3991b53228c 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -8,60 +8,62 @@ // pre-tokenization types enum llama_vocab_pre_type { - LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, - LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, - LLAMA_VOCAB_PRE_TYPE_FALCON = 4, - LLAMA_VOCAB_PRE_TYPE_MPT = 5, - LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, - LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, - LLAMA_VOCAB_PRE_TYPE_REFACT = 8, - LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, - LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, - LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, - LLAMA_VOCAB_PRE_TYPE_OLMO = 12, - LLAMA_VOCAB_PRE_TYPE_DBRX = 13, - LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, - LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, - LLAMA_VOCAB_PRE_TYPE_VIKING = 18, - LLAMA_VOCAB_PRE_TYPE_JAIS = 19, - LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, - LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, - LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, - LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, - LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, - LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, - LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, - LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, - LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, - LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, - LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, - LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, - LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, - LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, - LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, - LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, - LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, - LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, - LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, - LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, - LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, - LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, - LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, - LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, - LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, - LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, - LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, - LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 51, - LLAMA_VOCAB_PRE_TYPE_MINICPM5 = 52, - LLAMA_VOCAB_PRE_TYPE_WHITESPACE = 53, + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + LLAMA_VOCAB_PRE_TYPE_PORO = 15, + LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, + LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, + LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, + LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, + LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, + LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, + LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, + LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, + LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, + LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, + LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, + LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, + LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, + LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, + LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, + LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, + LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 51, + LLAMA_VOCAB_PRE_TYPE_MINICPM5 = 52, + LLAMA_VOCAB_PRE_TYPE_WHITESPACE = 53, + LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI = 54, + LLAMA_VOCAB_PRE_TYPE_MELLUM2 = 55, }; struct LLM_KV; diff --git a/src/models/mellum.cpp b/src/models/mellum.cpp new file mode 100644 index 000000000000..a2372399bbdc --- /dev/null +++ b/src/models/mellum.cpp @@ -0,0 +1,225 @@ +#include "models.h" + +void llama_model_mellum::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + uint32_t swa_period = 4; + const auto res = ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + if (res) { + hparams.set_swa_pattern(swa_period); + } else { + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + } + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_12B_A2_5B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mellum::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for Mellum"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for Mellum"); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr llama_model_mellum::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique>(*this, params); + } + return std::make_unique>(*this, params); +} + +template +llama_model_mellum::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + const bool is_swa = hparams.is_swa(il); + + if (is_swa) { + // For sliding window layers, use regular rope with no yarn rope scaling. + // This is achieved here by setting freq_scale and attn_factor to 1. + // We also set ext_factor to 0 to avoid a few unnecessary computations. + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il, + nullptr, nullptr, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +template struct llama_model_mellum::graph; +template struct llama_model_mellum::graph; diff --git a/src/models/models.h b/src/models/models.h index 5251e2d82802..866e0d0be3ed 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -411,6 +411,18 @@ struct llama_model_stablelm : public llama_model_base { std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; +struct llama_model_mellum : public llama_model_base { + llama_model_mellum(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; struct llama_model_qwen : public llama_model_base { llama_model_qwen(const struct llama_model_params & params) : llama_model_base(params) {} @@ -1913,5 +1925,9 @@ struct llama_model_step35 : public llama_model_base { graph(const llama_model & model, const llm_graph_params & params); }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; diff --git a/src/models/modern-bert.cpp b/src/models/modern-bert.cpp index e9b79ffc6dc0..5ab51867cc03 100644 --- a/src/models/modern-bert.cpp +++ b/src/models/modern-bert.cpp @@ -14,6 +14,14 @@ void llama_model_modern_bert::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + // Some ModernBert derivatives (e.g. IBM Granite Embedding 97m R2) use + // SiLU/SwiGLU in the FFN instead of the default GELU/GeGLU. + hparams.llm_ffn_op = LLM_FFN_GEGLU; + std::string hidden_act; + if (ml.get_key(LLM_KV_HIDDEN_ACT, hidden_act, false)) { + hparams.llm_ffn_op = llm_ffn_op_type_from_string(hidden_act, LLM_FFN_GEGLU); + } + switch (hparams.n_layer) { case 12: type = LLM_TYPE_47M; break; // granite-embedding-small @@ -144,7 +152,8 @@ llama_model_modern_bert::graph::graph(const llama_model & model, const llm_graph NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, - LLM_FFN_GEGLU, LLM_FFN_SEQ, il); + hparams.llm_ffn_op, + LLM_FFN_SEQ, il); // attentions bypass the intermediate layer cur = ggml_add(ctx0, cur, ffn_inp); diff --git a/src/models/step35.cpp b/src/models/step35.cpp index 3b68e68707ae..caf18c743ff4 100644 --- a/src/models/step35.cpp +++ b/src/models/step35.cpp @@ -26,20 +26,36 @@ void llama_model_step35::load_arch_hparams(llama_model_loader & ml) { ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); - switch (hparams.n_layer) { + // NextN/MTP (Step3p5): extra decoder block appended beyond the main stack. + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 45: type = LLM_TYPE_196B_A11B; break; default: type = LLM_TYPE_UNKNOWN; } } -void llama_model_step35::load_arch_tensors(llama_model_loader &) { +void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + // Trunk-only: the GGUF declares MTP layers in metadata but the actual MTP + // tensors live in a separate file (e.g. user split target/draft). Mark + // MTP tensors NOT_REQUIRED so the trunk loads cleanly. + const std::string mtp_probe = "blk." + std::to_string(n_main) + ".nextn.eh_proj.weight"; + const bool trunk_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight(mtp_probe.c_str()) == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + const int mtp_flags = trunk_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, trunk_flags); // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. @@ -51,14 +67,14 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { n_rot_max = n_rot; } - for (int i = 0; i < n_layer; ++i) { + auto load_block_trunk = [&](int i, int flags) { auto & layer = layers[i]; const uint32_t n_head_l = hparams.n_head(i); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); @@ -70,13 +86,13 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, flags); // head-wise attention gate (Step35 self_attn.g_proj) layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); // dense MLP (leading dense blocks) layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); @@ -95,10 +111,86 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + }; + + auto load_block_mtp = [&](int i, bool is_first_mtp) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + // The MTP block is a full Step3p5 decoder layer (mtp_block) plus the + // NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head). + // `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only. + // + // Only the FIRST MTP block (i == n_main) is required for the + // single-block MTP runtime; trailing MTP blocks are always tolerated + // as missing so pruned GGUFs (block 0 only) load cleanly. Override + // mtp_flags to NOT_REQUIRED for those. + const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + } + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags); + + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + + // dense MLP (leading dense blocks) — present if the MTP block isn't MoE + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + // Only the first MTP block (i == n_main) is required at runtime — the + // single-block-MTP graph in build_arch_graph always uses that one. + // Trailing MTP blocks are loaded if present (so an un-pruned GGUF with + // all MTP layers still works) but tolerated when absent via the pruning + // path. See scripts/prune_step35_extra_mtp.py for the pruner. + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i, /*is_first_mtp=*/ i == (int) n_main); } } std::unique_ptr llama_model_step35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -111,7 +203,9 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para auto * inp_attn = build_attn_inp_kv_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; const uint32_t n_head_l = hparams.n_head(il); @@ -198,8 +292,8 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "attn_proj", il); } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -257,6 +351,13 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); res->t_embd = cur; @@ -267,3 +368,192 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para ggml_build_forward_expand(gf, cur); } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Step3p5 (MoE) +llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "STEP35 MTP requires nextn_predict_layers > 0"); + + // Single-block MTP only: always run the first trained MTP block (Qwen + // MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to + // be a much deeper refactor than this PR justifies; the trailing MTP + // blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just + // block 0) also work — see load_arch_tensors below and + // scripts/prune_step35_extra_mtp.py. + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + const uint32_t n_head_l = hparams.n_head(il); + const uint32_t n_head_kv_l = hparams.n_head_kv(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + // mtp_block: full Step3p5 decoder layer (attention with optional head-wise gate, then MoE/dense FFN) + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur = build_lora_mm(layer.wq, cur, layer.wq_s); + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + cb(Qcur, "mtp_Qcur", il); + cb(Kcur, "mtp_Kcur", il); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); + + if (layer.attn_q_norm) { + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + } + if (layer.attn_k_norm) { + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + } + + const bool is_swa = hparams.is_swa(il); + ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il); + const int64_t n_rot_l = hparams.n_rot(il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "mtp_Qcur_pos", il); + cb(Kcur, "mtp_Kcur_pos", il); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k)); + ggml_tensor * attn_out = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(attn_out, "mtp_attn_out", il); + + // head-wise attention gate: sigmoid(g_proj(x)) + if (layer.wqkv_gate) { + ggml_tensor * gate = build_lora_mm(layer.wqkv_gate, cur); // [n_head_l, n_tokens] + cb(gate, "mtp_attn_gate", il); + + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "mtp_attn_gate_sigmoid", il); + + ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens); + ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens); + cb(gate_3d, "mtp_attn_gate_3d", il); + + attn_3d = ggml_mul(ctx0, attn_3d, gate_3d); + cb(attn_3d, "mtp_attn_gated_3d", il); + + attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens); + cb(attn_out, "mtp_attn_gated", il); + } + + cur = build_lora_mm(layer.wo, attn_out, layer.wo_s); + cb(cur, "mtp_attn_proj", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_inp = cur; + cur = build_norm(cur, layer.ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_ffn_norm", il); + + // FFN: dense MLP or MoE (mirrors trunk path) + if (layer.ffn_gate_inp == nullptr) { + cur = build_ffn(cur, + layer.ffn_up, layer.ffn_up_b, nullptr, + layer.ffn_gate, layer.ffn_gate_b, nullptr, + layer.ffn_down, layer.ffn_down_b, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + } else { + ggml_tensor * moe_out = build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + layer.ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "mtp_ffn_moe_out", il); + + ggml_tensor * sh_out = build_ffn(cur, + layer.ffn_up_shexp, nullptr, nullptr, + layer.ffn_gate_shexp, nullptr, nullptr, + layer.ffn_down_shexp, nullptr, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(sh_out, "mtp_ffn_shared_out", il); + + cur = ggml_add(ctx0, moe_out, sh_out); + cb(cur, "mtp_ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "STEP35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "STEP35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/tests/test-llama-archs.cpp b/tests/test-llama-archs.cpp index 1def7faff605..4fe585e29a22 100644 --- a/tests/test-llama-archs.cpp +++ b/tests/test-llama-archs.cpp @@ -357,6 +357,7 @@ static bool moe_mandatory(const llm_arch arch) { case LLM_ARCH_KIMI_LINEAR: case LLM_ARCH_STEP35: case LLM_ARCH_MISTRAL4: + case LLM_ARCH_MELLUM: return true; default: return false; diff --git a/tools/server/tests/requirements.txt b/tools/server/tests/requirements.txt index 92d27e2a13c1..ca7a0281fa14 100644 --- a/tools/server/tests/requirements.txt +++ b/tools/server/tests/requirements.txt @@ -1,6 +1,5 @@ aiohttp~=3.9.3 pytest~=8.3.3 -huggingface_hub>=1.5.0,<2.0 numpy~=1.26.4 openai~=2.14.0 prometheus-client~=0.20.0