|
8 | 8 |
|
9 | 9 | import logging |
10 | 10 | import operator |
11 | | -from typing import Any, Optional, Union |
| 11 | +from collections.abc import Mapping, Sequence |
| 12 | +from typing import Any, cast, Optional, Union |
12 | 13 |
|
13 | 14 | import torch |
14 | 15 | from torch._inductor.decomposition import remove_decompositions |
@@ -301,23 +302,27 @@ def __init__( |
301 | 302 | "Warning: Using pre-quantized inputs. This should only be done when calibration has been confirmed." |
302 | 303 | "Incorrect quantization parameters can lead to significant accuracy degradation." |
303 | 304 | ) |
304 | | - if isinstance(input_args, list): |
305 | | - self.quant_args = extract_input_quant_params_from_graph(module, input_args) |
306 | | - elif isinstance(input_args, dict): |
| 305 | + if isinstance(input_args, Sequence) and not isinstance( |
| 306 | + input_args, (str, bytes) |
| 307 | + ): |
| 308 | + self.quant_args = extract_input_quant_params_from_graph( |
| 309 | + module, list(input_args) |
| 310 | + ) |
| 311 | + elif isinstance(input_args, Mapping): |
307 | 312 | # dict[int, QuantArgs] — use directly |
308 | 313 | # dict[int, list[str]] — extract quant params from graph, keyed by input index |
309 | 314 | first_value = next(iter(input_args.values()), None) |
310 | 315 | if ( |
311 | | - isinstance(first_value, (list, tuple)) |
| 316 | + isinstance(first_value, (list, tuple, Sequence)) |
| 317 | + and not isinstance(first_value, (str, bytes)) |
312 | 318 | and first_value |
313 | 319 | and isinstance(first_value[0], str) |
314 | 320 | ): |
315 | 321 | # Values are lists of node names: extract quant params and map |
316 | 322 | # to the caller-specified input indices. |
317 | 323 | for input_idx, node_names in input_args.items(): |
318 | | - assert isinstance(node_names, list) |
319 | 324 | extracted = extract_input_quant_params_from_graph( |
320 | | - module, node_names |
| 325 | + module, list(cast(Sequence[str], node_names)) |
321 | 326 | ) |
322 | 327 | # Use the first extracted quant params for this input index. |
323 | 328 | if extracted: |
@@ -430,6 +435,7 @@ def _get_transparent_ops() -> set[Any]: |
430 | 435 | torch.ops.aten.view.default, |
431 | 436 | torch.ops.aten.reshape.default, |
432 | 437 | torch.ops.aten.split.Tensor, |
| 438 | + torch.ops.aten.chunk.default, |
433 | 439 | torch.ops.aten.slice_copy.Tensor, |
434 | 440 | torch.ops.aten.permute_copy.default, |
435 | 441 | torch.ops.aten.permute.default, |
|
0 commit comments