Skip to content

Commit 65861e7

Browse files
committed
convert: add dsv4 conversion
1 parent af6528e commit 65861e7

4 files changed

Lines changed: 437 additions & 1 deletion

File tree

conversion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
"DeepseekV2ForCausalLM": "deepseek",
4949
"DeepseekV3ForCausalLM": "deepseek",
5050
"DeepseekV32ForCausalLM": "deepseek",
51+
"DeepseekV4ForCausalLM": "deepseek",
5152
"DistilBertForMaskedLM": "bert",
5253
"DistilBertForSequenceClassification": "bert",
5354
"DistilBertModel": "bert",

conversion/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2578,6 +2578,17 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
25782578
return cls._wrap_fn(func)(*args, **kwargs)
25792579

25802580

2581+
if hasattr(torch, "float8_e8m0fnu"):
2582+
_torch_float8_e8m0 = torch.float8_e8m0fnu
2583+
LazyTorchTensor._dtype_map[_torch_float8_e8m0] = np.uint8
2584+
LazyTorchTensor._dtype_byteswap_map[_torch_float8_e8m0] = np.uint8
2585+
LazyTorchTensor._dtype_str_map["F8_E8M0"] = _torch_float8_e8m0
2586+
else:
2587+
# Older torch builds do not expose F8_E8M0. Keep the raw bytes so callers
2588+
# that know the format can decode them explicitly.
2589+
LazyTorchTensor._dtype_str_map["F8_E8M0"] = torch.uint8
2590+
2591+
25812592
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
25822593
# TODO @ngxson : this won't work correctly if the model has both audio & vision encoders
25832594
# maybe we should fallback to text model's arch in that case, since not many models have both

conversion/deepseek.py

Lines changed: 336 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from __future__ import annotations
22

3+
import json
34
import re
45

56
from typing import Any, Callable, Iterable, TYPE_CHECKING
67

8+
import numpy as np
79
import torch
810

911
if TYPE_CHECKING:
1012
from torch import Tensor
1113

12-
from .base import MmprojModel, ModelBase, TextModel, gguf, logger
14+
from .base import LazyTorchTensor, MmprojModel, ModelBase, TextModel, gguf, logger
1315

1416
from .qwen import QwenModel
1517

@@ -459,3 +461,336 @@ def set_gguf_parameters(self):
459461
self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"])
460462
self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"])
461463
self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"])
464+
465+
466+
@ModelBase.register("DeepseekV4ForCausalLM")
467+
class DeepseekV4FlashModel(TextModel):
468+
model_arch = gguf.MODEL_ARCH.DEEPSEEK_V4_FLASH
469+
_skipped_mtp_tensors = 0
470+
471+
def __init__(self, *args, **kwargs):
472+
type(self)._skipped_mtp_tensors = 0
473+
super().__init__(*args, **kwargs)
474+
475+
with open(self.dir_model / "config.json", "r", encoding="utf-8") as f:
476+
raw_hparams = json.load(f)
477+
for key, value in raw_hparams.items():
478+
self.hparams.setdefault(key, value)
479+
480+
self.block_count = self.hparams["num_hidden_layers"]
481+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
482+
483+
self._dsv4_fp8_dequantized: set[str] = set()
484+
self._dsv4_bf16_tensors: set[str] = set()
485+
self._dsv4_f32_tensors: set[str] = set()
486+
self._dsv4_mxfp4_generated = False
487+
self._collect_source_dtypes()
488+
489+
if type(self)._skipped_mtp_tensors:
490+
logger.info("Skipping %d DeepSeek-V4 MTP tensor(s) for conversion v0", type(self)._skipped_mtp_tensors)
491+
492+
@classmethod
493+
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
494+
name, _ = item
495+
if name.startswith("mtp."):
496+
cls._skipped_mtp_tensors += 1
497+
return None
498+
return super().filter_tensors(item)
499+
500+
def set_vocab(self):
501+
self._set_vocab_gpt2()
502+
503+
@staticmethod
504+
def _float8_dtypes() -> tuple[torch.dtype, ...]:
505+
return tuple(
506+
dtype for dtype in (
507+
getattr(torch, "float8_e4m3fn", None),
508+
getattr(torch, "float8_e5m2", None),
509+
) if dtype is not None
510+
)
511+
512+
@staticmethod
513+
def _e8m0_to_float(scale: Tensor) -> Tensor:
514+
torch_float8_e8m0 = getattr(torch, "float8_e8m0fnu", None)
515+
if torch_float8_e8m0 is not None and scale.dtype == torch_float8_e8m0:
516+
return scale.float()
517+
518+
bits = scale.view(torch.uint8).float()
519+
return torch.pow(torch.tensor(2.0, device=bits.device), bits - 127.0)
520+
521+
def _collect_source_dtypes(self) -> None:
522+
for name, gen in self.model_tensors.items():
523+
dtype = gen().dtype
524+
if dtype == torch.bfloat16:
525+
self._dsv4_bf16_tensors.add(name)
526+
elif dtype == torch.float32:
527+
self._dsv4_f32_tensors.add(name)
528+
529+
def set_gguf_parameters(self):
530+
hparams = self.hparams
531+
arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
532+
533+
self.gguf_writer.add_block_count(self.block_count)
534+
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
535+
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
536+
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
537+
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
538+
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
539+
self.gguf_writer.add_key_length(hparams["head_dim"])
540+
self.gguf_writer.add_value_length(hparams["head_dim"])
541+
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
542+
self.gguf_writer.add_rope_freq_base(hparams["rope_theta"])
543+
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
544+
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
545+
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
546+
547+
rope_scaling = hparams.get("rope_scaling") or {}
548+
rope_type = rope_scaling.get("type", rope_scaling.get("rope_type"))
549+
if rope_type == "yarn":
550+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
551+
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
552+
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
553+
if (yarn_beta_fast := rope_scaling.get("beta_fast")) is not None:
554+
self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_beta_fast)
555+
if (yarn_beta_slow := rope_scaling.get("beta_slow")) is not None:
556+
self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_beta_slow)
557+
else:
558+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
559+
560+
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
561+
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
562+
self.gguf_writer.add_expert_used_count(hparams["num_experts_per_tok"])
563+
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
564+
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
565+
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
566+
self.gguf_writer.add_swiglu_clamp_exp([hparams["swiglu_limit"]] * self.block_count)
567+
self.gguf_writer.add_swiglu_clamp_shexp([hparams["swiglu_limit"]] * self.block_count)
568+
569+
self.gguf_writer.add_indexer_head_count(hparams["index_n_heads"])
570+
self.gguf_writer.add_indexer_key_length(hparams["index_head_dim"])
571+
self.gguf_writer.add_indexer_top_k(hparams["index_topk"])
572+
573+
self.gguf_writer.add_uint32(f"{arch}.attention.o_group_count", hparams["o_groups"])
574+
self.gguf_writer.add_uint32(f"{arch}.attention.o_lora_rank", hparams["o_lora_rank"])
575+
self.gguf_writer.add_array(f"{arch}.attention.compress_ratios", hparams["compress_ratios"])
576+
self.gguf_writer.add_float32(f"{arch}.attention.compress_rope.freq_base", hparams["compress_rope_theta"])
577+
self.gguf_writer.add_uint32(f"{arch}.hc.mult", hparams["hc_mult"])
578+
self.gguf_writer.add_uint32(f"{arch}.hc.sinkhorn_iters", hparams["hc_sinkhorn_iters"])
579+
self.gguf_writer.add_float32(f"{arch}.hc.eps", hparams["hc_eps"])
580+
self.gguf_writer.add_uint32(f"{arch}.moe.hash_layer_count", hparams["num_hash_layers"])
581+
self.gguf_writer.add_string(f"{arch}.moe.score_func", hparams["scoring_func"])
582+
self.gguf_writer.add_string(f"{arch}.moe.topk_method", hparams["topk_method"])
583+
584+
self.gguf_writer.add_file_type(self.ftype)
585+
logger.info(f"gguf: file type = {self.ftype}")
586+
587+
def dequant_model(self):
588+
fp8_dtypes = self._float8_dtypes()
589+
tensors_to_remove: list[str] = []
590+
591+
def dequant_fp8_weight(weight: Tensor, scale: Tensor) -> Tensor:
592+
out_features, in_features = weight.shape
593+
scale_f = self._e8m0_to_float(scale)
594+
scale_f = scale_f.repeat_interleave(128, 0)[:out_features]
595+
scale_f = scale_f.repeat_interleave(128, 1)[:, :in_features]
596+
return weight.float() * scale_f
597+
598+
for name in list(self.model_tensors.keys()):
599+
if not name.endswith(".scale"):
600+
continue
601+
weight_name = name.removesuffix(".scale") + ".weight"
602+
if weight_name not in self.model_tensors:
603+
continue
604+
605+
weight = self.model_tensors[weight_name]
606+
scale = self.model_tensors[name]
607+
if weight().dtype not in fp8_dtypes:
608+
continue
609+
610+
self.model_tensors[weight_name] = lambda w=weight, s=scale: dequant_fp8_weight(w(), s())
611+
self._dsv4_fp8_dequantized.add(weight_name)
612+
tensors_to_remove.append(name)
613+
614+
for name in tensors_to_remove:
615+
del self.model_tensors[name]
616+
617+
@staticmethod
618+
def _pack_mxfp4_blocks(weight: Tensor, scale: Tensor) -> np.ndarray:
619+
packed = weight.contiguous().view(torch.uint8)
620+
scale_u8 = scale.contiguous().view(torch.uint8)
621+
622+
out_features, packed_cols = packed.shape
623+
logical_cols = packed_cols * 2
624+
if logical_cols % 32 != 0:
625+
raise ValueError(f"MXFP4 source row has {logical_cols} values, expected a multiple of 32")
626+
627+
n_blocks = logical_cols // 32
628+
if tuple(scale_u8.shape) != (out_features, n_blocks):
629+
raise ValueError(f"MXFP4 scale shape {tuple(scale_u8.shape)} does not match {(out_features, n_blocks)}")
630+
631+
src = packed.reshape(out_features, n_blocks, 16)
632+
low = src & 0x0F
633+
high = (src >> 4) & 0x0F
634+
635+
# The safetensors bytes store adjacent values as low/high nibbles.
636+
# ggml MXFP4 blocks store values 0..15 in low nibbles and 16..31 in high nibbles.
637+
vals = torch.stack((low, high), dim=-1).reshape(out_features, n_blocks, 32)
638+
qs = vals[:, :, :16] | (vals[:, :, 16:] << 4)
639+
raw = torch.cat((scale_u8.unsqueeze(-1), qs.to(torch.uint8)), dim=-1)
640+
return raw.reshape(out_features, n_blocks * 17).cpu().numpy()
641+
642+
def _write_mxfp4_expert_tensor(self, bid: int, proj: str, tensor_key: gguf.MODEL_TENSOR) -> list[str]:
643+
n_experts = self.hparams["n_routed_experts"]
644+
data: np.ndarray | None = None
645+
consumed: list[str] = []
646+
647+
for eid in range(n_experts):
648+
weight_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.weight"
649+
scale_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.scale"
650+
if weight_name not in self.model_tensors or scale_name not in self.model_tensors:
651+
raise KeyError(f"Missing routed expert tensors for {weight_name}")
652+
653+
weight = LazyTorchTensor.to_eager(self.model_tensors[weight_name]())
654+
scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]())
655+
packed = self._pack_mxfp4_blocks(weight, scale)
656+
if data is None:
657+
data = np.empty((n_experts, *packed.shape), dtype=packed.dtype)
658+
data[eid] = packed
659+
consumed.extend((weight_name, scale_name))
660+
661+
assert data is not None
662+
new_name = self.format_tensor_name(tensor_key, bid)
663+
shape = gguf.quant_shape_from_byte_shape(data.shape, gguf.GGMLQuantizationType.MXFP4)
664+
logger.info(f"{new_name}: repacked routed experts to MXFP4, shape = {{{', '.join(str(n) for n in reversed(shape))}}}")
665+
self.gguf_writer.add_tensor(new_name, data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)
666+
667+
return consumed
668+
669+
def _write_hash_routing_tensors(self) -> list[str]:
670+
consumed: list[str] = []
671+
672+
for bid in range(self.hparams["num_hash_layers"]):
673+
name = f"layers.{bid}.ffn.gate.tid2eid"
674+
if name not in self.model_tensors:
675+
raise KeyError(f"Missing hash routing tensor {name}")
676+
677+
data_torch = LazyTorchTensor.to_eager(self.model_tensors[name]())
678+
data = data_torch.to(torch.int32).cpu().numpy()
679+
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_TID2EID, bid, "")
680+
logger.info(f"{new_name}: converted hash routing table to I32, shape = {{{', '.join(str(n) for n in reversed(data.shape))}}}")
681+
self.gguf_writer.add_tensor(new_name, data)
682+
consumed.append(name)
683+
684+
return consumed
685+
686+
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
687+
if self._dsv4_mxfp4_generated:
688+
return ()
689+
690+
consumed: list[str] = self._write_hash_routing_tensors()
691+
for bid in range(self.block_count):
692+
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w1", gguf.MODEL_TENSOR.FFN_GATE_EXP))
693+
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w2", gguf.MODEL_TENSOR.FFN_DOWN_EXP))
694+
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w3", gguf.MODEL_TENSOR.FFN_UP_EXP))
695+
696+
for name in consumed:
697+
del self.model_tensors[name]
698+
699+
self._dsv4_mxfp4_generated = True
700+
return ()
701+
702+
def _format_dsv4_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> str:
703+
return self.format_tensor_name(key, bid, suffix)
704+
705+
def _map_dsv4_tensor_name(self, name: str, bid: int | None) -> tuple[gguf.MODEL_TENSOR, str]:
706+
root_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = {
707+
"embed.weight": (gguf.MODEL_TENSOR.TOKEN_EMBD, ".weight"),
708+
"norm.weight": (gguf.MODEL_TENSOR.OUTPUT_NORM, ".weight"),
709+
"head.weight": (gguf.MODEL_TENSOR.OUTPUT, ".weight"),
710+
"hc_head_fn": (gguf.MODEL_TENSOR.HC_HEAD_FN, ""),
711+
"hc_head_base": (gguf.MODEL_TENSOR.HC_HEAD_BASE, ""),
712+
"hc_head_scale": (gguf.MODEL_TENSOR.HC_HEAD_SCALE, ""),
713+
}
714+
if name in root_map:
715+
return root_map[name]
716+
717+
match = re.match(r"layers\.(\d+)\.(.+)$", name)
718+
if match is None:
719+
raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}")
720+
721+
layer = int(match.group(1))
722+
if bid != layer:
723+
raise ValueError(f"Tensor {name!r} parsed bid {bid} but layer name has {layer}")
724+
725+
layer_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = {
726+
"hc_attn_fn": (gguf.MODEL_TENSOR.HC_ATTN_FN, ""),
727+
"hc_attn_base": (gguf.MODEL_TENSOR.HC_ATTN_BASE, ""),
728+
"hc_attn_scale": (gguf.MODEL_TENSOR.HC_ATTN_SCALE, ""),
729+
"hc_ffn_fn": (gguf.MODEL_TENSOR.HC_FFN_FN, ""),
730+
"hc_ffn_base": (gguf.MODEL_TENSOR.HC_FFN_BASE, ""),
731+
"hc_ffn_scale": (gguf.MODEL_TENSOR.HC_FFN_SCALE, ""),
732+
"attn.attn_sink": (gguf.MODEL_TENSOR.ATTN_SINKS, ""),
733+
"attn.wq_a.weight": (gguf.MODEL_TENSOR.ATTN_Q_A, ".weight"),
734+
"attn.wq_b.weight": (gguf.MODEL_TENSOR.ATTN_Q_B, ".weight"),
735+
"attn.q_norm.weight": (gguf.MODEL_TENSOR.ATTN_Q_A_NORM, ".weight"),
736+
"attn.wkv.weight": (gguf.MODEL_TENSOR.ATTN_KV, ".weight"),
737+
"attn.kv_norm.weight": (gguf.MODEL_TENSOR.ATTN_KV_NORM, ".weight"),
738+
"attn.wo_a.weight": (gguf.MODEL_TENSOR.ATTN_OUT_A, ".weight"),
739+
"attn.wo_b.weight": (gguf.MODEL_TENSOR.ATTN_OUT_B, ".weight"),
740+
"attn.compressor.ape": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_APE, ""),
741+
"attn.compressor.wkv.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WKV, ".weight"),
742+
"attn.compressor.wgate.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WGATE, ".weight"),
743+
"attn.compressor.norm.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_NORM, ".weight"),
744+
"attn.indexer.wq_b.weight": (gguf.MODEL_TENSOR.INDEXER_ATTN_Q_B, ".weight"),
745+
"attn.indexer.weights_proj.weight": (gguf.MODEL_TENSOR.INDEXER_PROJ, ".weight"),
746+
"attn.indexer.compressor.ape": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_APE, ""),
747+
"attn.indexer.compressor.wkv.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WKV, ".weight"),
748+
"attn.indexer.compressor.wgate.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE, ".weight"),
749+
"attn.indexer.compressor.norm.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_NORM, ".weight"),
750+
"attn_norm.weight": (gguf.MODEL_TENSOR.ATTN_NORM, ".weight"),
751+
"ffn_norm.weight": (gguf.MODEL_TENSOR.FFN_NORM, ".weight"),
752+
"ffn.gate.weight": (gguf.MODEL_TENSOR.FFN_GATE_INP, ".weight"),
753+
"ffn.gate.bias": (gguf.MODEL_TENSOR.FFN_EXP_PROBS_B, ".bias"),
754+
"ffn.gate.tid2eid": (gguf.MODEL_TENSOR.FFN_GATE_TID2EID, ""),
755+
"ffn.shared_experts.w1.weight": (gguf.MODEL_TENSOR.FFN_GATE_SHEXP, ".weight"),
756+
"ffn.shared_experts.w2.weight": (gguf.MODEL_TENSOR.FFN_DOWN_SHEXP, ".weight"),
757+
"ffn.shared_experts.w3.weight": (gguf.MODEL_TENSOR.FFN_UP_SHEXP, ".weight"),
758+
}
759+
760+
tensor_name = match.group(2)
761+
if tensor_name in layer_map:
762+
return layer_map[tensor_name]
763+
764+
if re.match(r"ffn\.experts\.\d+\.w[123]\.(weight|scale)$", tensor_name):
765+
return gguf.MODEL_TENSOR.FFN_GATE_EXP, ""
766+
767+
raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}")
768+
769+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
770+
if re.match(r"layers\.\d+\.ffn\.experts\.\d+\.w[123]\.(weight|scale)$", name):
771+
return []
772+
773+
tensor_key, suffix = self._map_dsv4_tensor_name(name, bid)
774+
if tensor_key == gguf.MODEL_TENSOR.FFN_GATE_TID2EID:
775+
return []
776+
elif tensor_key == gguf.MODEL_TENSOR.ATTN_OUT_A:
777+
data_torch = data_torch.reshape(self.hparams["o_groups"], self.hparams["o_lora_rank"], self.hparams["hidden_size"])
778+
779+
return [(self._format_dsv4_tensor_name(tensor_key, bid, suffix), data_torch)]
780+
781+
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
782+
del new_name, bid # unused
783+
784+
if name in self._dsv4_fp8_dequantized and n_dims >= 2:
785+
return gguf.GGMLQuantizationType.Q8_0
786+
if name in self._dsv4_f32_tensors:
787+
return gguf.GGMLQuantizationType.F32
788+
if name in self._dsv4_bf16_tensors and n_dims >= 2:
789+
return gguf.GGMLQuantizationType.BF16
790+
791+
return False
792+
793+
def prepare_tensors(self):
794+
super().prepare_tensors()
795+
self._is_mxfp4 = True
796+
self.ftype = gguf.LlamaFileType.MOSTLY_MXFP4_MOE

0 commit comments

Comments
 (0)