Skip to content

Commit fb2f692

Browse files
committed
fix: skip tests that require torchvision operators
1 parent 503ca51 commit fb2f692

6 files changed

Lines changed: 40 additions & 27 deletions

File tree

tests/py/ts/api/test_e2e_behavior.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
import copy
2-
import importlib.util
32
import unittest
43
from typing import Dict
54

65
import torch
76
import torch_tensorrt as torchtrt
87
from utils import same_output_format
98

10-
if importlib.util.find_spec("torchvision"):
9+
try:
1110
import torchvision.models as models
1211

12+
HAS_TORCHVISION = True
13+
except (ImportError, RuntimeError):
14+
HAS_TORCHVISION = False
15+
1316

1417
@unittest.skipIf(
1518
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
1619
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
1720
)
1821
@unittest.skipIf(
19-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
22+
not HAS_TORCHVISION, "torchvision not available"
2023
)
2124
class TestInputTypeDefaultsFP32Model(unittest.TestCase):
2225

@@ -68,7 +71,7 @@ class TestInputTypeDefaultsFP16Model(unittest.TestCase):
6871
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
6972
)
7073
@unittest.skipIf(
71-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
74+
not HAS_TORCHVISION, "torchvision not available"
7275
)
7376
def test_input_use_default_fp16(self):
7477
self.model = models.resnet18(pretrained=True).eval().to("cuda")
@@ -89,7 +92,7 @@ def test_input_use_default_fp16(self):
8992
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
9093
)
9194
@unittest.skipIf(
92-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
95+
not HAS_TORCHVISION, "torchvision not available"
9396
)
9497
def test_input_use_default_fp16_without_fp16_enabled(self):
9598
self.model = models.resnet18(pretrained=True).eval().to("cuda")
@@ -108,7 +111,7 @@ def test_input_use_default_fp16_without_fp16_enabled(self):
108111
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
109112
)
110113
@unittest.skipIf(
111-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
114+
not HAS_TORCHVISION, "torchvision not available"
112115
)
113116
def test_input_respect_user_setting_fp16_weights_fp32_in(self):
114117
self.model = models.resnet18(pretrained=True).eval().to("cuda")
@@ -130,7 +133,7 @@ def test_input_respect_user_setting_fp16_weights_fp32_in(self):
130133
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
131134
)
132135
@unittest.skipIf(
133-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
136+
not HAS_TORCHVISION, "torchvision not available"
134137
)
135138
def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self):
136139
self.model = models.resnet18(pretrained=True).eval().to("cuda")

tests/py/ts/api/test_embed_engines.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
import copy
2-
import importlib
32
import unittest
43
from typing import Dict
54

65
import torch
76
import torch_tensorrt as torchtrt
87
from utils import COSINE_THRESHOLD, cosine_similarity
98

10-
if importlib.util.find_spec("torchvision"):
9+
try:
1110
import timm
1211
import torchvision.models as models
1312

13+
HAS_TORCHVISION = True
14+
except (ImportError, RuntimeError):
15+
HAS_TORCHVISION = False
16+
1417

1518
class TestModelToEngineToModel(unittest.TestCase):
1619
@unittest.skipIf(
17-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
20+
not HAS_TORCHVISION, "torchvision not available"
1821
)
1922
@unittest.skipIf(
2023
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
@@ -49,9 +52,8 @@ def test_resnet50(self):
4952
)
5053

5154
@unittest.skipIf(
52-
not importlib.util.find_spec("timm")
53-
or not importlib.util.find_spec("torchvision"),
54-
"timm or torchvision not installed",
55+
not HAS_TORCHVISION,
56+
"timm or torchvision not available",
5557
)
5658
@unittest.skipIf(
5759
torchtrt.ENABLED_FEATURES.tensorrt_rtx,

tests/py/ts/api/test_module_fallback.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1-
import importlib.util
21
import unittest
32

43
import torch
54
import torch_tensorrt as torchtrt
65
from utils import COSINE_THRESHOLD, cosine_similarity
76

8-
if importlib.util.find_spec("torchvision"):
7+
try:
98
import torchvision.models as models
109

10+
HAS_TORCHVISION = True
11+
except (ImportError, RuntimeError):
12+
HAS_TORCHVISION = False
13+
1114

1215
@unittest.skipIf(
1316
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
1417
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
1518
)
1619
@unittest.skipIf(
17-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
20+
not HAS_TORCHVISION, "torchvision not available"
1821
)
1922
class TestModuleFallback(unittest.TestCase):
2023
def test_fallback_resnet18(self):

tests/py/ts/api/test_operator_fallback.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1-
import importlib.util
21
import unittest
32

43
import torch
54
import torch_tensorrt as torchtrt
65
from utils import COSINE_THRESHOLD, cosine_similarity
76

8-
if importlib.util.find_spec("torchvision"):
7+
try:
98
import torchvision.models as models
109

10+
HAS_TORCHVISION = True
11+
except (ImportError, RuntimeError):
12+
HAS_TORCHVISION = False
13+
1114

1215
@unittest.skipIf(
1316
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
1417
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
1518
)
1619
@unittest.skipIf(
17-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
20+
not HAS_TORCHVISION, "torchvision not available"
1821
)
1922
class TestFallbackModels(unittest.TestCase):
2023
def test_fallback_resnet18(self):

tests/py/ts/api/test_ts_backend.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
import copy
2-
import importlib.util
32
import unittest
43
from typing import Dict
54

65
import torch
76
import torch_tensorrt as torchtrt
87
from utils import COSINE_THRESHOLD, cosine_similarity
98

10-
if importlib.util.find_spec("torchvision"):
9+
try:
1110
import torchvision.models as models
1211

12+
HAS_TORCHVISION = True
13+
except (ImportError, RuntimeError):
14+
HAS_TORCHVISION = False
15+
1316

1417
@unittest.skipIf(
15-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
18+
not HAS_TORCHVISION, "torchvision not available"
1619
)
1720
class TestCompile(unittest.TestCase):
1821
def test_compile_traced(self):
@@ -129,7 +132,7 @@ def test_default_device(self):
129132
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
130133
)
131134
@unittest.skipIf(
132-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
135+
not HAS_TORCHVISION, "torchvision not available"
133136
)
134137
class TestCheckMethodOpSupport(unittest.TestCase):
135138
def test_check_support(self):
@@ -138,7 +141,9 @@ def test_check_support(self):
138141

139142
self.assertTrue(torchtrt.ts.check_method_op_support(self.module, "forward"))
140143

141-
144+
@unittest.skipIf(
145+
not HAS_TORCHVISION, "torchvision not available"
146+
)
142147
class TestModuleIdentification(unittest.TestCase):
143148
def test_module_type(self):
144149
nn_module = models.alexnet(pretrained=True).eval().to("cuda")

uv.lock

Lines changed: 0 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)