|
7 | 7 |
|
8 | 8 | class KANBlock(torch.nn.Module): |
9 | 9 | """ |
10 | | - TODO: docstring. |
| 10 | + The inner block of the Kolmogorov-Arnold Network (KAN). |
| 11 | +
|
| 12 | + The block applies a spline transformation to the input, optionally combined |
| 13 | + with a linear transformation of a base activation function. The output is |
| 14 | + aggregated across input dimensions to produce the final output. |
| 15 | +
|
| 16 | + .. seealso:: |
| 17 | +
|
| 18 | + **Original reference**: |
| 19 | + Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M., |
| 20 | + Hou T., Tegmark M. (2025). |
| 21 | + *KAN: Kolmogorov-Arnold Networks*. |
| 22 | + DOI: `arXiv preprint arXiv:2404.19756. |
| 23 | + <https://arxiv.org/abs/2404.19756>`_ |
11 | 24 | """ |
12 | 25 |
|
13 | 26 | def __init__( |
@@ -119,16 +132,15 @@ def __init__( |
119 | 132 |
|
120 | 133 | def forward(self, x): |
121 | 134 | """ |
122 | | - Forward pass of the :class:`KANBlock`. It transforms the input using a |
123 | | - vectorized spline basis and optionally adds a linear transformation of a |
124 | | - base activation function. |
125 | | -
|
126 | | - The input is expected to have shape (batch_size, input_dimensions) and |
127 | | - the output will have shape (batch_size, output_dimensions). |
| 135 | + Forward pass of the Kolmogorov-Arnold block. The input is passed through |
| 136 | + the spline transformation, optionally combined with a linear |
| 137 | + transformation of the base function output, and then aggregated across |
| 138 | + input dimensions to produce the final output. |
128 | 139 |
|
129 | | - :param torch.Tensor x: The input tensor for the model. |
| 140 | + :param x: The input tensor for the model. |
| 141 | + :type x: torch.Tensor | LabelTensor |
130 | 142 | :return: The output tensor of the model. |
131 | | - :rtype: torch.Tensor |
| 143 | + :rtype: torch.Tensor | LabelTensor |
132 | 144 | """ |
133 | 145 | y = self.spline(x) |
134 | 146 |
|
|
0 commit comments