Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from executorch.backends.cadence.aot.quantizer.patterns import (
AddmmPattern,
AddPattern,
AddReluPattern0,
AddReluPattern1,
BmmPattern,
CatPattern,
Conv1dPattern,
Expand Down Expand Up @@ -63,6 +65,7 @@
Conv2dReluPattern0,
Conv2dReluPattern1,
)
AddReluPatterns = (AddReluPattern0, AddReluPattern1)


def get_args_and_kwargs_add(
Expand Down Expand Up @@ -616,7 +619,20 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
inputs_inputs + weights_inputs + other_inputs + bias_inputs
)
kwargs = {}
if isinstance(pattern, AddPattern):
if isinstance(pattern, AddReluPatterns):
# For AddReLU, we are fusing Add+ReLU.
# The quantized_add op performs requantization,
# so the relu is implicit in the output quant params.
check_out_zero_point_is_min_range(
quant_node.args[2], quant_node.args[5]
)
args, kwargs = get_args_and_kwargs_add(
graph_module,
inputs_inputs,
dequants_inputs,
quant_node,
)
elif isinstance(pattern, AddPattern):
args, kwargs = get_args_and_kwargs_add(
graph_module,
inputs_inputs,
Expand Down
55 changes: 55 additions & 0 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,61 @@ def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_add.per_tensor


# This is a base class for Add+ReLU fusion, since it can be used with two different relu aten ops
class AddReluBasePattern(QuantizationPattern):
@abstractmethod
def partition_types(self) -> List[OpOverload]:
pass

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# The first node should be add, the second should be relu
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
add_node = fused_partition[0].nodes[-1]
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
relu_node = fused_partition[1].nodes[-1]

# Bail if:
# - the add node is not a tensor add
# - the add node has kwargs (e.g. alpha)
is_tensor_add = isinstance(add_node.args[0], fx.Node) and isinstance(
add_node.args[1], fx.Node
)
if not is_tensor_add or len(add_node.kwargs) > 0:
return (
PartitionAnchors(
empty=True,
),
add_node,
)

return (
PartitionAnchors(
inputs=[(add_node, 0), (add_node, 1)],
weights=[],
biases=[],
output=[(relu_node,)], # Output is from the relu node
),
relu_node,
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_add.per_tensor


# Add + regular relu op fusion
class AddReluPattern0(AddReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.add.Tensor, torch.ops.aten.relu.default]


# Add + alternate relu op fusion
class AddReluPattern1(AddReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.add.Tensor, torch.ops.aten.relu_.default]


class BmmPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.bmm.default]
Expand Down
4 changes: 4 additions & 0 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from executorch.backends.cadence.aot.quantizer.patterns import (
AddmmPattern,
AddPattern,
AddReluPattern0,
AddReluPattern1,
BmmPattern,
CatPattern,
Conv1dPattern,
Expand Down Expand Up @@ -398,6 +400,8 @@ def __init__(
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), a8w8sym))
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), a8w8sym))
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), a8w8sym))
quantizers.append(CadenceAtenQuantizer(AddReluPattern0(), a8w8))
quantizers.append(CadenceAtenQuantizer(AddReluPattern1(), a8w8))
quantizers = quantizers + get_cadence_default_quantizers(is_qat=is_qat)
quantizers.append(CadenceAtenQuantizer(AddPattern(), a8w8))
quantizers.append(CadenceAtenQuantizer(CatPattern(), a8w8))
Expand Down
53 changes: 53 additions & 0 deletions backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,15 @@
[qconfig_A8W8.input_activation],
),
# CadenceFusedConvReluQuantizer test cases
(
"fused_add_relu_A8W8",
lambda self: self._build_add_relu_graph(),
CadenceFusedConvReluQuantizer(),
torch.ops.aten.relu.default,
qconfig_A8W8.output_activation,
# For fused add+relu: both inputs are activations from add node
[qconfig_A8W8.input_activation, qconfig_A8W8.input_activation],
),
(
"fused_conv1d_relu_A8W8sym",
lambda self: self._build_conv1d_relu_graph(),
Expand Down Expand Up @@ -508,6 +517,50 @@ def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
)
return gm, max_pool_nodes[0]

def _build_add_relu_graph(
self,
) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]:
"""Build a graph with an add followed by relu (fused pattern).

Returns:
A tuple of (graph_module, relu_node, add_node).
The relu_node is the target node where the annotation is placed.
The add_node is the input source node whose args contain the quantized inputs.
"""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 10))
y = builder.placeholder("y", torch.randn(1, 10))
add = builder.call_operator(
op=torch.ops.aten.add.Tensor,
args=(x, y),
meta=NodeMetadata(
{"source_fn_stack": [("add", torch.ops.aten.add.Tensor)]}
),
)
relu = builder.call_operator(
op=torch.ops.aten.relu.default,
args=(add,),
meta=NodeMetadata(
{"source_fn_stack": [("relu", torch.ops.aten.relu.default)]}
),
)
builder.output([relu])
gm = builder.get_graph_module()

relu_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.relu.default,
)
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")

add_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.add.Tensor,
)
self.assertEqual(len(add_nodes), 1, "Should find exactly one add node")

return gm, relu_nodes[0], add_nodes[0]

def _build_conv2d_relu_graph(
self,
) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]:
Expand Down
Loading