@@ -1218,9 +1218,9 @@ def tryCall(
12181218 def assert_onnx_disc (
12191219 self ,
12201220 test_name : str ,
1221- proto : "onnx.ModelProto" , # noqa: F821
1221+ proto : Union [ str , "onnx.ModelProto" ] , # noqa: F821
12221222 model : "torch.nn.Module" , # noqa: F821
1223- inputs : Union [Tuple [Any ], Dict [str , Any ]],
1223+ inputs : Union [Tuple [Any ], Dict [str , Any ], List [ Any ] ],
12241224 verbose : int = 0 ,
12251225 atol : float = 1e-5 ,
12261226 rtol : float = 1e-3 ,
@@ -1264,7 +1264,9 @@ def assert_onnx_disc(
12641264 name = f"{ test_name } .onnx"
12651265 if verbose :
12661266 print (f"[{ vname } ] save the onnx model into { name !r} " )
1267+ model_file = None
12671268 if isinstance (proto , str ):
1269+ model_file = proto
12681270 name = proto
12691271 proto = onnx .load (name )
12701272 elif not self .unit_test_going ():
@@ -1277,45 +1279,64 @@ def assert_onnx_disc(
12771279 if verbose :
12781280 print (f"[{ vname } ] make feeds { string_type (inputs , ** kws )} " )
12791281
1282+ if not isinstance (inputs , list ):
1283+ inputs = [inputs ]
1284+ if expected is not None :
1285+ expected = [expected ]
1286+
1287+ gots = []
12801288 if use_ort :
12811289 assert isinstance (
12821290 proto , onnx .ModelProto
12831291 ), f"Unexpected type { type (proto )} for proto"
1284- feeds = make_feeds (proto , inputs , use_numpy = True , copy = True )
12851292 import onnxruntime
12861293
12871294 options = onnxruntime .SessionOptions ()
12881295 if ort_optimized_graph :
12891296 options .optimized_model_filepath = f"{ name } .optort.onnx"
1297+ if "log_severity_level" in kwargs :
1298+ options .log_severity_level = kwargs ["log_severity_level" ]
1299+ if "log_verbosity_level" in kwargs :
1300+ options .log_verbosity_level = kwargs ["log_verbosity_level" ]
12901301 providers = kwargs .get ("providers" , ["CPUExecutionProvider" ])
12911302 if verbose :
12921303 print (f"[{ vname } ] create onnxruntime.InferenceSession with { providers } " )
12931304 sess = onnxruntime .InferenceSession (
1294- proto .SerializeToString (), options , providers = providers
1305+ model_file or proto .SerializeToString (), options , providers = providers
12951306 )
1296- if verbose :
1297- print (f"[{ vname } ] run ort feeds { string_type (feeds , ** kws )} " )
1298- got = sess .run (None , feeds )
1307+ for inp in inputs :
1308+ feeds = make_feeds (proto , inp , use_numpy = True , copy = True )
1309+ if verbose :
1310+ print (f"[{ vname } ] run ort feeds { string_type (feeds , ** kws )} " )
1311+ got = sess .run (None , feeds )
1312+ gots .append (got )
12991313 else :
1300- feeds = make_feeds (proto , inputs , copy = True )
13011314 if verbose :
13021315 print (f"[{ vname } ] create InferenceSessionForTorch" )
13031316 sess = InferenceSessionForTorch (proto , ** kwargs )
1304- if verbose :
1305- print (f"[{ vname } ] run orttorch feeds { string_type (feeds , ** kws )} " )
1306- got = sess .run (None , feeds )
1317+ for inp in inputs :
1318+ feeds = make_feeds (proto , inp , copy = True )
1319+ if verbose :
1320+ print (f"[{ vname } ] run orttorch feeds { string_type (feeds , ** kws )} " )
1321+ got = sess .run (None , feeds )
1322+ gots .append (got )
13071323 if verbose :
13081324 print (f"[{ vname } ] compute expected values" )
13091325
13101326 if expected is None :
13111327 if copy_inputs :
1312- expected = (
1313- model (* copy .deepcopy (inputs ))
1314- if isinstance (inputs , tuple )
1315- else model (** copy .deepcopy (inputs ))
1316- )
1328+ expected = [
1329+ (
1330+ model (* copy .deepcopy (inp ))
1331+ if isinstance (inp , tuple )
1332+ else model (** copy .deepcopy (inp ))
1333+ )
1334+ for inp in inputs
1335+ ]
13171336 else :
1318- expected = model (* inputs ) if isinstance (inputs , tuple ) else model (** inputs )
1337+ expected = [
1338+ model (* inp ) if isinstance (inp , tuple ) else model (** inp ) for inp in inputs
1339+ ]
13191340
13201341 if verbose :
13211342 print (f"[{ vname } ] expected { string_type (expected , ** kws )} " )
@@ -1328,47 +1349,50 @@ def assert_onnx_disc(
13281349 import torch
13291350
13301351 ep = torch .export .load (ep )
1331- ep_inputs = copy . deepcopy ( inputs ) if copy_inputs else inputs
1352+
13321353 ep_model = ep .module () # type: ignore[union-attr]
1333- ep_expected = (
1334- ep_model (* copy .deepcopy (ep_inputs ))
1335- if isinstance (ep_inputs , tuple )
1336- else ep_model (** copy .deepcopy (ep_inputs ))
1337- )
1338- if verbose :
1339- print (f"[{ vname } ] ep_expected { string_type (ep_expected , ** kws )} " )
1340- ep_diff = max_diff (expected , ep_expected , hist = [0.1 , 0.01 ])
1354+ for expe , inp , got in zip (expected , inputs , gots ):
1355+ ep_inputs = copy .deepcopy (inp ) if copy_inputs else inp
1356+ ep_expected = (
1357+ ep_model (* copy .deepcopy (ep_inputs ))
1358+ if isinstance (ep_inputs , tuple )
1359+ else ep_model (** copy .deepcopy (ep_inputs ))
1360+ )
1361+ if verbose :
1362+ print (f"[{ vname } ] ep_expected { string_type (ep_expected , ** kws )} " )
1363+ ep_diff = max_diff (expe , ep_expected , hist = [0.1 , 0.01 ])
1364+ if verbose :
1365+ print (f"[{ vname } ] ep_diff { string_diff (ep_diff )} " )
1366+ assert (
1367+ isinstance (ep_diff ["abs" ], float )
1368+ and isinstance (ep_diff ["rel" ], float )
1369+ and not numpy .isnan (ep_diff ["abs" ])
1370+ and ep_diff ["abs" ] <= atol
1371+ and not numpy .isnan (ep_diff ["rel" ])
1372+ and ep_diff ["rel" ] <= rtol
1373+ ), (
1374+ f"discrepancies in { test_name !r} between the exported program "
1375+ f"and the exported model diff={ string_diff (ep_diff )} "
1376+ )
1377+ ep_nx_diff = max_diff (ep_expected , got , flatten = True , hist = [0.1 , 0.01 ])
1378+ if verbose :
1379+ print (f"[{ vname } ] ep_nx_diff { string_diff (ep_nx_diff )} " )
1380+
1381+ for expe , got in zip (expected , gots ):
1382+ diff = max_diff (expe , got , flatten = True , hist = [0.1 , 0.01 ])
13411383 if verbose :
1342- print (f"[{ vname } ] ep_diff { string_diff (ep_diff )} " )
1384+ print (f"[{ vname } ] diff { string_diff (diff )} " )
13431385 assert (
1344- isinstance (ep_diff ["abs" ], float )
1345- and isinstance (ep_diff ["rel" ], float )
1346- and not numpy .isnan (ep_diff ["abs" ])
1347- and ep_diff ["abs" ] <= atol
1348- and not numpy .isnan (ep_diff ["rel" ])
1349- and ep_diff ["rel" ] <= rtol
1386+ isinstance (diff ["abs" ], float )
1387+ and isinstance (diff ["rel" ], float )
1388+ and not numpy .isnan (diff ["abs" ])
1389+ and diff ["abs" ] <= atol
1390+ and not numpy .isnan (diff ["rel" ])
1391+ and diff ["rel" ] <= rtol
13501392 ), (
1351- f"discrepancies in { test_name !r} between the exported program "
1352- f"and the exported model diff={ string_diff (ep_diff )} "
1393+ f"discrepancies in { test_name !r} between the model and "
1394+ f"the onnx model diff={ string_diff (diff )} "
13531395 )
1354- ep_nx_diff = max_diff (ep_expected , got , flatten = True , hist = [0.1 , 0.01 ])
1355- if verbose :
1356- print (f"[{ vname } ] ep_nx_diff { string_diff (ep_nx_diff )} " )
1357-
1358- diff = max_diff (expected , got , flatten = True , hist = [0.1 , 0.01 ])
1359- if verbose :
1360- print (f"[{ vname } ] diff { string_diff (diff )} " )
1361- assert (
1362- isinstance (diff ["abs" ], float )
1363- and isinstance (diff ["rel" ], float )
1364- and not numpy .isnan (diff ["abs" ])
1365- and diff ["abs" ] <= atol
1366- and not numpy .isnan (diff ["rel" ])
1367- and diff ["rel" ] <= rtol
1368- ), (
1369- f"discrepancies in { test_name !r} between the model and "
1370- f"the onnx model diff={ string_diff (diff )} "
1371- )
13721396
13731397 def _debug (self ):
13741398 "Tells if DEBUG=1 is set up."
0 commit comments