Skip to content

Commit 512b3bb

Browse files
committed
Feat 22103. Update to resolve review comments
1 parent e583f3b commit 512b3bb

1 file changed

Lines changed: 71 additions & 0 deletions

File tree

convert_hf_to_gguf.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2815,6 +2815,7 @@ def prepare_tensors(self):
28152815
"LlavaForConditionalGeneration",
28162816
"VoxtralForConditionalGeneration",
28172817
"IQuestCoderForCausalLM",
2818+
"Sarashina2VisionForCausalLM",
28182819
"LlamaModel")
28192820
class LlamaModel(TextModel):
28202821
model_arch = gguf.MODEL_ARCH.LLAMA
@@ -13197,6 +13198,74 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
1319713198
yield from super().modify_tensors(data_torch, name, bid)
1319813199

1319913200

13201+
@ModelBase.register("Sarashina2VisionForCausalLM")
13202+
class Sarashina2VLTextModel(LlamaModel):
13203+
model_arch = gguf.MODEL_ARCH.LLAMA
13204+
13205+
def __init__(self, *args, **kwargs):
13206+
super().__init__(*args, **kwargs)
13207+
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
13208+
self.origin_hf_arch = hparams.get('architectures', [None])[0]
13209+
13210+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
13211+
if name.startswith("llm."):
13212+
name = name[len("llm."):]
13213+
yield from super().modify_tensors(data_torch, name, bid)
13214+
13215+
13216+
@ModelBase.register("Sarashina2VisionForCausalLM")
13217+
class Sarashina2VLVisionModel(MmprojModel):
13218+
model_type = ModelType.MMPROJ
13219+
13220+
def __init__(self, *args, **kwargs):
13221+
super().__init__(*args, **kwargs)
13222+
assert self.hparams_vision is not None
13223+
self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560)
13224+
# rename config.json values
13225+
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
13226+
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
13227+
if "embed_dim" in self.hparams_vision: # qwen2vl
13228+
self.hparams_vision["intermediate_size"] = self.hparams_vision.get("hidden_size")
13229+
self.hparams_vision["hidden_size"] = self.hparams_vision.get("embed_dim")
13230+
13231+
def set_gguf_parameters(self):
13232+
super().set_gguf_parameters()
13233+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL)
13234+
self.gguf_writer.add_vision_spatial_merge_size(2)
13235+
self.gguf_writer.add_vision_attention_layernorm_eps(self.global_config.get("rms_norm_eps", 1e-6))
13236+
13237+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
13238+
assert self.hparams_vision is not None
13239+
# Two tensors will be ignored
13240+
#if name in ('norm.weight', 'norm.bias'):
13241+
# name = "visual.post_layer" + name
13242+
if name.startswith("visual."):
13243+
# process visual tensors
13244+
# split QKV tensors if needed
13245+
if ".qkv." in name:
13246+
if data_torch.ndim == 2: # weight
13247+
c3, _ = data_torch.shape
13248+
else: # bias
13249+
c3 = data_torch.shape[0]
13250+
assert c3 % 3 == 0
13251+
c = c3 // 3
13252+
wq = data_torch[:c]
13253+
wk = data_torch[c: c * 2]
13254+
wv = data_torch[c * 2:]
13255+
yield from super().modify_tensors(wq, name.replace("qkv", "q"), bid)
13256+
yield from super().modify_tensors(wk, name.replace("qkv", "k"), bid)
13257+
yield from super().modify_tensors(wv, name.replace("qkv", "v"), bid)
13258+
elif 'patch_embed.proj.weight' in name:
13259+
# split Conv3D into Conv2Ds
13260+
c1, c2, kt, kh, kw = data_torch.shape
13261+
del c1, c2, kh, kw # unused
13262+
assert kt == 2, "Current implementation only support temporal_patch_size of 2"
13263+
yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight" , data_torch[:, :, 0, ...])
13264+
yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...])
13265+
else:
13266+
yield from super().modify_tensors(data_torch, name, bid)
13267+
13268+
1320013269
###### CONVERSION LOGIC ######
1320113270

1320213271

@@ -13454,6 +13523,8 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
1345413523
# TODO: refactor this later to avoid adding exception here
1345513524
if model_type == ModelType.TEXT and arch == "StepVLForConditionalGeneration":
1345613525
return arch
13526+
if model_type == ModelType.TEXT and arch == "Sarashina2VisionForCausalLM":
13527+
return "Sarashina2VisionForCausalLM"
1345713528

1345813529
# if "architectures" is found in the sub-config, use that instead
1345913530
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:

0 commit comments

Comments
 (0)