3030
3131from monai .apps .utils import get_logger
3232from monai .config import PathLike
33+ from monai .utils .deprecate_utils import deprecated
3334from monai .utils .misc import ensure_tuple , save_obj , set_determinism
34- from monai .utils .module import look_up_option , optional_import
35+ from monai .utils .module import look_up_option , optional_import , pytorch_after
3536from monai .utils .type_conversion import convert_to_dst_type , convert_to_tensor
3637
3738onnx , _ = optional_import ("onnx" )
5758 "save_state" ,
5859 "convert_to_onnx" ,
5960 "convert_to_torchscript" ,
61+ "convert_to_export" ,
6062 "convert_to_trt" ,
6163 "meshgrid_ij" ,
6264 "meshgrid_xy" ,
@@ -793,6 +795,16 @@ def convert_to_onnx(
793795 return onnx_model
794796
795797
798+ def _recursive_to (x , device ):
799+ """Recursively move tensors (and nested tuples/lists of tensors) to *device*."""
800+ if isinstance (x , torch .Tensor ):
801+ return x .to (device )
802+ if isinstance (x , (tuple , list )):
803+ return type (x )(_recursive_to (i , device ) for i in x )
804+ return x
805+
806+
807+ @deprecated (since = "1.5" , removed = "1.7" , msg_suffix = "Use convert_to_export() instead." )
796808def convert_to_torchscript (
797809 model : nn .Module ,
798810 filename_or_obj : Any | None = None ,
@@ -863,6 +875,82 @@ def convert_to_torchscript(
863875 return script_module
864876
865877
878+ def convert_to_export (
879+ model : nn .Module ,
880+ filename_or_obj : Any | None = None ,
881+ extra_files : dict | None = None ,
882+ verify : bool = False ,
883+ inputs : Sequence [Any ] | None = None ,
884+ dynamic_shapes : dict | tuple | None = None ,
885+ device : str | torch .device | None = None ,
886+ rtol : float = 1e-4 ,
887+ atol : float = 0.0 ,
888+ ** kwargs ,
889+ ) -> torch .export .ExportedProgram :
890+ """
891+ Utility to export a model using :func:`torch.export.export` and optionally save to a ``.pt2`` file,
892+ with optional input/output data verification.
893+
894+ Args:
895+ model: source PyTorch model to export.
896+ filename_or_obj: if not None, a file path string to save the exported program.
897+ extra_files: map from filename to contents to store in the saved archive.
898+ verify: whether to verify the input and output of the exported model.
899+ If ``filename_or_obj`` is not None, loads the saved model and verifies.
900+ inputs: input test data for export and verification. Should be a sequence of
901+ tensors that map to positional arguments of ``model()``.
902+ dynamic_shapes: dynamic shape specifications passed to :func:`torch.export.export`.
903+ See PyTorch docs for format details.
904+ device: target device to verify the model. If None, uses CUDA if available.
905+ rtol: the relative tolerance when comparing outputs.
906+ atol: the absolute tolerance when comparing outputs.
907+ kwargs: additional keyword arguments for :func:`torch.export.export`.
908+
909+ Returns:
910+ A :class:`torch.export.ExportedProgram` representing the exported model.
911+ """
912+ if inputs is None :
913+ raise ValueError ("Input data is required for torch.export.export." )
914+
915+ model .eval ()
916+ with torch .no_grad ():
917+ export_args = tuple (inputs )
918+ exported = torch .export .export (model , args = export_args , dynamic_shapes = dynamic_shapes , ** kwargs )
919+
920+ if filename_or_obj is not None :
921+ save_extra : dict [str , Any ] = {}
922+ if extra_files is not None :
923+ # torch.export.save requires str values; decode bytes from legacy callers
924+ save_extra .update ({k : v .decode () if isinstance (v , bytes ) else v for k , v in extra_files .items ()})
925+ torch .export .save (exported , filename_or_obj , extra_files = save_extra if save_extra else None )
926+
927+ if verify :
928+ if device is None :
929+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
930+
931+ verify_args = tuple (_recursive_to (i , device ) for i in inputs )
932+
933+ # Always verify against the in-memory export to avoid device placement
934+ # issues that can occur when reloading from file (torch.export.load does
935+ # not support map_location).
936+ loaded_module = exported .module ()
937+ loaded_module .to (device )
938+ model .to (device )
939+
940+ with torch .no_grad ():
941+ set_determinism (seed = 0 )
942+ torch_out = ensure_tuple (model (* verify_args ))
943+ set_determinism (seed = 0 )
944+ export_out = ensure_tuple (loaded_module (* verify_args ))
945+ set_determinism (seed = None )
946+
947+ for r1 , r2 in zip (torch_out , export_out ):
948+ if isinstance (r1 , torch .Tensor ) or isinstance (r2 , torch .Tensor ):
949+ torch .testing .assert_close (r1 , r2 , rtol = rtol , atol = atol ) # type: ignore
950+
951+ return exported
952+
953+
866954def _onnx_trt_compile (
867955 onnx_model ,
868956 min_shape : Sequence [int ],
@@ -1012,9 +1100,9 @@ def convert_to_trt(
10121100 convert_precision = torch .float32 if precision == "fp32" else torch .half
10131101 inputs = [torch .rand (ensure_tuple (input_shape )).to (target_device )]
10141102
1015- # convert the torch model to a TorchScript model on target device
10161103 model = model .eval ().to (target_device )
10171104 min_input_shape , opt_input_shape , max_input_shape = get_profile_shapes (input_shape , dynamic_batchsize )
1105+ _use_dynamo = pytorch_after (2 , 9 )
10181106
10191107 if use_onnx :
10201108 # set the batch dim as dynamic
@@ -1035,40 +1123,55 @@ def convert_to_trt(
10351123 output_names = onnx_output_names ,
10361124 )
10371125 else :
1038- ir_model = convert_to_torchscript (model , device = target_device , inputs = inputs , use_trace = use_trace )
1039- ir_model .eval ()
1040- # convert the model through the Torch-TensorRT way
1041- ir_model .to (target_device )
1126+ # Torch-TensorRT compilation path
10421127 with torch .no_grad ():
10431128 with torch .cuda .device (device = device ):
10441129 input_placeholder = [
10451130 torch_tensorrt .Input (
10461131 min_shape = min_input_shape , opt_shape = opt_input_shape , max_shape = max_input_shape
10471132 )
10481133 ]
1049- trt_model = torch_tensorrt .compile (
1050- ir_model ,
1051- inputs = input_placeholder ,
1052- enabled_precisions = convert_precision ,
1053- device = torch_tensorrt .Device (f"cuda:{ device } " ),
1054- ir = "torchscript" ,
1055- ** kwargs ,
1056- )
1134+ # Use dynamo IR (torch.export-based) which is the default in newer torch-tensorrt
1135+ if _use_dynamo :
1136+ trt_model = torch_tensorrt .compile (
1137+ model ,
1138+ inputs = input_placeholder ,
1139+ enabled_precisions = convert_precision ,
1140+ device = torch_tensorrt .Device (f"cuda:{ device } " ),
1141+ ir = "dynamo" ,
1142+ ** kwargs ,
1143+ )
1144+ else :
1145+ ir_model = convert_to_torchscript (
1146+ model , device = target_device , inputs = inputs , use_trace = use_trace
1147+ )
1148+ trt_model = torch_tensorrt .compile (
1149+ ir_model ,
1150+ inputs = input_placeholder ,
1151+ enabled_precisions = convert_precision ,
1152+ device = torch_tensorrt .Device (f"cuda:{ device } " ),
1153+ ir = "torchscript" ,
1154+ ** kwargs ,
1155+ )
10571156
10581157 # verify the outputs between the TensorRT model and PyTorch model
10591158 if verify :
10601159 if inputs is None :
10611160 raise ValueError ("Missing input data for verification." )
10621161
1063- trt_model = torch .jit .load (filename_or_obj ) if filename_or_obj is not None else trt_model
1162+ if filename_or_obj is not None :
1163+ if _use_dynamo :
1164+ trt_model = torch .export .load (filename_or_obj ).module ()
1165+ else :
1166+ trt_model = torch .jit .load (filename_or_obj )
10641167
10651168 with torch .no_grad ():
10661169 set_determinism (seed = 0 )
10671170 torch_out = ensure_tuple (model (* inputs ))
10681171 set_determinism (seed = 0 )
10691172 trt_out = ensure_tuple (trt_model (* inputs ))
10701173 set_determinism (seed = None )
1071- # compare TorchScript and PyTorch results
1174+ # compare TensorRT and PyTorch results
10721175 for r1 , r2 in zip (torch_out , trt_out ):
10731176 if isinstance (r1 , torch .Tensor ) or isinstance (r2 , torch .Tensor ):
10741177 torch .testing .assert_close (r1 , r2 , rtol = rtol , atol = atol ) # type: ignore
0 commit comments