|
12 | 12 | from typing import Any, cast, Optional, Union |
13 | 13 |
|
14 | 14 | import torch |
| 15 | + |
| 16 | +from executorch.backends.transforms.permute_pass_utils import get_arg |
15 | 17 | from torch._inductor.decomposition import remove_decompositions |
16 | 18 | from torch.fx import GraphModule |
17 | 19 | from torch.fx.passes.infra.pass_base import PassBase, PassResult |
@@ -159,6 +161,40 @@ def extract_output_dequant_params( |
159 | 161 | raise ValueError("Could not find dequantize_per_tensor at the output of the graph") |
160 | 162 |
|
161 | 163 |
|
| 164 | +def extract_all_output_dequant_params( |
| 165 | + module: torch.fx.GraphModule, |
| 166 | +) -> list[QuantArgs | None]: |
| 167 | + """ |
| 168 | + Extract per-output dequantization parameters from a multi-output model. |
| 169 | +
|
| 170 | + Returns a QuantArgs tuple for outputs ending in dequantize_per_tensor |
| 171 | + or None for outputs that aren't dequantized. |
| 172 | + """ |
| 173 | + output_nodes = module.graph.find_nodes(op="output") |
| 174 | + if not output_nodes: |
| 175 | + raise ValueError("No output node in graph") |
| 176 | + output_args = output_nodes[0].args[0] |
| 177 | + if not isinstance(output_args, (tuple, list)): |
| 178 | + output_args = (output_args,) |
| 179 | + |
| 180 | + dequant_ops = _get_dequantize_ops() |
| 181 | + params: list[QuantArgs | None] = [] |
| 182 | + for out in output_args: |
| 183 | + if not isinstance(out, torch.fx.Node) or out.target not in dequant_ops: |
| 184 | + params.append(None) |
| 185 | + continue |
| 186 | + params.append( |
| 187 | + ( |
| 188 | + float(get_arg(out, "scale", float)), |
| 189 | + int(get_arg(out, "zero_point", int)), |
| 190 | + int(get_arg(out, "quant_min", int)), |
| 191 | + int(get_arg(out, "quant_max", int)), |
| 192 | + get_arg(out, "dtype", torch.dtype), |
| 193 | + ) |
| 194 | + ) |
| 195 | + return params |
| 196 | + |
| 197 | + |
162 | 198 | def extract_output_dequant_params_through_permute( |
163 | 199 | module: torch.fx.GraphModule, |
164 | 200 | ) -> QuantArgs: |
@@ -400,33 +436,60 @@ def sink_dequants(program: torch.export.ExportedProgram) -> None: |
400 | 436 |
|
401 | 437 | class QuantizedOutputWrapper(torch.nn.Module): |
402 | 438 | """ |
403 | | - Wrapper that quantizes a model's output so it produces uint8 tensors. |
| 439 | + Wrapper that quantizes a model's output(s) so they produce quantized tensors. |
404 | 440 |
|
405 | 441 | Mirrors QuantizedInputWrapper: the wrapper adds a quantize_per_tensor after |
406 | | - the model's output. When the graph is traced, the dequant (from the model) → |
| 442 | + each output. When the graph is traced, the dequant (from the model) → |
407 | 443 | quant (from the wrapper) pair with matching parameters folds away, leaving |
408 | 444 | the output in its quantized form. |
409 | 445 |
|
410 | 446 | Args: |
411 | 447 | module: The module to wrap (may already be a QuantizedInputWrapper). |
412 | | - output_quant_args: (scale, zero_point, qmin, qmax, dtype) for the output. |
| 448 | + output_quant_args: Quantization parameters — either a single QuantArgs |
| 449 | + tuple or a list with one entry per output. |
413 | 450 | """ |
414 | 451 |
|
415 | 452 | def __init__( |
416 | 453 | self, |
417 | 454 | module: torch.nn.Module, |
418 | | - output_quant_args: QuantArgs, |
| 455 | + output_quant_args: Union[QuantArgs, list[QuantArgs | None]], |
419 | 456 | ) -> None: |
420 | 457 | super().__init__() |
421 | 458 | self.module: torch.nn.Module = module |
422 | | - self.output_quant_args: QuantArgs = output_quant_args |
| 459 | + if isinstance(output_quant_args, list): |
| 460 | + self._multi_output: bool = True |
| 461 | + self._per_output_args: list[QuantArgs | None] = output_quant_args |
| 462 | + else: |
| 463 | + self._multi_output = False |
| 464 | + self._per_output_args = [output_quant_args] |
423 | 465 |
|
424 | 466 | def forward(self, *args: torch.Tensor) -> Any: |
425 | | - result = self.module(*args) |
426 | | - scale, zp, qmin, qmax, dtype = self.output_quant_args |
427 | | - return torch.ops.quantized_decomposed.quantize_per_tensor.default( |
428 | | - result, scale, zp, qmin, qmax, dtype |
429 | | - ) |
| 467 | + model_output = self.module(*args) |
| 468 | + if not self._multi_output: |
| 469 | + quant_args = self._per_output_args[0] |
| 470 | + assert quant_args is not None |
| 471 | + scale, zero_point, quant_min, quant_max, dtype = quant_args |
| 472 | + return torch.ops.quantized_decomposed.quantize_per_tensor.default( |
| 473 | + model_output, scale, zero_point, quant_min, quant_max, dtype |
| 474 | + ) |
| 475 | + |
| 476 | + quantized_outputs: list[torch.Tensor] = [] |
| 477 | + for output_index, output_tensor in enumerate(model_output): |
| 478 | + quant_args = ( |
| 479 | + self._per_output_args[output_index] |
| 480 | + if output_index < len(self._per_output_args) |
| 481 | + else None |
| 482 | + ) |
| 483 | + if quant_args is None: |
| 484 | + quantized_outputs.append(output_tensor) |
| 485 | + else: |
| 486 | + scale, zero_point, quant_min, quant_max, dtype = quant_args |
| 487 | + quantized_outputs.append( |
| 488 | + torch.ops.quantized_decomposed.quantize_per_tensor.default( |
| 489 | + output_tensor, scale, zero_point, quant_min, quant_max, dtype |
| 490 | + ) |
| 491 | + ) |
| 492 | + return tuple(quantized_outputs) |
430 | 493 |
|
431 | 494 |
|
432 | 495 | def _get_transparent_ops() -> set[Any]: |
|
0 commit comments