@@ -116,6 +116,130 @@ def test_to_tosa_FP_bf16_with_extension():
116116 pipeline .run ()
117117
118118
119+ _TO_COPY_TEST_DATA_FP_FP8 = {
120+ "fp32_to_fp8e4m3" : lambda : (
121+ torch .rand ((1 , 2 , 3 , 4 ), dtype = torch .float32 ),
122+ torch .float8_e4m3fn ,
123+ "fp8e4m3" ,
124+ ),
125+ "fp16_to_fp8e5m2" : lambda : (
126+ torch .rand ((1 , 2 , 3 , 4 ), dtype = torch .float16 ),
127+ torch .float8_e5m2 ,
128+ "fp8e5m2" ,
129+ ),
130+ "fp8e4m3_to_fp32" : lambda : (
131+ torch .rand ((1 , 2 , 3 , 4 ), dtype = torch .float32 ).to (torch .float8_e4m3fn ),
132+ torch .float32 ,
133+ "fp8e4m3" ,
134+ ),
135+ "fp8e5m2_to_fp16" : lambda : (
136+ torch .rand ((1 , 2 , 3 , 4 ), dtype = torch .float32 ).to (torch .float8_e5m2 ),
137+ torch .float16 ,
138+ "fp8e5m2" ,
139+ ),
140+ }
141+
142+
143+ def test_to_tosa_FP_fp8e4m3_requires_extension ():
144+ test_tensor = torch .rand ((1 , 2 , 3 , 4 ), dtype = torch .float32 )
145+ pipeline = OpNotSupportedPipeline [input_t1 ](
146+ Cast (torch .float8_e4m3fn ),
147+ (test_tensor ,),
148+ {
149+ "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" : 1
150+ },
151+ )
152+ pipeline .run ()
153+
154+
155+ def test_to_tosa_FP_fp8e5m2_requires_extension ():
156+ test_tensor = torch .rand ((1 , 2 , 3 , 4 ), dtype = torch .float16 )
157+ pipeline = OpNotSupportedPipeline [input_t1 ](
158+ Cast (torch .float8_e5m2 ),
159+ (test_tensor ,),
160+ {
161+ "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" : 1
162+ },
163+ )
164+ pipeline .run ()
165+
166+
167+ def test_to_tosa_FP_bf16_to_fp8e4m3_requires_both_extensions ():
168+ test_tensor = torch .rand ((1 , 2 , 3 , 4 ), dtype = torch .bfloat16 )
169+ pipeline = OpNotSupportedPipeline [input_t1 ](
170+ Cast (torch .float8_e4m3fn ),
171+ (test_tensor ,),
172+ {
173+ "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" : 1
174+ },
175+ tosa_extensions = ["bf16" ],
176+ )
177+ pipeline .run ()
178+
179+
180+ def test_to_tosa_FP_bf16_to_fp8e5m2_requires_both_extensions ():
181+ test_tensor = torch .rand ((1 , 2 , 3 , 4 ), dtype = torch .bfloat16 )
182+ pipeline = OpNotSupportedPipeline [input_t1 ](
183+ Cast (torch .float8_e5m2 ),
184+ (test_tensor ,),
185+ {
186+ "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" : 1
187+ },
188+ tosa_extensions = ["bf16" ],
189+ )
190+ pipeline .run ()
191+
192+
193+ @common .parametrize ("test_data" , _TO_COPY_TEST_DATA_FP_FP8 )
194+ def test_to_tosa_FP_fp8_with_extension (test_data : Tuple ):
195+ test_tensor , new_dtype , tosa_extension = test_data ()
196+ pipeline = TosaPipelineFP [input_t1 ](
197+ Cast (new_dtype ),
198+ (test_tensor ,),
199+ aten_op = [],
200+ exir_op = [],
201+ tosa_extensions = [tosa_extension ],
202+ )
203+ pipeline .run ()
204+
205+
206+ _TO_COPY_TEST_DATA_BF16_FP8 = {
207+ "bf16_to_fp8e4m3" : lambda : (
208+ torch .rand ((1 , 2 , 3 , 4 ), dtype = torch .bfloat16 ),
209+ torch .float8_e4m3fn ,
210+ ["bf16" , "fp8e4m3" ],
211+ ),
212+ "fp8e4m3_to_bf16" : lambda : (
213+ torch .rand ((1 , 2 , 3 , 4 ), dtype = torch .float32 ).to (torch .float8_e4m3fn ),
214+ torch .bfloat16 ,
215+ ["bf16" , "fp8e4m3" ],
216+ ),
217+ "bf16_to_fp8e5m2" : lambda : (
218+ torch .rand ((1 , 2 , 3 , 4 ), dtype = torch .bfloat16 ),
219+ torch .float8_e5m2 ,
220+ ["bf16" , "fp8e5m2" ],
221+ ),
222+ "fp8e5m2_to_bf16" : lambda : (
223+ torch .rand ((1 , 2 , 3 , 4 ), dtype = torch .float32 ).to (torch .float8_e5m2 ),
224+ torch .bfloat16 ,
225+ ["bf16" , "fp8e5m2" ],
226+ ),
227+ }
228+
229+
230+ @common .parametrize ("test_data" , _TO_COPY_TEST_DATA_BF16_FP8 )
231+ def test_to_tosa_FP_bf16_fp8_with_extensions (test_data : Tuple ):
232+ test_tensor , new_dtype , tosa_extensions = test_data ()
233+ pipeline = TosaPipelineFP [input_t1 ](
234+ Cast (new_dtype ),
235+ (test_tensor ,),
236+ aten_op = [],
237+ exir_op = [],
238+ tosa_extensions = tosa_extensions ,
239+ )
240+ pipeline .run ()
241+
242+
119243@common .parametrize ("test_data" , _TO_COPY_TEST_DATA_FP )
120244@common .SkipIfNoModelConverter
121245def test_to_vgf_no_quant (test_data : Tuple ):
0 commit comments