Skip to content

Commit 50c170c

Browse files
authored
Add QuantizedOutputWrapper API
Differential Revision: D92579429 Pull Request resolved: #17289
1 parent f655c51 commit 50c170c

1 file changed

Lines changed: 131 additions & 23 deletions

File tree

backends/cadence/aot/compiler_funcs.py

Lines changed: 131 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,6 @@
1919
QuantArgs = tuple[float, int, int, int, torch.dtype]
2020

2121

22-
def extract_input_shapes_from_graph(
23-
module: GraphModule,
24-
) -> dict[int, tuple[int, ...]]:
25-
"""
26-
Extract input shapes from the FX graph placeholder nodes.
27-
28-
Returns a dict mapping input index to expected shape tuple.
29-
"""
30-
input_shapes: dict[int, tuple[int, ...]] = {}
31-
idx = 0
32-
for node in module.graph.nodes:
33-
if node.op == "placeholder":
34-
# Get the tensor_meta from the node if available
35-
if "val" in node.meta:
36-
val = node.meta["val"]
37-
if isinstance(val, torch.Tensor):
38-
input_shapes[idx] = tuple(val.shape)
39-
elif hasattr(val, "shape"):
40-
input_shapes[idx] = tuple(val.shape)
41-
idx += 1
42-
return input_shapes
43-
44-
4522
@torch.no_grad()
4623
def trace(
4724
model: torch.nn.Module,
@@ -81,6 +58,29 @@ def prepare(
8158
return prepared_model
8259

8360

61+
def extract_input_shapes_from_graph(
62+
module: GraphModule,
63+
) -> dict[int, tuple[int, ...]]:
64+
"""
65+
Extract input shapes from the FX graph placeholder nodes.
66+
67+
Returns a dict mapping input index to expected shape tuple.
68+
"""
69+
input_shapes: dict[int, tuple[int, ...]] = {}
70+
idx = 0
71+
for node in module.graph.nodes:
72+
if node.op == "placeholder":
73+
# Get the tensor_meta from the node if available
74+
if "val" in node.meta:
75+
val = node.meta["val"]
76+
if isinstance(val, torch.Tensor):
77+
input_shapes[idx] = tuple(val.shape)
78+
elif hasattr(val, "shape"):
79+
input_shapes[idx] = tuple(val.shape)
80+
idx += 1
81+
return input_shapes
82+
83+
8484
def extract_quant_params_through_permute(
8585
module: torch.fx.GraphModule,
8686
) -> dict[int, tuple[float, int, int, int, torch.dtype]]:
@@ -121,6 +121,83 @@ def extract_quant_params_through_permute(
121121
return quant_args
122122

123123

124+
def extract_output_dequant_params(
125+
module: torch.fx.GraphModule,
126+
) -> QuantArgs:
127+
"""
128+
Extract dequantization parameters from the output of a quantized model.
129+
130+
The graph is expected to end with:
131+
... → dequantize_per_tensor(scale, zp, qmin, qmax, dtype) → output
132+
"""
133+
for node in module.graph.nodes:
134+
if node.op != "output":
135+
continue
136+
output_args = node.args[0]
137+
if isinstance(output_args, (tuple, list)):
138+
target_output = output_args[0]
139+
else:
140+
target_output = output_args
141+
if not isinstance(target_output, torch.fx.Node):
142+
raise ValueError("Output node is not an FX node")
143+
if "dequantize_per_tensor" in str(target_output.target):
144+
args = target_output.args[1:]
145+
if len(args) >= 5:
146+
dtype = args[4]
147+
assert isinstance(dtype, torch.dtype)
148+
return (
149+
float(args[0]), # scale
150+
int(args[1]), # zero_point
151+
int(args[2]), # qmin
152+
int(args[3]), # qmax
153+
dtype,
154+
)
155+
raise ValueError("Could not find dequantize_per_tensor at the output of the graph")
156+
157+
158+
def extract_output_dequant_params_through_permute(
159+
module: torch.fx.GraphModule,
160+
) -> QuantArgs:
161+
"""
162+
Extract dequantization parameters from the output through a permute.
163+
164+
For models with nhwc output, the graph ends with:
165+
... → dequantize_per_tensor → permute(0, 2, 3, 1) → output
166+
"""
167+
for node in module.graph.nodes:
168+
if node.op != "output":
169+
continue
170+
output_args = node.args[0]
171+
if isinstance(output_args, (tuple, list)):
172+
target_output = output_args[0]
173+
else:
174+
target_output = output_args
175+
if not isinstance(target_output, torch.fx.Node):
176+
raise ValueError("Output node is not an FX node")
177+
if target_output.target in (
178+
torch.ops.aten.permute.default,
179+
torch.ops.aten.permute_copy.default,
180+
):
181+
permute_input = target_output.args[0]
182+
if isinstance(
183+
permute_input, torch.fx.Node
184+
) and "dequantize_per_tensor" in str(permute_input.target):
185+
args = permute_input.args[1:]
186+
if len(args) >= 5:
187+
dtype = args[4]
188+
assert isinstance(dtype, torch.dtype)
189+
return (
190+
float(args[0]), # scale
191+
int(args[1]), # zero_point
192+
int(args[2]), # qmin
193+
int(args[3]), # qmax
194+
dtype,
195+
)
196+
raise ValueError(
197+
"Could not find dequantize_per_tensor → permute at the output of the graph"
198+
)
199+
200+
124201
def extract_input_quant_params_from_graph(
125202
module: GraphModule,
126203
input_names: list[str],
@@ -241,3 +318,34 @@ def forward(self, *args: torch.Tensor) -> Any:
241318
dequantized_args.append(node)
242319

243320
return self.module(*dequantized_args)
321+
322+
323+
class QuantizedOutputWrapper(torch.nn.Module):
324+
"""
325+
Wrapper that quantizes a model's output so it produces uint8 tensors.
326+
327+
Mirrors QuantizedInputWrapper: the wrapper adds a quantize_per_tensor after
328+
the model's output. When the graph is traced, the dequant (from the model) →
329+
quant (from the wrapper) pair with matching parameters folds away, leaving
330+
the output in its quantized form.
331+
332+
Args:
333+
module: The module to wrap (may already be a QuantizedInputWrapper).
334+
output_quant_args: (scale, zero_point, qmin, qmax, dtype) for the output.
335+
"""
336+
337+
def __init__(
338+
self,
339+
module: torch.nn.Module,
340+
output_quant_args: QuantArgs,
341+
) -> None:
342+
super().__init__()
343+
self.module: torch.nn.Module = module
344+
self.output_quant_args: QuantArgs = output_quant_args
345+
346+
def forward(self, *args: torch.Tensor) -> Any:
347+
result = self.module(*args)
348+
scale, zp, qmin, qmax, dtype = self.output_quant_args
349+
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
350+
result, scale, zp, qmin, qmax, dtype
351+
)

0 commit comments

Comments
 (0)