3535from monai .bundle .workflows import BundleWorkflow , ConfigWorkflow
3636from monai .config import PathLike
3737from monai .data import load_net_with_metadata , save_net_with_metadata
38+ from monai .data .export_utils import load_exported_program , save_exported_program
3839from monai .networks import (
40+ convert_to_export ,
3941 convert_to_onnx ,
4042 convert_to_torchscript ,
4143 convert_to_trt ,
4648from monai .utils import (
4749 IgniteInfo ,
4850 check_parent_dir ,
51+ deprecated ,
4952 ensure_tuple ,
5053 get_equivalent_dtype ,
5154 min_version ,
@@ -632,6 +635,7 @@ def load(
632635 workflow_type : str = "train" ,
633636 model_file : str | None = None ,
634637 load_ts_module : bool = False ,
638+ load_exported_module : bool = False ,
635639 bundle_dir : PathLike | None = None ,
636640 source : str = DEFAULT_DOWNLOAD_SOURCE ,
637641 repo : str | None = None ,
@@ -646,7 +650,7 @@ def load(
646650 net_override : dict | None = None ,
647651) -> object | tuple [torch .nn .Module , dict , dict ] | Any :
648652 """
649- Load model weights or TorchScript module of a bundle.
653+ Load model weights, TorchScript module, or exported program of a bundle.
650654
651655 Args:
652656 name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
@@ -664,10 +668,16 @@ def load(
664668 or "infer", "inference", "eval", "evaluation" for a inference workflow,
665669 other unsupported string will raise a ValueError.
666670 default to `train` for training workflow.
667- model_file: the relative path of the model weights or TorchScript module within bundle.
668- If `None`, "models/model.pt" or "models/model.ts" will be used.
671+ model_file: the relative path of the model weights or exported module within bundle.
672+ If `None`, "models/model.pt", "models/model.ts", or "models/model.pt2" will be used
673+ depending on the loading mode.
669674 load_ts_module: a flag to specify if loading the TorchScript module.
670- bundle_dir: directory the weights/TorchScript module will be loaded from.
675+
676+ .. deprecated:: 1.5
677+ Use ``load_exported_module=True`` instead.
678+
679+ load_exported_module: a flag to specify if loading a ``torch.export`` ``.pt2`` module.
680+ bundle_dir: directory the weights/module will be loaded from.
671681 Default is `bundle` subfolder under `torch.hub.get_dir()`.
672682 source: storage location name. This argument is used when `model_file` is not existing locally and need to be
673683 downloaded first.
@@ -684,32 +694,48 @@ def load(
684694 device: target device of returned weights or module, if `None`, prefer to "cuda" if existing.
685695 key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model
686696 weights. if not nested checkpoint, no need to set.
687- config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module,
688- see `_extra_files` in `torch.jit.load` for more details.
697+ config_files: extra filenames would be loaded. The argument only works when loading a TorchScript
698+ or exported module, see `` _extra_files`` in `` torch.jit.load`` / ``torch.export.load`` for details.
689699 workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
690700 args_file: a JSON or YAML file to provide default values for all the args in "download" function.
691701 copy_model_args: other arguments for the `monai.networks.copy_model_state` function.
692702 net_override: id-value pairs to override the parameters in the network of the bundle, default to `None`.
693703
694704 Returns:
695- 1. If `load_ts_module` is ` False` and `model` is `None`,
705+ 1. If `` load_ts_module`` and ``load_exported_module`` are both `` False`` and `` model`` is `` None` `,
696706 return model weights if can't find "network_def" in the bundle,
697707 else return an instantiated network that loaded the weights.
698- 2. If `load_ts_module` is ` False` and `model` is not `None`,
708+ 2. If `` load_ts_module`` and ``load_exported_module`` are both `` False`` and `` model`` is not `` None` `,
699709 return an instantiated network that loaded the weights.
700- 3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
710+ 3. If `` load_ts_module`` is `` True` `, return a triple that include a TorchScript module,
701711 the corresponding metadata dict, and extra files dict.
702- please check `monai.data.load_net_with_metadata` for more details.
712+ please check ``monai.data.load_net_with_metadata`` for more details.
713+ 4. If ``load_exported_module`` is ``True``, return a triple of
714+ (ExportedProgram, metadata dict, extra files dict).
715+ See :func:`monai.data.load_exported_program` for more details.
703716
704717 """
705718 bundle_dir_ = _process_bundle_dir (bundle_dir )
706719 net_override = {} if net_override is None else net_override
707720 copy_model_args = {} if copy_model_args is None else copy_model_args
708721
722+ if load_ts_module :
723+ warnings .warn (
724+ "load_ts_module is deprecated since v1.5 and will be removed in v1.7. "
725+ "Use load_exported_module=True instead." ,
726+ FutureWarning ,
727+ stacklevel = 2 ,
728+ )
729+
709730 if device is None :
710731 device = "cuda:0" if is_available () else "cpu"
711732 if model_file is None :
712- model_file = os .path .join ("models" , "model.ts" if load_ts_module is True else "model.pt" )
733+ if load_exported_module :
734+ model_file = os .path .join ("models" , "model.pt2" )
735+ elif load_ts_module :
736+ model_file = os .path .join ("models" , "model.ts" )
737+ else :
738+ model_file = os .path .join ("models" , "model.pt" )
713739 if source == "ngc" :
714740 name = _add_ngc_prefix (name )
715741 if remove_prefix :
@@ -727,14 +753,25 @@ def load(
727753 args_file = args_file ,
728754 )
729755
756+ # loading with `torch.export.load`
757+ if load_exported_module :
758+ return load_exported_program (full_path , more_extra_files = config_files or ())
730759 # loading with `torch.jit.load`
731760 if load_ts_module is True :
732- return load_net_with_metadata (full_path , map_location = torch .device (device ), more_extra_files = config_files )
761+ # Suppress the @deprecated warning from load_net_with_metadata since the user
762+ # already received a FutureWarning about load_ts_module above.
763+ with warnings .catch_warnings ():
764+ warnings .filterwarnings ("ignore" , category = FutureWarning , message = ".*load_net_with_metadata.*" )
765+ return load_net_with_metadata (full_path , map_location = torch .device (device ), more_extra_files = config_files )
733766 # loading with `torch.load`
734767 model_dict = torch .load (full_path , map_location = torch .device (device ), weights_only = True )
735768
736769 if not isinstance (model_dict , Mapping ):
737- warnings .warn (f"the state dictionary from { full_path } should be a dictionary but got { type (model_dict )} ." )
770+ warnings .warn (
771+ f"the state dictionary from { full_path } should be a dictionary but got { type (model_dict )} ." ,
772+ category = UserWarning ,
773+ stacklevel = 2 ,
774+ )
738775 model_dict = get_state_dict (model_dict )
739776
740777 _workflow = None
@@ -750,11 +787,17 @@ def load(
750787 ** _net_override ,
751788 )
752789 else :
753- warnings .warn (f"Cannot find the config file: { bundle_config_file } , return state dict instead." )
790+ warnings .warn (
791+ f"Cannot find the config file: { bundle_config_file } , return state dict instead." ,
792+ stacklevel = 2 ,
793+ )
754794 return model_dict
755795 if _workflow is not None :
756796 if not hasattr (_workflow , "network_def" ):
757- warnings .warn ("No available network definition in the bundle, return state dict instead." )
797+ warnings .warn (
798+ "No available network definition in the bundle, return state dict instead." ,
799+ stacklevel = 2 ,
800+ )
758801 return model_dict
759802 else :
760803 model = _workflow .network_def
@@ -1277,7 +1320,7 @@ def _export(
12771320 (extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input.
12781321 parser: a ConfigParser of the bundle to be converted.
12791322 net_id: ID name of the network component in the parser, it must be `torch.nn.Module`.
1280- filepath: filepath to export, if filename has no extension, it becomes `.ts` .
1323+ filepath: filepath to export.
12811324 ckpt_file: filepath of the model checkpoint to load.
12821325 config_file: filepath of the config file to save in the converted model,the saved key in the converted
12831326 model is the config filename without extension, and the saved config value is always serialized in
@@ -1434,6 +1477,7 @@ def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> N
14341477 )
14351478
14361479
1480+ @deprecated (since = "1.5" , removed = "1.7" , msg_suffix = "Use export_checkpoint() instead." )
14371481def ckpt_export (
14381482 net_id : str | None = None ,
14391483 filepath : PathLike | None = None ,
@@ -1568,6 +1612,139 @@ def ckpt_export(
15681612 )
15691613
15701614
1615+ def export_checkpoint (
1616+ net_id : str | None = None ,
1617+ filepath : PathLike | None = None ,
1618+ ckpt_file : str | None = None ,
1619+ meta_file : str | Sequence [str ] | None = None ,
1620+ config_file : str | Sequence [str ] | None = None ,
1621+ key_in_ckpt : str | None = None ,
1622+ input_shape : Sequence [int ] | None = None ,
1623+ dynamic_shapes : dict | tuple | None = None ,
1624+ args_file : str | None = None ,
1625+ converter_kwargs : Mapping | None = None ,
1626+ ** override : Any ,
1627+ ) -> None :
1628+ """
1629+ Export the model checkpoint to a ``.pt2`` file using :func:`torch.export.export`, with metadata and
1630+ config included.
1631+
1632+ Typical usage examples:
1633+
1634+ .. code-block:: bash
1635+
1636+ python -m monai.bundle export_checkpoint network --filepath <path> --ckpt_file <checkpoint path> ...
1637+
1638+ Args:
1639+ net_id: ID name of the network component in the config, it must be ``torch.nn.Module``.
1640+ Default to ``"network_def"``.
1641+ filepath: filepath to export. If filename has no extension it becomes ``.pt2``.
1642+ Default to ``"models/model.pt2"`` under ``"os.getcwd()"`` if ``bundle_root`` is not specified.
1643+ ckpt_file: filepath of the model checkpoint to load.
1644+ Default to ``"models/model.pt"`` under ``"os.getcwd()"`` if ``bundle_root`` is not specified.
1645+ meta_file: filepath of the metadata file. If it is a list of file paths, contents will be merged.
1646+ Default to ``"configs/metadata.json"`` under ``"os.getcwd()"`` if ``bundle_root`` is not specified.
1647+ config_file: filepath of the config file to save in the exported model. The saved key is the
1648+ config filename without extension; the value is always serialized in JSON format.
1649+ It can be a single file or a list of files. If ``None``, must be provided in ``args_file``.
1650+ key_in_ckpt: for nested checkpoints like ``{"model": XXX, "optimizer": XXX, ...}``, specify the
1651+ key of model weights. If not nested, no need to set.
1652+ input_shape: a shape used to generate random input for the network, e.g. ``[N, C, H, W]`` or
1653+ ``[N, C, H, W, D]``. If not given, will try to parse from ``metadata``.
1654+ dynamic_shapes: dynamic shape specifications passed to :func:`torch.export.export`.
1655+ args_file: a JSON or YAML file to provide default values for all the parameters.
1656+ converter_kwargs: extra arguments for :func:`~monai.networks.utils.convert_to_export`,
1657+ except ones that already exist in the input parameters.
1658+ override: id-value pairs to override or add the corresponding config content.
1659+ """
1660+ _args = update_kwargs (
1661+ args = args_file ,
1662+ net_id = net_id ,
1663+ filepath = filepath ,
1664+ meta_file = meta_file ,
1665+ config_file = config_file ,
1666+ ckpt_file = ckpt_file ,
1667+ key_in_ckpt = key_in_ckpt ,
1668+ input_shape = input_shape ,
1669+ dynamic_shapes = dynamic_shapes ,
1670+ converter_kwargs = converter_kwargs ,
1671+ ** override ,
1672+ )
1673+ _log_input_summary (tag = "export_checkpoint" , args = _args )
1674+ (
1675+ config_file_ ,
1676+ filepath_ ,
1677+ ckpt_file_ ,
1678+ net_id_ ,
1679+ meta_file_ ,
1680+ key_in_ckpt_ ,
1681+ input_shape_ ,
1682+ dynamic_shapes_ ,
1683+ converter_kwargs_ ,
1684+ ) = _pop_args (
1685+ _args ,
1686+ "config_file" ,
1687+ filepath = None ,
1688+ ckpt_file = None ,
1689+ net_id = None ,
1690+ meta_file = None ,
1691+ key_in_ckpt = "" ,
1692+ input_shape = None ,
1693+ dynamic_shapes = None ,
1694+ converter_kwargs = {},
1695+ )
1696+ bundle_root = _args .get ("bundle_root" , os .getcwd ())
1697+
1698+ parser = ConfigParser ()
1699+ parser .read_config (f = config_file_ )
1700+ meta_file_ = os .path .join (bundle_root , "configs" , "metadata.json" ) if meta_file_ is None else meta_file_
1701+ if os .path .exists (meta_file_ ):
1702+ parser .read_meta (f = meta_file_ )
1703+
1704+ for k , v in _args .items ():
1705+ parser [k ] = v
1706+
1707+ filepath_ = os .path .join (bundle_root , "models" , "model.pt2" ) if filepath_ is None else filepath_
1708+ ckpt_file_ = os .path .join (bundle_root , "models" , "model.pt" ) if ckpt_file_ is None else ckpt_file_
1709+ if not os .path .exists (ckpt_file_ ):
1710+ raise FileNotFoundError (f'Checkpoint file "{ ckpt_file_ } " not found, please specify it in argument "ckpt_file".' )
1711+
1712+ net_id_ = "network_def" if net_id_ is None else net_id_
1713+ try :
1714+ parser .get_parsed_content (net_id_ )
1715+ except ValueError as e :
1716+ raise ValueError (
1717+ f'Network definition "{ net_id_ } " cannot be found in "{ config_file_ } ", specify name with argument "net_id".'
1718+ ) from e
1719+
1720+ if not input_shape_ :
1721+ input_shape_ = _get_fake_input_shape (parser = parser )
1722+
1723+ if not input_shape_ :
1724+ raise ValueError (
1725+ "Cannot determine input shape automatically. "
1726+ "Please provide it explicitly via the 'input_shape' argument."
1727+ )
1728+
1729+ inputs_ : Sequence [Any ] = [torch .rand (input_shape_ )]
1730+
1731+ converter_kwargs_ .update ({"inputs" : inputs_ , "dynamic_shapes" : dynamic_shapes_ })
1732+
1733+ save_ep = partial (save_exported_program , include_config_vals = False , append_timestamp = False )
1734+
1735+ _export (
1736+ convert_to_export ,
1737+ save_ep ,
1738+ parser ,
1739+ net_id = net_id_ ,
1740+ filepath = filepath_ ,
1741+ ckpt_file = ckpt_file_ ,
1742+ config_file = config_file_ ,
1743+ key_in_ckpt = key_in_ckpt_ ,
1744+ ** converter_kwargs_ ,
1745+ )
1746+
1747+
15711748def trt_export (
15721749 net_id : str | None = None ,
15731750 filepath : PathLike | None = None ,
@@ -1588,20 +1765,19 @@ def trt_export(
15881765 ** override : Any ,
15891766) -> None :
15901767 """
1591- Export the model checkpoint to the given filepath as a TensorRT engine-based TorchScript .
1768+ Export the model checkpoint to the given filepath as a TensorRT engine.
15921769 Currently, this API only supports converting models whose inputs are all tensors.
15931770 Note: NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.
15941771 Review the TensorRT Support Matrix for which GPUs are supported.
15951772
15961773 There are two ways to export a model:
1597- 1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript.
1598- 2, ONNX-TensorRT way: PyTorch module ---> TorchScript module ---> ONNX model ---> TensorRT engine --->
1599- TensorRT engine-based TorchScript .
1774+ 1, Torch-TensorRT way: PyTorch module ---> TensorRT engine (via ``torch.export`` on PyTorch >= 2.9,
1775+ or via TorchScript on older versions).
1776+ 2, ONNX- TensorRT way: PyTorch module ---> ONNX model ---> TensorRT engine .
16001777
16011778 When exporting through the first way, some models suffer from the slowdown problem, since Torch-TensorRT
16021779 may only convert a little part of the PyTorch model to the TensorRT engine. However when exporting through
1603- the second way, some Python data structures like `dict` are not supported. And some TorchScript models are
1604- not supported by the ONNX if exported through `torch.jit.script`.
1780+ the second way, some Python data structures like ``dict`` are not supported.
16051781
16061782 Typical usage examples:
16071783
@@ -1624,8 +1800,8 @@ def trt_export(
16241800 precision: the weight precision of the converted TensorRT engine based TorchScript models. Should be 'fp32' or 'fp16'.
16251801 input_shape: the input shape that is used to convert the model. Should be a list like [N, C, H, W] or
16261802 [N, C, H, W, D]. If not given, will try to parse from the `metadata` config.
1627- use_trace: whether using `torch.jit.trace` to convert the PyTorch model to a TorchScript model and then convert to
1628- a TensorRT engine based TorchScript model or an ONNX model (if `use_onnx` is True ).
1803+ use_trace: whether using `` torch.jit.trace`` to convert the PyTorch model to a TorchScript model
1804+ (only used on PyTorch < 2.9 when ``use_onnx`` is ``False``; on 2.9+ ``torch.export`` is used instead ).
16291805 dynamic_batchsize: a sequence with three elements to define the batch size range of the input for the model to be
16301806 converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. After converted, the batchsize of
16311807 model input should between `MIN_BATCH` and `MAX_BATCH` and the `OPT_BATCH` is the best performance batchsize
@@ -1729,11 +1905,15 @@ def trt_export(
17291905 }
17301906 converter_kwargs_ .update (trt_api_parameters )
17311907
1732- save_ts = partial (save_net_with_metadata , include_config_vals = False , append_timestamp = False )
1908+ def _save_trt_model (trt_obj , filepath , ** kwargs ):
1909+ """Save TRT model without triggering deprecation warnings from internal calls."""
1910+ with warnings .catch_warnings ():
1911+ warnings .filterwarnings ("ignore" , category = FutureWarning , message = ".*save_net_with_metadata.*" )
1912+ save_net_with_metadata (trt_obj , filepath , include_config_vals = False , append_timestamp = False , ** kwargs )
17331913
17341914 _export (
17351915 convert_to_trt ,
1736- save_ts ,
1916+ _save_trt_model ,
17371917 parser ,
17381918 net_id = net_id_ ,
17391919 filepath = filepath_ ,
0 commit comments