Skip to content

Commit 331ac4a

Browse files
authored
Support two modes of depthwise conv
Differential Revision: D94432987 Pull Request resolved: #17723
1 parent c015241 commit 331ac4a

2 files changed

Lines changed: 19 additions & 10 deletions

File tree

backends/cadence/aot/ops_registrations.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,18 +1179,25 @@ def quantized_conv2d_nhwc_per_tensor_meta(
11791179
in_size = input.shape
11801180
# Assert that the input tensor has at least 3 dimensions, and at most 6
11811181
assert len(in_size) > 2
1182-
assert len(in_size) < 6
1182+
assert len(in_size) < 5
11831183

11841184
# Determine weight layout based on depthwise vs regular conv.
11851185
in_channels = in_size[-1]
1186-
if is_depthwise_conv(groups, in_channels):
1187-
*kernel_size, out_channels = weight.shape
1188-
elif len(in_size) == 3:
1189-
# 1D conv: weight is [OC, K, IC]
1190-
out_channels, *kernel_size, _ = weight.shape
1186+
depthwise = is_depthwise_conv(groups, in_channels)
1187+
if len(in_size) == 3:
1188+
if len(weight.shape) == 2:
1189+
assert depthwise
1190+
*kernel_size, out_channels = weight.shape
1191+
else:
1192+
out_channels, *kernel_size, _ = weight.shape
1193+
elif len(in_size) == 4:
1194+
if len(weight.shape) == 3:
1195+
assert depthwise
1196+
*kernel_size, out_channels = weight.shape
1197+
else:
1198+
out_channels, *kernel_size, _ = weight.shape
11911199
else:
1192-
# 2D regular conv: weight is [OC, KH, KW, IC]
1193-
out_channels, *kernel_size, _ = weight.shape
1200+
raise ValueError("Unsupported input tensor size")
11941201

11951202
# Compute the output tensor size
11961203
output_size = (

backends/cadence/aot/ref_implementations.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,8 @@ def quantized_conv2d_nhwc_per_tensor(
11101110
if len(input_tensor.shape) == 3:
11111111
# 1D conv: input is [N, L, C] -> [N, C, L]
11121112
input_tensor = input_tensor.movedim(-1, 1).contiguous()
1113-
if depthwise:
1113+
if len(weight.shape) == 2:
1114+
assert depthwise, "1D depthwise conv requires 2D weight tensor"
11141115
# 1D depthwise: weight is [K, OC] -> [OC, 1, K]
11151116
weight = weight.permute(1, 0).unsqueeze(1).contiguous()
11161117
else:
@@ -1120,7 +1121,8 @@ def quantized_conv2d_nhwc_per_tensor(
11201121
else:
11211122
# 2D conv: input is [N, H, W, C] -> [N, C, H, W]
11221123
input_tensor = input_tensor.movedim(-1, -3)
1123-
if depthwise:
1124+
if len(weight.shape) == 3:
1125+
assert depthwise, "2D depthwise conv requires 3D weight tensor"
11241126
# 2D depthwise: weight is [KH, KW, OC] -> [OC, 1, KH, KW]
11251127
weight = weight.permute(2, 0, 1).unsqueeze(1).contiguous()
11261128
else:

0 commit comments

Comments
 (0)