Skip to content

Commit e72cb0d

Browse files
committed
type hints update
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
1 parent 60f1aa0 commit e72cb0d

2 files changed

Lines changed: 10 additions & 8 deletions

File tree

fms_mo/utils/aiu_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
import logging
1818

1919
# Third Party
20-
from fms_mo.utils.qconfig_utils import qconfig_save
2120
from transformers.modeling_utils import PreTrainedModel
2221
import torch
2322

23+
# Local
24+
from fms_mo.utils.qconfig_utils import qconfig_save
25+
2426
# logging is only enabled for verbose output (performance is less critical during debug),
2527
# and f-string style logging is preferred for code readability
2628
# pylint: disable=logging-not-lazy
@@ -218,23 +220,23 @@ def convert_sd_for_aiu(
218220

219221
def save_sd_for_aiu(
220222
model: PreTrainedModel,
221-
output_dir: str = "./",
222-
savename: str = "qmodel_state_dict.pt",
223+
output_dir: str | Path = "./",
224+
savename: str | Path = "qmodel_for_aiu.pt",
223225
verbose: bool = False,
224226
) -> None:
225227
"""Save model state dictionary after conversion for AIU compatibility."""
226228

227229
converted_sd = convert_sd_for_aiu(model, verbose)
228230
torch.save(converted_sd, Path(output_dir) / savename)
229-
logger.info("Model saved.")
231+
logger.info(f"Quantized model checkpoint saved to {Path(output_dir) / savename}")
230232

231233

232234
def save_for_aiu(
233235
model: PreTrainedModel,
234236
qcfg: dict,
235-
output_dir: str = "./",
236-
file_name: str = "qmodel.pt",
237-
cfg_name: str = "qcfg.json",
237+
output_dir: str | Path = "./",
238+
file_name: str | Path = "qmodel_for_aiu.pt",
239+
cfg_name: str | Path = "qcfg.json",
238240
recipe: str | None = None,
239241
verbose: bool = False,
240242
) -> None:

fms_mo/utils/qconfig_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def qconfig_save(
540540
qcfg: dict,
541541
recipe: str | None = None,
542542
minimal: bool = True,
543-
fname: str = "qcfg.json",
543+
fname: str | Path = "qcfg.json",
544544
) -> None:
545545
"""
546546
Try to save qcfg into a JSON file (or use .pt format if something really can't be text-only).

0 commit comments

Comments
 (0)