1- # Copyright 2025 Arm Limited and/or its affiliates.
1+ # Copyright 2025-2026 Arm Limited and/or its affiliates.
22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
66
7- from typing import Tuple
7+ from typing import Callable , Tuple
88
99import torch
1010
1717
1818
1919class SDPA (torch .nn .Module ):
20- def __init__ (self ):
20+ def __init__ (self , attn_mask = None , is_causal = False ):
2121 super ().__init__ ()
22+ self .attn_mask = attn_mask
23+ self .is_causal = is_causal
2224
2325 def forward (self , query , key , value ):
2426 return torch .nn .functional .scaled_dot_product_attention (
25- query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False
27+ query , key , value , attn_mask = self . attn_mask , is_causal = self . is_causal
2628 )
2729
2830
2931input_t = Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]
30-
31-
32- def test_sdpa_tosa_FP ():
33- test_input = tuple (torch .randn (1 , 3 , 197 , 64 ) for x in range (3 ))
34- pipeline = TosaPipelineFP [input_t ](SDPA (), test_input , [], [])
32+ test_case_t = Callable [[], Tuple [SDPA , input_t ]]
33+
34+ test_suite = {
35+ # test_name: generator(model, inputs)
36+ "randn_no_mask_non_causal" : lambda : (
37+ SDPA (attn_mask = None , is_causal = False ),
38+ tuple (torch .randn (1 , 3 , 197 , 64 ) for _ in range (3 )),
39+ ),
40+ "randn_no_mask_causal" : lambda : (
41+ SDPA (attn_mask = None , is_causal = True ),
42+ tuple (torch .randn (1 , 3 , 197 , 64 ) for _ in range (3 )),
43+ ),
44+ "randn_with_bool_mask_non_causal" : lambda : (
45+ SDPA (attn_mask = (torch .rand (1 , 3 , 197 , 1 ) > 0.5 ), is_causal = False ),
46+ tuple (torch .randn (1 , 3 , 197 , 64 ) for _ in range (3 )),
47+ ),
48+ "randn_with_additive_mask_non_causal" : lambda : (
49+ SDPA (
50+ attn_mask = torch .where (torch .rand (1 , 3 , 197 , 1 ) > 0.5 , 0.0 , - float ("inf" )),
51+ is_causal = False ,
52+ ),
53+ tuple (torch .randn (1 , 3 , 197 , 64 ) for _ in range (3 )),
54+ ),
55+ # causal with mask is not supported in PyTorch (https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
56+ }
57+
58+
59+ @common .parametrize ("test_case" , test_suite )
60+ def test_sdpa_tosa_FP (test_case : test_case_t ):
61+ model , test_input = test_case ()
62+ pipeline = TosaPipelineFP [input_t ](model , test_input , [], [])
3563 pipeline .pop_stage ("check_count.exir" )
3664 pipeline .run ()
3765
3866
39- def test_sdpa_tosa_INT ():
40- test_input = tuple (torch .randn (1 , 3 , 197 , 64 ) for x in range (3 ))
41- pipeline = TosaPipelineINT [input_t ](SDPA (), test_input , [], [])
67+ @common .parametrize ("test_case" , test_suite )
68+ def test_sdpa_tosa_INT (test_case : test_case_t ):
69+ model , test_input = test_case ()
70+ pipeline = TosaPipelineINT [input_t ](model , test_input , [], [])
4271 pipeline .pop_stage ("check.quant_nodes" )
4372 pipeline .pop_stage ("check_count.exir" )
4473 pipeline .pop_stage (
@@ -48,10 +77,11 @@ def test_sdpa_tosa_INT():
4877
4978
5079@common .SkipIfNoModelConverter
51- def test_sdpa_vgf_no_quant ():
52- test_input = tuple (torch .randn (1 , 3 , 197 , 64 ) for _ in range (3 ))
80+ @common .parametrize ("test_case" , test_suite )
81+ def test_sdpa_vgf_no_quant (test_case : test_case_t ):
82+ model , test_input = test_case ()
5383 pipeline = VgfPipeline [input_t ](
54- SDPA () ,
84+ model ,
5585 test_input ,
5686 [],
5787 [],
@@ -61,13 +91,10 @@ def test_sdpa_vgf_no_quant():
6191
6292
6393@common .SkipIfNoModelConverter
64- def test_sdpa_vgf_quant ():
65- test_input = tuple (torch .randn (1 , 3 , 197 , 64 ) for _ in range (3 ))
94+ @common .parametrize ("test_case" , test_suite )
95+ def test_sdpa_vgf_quant (test_case : test_case_t ):
96+ model , test_input = test_case ()
6697 pipeline = VgfPipeline [input_t ](
67- SDPA (),
68- test_input ,
69- [],
70- [],
71- quantize = True ,
98+ model , test_input , [], [], quantize = True , run_on_vulkan_runtime = False
7299 )
73100 pipeline .run ()
0 commit comments