Skip to content

Commit 3924a22

Browse files
authored
Arm backend: remove exported_program arg from NodeVisitor (#19370)
It was only used for a single check in op_tosa_table. Node visitors should only require local knowledge. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent 960c492 commit 3924a22

3 files changed

Lines changed: 5 additions & 13 deletions

File tree

backends/arm/operators/node_visitor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
TosaSpecification,
3131
TosaSpecMapping,
3232
)
33-
from torch.export import ExportedProgram
3433

3534
logger = logging.getLogger(__name__)
3635

@@ -39,7 +38,6 @@ class NodeVisitor:
3938
"""Provide a visitor pattern to lower edge IR to TOSA.
4039
4140
Attributes:
42-
_exported_program (torch.export.ExportedProgram): Source program being lowered.
4341
tosa_spec (TosaSpecification): Active TOSA specification for lowering.
4442
debug_hook (Optional[DebugHook]): Optional hook for debug metadata.
4543
@@ -54,11 +52,9 @@ class NodeVisitor:
5452

5553
def __init__(
5654
self,
57-
exported_program: ExportedProgram,
5855
tosa_spec: TosaSpecification,
5956
debug_hook: Optional[DebugHook] = None,
6057
):
61-
self._exported_program = exported_program
6258
self.tosa_spec = tosa_spec
6359
self.debug_hook = debug_hook
6460

backends/arm/operators/op_tosa_table.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,17 @@ def define_node(
5252
# The name of the table constant is a bit complex.
5353
# The name of the pytorch buffer will be the target of last node argument.
5454
# However, when it is serialized to TOSA, a submodule suffix might be added. The TOSA buffer name thus
55-
# needs to be taken from the last TosaArg.
56-
pytorch_table_buffer_name = node.args[-1].target # type: ignore[union-attr]
57-
tosa_table_buffer_name = inputs[-1].name
58-
if pytorch_table_buffer_name not in self._exported_program.state_dict.keys():
59-
raise RuntimeError(
60-
f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}."
61-
)
55+
# needs to be taken from the buffer TosaArg.
56+
input, table_buffer = inputs
57+
tosa_table_buffer_name = table_buffer.name
6258

6359
attr = ts.TosaSerializerAttribute()
6460
attr.TableAttribute()
6561
self._serialize_operator(
6662
node,
6763
tosa_graph,
6864
ts.Op.TABLE,
69-
[inputs[0].name, tosa_table_buffer_name],
65+
[input.name, tosa_table_buffer_name],
7066
[output.name],
7167
attr,
7268
)

backends/arm/tosa/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def _preprocess_module( # noqa: C901
340340
# TODO: Fix the need to lazily import this.
341341
from executorch.backends.arm.operators.node_visitor import get_node_visitors
342342

343-
node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook)
343+
node_visitors = get_node_visitors(tosa_spec, debug_hook)
344344

345345
if output_order_workaround:
346346
logger.debug("Re-sorting outputs during TOSA lowering.")

0 commit comments

Comments
 (0)