Skip to content

Commit cf7de9d

Browse files
committed
Follow comments
1 parent 0c1f7d8 commit cf7de9d

1 file changed

Lines changed: 6 additions & 10 deletions

File tree

convert_hf_to_gguf.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13201,15 +13201,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
1320113201
class Sarashina2VLTextModel(LlamaModel):
1320213202
model_arch = gguf.MODEL_ARCH.LLAMA
1320313203

13204-
def __init__(self, *args, **kwargs):
13205-
super().__init__(*args, **kwargs)
13206-
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
13207-
self.origin_hf_arch = hparams.get('architectures', [None])[0]
13208-
1320913204
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
1321013205
if name.startswith("llm."):
13211-
name = name[len("llm."):]
13212-
yield from super().modify_tensors(data_torch, name, bid)
13206+
name = name.replace("llm.", "", 1)
13207+
elif name.startswith("norm.") or name.startswith("visual."):
13208+
return
13209+
13210+
yield from super().modify_tensors(data_torch, name, bid)
1321313211

1321413212

1321513213
@ModelBase.register("Sarashina2VisionForCausalLM")
@@ -13474,10 +13472,8 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
1347413472
# Step3-VL keeps text config under text_config but uses a custom top-level architecture.
1347513473
# For text conversion we route to a dedicated text-only class.
1347613474
# TODO: refactor this later to avoid adding exception here
13477-
if model_type == ModelType.TEXT and arch == "StepVLForConditionalGeneration":
13475+
if model_type == ModelType.TEXT and arch in ("StepVLForConditionalGeneration", "Sarashina2VisionForCausalLM"):
1347813476
return arch
13479-
if model_type == ModelType.TEXT and arch == "Sarashina2VisionForCausalLM":
13480-
return "Sarashina2VisionForCausalLM"
1348113477

1348213478
# if "architectures" is found in the sub-config, use that instead
1348313479
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:

0 commit comments

Comments
 (0)