@@ -710,15 +710,15 @@ def _generate_nvfp4_tensors(self):
710710 self._repack_nvfp4(name, weight, scale, scale2, input_scale)
711711
712712 # Flush any remaining experts (fallback if n_experts was unknown)
713- for bid, proj_type in expert_blocks.keys():
713+ for bid, proj_type in list( expert_blocks.keys() ):
714714 self._flush_nvfp4_experts((bid, proj_type), expert_blocks, expert_scales, expert_input_scales, expert_shapes, bid, proj_type)
715715
716716 # Remove consumed tensors so get_tensors/modify_tensors won't see them
717717 for name in consumed:
718718 self.model_tensors.pop(name, None)
719719
720720 # Remove any remaining unused auxiliary tensors
721- for name in self.model_tensors.keys():
721+ for name in list( self.model_tensors.keys() ):
722722 if name.endswith((".k_scale", ".v_scale")):
723723 del self.model_tensors[name]
724724
@@ -7988,13 +7988,37 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
79887988 rope_freqs_full = torch.tensor(values, dtype=torch.float32)
79897989 yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), rope_freqs_full)
79907990
7991+ def _generate_nvfp4_tensors(self):
7992+ # Gemma-4 stores a per-layer router.per_expert_scale ([n_expert]) that scales
7993+ # each expert's contribution. It's mathematically equivalent to a per-expert
7994+ # scalar on the down_proj output, which is exactly where ffn_down_exps_s is
7995+ # applied at inference. Fold it into each expert's NVFP4 weight_scale_2 so the
7996+ # existing NVFP4 path produces the right scales.
7997+ n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=True) or 0
7998+ for name in [n for n in self.model_tensors if n.endswith(".router.per_expert_scale")]:
7999+ bid_match = re.search(r"\.layers\.(\d+)\.", name)
8000+ if bid_match is None:
8001+ continue
8002+ bid = bid_match.group(1)
8003+ prefix = name[: name.index(f".layers.{bid}.") + len(f".layers.{bid}.")]
8004+ w2_targets = [f"{prefix}experts.{e}.down_proj.weight_scale_2" for e in range(n_experts)]
8005+ present = [w2 in self.model_tensors for w2 in w2_targets]
8006+ if not any(present):
8007+ continue
8008+ assert all(present), f"layer {bid}: partial NVFP4 quantization across experts"
8009+ r = self.model_tensors.pop(name)
8010+ for e, w2 in enumerate(w2_targets):
8011+ s = self.model_tensors[w2]
8012+ self.model_tensors[w2] = lambda s=s, r=r, i=e: s() * r()[i]
8013+ super()._generate_nvfp4_tensors()
8014+
79918015 @classmethod
79928016 def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
79938017 name, gen = item
79948018
79958019 if name.endswith("per_dim_scale") or name.endswith("layer_scalar"):
79968020 name = name + ".weight"
7997- if ".experts." in name and not name.endswith(".weight"):
8021+ if ".experts." in name and not name.endswith(( ".weight", ".weight_scale", ".weight_scale_2", ".input_scale") ):
79988022 name += ".weight"
79998023
80008024 return super().filter_tensors((name, gen))
@@ -13684,6 +13708,27 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
1368413708 yield from super().modify_tensors(data_torch, name, bid)
1368513709
1368613710
13711+ @ModelBase.register("Sarashina2VisionForCausalLM")
13712+ class Sarashina2VLTextModel(LlamaModel):
13713+ model_arch = gguf.MODEL_ARCH.LLAMA
13714+
13715+ @classmethod
13716+ def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
13717+ name, gen = item
13718+ if name.startswith("llm."):
13719+ name = name.replace("llm.", "", 1)
13720+ elif name.startswith("norm."):
13721+ return None
13722+ return super().filter_tensors((name, gen))
13723+
13724+
13725+ @ModelBase.register("Sarashina2VisionForCausalLM")
13726+ class Sarashina2VLVisionModel(Qwen2VLVisionModel):
13727+ def __init__(self, *args, **kwargs):
13728+ super().__init__(*args, **kwargs)
13729+ self.global_config['model_type'] = "qwen2_vl"
13730+
13731+
1368713732###### CONVERSION LOGIC ######
1368813733
1368913734
@@ -13940,7 +13985,7 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
1394013985 # Step3-VL keeps text config under text_config but uses a custom top-level architecture.
1394113986 # For text conversion we route to a dedicated text-only class.
1394213987 # TODO: refactor this later to avoid adding exception here
13943- if model_type == ModelType.TEXT and arch == "StepVLForConditionalGeneration":
13988+ if model_type == ModelType.TEXT and arch in ( "StepVLForConditionalGeneration", "Sarashina2VisionForCausalLM") :
1394413989 return arch
1394513990
1394613991 # if "architectures" is found in the sub-config, use that instead
0 commit comments