Skip to content

Commit 300e368

Browse files
Fix DecomposeConcatenate passing invalid out_dtype kwarg to quantize_per_tensor (#18569)
### Summary When decomposing a quantized `cat` with more than 5 inputs, `DecomposeConcatenate` inserts an intermediate Q/DQ pair. Previously it shared the same `kwargs` dict (copied from the original `dequantize_per_tensor` node) for both the `quantize_per_tensor` and `dequantize_per_tensor` nodes. However, `quantize_per_tensor.default` does not accept `out_dtype` — only `dequantize_per_tensor.default` does. This causes failures for models where the dequantize node carries `out_dtype` (e.g., fp16 quantized models). This PR splits the kwargs so that `out_dtype` is excluded from the quantize node kwargs while preserved for the dequantize node. ### Test plan Existing tests pass: ``` python -m unittest backends.xnnpack.test.passes.test_decompose_cat_pass -v ``` ``` test_cat_gt_10 ... ok test_cat_gt_5 ... ok test_qs8_cat_gt_10 ... ok test_qs8_cat_gt_5 ... ok ``` cc @GregoryComer @digantdesai @cbilgin
1 parent 042151d commit 300e368

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

backends/xnnpack/_passes/decompose_cat.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@ def call(self, graph_module: torch.fx.GraphModule):
6969
# if quantized we need to enforce the q/dq pattern for the newly inserted
7070
# concat node
7171
q_params = nodes_to_concat[0].args[1:]
72-
q_kwargs = nodes_to_concat[0].kwargs
72+
dq_kwargs = nodes_to_concat[0].kwargs
73+
# quantize_per_tensor does not accept out_dtype, so exclude
74+
# it from kwargs passed to the quantize node. out_dtype is
75+
# only valid for dequantize_per_tensor (e.g. fp16 models).
76+
q_kwargs = {k: v for k, v in dq_kwargs.items() if k != "out_dtype"}
7377
# Quantizer enforces all the inputs and output to a concat node must share
7478
# the same qparams, this means the newly inserted q/dq pair must share the
7579
# same qparams as the first quantized input in the concat node.
@@ -89,7 +93,7 @@ def call(self, graph_module: torch.fx.GraphModule):
8993
"call_function",
9094
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
9195
args=(q_node,) + q_params,
92-
kwargs=q_kwargs,
96+
kwargs=dq_kwargs,
9397
)
9498
tag_as_implicit_q_dq(dq_node)
9599
remainder_concat_node.args = (

0 commit comments

Comments
 (0)