Skip to content

Commit 3ff6579

Browse files
Add get_delegated_payload utility function (pytorch#11651) (pytorch#16852)
## Summary Add `get_delegated_payload()` utility function to extract delegate payloads from a graph module. The function returns a dictionary mapping delegate names to tuples of (backend_id, compile_specs, processed_bytes). Fixes pytorch#11651 ## Test plan Added 3 unit tests: ```bash python -m unittest exir.backend.test.test_utils.TestUtils.test_get_delegated_payload_with_delegates -v python -m unittest exir.backend.test.test_utils.TestUtils.test_get_delegated_payload_without_delegates -v python -m unittest exir.backend.test.test_utils.TestUtils.test_get_delegated_payload_keys_match_delegates -v
1 parent 2c9f2b3 commit 3ff6579

2 files changed

Lines changed: 179 additions & 1 deletion

File tree

exir/backend/test/test_utils.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
1515
from executorch.exir.backend.utils import (
1616
format_delegated_graph,
17+
get_delegated_payload,
1718
get_delegates,
1819
get_non_lowered_nodes,
1920
is_identical_graph,
@@ -262,3 +263,130 @@ def forward(self, a, x, b):
262263
graph_str,
263264
"Expect to see the aten.mm in the delegated graph",
264265
)
266+
267+
def test_get_delegated_payload_with_delegates(self):
268+
"""Test get_delegated_payload returns correct payload for delegated modules."""
269+
270+
class Model(torch.nn.Module):
271+
def __init__(self):
272+
super().__init__()
273+
274+
def forward(self, a, x, b):
275+
y = torch.mm(a, x)
276+
z = y + b
277+
a = z - a
278+
y = torch.mm(a, x)
279+
z = y + b
280+
return z
281+
282+
m = Model()
283+
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
284+
285+
edge = to_edge(export(m, inputs, strict=True)).to_backend(
286+
AddMulPartitionerDemo()
287+
)
288+
289+
payloads = get_delegated_payload(edge.exported_program().graph_module)
290+
291+
# Should have 2 delegates: (mm + add) -> sub -> (mm + add)
292+
self.assertEqual(len(payloads), 2)
293+
294+
# Verify payload structure for each delegate
295+
for name, (backend_id, compile_specs, processed_bytes) in payloads.items():
296+
# Check delegate name format
297+
self.assertTrue(
298+
name.startswith("lowered_module_"),
299+
f"Delegate name should start with 'lowered_module_', got {name}",
300+
)
301+
302+
# Check backend_id
303+
self.assertEqual(
304+
backend_id,
305+
"BackendWithCompilerDemo",
306+
f"Expected backend_id 'BackendWithCompilerDemo', got {backend_id}",
307+
)
308+
309+
# Check compile_specs is a list
310+
self.assertIsInstance(
311+
compile_specs,
312+
list,
313+
f"compile_specs should be a list, got {type(compile_specs)}",
314+
)
315+
316+
# Check processed_bytes is bytes
317+
self.assertIsInstance(
318+
processed_bytes,
319+
bytes,
320+
f"processed_bytes should be bytes, got {type(processed_bytes)}",
321+
)
322+
323+
# Verify processed_bytes is not empty (backend should produce some output)
324+
self.assertGreater(
325+
len(processed_bytes),
326+
0,
327+
"processed_bytes should not be empty",
328+
)
329+
330+
def test_get_delegated_payload_without_delegates(self):
331+
"""Test get_delegated_payload returns empty dict when no delegates present."""
332+
333+
class SimpleModel(torch.nn.Module):
334+
def __init__(self):
335+
super().__init__()
336+
337+
def forward(self, x):
338+
return x + 1
339+
340+
m = SimpleModel()
341+
inputs = (torch.randn(2, 2),)
342+
343+
# Create edge program without delegation
344+
edge = to_edge(export(m, inputs, strict=True))
345+
346+
payloads = get_delegated_payload(edge.exported_program().graph_module)
347+
348+
# Should have no delegates
349+
self.assertEqual(
350+
len(payloads),
351+
0,
352+
"Expected empty payload dict when no delegates present",
353+
)
354+
355+
def test_get_delegated_payload_keys_match_delegates(self):
356+
"""Test that get_delegated_payload keys match get_delegates node names."""
357+
358+
class Model(torch.nn.Module):
359+
def __init__(self):
360+
super().__init__()
361+
362+
def forward(self, a, x, b):
363+
y = torch.mm(a, x)
364+
z = y + b
365+
a = z - a
366+
y = torch.mm(a, x)
367+
z = y + b
368+
return z
369+
370+
m = Model()
371+
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
372+
373+
edge = to_edge(export(m, inputs, strict=True)).to_backend(
374+
AddMulPartitionerDemo()
375+
)
376+
377+
graph_module = edge.exported_program().graph_module
378+
379+
# Get delegates using existing utility
380+
delegate_nodes = get_delegates(graph_module.graph)
381+
delegate_names = {node.name for node in delegate_nodes}
382+
383+
# Get payloads using new utility
384+
payloads = get_delegated_payload(graph_module)
385+
payload_names = set(payloads.keys())
386+
387+
# Names should match
388+
self.assertEqual(
389+
delegate_names,
390+
payload_names,
391+
f"Delegate names mismatch: {delegate_names} vs {payload_names}",
392+
)

exir/backend/utils.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import operator
1212
from collections import defaultdict, OrderedDict
1313
from functools import lru_cache
14-
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
14+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
1515

1616
import torch
1717
from executorch.exir.backend.backend_details import ExportedProgram
@@ -659,3 +659,53 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
659659
if not is_supported:
660660
self.reporter.report_reject(node, self.message)
661661
return is_supported
662+
663+
664+
def get_delegated_payload(
665+
graph_module: torch.fx.GraphModule,
666+
) -> Dict[str, Tuple[str, List[Any], bytes]]:
667+
"""
668+
Extracts the payload for delegates from a graph module that has been lowered
669+
to one or more backends.
670+
671+
This function iterates through all lowered modules (delegates) in the graph
672+
and returns a dictionary mapping each delegate's name to a tuple containing:
673+
- backend_id: The name/identifier of the backend
674+
- compile_specs: A list of backend-specific compilation specifications
675+
- processed_bytes: The delegate blob created by the backend's preprocess method
676+
677+
Args:
678+
graph_module: A torch.fx.GraphModule that may contain lowered backend modules.
679+
This is typically obtained from an EdgeProgramManager or ExecutorchProgram
680+
via `.exported_program().graph_module`.
681+
682+
Returns:
683+
Dict[str, Tuple[str, List[Any], bytes]]: A dictionary where:
684+
- Keys are the delegate names (e.g., "lowered_module_0", "lowered_module_1")
685+
- Values are tuples of (backend_id, compile_specs, processed_bytes)
686+
687+
Example:
688+
>>> edge = to_edge(export(model, inputs))
689+
>>> edge = edge.to_backend(MyPartitioner())
690+
>>> payloads = get_delegated_payload(edge.exported_program().graph_module)
691+
>>> for name, (backend_id, specs, data) in payloads.items():
692+
... print(f"{name}: backend={backend_id}, data_size={len(data)}")
693+
"""
694+
from executorch.exir.lowered_backend_module import LoweredBackendModule
695+
696+
delegate_payloads: Dict[str, Tuple[str, List[Any], bytes]] = {}
697+
698+
# Find all lowered modules in the graph
699+
for node in graph_module.graph.nodes:
700+
if node.op == "get_attr" and node.name.startswith("lowered_module_"):
701+
lowered_module = getattr(graph_module, node.name, None)
702+
if lowered_module is not None and isinstance(
703+
lowered_module, LoweredBackendModule
704+
):
705+
delegate_payloads[node.name] = (
706+
lowered_module.backend_id,
707+
lowered_module.compile_specs,
708+
lowered_module.processed_bytes,
709+
)
710+
711+
return delegate_payloads

0 commit comments

Comments
 (0)