|
215 | 215 | [qconfig_A8W8.input_activation], |
216 | 216 | ), |
217 | 217 | # CadenceFusedConvReluQuantizer test cases |
| 218 | + ( |
| 219 | + "fused_add_relu_A8W8", |
| 220 | + lambda self: self._build_add_relu_graph(), |
| 221 | + CadenceFusedConvReluQuantizer(), |
| 222 | + torch.ops.aten.relu.default, |
| 223 | + qconfig_A8W8.output_activation, |
| 224 | + # For fused add+relu: both inputs are activations from add node |
| 225 | + [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], |
| 226 | + ), |
218 | 227 | ( |
219 | 228 | "fused_conv1d_relu_A8W8sym", |
220 | 229 | lambda self: self._build_conv1d_relu_graph(), |
@@ -508,6 +517,50 @@ def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: |
508 | 517 | ) |
509 | 518 | return gm, max_pool_nodes[0] |
510 | 519 |
|
| 520 | + def _build_add_relu_graph( |
| 521 | + self, |
| 522 | + ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: |
| 523 | + """Build a graph with an add followed by relu (fused pattern). |
| 524 | +
|
| 525 | + Returns: |
| 526 | + A tuple of (graph_module, relu_node, add_node). |
| 527 | + The relu_node is the target node where the annotation is placed. |
| 528 | + The add_node is the input source node whose args contain the quantized inputs. |
| 529 | + """ |
| 530 | + builder = GraphBuilder() |
| 531 | + x = builder.placeholder("x", torch.randn(1, 10)) |
| 532 | + y = builder.placeholder("y", torch.randn(1, 10)) |
| 533 | + add = builder.call_operator( |
| 534 | + op=torch.ops.aten.add.Tensor, |
| 535 | + args=(x, y), |
| 536 | + meta=NodeMetadata( |
| 537 | + {"source_fn_stack": [("add", torch.ops.aten.add.Tensor)]} |
| 538 | + ), |
| 539 | + ) |
| 540 | + relu = builder.call_operator( |
| 541 | + op=torch.ops.aten.relu.default, |
| 542 | + args=(add,), |
| 543 | + meta=NodeMetadata( |
| 544 | + {"source_fn_stack": [("relu", torch.ops.aten.relu.default)]} |
| 545 | + ), |
| 546 | + ) |
| 547 | + builder.output([relu]) |
| 548 | + gm = builder.get_graph_module() |
| 549 | + |
| 550 | + relu_nodes = gm.graph.find_nodes( |
| 551 | + op="call_function", |
| 552 | + target=torch.ops.aten.relu.default, |
| 553 | + ) |
| 554 | + self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") |
| 555 | + |
| 556 | + add_nodes = gm.graph.find_nodes( |
| 557 | + op="call_function", |
| 558 | + target=torch.ops.aten.add.Tensor, |
| 559 | + ) |
| 560 | + self.assertEqual(len(add_nodes), 1, "Should find exactly one add node") |
| 561 | + |
| 562 | + return gm, relu_nodes[0], add_nodes[0] |
| 563 | + |
511 | 564 | def _build_conv2d_relu_graph( |
512 | 565 | self, |
513 | 566 | ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: |
|
0 commit comments