@@ -238,6 +238,12 @@ def register_fake(
238238lib .define (
239239 "quantized_conv2d_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, Tensor? offset=None, *, Tensor(a!) out) -> Tensor(a!)"
240240)
241+ lib .define (
242+ "quantized_conv2d_depthwise_nhwc(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
243+ )
244+ lib .define (
245+ "quantized_conv2d_depthwise_nhwc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
246+ )
241247lib .define (
242248 "quantized_conv1d_ncl(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)"
243249)
@@ -2105,6 +2111,49 @@ def quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_meta(
21052111 return input .new_empty (output_size , dtype = input .dtype )
21062112
21072113
2114+ @register_fake ("cadence::quantized_conv2d_depthwise_nhwc" )
2115+ def quantized_conv2d_depthwise_nhwc_meta (
2116+ input : torch .Tensor ,
2117+ weight : torch .Tensor ,
2118+ bias : torch .Tensor ,
2119+ stride : Tuple [int ],
2120+ padding : Tuple [int ],
2121+ dilation : Tuple [int ],
2122+ groups : int ,
2123+ in_zero_point : int ,
2124+ weight_zero_point : int ,
2125+ bias_scale : float ,
2126+ output_scale : float ,
2127+ output_zero_point : int ,
2128+ out_multiplier : int ,
2129+ out_shift : int ,
2130+ ) -> torch .Tensor :
2131+ in_size = input .shape
2132+ assert len (in_size ) > 2
2133+ assert len (in_size ) < 6
2134+ # Depthwise weight is always [*kernel_size, OC]:
2135+ # 2D: [KH, KW, OC], 1D: [K, OC]
2136+ * kernel_size , out_channels = weight .shape
2137+
2138+ output_size = (
2139+ get_conv1d_output_size (
2140+ in_size ,
2141+ out_channels ,
2142+ stride [- 1 ],
2143+ padding [- 1 ],
2144+ dilation [- 1 ],
2145+ kernel_size [0 ],
2146+ True ,
2147+ )
2148+ if len (in_size ) == 3
2149+ else get_conv2d_output_size (
2150+ in_size , out_channels , stride , padding , dilation , kernel_size , True
2151+ )
2152+ )
2153+
2154+ return input .new_empty (output_size , dtype = input .dtype )
2155+
2156+
21082157@register_fake ("cadence::quantized_layer_norm" )
21092158def quantized_layer_norm_meta (
21102159 input : torch .Tensor ,
0 commit comments