forked from ggml-org/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmistral3.py
More file actions
67 lines (50 loc) · 2.25 KB
/
Copy pathmistral3.py
File metadata and controls
67 lines (50 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from torch import Tensor
from .base import ModelBase, TextModel, gguf
from .deepseek import DeepseekV2Model
from .llama import LlamaModel
@ModelBase.register(
"Mistral3ForConditionalGeneration",
"Ministral3ForCausalLM",
)
class Mistral3Model(TextModel):
class Ministral3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.MISTRAL3
def set_gguf_parameters(self):
super().set_gguf_parameters()
rope_params = self.rope_parameters
if self.hparams.get("model_type") == "ministral3":
assert rope_params, "ministral3 must have 'rope_parameters' config"
assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
class Mistral4Model(DeepseekV2Model):
model_arch = gguf.MODEL_ARCH.MISTRAL4
skip_mtp = False # model contains no MTP layers, so no need to skip
merge_expert = False # experts are already stacked as 3D
def modify_tensors(self, data_torch, name, bid):
if name.endswith(".down_proj") or name.endswith(".gate_up_proj"):
name = name + ".weight"
yield from super().modify_tensors(data_torch, name, bid)
model_arch = gguf.MODEL_ARCH.MISTRAL3 # unused
impl: TextModel
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams.get("model_type") == "mistral4":
self.impl = Mistral3Model.Mistral4Model(*args, **kwargs)
else:
self.impl = Mistral3Model.Ministral3Model(*args, **kwargs)
def set_vocab(self):
self.impl.set_vocab()
def set_gguf_parameters(self):
self.impl.set_gguf_parameters()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
yield from self.impl.modify_tensors(data_torch, name, bid)
def prepare_tensors(self):
self.impl.prepare_tensors()
def write_vocab(self):
self.impl.write_vocab()
def write(self):
self.impl.write()