Skip to content

Commit bf2243a

Browse files
authored
Move optimization passes from opt_level=0 to opt_level=1 (#18206)
Differential Revision: D96766073 Pull Request resolved: #18206
1 parent fb90480 commit bf2243a

19 files changed

Lines changed: 613 additions & 50 deletions

backends/cadence/aot/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ fbcode_target(_kind = runtime.python_library,
300300
],
301301
typing = True,
302302
deps = [
303+
":fuse_ops",
303304
":ops_registrations",
304305
"//caffe2:torch",
305306
"//executorch/backends/cadence/aot:pass_utils",

backends/cadence/aot/decompose_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
from torch.fx.node import Argument
2424

2525

26-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
26+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
2727
class DecomposeAtenApproxGeluPass(ExportPass):
2828
"""
29-
Decompose the aten gelu op with an approximate arg to a series of simpler ops
29+
Decompose the aten gelu op with an approximate arg to a series of simpler ops.
30+
This is an optimization - gelu has a portable kernel fallback, but decomposing
31+
may be more efficient on some backends.
3032
"""
3133

3234
def call_operator(

backends/cadence/aot/functions.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,15 @@
309309
- arg_meta: null
310310
kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out
311311

312-
- func: cadence::quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
312+
- func: cadence::quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
313313
kernels:
314314
- arg_meta: null
315-
kernel_name: impl::generic::quantized_max_pool2d_out
315+
kernel_name: impl::generic::quantized_max_pool2d_nchw_out
316+
317+
- func: cadence::quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
318+
kernels:
319+
- arg_meta: null
320+
kernel_name: impl::generic::quantized_max_pool2d_nhwc_out
316321

317322
- func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
318323
kernels:

backends/cadence/aot/fuse_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,10 @@ def can_fuse_for_chain(
11701170
return False
11711171

11721172
# checking that permut2(permut1(identity)) == identity, modulo unitary dimensions
1173-
input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape
1173+
producer_input = cast(torch.fx.Node, producer.args[0])
1174+
if "val" not in producer_input.meta:
1175+
return False
1176+
input_shape = producer_input.meta["val"].shape
11741177
ident_dims = list(range(len(input_shape)))
11751178
# this mapping helps to handle both transpose and permutations
11761179
f: dict[Any, Callable] = {

backends/cadence/aot/ops_registrations.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,16 @@ def register_fake(
214214
)
215215

216216
lib.define(
217-
"quantized_max_pool2d(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
217+
"quantized_max_pool2d_nchw(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
218218
)
219219
lib.define(
220-
"quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
220+
"quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
221+
)
222+
lib.define(
223+
"quantized_max_pool2d_nhwc(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
224+
)
225+
lib.define(
226+
"quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
221227
)
222228

223229
lib.define(
@@ -2277,8 +2283,8 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta(
22772283
return input.new_empty(input.size(), dtype=input.dtype)
22782284

22792285

2280-
@register_fake("cadence::quantized_max_pool2d")
2281-
def quantized_max_pool2d_meta(
2286+
@register_fake("cadence::quantized_max_pool2d_nchw")
2287+
def quantized_max_pool2d_nchw_meta(
22822288
input: torch.Tensor,
22832289
kernel_size: list[int],
22842290
stride: list[int],
@@ -2318,6 +2324,47 @@ def quantized_max_pool2d_meta(
23182324
return input.new_empty([batch, channels, height_out, width_out], dtype=input.dtype)
23192325

23202326

2327+
@register_fake("cadence::quantized_max_pool2d_nhwc")
2328+
def quantized_max_pool2d_nhwc_meta(
2329+
input: torch.Tensor,
2330+
kernel_size: list[int],
2331+
stride: list[int],
2332+
padding: list[int],
2333+
dilation: list[int],
2334+
ceil_mode: bool,
2335+
) -> torch.Tensor:
2336+
assert (
2337+
len(kernel_size) == 2
2338+
), f"kernel_size must have 2 elements, got {len(kernel_size)}"
2339+
assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}"
2340+
assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}"
2341+
assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}"
2342+
assert (
2343+
len(input.size()) == 4
2344+
), f"input must be 4D (N, H, W, C), got {len(input.size())}D"
2345+
2346+
batch = input.size(0)
2347+
height_in = input.size(1)
2348+
width_in = input.size(2)
2349+
channels = input.size(3)
2350+
2351+
height_out_raw = (
2352+
height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
2353+
) / stride[0] + 1
2354+
width_out_raw = (
2355+
width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
2356+
) / stride[1] + 1
2357+
2358+
if ceil_mode:
2359+
height_out = ceil(height_out_raw)
2360+
width_out = ceil(width_out_raw)
2361+
else:
2362+
height_out = int(height_out_raw)
2363+
width_out = int(width_out_raw)
2364+
2365+
return input.new_empty([batch, height_out, width_out, channels], dtype=input.dtype)
2366+
2367+
23212368
@register_fake("cadence::fully_connected")
23222369
def fully_connected_meta(
23232370
src: torch.Tensor,

backends/cadence/aot/passes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from executorch.backends.cadence.aot.remove_ops import (
2626
CadenceRemoveNops,
2727
RemoveNopSliceOrViewOpPass,
28+
RemovePermutesAroundElementwiseOps,
2829
RemoveRedundantOps,
2930
)
3031
from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
@@ -89,6 +90,7 @@ def get_passes_in_default_order() -> list[Type[ExportPass]]:
8990
CadenceSimplifyOpsInGraph.passes,
9091
FinalizePipeline,
9192
FuseFullThenReshapePass,
93+
RemovePermutesAroundElementwiseOps,
9294
FuseTransposeOrPermuteOpPairsPass,
9395
RemoveNopSliceOrViewOpPass,
9496
CompileTimeTypeDispatchPass,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def get_anchors(
459459
)
460460

461461
def replacement_op(self) -> OpOverload:
462-
return torch.ops.cadence.quantized_max_pool2d.default
462+
return torch.ops.cadence.quantized_max_pool2d_nchw.default
463463

464464

465465
class MaxPool2dWithoutIndicesPattern(QuantizationPattern):
@@ -498,7 +498,10 @@ def get_anchors(
498498
)
499499

500500
def replacement_op(self) -> OpOverload:
501-
return torch.ops.cadence.quantized_max_pool2d.default
501+
return torch.ops.cadence.quantized_max_pool2d_nchw.default
502+
503+
504+
# This is a base class for ReLU
502505

503506

504507
# This is a base class for ReLU, since it can be used with two different aten ops

backends/cadence/aot/ref_implementations.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,8 +1868,8 @@ def rms_norm(
18681868
return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X)
18691869

18701870

1871-
@impl_tracked(m, "quantized_max_pool2d")
1872-
def quantized_max_pool2d(
1871+
@impl_tracked(m, "quantized_max_pool2d_nchw")
1872+
def quantized_max_pool2d_nchw(
18731873
input: torch.Tensor,
18741874
kernel_size: list[int],
18751875
stride: list[int],
@@ -1897,6 +1897,37 @@ def quantized_max_pool2d(
18971897
)
18981898

18991899

1900+
@impl_tracked(m, "quantized_max_pool2d_nhwc")
1901+
def quantized_max_pool2d_nhwc(
1902+
input: torch.Tensor,
1903+
kernel_size: list[int],
1904+
stride: list[int],
1905+
padding: list[int],
1906+
dilation: list[int],
1907+
ceil_mode: bool,
1908+
) -> torch.Tensor:
1909+
"""
1910+
Quantized max pooling in NHWC layout.
1911+
1912+
Converts NHWC→NCHW, performs max pooling, then converts back NCHW→NHWC.
1913+
"""
1914+
# Convert NHWC [N, H, W, C] to NCHW [N, C, H, W]
1915+
input_nchw = input.permute(0, 3, 1, 2).contiguous()
1916+
1917+
# Call the NCHW version
1918+
output_nchw = quantized_max_pool2d_nchw(
1919+
input_nchw,
1920+
kernel_size=kernel_size,
1921+
stride=stride,
1922+
padding=padding,
1923+
dilation=dilation,
1924+
ceil_mode=ceil_mode,
1925+
)
1926+
1927+
# Convert NCHW [N, C, H_out, W_out] back to NHWC [N, H_out, W_out, C]
1928+
return output_nchw.permute(0, 2, 3, 1).contiguous()
1929+
1930+
19001931
@impl_tracked(m, "where_Scalar")
19011932
def where_Scalar(
19021933
condition: torch.Tensor,

backends/cadence/aot/remove_ops.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515
import torch
1616
import torch.fx
17+
18+
from executorch.backends.cadence.aot.fuse_ops import FuseTransposeOrPermuteOpPairsPass
1719
from executorch.backends.cadence.aot.pass_utils import (
1820
CadencePassAttribute,
1921
get_arg,
2022
register_cadence_pass,
2123
RemoveOrReplacePassInterface,
2224
set_arg,
2325
)
24-
2526
from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
2627
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
2728
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
@@ -33,7 +34,7 @@
3334
from torch.fx.node import Node
3435

3536

36-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
37+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
3738
class RemoveCloneOpsTransformImported(ExportPass):
3839
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
3940
finalize_passes: List[PassType] = [
@@ -44,7 +45,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4445
return result
4546

4647

47-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
48+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
4849
class RemoveDetachCopyPass(RemoveOrReplacePassInterface):
4950
@property
5051
def targets(self) -> list[EdgeOpOverload]:
@@ -66,7 +67,7 @@ class RemoveRedundantOps:
6667
]
6768

6869

69-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
70+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
7071
class RemoveZeroSizedCatArgsPass(RemoveOrReplacePassInterface):
7172
@property
7273
def targets(self) -> list[EdgeOpOverload]:
@@ -120,11 +121,11 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
120121
return False
121122

122123

123-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
124+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
124125
class RemoveNopExpandOpPass(RemoveOrReplacePassInterface):
125126
"""
126127
For an expand op, if the operator shape matches the expand shape, then the
127-
expand is a nop.
128+
expand is a nop. This is an optimization that removes unnecessary ops.
128129
"""
129130

130131
@property
@@ -143,9 +144,9 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
143144
return False
144145

145146

146-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
147+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
147148
class RemoveToOpsPass(RemoveOrReplacePassInterface):
148-
# aten.to.* as of now are all nops
149+
# aten.to.* ops are no-ops in inference - this is an optimization
149150
@property
150151
def targets(self) -> list[EdgeOpOverload]:
151152
return [
@@ -264,11 +265,11 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
264265
return True
265266

266267

267-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
268+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
268269
class RemoveAliasCopyOpPass(RemoveOrReplacePassInterface):
269270
"""
270-
271271
alias_copy is a no-op and can be removed.
272+
This is an optimization that removes unnecessary ops.
272273
"""
273274

274275
@property
@@ -412,6 +413,9 @@ class Subgraph:
412413
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
413414
exir_ops.edge.cadence.quantize_per_tensor.default,
414415
exir_ops.edge.cadence.dequantize_per_tensor.default,
416+
exir_ops.edge.cadence.quantized_relu.per_tensor,
417+
exir_ops.edge.cadence.requantize.per_tensor,
418+
exir_ops.edge.cadence.quantized_add.per_tensor,
415419
# Ops that require special handling.
416420
exir_ops.edge.aten.cat.default,
417421
exir_ops.edge.aten.mean.dim,
@@ -804,6 +808,7 @@ class CommonRemovePasses:
804808
RemoveToOpsPass,
805809
RemoveZeroSizedCatArgsPass,
806810
RemovePermutesAroundElementwiseOps,
811+
FuseTransposeOrPermuteOpPairsPass,
807812
RemoveSqueezeViewBeforeElementwiseOps,
808813
RemoveCatFromSliceCopyPass,
809814
RemoveCloneOpsTransformImported,

0 commit comments

Comments
 (0)