Skip to content

Commit 62d6dc1

Browse files
transforms: Fix bug in SDPA decomposition (#17238)
When SDPA with causal=True was lowered, a bug in decompose_sdpa was discovered. It assumed that all args are nodes which caused a crash when this is not the case. The pass has been updated to instead access node.all_input_nodes instead of node.args. New test cases are added to backends/arm/test/ops/test_sdpa.py. cc @freddan80 @per @zingo @digantdesai Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent 9d45afc commit 62d6dc1

2 files changed

Lines changed: 51 additions & 24 deletions

File tree

backends/arm/test/ops/test_sdpa.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66

7-
from typing import Tuple
7+
from typing import Callable, Tuple
88

99
import torch
1010

@@ -17,28 +17,57 @@
1717

1818

1919
class SDPA(torch.nn.Module):
20-
def __init__(self):
20+
def __init__(self, attn_mask=None, is_causal=False):
2121
super().__init__()
22+
self.attn_mask = attn_mask
23+
self.is_causal = is_causal
2224

2325
def forward(self, query, key, value):
2426
return torch.nn.functional.scaled_dot_product_attention(
25-
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
27+
query, key, value, attn_mask=self.attn_mask, is_causal=self.is_causal
2628
)
2729

2830

2931
input_t = Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
30-
31-
32-
def test_sdpa_tosa_FP():
33-
test_input = tuple(torch.randn(1, 3, 197, 64) for x in range(3))
34-
pipeline = TosaPipelineFP[input_t](SDPA(), test_input, [], [])
32+
test_case_t = Callable[[], Tuple[SDPA, input_t]]
33+
34+
test_suite = {
35+
# test_name: generator(model, inputs)
36+
"randn_no_mask_non_causal": lambda: (
37+
SDPA(attn_mask=None, is_causal=False),
38+
tuple(torch.randn(1, 3, 197, 64) for _ in range(3)),
39+
),
40+
"randn_no_mask_causal": lambda: (
41+
SDPA(attn_mask=None, is_causal=True),
42+
tuple(torch.randn(1, 3, 197, 64) for _ in range(3)),
43+
),
44+
"randn_with_bool_mask_non_causal": lambda: (
45+
SDPA(attn_mask=(torch.rand(1, 3, 197, 1) > 0.5), is_causal=False),
46+
tuple(torch.randn(1, 3, 197, 64) for _ in range(3)),
47+
),
48+
"randn_with_additive_mask_non_causal": lambda: (
49+
SDPA(
50+
attn_mask=torch.where(torch.rand(1, 3, 197, 1) > 0.5, 0.0, -float("inf")),
51+
is_causal=False,
52+
),
53+
tuple(torch.randn(1, 3, 197, 64) for _ in range(3)),
54+
),
55+
# causal with mask is not supported in PyTorch (https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
56+
}
57+
58+
59+
@common.parametrize("test_case", test_suite)
60+
def test_sdpa_tosa_FP(test_case: test_case_t):
61+
model, test_input = test_case()
62+
pipeline = TosaPipelineFP[input_t](model, test_input, [], [])
3563
pipeline.pop_stage("check_count.exir")
3664
pipeline.run()
3765

3866

39-
def test_sdpa_tosa_INT():
40-
test_input = tuple(torch.randn(1, 3, 197, 64) for x in range(3))
41-
pipeline = TosaPipelineINT[input_t](SDPA(), test_input, [], [])
67+
@common.parametrize("test_case", test_suite)
68+
def test_sdpa_tosa_INT(test_case: test_case_t):
69+
model, test_input = test_case()
70+
pipeline = TosaPipelineINT[input_t](model, test_input, [], [])
4271
pipeline.pop_stage("check.quant_nodes")
4372
pipeline.pop_stage("check_count.exir")
4473
pipeline.pop_stage(
@@ -48,10 +77,11 @@ def test_sdpa_tosa_INT():
4877

4978

5079
@common.SkipIfNoModelConverter
51-
def test_sdpa_vgf_no_quant():
52-
test_input = tuple(torch.randn(1, 3, 197, 64) for _ in range(3))
80+
@common.parametrize("test_case", test_suite)
81+
def test_sdpa_vgf_no_quant(test_case: test_case_t):
82+
model, test_input = test_case()
5383
pipeline = VgfPipeline[input_t](
54-
SDPA(),
84+
model,
5585
test_input,
5686
[],
5787
[],
@@ -61,13 +91,10 @@ def test_sdpa_vgf_no_quant():
6191

6292

6393
@common.SkipIfNoModelConverter
64-
def test_sdpa_vgf_quant():
65-
test_input = tuple(torch.randn(1, 3, 197, 64) for _ in range(3))
94+
@common.parametrize("test_case", test_suite)
95+
def test_sdpa_vgf_quant(test_case: test_case_t):
96+
model, test_input = test_case()
6697
pipeline = VgfPipeline[input_t](
67-
SDPA(),
68-
test_input,
69-
[],
70-
[],
71-
quantize=True,
98+
model, test_input, [], [], quantize=True, run_on_vulkan_runtime=False
7299
)
73100
pipeline.run()

backends/transforms/decompose_sdpa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3-
# Copyright 2025 Arm Limited and/or its affiliates.
3+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -49,7 +49,7 @@ def _decompose_sdpa_node(
4949
allow_non_fake_inputs: bool,
5050
) -> None:
5151
graph = graph_module.graph
52-
input_tensors = (arg.meta["val"] for arg in node.args)
52+
input_tensors = (input_node.meta["val"] for input_node in node.all_input_nodes)
5353
scale = node.kwargs.get("scale", None)
5454

5555
# refer to pytorch/test/test_decomp.py

0 commit comments

Comments
 (0)