Skip to content

Commit 703d869

Browse files
JakeStevensfacebook-github-bot
authored andcommitted
Move clamp to independent quantization in annotator
Summary: The clamp operation was incorrectly placed in `_one_to_one_shared_input_qspec`, which causes the input and output observers to be shared. This is problematic because clamp explicitly modifies the value range by enforcing min/max bounds. When using clamp to prevent undefined behavior (e.g., clamping inputs to rsqrt to be positive), the pre-clamp and post-clamp ranges can be very different. With shared observers, the pre-clamp (smaller) values dominate the min_val, causing incorrect quantization parameters for the post-clamp tensor. This fix moves clamp to `_one_to_one`, giving it independent input/output quantization so each observer properly tracks its respective range. Differential Revision: D92408418
1 parent 2f11c0c commit 703d869

3 files changed

Lines changed: 105 additions & 2 deletions

File tree

backends/arm/quantizer/quantization_annotator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,10 @@ def _match_pattern(
433433
torch.ops.aten.acos.default,
434434
torch.ops.aten.cumsum.default,
435435
torch.ops.aten.tan.default,
436+
# Clamp modifies the value range (enforces min/max bounds), so it needs
437+
# independent input/output quantization to properly track the clamped range.
438+
torch.ops.aten.clamp.default,
439+
torch.ops.aten.clamp.Tensor,
436440
]
437441

438442
_one_to_one_shared_input_qspec = [
@@ -480,8 +484,6 @@ def _match_pattern(
480484
torch.ops.aten.pad.default,
481485
torch.ops.aten.amax.default,
482486
torch.ops.aten.amin.default,
483-
torch.ops.aten.clamp.default,
484-
torch.ops.aten.clamp.Tensor,
485487
torch.ops.aten.unflatten.int,
486488
torch.ops.aten.gather.default,
487489
torch.ops.aten.unfold_copy.default,
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024-2026 Arm Limited and/or its affiliates.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Verify that clamp uses independent input/output quantization.
8+
9+
Clamp modifies the value range by enforcing min/max bounds, so its output
10+
observer must be independent from its input observer. When observers are
11+
shared, the pre-clamp (wider) values dominate the observed range and the
12+
post-clamp tensor gets incorrect quantization parameters.
13+
14+
This test feeds a wide-range input through a narrow clamp and checks that
15+
the quantization scale for the clamp output differs from the input scale.
16+
"""
17+
18+
import torch
19+
from executorch.backends.arm.quantizer import (
20+
get_symmetric_quantization_config,
21+
TOSAQuantizer,
22+
)
23+
from executorch.backends.arm.tosa import TosaSpecification
24+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
25+
26+
Q_PER_TENSOR = torch.ops.quantized_decomposed.quantize_per_tensor.default
27+
DQ_PER_TENSOR = torch.ops.quantized_decomposed.dequantize_per_tensor.default
28+
29+
30+
class ClampModel(torch.nn.Module):
31+
def forward(self, x: torch.Tensor) -> torch.Tensor:
32+
return torch.clamp(x, min=0.0, max=1.0)
33+
34+
35+
def test_clamp_has_different_input_output_qparams():
36+
"""Input and output scales must differ when clamp narrows the range.
37+
38+
A wide-range input ([-50, 50]) clamped to [0, 1] should produce a much
39+
smaller output scale than input scale, because the output observer only
40+
sees values in [0, 1] while the input observer sees the full [-50, 50].
41+
42+
Before the fix (clamp in _one_to_one_shared_input_qspec), both observers
43+
were shared and would produce identical scales — the wider input range
44+
dominated, wasting output precision.
45+
"""
46+
model = ClampModel()
47+
model.eval()
48+
49+
# Use deterministic wide-range calibration data so the input observer
50+
# sees [-50, 50] while the output observer sees only [0, 1].
51+
calibration_input = torch.linspace(-50, 50, 200).reshape(1, 200)
52+
53+
tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
54+
quantizer = TOSAQuantizer(tosa_spec)
55+
quantizer.set_global(get_symmetric_quantization_config(is_per_channel=False))
56+
57+
exported = torch.export.export(model, (calibration_input,))
58+
prepared = prepare_pt2e(exported.module(), quantizer)
59+
prepared(calibration_input)
60+
converted = convert_pt2e(prepared)
61+
62+
# After conversion the graph has explicit quantize/dequantize nodes.
63+
# For clamp with independent qspecs the pattern is:
64+
# dequantize_per_tensor(input_scale) -> clamp -> quantize_per_tensor(output_scale)
65+
# With shared qspecs both scales would be identical.
66+
clamp_nodes = [
67+
n
68+
for n in converted.graph.nodes
69+
if n.target in (torch.ops.aten.clamp.default, torch.ops.aten.clamp.Tensor)
70+
]
71+
assert (
72+
len(clamp_nodes) == 1
73+
), f"Expected exactly 1 clamp node, found {len(clamp_nodes)}"
74+
clamp_node = clamp_nodes[0]
75+
76+
# Get the dequant feeding clamp's input — its scale is arg[1].
77+
input_dq = clamp_node.args[0]
78+
assert (
79+
input_dq.target == DQ_PER_TENSOR
80+
), f"Expected dequantize_per_tensor before clamp, got {input_dq.target}"
81+
input_scale = float(input_dq.args[1])
82+
83+
# Get the quant consuming clamp's output — its scale is arg[1].
84+
clamp_users = list(clamp_node.users)
85+
assert (
86+
len(clamp_users) == 1
87+
), f"Expected exactly 1 user of clamp, found {len(clamp_users)}"
88+
output_q = clamp_users[0]
89+
assert (
90+
output_q.target == Q_PER_TENSOR
91+
), f"Expected quantize_per_tensor after clamp, got {output_q.target}"
92+
output_scale = float(output_q.args[1])
93+
94+
# With independent quantization the output scale (tracking [0, 1]) must
95+
# be much smaller than the input scale (tracking [-50, 50]).
96+
assert output_scale < input_scale, (
97+
f"Clamp output scale ({output_scale}) should be smaller than input "
98+
f"scale ({input_scale}) because clamp narrows [−50, 50] → [0, 1]. "
99+
"If they are equal, clamp is using shared observers (bug)."
100+
)

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def define_arm_tests():
3636
# Quantization
3737
test_files += [
3838
"quantizer/test_generic_annotater.py",
39+
"quantizer/test_clamp_quantization.py",
3940
]
4041

4142
# Misc tests

0 commit comments

Comments
 (0)