Skip to content

Commit d3eb796

Browse files
authored
improves side-by-side with more verification (#339)
* remove torch_deepcopy * many tiny fixes * fix * last changes * fix * fix * sr * fix export * api
1 parent 93b21c6 commit d3eb796

11 files changed

Lines changed: 174 additions & 82 deletions

File tree

_unittests/ut_export/test_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
hide_stdout,
77
has_transformers,
88
ignore_warnings,
9+
requires_transformers,
910
)
1011
from onnx_diagnostic.helpers import max_diff
1112
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
@@ -43,6 +44,7 @@ def forward(self, x, y):
4344

4445
@hide_stdout()
4546
@ignore_warnings(FutureWarning)
47+
@requires_transformers("4.50")
4648
def test_tiny_llm_to_onnx(self):
4749
import onnxruntime
4850

_unittests/ut_tasks/try_export.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_qwen25_vli_visual(self):
5757
TESTDEVICE=cuda \\
5858
TESTDTYPE=float16 \\
5959
EXPORTER=custom \\
60+
CUT_EXPORTED_PROGRAM=qwen_sdpa_attention_loopmha_16 \\
6061
python _unittests/ut_tasks/try_export.py -k qwen25_vli_visual
6162
6263
.. code-block:: bash
@@ -125,25 +126,32 @@ def _config_reduction(config, task):
125126
print(f"-- processor={type(processor)}")
126127
print(f"-- PROCESSOR LOADED IN {time.perf_counter() - begin}")
127128

128-
big_inputs = dict(
129-
hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device),
130-
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device),
131-
)
132-
print("-- save inputs")
133129
inputs = dict(
134130
hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device),
135131
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
136132
)
137133
if not self.unit_test_going():
138134
print("-- save inputs")
139-
torch.save(big_inputs, self.get_dump_file("qwen25_vli_visual.inputs.big.pt"))
140135
torch.save(inputs, self.get_dump_file("qwen25_vli_visual.inputs.pt"))
136+
print("-- save big inputs")
137+
big_inputs = dict(
138+
hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device),
139+
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device),
140+
)
141+
torch.save(big_inputs, self.get_dump_file("qwen25_vli_visual.inputs.big.pt"))
142+
else:
143+
big_inputs = None
141144

142145
print(f"-- inputs: {self.string_type(inputs, with_shape=True)}")
143146
# this is too long
144147
model_to_export = model.visual if hasattr(model, "visual") else model.model.visual
145148
begin = time.perf_counter()
146-
expected = model_to_export(**inputs)
149+
if not os.environ.get("STOPAT", ""):
150+
expected = model_to_export(**inputs)
151+
expected_big = None if big_inputs is None else model_to_export(**big_inputs)
152+
else:
153+
expected = None
154+
expected_big = None
147155
print(f"-- MODEL RUN IN {time.perf_counter() - begin}")
148156
print(f"-- expected: {self.string_type(expected, with_shape=True)}")
149157

@@ -184,6 +192,11 @@ def _config_reduction(config, task):
184192
verbose=1,
185193
stop_if_static=2,
186194
):
195+
if expected is None:
196+
expected = model_to_export(**inputs)
197+
expected_big = (
198+
None if big_inputs is None else model_to_export(**big_inputs)
199+
)
187200
to_onnx(
188201
model_to_export,
189202
kwargs=export_inputs,
@@ -256,7 +269,7 @@ def _config_reduction(config, task):
256269
(f"test_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}"),
257270
filename,
258271
model_to_export,
259-
export_inputs,
272+
[_ for _ in [export_inputs, big_inputs] if _ is not None],
260273
verbose=1,
261274
providers=(
262275
["CUDAExecutionProvider", "CPUExecutionProvider"]
@@ -267,7 +280,9 @@ def _config_reduction(config, task):
267280
atol=0.05,
268281
rtol=10,
269282
# ep=pt2_file,
270-
expected=expected,
283+
expected=[_ for _ in [expected, expected_big] if _ is not None],
284+
log_severity_level=0,
285+
log_verbosity_level=0,
271286
)
272287
print(f"-- MODEL VERIFIED IN {time.perf_counter() - begin}")
273288
os.environ["QWEN25ATTENTION"] = qwen25_attention

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def forward(self, q, k, cos, sin):
270270
"test_qwen_apply_multimodal_rotary_pos_emb",
271271
proto,
272272
model,
273-
inputs,
273+
[inputs],
274274
verbose=1,
275275
atol=1e-3,
276276
rtol=1,

onnx_diagnostic/_command_lines_parser.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1286,7 +1286,13 @@ def get_parser_sbs() -> ArgumentParser:
12861286
"--first",
12871287
action=BooleanOptionalAction,
12881288
default=False,
1289-
help="First runs the whole model.",
1289+
help="First runs the whole model (default is False).",
1290+
)
1291+
parser.add_argument(
1292+
"--sbs",
1293+
action=BooleanOptionalAction,
1294+
default=True,
1295+
help="Runs the side-by-side (default is True).",
12901296
)
12911297
parser.add_argument(
12921298
"-2",
@@ -1431,6 +1437,10 @@ def _size(name):
14311437
print("-- done")
14321438
del sess
14331439

1440+
if not args.sbs:
1441+
print("-- done")
1442+
return
1443+
14341444
print(f"-- load onnx {args.onnx!r}")
14351445
begin = time.perf_counter()
14361446
onx = onnx.load(args.onnx)

onnx_diagnostic/export/onnx_plug.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ def get_function_proto(self, *args) -> onnx.FunctionProto:
231231
and args[0] in self._function_proto_versioned
232232
):
233233
return self._function_proto_versioned[args[0]]
234+
assert any(
235+
a is not None for a in args
236+
), f"Unexpected args={string_type(args, with_shape=True)}"
234237
try:
235238
key = self.version_selector(*args) # type: ignore[misc]
236239
except (ValueError, AttributeError) as e:
@@ -414,7 +417,7 @@ def onnx_dynamo_converter(self) -> Callable:
414417
onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1)
415418

416419
def get_proto(*args):
417-
function_proto = self.get_function_proto()
420+
function_proto = self.get_function_proto(*args)
418421
schema = onnx_plug_op[function_proto.name]
419422
if schema is None:
420423
all_types = [

onnx_diagnostic/ext_test_case.py

Lines changed: 77 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,9 +1218,9 @@ def tryCall(
12181218
def assert_onnx_disc(
12191219
self,
12201220
test_name: str,
1221-
proto: "onnx.ModelProto", # noqa: F821
1221+
proto: Union[str, "onnx.ModelProto"], # noqa: F821
12221222
model: "torch.nn.Module", # noqa: F821
1223-
inputs: Union[Tuple[Any], Dict[str, Any]],
1223+
inputs: Union[Tuple[Any], Dict[str, Any], List[Any]],
12241224
verbose: int = 0,
12251225
atol: float = 1e-5,
12261226
rtol: float = 1e-3,
@@ -1264,7 +1264,9 @@ def assert_onnx_disc(
12641264
name = f"{test_name}.onnx"
12651265
if verbose:
12661266
print(f"[{vname}] save the onnx model into {name!r}")
1267+
model_file = None
12671268
if isinstance(proto, str):
1269+
model_file = proto
12681270
name = proto
12691271
proto = onnx.load(name)
12701272
elif not self.unit_test_going():
@@ -1277,45 +1279,64 @@ def assert_onnx_disc(
12771279
if verbose:
12781280
print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
12791281

1282+
if not isinstance(inputs, list):
1283+
inputs = [inputs]
1284+
if expected is not None:
1285+
expected = [expected]
1286+
1287+
gots = []
12801288
if use_ort:
12811289
assert isinstance(
12821290
proto, onnx.ModelProto
12831291
), f"Unexpected type {type(proto)} for proto"
1284-
feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
12851292
import onnxruntime
12861293

12871294
options = onnxruntime.SessionOptions()
12881295
if ort_optimized_graph:
12891296
options.optimized_model_filepath = f"{name}.optort.onnx"
1297+
if "log_severity_level" in kwargs:
1298+
options.log_severity_level = kwargs["log_severity_level"]
1299+
if "log_verbosity_level" in kwargs:
1300+
options.log_verbosity_level = kwargs["log_verbosity_level"]
12901301
providers = kwargs.get("providers", ["CPUExecutionProvider"])
12911302
if verbose:
12921303
print(f"[{vname}] create onnxruntime.InferenceSession with {providers}")
12931304
sess = onnxruntime.InferenceSession(
1294-
proto.SerializeToString(), options, providers=providers
1305+
model_file or proto.SerializeToString(), options, providers=providers
12951306
)
1296-
if verbose:
1297-
print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
1298-
got = sess.run(None, feeds)
1307+
for inp in inputs:
1308+
feeds = make_feeds(proto, inp, use_numpy=True, copy=True)
1309+
if verbose:
1310+
print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
1311+
got = sess.run(None, feeds)
1312+
gots.append(got)
12991313
else:
1300-
feeds = make_feeds(proto, inputs, copy=True)
13011314
if verbose:
13021315
print(f"[{vname}] create InferenceSessionForTorch")
13031316
sess = InferenceSessionForTorch(proto, **kwargs)
1304-
if verbose:
1305-
print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}")
1306-
got = sess.run(None, feeds)
1317+
for inp in inputs:
1318+
feeds = make_feeds(proto, inp, copy=True)
1319+
if verbose:
1320+
print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}")
1321+
got = sess.run(None, feeds)
1322+
gots.append(got)
13071323
if verbose:
13081324
print(f"[{vname}] compute expected values")
13091325

13101326
if expected is None:
13111327
if copy_inputs:
1312-
expected = (
1313-
model(*copy.deepcopy(inputs))
1314-
if isinstance(inputs, tuple)
1315-
else model(**copy.deepcopy(inputs))
1316-
)
1328+
expected = [
1329+
(
1330+
model(*copy.deepcopy(inp))
1331+
if isinstance(inp, tuple)
1332+
else model(**copy.deepcopy(inp))
1333+
)
1334+
for inp in inputs
1335+
]
13171336
else:
1318-
expected = model(*inputs) if isinstance(inputs, tuple) else model(**inputs)
1337+
expected = [
1338+
model(*inp) if isinstance(inp, tuple) else model(**inp) for inp in inputs
1339+
]
13191340

13201341
if verbose:
13211342
print(f"[{vname}] expected {string_type(expected, **kws)}")
@@ -1328,47 +1349,50 @@ def assert_onnx_disc(
13281349
import torch
13291350

13301351
ep = torch.export.load(ep)
1331-
ep_inputs = copy.deepcopy(inputs) if copy_inputs else inputs
1352+
13321353
ep_model = ep.module() # type: ignore[union-attr]
1333-
ep_expected = (
1334-
ep_model(*copy.deepcopy(ep_inputs))
1335-
if isinstance(ep_inputs, tuple)
1336-
else ep_model(**copy.deepcopy(ep_inputs))
1337-
)
1338-
if verbose:
1339-
print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}")
1340-
ep_diff = max_diff(expected, ep_expected, hist=[0.1, 0.01])
1354+
for expe, inp, got in zip(expected, inputs, gots):
1355+
ep_inputs = copy.deepcopy(inp) if copy_inputs else inp
1356+
ep_expected = (
1357+
ep_model(*copy.deepcopy(ep_inputs))
1358+
if isinstance(ep_inputs, tuple)
1359+
else ep_model(**copy.deepcopy(ep_inputs))
1360+
)
1361+
if verbose:
1362+
print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}")
1363+
ep_diff = max_diff(expe, ep_expected, hist=[0.1, 0.01])
1364+
if verbose:
1365+
print(f"[{vname}] ep_diff {string_diff(ep_diff)}")
1366+
assert (
1367+
isinstance(ep_diff["abs"], float)
1368+
and isinstance(ep_diff["rel"], float)
1369+
and not numpy.isnan(ep_diff["abs"])
1370+
and ep_diff["abs"] <= atol
1371+
and not numpy.isnan(ep_diff["rel"])
1372+
and ep_diff["rel"] <= rtol
1373+
), (
1374+
f"discrepancies in {test_name!r} between the exported program "
1375+
f"and the exported model diff={string_diff(ep_diff)}"
1376+
)
1377+
ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01])
1378+
if verbose:
1379+
print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
1380+
1381+
for expe, got in zip(expected, gots):
1382+
diff = max_diff(expe, got, flatten=True, hist=[0.1, 0.01])
13411383
if verbose:
1342-
print(f"[{vname}] ep_diff {string_diff(ep_diff)}")
1384+
print(f"[{vname}] diff {string_diff(diff)}")
13431385
assert (
1344-
isinstance(ep_diff["abs"], float)
1345-
and isinstance(ep_diff["rel"], float)
1346-
and not numpy.isnan(ep_diff["abs"])
1347-
and ep_diff["abs"] <= atol
1348-
and not numpy.isnan(ep_diff["rel"])
1349-
and ep_diff["rel"] <= rtol
1386+
isinstance(diff["abs"], float)
1387+
and isinstance(diff["rel"], float)
1388+
and not numpy.isnan(diff["abs"])
1389+
and diff["abs"] <= atol
1390+
and not numpy.isnan(diff["rel"])
1391+
and diff["rel"] <= rtol
13501392
), (
1351-
f"discrepancies in {test_name!r} between the exported program "
1352-
f"and the exported model diff={string_diff(ep_diff)}"
1393+
f"discrepancies in {test_name!r} between the model and "
1394+
f"the onnx model diff={string_diff(diff)}"
13531395
)
1354-
ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01])
1355-
if verbose:
1356-
print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
1357-
1358-
diff = max_diff(expected, got, flatten=True, hist=[0.1, 0.01])
1359-
if verbose:
1360-
print(f"[{vname}] diff {string_diff(diff)}")
1361-
assert (
1362-
isinstance(diff["abs"], float)
1363-
and isinstance(diff["rel"], float)
1364-
and not numpy.isnan(diff["abs"])
1365-
and diff["abs"] <= atol
1366-
and not numpy.isnan(diff["rel"])
1367-
and diff["rel"] <= rtol
1368-
), (
1369-
f"discrepancies in {test_name!r} between the model and "
1370-
f"the onnx model diff={string_diff(diff)}"
1371-
)
13721396

13731397
def _debug(self):
13741398
"Tells if DEBUG=1 is set up."

onnx_diagnostic/helpers/helper.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,20 @@ def size_type(dtype: Any) -> int:
9494
raise AssertionError(f"Unexpected dtype={dtype}")
9595

9696

97+
def _string_tensor(obj, cls: str, with_shape: bool, with_device: bool, verbose: int) -> str:
98+
from .torch_helper import torch_dtype_to_onnx_dtype
99+
100+
i = torch_dtype_to_onnx_dtype(obj.dtype)
101+
prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
102+
if not with_shape:
103+
if verbose:
104+
print(f"[string_type] {cls}1:{type(obj)}")
105+
return f"{prefix}{cls}{i}r{len(obj.shape)}"
106+
if verbose:
107+
print(f"[string_type] {cls}2:{type(obj)}")
108+
return f"{prefix}{cls}{i}s{'x'.join(map(str, obj.shape))}"
109+
110+
97111
def string_type(
98112
obj: Any,
99113
with_shape: bool = False,
@@ -453,17 +467,7 @@ def string_type(
453467

454468
# Tensors
455469
if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor):
456-
from .torch_helper import torch_dtype_to_onnx_dtype
457-
458-
i = torch_dtype_to_onnx_dtype(obj.dtype)
459-
prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
460-
if not with_shape:
461-
if verbose:
462-
print(f"[string_type] F1:{type(obj)}")
463-
return f"{prefix}F{i}r{len(obj.shape)}"
464-
if verbose:
465-
print(f"[string_type] F2:{type(obj)}")
466-
return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}"
470+
return _string_tensor(obj, "F", with_shape, with_device, verbose)
467471

468472
if isinstance(obj, torch.Tensor):
469473
from .torch_helper import torch_dtype_to_onnx_dtype
@@ -544,6 +548,9 @@ def string_type(
544548
print(f"[string_type] V6:{type(obj)}")
545549
return f"{dev}OV{dt}r{len(shape)}"
546550

551+
if obj.__class__.__name__ == "SymbolicTensor":
552+
return _string_tensor(obj, "ST", with_shape, with_device, verbose)
553+
547554
# others classes
548555

549556
if obj.__class__.__name__ == "MambaCache":

0 commit comments

Comments
 (0)