@@ -1220,7 +1220,7 @@ def assert_onnx_disc(
12201220 test_name : str ,
12211221 proto : Union [str , "onnx.ModelProto" ], # noqa: F821
12221222 model : "torch.nn.Module" , # noqa: F821
1223- inputs : Union [Tuple [Any ], Dict [str , Any ]],
1223+ inputs : Union [Tuple [Any ], Dict [str , Any ], List [ Any ] ],
12241224 verbose : int = 0 ,
12251225 atol : float = 1e-5 ,
12261226 rtol : float = 1e-3 ,
@@ -1279,11 +1279,16 @@ def assert_onnx_disc(
12791279 if verbose :
12801280 print (f"[{ vname } ] make feeds { string_type (inputs , ** kws )} " )
12811281
1282+ if not isinstance (inputs , list ):
1283+ inputs = [inputs ]
1284+ if expected is not None :
1285+ expected = [expected ]
1286+
1287+ gots = []
12821288 if use_ort :
12831289 assert isinstance (
12841290 proto , onnx .ModelProto
12851291 ), f"Unexpected type { type (proto )} for proto"
1286- feeds = make_feeds (proto , inputs , use_numpy = True , copy = True )
12871292 import onnxruntime
12881293
12891294 options = onnxruntime .SessionOptions ()
@@ -1299,29 +1304,40 @@ def assert_onnx_disc(
12991304 sess = onnxruntime .InferenceSession (
13001305 model_file or proto .SerializeToString (), options , providers = providers
13011306 )
1302- if verbose :
1303- print (f"[{ vname } ] run ort feeds { string_type (feeds , ** kws )} " )
1304- got = sess .run (None , feeds )
1307+ for inp in inputs :
1308+ feeds = make_feeds (proto , inp , use_numpy = True , copy = True )
1309+ if verbose :
1310+ print (f"[{ vname } ] run ort feeds { string_type (feeds , ** kws )} " )
1311+ got = sess .run (None , feeds )
1312+ gots .append (got )
13051313 else :
1306- feeds = make_feeds (proto , inputs , copy = True )
13071314 if verbose :
13081315 print (f"[{ vname } ] create InferenceSessionForTorch" )
13091316 sess = InferenceSessionForTorch (proto , ** kwargs )
1310- if verbose :
1311- print (f"[{ vname } ] run orttorch feeds { string_type (feeds , ** kws )} " )
1312- got = sess .run (None , feeds )
1317+ for inp in inputs :
1318+ feeds = make_feeds (proto , inp , copy = True )
1319+ if verbose :
1320+ print (f"[{ vname } ] run orttorch feeds { string_type (feeds , ** kws )} " )
1321+ got = sess .run (None , feeds )
1322+ gots .append (got )
13131323 if verbose :
13141324 print (f"[{ vname } ] compute expected values" )
13151325
13161326 if expected is None :
13171327 if copy_inputs :
1318- expected = (
1319- model (* copy .deepcopy (inputs ))
1320- if isinstance (inputs , tuple )
1321- else model (** copy .deepcopy (inputs ))
1322- )
1328+ expected = [
1329+ (
1330+ model (* copy .deepcopy (inputs ))
1331+ if isinstance (inputs , tuple )
1332+ else model (** copy .deepcopy (inp ))
1333+ )
1334+ for inp in inputs
1335+ ]
13231336 else :
1324- expected = model (* inputs ) if isinstance (inputs , tuple ) else model (** inputs )
1337+ expected = [
1338+ model (* inp ) if isinstance (inputs , tuple ) else model (** inp )
1339+ for inp in inputs
1340+ ]
13251341
13261342 if verbose :
13271343 print (f"[{ vname } ] expected { string_type (expected , ** kws )} " )
@@ -1334,47 +1350,50 @@ def assert_onnx_disc(
13341350 import torch
13351351
13361352 ep = torch .export .load (ep )
1337- ep_inputs = copy . deepcopy ( inputs ) if copy_inputs else inputs
1353+
13381354 ep_model = ep .module () # type: ignore[union-attr]
1339- ep_expected = (
1340- ep_model (* copy .deepcopy (ep_inputs ))
1341- if isinstance (ep_inputs , tuple )
1342- else ep_model (** copy .deepcopy (ep_inputs ))
1343- )
1344- if verbose :
1345- print (f"[{ vname } ] ep_expected { string_type (ep_expected , ** kws )} " )
1346- ep_diff = max_diff (expected , ep_expected , hist = [0.1 , 0.01 ])
1355+ for expe , inp , got in zip (expected , inputs , gots ):
1356+ ep_inputs = copy .deepcopy (inp ) if copy_inputs else inp
1357+ ep_expected = (
1358+ ep_model (* copy .deepcopy (ep_inputs ))
1359+ if isinstance (ep_inputs , tuple )
1360+ else ep_model (** copy .deepcopy (ep_inputs ))
1361+ )
1362+ if verbose :
1363+ print (f"[{ vname } ] ep_expected { string_type (ep_expected , ** kws )} " )
1364+ ep_diff = max_diff (expe , ep_expected , hist = [0.1 , 0.01 ])
1365+ if verbose :
1366+ print (f"[{ vname } ] ep_diff { string_diff (ep_diff )} " )
1367+ assert (
1368+ isinstance (ep_diff ["abs" ], float )
1369+ and isinstance (ep_diff ["rel" ], float )
1370+ and not numpy .isnan (ep_diff ["abs" ])
1371+ and ep_diff ["abs" ] <= atol
1372+ and not numpy .isnan (ep_diff ["rel" ])
1373+ and ep_diff ["rel" ] <= rtol
1374+ ), (
1375+ f"discrepancies in { test_name !r} between the exported program "
1376+ f"and the exported model diff={ string_diff (ep_diff )} "
1377+ )
1378+ ep_nx_diff = max_diff (ep_expected , got , flatten = True , hist = [0.1 , 0.01 ])
1379+ if verbose :
1380+ print (f"[{ vname } ] ep_nx_diff { string_diff (ep_nx_diff )} " )
1381+
1382+ for expe , got in zip (expected , gots ):
1383+ diff = max_diff (expe , got , flatten = True , hist = [0.1 , 0.01 ])
13471384 if verbose :
1348- print (f"[{ vname } ] ep_diff { string_diff (ep_diff )} " )
1385+ print (f"[{ vname } ] diff { string_diff (diff )} " )
13491386 assert (
1350- isinstance (ep_diff ["abs" ], float )
1351- and isinstance (ep_diff ["rel" ], float )
1352- and not numpy .isnan (ep_diff ["abs" ])
1353- and ep_diff ["abs" ] <= atol
1354- and not numpy .isnan (ep_diff ["rel" ])
1355- and ep_diff ["rel" ] <= rtol
1387+ isinstance (diff ["abs" ], float )
1388+ and isinstance (diff ["rel" ], float )
1389+ and not numpy .isnan (diff ["abs" ])
1390+ and diff ["abs" ] <= atol
1391+ and not numpy .isnan (diff ["rel" ])
1392+ and diff ["rel" ] <= rtol
13561393 ), (
1357- f"discrepancies in { test_name !r} between the exported program "
1358- f"and the exported model diff={ string_diff (ep_diff )} "
1394+ f"discrepancies in { test_name !r} between the model and "
1395+ f"the onnx model diff={ string_diff (diff )} "
13591396 )
1360- ep_nx_diff = max_diff (ep_expected , got , flatten = True , hist = [0.1 , 0.01 ])
1361- if verbose :
1362- print (f"[{ vname } ] ep_nx_diff { string_diff (ep_nx_diff )} " )
1363-
1364- diff = max_diff (expected , got , flatten = True , hist = [0.1 , 0.01 ])
1365- if verbose :
1366- print (f"[{ vname } ] diff { string_diff (diff )} " )
1367- assert (
1368- isinstance (diff ["abs" ], float )
1369- and isinstance (diff ["rel" ], float )
1370- and not numpy .isnan (diff ["abs" ])
1371- and diff ["abs" ] <= atol
1372- and not numpy .isnan (diff ["rel" ])
1373- and diff ["rel" ] <= rtol
1374- ), (
1375- f"discrepancies in { test_name !r} between the model and "
1376- f"the onnx model diff={ string_diff (diff )} "
1377- )
13781397
13791398 def _debug (self ):
13801399 "Tells if DEBUG=1 is set up."
0 commit comments