|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import json |
3 | 4 | import re |
4 | 5 |
|
5 | 6 | from typing import Any, Callable, Iterable, TYPE_CHECKING |
6 | 7 |
|
| 8 | +import numpy as np |
7 | 9 | import torch |
8 | 10 |
|
9 | 11 | if TYPE_CHECKING: |
10 | 12 | from torch import Tensor |
11 | 13 |
|
12 | | -from .base import MmprojModel, ModelBase, TextModel, gguf, logger |
| 14 | +from .base import LazyTorchTensor, MmprojModel, ModelBase, TextModel, gguf, logger |
13 | 15 |
|
14 | 16 | from .qwen import QwenModel |
15 | 17 |
|
@@ -459,3 +461,336 @@ def set_gguf_parameters(self): |
459 | 461 | self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"]) |
460 | 462 | self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"]) |
461 | 463 | self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"]) |
| 464 | + |
| 465 | + |
| 466 | +@ModelBase.register("DeepseekV4ForCausalLM") |
| 467 | +class DeepseekV4FlashModel(TextModel): |
| 468 | + model_arch = gguf.MODEL_ARCH.DEEPSEEK_V4_FLASH |
| 469 | + _skipped_mtp_tensors = 0 |
| 470 | + |
| 471 | + def __init__(self, *args, **kwargs): |
| 472 | + type(self)._skipped_mtp_tensors = 0 |
| 473 | + super().__init__(*args, **kwargs) |
| 474 | + |
| 475 | + with open(self.dir_model / "config.json", "r", encoding="utf-8") as f: |
| 476 | + raw_hparams = json.load(f) |
| 477 | + for key, value in raw_hparams.items(): |
| 478 | + self.hparams.setdefault(key, value) |
| 479 | + |
| 480 | + self.block_count = self.hparams["num_hidden_layers"] |
| 481 | + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) |
| 482 | + |
| 483 | + self._dsv4_fp8_dequantized: set[str] = set() |
| 484 | + self._dsv4_bf16_tensors: set[str] = set() |
| 485 | + self._dsv4_f32_tensors: set[str] = set() |
| 486 | + self._dsv4_mxfp4_generated = False |
| 487 | + self._collect_source_dtypes() |
| 488 | + |
| 489 | + if type(self)._skipped_mtp_tensors: |
| 490 | + logger.info("Skipping %d DeepSeek-V4 MTP tensor(s) for conversion v0", type(self)._skipped_mtp_tensors) |
| 491 | + |
| 492 | + @classmethod |
| 493 | + def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None: |
| 494 | + name, _ = item |
| 495 | + if name.startswith("mtp."): |
| 496 | + cls._skipped_mtp_tensors += 1 |
| 497 | + return None |
| 498 | + return super().filter_tensors(item) |
| 499 | + |
| 500 | + def set_vocab(self): |
| 501 | + self._set_vocab_gpt2() |
| 502 | + |
| 503 | + @staticmethod |
| 504 | + def _float8_dtypes() -> tuple[torch.dtype, ...]: |
| 505 | + return tuple( |
| 506 | + dtype for dtype in ( |
| 507 | + getattr(torch, "float8_e4m3fn", None), |
| 508 | + getattr(torch, "float8_e5m2", None), |
| 509 | + ) if dtype is not None |
| 510 | + ) |
| 511 | + |
| 512 | + @staticmethod |
| 513 | + def _e8m0_to_float(scale: Tensor) -> Tensor: |
| 514 | + torch_float8_e8m0 = getattr(torch, "float8_e8m0fnu", None) |
| 515 | + if torch_float8_e8m0 is not None and scale.dtype == torch_float8_e8m0: |
| 516 | + return scale.float() |
| 517 | + |
| 518 | + bits = scale.view(torch.uint8).float() |
| 519 | + return torch.pow(torch.tensor(2.0, device=bits.device), bits - 127.0) |
| 520 | + |
| 521 | + def _collect_source_dtypes(self) -> None: |
| 522 | + for name, gen in self.model_tensors.items(): |
| 523 | + dtype = gen().dtype |
| 524 | + if dtype == torch.bfloat16: |
| 525 | + self._dsv4_bf16_tensors.add(name) |
| 526 | + elif dtype == torch.float32: |
| 527 | + self._dsv4_f32_tensors.add(name) |
| 528 | + |
| 529 | + def set_gguf_parameters(self): |
| 530 | + hparams = self.hparams |
| 531 | + arch = gguf.MODEL_ARCH_NAMES[self.model_arch] |
| 532 | + |
| 533 | + self.gguf_writer.add_block_count(self.block_count) |
| 534 | + self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) |
| 535 | + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) |
| 536 | + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) |
| 537 | + self.gguf_writer.add_head_count(hparams["num_attention_heads"]) |
| 538 | + self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"]) |
| 539 | + self.gguf_writer.add_key_length(hparams["head_dim"]) |
| 540 | + self.gguf_writer.add_value_length(hparams["head_dim"]) |
| 541 | + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) |
| 542 | + self.gguf_writer.add_rope_freq_base(hparams["rope_theta"]) |
| 543 | + self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) |
| 544 | + self.gguf_writer.add_sliding_window(hparams["sliding_window"]) |
| 545 | + self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) |
| 546 | + |
| 547 | + rope_scaling = hparams.get("rope_scaling") or {} |
| 548 | + rope_type = rope_scaling.get("type", rope_scaling.get("rope_type")) |
| 549 | + if rope_type == "yarn": |
| 550 | + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) |
| 551 | + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) |
| 552 | + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) |
| 553 | + if (yarn_beta_fast := rope_scaling.get("beta_fast")) is not None: |
| 554 | + self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_beta_fast) |
| 555 | + if (yarn_beta_slow := rope_scaling.get("beta_slow")) is not None: |
| 556 | + self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_beta_slow) |
| 557 | + else: |
| 558 | + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) |
| 559 | + |
| 560 | + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) |
| 561 | + self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) |
| 562 | + self.gguf_writer.add_expert_used_count(hparams["num_experts_per_tok"]) |
| 563 | + self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) |
| 564 | + self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) |
| 565 | + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) |
| 566 | + self.gguf_writer.add_swiglu_clamp_exp([hparams["swiglu_limit"]] * self.block_count) |
| 567 | + self.gguf_writer.add_swiglu_clamp_shexp([hparams["swiglu_limit"]] * self.block_count) |
| 568 | + |
| 569 | + self.gguf_writer.add_indexer_head_count(hparams["index_n_heads"]) |
| 570 | + self.gguf_writer.add_indexer_key_length(hparams["index_head_dim"]) |
| 571 | + self.gguf_writer.add_indexer_top_k(hparams["index_topk"]) |
| 572 | + |
| 573 | + self.gguf_writer.add_uint32(f"{arch}.attention.o_group_count", hparams["o_groups"]) |
| 574 | + self.gguf_writer.add_uint32(f"{arch}.attention.o_lora_rank", hparams["o_lora_rank"]) |
| 575 | + self.gguf_writer.add_array(f"{arch}.attention.compress_ratios", hparams["compress_ratios"]) |
| 576 | + self.gguf_writer.add_float32(f"{arch}.attention.compress_rope.freq_base", hparams["compress_rope_theta"]) |
| 577 | + self.gguf_writer.add_uint32(f"{arch}.hc.mult", hparams["hc_mult"]) |
| 578 | + self.gguf_writer.add_uint32(f"{arch}.hc.sinkhorn_iters", hparams["hc_sinkhorn_iters"]) |
| 579 | + self.gguf_writer.add_float32(f"{arch}.hc.eps", hparams["hc_eps"]) |
| 580 | + self.gguf_writer.add_uint32(f"{arch}.moe.hash_layer_count", hparams["num_hash_layers"]) |
| 581 | + self.gguf_writer.add_string(f"{arch}.moe.score_func", hparams["scoring_func"]) |
| 582 | + self.gguf_writer.add_string(f"{arch}.moe.topk_method", hparams["topk_method"]) |
| 583 | + |
| 584 | + self.gguf_writer.add_file_type(self.ftype) |
| 585 | + logger.info(f"gguf: file type = {self.ftype}") |
| 586 | + |
| 587 | + def dequant_model(self): |
| 588 | + fp8_dtypes = self._float8_dtypes() |
| 589 | + tensors_to_remove: list[str] = [] |
| 590 | + |
| 591 | + def dequant_fp8_weight(weight: Tensor, scale: Tensor) -> Tensor: |
| 592 | + out_features, in_features = weight.shape |
| 593 | + scale_f = self._e8m0_to_float(scale) |
| 594 | + scale_f = scale_f.repeat_interleave(128, 0)[:out_features] |
| 595 | + scale_f = scale_f.repeat_interleave(128, 1)[:, :in_features] |
| 596 | + return weight.float() * scale_f |
| 597 | + |
| 598 | + for name in list(self.model_tensors.keys()): |
| 599 | + if not name.endswith(".scale"): |
| 600 | + continue |
| 601 | + weight_name = name.removesuffix(".scale") + ".weight" |
| 602 | + if weight_name not in self.model_tensors: |
| 603 | + continue |
| 604 | + |
| 605 | + weight = self.model_tensors[weight_name] |
| 606 | + scale = self.model_tensors[name] |
| 607 | + if weight().dtype not in fp8_dtypes: |
| 608 | + continue |
| 609 | + |
| 610 | + self.model_tensors[weight_name] = lambda w=weight, s=scale: dequant_fp8_weight(w(), s()) |
| 611 | + self._dsv4_fp8_dequantized.add(weight_name) |
| 612 | + tensors_to_remove.append(name) |
| 613 | + |
| 614 | + for name in tensors_to_remove: |
| 615 | + del self.model_tensors[name] |
| 616 | + |
| 617 | + @staticmethod |
| 618 | + def _pack_mxfp4_blocks(weight: Tensor, scale: Tensor) -> np.ndarray: |
| 619 | + packed = weight.contiguous().view(torch.uint8) |
| 620 | + scale_u8 = scale.contiguous().view(torch.uint8) |
| 621 | + |
| 622 | + out_features, packed_cols = packed.shape |
| 623 | + logical_cols = packed_cols * 2 |
| 624 | + if logical_cols % 32 != 0: |
| 625 | + raise ValueError(f"MXFP4 source row has {logical_cols} values, expected a multiple of 32") |
| 626 | + |
| 627 | + n_blocks = logical_cols // 32 |
| 628 | + if tuple(scale_u8.shape) != (out_features, n_blocks): |
| 629 | + raise ValueError(f"MXFP4 scale shape {tuple(scale_u8.shape)} does not match {(out_features, n_blocks)}") |
| 630 | + |
| 631 | + src = packed.reshape(out_features, n_blocks, 16) |
| 632 | + low = src & 0x0F |
| 633 | + high = (src >> 4) & 0x0F |
| 634 | + |
| 635 | + # The safetensors bytes store adjacent values as low/high nibbles. |
| 636 | + # ggml MXFP4 blocks store values 0..15 in low nibbles and 16..31 in high nibbles. |
| 637 | + vals = torch.stack((low, high), dim=-1).reshape(out_features, n_blocks, 32) |
| 638 | + qs = vals[:, :, :16] | (vals[:, :, 16:] << 4) |
| 639 | + raw = torch.cat((scale_u8.unsqueeze(-1), qs.to(torch.uint8)), dim=-1) |
| 640 | + return raw.reshape(out_features, n_blocks * 17).cpu().numpy() |
| 641 | + |
| 642 | + def _write_mxfp4_expert_tensor(self, bid: int, proj: str, tensor_key: gguf.MODEL_TENSOR) -> list[str]: |
| 643 | + n_experts = self.hparams["n_routed_experts"] |
| 644 | + data: np.ndarray | None = None |
| 645 | + consumed: list[str] = [] |
| 646 | + |
| 647 | + for eid in range(n_experts): |
| 648 | + weight_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.weight" |
| 649 | + scale_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.scale" |
| 650 | + if weight_name not in self.model_tensors or scale_name not in self.model_tensors: |
| 651 | + raise KeyError(f"Missing routed expert tensors for {weight_name}") |
| 652 | + |
| 653 | + weight = LazyTorchTensor.to_eager(self.model_tensors[weight_name]()) |
| 654 | + scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]()) |
| 655 | + packed = self._pack_mxfp4_blocks(weight, scale) |
| 656 | + if data is None: |
| 657 | + data = np.empty((n_experts, *packed.shape), dtype=packed.dtype) |
| 658 | + data[eid] = packed |
| 659 | + consumed.extend((weight_name, scale_name)) |
| 660 | + |
| 661 | + assert data is not None |
| 662 | + new_name = self.format_tensor_name(tensor_key, bid) |
| 663 | + shape = gguf.quant_shape_from_byte_shape(data.shape, gguf.GGMLQuantizationType.MXFP4) |
| 664 | + logger.info(f"{new_name}: repacked routed experts to MXFP4, shape = {{{', '.join(str(n) for n in reversed(shape))}}}") |
| 665 | + self.gguf_writer.add_tensor(new_name, data, raw_dtype=gguf.GGMLQuantizationType.MXFP4) |
| 666 | + |
| 667 | + return consumed |
| 668 | + |
| 669 | + def _write_hash_routing_tensors(self) -> list[str]: |
| 670 | + consumed: list[str] = [] |
| 671 | + |
| 672 | + for bid in range(self.hparams["num_hash_layers"]): |
| 673 | + name = f"layers.{bid}.ffn.gate.tid2eid" |
| 674 | + if name not in self.model_tensors: |
| 675 | + raise KeyError(f"Missing hash routing tensor {name}") |
| 676 | + |
| 677 | + data_torch = LazyTorchTensor.to_eager(self.model_tensors[name]()) |
| 678 | + data = data_torch.to(torch.int32).cpu().numpy() |
| 679 | + new_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_TID2EID, bid, "") |
| 680 | + logger.info(f"{new_name}: converted hash routing table to I32, shape = {{{', '.join(str(n) for n in reversed(data.shape))}}}") |
| 681 | + self.gguf_writer.add_tensor(new_name, data) |
| 682 | + consumed.append(name) |
| 683 | + |
| 684 | + return consumed |
| 685 | + |
| 686 | + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: |
| 687 | + if self._dsv4_mxfp4_generated: |
| 688 | + return () |
| 689 | + |
| 690 | + consumed: list[str] = self._write_hash_routing_tensors() |
| 691 | + for bid in range(self.block_count): |
| 692 | + consumed.extend(self._write_mxfp4_expert_tensor(bid, "w1", gguf.MODEL_TENSOR.FFN_GATE_EXP)) |
| 693 | + consumed.extend(self._write_mxfp4_expert_tensor(bid, "w2", gguf.MODEL_TENSOR.FFN_DOWN_EXP)) |
| 694 | + consumed.extend(self._write_mxfp4_expert_tensor(bid, "w3", gguf.MODEL_TENSOR.FFN_UP_EXP)) |
| 695 | + |
| 696 | + for name in consumed: |
| 697 | + del self.model_tensors[name] |
| 698 | + |
| 699 | + self._dsv4_mxfp4_generated = True |
| 700 | + return () |
| 701 | + |
| 702 | + def _format_dsv4_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> str: |
| 703 | + return self.format_tensor_name(key, bid, suffix) |
| 704 | + |
| 705 | + def _map_dsv4_tensor_name(self, name: str, bid: int | None) -> tuple[gguf.MODEL_TENSOR, str]: |
| 706 | + root_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = { |
| 707 | + "embed.weight": (gguf.MODEL_TENSOR.TOKEN_EMBD, ".weight"), |
| 708 | + "norm.weight": (gguf.MODEL_TENSOR.OUTPUT_NORM, ".weight"), |
| 709 | + "head.weight": (gguf.MODEL_TENSOR.OUTPUT, ".weight"), |
| 710 | + "hc_head_fn": (gguf.MODEL_TENSOR.HC_HEAD_FN, ""), |
| 711 | + "hc_head_base": (gguf.MODEL_TENSOR.HC_HEAD_BASE, ""), |
| 712 | + "hc_head_scale": (gguf.MODEL_TENSOR.HC_HEAD_SCALE, ""), |
| 713 | + } |
| 714 | + if name in root_map: |
| 715 | + return root_map[name] |
| 716 | + |
| 717 | + match = re.match(r"layers\.(\d+)\.(.+)$", name) |
| 718 | + if match is None: |
| 719 | + raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}") |
| 720 | + |
| 721 | + layer = int(match.group(1)) |
| 722 | + if bid != layer: |
| 723 | + raise ValueError(f"Tensor {name!r} parsed bid {bid} but layer name has {layer}") |
| 724 | + |
| 725 | + layer_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = { |
| 726 | + "hc_attn_fn": (gguf.MODEL_TENSOR.HC_ATTN_FN, ""), |
| 727 | + "hc_attn_base": (gguf.MODEL_TENSOR.HC_ATTN_BASE, ""), |
| 728 | + "hc_attn_scale": (gguf.MODEL_TENSOR.HC_ATTN_SCALE, ""), |
| 729 | + "hc_ffn_fn": (gguf.MODEL_TENSOR.HC_FFN_FN, ""), |
| 730 | + "hc_ffn_base": (gguf.MODEL_TENSOR.HC_FFN_BASE, ""), |
| 731 | + "hc_ffn_scale": (gguf.MODEL_TENSOR.HC_FFN_SCALE, ""), |
| 732 | + "attn.attn_sink": (gguf.MODEL_TENSOR.ATTN_SINKS, ""), |
| 733 | + "attn.wq_a.weight": (gguf.MODEL_TENSOR.ATTN_Q_A, ".weight"), |
| 734 | + "attn.wq_b.weight": (gguf.MODEL_TENSOR.ATTN_Q_B, ".weight"), |
| 735 | + "attn.q_norm.weight": (gguf.MODEL_TENSOR.ATTN_Q_A_NORM, ".weight"), |
| 736 | + "attn.wkv.weight": (gguf.MODEL_TENSOR.ATTN_KV, ".weight"), |
| 737 | + "attn.kv_norm.weight": (gguf.MODEL_TENSOR.ATTN_KV_NORM, ".weight"), |
| 738 | + "attn.wo_a.weight": (gguf.MODEL_TENSOR.ATTN_OUT_A, ".weight"), |
| 739 | + "attn.wo_b.weight": (gguf.MODEL_TENSOR.ATTN_OUT_B, ".weight"), |
| 740 | + "attn.compressor.ape": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_APE, ""), |
| 741 | + "attn.compressor.wkv.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WKV, ".weight"), |
| 742 | + "attn.compressor.wgate.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WGATE, ".weight"), |
| 743 | + "attn.compressor.norm.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_NORM, ".weight"), |
| 744 | + "attn.indexer.wq_b.weight": (gguf.MODEL_TENSOR.INDEXER_ATTN_Q_B, ".weight"), |
| 745 | + "attn.indexer.weights_proj.weight": (gguf.MODEL_TENSOR.INDEXER_PROJ, ".weight"), |
| 746 | + "attn.indexer.compressor.ape": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_APE, ""), |
| 747 | + "attn.indexer.compressor.wkv.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WKV, ".weight"), |
| 748 | + "attn.indexer.compressor.wgate.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE, ".weight"), |
| 749 | + "attn.indexer.compressor.norm.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_NORM, ".weight"), |
| 750 | + "attn_norm.weight": (gguf.MODEL_TENSOR.ATTN_NORM, ".weight"), |
| 751 | + "ffn_norm.weight": (gguf.MODEL_TENSOR.FFN_NORM, ".weight"), |
| 752 | + "ffn.gate.weight": (gguf.MODEL_TENSOR.FFN_GATE_INP, ".weight"), |
| 753 | + "ffn.gate.bias": (gguf.MODEL_TENSOR.FFN_EXP_PROBS_B, ".bias"), |
| 754 | + "ffn.gate.tid2eid": (gguf.MODEL_TENSOR.FFN_GATE_TID2EID, ""), |
| 755 | + "ffn.shared_experts.w1.weight": (gguf.MODEL_TENSOR.FFN_GATE_SHEXP, ".weight"), |
| 756 | + "ffn.shared_experts.w2.weight": (gguf.MODEL_TENSOR.FFN_DOWN_SHEXP, ".weight"), |
| 757 | + "ffn.shared_experts.w3.weight": (gguf.MODEL_TENSOR.FFN_UP_SHEXP, ".weight"), |
| 758 | + } |
| 759 | + |
| 760 | + tensor_name = match.group(2) |
| 761 | + if tensor_name in layer_map: |
| 762 | + return layer_map[tensor_name] |
| 763 | + |
| 764 | + if re.match(r"ffn\.experts\.\d+\.w[123]\.(weight|scale)$", tensor_name): |
| 765 | + return gguf.MODEL_TENSOR.FFN_GATE_EXP, "" |
| 766 | + |
| 767 | + raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}") |
| 768 | + |
| 769 | + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: |
| 770 | + if re.match(r"layers\.\d+\.ffn\.experts\.\d+\.w[123]\.(weight|scale)$", name): |
| 771 | + return [] |
| 772 | + |
| 773 | + tensor_key, suffix = self._map_dsv4_tensor_name(name, bid) |
| 774 | + if tensor_key == gguf.MODEL_TENSOR.FFN_GATE_TID2EID: |
| 775 | + return [] |
| 776 | + elif tensor_key == gguf.MODEL_TENSOR.ATTN_OUT_A: |
| 777 | + data_torch = data_torch.reshape(self.hparams["o_groups"], self.hparams["o_lora_rank"], self.hparams["hidden_size"]) |
| 778 | + |
| 779 | + return [(self._format_dsv4_tensor_name(tensor_key, bid, suffix), data_torch)] |
| 780 | + |
| 781 | + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: |
| 782 | + del new_name, bid # unused |
| 783 | + |
| 784 | + if name in self._dsv4_fp8_dequantized and n_dims >= 2: |
| 785 | + return gguf.GGMLQuantizationType.Q8_0 |
| 786 | + if name in self._dsv4_f32_tensors: |
| 787 | + return gguf.GGMLQuantizationType.F32 |
| 788 | + if name in self._dsv4_bf16_tensors and n_dims >= 2: |
| 789 | + return gguf.GGMLQuantizationType.BF16 |
| 790 | + |
| 791 | + return False |
| 792 | + |
| 793 | + def prepare_tensors(self): |
| 794 | + super().prepare_tensors() |
| 795 | + self._is_mxfp4 = True |
| 796 | + self.ftype = gguf.LlamaFileType.MOSTLY_MXFP4_MOE |
0 commit comments