1717 VgfPipeline ,
1818)
1919
20+ amax_aten_op = "torch.ops.aten.amax"
21+ amax_exir_op = "executorch_exir_dialects_edge__ops_aten_amax_default"
22+
23+ max_aten_op = "torch.ops.aten.max"
24+ max_exir_op = "executorch_exir_dialects_edge__ops_aten_max_default"
25+
2026
2127class Amax (torch .nn .Module ):
2228 input_t = Tuple [Tuple [torch .Tensor ], int | Tuple [int ], bool ]
23- aten_op = ["torch.ops.aten.amax" ]
2429
2530 def __init__ (self , dim , keep_dims ):
2631 self .dim = dim
@@ -38,10 +43,35 @@ def forward(self, x):
3843 "rank_4_mult_batches" : lambda : ((torch .rand ([2 , 2 , 2 , 2 ]),), (0 ), True ),
3944 }
4045
46+ test_data_fp16 : Dict [str , input_t ] = {
47+ "rank_1_dim_0_fp16" : lambda : (
48+ (torch .rand ([10 ], dtype = torch .float16 ),),
49+ 0 ,
50+ False ,
51+ ),
52+ "rank_2_dim_1_keep_dims_fp16" : lambda : (
53+ (torch .rand ([2 , 2 ], dtype = torch .float16 ),),
54+ (1 ,),
55+ True ,
56+ ),
57+ }
58+
59+ test_data_bf16 : Dict [str , input_t ] = {
60+ "rank_1_dim_0_bf16" : lambda : (
61+ (torch .rand ([10 ], dtype = torch .bfloat16 ),),
62+ 0 ,
63+ False ,
64+ ),
65+ "rank_2_dim_1_keep_dims_bf16" : lambda : (
66+ (torch .rand ([2 , 2 ], dtype = torch .bfloat16 ),),
67+ (1 ,),
68+ True ,
69+ ),
70+ }
71+
4172
4273class Max (torch .nn .Module ):
4374 input_t = Tuple [Tuple [torch .Tensor ], int ]
44- aten_op = ["torch.ops.aten.amax" ]
4575
4676 def __init__ (self , dim ):
4777 self .dim = dim
@@ -58,8 +88,20 @@ def forward(self, x):
5888 "rank_4_dim_3" : lambda : ((torch .rand ([2 , 2 , 2 , 2 ]),), 3 ),
5989 }
6090
91+ test_data_fp16 : Dict [str , input_t ] = {
92+ "rank_1_dim_0_fp16" : lambda : ((torch .rand ([10 ], dtype = torch .float16 ),), 0 ),
93+ "rank_2_dim_1_fp16" : lambda : ((torch .rand ([2 , 2 ], dtype = torch .float16 ),), 1 ),
94+ }
95+
96+ test_data_bf16 : Dict [str , input_t ] = {
97+ "rank_1_dim_0_bf16" : lambda : ((torch .rand ([10 ], dtype = torch .bfloat16 ),), 0 ),
98+ "rank_2_dim_1_bf16" : lambda : ((torch .rand ([2 , 2 ], dtype = torch .bfloat16 ),), 1 ),
99+ }
100+
61101
62102class MaxWithIndex (torch .nn .Module ):
103+ input_t = Tuple [Tuple [torch .Tensor ], int ]
104+
63105 def __init__ (self , dim ):
64106 self .dim = dim
65107 super ().__init__ ()
@@ -68,18 +110,29 @@ def forward(self, x):
68110 x , i = torch .max (x , self .dim )
69111 return x , i
70112
113+ test_data : Dict [str , input_t ] = Max .test_data
114+ test_data_fp16 : Dict [str , input_t ] = Max .test_data_fp16
115+ test_data_bf16 : Dict [str , input_t ] = Max .test_data_bf16
71116
72- @common .parametrize ("test_data" , Amax .test_data )
117+
118+ @common .parametrize (
119+ "test_data" , Amax .test_data | Amax .test_data_fp16 | Amax .test_data_bf16
120+ )
73121def test_amax_tosa_FP (test_data : Amax .input_t ):
74122 data , dim , keep_dims = test_data ()
75- pipeline = TosaPipelineFP [Amax .input_t ](Amax (dim , keep_dims ), data , Amax .aten_op )
123+ pipeline = TosaPipelineFP [Amax .input_t ](
124+ Amax (dim , keep_dims ),
125+ data ,
126+ amax_aten_op ,
127+ tosa_extensions = ["bf16" ],
128+ )
76129 pipeline .run ()
77130
78131
79132@common .parametrize ("test_data" , Amax .test_data )
80133def test_amax_tosa_INT (test_data : Amax .input_t ):
81134 data , dim , keep_dims = test_data ()
82- pipeline = TosaPipelineINT [Amax .input_t ](Amax (dim , keep_dims ), data , Amax . aten_op )
135+ pipeline = TosaPipelineINT [Amax .input_t ](Amax (dim , keep_dims ), data , amax_aten_op )
83136 pipeline .run ()
84137
85138
@@ -88,7 +141,7 @@ def test_amax_u55_INT_not_delegated():
88141 pipeline = OpNotSupportedPipeline [Amax .input_t ](
89142 Amax (dim , keep_dims ),
90143 data ,
91- {" executorch_exir_dialects_edge__ops_aten_amax_default" : 1 },
144+ {"executorch_exir_dialects_edge__ops_aten_amax_default" : 1 },
92145 quantize = True ,
93146 u55_subset = True ,
94147 )
@@ -102,23 +155,30 @@ def test_amax_u85_INT(test_data: Amax.input_t):
102155 pipeline = EthosU85PipelineINT [Amax .input_t ](
103156 Amax (dim , keep_dims ),
104157 data ,
105- Amax . aten_op ,
158+ amax_aten_op ,
106159 )
107160 pipeline .run ()
108161
109162
110- @common .parametrize ("test_data" , Max .test_data )
163+ @common .parametrize (
164+ "test_data" , Max .test_data | Max .test_data_fp16 | Max .test_data_bf16
165+ )
111166def test_max_dim_tosa_FP_to_amax (test_data : Max .input_t ):
112167 data , dim = test_data ()
113- pipeline = TosaPipelineFP [Max .input_t ](Max (dim ), data , "torch.ops.aten.max" )
168+ pipeline = TosaPipelineFP [Max .input_t ](
169+ Max (dim ),
170+ data ,
171+ max_aten_op ,
172+ tosa_extensions = ["bf16" ],
173+ )
114174 pipeline .run ()
115175
116176
117177@common .parametrize ("test_data" , Max .test_data )
118178def test_max_dim_tosa_INT_to_amax (test_data : Max .input_t ):
119179 data , dim = test_data ()
120180 module = Max (dim )
121- pipeline = TosaPipelineINT [Max .input_t ](module , data , "torch.ops.aten.amax" )
181+ pipeline = TosaPipelineINT [Max .input_t ](module , data , amax_aten_op )
122182 pipeline .run ()
123183
124184
@@ -137,15 +197,15 @@ def test_max_dim_tosa_FP_not_delegated():
137197 pipeline .run ()
138198
139199
140- @common .parametrize ("test_data" , Amax .test_data )
200+ @common .parametrize ("test_data" , Amax .test_data | Amax . test_data_fp16 )
141201@common .SkipIfNoModelConverter
142202def test_amax_vgf_no_quant (test_data : Amax .input_t ):
143203 data , dim , keep_dims = test_data ()
144204 module = Amax (dim , keep_dims )
145205 pipeline = VgfPipeline [Amax .input_t ](
146206 module ,
147207 data ,
148- Amax . aten_op ,
208+ amax_aten_op ,
149209 quantize = False ,
150210 )
151211 pipeline .run ()
@@ -159,20 +219,20 @@ def test_amax_vgf_quant(test_data: Amax.input_t):
159219 pipeline = VgfPipeline [Amax .input_t ](
160220 module ,
161221 data ,
162- Amax . aten_op ,
222+ amax_aten_op ,
163223 quantize = True ,
164224 )
165225 pipeline .run ()
166226
167227
168- @common .parametrize ("test_data" , Max .test_data )
228+ @common .parametrize ("test_data" , Max .test_data | Max . test_data_fp16 )
169229@common .SkipIfNoModelConverter
170230def test_max_dim_vgf_no_quant_to_amax (test_data : Max .input_t ):
171231 data , dim = test_data ()
172232 pipeline = VgfPipeline [Max .input_t ](
173233 Max (dim ),
174234 data ,
175- "torch.ops.aten.max" ,
235+ max_aten_op ,
176236 quantize = False ,
177237 )
178238 pipeline .run ()
@@ -185,7 +245,7 @@ def test_max_dim_vgf_quant_to_amax(test_data: Max.input_t):
185245 pipeline = VgfPipeline [Max .input_t ](
186246 Max (dim ),
187247 data ,
188- "torch.ops.aten.amax" ,
248+ amax_aten_op ,
189249 quantize = True ,
190250 )
191251 pipeline .run ()
@@ -199,7 +259,7 @@ def test_amax_tosa_INT_a16w8(test_data: Amax.input_t):
199259 pipeline = TosaPipelineINT [Max .input_t ](
200260 module ,
201261 data ,
202- "torch.ops.aten.amax" ,
262+ amax_aten_op ,
203263 tosa_extensions = ["int16" ],
204264 )
205265 pipeline .run ()
@@ -214,7 +274,7 @@ def test_amax_u85_INT_a16w8(test_data: Amax.input_t):
214274 pipeline = EthosU85PipelineINT [Max .input_t ](
215275 module ,
216276 data ,
217- "torch.ops.aten.amax" ,
277+ amax_aten_op ,
218278 a16w8_quantization = True ,
219279 use_to_edge_transform_and_lower = True ,
220280 )
0 commit comments