44import onnx
55import torch
66from ..helpers import max_diff , string_type
7- from ..helpers .torch_helper import torch_dtype_to_onnx_dtype
7+ from ..helpers .torch_helper import (
8+ torch_dtype_to_onnx_dtype ,
9+ onnx_dtype_to_torch_dtype ,
10+ int_device_to_torch_device ,
11+ )
812from ..reference import OnnxruntimeEvaluator
913
1014TUPLE_TENSORS = Tuple [torch .Tensor , ...]
@@ -50,7 +54,10 @@ class EagerDirectReplacementWithOnnx:
5054 only tensors must be counted
5155 :param name: the name of the custom op, the function name if not specified
5256 :param kwargs: constants parameters with their default values
53- :param version_selector: selects the version based on the arguments
57+ :param version_selector: selects the version based on the arguments,
58+ see below for an example, this allows the user to define different
59+ onnx version depending on the inputs
60+ :param default_opset: opset to use by default
5461 :param verbose: verbose level
5562
5663 Here is an example:
@@ -134,6 +141,60 @@ def forward(self, x):
134141 ).model_proto
135142
136143 print(pretty_onnx(onx))
144+
145+ This shows how to define multiple versions depending on the device,
146+ the type or the targetted onnx opset.
147+
148+ .. code-block:: python
149+
150+ def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
151+ first_tensor = next(a for a in args if a is not None)
152+ dtype = first_tensor.dtype
153+ strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
154+ if strategy is not None:
155+ return strategy, dtype
156+ if dtype == torch.float32:
157+ if opset >= 24:
158+ return "LOOPA24", dtype
159+ return "LOOPMHA", dtype
160+ if dtype == torch.float16:
161+ if first_tensor.is_cuda:
162+ return "PACKED", dtype
163+ return "LOOPMHA", dtype
164+ raise AssertionError(
165+ f"Unable to handle type {torch.dtype} on "
166+ f"device {torch.device} with opset={opset}"
167+ )
168+
169+ qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
170+ qwen_sdpa_attention,
171+ lambda qs, *args, **kwargs: torch.empty(
172+ (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
173+ dtype=qs.dtype,
174+ device=qs.device,
175+ ),
176+ {
177+ ("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
178+ PackedAttention.to_function_proto()
179+ ),
180+ ("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(),
181+ ("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type(
182+ onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto()
183+ ),
184+ ("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
185+ LoopMHAAttention.to_function_proto()
186+ ),
187+ ("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
188+ onnx.TensorProto.FLOAT16,
189+ _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
190+ ),
191+ },
192+ n_inputs=4,
193+ n_outputs=1,
194+ kwargs=dict(scaling=0.11180339887498948, num_heads=16),
195+ name="qwen_sdpa_attention_versatile",
196+ version_selector=qwen_version_selector,
197+ )
137198 """
138199
139200 def __init__ (
@@ -146,7 +207,8 @@ def __init__(
146207 name : Optional [str ] = None ,
147208 kwargs : Optional [Dict [str , Union [int , float ]]] = None ,
148209 verbose : int = 0 ,
149- version_selector : Optional [Callable [[Any ], Any ]] = None ,
210+ version_selector : Optional [Callable [..., Tuple [Any , ...]]] = None ,
211+ default_opset : int = 22 ,
150212 ):
151213 assert isinstance (function_proto , onnx .FunctionProto ) or (
152214 isinstance (function_proto , dict )
@@ -183,6 +245,7 @@ def __init__(
183245 self .verbose = verbose
184246 self .custom_op = self ._register ()
185247 self .version_selector = version_selector
248+ self .default_opset = default_opset
186249 self ._check_protos (params )
187250
188251 def _check_protos (self , params ):
@@ -221,21 +284,18 @@ def _check_protos(self, params):
221284 not self ._function_proto_versioned or self .version_selector
222285 ), "version_selector is needed when multiple protos are given."
223286
224- def get_function_proto (self , * args ) -> onnx .FunctionProto :
287+ def get_function_proto (self , opset : int , * args ) -> onnx .FunctionProto :
225288 """Returns the correct version based on the inputs."""
226289 if self ._function_proto :
227290 return self ._function_proto
228- if (
229- len (args ) == 1
230- and isinstance (args [0 ], (int , str ))
231- and args [0 ] in self ._function_proto_versioned
232- ):
233- return self ._function_proto_versioned [args [0 ]]
291+ assert isinstance (
292+ opset , int
293+ ), f"The first argument must be an integer for the onnx opset but it is { type (opset )} "
234294 assert any (
235295 a is not None for a in args
236296 ), f"Unexpected args={ string_type (args , with_shape = True )} "
237297 try :
238- key = self .version_selector (* args ) # type: ignore[misc]
298+ key = self .version_selector (opset , * args ) # type: ignore[misc]
239299 except (ValueError , AttributeError ) as e :
240300 raise AssertionError (
241301 f"Unable to select a version, fails to get a key, available="
@@ -302,6 +362,7 @@ def verify(
302362 * args ,
303363 engine : Optional [Callable ] = None ,
304364 dump_onnx_model : Optional [str ] = None ,
365+ opset : int = 22 ,
305366 ** kwargs ,
306367 ) -> VerifyResult :
307368 """
@@ -316,6 +377,7 @@ def verify(
316377 :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
317378 :param dump_onnx_model: to dump the onnx model used to verify
318379 eager and onnx produce the same results
380+ :param opset: onnx opset to use
319381 :param kwargs: additional arguments to the function
320382 :return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
321383 """
@@ -350,7 +412,7 @@ def verify(
350412 assert engine is None , f"Not implemented yet with engine={ engine !r} "
351413 ags , kws = self ._make_args_kwargs (* args , ** kwargs )
352414 sess = OnnxruntimeEvaluator (
353- self .get_function_proto (* args ),
415+ self .get_function_proto (opset , * args ),
354416 whole = True ,
355417 dump_onnx_model = dump_onnx_model ,
356418 function_kwargs = kws ,
@@ -383,7 +445,17 @@ def converter(
383445 * args ,
384446 ** kwargs ,
385447 ) -> Any :
386- function_proto = self .get_function_proto (g .get_type (args [0 ]))
448+ has_devices = [a for a in args if g .has_device (a )]
449+ assert (
450+ has_devices
451+ ), f"Missing device for any of the inputs { args } { g .get_debug_msg ()} "
452+ arg_device = has_devices [0 ]
453+ fake_tensor = torch .empty (
454+ tuple ([(_ if isinstance (_ , int ) else 2 ) for _ in g .get_shape (args [0 ])]),
455+ dtype = onnx_dtype_to_torch_dtype (g .get_type (args [0 ])),
456+ device = int_device_to_torch_device (g .get_device (arg_device )),
457+ )
458+ function_proto = self .get_function_proto (g .main_opset , fake_tensor )
387459 if not g .has_local_function (function_proto .name , domain = function_proto .domain ):
388460 g .add_function (function_proto )
389461 ags , kws = self ._make_args_kwargs (* args , ** kwargs )
@@ -417,7 +489,7 @@ def onnx_dynamo_converter(self) -> Callable:
417489 onnx_plug_op = onnxscript .values .Opset (domain = self .domain , version = 1 )
418490
419491 def get_proto (* args ):
420- function_proto = self .get_function_proto (* args )
492+ function_proto = self .get_function_proto (self . default_opset , * args )
421493 schema = onnx_plug_op [function_proto .name ]
422494 if schema is None :
423495 all_types = [
0 commit comments