-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathutils.py
More file actions
executable file
·162 lines (135 loc) · 5.04 KB
/
utils.py
File metadata and controls
executable file
·162 lines (135 loc) · 5.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional
import torch
from torch._export.utils import (
get_buffer,
get_lifted_tensor_constant,
get_param,
is_buffer,
is_lifted_tensor_constant,
is_param,
)
def is_parameter(
node: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> bool:
return (
is_param(edge_program, node)
or is_buffer(edge_program, node)
or is_lifted_tensor_constant(edge_program, node)
)
def get_parameter(
node: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> torch.Tensor:
param = None
if is_param(edge_program, node):
param = get_param(edge_program, node)
if is_buffer(edge_program, node):
param = get_buffer(edge_program, node)
if is_lifted_tensor_constant(edge_program, node):
param = get_lifted_tensor_constant(edge_program, node)
assert param is not None, (
f"Expect {node.name} to be parameter, buffer, or lifted tensor constant"
)
# update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32)
assert isinstance(param, torch.Tensor), "Expect parameter to be tensor"
param = param.type(node.meta["val"].dtype)
return param
def set_parameter(
param: torch.Tensor, node: torch.fx.Node, edge_program: torch.export.ExportedProgram
):
status = False
if is_param(edge_program, node):
edge_program.state_dict[
edge_program.graph_signature.inputs_to_parameters[node.name]
] = param
status = True
if is_buffer(edge_program, node):
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
if buffer_name in edge_program.graph_signature.non_persistent_buffers:
edge_program.constants[buffer_name] = param
else:
edge_program.state_dict[buffer_name] = param
status = True
assert status, "Failed to set parameter"
def is_graph_input(
tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> bool:
"""
Check if the given tensor is a graph input
Args:
tensor: EdgeIR Tensor that is being checked for graph input
"""
return tensor.op == "placeholder" and not is_parameter(tensor, edge_program)
def is_mutable_buffer_input(
tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> bool:
"""
Check if the given tensor is a mutable buffer input
Args:
tensor: EdgeIR Tensor that is being checked for mutable buffer input
"""
if tensor.op == "placeholder" and is_buffer(edge_program, tensor):
fqn = edge_program.graph_signature.inputs_to_buffers[tensor.target]
# if the buffer is mutated then record that
return fqn in edge_program.graph_signature.buffers_to_mutate.values()
def is_graph_output(node: torch.fx.Node) -> bool:
"""
Check if the given tensor is used as a graph output
Args:
tensor: EdgeIR Tensor that is being checked for graph input
"""
for user in node.users.keys():
# getitem node is skipped, check the op_skip_ops.py
if user.op == "output" or (
user.target.__name__ == "getitem" and is_graph_output(user)
):
return True
return False
def is_mutable_buffer_output(
tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> bool:
"""
Check if the given tensor is a mutable buffer output
Args:
tensor: EdgeIR Tensor that is being checked for mutable buffer output
"""
return (
any(
user.op == "output"
or user.target.__name__ == "getitem"
and is_graph_output(user)
for user in tensor.users.keys()
)
and tensor.name in edge_program.graph_signature.buffers_to_mutate.keys()
)
def is_constant(
tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> bool:
"""
Check if the given tensor is a constant
Args:
tensor: EdgeIR Tensor that is being checked for graph input
"""
# constants should not be treated as input placeholder
# pay attention to the pytorch design, change this if
# breakage happened:
# pytorch/torch/_export/passes/lift_constant_tensor_pass.py
if is_parameter(tensor, edge_program):
return tensor.meta["val"].constant is not None
return False
def deduce_dtype(
tensor: torch.Tensor, quant_infos: Optional[Dict] = None
) -> torch.dtype:
if quant_infos:
quant_range = quant_infos["quant_max"] - quant_infos["quant_min"]
unsigned = quant_infos["quant_min"] >= 0
if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
return torch.uint8 if unsigned else torch.int8
elif quant_range <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min:
return torch.uint16 if unsigned else torch.int16
return quant_infos["dtype"]
return tensor.dtype