diff --git a/conversion/__init__.py b/conversion/__init__.py index 4a1fd5bb70f0..02ea6385208a 100644 --- a/conversion/__init__.py +++ b/conversion/__init__.py @@ -51,6 +51,7 @@ "DeepseekV3ForCausalLM": "deepseek", "DeepseekV32ForCausalLM": "deepseek", "DFlashDraftModel": "qwen", + "DeepseekV4ForCausalLM": "deepseek", "DistilBertForMaskedLM": "bert", "DistilBertForSequenceClassification": "bert", "DistilBertModel": "bert", diff --git a/conversion/base.py b/conversion/base.py index 08fd3747c408..0421aa4bc4d3 100644 --- a/conversion/base.py +++ b/conversion/base.py @@ -1273,7 +1273,7 @@ def set_gguf_parameters(self): if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: self.gguf_writer.add_layer_norm_eps(f_norm_eps) logger.info(f"gguf: layer norm epsilon = {f_norm_eps}") - if (n_experts := self.find_hparam(["num_local_experts", "num_experts"], optional=True)) is not None: + if (n_experts := self.find_hparam(["num_local_experts", "num_experts", "n_routed_experts"], optional=True)) is not None: self.gguf_writer.add_expert_count(n_experts) logger.info(f"gguf: expert count = {n_experts}") if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token", "top_k_experts"], optional=True)) is not None: @@ -1291,6 +1291,8 @@ def set_gguf_parameters(self): self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) elif score_func == "softmax": self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + elif score_func == "sqrtsoftplus": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SQRTSOFTPLUS) else: raise ValueError(f"Unsupported expert score gating function value: {score_func}") logger.info(f"gguf: expert score gating function = {score_func}") @@ -2600,6 +2602,17 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return cls._wrap_fn(func)(*args, **kwargs) +if hasattr(torch, "float8_e8m0fnu"): + _torch_float8_e8m0 = torch.float8_e8m0fnu + LazyTorchTensor._dtype_map[_torch_float8_e8m0] = np.uint8 + LazyTorchTensor._dtype_byteswap_map[_torch_float8_e8m0] = np.uint8 + LazyTorchTensor._dtype_str_map["F8_E8M0"] = _torch_float8_e8m0 +else: + # Older torch builds do not expose F8_E8M0. Keep the raw bytes so callers + # that know the format can decode them explicitly. + LazyTorchTensor._dtype_str_map["F8_E8M0"] = torch.uint8 + + def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str: # TODO @ngxson : this won't work correctly if the model has both audio & vision encoders # maybe we should fallback to text model's arch in that case, since not many models have both diff --git a/conversion/deepseek.py b/conversion/deepseek.py index 4c93fb66df64..cfac5201ec46 100644 --- a/conversion/deepseek.py +++ b/conversion/deepseek.py @@ -1,15 +1,18 @@ from __future__ import annotations +import json import re +from pathlib import Path from typing import Any, Callable, Iterable, TYPE_CHECKING +import numpy as np import torch if TYPE_CHECKING: from torch import Tensor -from .base import MmprojModel, ModelBase, TextModel, gguf, logger +from .base import LazyTorchTensor, MmprojModel, ModelBase, TextModel, gguf, logger from .qwen import QwenModel @@ -467,3 +470,310 @@ def set_gguf_parameters(self): self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"]) self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"]) self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"]) + + +@ModelBase.register("DeepseekV4ForCausalLM") +class DeepseekV4Model(TextModel): + model_arch = gguf.MODEL_ARCH.DEEPSEEK4 + _skipped_mtp_tensors = 0 + + def __init__(self, *args, **kwargs): + type(self)._skipped_mtp_tensors = 0 + super().__init__(*args, **kwargs) + + with open(self.dir_model / "config.json", "r", encoding="utf-8") as f: + raw_hparams = json.load(f) + for key, value in raw_hparams.items(): + self.hparams.setdefault(key, value) + + self.block_count = self.hparams["num_hidden_layers"] + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + self._dsv4_fp8_dequantized: set[str] = set() + self._dsv4_bf16_tensors: set[str] = set() + self._dsv4_f32_tensors: set[str] = set() + self._dsv4_mxfp4_generated = False + self._collect_source_dtypes() + + if type(self)._skipped_mtp_tensors: + logger.info("Skipping %d DeepSeek-V4 MTP tensor(s) for conversion v0", type(self)._skipped_mtp_tensors) + + # add a default chat template; if the model has a built-in template, it will be overridden later + template_path = Path(__file__).parent.parent / "models" / "templates" / "deepseek-ai-DeepSeek-V4.jinja" + if template_path.is_file(): + with open(template_path, "r", encoding="utf-8") as f: + self.gguf_writer.add_chat_template(f.read()) + + @classmethod + def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None: + name, _ = item + if name.startswith("mtp."): + cls._skipped_mtp_tensors += 1 + return None + return super().filter_tensors(item) + + def set_vocab(self): + self._set_vocab_gpt2() + + @staticmethod + def _float8_dtypes() -> tuple[torch.dtype, ...]: + return tuple( + dtype for dtype in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e5m2", None), + ) if dtype is not None + ) + + @staticmethod + def _e8m0_to_float(scale: Tensor) -> Tensor: + torch_float8_e8m0 = getattr(torch, "float8_e8m0fnu", None) + if torch_float8_e8m0 is not None and scale.dtype == torch_float8_e8m0: + return scale.float() + + bits = scale.view(torch.uint8).float() + return torch.exp2(bits - 127.0) + + def _collect_source_dtypes(self) -> None: + for name, gen in self.model_tensors.items(): + dtype = gen().dtype + if dtype == torch.bfloat16: + self._dsv4_bf16_tensors.add(name) + elif dtype == torch.float32: + self._dsv4_f32_tensors.add(name) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) + self.gguf_writer.add_sliding_window(hparams["sliding_window"]) + + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) + self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + self.gguf_writer.add_swiglu_clamp_exp([hparams["swiglu_limit"]] * self.block_count) + self.gguf_writer.add_swiglu_clamp_shexp([hparams["swiglu_limit"]] * self.block_count) + + self.gguf_writer.add_indexer_head_count(hparams["index_n_heads"]) + self.gguf_writer.add_indexer_key_length(hparams["index_head_dim"]) + self.gguf_writer.add_indexer_top_k(hparams["index_topk"]) + + self.gguf_writer.add_attention_output_group_count(hparams["o_groups"]) + self.gguf_writer.add_attention_output_lora_rank(hparams["o_lora_rank"]) + self.gguf_writer.add_attention_compress_ratios(hparams["compress_ratios"]) + self.gguf_writer.add_attention_compress_rope_freq_base(hparams["compress_rope_theta"]) + self.gguf_writer.add_hyper_connection_count(hparams["hc_mult"]) + self.gguf_writer.add_hyper_connection_sinkhorn_iterations(hparams["hc_sinkhorn_iters"]) + self.gguf_writer.add_hyper_connection_epsilon(hparams["hc_eps"]) + self.gguf_writer.add_hash_layer_count(hparams["num_hash_layers"]) + + def dequant_model(self): + fp8_dtypes = self._float8_dtypes() + tensors_to_remove: list[str] = [] + + def dequant_fp8_weight(weight: Tensor, scale: Tensor) -> Tensor: + out_features, in_features = weight.shape + scale_f = self._e8m0_to_float(scale) + scale_f = scale_f.repeat_interleave(128, 0)[:out_features] + scale_f = scale_f.repeat_interleave(128, 1)[:, :in_features] + return weight.float() * scale_f + + for name in list(self.model_tensors.keys()): + if not name.endswith(".scale"): + continue + weight_name = name.removesuffix(".scale") + ".weight" + if weight_name not in self.model_tensors: + continue + + weight = self.model_tensors[weight_name] + scale = self.model_tensors[name] + if weight().dtype not in fp8_dtypes: + continue + + self.model_tensors[weight_name] = lambda w=weight, s=scale: dequant_fp8_weight(w(), s()) + self._dsv4_fp8_dequantized.add(weight_name) + tensors_to_remove.append(name) + + for name in tensors_to_remove: + del self.model_tensors[name] + + @staticmethod + def _pack_mxfp4_blocks(weight: Tensor, scale: Tensor) -> np.ndarray: + packed = weight.contiguous().view(torch.uint8) + scale_u8 = scale.contiguous().view(torch.uint8) + + out_features, packed_cols = packed.shape + logical_cols = packed_cols * 2 + if logical_cols % 32 != 0: + raise ValueError(f"MXFP4 source row has {logical_cols} values, expected a multiple of 32") + + n_blocks = logical_cols // 32 + if tuple(scale_u8.shape) != (out_features, n_blocks): + raise ValueError(f"MXFP4 scale shape {tuple(scale_u8.shape)} does not match {(out_features, n_blocks)}") + + src = packed.reshape(out_features, n_blocks, 16) + low = src & 0x0F + high = (src >> 4) & 0x0F + + # The safetensors bytes store adjacent values as low/high nibbles. + # ggml MXFP4 blocks store values 0..15 in low nibbles and 16..31 in high nibbles. + vals = torch.stack((low, high), dim=-1).reshape(out_features, n_blocks, 32) + qs = vals[:, :, :16] | (vals[:, :, 16:] << 4) + raw = torch.cat((scale_u8.unsqueeze(-1), qs.to(torch.uint8)), dim=-1) + return raw.reshape(out_features, n_blocks * 17).cpu().numpy() + + def _write_mxfp4_expert_tensor(self, bid: int, proj: str, tensor_key: gguf.MODEL_TENSOR) -> list[str]: + n_experts = self.hparams["n_routed_experts"] + data: np.ndarray | None = None + consumed: list[str] = [] + + for eid in range(n_experts): + weight_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.weight" + scale_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.scale" + if weight_name not in self.model_tensors or scale_name not in self.model_tensors: + raise KeyError(f"Missing routed expert tensors for {weight_name}") + + weight = LazyTorchTensor.to_eager(self.model_tensors[weight_name]()) + scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]()) + packed = self._pack_mxfp4_blocks(weight, scale) + if data is None: + data = np.empty((n_experts, *packed.shape), dtype=packed.dtype) + data[eid] = packed + consumed.extend((weight_name, scale_name)) + + assert data is not None + new_name = self.format_tensor_name(tensor_key, bid) + shape = gguf.quant_shape_from_byte_shape(data.shape, gguf.GGMLQuantizationType.MXFP4) + logger.info(f"{new_name}: repacked routed experts to MXFP4, shape = {{{', '.join(str(n) for n in reversed(shape))}}}") + self.gguf_writer.add_tensor(new_name, data, raw_dtype=gguf.GGMLQuantizationType.MXFP4) + + return consumed + + def _write_hash_routing_tensors(self) -> list[str]: + consumed: list[str] = [] + + for bid in range(self.hparams["num_hash_layers"]): + name = f"layers.{bid}.ffn.gate.tid2eid" + if name not in self.model_tensors: + raise KeyError(f"Missing hash routing tensor {name}") + + data_torch = LazyTorchTensor.to_eager(self.model_tensors[name]()) + data = data_torch.to(torch.int32).cpu().numpy() + new_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_TID2EID, bid, ".weight") + logger.info(f"{new_name}: converted hash routing table to I32, shape = {{{', '.join(str(n) for n in reversed(data.shape))}}}") + self.gguf_writer.add_tensor(new_name, data) + consumed.append(name) + + return consumed + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + if self._dsv4_mxfp4_generated: + return () + + consumed: list[str] = self._write_hash_routing_tensors() + for bid in range(self.block_count): + consumed.extend(self._write_mxfp4_expert_tensor(bid, "w1", gguf.MODEL_TENSOR.FFN_GATE_EXP)) + consumed.extend(self._write_mxfp4_expert_tensor(bid, "w2", gguf.MODEL_TENSOR.FFN_DOWN_EXP)) + consumed.extend(self._write_mxfp4_expert_tensor(bid, "w3", gguf.MODEL_TENSOR.FFN_UP_EXP)) + + for name in consumed: + del self.model_tensors[name] + + self._dsv4_mxfp4_generated = True + return () + + def _format_dsv4_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> str: + return self.format_tensor_name(key, bid, suffix) + + def _map_dsv4_tensor_name(self, name: str, bid: int | None) -> tuple[gguf.MODEL_TENSOR, str]: + root_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = { + "embed.weight": (gguf.MODEL_TENSOR.TOKEN_EMBD, ".weight"), + "norm.weight": (gguf.MODEL_TENSOR.OUTPUT_NORM, ".weight"), + "head.weight": (gguf.MODEL_TENSOR.OUTPUT, ".weight"), + "hc_head_fn": (gguf.MODEL_TENSOR.HC_HEAD_FN, ".weight"), + "hc_head_base": (gguf.MODEL_TENSOR.HC_HEAD_BASE, ".weight"), + "hc_head_scale": (gguf.MODEL_TENSOR.HC_HEAD_SCALE, ".weight"), + } + if name in root_map: + return root_map[name] + + match = re.match(r"layers\.(\d+)\.(.+)$", name) + if match is None: + raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}") + + layer = int(match.group(1)) + if bid != layer: + raise ValueError(f"Tensor {name!r} parsed bid {bid} but layer name has {layer}") + + layer_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = { + "hc_attn_fn": (gguf.MODEL_TENSOR.HC_ATTN_FN, ".weight"), + "hc_attn_base": (gguf.MODEL_TENSOR.HC_ATTN_BASE, ".weight"), + "hc_attn_scale": (gguf.MODEL_TENSOR.HC_ATTN_SCALE, ".weight"), + "hc_ffn_fn": (gguf.MODEL_TENSOR.HC_FFN_FN, ".weight"), + "hc_ffn_base": (gguf.MODEL_TENSOR.HC_FFN_BASE, ".weight"), + "hc_ffn_scale": (gguf.MODEL_TENSOR.HC_FFN_SCALE, ".weight"), + "attn.attn_sink": (gguf.MODEL_TENSOR.ATTN_SINKS, ".weight"), + "attn.wq_a.weight": (gguf.MODEL_TENSOR.ATTN_Q_A, ".weight"), + "attn.wq_b.weight": (gguf.MODEL_TENSOR.ATTN_Q_B, ".weight"), + "attn.q_norm.weight": (gguf.MODEL_TENSOR.ATTN_Q_A_NORM, ".weight"), + "attn.wkv.weight": (gguf.MODEL_TENSOR.ATTN_KV, ".weight"), + "attn.kv_norm.weight": (gguf.MODEL_TENSOR.ATTN_KV_NORM, ".weight"), + "attn.wo_a.weight": (gguf.MODEL_TENSOR.ATTN_OUT_A, ".weight"), + "attn.wo_b.weight": (gguf.MODEL_TENSOR.ATTN_OUT_B, ".weight"), + "attn.compressor.ape": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_APE, ".weight"), + "attn.compressor.wkv.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WKV, ".weight"), + "attn.compressor.wgate.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WGATE, ".weight"), + "attn.compressor.norm.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_NORM, ".weight"), + "attn.indexer.wq_b.weight": (gguf.MODEL_TENSOR.INDEXER_ATTN_Q_B, ".weight"), + "attn.indexer.weights_proj.weight": (gguf.MODEL_TENSOR.INDEXER_PROJ, ".weight"), + "attn.indexer.compressor.ape": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_APE, ".weight"), + "attn.indexer.compressor.wkv.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WKV, ".weight"), + "attn.indexer.compressor.wgate.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE, ".weight"), + "attn.indexer.compressor.norm.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_NORM, ".weight"), + "attn_norm.weight": (gguf.MODEL_TENSOR.ATTN_NORM, ".weight"), + "ffn_norm.weight": (gguf.MODEL_TENSOR.FFN_NORM, ".weight"), + "ffn.gate.weight": (gguf.MODEL_TENSOR.FFN_GATE_INP, ".weight"), + "ffn.gate.bias": (gguf.MODEL_TENSOR.FFN_EXP_PROBS_B, ".bias"), + "ffn.gate.tid2eid": (gguf.MODEL_TENSOR.FFN_GATE_TID2EID, ".weight"), + "ffn.shared_experts.w1.weight": (gguf.MODEL_TENSOR.FFN_GATE_SHEXP, ".weight"), + "ffn.shared_experts.w2.weight": (gguf.MODEL_TENSOR.FFN_DOWN_SHEXP, ".weight"), + "ffn.shared_experts.w3.weight": (gguf.MODEL_TENSOR.FFN_UP_SHEXP, ".weight"), + } + + tensor_name = match.group(2) + if tensor_name in layer_map: + return layer_map[tensor_name] + + if re.match(r"ffn\.experts\.\d+\.w[123]\.(weight|scale)$", tensor_name): + return gguf.MODEL_TENSOR.FFN_GATE_EXP, ".weight" + + raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}") + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if re.match(r"layers\.\d+\.ffn\.experts\.\d+\.w[123]\.(weight|scale)$", name): + return [] + + tensor_key, suffix = self._map_dsv4_tensor_name(name, bid) + if tensor_key == gguf.MODEL_TENSOR.FFN_GATE_TID2EID: + return [] + + return [(self._format_dsv4_tensor_name(tensor_key, bid, suffix), data_torch)] + + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: + del new_name, bid # unused + + if name in self._dsv4_fp8_dequantized and n_dims >= 2: + return gguf.GGMLQuantizationType.Q8_0 + if name in self._dsv4_f32_tensors: + return gguf.GGMLQuantizationType.F32 + if name in self._dsv4_bf16_tensors and n_dims >= 2: + return gguf.GGMLQuantizationType.BF16 + + return False + + def prepare_tensors(self): + super().prepare_tensors() + self._is_mxfp4 = True + self.ftype = gguf.LlamaFileType.MOSTLY_MXFP4_MOE diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index bcd10beb0418..b26fab727dd3 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -145,6 +145,7 @@ class LLM: TOKEN_SHIFT_COUNT = "{arch}.token_shift_count" INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step" FULL_ATTENTION_INTERVAL = "{arch}.full_attention_interval" + HASH_LAYER_COUNT = "{arch}.hash_layer_count" ACTIVATION_SPARSITY_SCALE = "{arch}.activation_sparsity_scale" ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx" ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs" @@ -179,8 +180,12 @@ class Attention: REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count" SLIDING_WINDOW = "{arch}.attention.sliding_window" SCALE = "{arch}.attention.scale" + OUTPUT_GROUP_COUNT = "{arch}.attention.output_group_count" + OUTPUT_LORA_RANK = "{arch}.attention.output_lora_rank" OUTPUT_SCALE = "{arch}.attention.output_scale" VALUE_SCALE = "{arch}.attention.value_scale" + COMPRESS_RATIOS = "{arch}.attention.compress_ratios" + COMPRESS_ROPE_FREQ_BASE = "{arch}.attention.compress_rope_freq_base" TEMPERATURE_LENGTH = "{arch}.attention.temperature_length" KEY_LENGTH_MLA = "{arch}.attention.key_length_mla" VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" @@ -195,6 +200,11 @@ class Indexer: KEY_LENGTH = "{arch}.attention.indexer.key_length" TOP_K = "{arch}.attention.indexer.top_k" + class HyperConnection: + COUNT = "{arch}.hyper_connection.count" + SINKHORN_ITERATIONS = "{arch}.hyper_connection.sinkhorn_iterations" + EPSILON = "{arch}.hyper_connection.epsilon" + class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" DIMENSION_COUNT_SWA = "{arch}.rope.dimension_count_swa" @@ -469,6 +479,7 @@ class MODEL_ARCH(IntEnum): DEEPSEEK2 = auto() DEEPSEEK2OCR = auto() DEEPSEEK32 = auto() + DEEPSEEK4 = auto() CHATGLM = auto() GLM4 = auto() GLM4_MOE = auto() @@ -554,6 +565,9 @@ class MODEL_TENSOR(IntEnum): DENSE_2_OUT = auto() # embeddinggemma 2_Dense DENSE_3_OUT = auto() # embeddinggemma 3_Dense OUTPUT_NORM = auto() + HC_HEAD_FN = auto() + HC_HEAD_BASE = auto() + HC_HEAD_SCALE = auto() ROPE_FREQS = auto() ROPE_FACTORS_LONG = auto() ROPE_FACTORS_SHORT = auto() @@ -593,6 +607,7 @@ class MODEL_TENSOR(IntEnum): FFN_DOWN_CHEXP = auto() FFN_UP_CHEXP = auto() FFN_EXP_PROBS_B = auto() + FFN_GATE_TID2EID = auto() MOE_LATENT_DOWN = auto() # nemotron 3 super MOE_LATENT_UP = auto() # nemotron 3 super ATTN_Q_NORM = auto() @@ -680,6 +695,20 @@ class MODEL_TENSOR(IntEnum): ATTN_V_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() + ATTN_KV = auto() + ATTN_KV_NORM = auto() + ATTN_OUT_A = auto() + ATTN_OUT_B = auto() + HC_ATTN_FN = auto() + HC_ATTN_BASE = auto() + HC_ATTN_SCALE = auto() + HC_FFN_FN = auto() + HC_FFN_BASE = auto() + HC_FFN_SCALE = auto() + ATTN_COMPRESSOR_WKV = auto() + ATTN_COMPRESSOR_WGATE = auto() + ATTN_COMPRESSOR_APE = auto() + ATTN_COMPRESSOR_NORM = auto() FFN_SUB_NORM = auto() ATTN_SUB_NORM = auto() DEC_ATTN_NORM = auto() @@ -741,6 +770,10 @@ class MODEL_TENSOR(IntEnum): INDEXER_PROJ = auto() INDEXER_ATTN_K = auto() INDEXER_ATTN_Q_B = auto() + INDEXER_COMPRESSOR_WKV = auto() + INDEXER_COMPRESSOR_WGATE = auto() + INDEXER_COMPRESSOR_APE = auto() + INDEXER_COMPRESSOR_NORM = auto() # vision V_MMPROJ = auto() V_MMPROJ_FC = auto() @@ -1026,6 +1059,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.DEEPSEEK2OCR: "deepseek2-ocr", MODEL_ARCH.DEEPSEEK32: "deepseek32", + MODEL_ARCH.DEEPSEEK4: "deepseek4", MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.GLM4: "glm4", MODEL_ARCH.GLM4_MOE: "glm4moe", @@ -1110,6 +1144,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense + MODEL_TENSOR.HC_HEAD_FN: "output_hc_fn", + MODEL_TENSOR.HC_HEAD_BASE: "output_hc_base", + MODEL_TENSOR.HC_HEAD_SCALE: "output_hc_scale", MODEL_TENSOR.ROPE_FREQS: "rope_freqs", MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", @@ -1151,6 +1188,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps", MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", + MODEL_TENSOR.FFN_GATE_TID2EID: "blk.{bid}.ffn_gate_tid2eid", MODEL_TENSOR.MOE_LATENT_DOWN: "blk.{bid}.ffn_latent_down", # nemotron 3 super MODEL_TENSOR.MOE_LATENT_UP: "blk.{bid}.ffn_latent_up", # nemotron 3 super MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", @@ -1236,6 +1274,20 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", + MODEL_TENSOR.ATTN_KV: "blk.{bid}.attn_kv", + MODEL_TENSOR.ATTN_KV_NORM: "blk.{bid}.attn_kv_a_norm", + MODEL_TENSOR.ATTN_OUT_A: "blk.{bid}.attn_output_a", + MODEL_TENSOR.ATTN_OUT_B: "blk.{bid}.attn_output_b", + MODEL_TENSOR.HC_ATTN_FN: "blk.{bid}.hc_attn_fn", + MODEL_TENSOR.HC_ATTN_BASE: "blk.{bid}.hc_attn_base", + MODEL_TENSOR.HC_ATTN_SCALE: "blk.{bid}.hc_attn_scale", + MODEL_TENSOR.HC_FFN_FN: "blk.{bid}.hc_ffn_fn", + MODEL_TENSOR.HC_FFN_BASE: "blk.{bid}.hc_ffn_base", + MODEL_TENSOR.HC_FFN_SCALE: "blk.{bid}.hc_ffn_scale", + MODEL_TENSOR.ATTN_COMPRESSOR_WKV: "blk.{bid}.attn_compressor_kv", + MODEL_TENSOR.ATTN_COMPRESSOR_WGATE: "blk.{bid}.attn_compressor_gate", + MODEL_TENSOR.ATTN_COMPRESSOR_APE: "blk.{bid}.attn_compressor_ape", + MODEL_TENSOR.ATTN_COMPRESSOR_NORM: "blk.{bid}.attn_compressor_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm", @@ -1297,6 +1349,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.INDEXER_PROJ: "blk.{bid}.indexer.proj", MODEL_TENSOR.INDEXER_ATTN_K: "blk.{bid}.indexer.attn_k", MODEL_TENSOR.INDEXER_ATTN_Q_B: "blk.{bid}.indexer.attn_q_b", + MODEL_TENSOR.INDEXER_COMPRESSOR_WKV: "blk.{bid}.indexer_compressor_kv", + MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE: "blk.{bid}.indexer_compressor_gate", + MODEL_TENSOR.INDEXER_COMPRESSOR_APE: "blk.{bid}.indexer_compressor_ape", + MODEL_TENSOR.INDEXER_COMPRESSOR_NORM: "blk.{bid}.indexer_compressor_norm", # vision MODEL_TENSOR.V_MMPROJ: "mm.{bid}", MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc", @@ -3137,6 +3193,49 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], + MODEL_ARCH.DEEPSEEK4: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.HC_HEAD_FN, + MODEL_TENSOR.HC_HEAD_BASE, + MODEL_TENSOR.HC_HEAD_SCALE, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_SINKS, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV, + MODEL_TENSOR.ATTN_KV_NORM, + MODEL_TENSOR.ATTN_OUT_A, + MODEL_TENSOR.ATTN_OUT_B, + MODEL_TENSOR.HC_ATTN_FN, + MODEL_TENSOR.HC_ATTN_BASE, + MODEL_TENSOR.HC_ATTN_SCALE, + MODEL_TENSOR.HC_FFN_FN, + MODEL_TENSOR.HC_FFN_BASE, + MODEL_TENSOR.HC_FFN_SCALE, + MODEL_TENSOR.ATTN_COMPRESSOR_WKV, + MODEL_TENSOR.ATTN_COMPRESSOR_WGATE, + MODEL_TENSOR.ATTN_COMPRESSOR_APE, + MODEL_TENSOR.ATTN_COMPRESSOR_NORM, + MODEL_TENSOR.INDEXER_PROJ, + MODEL_TENSOR.INDEXER_ATTN_Q_B, + MODEL_TENSOR.INDEXER_COMPRESSOR_WKV, + MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE, + MODEL_TENSOR.INDEXER_COMPRESSOR_APE, + MODEL_TENSOR.INDEXER_COMPRESSOR_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_TID2EID, + MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], MODEL_ARCH.ERNIE4_5_MOE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -4436,8 +4535,9 @@ class GGMLQuantizationType(IntEnum): class ExpertGatingFuncType(IntEnum): - SOFTMAX = 1 - SIGMOID = 2 + SOFTMAX = 1 + SIGMOID = 2 + SQRTSOFTPLUS = 4 # TODO: add GGMLFileType from ggml_ftype in ggml.h diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a06ec88b32ca..a95b4c117a56 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -715,6 +715,9 @@ def add_leading_dense_block_count(self, length: int) -> None: def add_full_attention_interval(self, interval: int) -> None: self.add_uint32(Keys.LLM.FULL_ATTENTION_INTERVAL.format(arch=self.arch), interval) + def add_hash_layer_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.HASH_LAYER_COUNT.format(arch=self.arch), count) + def add_feed_forward_length(self, length: int | Sequence[int]) -> None: if isinstance(length, int): self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) @@ -940,6 +943,27 @@ def add_relative_attn_buckets_count(self, value: int) -> None: def add_sliding_window(self, value: int) -> None: self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value) + def add_attention_output_group_count(self, count: int) -> None: + self.add_uint32(Keys.Attention.OUTPUT_GROUP_COUNT.format(arch=self.arch), count) + + def add_attention_output_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.OUTPUT_LORA_RANK.format(arch=self.arch), length) + + def add_attention_compress_ratios(self, values: Sequence[int]) -> None: + self.add_array(Keys.Attention.COMPRESS_RATIOS.format(arch=self.arch), values) + + def add_attention_compress_rope_freq_base(self, value: float) -> None: + self.add_float32(Keys.Attention.COMPRESS_ROPE_FREQ_BASE.format(arch=self.arch), value) + + def add_hyper_connection_count(self, count: int) -> None: + self.add_uint32(Keys.HyperConnection.COUNT.format(arch=self.arch), count) + + def add_hyper_connection_sinkhorn_iterations(self, count: int) -> None: + self.add_uint32(Keys.HyperConnection.SINKHORN_ITERATIONS.format(arch=self.arch), count) + + def add_hyper_connection_epsilon(self, value: float) -> None: + self.add_float32(Keys.HyperConnection.EPSILON.format(arch=self.arch), value) + def add_attention_scale(self, value: float) -> None: self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value) diff --git a/models/templates/deepseek-ai-DeepSeek-V4.jinja b/models/templates/deepseek-ai-DeepSeek-V4.jinja new file mode 100644 index 000000000000..f19f787b1b7e --- /dev/null +++ b/models/templates/deepseek-ai-DeepSeek-V4.jinja @@ -0,0 +1,112 @@ +{%- if not add_generation_prompt is defined -%} + {%- set add_generation_prompt = false -%} +{%- endif -%} +{%- if not thinking is defined -%} + {%- if enable_thinking is defined -%} + {%- set thinking = enable_thinking -%} + {%- else -%} + {%- set thinking = false -%} + {%- endif -%} +{%- endif -%} +{%- set dsml_token = '|DSML|' -%} +{%- set thinking_start_token = '' -%} +{%- set thinking_end_token = '' -%} +{%- set tools_header = '## Tools\n\nYou have access to a set of tools to help answer the user\'s question. You can invoke tools by writing a "<' + dsml_token + 'tool_calls>" block like the following:\n\n<' + dsml_token + 'tool_calls>\n<' + dsml_token + 'invoke name="$TOOL_NAME">\n<' + dsml_token + 'parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE\n...\n\n<' + dsml_token + 'invoke name="$TOOL_NAME2">\n...\n\n\n\nString parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.\n\nIf thinking_mode is enabled (triggered by ' + thinking_start_token + '), you MUST output your complete reasoning inside ' + thinking_start_token + '...' + thinking_end_token + ' BEFORE any tool calls or final response.\n\nOtherwise, output directly after ' + thinking_end_token + ' with tool calls or final response.\n\n### Available Tool Schemas\n\n' -%} +{%- set tools_footer = '\nYou MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.\n' -%} +{%- set ns = namespace(system_prompt='', is_first_sp=true) -%} +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {%- if ns.is_first_sp -%} + {%- set ns.system_prompt = ns.system_prompt + (message['content'] or '') -%} + {%- set ns.is_first_sp = false -%} + {%- else -%} + {%- set ns.system_prompt = ns.system_prompt + '\n\n' + (message['content'] or '') -%} + {%- endif -%} + {%- endif -%} +{%- endfor -%} +{%- if tools is defined and tools -%} + {%- set ts = namespace(schemas='') -%} + {%- for tool in tools -%} + {%- if tool['type'] == 'function' -%} + {%- set ts.schemas = ts.schemas + (tool['function'] | tojson) + '\n' -%} + {%- endif -%} + {%- endfor -%} + {%- if ns.system_prompt -%} + {%- set ns.system_prompt = ns.system_prompt + '\n\n' + tools_header + ts.schemas + tools_footer -%} + {%- else -%} + {%- set ns.system_prompt = tools_header + ts.schemas + tools_footer -%} + {%- endif -%} +{%- endif -%} +{{- bos_token -}} +{{- ns.system_prompt -}} +{%- set last_user_idx = namespace(value=-1) -%} +{%- for message in messages -%} + {%- if message['role'] == 'user' or message['role'] == 'developer' or message['role'] == 'tool' -%} + {%- set last_user_idx.value = loop.index0 -%} + {%- endif -%} +{%- endfor -%} +{%- set state = namespace(in_user=false) -%} +{%- for message in messages -%} + {%- if message['role'] == 'user' or message['role'] == 'developer' -%} + {%- if state.in_user -%} + {{- '\n\n' -}} + {%- else -%} + {{- '<|User|>' -}} + {%- set state.in_user = true -%} + {%- endif -%} + {{- message['content'] or '' -}} + {%- elif message['role'] == 'tool' -%} + {%- if state.in_user -%} + {{- '\n\n' -}} + {%- else -%} + {{- '<|User|>' -}} + {%- set state.in_user = true -%} + {%- endif -%} + {{- '' + (message['content'] or '') + '' -}} + {%- elif message['role'] == 'assistant' -%} + {%- set state.in_user = false -%} + {{- '<|Assistant|>' -}} + {%- set is_after_last_user = loop.index0 > last_user_idx.value -%} + {%- if is_after_last_user and thinking -%} + {{- thinking_start_token -}} + {%- if message['reasoning_content'] is defined and message['reasoning_content'] -%} + {{- message['reasoning_content'] -}} + {%- endif -%} + {{- thinking_end_token -}} + {%- else -%} + {{- thinking_end_token -}} + {%- endif -%} + {%- if message['content'] is defined and message['content'] -%} + {{- message['content'] -}} + {%- endif -%} + {%- if message['tool_calls'] -%} + {{- '\n\n<' + dsml_token + 'tool_calls>\n' -}} + {%- for tool in message['tool_calls'] -%} + {%- set func = tool['function'] -%} + {{- '<' + dsml_token + 'invoke name="' + func['name'] + '">\n' -}} + {%- set args = func['arguments'] -%} + {%- if args is string -%} + {%- set args = args | from_json -%} + {%- endif -%} + {%- for key, val in args.items() -%} + {%- if val is string -%} + {{- '<' + dsml_token + 'parameter name="' + key + '" string="true">' + val + '\n' -}} + {%- else -%} + {{- '<' + dsml_token + 'parameter name="' + key + '" string="false">' + (val | tojson) + '\n' -}} + {%- endif -%} + {%- endfor -%} + {{- '\n' -}} + {%- endfor -%} + {{- '' -}} + {%- endif -%} + {{- '<|end▁of▁sentence|>' -}} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{- '<|Assistant|>' -}} + {%- if thinking -%} + {{- thinking_start_token -}} + {%- else -%} + {{- thinking_end_token -}} + {%- endif -%} +{%- endif -%} \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d15ccfd99f14..320784c3a8cc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -25,6 +25,7 @@ add_library(llama llama-kv-cache.cpp llama-kv-cache-iswa.cpp llama-kv-cache-dsa.cpp + llama-kv-cache-dsv4.cpp llama-memory.cpp llama-memory-hybrid.cpp llama-memory-hybrid-iswa.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d80915ffdba5..98f391a9115f 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -77,6 +77,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" }, { LLM_ARCH_DEEPSEEK32, "deepseek32" }, + { LLM_ARCH_DEEPSEEK4, "deepseek4" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, @@ -440,6 +441,23 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_KV, "blk.%d.attn_kv" }, + { LLM_TENSOR_ATTN_KV_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_OUT_A, "blk.%d.attn_output_a" }, + { LLM_TENSOR_ATTN_OUT_B, "blk.%d.attn_output_b" }, + { LLM_TENSOR_HC_HEAD_FN, "output_hc_fn" }, + { LLM_TENSOR_HC_HEAD_BASE, "output_hc_base" }, + { LLM_TENSOR_HC_HEAD_SCALE, "output_hc_scale" }, + { LLM_TENSOR_HC_ATTN_FN, "blk.%d.hc_attn_fn" }, + { LLM_TENSOR_HC_ATTN_BASE, "blk.%d.hc_attn_base" }, + { LLM_TENSOR_HC_ATTN_SCALE, "blk.%d.hc_attn_scale" }, + { LLM_TENSOR_HC_FFN_FN, "blk.%d.hc_ffn_fn" }, + { LLM_TENSOR_HC_FFN_BASE, "blk.%d.hc_ffn_base" }, + { LLM_TENSOR_HC_FFN_SCALE, "blk.%d.hc_ffn_scale" }, + { LLM_TENSOR_ATTN_COMPRESSOR_WKV, "blk.%d.attn_compressor_kv" }, + { LLM_TENSOR_ATTN_COMPRESSOR_WGATE, "blk.%d.attn_compressor_gate" }, + { LLM_TENSOR_ATTN_COMPRESSOR_APE, "blk.%d.attn_compressor_ape" }, + { LLM_TENSOR_ATTN_COMPRESSOR_NORM, "blk.%d.attn_compressor_norm" }, { LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" }, { LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" }, { LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" }, @@ -566,6 +584,11 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" }, { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" }, { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, + { LLM_TENSOR_INDEXER_COMPRESSOR_WKV, "blk.%d.indexer_compressor_kv" }, + { LLM_TENSOR_INDEXER_COMPRESSOR_WGATE, "blk.%d.indexer_compressor_gate" }, + { LLM_TENSOR_INDEXER_COMPRESSOR_APE, "blk.%d.indexer_compressor_ape" }, + { LLM_TENSOR_INDEXER_COMPRESSOR_NORM, "blk.%d.indexer_compressor_norm" }, + { LLM_TENSOR_FFN_GATE_TID2EID, "blk.%d.ffn_gate_tid2eid" }, { LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" }, { LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" }, { LLM_TENSOR_FC, "fc" }, @@ -616,6 +639,23 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_KV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_KV_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_OUT_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_OUT_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_HC_HEAD_FN, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_HC_HEAD_BASE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_ADD}}, + {LLM_TENSOR_HC_HEAD_SCALE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_HC_ATTN_FN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_HC_ATTN_BASE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_HC_ATTN_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_HC_FFN_FN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_HC_FFN_BASE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_HC_FFN_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_COMPRESSOR_WKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_COMPRESSOR_WGATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_COMPRESSOR_APE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_ATTN_COMPRESSOR_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_SINKS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SCALE}}, @@ -779,6 +819,11 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_COMPRESSOR_WKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_COMPRESSOR_WGATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_COMPRESSOR_APE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_INDEXER_COMPRESSOR_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_GATE_TID2EID, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, {LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the @@ -933,6 +978,7 @@ bool llm_arch_supports_sm_tensor(const llm_arch & arch) { case LLM_ARCH_OLMOE: case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK32: + case LLM_ARCH_DEEPSEEK4: case LLM_ARCH_GLM_DSA: case LLM_ARCH_BITNET: case LLM_ARCH_T5: diff --git a/src/llama-arch.h b/src/llama-arch.h index 946518d5f224..7087785d522d 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -82,6 +82,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK2, LLM_ARCH_DEEPSEEK2OCR, LLM_ARCH_DEEPSEEK32, + LLM_ARCH_DEEPSEEK4, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, @@ -501,10 +502,27 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_KV, + LLM_TENSOR_ATTN_KV_NORM, + LLM_TENSOR_ATTN_OUT_A, + LLM_TENSOR_ATTN_OUT_B, LLM_TENSOR_ATTN_K_B, LLM_TENSOR_ATTN_V_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_HC_HEAD_FN, + LLM_TENSOR_HC_HEAD_BASE, + LLM_TENSOR_HC_HEAD_SCALE, + LLM_TENSOR_HC_ATTN_FN, + LLM_TENSOR_HC_ATTN_BASE, + LLM_TENSOR_HC_ATTN_SCALE, + LLM_TENSOR_HC_FFN_FN, + LLM_TENSOR_HC_FFN_BASE, + LLM_TENSOR_HC_FFN_SCALE, + LLM_TENSOR_ATTN_COMPRESSOR_WKV, + LLM_TENSOR_ATTN_COMPRESSOR_WGATE, + LLM_TENSOR_ATTN_COMPRESSOR_APE, + LLM_TENSOR_ATTN_COMPRESSOR_NORM, LLM_TENSOR_ATTN_SUB_NORM, LLM_TENSOR_FFN_SUB_NORM, LLM_TENSOR_DEC_ATTN_NORM, @@ -566,6 +584,11 @@ enum llm_tensor { LLM_TENSOR_INDEXER_PROJ, LLM_TENSOR_INDEXER_ATTN_K, LLM_TENSOR_INDEXER_ATTN_Q_B, + LLM_TENSOR_INDEXER_COMPRESSOR_WKV, + LLM_TENSOR_INDEXER_COMPRESSOR_WGATE, + LLM_TENSOR_INDEXER_COMPRESSOR_APE, + LLM_TENSOR_INDEXER_COMPRESSOR_NORM, + LLM_TENSOR_FFN_GATE_TID2EID, LLM_TENSOR_NEXTN_PROJ_PRE, LLM_TENSOR_NEXTN_PROJ_POST, LLM_TENSOR_NEXTN_EH_PROJ, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 029141e2aaf2..0465430df43a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2321,7 +2321,11 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { - if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) { + if (model.arch == LLM_ARCH_QWEN3NEXT || + model.arch == LLM_ARCH_KIMI_LINEAR || + model.arch == LLM_ARCH_QWEN35 || + model.arch == LLM_ARCH_QWEN35MOE || + model.arch == LLM_ARCH_DEEPSEEK4) { return std::max(n_tokens * 40, 32u * model.n_tensors()); } uint32_t res = std::max(1024u, 8u*model.n_tensors()); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 3ded70bc0f71..4c86e43c1f74 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -8,6 +8,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" #include "llama-kv-cache-dsa.h" +#include "llama-kv-cache-dsv4.h" #include "llama-memory-hybrid.h" #include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" @@ -17,6 +18,7 @@ #include #include #include +#include #include // dedup helpers @@ -568,7 +570,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { // base tensors may not be allocated if there are no non-SWA attention layers if (self_k_idxs && self_k_idxs->buffer) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); - mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); + if (self_v_idxs) { + mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); + } } // the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live @@ -579,7 +583,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); - mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); + if (self_v_idxs_swa) { + mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); + } } if (self_kq_mask_swa && self_kq_mask_swa->buffer) { @@ -633,6 +639,283 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { return res; } +static void dsv4_set_i64(ggml_tensor * dst, const std::vector & src) { + if (!dst || !dst->buffer) { + return; + } + + GGML_ASSERT(dst->ne[0] == (int64_t) src.size()); + ggml_backend_tensor_set(dst, src.data(), 0, src.size()*ggml_element_size(dst)); +} + +static void dsv4_set_i32(ggml_tensor * dst, const std::vector & src) { + if (!dst || !dst->buffer) { + return; + } + + GGML_ASSERT(dst->ne[0] == (int64_t) src.size()); + ggml_backend_tensor_set(dst, src.data(), 0, src.size()*ggml_element_size(dst)); +} + +static void dsv4_set_kq_mask( + ggml_tensor * dst, + const llama_kv_cache_dsv4_context::comp_plan & plan, + uint32_t n_tokens, + int64_t n_stream) { + if (!dst || !dst->buffer) { + return; + } + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(n_stream > 0); + GGML_ASSERT(n_tokens%n_stream == 0); + GGML_ASSERT(dst->ne[0] == plan.n_kv); + GGML_ASSERT(dst->ne[1] == (int64_t) n_tokens/n_stream); + GGML_ASSERT(dst->ne[2] == 1); + GGML_ASSERT(dst->ne[3] == n_stream); + GGML_ASSERT((int64_t) plan.n_visible.size() == (int64_t) n_tokens); + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + float * data = (float *) dst->data; + + for (int64_t i = 0; i < (int64_t) n_tokens; ++i) { + const int32_t n_visible = plan.n_visible[i]; + + for (int64_t j = 0; j < dst->ne[0]; ++j) { + data[i*dst->ne[0] + j] = j < n_visible ? 0.0f : -INFINITY; + } + } +} + +static ggml_tensor * dsv4_build_raw_kq_mask( + ggml_context * ctx, + const llama_kv_cache_dsv4_raw_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams, + int64_t n_stream) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + + GGML_ASSERT(n_stream > 0); + GGML_ASSERT(n_tokens%n_stream == 0); + + const bool use_fattn = cparams.flash_attn && (!cparams.kv_unified || n_stream == 1); + const auto type = use_fattn ? GGML_TYPE_F16 : GGML_TYPE_F32; + + ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_kq_mask"); + + return res; +} + +static bool dsv4_can_reuse_raw_kq_mask( + ggml_tensor * kq_mask, + const llama_kv_cache_dsv4_raw_context * mctx, + const llama_ubatch & ubatch, + int64_t n_stream) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + + GGML_ASSERT(n_stream > 0); + + bool res = true; + + res &= (kq_mask->ne[0] == n_kv); + res &= (kq_mask->ne[1] == n_tokens/n_stream); + res &= (kq_mask->ne[2] == 1); + res &= (kq_mask->ne[3] == n_stream); + + return res; +} + +static std::string dsv4_plan_positions(const std::vector & values) { + std::ostringstream ss; + ss << "["; + for (size_t i = 0; i < values.size(); ++i) { + if (i > 0) { + ss << ", "; + } + ss << values[i]; + } + ss << "]"; + return ss.str(); +} + +static bool dsv4_compress_debug() { + static const bool debug = []() { + const char * env = getenv("LLAMA_DSV4_COMPRESS_DEBUG"); + return env && atoi(env) > 0; + }(); + + return debug; +} + +static void dsv4_set_comp_inputs( + const llm_graph_input_dsv4::comp_input & inp, + const llama_kv_cache_dsv4_context::comp_plan & plan, + const char * name, + bool debug, + uint32_t n_tokens, + int64_t n_stream) { + dsv4_set_i32(inp.state_pos, plan.state_pos); + dsv4_set_i32(inp.state_persist_src_idxs, plan.state_persist_src_idxs); + dsv4_set_i32(inp.state_persist_dst_idxs, plan.state_persist_dst_idxs); + dsv4_set_i32(inp.state_read_idxs, plan.state_read_idxs); + dsv4_set_i64(inp.state_write_idxs, plan.state_write_idxs); + dsv4_set_i32(inp.state_write_pos, plan.state_write_pos); + dsv4_set_kq_mask(inp.kq_mask, plan, n_tokens, n_stream); + + if (debug || dsv4_compress_debug()) { + LLAMA_LOG_INFO("%s: %s n_tokens=%u, n_stream=%d, state_persist_dst=%s, state_write_pos=%s\n", + __func__, name, n_tokens, (int) n_stream, + dsv4_plan_positions(plan.state_persist_dst_idxs).c_str(), + dsv4_plan_positions(plan.state_write_pos).c_str()); + } +} + +static bool dsv4_can_reuse_tensor_1d(ggml_tensor * t, int64_t ne0) { + return (t == nullptr && ne0 == 0) || (t != nullptr && t->ne[0] == ne0); +} + +static bool dsv4_can_reuse_kq_mask( + ggml_tensor * t, + const llama_kv_cache_dsv4_context::comp_plan & plan, + uint32_t n_tokens, + int64_t n_stream) { + if (plan.n_kv == 0) { + return t == nullptr; + } + + GGML_ASSERT(n_stream > 0); + + return t != nullptr && + t->ne[0] == plan.n_kv && + t->ne[1] == (int64_t) n_tokens/n_stream && + t->ne[2] == 1 && + t->ne[3] == n_stream; +} + +static bool dsv4_can_reuse_comp_input( + const llm_graph_input_dsv4::comp_input & inp, + const llama_kv_cache_dsv4_context::comp_plan & plan, + uint32_t n_tokens, + int64_t n_stream) { + bool res = true; + res &= dsv4_can_reuse_tensor_1d(inp.state_pos, plan.state_pos.size()); + res &= dsv4_can_reuse_tensor_1d(inp.state_persist_src_idxs, plan.state_persist_src_idxs.size()); + res &= dsv4_can_reuse_tensor_1d(inp.state_persist_dst_idxs, plan.state_persist_dst_idxs.size()); + res &= dsv4_can_reuse_tensor_1d(inp.state_read_idxs, plan.state_read_idxs.size()); + res &= dsv4_can_reuse_tensor_1d(inp.state_write_idxs, plan.state_write_idxs.size()); + res &= dsv4_can_reuse_tensor_1d(inp.state_write_pos, plan.state_write_pos.size()); + res &= dsv4_can_reuse_kq_mask(inp.kq_mask, plan, n_tokens, n_stream); + + return res; +} + +static ggml_tensor * dsv4_build_input_1d( + ggml_context * ctx, + ggml_type type, + int64_t ne0, + const std::string & name) { + if (ne0 == 0) { + return nullptr; + } + + ggml_tensor * res = ggml_new_tensor_1d(ctx, type, ne0); + ggml_set_input(res); + ggml_set_name(res, name.c_str()); + + return res; +} + +static void dsv4_build_comp_inputs( + ggml_context * ctx, + llm_graph_input_dsv4::comp_input & inp, + const llama_kv_cache_dsv4_context::comp_plan & plan, + const char * name, + int64_t n_stream) { + inp.state_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_pos.size(), std::string("dsv4_") + name + "_state_pos"); + inp.state_persist_src_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_persist_src_idxs.size(), std::string("dsv4_") + name + "_state_persist_src_idxs"); + inp.state_persist_dst_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_persist_dst_idxs.size(), std::string("dsv4_") + name + "_state_persist_dst_idxs"); + inp.state_read_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_read_idxs.size(), std::string("dsv4_") + name + "_state_read_idxs"); + inp.state_write_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I64, plan.state_write_idxs.size(), std::string("dsv4_") + name + "_state_write_idxs"); + inp.state_write_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_write_pos.size(), std::string("dsv4_") + name + "_state_write_pos"); + + if (plan.n_kv > 0) { + const int64_t n_tokens = (int64_t) plan.n_visible.size(); + + GGML_ASSERT(n_stream > 0); + GGML_ASSERT(n_tokens%n_stream == 0); + + inp.kq_mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, plan.n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(inp.kq_mask); + ggml_set_name(inp.kq_mask, (std::string("dsv4_") + name + "_kq_mask").c_str()); + } +} + +void llm_graph_input_dsv4_raw::set_input(const llama_ubatch * ubatch) { + if (self_k_idxs && self_k_idxs->buffer) { + mctx->set_input_k_idxs(self_k_idxs); + } + + if (self_kq_mask && self_kq_mask->buffer) { + mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } + + if (self_k_rot) { + mctx->set_input_k_rot(self_k_rot); + } +} + +void llm_graph_input_dsv4::set_input(const llama_ubatch * ubatch) { + const auto & plan_csa = mctx->get_csa_plan(*ubatch); + const auto & plan_hca = mctx->get_hca_plan(*ubatch); + const auto & plan_lid = mctx->get_lid_plan(*ubatch); + const int64_t n_stream = plan_csa.n_stream; + + inp_raw->mctx = mctx->get_raw(); + inp_raw->set_input(ubatch); + + dsv4_set_comp_inputs(inp_csa, plan_csa, "csa", debug > 0, ubatch->n_tokens, n_stream); + dsv4_set_comp_inputs(inp_hca, plan_hca, "hca", debug > 0, ubatch->n_tokens, n_stream); + dsv4_set_comp_inputs(inp_lid, plan_lid, "lid", debug > 0, ubatch->n_tokens, n_stream); + + if (inp_lid.k_rot && inp_lid.k_rot->buffer) { + mctx->get_lid()->set_input_k_rot(inp_lid.k_rot); + } +} + +bool llm_graph_input_dsv4::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + inp_raw->mctx = mctx->get_raw(); + + bool res = true; + + const auto & plan_csa = mctx->get_csa_plan(params.ubatch); + const auto & plan_hca = mctx->get_hca_plan(params.ubatch); + const auto & plan_lid = mctx->get_lid_plan(params.ubatch); + const int64_t n_stream = plan_csa.n_stream; + + const auto * raw_ctx = mctx->get_raw(); + inp_raw->mctx = raw_ctx; + + if (inp_raw->self_k_idxs && inp_raw->self_k_idxs->buffer) { + res &= inp_raw->self_k_idxs->ne[0] == raw_ctx->get_n_write(); + } + if (inp_raw->self_kq_mask && inp_raw->self_kq_mask->buffer) { + res &= dsv4_can_reuse_raw_kq_mask(inp_raw->self_kq_mask, raw_ctx, params.ubatch, n_stream); + } + + res &= dsv4_can_reuse_comp_input(inp_csa, plan_csa, params.ubatch.n_tokens, n_stream); + res &= dsv4_can_reuse_comp_input(inp_hca, plan_hca, params.ubatch.n_tokens, n_stream); + res &= dsv4_can_reuse_comp_input(inp_lid, plan_lid, params.ubatch.n_tokens, n_stream); + + return res; +} + void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { GGML_ASSERT(cross_kq_mask); @@ -1351,20 +1634,24 @@ ggml_tensor * llm_graph_context::build_ffn( switch (type_op) { case LLM_FFN_SILU: if (gate && type_gate == LLM_FFN_PAR) { - // Step35: HF clamps gate (after SiLU) and up before multiplication - if (arch == LLM_ARCH_STEP35 && il >= 0) { + if (il >= 0) { const float limit = hparams.swiglu_clamp_shexp[il]; constexpr float eps = 1e-6f; if (limit > eps) { - ggml_tensor * gate_act = ggml_silu(ctx0, cur); - cb(gate_act, "ffn_silu", il); - gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); - cb(gate_act, "ffn_silu_clamped", il); - tmp = ggml_clamp(ctx0, tmp, -limit, limit); cb(tmp, "ffn_up_clamped", il); - cur = ggml_mul(ctx0, gate_act, tmp); + if (arch == LLM_ARCH_DEEPSEEK4) { + cur = ggml_clamp(ctx0, cur, -INFINITY, limit); + cb(cur, "ffn_gate_clamped", il); + cur = ggml_swiglu_split(ctx0, cur, tmp); + } else { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_silu_clamped", il); + cur = ggml_mul(ctx0, gate_act, tmp); + } cb(cur, "ffn_swiglu_limited", il); type_gate = LLM_FFN_SEQ; break; @@ -1474,7 +1761,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_tensor * gate_up_exps, ggml_tensor * up_exps_s, ggml_tensor * gate_exps_s, - ggml_tensor * down_exps_s) const { + ggml_tensor * down_exps_s, + ggml_tensor * selected_experts_in) const { return build_moe_ffn( cur, gate_inp, /* gate_inp_b */ nullptr, @@ -1494,7 +1782,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( /* gate_up_exps_b */ nullptr, up_exps_s, gate_exps_s, - down_exps_s + down_exps_s, + selected_experts_in ); } @@ -1521,7 +1810,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_tensor * gate_up_exps_b, ggml_tensor * up_exps_s, ggml_tensor * gate_exps_s, - ggml_tensor * down_exps_s) const { + ggml_tensor * down_exps_s, + ggml_tensor * selected_experts_in) const { const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN @@ -1530,6 +1820,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( if (probs_in == nullptr) { logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens] + if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS) { + ggml_mul_mat_set_prec(logits, GGML_PREC_F32); + } cb(logits, "ffn_moe_logits", il); } else { logits = probs_in; @@ -1554,6 +1847,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn( { probs = logits; // [n_expert, n_tokens] } break; + case LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS: + { + probs = ggml_sqrt(ctx0, ggml_softplus(ctx0, logits)); // [n_expert, n_tokens] + } break; default: GGML_ABORT("fatal error"); } @@ -1604,8 +1901,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn( } // select experts - ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] - cb(selected_experts->src[0], "ffn_moe_argsort", il); + ggml_tensor * selected_experts = selected_experts_in; + if (selected_experts == nullptr) { + selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + cb(selected_experts->src[0], "ffn_moe_argsort", il); + } cb(selected_experts, "ffn_moe_topk", il); if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) { @@ -1718,20 +2018,24 @@ ggml_tensor * llm_graph_context::build_moe_ffn( switch (type_op) { case LLM_FFN_SILU: if (gate_exps) { - // Step35: per-layer clamp for routed experts - if (arch == LLM_ARCH_STEP35 && il >= 0) { + if (il >= 0) { const float limit = hparams.swiglu_clamp_exp[il]; constexpr float eps = 1e-6f; if (limit > eps) { - ggml_tensor * gate_act = ggml_silu(ctx0, cur); - cb(gate_act, "ffn_moe_silu", il); - gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); - cb(gate_act, "ffn_moe_silu_clamped", il); - up = ggml_clamp(ctx0, up, -limit, limit); cb(up, "ffn_moe_up_clamped", il); - cur = ggml_mul(ctx0, gate_act, up); + if (arch == LLM_ARCH_DEEPSEEK4) { + cur = ggml_clamp(ctx0, cur, -INFINITY, limit); + cb(cur, "ffn_moe_gate_clamped", il); + cur = ggml_swiglu_split(ctx0, cur, up); + } else { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_moe_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_moe_silu_clamped", il); + cur = ggml_mul(ctx0, gate_act, up); + } cb(cur, "ffn_moe_swiglu_limited", il); break; } @@ -2760,6 +3064,31 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); } +llm_graph_input_dsv4 * llm_graph_context::build_inp_dsv4() const { + const auto * mctx_cur = static_cast(mctx); + const auto * raw_ctx = mctx_cur->get_raw(); + + auto inp_raw = std::make_unique(cparams, raw_ctx); + + const int64_t n_stream = mctx_cur->get_csa_plan(ubatch).n_stream; + + GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "DSV4 expects SWA raw cache"); + + inp_raw->self_k_idxs = raw_ctx->build_input_k_idxs(ctx0, ubatch); + inp_raw->self_kq_mask = dsv4_build_raw_kq_mask(ctx0, raw_ctx, ubatch, cparams, n_stream); + inp_raw->self_kq_mask_cnv = inp_raw->self_kq_mask; + + inp_raw->self_k_rot = raw_ctx->build_input_k_rot(ctx0); + auto inp = std::make_unique(cparams, std::move(inp_raw), mctx_cur); + + dsv4_build_comp_inputs(ctx0, inp->inp_csa, mctx_cur->get_csa_plan(ubatch), "csa", n_stream); + dsv4_build_comp_inputs(ctx0, inp->inp_hca, mctx_cur->get_hca_plan(ubatch), "hca", n_stream); + dsv4_build_comp_inputs(ctx0, inp->inp_lid, mctx_cur->get_lid_plan(ubatch), "lid", n_stream); + inp->inp_lid.k_rot = mctx_cur->get_lid()->build_input_k_rot(ctx0); + + return (llm_graph_input_dsv4 *) res->add_input(std::move(inp)); +} + ggml_tensor * llm_graph_context::build_rs( ggml_tensor * s, ggml_tensor * state_copy_main, diff --git a/src/llama-graph.h b/src/llama-graph.h index a6e8c3985ba5..4b5b75c632ab 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -23,6 +23,8 @@ struct llama_memory_context_i; class llama_kv_cache_context; class llama_kv_cache_dsa_context; +class llama_kv_cache_dsv4_raw_context; +class llama_kv_cache_dsv4_context; class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; @@ -459,6 +461,79 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { const llama_kv_cache_iswa_context * mctx; }; +// DSV4 raw graph inputs are SWA-only, but their mask may be stream-shaped +// so raw K can be concatenated with DSV4 compressed K in one attention op. +class llm_graph_input_dsv4_raw { +public: + llm_graph_input_dsv4_raw( + const llama_cparams & cparams, + const llama_kv_cache_dsv4_raw_context * mctx) : + cparams(cparams), + mctx(mctx) { + } + + void set_input(const llama_ubatch * ubatch); + + ggml_tensor * get_k_idxs() const { return self_k_idxs; } + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + + ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + ggml_tensor * self_k_rot = nullptr; + + const llama_cparams cparams; + + const llama_kv_cache_dsv4_raw_context * mctx; +}; + +class llm_graph_input_dsv4 : public llm_graph_input_i { +public: + struct comp_input { + ggml_tensor * state_pos = nullptr; // I32 [n_state] + ggml_tensor * state_persist_src_idxs = nullptr; // I32 [n_state_persist] + ggml_tensor * state_persist_dst_idxs = nullptr; // I32 [n_state_persist] + ggml_tensor * state_read_idxs = nullptr; // I32 [ratio*n_state_write] + ggml_tensor * state_write_idxs = nullptr; // I64 [n_state_write] + ggml_tensor * state_write_pos = nullptr; // I32 [n_state_write] + + ggml_tensor * kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + + ggml_tensor * k_rot = nullptr; + }; + + llm_graph_input_dsv4( + const llama_cparams & cparams, + std::unique_ptr inp_raw, + const llama_kv_cache_dsv4_context * mctx) : + inp_raw(std::move(inp_raw)), + cparams(cparams), + mctx(mctx) { + } + ~llm_graph_input_dsv4() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + llm_graph_input_dsv4_raw * get_raw() const { return inp_raw.get(); } + const comp_input & get_csa() const { return inp_csa; } + const comp_input & get_hca() const { return inp_hca; } + const comp_input & get_lid() const { return inp_lid; } + + std::unique_ptr inp_raw; + + comp_input inp_csa; + comp_input inp_hca; + comp_input inp_lid; + + const llama_cparams cparams; + + const llama_kv_cache_dsv4_context * mctx; +}; + class llm_graph_input_attn_cross : public llm_graph_input_i { public: llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {} @@ -920,7 +995,8 @@ struct llm_graph_context { ggml_tensor * gate_up_exps = nullptr, ggml_tensor * up_exps_s = nullptr, ggml_tensor * gate_exps_s = nullptr, - ggml_tensor * down_exps_s = nullptr) const; + ggml_tensor * down_exps_s = nullptr, + ggml_tensor * selected_experts_in = nullptr) const; ggml_tensor * build_moe_ffn( ggml_tensor * cur, @@ -945,7 +1021,8 @@ struct llm_graph_context { ggml_tensor * gate_up_exps_b = nullptr, ggml_tensor * up_exps_s = nullptr, ggml_tensor * gate_exps_s = nullptr, - ggml_tensor * down_exps_s = nullptr) const; + ggml_tensor * down_exps_s = nullptr, + ggml_tensor * selected_experts_in = nullptr) const; // // inputs @@ -1045,6 +1122,8 @@ struct llm_graph_context { llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const; + llm_graph_input_dsv4 * build_inp_dsv4() const; + // note: if k_cur or v_cur are not provided, they will not be stored in the memory ggml_tensor * build_attn( llm_graph_input_attn_kv_iswa * inp, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 2eadeb214811..8be5f28f39e6 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -14,6 +14,7 @@ enum llama_expert_gating_func_type { LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT = 3, // applied to the router weights instead of the logits + LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS = 4, }; enum llama_swa_type { @@ -226,6 +227,16 @@ struct llama_hparams { uint32_t indexer_head_size = 0; uint32_t indexer_top_k = 0; + // DeepSeek-V4 + uint32_t dsv4_o_group_count = 0; + uint32_t dsv4_o_lora_rank = 0; + uint32_t dsv4_hc_mult = 0; + uint32_t dsv4_hc_sinkhorn_iters = 0; + uint32_t dsv4_hash_layer_count = 0; + float dsv4_compress_rope_base = 0.0f; + float dsv4_hc_eps = 0.0f; + std::array dsv4_compress_ratios; + // qwen3vl deepstack // When parsed from GGUF, this implies the first N layers consume the first // N deepstack embeddings. Use deepstack_mapping_arr if you need a more diff --git a/src/llama-kv-cache-dsv4.cpp b/src/llama-kv-cache-dsv4.cpp new file mode 100644 index 000000000000..dfb2fc2620a8 --- /dev/null +++ b/src/llama-kv-cache-dsv4.cpp @@ -0,0 +1,1841 @@ +#include "llama-kv-cache-dsv4.h" + +#include "ggml-backend.h" +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-io.h" +#include "llama-model.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +static constexpr uint32_t DSV4_CSA_RATIO = 4; +static constexpr uint32_t DSV4_HCA_RATIO = 128; + +static constexpr uint32_t DSV4_STATE_MAGIC = 0x34565344; // DSV4 +static constexpr uint32_t DSV4_STATE_VERSION = 1; +static constexpr uint32_t DSV4_STATE_MODE_FULL = 0; +static constexpr uint32_t DSV4_STATE_MODE_PARTIAL = 1; +static constexpr uint32_t DSV4_K_CACHE_STATE_VER = 1; +static constexpr uint32_t DSV4_COMP_STATE_VER = 1; + +static uint32_t dsv4_comp_size(uint32_t kv_size, uint32_t ratio) { + return std::max(1, (kv_size + ratio - 1)/ratio); +} + +static int64_t dsv4_stream_offset(uint32_t n_stream, llama_seq_id seq_id, uint32_t size) { + if (n_stream <= 1) { + return 0; + } + if (seq_id < 0 || (uint32_t) seq_id >= n_stream) { + throw std::runtime_error("DSV4 sequence id out of stream range"); + } + + return (int64_t) seq_id*size; +} + +static bool dsv4_ubatch_has_coupled(const llama_ubatch & ubatch) { + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + if (ubatch.n_seq_id[i] > 1) { + return true; + } + } + + return false; +} + +static bool dsv4_token_has_seq(const llama_ubatch & ubatch, uint32_t i, llama_seq_id seq_id) { + for (int32_t s = 0; s < ubatch.n_seq_id[i]; ++s) { + if (ubatch.seq_id[i][s] == seq_id) { + return true; + } + } + + return false; +} + +static llama_ubatch dsv4_build_raw_write_ubatch(const llama_ubatch & ubatch) { + if (!dsv4_ubatch_has_coupled(ubatch)) { + return ubatch; + } + if (ubatch.embd) { + throw std::runtime_error("DSV4 coupled embedding ubatches are not supported"); + } + + std::vector counts(ubatch.n_seqs_unq, 0); + uint32_t n_tokens = 0; + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + if (dsv4_token_has_seq(ubatch, i, seq_id)) { + ++counts[s]; + ++n_tokens; + } + } + } + + if (n_tokens == 0) { + return ubatch; + } + + const uint32_t n_seq_tokens = counts[0]; + for (uint32_t s = 1; s < counts.size(); ++s) { + if (counts[s] != n_seq_tokens) { + throw std::runtime_error("DSV4 coupled raw writes require equal sequence lengths"); + } + } + + auto data = std::make_shared(); + data->pos.resize((size_t) n_tokens*ubatch.n_pos); + data->n_seq_id.reserve(n_tokens); + data->seq_id.reserve(n_tokens); + data->seq_id_data.reserve(n_tokens); + data->seq_id_unq.assign(ubatch.seq_id_unq, ubatch.seq_id_unq + ubatch.n_seqs_unq); + data->seq_idx.assign(LLAMA_MAX_SEQ, -1); + data->output.assign(n_tokens, 0); + if (ubatch.token) { + data->token.reserve(n_tokens); + } + + for (uint32_t s = 0; s < data->seq_id_unq.size(); ++s) { + data->seq_idx[data->seq_id_unq[s]] = s; + } + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + if (!dsv4_token_has_seq(ubatch, i, seq_id)) { + continue; + } + + const uint32_t dst = data->n_seq_id.size(); + if (ubatch.token) { + data->token.push_back(ubatch.token[i]); + } + for (uint32_t p = 0; p < ubatch.n_pos; ++p) { + data->pos[(size_t) p*n_tokens + dst] = ubatch.pos[(size_t) p*ubatch.n_tokens + i]; + } + data->n_seq_id.push_back(1); + data->seq_id_data.push_back(seq_id); + } + } + + for (uint32_t i = 0; i < n_tokens; ++i) { + data->seq_id.push_back(&data->seq_id_data[i]); + } + + llama_ubatch res { + /*.b_equal_seqs =*/ true, + /*.n_tokens =*/ n_tokens, + /*.n_seq_tokens =*/ n_seq_tokens, + /*.n_seqs =*/ ubatch.n_seqs_unq, + /*.n_seqs_unq =*/ ubatch.n_seqs_unq, + /*.n_pos =*/ ubatch.n_pos, + /*.token =*/ data->token.empty() ? nullptr : data->token.data(), + /*.embd =*/ nullptr, + /*.pos =*/ data->pos.data(), + /*.n_seq_id =*/ data->n_seq_id.data(), + /*.seq_id =*/ data->seq_id.data(), + /*.seq_id_unq =*/ data->seq_id_unq.data(), + /*.seq_idx =*/ data->seq_idx.data(), + /*.output =*/ data->output.data(), + /*.data =*/ data, + }; + + return res; +} + +static std::vector dsv4_build_raw_write_ubatches(const std::vector & ubatches) { + std::vector res; + res.reserve(ubatches.size()); + for (const llama_ubatch & ubatch : ubatches) { + res.push_back(dsv4_build_raw_write_ubatch(ubatch)); + } + return res; +} + +static bool dsv4_batch_has_coupled(const llama_batch & batch) { + if (!batch.n_seq_id) { + return false; + } + + for (int32_t i = 0; i < batch.n_tokens; ++i) { + if (batch.n_seq_id[i] > 1) { + return true; + } + } + + return false; +} + +static bool dsv4_batch_same_seq_set(const llama_batch & batch) { + if (!batch.n_seq_id || !batch.seq_id || batch.n_tokens <= 1) { + return true; + } + + const int32_t n_seq_id_ref = batch.n_seq_id[0]; + + for (int32_t i = 1; i < batch.n_tokens; ++i) { + if (batch.n_seq_id[i] != n_seq_id_ref) { + return false; + } + + for (int32_t r = 0; r < n_seq_id_ref; ++r) { + bool found = false; + for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { + if (batch.seq_id[0][r] == batch.seq_id[i][s]) { + found = true; + break; + } + } + if (!found) { + return false; + } + } + } + + return true; +} + +static int64_t dsv4_comp_graph_n_stream(const llama_ubatch & ubatch, uint32_t n_stream) { + // Coupled sequence sets must stay in one graph stream because their + // compressed state is shared. Independent per-seq state can fan out. + if (n_stream <= 1 || ubatch.n_seqs_unq <= 1 || dsv4_ubatch_has_coupled(ubatch)) { + return 1; + } + + return ubatch.n_seqs_unq; +} + +static void dsv4_state_src_stream_range( + uint32_t n_stream, + llama_seq_id seq_id, + uint32_t & s0, + uint32_t & ns) { + if (seq_id >= 0 && n_stream > 1) { + if ((uint32_t) seq_id >= n_stream) { + throw std::runtime_error("DSV4 state sequence id out of stream range"); + } + + s0 = (uint32_t) seq_id; + ns = 1; + return; + } + + s0 = 0; + ns = seq_id >= 0 ? 1 : n_stream; +} + +static void dsv4_state_dst_stream_range( + uint32_t n_stream, + llama_seq_id seq_id, + uint32_t ns, + uint32_t & s0) { + if (seq_id >= 0) { + if (ns != 1) { + throw std::runtime_error("DSV4 sequence state stream count mismatch"); + } + if (n_stream > 1 && (uint32_t) seq_id >= n_stream) { + throw std::runtime_error("DSV4 state sequence id out of stream range"); + } + + s0 = n_stream > 1 ? (uint32_t) seq_id : 0; + return; + } + + if (ns != n_stream) { + throw std::runtime_error("DSV4 full state stream count mismatch"); + } + + s0 = 0; +} + +static void dsv4_state_write_tensor_streams( + llama_io_write_i & io, + ggml_tensor * tensor, + uint32_t n_rows, + uint32_t s0, + uint32_t ns) { + const int32_t type_i = (int32_t) tensor->type; + const uint64_t ne0 = tensor->ne[0]; + const uint64_t rows = n_rows; + const uint64_t row_size = ggml_row_size(tensor->type, tensor->ne[0]); + + io.write(&type_i, sizeof(type_i)); + io.write(&ne0, sizeof(ne0)); + io.write(&rows, sizeof(rows)); + io.write(&row_size, sizeof(row_size)); + + const size_t offset = (size_t) s0*n_rows*row_size; + const size_t size = (size_t) ns*n_rows*row_size; + + io.write_tensor(tensor, offset, size); +} + +static void dsv4_state_read_tensor_streams( + llama_io_read_i & io, + ggml_tensor * tensor, + uint32_t n_rows, + uint32_t s0, + uint32_t ns) { + int32_t type_i_ref; + uint64_t ne0_ref; + uint64_t rows_ref; + uint64_t row_size_ref; + + io.read(&type_i_ref, sizeof(type_i_ref)); + io.read(&ne0_ref, sizeof(ne0_ref)); + io.read(&rows_ref, sizeof(rows_ref)); + io.read(&row_size_ref, sizeof(row_size_ref)); + + const int32_t type_i = (int32_t) tensor->type; + const uint64_t ne0 = tensor->ne[0]; + const uint64_t rows = n_rows; + const uint64_t row_size = ggml_row_size(tensor->type, tensor->ne[0]); + + if (type_i != type_i_ref || ne0 != ne0_ref || rows != rows_ref || row_size != row_size_ref) { + throw std::runtime_error("DSV4 state tensor metadata mismatch"); + } + + const size_t offset = (size_t) s0*n_rows*row_size; + const size_t size = (size_t) ns*n_rows*row_size; + + io.read_tensor(tensor, offset, size); +} + +static void dsv4_state_write_k_cache( + llama_io_write_i & io, + const llama_kv_cache * kv, + llama_seq_id seq_id, + llama_state_seq_flags flags) { + GGML_UNUSED(flags); + + uint32_t s0; + uint32_t ns; + dsv4_state_src_stream_range(kv->get_n_stream(), seq_id, s0, ns); + + const uint32_t version = DSV4_K_CACHE_STATE_VER; + const uint32_t kv_size = kv->get_size(); + const auto layer_ids = kv->get_layer_ids(); + const uint32_t n_layer = layer_ids.size(); + + io.write(&version, sizeof(version)); + io.write(&kv_size, sizeof(kv_size)); + io.write(&ns, sizeof(ns)); + io.write(&n_layer, sizeof(n_layer)); + + for (uint32_t il : layer_ids) { + io.write(&il, sizeof(il)); + dsv4_state_write_tensor_streams(io, kv->get_k_storage(il), kv_size, s0, ns); + } +} + +static void dsv4_state_read_k_cache( + llama_io_read_i & io, + llama_kv_cache * kv, + llama_seq_id seq_id, + llama_state_seq_flags flags) { + GGML_UNUSED(flags); + + uint32_t version; + uint32_t kv_size_ref; + uint32_t ns; + uint32_t n_layer_ref; + + io.read(&version, sizeof(version)); + io.read(&kv_size_ref, sizeof(kv_size_ref)); + io.read(&ns, sizeof(ns)); + io.read(&n_layer_ref, sizeof(n_layer_ref)); + + if (version != DSV4_K_CACHE_STATE_VER) { + throw std::runtime_error("DSV4 K-cache state version mismatch"); + } + if (kv_size_ref != kv->get_size()) { + throw std::runtime_error("DSV4 K-cache state size mismatch"); + } + + uint32_t s0; + dsv4_state_dst_stream_range(kv->get_n_stream(), seq_id, ns, s0); + + const auto layer_ids = kv->get_layer_ids(); + if (n_layer_ref != layer_ids.size()) { + throw std::runtime_error("DSV4 K-cache layer count mismatch"); + } + + for (uint32_t il : layer_ids) { + uint32_t il_ref; + io.read(&il_ref, sizeof(il_ref)); + if (il_ref != il) { + throw std::runtime_error("DSV4 K-cache layer id mismatch"); + } + + dsv4_state_read_tensor_streams(io, kv->get_k_storage(il), kv->get_size(), s0, ns); + } +} + +static std::string dsv4_plan_positions(const std::vector & values) { + std::ostringstream ss; + ss << "["; + for (size_t i = 0; i < values.size(); ++i) { + if (i > 0) { + ss << ", "; + } + ss << values[i]; + } + ss << "]"; + return ss.str(); +} + +static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan( + const llama_ubatch & ubatch, + uint32_t ratio, + bool overlap, + uint32_t state_size, + uint32_t kv_size, + uint32_t n_stream) { + llama_kv_cache_dsv4_context::comp_plan plan; + plan.n_visible.resize(ubatch.n_tokens); + plan.n_stream = dsv4_comp_graph_n_stream(ubatch, n_stream); + + // n_stream is the persistent cache/state layout; plan.n_stream is the + // graph view for this ubatch and can be a subset of those streams. + if (n_stream <= 1 && ubatch.n_seqs_unq > 1) { + throw std::runtime_error("DSV4 single compressed stream cannot serve multiple sequences"); + } + + const int64_t state_rows = (int64_t) state_size*n_stream; + + struct persist_row { + int32_t dst; + int32_t src; + llama_pos pos; + }; + + std::vector persist_rows; + + // For the overlap compressor, build_overlap_compressed_kv_from_state() consumes + // state_read_idxs as two contiguous halves: the first ratio*n_blocks entries are + // the "previous-window" gather indices for every block, followed by the + // "current-window" indices for every block. Collect them separately here and + // append cur after prev once the loop has visited all completed blocks + std::vector overlap_prev_reads; + std::vector overlap_cur_reads; + + std::map, int64_t> curr_token_idx_map; + + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + for (int32_t s = 0; s < ubatch.n_seq_id[i]; ++s) { + curr_token_idx_map[std::make_pair(ubatch.seq_id[i][s], ubatch.pos[i])] = i; + } + } + + const auto state_source_idx = [&](llama_seq_id seq_id, llama_pos pos) -> int32_t { + if (pos < 0) { + // The overlap compressor needs a zero/-inf source for the first + // block's previous half. The graph appends that row after the + // current-ubatch scratch rows. + return (int32_t) (state_rows + ubatch.n_tokens); + } + + const auto key = std::make_pair(seq_id, pos); + if (curr_token_idx_map.find(key) != curr_token_idx_map.end()) { + return (int32_t) (state_rows + curr_token_idx_map.at(key)); + } + + const int64_t stream_off = dsv4_stream_offset(n_stream, seq_id, state_size); + return (int32_t) (stream_off + pos%state_size); + }; + + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + const llama_pos pos = ubatch.pos[i]; + + if (pos < 0) { + continue; + } + + plan.state_pos.push_back((int32_t) (pos%ratio)); + + const int64_t n_visible = (int64_t) (pos + 1)/ratio; + plan.n_visible[i] = (int32_t) n_visible; + plan.n_kv = std::max(plan.n_kv, n_visible); + + for (int32_t s = 0; s < ubatch.n_seq_id[i]; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[i][s]; + const int64_t stream_off = dsv4_stream_offset(n_stream, seq_id, state_size); + const int32_t state_idx = (int32_t) (stream_off + pos%state_size); + + const auto it = std::find_if(persist_rows.begin(), persist_rows.end(), + [state_idx](const persist_row & row) { + return row.dst == state_idx; + }); + if (it == persist_rows.end()) { + persist_rows.push_back({ state_idx, (int32_t) i, pos }); + } else if (pos > it->pos) { + it->src = (int32_t) i; + it->pos = pos; + } + + if ((pos + 1) % ratio != 0) { + continue; + } + + const llama_pos source_start = pos + 1 - ratio; + const int64_t cache_off = dsv4_stream_offset(n_stream, seq_id, kv_size); + + plan.state_write_idxs.push_back(cache_off + pos/ratio); + plan.state_write_pos.push_back((int32_t) source_start); + + if (overlap) { + const llama_pos prev_start = source_start - ratio; + + for (uint32_t j = 0; j < ratio; ++j) { + overlap_prev_reads.push_back(state_source_idx(seq_id, prev_start + j)); + } + for (uint32_t j = 0; j < ratio; ++j) { + overlap_cur_reads.push_back(state_source_idx(seq_id, source_start + j)); + } + } else { + for (uint32_t j = 0; j < ratio; ++j) { + plan.state_read_idxs.push_back(state_source_idx(seq_id, source_start + j)); + } + } + } + } + + if (ratio == DSV4_CSA_RATIO && plan.state_write_idxs.empty() && !plan.state_pos.empty()) { + // Non-boundary CSA steps still need a write op so their graph matches + // boundary steps. Use a padded scratch row that is masked from attention. + assert(kv_size > 0); + + uint32_t i = 0; + while (i < ubatch.n_tokens && ubatch.pos[i] < 0) { + ++i; + } + assert(i < ubatch.n_tokens); + + const llama_pos pos = ubatch.pos[i]; + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + const int64_t cache_off = dsv4_stream_offset(n_stream, seq_id, kv_size); + const int32_t source_idx = state_source_idx(seq_id, pos); + + plan.state_write_idxs.push_back(cache_off + kv_size - 1); + plan.state_write_pos .push_back(0); + + if (overlap) { + for (uint32_t j = 0; j < ratio; ++j) { + overlap_prev_reads.push_back(source_idx); + overlap_cur_reads .push_back(source_idx); + } + } else { + for (uint32_t j = 0; j < ratio; ++j) { + plan.state_read_idxs.push_back(source_idx); + } + } + } + + if (overlap) { + // [ all blocks' prev-window indices | all blocks' cur-window indices ] + plan.state_read_idxs.reserve(overlap_prev_reads.size() + overlap_cur_reads.size()); + plan.state_read_idxs.insert(plan.state_read_idxs.end(), + overlap_prev_reads.begin(), overlap_prev_reads.end()); + plan.state_read_idxs.insert(plan.state_read_idxs.end(), + overlap_cur_reads.begin(), overlap_cur_reads.end()); + } + + plan.n_kv = GGML_PAD(plan.n_kv, 256u); + + std::sort(persist_rows.begin(), persist_rows.end(), + [](const persist_row & a, const persist_row & b) { + return a.dst < b.dst; + }); + + for (const persist_row & row : persist_rows) { + plan.state_persist_src_idxs.push_back(row.src); + plan.state_persist_dst_idxs.push_back(row.dst); + } + + static const bool debug = []() { + const char * env = getenv("LLAMA_DSV4_COMPRESS_DEBUG"); + return env && atoi(env) > 0; + }(); + + if (debug) { + LLAMA_LOG_INFO("%s: ratio=%u, n_tokens=%u, state_persist_dst=%s, state_write_pos=%s\n", + __func__, ratio, ubatch.n_tokens, + dsv4_plan_positions(plan.state_persist_dst_idxs).c_str(), + dsv4_plan_positions(plan.state_write_pos).c_str()); + } + + return plan; +} + +static std::vector dsv4_build_comp_plans( + const std::vector & ubatches, + uint32_t ratio, + bool overlap, + uint32_t state_size, + uint32_t kv_size, + uint32_t n_stream) { + std::vector plans; + plans.reserve(ubatches.size()); + + for (const llama_ubatch & ubatch : ubatches) { + plans.push_back(dsv4_build_comp_plan(ubatch, ratio, overlap, state_size, kv_size, n_stream)); + } + + return plans; +} + +static llama_kv_cache::slot_info_vec_t dsv4_build_comp_sinfos( + const std::vector & ubatches, + uint32_t n_stream) { + llama_kv_cache::slot_info_vec_t sinfos; + sinfos.reserve(ubatches.size()); + + for (const llama_ubatch & ubatch : ubatches) { + if (n_stream <= 1 && ubatch.n_seqs_unq > 1) { + throw std::runtime_error("DSV4 single compressed stream cannot serve multiple sequences"); + } + + const uint32_t ns = (uint32_t) dsv4_comp_graph_n_stream(ubatch, n_stream); + llama_kv_cache::slot_info sinfo; + sinfo.s0 = n_stream > 1 ? LLAMA_MAX_SEQ : 0; + sinfo.s1 = 0; + sinfo.resize(ns); + + for (uint32_t s = 0; s < ns; ++s) { + const llama_seq_id seq_id = n_stream > 1 ? ubatch.seq_id_unq[s] : 0; + const uint32_t strm = (uint32_t) dsv4_stream_offset(n_stream, seq_id, 1); + + sinfo.s0 = std::min(sinfo.s0, strm); + sinfo.s1 = std::max(sinfo.s1, strm); + sinfo.strm[s] = strm; + sinfo.idxs[s].resize(1, 0); + } + + if (n_stream > 1 && sinfo.s1 - sinfo.s0 + 1 != ns) { + throw std::runtime_error("DSV4 compressed streams are not contiguous in ubatch"); + } + + sinfos.push_back(std::move(sinfo)); + } + + return sinfos; +} + +static llama_kv_cache::slot_info_vec_t dsv4_build_raw_read_sinfos( + const llama_kv_cache::slot_info_vec_t & sinfos_write, + const std::vector & ubatches) { + llama_kv_cache::slot_info_vec_t sinfos; + sinfos.reserve(ubatches.size()); + + for (size_t i = 0; i < ubatches.size(); ++i) { + const llama_ubatch & ubatch = ubatches[i]; + const auto & sinfo_write = sinfos_write[i]; + + if (!dsv4_ubatch_has_coupled(ubatch)) { + sinfos.push_back(sinfo_write); + continue; + } + + const llama_seq_id seq_id = ubatch.seq_id[0][0]; + uint32_t i_stream = 0; + for (; i_stream < sinfo_write.n_stream(); ++i_stream) { + if (sinfo_write.strm[i_stream] == seq_id) { + break; + } + } + if (i_stream == sinfo_write.n_stream()) { + throw std::runtime_error("DSV4 raw write stream not found for coupled read"); + } + + llama_kv_cache::slot_info sinfo; + sinfo.s0 = sinfo_write.strm[i_stream]; + sinfo.s1 = sinfo_write.strm[i_stream]; + sinfo.resize(1); + sinfo.strm[0] = sinfo_write.strm[i_stream]; + sinfo.idxs[0] = sinfo_write.idxs[i_stream]; + sinfos.push_back(std::move(sinfo)); + } + + return sinfos; +} + +static llama_kv_cache_dsv4_context::comp_plan dsv4_build_reserve_comp_plan( + const llama_ubatch & ubatch, + uint32_t ratio, + bool overlap, + uint32_t state_size, + uint32_t kv_size, + uint32_t n_stream) { + llama_kv_cache_dsv4_context::comp_plan plan; + plan.n_visible.resize(ubatch.n_tokens); + plan.n_stream = dsv4_comp_graph_n_stream(ubatch, n_stream); + plan.n_kv = kv_size; + + if (ubatch.n_tokens == 0) { + return plan; + } + + const uint32_t n_seqs = std::max(1, ubatch.n_seqs); + const uint32_t n_seq_tokens = std::max(1, ubatch.n_seq_tokens); + const uint64_t n_blocks_u64 = (uint64_t) n_seqs*((n_seq_tokens + ratio - 1)/ratio); + const size_t n_blocks = (size_t) std::max(1, n_blocks_u64); + GGML_ASSERT((uint64_t) n_blocks == std::max(1, n_blocks_u64)); + + const uint64_t state_rows = (uint64_t) state_size*n_stream; + const size_t n_persist = (size_t) std::min(ubatch.n_tokens, state_rows); + + plan.state_pos .resize(ubatch.n_tokens); + plan.state_persist_src_idxs.resize(n_persist); + plan.state_persist_dst_idxs.resize(n_persist); + plan.state_read_idxs .resize((overlap ? 2u : 1u)*ratio*n_blocks); + plan.state_write_idxs.resize(n_blocks); + plan.state_write_pos .resize(n_blocks); + + return plan; +} + +static void dsv4_make_k_only(llama_hparams & hparams) { + // llama_kv_cache uses hparams.is_mla() to allocate K-only storage. + hparams.n_embd_head_k_mla_impl = hparams.n_embd_head_k(); + hparams.n_embd_head_v_mla_impl = hparams.n_embd_head_k(); +} + +// +// llama_dsv4_comp_state +// + +llama_dsv4_comp_state::llama_dsv4_comp_state( + const llama_model & model, + bool offload, + bool unified, + uint32_t n_seq_max, + uint32_t ratio, + uint32_t state_size, + uint32_t n_embd_state, + const char * name, + const llama_memory_i::layer_filter_cb & filter) : + ratio(ratio), + state_size(state_size), + n_embd_state(n_embd_state), + n_stream(unified ? 1 : n_seq_max) { + const llama_hparams & hparams = model.hparams; + + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + + std::map ctx_map; + + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*hparams.n_layer()*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map.emplace(buft, ctx); + + return ctx; + } + + return it->second.get(); + }; + + for (uint32_t il = 0; il < hparams.n_layer(); ++il) { + if (filter && !filter(il)) { + continue; + } + + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (offload) { + auto * dev = model.dev_layer(il); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + throw std::runtime_error("failed to create ggml context for DSV4 compressor state"); + } + + ggml_tensor * kv = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_state, state_size, n_stream); + ggml_tensor * score = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_state, state_size, n_stream); + + ggml_format_name(kv, "dsv4_%s_state_kv_l%d", name, il); + ggml_format_name(score, "dsv4_%s_state_score_l%d", name, il); + + map_layer_ids[il] = layers.size(); + + layers.push_back({ il, kv, score }); + } + + for (auto & [buft, ctx] : ctx_map) { + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); + if (!buf) { + throw std::runtime_error("failed to allocate buffer for DSV4 compressor state"); + } + + ggml_backend_buffer_clear(buf, 0); + + LLAMA_LOG_INFO("%s: %10s DSV4 %s state buffer size = %8.2f MiB\n", + __func__, ggml_backend_buffer_name(buf), name, ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + + ctxs_bufs.emplace_back(std::move(ctx), buf); + } + + LLAMA_LOG_INFO("%s: %s ratio = %u, state = %u x %u, streams = %u, layers = %zu, size = %7.2f MiB\n", + __func__, name, ratio, state_size, n_embd_state, n_stream, layers.size(), total_size()/1024.0/1024.0); +} + +void llama_dsv4_comp_state::clear(bool data) { + if (!data) { + return; + } + + for (auto & [_, buf] : ctxs_bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } +} + +uint32_t llama_dsv4_comp_state::get_ratio() const { + return ratio; +} + +uint32_t llama_dsv4_comp_state::get_state_size() const { + return state_size; +} + +uint32_t llama_dsv4_comp_state::get_n_stream() const { + return n_stream; +} + +std::map llama_dsv4_comp_state::memory_breakdown() const { + std::map ret; + for (const auto & [_, buf] : ctxs_bufs) { + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf.get()); + ret[buft] += ggml_backend_buffer_get_size(buf.get()); + } + return ret; +} + +void llama_dsv4_comp_state::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + GGML_UNUSED(flags); + + uint32_t s0; + uint32_t ns; + dsv4_state_src_stream_range(n_stream, seq_id, s0, ns); + + const uint32_t version = DSV4_COMP_STATE_VER; + const uint32_t n_layer = layers.size(); + + io.write(&version, sizeof(version)); + io.write(&ratio, sizeof(ratio)); + io.write(&state_size, sizeof(state_size)); + io.write(&n_embd_state, sizeof(n_embd_state)); + io.write(&ns, sizeof(ns)); + io.write(&n_layer, sizeof(n_layer)); + + for (const auto & layer : layers) { + io.write(&layer.il, sizeof(layer.il)); + + dsv4_state_write_tensor_streams(io, layer.kv, state_size, s0, ns); + dsv4_state_write_tensor_streams(io, layer.score, state_size, s0, ns); + } +} + +void llama_dsv4_comp_state::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + GGML_UNUSED(flags); + + uint32_t version; + uint32_t ratio_ref; + uint32_t state_size_ref; + uint32_t n_embd_state_ref; + uint32_t ns; + uint32_t n_layer_ref; + + io.read(&version, sizeof(version)); + io.read(&ratio_ref, sizeof(ratio_ref)); + io.read(&state_size_ref, sizeof(state_size_ref)); + io.read(&n_embd_state_ref, sizeof(n_embd_state_ref)); + io.read(&ns, sizeof(ns)); + io.read(&n_layer_ref, sizeof(n_layer_ref)); + + if (version != DSV4_COMP_STATE_VER) { + throw std::runtime_error("DSV4 compressor state version mismatch"); + } + if (ratio_ref != ratio || state_size_ref != state_size || n_embd_state_ref != n_embd_state) { + throw std::runtime_error("DSV4 compressor state metadata mismatch"); + } + if (n_layer_ref != layers.size()) { + throw std::runtime_error("DSV4 compressor state layer count mismatch"); + } + + uint32_t s0; + dsv4_state_dst_stream_range(n_stream, seq_id, ns, s0); + + for (const auto & layer : layers) { + uint32_t il_ref; + io.read(&il_ref, sizeof(il_ref)); + if (il_ref != layer.il) { + throw std::runtime_error("DSV4 compressor state layer id mismatch"); + } + + dsv4_state_read_tensor_streams(io, layer.kv, state_size, s0, ns); + dsv4_state_read_tensor_streams(io, layer.score, state_size, s0, ns); + } +} + +ggml_tensor * llama_dsv4_comp_state::get_kv(ggml_context * ctx, int32_t il) const { + const int32_t ids = map_layer_ids.at(il); + + ggml_tensor * state = layers[ids].kv; + + return ggml_reshape_2d(ctx, state, state->ne[0], state->ne[1]*state->ne[2]); +} + +ggml_tensor * llama_dsv4_comp_state::get_score(ggml_context * ctx, int32_t il) const { + const int32_t ids = map_layer_ids.at(il); + + ggml_tensor * state = layers[ids].score; + + return ggml_reshape_2d(ctx, state, state->ne[0], state->ne[1]*state->ne[2]); +} + +ggml_tensor * llama_dsv4_comp_state::cpy_kv(ggml_context * ctx, ggml_tensor * cur, ggml_tensor * idxs, int32_t il) const { + return ggml_set_rows(ctx, get_kv(ctx, il), cur, idxs); +} + +ggml_tensor * llama_dsv4_comp_state::cpy_score(ggml_context * ctx, ggml_tensor * cur, ggml_tensor * idxs, int32_t il) const { + return ggml_set_rows(ctx, get_score(ctx, il), cur, idxs); +} + +size_t llama_dsv4_comp_state::total_size() const { + size_t size = 0; + + for (const auto & [_, buf] : ctxs_bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +// +// llama_kv_cache_dsv4 +// + +llama_kv_cache_dsv4::llama_kv_cache_dsv4( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_ubatch, + uint32_t n_pad, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse) : + hparams_raw(model.hparams), + hparams_csa(model.hparams), + hparams_hca(model.hparams), + hparams_lid(model.hparams), + n_seq_max(n_seq_max) { + + const layer_filter_cb filter_raw = [&](int32_t il) { + if (filter && !filter(il)) { + return false; + } + + return true; + }; + + GGML_UNUSED(unified); + + // Keep DSV4 KV/state streams per sequence even when public KV mode is unified. + const bool unified_raw = false; + + LLAMA_LOG_INFO("%s: creating DSV4 raw KV cache\n", __func__); + + dsv4_make_k_only(hparams_raw); + + kv_raw = std::make_unique( + model, hparams_raw, type_k, type_v, + v_trans, offload, swa_full, unified_raw, kv_size, n_seq_max, n_ubatch, n_pad, + nullptr, filter_raw, reuse, nullptr); + + dsv4_make_k_only(hparams_csa); + dsv4_make_k_only(hparams_hca); + + std::fill(hparams_lid.n_head_kv_arr.begin(), hparams_lid.n_head_kv_arr.end(), 1); + hparams_lid.n_embd_head_k_full = model.hparams.indexer_head_size; + hparams_lid.n_embd_head_v_full = model.hparams.indexer_head_size; + hparams_lid.n_embd_head_k_swa = model.hparams.indexer_head_size; + hparams_lid.n_embd_head_v_swa = model.hparams.indexer_head_size; + hparams_lid.rope_type = LLAMA_ROPE_TYPE_NEOX; + dsv4_make_k_only(hparams_lid); + + const layer_filter_cb filter_csa = [&](int32_t il) { + if (filter && !filter(il)) { + return false; + } + + return model.hparams.dsv4_compress_ratios[il] == DSV4_CSA_RATIO; + }; + + const layer_filter_cb filter_hca = [&](int32_t il) { + if (filter && !filter(il)) { + return false; + } + + return model.hparams.dsv4_compress_ratios[il] == DSV4_HCA_RATIO; + }; + + const bool unified_compressed = false; + + LLAMA_LOG_INFO("%s: creating DSV4 CSA compressed KV cache, size = %u cells\n", + __func__, dsv4_comp_size(kv_size, DSV4_CSA_RATIO)); + + kv_csa = std::make_unique( + model, hparams_csa, type_k, type_v, + v_trans, offload, unified_compressed, GGML_PAD(dsv4_comp_size(kv_size, DSV4_CSA_RATIO), 256u), n_seq_max, n_pad, + 0, LLAMA_SWA_TYPE_NONE, nullptr, filter_csa, nullptr, nullptr); + + LLAMA_LOG_INFO("%s: creating DSV4 HCA compressed KV cache, size = %u cells\n", + __func__, dsv4_comp_size(kv_size, DSV4_HCA_RATIO)); + + kv_hca = std::make_unique( + model, hparams_hca, type_k, type_v, + v_trans, offload, unified_compressed, GGML_PAD(dsv4_comp_size(kv_size, DSV4_HCA_RATIO), 256u), n_seq_max, n_pad, + 0, LLAMA_SWA_TYPE_NONE, nullptr, filter_hca, nullptr, nullptr); + + LLAMA_LOG_INFO("%s: creating DSV4 lightning-indexer KV cache, size = %u cells\n", + __func__, dsv4_comp_size(kv_size, DSV4_CSA_RATIO)); + + kv_lid = std::make_unique( + model, hparams_lid, type_k, type_v, + v_trans, offload, unified_compressed, GGML_PAD(dsv4_comp_size(kv_size, DSV4_CSA_RATIO), 256u), n_seq_max, n_pad, + 0, LLAMA_SWA_TYPE_NONE, nullptr, filter_csa, nullptr, nullptr); + + LLAMA_LOG_INFO("%s: creating DSV4 CSA compressor state\n", __func__); + + csa_state = std::make_unique( + model, offload, unified_compressed, n_seq_max, DSV4_CSA_RATIO, 2*DSV4_CSA_RATIO, + 2*model.hparams.n_embd_head_k(), "csa", filter_csa); + + LLAMA_LOG_INFO("%s: creating DSV4 HCA compressor state\n", __func__); + + hca_state = std::make_unique( + model, offload, unified_compressed, n_seq_max, DSV4_HCA_RATIO, DSV4_HCA_RATIO, + model.hparams.n_embd_head_k(), "hca", filter_hca); + + LLAMA_LOG_INFO("%s: creating DSV4 lightning-indexer compressor state\n", __func__); + + lid_state = std::make_unique( + model, offload, unified_compressed, n_seq_max, DSV4_CSA_RATIO, 2*DSV4_CSA_RATIO, + 2*model.hparams.indexer_head_size, "lid", filter_csa); + + // DSV4 attention reads compressed-K / compressor-state rows that the current + // graph does not necessarily overwrite; uninitialized buffer contents would + // otherwise leak in (instance-specific garbage) and corrupt recall. Zero all + // compressed buffers up front so reads of un-written rows are deterministic. + clear_compressed(true); +} + +llama_memory_context_ptr llama_kv_cache_dsv4::init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) { + GGML_UNUSED(embd_all); + + const bool raw_per_seq = kv_raw->get_base()->get_n_stream() != 1; + const bool comp_per_seq = csa_state->get_n_stream() > 1; + const bool comp_coupled = comp_per_seq && !raw_per_seq && dsv4_batch_has_coupled(balloc.get_batch()); + const bool comp_coupled_same_set = comp_coupled && dsv4_batch_same_seq_set(balloc.get_batch()); + + const auto make_context = [&](std::vector ubatches) -> llama_memory_context_ptr { + auto ubatches_raw = dsv4_build_raw_write_ubatches(ubatches); + + auto sinfos_raw_base_write = kv_raw->get_base()->prepare(ubatches_raw); + if (sinfos_raw_base_write.empty()) { + return nullptr; + } + + auto sinfos_raw_swa_write = kv_raw->get_swa()->prepare(ubatches_raw); + if (sinfos_raw_swa_write.empty()) { + return nullptr; + } + + auto sinfos_raw_swa_read = dsv4_build_raw_read_sinfos(sinfos_raw_swa_write, ubatches); + + return std::make_unique( + this, + std::move(sinfos_raw_base_write), + std::move(sinfos_raw_swa_write), + std::move(sinfos_raw_swa_read), + std::move(ubatches), + std::move(ubatches_raw)); + }; + + // Match llama_kv_cache_iswa splitting when DSV4 compressed state does not + // require per-sequence graph layout. + do { + if (raw_per_seq || comp_per_seq) { + break; + } + + balloc.split_reset(); + + std::vector ubatches; + while (true) { + auto ubatch = balloc.split_simple(n_ubatch); + if (ubatch.n_tokens == 0) { + break; + } + ubatches.push_back(std::move(ubatch)); // NOLINT + } + + if (balloc.get_n_used() < balloc.get_n_tokens()) { + break; + } + + if (auto ctx = make_context(std::move(ubatches))) { + return ctx; + } + } while (false); + + // When either raw or compressed state is per-sequence, split ubatches so + // every token maps cleanly to its stream. This may serialize independent + // non-unified sequences, but keeps compressed state ownership explicit. + do { + balloc.split_reset(); + + std::vector ubatches; + while (true) { + llama_ubatch ubatch; + if (comp_coupled_same_set) { + ubatch = balloc.split_equal(n_ubatch, false); + } else if (comp_coupled) { + ubatch = balloc.split_seq(1); + } else if (comp_per_seq) { + ubatch = balloc.split_seq(n_ubatch); + } else { + ubatch = balloc.split_equal(n_ubatch, raw_per_seq); + } + + if (ubatch.n_tokens == 0) { + break; + } + ubatches.push_back(std::move(ubatch)); // NOLINT + } + + if (balloc.get_n_used() < balloc.get_n_tokens()) { + break; + } + + if (auto ctx = make_context(std::move(ubatches))) { + return ctx; + } + } while (false); + + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_context_ptr llama_kv_cache_dsv4::init_full() { + return std::make_unique(this); +} + +llama_memory_context_ptr llama_kv_cache_dsv4::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); +} + +bool llama_kv_cache_dsv4::get_can_shift() const { + // Compressed row metadata uses block-derived positions. Keep shifting + // disabled until DSV4 compressed-cache shift semantics are wired. + return false; +} + +void llama_kv_cache_dsv4::clear(bool data) { + kv_raw->clear(data); + clear_compressed(true); // DSV4 compressed buffers must never expose stale/uninit rows +} + +bool llama_kv_cache_dsv4::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (p1 >= 0) { + return false; + } + + if (p0 > 0) { + // DSV4 compressed cache rows are derived from running compressor state, + // so arbitrary rollback is not reconstructible from the raw cache alone. + // Allow the common prompt-cache cleanup no-op: remove [end, infinity). + if (seq_id >= 0 && p0 > kv_raw->seq_pos_max(seq_id)) { + return true; + } + + return false; + } + + const bool res = kv_raw->seq_rm(seq_id, p0, p1); + + if (res) { + clear_compressed(true); + } + + return res; +} + +void llama_kv_cache_dsv4::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_raw->seq_cp(seq_id_src, seq_id_dst, p0, p1); + clear_compressed(true); +} + +void llama_kv_cache_dsv4::seq_keep(llama_seq_id seq_id) { + kv_raw->seq_keep(seq_id); + clear_compressed(true); +} + +void llama_kv_cache_dsv4::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_raw->seq_add(seq_id, p0, p1, shift); + clear_compressed(true); +} + +void llama_kv_cache_dsv4::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_raw->seq_div(seq_id, p0, p1, d); + clear_compressed(true); +} + +llama_pos llama_kv_cache_dsv4::seq_pos_min(llama_seq_id seq_id) const { + if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { + return -1; + } + + // The raw SWA cache may contain a wider window, but the compressed DSV4 + // state cannot be rolled back within that window. Report only the current + // boundary so server-context uses checkpoints for rollback. + return kv_raw->seq_pos_max(seq_id); +} + +llama_pos llama_kv_cache_dsv4::seq_pos_max(llama_seq_id seq_id) const { + if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { + return -1; + } + + return kv_raw->seq_pos_max(seq_id); +} + +std::map llama_kv_cache_dsv4::memory_breakdown() const { + std::map mb = kv_raw->memory_breakdown(); + for (const auto & buft_size : kv_csa->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + for (const auto & buft_size : kv_hca->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + for (const auto & buft_size : kv_lid->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + for (const auto & buft_size : csa_state->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + for (const auto & buft_size : hca_state->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + for (const auto & buft_size : lid_state->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; +} + +void llama_kv_cache_dsv4::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + const bool partial_only = flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY; + + const uint32_t magic = DSV4_STATE_MAGIC; + const uint32_t version = DSV4_STATE_VERSION; + const uint32_t mode = partial_only ? DSV4_STATE_MODE_PARTIAL : DSV4_STATE_MODE_FULL; + + io.write(&magic, sizeof(magic)); + io.write(&version, sizeof(version)); + io.write(&mode, sizeof(mode)); + + kv_raw->state_write(io, seq_id, flags); + + if (!partial_only) { + dsv4_state_write_k_cache(io, kv_csa.get(), seq_id, flags); + dsv4_state_write_k_cache(io, kv_hca.get(), seq_id, flags); + dsv4_state_write_k_cache(io, kv_lid.get(), seq_id, flags); + } + + csa_state->state_write(io, seq_id, flags); + hca_state->state_write(io, seq_id, flags); + lid_state->state_write(io, seq_id, flags); +} + +void llama_kv_cache_dsv4::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + uint32_t magic; + uint32_t version; + uint32_t mode = DSV4_STATE_MODE_FULL; + + io.read(&magic, sizeof(magic)); + io.read(&version, sizeof(version)); + + if (magic != DSV4_STATE_MAGIC) { + throw std::runtime_error("DSV4 state magic mismatch"); + } + if (version != DSV4_STATE_VERSION) { + throw std::runtime_error("DSV4 state version mismatch"); + } + + io.read(&mode, sizeof(mode)); + if (mode != DSV4_STATE_MODE_FULL && mode != DSV4_STATE_MODE_PARTIAL) { + throw std::runtime_error("DSV4 state mode mismatch"); + } + + const bool partial_only = mode == DSV4_STATE_MODE_PARTIAL; + if (partial_only != !!(flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY)) { + throw std::runtime_error("DSV4 state flags mismatch"); + } + + kv_raw->state_read(io, seq_id, flags); + + if (!partial_only) { + dsv4_state_read_k_cache(io, kv_csa.get(), seq_id, flags); + dsv4_state_read_k_cache(io, kv_hca.get(), seq_id, flags); + dsv4_state_read_k_cache(io, kv_lid.get(), seq_id, flags); + } + + csa_state->state_read(io, seq_id, flags); + hca_state->state_read(io, seq_id, flags); + lid_state->state_read(io, seq_id, flags); + +} + +llama_kv_cache_iswa * llama_kv_cache_dsv4::get_raw() const { + return kv_raw.get(); +} + +llama_kv_cache * llama_kv_cache_dsv4::get_csa() const { + return kv_csa.get(); +} + +llama_kv_cache * llama_kv_cache_dsv4::get_hca() const { + return kv_hca.get(); +} + +llama_kv_cache * llama_kv_cache_dsv4::get_lid() const { + return kv_lid.get(); +} + +llama_dsv4_comp_state * llama_kv_cache_dsv4::get_csa_state() const { + return csa_state.get(); +} + +llama_dsv4_comp_state * llama_kv_cache_dsv4::get_hca_state() const { + return hca_state.get(); +} + +llama_dsv4_comp_state * llama_kv_cache_dsv4::get_lid_state() const { + return lid_state.get(); +} + +void llama_kv_cache_dsv4::clear_compressed(bool data) { + kv_csa->clear(data); + kv_hca->clear(data); + kv_lid->clear(data); + csa_state->clear(data); + hca_state->clear(data); + lid_state->clear(data); +} + +// +// llama_kv_cache_dsv4_raw_context +// + +static llama_kv_cache::slot_info dsv4_build_full_sinfo(const llama_kv_cache * kv) { + const uint32_t n_stream = kv->get_n_stream(); + + llama_kv_cache::slot_info sinfo; + sinfo.s0 = 0; + sinfo.s1 = n_stream - 1; + sinfo.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + sinfo.strm[s] = s; + sinfo.idxs[s].resize(1, 0); + } + + return sinfo; +} + +llama_kv_cache_dsv4_raw_context::llama_kv_cache_dsv4_raw_context(llama_kv_cache_iswa * kv) : + kv_swa(kv->get_swa()), + ctx_base_mem(nullptr), + ctx_swa_mem(nullptr), + n_kv(kv_swa->get_size()), + status(LLAMA_MEMORY_STATUS_SUCCESS) { + sinfos_read.push_back(dsv4_build_full_sinfo(kv_swa)); + sinfos_write = sinfos_read; +} + +llama_kv_cache_dsv4_raw_context::llama_kv_cache_dsv4_raw_context( + llama_kv_cache_iswa * kv, + llama_context * lctx, + bool optimize) : + kv_swa(kv->get_swa()), + ctx_base_mem(kv->get_base()->init_update(lctx, optimize)), + ctx_swa_mem(kv->get_swa()->init_update(lctx, optimize)), + n_kv(kv_swa->get_size()), + status(llama_memory_status_combine(ctx_base_mem->get_status(), ctx_swa_mem->get_status())) { +} + +llama_kv_cache_dsv4_raw_context::llama_kv_cache_dsv4_raw_context( + llama_kv_cache_iswa * kv, + slot_info_vec_t sinfos_base_write, + slot_info_vec_t sinfos_swa_write, + slot_info_vec_t sinfos_swa_read, + std::vector ubatches, + std::vector ubatches_write) : + kv_swa(kv->get_swa()), + sinfos_write(std::move(sinfos_swa_write)), + sinfos_read(std::move(sinfos_swa_read)), + ubatches(std::move(ubatches)), + ubatches_write(std::move(ubatches_write)), + ctx_base_mem(std::make_unique( + kv->get_base(), std::move(sinfos_base_write), this->ubatches_write)), + ctx_swa_mem(nullptr), + n_kv(kv_swa->get_size()), + status(LLAMA_MEMORY_STATUS_SUCCESS) { +} + +bool llama_kv_cache_dsv4_raw_context::next() { + if (ubatches.empty()) { + return true; + } + + if (ctx_base_mem) { + ctx_base_mem->next(); + } + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_dsv4_raw_context::apply() { + bool res = true; + + if (ctx_base_mem) { + res = res & ctx_base_mem->apply(); + } + if (ctx_swa_mem) { + res = res & ctx_swa_mem->apply(); + } + if (!ubatches_write.empty()) { + kv_swa->apply_ubatch(sinfos_write[i_next], ubatches_write[i_next]); + n_kv = kv_swa->get_n_kv(sinfos_read[i_next]); + } + + return res; +} + +llama_memory_status llama_kv_cache_dsv4_raw_context::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_dsv4_raw_context::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_dsv4_raw_context::get_n_kv() const { + return n_kv; +} + +uint32_t llama_kv_cache_dsv4_raw_context::get_n_write() const { + if (ubatches_write.empty()) { + return 0; + } + + return ubatches_write[i_next].n_tokens; +} + +ggml_tensor * llama_kv_cache_dsv4_raw_context::get_k(ggml_context * ctx, int32_t il) const { + return kv_swa->get_k(ctx, il, n_kv, sinfos_read[i_next]); +} + +ggml_tensor * llama_kv_cache_dsv4_raw_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const { + const auto & sinfo = sinfos_write[i_next]; + + if (k_cur->ne[2] == k_idxs->ne[0]) { + return kv_swa->cpy_k(ctx, k_cur, k_idxs, il, sinfo); + } + + // k_idxs may be expanded to one block per stream while k_cur is only + // the token block. Keep zero deps on all copies so each write executes. + const int64_t n_fanout = (int64_t) sinfo.size()*sinfo.n_stream(); + + GGML_ASSERT(sinfo.n_stream() > 1); + GGML_ASSERT(k_cur->ne[2] == (int64_t) sinfo.size()); + GGML_ASSERT(k_idxs->ne[0] == n_fanout); + + ggml_tensor * res = nullptr; + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + ggml_tensor * k_idxs_s = ggml_view_1d(ctx, k_idxs, sinfo.size(), s*sinfo.size()*ggml_element_size(k_idxs)); + ggml_tensor * cur = kv_swa->cpy_k(ctx, k_cur, k_idxs_s, il, sinfo); + if (res == nullptr) { + res = cur; + } else { + res = ggml_add(ctx, res, ggml_sub(ctx, cur, cur)); + } + } + + return res; +} + +ggml_tensor * llama_kv_cache_dsv4_raw_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { + const uint32_t n_tokens = ubatches_write.empty() ? ubatch.n_tokens : ubatches_write[i_next].n_tokens; + + ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + ggml_set_input(k_idxs); + + return k_idxs; +} + +ggml_tensor * llama_kv_cache_dsv4_raw_context::build_input_k_rot(ggml_context * ctx) const { + return kv_swa->build_input_k_rot(ctx); +} + +void llama_kv_cache_dsv4_raw_context::set_input_k_idxs(ggml_tensor * dst) const { + kv_swa->set_input_k_idxs(dst, &ubatches_write[i_next], sinfos_write[i_next]); +} + +void llama_kv_cache_dsv4_raw_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + kv_swa->set_input_kq_mask(dst, ubatch, causal_attn); +} + +void llama_kv_cache_dsv4_raw_context::set_input_k_rot(ggml_tensor * dst) const { + kv_swa->set_input_k_rot(dst); +} + +// +// llama_kv_cache_dsv4_comp_context +// + +llama_kv_cache_dsv4_comp_context::llama_kv_cache_dsv4_comp_context(llama_kv_cache * kv) : kv(kv), n_kv(kv->get_size()) { + const uint32_t n_stream = kv->get_n_stream(); + + sinfos.resize(1); + sinfos[0].s0 = 0; + sinfos[0].s1 = n_stream - 1; + sinfos[0].idxs.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + sinfos[0].strm.push_back(s); + sinfos[0].idxs[s].resize(1, 0); + } +} + +llama_kv_cache_dsv4_comp_context::llama_kv_cache_dsv4_comp_context( + llama_kv_cache * kv, + slot_info_vec_t sinfos, + std::vector ubatches) : + kv(kv), + sinfos(std::move(sinfos)), + ubatches(std::move(ubatches)), + n_kv(kv->get_size()) { +} + +bool llama_kv_cache_dsv4_comp_context::next() { + if (ubatches.empty()) { + return true; + } + + if (++i_cur >= ubatches.size()) { + return false; + } + + return true; +} + +uint32_t llama_kv_cache_dsv4_comp_context::get_n_kv() const { + return n_kv; +} + +ggml_tensor * llama_kv_cache_dsv4_comp_context::get_k(ggml_context * ctx, int32_t il) const { + return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); +} + +ggml_tensor * llama_kv_cache_dsv4_comp_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const { + return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]); +} + +ggml_tensor * llama_kv_cache_dsv4_comp_context::build_input_k_rot(ggml_context * ctx) const { + return kv->build_input_k_rot(ctx); +} + +void llama_kv_cache_dsv4_comp_context::set_input_k_rot(ggml_tensor * dst) const { + kv->set_input_k_rot(dst); +} + +// +// llama_kv_cache_dsv4_context +// + +llama_kv_cache_dsv4_context::llama_kv_cache_dsv4_context(llama_memory_status status) : status(status) {} + +llama_kv_cache_dsv4_context::llama_kv_cache_dsv4_context( + llama_kv_cache_dsv4 * kv) : + ctx_raw(std::make_unique(kv->get_raw())), + ctx_csa_mem(kv->get_csa()->init_full()), + ctx_hca_mem(kv->get_hca()->init_full()), + ctx_lid_mem(kv->get_lid()->init_full()), + ctx_csa(std::make_unique(kv->get_csa())), + ctx_hca(std::make_unique(kv->get_hca())), + ctx_lid(std::make_unique(kv->get_lid())), + csa_state(kv->get_csa_state()), + hca_state(kv->get_hca_state()), + lid_state(kv->get_lid_state()), + reserve_plans(true), + status(llama_memory_status_combine( + llama_memory_status_combine(ctx_raw->get_status(), ctx_csa_mem->get_status()), + llama_memory_status_combine(ctx_hca_mem->get_status(), ctx_lid_mem->get_status()))) { +} + +llama_kv_cache_dsv4_context::llama_kv_cache_dsv4_context( + llama_kv_cache_dsv4 * kv, + llama_context * lctx, + bool optimize) : + ctx_raw(std::make_unique(kv->get_raw(), lctx, optimize)), + ctx_csa_mem(kv->get_csa()->init_update(lctx, optimize)), + ctx_hca_mem(kv->get_hca()->init_update(lctx, optimize)), + ctx_lid_mem(kv->get_lid()->init_update(lctx, optimize)), + ctx_csa(std::make_unique(kv->get_csa())), + ctx_hca(std::make_unique(kv->get_hca())), + ctx_lid(std::make_unique(kv->get_lid())), + csa_state(kv->get_csa_state()), + hca_state(kv->get_hca_state()), + lid_state(kv->get_lid_state()), + status(llama_memory_status_combine( + llama_memory_status_combine(ctx_raw->get_status(), ctx_csa_mem->get_status()), + llama_memory_status_combine(ctx_hca_mem->get_status(), ctx_lid_mem->get_status()))) { +} + +llama_kv_cache_dsv4_context::llama_kv_cache_dsv4_context( + llama_kv_cache_dsv4 * kv, + slot_info_vec_t sinfos_raw_base_write, + slot_info_vec_t sinfos_raw_swa_write, + slot_info_vec_t sinfos_raw_swa_read, + std::vector ubatches, + std::vector ubatches_raw) : + ubatches(std::move(ubatches)), + plans_csa(dsv4_build_comp_plans(this->ubatches, DSV4_CSA_RATIO, true, + kv->get_csa_state()->get_state_size(), kv->get_csa()->get_size(), kv->get_csa_state()->get_n_stream())), + plans_hca(dsv4_build_comp_plans(this->ubatches, DSV4_HCA_RATIO, false, + kv->get_hca_state()->get_state_size(), kv->get_hca()->get_size(), kv->get_hca_state()->get_n_stream())), + plans_lid(plans_csa), + ctx_raw(std::make_unique( + kv->get_raw(), + std::move(sinfos_raw_base_write), + std::move(sinfos_raw_swa_write), + std::move(sinfos_raw_swa_read), + this->ubatches, + std::move(ubatches_raw))), + ctx_csa_mem(nullptr), + ctx_hca_mem(nullptr), + ctx_lid_mem(nullptr), + ctx_csa(std::make_unique( + kv->get_csa(), + dsv4_build_comp_sinfos(this->ubatches, kv->get_csa()->get_n_stream()), + this->ubatches)), + ctx_hca(std::make_unique( + kv->get_hca(), + dsv4_build_comp_sinfos(this->ubatches, kv->get_hca()->get_n_stream()), + this->ubatches)), + ctx_lid(std::make_unique( + kv->get_lid(), + dsv4_build_comp_sinfos(this->ubatches, kv->get_lid()->get_n_stream()), + this->ubatches)), + csa_state(kv->get_csa_state()), + hca_state(kv->get_hca_state()), + lid_state(kv->get_lid_state()), + status(ctx_raw->get_status()) { +} + +llama_kv_cache_dsv4_context::~llama_kv_cache_dsv4_context() = default; + +bool llama_kv_cache_dsv4_context::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + ctx_raw->next(); + ctx_csa->next(); + ctx_hca->next(); + ctx_lid->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_dsv4_context::apply() { + assert(!llama_memory_status_is_fail(status)); + + bool res = true; + + res = res & ctx_raw->apply(); + + return res; +} + +llama_memory_status llama_kv_cache_dsv4_context::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_dsv4_context::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +const llama_kv_cache_dsv4_raw_context * llama_kv_cache_dsv4_context::get_raw() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ctx_raw.get(); +} + +const llama_kv_cache_dsv4_comp_context * llama_kv_cache_dsv4_context::get_csa() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ctx_csa.get(); +} + +const llama_kv_cache_dsv4_comp_context * llama_kv_cache_dsv4_context::get_hca() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ctx_hca.get(); +} + +const llama_kv_cache_dsv4_comp_context * llama_kv_cache_dsv4_context::get_lid() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ctx_lid.get(); +} + +const llama_dsv4_comp_state * llama_kv_cache_dsv4_context::get_csa_state() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return csa_state; +} + +const llama_dsv4_comp_state * llama_kv_cache_dsv4_context::get_hca_state() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return hca_state; +} + +const llama_dsv4_comp_state * llama_kv_cache_dsv4_context::get_lid_state() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return lid_state; +} + +const llama_kv_cache_dsv4_context::comp_plan & llama_kv_cache_dsv4_context::get_csa_plan() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + static const comp_plan empty; + if (plans_csa.empty()) { + return empty; + } + + return plans_csa[i_next]; +} + +const llama_kv_cache_dsv4_context::comp_plan & llama_kv_cache_dsv4_context::get_hca_plan() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + static const comp_plan empty; + if (plans_hca.empty()) { + return empty; + } + + return plans_hca[i_next]; +} + +const llama_kv_cache_dsv4_context::comp_plan & llama_kv_cache_dsv4_context::get_lid_plan() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + static const comp_plan empty; + if (plans_lid.empty()) { + return empty; + } + + return plans_lid[i_next]; +} + +const llama_kv_cache_dsv4_context::comp_plan & llama_kv_cache_dsv4_context::get_csa_plan(const llama_ubatch & ubatch) const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (!reserve_plans) { + return get_csa_plan(); + } + + reserve_plan_csa = dsv4_build_reserve_comp_plan( + ubatch, DSV4_CSA_RATIO, true, + csa_state->get_state_size(), get_csa()->get_n_kv(), csa_state->get_n_stream()); + + return reserve_plan_csa; +} + +const llama_kv_cache_dsv4_context::comp_plan & llama_kv_cache_dsv4_context::get_hca_plan(const llama_ubatch & ubatch) const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (!reserve_plans) { + return get_hca_plan(); + } + + reserve_plan_hca = dsv4_build_reserve_comp_plan( + ubatch, DSV4_HCA_RATIO, false, + hca_state->get_state_size(), get_hca()->get_n_kv(), hca_state->get_n_stream()); + + return reserve_plan_hca; +} + +const llama_kv_cache_dsv4_context::comp_plan & llama_kv_cache_dsv4_context::get_lid_plan(const llama_ubatch & ubatch) const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (!reserve_plans) { + return get_lid_plan(); + } + + reserve_plan_lid = dsv4_build_reserve_comp_plan( + ubatch, DSV4_CSA_RATIO, true, + lid_state->get_state_size(), get_lid()->get_n_kv(), lid_state->get_n_stream()); + + return reserve_plan_lid; +} diff --git a/src/llama-kv-cache-dsv4.h b/src/llama-kv-cache-dsv4.h new file mode 100644 index 000000000000..3be49cd97e3a --- /dev/null +++ b/src/llama-kv-cache-dsv4.h @@ -0,0 +1,362 @@ +#pragma once + +#include "llama-kv-cache.h" +#include "llama-kv-cache-iswa.h" + +#include +#include +#include +#include + +class llama_dsv4_comp_state { +public: + llama_dsv4_comp_state( + const llama_model & model, + bool offload, + bool unified, + uint32_t n_seq_max, + uint32_t ratio, + uint32_t state_size, + uint32_t n_embd_state, + const char * name, + const llama_memory_i::layer_filter_cb & filter); + + void clear(bool data); + + uint32_t get_ratio() const; + uint32_t get_state_size() const; + uint32_t get_n_stream() const; + + std::map memory_breakdown() const; + + void state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const; + void state_read (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); + + ggml_tensor * get_kv (ggml_context * ctx, int32_t il) const; + ggml_tensor * get_score(ggml_context * ctx, int32_t il) const; + + ggml_tensor * cpy_kv (ggml_context * ctx, ggml_tensor * cur, ggml_tensor * idxs, int32_t il) const; + ggml_tensor * cpy_score(ggml_context * ctx, ggml_tensor * cur, ggml_tensor * idxs, int32_t il) const; + +private: + struct layer { + uint32_t il; + + ggml_tensor * kv; + ggml_tensor * score; + }; + + const uint32_t ratio; + const uint32_t state_size; + const uint32_t n_embd_state; + const uint32_t n_stream; + + std::vector> ctxs_bufs; + + std::vector layers; + + std::unordered_map map_layer_ids; + + size_t total_size() const; +}; + +// +// llama_kv_cache_dsv4 +// + +// DSV4 uses a normal raw/SWA token cache plus compressed K-only block caches. +// The compressed caches are storage only; DSV4-specific visibility and block +// planning are handled by llama_kv_cache_dsv4_context / llm_graph_input_dsv4. + +class llama_kv_cache_dsv4 : public llama_memory_i { +public: + llama_kv_cache_dsv4( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_ubatch, + uint32_t n_pad, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse); + + ~llama_kv_cache_dsv4() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map memory_breakdown() const override; + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + // + // llama_kv_cache_dsv4 specific API + // + + llama_kv_cache_iswa * get_raw() const; + llama_kv_cache * get_csa() const; + llama_kv_cache * get_hca() const; + llama_kv_cache * get_lid() const; + llama_dsv4_comp_state * get_csa_state() const; + llama_dsv4_comp_state * get_hca_state() const; + llama_dsv4_comp_state * get_lid_state() const; + +private: + llama_hparams hparams_raw; + llama_hparams hparams_csa; + llama_hparams hparams_hca; + llama_hparams hparams_lid; + + const uint32_t n_seq_max; + + std::unique_ptr kv_raw; + std::unique_ptr kv_csa; + std::unique_ptr kv_hca; + std::unique_ptr kv_lid; + std::unique_ptr csa_state; + std::unique_ptr hca_state; + std::unique_ptr lid_state; + + void clear_compressed(bool data); +}; + +// DSV4 raw attention only uses the SWA half of kv_raw. The base half is kept +// for generic ISWA bookkeeping, but it has no DSV4 layers to expose here. +class llama_kv_cache_dsv4_raw_context : public llama_memory_context_i { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + llama_kv_cache_dsv4_raw_context(llama_kv_cache_iswa * kv); + + llama_kv_cache_dsv4_raw_context( + llama_kv_cache_iswa * kv, + llama_context * lctx, + bool optimize); + + llama_kv_cache_dsv4_raw_context( + llama_kv_cache_iswa * kv, + slot_info_vec_t sinfos_base_write, + slot_info_vec_t sinfos_swa_write, + slot_info_vec_t sinfos_swa_read, + std::vector ubatches, + std::vector ubatches_write); + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + uint32_t get_n_kv() const; + uint32_t get_n_write() const; + + ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; + + ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_k_rot(ggml_context * ctx) const; + + void set_input_k_idxs(ggml_tensor * dst) const; + void set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_k_rot(ggml_tensor * dst) const; + +private: + size_t i_next = 0; + + llama_kv_cache * kv_swa = nullptr; + + slot_info_vec_t sinfos_write; + slot_info_vec_t sinfos_read; + std::vector ubatches; + std::vector ubatches_write; + + const llama_memory_context_ptr ctx_base_mem; + const llama_memory_context_ptr ctx_swa_mem; + + uint32_t n_kv = 0; + + const llama_memory_status status; +}; + +// DSV4 compressed KV rows are graph outputs, not normal token KV writes. +// Keep a small context that exposes K tensors without generic apply() semantics. +class llama_kv_cache_dsv4_comp_context { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + llama_kv_cache_dsv4_comp_context(llama_kv_cache * kv); + + llama_kv_cache_dsv4_comp_context( + llama_kv_cache * kv, + slot_info_vec_t sinfos, + std::vector ubatches); + + bool next(); + + uint32_t get_n_kv() const; + + ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; + + ggml_tensor * build_input_k_rot(ggml_context * ctx) const; + void set_input_k_rot(ggml_tensor * dst) const; + +private: + llama_kv_cache * kv; + + size_t i_cur = 0; + slot_info_vec_t sinfos; + std::vector ubatches; + + uint32_t n_kv; +}; + +class llama_kv_cache_dsv4_context : public llama_memory_context_i { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + struct comp_plan { + // Per-ubatch recipe for updating compressor state, committing completed + // compressed rows, and masking the compressed attention source. + + // APE row ids, i.e. pos % ratio, for the compressor-state updates. + std::vector state_pos; + + // Current-ubatch source row ids and unique persistent-state + // destination row ids for deterministic ring-state updates. + std::vector state_persist_src_idxs; + std::vector state_persist_dst_idxs; + + // Flattened source row ids used for state-backed commits. Source rows + // index the graph-local [persistent_state | current_ubatch_scratch] + // tensor. For overlapped compression the first half is previous rows + // and the second half is current rows; a final synthetic zero/-inf row + // may be addressed for the first block's previous half. + std::vector state_read_idxs; + + // Final compressed-cache row ids written by state-backed commits. + // A non-boundary CSA/LID decode step can target a masked scratch row. + std::vector state_write_idxs; + + // RoPE positions for state-backed commits. + std::vector state_write_pos; + + // Number of completed compressed rows visible for each query token. + std::vector n_visible; + + // Number of streams used by the attention graph for this ubatch. + int64_t n_stream = 1; + + // Graph-width for compressed rows. This can be larger than n_visible + // so masked padding rows do not force a new graph at every CSA block. + int64_t n_kv = 0; + }; + + llama_kv_cache_dsv4_context(llama_memory_status status); + + llama_kv_cache_dsv4_context( + llama_kv_cache_dsv4 * kv); + + llama_kv_cache_dsv4_context( + llama_kv_cache_dsv4 * kv, + llama_context * lctx, + bool optimize); + + llama_kv_cache_dsv4_context( + llama_kv_cache_dsv4 * kv, + slot_info_vec_t sinfos_raw_base_write, + slot_info_vec_t sinfos_raw_swa_write, + slot_info_vec_t sinfos_raw_swa_read, + std::vector ubatches, + std::vector ubatches_raw); + + virtual ~llama_kv_cache_dsv4_context(); + + // + // llama_memory_context_i + // + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_dsv4_context specific API + // + + const llama_kv_cache_dsv4_raw_context * get_raw() const; + const llama_kv_cache_dsv4_comp_context * get_csa() const; + const llama_kv_cache_dsv4_comp_context * get_hca() const; + const llama_kv_cache_dsv4_comp_context * get_lid() const; + const llama_dsv4_comp_state * get_csa_state() const; + const llama_dsv4_comp_state * get_hca_state() const; + const llama_dsv4_comp_state * get_lid_state() const; + + const comp_plan & get_csa_plan() const; + const comp_plan & get_hca_plan() const; + const comp_plan & get_lid_plan() const; + + const comp_plan & get_csa_plan(const llama_ubatch & ubatch) const; + const comp_plan & get_hca_plan(const llama_ubatch & ubatch) const; + const comp_plan & get_lid_plan(const llama_ubatch & ubatch) const; + +private: + size_t i_next = 0; + + std::vector ubatches; + + std::vector plans_csa; + std::vector plans_hca; + std::vector plans_lid; + + const std::unique_ptr ctx_raw; + const llama_memory_context_ptr ctx_csa_mem; + const llama_memory_context_ptr ctx_hca_mem; + const llama_memory_context_ptr ctx_lid_mem; + + const std::unique_ptr ctx_csa; + const std::unique_ptr ctx_hca; + const std::unique_ptr ctx_lid; + + const llama_dsv4_comp_state * csa_state = nullptr; + const llama_dsv4_comp_state * hca_state = nullptr; + const llama_dsv4_comp_state * lid_state = nullptr; + + bool reserve_plans = false; + mutable comp_plan reserve_plan_csa; + mutable comp_plan reserve_plan_hca; + mutable comp_plan reserve_plan_lid; + + const llama_memory_status status; +}; diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index aa1b1b72ebe6..2fcf238d9173 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -26,7 +26,28 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( llama_memory_t mem_other, const layer_filter_cb & filter, const layer_reuse_cb & reuse, - const layer_share_cb & share) : hparams(model.hparams), unified(unified) { + const layer_share_cb & share) : + llama_kv_cache_iswa(model, model.hparams, type_k, type_v, v_trans, offload, swa_full, unified, + kv_size, n_seq_max, n_ubatch, n_pad, mem_other, filter, reuse, share) { +} + +llama_kv_cache_iswa::llama_kv_cache_iswa( + const llama_model & model, + const llama_hparams & hparams, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_ubatch, + uint32_t n_pad, + llama_memory_t mem_other, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse, + const layer_share_cb & share) : unified(unified) { // chain filters const layer_filter_cb filter_base = [&](int32_t il) { diff --git a/src/llama-kv-cache-iswa.h b/src/llama-kv-cache-iswa.h index dfafc1ef510b..7dab6eaa82c8 100644 --- a/src/llama-kv-cache-iswa.h +++ b/src/llama-kv-cache-iswa.h @@ -30,6 +30,24 @@ class llama_kv_cache_iswa : public llama_memory_i { const layer_reuse_cb & reuse, const layer_share_cb & share); + llama_kv_cache_iswa( + const llama_model & model, + const llama_hparams & hparams, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_ubatch, + uint32_t n_pad, + llama_memory_t mem_other, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse, + const layer_share_cb & share); + ~llama_kv_cache_iswa() = default; // @@ -73,8 +91,6 @@ class llama_kv_cache_iswa : public llama_memory_i { llama_kv_cache * get_swa () const; private: - const llama_hparams & hparams; - const bool unified; std::unique_ptr kv_base; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 2802103bdd82..12bf5c37914d 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -211,10 +211,12 @@ llama_kv_cache::llama_kv_cache( n_embd_head_k_all = -1; } - if (n_embd_head_v_all == 0) { - n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il); - } else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) { - n_embd_head_v_all = -1; + if (!is_mla) { + if (n_embd_head_v_all == 0) { + n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il); + } else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) { + n_embd_head_v_all = -1; + } } // [TAG_V_CACHE_VARIABLE] @@ -336,8 +338,9 @@ llama_kv_cache::llama_kv_cache( ggml_is_quantized(type_k) && hparams.n_embd_head_k() % 64 == 0; - // always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer - if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) { + // always create Hadamard rotation tensors for DeepSeek lightning indexers + if ((model.arch == LLM_ARCH_DEEPSEEK32 || model.arch == LLM_ARCH_DEEPSEEK4) && + hparams.n_embd_head_k_full == hparams.indexer_head_size) { attn_rot_k = true; } @@ -1220,6 +1223,23 @@ ggml_type llama_kv_cache::type_v() const { return layers[0].v->type; } +std::vector llama_kv_cache::get_layer_ids() const { + std::vector res; + res.reserve(layers.size()); + + for (const auto & layer : layers) { + res.push_back(layer.il); + } + + return res; +} + +ggml_tensor * llama_kv_cache::get_k_storage(int32_t il) const { + const int32_t ikv = map_layer_ids.at(il); + + return layers[ikv].k; +} + uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { uint32_t result = 0; diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 3d68f98c1424..531d99dbdec1 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -161,6 +161,9 @@ class llama_kv_cache : public llama_memory_i { ggml_type type_k() const; ggml_type type_v() const; + std::vector get_layer_ids() const; + ggml_tensor * get_k_storage(int32_t il) const; + // // graph_build API // diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 474cabdfc095..229da5076d53 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -294,6 +294,8 @@ namespace GGUFMeta { } template bool llama_model_loader::get_arr_n(enum llm_kv kid, uint32_t & result, bool required); + template std::enable_if::value, bool>::type + llama_model_loader::get_arr_n(const std::string & key, uint32_t & result, bool required); template bool llama_model_loader::get_arr(const std::string & key, std::vector & result, bool required) { @@ -395,6 +397,10 @@ namespace GGUFMeta { template bool llama_model_loader::get_arr>(enum llm_kv kid, std::vector & result, bool required); template bool llama_model_loader::get_arr>(enum llm_kv kid, std::array & result, bool required); template bool llama_model_loader::get_arr>(enum llm_kv kid, std::vector & result, bool required); + template bool llama_model_loader::get_arr( + const std::string & key, + std::array & result, + bool required); template bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index eaf29505c33a..d58ebac28b9b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11,6 +11,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" #include "llama-kv-cache-dsa.h" +#include "llama-kv-cache-dsv4.h" #include "llama-memory-hybrid.h" #include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" @@ -181,6 +182,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_deepseek2ocr(params); case LLM_ARCH_DEEPSEEK32: return new llama_model_deepseek32(params); + case LLM_ARCH_DEEPSEEK4: + return new llama_model_deepseek4(params); case LLM_ARCH_GLM_DSA: return new llama_model_glm_dsa(params); case LLM_ARCH_MISTRAL4: @@ -817,6 +820,7 @@ static const char * llama_expert_gating_func_name(llama_expert_gating_func_type switch (type) { case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX: return "softmax"; case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID: return "sigmoid"; + case LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS: return "sqrtsoftplus"; default: return "unknown"; } } @@ -2156,7 +2160,24 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } } - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + if (arch == LLM_ARCH_DEEPSEEK4) { + GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE); + + res = new llama_kv_cache_dsv4( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + filter, + reuse); + } else if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); if (arch == LLM_ARCH_GEMMA4_ASSISTANT) { @@ -2328,6 +2349,11 @@ int32_t llama_model_n_head_kv(const llama_model * model) { } int32_t llama_model_n_swa(const llama_model * model) { + // dsv4 kv-cache has SWA but it cannot be used as a rollback because of + // other compression ratios, so we return 0 here + if (model->arch == LLM_ARCH_DEEPSEEK4) { + return 0; + } return model->hparams.n_swa; } @@ -2409,6 +2435,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK2OCR: case LLM_ARCH_DEEPSEEK32: + case LLM_ARCH_DEEPSEEK4: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: diff --git a/src/llama-model.h b/src/llama-model.h index 77d8d3b6258a..4800d2928c52 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -255,9 +255,11 @@ struct llama_layer { struct ggml_tensor * wq_b = nullptr; struct ggml_tensor * wkv_a_mqa = nullptr; struct ggml_tensor * wkv_b = nullptr; + struct ggml_tensor * wkv = nullptr; struct ggml_tensor * wk_b = nullptr; struct ggml_tensor * wv_b = nullptr; struct ggml_tensor * wqkv_b = nullptr; + struct ggml_tensor * wo_a = nullptr; struct ggml_tensor * wo_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; @@ -333,6 +335,7 @@ struct llama_layer { struct ggml_tensor * ffn_up_b = nullptr; // b3 struct ggml_tensor * ffn_act = nullptr; struct ggml_tensor * ffn_exp_probs_b = nullptr; + struct ggml_tensor * ffn_gate_tid2eid = nullptr; // mamba proj struct ggml_tensor * ssm_in = nullptr; @@ -463,6 +466,23 @@ struct llama_layer { // openai-moe struct ggml_tensor * attn_sinks = nullptr; + // DeepSeek-V4 + struct ggml_tensor * attn_kv_norm = nullptr; + struct ggml_tensor * hc_attn_fn = nullptr; + struct ggml_tensor * hc_attn_base = nullptr; + struct ggml_tensor * hc_attn_scale = nullptr; + struct ggml_tensor * hc_ffn_fn = nullptr; + struct ggml_tensor * hc_ffn_base = nullptr; + struct ggml_tensor * hc_ffn_scale = nullptr; + struct ggml_tensor * attn_comp_wkv = nullptr; + struct ggml_tensor * attn_comp_wgate = nullptr; + struct ggml_tensor * attn_comp_ape = nullptr; + struct ggml_tensor * attn_comp_norm = nullptr; + struct ggml_tensor * indexer_comp_wkv = nullptr; + struct ggml_tensor * indexer_comp_wgate = nullptr; + struct ggml_tensor * indexer_comp_ape = nullptr; + struct ggml_tensor * indexer_comp_norm = nullptr; + // cogvlm struct ggml_tensor * visexp_attn_wqkv = nullptr; struct ggml_tensor * visexp_attn_wo = nullptr; @@ -553,6 +573,11 @@ struct llama_model { struct ggml_tensor * nextn_proj_pre = nullptr; struct ggml_tensor * nextn_proj_post = nullptr; + // DeepSeek-V4 + struct ggml_tensor * hc_head_fn = nullptr; + struct ggml_tensor * hc_head_base = nullptr; + struct ggml_tensor * hc_head_scale = nullptr; + // classifier struct ggml_tensor * cls = nullptr; struct ggml_tensor * cls_b = nullptr; diff --git a/src/models/deepseek4.cpp b/src/models/deepseek4.cpp new file mode 100644 index 000000000000..2dd5eb0b123e --- /dev/null +++ b/src/models/deepseek4.cpp @@ -0,0 +1,1190 @@ +#include "models.h" + +#include "llama-kv-cache-dsv4.h" + +#include +#include +#include +#include + +static std::string dsv4_kv(const char * suffix) { + return std::string("deepseek4.") + suffix; +} + +static float dsv4_rope_attn_factor(float freq_scale, float ext_factor) { + if (ext_factor == 0.0f) { + return 1.0f; + } + + return 1.0f / (1.0f + 0.1f*logf(1.0f/freq_scale)); +} + +void llama_model_deepseek4::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer()); + if (!ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer(), 0)) { + hparams.swiglu_clamp_shexp = hparams.swiglu_clamp_exp; + } + + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + + ml.get_key(dsv4_kv("attention.output_group_count"), hparams.dsv4_o_group_count); + ml.get_key(dsv4_kv("attention.output_lora_rank"), hparams.dsv4_o_lora_rank); + ml.get_key(dsv4_kv("attention.compress_rope_freq_base"), hparams.dsv4_compress_rope_base); + ml.get_key(dsv4_kv("hyper_connection.count"), hparams.dsv4_hc_mult); + ml.get_key(dsv4_kv("hyper_connection.sinkhorn_iterations"), hparams.dsv4_hc_sinkhorn_iters); + ml.get_key(dsv4_kv("hyper_connection.epsilon"), hparams.dsv4_hc_eps); + ml.get_key(dsv4_kv("hash_layer_count"), hparams.dsv4_hash_layer_count); + + uint32_t n_compress_ratios = 0; + ml.get_arr_n(dsv4_kv("attention.compress_ratios"), n_compress_ratios); + if (n_compress_ratios < hparams.n_layer()) { + throw std::runtime_error("DeepSeek-V4 compress_ratios is shorter than block_count"); + } + ml.get_arr(dsv4_kv("attention.compress_ratios"), hparams.dsv4_compress_ratios); + + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + if (hparams.expert_gating_func != LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS) { + throw std::runtime_error("DeepSeek-V4 loader currently expects sqrtsoftplus MoE scoring"); + } + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(0); + + switch (hparams.n_layer()) { + case 43: type = LLM_TYPE_UNKNOWN; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deepseek4::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const int64_t n_embd_head = hparams.n_embd_head_k(); + const int64_t o_groups = hparams.dsv4_o_group_count; + const int64_t o_lora_rank = hparams.dsv4_o_lora_rank; + const int64_t hc_mult = hparams.dsv4_hc_mult; + const int64_t hc_dim = hc_mult * n_embd; + const int64_t hc_mix_dim = (2 + hc_mult) * hc_mult; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + hc_head_fn = create_tensor(tn(LLM_TENSOR_HC_HEAD_FN, "weight"), {hc_dim, hc_mult}, 0); + hc_head_base = create_tensor(tn(LLM_TENSOR_HC_HEAD_BASE, "weight"), {hc_mult}, 0); + hc_head_scale = create_tensor(tn(LLM_TENSOR_HC_HEAD_SCALE, "weight"), {1}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head}, 0); + layer.wkv = create_tensor(tn(LLM_TENSOR_ATTN_KV, "weight", i), {n_embd, n_embd_head}, 0); + layer.attn_kv_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_NORM, "weight", i), {n_embd_head}, 0); + layer.wo_a = create_tensor(tn(LLM_TENSOR_ATTN_OUT_A, "weight", i), {n_head * n_embd_head / o_groups, o_lora_rank * o_groups}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_B, "weight", i), {o_groups * o_lora_rank, n_embd}, 0); + + layer.hc_attn_fn = create_tensor(tn(LLM_TENSOR_HC_ATTN_FN, "weight", i), {hc_dim, hc_mix_dim}, 0); + layer.hc_attn_base = create_tensor(tn(LLM_TENSOR_HC_ATTN_BASE, "weight", i), {hc_mix_dim}, 0); + layer.hc_attn_scale = create_tensor(tn(LLM_TENSOR_HC_ATTN_SCALE, "weight", i), {3}, 0); + layer.hc_ffn_fn = create_tensor(tn(LLM_TENSOR_HC_FFN_FN, "weight", i), {hc_dim, hc_mix_dim}, 0); + layer.hc_ffn_base = create_tensor(tn(LLM_TENSOR_HC_FFN_BASE, "weight", i), {hc_mix_dim}, 0); + layer.hc_ffn_scale = create_tensor(tn(LLM_TENSOR_HC_FFN_SCALE, "weight", i), {3}, 0); + + const int64_t ratio = hparams.dsv4_compress_ratios[i]; + if (ratio != 0) { + const int64_t coff = ratio == 4 ? 2 : 1; + + layer.attn_comp_wkv = create_tensor(tn(LLM_TENSOR_ATTN_COMPRESSOR_WKV, "weight", i), {n_embd, coff * n_embd_head}, 0); + layer.attn_comp_wgate = create_tensor(tn(LLM_TENSOR_ATTN_COMPRESSOR_WGATE, "weight", i), {n_embd, coff * n_embd_head}, 0); + layer.attn_comp_ape = create_tensor(tn(LLM_TENSOR_ATTN_COMPRESSOR_APE, "weight", i), {coff * n_embd_head, ratio}, 0); + layer.attn_comp_norm = create_tensor(tn(LLM_TENSOR_ATTN_COMPRESSOR_NORM, "weight", i), {n_embd_head}, 0); + + if (ratio == 4) { + const int64_t n_embd_indexer = hparams.indexer_head_size; + + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, 0); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * n_embd_indexer}, 0); + + layer.indexer_comp_wkv = create_tensor(tn(LLM_TENSOR_INDEXER_COMPRESSOR_WKV, "weight", i), {n_embd, 2 * n_embd_indexer}, 0); + layer.indexer_comp_wgate = create_tensor(tn(LLM_TENSOR_INDEXER_COMPRESSOR_WGATE, "weight", i), {n_embd, 2 * n_embd_indexer}, 0); + layer.indexer_comp_ape = create_tensor(tn(LLM_TENSOR_INDEXER_COMPRESSOR_APE, "weight", i), {2 * n_embd_indexer, ratio}, 0); + layer.indexer_comp_norm = create_tensor(tn(LLM_TENSOR_INDEXER_COMPRESSOR_NORM, "weight", i), {n_embd_indexer}, 0); + } else if (ratio != 128) { + throw std::runtime_error("DeepSeek-V4 loader only supports compression ratios 0, 4, and 128"); + } + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + if ((uint32_t) i < hparams.dsv4_hash_layer_count) { + layer.ffn_gate_tid2eid = create_tensor(tn(LLM_TENSOR_FFN_GATE_TID2EID, "weight", i), {n_expert_used, n_vocab}, 0); + } else { + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_exp * n_expert_shared, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } +} + +std::unique_ptr llama_model_deepseek4::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +static size_t dsv4_elem_offset(const ggml_tensor * t, int64_t i) { + return ggml_row_size(t->type, i); +} + +static ggml_tensor * dsv4_view_1d(ggml_context * ctx, ggml_tensor * t, int64_t ne0, int64_t i0) { + return ggml_view_1d(ctx, t, ne0, dsv4_elem_offset(t, i0)); +} + +static ggml_tensor * dsv4_view_2d( + ggml_context * ctx, + ggml_tensor * t, + int64_t ne0, + int64_t ne1, + int64_t i0) { + return ggml_view_2d(ctx, t, ne0, ne1, t->nb[1], dsv4_elem_offset(t, i0)); +} + +static ggml_tensor * dsv4_append_zero_row(ggml_context * ctx, ggml_tensor * t, bool neg_inf) { + ggml_tensor * row = ggml_view_1d(ctx, t, t->ne[0], 0); + row = neg_inf ? ggml_scale_bias(ctx, row, 0.0f, -INFINITY) : ggml_scale(ctx, row, 0.0f); + row = ggml_reshape_2d(ctx, row, t->ne[0], 1); + + return ggml_concat(ctx, t, row, 1); +} + +static ggml_tensor * dsv4_with_zero_dep(ggml_context * ctx, ggml_tensor * t, ggml_tensor * dep) { + if (dep == nullptr) { + return t; + } + + ggml_tensor * zero = ggml_scale(ctx, ggml_sum(ctx, dep), 0.0f); + return ggml_add(ctx, t, zero); +} + +// Raw SWA K is stored once, but compressed K/masks can carry a stream axis. +// Repeat raw K at graph build time before concatenating raw and compressed K. +static ggml_tensor * dsv4_repeat_streams(ggml_context * ctx, ggml_tensor * t, int64_t n_stream) { + if (t->ne[3] == n_stream) { + return t; + } + + GGML_ASSERT(t->ne[3] == 1); + return ggml_repeat_4d(ctx, t, t->ne[0], t->ne[1], t->ne[2], n_stream); +} + +static ggml_tensor * dsv4_build_kq_zero_bias( + ggml_context * ctx, + const llama_cparams & cparams, + ggml_tensor * kq_mask, + int64_t n_head) { + if (!cparams.kv_unified || !cparams.flash_attn || kq_mask->ne[3] == 1) { + return nullptr; + } + + // Keep multi-stream unified DSV4 on the explicit attention path. + ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, + kq_mask->ne[0], kq_mask->ne[1], n_head, kq_mask->ne[3]); + return ggml_fill(ctx, res, 0.0f); +} + +static constexpr int64_t DSV4_CSA_RATIO = 4; +static constexpr int64_t DSV4_HCA_RATIO = 128; + +static ggml_tensor * dsv4_hc_affine( + ggml_context * ctx, + ggml_tensor * x, + ggml_tensor * scale, + ggml_tensor * base) { + x = ggml_mul(ctx, x, scale); + x = ggml_add(ctx, x, base); + return x; +} + +ggml_tensor * llama_model_deepseek4::graph::build_hc_weighted_sum( + ggml_tensor * x, + ggml_tensor * weights) const { + const int64_t hc = hparams.dsv4_hc_mult; + const int64_t nt = x->ne[2]; + + ggml_tensor * acc = nullptr; + for (int64_t ih = 0; ih < hc; ++ih) { + ggml_tensor * xh = ggml_view_2d(ctx0, x, n_embd, nt, x->nb[2], ih*x->nb[1]); + ggml_tensor * wh = ggml_view_2d(ctx0, weights, 1, nt, weights->nb[1], ih*weights->nb[0]); + + ggml_tensor * cur = ggml_mul(ctx0, xh, wh); + acc = acc ? ggml_add(ctx0, acc, cur) : cur; + } + + return acc; +} + +ggml_tensor * llama_model_deepseek4::graph::build_hc_sinkhorn( + ggml_tensor * comb, + int il) const { + GGML_UNUSED(il); + + // comb is [dst_hc, src_hc, n_tokens]. Sinkhorn follows the reference: + // row softmax over dst, one column normalization, then repeated row/column normalization. + comb = ggml_soft_max(ctx0, comb); + + ggml_tensor * eps = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + eps = ggml_fill(ctx0, eps, hparams.dsv4_hc_eps); + + comb = ggml_add(ctx0, comb, eps); + + auto norm_cols = [&]() { + ggml_tensor * comb_src_dst = ggml_cont(ctx0, ggml_permute(ctx0, comb, 1, 0, 2, 3)); + ggml_tensor * col_sum = ggml_sum_rows(ctx0, comb_src_dst); + col_sum = ggml_add(ctx0, col_sum, eps); + col_sum = ggml_permute(ctx0, col_sum, 1, 0, 2, 3); + comb = ggml_div(ctx0, comb, col_sum); + }; + + auto norm_rows = [&]() { + ggml_tensor * row_sum = ggml_sum_rows(ctx0, comb); + row_sum = ggml_add(ctx0, row_sum, eps); + comb = ggml_div(ctx0, comb, row_sum); + }; + + norm_cols(); + for (uint32_t i = 1; i < hparams.dsv4_hc_sinkhorn_iters; ++i) { + norm_rows(); + norm_cols(); + } + + return comb; +} + +ggml_tensor * llama_model_deepseek4::graph::build_hc_pre( + ggml_tensor * x, + ggml_tensor * hc_fn, + ggml_tensor * hc_scale, + ggml_tensor * hc_base, + ggml_tensor ** post, + ggml_tensor ** comb, + int il) const { + const int64_t hc = hparams.dsv4_hc_mult; + const int64_t hc_dim = hc*n_embd; + const int64_t hc_mix_dim = (2 + hc)*hc; + const int64_t nt = x->ne[2]; + + GGML_ASSERT(hc == 4); + GGML_ASSERT(hc_fn->ne[1] == hc_mix_dim); + + ggml_tensor * flat = ggml_reshape_2d(ctx0, x, hc_dim, nt); + ggml_tensor * flat_norm = ggml_rms_norm(ctx0, flat, norm_rms_eps); + ggml_tensor * mixes = ggml_mul_mat(ctx0, hc_fn, flat_norm); + cb(mixes, "hc_mixes", il); + + ggml_tensor * scale_pre = dsv4_view_1d(ctx0, hc_scale, 1, 0); + ggml_tensor * scale_post = dsv4_view_1d(ctx0, hc_scale, 1, 1); + ggml_tensor * scale_comb = dsv4_view_1d(ctx0, hc_scale, 1, 2); + + ggml_tensor * base_pre = dsv4_view_1d(ctx0, hc_base, hc, 0); + ggml_tensor * base_post = dsv4_view_1d(ctx0, hc_base, hc, hc); + ggml_tensor * base_comb = dsv4_view_1d(ctx0, hc_base, hc*hc, 2*hc); + + ggml_tensor * pre = dsv4_view_2d(ctx0, mixes, hc, nt, 0); + pre = dsv4_hc_affine(ctx0, pre, scale_pre, base_pre); + pre = ggml_sigmoid(ctx0, pre); + pre = ggml_scale_bias(ctx0, pre, 1.0f, hparams.dsv4_hc_eps); + cb(pre, "hc_pre", il); + + *post = dsv4_view_2d(ctx0, mixes, hc, nt, hc); + *post = dsv4_hc_affine(ctx0, *post, scale_post, base_post); + *post = ggml_sigmoid(ctx0, *post); + *post = ggml_scale(ctx0, *post, 2.0f); + cb(*post, "hc_post", il); + + *comb = dsv4_view_2d(ctx0, mixes, hc*hc, nt, 2*hc); + *comb = dsv4_hc_affine(ctx0, *comb, scale_comb, base_comb); + *comb = ggml_reshape_3d(ctx0, *comb, hc, hc, nt); + *comb = build_hc_sinkhorn(*comb, il); + cb(*comb, "hc_comb", il); + + return build_hc_weighted_sum(x, pre); +} + +ggml_tensor * llama_model_deepseek4::graph::build_hc_post( + ggml_tensor * x, + ggml_tensor * residual, + ggml_tensor * post, + ggml_tensor * comb, + int il) const { + GGML_UNUSED(il); + + const int64_t hc = hparams.dsv4_hc_mult; + const int64_t nt = x->ne[1]; + + ggml_tensor * out = nullptr; + for (int64_t dst = 0; dst < hc; ++dst) { + ggml_tensor * post_dst = ggml_view_2d(ctx0, post, 1, nt, post->nb[1], dst*post->nb[0]); + ggml_tensor * cur = ggml_mul(ctx0, x, post_dst); + + for (int64_t src = 0; src < hc; ++src) { + ggml_tensor * res_src = ggml_view_2d(ctx0, residual, n_embd, nt, residual->nb[2], src*residual->nb[1]); + ggml_tensor * comb_src_dst = ggml_view_2d(ctx0, comb, 1, nt, comb->nb[2], dst*comb->nb[0] + src*comb->nb[1]); + cur = ggml_add(ctx0, cur, ggml_mul(ctx0, res_src, comb_src_dst)); + } + + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, nt); + out = out ? ggml_concat(ctx0, out, cur, 1) : cur; + } + + return out; +} + +ggml_tensor * llama_model_deepseek4::graph::build_hc_head( + ggml_tensor * x, + ggml_tensor * hc_fn, + ggml_tensor * hc_scale, + ggml_tensor * hc_base) const { + const int64_t hc = hparams.dsv4_hc_mult; + const int64_t hc_dim = hc*n_embd; + const int64_t nt = x->ne[2]; + + ggml_tensor * flat = ggml_reshape_2d(ctx0, x, hc_dim, nt); + ggml_tensor * flat_norm = ggml_rms_norm(ctx0, flat, norm_rms_eps); + ggml_tensor * mixes = ggml_mul_mat(ctx0, hc_fn, flat_norm); + cb(mixes, "hc_head_mixes", -1); + + ggml_tensor * pre = dsv4_hc_affine(ctx0, mixes, hc_scale, hc_base); + pre = ggml_sigmoid(ctx0, pre); + pre = ggml_scale_bias(ctx0, pre, 1.0f, hparams.dsv4_hc_eps); + cb(pre, "hc_head_pre", -1); + + return build_hc_weighted_sum(x, pre); +} + +ggml_tensor * llama_model_deepseek4::graph::build_hca_compressed_kv_from_state( + ggml_tensor * kv_state, + ggml_tensor * score_state, + ggml_tensor * state_read_idxs, + ggml_tensor * comp_pos, + ggml_tensor * norm, + int64_t n_embd_head, + const char * name, + int il) const { + const int64_t n_embd_head_rope = hparams.n_rot(); + const int64_t n_embd_head_nope = n_embd_head - n_embd_head_rope; + const int64_t n_blocks = comp_pos ? comp_pos->ne[0] : 0; + + GGML_ASSERT(n_blocks > 0); + GGML_ASSERT(state_read_idxs); + GGML_ASSERT(state_read_idxs->ne[0] == DSV4_HCA_RATIO*n_blocks); + GGML_ASSERT(n_embd_head >= n_embd_head_rope); + + ggml_tensor * kv = ggml_get_rows(ctx0, kv_state, state_read_idxs); + kv = ggml_reshape_3d(ctx0, kv, n_embd_head, DSV4_HCA_RATIO, n_blocks); + cb(kv, name, il); + + ggml_tensor * score = ggml_get_rows(ctx0, score_state, state_read_idxs); + score = ggml_reshape_3d(ctx0, score, n_embd_head, DSV4_HCA_RATIO, n_blocks); + cb(score, name, il); + + ggml_tensor * values = ggml_cont(ctx0, ggml_permute(ctx0, kv, 1, 0, 2, 3)); + ggml_tensor * scores = ggml_cont(ctx0, ggml_permute(ctx0, score, 1, 0, 2, 3)); + + ggml_tensor * weights = ggml_soft_max(ctx0, scores); + ggml_tensor * comp = ggml_mul(ctx0, values, weights); + comp = ggml_sum_rows(ctx0, comp); + comp = ggml_cont(ctx0, ggml_permute(ctx0, comp, 1, 0, 2, 3)); + cb(comp, name, il); + + comp = build_norm(comp, norm, nullptr, LLM_NORM_RMS, il); + cb(comp, name, il); + + ggml_tensor * comp_nope = ggml_view_3d(ctx0, comp, n_embd_head_nope, 1, n_blocks, + ggml_row_size(comp->type, n_embd_head), + ggml_row_size(comp->type, n_embd_head), + 0); + ggml_tensor * comp_pe = ggml_view_3d(ctx0, comp, n_embd_head_rope, 1, n_blocks, + ggml_row_size(comp->type, n_embd_head), + ggml_row_size(comp->type, n_embd_head), + ggml_row_size(comp->type, n_embd_head_nope)); + + comp_pe = ggml_rope_ext(ctx0, comp_pe, comp_pos, nullptr, n_embd_head_rope, rope_type, n_ctx_orig, + hparams.dsv4_compress_rope_base, freq_scale, ext_factor, + dsv4_rope_attn_factor(freq_scale, ext_factor), beta_fast, beta_slow); + cb(comp_pe, name, il); + + comp = ggml_concat(ctx0, comp_nope, comp_pe, 0); + cb(comp, name, il); + + return comp; +} + +ggml_tensor * llama_model_deepseek4::graph::build_overlap_compressed_kv_from_state( + ggml_tensor * kv_state, + ggml_tensor * score_state, + ggml_tensor * state_read_idxs, + ggml_tensor * comp_pos, + ggml_tensor * norm, + int64_t ratio, + int64_t n_embd_head, + const char * name, + int il) const { + const int64_t n_embd_head_rope = hparams.n_rot(); + const int64_t n_embd_head_nope = n_embd_head - n_embd_head_rope; + const int64_t n_blocks = comp_pos ? comp_pos->ne[0] : 0; + + GGML_ASSERT(n_blocks > 0); + GGML_ASSERT(state_read_idxs); + GGML_ASSERT(state_read_idxs->ne[0] == 2*ratio*n_blocks); + GGML_ASSERT(kv_state->ne[0] == 2*n_embd_head); + GGML_ASSERT(score_state->ne[0] == 2*n_embd_head); + GGML_ASSERT(n_embd_head >= n_embd_head_rope); + + kv_state = dsv4_append_zero_row(ctx0, kv_state, false); + score_state = dsv4_append_zero_row(ctx0, score_state, true); + + ggml_tensor * prev_idxs = dsv4_view_1d(ctx0, state_read_idxs, ratio*n_blocks, 0); + ggml_tensor * cur_idxs = dsv4_view_1d(ctx0, state_read_idxs, ratio*n_blocks, ratio*n_blocks); + + ggml_tensor * kv_prev = ggml_get_rows(ctx0, kv_state, prev_idxs); + kv_prev = ggml_cont(ctx0, ggml_view_2d(ctx0, kv_prev, n_embd_head, ratio*n_blocks, kv_prev->nb[1], 0)); + kv_prev = ggml_reshape_3d(ctx0, kv_prev, n_embd_head, ratio, n_blocks); + cb(kv_prev, name, il); + + ggml_tensor * score_prev = ggml_get_rows(ctx0, score_state, prev_idxs); + score_prev = ggml_cont(ctx0, ggml_view_2d(ctx0, score_prev, n_embd_head, ratio*n_blocks, score_prev->nb[1], 0)); + score_prev = ggml_reshape_3d(ctx0, score_prev, n_embd_head, ratio, n_blocks); + cb(score_prev, name, il); + + ggml_tensor * kv_cur = ggml_get_rows(ctx0, kv_state, cur_idxs); + kv_cur = ggml_cont(ctx0, ggml_view_2d(ctx0, kv_cur, n_embd_head, ratio*n_blocks, kv_cur->nb[1], + ggml_row_size(kv_cur->type, n_embd_head))); + kv_cur = ggml_reshape_3d(ctx0, kv_cur, n_embd_head, ratio, n_blocks); + + ggml_tensor * score_cur = ggml_get_rows(ctx0, score_state, cur_idxs); + score_cur = ggml_cont(ctx0, ggml_view_2d(ctx0, score_cur, n_embd_head, ratio*n_blocks, score_cur->nb[1], + ggml_row_size(score_cur->type, n_embd_head))); + score_cur = ggml_reshape_3d(ctx0, score_cur, n_embd_head, ratio, n_blocks); + + ggml_tensor * values = ggml_concat(ctx0, kv_prev, kv_cur, 1); + ggml_tensor * scores = ggml_concat(ctx0, score_prev, score_cur, 1); + + values = ggml_cont(ctx0, ggml_permute(ctx0, values, 1, 0, 2, 3)); + scores = ggml_cont(ctx0, ggml_permute(ctx0, scores, 1, 0, 2, 3)); + + ggml_tensor * weights = ggml_soft_max(ctx0, scores); + ggml_tensor * comp = ggml_mul(ctx0, values, weights); + comp = ggml_sum_rows(ctx0, comp); + comp = ggml_cont(ctx0, ggml_permute(ctx0, comp, 1, 0, 2, 3)); + cb(comp, name, il); + + comp = build_norm(comp, norm, nullptr, LLM_NORM_RMS, il); + cb(comp, name, il); + + ggml_tensor * comp_nope = ggml_view_3d(ctx0, comp, n_embd_head_nope, 1, n_blocks, + ggml_row_size(comp->type, n_embd_head), + ggml_row_size(comp->type, n_embd_head), + 0); + ggml_tensor * comp_pe = ggml_view_3d(ctx0, comp, n_embd_head_rope, 1, n_blocks, + ggml_row_size(comp->type, n_embd_head), + ggml_row_size(comp->type, n_embd_head), + ggml_row_size(comp->type, n_embd_head_nope)); + + comp_pe = ggml_rope_ext(ctx0, comp_pe, comp_pos, nullptr, n_embd_head_rope, rope_type, n_ctx_orig, + hparams.dsv4_compress_rope_base, freq_scale, ext_factor, + dsv4_rope_attn_factor(freq_scale, ext_factor), beta_fast, beta_slow); + cb(comp_pe, name, il); + + comp = ggml_concat(ctx0, comp_nope, comp_pe, 0); + cb(comp, name, il); + + return comp; +} + +ggml_tensor * llama_model_deepseek4::graph::build_lid_top_k( + const llama_model & model, + llm_graph_input_dsv4 * inp_dsv4, + ggml_tensor * qr, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int il) const { + const auto & layer = model.layers[il]; + const auto & inp_lid = inp_dsv4->get_lid(); + const int64_t n_embd_indexer_head = hparams.indexer_head_size; + const int64_t n_embd_indexer_head_rope = hparams.n_rot(); + const int64_t n_embd_indexer_head_nope = n_embd_indexer_head - n_embd_indexer_head_rope; + const int64_t n_indexer_head = hparams.indexer_n_head; + const int64_t nt = cur->ne[1]; + + GGML_ASSERT(inp_lid.kq_mask); + GGML_ASSERT(inp_lid.k_rot); + GGML_ASSERT(n_embd_indexer_head >= n_embd_indexer_head_rope); + + ggml_tensor * indexer_q = build_lora_mm(layer.indexer_attn_q_b, qr); + indexer_q = ggml_reshape_3d(ctx0, indexer_q, n_embd_indexer_head, n_indexer_head, nt); + cb(indexer_q, "lid_q", il); + + ggml_tensor * indexer_q_nope = ggml_view_3d(ctx0, indexer_q, n_embd_indexer_head_nope, n_indexer_head, nt, + ggml_row_size(indexer_q->type, n_embd_indexer_head), + ggml_row_size(indexer_q->type, n_embd_indexer_head)*n_indexer_head, + 0); + ggml_tensor * indexer_q_pe = ggml_view_3d(ctx0, indexer_q, n_embd_indexer_head_rope, n_indexer_head, nt, + ggml_row_size(indexer_q->type, n_embd_indexer_head), + ggml_row_size(indexer_q->type, n_embd_indexer_head)*n_indexer_head, + ggml_row_size(indexer_q->type, n_embd_indexer_head_nope)); + + indexer_q_pe = ggml_rope_ext(ctx0, indexer_q_pe, inp_pos, nullptr, n_embd_indexer_head_rope, + rope_type, n_ctx_orig, hparams.dsv4_compress_rope_base, freq_scale, + ext_factor, dsv4_rope_attn_factor(freq_scale, ext_factor), beta_fast, beta_slow); + cb(indexer_q_pe, "lid_q_pe", il); + + indexer_q = ggml_concat(ctx0, indexer_q_nope, indexer_q_pe, 0); + indexer_q = ggml_mul_mat(ctx0, inp_lid.k_rot, indexer_q); + cb(indexer_q, "lid_q_rot", il); + + ggml_tensor * indexer_weights = build_lora_mm(layer.indexer_proj, cur); + indexer_weights = ggml_scale(ctx0, indexer_weights, 1.0f/sqrtf(float(n_embd_indexer_head*n_indexer_head))); + cb(indexer_weights, "lid_weights", il); + + ggml_tensor * indexer_k = inp_dsv4->mctx->get_lid()->get_k(ctx0, il); + const int64_t n_lid = inp_lid.kq_mask->ne[0]; + GGML_ASSERT(n_lid > 0); + GGML_ASSERT(n_lid <= indexer_k->ne[2]); + + indexer_k = ggml_view_4d(ctx0, indexer_k, + indexer_k->ne[0], indexer_k->ne[1], n_lid, indexer_k->ne[3], + indexer_k->nb[1], indexer_k->nb[2], indexer_k->nb[3], 0); + cb(indexer_k, "lid_k", il); + + const int64_t n_stream = indexer_k->ne[3]; + indexer_q = ggml_view_4d(ctx0, indexer_q, + indexer_q->ne[0], indexer_q->ne[1], indexer_q->ne[2]/n_stream, n_stream, + indexer_q->nb[1], indexer_q->nb[2], indexer_q->nb[3]/n_stream, 0); + indexer_weights = ggml_view_4d(ctx0, indexer_weights, + indexer_weights->ne[0], indexer_weights->ne[1]/n_stream, indexer_weights->ne[2], n_stream, + indexer_weights->nb[1], indexer_weights->nb[2]/n_stream, indexer_weights->nb[3]/n_stream, 0); + + indexer_q = ggml_permute(ctx0, indexer_q, 0, 2, 1, 3); + cb(indexer_q, "lid_q", il); + indexer_k = ggml_permute(ctx0, indexer_k, 0, 2, 1, 3); + cb(indexer_k, "lid_k", il); + + ggml_tensor * indexer_kq = ggml_mul_mat(ctx0, indexer_k, indexer_q); + cb(indexer_kq, "lid_kq", il); + + indexer_kq = ggml_cont(ctx0, ggml_permute(ctx0, indexer_kq, 2, 1, 0, 3)); + cb(indexer_kq, "lid_kq", il); + + ggml_tensor * indexer_score = ggml_relu(ctx0, indexer_kq); + indexer_score = ggml_mul(ctx0, indexer_score, indexer_weights); + indexer_score = ggml_sum_rows(ctx0, indexer_score); + indexer_score = ggml_cont(ctx0, ggml_permute(ctx0, indexer_score, 2, 1, 0, 3)); + cb(indexer_score, "lid_score", il); + + indexer_score = ggml_add(ctx0, indexer_score, inp_lid.kq_mask); + cb(indexer_score, "lid_score_masked", il); + + const uint32_t n_top_k = indexer_score->ne[0] < hparams.indexer_top_k ? indexer_score->ne[0] : hparams.indexer_top_k; + ggml_tensor * top_k = ggml_cont(ctx0, ggml_top_k(ctx0, indexer_score, n_top_k)); + cb(top_k, "lid_top_k", il); + + return top_k; +} + +ggml_tensor * llama_model_deepseek4::graph::build_top_k_mask( + ggml_tensor * kq_mask, + ggml_tensor * top_k, + const char * name, + int il) const { + GGML_ASSERT(kq_mask); + GGML_ASSERT(top_k); + + ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask, -INFINITY); + kq_mask_all = ggml_view_4d(ctx0, kq_mask_all, 1, kq_mask_all->ne[0], kq_mask_all->ne[1], kq_mask_all->ne[3], + kq_mask_all->nb[0], kq_mask_all->nb[1], kq_mask_all->nb[2], 0); + + ggml_tensor * top_k_3d = ggml_view_4d(ctx0, top_k, top_k->ne[0], top_k->ne[1], top_k->ne[3], 1, + top_k->nb[1], top_k->nb[2], top_k->ne[3]*top_k->nb[3], 0); + + ggml_tensor * zeros = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 1, top_k_3d->ne[0], top_k_3d->ne[1], top_k_3d->ne[2]); + zeros = ggml_fill(ctx0, zeros, 0.0f); + + ggml_tensor * kq_mask_top_k = ggml_set_rows(ctx0, kq_mask_all, zeros, top_k_3d); + kq_mask_top_k = ggml_view_4d(ctx0, kq_mask_top_k, + kq_mask_top_k->ne[1], kq_mask_top_k->ne[2], 1, kq_mask_top_k->ne[3], + kq_mask_top_k->nb[2], kq_mask_top_k->nb[3], kq_mask_top_k->nb[3], 0); + + kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask); + cb(kq_mask_top_k, name, il); + + return kq_mask_top_k; +} + +ggml_tensor * llama_model_deepseek4::graph::build_csa_lid_attention( + const llama_model & model, + llm_graph_input_dsv4 * inp_dsv4, + llm_graph_input_dsv4_raw * inp_attn, + ggml_tensor * q, + ggml_tensor * kv, + ggml_tensor * qr, + ggml_tensor * cur, + ggml_tensor * inp_pos, + ggml_tensor * sinks, + float kq_scale, + int il) const { + const auto & inp_csa = inp_dsv4->get_csa(); + GGML_ASSERT(inp_csa.kq_mask); + GGML_ASSERT(inp_attn->self_k_rot == nullptr); + + ggml_tensor * top_k = build_lid_top_k(model, inp_dsv4, qr, cur, inp_pos, il); + + ggml_build_forward_expand(gf, q); + ggml_build_forward_expand(gf, kv); + + const llama_kv_cache_dsv4_raw_context * mctx_raw = inp_attn->mctx; + + ggml_build_forward_expand(gf, mctx_raw->cpy_k(ctx0, kv, inp_attn->get_k_idxs(), il)); + + ggml_tensor * raw_k = mctx_raw->get_k(ctx0, il); + cb(raw_k, "csa_raw_k", il); + + ggml_tensor * csa_k = inp_dsv4->mctx->get_csa()->get_k(ctx0, il); + const int64_t n_csa = inp_csa.kq_mask->ne[0]; + GGML_ASSERT(n_csa > 0); + GGML_ASSERT(n_csa <= csa_k->ne[2]); + + csa_k = ggml_view_4d(ctx0, csa_k, + csa_k->ne[0], csa_k->ne[1], n_csa, csa_k->ne[3], + csa_k->nb[1], csa_k->nb[2], csa_k->nb[3], 0); + cb(csa_k, "csa_comp_k", il); + + raw_k = dsv4_repeat_streams(ctx0, raw_k, csa_k->ne[3]); + + ggml_tensor * k_all = ggml_concat(ctx0, raw_k, csa_k, 2); + cb(k_all, "csa_k_all", il); + + ggml_tensor * raw_mask = inp_attn->get_kq_mask(); + ggml_tensor * csa_mask = build_top_k_mask(inp_csa.kq_mask, top_k, "csa_top_k_mask", il); + const bool use_fattn = cparams.flash_attn && (!cparams.kv_unified || csa_mask->ne[3] == 1); + if (use_fattn && csa_mask->type != GGML_TYPE_F16) { + csa_mask = ggml_cast(ctx0, csa_mask, GGML_TYPE_F16); + } + if (raw_mask->type != csa_mask->type) { + raw_mask = ggml_cast(ctx0, raw_mask, csa_mask->type); + } + + ggml_tensor * kq_mask = ggml_concat(ctx0, raw_mask, csa_mask, 0); + cb(kq_mask, "csa_lid_kq_mask", il); + + ggml_tensor * kq_b = dsv4_build_kq_zero_bias(ctx0, cparams, kq_mask, q->ne[1]); + ggml_tensor * out = build_attn_mha(q, k_all, k_all, kq_b, kq_mask, sinks, nullptr, kq_scale, il); + cb(out, "attn_csa_lid", il); + + return out; +} + +ggml_tensor * llama_model_deepseek4::graph::build_hca_attention( + llm_graph_input_dsv4 * inp_dsv4, + llm_graph_input_dsv4_raw * inp_attn, + ggml_tensor * q, + ggml_tensor * kv, + ggml_tensor * sinks, + float kq_scale, + int il) const { + const auto & inp_hca = inp_dsv4->get_hca(); + GGML_ASSERT(inp_hca.kq_mask); + GGML_ASSERT(inp_attn->self_k_rot == nullptr); + + ggml_build_forward_expand(gf, q); + ggml_build_forward_expand(gf, kv); + + const llama_kv_cache_dsv4_raw_context * mctx_raw = inp_attn->mctx; + + ggml_build_forward_expand(gf, mctx_raw->cpy_k(ctx0, kv, inp_attn->get_k_idxs(), il)); + + ggml_tensor * raw_k = mctx_raw->get_k(ctx0, il); + cb(raw_k, "hca_raw_k", il); + + ggml_tensor * hca_k = inp_dsv4->mctx->get_hca()->get_k(ctx0, il); + const int64_t n_hca = inp_hca.kq_mask->ne[0]; + GGML_ASSERT(n_hca > 0); + GGML_ASSERT(n_hca <= hca_k->ne[2]); + + hca_k = ggml_view_4d(ctx0, hca_k, + hca_k->ne[0], hca_k->ne[1], n_hca, hca_k->ne[3], + hca_k->nb[1], hca_k->nb[2], hca_k->nb[3], 0); + cb(hca_k, "hca_comp_k", il); + + raw_k = dsv4_repeat_streams(ctx0, raw_k, hca_k->ne[3]); + + ggml_tensor * k_all = ggml_concat(ctx0, raw_k, hca_k, 2); + cb(k_all, "hca_k_all", il); + + ggml_tensor * raw_mask = inp_attn->get_kq_mask(); + ggml_tensor * hca_mask = inp_hca.kq_mask; + const bool use_fattn = cparams.flash_attn && (!cparams.kv_unified || hca_mask->ne[3] == 1); + if (use_fattn && hca_mask->type != GGML_TYPE_F16) { + hca_mask = ggml_cast(ctx0, hca_mask, GGML_TYPE_F16); + } + if (raw_mask->type != hca_mask->type) { + raw_mask = ggml_cast(ctx0, raw_mask, hca_mask->type); + } + + ggml_tensor * kq_mask = ggml_concat(ctx0, raw_mask, hca_mask, 0); + cb(kq_mask, "hca_kq_mask", il); + + ggml_tensor * kq_b = dsv4_build_kq_zero_bias(ctx0, cparams, kq_mask, q->ne[1]); + ggml_tensor * out = build_attn_mha(q, k_all, k_all, kq_b, kq_mask, sinks, nullptr, kq_scale, il); + cb(out, "attn_hca", il); + + return out; +} + +ggml_tensor * llama_model_deepseek4::graph::build_raw_attention( + llm_graph_input_dsv4_raw * inp_attn, + ggml_tensor * q, + ggml_tensor * kv, + ggml_tensor * sinks, + float kq_scale, + int il) const { + GGML_ASSERT(hparams.is_swa(il)); + + ggml_tensor * k_rot = inp_attn->self_k_rot; + + if (k_rot) { + q = ggml_mul_mat(ctx0, k_rot, q); + kv = ggml_mul_mat(ctx0, k_rot, kv); + } + + ggml_build_forward_expand(gf, q); + ggml_build_forward_expand(gf, kv); + + const llama_kv_cache_dsv4_raw_context * mctx_cur = inp_attn->mctx; + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, kv, inp_attn->get_k_idxs(), il)); + + ggml_tensor * kq_mask = inp_attn->get_kq_mask(); + + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + k = dsv4_repeat_streams(ctx0, k, kq_mask->ne[3]); + + ggml_tensor * kq_b = dsv4_build_kq_zero_bias(ctx0, cparams, kq_mask, q->ne[1]); + ggml_tensor * out = build_attn_mha(q, k, k, kq_b, kq_mask, sinks, nullptr, kq_scale, il); + cb(out, "attn_raw", il); + + return out; +} + +ggml_tensor * llama_model_deepseek4::graph::build_attention( + const llama_model & model, + llm_graph_input_dsv4 * inp_dsv4, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int il) const { + const auto & layer = model.layers[il]; + llm_graph_input_dsv4_raw * inp_attn = inp_dsv4->get_raw(); + + const int64_t n_embd_head = hparams.n_embd_head_k(); + const int64_t n_embd_head_rope = hparams.n_rot(); + const int64_t n_embd_head_nope = n_embd_head - n_embd_head_rope; + const int64_t n_groups = hparams.dsv4_o_group_count; + const int64_t n_heads_group = n_head / n_groups; + const int64_t o_lora_rank = hparams.dsv4_o_lora_rank; + const int64_t o_group_dim = n_heads_group*n_embd_head; + const int64_t nt = cur->ne[1]; + + GGML_ASSERT(n_embd_head == n_embd_head_v); + GGML_ASSERT(n_head % n_groups == 0); + + const bool use_compress_rope = hparams.dsv4_compress_ratios[il] != 0; + const float freq_base_l = use_compress_rope ? hparams.dsv4_compress_rope_base : freq_base; + const float freq_scale_l = use_compress_rope ? freq_scale : 1.0f; + const float ext_factor_l = use_compress_rope ? ext_factor : 0.0f; + const float attn_factor_l = dsv4_rope_attn_factor(freq_scale_l, ext_factor_l); + const float beta_fast_l = use_compress_rope ? beta_fast : 0.0f; + const float beta_slow_l = use_compress_rope ? beta_slow : 0.0f; + const int32_t n_ctx_orig_l = use_compress_rope ? n_ctx_orig : 0; + + ggml_tensor * qr = build_lora_mm(layer.wq_a, cur); + cb(qr, "qr", il); + + qr = build_norm(qr, layer.attn_q_a_norm, nullptr, LLM_NORM_RMS, il); + cb(qr, "qr_norm", il); + + ggml_tensor * q = build_lora_mm(layer.wq_b, qr); + q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, nt); + q = ggml_rms_norm(ctx0, q, norm_rms_eps); + cb(q, "q_norm", il); + + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_nope, n_head, nt, + ggml_row_size(q->type, n_embd_head), + ggml_row_size(q->type, n_embd_head)*n_head, + 0); + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_rope, n_head, nt, + ggml_row_size(q->type, n_embd_head), + ggml_row_size(q->type, n_embd_head)*n_head, + ggml_row_size(q->type, n_embd_head_nope)); + q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, n_embd_head_rope, rope_type, n_ctx_orig_l, + freq_base_l, freq_scale_l, ext_factor_l, attn_factor_l, beta_fast_l, beta_slow_l); + cb(q_pe, "q_pe", il); + q = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q, "q", il); + + ggml_tensor * kv = build_lora_mm(layer.wkv, cur); + kv = build_norm(kv, layer.attn_kv_norm, nullptr, LLM_NORM_RMS, il); + kv = ggml_reshape_3d(ctx0, kv, n_embd_head, 1, nt); + cb(kv, "kv_norm", il); + + ggml_tensor * kv_nope = ggml_view_3d(ctx0, kv, n_embd_head_nope, 1, nt, + ggml_row_size(kv->type, n_embd_head), + ggml_row_size(kv->type, n_embd_head), + 0); + ggml_tensor * kv_pe = ggml_view_3d(ctx0, kv, n_embd_head_rope, 1, nt, + ggml_row_size(kv->type, n_embd_head), + ggml_row_size(kv->type, n_embd_head), + ggml_row_size(kv->type, n_embd_head_nope)); + kv_pe = ggml_rope_ext(ctx0, kv_pe, inp_pos, nullptr, n_embd_head_rope, rope_type, n_ctx_orig_l, + freq_base_l, freq_scale_l, ext_factor_l, attn_factor_l, beta_fast_l, beta_slow_l); + cb(kv_pe, "kv_pe", il); + kv = ggml_concat(ctx0, kv_nope, kv_pe, 0); + cb(kv, "kv", il); + + const int64_t ratio = hparams.dsv4_compress_ratios[il]; + + ggml_tensor * hca_state_kv = nullptr; + ggml_tensor * hca_state_score = nullptr; + if (ratio == DSV4_HCA_RATIO && inp_dsv4->get_hca().state_pos) { + hca_state_kv = build_lora_mm(layer.attn_comp_wkv, cur); + cb(hca_state_kv, "hca_state_kv", il); + + hca_state_score = build_lora_mm(layer.attn_comp_wgate, cur); + cb(hca_state_score, "hca_state_score", il); + + ggml_tensor * ape = layer.attn_comp_ape; + + ggml_tensor * ape_rows = ggml_get_rows(ctx0, ape, inp_dsv4->get_hca().state_pos); + hca_state_score = ggml_add(ctx0, hca_state_score, ape_rows); + cb(hca_state_score, "hca_state_score_ape", il); + + } + + if (ratio == DSV4_CSA_RATIO && inp_dsv4->get_csa().state_pos) { + ggml_tensor * csa_state_kv = build_lora_mm(layer.attn_comp_wkv, cur); + cb(csa_state_kv, "csa_state_kv", il); + + ggml_tensor * csa_state_score = build_lora_mm(layer.attn_comp_wgate, cur); + cb(csa_state_score, "csa_state_score", il); + + ggml_tensor * csa_ape = layer.attn_comp_ape; + + ggml_tensor * csa_ape_rows = ggml_get_rows(ctx0, csa_ape, inp_dsv4->get_csa().state_pos); + csa_state_score = ggml_add(ctx0, csa_state_score, csa_ape_rows); + cb(csa_state_score, "csa_state_score_ape", il); + + GGML_ASSERT(inp_dsv4->get_csa().state_write_idxs); + + ggml_tensor * csa_source_kv = ggml_concat(ctx0, + inp_dsv4->mctx->get_csa_state()->get_kv(ctx0, il), csa_state_kv, 1); + ggml_tensor * csa_source_score = ggml_concat(ctx0, + inp_dsv4->mctx->get_csa_state()->get_score(ctx0, il), csa_state_score, 1); + + ggml_tensor * kv_comp_csa_state = build_overlap_compressed_kv_from_state( + csa_source_kv, + csa_source_score, + inp_dsv4->get_csa().state_read_idxs, + inp_dsv4->get_csa().state_write_pos, + layer.attn_comp_norm, + DSV4_CSA_RATIO, + n_embd_head, + "csa_state_compress", + il); + + ggml_build_forward_expand(gf, inp_dsv4->mctx->get_csa()->cpy_k(ctx0, + kv_comp_csa_state, inp_dsv4->get_csa().state_write_idxs, il)); + + csa_state_kv = dsv4_with_zero_dep(ctx0, csa_state_kv, kv_comp_csa_state); + csa_state_score = dsv4_with_zero_dep(ctx0, csa_state_score, kv_comp_csa_state); + + ggml_tensor * csa_persist_kv = ggml_get_rows(ctx0, csa_state_kv, inp_dsv4->get_csa().state_persist_src_idxs); + ggml_tensor * csa_persist_score = ggml_get_rows(ctx0, csa_state_score, inp_dsv4->get_csa().state_persist_src_idxs); + + csa_state_kv = inp_dsv4->mctx->get_csa_state()->cpy_kv(ctx0, + csa_persist_kv, inp_dsv4->get_csa().state_persist_dst_idxs, il); + csa_state_score = inp_dsv4->mctx->get_csa_state()->cpy_score(ctx0, + csa_persist_score, inp_dsv4->get_csa().state_persist_dst_idxs, il); + + ggml_build_forward_expand(gf, csa_state_kv); + ggml_build_forward_expand(gf, csa_state_score); + + ggml_tensor * lid_state_kv = build_lora_mm(layer.indexer_comp_wkv, cur); + cb(lid_state_kv, "lid_state_kv", il); + + ggml_tensor * lid_state_score = build_lora_mm(layer.indexer_comp_wgate, cur); + cb(lid_state_score, "lid_state_score", il); + + ggml_tensor * lid_ape = layer.indexer_comp_ape; + + ggml_tensor * lid_ape_rows = ggml_get_rows(ctx0, lid_ape, inp_dsv4->get_lid().state_pos); + lid_state_score = ggml_add(ctx0, lid_state_score, lid_ape_rows); + cb(lid_state_score, "lid_state_score_ape", il); + + GGML_ASSERT(inp_dsv4->get_lid().state_write_idxs); + + ggml_tensor * lid_source_kv = ggml_concat(ctx0, + inp_dsv4->mctx->get_lid_state()->get_kv(ctx0, il), lid_state_kv, 1); + ggml_tensor * lid_source_score = ggml_concat(ctx0, + inp_dsv4->mctx->get_lid_state()->get_score(ctx0, il), lid_state_score, 1); + + ggml_tensor * kv_comp_lid_state = build_overlap_compressed_kv_from_state( + lid_source_kv, + lid_source_score, + inp_dsv4->get_lid().state_read_idxs, + inp_dsv4->get_lid().state_write_pos, + layer.indexer_comp_norm, + DSV4_CSA_RATIO, + hparams.indexer_head_size, + "lid_state_compress", + il); + + if (inp_dsv4->get_lid().k_rot) { + kv_comp_lid_state = ggml_mul_mat(ctx0, inp_dsv4->get_lid().k_rot, kv_comp_lid_state); + cb(kv_comp_lid_state, "lid_state_compress_rot", il); + } + + ggml_build_forward_expand(gf, inp_dsv4->mctx->get_lid()->cpy_k(ctx0, + kv_comp_lid_state, inp_dsv4->get_lid().state_write_idxs, il)); + + lid_state_kv = dsv4_with_zero_dep(ctx0, lid_state_kv, kv_comp_lid_state); + lid_state_score = dsv4_with_zero_dep(ctx0, lid_state_score, kv_comp_lid_state); + + ggml_tensor * lid_persist_kv = ggml_get_rows(ctx0, lid_state_kv, inp_dsv4->get_lid().state_persist_src_idxs); + ggml_tensor * lid_persist_score = ggml_get_rows(ctx0, lid_state_score, inp_dsv4->get_lid().state_persist_src_idxs); + + lid_state_kv = inp_dsv4->mctx->get_lid_state()->cpy_kv(ctx0, + lid_persist_kv, inp_dsv4->get_lid().state_persist_dst_idxs, il); + lid_state_score = inp_dsv4->mctx->get_lid_state()->cpy_score(ctx0, + lid_persist_score, inp_dsv4->get_lid().state_persist_dst_idxs, il); + + ggml_build_forward_expand(gf, lid_state_kv); + ggml_build_forward_expand(gf, lid_state_score); + } + + ggml_tensor * hca_state_dep = nullptr; + if (ratio == DSV4_HCA_RATIO && inp_dsv4->get_hca().state_write_idxs) { + GGML_ASSERT(hca_state_kv); + GGML_ASSERT(hca_state_score); + + ggml_tensor * hca_source_kv = ggml_concat(ctx0, + inp_dsv4->mctx->get_hca_state()->get_kv(ctx0, il), hca_state_kv, 1); + ggml_tensor * hca_source_score = ggml_concat(ctx0, + inp_dsv4->mctx->get_hca_state()->get_score(ctx0, il), hca_state_score, 1); + + ggml_tensor * kv_comp_hca = build_hca_compressed_kv_from_state( + hca_source_kv, + hca_source_score, + inp_dsv4->get_hca().state_read_idxs, + inp_dsv4->get_hca().state_write_pos, + layer.attn_comp_norm, + n_embd_head, + "hca_state_compress", + il); + + ggml_build_forward_expand(gf, inp_dsv4->mctx->get_hca()->cpy_k(ctx0, + kv_comp_hca, inp_dsv4->get_hca().state_write_idxs, il)); + hca_state_dep = kv_comp_hca; + } + + if (ratio == DSV4_HCA_RATIO && inp_dsv4->get_hca().state_pos) { + GGML_ASSERT(hca_state_kv); + GGML_ASSERT(hca_state_score); + + hca_state_kv = dsv4_with_zero_dep(ctx0, hca_state_kv, hca_state_dep); + hca_state_score = dsv4_with_zero_dep(ctx0, hca_state_score, hca_state_dep); + + ggml_tensor * hca_persist_kv = ggml_get_rows(ctx0, hca_state_kv, inp_dsv4->get_hca().state_persist_src_idxs); + ggml_tensor * hca_persist_score = ggml_get_rows(ctx0, hca_state_score, inp_dsv4->get_hca().state_persist_src_idxs); + + hca_state_kv = inp_dsv4->mctx->get_hca_state()->cpy_kv(ctx0, + hca_persist_kv, inp_dsv4->get_hca().state_persist_dst_idxs, il); + hca_state_score = inp_dsv4->mctx->get_hca_state()->cpy_score(ctx0, + hca_persist_score, inp_dsv4->get_hca().state_persist_dst_idxs, il); + + ggml_build_forward_expand(gf, hca_state_kv); + ggml_build_forward_expand(gf, hca_state_score); + } + + ggml_tensor * out = nullptr; + if (ratio == DSV4_CSA_RATIO && + inp_dsv4->get_csa().kq_mask && + inp_dsv4->get_lid().kq_mask && + inp_dsv4->get_lid().k_rot && + inp_attn->self_k_rot == nullptr) { + out = build_csa_lid_attention(model, inp_dsv4, inp_attn, q, kv, qr, cur, inp_pos, layer.attn_sinks, + 1.0f/sqrtf(float(n_embd_head)), il); + } else if (ratio == DSV4_HCA_RATIO && + inp_dsv4->get_hca().kq_mask && + inp_attn->self_k_rot == nullptr) { + out = build_hca_attention(inp_dsv4, inp_attn, q, kv, layer.attn_sinks, + 1.0f/sqrtf(float(n_embd_head)), il); + } else { + out = build_raw_attention(inp_attn, q, kv, layer.attn_sinks, + 1.0f/sqrtf(float(n_embd_head)), il); + } + + out = ggml_reshape_3d(ctx0, out, n_embd_head, n_head, nt); + ggml_tensor * out_nope = ggml_view_3d(ctx0, out, n_embd_head_nope, n_head, nt, + ggml_row_size(out->type, n_embd_head), + ggml_row_size(out->type, n_embd_head)*n_head, + 0); + ggml_tensor * out_pe = ggml_view_3d(ctx0, out, n_embd_head_rope, n_head, nt, + ggml_row_size(out->type, n_embd_head), + ggml_row_size(out->type, n_embd_head)*n_head, + ggml_row_size(out->type, n_embd_head_nope)); + out_pe = ggml_rope_ext_back(ctx0, out_pe, inp_pos, nullptr, n_embd_head_rope, rope_type, n_ctx_orig_l, + freq_base_l, freq_scale_l, ext_factor_l, attn_factor_l, beta_fast_l, beta_slow_l); + out = ggml_concat(ctx0, out_nope, out_pe, 0); + cb(out, "attn_derope", il); + + out = ggml_reshape_3d(ctx0, out, o_group_dim, n_groups, nt); + out = ggml_permute(ctx0, out, 0, 2, 1, 3); + ggml_tensor * oa = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, layer.wo_a, layer.wo_a->ne[0], o_lora_rank, n_groups), out); + cb(oa, "attn_wo_a", il); + oa = ggml_permute(ctx0, oa, 0, 2, 1, 3); + oa = ggml_cont_2d(ctx0, oa, o_lora_rank*n_groups, nt); + + out = build_lora_mm(layer.wo_b, oa); + cb(out, "attn_out", il); + + return out; +} + +llama_model_deepseek4::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + ggml_tensor * cur; + + ggml_tensor * inp = build_inp_embd(model.tok_embd); + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_dsv4 * inp_dsv4 = build_inp_dsv4(); + llm_graph_input_dsv4_raw * inp_attn = inp_dsv4->get_raw(); + ggml_build_forward_expand(gf, inp_attn->self_kq_mask); + + const int64_t hc = hparams.dsv4_hc_mult; + ggml_tensor * inpL = ggml_reshape_3d(ctx0, inp, n_embd, 1, n_tokens); + inpL = ggml_repeat_4d(ctx0, inpL, n_embd, hc, n_tokens, 1); + cb(inpL, "hc_init", -1); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * residual = inpL; + ggml_tensor * post = nullptr; + ggml_tensor * comb = nullptr; + + cur = build_hc_pre(inpL, + model.layers[il].hc_attn_fn, + model.layers[il].hc_attn_scale, + model.layers[il].hc_attn_base, + &post, &comb, il); + cb(cur, "hc_attn_pre", il); + + cur = build_norm(cur, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + cur = build_attention(model, inp_dsv4, cur, inp_pos, il); + + inpL = build_hc_post(cur, residual, post, comb, il); + cb(inpL, "hc_attn_post", il); + + residual = inpL; + cur = build_hc_pre(inpL, + model.layers[il].hc_ffn_fn, + model.layers[il].hc_ffn_scale, + model.layers[il].hc_ffn_base, + &post, &comb, il); + cb(cur, "hc_ffn_pre", il); + + cur = build_norm(cur, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + const auto & layer = model.layers[il]; + ggml_tensor * selected_experts = nullptr; + ggml_tensor * exp_probs_b = layer.ffn_exp_probs_b; + if ((uint32_t) il < hparams.dsv4_hash_layer_count) { + selected_experts = ggml_get_rows(ctx0, layer.ffn_gate_tid2eid, res->t_inp_tokens); + exp_probs_b = nullptr; + } + + ggml_tensor * moe_out = build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + exp_probs_b, + n_expert, hparams.n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + selected_experts); + cb(moe_out, "ffn_moe_out", il); + + ggml_tensor * ffn_shexp = build_ffn(cur, + layer.ffn_up_shexp, nullptr, nullptr, + layer.ffn_gate_shexp, nullptr, nullptr, + layer.ffn_down_shexp, nullptr, nullptr, + nullptr, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + + inpL = build_hc_post(cur, residual, post, comb, il); + inpL = build_cvec(inpL, il); + cb(inpL, "l_out", il); + } + + if (inp_out_ids) { + ggml_tensor * flat = ggml_reshape_2d(ctx0, inpL, n_embd*hc, n_tokens); + flat = ggml_get_rows(ctx0, flat, inp_out_ids); + inpL = ggml_reshape_3d(ctx0, flat, n_embd, hc, n_outputs); + } + + cur = build_hc_head(inpL, model.hc_head_fn, model.hc_head_scale, model.hc_head_base); + cb(cur, "hc_head", -1); + + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index d89ab96d0271..7a52e7bc1ab7 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -1085,6 +1085,121 @@ struct llama_model_deepseek32 : public llama_model_base { }; +struct llama_model_deepseek4 : public llama_model_base { + llama_model_deepseek4(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + + ggml_tensor * build_hc_pre( + ggml_tensor * x, + ggml_tensor * hc_fn, + ggml_tensor * hc_scale, + ggml_tensor * hc_base, + ggml_tensor ** post, + ggml_tensor ** comb, + int il) const; + + ggml_tensor * build_hc_post( + ggml_tensor * x, + ggml_tensor * residual, + ggml_tensor * post, + ggml_tensor * comb, + int il) const; + + ggml_tensor * build_hc_head( + ggml_tensor * x, + ggml_tensor * hc_fn, + ggml_tensor * hc_scale, + ggml_tensor * hc_base) const; + + ggml_tensor * build_attention( + const llama_model & model, + llm_graph_input_dsv4 * inp_dsv4, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int il) const; + + ggml_tensor * build_hca_compressed_kv_from_state( + ggml_tensor * kv_state, + ggml_tensor * score_state, + ggml_tensor * state_read_idxs, + ggml_tensor * comp_pos, + ggml_tensor * norm, + int64_t n_embd_head, + const char * name, + int il) const; + + ggml_tensor * build_overlap_compressed_kv_from_state( + ggml_tensor * kv_state, + ggml_tensor * score_state, + ggml_tensor * state_read_idxs, + ggml_tensor * comp_pos, + ggml_tensor * norm, + int64_t ratio, + int64_t n_embd_head, + const char * name, + int il) const; + + ggml_tensor * build_lid_top_k( + const llama_model & model, + llm_graph_input_dsv4 * inp_dsv4, + ggml_tensor * qr, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int il) const; + + ggml_tensor * build_top_k_mask( + ggml_tensor * kq_mask, + ggml_tensor * top_k, + const char * name, + int il) const; + + ggml_tensor * build_csa_lid_attention( + const llama_model & model, + llm_graph_input_dsv4 * inp_dsv4, + llm_graph_input_dsv4_raw * inp_attn, + ggml_tensor * q, + ggml_tensor * kv, + ggml_tensor * qr, + ggml_tensor * cur, + ggml_tensor * inp_pos, + ggml_tensor * sinks, + float kq_scale, + int il) const; + + ggml_tensor * build_hca_attention( + llm_graph_input_dsv4 * inp_dsv4, + llm_graph_input_dsv4_raw * inp_attn, + ggml_tensor * q, + ggml_tensor * kv, + ggml_tensor * sinks, + float kq_scale, + int il) const; + + ggml_tensor * build_raw_attention( + llm_graph_input_dsv4_raw * inp_attn, + ggml_tensor * q, + ggml_tensor * kv, + ggml_tensor * sinks, + float kq_scale, + int il) const; + + ggml_tensor * build_hc_weighted_sum( + ggml_tensor * x, + ggml_tensor * weights) const; + + ggml_tensor * build_hc_sinkhorn( + ggml_tensor * comb, + int il) const; + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_deepseek2ocr : public llama_model_base { llama_model_deepseek2ocr(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; diff --git a/tests/test-llama-archs.cpp b/tests/test-llama-archs.cpp index c781d2903e3d..f39abe773fc6 100644 --- a/tests/test-llama-archs.cpp +++ b/tests/test-llama-archs.cpp @@ -412,6 +412,9 @@ static bool arch_supported(const llm_arch arch) { if (arch == LLM_ARCH_DEEPSEEK2OCR) { return false; } + if (arch == LLM_ARCH_DEEPSEEK4) { + return false; + } // FIXME some models are segfaulting with WebGPU: #ifdef GGML_USE_WEBGPU