|
17 | 17 | import logging |
18 | 18 |
|
19 | 19 | # Third Party |
20 | | -from fms_mo.utils.qconfig_utils import qconfig_save |
21 | 20 | from transformers.modeling_utils import PreTrainedModel |
22 | 21 | import torch |
23 | 22 |
|
| 23 | +# Local |
| 24 | +from fms_mo.utils.qconfig_utils import qconfig_save |
| 25 | + |
24 | 26 | # logging is only enabled for verbose output (performance is less critical during debug), |
25 | 27 | # and f-string style logging is preferred for code readability |
26 | 28 | # pylint: disable=logging-not-lazy |
@@ -218,23 +220,23 @@ def convert_sd_for_aiu( |
218 | 220 |
|
219 | 221 | def save_sd_for_aiu( |
220 | 222 | model: PreTrainedModel, |
221 | | - output_dir: str = "./", |
222 | | - savename: str = "qmodel_state_dict.pt", |
| 223 | + output_dir: str | Path = "./", |
| 224 | + savename: str | Path = "qmodel_for_aiu.pt", |
223 | 225 | verbose: bool = False, |
224 | 226 | ) -> None: |
225 | 227 | """Save model state dictionary after conversion for AIU compatibility.""" |
226 | 228 |
|
227 | 229 | converted_sd = convert_sd_for_aiu(model, verbose) |
228 | 230 | torch.save(converted_sd, Path(output_dir) / savename) |
229 | | - logger.info("Model saved.") |
| 231 | + logger.info(f"Quantized model checkpoint saved to {Path(output_dir) / savename}") |
230 | 232 |
|
231 | 233 |
|
232 | 234 | def save_for_aiu( |
233 | 235 | model: PreTrainedModel, |
234 | 236 | qcfg: dict, |
235 | | - output_dir: str = "./", |
236 | | - file_name: str = "qmodel.pt", |
237 | | - cfg_name: str = "qcfg.json", |
| 237 | + output_dir: str | Path = "./", |
| 238 | + file_name: str | Path = "qmodel_for_aiu.pt", |
| 239 | + cfg_name: str | Path = "qcfg.json", |
238 | 240 | recipe: str | None = None, |
239 | 241 | verbose: bool = False, |
240 | 242 | ) -> None: |
|
0 commit comments