@@ -77,7 +77,7 @@ def read_config(fname):
7777 return json .load (f )
7878
7979
80- def create_models (model_type , model_path , download_dir , force_onnx_adapter = False ):
80+ def create_models (model_type , model_path , download_dir , force_onnx_adapter = False , device = "CPU" ):
8181 if model_path .endswith (".onnx" ) and force_onnx_adapter :
8282 wrapper_type = model_type .get_model_class (
8383 load_parameters_from_onnx (onnx .load (model_path ))["model_info" ]["model_type" ],
@@ -92,15 +92,15 @@ def create_models(model_type, model_path, download_dir, force_onnx_adapter=False
9292 return [model ]
9393
9494 models = [
95- model_type .create_model (model_path , device = "CPU" , download_dir = download_dir ),
95+ model_type .create_model (model_path , device = device , download_dir = download_dir ),
9696 ]
9797 if model_path .endswith (".xml" ):
9898 model = create_core ().read_model (model_path )
9999 if model .has_rt_info (["model_info" , "model_type" ]):
100100 wrapper_type = model_type .get_model_class (
101101 create_core ().read_model (model_path ).get_rt_info (["model_info" , "model_type" ]).astype (str ),
102102 )
103- model = wrapper_type (OpenvinoAdapter (create_core (), model_path , device = "CPU" ))
103+ model = wrapper_type (OpenvinoAdapter (create_core (), model_path , device = device ))
104104 model .load ()
105105 models .append (model )
106106 return models
@@ -111,6 +111,11 @@ def data(pytestconfig):
111111 return pytestconfig .getoption ("data" )
112112
113113
114+ @pytest .fixture (scope = "session" )
115+ def device (pytestconfig ):
116+ return pytestconfig .getoption ("device" )
117+
118+
114119@pytest .fixture (scope = "session" )
115120def dump (pytestconfig ):
116121 return pytestconfig .getoption ("dump" )
@@ -125,7 +130,7 @@ def result(pytestconfig):
125130 ("model_data" ),
126131 read_config (Path (__file__ ).resolve ().parent / "public_scope.json" ),
127132)
128- def test_image_models (data , dump , result , model_data ): # noqa: C901
133+ def test_image_models (data , device , dump , result , model_data ): # noqa: C901
129134 name = model_data ["name" ]
130135 if name .endswith ((".xml" , ".onnx" )):
131136 name = f"{ data } /{ name } "
@@ -135,13 +140,14 @@ def test_image_models(data, dump, result, model_data): # noqa: C901
135140 name ,
136141 data ,
137142 model_data .get ("force_ort" , False ),
143+ device = device ,
138144 ):
139145 if "tiler" in model_data :
140146 if "extra_model" in model_data :
141147 extra_adapter = OpenvinoAdapter (
142148 create_core (),
143149 f"{ data } /{ model_data ['extra_model' ]} " ,
144- device = "CPU" ,
150+ device = device ,
145151 )
146152
147153 extra_model = MODEL_TYPE_MAPPING [model_data ["extra_type" ]](
@@ -160,7 +166,7 @@ def test_image_models(data, dump, result, model_data): # noqa: C901
160166 encoder_adapter = OpenvinoAdapter (
161167 create_core (),
162168 f"{ data } /{ model_data ['encoder' ]} " ,
163- device = "CPU" ,
169+ device = device ,
164170 )
165171
166172 encoder_model = MODEL_TYPE_MAPPING [model_data ["encoder_type" ]](
0 commit comments