11import copy
2- import importlib .util
32import unittest
43from typing import Dict
54
65import torch
76import torch_tensorrt as torchtrt
87from utils import same_output_format
98
10- if importlib . util . find_spec ( "torchvision" ) :
9+ try :
1110 import torchvision .models as models
1211
12+ HAS_TORCHVISION = True
13+ except (ImportError , RuntimeError ):
14+ HAS_TORCHVISION = False
15+
1316
1417@unittest .skipIf (
1518 torchtrt .ENABLED_FEATURES .tensorrt_rtx ,
1619 "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx" ,
1720)
1821@unittest .skipIf (
19- not importlib . util . find_spec ( "torchvision" ) , "torchvision not installed "
22+ not HAS_TORCHVISION , "torchvision not available "
2023)
2124class TestInputTypeDefaultsFP32Model (unittest .TestCase ):
2225
@@ -68,7 +71,7 @@ class TestInputTypeDefaultsFP16Model(unittest.TestCase):
6871 "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx" ,
6972 )
7073 @unittest .skipIf (
71- not importlib . util . find_spec ( "torchvision" ) , "torchvision not installed "
74+ not HAS_TORCHVISION , "torchvision not available "
7275 )
7376 def test_input_use_default_fp16 (self ):
7477 self .model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
@@ -89,7 +92,7 @@ def test_input_use_default_fp16(self):
8992 "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx" ,
9093 )
9194 @unittest .skipIf (
92- not importlib . util . find_spec ( "torchvision" ) , "torchvision not installed "
95+ not HAS_TORCHVISION , "torchvision not available "
9396 )
9497 def test_input_use_default_fp16_without_fp16_enabled (self ):
9598 self .model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
@@ -108,7 +111,7 @@ def test_input_use_default_fp16_without_fp16_enabled(self):
108111 "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx" ,
109112 )
110113 @unittest .skipIf (
111- not importlib . util . find_spec ( "torchvision" ) , "torchvision not installed "
114+ not HAS_TORCHVISION , "torchvision not available "
112115 )
113116 def test_input_respect_user_setting_fp16_weights_fp32_in (self ):
114117 self .model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
@@ -130,7 +133,7 @@ def test_input_respect_user_setting_fp16_weights_fp32_in(self):
130133 "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx" ,
131134 )
132135 @unittest .skipIf (
133- not importlib . util . find_spec ( "torchvision" ) , "torchvision not installed "
136+ not HAS_TORCHVISION , "torchvision not available "
134137 )
135138 def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor (self ):
136139 self .model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
0 commit comments