66import numpy as np
77import pytest
88import torch
9-
109from executorch .backends .nxp .backend .edge_program_converter import (
1110 EdgeProgramToIRConverter ,
1211)
13- from executorch .backends .nxp .tests .executorch_pipeline import to_quantized_edge_program
12+ from executorch .backends .nxp .backend .ir .converter .builder .aten_model_builder_director import (
13+ AtenModelBuilderDirector ,
14+ )
15+ from executorch .backends .nxp .backend .ir .lib .tflite .BuiltinOperator import (
16+ BuiltinOperator as Ops ,
17+ )
18+ from executorch .backends .nxp .tests .executorch_pipeline import (
19+ ModelInputSpec ,
20+ to_quantized_edge_program ,
21+ )
1422from executorch .backends .nxp .tests .executors import (
1523 convert_run_compare ,
1624 graph_contains_any_of_ops ,
1725)
18- from executorch .exir .dialects ._ops import ops as exir_ops
26+ from executorch .backends .nxp .tests .graph_verifier import DetailedGraphVerifier
27+ from executorch .backends .nxp .tests .model_output_comparator import (
28+ NumericalStatsOutputComparator ,
29+ )
30+ from executorch .backends .nxp .tests .nsys_testing import lower_run_compare
31+ from executorch .backends .nxp .tests .ops_aliases import (
32+ AddTensor ,
33+ Clamp ,
34+ ExecutorchDelegateCall ,
35+ )
36+ from executorch .backends .nxp .tests .use_qat import * # noqa: F403
1937
2038
2139@pytest .fixture (autouse = True )
@@ -24,11 +42,6 @@ def reseed_model_per_test_run():
2442 np .random .seed (23 )
2543
2644
27- # noinspection PyProtectedMember
28- ExecutorchDelegateCall = torch .ops .higher_order .executorch_call_delegate
29- Clamp = exir_ops .edge .aten .clamp .default
30-
31-
3245class ClampModule (torch .nn .Module ):
3346
3447 # noinspection PyShadowingBuiltins
@@ -180,3 +193,119 @@ def test_convert_clamp__no_delegation__unsupported_bounds(min, max):
180193
181194 # Make sure the `clamp` was NOT delegated.
182195 assert graph_contains_any_of_ops (delegated_ep .graph , [Clamp ])
196+
197+
198+ class TestClampNewNeutronFlow :
199+ @pytest .mark .parametrize (
200+ "min, max" ,
201+ [
202+ pytest .param (- 1 , 2 , id = "min = -1, max = 2 (Max/Min)" ),
203+ pytest .param (None , 1 , id = "min = None, max = 1 (Max/Min)" ),
204+ pytest .param (1 , None , id = "min = 1, max = None (Max/Min)" ),
205+ pytest .param (0 , 2 , id = "min = 0, max = 2 (Max/Min)" ),
206+ pytest .param (0 , 1 , id = "min = 0, max = 1 (Relu0To1)" ),
207+ pytest .param (- 1 , 1 , id = "min = -1, max = 1 (ReluN1To1)" ),
208+ pytest .param (0 , None , id = "min = 0, max = None (Relu)" ),
209+ # # Float bounds
210+ pytest .param (- 1.0 , 2.0 , id = "min = -1.0, max = 2.0 (Max/Min)" ),
211+ pytest .param (None , 1.0 , id = "min = None, max = 1.0 (Max/Min)" ),
212+ pytest .param (1.0 , None , id = "min = 1.0, max = None (Max/Min)" ),
213+ pytest .param (1.0 , float ("inf" ), id = "min = 1.0, max = infinity (Max/Min)" ),
214+ pytest .param (- float ("inf" ), 1.0 , id = "min = infinity, max = 1.0 (Max/Min)" ),
215+ pytest .param (0.1 , 0.5 , id = "min = 0.1, max = 0.5 (Max/Min)" ),
216+ pytest .param (0.0 , 1.0 , id = "min = 0.0, max = 1.0 (Relu0To1)" ),
217+ pytest .param (- 1.0 , 1.0 , id = "min = -1.0, max = 1.0 (ReluN1To1)" ),
218+ pytest .param (0.0 , None , id = "min = 0, max = None (Relu)" ),
219+ ],
220+ )
221+ def test_convert_clamp__full_pipeline (self , mocker , min , max , use_qat ):
222+ input_shape = (2 , 7 , 2 ) # Indivisible by num_macs
223+ model = AddClampModule (min , max )
224+
225+ x_input_spec = ModelInputSpec (input_shape )
226+ comparator = NumericalStatsOutputComparator ()
227+ graph_verifier = DetailedGraphVerifier (
228+ mocker ,
229+ expected_delegated_ops = {
230+ AddTensor : 1 ,
231+ Clamp : 1 ,
232+ },
233+ expected_non_delegated_ops = {},
234+ )
235+
236+ lower_run_compare (
237+ model = model ,
238+ input_spec = [x_input_spec ],
239+ dlg_model_verifier = graph_verifier ,
240+ output_comparator = comparator ,
241+ use_new_flow_neutron_c = True ,
242+ use_qat = use_qat ,
243+ )
244+
245+ # noinspection PyShadowingBuiltins
246+ @pytest .mark .parametrize (
247+ "min, max, expected_tflite_ops" ,
248+ [
249+ pytest .param (
250+ 0.1 ,
251+ 0.5 ,
252+ [Ops .ADD , Ops .MAXIMUM , Ops .MINIMUM ],
253+ id = "min = 0.1, max = 0.5 (Max/Min)" ,
254+ ),
255+ pytest .param (
256+ 0.0 , 1.0 , [Ops .ADD , Ops .RELU_0_TO_1 ], id = "min = 0, max = 1 (Relu0To1)"
257+ ),
258+ pytest .param (
259+ - 1.0 ,
260+ 1.0 ,
261+ [Ops .ADD , Ops .RELU_N1_TO_1 ],
262+ id = "min = -1, max = 1 (ReluN1To1)" ,
263+ ),
264+ pytest .param (
265+ 0.0 , None , [Ops .ADD , Ops .RELU ], id = "min = 0, max = None (Relu)"
266+ ),
267+ pytest .param (
268+ 0.0 ,
269+ float ("inf" ),
270+ [Ops .ADD , Ops .RELU ],
271+ id = "min = 0, max = infinity (Relu)" ,
272+ ),
273+ ],
274+ )
275+ def test_convert_clamp__relu_vs_maxmin (self , mocker , min , max , expected_tflite_ops ):
276+ input_shape = (23 ,)
277+ model = AddClampModule (min , max )
278+
279+ converter_spy = mocker .spy (EdgeProgramToIRConverter , "convert_program" )
280+ tflite_spy = mocker .spy (AtenModelBuilderDirector , "finish" )
281+
282+ delegated_ep = to_quantized_edge_program (
283+ model ,
284+ input_shape ,
285+ use_new_flow_neutron_c = True ,
286+ ).exported_program ()
287+
288+ # Make sure the `clamp` was delegated.
289+ assert graph_contains_any_of_ops (delegated_ep .graph , [ExecutorchDelegateCall ])
290+ assert not graph_contains_any_of_ops (delegated_ep .graph , [Clamp ])
291+
292+ intermediate_ep = converter_spy .call_args .args [1 ]
293+ quant_node = list (intermediate_ep .graph .nodes )[- 2 ]
294+ dequant_node = list (intermediate_ep .graph .nodes )[- 4 ]
295+ tflite_internal_ops = list (
296+ op .builtin_code for op in tflite_spy .spy_return .operator_codes .vector
297+ )
298+
299+ assert graph_contains_any_of_ops (intermediate_ep .graph , [Clamp ])
300+ assert len (tflite_internal_ops ) == len (expected_tflite_ops ) + 1 # Transpose
301+ assert all (op in tflite_internal_ops for op in expected_tflite_ops )
302+
303+ if len (expected_tflite_ops ) == 3 :
304+ # Min/Max variant should have same input and output quantization
305+ assert all (
306+ q == dq for q , dq in zip (quant_node .args [1 :], dequant_node .args [1 :])
307+ )
308+ else :
309+ assert not all (
310+ q == dq for q , dq in zip (quant_node .args [1 :], dequant_node .args [1 :])
311+ )
0 commit comments