Skip to content

Commit 6821c1a

Browse files
authored
Arm backend: Update _adjust_weight_qspec_for_conv_transpose (#18094)
* FusedMovingAvgObsFakeQuantize only supports channel axis 0, fall back to non-fused (MovingAveragePerChannelMinMaxObserver + FakeQuantize) when applicable. * Always check/update ch_axis for FakeQuantize/observer constructor regardless if ch axis is correct in Qspec. * Adds unit tests. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @Sebastian-Larsson @robell --------- Signed-off-by: Måns Nilsson <mans.nilsson@arm.com>
1 parent 5c8f9e5 commit 6821c1a

2 files changed

Lines changed: 279 additions & 29 deletions

File tree

backends/arm/quantizer/quantization_annotator.py

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
from torch._subclasses import FakeTensor
2424

2525
from torch.fx import Node
26-
from torchao.quantization.pt2e import PartialWrapper
26+
from torchao.quantization.pt2e import (
27+
FakeQuantize,
28+
FusedMovingAvgObsFakeQuantize,
29+
MovingAveragePerChannelMinMaxObserver,
30+
PartialWrapper,
31+
)
2732
from torchao.quantization.pt2e.quantizer import (
2833
annotate_input_qspec_map,
2934
annotate_output_qspec,
@@ -41,6 +46,12 @@
4146
logger = logging.getLogger(__name__)
4247

4348

49+
def _is_fused_moving_avg_obs_fake_quant_ctor(func: object) -> bool:
50+
"""Return True when ``func`` is the fused fake-quant class or a subclass."""
51+
52+
return isinstance(func, type) and issubclass(func, FusedMovingAvgObsFakeQuantize)
53+
54+
4455
@dataclass(frozen=True)
4556
class _QuantProperty:
4657
"""Specify how the input/output at 'index' must be quantized."""
@@ -85,10 +96,29 @@ def _as_list(x):
8596
]
8697

8798

88-
def _adjust_weight_qspec_for_conv_transpose(node: Node, weight_qspec):
99+
def _adjust_weight_qspec_for_conv_transpose(
100+
node: Node, weight_qspec: QuantizationSpec | None
101+
) -> QuantizationSpec | None:
102+
"""Adjust weight qspec axis/ctor for conv_transpose2d per-channel
103+
quantization.
104+
105+
Use axis 1 for ungrouped ConvTranspose2d weights because the weight layout is
106+
(in_channels, out_channels / groups, kH, kW). Grouped transpose conv keeps axis 0.
107+
108+
If the weight qspec contains a TorchAO QAT fake-quant/observer constructor
109+
(e.g. PartialWrapper(partial(...)) or a with_args-based constructor), the
110+
constructor is rebuilt with the corrected axis. For fused per-channel
111+
FakeQuantize, which only supports axis 0, the constructor is replaced with
112+
a non-fused FakeQuantize + MovingAveragePerChannelMinMaxObserver when the
113+
required axis is not 0.
114+
115+
Return the qspec unchanged when weights are unset.
116+
117+
"""
118+
89119
if (
90120
node.target != torch.ops.aten.conv_transpose2d.input
91-
or not isinstance(weight_qspec, QuantizationSpec)
121+
or weight_qspec is None
92122
or weight_qspec.qscheme != torch.per_channel_symmetric
93123
):
94124
return weight_qspec
@@ -101,27 +131,42 @@ def _adjust_weight_qspec_for_conv_transpose(node: Node, weight_qspec):
101131
if len(node.args) > 6 and isinstance(node.args[6], int):
102132
groups = node.args[6]
103133
expected_axis = 0 if groups != 1 else 1
104-
if weight_qspec.ch_axis == expected_axis:
105-
return weight_qspec
106134

107135
observer_or_fake_quant_ctr = weight_qspec.observer_or_fake_quant_ctr
108-
# TorchAO PT2e QAT commonly represents the ctor as PartialWrapper(partial(...)).
109-
# Rebuild it to update ch_axis while preserving callable_args.
136+
observer_or_fake_quant_ctr_changed = False
137+
# QAT FakeQuantize uses PartialWrapper; rebuild its partial to update ch_axis
138+
# without breaking TorchAO introspection.
110139
if isinstance(observer_or_fake_quant_ctr, PartialWrapper):
111140
original_callable_args = dict(observer_or_fake_quant_ctr.callable_args)
112141
base_partial = observer_or_fake_quant_ctr.p
113142
if isinstance(base_partial, functools.partial):
114143
base_keywords = dict(base_partial.keywords or {})
115144
base_keywords["ch_axis"] = expected_axis
116-
observer_or_fake_quant_ctr = PartialWrapper(
117-
functools.partial(base_partial.func, **base_keywords)
118-
)
145+
if (
146+
_is_fused_moving_avg_obs_fake_quant_ctor(base_partial.func)
147+
and expected_axis != 0
148+
):
149+
# Fused per-channel FakeQuant only supports axis 0; for other axes,
150+
# fall back to FakeQuantize with a per-channel observer.
151+
base_keywords["observer"] = MovingAveragePerChannelMinMaxObserver
152+
observer_or_fake_quant_ctr = PartialWrapper(
153+
functools.partial(FakeQuantize, **base_keywords)
154+
)
155+
else:
156+
observer_or_fake_quant_ctr = PartialWrapper(
157+
functools.partial(base_partial.func, **base_keywords)
158+
)
119159
observer_or_fake_quant_ctr.callable_args = original_callable_args
120-
# Non-QAT observer/fake-quant constructors can be updated via with_args.
160+
observer_or_fake_quant_ctr_changed = True
161+
# Non-QAT observer/fake-quant ctrs can be updated via with_args.
121162
elif hasattr(observer_or_fake_quant_ctr, "with_args"):
122163
observer_or_fake_quant_ctr = observer_or_fake_quant_ctr.with_args(
123164
ch_axis=expected_axis
124165
)
166+
observer_or_fake_quant_ctr_changed = True
167+
168+
if weight_qspec.ch_axis == expected_axis and not observer_or_fake_quant_ctr_changed:
169+
return weight_qspec
125170

126171
return QuantizationSpec(
127172
dtype=weight_qspec.dtype,
@@ -581,9 +626,10 @@ def any_or_hardtanh_min_zero(n: Node):
581626
filter_fn=any_or_hardtanh_min_zero,
582627
):
583628
if node.target in _conv_ops:
629+
conv_weight_qspec = ensure_type(QuantizationSpec, weight_qspec) # For MyPy
584630
quant_properties.quant_inputs = [
585631
_QuantProperty(0, input_act_qspec),
586-
_QuantProperty(1, weight_qspec, mark_annotated=True),
632+
_QuantProperty(1, conv_weight_qspec, mark_annotated=True),
587633
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
588634
]
589635
elif node.target in (
@@ -602,9 +648,10 @@ def any_or_hardtanh_min_zero(n: Node):
602648
],
603649
):
604650
if node.target in _conv_ops:
651+
conv_weight_qspec = ensure_type(QuantizationSpec, weight_qspec) # For MyPy
605652
quant_properties.quant_inputs = [
606653
_QuantProperty(0, input_act_qspec),
607-
_QuantProperty(1, weight_qspec, mark_annotated=True),
654+
_QuantProperty(1, conv_weight_qspec, mark_annotated=True),
608655
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
609656
]
610657
elif node.target in [
@@ -631,9 +678,12 @@ def any_or_hardtanh_min_zero(n: Node):
631678
*_conv_ops,
632679
torch.ops.aten.linear.default,
633680
):
681+
conv_or_linear_weight_qspec = ensure_type(
682+
QuantizationSpec, weight_qspec
683+
) # For MyPy
634684
quant_properties.quant_inputs = [
635685
_QuantProperty(0, input_act_qspec),
636-
_QuantProperty(1, weight_qspec, mark_annotated=True),
686+
_QuantProperty(1, conv_or_linear_weight_qspec, mark_annotated=True),
637687
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
638688
]
639689
else:
@@ -642,9 +692,12 @@ def any_or_hardtanh_min_zero(n: Node):
642692
*_conv_ops,
643693
torch.ops.aten.linear.default,
644694
):
695+
conv_or_linear_weight_qspec = ensure_type(
696+
QuantizationSpec, weight_qspec
697+
) # For MyPy
645698
quant_properties.quant_inputs = [
646699
_QuantProperty(0, input_act_qspec),
647-
_QuantProperty(1, weight_qspec, mark_annotated=True),
700+
_QuantProperty(1, conv_or_linear_weight_qspec, mark_annotated=True),
648701
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
649702
]
650703
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)

0 commit comments

Comments
 (0)