Skip to content

Commit 314ec6c

Browse files
committed
fix: skip tests that require torchvision operators
1 parent 503ca51 commit 314ec6c

File tree

6 files changed

+38
-46
lines changed

6 files changed

+38
-46
lines changed

tests/py/ts/api/test_e2e_behavior.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
import copy
2-
import importlib.util
32
import unittest
43
from typing import Dict
54

65
import torch
76
import torch_tensorrt as torchtrt
87
from utils import same_output_format
98

10-
if importlib.util.find_spec("torchvision"):
9+
try:
1110
import torchvision.models as models
1211

12+
HAS_TORCHVISION = True
13+
except (ImportError, RuntimeError):
14+
HAS_TORCHVISION = False
15+
1316

1417
@unittest.skipIf(
1518
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
1619
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
1720
)
18-
@unittest.skipIf(
19-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
20-
)
21+
@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
2122
class TestInputTypeDefaultsFP32Model(unittest.TestCase):
2223

2324
def test_input_use_default_fp32(self):
@@ -67,9 +68,7 @@ class TestInputTypeDefaultsFP16Model(unittest.TestCase):
6768
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
6869
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
6970
)
70-
@unittest.skipIf(
71-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
72-
)
71+
@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
7372
def test_input_use_default_fp16(self):
7473
self.model = models.resnet18(pretrained=True).eval().to("cuda")
7574
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
@@ -88,9 +87,7 @@ def test_input_use_default_fp16(self):
8887
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
8988
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
9089
)
91-
@unittest.skipIf(
92-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
93-
)
90+
@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
9491
def test_input_use_default_fp16_without_fp16_enabled(self):
9592
self.model = models.resnet18(pretrained=True).eval().to("cuda")
9693
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
@@ -107,9 +104,7 @@ def test_input_use_default_fp16_without_fp16_enabled(self):
107104
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
108105
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
109106
)
110-
@unittest.skipIf(
111-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
112-
)
107+
@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
113108
def test_input_respect_user_setting_fp16_weights_fp32_in(self):
114109
self.model = models.resnet18(pretrained=True).eval().to("cuda")
115110
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
@@ -129,9 +124,7 @@ def test_input_respect_user_setting_fp16_weights_fp32_in(self):
129124
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
130125
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
131126
)
132-
@unittest.skipIf(
133-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
134-
)
127+
@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
135128
def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self):
136129
self.model = models.resnet18(pretrained=True).eval().to("cuda")
137130
self.input = torch.randn((1, 3, 224, 224)).to("cuda")

tests/py/ts/api/test_embed_engines.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
import copy
2-
import importlib
32
import unittest
43
from typing import Dict
54

65
import torch
76
import torch_tensorrt as torchtrt
87
from utils import COSINE_THRESHOLD, cosine_similarity
98

10-
if importlib.util.find_spec("torchvision"):
9+
try:
1110
import timm
1211
import torchvision.models as models
1312

13+
HAS_TORCHVISION = True
14+
except (ImportError, RuntimeError):
15+
HAS_TORCHVISION = False
16+
1417

1518
class TestModelToEngineToModel(unittest.TestCase):
16-
@unittest.skipIf(
17-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
18-
)
19+
@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
1920
@unittest.skipIf(
2021
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
2122
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
@@ -49,9 +50,8 @@ def test_resnet50(self):
4950
)
5051

5152
@unittest.skipIf(
52-
not importlib.util.find_spec("timm")
53-
or not importlib.util.find_spec("torchvision"),
54-
"timm or torchvision not installed",
53+
not HAS_TORCHVISION,
54+
"timm or torchvision not available",
5555
)
5656
@unittest.skipIf(
5757
torchtrt.ENABLED_FEATURES.tensorrt_rtx,

tests/py/ts/api/test_module_fallback.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
1-
import importlib.util
21
import unittest
32

43
import torch
54
import torch_tensorrt as torchtrt
65
from utils import COSINE_THRESHOLD, cosine_similarity
76

8-
if importlib.util.find_spec("torchvision"):
7+
try:
98
import torchvision.models as models
109

10+
HAS_TORCHVISION = True
11+
except (ImportError, RuntimeError):
12+
HAS_TORCHVISION = False
13+
1114

1215
@unittest.skipIf(
1316
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
1417
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
1518
)
16-
@unittest.skipIf(
17-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
18-
)
19+
@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
1920
class TestModuleFallback(unittest.TestCase):
2021
def test_fallback_resnet18(self):
2122
self.model = models.resnet18(pretrained=True).eval().to("cuda")

tests/py/ts/api/test_operator_fallback.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
1-
import importlib.util
21
import unittest
32

43
import torch
54
import torch_tensorrt as torchtrt
65
from utils import COSINE_THRESHOLD, cosine_similarity
76

8-
if importlib.util.find_spec("torchvision"):
7+
try:
98
import torchvision.models as models
109

10+
HAS_TORCHVISION = True
11+
except (ImportError, RuntimeError):
12+
HAS_TORCHVISION = False
13+
1114

1215
@unittest.skipIf(
1316
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
1417
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
1518
)
16-
@unittest.skipIf(
17-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
18-
)
19+
@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
1920
class TestFallbackModels(unittest.TestCase):
2021
def test_fallback_resnet18(self):
2122
self.model = models.resnet18(pretrained=True).eval().to("cuda")

tests/py/ts/api/test_ts_backend.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
import copy
2-
import importlib.util
32
import unittest
43
from typing import Dict
54

65
import torch
76
import torch_tensorrt as torchtrt
87
from utils import COSINE_THRESHOLD, cosine_similarity
98

10-
if importlib.util.find_spec("torchvision"):
9+
try:
1110
import torchvision.models as models
1211

12+
HAS_TORCHVISION = True
13+
except (ImportError, RuntimeError):
14+
HAS_TORCHVISION = False
1315

14-
@unittest.skipIf(
15-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
16-
)
16+
17+
@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
1718
class TestCompile(unittest.TestCase):
1819
def test_compile_traced(self):
1920
self.model = models.vgg16(pretrained=True).eval().to("cuda")
@@ -128,9 +129,7 @@ def test_default_device(self):
128129
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
129130
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
130131
)
131-
@unittest.skipIf(
132-
not importlib.util.find_spec("torchvision"), "torchvision not installed"
133-
)
132+
@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
134133
class TestCheckMethodOpSupport(unittest.TestCase):
135134
def test_check_support(self):
136135
module = models.alexnet(pretrained=True).eval().to("cuda")
@@ -139,6 +138,7 @@ def test_check_support(self):
139138
self.assertTrue(torchtrt.ts.check_method_op_support(self.module, "forward"))
140139

141140

141+
@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
142142
class TestModuleIdentification(unittest.TestCase):
143143
def test_module_type(self):
144144
nn_module = models.alexnet(pretrained=True).eval().to("cuda")

uv.lock

Lines changed: 0 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)