diff --git a/_unittests/ut_export/test_api.py b/_unittests/ut_export/test_api.py index c1fa8b09..b483696d 100644 --- a/_unittests/ut_export/test_api.py +++ b/_unittests/ut_export/test_api.py @@ -6,6 +6,7 @@ hide_stdout, has_transformers, ignore_warnings, + requires_transformers, ) from onnx_diagnostic.helpers import max_diff from onnx_diagnostic.helpers.torch_helper import torch_deepcopy @@ -43,6 +44,7 @@ def forward(self, x, y): @hide_stdout() @ignore_warnings(FutureWarning) + @requires_transformers("4.50") def test_tiny_llm_to_onnx(self): import onnxruntime diff --git a/_unittests/ut_tasks/try_export.py b/_unittests/ut_tasks/try_export.py index d248f9dc..f79a7667 100644 --- a/_unittests/ut_tasks/try_export.py +++ b/_unittests/ut_tasks/try_export.py @@ -57,6 +57,7 @@ def test_qwen25_vli_visual(self): TESTDEVICE=cuda \\ TESTDTYPE=float16 \\ EXPORTER=custom \\ + CUT_EXPORTED_PROGRAM=qwen_sdpa_attention_loopmha_16 \\ python _unittests/ut_tasks/try_export.py -k qwen25_vli_visual .. code-block:: bash @@ -125,25 +126,32 @@ def _config_reduction(config, task): print(f"-- processor={type(processor)}") print(f"-- PROCESSOR LOADED IN {time.perf_counter() - begin}") - big_inputs = dict( - hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device), - grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device), - ) - print("-- save inputs") inputs = dict( hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device), grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device), ) if not self.unit_test_going(): print("-- save inputs") - torch.save(big_inputs, self.get_dump_file("qwen25_vli_visual.inputs.big.pt")) torch.save(inputs, self.get_dump_file("qwen25_vli_visual.inputs.pt")) + print("-- save big inputs") + big_inputs = dict( + hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device), + grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device), + ) + torch.save(big_inputs, self.get_dump_file("qwen25_vli_visual.inputs.big.pt")) + else: + big_inputs = None print(f"-- inputs: {self.string_type(inputs, with_shape=True)}") # this is too long model_to_export = model.visual if hasattr(model, "visual") else model.model.visual begin = time.perf_counter() - expected = model_to_export(**inputs) + if not os.environ.get("STOPAT", ""): + expected = model_to_export(**inputs) + expected_big = None if big_inputs is None else model_to_export(**big_inputs) + else: + expected = None + expected_big = None print(f"-- MODEL RUN IN {time.perf_counter() - begin}") print(f"-- expected: {self.string_type(expected, with_shape=True)}") @@ -184,6 +192,11 @@ def _config_reduction(config, task): verbose=1, stop_if_static=2, ): + if expected is None: + expected = model_to_export(**inputs) + expected_big = ( + None if big_inputs is None else model_to_export(**big_inputs) + ) to_onnx( model_to_export, kwargs=export_inputs, @@ -256,7 +269,7 @@ def _config_reduction(config, task): (f"test_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}"), filename, model_to_export, - export_inputs, + [_ for _ in [export_inputs, big_inputs] if _ is not None], verbose=1, providers=( ["CUDAExecutionProvider", "CPUExecutionProvider"] @@ -267,7 +280,9 @@ def _config_reduction(config, task): atol=0.05, rtol=10, # ep=pt2_file, - expected=expected, + expected=[_ for _ in [expected, expected_big] if _ is not None], + log_severity_level=0, + log_verbosity_level=0, ) print(f"-- MODEL VERIFIED IN {time.perf_counter() - begin}") os.environ["QWEN25ATTENTION"] = qwen25_attention diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 6a53d07c..8c58d796 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -270,7 +270,7 @@ def forward(self, q, k, cos, sin): "test_qwen_apply_multimodal_rotary_pos_emb", proto, model, - inputs, + [inputs], verbose=1, atol=1e-3, rtol=1, diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index ae885966..37582f0b 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1286,7 +1286,13 @@ def get_parser_sbs() -> ArgumentParser: "--first", action=BooleanOptionalAction, default=False, - help="First runs the whole model.", + help="First runs the whole model (default is False).", + ) + parser.add_argument( + "--sbs", + action=BooleanOptionalAction, + default=True, + help="Runs the side-by-side (default is True).", ) parser.add_argument( "-2", @@ -1431,6 +1437,10 @@ def _size(name): print("-- done") del sess + if not args.sbs: + print("-- done") + return + print(f"-- load onnx {args.onnx!r}") begin = time.perf_counter() onx = onnx.load(args.onnx) diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index cfc65cb2..300e339b 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -231,6 +231,9 @@ def get_function_proto(self, *args) -> onnx.FunctionProto: and args[0] in self._function_proto_versioned ): return self._function_proto_versioned[args[0]] + assert any( + a is not None for a in args + ), f"Unexpected args={string_type(args, with_shape=True)}" try: key = self.version_selector(*args) # type: ignore[misc] except (ValueError, AttributeError) as e: @@ -414,7 +417,7 @@ def onnx_dynamo_converter(self) -> Callable: onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1) def get_proto(*args): - function_proto = self.get_function_proto() + function_proto = self.get_function_proto(*args) schema = onnx_plug_op[function_proto.name] if schema is None: all_types = [ diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index 0da11c02..211794c9 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -1218,9 +1218,9 @@ def tryCall( def assert_onnx_disc( self, test_name: str, - proto: "onnx.ModelProto", # noqa: F821 + proto: Union[str, "onnx.ModelProto"], # noqa: F821 model: "torch.nn.Module", # noqa: F821 - inputs: Union[Tuple[Any], Dict[str, Any]], + inputs: Union[Tuple[Any], Dict[str, Any], List[Any]], verbose: int = 0, atol: float = 1e-5, rtol: float = 1e-3, @@ -1264,7 +1264,9 @@ def assert_onnx_disc( name = f"{test_name}.onnx" if verbose: print(f"[{vname}] save the onnx model into {name!r}") + model_file = None if isinstance(proto, str): + model_file = proto name = proto proto = onnx.load(name) elif not self.unit_test_going(): @@ -1277,45 +1279,64 @@ def assert_onnx_disc( if verbose: print(f"[{vname}] make feeds {string_type(inputs, **kws)}") + if not isinstance(inputs, list): + inputs = [inputs] + if expected is not None: + expected = [expected] + + gots = [] if use_ort: assert isinstance( proto, onnx.ModelProto ), f"Unexpected type {type(proto)} for proto" - feeds = make_feeds(proto, inputs, use_numpy=True, copy=True) import onnxruntime options = onnxruntime.SessionOptions() if ort_optimized_graph: options.optimized_model_filepath = f"{name}.optort.onnx" + if "log_severity_level" in kwargs: + options.log_severity_level = kwargs["log_severity_level"] + if "log_verbosity_level" in kwargs: + options.log_verbosity_level = kwargs["log_verbosity_level"] providers = kwargs.get("providers", ["CPUExecutionProvider"]) if verbose: print(f"[{vname}] create onnxruntime.InferenceSession with {providers}") sess = onnxruntime.InferenceSession( - proto.SerializeToString(), options, providers=providers + model_file or proto.SerializeToString(), options, providers=providers ) - if verbose: - print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}") - got = sess.run(None, feeds) + for inp in inputs: + feeds = make_feeds(proto, inp, use_numpy=True, copy=True) + if verbose: + print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}") + got = sess.run(None, feeds) + gots.append(got) else: - feeds = make_feeds(proto, inputs, copy=True) if verbose: print(f"[{vname}] create InferenceSessionForTorch") sess = InferenceSessionForTorch(proto, **kwargs) - if verbose: - print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}") - got = sess.run(None, feeds) + for inp in inputs: + feeds = make_feeds(proto, inp, copy=True) + if verbose: + print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}") + got = sess.run(None, feeds) + gots.append(got) if verbose: print(f"[{vname}] compute expected values") if expected is None: if copy_inputs: - expected = ( - model(*copy.deepcopy(inputs)) - if isinstance(inputs, tuple) - else model(**copy.deepcopy(inputs)) - ) + expected = [ + ( + model(*copy.deepcopy(inp)) + if isinstance(inp, tuple) + else model(**copy.deepcopy(inp)) + ) + for inp in inputs + ] else: - expected = model(*inputs) if isinstance(inputs, tuple) else model(**inputs) + expected = [ + model(*inp) if isinstance(inp, tuple) else model(**inp) for inp in inputs + ] if verbose: print(f"[{vname}] expected {string_type(expected, **kws)}") @@ -1328,47 +1349,50 @@ def assert_onnx_disc( import torch ep = torch.export.load(ep) - ep_inputs = copy.deepcopy(inputs) if copy_inputs else inputs + ep_model = ep.module() # type: ignore[union-attr] - ep_expected = ( - ep_model(*copy.deepcopy(ep_inputs)) - if isinstance(ep_inputs, tuple) - else ep_model(**copy.deepcopy(ep_inputs)) - ) - if verbose: - print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}") - ep_diff = max_diff(expected, ep_expected, hist=[0.1, 0.01]) + for expe, inp, got in zip(expected, inputs, gots): + ep_inputs = copy.deepcopy(inp) if copy_inputs else inp + ep_expected = ( + ep_model(*copy.deepcopy(ep_inputs)) + if isinstance(ep_inputs, tuple) + else ep_model(**copy.deepcopy(ep_inputs)) + ) + if verbose: + print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}") + ep_diff = max_diff(expe, ep_expected, hist=[0.1, 0.01]) + if verbose: + print(f"[{vname}] ep_diff {string_diff(ep_diff)}") + assert ( + isinstance(ep_diff["abs"], float) + and isinstance(ep_diff["rel"], float) + and not numpy.isnan(ep_diff["abs"]) + and ep_diff["abs"] <= atol + and not numpy.isnan(ep_diff["rel"]) + and ep_diff["rel"] <= rtol + ), ( + f"discrepancies in {test_name!r} between the exported program " + f"and the exported model diff={string_diff(ep_diff)}" + ) + ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01]) + if verbose: + print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}") + + for expe, got in zip(expected, gots): + diff = max_diff(expe, got, flatten=True, hist=[0.1, 0.01]) if verbose: - print(f"[{vname}] ep_diff {string_diff(ep_diff)}") + print(f"[{vname}] diff {string_diff(diff)}") assert ( - isinstance(ep_diff["abs"], float) - and isinstance(ep_diff["rel"], float) - and not numpy.isnan(ep_diff["abs"]) - and ep_diff["abs"] <= atol - and not numpy.isnan(ep_diff["rel"]) - and ep_diff["rel"] <= rtol + isinstance(diff["abs"], float) + and isinstance(diff["rel"], float) + and not numpy.isnan(diff["abs"]) + and diff["abs"] <= atol + and not numpy.isnan(diff["rel"]) + and diff["rel"] <= rtol ), ( - f"discrepancies in {test_name!r} between the exported program " - f"and the exported model diff={string_diff(ep_diff)}" + f"discrepancies in {test_name!r} between the model and " + f"the onnx model diff={string_diff(diff)}" ) - ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01]) - if verbose: - print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}") - - diff = max_diff(expected, got, flatten=True, hist=[0.1, 0.01]) - if verbose: - print(f"[{vname}] diff {string_diff(diff)}") - assert ( - isinstance(diff["abs"], float) - and isinstance(diff["rel"], float) - and not numpy.isnan(diff["abs"]) - and diff["abs"] <= atol - and not numpy.isnan(diff["rel"]) - and diff["rel"] <= rtol - ), ( - f"discrepancies in {test_name!r} between the model and " - f"the onnx model diff={string_diff(diff)}" - ) def _debug(self): "Tells if DEBUG=1 is set up." diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index b9f6c927..95182801 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -94,6 +94,20 @@ def size_type(dtype: Any) -> int: raise AssertionError(f"Unexpected dtype={dtype}") +def _string_tensor(obj, cls: str, with_shape: bool, with_device: bool, verbose: int) -> str: + from .torch_helper import torch_dtype_to_onnx_dtype + + i = torch_dtype_to_onnx_dtype(obj.dtype) + prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else "" + if not with_shape: + if verbose: + print(f"[string_type] {cls}1:{type(obj)}") + return f"{prefix}{cls}{i}r{len(obj.shape)}" + if verbose: + print(f"[string_type] {cls}2:{type(obj)}") + return f"{prefix}{cls}{i}s{'x'.join(map(str, obj.shape))}" + + def string_type( obj: Any, with_shape: bool = False, @@ -453,17 +467,7 @@ def string_type( # Tensors if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor): - from .torch_helper import torch_dtype_to_onnx_dtype - - i = torch_dtype_to_onnx_dtype(obj.dtype) - prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else "" - if not with_shape: - if verbose: - print(f"[string_type] F1:{type(obj)}") - return f"{prefix}F{i}r{len(obj.shape)}" - if verbose: - print(f"[string_type] F2:{type(obj)}") - return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}" + return _string_tensor(obj, "F", with_shape, with_device, verbose) if isinstance(obj, torch.Tensor): from .torch_helper import torch_dtype_to_onnx_dtype @@ -544,6 +548,9 @@ def string_type( print(f"[string_type] V6:{type(obj)}") return f"{dev}OV{dt}r{len(shape)}" + if obj.__class__.__name__ == "SymbolicTensor": + return _string_tensor(obj, "ST", with_shape, with_device, verbose) + # others classes if obj.__class__.__name__ == "MambaCache": diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 68fcf599..74dcea1b 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1646,6 +1646,7 @@ def select_model_inputs_outputs( known_shapes[shape.name] = shape.type var_in = [] + existing = {i.name: i for i in model.graph.input} for name in inputs: if overwrite is not None and name in overwrite: dtype, shape = overwrite[name] @@ -1660,12 +1661,15 @@ def select_model_inputs_outputs( else: shape = get_tensor_shape(known_shapes[name]) value_info = oh.make_tensor_value_info(name, proto_dtype, shape) + elif name in existing: + value_info = existing[name] else: value_info = ValueInfoProto() value_info.name = name var_in.append(value_info) var_out = [] + existing = {i.name: i for i in model.graph.output} for name in outputs: if overwrite is not None and name in overwrite: dtype, shape = overwrite[name] @@ -1680,6 +1684,8 @@ def select_model_inputs_outputs( else: shape = get_tensor_shape(known_shapes[name]) value_info = oh.make_tensor_value_info(name, proto_dtype, shape) + elif name in existing: + value_info = existing[name] else: value_info = ValueInfoProto() value_info.name = name diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index f7c0bbfc..7b0eb566 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -139,6 +139,15 @@ def onnx_dtype_to_torch_dtype(itype: int) -> torch.dtype: ) +_TYPENAME = dict( + FLOAT=onnx.TensorProto.FLOAT, + INT64=onnx.TensorProto.INT64, + INT32=onnx.TensorProto.INT32, + FLOAT16=onnx.TensorProto.FLOAT16, + BFLOAT16=onnx.TensorProto.BFLOAT16, +) + + def torch_dtype_to_onnx_dtype(to: torch.dtype) -> int: """ Converts a torch dtype into a onnx element type. @@ -182,7 +191,13 @@ def torch_dtype_to_onnx_dtype(to: torch.dtype) -> int: return onnx.TensorProto.COMPLEX64 if to == torch.complex128: return onnx.TensorProto.COMPLEX128 - raise NotImplementedError(f"Unable to convert torch dtype {to!r} to onnx dtype.") + # SymbolicTensor + sto = str(to) + if sto in _TYPENAME: + return _TYPENAME[sto] + raise NotImplementedError( + f"Unable to convert torch dtype {to!r} ({type(to)}) to onnx dtype." + ) def _forward_( diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py index dfa5698e..1a1c61c3 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py @@ -26,6 +26,11 @@ op = onnxscript.opset22 op24 = onnxscript.onnx_opset.opset24 msft_op = onnxscript.values.Opset("com.microsoft", 1) + STOPAT = ( + int(os.environ.get("STOPAT", None)) + if os.environ.get("STOPAT", None) is not None + else None + ) def _add_com_microsoft_opset(function_proto: onnx.FunctionProto) -> onnx.FunctionProto: opsets = {d.domain: d.version for d in function_proto.opset_import} @@ -253,7 +258,9 @@ def qwen_sdpa_attention( n_outputs=1, kwargs=dict(scaling=0.11180339887498948, num_heads=16), name="qwen_sdpa_attention_loopmha", - version_selector=lambda *args: torch_dtype_to_onnx_dtype(args[0].dtype), + version_selector=lambda *args: torch_dtype_to_onnx_dtype( + next(a for a in args if a is not None).dtype + ), ) PLUGS.append(qwen_sdpa_attention_loopmha_versatile) @@ -274,7 +281,9 @@ def qwen_sdpa_attention( n_outputs=1, kwargs=dict(scaling=0.11180339887498948, num_heads=16), name="qwen_sdpa_attention_loopa24", - version_selector=lambda *args: torch_dtype_to_onnx_dtype(args[0].dtype), + version_selector=lambda *args: torch_dtype_to_onnx_dtype( + next(a for a in args if a is not None).dtype + ), ) PLUGS.append(qwen_sdpa_attention_loopa24_versatile) @@ -525,6 +534,8 @@ def forward( position_embeddings=position_embeddings, **kwargs, ) + if STOPAT is not None and layer_num > STOPAT: + break hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index c6018e90..729b2416 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -12,7 +12,6 @@ from_numpy, to_tensor, torch_dtype_to_onnx_dtype, - torch_deepcopy, ) from ..helpers.torch_fx_graph_helper import prepare_args_kwargs, run_fx_node from ..reference.ort_evaluator import OnnxList, OnnxruntimeEvaluator @@ -194,7 +193,7 @@ def _loop_onnx_node( print(f"[run_aligned] feeds={string_type(feeds, **str_kws)}") begin = time.perf_counter() try: - res = ref.run(None, torch_deepcopy(feeds)) # type: ignore[attr-defined] + res = ref.run(None, feeds) # type: ignore[attr-defined] except Exception as e: raise RuntimeError( f"Unable to run node {node.op_type}, domain={node.domain} " @@ -247,7 +246,7 @@ def _loop_onnx_node( f"[run_aligned] feeds for second run=" f"{string_type(new_feeds, **str_kws)}" ) - cross = ref.run(None, torch_deepcopy(new_feeds)) + cross = ref.run(None, new_feeds) if verbose > 1: print(f"[run_aligned] got for second run={string_type(cross, **str_kws)}") # Gemm = torch.nn.function.linear, in that case, we just run it as well