diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 6b7990c0f2c..5375367b929 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -15,6 +15,8 @@ from executorch.backends.cadence.aot.quantizer.patterns import ( AddmmPattern, AddPattern, + AddReluPattern0, + AddReluPattern1, BmmPattern, CatPattern, Conv1dPattern, @@ -63,6 +65,7 @@ Conv2dReluPattern0, Conv2dReluPattern1, ) +AddReluPatterns = (AddReluPattern0, AddReluPattern1) def get_args_and_kwargs_add( @@ -616,7 +619,20 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 inputs_inputs + weights_inputs + other_inputs + bias_inputs ) kwargs = {} - if isinstance(pattern, AddPattern): + if isinstance(pattern, AddReluPatterns): + # For AddReLU, we are fusing Add+ReLU. + # The quantized_add op performs requantization, + # so the relu is implicit in the output quant params. + check_out_zero_point_is_min_range( + quant_node.args[2], quant_node.args[5] + ) + args, kwargs = get_args_and_kwargs_add( + graph_module, + inputs_inputs, + dequants_inputs, + quant_node, + ) + elif isinstance(pattern, AddPattern): args, kwargs = get_args_and_kwargs_add( graph_module, inputs_inputs, diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 2ce50871fc0..07aad18e36a 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -153,6 +153,61 @@ def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_add.per_tensor +# This is a base class for Add+ReLU fusion, since it can be used with two different relu aten ops +class AddReluBasePattern(QuantizationPattern): + @abstractmethod + def partition_types(self) -> List[OpOverload]: + pass + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # The first node should be add, the second should be relu + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + add_node = fused_partition[0].nodes[-1] + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + relu_node = fused_partition[1].nodes[-1] + + # Bail if: + # - the add node is not a tensor add + # - the add node has kwargs (e.g. alpha) + is_tensor_add = isinstance(add_node.args[0], fx.Node) and isinstance( + add_node.args[1], fx.Node + ) + if not is_tensor_add or len(add_node.kwargs) > 0: + return ( + PartitionAnchors( + empty=True, + ), + add_node, + ) + + return ( + PartitionAnchors( + inputs=[(add_node, 0), (add_node, 1)], + weights=[], + biases=[], + output=[(relu_node,)], # Output is from the relu node + ), + relu_node, + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_add.per_tensor + + +# Add + regular relu op fusion +class AddReluPattern0(AddReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.add.Tensor, torch.ops.aten.relu.default] + + +# Add + alternate relu op fusion +class AddReluPattern1(AddReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.add.Tensor, torch.ops.aten.relu_.default] + + class BmmPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: return [torch.ops.aten.bmm.default] diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 4edcd96e132..d521b9f83cf 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -13,6 +13,8 @@ from executorch.backends.cadence.aot.quantizer.patterns import ( AddmmPattern, AddPattern, + AddReluPattern0, + AddReluPattern1, BmmPattern, CatPattern, Conv1dPattern, @@ -398,6 +400,8 @@ def __init__( quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), a8w8sym)) quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), a8w8sym)) quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), a8w8sym)) + quantizers.append(CadenceAtenQuantizer(AddReluPattern0(), a8w8)) + quantizers.append(CadenceAtenQuantizer(AddReluPattern1(), a8w8)) quantizers = quantizers + get_cadence_default_quantizers(is_qat=is_qat) quantizers.append(CadenceAtenQuantizer(AddPattern(), a8w8)) quantizers.append(CadenceAtenQuantizer(CatPattern(), a8w8)) diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index 06e2c08f4f4..dde26f06b7b 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -215,6 +215,15 @@ [qconfig_A8W8.input_activation], ), # CadenceFusedConvReluQuantizer test cases + ( + "fused_add_relu_A8W8", + lambda self: self._build_add_relu_graph(), + CadenceFusedConvReluQuantizer(), + torch.ops.aten.relu.default, + qconfig_A8W8.output_activation, + # For fused add+relu: both inputs are activations from add node + [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], + ), ( "fused_conv1d_relu_A8W8sym", lambda self: self._build_conv1d_relu_graph(), @@ -508,6 +517,50 @@ def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: ) return gm, max_pool_nodes[0] + def _build_add_relu_graph( + self, + ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: + """Build a graph with an add followed by relu (fused pattern). + + Returns: + A tuple of (graph_module, relu_node, add_node). + The relu_node is the target node where the annotation is placed. + The add_node is the input source node whose args contain the quantized inputs. + """ + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 10)) + y = builder.placeholder("y", torch.randn(1, 10)) + add = builder.call_operator( + op=torch.ops.aten.add.Tensor, + args=(x, y), + meta=NodeMetadata( + {"source_fn_stack": [("add", torch.ops.aten.add.Tensor)]} + ), + ) + relu = builder.call_operator( + op=torch.ops.aten.relu.default, + args=(add,), + meta=NodeMetadata( + {"source_fn_stack": [("relu", torch.ops.aten.relu.default)]} + ), + ) + builder.output([relu]) + gm = builder.get_graph_module() + + relu_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.relu.default, + ) + self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") + + add_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.add.Tensor, + ) + self.assertEqual(len(add_nodes), 1, "Should find exactly one add node") + + return gm, relu_nodes[0], add_nodes[0] + def _build_conv2d_relu_graph( self, ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: