Skip to content

Commit f3b66dc

Browse files
authored
Generalize QuantizedOutputWrapper for multi-output models (pytorch#19987)
Differential Revision: D107429509 Pull Request resolved: pytorch#19987
1 parent 1bf982a commit f3b66dc

2 files changed

Lines changed: 74 additions & 10 deletions

File tree

backends/cadence/aot/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ fbcode_target(_kind = runtime.python_library,
426426
typing = True,
427427
deps = [
428428
"//caffe2:torch",
429+
"//executorch/backends/transforms:permute_pass_utils",
429430
"//pytorch/ao:torchao",
430431
],
431432
)

backends/cadence/aot/compiler_funcs.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from typing import Any, cast, Optional, Union
1313

1414
import torch
15+
16+
from executorch.backends.transforms.permute_pass_utils import get_arg
1517
from torch._inductor.decomposition import remove_decompositions
1618
from torch.fx import GraphModule
1719
from torch.fx.passes.infra.pass_base import PassBase, PassResult
@@ -159,6 +161,40 @@ def extract_output_dequant_params(
159161
raise ValueError("Could not find dequantize_per_tensor at the output of the graph")
160162

161163

164+
def extract_all_output_dequant_params(
165+
module: torch.fx.GraphModule,
166+
) -> list[QuantArgs | None]:
167+
"""
168+
Extract per-output dequantization parameters from a multi-output model.
169+
170+
Returns a QuantArgs tuple for outputs ending in dequantize_per_tensor
171+
or None for outputs that aren't dequantized.
172+
"""
173+
output_nodes = module.graph.find_nodes(op="output")
174+
if not output_nodes:
175+
raise ValueError("No output node in graph")
176+
output_args = output_nodes[0].args[0]
177+
if not isinstance(output_args, (tuple, list)):
178+
output_args = (output_args,)
179+
180+
dequant_ops = _get_dequantize_ops()
181+
params: list[QuantArgs | None] = []
182+
for out in output_args:
183+
if not isinstance(out, torch.fx.Node) or out.target not in dequant_ops:
184+
params.append(None)
185+
continue
186+
params.append(
187+
(
188+
float(get_arg(out, "scale", float)),
189+
int(get_arg(out, "zero_point", int)),
190+
int(get_arg(out, "quant_min", int)),
191+
int(get_arg(out, "quant_max", int)),
192+
get_arg(out, "dtype", torch.dtype),
193+
)
194+
)
195+
return params
196+
197+
162198
def extract_output_dequant_params_through_permute(
163199
module: torch.fx.GraphModule,
164200
) -> QuantArgs:
@@ -400,33 +436,60 @@ def sink_dequants(program: torch.export.ExportedProgram) -> None:
400436

401437
class QuantizedOutputWrapper(torch.nn.Module):
402438
"""
403-
Wrapper that quantizes a model's output so it produces uint8 tensors.
439+
Wrapper that quantizes a model's output(s) so they produce quantized tensors.
404440
405441
Mirrors QuantizedInputWrapper: the wrapper adds a quantize_per_tensor after
406-
the model's output. When the graph is traced, the dequant (from the model) →
442+
each output. When the graph is traced, the dequant (from the model) →
407443
quant (from the wrapper) pair with matching parameters folds away, leaving
408444
the output in its quantized form.
409445
410446
Args:
411447
module: The module to wrap (may already be a QuantizedInputWrapper).
412-
output_quant_args: (scale, zero_point, qmin, qmax, dtype) for the output.
448+
output_quant_args: Quantization parameters — either a single QuantArgs
449+
tuple or a list with one entry per output.
413450
"""
414451

415452
def __init__(
416453
self,
417454
module: torch.nn.Module,
418-
output_quant_args: QuantArgs,
455+
output_quant_args: Union[QuantArgs, list[QuantArgs | None]],
419456
) -> None:
420457
super().__init__()
421458
self.module: torch.nn.Module = module
422-
self.output_quant_args: QuantArgs = output_quant_args
459+
if isinstance(output_quant_args, list):
460+
self._multi_output: bool = True
461+
self._per_output_args: list[QuantArgs | None] = output_quant_args
462+
else:
463+
self._multi_output = False
464+
self._per_output_args = [output_quant_args]
423465

424466
def forward(self, *args: torch.Tensor) -> Any:
425-
result = self.module(*args)
426-
scale, zp, qmin, qmax, dtype = self.output_quant_args
427-
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
428-
result, scale, zp, qmin, qmax, dtype
429-
)
467+
model_output = self.module(*args)
468+
if not self._multi_output:
469+
quant_args = self._per_output_args[0]
470+
assert quant_args is not None
471+
scale, zero_point, quant_min, quant_max, dtype = quant_args
472+
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
473+
model_output, scale, zero_point, quant_min, quant_max, dtype
474+
)
475+
476+
quantized_outputs: list[torch.Tensor] = []
477+
for output_index, output_tensor in enumerate(model_output):
478+
quant_args = (
479+
self._per_output_args[output_index]
480+
if output_index < len(self._per_output_args)
481+
else None
482+
)
483+
if quant_args is None:
484+
quantized_outputs.append(output_tensor)
485+
else:
486+
scale, zero_point, quant_min, quant_max, dtype = quant_args
487+
quantized_outputs.append(
488+
torch.ops.quantized_decomposed.quantize_per_tensor.default(
489+
output_tensor, scale, zero_point, quant_min, quant_max, dtype
490+
)
491+
)
492+
return tuple(quantized_outputs)
430493

431494

432495
def _get_transparent_ops() -> set[Any]:

0 commit comments

Comments
 (0)