Skip to content

Commit f8da665

Browse files
committed
Update base for Update on "[ET Device Support] Schema changes: device info on Tensor and buffer-level device array"
This diff adds device placement information to the ExecuTorch schema to support representing tensor-level device type information, which will be the basic requirement for the following tensor_parser updates. This is part of the Phase 1 implementation to make ET device type work E2E without user-specified device placement. Design doc: https://docs.google.com/document/d/1lwd9BlohmwkN5EEvRulO_b-XnZBwv1nMb5l2K3jfuwA/edit?tab=t.0#heading=h.o6anuvkix4bu Differential Revision: [D93635657](https://our.internmc.facebook.com/intern/diff/D93635657/) [ghstack-poisoned]
2 parents a11e78b + bf2243a commit f8da665

24 files changed

Lines changed: 707 additions & 100 deletions

backends/arm/quantizer/quantization_annotator.py

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,27 @@
1313
import logging
1414
import operator
1515
from dataclasses import dataclass, replace
16-
from typing import Callable, cast, List, Optional, Sequence
16+
from typing import Any, Callable, cast, Iterable, List, NamedTuple, Optional, Sequence
1717

1818
import torch
1919
import torch.fx
2020
from executorch.backends.arm.common.debug import get_node_debug_info
2121
from executorch.backends.arm.common.type import ensure_type
2222
from executorch.backends.arm.quantizer import QuantizationConfig
23-
from torch._subclasses import FakeTensor
2423

24+
from torch._subclasses import FakeTensor
2525
from torch.fx import Node
2626
from torchao.quantization.pt2e import (
2727
FakeQuantize,
2828
FusedMovingAvgObsFakeQuantize,
2929
MovingAveragePerChannelMinMaxObserver,
3030
PartialWrapper,
3131
)
32+
3233
from torchao.quantization.pt2e.quantizer import (
3334
annotate_input_qspec_map,
3435
annotate_output_qspec,
36+
FixedQParamsQuantizationSpec,
3537
QuantizationSpec,
3638
QuantizationSpecBase,
3739
SharedQuantizationSpec,
@@ -78,6 +80,11 @@ def __init__(self):
7880
self.quant_output: Optional[_QuantProperty] = None
7981

8082

83+
class _QParams(NamedTuple):
84+
scale: float
85+
zero_point: int
86+
87+
8188
def _as_list(x):
8289
"""Return ``x`` wrapped as a list if needed.
8390
@@ -391,14 +398,16 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
391398

392399

393400
def _match_pattern(
394-
node: Node, pattern: List[List], filter_fn: Optional[Callable[[Node], bool]] = None
401+
node: Node,
402+
pattern: Sequence[Iterable[object]],
403+
filter_fn: Optional[Callable[[Node], bool]] = None,
395404
) -> bool:
396405
"""Check whether a node chain matches a pattern.
397406
398407
Verify a chain of ancestors -> node -> descendants matches the provided
399408
``pattern``. If ``filter_fn`` is provided, require all nodes in the chain
400-
to pass the filter. Each pattern element is a list of disjunctive node
401-
targets.
409+
to pass the filter. Each pattern element is an iterable of disjunctive
410+
node targets.
402411
403412
"""
404413
if len(pattern) < 1:
@@ -432,16 +441,39 @@ def _match_pattern(
432441
return left_condition and right_condition
433442

434443

435-
_conv_ops = [
444+
_conv_ops = {
436445
torch.ops.aten.conv1d.default,
437446
torch.ops.aten.conv2d.default,
438447
torch.ops.aten.conv2d.padding,
439448
torch.ops.aten.conv_transpose2d.input,
440449
torch.ops.aten.conv3d.default,
441450
torch.ops.aten.conv3d.padding,
442-
]
451+
}
443452

444-
_one_to_one = [
453+
# For these ops, we use fixed qspecs, meaning that quantization params for
454+
# these are statically defined. This is to prevent issues with out-of-range
455+
# values when using dynamic quantization.
456+
#
457+
# Dict of operator to a dict of num_bits to qparams for that operator.
458+
_fixed_input_qspec_ops: dict[Any, dict[int, _QParams]] = {
459+
# acos has a valid range of [-1, 1]
460+
torch.ops.aten.acos.default: {
461+
8: _QParams((1.0 - (-1.0)) / (1 << 8), 0),
462+
16: _QParams((1.0 - (-1.0)) / (1 << 16), 0),
463+
},
464+
# asin has a valid range of [-1, 1]
465+
torch.ops.aten.asin.default: {
466+
8: _QParams((1.0 - (-1.0)) / (1 << 8), 0),
467+
16: _QParams((1.0 - (-1.0)) / (1 << 16), 0),
468+
},
469+
# atanh has a valid range of (-1, 1) (excluding -1 and 1).
470+
torch.ops.aten.atanh.default: {
471+
8: _QParams((0.999 - (-0.999)) / (1 << 8), 0),
472+
16: _QParams((0.99999 - (-0.99999)) / (1 << 16), 0),
473+
},
474+
}
475+
476+
_one_to_one = {
445477
torch.ops.aten.abs.default,
446478
torch.ops.aten.ceil.default,
447479
torch.ops.aten.erf.default,
@@ -472,16 +504,13 @@ def _match_pattern(
472504
torch.ops.aten.log1p.default,
473505
torch.ops.aten.acosh.default,
474506
torch.ops.aten.sign.default,
475-
torch.ops.aten.asin.default,
476-
torch.ops.aten.atanh.default,
477507
torch.ops.aten.asinh.default,
478508
torch.ops.aten.cosh.default,
479-
torch.ops.aten.acos.default,
480509
torch.ops.aten.cumsum.default,
481510
torch.ops.aten.tan.default,
482-
]
511+
}
483512

484-
_one_to_one_shared_input_qspec = [
513+
_one_to_one_shared_input_qspec = {
485514
torch.ops.aten.squeeze.default,
486515
torch.ops.aten.squeeze_copy.default,
487516
torch.ops.aten.squeeze_copy.dim,
@@ -539,9 +568,9 @@ def _match_pattern(
539568
# dequant -> neg -> requant chain.
540569
torch.ops.aten.neg.default,
541570
torch.ops.aten.detach_copy.default,
542-
]
571+
}
543572

544-
_one_to_one_shared_input_or_input_act_qspec = [
573+
_one_to_one_shared_input_or_input_act_qspec = {
545574
torch.ops.aten.alias.default,
546575
torch.ops.aten.clone.default,
547576
torch.ops.aten.hardtanh.default,
@@ -562,7 +591,7 @@ def _match_pattern(
562591
torch.ops.aten.alias_copy.default,
563592
torch.ops.aten.pixel_shuffle.default,
564593
torch.ops.aten.pixel_unshuffle.default,
565-
]
594+
}
566595

567596

568597
def get_quant_properties( # noqa: C901
@@ -615,13 +644,13 @@ def any_or_hardtanh_min_zero(n: Node):
615644
node,
616645
[
617646
_conv_ops,
618-
[torch.ops.aten.batch_norm.default],
619-
[
647+
{torch.ops.aten.batch_norm.default},
648+
{
620649
torch.ops.aten.relu.default,
621650
torch.ops.aten.relu_.default,
622651
torch.ops.aten.hardtanh.default,
623652
torch.ops.aten.hardtanh_.default,
624-
],
653+
},
625654
],
626655
filter_fn=any_or_hardtanh_min_zero,
627656
):
@@ -644,7 +673,7 @@ def any_or_hardtanh_min_zero(n: Node):
644673
node,
645674
[
646675
_conv_ops,
647-
[torch.ops.aten.batch_norm.default],
676+
{torch.ops.aten.batch_norm.default},
648677
],
649678
):
650679
if node.target in _conv_ops:
@@ -654,23 +683,21 @@ def any_or_hardtanh_min_zero(n: Node):
654683
_QuantProperty(1, conv_weight_qspec, mark_annotated=True),
655684
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
656685
]
657-
elif node.target in [
658-
torch.ops.aten.batch_norm.default,
659-
]:
686+
elif node.target in {torch.ops.aten.batch_norm.default}:
660687
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
661688
elif not is_symmetric and _match_pattern(
662689
node,
663690
[
664-
[
691+
{
665692
*_conv_ops,
666693
torch.ops.aten.linear.default,
667-
],
668-
[
694+
},
695+
{
669696
torch.ops.aten.relu.default,
670697
torch.ops.aten.relu_.default,
671698
torch.ops.aten.hardtanh.default,
672699
torch.ops.aten.hardtanh_.default,
673-
],
700+
},
674701
],
675702
any_or_hardtanh_min_zero,
676703
):
@@ -784,6 +811,25 @@ def any_or_hardtanh_min_zero(n: Node):
784811
elif node.target in _one_to_one:
785812
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
786813
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
814+
elif node.target in _fixed_input_qspec_ops:
815+
num_bits = torch.iinfo(input_act_qspec.dtype).bits
816+
qparams = _fixed_input_qspec_ops[node.target][num_bits]
817+
818+
quant_properties.quant_inputs = [
819+
_QuantProperty(
820+
0,
821+
FixedQParamsQuantizationSpec(
822+
dtype=input_act_qspec.dtype,
823+
scale=qparams.scale,
824+
zero_point=qparams.zero_point,
825+
quant_min=input_act_qspec.quant_min,
826+
quant_max=input_act_qspec.quant_max,
827+
qscheme=input_act_qspec.qscheme,
828+
is_dynamic=input_act_qspec.is_dynamic,
829+
),
830+
)
831+
]
832+
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
787833
elif node.target in _one_to_one_shared_input_qspec:
788834
input_node = ensure_type(Node, node.args[0])
789835
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]

backends/arm/test/ops/test_acos.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def test_acos_tosa_INT(test_data: Tuple):
6565
(test_data(),),
6666
aten_op=aten_op,
6767
exir_op=exir_op,
68-
frobenius_threshold=0.5, # MLETORCH-1709
6968
)
7069
pipeline.run()
7170

backends/arm/test/ops/test_asin.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ def test_asin_tosa_INT(test_data: Tuple):
5555
(test_data(),),
5656
aten_op=[],
5757
exir_op=[],
58-
frobenius_threshold=0.6, # MLETORCH-1709
59-
cosine_threshold=0.8, # MLETORCH-1709
6058
)
6159
pipeline.run()
6260

backends/arm/test/ops/test_atanh.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@
2626
test_data_suite = {
2727
"zeros": torch.zeros(1, 10, 10, 10),
2828
"zeros_alt_shape": torch.zeros(1, 10, 3, 5),
29-
"ones": torch.ones(10, 10, 10),
3029
"rand": torch.rand(10, 10) - 0.5,
3130
"rand_alt_shape": torch.rand(1, 10, 3, 5) - 0.5,
3231
"ramp": torch.arange(-1, 1, 0.2),
33-
"near_bounds": torch.tensor([-0.999999, -0.999, -0.9, 0.9, 0.999, 0.999999]),
32+
"near_bounds": torch.tensor([-0.99, -0.9, 0.9, 0.99]),
3433
"on_bounds": torch.tensor([-1.0, 1.0]),
3534
}
3635

@@ -58,9 +57,11 @@ def test_atanh_tosa_INT(test_data: Tuple):
5857
(test_data,),
5958
aten_op=aten_op,
6059
exir_op=exir_op,
61-
frobenius_threshold=None, # MLETORCH-1709
62-
cosine_threshold=0.7,
6360
)
61+
if torch.any(test_data >= 1) or torch.any(test_data <= -1):
62+
# The quantized model will saturate to max/min values while the
63+
# original model will return inf/-inf, so comparison wont be valid here.
64+
pipeline.pop_stage("run_method_and_compare_outputs.original_model")
6465
pipeline.run()
6566

6667

backends/arm/test/tester/analyze_output_utils.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -337,22 +337,6 @@ def dump_error_output(
337337
logger.error(f"{atol=}, {rtol=}, {qtol=}")
338338

339339

340-
if __name__ == "__main__":
341-
"""This is expected to produce the example output of print_diff."""
342-
torch.manual_seed(0)
343-
a = torch.rand(3, 3, 2, 2) * 0.01
344-
b = a.clone().detach()
345-
logger.info(b)
346-
347-
# Errors in all channels in element (1,1)
348-
a[1, :, 1, 1] = 0
349-
# Errors in (0,0) and (1,1) in channel 1
350-
a[2, 1, 1, 1] = 0
351-
a[2, 1, 0, 0] = 0
352-
353-
print_error_diffs(a, b)
354-
355-
356340
def compare_rel_frobenius_and_cosine_similarity(
357341
reference_output: torch.Tensor,
358342
test_output: torch.Tensor,
@@ -452,3 +436,19 @@ def compare_rel_frobenius_and_cosine_similarity(
452436
f"Tensor-wise comparison failed: Cosine similarity {cosine_similarity} is below threshold {cosine_threshold}."
453437
f" (Relative frobenius error: {relative_frobenius_error}, threshold {frobenius_threshold})."
454438
)
439+
440+
441+
if __name__ == "__main__":
442+
"""This is expected to produce the example output of print_diff."""
443+
torch.manual_seed(0)
444+
a = torch.rand(3, 3, 2, 2) * 0.01
445+
b = a.clone().detach()
446+
logger.info(b)
447+
448+
# Errors in all channels in element (1,1)
449+
a[1, :, 1, 1] = 0
450+
# Errors in (0,0) and (1,1) in channel 1
451+
a[2, 1, 1, 1] = 0
452+
a[2, 1, 0, 0] = 0
453+
454+
print_error_diffs(a, b)

backends/cadence/aot/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ fbcode_target(_kind = runtime.python_library,
300300
],
301301
typing = True,
302302
deps = [
303+
":fuse_ops",
303304
":ops_registrations",
304305
"//caffe2:torch",
305306
"//executorch/backends/cadence/aot:pass_utils",

backends/cadence/aot/decompose_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
from torch.fx.node import Argument
2424

2525

26-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
26+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
2727
class DecomposeAtenApproxGeluPass(ExportPass):
2828
"""
29-
Decompose the aten gelu op with an approximate arg to a series of simpler ops
29+
Decompose the aten gelu op with an approximate arg to a series of simpler ops.
30+
This is an optimization - gelu has a portable kernel fallback, but decomposing
31+
may be more efficient on some backends.
3032
"""
3133

3234
def call_operator(

backends/cadence/aot/functions.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,15 @@
309309
- arg_meta: null
310310
kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out
311311

312-
- func: cadence::quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
312+
- func: cadence::quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
313313
kernels:
314314
- arg_meta: null
315-
kernel_name: impl::generic::quantized_max_pool2d_out
315+
kernel_name: impl::generic::quantized_max_pool2d_nchw_out
316+
317+
- func: cadence::quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
318+
kernels:
319+
- arg_meta: null
320+
kernel_name: impl::generic::quantized_max_pool2d_nhwc_out
316321

317322
- func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
318323
kernels:

backends/cadence/aot/fuse_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,10 @@ def can_fuse_for_chain(
11701170
return False
11711171

11721172
# checking that permut2(permut1(identity)) == identity, modulo unitary dimensions
1173-
input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape
1173+
producer_input = cast(torch.fx.Node, producer.args[0])
1174+
if "val" not in producer_input.meta:
1175+
return False
1176+
input_shape = producer_input.meta["val"].shape
11741177
ident_dims = list(range(len(input_shape)))
11751178
# this mapping helps to handle both transpose and permutations
11761179
f: dict[Any, Callable] = {

0 commit comments

Comments
 (0)