@@ -150,20 +150,18 @@ def forward(self, x):
150150 def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
151151 first_tensor = next(a for a in args if a is not None)
152152 dtype = first_tensor.dtype
153- strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
154- if strategy is not None:
155- return strategy, dtype
153+ itype = torch_dtype_to_onnx_dtype(dtype)
156154 if dtype == torch.float32:
157155 if opset >= 24:
158- return "LOOPA24", dtype
159- return "LOOPMHA", dtype
156+ return "LOOPA24", itype
157+ return "LOOPMHA", itype
160158 if dtype == torch.float16:
161159 if first_tensor.is_cuda:
162- return "PACKED", dtype
163- return "LOOPMHA", dtype
160+ return "PACKED", itype
161+ return "LOOPMHA", itype
164162 raise AssertionError(
165- f"Unable to handle type {torch.dtype} on "
166- f"device {torch.device} with opset={opset}"
163+ f"Unable to handle type {torch.dtype} (itype={itype}) "
164+ f"on device {torch.device} with opset={opset}"
167165 )
168166
169167 qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
@@ -338,6 +336,8 @@ def _register(self):
338336 input_args .append (f"int { p } ={ val } " )
339337 elif isinstance (val , float ):
340338 input_args .append (f"float { p } ={ val } " )
339+ elif isinstance (val , str ):
340+ input_args .append (f"str { p } ={ val } " )
341341 else :
342342 raise NotImplementedError (
343343 f"kwargs { p !r} has a default value of unsupported type { type (val )} "
@@ -445,7 +445,7 @@ def converter(
445445 * args ,
446446 ** kwargs ,
447447 ) -> Any :
448- has_devices = [a for a in args if g .has_device (a )]
448+ has_devices = [a for a in args if isinstance ( a , str ) and g .has_device (a )]
449449 assert (
450450 has_devices
451451 ), f"Missing device for any of the inputs { args } { g .get_debug_msg ()} "
0 commit comments