Skip to content

Commit 3ded2db

Browse files
committed
Feat 22103. Update to resolve review comments #2
1 parent 0adede8 commit 3ded2db

1 file changed

Lines changed: 24 additions & 0 deletions

File tree

convert_hf_to_gguf.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13197,6 +13197,28 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
1319713197
yield from super().modify_tensors(data_torch, name, bid)
1319813198

1319913199

13200+
@ModelBase.register("Sarashina2VisionForCausalLM")
13201+
class Sarashina2VLTextModel(LlamaModel):
13202+
model_arch = gguf.MODEL_ARCH.LLAMA
13203+
13204+
def __init__(self, *args, **kwargs):
13205+
super().__init__(*args, **kwargs)
13206+
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
13207+
self.origin_hf_arch = hparams.get('architectures', [None])[0]
13208+
13209+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
13210+
if name.startswith("llm."):
13211+
name = name[len("llm."):]
13212+
yield from super().modify_tensors(data_torch, name, bid)
13213+
13214+
13215+
@ModelBase.register("Sarashina2VisionForCausalLM")
13216+
class Sarashina2VLVisionModel(Qwen2VLVisionModel):
13217+
def __init__(self, *args, **kwargs):
13218+
super().__init__(*args, **kwargs)
13219+
self.global_config['model_type'] = "qwen2_vl"
13220+
13221+
1320013222
###### CONVERSION LOGIC ######
1320113223

1320213224

@@ -13454,6 +13476,8 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
1345413476
# TODO: refactor this later to avoid adding exception here
1345513477
if model_type == ModelType.TEXT and arch == "StepVLForConditionalGeneration":
1345613478
return arch
13479+
if model_type == ModelType.TEXT and arch == "Sarashina2VisionForCausalLM":
13480+
return "Sarashina2VisionForCausalLM"
1345713481

1345813482
# if "architectures" is found in the sub-config, use that instead
1345913483
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:

0 commit comments

Comments
 (0)