diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index dba062c7..ed2b273f 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.4 +++++ +* :pr:`340`: supports devices in onnx plugs * :pr:`338`: fixes ReplayConfiguration.dump, add function to select of part of a model * :pr:`337`: fixes extract_subset_of_nodes * :pr:`336`: implements versioned onnx plugs diff --git a/_unittests/ut_tasks/try_export.py b/_unittests/ut_tasks/try_export.py index f79a7667..d1ca0425 100644 --- a/_unittests/ut_tasks/try_export.py +++ b/_unittests/ut_tasks/try_export.py @@ -52,7 +52,6 @@ def test_qwen25_vli_visual(self): .. code-block:: bash NEVERTEST=1 \\ - QWEN25ATTENTION=BIGMASK \\ PRETRAINED=1 \\ TESTDEVICE=cuda \\ TESTDTYPE=float16 \\ @@ -164,9 +163,11 @@ def _config_reduction(config, task): if qwen25_attention: attention_options = [qwen25_attention] elif device == "cuda" and dtype in ("float16", "bfloat16"): - attention_options = ["PACKED", "BIGMASK"] + attention_options = [ + "PACKED", + ] else: - attention_options = ["LOOPMHA", "LOOPA24", "BIGMASK"] + attention_options = ["LOOPMHA", "LOOPA24"] # fake_inputs = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes)[0] for attention in attention_options: diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index 450ce6d6..706fe1d3 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -682,7 +682,7 @@ def test_sbs_with_loops(self): PLUGS_Qwen25, ) from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( - qwen_sdpa_attention_loopmha_versatile, + qwen_sdpa_attention_versatile, ) class Model(torch.nn.Module): @@ -693,9 +693,7 @@ def forward(self, query, key, value, seq_lens): qs = query * mask ks = key * mask vs = value * mask - attn_output = qwen_sdpa_attention_loopmha_versatile( - qs, ks, vs, seq_lens, 0.11, 16 - ) + attn_output = qwen_sdpa_attention_versatile(qs, ks, vs, seq_lens, 0.11, 16) red = attn_output.mean(dim=-1, keepdim=True) return attn_output - red diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index 300e339b..e3ed995d 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -4,7 +4,11 @@ import onnx import torch from ..helpers import max_diff, string_type -from ..helpers.torch_helper import torch_dtype_to_onnx_dtype +from ..helpers.torch_helper import ( + torch_dtype_to_onnx_dtype, + onnx_dtype_to_torch_dtype, + int_device_to_torch_device, +) from ..reference import OnnxruntimeEvaluator TUPLE_TENSORS = Tuple[torch.Tensor, ...] @@ -50,7 +54,10 @@ class EagerDirectReplacementWithOnnx: only tensors must be counted :param name: the name of the custom op, the function name if not specified :param kwargs: constants parameters with their default values - :param version_selector: selects the version based on the arguments + :param version_selector: selects the version based on the arguments, + see below for an example, this allows the user to define different + onnx version depending on the inputs + :param default_opset: opset to use by default :param verbose: verbose level Here is an example: @@ -134,6 +141,58 @@ def forward(self, x): ).model_proto print(pretty_onnx(onx)) + + This shows how to define multiple versions depending on the device, + the type or the targeted onnx opset. + + .. code-block:: python + + def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]: + first_tensor = next(a for a in args if a is not None) + dtype = first_tensor.dtype + itype = torch_dtype_to_onnx_dtype(dtype) + if dtype == torch.float32: + if opset >= 24: + return "LOOPA24", itype + return "LOOPMHA", itype + if dtype == torch.float16: + if first_tensor.is_cuda: + return "PACKED", itype + return "LOOPMHA", itype + raise AssertionError( + f"Unable to handle type {torch.dtype} (itype={itype}) " + f"on device {torch.device} with opset={opset}" + ) + + qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx( + qwen_sdpa_attention, + lambda qs, *args, **kwargs: torch.empty( + (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]), + dtype=qs.dtype, + device=qs.device, + ), + { + ("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset( + PackedAttention.to_function_proto() + ), + ("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(), + ("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type( + onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto() + ), + ("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset( + LoopMHAAttention.to_function_proto() + ), + ("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type( + onnx.TensorProto.FLOAT16, + _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()), + ), + }, + n_inputs=4, + n_outputs=1, + kwargs=dict(scaling=0.11180339887498948, num_heads=16), + name="qwen_sdpa_attention_versatile", + version_selector=qwen_version_selector, + ) """ def __init__( @@ -146,7 +205,8 @@ def __init__( name: Optional[str] = None, kwargs: Optional[Dict[str, Union[int, float]]] = None, verbose: int = 0, - version_selector: Optional[Callable[[Any], Any]] = None, + version_selector: Optional[Callable[..., Tuple[Any, ...]]] = None, + default_opset: int = 22, ): assert isinstance(function_proto, onnx.FunctionProto) or ( isinstance(function_proto, dict) @@ -183,6 +243,7 @@ def __init__( self.verbose = verbose self.custom_op = self._register() self.version_selector = version_selector + self.default_opset = default_opset self._check_protos(params) def _check_protos(self, params): @@ -221,21 +282,18 @@ def _check_protos(self, params): not self._function_proto_versioned or self.version_selector ), "version_selector is needed when multiple protos are given." - def get_function_proto(self, *args) -> onnx.FunctionProto: + def get_function_proto(self, opset: int, *args) -> onnx.FunctionProto: """Returns the correct version based on the inputs.""" if self._function_proto: return self._function_proto - if ( - len(args) == 1 - and isinstance(args[0], (int, str)) - and args[0] in self._function_proto_versioned - ): - return self._function_proto_versioned[args[0]] + assert isinstance( + opset, int + ), f"The first argument must be an integer for the onnx opset but it is {type(opset)}" assert any( a is not None for a in args ), f"Unexpected args={string_type(args, with_shape=True)}" try: - key = self.version_selector(*args) # type: ignore[misc] + key = self.version_selector(opset, *args) # type: ignore[misc] except (ValueError, AttributeError) as e: raise AssertionError( f"Unable to select a version, fails to get a key, available=" @@ -278,6 +336,8 @@ def _register(self): input_args.append(f"int {p}={val}") elif isinstance(val, float): input_args.append(f"float {p}={val}") + elif isinstance(val, str): + input_args.append(f"str {p}={val}") else: raise NotImplementedError( f"kwargs {p!r} has a default value of unsupported type {type(val)}" @@ -302,6 +362,7 @@ def verify( *args, engine: Optional[Callable] = None, dump_onnx_model: Optional[str] = None, + opset: int = 22, **kwargs, ) -> VerifyResult: """ @@ -316,6 +377,7 @@ def verify( :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`. :param dump_onnx_model: to dump the onnx model used to verify eager and onnx produce the same results + :param opset: onnx opset to use :param kwargs: additional arguments to the function :return: outputs of :func:`onnx_diagnostic.helpers.max_diff` """ @@ -350,7 +412,7 @@ def verify( assert engine is None, f"Not implemented yet with engine={engine!r}" ags, kws = self._make_args_kwargs(*args, **kwargs) sess = OnnxruntimeEvaluator( - self.get_function_proto(*args), + self.get_function_proto(opset, *args), whole=True, dump_onnx_model=dump_onnx_model, function_kwargs=kws, @@ -383,7 +445,17 @@ def converter( *args, **kwargs, ) -> Any: - function_proto = self.get_function_proto(g.get_type(args[0])) + has_devices = [a for a in args if isinstance(a, str) and g.has_device(a)] + assert ( + has_devices + ), f"Missing device for any of the inputs {args}{g.get_debug_msg()}" + arg_device = has_devices[0] + fake_tensor = torch.empty( + tuple([(_ if isinstance(_, int) else 2) for _ in g.get_shape(args[0])]), + dtype=onnx_dtype_to_torch_dtype(g.get_type(args[0])), + device=int_device_to_torch_device(g.get_device(arg_device)), + ) + function_proto = self.get_function_proto(g.main_opset, fake_tensor) if not g.has_local_function(function_proto.name, domain=function_proto.domain): g.add_function(function_proto) ags, kws = self._make_args_kwargs(*args, **kwargs) @@ -417,7 +489,7 @@ def onnx_dynamo_converter(self) -> Callable: onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1) def get_proto(*args): - function_proto = self.get_function_proto(*args) + function_proto = self.get_function_proto(self.default_opset, *args) schema = onnx_plug_op[function_proto.name] if schema is None: all_types = [ diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 74dcea1b..fd5a7b23 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -17,7 +17,6 @@ Union, ) import numpy as np -import numpy.typing as npt import onnx import onnx.helper as oh import onnx.numpy_helper as onh @@ -33,6 +32,8 @@ load as onnx_load, ) +TensorLike = Union[np.ndarray, "torch.Tensor"] # noqa: F821 + def _make_stat(init: TensorProto) -> Dict[str, float]: """ @@ -490,7 +491,7 @@ def convert_endian(tensor: TensorProto) -> None: tensor.raw_data = np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap().tobytes() -def from_array_ml_dtypes(arr: npt.ArrayLike, name: Optional[str] = None) -> TensorProto: +def from_array_ml_dtypes(arr: TensorLike, name: Optional[str] = None) -> TensorProto: """ Converts a numpy array to a tensor def assuming the dtype is defined in ml_dtypes. @@ -536,7 +537,7 @@ def from_array_ml_dtypes(arr: npt.ArrayLike, name: Optional[str] = None) -> Tens } -def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> TensorProto: +def from_array_extended(tensor: TensorLike, name: Optional[str] = None) -> TensorProto: """ Converts an array into a :class:`onnx.TensorProto`. @@ -603,7 +604,7 @@ def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> Te return t -def to_array_extended(proto: TensorProto) -> npt.ArrayLike: +def to_array_extended(proto: TensorProto) -> TensorLike: """Converts :class:`onnx.TensorProto` into a numpy array.""" arr = onh.to_array(proto) if proto.data_type >= onnx.TensorProto.BFLOAT16: diff --git a/onnx_diagnostic/helpers/ort_session.py b/onnx_diagnostic/helpers/ort_session.py index 56f260c4..12812835 100644 --- a/onnx_diagnostic/helpers/ort_session.py +++ b/onnx_diagnostic/helpers/ort_session.py @@ -1,7 +1,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import onnx import numpy as np -import numpy.typing as npt import torch from torch._C import _from_dlpack import onnxruntime @@ -16,6 +15,7 @@ DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)} +TensorLike = Union[np.ndarray, torch.Tensor] class _InferenceSession: @@ -243,16 +243,16 @@ def __init__( ) def run( - self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike] - ) -> List[Optional[npt.ArrayLike]]: + self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike] + ) -> List[Optional[TensorLike]]: """Calls :meth:`onnxruntime.InferenceSession.run`.""" # sess.run does not support blfoat16 # res = self.sess.run(output_names, feeds) return self._post_process_inplace(list(self.run_dlpack(output_names, feeds))) def run_dlpack( - self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike] - ) -> Tuple[Optional[npt.ArrayLike], ...]: + self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike] + ) -> Tuple[Optional[TensorLike], ...]: """ Same as :meth:`onnxruntime.InferenceSession.run` except that feeds is a dictionary of :class:`np.ndarray`. @@ -289,13 +289,13 @@ def run_dlpack( def _ortvalues_to_numpy_tensor( self, ortvalues: Union[List[ORTC.OrtValue], ORTC.OrtValueVector], - ) -> Tuple[Optional[npt.ArrayLike], ...]: + ) -> Tuple[Optional[TensorLike], ...]: if len(ortvalues) == 0: return tuple() if self.nvtx: self.torch.cuda.nvtx.range_push("_ortvalues_to_numpy_tensor") - res: List[Optional[npt.ArrayLike]] = [] # noqa: F823 + res: List[Optional[TensorLike]] = [] # noqa: F823 for i in range(len(ortvalues)): if not ortvalues[i].has_value(): res.append(None) @@ -556,7 +556,7 @@ def investigate_onnxruntime_issue( Union[str, Callable[[onnx.ModelProto], onnxruntime.InferenceSession]] ] = None, # if model needs to be run. - feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, npt.ArrayLike]]] = None, + feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, TensorLike]]] = None, verbose: int = 0, dump_filename: Optional[str] = None, infer_shapes: bool = True, diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 7b0eb566..2f7978db 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -1103,3 +1103,13 @@ def study_discrepancies( if name: fig.savefig(name) return ax + + +def int_device_to_torch_device(device_id: int) -> torch.device: + """ + Converts a device defined as an integer (coming from :meth:`torch.Tensor.get_device`) + into a ``torch.device``. + """ + if device_id < 0: + return torch.device("cpu") + return torch.device("cuda", device_id) diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py index 1a1c61c3..98c1e29b 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py @@ -1,11 +1,11 @@ import os -from typing import Callable, Optional +from typing import Callable, Optional, Tuple import onnx import onnx.helper as oh import torch import torch.nn.functional as F -from ...helpers.torch_helper import torch_dtype_to_onnx_dtype from ...export.onnx_plug import EagerDirectReplacementWithOnnx +from ...helpers.torch_helper import torch_dtype_to_onnx_dtype from .patch_helper import _is_torchdynamo_exporting from ._patch_transformers_attention import patched_sdpa_attention_forward @@ -200,6 +200,39 @@ def qwen_sdpa_attention( scaling: float = 0, num_heads: int = 16, ) -> torch.Tensor: + """ + The loop can be removed with the following code + but it hits memory overflow for big inputs. + + .. code-block:: python + + # make square mask + indices = torch.arange( + cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device + ) + dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to( + cu_seqlens.dtype + ) + dot = dot.sum(dim=0) + mask = dot.unsqueeze(1) - dot.unsqueeze(0) + bool_mask = mask == 0 + bool_mask = bool_mask.unsqueeze(0).unsqueeze(0) + + torch._check(bool_mask.shape[2] == key_states.shape[2]) + torch._check(bool_mask.shape[3] == key_states.shape[2]) + + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=bool_mask, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + ) + """ lengths = cu_seqlens[1:] - cu_seqlens[:-1] splits = [ torch.split(tensor, lengths.tolist(), dim=2) @@ -222,23 +255,27 @@ def qwen_sdpa_attention( attn_output = torch.cat(attn_outputs, dim=1) return attn_output - # not ideal - qwen_sdpa_attention_packed_versatile = EagerDirectReplacementWithOnnx( - qwen_sdpa_attention, - lambda qs, *args, **kwargs: torch.empty( - (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]), - dtype=qs.dtype, - device=qs.device, - ), - _add_com_microsoft_opset(PackedAttention.to_function_proto()), - n_inputs=4, - n_outputs=1, - kwargs=dict(scaling=0.11180339887498948, num_heads=16), - name="qwen_sdpa_attention_packed", - ) - PLUGS.append(qwen_sdpa_attention_packed_versatile) + def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]: + first_tensor = next(a for a in args if a is not None) + dtype = first_tensor.dtype + strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION() + itype = torch_dtype_to_onnx_dtype(dtype) + if strategy is not None: + return strategy, itype + if dtype == torch.float32: + if opset >= 24: + return "LOOPA24", itype + return "LOOPMHA", itype + if dtype == torch.float16: + if first_tensor.is_cuda: + return "PACKED", itype + return "LOOPMHA", itype + raise AssertionError( + f"Unable to handle type {torch.dtype} (itype={itype}) " + f"on device {torch.device} with opset={opset}" + ) - qwen_sdpa_attention_loopmha_versatile = EagerDirectReplacementWithOnnx( + qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx( qwen_sdpa_attention, lambda qs, *args, **kwargs: torch.empty( (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]), @@ -246,10 +283,17 @@ def qwen_sdpa_attention( device=qs.device, ), { - onnx.TensorProto.FLOAT: _add_com_microsoft_opset( + ("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset( + PackedAttention.to_function_proto() + ), + ("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(), + ("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type( + onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto() + ), + ("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset( LoopMHAAttention.to_function_proto() ), - onnx.TensorProto.FLOAT16: _update_sequence_type( + ("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type( onnx.TensorProto.FLOAT16, _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()), ), @@ -257,35 +301,10 @@ def qwen_sdpa_attention( n_inputs=4, n_outputs=1, kwargs=dict(scaling=0.11180339887498948, num_heads=16), - name="qwen_sdpa_attention_loopmha", - version_selector=lambda *args: torch_dtype_to_onnx_dtype( - next(a for a in args if a is not None).dtype - ), + name="qwen_sdpa_attention_versatile", + version_selector=qwen_version_selector, ) - PLUGS.append(qwen_sdpa_attention_loopmha_versatile) - - qwen_sdpa_attention_loopa24_versatile = EagerDirectReplacementWithOnnx( - qwen_sdpa_attention, - lambda qs, *args, **kwargs: torch.empty( - (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]), - dtype=qs.dtype, - device=qs.device, - ), - { - onnx.TensorProto.FLOAT: LoopAttention24.to_function_proto(), - onnx.TensorProto.FLOAT16: _update_sequence_type( - onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto() - ), - }, - n_inputs=4, - n_outputs=1, - kwargs=dict(scaling=0.11180339887498948, num_heads=16), - name="qwen_sdpa_attention_loopa24", - version_selector=lambda *args: torch_dtype_to_onnx_dtype( - next(a for a in args if a is not None).dtype - ), - ) - PLUGS.append(qwen_sdpa_attention_loopa24_versatile) + PLUGS.append(qwen_sdpa_attention_versatile) class patched_Qwen2_5_VLForConditionalGeneration: _PATCHES_ = ["prepare_inputs_for_generation"] @@ -575,9 +594,7 @@ class patched_Qwen2_5_VLVisionAttention: _PATCHED_CLASS_ = ( transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLVisionAttention ) - STRATEGY_FOR_ATTENTION = lambda: os.environ.get( # noqa: E731 - "QWEN25ATTENTION", "PACKED" - ) + STRATEGY_FOR_ATTENTION = lambda: os.environ.get("QWEN25ATTENTION", None) # noqa: E731 def forward( self, @@ -626,9 +643,8 @@ def forward( is transformers.integrations.sdpa_attention.sdpa_attention_forward or attention_interface is patched_sdpa_attention_forward ) - attention_strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION() - if is_sdpa and attention_strategy in "PACKED": - attn_output = qwen_sdpa_attention_packed_versatile( + if is_sdpa: + attn_output = qwen_sdpa_attention_versatile( query_states, key_states, value_states, @@ -660,78 +676,10 @@ def forward( ), version=1, ) - elif is_sdpa and attention_strategy == "LOOPA24": - attn_output = qwen_sdpa_attention_loopa24_versatile( - query_states, - key_states, - value_states, - cu_seqlens, - self.scaling, - self.num_heads, - ) - elif is_sdpa and attention_strategy == "LOOPMHA": - attn_output = qwen_sdpa_attention_loopmha_versatile( - query_states, - key_states, - value_states, - cu_seqlens, - self.scaling, - self.num_heads, - ) - - # to rewrite later with a for loop - # def _iteration(start_end, query_states, key_states, value_states): - # return patched_Qwen2_5_VLVisionAttentionOneIteration.forward( - # self, - # start_end, - # query_states, - # key_states, - # value_states, - # scaling=self.scaling, - # dropout=0.0 if not self.training else self.attention_dropout, - # ) - - # starts = cu_seqlens[:-1] - # ends = cu_seqlens[1:] - # torch._check(starts.shape[0] > 0) - # torch._check(ends.shape[0] > 0) - # starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1) - # attn_outputs = [ - # _iteration(start_end, query_states, key_states, value_states) - # for start_end in starts_ends - # ] - # attn_output = torch.cat(attn_outputs, dim=1) - elif is_sdpa and attention_strategy == "BIGMASK": - # make square mask - indices = torch.arange( - cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device - ) - dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to( - cu_seqlens.dtype - ) - dot = dot.sum(dim=0) - mask = dot.unsqueeze(1) - dot.unsqueeze(0) - bool_mask = mask == 0 - bool_mask = bool_mask.unsqueeze(0).unsqueeze(0) - - torch._check(bool_mask.shape[2] == key_states.shape[2]) - torch._check(bool_mask.shape[3] == key_states.shape[2]) - - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=bool_mask, - scaling=self.scaling, - dropout=0.0 if not self.training else self.attention_dropout, - is_causal=False, - **kwargs, - ) else: raise NotImplementedError( - f"No corresponding export strategy for " - f"{attention_strategy!r}, " + f"No corresponding export strategy for implementation " + f"{self.config._attn_implementation!r}, " f"(use QWEN25ATTENTION to change it), and attention_interface=" f"{attention_interface!r} (use sdpa)" ) @@ -755,6 +703,7 @@ def forward( ) else: # Other implementations: Process each chunk separately + # = qwen_sdpa_attention lengths = cu_seqlens[1:] - cu_seqlens[:-1] splits = [ torch.split(tensor, lengths.tolist(), dim=2)