Skip to content

Commit ef4613b

Browse files
committed
samuraieng/sarashina22-00 Update to follow review comments
1 parent 5905b0a commit ef4613b

1 file changed

Lines changed: 71 additions & 9 deletions

File tree

convert_hf_to_gguf.py

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2810,6 +2810,7 @@ def prepare_tensors(self):
28102810
"LlavaForConditionalGeneration",
28112811
"VoxtralForConditionalGeneration",
28122812
"IQuestCoderForCausalLM",
2813+
"Sarashina2VisionForCausalLM",
28132814
"LlamaModel")
28142815
class LlamaModel(TextModel):
28152816
model_arch = gguf.MODEL_ARCH.LLAMA
@@ -2955,13 +2956,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
29552956
else:
29562957
return
29572958

2958-
if self.origin_hf_arch.startswith('Sarashina2VisionForCausalLM'):
2959-
# Remove llm. from name
2960-
if name.startswith("llm."):
2961-
name = name[len("llm."):]
2962-
elif name.startswith("visual.") or name in ("norm.weight", "norm.bias"):
2963-
return #Skip processing "modify_tensors"
2964-
29652959
yield from super().modify_tensors(data_torch, name, bid)
29662960

29672961
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
@@ -13117,6 +13111,74 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
1311713111
yield from super().modify_tensors(data_torch, name, bid)
1311813112

1311913113

13114+
@ModelBase.register("Sarashina2VisionForCausalLM")
13115+
class Sarashina2VLTextModel(LlamaModel):
13116+
model_arch = gguf.MODEL_ARCH.LLAMA
13117+
13118+
def __init__(self, *args, **kwargs):
13119+
super().__init__(*args, **kwargs)
13120+
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
13121+
self.origin_hf_arch = hparams.get('architectures', [None])[0]
13122+
13123+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
13124+
if name.startswith("llm."):
13125+
name = name[len("llm."):]
13126+
yield from super().modify_tensors(data_torch, name, bid)
13127+
13128+
13129+
@ModelBase.register("Sarashina2VisionForCausalLM")
13130+
class Sarashina2VLVisionModel(MmprojModel):
13131+
model_type = ModelType.MMPROJ
13132+
13133+
def __init__(self, *args, **kwargs):
13134+
super().__init__(*args, **kwargs)
13135+
assert self.hparams_vision is not None
13136+
self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560)
13137+
# rename config.json values
13138+
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
13139+
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
13140+
if "embed_dim" in self.hparams_vision: # qwen2vl
13141+
self.hparams_vision["intermediate_size"] = self.hparams_vision.get("hidden_size")
13142+
self.hparams_vision["hidden_size"] = self.hparams_vision.get("embed_dim")
13143+
13144+
def set_gguf_parameters(self):
13145+
super().set_gguf_parameters()
13146+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL)
13147+
self.gguf_writer.add_vision_spatial_merge_size(2)
13148+
self.gguf_writer.add_vision_attention_layernorm_eps(self.global_config.get("rms_norm_eps", 1e-6))
13149+
13150+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
13151+
assert self.hparams_vision is not None
13152+
# Two tensors will be ignored
13153+
#if name in ('norm.weight', 'norm.bias'):
13154+
# name = "visual.post_layer" + name
13155+
if name.startswith("visual."):
13156+
# process visual tensors
13157+
# split QKV tensors if needed
13158+
if ".qkv." in name:
13159+
if data_torch.ndim == 2: # weight
13160+
c3, _ = data_torch.shape
13161+
else: # bias
13162+
c3 = data_torch.shape[0]
13163+
assert c3 % 3 == 0
13164+
c = c3 // 3
13165+
wq = data_torch[:c]
13166+
wk = data_torch[c: c * 2]
13167+
wv = data_torch[c * 2:]
13168+
yield from super().modify_tensors(wq, name.replace("qkv", "q"), bid)
13169+
yield from super().modify_tensors(wk, name.replace("qkv", "k"), bid)
13170+
yield from super().modify_tensors(wv, name.replace("qkv", "v"), bid)
13171+
elif 'patch_embed.proj.weight' in name:
13172+
# split Conv3D into Conv2Ds
13173+
c1, c2, kt, kh, kw = data_torch.shape
13174+
del c1, c2, kh, kw # unused
13175+
assert kt == 2, "Current implementation only support temporal_patch_size of 2"
13176+
yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight" , data_torch[:, :, 0, ...])
13177+
yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...])
13178+
else:
13179+
yield from super().modify_tensors(data_torch, name, bid)
13180+
13181+
1312013182
###### CONVERSION LOGIC ######
1312113183

1312213184

@@ -13374,14 +13436,14 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
1337413436
# TODO: refactor this later to avoid adding exception here
1337513437
if model_type == ModelType.TEXT and arch == "StepVLForConditionalGeneration":
1337613438
return arch
13439+
if model_type == ModelType.TEXT and arch == "Sarashina2VisionForCausalLM":
13440+
return "Sarashina2VisionForCausalLM"
1337713441

1337813442
# if "architectures" is found in the sub-config, use that instead
1337913443
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
1338013444
arch = text_config["architectures"][0]
1338113445
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
1338213446
arch = vision_config["architectures"][0]
13383-
if "Sarashina2VisionForCausalLM" in arch:
13384-
arch = "Qwen2VLForConditionalGeneration"
1338513447
if arch is None:
1338613448
raise ValueError("Failed to detect model architecture")
1338713449
return arch

0 commit comments

Comments
 (0)