Skip to content

Commit 8e67a7a

Browse files
authored
Fix conv1d QuantFusion issue (#18623)
Differential Revision: D98941237 Pull Request resolved: #18623
1 parent 37ea3ff commit 8e67a7a

File tree

3 files changed

+7
-12
lines changed

3 files changed

+7
-12
lines changed

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
copy_node_metadata,
4141
create_zero_bias_int32,
4242
find_sequential_partitions_aten,
43-
get_conv_args,
4443
quantize_tensor_multiplier,
4544
)
4645
from executorch.exir.pass_base import ExportPass
@@ -263,10 +262,10 @@ def get_args_and_kwargs_conv(
263262
weight_zero_point = dequants_weights[0].args[2]
264263
# pyre-fixme[58]: Unsupported operand types
265264
bias_scale = dequants_inputs[0].args[1] * weight_scale
266-
stride = [1, 1] if len(op_node.args) < 4 else get_conv_args(op_node.args[3], 1)
267-
padding = [0, 0] if len(op_node.args) < 5 else get_conv_args(op_node.args[4], 0)
268-
dilation = [1, 1] if len(op_node.args) < 6 else get_conv_args(op_node.args[5], 1)
269-
groups = 1 if len(op_node.args) < 7 else op_node.args[6]
265+
stride = get_arg(op_node, "stride", list[int])
266+
padding = get_arg(op_node, "padding", list[int])
267+
dilation = get_arg(op_node, "dilation", list[int])
268+
groups = get_arg(op_node, "groups", int)
270269

271270
# If bias is not available, create a bias tensor with the shape of weight[0]
272271
if not bias_inputs:

backends/cadence/aot/quantizer/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,6 @@ def get_bias_qparams(
170170
return bias_scale, bias_zero_point
171171

172172

173-
def get_conv_args(arg, first_val: int) -> List[fx.Node]:
174-
return arg if len(arg) == 2 else [first_val, arg[0]]
175-
176-
177173
def get_aten_node_target_partitions(
178174
graph: torch.fx.Graph,
179175
wanted_original_aten_op: List[OpOverload],

backends/cadence/hifi/operators/op_quantized_conv1d_nlc.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,9 @@ void quantized_conv1d_nlc_per_tensor_out(
262262
ScalarType dtype = out.scalar_type();
263263

264264
if (dtype == ScalarType::Char) {
265-
// HiFi nnlib conv2d kernel produces incorrect results with stride > 1
266-
// on some backends (e.g., Artemis HiFi4). Fall back to generic.
267-
if (stride[0] > 1) {
265+
// HiFi nnlib conv2d kernel does not support depthwise (groups > 1)
266+
// or stride > 1. Fall back to generic implementation.
267+
if (groups > 1 || stride[0] > 1) {
268268
impl::generic::native::quantized_conv1d_nlc_per_tensor_out(
269269
ctx,
270270
input,

0 commit comments

Comments
 (0)