Skip to content

Commit f9f988d

Browse files
Jiseong-ohChen03ZhaoSamsungSangsoo.ko
committed
Support Quantized MobileBert
- update annotator - Support quantized mobilebert - update Quantization strategy Co-authored-by: chen.zhao <chen03.zhao@samsung.com> Co-authored-by: Sangsoo.ko <sangsoo.ko@samsung.com> Signed-off-by: jiseong.oh <jiseong.oh@samsung.com>
1 parent 34f3723 commit f9f988d

23 files changed

+1502
-184
lines changed

backends/samsung/_passes/annotate_qparams.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch._export.utils import get_buffer
1515
from torch.export import ExportedProgram
1616
from torch.fx import GraphModule, Node
17+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1718

1819

1920
class AnnotateQparamsPass(ExportPass):
@@ -148,13 +149,34 @@ def _check_same(requant_obj, ori_obj) -> bool:
148149
_check_same(ori_quant_attrs[key], requantize_attrs[key])
149150
for key in key_map.values()
150151
):
151-
requantize_map[idx] = requantize_attrs
152+
if (
153+
ori_quant_attrs[QuantConstants.QUANT_KEY.quant_dtype]
154+
!= requantize_attrs[QuantConstants.QUANT_KEY.quant_dtype]
155+
):
156+
# For Q-DQ who will change quant dtype, we will insert requantization node
157+
requantize_map[idx] = requantize_attrs
158+
else:
159+
node.meta["quantize_attrs"] = requantize_attrs
152160

153161
def _annotate(self, graph_module: GraphModule):
154162
for node in graph_module.graph.nodes:
163+
if key_map := QuantConstants.DEQUANT_OPS_KEY_MAP.get(node.target, None):
164+
# We will fold node with constant output in the future pass as a constant node
165+
# example: Constant->Q->DQ->nodeN->Q->DQ, this seq will be folded to one
166+
# We need to store the q-params from last DQ params for quantizing constant value
167+
quant_attrs = self.get_quant_attrs(node, key_map)
168+
if node.args[0].target in QuantConstants.QUANT_OPS_KEY_MAP:
169+
node.meta["quantize_attrs"] = quant_attrs
170+
else:
171+
node.args[0].meta["quantize_attrs"] = quant_attrs
172+
continue
155173
key_map = QuantConstants.QUANT_OPS_KEY_MAP.get(node.target, None)
156174
if not key_map:
157175
continue
176+
quant_attrs = self.get_quant_attrs(node, key_map)
177+
if node.args[0].target in QuantConstants.QUANT_OPS_KEY_MAP:
178+
node.meta["quantize_attrs"] = quant_attrs
179+
continue
158180
source_node = node.args[0]
159181
if source_node.target in (
160182
*QuantConstants.QUANT_OPS_KEY_MAP,
@@ -164,13 +186,26 @@ def _annotate(self, graph_module: GraphModule):
164186
continue
165187
elif source_node.target == operator.getitem:
166188
source_node = source_node.args[0]
167-
quant_attrs = self.get_quant_attrs(node, key_map)
189+
168190
source_node.meta["quantize_attrs"] = quant_attrs
169191
self._annotate_requantize(source_node)
170192
self._propagate_quant_params(source_node)
171193

194+
def _annotate_decomposed_mm(self, graph_module: GraphModule):
195+
for source_list in get_source_partitions(graph_module.graph, ["matmul"]).get(
196+
"matmul", {}
197+
):
198+
final_view = source_list.output_nodes[0]
199+
if not (quantize_attrs := final_view.meta.get("quantize_attrs")):
200+
continue
201+
for node in source_list.nodes:
202+
if node.target == exir_ops.edge.aten.bmm.default:
203+
node.meta["quantize_attrs"] = quantize_attrs
204+
break
205+
172206
def call(self, graph_module: GraphModule):
173207
self._annotate(graph_module)
208+
self._annotate_decomposed_mm(graph_module)
174209
graph_module.recompile()
175210
return PassResult(graph_module, True)
176211

backends/samsung/_passes/annotate_scalar_parameters.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
from executorch.backends.samsung.quantizer.quantizer import global_quant_info
98
from executorch.backends.samsung.utils.constants import QuantConstants
109
from executorch.backends.transforms.utils import get_param_tensor, is_param_node
1110
from executorch.exir.dialects._ops import ops as exir_ops
@@ -25,6 +24,7 @@ class AnnotateScalarParametersPass(ExportPass):
2524
exir_ops.edge.aten.mul.Tensor,
2625
exir_ops.edge.aten.add.Tensor,
2726
exir_ops.edge.aten.div.Tensor,
27+
exir_ops.edge.aten.sub.Tensor,
2828
}
2929

3030
def __init__(self, edge_program: ExportedProgram):
@@ -35,27 +35,37 @@ def annotate(self, graph_module: torch.fx.GraphModule):
3535
for node in graph_module.graph.nodes:
3636
if node.target not in self.TARGET_OPS or "quantize_attrs" not in node.meta:
3737
continue
38-
torch_quant_dtype = global_quant_info.weight_precison.torch_dtype
39-
for input_arg in node.all_input_nodes:
40-
if input_arg.op not in ("placeholder", "get_attr") or not is_param_node(
41-
self.edge_program, input_arg
38+
input0, input1 = node.all_input_nodes[0], node.all_input_nodes[1]
39+
if input0.op not in ("placeholder", "get_attr") or not is_param_node(
40+
self.edge_program, input0
41+
):
42+
if input1.op not in ("placeholder", "get_attr") or not is_param_node(
43+
self.edge_program, input1
4244
):
4345
continue
44-
else:
45-
tensor = get_param_tensor(self.edge_program, input_arg)
46-
if not tensor.shape:
47-
qparams = {
48-
QuantConstants.QUANT_KEY.scale: float(tensor),
49-
QuantConstants.QUANT_KEY.quant_dtype: torch_quant_dtype,
50-
QuantConstants.QUANT_KEY.quant_max: torch.iinfo(
51-
torch_quant_dtype
52-
).max,
53-
QuantConstants.QUANT_KEY.quant_min: torch.iinfo(
54-
torch_quant_dtype
55-
).min,
56-
QuantConstants.QUANT_KEY.zero_point: 0,
57-
}
58-
input_arg.meta["quantize_attrs"] = qparams
46+
ifm_node, param_tensor_node = input0, input1
47+
else:
48+
ifm_node, param_tensor_node = input1, input0
49+
if not (quantize_attrs := ifm_node.meta.get("quantize_attrs")):
50+
continue
51+
param_tensor = get_param_tensor(self.edge_program, param_tensor_node)
52+
if not param_tensor.shape:
53+
scale = (
54+
float(param_tensor) if param_tensor > 0 else -float(param_tensor)
55+
)
56+
else:
57+
continue
58+
q_dtype = quantize_attrs[QuantConstants.QUANT_KEY.quant_dtype]
59+
if scale == 0:
60+
scale = 1.0
61+
qparams = {
62+
QuantConstants.QUANT_KEY.scale: scale,
63+
QuantConstants.QUANT_KEY.quant_dtype: q_dtype,
64+
QuantConstants.QUANT_KEY.quant_max: torch.iinfo(q_dtype).max,
65+
QuantConstants.QUANT_KEY.quant_min: torch.iinfo(q_dtype).min,
66+
QuantConstants.QUANT_KEY.zero_point: 0,
67+
}
68+
param_tensor_node.meta["quantize_attrs"] = qparams
5969

6070
def call(self, graph_module: torch.fx.GraphModule):
6171
graph = graph_module.graph

backends/samsung/_passes/fuse_conv_act.py renamed to backends/samsung/_passes/fuse_activation.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def map_hardtan_relux(tanhnode: torch.fx.node.Node) -> Optional[str]:
2424
return None
2525

2626

27-
class FuseConvActPass(ExportPass):
27+
class FuseActivationPass(ExportPass):
2828
TARGET_ACTS_MAP = {
2929
exir_ops.edge.aten.relu.default: (lambda x: "RELU"),
3030
exir_ops.edge.aten.relu_.default: (lambda x: "RELU"),
@@ -33,45 +33,45 @@ class FuseConvActPass(ExportPass):
3333
exir_ops.edge.aten.hardtanh.default: map_hardtan_relux,
3434
exir_ops.edge.aten.hardtanh_.default: map_hardtan_relux,
3535
}
36+
TARGET_SOURCE_NODES = {
37+
exir_ops.edge.aten.convolution.default,
38+
exir_ops.edge.aten.linear.default,
39+
}
3640

3741
def _fuse(
3842
self,
3943
graph_module: GraphModule,
4044
):
41-
for target_conv, target_act in self.get_target_conv_act(graph_module):
45+
for target_src, target_act in self.get_target_src_act(graph_module):
4246
assert (
4347
act_name := self.TARGET_ACTS_MAP.get(target_act.target)(target_act)
4448
), f"Not supported {target_act.name} now."
45-
target_conv.meta["activation"] = act_name
49+
target_src.meta["activation"] = act_name
4650
if "quantize_attrs" in target_act.meta:
47-
target_conv.meta["quantize_attrs"] = target_act.meta["quantize_attrs"]
48-
49-
# If we merge the real out activation to conv, the conv should be the real out
50-
if "real_out" in target_act.meta:
51-
target_conv.meta["real_out"] = target_act.meta["real_out"]
51+
target_src.meta["quantize_attrs"] = target_act.meta["quantize_attrs"]
52+
else:
53+
continue
5254
for user in [user for user in target_act.users.keys()]: # noqa: C416
53-
user.replace_input_with(target_act, target_conv)
55+
user.replace_input_with(target_act, target_src)
5456
graph_module.graph.erase_node(target_act)
5557

56-
def get_target_conv_act(self, graph_module: GraphModule):
58+
def get_target_src_act(self, graph_module: GraphModule):
5759
for node in graph_module.graph.nodes:
58-
if node.target != exir_ops.edge.aten.convolution.default:
60+
if node.target not in self.TARGET_SOURCE_NODES:
5961
continue
6062
if len(node.users) != 1:
61-
# Such cases couldn't be conv + act
63+
# Such cases couldn't be src + act
6264
continue
6365
act_node = list(node.users.keys())[0]
6466
if act_node.target not in self.TARGET_ACTS_MAP:
6567
continue
6668
if "quantize_attrs" in node.meta:
67-
# If the conv's output is quantized
68-
# We do not fuse them
69+
# If we merge the real out activation to source, the source should be the real out
6970
continue
7071
yield node, act_node
7172

7273
def call(self, graph_module: GraphModule):
7374
self._fuse(graph_module)
7475
graph_module.recompile()
7576
dead_code_elimination_pass(graph_module)
76-
_ = super().call(graph_module).graph_module
7777
return PassResult(graph_module, True)

backends/samsung/_passes/insert_qdq.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,18 @@ def _add_qdq(self, graph_module: GraphModule):
156156
elif is_graph_output(node):
157157
self._add_dq_after(graph_module, node)
158158

159+
def _add_q_for_cast(self, graph_module: GraphModule):
160+
for node in list(graph_module.graph.nodes):
161+
if not node.target == exir_ops.edge.aten._to_copy.default:
162+
continue
163+
if "quantize_attrs" not in node.meta:
164+
continue
165+
self._add_q_after(graph_module, node)
166+
159167
def call(self, graph_module: GraphModule):
160168
self._add_qdq(graph_module)
161169
self._add_qdq_for_requantize(graph_module)
170+
self._add_q_for_cast(graph_module)
162171
graph_module.graph.eliminate_dead_code()
163172
graph_module.recompile()
164173
return PassResult(graph_module, True)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) 2025 Samsung Electronics Co. LTD
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.samsung.utils.constants import QuantConstants
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
from executorch.exir.passes import dead_code_elimination_pass
12+
from torch.export import ExportedProgram
13+
from torch.fx import GraphModule
14+
15+
16+
class TransformQuantizedMaskPass(ExportPass):
17+
def __init__(self, edge_program: ExportedProgram):
18+
super().__init__()
19+
self.edge_program = edge_program
20+
21+
def get_mask_mul(self, graph_module: GraphModule):
22+
"""
23+
Iterator for each patterns in the graph.
24+
The obj returned by iterator is the first node of the pattern.
25+
"""
26+
nodes_in_pattern = (
27+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
28+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
29+
exir_ops.edge.aten.sub.Tensor,
30+
exir_ops.edge.aten._to_copy.default,
31+
exir_ops.edge.aten.unsqueeze_copy.default,
32+
exir_ops.edge.aten.mul.Tensor,
33+
)
34+
mask_node = None
35+
for node in graph_module.graph.nodes:
36+
if node.target != "attention_mask":
37+
continue
38+
else:
39+
mask_node = node
40+
break
41+
if mask_node is None:
42+
return None
43+
while node.target != exir_ops.edge.aten.mul.Tensor:
44+
find_next = False
45+
for successor in list(node.users.keys()):
46+
if successor.target in nodes_in_pattern:
47+
node = successor
48+
find_next = True
49+
break
50+
if not find_next:
51+
return None
52+
return node
53+
54+
def transform(
55+
self,
56+
graph_module: GraphModule,
57+
):
58+
mask_mul = self.get_mask_mul(graph_module)
59+
if mask_mul is None:
60+
return
61+
rsub_node = mask_mul.args[0]
62+
manual_mul_idx = 0
63+
for add in list(mask_mul.users.keys()):
64+
custom_tensor_name = f"_custom_tensor_{manual_mul_idx}"
65+
div_node = add.args[0]
66+
if "quantize_attrs" not in div_node.meta:
67+
return
68+
div_quant_args = div_node.meta["quantize_attrs"]
69+
custom_tensor = torch.tensor(
70+
(
71+
div_node.meta["quantize_attrs"][QuantConstants.QUANT_KEY.quant_min]
72+
- div_node.meta["quantize_attrs"][
73+
QuantConstants.QUANT_KEY.zero_point
74+
]
75+
)
76+
* div_node.meta["quantize_attrs"][QuantConstants.QUANT_KEY.scale],
77+
dtype=torch.float32,
78+
)
79+
graph_module.register_buffer(custom_tensor_name, custom_tensor)
80+
add.meta["quantize_attrs"] = div_quant_args
81+
with graph_module.graph.inserting_after(rsub_node):
82+
custom_attr = graph_module.graph.get_attr(custom_tensor_name)
83+
with graph_module.graph.inserting_after(custom_attr):
84+
new_mul = graph_module.graph.create_node(
85+
"call_function",
86+
exir_ops.edge.aten.mul.Tensor,
87+
(mask_mul.args[0], custom_attr),
88+
)
89+
new_mul.meta["quantize_attrs"] = div_quant_args
90+
add.replace_input_with(mask_mul, new_mul)
91+
92+
rsub_in = rsub_node.args[1]
93+
with graph_module.graph.inserting_before(add):
94+
new_mul = graph_module.graph.create_node(
95+
"call_function", exir_ops.edge.aten.mul.Tensor, (div_node, rsub_in)
96+
)
97+
new_mul.meta["quantize_attrs"] = div_quant_args
98+
add.replace_input_with(div_node, new_mul)
99+
manual_mul_idx += 1
100+
101+
def call(self, graph_module: GraphModule):
102+
self.transform(graph_module)
103+
graph_module.recompile()
104+
dead_code_elimination_pass(graph_module)
105+
return PassResult(graph_module, True)

backends/samsung/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
op_mul,
3535
op_permute,
3636
op_pixel_shuffle,
37+
op_placeholder,
3738
op_quantize,
3839
op_relu,
3940
op_reshape,
@@ -80,6 +81,7 @@
8081
op_mul,
8182
op_permute,
8283
op_pixel_shuffle,
84+
op_placeholder,
8385
op_quantize,
8486
op_relu,
8587
op_reshape,

backends/samsung/builders/op_constant_pad_nd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,5 @@ def define_node(
5252
"padding": "EXPLICIT",
5353
"padding_type": "CONSTANT",
5454
}
55-
55+
self._update_params_qdtype(node, params)
5656
enn_graph.define_op(node.name, "PAD", [input_id], [output_id], params)

backends/samsung/builders/op_embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def define_node(
3636
output_id = self.define_tensor(node, enn_graph, vals_to_ids)
3737

3838
params = {"axis": 0, "input_type": "indices"}
39+
self._update_params_qdtype(node, params)
3940
enn_graph.define_op(
4041
node.name, "GATHER", [input_id, weight_id], [output_id], params
4142
)

0 commit comments

Comments
 (0)