File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 3030 TosaSpecification ,
3131 TosaSpecMapping ,
3232)
33- from torch .export import ExportedProgram
3433
3534logger = logging .getLogger (__name__ )
3635
@@ -39,7 +38,6 @@ class NodeVisitor:
3938 """Provide a visitor pattern to lower edge IR to TOSA.
4039
4140 Attributes:
42- _exported_program (torch.export.ExportedProgram): Source program being lowered.
4341 tosa_spec (TosaSpecification): Active TOSA specification for lowering.
4442 debug_hook (Optional[DebugHook]): Optional hook for debug metadata.
4543
@@ -54,11 +52,9 @@ class NodeVisitor:
5452
5553 def __init__ (
5654 self ,
57- exported_program : ExportedProgram ,
5855 tosa_spec : TosaSpecification ,
5956 debug_hook : Optional [DebugHook ] = None ,
6057 ):
61- self ._exported_program = exported_program
6258 self .tosa_spec = tosa_spec
6359 self .debug_hook = debug_hook
6460
Original file line number Diff line number Diff line change @@ -52,21 +52,17 @@ def define_node(
5252 # The name of the table constant is a bit complex.
5353 # The name of the pytorch buffer will be the target of last node argument.
5454 # However, when it is serialized to TOSA, a submodule suffix might be added. The TOSA buffer name thus
55- # needs to be taken from the last TosaArg.
56- pytorch_table_buffer_name = node .args [- 1 ].target # type: ignore[union-attr]
57- tosa_table_buffer_name = inputs [- 1 ].name
58- if pytorch_table_buffer_name not in self ._exported_program .state_dict .keys ():
59- raise RuntimeError (
60- f"Did not find key { node .name } in state_dict { self ._exported_program .state_dict .keys ()} ."
61- )
55+ # needs to be taken from the buffer TosaArg.
56+ input , table_buffer = inputs
57+ tosa_table_buffer_name = table_buffer .name
6258
6359 attr = ts .TosaSerializerAttribute ()
6460 attr .TableAttribute ()
6561 self ._serialize_operator (
6662 node ,
6763 tosa_graph ,
6864 ts .Op .TABLE ,
69- [inputs [ 0 ] .name , tosa_table_buffer_name ],
65+ [input .name , tosa_table_buffer_name ],
7066 [output .name ],
7167 attr ,
7268 )
Original file line number Diff line number Diff line change @@ -340,7 +340,7 @@ def _preprocess_module( # noqa: C901
340340 # TODO: Fix the need to lazily import this.
341341 from executorch .backends .arm .operators .node_visitor import get_node_visitors
342342
343- node_visitors = get_node_visitors (edge_program , tosa_spec , debug_hook )
343+ node_visitors = get_node_visitors (tosa_spec , debug_hook )
344344
345345 if output_order_workaround :
346346 logger .debug ("Re-sorting outputs during TOSA lowering." )
You can’t perform that action at this time.
0 commit comments