@@ -1179,18 +1179,25 @@ def quantized_conv2d_nhwc_per_tensor_meta(
11791179 in_size = input .shape
11801180 # Assert that the input tensor has at least 3 dimensions, and at most 6
11811181 assert len (in_size ) > 2
1182- assert len (in_size ) < 6
1182+ assert len (in_size ) < 5
11831183
11841184 # Determine weight layout based on depthwise vs regular conv.
11851185 in_channels = in_size [- 1 ]
1186- if is_depthwise_conv (groups , in_channels ):
1187- * kernel_size , out_channels = weight .shape
1188- elif len (in_size ) == 3 :
1189- # 1D conv: weight is [OC, K, IC]
1190- out_channels , * kernel_size , _ = weight .shape
1186+ depthwise = is_depthwise_conv (groups , in_channels )
1187+ if len (in_size ) == 3 :
1188+ if len (weight .shape ) == 2 :
1189+ assert depthwise
1190+ * kernel_size , out_channels = weight .shape
1191+ else :
1192+ out_channels , * kernel_size , _ = weight .shape
1193+ elif len (in_size ) == 4 :
1194+ if len (weight .shape ) == 3 :
1195+ assert depthwise
1196+ * kernel_size , out_channels = weight .shape
1197+ else :
1198+ out_channels , * kernel_size , _ = weight .shape
11911199 else :
1192- # 2D regular conv: weight is [OC, KH, KW, IC]
1193- out_channels , * kernel_size , _ = weight .shape
1200+ raise ValueError ("Unsupported input tensor size" )
11941201
11951202 # Compute the output tensor size
11961203 output_size = (
0 commit comments