Skip to content

Commit 7b5dcc1

Browse files
authored
Add add-relu fusion in the quantizer
Differential Revision: D102189156 Pull Request resolved: pytorch#19077
1 parent 6d23e41 commit 7b5dcc1

4 files changed

Lines changed: 129 additions & 1 deletion

File tree

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from executorch.backends.cadence.aot.quantizer.patterns import (
1616
AddmmPattern,
1717
AddPattern,
18+
AddReluPattern0,
19+
AddReluPattern1,
1820
BmmPattern,
1921
CatPattern,
2022
Conv1dPattern,
@@ -63,6 +65,7 @@
6365
Conv2dReluPattern0,
6466
Conv2dReluPattern1,
6567
)
68+
AddReluPatterns = (AddReluPattern0, AddReluPattern1)
6669

6770

6871
def get_args_and_kwargs_add(
@@ -616,7 +619,20 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
616619
inputs_inputs + weights_inputs + other_inputs + bias_inputs
617620
)
618621
kwargs = {}
619-
if isinstance(pattern, AddPattern):
622+
if isinstance(pattern, AddReluPatterns):
623+
# For AddReLU, we are fusing Add+ReLU.
624+
# The quantized_add op performs requantization,
625+
# so the relu is implicit in the output quant params.
626+
check_out_zero_point_is_min_range(
627+
quant_node.args[2], quant_node.args[5]
628+
)
629+
args, kwargs = get_args_and_kwargs_add(
630+
graph_module,
631+
inputs_inputs,
632+
dequants_inputs,
633+
quant_node,
634+
)
635+
elif isinstance(pattern, AddPattern):
620636
args, kwargs = get_args_and_kwargs_add(
621637
graph_module,
622638
inputs_inputs,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,61 @@ def replacement_op(self) -> OpOverload:
153153
return torch.ops.cadence.quantized_add.per_tensor
154154

155155

156+
# This is a base class for Add+ReLU fusion, since it can be used with two different relu aten ops
157+
class AddReluBasePattern(QuantizationPattern):
158+
@abstractmethod
159+
def partition_types(self) -> List[OpOverload]:
160+
pass
161+
162+
def get_anchors(
163+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
164+
) -> Tuple[PartitionAnchors, fx.Node]:
165+
# The first node should be add, the second should be relu
166+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
167+
add_node = fused_partition[0].nodes[-1]
168+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
169+
relu_node = fused_partition[1].nodes[-1]
170+
171+
# Bail if:
172+
# - the add node is not a tensor add
173+
# - the add node has kwargs (e.g. alpha)
174+
is_tensor_add = isinstance(add_node.args[0], fx.Node) and isinstance(
175+
add_node.args[1], fx.Node
176+
)
177+
if not is_tensor_add or len(add_node.kwargs) > 0:
178+
return (
179+
PartitionAnchors(
180+
empty=True,
181+
),
182+
add_node,
183+
)
184+
185+
return (
186+
PartitionAnchors(
187+
inputs=[(add_node, 0), (add_node, 1)],
188+
weights=[],
189+
biases=[],
190+
output=[(relu_node,)], # Output is from the relu node
191+
),
192+
relu_node,
193+
)
194+
195+
def replacement_op(self) -> OpOverload:
196+
return torch.ops.cadence.quantized_add.per_tensor
197+
198+
199+
# Add + regular relu op fusion
200+
class AddReluPattern0(AddReluBasePattern):
201+
def partition_types(self) -> List[OpOverload]:
202+
return [torch.ops.aten.add.Tensor, torch.ops.aten.relu.default]
203+
204+
205+
# Add + alternate relu op fusion
206+
class AddReluPattern1(AddReluBasePattern):
207+
def partition_types(self) -> List[OpOverload]:
208+
return [torch.ops.aten.add.Tensor, torch.ops.aten.relu_.default]
209+
210+
156211
class BmmPattern(QuantizationPattern):
157212
def partition_types(self) -> List[OpOverload]:
158213
return [torch.ops.aten.bmm.default]

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from executorch.backends.cadence.aot.quantizer.patterns import (
1414
AddmmPattern,
1515
AddPattern,
16+
AddReluPattern0,
17+
AddReluPattern1,
1618
BmmPattern,
1719
CatPattern,
1820
Conv1dPattern,
@@ -398,6 +400,8 @@ def __init__(
398400
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), a8w8sym))
399401
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), a8w8sym))
400402
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), a8w8sym))
403+
quantizers.append(CadenceAtenQuantizer(AddReluPattern0(), a8w8))
404+
quantizers.append(CadenceAtenQuantizer(AddReluPattern1(), a8w8))
401405
quantizers = quantizers + get_cadence_default_quantizers(is_qat=is_qat)
402406
quantizers.append(CadenceAtenQuantizer(AddPattern(), a8w8))
403407
quantizers.append(CadenceAtenQuantizer(CatPattern(), a8w8))

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,15 @@
215215
[qconfig_A8W8.input_activation],
216216
),
217217
# CadenceFusedConvReluQuantizer test cases
218+
(
219+
"fused_add_relu_A8W8",
220+
lambda self: self._build_add_relu_graph(),
221+
CadenceFusedConvReluQuantizer(),
222+
torch.ops.aten.relu.default,
223+
qconfig_A8W8.output_activation,
224+
# For fused add+relu: both inputs are activations from add node
225+
[qconfig_A8W8.input_activation, qconfig_A8W8.input_activation],
226+
),
218227
(
219228
"fused_conv1d_relu_A8W8sym",
220229
lambda self: self._build_conv1d_relu_graph(),
@@ -508,6 +517,50 @@ def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
508517
)
509518
return gm, max_pool_nodes[0]
510519

520+
def _build_add_relu_graph(
521+
self,
522+
) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]:
523+
"""Build a graph with an add followed by relu (fused pattern).
524+
525+
Returns:
526+
A tuple of (graph_module, relu_node, add_node).
527+
The relu_node is the target node where the annotation is placed.
528+
The add_node is the input source node whose args contain the quantized inputs.
529+
"""
530+
builder = GraphBuilder()
531+
x = builder.placeholder("x", torch.randn(1, 10))
532+
y = builder.placeholder("y", torch.randn(1, 10))
533+
add = builder.call_operator(
534+
op=torch.ops.aten.add.Tensor,
535+
args=(x, y),
536+
meta=NodeMetadata(
537+
{"source_fn_stack": [("add", torch.ops.aten.add.Tensor)]}
538+
),
539+
)
540+
relu = builder.call_operator(
541+
op=torch.ops.aten.relu.default,
542+
args=(add,),
543+
meta=NodeMetadata(
544+
{"source_fn_stack": [("relu", torch.ops.aten.relu.default)]}
545+
),
546+
)
547+
builder.output([relu])
548+
gm = builder.get_graph_module()
549+
550+
relu_nodes = gm.graph.find_nodes(
551+
op="call_function",
552+
target=torch.ops.aten.relu.default,
553+
)
554+
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
555+
556+
add_nodes = gm.graph.find_nodes(
557+
op="call_function",
558+
target=torch.ops.aten.add.Tensor,
559+
)
560+
self.assertEqual(len(add_nodes), 1, "Should find exactly one add node")
561+
562+
return gm, relu_nodes[0], add_nodes[0]
563+
511564
def _build_conv2d_relu_graph(
512565
self,
513566
) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]:

0 commit comments

Comments
 (0)