Skip to content

Commit be18b63

Browse files
committed
Add export_checkpoint() bundle CLI and update bundle load()
Add export_checkpoint() as the torch.export replacement for the deprecated ckpt_export() command. Update load() to support .pt2 files via load_exported_module(). Register the new command in __main__.py and wire up the TRT save wrapper. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent a769c7e commit be18b63

File tree

3 files changed

+208
-26
lines changed

3 files changed

+208
-26
lines changed

monai/bundle/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
create_workflow,
2121
download,
2222
download_large_files,
23+
export_checkpoint,
2324
get_all_bundles_list,
2425
get_bundle_info,
2526
get_bundle_versions,

monai/bundle/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ckpt_export,
1616
download,
1717
download_large_files,
18+
export_checkpoint,
1819
init_bundle,
1920
onnx_export,
2021
run,

monai/bundle/scripts.py

Lines changed: 206 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
3636
from monai.config import PathLike
3737
from monai.data import load_net_with_metadata, save_net_with_metadata
38+
from monai.data.export_utils import load_exported_program, save_exported_program
3839
from monai.networks import (
40+
convert_to_export,
3941
convert_to_onnx,
4042
convert_to_torchscript,
4143
convert_to_trt,
@@ -46,6 +48,7 @@
4648
from monai.utils import (
4749
IgniteInfo,
4850
check_parent_dir,
51+
deprecated,
4952
ensure_tuple,
5053
get_equivalent_dtype,
5154
min_version,
@@ -632,6 +635,7 @@ def load(
632635
workflow_type: str = "train",
633636
model_file: str | None = None,
634637
load_ts_module: bool = False,
638+
load_exported_module: bool = False,
635639
bundle_dir: PathLike | None = None,
636640
source: str = DEFAULT_DOWNLOAD_SOURCE,
637641
repo: str | None = None,
@@ -646,7 +650,7 @@ def load(
646650
net_override: dict | None = None,
647651
) -> object | tuple[torch.nn.Module, dict, dict] | Any:
648652
"""
649-
Load model weights or TorchScript module of a bundle.
653+
Load model weights, TorchScript module, or exported program of a bundle.
650654
651655
Args:
652656
name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
@@ -664,10 +668,16 @@ def load(
664668
or "infer", "inference", "eval", "evaluation" for a inference workflow,
665669
other unsupported string will raise a ValueError.
666670
default to `train` for training workflow.
667-
model_file: the relative path of the model weights or TorchScript module within bundle.
668-
If `None`, "models/model.pt" or "models/model.ts" will be used.
671+
model_file: the relative path of the model weights or exported module within bundle.
672+
If `None`, "models/model.pt", "models/model.ts", or "models/model.pt2" will be used
673+
depending on the loading mode.
669674
load_ts_module: a flag to specify if loading the TorchScript module.
670-
bundle_dir: directory the weights/TorchScript module will be loaded from.
675+
676+
.. deprecated:: 1.5
677+
Use ``load_exported_module=True`` instead.
678+
679+
load_exported_module: a flag to specify if loading a ``torch.export`` ``.pt2`` module.
680+
bundle_dir: directory the weights/module will be loaded from.
671681
Default is `bundle` subfolder under `torch.hub.get_dir()`.
672682
source: storage location name. This argument is used when `model_file` is not existing locally and need to be
673683
downloaded first.
@@ -684,32 +694,48 @@ def load(
684694
device: target device of returned weights or module, if `None`, prefer to "cuda" if existing.
685695
key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model
686696
weights. if not nested checkpoint, no need to set.
687-
config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module,
688-
see `_extra_files` in `torch.jit.load` for more details.
697+
config_files: extra filenames would be loaded. The argument only works when loading a TorchScript
698+
or exported module, see ``_extra_files`` in ``torch.jit.load`` / ``torch.export.load`` for details.
689699
workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
690700
args_file: a JSON or YAML file to provide default values for all the args in "download" function.
691701
copy_model_args: other arguments for the `monai.networks.copy_model_state` function.
692702
net_override: id-value pairs to override the parameters in the network of the bundle, default to `None`.
693703
694704
Returns:
695-
1. If `load_ts_module` is `False` and `model` is `None`,
705+
1. If ``load_ts_module`` and ``load_exported_module`` are both ``False`` and ``model`` is ``None``,
696706
return model weights if can't find "network_def" in the bundle,
697707
else return an instantiated network that loaded the weights.
698-
2. If `load_ts_module` is `False` and `model` is not `None`,
708+
2. If ``load_ts_module`` and ``load_exported_module`` are both ``False`` and ``model`` is not ``None``,
699709
return an instantiated network that loaded the weights.
700-
3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
710+
3. If ``load_ts_module`` is ``True``, return a triple that include a TorchScript module,
701711
the corresponding metadata dict, and extra files dict.
702-
please check `monai.data.load_net_with_metadata` for more details.
712+
please check ``monai.data.load_net_with_metadata`` for more details.
713+
4. If ``load_exported_module`` is ``True``, return a triple of
714+
(ExportedProgram, metadata dict, extra files dict).
715+
See :func:`monai.data.load_exported_program` for more details.
703716
704717
"""
705718
bundle_dir_ = _process_bundle_dir(bundle_dir)
706719
net_override = {} if net_override is None else net_override
707720
copy_model_args = {} if copy_model_args is None else copy_model_args
708721

722+
if load_ts_module:
723+
warnings.warn(
724+
"load_ts_module is deprecated since v1.5 and will be removed in v1.7. "
725+
"Use load_exported_module=True instead.",
726+
FutureWarning,
727+
stacklevel=2,
728+
)
729+
709730
if device is None:
710731
device = "cuda:0" if is_available() else "cpu"
711732
if model_file is None:
712-
model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt")
733+
if load_exported_module:
734+
model_file = os.path.join("models", "model.pt2")
735+
elif load_ts_module:
736+
model_file = os.path.join("models", "model.ts")
737+
else:
738+
model_file = os.path.join("models", "model.pt")
713739
if source == "ngc":
714740
name = _add_ngc_prefix(name)
715741
if remove_prefix:
@@ -727,14 +753,25 @@ def load(
727753
args_file=args_file,
728754
)
729755

756+
# loading with `torch.export.load`
757+
if load_exported_module:
758+
return load_exported_program(full_path, more_extra_files=config_files or ())
730759
# loading with `torch.jit.load`
731760
if load_ts_module is True:
732-
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
761+
# Suppress the @deprecated warning from load_net_with_metadata since the user
762+
# already received a FutureWarning about load_ts_module above.
763+
with warnings.catch_warnings():
764+
warnings.filterwarnings("ignore", category=FutureWarning, message=".*load_net_with_metadata.*")
765+
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
733766
# loading with `torch.load`
734767
model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True)
735768

736769
if not isinstance(model_dict, Mapping):
737-
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
770+
warnings.warn(
771+
f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.",
772+
category=UserWarning,
773+
stacklevel=2,
774+
)
738775
model_dict = get_state_dict(model_dict)
739776

740777
_workflow = None
@@ -750,11 +787,17 @@ def load(
750787
**_net_override,
751788
)
752789
else:
753-
warnings.warn(f"Cannot find the config file: {bundle_config_file}, return state dict instead.")
790+
warnings.warn(
791+
f"Cannot find the config file: {bundle_config_file}, return state dict instead.",
792+
stacklevel=2,
793+
)
754794
return model_dict
755795
if _workflow is not None:
756796
if not hasattr(_workflow, "network_def"):
757-
warnings.warn("No available network definition in the bundle, return state dict instead.")
797+
warnings.warn(
798+
"No available network definition in the bundle, return state dict instead.",
799+
stacklevel=2,
800+
)
758801
return model_dict
759802
else:
760803
model = _workflow.network_def
@@ -1277,7 +1320,7 @@ def _export(
12771320
(extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input.
12781321
parser: a ConfigParser of the bundle to be converted.
12791322
net_id: ID name of the network component in the parser, it must be `torch.nn.Module`.
1280-
filepath: filepath to export, if filename has no extension, it becomes `.ts`.
1323+
filepath: filepath to export.
12811324
ckpt_file: filepath of the model checkpoint to load.
12821325
config_file: filepath of the config file to save in the converted model,the saved key in the converted
12831326
model is the config filename without extension, and the saved config value is always serialized in
@@ -1434,6 +1477,7 @@ def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> N
14341477
)
14351478

14361479

1480+
@deprecated(since="1.5", removed="1.7", msg_suffix="Use export_checkpoint() instead.")
14371481
def ckpt_export(
14381482
net_id: str | None = None,
14391483
filepath: PathLike | None = None,
@@ -1568,6 +1612,139 @@ def ckpt_export(
15681612
)
15691613

15701614

1615+
def export_checkpoint(
1616+
net_id: str | None = None,
1617+
filepath: PathLike | None = None,
1618+
ckpt_file: str | None = None,
1619+
meta_file: str | Sequence[str] | None = None,
1620+
config_file: str | Sequence[str] | None = None,
1621+
key_in_ckpt: str | None = None,
1622+
input_shape: Sequence[int] | None = None,
1623+
dynamic_shapes: dict | tuple | None = None,
1624+
args_file: str | None = None,
1625+
converter_kwargs: Mapping | None = None,
1626+
**override: Any,
1627+
) -> None:
1628+
"""
1629+
Export the model checkpoint to a ``.pt2`` file using :func:`torch.export.export`, with metadata and
1630+
config included.
1631+
1632+
Typical usage examples:
1633+
1634+
.. code-block:: bash
1635+
1636+
python -m monai.bundle export_checkpoint network --filepath <path> --ckpt_file <checkpoint path> ...
1637+
1638+
Args:
1639+
net_id: ID name of the network component in the config, it must be ``torch.nn.Module``.
1640+
Default to ``"network_def"``.
1641+
filepath: filepath to export. If filename has no extension it becomes ``.pt2``.
1642+
Default to ``"models/model.pt2"`` under ``"os.getcwd()"`` if ``bundle_root`` is not specified.
1643+
ckpt_file: filepath of the model checkpoint to load.
1644+
Default to ``"models/model.pt"`` under ``"os.getcwd()"`` if ``bundle_root`` is not specified.
1645+
meta_file: filepath of the metadata file. If it is a list of file paths, contents will be merged.
1646+
Default to ``"configs/metadata.json"`` under ``"os.getcwd()"`` if ``bundle_root`` is not specified.
1647+
config_file: filepath of the config file to save in the exported model. The saved key is the
1648+
config filename without extension; the value is always serialized in JSON format.
1649+
It can be a single file or a list of files. If ``None``, must be provided in ``args_file``.
1650+
key_in_ckpt: for nested checkpoints like ``{"model": XXX, "optimizer": XXX, ...}``, specify the
1651+
key of model weights. If not nested, no need to set.
1652+
input_shape: a shape used to generate random input for the network, e.g. ``[N, C, H, W]`` or
1653+
``[N, C, H, W, D]``. If not given, will try to parse from ``metadata``.
1654+
dynamic_shapes: dynamic shape specifications passed to :func:`torch.export.export`.
1655+
args_file: a JSON or YAML file to provide default values for all the parameters.
1656+
converter_kwargs: extra arguments for :func:`~monai.networks.utils.convert_to_export`,
1657+
except ones that already exist in the input parameters.
1658+
override: id-value pairs to override or add the corresponding config content.
1659+
"""
1660+
_args = update_kwargs(
1661+
args=args_file,
1662+
net_id=net_id,
1663+
filepath=filepath,
1664+
meta_file=meta_file,
1665+
config_file=config_file,
1666+
ckpt_file=ckpt_file,
1667+
key_in_ckpt=key_in_ckpt,
1668+
input_shape=input_shape,
1669+
dynamic_shapes=dynamic_shapes,
1670+
converter_kwargs=converter_kwargs,
1671+
**override,
1672+
)
1673+
_log_input_summary(tag="export_checkpoint", args=_args)
1674+
(
1675+
config_file_,
1676+
filepath_,
1677+
ckpt_file_,
1678+
net_id_,
1679+
meta_file_,
1680+
key_in_ckpt_,
1681+
input_shape_,
1682+
dynamic_shapes_,
1683+
converter_kwargs_,
1684+
) = _pop_args(
1685+
_args,
1686+
"config_file",
1687+
filepath=None,
1688+
ckpt_file=None,
1689+
net_id=None,
1690+
meta_file=None,
1691+
key_in_ckpt="",
1692+
input_shape=None,
1693+
dynamic_shapes=None,
1694+
converter_kwargs={},
1695+
)
1696+
bundle_root = _args.get("bundle_root", os.getcwd())
1697+
1698+
parser = ConfigParser()
1699+
parser.read_config(f=config_file_)
1700+
meta_file_ = os.path.join(bundle_root, "configs", "metadata.json") if meta_file_ is None else meta_file_
1701+
if os.path.exists(meta_file_):
1702+
parser.read_meta(f=meta_file_)
1703+
1704+
for k, v in _args.items():
1705+
parser[k] = v
1706+
1707+
filepath_ = os.path.join(bundle_root, "models", "model.pt2") if filepath_ is None else filepath_
1708+
ckpt_file_ = os.path.join(bundle_root, "models", "model.pt") if ckpt_file_ is None else ckpt_file_
1709+
if not os.path.exists(ckpt_file_):
1710+
raise FileNotFoundError(f'Checkpoint file "{ckpt_file_}" not found, please specify it in argument "ckpt_file".')
1711+
1712+
net_id_ = "network_def" if net_id_ is None else net_id_
1713+
try:
1714+
parser.get_parsed_content(net_id_)
1715+
except ValueError as e:
1716+
raise ValueError(
1717+
f'Network definition "{net_id_}" cannot be found in "{config_file_}", specify name with argument "net_id".'
1718+
) from e
1719+
1720+
if not input_shape_:
1721+
input_shape_ = _get_fake_input_shape(parser=parser)
1722+
1723+
if not input_shape_:
1724+
raise ValueError(
1725+
"Cannot determine input shape automatically. "
1726+
"Please provide it explicitly via the 'input_shape' argument."
1727+
)
1728+
1729+
inputs_: Sequence[Any] = [torch.rand(input_shape_)]
1730+
1731+
converter_kwargs_.update({"inputs": inputs_, "dynamic_shapes": dynamic_shapes_})
1732+
1733+
save_ep = partial(save_exported_program, include_config_vals=False, append_timestamp=False)
1734+
1735+
_export(
1736+
convert_to_export,
1737+
save_ep,
1738+
parser,
1739+
net_id=net_id_,
1740+
filepath=filepath_,
1741+
ckpt_file=ckpt_file_,
1742+
config_file=config_file_,
1743+
key_in_ckpt=key_in_ckpt_,
1744+
**converter_kwargs_,
1745+
)
1746+
1747+
15711748
def trt_export(
15721749
net_id: str | None = None,
15731750
filepath: PathLike | None = None,
@@ -1588,20 +1765,19 @@ def trt_export(
15881765
**override: Any,
15891766
) -> None:
15901767
"""
1591-
Export the model checkpoint to the given filepath as a TensorRT engine-based TorchScript.
1768+
Export the model checkpoint to the given filepath as a TensorRT engine.
15921769
Currently, this API only supports converting models whose inputs are all tensors.
15931770
Note: NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.
15941771
Review the TensorRT Support Matrix for which GPUs are supported.
15951772
15961773
There are two ways to export a model:
1597-
1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript.
1598-
2, ONNX-TensorRT way: PyTorch module ---> TorchScript module ---> ONNX model ---> TensorRT engine --->
1599-
TensorRT engine-based TorchScript.
1774+
1, Torch-TensorRT way: PyTorch module ---> TensorRT engine (via ``torch.export`` on PyTorch >= 2.9,
1775+
or via TorchScript on older versions).
1776+
2, ONNX-TensorRT way: PyTorch module ---> ONNX model ---> TensorRT engine.
16001777
16011778
When exporting through the first way, some models suffer from the slowdown problem, since Torch-TensorRT
16021779
may only convert a little part of the PyTorch model to the TensorRT engine. However when exporting through
1603-
the second way, some Python data structures like `dict` are not supported. And some TorchScript models are
1604-
not supported by the ONNX if exported through `torch.jit.script`.
1780+
the second way, some Python data structures like ``dict`` are not supported.
16051781
16061782
Typical usage examples:
16071783
@@ -1624,8 +1800,8 @@ def trt_export(
16241800
precision: the weight precision of the converted TensorRT engine based TorchScript models. Should be 'fp32' or 'fp16'.
16251801
input_shape: the input shape that is used to convert the model. Should be a list like [N, C, H, W] or
16261802
[N, C, H, W, D]. If not given, will try to parse from the `metadata` config.
1627-
use_trace: whether using `torch.jit.trace` to convert the PyTorch model to a TorchScript model and then convert to
1628-
a TensorRT engine based TorchScript model or an ONNX model (if `use_onnx` is True).
1803+
use_trace: whether using ``torch.jit.trace`` to convert the PyTorch model to a TorchScript model
1804+
(only used on PyTorch < 2.9 when ``use_onnx`` is ``False``; on 2.9+ ``torch.export`` is used instead).
16291805
dynamic_batchsize: a sequence with three elements to define the batch size range of the input for the model to be
16301806
converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. After converted, the batchsize of
16311807
model input should between `MIN_BATCH` and `MAX_BATCH` and the `OPT_BATCH` is the best performance batchsize
@@ -1729,11 +1905,15 @@ def trt_export(
17291905
}
17301906
converter_kwargs_.update(trt_api_parameters)
17311907

1732-
save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)
1908+
def _save_trt_model(trt_obj, filepath, **kwargs):
1909+
"""Save TRT model without triggering deprecation warnings from internal calls."""
1910+
with warnings.catch_warnings():
1911+
warnings.filterwarnings("ignore", category=FutureWarning, message=".*save_net_with_metadata.*")
1912+
save_net_with_metadata(trt_obj, filepath, include_config_vals=False, append_timestamp=False, **kwargs)
17331913

17341914
_export(
17351915
convert_to_trt,
1736-
save_ts,
1916+
_save_trt_model,
17371917
parser,
17381918
net_id=net_id_,
17391919
filepath=filepath_,

0 commit comments

Comments
 (0)