Skip to content

Commit eab5f4b

Browse files
NefAIcursoragent
andcommitted
fix: ensure output/out-var node specs before memory planning
- Add ensure_graph_node_specs() to set meta['spec'] from meta['val'] when missing (e.g. after delegation), so memory planning and verifier succeed. - Call it at start of MemoryPlanningPass.run() for all graph modules. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 25bb325 commit eab5f4b

2 files changed

Lines changed: 25 additions & 1 deletion

File tree

exir/memory_planning.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
InputKind,
4242
)
4343
from torch.fx import Node
44-
from torch.utils._pytree import tree_flatten
44+
from torch.utils._pytree import tree_flatten, tree_map
4545

4646
REGISTERED_ALGOS: Dict[str, Callable[..., List[int]]] = {}
4747

@@ -755,6 +755,26 @@ def get_node_tensor_specs(
755755
]
756756

757757

758+
def ensure_graph_node_specs(graph_module: torch.fx.GraphModule) -> None:
759+
"""
760+
Set meta["spec"] from meta["val"] for nodes that are missing spec (e.g. output
761+
or out-var nodes in delegated graphs that were built after SpecPropPass).
762+
"""
763+
for node in graph_module.graph.nodes:
764+
if "spec" in node.meta:
765+
continue
766+
if "val" not in node.meta:
767+
continue
768+
val = node.meta["val"]
769+
770+
def to_spec(x: Any) -> Any:
771+
if isinstance(x, torch.Tensor):
772+
return TensorSpec.from_tensor(x)
773+
return x
774+
775+
node.meta["spec"] = tree_map(to_spec, val)
776+
777+
758778
# Little bit hacky to check if the graph contains
759779
# XNNPACK delegate
760780
# Why?

exir/passes/memory_planning_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
_is_out_var_node,
2020
apply_algo,
2121
collect_specs_from_nodes,
22+
ensure_graph_node_specs,
2223
filter_nodes,
2324
get_node_tensor_specs,
2425
MemoryPlanningAlgorithmSuite,
@@ -233,6 +234,9 @@ def run(
233234
A pass for memory planning. The actual algorithm used will be picked by
234235
memory_planning_algo
235236
"""
237+
for subgm in graph_module.modules():
238+
if isinstance(subgm, torch.fx.GraphModule):
239+
ensure_graph_node_specs(subgm)
236240
self._set_alloc_node_spec(graph_module)
237241
# TODO(shunting) if people have concern of adding a field to GraphModule
238242
# directly, we should define a GraphModule subclass that we can add our

0 commit comments

Comments
 (0)