Skip to content

Commit 39dade2

Browse files
authored
Preserve GraphModule.meta in ExportPass
Differential Revision: D108172756 Pull Request resolved: pytorch#20197
1 parent 129c687 commit 39dade2

2 files changed

Lines changed: 29 additions & 0 deletions

File tree

exir/pass_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,9 @@ def call_submodule(
692692

693693
new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
694694

695+
# Preserve GraphModule-level metadata from the input module.
696+
new_graph_module.meta = graph_module.meta.copy()
697+
695698
self.tracer = prev_tracer
696699
self.interpreter = prev_interpreter
697700
return PassResult(

exir/tests/test_passes.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,32 @@ class NullPass(ExportPass):
544544
self.assertEqual(new_node.op, old_node.op)
545545
self.assertEqual(new_node.target, old_node.target)
546546

547+
def test_export_pass_preserves_graph_module_meta(self) -> None:
548+
"""ExportPass should preserve GraphModule-level meta through re-tracing."""
549+
550+
class Foo(torch.nn.Module):
551+
def forward(self, x: torch.Tensor) -> torch.Tensor:
552+
return x + 1
553+
554+
class NullPass(ExportPass):
555+
pass
556+
557+
prog = to_edge(
558+
export(Foo(), (torch.ones(3, 2),), strict=True),
559+
)
560+
# Set custom metadata on the graph module before the pass.
561+
prog.exported_program().graph_module.meta["custom"] = {
562+
"test_key": "test_value",
563+
"nested": {"a": 1},
564+
}
565+
566+
new_prog = prog.transform([NullPass()])
567+
new_meta = new_prog.exported_program().graph_module.meta
568+
569+
self.assertIn("custom", new_meta)
570+
self.assertEqual(new_meta["custom"]["test_key"], "test_value")
571+
self.assertEqual(new_meta["custom"]["nested"]["a"], 1)
572+
547573
def test_export_scalar_to_tensor_pass(self) -> None:
548574
# Build a graph with a scalar argument where schema expects tensor
549575
graph = torch.fx.Graph()

0 commit comments

Comments
 (0)