@@ -1273,7 +1273,7 @@ def set_gguf_parameters(self):
12731273 if (f_norm_eps := self .find_hparam (["layer_norm_eps" , "layer_norm_epsilon" , "norm_epsilon" ], optional = True )) is not None :
12741274 self .gguf_writer .add_layer_norm_eps (f_norm_eps )
12751275 logger .info (f"gguf: layer norm epsilon = { f_norm_eps } " )
1276- if (n_experts := self .find_hparam (["num_local_experts" , "num_experts" ], optional = True )) is not None :
1276+ if (n_experts := self .find_hparam (["num_local_experts" , "num_experts" , "n_routed_experts" ], optional = True )) is not None :
12771277 self .gguf_writer .add_expert_count (n_experts )
12781278 logger .info (f"gguf: expert count = { n_experts } " )
12791279 if (n_experts_used := self .find_hparam (["num_experts_per_tok" , "num_experts_per_token" , "top_k_experts" ], optional = True )) is not None :
@@ -1291,6 +1291,8 @@ def set_gguf_parameters(self):
12911291 self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
12921292 elif score_func == "softmax" :
12931293 self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SOFTMAX )
1294+ elif score_func == "sqrtsoftplus" :
1295+ self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SQRTSOFTPLUS )
12941296 else :
12951297 raise ValueError (f"Unsupported expert score gating function value: { score_func } " )
12961298 logger .info (f"gguf: expert score gating function = { score_func } " )
@@ -2600,6 +2602,17 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
26002602 return cls ._wrap_fn (func )(* args , ** kwargs )
26012603
26022604
2605+ if hasattr (torch , "float8_e8m0fnu" ):
2606+ _torch_float8_e8m0 = torch .float8_e8m0fnu
2607+ LazyTorchTensor ._dtype_map [_torch_float8_e8m0 ] = np .uint8
2608+ LazyTorchTensor ._dtype_byteswap_map [_torch_float8_e8m0 ] = np .uint8
2609+ LazyTorchTensor ._dtype_str_map ["F8_E8M0" ] = _torch_float8_e8m0
2610+ else :
2611+ # Older torch builds do not expose F8_E8M0. Keep the raw bytes so callers
2612+ # that know the format can decode them explicitly.
2613+ LazyTorchTensor ._dtype_str_map ["F8_E8M0" ] = torch .uint8
2614+
2615+
26032616def get_model_architecture (hparams : dict [str , Any ], model_type : ModelType ) -> str :
26042617 # TODO @ngxson : this won't work correctly if the model has both audio & vision encoders
26052618 # maybe we should fallback to text model's arch in that case, since not many models have both
0 commit comments