forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpartitioner.py
More file actions
114 lines (91 loc) · 4.76 KB
/
partitioner.py
File metadata and controls
114 lines (91 loc) · 4.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from types import MappingProxyType
from typing import Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union
import torch
from executorch.exir.backend.backend_details import enforcedmethod
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export import ExportedProgram
class DelegationSpec(NamedTuple):
backend_id: str
compile_specs: List[CompileSpec]
@dataclass
class PartitionResult:
"""
tagged_exported_program: the graph with nodes that intend to be delegated containing a "DelegationSpec" metadata
partition_tags: A dictionary that will be used to keep track of the tags and it's corresponding DelegationSpec. The tag is defined by users and used
in the node.meta.
"""
tagged_exported_program: ExportedProgram
partition_tags: Dict[str, DelegationSpec]
class Partitioner(ABC):
"""
Defines a callable interface for partitioning an exported program for
backend delegation.
A partitioner implementation would receive an exported program, determine what portions of
the it can be delegated to certain backend (though a partitioner can target multiple
backends as well), and return the PartitionResult including:
- the same input module with specific nodes in the input graph tagged for delegation
- the "partition_tags" to indicate how the tag is mapped to Delegation Spec.
The nodes that intend to be delegated must be tagged (by setting
node.meta["delegation_tag"]) and this tag must be provided in the
`partition_tags` dictionary mapping to an instance of
DelegationSpec(backend_id, method_compilation_spec). Each tag must represent
a distinct submodule that we intend on lowering and should be fully contained.
For details on method_compilation_spec see the to_backend API, as these objects follow
the same format.
Args:
exported_program: An ExportedProgram in Edge dialect to be partitioned for backend delegation.
"""
def __init__(
self,
spec: Mapping[Union[str, int, float, bool], object] = MappingProxyType({}),
):
self._spec = spec
def __call__(self, exported_program: ExportedProgram) -> PartitionResult:
return self.partition(exported_program)
@property
def spec(self) -> Mapping[Union[str, int, float, bool], object]:
return self._spec
@enforcedmethod
@abstractmethod
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
"""
Returns the input exported program with newly created sub-Modules encapsulating
specific portions of the input "tagged" for delegation.
The specific implementation is free to decide how existing computation in the
input exported program should be delegated to one or even more than one specific
backends.
The contract is stringent in that:
* Each node that is intended to be delegated must be tagged
* No change in the original input exported program (ExportedProgram) representation can take
place other than adding sub-Modules for encapsulating existing portions of the
input exported program and the associated metadata for tagging.
Args:
exported_program: An ExportedProgram in Edge dialect to be partitioned for backend delegation.
Returns:
PartitionResult: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers.
"""
pass
def ops_to_not_decompose(
self,
ep: ExportedProgram,
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
"""
Returns a list of operator names that should not be decomposed. When these ops are
registered and the `to_backend` is invoked through to_edge_transform_and_lower it will be
guaranteed that the program that the backend receives will not have any of these ops
decomposed.
Returns:
List[torch._ops.OpOverload]: a list of operator names that should not be decomposed.
Optional[Callable[[torch.fx.Node], bool]]]: an optional callable, acting as a filter, that users can provide
which will be called for each node in the graph that users can use as a filter for certain
nodes that should be continued to be decomposed even though the op they correspond to is
in the list returned by ops_to_not_decompose.
"""
return ([], None)