Skip to content

Commit 2db1d8f

Browse files
committed
Add torch.export save/load infrastructure and deprecate TorchScript utils
Introduce save_exported_program() and load_exported_program() in a new monai/data/export_utils.py module for the torch.export-based .pt2 serialisation format. Mark the legacy save_net_with_metadata() and load_net_with_metadata() as deprecated. Also adds ExportMetadataKeys enum and inlines warn_deprecated() into deprecate_utils to remove a circular-import hazard. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent 9ddd5e6 commit 2db1d8f

File tree

6 files changed

+146
-9
lines changed

6 files changed

+146
-9
lines changed

monai/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
from .test_time_augmentation import TestTimeAugmentation
7878
from .thread_buffer import ThreadBuffer, ThreadDataLoader
7979
from .torchscript_utils import load_net_with_metadata, save_net_with_metadata
80+
from .export_utils import load_exported_program, save_exported_program
8081
from .utils import (
8182
affine_to_spacing,
8283
compute_importance_map,

monai/data/export_utils.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import datetime
15+
import json
16+
import logging
17+
import os
18+
from collections.abc import Mapping, Sequence
19+
from typing import IO, Any
20+
21+
import torch
22+
23+
from monai.config import get_config_values
24+
from monai.data.torchscript_utils import METADATA_FILENAME
25+
from monai.utils import ExportMetadataKeys
26+
27+
__all__ = ["save_exported_program", "load_exported_program"]
28+
29+
30+
def save_exported_program(
31+
exported_program: torch.export.ExportedProgram,
32+
filename_prefix_or_stream: str | os.PathLike | IO[bytes],
33+
include_config_vals: bool = True,
34+
append_timestamp: bool = False,
35+
meta_values: Mapping[str, Any] | None = None,
36+
more_extra_files: Mapping[str, Any] | None = None,
37+
) -> None:
38+
"""
39+
Save an ``ExportedProgram`` produced by :func:`torch.export.export` with metadata included
40+
as a JSON file inside the ``.pt2`` archive.
41+
42+
Examples::
43+
44+
import torch
45+
from monai.networks.nets import UNet
46+
47+
net = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=[8, 16], strides=[2])
48+
ep = torch.export.export(net, args=(torch.rand(1, 1, 32, 32),))
49+
50+
meta = {"name": "Test UNet", "input_dims": 2}
51+
save_exported_program(ep, "test", meta_values=meta)
52+
53+
loaded_ep, loaded_meta, _ = load_exported_program("test.pt2")
54+
55+
Args:
56+
exported_program: an ``ExportedProgram`` returned by :func:`torch.export.export`.
57+
filename_prefix_or_stream: filename or file-like stream object.
58+
If a string filename has no extension it becomes ``.pt2``.
59+
include_config_vals: if True, MONAI, PyTorch, and NumPy versions are included in metadata.
60+
append_timestamp: if True, a timestamp is appended to the filename before the extension.
61+
meta_values: metadata values to store, compatible with JSON serialization.
62+
more_extra_files: additional data items to include in the archive.
63+
"""
64+
now = datetime.datetime.now()
65+
metadict: dict[str, Any] = {}
66+
67+
if include_config_vals:
68+
metadict.update(get_config_values())
69+
metadict[ExportMetadataKeys.TIMESTAMP.value] = now.astimezone().isoformat()
70+
71+
if meta_values is not None:
72+
metadict.update(meta_values)
73+
74+
json_data = json.dumps(metadict)
75+
76+
extra_files: dict[str, Any] = {METADATA_FILENAME: json_data}
77+
78+
if more_extra_files is not None:
79+
extra_files.update(more_extra_files)
80+
81+
# torch.export.save requires str values; decode bytes from legacy callers (e.g. _export helper)
82+
extra_files = {k: v.decode() if isinstance(v, bytes) else v for k, v in extra_files.items()}
83+
84+
if isinstance(filename_prefix_or_stream, (str, os.PathLike)):
85+
filename_prefix_or_stream = str(filename_prefix_or_stream)
86+
filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream)
87+
if ext == "":
88+
ext = ".pt2"
89+
90+
if append_timestamp:
91+
filename_prefix_or_stream = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}")
92+
else:
93+
filename_prefix_or_stream = filename_no_ext + ext
94+
95+
torch.export.save(exported_program, filename_prefix_or_stream, extra_files=extra_files)
96+
97+
98+
def load_exported_program(
99+
filename_prefix_or_stream: str | os.PathLike | IO[bytes],
100+
more_extra_files: Sequence[str] = (),
101+
) -> tuple[torch.export.ExportedProgram, dict, dict]:
102+
"""
103+
Load an ``ExportedProgram`` from a ``.pt2`` file and extract stored JSON metadata.
104+
105+
Args:
106+
filename_prefix_or_stream: filename or file-like stream object.
107+
more_extra_files: additional extra file names to load from the archive.
108+
109+
Returns:
110+
Triple of (ExportedProgram, metadata dict, extra files dict).
111+
"""
112+
extra_files: dict[str, Any] = dict.fromkeys(more_extra_files, "")
113+
extra_files[METADATA_FILENAME] = ""
114+
115+
exported_program = torch.export.load(filename_prefix_or_stream, extra_files=extra_files)
116+
117+
extra_files = dict(extra_files)
118+
119+
json_data = extra_files.pop(METADATA_FILENAME, "{}")
120+
121+
try:
122+
json_data_dict = json.loads(json_data)
123+
except json.JSONDecodeError:
124+
logging.getLogger(__name__).warning(
125+
"Failed to parse metadata JSON from exported program, returning empty metadata."
126+
)
127+
json_data_dict = {}
128+
129+
return exported_program, json_data_dict, extra_files

monai/data/torchscript_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121

2222
from monai.config import get_config_values
2323
from monai.utils import JITMetadataKeys
24+
from monai.utils.deprecate_utils import deprecated
2425

2526
METADATA_FILENAME = "metadata.json"
2627

2728

29+
@deprecated(since="1.5", removed="1.7", msg_suffix="Use monai.data.save_exported_program() instead.")
2830
def save_net_with_metadata(
2931
jit_obj: torch.nn.Module,
3032
filename_prefix_or_stream: str | IO[Any],
@@ -100,6 +102,7 @@ def save_net_with_metadata(
100102
torch.jit.save(jit_obj, filename_prefix_or_stream, extra_files)
101103

102104

105+
@deprecated(since="1.5", removed="1.7", msg_suffix="Use monai.data.load_exported_program() instead.")
103106
def load_net_with_metadata(
104107
filename_prefix_or_stream: str | IO[Any],
105108
map_location: torch.device | None = None,

monai/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
DiceCEReduction,
3232
DownsampleMode,
3333
EngineStatsKeys,
34+
ExportMetadataKeys,
3435
FastMRIKeys,
3536
ForwardMode,
3637
GanKeys,

monai/utils/deprecate_utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ class DeprecatedError(Exception):
3030
pass
3131

3232

33-
def warn_deprecated(obj, msg, warning_category=FutureWarning):
34-
"""
35-
Issue the warning message `msg`.
36-
"""
37-
warnings.warn(f"{obj}: {msg}", category=warning_category, stacklevel=2)
38-
3933

4034
def deprecated(
4135
since: str | None = None,
@@ -107,7 +101,7 @@ def _wrapper(*args, **kwargs):
107101
if is_removed:
108102
raise DeprecatedError(msg)
109103
if is_deprecated:
110-
warn_deprecated(obj, msg, warning_category)
104+
warnings.warn(f"{obj}: {msg}", category=warning_category, stacklevel=2)
111105

112106
return call_obj(*args, **kwargs)
113107

@@ -217,7 +211,7 @@ def _wrapper(*args, **kwargs):
217211
if is_removed:
218212
raise DeprecatedError(msg)
219213
if is_deprecated:
220-
warn_deprecated(argname, msg, warning_category)
214+
warnings.warn(f"{argname}: {msg}", category=warning_category, stacklevel=2)
221215

222216
return func(*args, **kwargs)
223217

@@ -317,7 +311,7 @@ def _decorator(func):
317311
def _wrapper(*args, **kwargs):
318312
if name not in sig.bind(*args, **kwargs).arguments and is_deprecated:
319313
# arg was not found so the default value is used
320-
warn_deprecated(argname, msg, warning_category)
314+
warnings.warn(f"{argname}: {msg}", category=warning_category, stacklevel=2)
321315

322316
return func(*args, **kwargs)
323317

monai/utils/enums.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
"BundlePropertyConfig",
6363
"AlgoKeys",
6464
"IgniteInfo",
65+
"JITMetadataKeys",
66+
"ExportMetadataKeys",
6567
]
6668

6769

@@ -423,6 +425,9 @@ class JITMetadataKeys(StrEnum):
423425
"""
424426
Keys stored in the metadata file for saved Torchscript models. Some of these are generated by the routines
425427
and others are optionally provided by users.
428+
429+
.. deprecated:: 1.5
430+
Use :class:`ExportMetadataKeys` instead.
426431
"""
427432

428433
NAME = "name"
@@ -431,6 +436,10 @@ class JITMetadataKeys(StrEnum):
431436
DESCRIPTION = "description"
432437

433438

439+
# ExportMetadataKeys shares the same members as JITMetadataKeys; alias to avoid duplication.
440+
ExportMetadataKeys = JITMetadataKeys
441+
442+
434443
class BoxModeName(StrEnum):
435444
"""
436445
Box mode names.

0 commit comments

Comments
 (0)