@@ -10055,20 +10055,23 @@ def transform_nibble_layout(self, tensor):
1005510055 out = (out_h >> 4) | (out_l << 4)
1005610056 return out
1005710057
10058- def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor):
10058+ def _repack_mxfp4(self, blocks: Tensor, scales: Tensor) -> Tensor:
10059+ """Repack blocks and scales into MXFP4 format, returns tensor."""
1005910060 assert blocks.dtype == torch.uint8
1006010061 assert scales.dtype == torch.uint8
1006110062 scales = scales.unsqueeze(-1)
1006210063 assert len(blocks.shape) == 4
1006310064 assert len(scales.shape) == 4
1006410065 blocks = self.transform_nibble_layout(blocks)
1006510066 new_data = torch.concat((scales, blocks), dim=-1)
10066- new_shape = [new_data.shape[0], new_data.shape[1], new_data.shape[2] * 32]
10067- logger.info(f"Repacked {new_name} with shape {new_shape} and quantization MXFP4")
1006810067 # flatten last dim
1006910068 new_data = new_data.view(new_data.shape[0], new_data.shape[1], new_data.shape[2] * new_data.shape[3])
10070- new_data = new_data.numpy()
10071- self.gguf_writer.add_tensor(new_name, new_data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)
10069+ return new_data
10070+
10071+ def tensor_force_quant(self, name, new_name, bid, n_dims):
10072+ if any(x in new_name for x in ("ffn_gate_exps", "ffn_up_exps", "ffn_down_exps", "ffn_gate_up_exps")):
10073+ return gguf.GGMLQuantizationType.MXFP4
10074+ return super().tensor_force_quant(name, new_name, bid, n_dims)
1007210075
1007310076 def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
1007410077 blocks0: Tensor = torch.zeros(1)
0 commit comments