Skip to content

Commit 1994514

Browse files
JakeStevensfacebook-github-bot
authored andcommitted
Move clamp to independent quantization in annotator (#17910)
Summary: The clamp operation was incorrectly placed in `_one_to_one_shared_input_qspec`, which causes the input and output observers to be shared. This is problematic because clamp explicitly modifies the value range by enforcing min/max bounds. When using clamp to prevent undefined behavior (e.g., clamping inputs to rsqrt to be positive), the pre-clamp and post-clamp ranges can be very different. With shared observers, the pre-clamp (smaller) values dominate the min_val, causing incorrect quantization parameters for the post-clamp tensor. This fix moves clamp to `_one_to_one`, giving it independent input/output quantization so each observer properly tracks its respective range. Differential Revision: D92408418
1 parent eb77ed4 commit 1994514

4 files changed

Lines changed: 150 additions & 25 deletions

File tree

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from executorch.backends.arm._passes import ArmPass
1313
from executorch.backends.arm._passes.arm_pass_utils import (
14+
create_node,
1415
get_param_tensor,
1516
is_param_node,
1617
set_node_arg,
@@ -347,53 +348,74 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901
347348

348349

349350
class QuantizeClampArgumentsPass(ArmPass):
350-
"""This pass makes sure that the arguments to clamp.default are quantized
351-
correctly.
351+
"""This pass quantizes the scalar min/max arguments of clamp.default and
352+
inserts a RESCALE when the input and output quantization parameters differ.
352353
353-
More specifically, this pass:
354-
- Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator.
354+
When clamp has independent input/output quantization (different scales),
355+
a RESCALE is inserted before the clamp to convert the input from the
356+
input domain to the output domain. The min/max bounds are quantized
357+
using the output quantization parameters, ensuring they are precise
358+
even when the clamp range is much narrower than the input range.
355359
356360
"""
357361

358362
_passes_required_after: Set[Type[ExportPass]] = set()
359363

360364
def call(self, graph_module: GraphModule) -> PassResult:
361365
modified = False
362-
# Loop over the graph nodes and find full.default nodes.
363366
for n in graph_module.graph.nodes:
364367
n = cast(Node, n)
365-
if n.target not in {
366-
exir_ops.edge.aten.clamp.default,
367-
}:
368+
if n.target != exir_ops.edge.aten.clamp.default:
368369
continue
369370

370371
try:
372+
input_qparams = get_input_qparams(n)
371373
output_qparams = get_output_qparams(n)
372374
except ValueError:
373375
continue
374-
if len(output_qparams) == 0:
376+
if len(input_qparams) == 0 or len(output_qparams) == 0:
375377
continue
376378

377-
# Qparams are stored per user index; use the first entry.
378-
qargs = next(iter(output_qparams.values()))
379+
input_qargs = next(iter(input_qparams.values()))
380+
output_qargs = next(iter(output_qparams.values()))
381+
382+
if input_qargs != output_qargs:
383+
input_node = n.args[0]
384+
with graph_module.graph.inserting_before(n):
385+
rescale_node = create_node(
386+
graph_module.graph,
387+
exir_ops.backend.tosa.RESCALE.default,
388+
(
389+
input_node,
390+
output_qargs.dtype,
391+
[
392+
input_qargs.get_scale_per_tensor()
393+
/ output_qargs.get_scale_per_tensor()
394+
],
395+
input_qargs.get_zp_per_tensor(),
396+
output_qargs.get_zp_per_tensor(),
397+
),
398+
from_node=n,
399+
)
400+
n.replace_input_with(input_node, rescale_node)
401+
n.meta["input_qparams"] = {0: output_qargs}
402+
403+
qargs = output_qargs
379404

380-
if n.target == exir_ops.edge.aten.clamp.default:
381-
# Quantize the min and max arguments of clamp, if they are not None
382-
min_val = n.args[1]
383-
max_val = None if len(n.args) <= 2 else n.args[2]
405+
min_val = n.args[1]
406+
max_val = None if len(n.args) <= 2 else n.args[2]
384407

385-
if min_val is not None:
386-
quantized_min_val = qargs.quantize_value(min_val).item()
387-
n.update_arg(1, quantized_min_val)
408+
if min_val is not None:
409+
quantized_min_val = qargs.quantize_value(min_val).item()
410+
n.update_arg(1, quantized_min_val)
388411

389-
if max_val is not None:
390-
quantized_max_val = qargs.quantize_value(max_val).item()
391-
n.update_arg(2, quantized_max_val)
412+
if max_val is not None:
413+
quantized_max_val = qargs.quantize_value(max_val).item()
414+
n.update_arg(2, quantized_max_val)
392415

393-
modified = True
416+
modified = True
394417

395418
if modified:
396-
# Retrace to refresh fake tensor metadata after updating clamp min/max.
397419
graph_module = super().call(graph_module).graph_module
398420
graph_module.recompile()
399421

backends/arm/quantizer/quantization_annotator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,10 @@ def _match_pattern(
433433
torch.ops.aten.acos.default,
434434
torch.ops.aten.cumsum.default,
435435
torch.ops.aten.tan.default,
436+
# Clamp modifies the value range (enforces min/max bounds), so it needs
437+
# independent input/output quantization to properly track the clamped range.
438+
torch.ops.aten.clamp.default,
439+
torch.ops.aten.clamp.Tensor,
436440
]
437441

438442
_one_to_one_shared_input_qspec = [
@@ -480,8 +484,6 @@ def _match_pattern(
480484
torch.ops.aten.pad.default,
481485
torch.ops.aten.amax.default,
482486
torch.ops.aten.amin.default,
483-
torch.ops.aten.clamp.default,
484-
torch.ops.aten.clamp.Tensor,
485487
torch.ops.aten.unflatten.int,
486488
torch.ops.aten.gather.default,
487489
torch.ops.aten.unfold_copy.default,
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024-2026 Arm Limited and/or its affiliates.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Verify that clamp uses independent input/output quantization.
8+
9+
Clamp modifies the value range by enforcing min/max bounds, so its output
10+
observer must be independent from its input observer. When observers are
11+
shared, the pre-clamp (wider) values dominate the observed range and the
12+
post-clamp tensor gets incorrect quantization parameters.
13+
14+
This test feeds a wide-range input through a narrow clamp and checks that
15+
the quantization scale for the clamp output differs from the input scale.
16+
"""
17+
18+
import torch
19+
from executorch.backends.arm.quantizer import (
20+
get_symmetric_quantization_config,
21+
TOSAQuantizer,
22+
)
23+
from executorch.backends.arm.tosa import TosaSpecification
24+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
25+
26+
Q_PER_TENSOR = torch.ops.quantized_decomposed.quantize_per_tensor.default
27+
DQ_PER_TENSOR = torch.ops.quantized_decomposed.dequantize_per_tensor.default
28+
29+
30+
class ClampModel(torch.nn.Module):
31+
def forward(self, x: torch.Tensor) -> torch.Tensor:
32+
return torch.clamp(x, min=0.0, max=1.0)
33+
34+
35+
def test_clamp_has_different_input_output_qparams():
36+
"""Input and output scales must differ when clamp narrows the range.
37+
38+
A wide-range input ([-50, 50]) clamped to [0, 1] should produce a much
39+
smaller output scale than input scale, because the output observer only
40+
sees values in [0, 1] while the input observer sees the full [-50, 50].
41+
42+
Before the fix (clamp in _one_to_one_shared_input_qspec), both observers
43+
were shared and would produce identical scales — the wider input range
44+
dominated, wasting output precision.
45+
"""
46+
model = ClampModel()
47+
model.eval()
48+
49+
# Use deterministic wide-range calibration data so the input observer
50+
# sees [-50, 50] while the output observer sees only [0, 1].
51+
calibration_input = torch.linspace(-50, 50, 200).reshape(1, 200)
52+
53+
tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
54+
quantizer = TOSAQuantizer(tosa_spec)
55+
quantizer.set_global(get_symmetric_quantization_config(is_per_channel=False))
56+
57+
exported = torch.export.export(model, (calibration_input,))
58+
prepared = prepare_pt2e(exported.module(), quantizer)
59+
prepared(calibration_input)
60+
converted = convert_pt2e(prepared)
61+
62+
# After conversion the graph has explicit quantize/dequantize nodes.
63+
# For clamp with independent qspecs the pattern is:
64+
# dequantize_per_tensor(input_scale) -> clamp -> quantize_per_tensor(output_scale)
65+
# With shared qspecs both scales would be identical.
66+
clamp_nodes = [
67+
n
68+
for n in converted.graph.nodes
69+
if n.target in (torch.ops.aten.clamp.default, torch.ops.aten.clamp.Tensor)
70+
]
71+
assert (
72+
len(clamp_nodes) == 1
73+
), f"Expected exactly 1 clamp node, found {len(clamp_nodes)}"
74+
clamp_node = clamp_nodes[0]
75+
76+
# Get the dequant feeding clamp's input — its scale is arg[1].
77+
input_dq = clamp_node.args[0]
78+
assert (
79+
input_dq.target == DQ_PER_TENSOR
80+
), f"Expected dequantize_per_tensor before clamp, got {input_dq.target}"
81+
input_scale = float(input_dq.args[1])
82+
83+
# Get the quant consuming clamp's output — its scale is arg[1].
84+
clamp_users = list(clamp_node.users)
85+
assert (
86+
len(clamp_users) == 1
87+
), f"Expected exactly 1 user of clamp, found {len(clamp_users)}"
88+
output_q = clamp_users[0]
89+
assert (
90+
output_q.target == Q_PER_TENSOR
91+
), f"Expected quantize_per_tensor after clamp, got {output_q.target}"
92+
output_scale = float(output_q.args[1])
93+
94+
# With independent quantization the output scale (tracking [0, 1]) must
95+
# be much smaller than the input scale (tracking [-50, 50]).
96+
assert output_scale < input_scale, (
97+
f"Clamp output scale ({output_scale}) should be smaller than input "
98+
f"scale ({input_scale}) because clamp narrows [−50, 50] → [0, 1]. "
99+
"If they are equal, clamp is using shared observers (bug)."
100+
)

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def define_arm_tests():
3636
# Quantization
3737
test_files += [
3838
"quantizer/test_generic_annotater.py",
39+
"quantizer/test_clamp_quantization.py",
3940
]
4041

4142
# Misc tests

0 commit comments

Comments
 (0)