File tree Expand file tree Collapse file tree 3 files changed +7
-12
lines changed
Expand file tree Collapse file tree 3 files changed +7
-12
lines changed Original file line number Diff line number Diff line change 4040 copy_node_metadata ,
4141 create_zero_bias_int32 ,
4242 find_sequential_partitions_aten ,
43- get_conv_args ,
4443 quantize_tensor_multiplier ,
4544)
4645from executorch .exir .pass_base import ExportPass
@@ -263,10 +262,10 @@ def get_args_and_kwargs_conv(
263262 weight_zero_point = dequants_weights [0 ].args [2 ]
264263 # pyre-fixme[58]: Unsupported operand types
265264 bias_scale = dequants_inputs [0 ].args [1 ] * weight_scale
266- stride = [ 1 , 1 ] if len (op_node . args ) < 4 else get_conv_args ( op_node . args [ 3 ], 1 )
267- padding = [ 0 , 0 ] if len (op_node . args ) < 5 else get_conv_args ( op_node . args [ 4 ], 0 )
268- dilation = [ 1 , 1 ] if len (op_node . args ) < 6 else get_conv_args ( op_node . args [ 5 ], 1 )
269- groups = 1 if len (op_node . args ) < 7 else op_node . args [ 6 ]
265+ stride = get_arg (op_node , "stride" , list [ int ] )
266+ padding = get_arg (op_node , "padding" , list [ int ] )
267+ dilation = get_arg (op_node , "dilation" , list [ int ] )
268+ groups = get_arg (op_node , "groups" , int )
270269
271270 # If bias is not available, create a bias tensor with the shape of weight[0]
272271 if not bias_inputs :
Original file line number Diff line number Diff line change @@ -170,10 +170,6 @@ def get_bias_qparams(
170170 return bias_scale , bias_zero_point
171171
172172
173- def get_conv_args (arg , first_val : int ) -> List [fx .Node ]:
174- return arg if len (arg ) == 2 else [first_val , arg [0 ]]
175-
176-
177173def get_aten_node_target_partitions (
178174 graph : torch .fx .Graph ,
179175 wanted_original_aten_op : List [OpOverload ],
Original file line number Diff line number Diff line change @@ -262,9 +262,9 @@ void quantized_conv1d_nlc_per_tensor_out(
262262 ScalarType dtype = out.scalar_type ();
263263
264264 if (dtype == ScalarType::Char) {
265- // HiFi nnlib conv2d kernel produces incorrect results with stride > 1
266- // on some backends (e.g., Artemis HiFi4). Fall back to generic.
267- if (stride[0 ] > 1 ) {
265+ // HiFi nnlib conv2d kernel does not support depthwise (groups > 1)
266+ // or stride > 1. Fall back to generic implementation .
267+ if (groups > 1 || stride[0 ] > 1 ) {
268268 impl::generic::native::quantized_conv1d_nlc_per_tensor_out (
269269 ctx,
270270 input,
You can’t perform that action at this time.
0 commit comments