diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 00ee0232..7dbb847e 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.3 +++++ +* :pr:`326`: use ConcatFromSequence in LoopMHA with the loop * :pr:`325`: adds plug for LoopMHA, extends the unit tests to measure the discrepancies * :pr:`324`: supports FunctionProto with arguments in OnnxruntimeEvaluator * :pr:`323`: drops torch 2.8 on CI diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 3fb0c587..0ff07656 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -626,6 +626,7 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self): *inputs, scaling=0.5, num_heads=16, + itype=onnx.TensorProto.FLOAT16, dump_onnx_model=self.get_dump_file( "test_plug_packed_multi_head_attention_qwen25_loopmha.onnx" ), @@ -636,7 +637,7 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self): self.assertLess(results.diffs[0]["abs"], 0.01) results = qwen_sdpa_attention_loopmha_versatile.verify( - *inputs, scaling=0.11180339887498948, num_heads=16 + *inputs, scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT16 ) self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) self.assertEqual(len(results.eager_outputs), len(results.diffs)) diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py index 5678868a..6a024470 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py @@ -32,6 +32,7 @@ def LoopMHAAttention( cu_seqlens, scaling: float = 0.11180339887498948, num_heads: int = 16, + itype: int = onnx.TensorProto.FLOAT, ): to_3d_shape = op.Constant(value_ints=[0, 0, -1]) query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3]) @@ -43,7 +44,8 @@ def LoopMHAAttention( num_patches = op.Size(cu_seqlens) - 1 seq_axis = op.Constant(value_ints=[1]) seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32) - attn_output = op.Slice(value_3d, [0], [0], seq_axis) + # attn_output = op.Slice(value_3d, [0], [0], seq_axis) + seq_attn = op.SequenceEmpty(dtype=itype) for i_patch in range(num_patches): i_1d = op.Reshape(i_patch, [1]) i_plus_1_1d = i_1d + 1 @@ -59,7 +61,9 @@ def LoopMHAAttention( num_heads=num_heads, scale=scaling, ) - attn_output = op.Concat(attn_output, mha_output, axis=1) + # attn_output = op.Concat(attn_output, mha_output, axis=1) + seq_attn = op.SequenceInsert(seq_attn, mha_output) + attn_output = op.ConcatFromSequence(seq_attn, axis=1) attn_output_4d = op.Reshape(attn_output, output_shape) return attn_output_4d @@ -128,6 +132,7 @@ def qwen_sdpa_attention( cu_seqlens: torch.Tensor, # F7su19 scaling: float = 0, num_heads: int = 16, + itype: int = onnx.TensorProto.FLOAT, ) -> torch.Tensor: lengths = cu_seqlens[1:] - cu_seqlens[:-1] splits = [ @@ -162,7 +167,7 @@ def qwen_sdpa_attention( _add_com_microsoft_opset(PackedAttention.to_function_proto()), n_inputs=4, n_outputs=1, - kwargs=dict(scaling=0.11180339887498948, num_heads=16), + kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT), name="qwen_sdpa_attention_packed", ) PLUGS.append(qwen_sdpa_attention_packed_versatile) @@ -177,7 +182,7 @@ def qwen_sdpa_attention( _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()), n_inputs=4, n_outputs=1, - kwargs=dict(scaling=0.11180339887498948, num_heads=16), + kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT), name="qwen_sdpa_attention_loopmha", ) PLUGS.append(qwen_sdpa_attention_loopmha_versatile) @@ -561,6 +566,15 @@ def forward( cu_seqlens, self.scaling, self.num_heads, + ( + onnx.TensorProto.FLOAT + if query_states.dtype == torch.float32 + else ( + onnx.TensorProto.FLOAT16 + if query_states.dtype == torch.float16 + else onnx.TensorProto.BFLOAT16 + ) + ), ) # to rewrite later with a for loop