2323from torch ._subclasses import FakeTensor
2424
2525from torch .fx import Node
26- from torchao .quantization .pt2e import PartialWrapper
26+ from torchao .quantization .pt2e import (
27+ FakeQuantize ,
28+ FusedMovingAvgObsFakeQuantize ,
29+ MovingAveragePerChannelMinMaxObserver ,
30+ PartialWrapper ,
31+ )
2732from torchao .quantization .pt2e .quantizer import (
2833 annotate_input_qspec_map ,
2934 annotate_output_qspec ,
4146logger = logging .getLogger (__name__ )
4247
4348
49+ def _is_fused_moving_avg_obs_fake_quant_ctor (func : object ) -> bool :
50+ """Return True when ``func`` is the fused fake-quant class or a subclass."""
51+
52+ return isinstance (func , type ) and issubclass (func , FusedMovingAvgObsFakeQuantize )
53+
54+
4455@dataclass (frozen = True )
4556class _QuantProperty :
4657 """Specify how the input/output at 'index' must be quantized."""
@@ -85,10 +96,29 @@ def _as_list(x):
8596 ]
8697
8798
88- def _adjust_weight_qspec_for_conv_transpose (node : Node , weight_qspec ):
99+ def _adjust_weight_qspec_for_conv_transpose (
100+ node : Node , weight_qspec : QuantizationSpec | None
101+ ) -> QuantizationSpec | None :
102+ """Adjust weight qspec axis/ctor for conv_transpose2d per-channel
103+ quantization.
104+
105+ Use axis 1 for ungrouped ConvTranspose2d weights because the weight layout is
106+ (in_channels, out_channels / groups, kH, kW). Grouped transpose conv keeps axis 0.
107+
108+ If the weight qspec contains a TorchAO QAT fake-quant/observer constructor
109+ (e.g. PartialWrapper(partial(...)) or a with_args-based constructor), the
110+ constructor is rebuilt with the corrected axis. For fused per-channel
111+ FakeQuantize, which only supports axis 0, the constructor is replaced with
112+ a non-fused FakeQuantize + MovingAveragePerChannelMinMaxObserver when the
113+ required axis is not 0.
114+
115+ Return the qspec unchanged when weights are unset.
116+
117+ """
118+
89119 if (
90120 node .target != torch .ops .aten .conv_transpose2d .input
91- or not isinstance ( weight_qspec , QuantizationSpec )
121+ or weight_qspec is None
92122 or weight_qspec .qscheme != torch .per_channel_symmetric
93123 ):
94124 return weight_qspec
@@ -101,27 +131,42 @@ def _adjust_weight_qspec_for_conv_transpose(node: Node, weight_qspec):
101131 if len (node .args ) > 6 and isinstance (node .args [6 ], int ):
102132 groups = node .args [6 ]
103133 expected_axis = 0 if groups != 1 else 1
104- if weight_qspec .ch_axis == expected_axis :
105- return weight_qspec
106134
107135 observer_or_fake_quant_ctr = weight_qspec .observer_or_fake_quant_ctr
108- # TorchAO PT2e QAT commonly represents the ctor as PartialWrapper(partial(...)).
109- # Rebuild it to update ch_axis while preserving callable_args.
136+ observer_or_fake_quant_ctr_changed = False
137+ # QAT FakeQuantize uses PartialWrapper; rebuild its partial to update ch_axis
138+ # without breaking TorchAO introspection.
110139 if isinstance (observer_or_fake_quant_ctr , PartialWrapper ):
111140 original_callable_args = dict (observer_or_fake_quant_ctr .callable_args )
112141 base_partial = observer_or_fake_quant_ctr .p
113142 if isinstance (base_partial , functools .partial ):
114143 base_keywords = dict (base_partial .keywords or {})
115144 base_keywords ["ch_axis" ] = expected_axis
116- observer_or_fake_quant_ctr = PartialWrapper (
117- functools .partial (base_partial .func , ** base_keywords )
118- )
145+ if (
146+ _is_fused_moving_avg_obs_fake_quant_ctor (base_partial .func )
147+ and expected_axis != 0
148+ ):
149+ # Fused per-channel FakeQuant only supports axis 0; for other axes,
150+ # fall back to FakeQuantize with a per-channel observer.
151+ base_keywords ["observer" ] = MovingAveragePerChannelMinMaxObserver
152+ observer_or_fake_quant_ctr = PartialWrapper (
153+ functools .partial (FakeQuantize , ** base_keywords )
154+ )
155+ else :
156+ observer_or_fake_quant_ctr = PartialWrapper (
157+ functools .partial (base_partial .func , ** base_keywords )
158+ )
119159 observer_or_fake_quant_ctr .callable_args = original_callable_args
120- # Non-QAT observer/fake-quant constructors can be updated via with_args.
160+ observer_or_fake_quant_ctr_changed = True
161+ # Non-QAT observer/fake-quant ctrs can be updated via with_args.
121162 elif hasattr (observer_or_fake_quant_ctr , "with_args" ):
122163 observer_or_fake_quant_ctr = observer_or_fake_quant_ctr .with_args (
123164 ch_axis = expected_axis
124165 )
166+ observer_or_fake_quant_ctr_changed = True
167+
168+ if weight_qspec .ch_axis == expected_axis and not observer_or_fake_quant_ctr_changed :
169+ return weight_qspec
125170
126171 return QuantizationSpec (
127172 dtype = weight_qspec .dtype ,
@@ -581,9 +626,10 @@ def any_or_hardtanh_min_zero(n: Node):
581626 filter_fn = any_or_hardtanh_min_zero ,
582627 ):
583628 if node .target in _conv_ops :
629+ conv_weight_qspec = ensure_type (QuantizationSpec , weight_qspec ) # For MyPy
584630 quant_properties .quant_inputs = [
585631 _QuantProperty (0 , input_act_qspec ),
586- _QuantProperty (1 , weight_qspec , mark_annotated = True ),
632+ _QuantProperty (1 , conv_weight_qspec , mark_annotated = True ),
587633 _QuantProperty (2 , bias_qspec , optional = True , mark_annotated = True ),
588634 ]
589635 elif node .target in (
@@ -602,9 +648,10 @@ def any_or_hardtanh_min_zero(n: Node):
602648 ],
603649 ):
604650 if node .target in _conv_ops :
651+ conv_weight_qspec = ensure_type (QuantizationSpec , weight_qspec ) # For MyPy
605652 quant_properties .quant_inputs = [
606653 _QuantProperty (0 , input_act_qspec ),
607- _QuantProperty (1 , weight_qspec , mark_annotated = True ),
654+ _QuantProperty (1 , conv_weight_qspec , mark_annotated = True ),
608655 _QuantProperty (2 , bias_qspec , optional = True , mark_annotated = True ),
609656 ]
610657 elif node .target in [
@@ -631,9 +678,12 @@ def any_or_hardtanh_min_zero(n: Node):
631678 * _conv_ops ,
632679 torch .ops .aten .linear .default ,
633680 ):
681+ conv_or_linear_weight_qspec = ensure_type (
682+ QuantizationSpec , weight_qspec
683+ ) # For MyPy
634684 quant_properties .quant_inputs = [
635685 _QuantProperty (0 , input_act_qspec ),
636- _QuantProperty (1 , weight_qspec , mark_annotated = True ),
686+ _QuantProperty (1 , conv_or_linear_weight_qspec , mark_annotated = True ),
637687 _QuantProperty (2 , bias_qspec , optional = True , mark_annotated = True ),
638688 ]
639689 else :
@@ -642,9 +692,12 @@ def any_or_hardtanh_min_zero(n: Node):
642692 * _conv_ops ,
643693 torch .ops .aten .linear .default ,
644694 ):
695+ conv_or_linear_weight_qspec = ensure_type (
696+ QuantizationSpec , weight_qspec
697+ ) # For MyPy
645698 quant_properties .quant_inputs = [
646699 _QuantProperty (0 , input_act_qspec ),
647- _QuantProperty (1 , weight_qspec , mark_annotated = True ),
700+ _QuantProperty (1 , conv_or_linear_weight_qspec , mark_annotated = True ),
648701 _QuantProperty (2 , bias_qspec , optional = True , mark_annotated = True ),
649702 ]
650703 quant_properties .quant_output = _QuantProperty (0 , output_act_qspec )
0 commit comments