Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.3
+++++

* :pr:`326`: use ConcatFromSequence in LoopMHA with the loop
* :pr:`325`: adds plug for LoopMHA, extends the unit tests to measure the discrepancies
* :pr:`324`: supports FunctionProto with arguments in OnnxruntimeEvaluator
* :pr:`323`: drops torch 2.8 on CI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
*inputs,
scaling=0.5,
num_heads=16,
itype=onnx.TensorProto.FLOAT16,
dump_onnx_model=self.get_dump_file(
"test_plug_packed_multi_head_attention_qwen25_loopmha.onnx"
),
Expand All @@ -636,7 +637,7 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
self.assertLess(results.diffs[0]["abs"], 0.01)

results = qwen_sdpa_attention_loopmha_versatile.verify(
*inputs, scaling=0.11180339887498948, num_heads=16
*inputs, scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT16
)
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
self.assertEqual(len(results.eager_outputs), len(results.diffs))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def LoopMHAAttention(
cu_seqlens,
scaling: float = 0.11180339887498948,
num_heads: int = 16,
itype: int = onnx.TensorProto.FLOAT,
):
to_3d_shape = op.Constant(value_ints=[0, 0, -1])
query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3])
Expand All @@ -43,7 +44,8 @@ def LoopMHAAttention(
num_patches = op.Size(cu_seqlens) - 1
seq_axis = op.Constant(value_ints=[1])
seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32)
attn_output = op.Slice(value_3d, [0], [0], seq_axis)
# attn_output = op.Slice(value_3d, [0], [0], seq_axis)
seq_attn = op.SequenceEmpty(dtype=itype)
for i_patch in range(num_patches):
i_1d = op.Reshape(i_patch, [1])
i_plus_1_1d = i_1d + 1
Expand All @@ -59,7 +61,9 @@ def LoopMHAAttention(
num_heads=num_heads,
scale=scaling,
)
attn_output = op.Concat(attn_output, mha_output, axis=1)
# attn_output = op.Concat(attn_output, mha_output, axis=1)
seq_attn = op.SequenceInsert(seq_attn, mha_output)
attn_output = op.ConcatFromSequence(seq_attn, axis=1)
attn_output_4d = op.Reshape(attn_output, output_shape)
return attn_output_4d

Expand Down Expand Up @@ -128,6 +132,7 @@ def qwen_sdpa_attention(
cu_seqlens: torch.Tensor, # F7su19
scaling: float = 0,
num_heads: int = 16,
itype: int = onnx.TensorProto.FLOAT,
) -> torch.Tensor:
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
Expand Down Expand Up @@ -162,7 +167,7 @@ def qwen_sdpa_attention(
_add_com_microsoft_opset(PackedAttention.to_function_proto()),
n_inputs=4,
n_outputs=1,
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT),
name="qwen_sdpa_attention_packed",
)
PLUGS.append(qwen_sdpa_attention_packed_versatile)
Expand All @@ -177,7 +182,7 @@ def qwen_sdpa_attention(
_add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
n_inputs=4,
n_outputs=1,
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT),
name="qwen_sdpa_attention_loopmha",
)
PLUGS.append(qwen_sdpa_attention_loopmha_versatile)
Expand Down Expand Up @@ -561,6 +566,15 @@ def forward(
cu_seqlens,
self.scaling,
self.num_heads,
(
onnx.TensorProto.FLOAT
if query_states.dtype == torch.float32
else (
onnx.TensorProto.FLOAT16
if query_states.dtype == torch.float16
else onnx.TensorProto.BFLOAT16
)
),
)

# to rewrite later with a for loop
Expand Down
Loading