@@ -351,10 +351,6 @@ def register_fake(
351351 "quantized_matmul_asym8uxasym8u_asym8u.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)"
352352)
353353
354- lib .define (
355- "convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, "
356- "int[] dilation, int groups, bool channel_last=False) -> (Tensor Y)"
357- )
358354lib .define (
359355 "transposed_convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, "
360356 "int[] dilation, SymInt[] output_padding, int groups, bool channel_last=False) -> (Tensor Y)"
@@ -489,8 +485,28 @@ def register_fake(
489485# ------------------------------------ #
490486# Migrated from the custom_ops.yaml files containing different operator variants (e.g., .out, .tensor_out)
491487lib .define (
492- "convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, "
493- "int groups, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
488+ "conv1d(Tensor input, Tensor weight, Tensor bias, int[1] stride, SymInt[1] padding, int[1] dilation, "
489+ "int groups) -> Tensor"
490+ )
491+ lib .define (
492+ "conv1d.out(Tensor input, Tensor weight, Tensor bias, int[1] stride, SymInt[1] padding, int[1] dilation, "
493+ "int groups, *, Tensor(a!) out) -> Tensor(a!)"
494+ )
495+ lib .define (
496+ "conv2d(Tensor input, Tensor weight, Tensor bias, int[2] stride, SymInt[2] padding, int[2] dilation, "
497+ "int groups) -> Tensor"
498+ )
499+ lib .define (
500+ "conv2d.out(Tensor input, Tensor weight, Tensor bias, int[2] stride, SymInt[2] padding, int[2] dilation, "
501+ "int groups, *, Tensor(a!) out) -> Tensor(a!)"
502+ )
503+ lib .define (
504+ "conv3d(Tensor input, Tensor weight, Tensor bias, int[3] stride, SymInt[3] padding, int[3] dilation, "
505+ "int groups) -> Tensor"
506+ )
507+ lib .define (
508+ "conv3d.out(Tensor input, Tensor weight, Tensor bias, int[3] stride, SymInt[3] padding, int[3] dilation, "
509+ "int groups, *, Tensor(a!) out) -> Tensor(a!)"
494510)
495511lib .define (
496512 "transposed_convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, "
@@ -2152,41 +2168,102 @@ def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta(
21522168 return src .new_empty (out_size , dtype = src .dtype )
21532169
21542170
2155- @register_fake ("cadence::convolution " )
2156- def convolution_meta (
2171+ @register_fake ("cadence::conv1d " )
2172+ def conv1d_meta (
21572173 input : torch .Tensor ,
21582174 weight : torch .Tensor ,
21592175 bias : torch .Tensor ,
21602176 stride : Tuple [int ],
21612177 padding : Tuple [int ],
21622178 dilation : Tuple [int ],
21632179 groups : int ,
2164- channel_last : bool = False ,
21652180) -> torch .Tensor :
2166- if channel_last :
2167- out_channels , * kernel_size , _ = weight .shape
2168- else :
2169- out_channels , _ , * kernel_size = weight .shape
2181+ assert (
2182+ len ( weight .shape ) == 3
2183+ ), f"Conv1d expects a 3D weight, got { len ( weight . shape ) } D"
2184+ out_channels , _ , kernel_size = weight .shape
21702185 in_size = input .shape
2171- # Assert that the input tensor has at least 3 dimensions, and at most 6
2172- assert len (in_size ) > 2
2173- assert len (in_size ) < 6
2186+ assert len (in_size ) == 3 , f"conv1d expects 3D input, got { len (in_size )} D"
21742187
2175- # Compute the output tensor size
2176- output_size = (
2177- get_conv1d_output_size (
2178- in_size ,
2179- out_channels ,
2180- stride [0 ],
2181- padding [0 ],
2182- dilation [0 ],
2183- kernel_size [0 ],
2184- channel_last ,
2185- )
2186- if len (in_size ) == 3
2187- else get_conv2d_output_size (
2188- in_size , out_channels , stride , padding , dilation , kernel_size , channel_last
2189- )
2188+ output_size = get_conv1d_output_size (
2189+ in_size ,
2190+ out_channels ,
2191+ stride [0 ],
2192+ padding [0 ],
2193+ dilation [0 ],
2194+ kernel_size ,
2195+ False ,
2196+ )
2197+
2198+ return input .new_empty (output_size , dtype = input .dtype )
2199+
2200+
2201+ @register_fake ("cadence::conv2d" )
2202+ def conv2d_meta (
2203+ input : torch .Tensor ,
2204+ weight : torch .Tensor ,
2205+ bias : torch .Tensor ,
2206+ stride : Tuple [int ],
2207+ padding : Tuple [int ],
2208+ dilation : Tuple [int ],
2209+ groups : int ,
2210+ ) -> torch .Tensor :
2211+ assert (
2212+ len (weight .shape ) == 4
2213+ ), f"Conv2d expects a 4D weight, got { len (weight .shape )} D"
2214+ out_channels , _ , * kernel_size = weight .shape
2215+ in_size = input .shape
2216+ assert len (in_size ) == 4 , f"conv2d expects 4D input, got { len (in_size )} D"
2217+
2218+ output_size = get_conv2d_output_size (
2219+ in_size , out_channels , stride , padding , dilation , kernel_size , False
2220+ )
2221+
2222+ return input .new_empty (output_size , dtype = input .dtype )
2223+
2224+
2225+ @register_fake ("cadence::conv3d" )
2226+ def conv3d_meta (
2227+ input : torch .Tensor ,
2228+ weight : torch .Tensor ,
2229+ bias : torch .Tensor ,
2230+ stride : Tuple [int , int , int ],
2231+ padding : Tuple [int , int , int ],
2232+ dilation : Tuple [int , int , int ],
2233+ groups : int ,
2234+ ) -> torch .Tensor :
2235+ assert (
2236+ len (weight .shape ) == 5
2237+ ), f"Conv3d expects a 5D weight, got { len (weight .shape )} D"
2238+ out_channels , _ , * kernel_size = weight .shape
2239+ in_size = input .shape
2240+ assert len (in_size ) == 5 , f"conv3d expects 5D input, got { len (in_size )} D"
2241+
2242+ # Helper to compute 3D convolution output size
2243+ def get_conv3d_output_size (
2244+ in_size : torch .Size ,
2245+ out_channels : int ,
2246+ stride : Tuple [int , int , int ],
2247+ padding : Tuple [int , int , int ],
2248+ dilation : Tuple [int , int , int ],
2249+ kernel_size : list [int ],
2250+ ) -> torch .Size :
2251+ N , C , D , H , W = in_size
2252+
2253+ dout = (D + 2 * padding [0 ] - dilation [0 ] * (kernel_size [0 ] - 1 ) - 1 ) // stride [
2254+ 0
2255+ ] + 1
2256+ hout = (H + 2 * padding [1 ] - dilation [1 ] * (kernel_size [1 ] - 1 ) - 1 ) // stride [
2257+ 1
2258+ ] + 1
2259+ wout = (W + 2 * padding [2 ] - dilation [2 ] * (kernel_size [2 ] - 1 ) - 1 ) // stride [
2260+ 2
2261+ ] + 1
2262+
2263+ return torch .Size ((N , out_channels , dout , hout , wout ))
2264+
2265+ output_size = get_conv3d_output_size (
2266+ in_size , out_channels , stride , padding , dilation , kernel_size
21902267 )
21912268
21922269 return input .new_empty (output_size , dtype = input .dtype )
0 commit comments