1010
1111import torch
1212import torch .nn .functional as F
13- from executorch import exir
14-
1513from executorch .backends .xnnpack .partition .xnnpack_partitioner import (
1614 XnnpackDynamicallyQuantizedPartitioner ,
1715 XnnpackPartitioner ,
2523 get_xnnpack_edge_compile_config ,
2624 get_xnnpack_executorch_backend_config ,
2725)
28- from executorch .backends .xnnpack .utils .utils import capture_graph_for_xnnpack
2926
3027# import the xnnpack backend implementation
3128from executorch .backends .xnnpack .xnnpack_preprocess import XnnpackBackend
3532from executorch .devtools .bundled_program .serialize import (
3633 serialize_from_bundled_program_to_flatbuffer ,
3734)
38- from executorch .exir import ExecutorchProgram , ExirExportedProgram
35+ from executorch .exir import EdgeProgramManager , to_edge
3936from executorch .exir .backend .backend_api import to_backend , validation_disabled
4037
4138from executorch .exir .passes .spec_prop_pass import SpecPropPass
@@ -157,6 +154,14 @@ def assert_outputs_equal(self, model_output, ref_output):
157154 torch .allclose (model_output [0 ], ref_output , atol = 1e-03 , rtol = 1e-03 )
158155 )
159156
157+ def _capture_graph_for_xnnpack (
158+ self , module : torch .nn .Module , sample_inputs : Tuple [torch .Tensor ]
159+ ) -> EdgeProgramManager :
160+ return to_edge (
161+ export (module , sample_inputs , strict = True ),
162+ compile_config = get_xnnpack_edge_compile_config (),
163+ ).transform (* get_transform_passes ())
164+
160165 def lower_module_and_test_output (
161166 self ,
162167 module : Any ,
@@ -167,15 +172,15 @@ def lower_module_and_test_output(
167172 # TODO: remove this after we migrate to use long term flow
168173 quantizer_api_test : bool = False ,
169174 dump_bundled_program : bool = False , # for debugging, dump the generated bundled program file
170- ) -> ExirExportedProgram :
175+ ) -> EdgeProgramManager :
171176 """
172177 Helper testing function that takes a torch.nn.Module and lowers it to XNNPACK with
173178 the given sample inputs. It then runs the lowered module and compares its
174179 outputs with the outputs of the eager module.
175180 """
176181
177182 if quantizer_api_test :
178- assert isinstance (module , ExirExportedProgram )
183+ assert isinstance (module , EdgeProgramManager )
179184 edge_program = module
180185 else :
181186
@@ -187,7 +192,9 @@ def __init__(self):
187192 def forward (self , * args ):
188193 return self .one_module (* args )
189194
190- edge_program = capture_graph_for_xnnpack (WrappedModule (), sample_inputs )
195+ edge_program = self ._capture_graph_for_xnnpack (
196+ WrappedModule (), sample_inputs
197+ )
191198
192199 partitioner = None
193200 if quantized :
@@ -201,35 +208,32 @@ def forward(self, *args):
201208 if use_partitioner :
202209 with validation_disabled ():
203210 delegated_program = edge_program
204- delegated_program .exported_program = to_backend (
205- edge_program .exported_program , partitioner
211+ delegated_program ._edge_programs [ "forward" ] = to_backend (
212+ edge_program .exported_program () , partitioner
206213 )
207214
208- executorch_program : ExecutorchProgram = delegated_program .to_executorch (
215+ executorch_program = delegated_program .to_executorch (
209216 get_xnnpack_executorch_backend_config ([SpecPropPass ()]),
210217 )
211218 else :
212- delegated_program = to_backend (
213- "XnnpackBackend" , edge_program .exported_program , []
219+ delegated_module = to_backend (
220+ "XnnpackBackend" , edge_program .exported_program () , []
214221 )
215222
216- exported_program : ExirExportedProgram = capture_graph_for_xnnpack (
217- delegated_program , sample_inputs
223+ exported_program = self . _capture_graph_for_xnnpack (
224+ delegated_module , sample_inputs
218225 )
219- executorch_program : ExecutorchProgram = exported_program .to_executorch (
226+ executorch_program = exported_program .to_executorch (
220227 get_xnnpack_executorch_backend_config (),
221228 )
222229
223- # print("Graph Module with delegate:")
224- # delegated_module.print_readable()
225-
226230 # Assert the backend name is xnnpack
227231 self .assertEqual (
228- executorch_program .program .execution_plan [0 ].delegates [0 ].id ,
232+ executorch_program .executorch_program .execution_plan [0 ].delegates [0 ].id ,
229233 XnnpackBackend .__name__ ,
230234 )
231235
232- ref_output = delegated_program (* sample_inputs )
236+ ref_output = delegated_program . exported_program (). module () (* sample_inputs )
233237 if dump_bundled_program :
234238 save_bundled_program (
235239 representative_inputs = sample_inputs ,
@@ -325,14 +329,9 @@ def quantize_and_test_model_with_quantizer(
325329 prepared = prepare_pt2e (m , quantizer )
326330 converted = convert_pt2e (prepared )
327331
328- captured_program = exir .capture (
329- converted ,
330- example_inputs ,
331- config = exir .CaptureConfig (enable_aot = True , _unlift = True ),
332- )
333-
334- edge_program = captured_program .to_edge (
335- get_xnnpack_edge_compile_config ()
332+ edge_program = to_edge (
333+ export (converted , example_inputs , strict = True ),
334+ compile_config = get_xnnpack_edge_compile_config (),
336335 ).transform (* get_transform_passes ())
337336 delegated_module = self .lower_module_and_test_output (
338337 module = edge_program ,
@@ -350,7 +349,7 @@ def quantize_and_test_model_with_quantizer(
350349 }
351350 for op in supported_ops :
352351 FileCheck ().check_count (op , 0 , exactly = True ).run (
353- delegated_module .exported_program .graph_module .code
352+ delegated_module .exported_program () .graph_module .code
354353 )
355354
356355 def _test_xnnpack_dqlinear (
@@ -398,12 +397,14 @@ def _test_xnnpack_dqlinear(
398397 prepared_linear ,
399398 )
400399
401- captured_dqlinear = capture_graph_for_xnnpack (converted_linear , example_inputs )
400+ captured_dqlinear = self ._capture_graph_for_xnnpack (
401+ converted_linear , example_inputs
402+ )
402403
403- captured_dqlinear .exported_program .graph_module .graph .print_tabular ()
404+ captured_dqlinear .exported_program () .graph_module .graph .print_tabular ()
404405
405406 lowered_module = to_backend (
406- "XnnpackBackend" , captured_dqlinear .exported_program , []
407+ "XnnpackBackend" , captured_dqlinear .exported_program () , []
407408 )
408409
409410 class CompositeModule (torch .nn .Module ):
@@ -417,19 +418,19 @@ def forward(self, x):
417418 composite_model = CompositeModule ()
418419 composite_model (* example_inputs )
419420
420- exported_program : ExirExportedProgram = capture_graph_for_xnnpack (
421+ exported_program = self . _capture_graph_for_xnnpack (
421422 composite_model , example_inputs
422423 )
423- executorch_program : ExecutorchProgram = exported_program .to_executorch (
424+ executorch_program = exported_program .to_executorch (
424425 get_xnnpack_executorch_backend_config (),
425426 )
426427
427428 self .assertEqual (
428- executorch_program .program .execution_plan [0 ].delegates [0 ].id ,
429+ executorch_program .executorch_program .execution_plan [0 ].delegates [0 ].id ,
429430 XnnpackBackend .__name__ ,
430431 )
431432
432- ref_output = captured_dqlinear (* example_inputs )
433+ ref_output = captured_dqlinear . exported_program (). module () (* example_inputs )
433434 ref_output = composite_model (* example_inputs )
434435 print ("ref_output:" , ref_output )
435436
0 commit comments