@@ -920,7 +920,7 @@ def set_gguf_parameters(self):
920920 self.gguf_writer.add_expert_group_used_count(n_group_used)
921921 logger.info(f"gguf: expert groups used count = {n_group_used}")
922922
923- if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func", "moe_router_activation_func"], optional=True)) is not None:
923+ if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func", "moe_router_activation", " moe_router_activation_func"], optional=True)) is not None:
924924 if score_func == "sigmoid":
925925 self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
926926 elif score_func == "softmax":
@@ -7912,6 +7912,135 @@ def prepare_tensors(self):
79127912 raise ValueError(f"Unprocessed experts: {experts}")
79137913
79147914
7915+ @ModelBase.register("Step3p5ForCausalLM")
7916+ class Step35Model(TextModel):
7917+ model_arch = gguf.MODEL_ARCH.STEP35
7918+
7919+ def set_gguf_parameters(self):
7920+ rope_theta = self.hparams.get("rope_theta")
7921+ if isinstance(rope_theta, list):
7922+ self.hparams["rope_theta"] = float(rope_theta[0])
7923+ self.hparams["local_rope_theta"] = float(rope_theta[1])
7924+ self.rope_parameters["rope_theta"] = self.hparams["rope_theta"]
7925+ self.rope_parameters["sliding_attention"] = {"rope_theta": self.hparams["local_rope_theta"]}
7926+
7927+ super().set_gguf_parameters()
7928+
7929+ layer_types = self.hparams.get("layer_types") or []
7930+ partial_rotary_factors = self.hparams.get("partial_rotary_factors") or []
7931+ attn_other = self.hparams.get("attention_other_setting") or {}
7932+
7933+ n_head_base = self.hparams["num_attention_heads"]
7934+ n_kv_base = self.hparams["num_attention_groups"]
7935+
7936+ n_head_swa = attn_other.get("num_attention_heads", n_head_base)
7937+ n_kv_swa = attn_other.get("num_attention_groups", n_kv_base)
7938+
7939+ layer_types = layer_types[: self.block_count]
7940+ partial_rotary_factors = partial_rotary_factors[: self.block_count]
7941+ assert [1.0 if lt == "sliding_attention" else 0.5 for lt in layer_types] == partial_rotary_factors
7942+ head_arr = [n_head_swa if lt == "sliding_attention" else n_head_base for lt in layer_types]
7943+ kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types]
7944+ swa_pat = [lt == "sliding_attention" for lt in layer_types]
7945+
7946+ self.gguf_writer.add_head_count(head_arr)
7947+ self.gguf_writer.add_head_count_kv(kv_arr)
7948+
7949+ self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
7950+ self.gguf_writer.add_sliding_window_pattern(swa_pat)
7951+
7952+ self.gguf_writer.add_value_length(self.hparams["head_dim"])
7953+
7954+ # MoE params
7955+ self.gguf_writer.add_expert_count(self.hparams["moe_num_experts"])
7956+ self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"])
7957+ self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
7958+ self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["share_expert_dim"])
7959+
7960+ if (moe_router_scaling_factor := self.hparams.get("moe_router_scaling_factor")) is not None:
7961+ self.gguf_writer.add_expert_weights_scale(moe_router_scaling_factor)
7962+ if (norm_expert_weight := self.hparams.get("norm_expert_weight")) is not None:
7963+ self.gguf_writer.add_expert_weights_norm(norm_expert_weight)
7964+
7965+ # leading dense blocks
7966+ leading_dense = 0
7967+ moe_layers_enum = self.hparams.get("moe_layers_enum")
7968+ if isinstance(moe_layers_enum, str) and moe_layers_enum.strip():
7969+ moe_layers = sorted(int(i) for i in moe_layers_enum.strip().split(","))
7970+ if moe_layers:
7971+ leading_dense = max(0, moe_layers[0])
7972+ self.gguf_writer.add_leading_dense_block_count(leading_dense)
7973+ self.gguf_writer.add_moe_every_n_layers(int(self.hparams.get("moe_every_n_layer", 1)))
7974+
7975+ self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
7976+
7977+ # Optional per-layer SwiGLU clamps.
7978+ if (limits := self.hparams.get("swiglu_limits")) is not None:
7979+ limits_f = [0.0 if v is None else float(v) for v in limits[: self.block_count]]
7980+ self.gguf_writer.add_swiglu_clamp_exp(limits_f)
7981+ if (limits_shared := self.hparams.get("swiglu_limits_shared")) is not None:
7982+ limits_shared_f = [0.0 if v is None else float(v) for v in limits_shared[: self.block_count]]
7983+ self.gguf_writer.add_swiglu_clamp_shexp(limits_shared_f)
7984+
7985+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
7986+ # remove mtp layers
7987+ if (m := re.match(r"model\.layers\.(\d+)\.", name)) is not None:
7988+ il = int(m.group(1))
7989+ n_main = int(self.hparams.get("num_hidden_layers", self.block_count))
7990+ if il >= n_main:
7991+ return
7992+ if name.endswith("norm.weight"):
7993+ data_torch += 1.0
7994+ # Map router bias (expert selection bias) to a GGUF bias tensor
7995+ if name.endswith(".moe.router_bias"):
7996+ name += ".bias"
7997+
7998+ if name.endswith((".self_attn.g_proj.weight", ".moe.gate.weight", ".moe.up_proj.weight", ".moe.gate_proj.weight", ".moe.down_proj.weight")):
7999+ data_torch = data_torch.squeeze().contiguous()
8000+
8001+ yield from super().modify_tensors(data_torch, name, bid)
8002+
8003+ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
8004+ # Step35 can optionally use Llama-3 style RoPE scaling (HF: rope_scaling.rope_type == "llama3").
8005+ # llama.cpp represents this via a single extra tensor: "rope_freqs.weight" (aka MODEL_TENSOR.ROPE_FREQS).
8006+ rope_params = self.rope_parameters.get("full_attention", self.rope_parameters)
8007+ rope_type = rope_params.get("rope_type") or ""
8008+ if rope_type.lower() != "llama3":
8009+ return
8010+
8011+ # Step35 configs can carry per-layer rope_theta as a list; for llama3 rope factors we use the base value.
8012+ rope_theta = self.hparams.get("rope_theta", 10000.0)
8013+ if isinstance(rope_theta, list):
8014+ rope_theta = rope_theta[0]
8015+ base = float(rope_theta)
8016+ if (dim := self.hparams.get("head_dim")) is None:
8017+ dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
8018+ dim = int(dim)
8019+
8020+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
8021+
8022+ factor = float(rope_params.get("factor", 8.0))
8023+ low_freq_factor = float(rope_params.get("low_freq_factor", 1.0))
8024+ high_freq_factor = float(rope_params.get("high_freq_factor", 4.0))
8025+ old_context_len = int(rope_params.get("original_max_position_embeddings", self.hparams.get("original_max_position_embeddings", 8192)))
8026+
8027+ low_freq_wavelen = old_context_len / low_freq_factor
8028+ high_freq_wavelen = old_context_len / high_freq_factor
8029+
8030+ rope_factors: list[float] = []
8031+ for freq in freqs:
8032+ wavelen = 2 * math.pi / float(freq)
8033+ if wavelen < high_freq_wavelen:
8034+ rope_factors.append(1.0)
8035+ elif wavelen > low_freq_wavelen:
8036+ rope_factors.append(factor)
8037+ else:
8038+ smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
8039+ rope_factors.append(1.0 / ((1.0 - smooth) / factor + smooth))
8040+
8041+ yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
8042+
8043+
79158044@ModelBase.register("PanguEmbeddedForCausalLM")
79168045class PanguEmbeddedModel(TextModel):
79178046 model_arch = gguf.MODEL_ARCH.PANGU_EMBED
0 commit comments