@@ -1473,3 +1473,173 @@ def _conv_forward(
14731473
14741474 def forward (self , input : Tensor ) -> Tensor :
14751475 return self ._conv_forward (input , self .weight , self .bias , self .P , self .SIG )
1476+
1477+
1478+ class Dcls2dK1d (_DclsNd ):
1479+ __doc__ = (
1480+ r"""Applies a 2D convolution over an input signal composed of several input
1481+ planes.
1482+ In the simplest case, the output value of the layer with input size
1483+ :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
1484+ can be precisely described as:
1485+ .. math::
1486+ \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
1487+ \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
1488+ where :math:`\star` is the valid 2D `cross-correlation`_ operator,
1489+ :math:`N` is a batch size, :math:`C` denotes a number of channels,
1490+ :math:`H` is a height of input planes in pixels, and :math:`W` is
1491+ width in pixels.
1492+ """
1493+ + r"""
1494+ This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
1495+ * :attr:`stride` controls the stride for the cross-correlation, a single
1496+ number or a tuple.
1497+ * :attr:`padding` controls the amount of implicit padding on both
1498+ sides for :attr:`padding` number of points for each dimension.
1499+ * :attr:`dilation` controls the spacing between the kernel points; also
1500+ known as the à trous algorithm. It is harder to describe, but this `link`_
1501+ has a nice visualization of what :attr:`dilation` does.
1502+ {groups_note}
1503+ The parameters :attr:`kernel_count`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
1504+ - a single ``int`` -- in which case the same value is used for the height and width dimension
1505+ - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
1506+ and the second `int` for the width dimension
1507+ Note:
1508+ {depthwise_separable_note}
1509+ Note:
1510+ {cudnn_reproducibility_note}
1511+ Args:
1512+ in_channels (int): Number of channels in the input image
1513+ out_channels (int): Number of channels produced by the convolution
1514+ kernel_count (int): Number of elements in the convolving kernel
1515+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1516+ padding (int or tuple, optional): Zero-padding added to both sides of
1517+ the input. Default: 0
1518+ padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
1519+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
1520+ dilated_kernel_size (int or tuple, optional): Size of dilated kernel. Default: 1
1521+ groups (int, optional): Number of blocked connections from input
1522+ channels to output channels. Default: 1
1523+ bias (bool, optional): If ``True``, adds a learnable bias to the
1524+ output. Default: ``True``
1525+ """
1526+ + r"""
1527+ Shape:
1528+ - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
1529+ - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
1530+ .. math::
1531+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
1532+ \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
1533+ .. math::
1534+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
1535+ \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
1536+ Attributes:
1537+ weight (Tensor): the learnable weights of the module of shape
1538+ :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
1539+ :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
1540+ The values of these weights are sampled from
1541+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
1542+ :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
1543+ bias (Tensor): the learnable bias of the module of shape
1544+ (out_channels). If :attr:`bias` is ``True``,
1545+ then the values of these weights are
1546+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
1547+ :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
1548+ Examples:
1549+ >>> # With square kernels and equal stride
1550+ >>> m = nn.Conv2d(16, 33, 3, stride=2)
1551+ >>> # non-square kernels and unequal stride and with padding
1552+ >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
1553+ >>> # non-square kernels and unequal stride and with padding and dilation
1554+ >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
1555+ >>> input = torch.randn(20, 16, 50, 100)
1556+ >>> output = m(input)
1557+ .. _cross-correlation:
1558+ https://en.wikipedia.org/wiki/Cross-correlation
1559+ .. _link:
1560+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
1561+ """
1562+ )
1563+
1564+ def __init__ (
1565+ self ,
1566+ in_channels : int ,
1567+ out_channels : int ,
1568+ kernel_count : int ,
1569+ stride : _size_2_t = 1 ,
1570+ padding : _size_2_t = 0 ,
1571+ dilated_kernel_size : _size_1_t = 1 ,
1572+ groups : int = 1 ,
1573+ bias : bool = True ,
1574+ padding_mode : str = "zeros" , # TODO: refine this type
1575+ version : str = "v1" ,
1576+ flat_dim : int = 0 ,
1577+ ):
1578+ stride_ = _pair (stride )
1579+ padding_ = _pair (padding )
1580+ dilated_kernel_size_ = _single (dilated_kernel_size )
1581+ super (Dcls2dK1d , self ).__init__ (
1582+ in_channels ,
1583+ out_channels ,
1584+ kernel_count ,
1585+ stride_ ,
1586+ padding_ ,
1587+ dilated_kernel_size_ ,
1588+ False ,
1589+ _pair (0 ),
1590+ groups ,
1591+ bias ,
1592+ padding_mode ,
1593+ version ,
1594+ )
1595+
1596+ self .DCK = ConstructKernel1d (
1597+ self .out_channels ,
1598+ self .in_channels ,
1599+ self .groups ,
1600+ self .kernel_count ,
1601+ self .dilated_kernel_size ,
1602+ self .version ,
1603+ )
1604+
1605+ self .flat_dim = flat_dim
1606+
1607+ def extra_repr (self ):
1608+ s = super (Dcls2dK1d , self ).extra_repr ()
1609+ return s .format (** self .__dict__ )
1610+
1611+ def _conv_forward (
1612+ self ,
1613+ input : Tensor ,
1614+ weight : Tensor ,
1615+ bias : Optional [Tensor ],
1616+ P : Tensor ,
1617+ SIG : Optional [Tensor ],
1618+ ):
1619+
1620+ if self .padding_mode != "zeros" :
1621+ return F .conv2d (
1622+ F .pad (
1623+ input ,
1624+ self ._reversed_padding_repeated_twice ,
1625+ mode = self .padding_mode ,
1626+ ),
1627+ self .DCK (weight , P , SIG ).unsqueeze (- 1 - self .flat_dim ),
1628+ bias ,
1629+ self .stride ,
1630+ _pair (0 ),
1631+ _pair (1 ),
1632+ self .groups ,
1633+ )
1634+ return F .conv2d (
1635+ input ,
1636+ self .DCK (weight , P , SIG ).unsqueeze (- 1 - self .flat_dim ),
1637+ bias ,
1638+ self .stride ,
1639+ self .padding ,
1640+ _pair (1 ),
1641+ self .groups ,
1642+ )
1643+
1644+ def forward (self , input : Tensor ) -> Tensor :
1645+ return self ._conv_forward (input , self .weight , self .bias , self .P , self .SIG )
0 commit comments