@@ -1218,7 +1218,7 @@ def tryCall(
12181218 def assert_onnx_disc (
12191219 self ,
12201220 test_name : str ,
1221- proto : "onnx.ModelProto" , # noqa: F821
1221+ proto : Union [ str , "onnx.ModelProto" ] , # noqa: F821
12221222 model : "torch.nn.Module" , # noqa: F821
12231223 inputs : Union [Tuple [Any ], Dict [str , Any ]],
12241224 verbose : int = 0 ,
@@ -1264,7 +1264,9 @@ def assert_onnx_disc(
12641264 name = f"{ test_name } .onnx"
12651265 if verbose :
12661266 print (f"[{ vname } ] save the onnx model into { name !r} " )
1267+ model_file = None
12671268 if isinstance (proto , str ):
1269+ model_file = proto
12681270 name = proto
12691271 proto = onnx .load (name )
12701272 elif not self .unit_test_going ():
@@ -1287,11 +1289,15 @@ def assert_onnx_disc(
12871289 options = onnxruntime .SessionOptions ()
12881290 if ort_optimized_graph :
12891291 options .optimized_model_filepath = f"{ name } .optort.onnx"
1292+ if "log_severity_level" in kwargs :
1293+ options .log_severity_level = kwargs ["log_severity_level" ]
1294+ if "log_verbosity_level" in kwargs :
1295+ options .log_verbosity_level = kwargs ["log_verbosity_level" ]
12901296 providers = kwargs .get ("providers" , ["CPUExecutionProvider" ])
12911297 if verbose :
12921298 print (f"[{ vname } ] create onnxruntime.InferenceSession with { providers } " )
12931299 sess = onnxruntime .InferenceSession (
1294- proto .SerializeToString (), options , providers = providers
1300+ model_file or proto .SerializeToString (), options , providers = providers
12951301 )
12961302 if verbose :
12971303 print (f"[{ vname } ] run ort feeds { string_type (feeds , ** kws )} " )
0 commit comments