Skip to content

Commit 53f8c1d

Browse files
committed
Migrate tests from TorchScript to torch.export
Add new test suites for export_utils, convert_to_export, and the export_checkpoint bundle CLI. Migrate ~60 existing test files from test_script (torch.jit.script) to test_export (torch.export.export), removing TorchScript-specific skip decorators where no longer needed. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent fc612cb commit 53f8c1d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+561
-234
lines changed

tests/apps/detection/networks/test_retinanet.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from monai.networks import eval_mode
2121
from monai.networks.nets import resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200
2222
from monai.utils import ensure_tuple, optional_import
23-
from tests.test_utils import dict_product, skip_if_quick, test_onnx_save, test_script_save
23+
from tests.test_utils import dict_product, skip_if_quick, test_export_save, test_onnx_save
2424

2525
_, has_torchvision = optional_import("torchvision")
2626
_, has_onnxruntime = optional_import("onnxruntime")
@@ -92,7 +92,7 @@
9292
MODEL_LIST = [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]
9393

9494
TEST_CASES = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=CASE_LIST)]
95-
TEST_CASES_TS = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=[TEST_CASE_1])]
95+
TEST_CASES_EXPORT = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=[TEST_CASE_1])]
9696

9797

9898
@unittest.skipUnless(has_torchvision, "Requires torchvision")
@@ -136,18 +136,18 @@ def test_retina_shape(self, model, input_param, input_shape):
136136
self.assertEqual(tuple(cc.shape for cc in result[net.cls_key]), expected_cls_shape)
137137
self.assertEqual(tuple(cc.shape for cc in result[net.box_reg_key]), expected_box_shape)
138138

139-
@parameterized.expand(TEST_CASES_TS)
140-
def test_script(self, model, input_param, input_shape):
139+
@parameterized.expand(TEST_CASES_EXPORT)
140+
def test_export(self, model, input_param, input_shape):
141141
try:
142-
idx = int(self.id().split("test_script_")[-1])
142+
idx = int(self.id().split("test_export_")[-1])
143143
except BaseException:
144144
idx = 0
145145
idx %= 3
146-
# test whether support torchscript
146+
# test whether support torch.export
147147
data = torch.randn(input_shape)
148148
backbone = model(**input_param)
149149
if idx == 0:
150-
test_script_save(backbone, data)
150+
test_export_save(backbone, data)
151151
return
152152
feature_extractor = resnet_fpn_feature_extractor(
153153
backbone=backbone,
@@ -157,7 +157,7 @@ def test_script(self, model, input_param, input_shape):
157157
returned_layers=[1, 2],
158158
)
159159
if idx == 1:
160-
test_script_save(feature_extractor, data)
160+
test_export_save(feature_extractor, data)
161161
return
162162
net = RetinaNet(
163163
spatial_dims=input_param["spatial_dims"],
@@ -167,17 +167,17 @@ def test_script(self, model, input_param, input_shape):
167167
size_divisible=32,
168168
)
169169
if idx == 2:
170-
test_script_save(net, data)
170+
test_export_save(net, data)
171171

172-
@parameterized.expand(TEST_CASES_TS)
172+
@parameterized.expand(TEST_CASES_EXPORT)
173173
@unittest.skipUnless(has_onnxruntime, "onnxruntime not installed")
174174
def test_onnx(self, model, input_param, input_shape):
175175
try:
176176
idx = int(self.id().split("test_onnx_")[-1])
177177
except BaseException:
178178
idx = 0
179179
idx %= 3
180-
# test whether support torchscript
180+
# test whether support torch.export
181181
data = torch.randn(input_shape)
182182
backbone = model(**input_param)
183183
if idx == 0:

tests/apps/detection/networks/test_retinanet_detector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from monai.apps.detection.utils.anchor_utils import AnchorGeneratorWithAnchorShape
2222
from monai.networks import eval_mode, train_mode
2323
from monai.utils import optional_import
24-
from tests.test_utils import skip_if_quick, test_script_save
24+
from tests.test_utils import skip_if_quick, test_export_save
2525

2626
_, has_torchvision = optional_import("torchvision")
2727

@@ -89,7 +89,7 @@
8989
TEST_CASES = []
9090
TEST_CASES = [TEST_CASE_1, TEST_CASE_2, TEST_CASE_2_A]
9191

92-
TEST_CASES_TS = [TEST_CASE_1]
92+
TEST_CASES_EXPORT = [TEST_CASE_1]
9393

9494

9595
class NaiveNetwork(torch.nn.Module):
@@ -183,9 +183,9 @@ def test_naive_retina_detector_shape(self, input_param, input_shape):
183183
targets = [one_target] * len(input_data)
184184
result = detector.forward(input_data, targets)
185185

186-
@parameterized.expand(TEST_CASES_TS)
187-
def test_script(self, input_param, input_shape):
188-
# test whether support torchscript
186+
@parameterized.expand(TEST_CASES_EXPORT)
187+
def test_export(self, input_param, input_shape):
188+
# test whether support torch.export
189189
returned_layers = [1]
190190
anchor_generator = AnchorGeneratorWithAnchorShape(
191191
feature_map_scales=(1, 2), base_anchor_shapes=((8,) * input_param["spatial_dims"],)
@@ -195,7 +195,7 @@ def test_script(self, input_param, input_shape):
195195
)
196196
with eval_mode(detector):
197197
input_data = torch.randn(input_shape)
198-
test_script_save(detector.network, input_data)
198+
test_export_save(detector.network, input_data)
199199

200200

201201
if __name__ == "__main__":

tests/apps/detection/utils/test_anchor_box.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from monai.apps.detection.utils.anchor_utils import AnchorGenerator, AnchorGeneratorWithAnchorShape
2020
from monai.utils import optional_import
21-
from tests.test_utils import assert_allclose, test_script_save
21+
from tests.test_utils import assert_allclose, test_export_save
2222

2323
_, has_torchvision = optional_import("torchvision")
2424

@@ -67,20 +67,20 @@ def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes):
6767
assert_allclose(a, a_f, type_test=True, device_test=False, atol=0.1)
6868

6969
@parameterized.expand(TEST_CASES_2D)
70-
def test_script_2d(self, input_param, image_shape, feature_maps_shapes):
71-
# test whether support torchscript
70+
def test_export_2d(self, input_param, image_shape, feature_maps_shapes):
71+
# test whether support torch.export
7272
anchor = AnchorGenerator(**input_param, indexing="xy")
7373
images = torch.rand(image_shape)
7474
feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes)
75-
test_script_save(anchor, images, feature_maps)
75+
test_export_save(anchor, images, feature_maps)
7676

7777
@parameterized.expand(TEST_CASES_SHAPE_3D)
78-
def test_script_3d(self, input_param, image_shape, feature_maps_shapes):
79-
# test whether support torchscript
78+
def test_export_3d(self, input_param, image_shape, feature_maps_shapes):
79+
# test whether support torch.export
8080
anchor = AnchorGeneratorWithAnchorShape(**input_param, indexing="ij")
8181
images = torch.rand(image_shape)
8282
feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes)
83-
test_script_save(anchor, images, feature_maps)
83+
test_export_save(anchor, images, feature_maps)
8484

8585

8686
if __name__ == "__main__":
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import json
15+
import os
16+
import tempfile
17+
import unittest
18+
from pathlib import Path
19+
20+
from parameterized import parameterized
21+
22+
from monai.bundle import ConfigParser
23+
from monai.data import load_exported_program
24+
from monai.networks import save_state
25+
from tests.test_utils import command_line_tests, skip_if_windows
26+
27+
TESTS_PATH = Path(__file__).parents[1]
28+
29+
# key_in_ckpt
30+
TEST_CASE_1 = [""]
31+
TEST_CASE_2 = ["model"]
32+
33+
34+
@skip_if_windows
35+
class TestExportCheckpoint(unittest.TestCase):
36+
37+
def setUp(self):
38+
self._orig_cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES")
39+
40+
def tearDown(self):
41+
if self._orig_cuda_env is not None:
42+
os.environ["CUDA_VISIBLE_DEVICES"] = self._orig_cuda_env
43+
else:
44+
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
45+
46+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
47+
def test_export(self, key_in_ckpt):
48+
meta_file = os.path.join(TESTS_PATH, "testing_data", "metadata.json")
49+
config_file = os.path.join(TESTS_PATH, "testing_data", "inference.json")
50+
with tempfile.TemporaryDirectory() as tempdir:
51+
def_args = {"meta_file": "will be replaced by `meta_file` arg"}
52+
def_args_file = os.path.join(tempdir, "def_args.yaml")
53+
54+
ckpt_file = os.path.join(tempdir, "model.pt")
55+
pt2_file = os.path.join(tempdir, "model.pt2")
56+
57+
parser = ConfigParser()
58+
parser.export_config_file(config=def_args, filepath=def_args_file)
59+
parser.read_config(config_file)
60+
net = parser.get_parsed_content("network_def")
61+
save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file)
62+
63+
cmd = [
64+
"coverage", "run", "-m", "monai.bundle", "export_checkpoint",
65+
"network_def", "--filepath", pt2_file,
66+
"--meta_file", meta_file,
67+
"--config_file", f"['{config_file}','{def_args_file}']",
68+
"--ckpt_file", ckpt_file,
69+
"--key_in_ckpt", key_in_ckpt,
70+
"--args_file", def_args_file,
71+
"--input_shape", "[1, 1, 96, 96, 96]",
72+
]
73+
command_line_tests(cmd)
74+
self.assertTrue(os.path.exists(pt2_file))
75+
76+
_, _metadata, extra_files = load_exported_program(
77+
pt2_file, more_extra_files=["inference.json", "def_args.json"]
78+
)
79+
self.assertIn("meta_file", json.loads(extra_files["def_args.json"]))
80+
self.assertIn("network_def", json.loads(extra_files["inference.json"]))
81+
82+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
83+
def test_default_value(self, key_in_ckpt):
84+
config_file = os.path.join(TESTS_PATH, "testing_data", "inference.json")
85+
with tempfile.TemporaryDirectory() as tempdir:
86+
def_args = {"meta_file": "will be replaced by `meta_file` arg"}
87+
def_args_file = os.path.join(tempdir, "def_args.yaml")
88+
ckpt_file = os.path.join(tempdir, "models", "model.pt")
89+
pt2_file = os.path.join(tempdir, "models", "model.pt2")
90+
91+
parser = ConfigParser()
92+
parser.export_config_file(config=def_args, filepath=def_args_file)
93+
parser.read_config(config_file)
94+
net = parser.get_parsed_content("network_def")
95+
save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file)
96+
97+
# check with default value
98+
cmd = [
99+
"coverage", "run", "-m", "monai.bundle", "export_checkpoint",
100+
"--key_in_ckpt", key_in_ckpt,
101+
"--config_file", config_file,
102+
"--bundle_root", tempdir,
103+
"--input_shape", "[1, 1, 96, 96, 96]",
104+
]
105+
command_line_tests(cmd)
106+
self.assertTrue(os.path.exists(pt2_file))
107+
108+
109+
if __name__ == "__main__":
110+
unittest.main()

tests/data/meta_tensor/test_meta_tensor.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,21 +224,22 @@ def test_get_set_meta_fns(self):
224224
self.assertTrue(get_track_meta())
225225

226226
@parameterized.expand(TEST_DEVICES)
227-
def test_torchscript(self, device):
227+
def test_export(self, device):
228228
shape = (1, 3, 10, 8)
229229
im, _ = self.get_im(shape, device=device)
230230
conv = torch.nn.Conv2d(im.shape[1], 5, 3)
231231
conv.to(device)
232232
im_conv = conv(im)
233-
traced_fn = torch.jit.trace(conv, im.as_tensor())
233+
exported = torch.export.export(conv, args=(im.as_tensor(),))
234234
# save it, load it, use it
235235
with tempfile.TemporaryDirectory() as tmp_dir:
236-
fname = os.path.join(tmp_dir, "im.pt")
237-
torch.jit.save(traced_fn, f=fname)
238-
traced_fn = torch.jit.load(fname)
239-
out = traced_fn(im)
236+
fname = os.path.join(tmp_dir, "im.pt2")
237+
torch.export.save(exported, fname)
238+
loaded = torch.export.load(fname)
239+
out = loaded.module()(im.as_tensor())
240240
self.assertIsInstance(out, torch.Tensor)
241-
self.check(out, im_conv, ids=False)
241+
# exported module returns plain Tensor, compare values only
242+
assert_allclose(out, im_conv)
242243

243244
def test_pickling(self):
244245
m, _ = self.get_im()

tests/data/test_export_utils.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import os
15+
import tempfile
16+
import unittest
17+
18+
import torch
19+
20+
from monai.config import get_config_values
21+
from monai.data import load_exported_program, save_exported_program
22+
from monai.utils import ExportMetadataKeys
23+
24+
25+
class TestModule(torch.nn.Module):
26+
__test__ = False
27+
28+
def forward(self, x):
29+
return x + 10
30+
31+
32+
class TestExportUtils(unittest.TestCase):
33+
34+
def test_save_exported_program(self):
35+
"""Save an exported program without metadata to a file."""
36+
ep = torch.export.export(TestModule(), args=(torch.tensor(1.0),))
37+
with tempfile.TemporaryDirectory() as tempdir:
38+
save_exported_program(ep, f"{tempdir}/test")
39+
self.assertTrue(os.path.isfile(f"{tempdir}/test.pt2"))
40+
41+
def test_save_exported_program_ext(self):
42+
"""Save an exported program to a file with custom extension."""
43+
ep = torch.export.export(TestModule(), args=(torch.tensor(1.0),))
44+
with tempfile.TemporaryDirectory() as tempdir:
45+
save_exported_program(ep, f"{tempdir}/test.zip")
46+
self.assertTrue(os.path.isfile(f"{tempdir}/test.zip"))
47+
48+
def test_save_with_metadata(self):
49+
"""Save an exported program with metadata to a file."""
50+
ep = torch.export.export(TestModule(), args=(torch.tensor(1.0),))
51+
test_metadata = {"foo": [1, 2], "bar": "string"}
52+
53+
with tempfile.TemporaryDirectory() as tempdir:
54+
save_exported_program(ep, f"{tempdir}/test", meta_values=test_metadata)
55+
self.assertTrue(os.path.isfile(f"{tempdir}/test.pt2"))
56+
57+
def test_load_exported_program(self):
58+
"""Save then load an exported program with no extra metadata."""
59+
ep = torch.export.export(TestModule(), args=(torch.tensor(1.0),))
60+
61+
with tempfile.TemporaryDirectory() as tempdir:
62+
save_exported_program(ep, f"{tempdir}/test")
63+
loaded_ep, meta, extra_files = load_exported_program(f"{tempdir}/test.pt2")
64+
65+
del meta[ExportMetadataKeys.TIMESTAMP.value]
66+
self.assertEqual(meta, get_config_values())
67+
self.assertEqual(extra_files, {})
68+
69+
# Verify the loaded program produces the same output
70+
result = loaded_ep.module()(torch.tensor(5.0))
71+
self.assertEqual(result.item(), 15.0)
72+
73+
def test_load_with_metadata(self):
74+
"""Save then load an exported program with metadata."""
75+
ep = torch.export.export(TestModule(), args=(torch.tensor(1.0),))
76+
test_metadata = {"foo": [1, 2], "bar": "string"}
77+
78+
with tempfile.TemporaryDirectory() as tempdir:
79+
save_exported_program(ep, f"{tempdir}/test", meta_values=test_metadata)
80+
_, meta, extra_files = load_exported_program(f"{tempdir}/test.pt2")
81+
82+
del meta[ExportMetadataKeys.TIMESTAMP.value]
83+
84+
test_compare = get_config_values()
85+
test_compare.update(test_metadata)
86+
self.assertEqual(meta, test_compare)
87+
self.assertEqual(extra_files, {})
88+
89+
def test_save_load_more_extra_files(self):
90+
"""Save then load extra file data from an exported program."""
91+
ep = torch.export.export(TestModule(), args=(torch.tensor(1.0),))
92+
test_metadata = {"foo": [1, 2], "bar": "string"}
93+
more_extra_files = {"test.txt": "This is test data"}
94+
95+
with tempfile.TemporaryDirectory() as tempdir:
96+
save_exported_program(ep, f"{tempdir}/test", meta_values=test_metadata, more_extra_files=more_extra_files)
97+
self.assertTrue(os.path.isfile(f"{tempdir}/test.pt2"))
98+
99+
_, _, loaded_extra_files = load_exported_program(f"{tempdir}/test.pt2", more_extra_files=("test.txt",))
100+
self.assertEqual(more_extra_files["test.txt"], loaded_extra_files["test.txt"])
101+
102+
103+
if __name__ == "__main__":
104+
unittest.main()

0 commit comments

Comments
 (0)