Skip to content

Commit 5bd5902

Browse files
add docstrings
1 parent d4dfb65 commit 5bd5902

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

pina/_src/model/block/kan_block.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,20 @@
77

88
class KANBlock(torch.nn.Module):
99
"""
10-
TODO: docstring.
10+
The inner block of the Kolmogorov-Arnold Network (KAN).
11+
12+
The block applies a spline transformation to the input, optionally combined
13+
with a linear transformation of a base activation function. The output is
14+
aggregated across input dimensions to produce the final output.
15+
16+
.. seealso::
17+
18+
**Original reference**:
19+
Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M.,
20+
Hou T., Tegmark M. (2025).
21+
*KAN: Kolmogorov-Arnold Networks*.
22+
DOI: `arXiv preprint arXiv:2404.19756.
23+
<https://arxiv.org/abs/2404.19756>`_
1124
"""
1225

1326
def __init__(
@@ -119,16 +132,15 @@ def __init__(
119132

120133
def forward(self, x):
121134
"""
122-
Forward pass of the :class:`KANBlock`. It transforms the input using a
123-
vectorized spline basis and optionally adds a linear transformation of a
124-
base activation function.
125-
126-
The input is expected to have shape (batch_size, input_dimensions) and
127-
the output will have shape (batch_size, output_dimensions).
135+
Forward pass of the Kolmogorov-Arnold block. The input is passed through
136+
the spline transformation, optionally combined with a linear
137+
transformation of the base function output, and then aggregated across
138+
input dimensions to produce the final output.
128139
129-
:param torch.Tensor x: The input tensor for the model.
140+
:param x: The input tensor for the model.
141+
:type x: torch.Tensor | LabelTensor
130142
:return: The output tensor of the model.
131-
:rtype: torch.Tensor
143+
:rtype: torch.Tensor | LabelTensor
132144
"""
133145
y = self.spline(x)
134146

pina/_src/model/kolmogorov_arnold_network.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,20 @@
55

66
class KolmogorovArnoldNetwork(torch.nn.Module):
77
"""
8-
TODO: add docstring.
8+
Implementation of Kolmogorov-Arnold Network (KAN).
9+
10+
The model consists of a sequence of KAN blocks, where each block applies a
11+
spline transformation to the input, optionally combined with a linear
12+
transformation of a base activation function.
13+
14+
.. seealso::
15+
16+
**Original reference**:
17+
Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M.,
18+
Hou T., Tegmark M. (2025).
19+
*KAN: Kolmogorov-Arnold Networks*.
20+
DOI: `arXiv preprint arXiv:2404.19756.
21+
<https://arxiv.org/abs/2404.19756>`_
922
"""
1023

1124
def __init__(
@@ -78,7 +91,13 @@ def __init__(
7891

7992
def forward(self, x):
8093
"""
81-
TODO: add docstring.
94+
Forward pass of the KolmogorovArnoldNetwork model. It passes the input
95+
through each KAN block in the network and returns the final output.
96+
97+
:param x: The input tensor for the model.
98+
:type x: torch.Tensor | LabelTensor
99+
:return: The output tensor of the model.
100+
:rtype: torch.Tensor | LabelTensor
82101
"""
83102
for layer in self.kan_layers:
84103
x = layer(x)

0 commit comments

Comments
 (0)