|
28 | 28 | from tqdm import tqdm |
29 | 29 | from transformers.models.deepseek_v3 import DeepseekV3Config |
30 | 30 |
|
31 | | -from ....utils import print_info |
32 | | -from ..modules import QDQModule, QDQSingleModule |
| 31 | +from ....utils import find_layers, find_parent_layer_and_sub_name, print_info |
| 32 | +from ..modules import QDQModule, QDQSingleModule, QLinear |
33 | 33 | from .packing_utils import pack_weight_to_int8 |
34 | 34 | from .quant_func import fake_quant_dequant, tensor_quant, weight_dequant |
35 | 35 |
|
@@ -188,6 +188,96 @@ def save(self, save_path): |
188 | 188 | self.quant_model.tokenizer.save_pretrained(save_path) |
189 | 189 |
|
190 | 190 |
|
| 191 | +class PTQDiffusionSave(PTQSaveBase): |
| 192 | + def __init__(self, quant_model): |
| 193 | + super().__init__(quant_model=quant_model) |
| 194 | + |
| 195 | + def save(self, save_path): |
| 196 | + a_quant_algo = self.quant_model.quant_config.quant_algo_info["a"] |
| 197 | + ignored_layers = self.quant_model.skip_layer_names() |
| 198 | + |
| 199 | + static_q_dict = { |
| 200 | + "quantization_config": { |
| 201 | + "quant_method": "fp8", |
| 202 | + "activation_scheme": ( |
| 203 | + "dynamic" if "dynamic" in a_quant_algo else "static" |
| 204 | + ), |
| 205 | + "ignored_layers": ignored_layers, |
| 206 | + } |
| 207 | + } |
| 208 | + |
| 209 | + os.makedirs(save_path, exist_ok=True) |
| 210 | + with open(os.path.join(save_path, "hf_quant_config.json"), "w") as f: |
| 211 | + json.dump(static_q_dict, f, indent=4) |
| 212 | + |
| 213 | + save_scales = {} |
| 214 | + layers_dict = find_layers( |
| 215 | + self.quant_model.get_model().transformer, layers=[QDQModule] |
| 216 | + ) |
| 217 | + for name, sub_layer in layers_dict.items(): |
| 218 | + parent_layer, sub_name = find_parent_layer_and_sub_name( |
| 219 | + self.quant_model.get_model().transformer, name |
| 220 | + ) |
| 221 | + q_module = QLinear( |
| 222 | + quant_algo=sub_layer.quant_algo, |
| 223 | + weight=sub_layer.weight, |
| 224 | + bias=sub_layer.bias, |
| 225 | + weight_scale=sub_layer.weight_scale.data.clone().detach(), |
| 226 | + input_scale=sub_layer.input_scale.data.clone().detach(), |
| 227 | + ) |
| 228 | + setattr(parent_layer, sub_name, q_module) |
| 229 | + save_scales[name + ".input_scale"] = sub_layer.input_scale |
| 230 | + save_scales[name + ".weight_scale"] = sub_layer.weight_scale |
| 231 | + |
| 232 | + self.quant_model.get_model().save_pretrained(save_path) |
| 233 | + safetensor_file = os.path.join(save_path, "model-scales.safetensors") |
| 234 | + safe_save(save_scales, safetensor_file) |
| 235 | + |
| 236 | + |
| 237 | +class PTQOnlyScaleSave(PTQSaveBase): |
| 238 | + def __init__(self, quant_model): |
| 239 | + super().__init__(quant_model=quant_model) |
| 240 | + |
| 241 | + def save(self, save_path): |
| 242 | + a_quant_algo = self.quant_model.quant_config.quant_algo_info["a"] |
| 243 | + ignored_layers = self.quant_model.skip_layer_names() |
| 244 | + |
| 245 | + static_q_dict = { |
| 246 | + "quantization_config": { |
| 247 | + "quant_method": "fp8", |
| 248 | + "activation_scheme": ( |
| 249 | + "dynamic" if "dynamic" in a_quant_algo else "static" |
| 250 | + ), |
| 251 | + "ignored_layers": ignored_layers, |
| 252 | + } |
| 253 | + } |
| 254 | + |
| 255 | + os.makedirs(save_path, exist_ok=True) |
| 256 | + with open(os.path.join(save_path, "hf_quant_config.json"), "w") as f: |
| 257 | + json.dump(static_q_dict, f, indent=4) |
| 258 | + |
| 259 | + save_scales = {} |
| 260 | + new_model_index = { |
| 261 | + "metadata": {}, |
| 262 | + "weight_map": {}, |
| 263 | + } |
| 264 | + safetensor_name = "model-scales.safetensors" |
| 265 | + for name, value in self.quant_model.act_scales_dict.items(): |
| 266 | + save_scales[name + ".input_scale"] = value |
| 267 | + new_model_index["weight_map"][name + ".input_scale"] = safetensor_name |
| 268 | + for name, value in self.quant_model.weight_scales_dict.items(): |
| 269 | + save_scales[name + ".weight_scale"] = value |
| 270 | + new_model_index["weight_map"][name + ".weight_scale"] = safetensor_name |
| 271 | + |
| 272 | + safetensor_file = os.path.join(save_path, safetensor_name) |
| 273 | + safe_save(save_scales, safetensor_file) |
| 274 | + |
| 275 | + # update model index json |
| 276 | + new_model_index_file = os.path.join(save_path, "model.safetensors.index.json") |
| 277 | + with open(new_model_index_file, "w") as f: |
| 278 | + json.dump(new_model_index, f, indent=2) |
| 279 | + |
| 280 | + |
191 | 281 | class PTQTorchSave(PTQSaveBase): |
192 | 282 | def __init__(self, quant_model): |
193 | 283 | super(PTQTorchSave, self).__init__(quant_model=quant_model) |
@@ -594,7 +684,7 @@ def merge_model(self, input_path, save_model_path, mp=16): |
594 | 684 | param_list.append(param) |
595 | 685 | newparam = torch.cat(param_list, dim=0) |
596 | 686 | new_save_dict[k] = newparam |
597 | | - print(f"shape of {k}: {new_save_dict[k].shape}") |
| 687 | + print_info(f"shape of {k}: {new_save_dict[k].shape}") |
598 | 688 | index_dict["weight_map"][k] = str(filename) |
599 | 689 | safe_save(new_save_dict, os.path.join(save_model_path, filename)) |
600 | 690 | # process others |
@@ -625,7 +715,7 @@ def merge_model(self, input_path, save_model_path, mp=16): |
625 | 715 | index_dict, |
626 | 716 | filename, |
627 | 717 | ) |
628 | | - print(f"shape of {k}: {new_save_dict[k].shape}") |
| 718 | + print_info(f"shape of {k}: {new_save_dict[k].shape}") |
629 | 719 | safe_save(new_save_dict, os.path.join(save_model_path, filename)) |
630 | 720 |
|
631 | 721 | # update scales map |
|
0 commit comments