|
7 | 7 |
|
8 | 8 | import logging |
9 | 9 | import os |
10 | | -from typing import Any, Tuple |
| 10 | +from typing import Any, Optional, Tuple |
11 | 11 |
|
12 | 12 | import serializer.tosa_serializer as ts # type: ignore |
13 | 13 | import torch |
14 | 14 | from executorch.backends.arm.tosa_mapping import TosaArg |
15 | 15 |
|
16 | 16 | from executorch.exir.dialects._ops import ops as exir_ops |
| 17 | +from executorch.exir.print_program import inspect_node |
17 | 18 | from serializer.tosa_serializer import TosaOp |
18 | 19 | from torch.fx import Node |
19 | 20 |
|
20 | 21 | logger = logging.getLogger(__name__) |
21 | | -logger.setLevel(logging.WARNING) |
22 | 22 | TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" |
23 | 23 | if TOSA_DBG_VERBOSE: |
24 | 24 | logging.basicConfig(level=logging.INFO) |
25 | 25 | logger.setLevel(logging.INFO) |
26 | 26 |
|
27 | 27 |
|
28 | | -def dbg_node(node: torch.fx.Node): |
| 28 | +def dbg_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule): |
29 | 29 | # Debug output of node information |
30 | | - logger.info(get_node_debug_info(node)) |
| 30 | + logger.info(get_node_debug_info(node, graph_module)) |
31 | 31 |
|
32 | 32 |
|
33 | | -def get_node_debug_info(node: torch.fx.Node) -> str: |
| 33 | +def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> str: |
34 | 34 | output = ( |
| 35 | + f" {inspect_node(graph=graph_module.graph, node=node)}\n" |
35 | 36 | "-- NODE DEBUG INFO --\n" |
36 | 37 | f" Op is {node.op}\n" |
37 | 38 | f" Name is {node.name}\n" |
@@ -71,21 +72,24 @@ def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""): |
71 | 72 | assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON" |
72 | 73 |
|
73 | 74 |
|
74 | | -def dbg_fail(node, tosa_graph, path): |
75 | | - dbg_tosa_dump(tosa_graph, path) |
| 75 | +def dbg_fail( |
| 76 | + node, |
| 77 | + graph_module, |
| 78 | + tosa_graph: Optional[ts.TosaSerializer] = None, |
| 79 | + path: Optional[str] = None, |
| 80 | +): |
76 | 81 | logger.warning("Internal error due to poorly handled node:") |
77 | | - dbg_node(node) |
78 | | - logger.warning(f"Debug output captured in '{path}'.") |
79 | | - raise RuntimeError("TOSA Internal Error on node, enable logging for further info.") |
| 82 | + if tosa_graph is not None and path is not None: |
| 83 | + dbg_tosa_dump(tosa_graph, path) |
| 84 | + logger.warning(f"Debug output captured in '{path}'.") |
| 85 | + dbg_node(node, graph_module) |
80 | 86 |
|
81 | 87 |
|
82 | 88 | def getNodeArgs(node: Node) -> list[TosaArg]: |
83 | 89 | try: |
84 | 90 | return [TosaArg(arg) for arg in node.args] |
85 | 91 | except ValueError as e: |
86 | | - raise ValueError( |
87 | | - f"Failed processing args to op:\n{get_node_debug_info(node)}" |
88 | | - ) from e |
| 92 | + raise ValueError(f"Failed processing args to op:\n{node}") from e |
89 | 93 |
|
90 | 94 |
|
91 | 95 | def get_output_node(node: Node) -> Node: |
|
0 commit comments