Skip to content

Commit 2548ee1

Browse files
authored
Revert "Move optimization passes from opt_level=0 to opt_level=1 (pytorch#18206)" (pytorch#18331)
Reverts pytorch#18206
1 parent 9076110 commit 2548ee1

19 files changed

Lines changed: 50 additions & 613 deletions

backends/cadence/aot/BUCK

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,6 @@ fbcode_target(_kind = runtime.python_library,
300300
],
301301
typing = True,
302302
deps = [
303-
":fuse_ops",
304303
":ops_registrations",
305304
"//caffe2:torch",
306305
"//executorch/backends/cadence/aot:pass_utils",

backends/cadence/aot/decompose_ops.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,10 @@
2323
from torch.fx.node import Argument
2424

2525

26-
@register_cadence_pass(CadencePassAttribute(opt_level=1))
26+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2727
class DecomposeAtenApproxGeluPass(ExportPass):
2828
"""
29-
Decompose the aten gelu op with an approximate arg to a series of simpler ops.
30-
This is an optimization - gelu has a portable kernel fallback, but decomposing
31-
may be more efficient on some backends.
29+
Decompose the aten gelu op with an approximate arg to a series of simpler ops
3230
"""
3331

3432
def call_operator(

backends/cadence/aot/functions.yaml

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -309,15 +309,10 @@
309309
- arg_meta: null
310310
kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out
311311

312-
- func: cadence::quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
312+
- func: cadence::quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
313313
kernels:
314314
- arg_meta: null
315-
kernel_name: impl::generic::quantized_max_pool2d_nchw_out
316-
317-
- func: cadence::quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
318-
kernels:
319-
- arg_meta: null
320-
kernel_name: impl::generic::quantized_max_pool2d_nhwc_out
315+
kernel_name: impl::generic::quantized_max_pool2d_out
321316

322317
- func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
323318
kernels:

backends/cadence/aot/fuse_ops.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,10 +1170,7 @@ def can_fuse_for_chain(
11701170
return False
11711171

11721172
# checking that permut2(permut1(identity)) == identity, modulo unitary dimensions
1173-
producer_input = cast(torch.fx.Node, producer.args[0])
1174-
if "val" not in producer_input.meta:
1175-
return False
1176-
input_shape = producer_input.meta["val"].shape
1173+
input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape
11771174
ident_dims = list(range(len(input_shape)))
11781175
# this mapping helps to handle both transpose and permutations
11791176
f: dict[Any, Callable] = {

backends/cadence/aot/ops_registrations.py

Lines changed: 4 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -214,16 +214,10 @@ def register_fake(
214214
)
215215

216216
lib.define(
217-
"quantized_max_pool2d_nchw(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
217+
"quantized_max_pool2d(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
218218
)
219219
lib.define(
220-
"quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
221-
)
222-
lib.define(
223-
"quantized_max_pool2d_nhwc(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
224-
)
225-
lib.define(
226-
"quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
220+
"quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
227221
)
228222

229223
lib.define(
@@ -2283,8 +2277,8 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta(
22832277
return input.new_empty(input.size(), dtype=input.dtype)
22842278

22852279

2286-
@register_fake("cadence::quantized_max_pool2d_nchw")
2287-
def quantized_max_pool2d_nchw_meta(
2280+
@register_fake("cadence::quantized_max_pool2d")
2281+
def quantized_max_pool2d_meta(
22882282
input: torch.Tensor,
22892283
kernel_size: list[int],
22902284
stride: list[int],
@@ -2324,47 +2318,6 @@ def quantized_max_pool2d_nchw_meta(
23242318
return input.new_empty([batch, channels, height_out, width_out], dtype=input.dtype)
23252319

23262320

2327-
@register_fake("cadence::quantized_max_pool2d_nhwc")
2328-
def quantized_max_pool2d_nhwc_meta(
2329-
input: torch.Tensor,
2330-
kernel_size: list[int],
2331-
stride: list[int],
2332-
padding: list[int],
2333-
dilation: list[int],
2334-
ceil_mode: bool,
2335-
) -> torch.Tensor:
2336-
assert (
2337-
len(kernel_size) == 2
2338-
), f"kernel_size must have 2 elements, got {len(kernel_size)}"
2339-
assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}"
2340-
assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}"
2341-
assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}"
2342-
assert (
2343-
len(input.size()) == 4
2344-
), f"input must be 4D (N, H, W, C), got {len(input.size())}D"
2345-
2346-
batch = input.size(0)
2347-
height_in = input.size(1)
2348-
width_in = input.size(2)
2349-
channels = input.size(3)
2350-
2351-
height_out_raw = (
2352-
height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
2353-
) / stride[0] + 1
2354-
width_out_raw = (
2355-
width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
2356-
) / stride[1] + 1
2357-
2358-
if ceil_mode:
2359-
height_out = ceil(height_out_raw)
2360-
width_out = ceil(width_out_raw)
2361-
else:
2362-
height_out = int(height_out_raw)
2363-
width_out = int(width_out_raw)
2364-
2365-
return input.new_empty([batch, height_out, width_out, channels], dtype=input.dtype)
2366-
2367-
23682321
@register_fake("cadence::fully_connected")
23692322
def fully_connected_meta(
23702323
src: torch.Tensor,

backends/cadence/aot/passes.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from executorch.backends.cadence.aot.remove_ops import (
2626
CadenceRemoveNops,
2727
RemoveNopSliceOrViewOpPass,
28-
RemovePermutesAroundElementwiseOps,
2928
RemoveRedundantOps,
3029
)
3130
from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
@@ -90,7 +89,6 @@ def get_passes_in_default_order() -> list[Type[ExportPass]]:
9089
CadenceSimplifyOpsInGraph.passes,
9190
FinalizePipeline,
9291
FuseFullThenReshapePass,
93-
RemovePermutesAroundElementwiseOps,
9492
FuseTransposeOrPermuteOpPairsPass,
9593
RemoveNopSliceOrViewOpPass,
9694
CompileTimeTypeDispatchPass,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def get_anchors(
459459
)
460460

461461
def replacement_op(self) -> OpOverload:
462-
return torch.ops.cadence.quantized_max_pool2d_nchw.default
462+
return torch.ops.cadence.quantized_max_pool2d.default
463463

464464

465465
class MaxPool2dWithoutIndicesPattern(QuantizationPattern):
@@ -498,10 +498,7 @@ def get_anchors(
498498
)
499499

500500
def replacement_op(self) -> OpOverload:
501-
return torch.ops.cadence.quantized_max_pool2d_nchw.default
502-
503-
504-
# This is a base class for ReLU
501+
return torch.ops.cadence.quantized_max_pool2d.default
505502

506503

507504
# This is a base class for ReLU, since it can be used with two different aten ops

backends/cadence/aot/ref_implementations.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,8 +1868,8 @@ def rms_norm(
18681868
return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X)
18691869

18701870

1871-
@impl_tracked(m, "quantized_max_pool2d_nchw")
1872-
def quantized_max_pool2d_nchw(
1871+
@impl_tracked(m, "quantized_max_pool2d")
1872+
def quantized_max_pool2d(
18731873
input: torch.Tensor,
18741874
kernel_size: list[int],
18751875
stride: list[int],
@@ -1897,37 +1897,6 @@ def quantized_max_pool2d_nchw(
18971897
)
18981898

18991899

1900-
@impl_tracked(m, "quantized_max_pool2d_nhwc")
1901-
def quantized_max_pool2d_nhwc(
1902-
input: torch.Tensor,
1903-
kernel_size: list[int],
1904-
stride: list[int],
1905-
padding: list[int],
1906-
dilation: list[int],
1907-
ceil_mode: bool,
1908-
) -> torch.Tensor:
1909-
"""
1910-
Quantized max pooling in NHWC layout.
1911-
1912-
Converts NHWC→NCHW, performs max pooling, then converts back NCHW→NHWC.
1913-
"""
1914-
# Convert NHWC [N, H, W, C] to NCHW [N, C, H, W]
1915-
input_nchw = input.permute(0, 3, 1, 2).contiguous()
1916-
1917-
# Call the NCHW version
1918-
output_nchw = quantized_max_pool2d_nchw(
1919-
input_nchw,
1920-
kernel_size=kernel_size,
1921-
stride=stride,
1922-
padding=padding,
1923-
dilation=dilation,
1924-
ceil_mode=ceil_mode,
1925-
)
1926-
1927-
# Convert NCHW [N, C, H_out, W_out] back to NHWC [N, H_out, W_out, C]
1928-
return output_nchw.permute(0, 2, 3, 1).contiguous()
1929-
1930-
19311900
@impl_tracked(m, "where_Scalar")
19321901
def where_Scalar(
19331902
condition: torch.Tensor,

backends/cadence/aot/remove_ops.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,14 @@
1414

1515
import torch
1616
import torch.fx
17-
18-
from executorch.backends.cadence.aot.fuse_ops import FuseTransposeOrPermuteOpPairsPass
1917
from executorch.backends.cadence.aot.pass_utils import (
2018
CadencePassAttribute,
2119
get_arg,
2220
register_cadence_pass,
2321
RemoveOrReplacePassInterface,
2422
set_arg,
2523
)
24+
2625
from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
2726
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
2827
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
@@ -34,7 +33,7 @@
3433
from torch.fx.node import Node
3534

3635

37-
@register_cadence_pass(CadencePassAttribute(opt_level=1))
36+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
3837
class RemoveCloneOpsTransformImported(ExportPass):
3938
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4039
finalize_passes: List[PassType] = [
@@ -45,7 +44,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4544
return result
4645

4746

48-
@register_cadence_pass(CadencePassAttribute(opt_level=1))
47+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
4948
class RemoveDetachCopyPass(RemoveOrReplacePassInterface):
5049
@property
5150
def targets(self) -> list[EdgeOpOverload]:
@@ -67,7 +66,7 @@ class RemoveRedundantOps:
6766
]
6867

6968

70-
@register_cadence_pass(CadencePassAttribute(opt_level=1))
69+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
7170
class RemoveZeroSizedCatArgsPass(RemoveOrReplacePassInterface):
7271
@property
7372
def targets(self) -> list[EdgeOpOverload]:
@@ -121,11 +120,11 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
121120
return False
122121

123122

124-
@register_cadence_pass(CadencePassAttribute(opt_level=1))
123+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
125124
class RemoveNopExpandOpPass(RemoveOrReplacePassInterface):
126125
"""
127126
For an expand op, if the operator shape matches the expand shape, then the
128-
expand is a nop. This is an optimization that removes unnecessary ops.
127+
expand is a nop.
129128
"""
130129

131130
@property
@@ -144,9 +143,9 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
144143
return False
145144

146145

147-
@register_cadence_pass(CadencePassAttribute(opt_level=1))
146+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
148147
class RemoveToOpsPass(RemoveOrReplacePassInterface):
149-
# aten.to.* ops are no-ops in inference - this is an optimization
148+
# aten.to.* as of now are all nops
150149
@property
151150
def targets(self) -> list[EdgeOpOverload]:
152151
return [
@@ -265,11 +264,11 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
265264
return True
266265

267266

268-
@register_cadence_pass(CadencePassAttribute(opt_level=1))
267+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
269268
class RemoveAliasCopyOpPass(RemoveOrReplacePassInterface):
270269
"""
270+
271271
alias_copy is a no-op and can be removed.
272-
This is an optimization that removes unnecessary ops.
273272
"""
274273

275274
@property
@@ -413,9 +412,6 @@ class Subgraph:
413412
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
414413
exir_ops.edge.cadence.quantize_per_tensor.default,
415414
exir_ops.edge.cadence.dequantize_per_tensor.default,
416-
exir_ops.edge.cadence.quantized_relu.per_tensor,
417-
exir_ops.edge.cadence.requantize.per_tensor,
418-
exir_ops.edge.cadence.quantized_add.per_tensor,
419415
# Ops that require special handling.
420416
exir_ops.edge.aten.cat.default,
421417
exir_ops.edge.aten.mean.dim,
@@ -808,7 +804,6 @@ class CommonRemovePasses:
808804
RemoveToOpsPass,
809805
RemoveZeroSizedCatArgsPass,
810806
RemovePermutesAroundElementwiseOps,
811-
FuseTransposeOrPermuteOpPairsPass,
812807
RemoveSqueezeViewBeforeElementwiseOps,
813808
RemoveCatFromSliceCopyPass,
814809
RemoveCloneOpsTransformImported,

0 commit comments

Comments
 (0)