forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmemory_format_ops_pass.py
More file actions
139 lines (116 loc) · 4.87 KB
/
memory_format_ops_pass.py
File metadata and controls
139 lines (116 loc) · 4.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
from typing import List, Optional
import torch
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.dim_order_utils import get_dim_order, get_memory_format
from executorch.exir.pass_base import ExportPass, ProxyValue
from executorch.exir.passes.dim_order_ops_registry import (
DimOrderOpsMap,
MemoryFormatOpsMap,
)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
class MemoryFormatOpsPass(ExportPass):
"""
This pass replaces ops which takes torch.memory_format as an argument with
'equivalent' op which takes dim_order. This is towards the larger ExecuTorch
goal to move away from torch.memory_format. There is a 1:1 mapping between
the aten op and the new edge dialect dim_order op.
"""
def call_operator(self, op, args, kwargs, meta):
if not (isinstance(op, EdgeOpOverload) and op in DimOrderOpsMap):
return super().call_operator(
op,
args,
kwargs,
meta,
)
# new kwargs with dim_order, and no memory_format for the new op
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable
# Get the target memory format for the EdgeOp, defaulting to
# preserve_format (clone() with no memory_format kwarg preserves
# the input's layout instead of forcing contiguous).
mem_format = nkwargs.pop("memory_format", torch.preserve_format)
# Get input tensor and ndim
input_tensor: Optional[torch.Tensor] = None
if isinstance(args[0], ProxyValue) and args[0].is_tensor():
input_tensor = args[0].to_tensor()
ndim = input_tensor.dim()
elif isinstance(args[0], torch.Tensor):
input_tensor = args[0]
ndim = input_tensor.dim()
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
ndim = len(args[0])
else:
assert (
0
), f"Expecting a Tensor, a ProxyValue, or a Sequence, but got {type(args[0])}"
# Derive dim_order based on memory format
dim_order: List[int]
if mem_format in (None, torch.preserve_format):
# preserve_format: inherit dim_order from input tensor
if input_tensor is not None:
dim_order = [int(d) for d in input_tensor.dim_order()]
else:
# Fallback to contiguous if no single input tensor is available
# (e.g. list inputs like torch.stack).
dim_order = list(range(ndim))
else:
# Explicit memory format (contiguous_format, channels_last, etc.)
dim_order = get_dim_order(mem_format, ndim)
nkwargs["dim_order"] = dim_order
logger.debug(
f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}."
f" {DimOrderOpsMap[op].__name__} = dim_order: {nkwargs['dim_order']}"
)
t = DimOrderOpsMap[op]
return super().call_operator(
t,
args,
nkwargs,
meta,
)
class DimOrderOpsRevertPass(ExportPass):
"""
This pass is to revert the dim_order ops back to the memory format ops.
"""
def call_operator(self, op, args, kwargs, meta):
if not (isinstance(op, EdgeOpOverload) and op in MemoryFormatOpsMap):
return super().call_operator(
op,
args,
kwargs,
meta,
)
# new kwargs with dim_order, and no memory_format for the new op
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable
# can always get the shape, assuming rank is specialized
if isinstance(args[0], ProxyValue) and args[0].is_tensor():
ndim = args[0].to_tensor().dim()
elif isinstance(args[0], torch.Tensor):
ndim = args[0].dim()
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
ndim = len(args[0])
else:
assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}"
# get the "to" memory format for the EdgeOp
default_dim_order = list(range(ndim))
dim_order = nkwargs.pop("dim_order", default_dim_order)
nkwargs["memory_format"] = get_memory_format(dim_order)
logger.debug(
f" {op.__name__} = dim_order: {dim_order}."
f" {MemoryFormatOpsMap[op].__name__} = rank: {ndim}, memory_format: {nkwargs['memory_format']}."
)
t = MemoryFormatOpsMap[op]
return super().call_operator(
t,
args,
nkwargs,
meta,
)