Skip to content

Commit a769c7e

Browse files
committed
Add convert_to_export() and deprecate convert_to_torchscript()
Introduce convert_to_export() as the torch.export replacement for convert_to_torchscript(), with a _recursive_to() helper for moving ExportedProgram state dicts across devices. Mark convert_to_torchscript() as deprecated and update convert_to_trt() to use the dynamo path gated behind pytorch_after(). Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent 2db1d8f commit a769c7e

File tree

2 files changed

+120
-16
lines changed

2 files changed

+120
-16
lines changed

monai/networks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .trt_compiler import trt_compile
1515
from .utils import (
1616
add_casts_around_norms,
17+
convert_to_export,
1718
convert_to_onnx,
1819
convert_to_torchscript,
1920
convert_to_trt,

monai/networks/utils.py

Lines changed: 119 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@
3030

3131
from monai.apps.utils import get_logger
3232
from monai.config import PathLike
33+
from monai.utils.deprecate_utils import deprecated
3334
from monai.utils.misc import ensure_tuple, save_obj, set_determinism
34-
from monai.utils.module import look_up_option, optional_import
35+
from monai.utils.module import look_up_option, optional_import, pytorch_after
3536
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
3637

3738
onnx, _ = optional_import("onnx")
@@ -57,6 +58,7 @@
5758
"save_state",
5859
"convert_to_onnx",
5960
"convert_to_torchscript",
61+
"convert_to_export",
6062
"convert_to_trt",
6163
"meshgrid_ij",
6264
"meshgrid_xy",
@@ -793,6 +795,16 @@ def convert_to_onnx(
793795
return onnx_model
794796

795797

798+
def _recursive_to(x, device):
799+
"""Recursively move tensors (and nested tuples/lists of tensors) to *device*."""
800+
if isinstance(x, torch.Tensor):
801+
return x.to(device)
802+
if isinstance(x, (tuple, list)):
803+
return type(x)(_recursive_to(i, device) for i in x)
804+
return x
805+
806+
807+
@deprecated(since="1.5", removed="1.7", msg_suffix="Use convert_to_export() instead.")
796808
def convert_to_torchscript(
797809
model: nn.Module,
798810
filename_or_obj: Any | None = None,
@@ -863,6 +875,82 @@ def convert_to_torchscript(
863875
return script_module
864876

865877

878+
def convert_to_export(
879+
model: nn.Module,
880+
filename_or_obj: Any | None = None,
881+
extra_files: dict | None = None,
882+
verify: bool = False,
883+
inputs: Sequence[Any] | None = None,
884+
dynamic_shapes: dict | tuple | None = None,
885+
device: str | torch.device | None = None,
886+
rtol: float = 1e-4,
887+
atol: float = 0.0,
888+
**kwargs,
889+
) -> torch.export.ExportedProgram:
890+
"""
891+
Utility to export a model using :func:`torch.export.export` and optionally save to a ``.pt2`` file,
892+
with optional input/output data verification.
893+
894+
Args:
895+
model: source PyTorch model to export.
896+
filename_or_obj: if not None, a file path string to save the exported program.
897+
extra_files: map from filename to contents to store in the saved archive.
898+
verify: whether to verify the input and output of the exported model.
899+
If ``filename_or_obj`` is not None, loads the saved model and verifies.
900+
inputs: input test data for export and verification. Should be a sequence of
901+
tensors that map to positional arguments of ``model()``.
902+
dynamic_shapes: dynamic shape specifications passed to :func:`torch.export.export`.
903+
See PyTorch docs for format details.
904+
device: target device to verify the model. If None, uses CUDA if available.
905+
rtol: the relative tolerance when comparing outputs.
906+
atol: the absolute tolerance when comparing outputs.
907+
kwargs: additional keyword arguments for :func:`torch.export.export`.
908+
909+
Returns:
910+
A :class:`torch.export.ExportedProgram` representing the exported model.
911+
"""
912+
if inputs is None:
913+
raise ValueError("Input data is required for torch.export.export.")
914+
915+
model.eval()
916+
with torch.no_grad():
917+
export_args = tuple(inputs)
918+
exported = torch.export.export(model, args=export_args, dynamic_shapes=dynamic_shapes, **kwargs)
919+
920+
if filename_or_obj is not None:
921+
save_extra: dict[str, Any] = {}
922+
if extra_files is not None:
923+
# torch.export.save requires str values; decode bytes from legacy callers
924+
save_extra.update({k: v.decode() if isinstance(v, bytes) else v for k, v in extra_files.items()})
925+
torch.export.save(exported, filename_or_obj, extra_files=save_extra if save_extra else None)
926+
927+
if verify:
928+
if device is None:
929+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
930+
931+
verify_args = tuple(_recursive_to(i, device) for i in inputs)
932+
933+
# Always verify against the in-memory export to avoid device placement
934+
# issues that can occur when reloading from file (torch.export.load does
935+
# not support map_location).
936+
loaded_module = exported.module()
937+
loaded_module.to(device)
938+
model.to(device)
939+
940+
with torch.no_grad():
941+
set_determinism(seed=0)
942+
torch_out = ensure_tuple(model(*verify_args))
943+
set_determinism(seed=0)
944+
export_out = ensure_tuple(loaded_module(*verify_args))
945+
set_determinism(seed=None)
946+
947+
for r1, r2 in zip(torch_out, export_out):
948+
if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor):
949+
torch.testing.assert_close(r1, r2, rtol=rtol, atol=atol) # type: ignore
950+
951+
return exported
952+
953+
866954
def _onnx_trt_compile(
867955
onnx_model,
868956
min_shape: Sequence[int],
@@ -1012,9 +1100,9 @@ def convert_to_trt(
10121100
convert_precision = torch.float32 if precision == "fp32" else torch.half
10131101
inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)]
10141102

1015-
# convert the torch model to a TorchScript model on target device
10161103
model = model.eval().to(target_device)
10171104
min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize)
1105+
_use_dynamo = pytorch_after(2, 9)
10181106

10191107
if use_onnx:
10201108
# set the batch dim as dynamic
@@ -1035,40 +1123,55 @@ def convert_to_trt(
10351123
output_names=onnx_output_names,
10361124
)
10371125
else:
1038-
ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace)
1039-
ir_model.eval()
1040-
# convert the model through the Torch-TensorRT way
1041-
ir_model.to(target_device)
1126+
# Torch-TensorRT compilation path
10421127
with torch.no_grad():
10431128
with torch.cuda.device(device=device):
10441129
input_placeholder = [
10451130
torch_tensorrt.Input(
10461131
min_shape=min_input_shape, opt_shape=opt_input_shape, max_shape=max_input_shape
10471132
)
10481133
]
1049-
trt_model = torch_tensorrt.compile(
1050-
ir_model,
1051-
inputs=input_placeholder,
1052-
enabled_precisions=convert_precision,
1053-
device=torch_tensorrt.Device(f"cuda:{device}"),
1054-
ir="torchscript",
1055-
**kwargs,
1056-
)
1134+
# Use dynamo IR (torch.export-based) which is the default in newer torch-tensorrt
1135+
if _use_dynamo:
1136+
trt_model = torch_tensorrt.compile(
1137+
model,
1138+
inputs=input_placeholder,
1139+
enabled_precisions=convert_precision,
1140+
device=torch_tensorrt.Device(f"cuda:{device}"),
1141+
ir="dynamo",
1142+
**kwargs,
1143+
)
1144+
else:
1145+
ir_model = convert_to_torchscript(
1146+
model, device=target_device, inputs=inputs, use_trace=use_trace
1147+
)
1148+
trt_model = torch_tensorrt.compile(
1149+
ir_model,
1150+
inputs=input_placeholder,
1151+
enabled_precisions=convert_precision,
1152+
device=torch_tensorrt.Device(f"cuda:{device}"),
1153+
ir="torchscript",
1154+
**kwargs,
1155+
)
10571156

10581157
# verify the outputs between the TensorRT model and PyTorch model
10591158
if verify:
10601159
if inputs is None:
10611160
raise ValueError("Missing input data for verification.")
10621161

1063-
trt_model = torch.jit.load(filename_or_obj) if filename_or_obj is not None else trt_model
1162+
if filename_or_obj is not None:
1163+
if _use_dynamo:
1164+
trt_model = torch.export.load(filename_or_obj).module()
1165+
else:
1166+
trt_model = torch.jit.load(filename_or_obj)
10641167

10651168
with torch.no_grad():
10661169
set_determinism(seed=0)
10671170
torch_out = ensure_tuple(model(*inputs))
10681171
set_determinism(seed=0)
10691172
trt_out = ensure_tuple(trt_model(*inputs))
10701173
set_determinism(seed=None)
1071-
# compare TorchScript and PyTorch results
1174+
# compare TensorRT and PyTorch results
10721175
for r1, r2 in zip(torch_out, trt_out):
10731176
if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor):
10741177
torch.testing.assert_close(r1, r2, rtol=rtol, atol=atol) # type: ignore

0 commit comments

Comments
 (0)