Skip to content

Commit 8e5ec80

Browse files
authored
Enable Q/DQ opt with multiple inputs
Differential Revision: D101700404 Pull Request resolved: #19010
1 parent c391738 commit 8e5ec80

1 file changed

Lines changed: 13 additions & 7 deletions

File tree

backends/cadence/aot/compiler_funcs.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
import logging
1010
import operator
11-
from typing import Any, Optional, Union
11+
from collections.abc import Mapping, Sequence
12+
from typing import Any, cast, Optional, Union
1213

1314
import torch
1415
from torch._inductor.decomposition import remove_decompositions
@@ -301,23 +302,27 @@ def __init__(
301302
"Warning: Using pre-quantized inputs. This should only be done when calibration has been confirmed."
302303
"Incorrect quantization parameters can lead to significant accuracy degradation."
303304
)
304-
if isinstance(input_args, list):
305-
self.quant_args = extract_input_quant_params_from_graph(module, input_args)
306-
elif isinstance(input_args, dict):
305+
if isinstance(input_args, Sequence) and not isinstance(
306+
input_args, (str, bytes)
307+
):
308+
self.quant_args = extract_input_quant_params_from_graph(
309+
module, list(input_args)
310+
)
311+
elif isinstance(input_args, Mapping):
307312
# dict[int, QuantArgs] — use directly
308313
# dict[int, list[str]] — extract quant params from graph, keyed by input index
309314
first_value = next(iter(input_args.values()), None)
310315
if (
311-
isinstance(first_value, (list, tuple))
316+
isinstance(first_value, (list, tuple, Sequence))
317+
and not isinstance(first_value, (str, bytes))
312318
and first_value
313319
and isinstance(first_value[0], str)
314320
):
315321
# Values are lists of node names: extract quant params and map
316322
# to the caller-specified input indices.
317323
for input_idx, node_names in input_args.items():
318-
assert isinstance(node_names, list)
319324
extracted = extract_input_quant_params_from_graph(
320-
module, node_names
325+
module, list(cast(Sequence[str], node_names))
321326
)
322327
# Use the first extracted quant params for this input index.
323328
if extracted:
@@ -430,6 +435,7 @@ def _get_transparent_ops() -> set[Any]:
430435
torch.ops.aten.view.default,
431436
torch.ops.aten.reshape.default,
432437
torch.ops.aten.split.Tensor,
438+
torch.ops.aten.chunk.default,
433439
torch.ops.aten.slice_copy.Tensor,
434440
torch.ops.aten.permute_copy.default,
435441
torch.ops.aten.permute.default,

0 commit comments

Comments
 (0)