|
19 | 19 | QuantArgs = tuple[float, int, int, int, torch.dtype] |
20 | 20 |
|
21 | 21 |
|
22 | | -def extract_input_shapes_from_graph( |
23 | | - module: GraphModule, |
24 | | -) -> dict[int, tuple[int, ...]]: |
25 | | - """ |
26 | | - Extract input shapes from the FX graph placeholder nodes. |
27 | | -
|
28 | | - Returns a dict mapping input index to expected shape tuple. |
29 | | - """ |
30 | | - input_shapes: dict[int, tuple[int, ...]] = {} |
31 | | - idx = 0 |
32 | | - for node in module.graph.nodes: |
33 | | - if node.op == "placeholder": |
34 | | - # Get the tensor_meta from the node if available |
35 | | - if "val" in node.meta: |
36 | | - val = node.meta["val"] |
37 | | - if isinstance(val, torch.Tensor): |
38 | | - input_shapes[idx] = tuple(val.shape) |
39 | | - elif hasattr(val, "shape"): |
40 | | - input_shapes[idx] = tuple(val.shape) |
41 | | - idx += 1 |
42 | | - return input_shapes |
43 | | - |
44 | | - |
45 | 22 | @torch.no_grad() |
46 | 23 | def trace( |
47 | 24 | model: torch.nn.Module, |
@@ -81,6 +58,29 @@ def prepare( |
81 | 58 | return prepared_model |
82 | 59 |
|
83 | 60 |
|
| 61 | +def extract_input_shapes_from_graph( |
| 62 | + module: GraphModule, |
| 63 | +) -> dict[int, tuple[int, ...]]: |
| 64 | + """ |
| 65 | + Extract input shapes from the FX graph placeholder nodes. |
| 66 | +
|
| 67 | + Returns a dict mapping input index to expected shape tuple. |
| 68 | + """ |
| 69 | + input_shapes: dict[int, tuple[int, ...]] = {} |
| 70 | + idx = 0 |
| 71 | + for node in module.graph.nodes: |
| 72 | + if node.op == "placeholder": |
| 73 | + # Get the tensor_meta from the node if available |
| 74 | + if "val" in node.meta: |
| 75 | + val = node.meta["val"] |
| 76 | + if isinstance(val, torch.Tensor): |
| 77 | + input_shapes[idx] = tuple(val.shape) |
| 78 | + elif hasattr(val, "shape"): |
| 79 | + input_shapes[idx] = tuple(val.shape) |
| 80 | + idx += 1 |
| 81 | + return input_shapes |
| 82 | + |
| 83 | + |
84 | 84 | def extract_quant_params_through_permute( |
85 | 85 | module: torch.fx.GraphModule, |
86 | 86 | ) -> dict[int, tuple[float, int, int, int, torch.dtype]]: |
@@ -121,6 +121,83 @@ def extract_quant_params_through_permute( |
121 | 121 | return quant_args |
122 | 122 |
|
123 | 123 |
|
| 124 | +def extract_output_dequant_params( |
| 125 | + module: torch.fx.GraphModule, |
| 126 | +) -> QuantArgs: |
| 127 | + """ |
| 128 | + Extract dequantization parameters from the output of a quantized model. |
| 129 | +
|
| 130 | + The graph is expected to end with: |
| 131 | + ... → dequantize_per_tensor(scale, zp, qmin, qmax, dtype) → output |
| 132 | + """ |
| 133 | + for node in module.graph.nodes: |
| 134 | + if node.op != "output": |
| 135 | + continue |
| 136 | + output_args = node.args[0] |
| 137 | + if isinstance(output_args, (tuple, list)): |
| 138 | + target_output = output_args[0] |
| 139 | + else: |
| 140 | + target_output = output_args |
| 141 | + if not isinstance(target_output, torch.fx.Node): |
| 142 | + raise ValueError("Output node is not an FX node") |
| 143 | + if "dequantize_per_tensor" in str(target_output.target): |
| 144 | + args = target_output.args[1:] |
| 145 | + if len(args) >= 5: |
| 146 | + dtype = args[4] |
| 147 | + assert isinstance(dtype, torch.dtype) |
| 148 | + return ( |
| 149 | + float(args[0]), # scale |
| 150 | + int(args[1]), # zero_point |
| 151 | + int(args[2]), # qmin |
| 152 | + int(args[3]), # qmax |
| 153 | + dtype, |
| 154 | + ) |
| 155 | + raise ValueError("Could not find dequantize_per_tensor at the output of the graph") |
| 156 | + |
| 157 | + |
| 158 | +def extract_output_dequant_params_through_permute( |
| 159 | + module: torch.fx.GraphModule, |
| 160 | +) -> QuantArgs: |
| 161 | + """ |
| 162 | + Extract dequantization parameters from the output through a permute. |
| 163 | +
|
| 164 | + For models with nhwc output, the graph ends with: |
| 165 | + ... → dequantize_per_tensor → permute(0, 2, 3, 1) → output |
| 166 | + """ |
| 167 | + for node in module.graph.nodes: |
| 168 | + if node.op != "output": |
| 169 | + continue |
| 170 | + output_args = node.args[0] |
| 171 | + if isinstance(output_args, (tuple, list)): |
| 172 | + target_output = output_args[0] |
| 173 | + else: |
| 174 | + target_output = output_args |
| 175 | + if not isinstance(target_output, torch.fx.Node): |
| 176 | + raise ValueError("Output node is not an FX node") |
| 177 | + if target_output.target in ( |
| 178 | + torch.ops.aten.permute.default, |
| 179 | + torch.ops.aten.permute_copy.default, |
| 180 | + ): |
| 181 | + permute_input = target_output.args[0] |
| 182 | + if isinstance( |
| 183 | + permute_input, torch.fx.Node |
| 184 | + ) and "dequantize_per_tensor" in str(permute_input.target): |
| 185 | + args = permute_input.args[1:] |
| 186 | + if len(args) >= 5: |
| 187 | + dtype = args[4] |
| 188 | + assert isinstance(dtype, torch.dtype) |
| 189 | + return ( |
| 190 | + float(args[0]), # scale |
| 191 | + int(args[1]), # zero_point |
| 192 | + int(args[2]), # qmin |
| 193 | + int(args[3]), # qmax |
| 194 | + dtype, |
| 195 | + ) |
| 196 | + raise ValueError( |
| 197 | + "Could not find dequantize_per_tensor → permute at the output of the graph" |
| 198 | + ) |
| 199 | + |
| 200 | + |
124 | 201 | def extract_input_quant_params_from_graph( |
125 | 202 | module: GraphModule, |
126 | 203 | input_names: list[str], |
@@ -241,3 +318,34 @@ def forward(self, *args: torch.Tensor) -> Any: |
241 | 318 | dequantized_args.append(node) |
242 | 319 |
|
243 | 320 | return self.module(*dequantized_args) |
| 321 | + |
| 322 | + |
| 323 | +class QuantizedOutputWrapper(torch.nn.Module): |
| 324 | + """ |
| 325 | + Wrapper that quantizes a model's output so it produces uint8 tensors. |
| 326 | +
|
| 327 | + Mirrors QuantizedInputWrapper: the wrapper adds a quantize_per_tensor after |
| 328 | + the model's output. When the graph is traced, the dequant (from the model) → |
| 329 | + quant (from the wrapper) pair with matching parameters folds away, leaving |
| 330 | + the output in its quantized form. |
| 331 | +
|
| 332 | + Args: |
| 333 | + module: The module to wrap (may already be a QuantizedInputWrapper). |
| 334 | + output_quant_args: (scale, zero_point, qmin, qmax, dtype) for the output. |
| 335 | + """ |
| 336 | + |
| 337 | + def __init__( |
| 338 | + self, |
| 339 | + module: torch.nn.Module, |
| 340 | + output_quant_args: QuantArgs, |
| 341 | + ) -> None: |
| 342 | + super().__init__() |
| 343 | + self.module: torch.nn.Module = module |
| 344 | + self.output_quant_args: QuantArgs = output_quant_args |
| 345 | + |
| 346 | + def forward(self, *args: torch.Tensor) -> Any: |
| 347 | + result = self.module(*args) |
| 348 | + scale, zp, qmin, qmax, dtype = self.output_quant_args |
| 349 | + return torch.ops.quantized_decomposed.quantize_per_tensor.default( |
| 350 | + result, scale, zp, qmin, qmax, dtype |
| 351 | + ) |
0 commit comments