Skip to content

Commit 8ac5666

Browse files
authored
add Dcls2dK1d for 2D with a flat dimension
1 parent 0d14a4c commit 8ac5666

1 file changed

Lines changed: 170 additions & 0 deletions

File tree

DCLS/construct/modules.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,3 +1473,173 @@ def _conv_forward(
14731473

14741474
def forward(self, input: Tensor) -> Tensor:
14751475
return self._conv_forward(input, self.weight, self.bias, self.P, self.SIG)
1476+
1477+
1478+
class Dcls2dK1d(_DclsNd):
1479+
__doc__ = (
1480+
r"""Applies a 2D convolution over an input signal composed of several input
1481+
planes.
1482+
In the simplest case, the output value of the layer with input size
1483+
:math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
1484+
can be precisely described as:
1485+
.. math::
1486+
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
1487+
\sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
1488+
where :math:`\star` is the valid 2D `cross-correlation`_ operator,
1489+
:math:`N` is a batch size, :math:`C` denotes a number of channels,
1490+
:math:`H` is a height of input planes in pixels, and :math:`W` is
1491+
width in pixels.
1492+
"""
1493+
+ r"""
1494+
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
1495+
* :attr:`stride` controls the stride for the cross-correlation, a single
1496+
number or a tuple.
1497+
* :attr:`padding` controls the amount of implicit padding on both
1498+
sides for :attr:`padding` number of points for each dimension.
1499+
* :attr:`dilation` controls the spacing between the kernel points; also
1500+
known as the à trous algorithm. It is harder to describe, but this `link`_
1501+
has a nice visualization of what :attr:`dilation` does.
1502+
{groups_note}
1503+
The parameters :attr:`kernel_count`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
1504+
- a single ``int`` -- in which case the same value is used for the height and width dimension
1505+
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
1506+
and the second `int` for the width dimension
1507+
Note:
1508+
{depthwise_separable_note}
1509+
Note:
1510+
{cudnn_reproducibility_note}
1511+
Args:
1512+
in_channels (int): Number of channels in the input image
1513+
out_channels (int): Number of channels produced by the convolution
1514+
kernel_count (int): Number of elements in the convolving kernel
1515+
stride (int or tuple, optional): Stride of the convolution. Default: 1
1516+
padding (int or tuple, optional): Zero-padding added to both sides of
1517+
the input. Default: 0
1518+
padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
1519+
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
1520+
dilated_kernel_size (int or tuple, optional): Size of dilated kernel. Default: 1
1521+
groups (int, optional): Number of blocked connections from input
1522+
channels to output channels. Default: 1
1523+
bias (bool, optional): If ``True``, adds a learnable bias to the
1524+
output. Default: ``True``
1525+
"""
1526+
+ r"""
1527+
Shape:
1528+
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
1529+
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
1530+
.. math::
1531+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
1532+
\times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
1533+
.. math::
1534+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
1535+
\times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
1536+
Attributes:
1537+
weight (Tensor): the learnable weights of the module of shape
1538+
:math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
1539+
:math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
1540+
The values of these weights are sampled from
1541+
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
1542+
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
1543+
bias (Tensor): the learnable bias of the module of shape
1544+
(out_channels). If :attr:`bias` is ``True``,
1545+
then the values of these weights are
1546+
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
1547+
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
1548+
Examples:
1549+
>>> # With square kernels and equal stride
1550+
>>> m = nn.Conv2d(16, 33, 3, stride=2)
1551+
>>> # non-square kernels and unequal stride and with padding
1552+
>>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
1553+
>>> # non-square kernels and unequal stride and with padding and dilation
1554+
>>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
1555+
>>> input = torch.randn(20, 16, 50, 100)
1556+
>>> output = m(input)
1557+
.. _cross-correlation:
1558+
https://en.wikipedia.org/wiki/Cross-correlation
1559+
.. _link:
1560+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
1561+
"""
1562+
)
1563+
1564+
def __init__(
1565+
self,
1566+
in_channels: int,
1567+
out_channels: int,
1568+
kernel_count: int,
1569+
stride: _size_2_t = 1,
1570+
padding: _size_2_t = 0,
1571+
dilated_kernel_size: _size_1_t = 1,
1572+
groups: int = 1,
1573+
bias: bool = True,
1574+
padding_mode: str = "zeros", # TODO: refine this type
1575+
version: str = "v1",
1576+
flat_dim: int = 0,
1577+
):
1578+
stride_ = _pair(stride)
1579+
padding_ = _pair(padding)
1580+
dilated_kernel_size_ = _single(dilated_kernel_size)
1581+
super(Dcls2dK1d, self).__init__(
1582+
in_channels,
1583+
out_channels,
1584+
kernel_count,
1585+
stride_,
1586+
padding_,
1587+
dilated_kernel_size_,
1588+
False,
1589+
_pair(0),
1590+
groups,
1591+
bias,
1592+
padding_mode,
1593+
version,
1594+
)
1595+
1596+
self.DCK = ConstructKernel1d(
1597+
self.out_channels,
1598+
self.in_channels,
1599+
self.groups,
1600+
self.kernel_count,
1601+
self.dilated_kernel_size,
1602+
self.version,
1603+
)
1604+
1605+
self.flat_dim = flat_dim
1606+
1607+
def extra_repr(self):
1608+
s = super(Dcls2dK1d, self).extra_repr()
1609+
return s.format(**self.__dict__)
1610+
1611+
def _conv_forward(
1612+
self,
1613+
input: Tensor,
1614+
weight: Tensor,
1615+
bias: Optional[Tensor],
1616+
P: Tensor,
1617+
SIG: Optional[Tensor],
1618+
):
1619+
1620+
if self.padding_mode != "zeros":
1621+
return F.conv2d(
1622+
F.pad(
1623+
input,
1624+
self._reversed_padding_repeated_twice,
1625+
mode=self.padding_mode,
1626+
),
1627+
self.DCK(weight, P, SIG).unsqueeze(-1 - self.flat_dim),
1628+
bias,
1629+
self.stride,
1630+
_pair(0),
1631+
_pair(1),
1632+
self.groups,
1633+
)
1634+
return F.conv2d(
1635+
input,
1636+
self.DCK(weight, P, SIG).unsqueeze(-1 - self.flat_dim),
1637+
bias,
1638+
self.stride,
1639+
self.padding,
1640+
_pair(1),
1641+
self.groups,
1642+
)
1643+
1644+
def forward(self, input: Tensor) -> Tensor:
1645+
return self._conv_forward(input, self.weight, self.bias, self.P, self.SIG)

0 commit comments

Comments
 (0)