Skip to content

Commit f72ca99

Browse files
committed
KAN with non-vectorized spline
1 parent 8307d12 commit f72ca99

File tree

6 files changed

+322
-46
lines changed

6 files changed

+322
-46
lines changed

pina/_src/model/spline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(self, order=4, knots=None, control_points=None):
117117
raise ValueError("knots must be one-dimensional.")
118118

119119
# Check dimensionality of control points
120-
if self.control_points.ndim > 1:
120+
if self.control_points.ndim > 2:
121121
raise ValueError("control_points must be one-dimensional.")
122122

123123
# Raise error if #knots != order + #control_points
@@ -277,9 +277,10 @@ def forward(self, x):
277277
:return: The output tensor.
278278
:rtype: torch.Tensor
279279
"""
280+
basis = self.basis(x.as_subclass(torch.Tensor))
280281
return torch.einsum(
281-
"...bi, i -> ...b",
282-
self.basis(x.as_subclass(torch.Tensor)).squeeze(-1),
282+
"...bi, ...i -> ...b",
283+
basis,
283284
self.control_points,
284285
)
285286

pina/condition/tensor_condition.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""Module for the DataCondition class."""
2+
3+
import torch
4+
from torch_geometric.data import Data
5+
from .condition_interface import ConditionInterface
6+
from ..label_tensor import LabelTensor
7+
from ..graph import Graph
8+
9+
10+
class _TensorCondition(ConditionInterface):
11+
12+
__slots__ = ["input", "conditional_variables"]
13+
_avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple)
14+
_avail_conditional_variables_cls = (torch.Tensor, LabelTensor)
15+
16+
def __new__(cls, input, conditional_variables=None):
17+
"""
18+
Instantiate the appropriate subclass of :class:`DataCondition` based on
19+
the type of ``input``.
20+
21+
:param input: Input data for the condition.
22+
:type input: torch.Tensor | LabelTensor | Graph |
23+
Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]
24+
:param conditional_variables: Conditional variables for the condition.
25+
:type conditional_variables: torch.Tensor | LabelTensor, optional
26+
:return: Subclass of DataCondition.
27+
:rtype: pina.condition.data_condition.TensorDataCondition |
28+
pina.condition.data_condition.GraphDataCondition
29+
30+
:raises ValueError: If input is not of type :class:`torch.Tensor`,
31+
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`,
32+
or :class:`~torch_geometric.data.Data`.
33+
"""
34+
35+
if cls != DataCondition:
36+
return super().__new__(cls)
37+
if isinstance(input, (torch.Tensor, LabelTensor)):
38+
subclass = TensorDataCondition
39+
return subclass.__new__(subclass, input, conditional_variables)
40+
41+
if isinstance(input, (Graph, Data, list, tuple)):
42+
cls._check_graph_list_consistency(input)
43+
subclass = GraphDataCondition
44+
return subclass.__new__(subclass, input, conditional_variables)
45+
46+
raise ValueError(
47+
"Invalid input types. "
48+
"Please provide either torch_geometric.data.Data or Graph objects."
49+
)
50+
51+
def __init__(self, input, conditional_variables=None):
52+
"""
53+
Initialize the object by storing the input and conditional
54+
variables (if any).
55+
56+
:param input: Input data for the condition.
57+
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
58+
list[Data] | tuple[Graph] | tuple[Data]
59+
:param conditional_variables: Conditional variables for the condition.
60+
:type conditional_variables: torch.Tensor | LabelTensor
61+
62+
.. note::
63+
If ``input`` consists of a list of :class:`~pina.graph.Graph` or
64+
:class:`~torch_geometric.data.Data`, all elements must have the same
65+
structure (keys and data types)
66+
"""
67+
68+
super().__init__()
69+
self.input = input
70+
self.conditional_variables = conditional_variables
71+
72+
73+
class TensorDataCondition(DataCondition):
74+
"""
75+
DataCondition for :class:`torch.Tensor` or
76+
:class:`~pina.label_tensor.LabelTensor` input data
77+
"""
78+
79+
80+
class GraphDataCondition(DataCondition):
81+
"""
82+
DataCondition for :class:`~pina.graph.Graph` or
83+
:class:`~torch_geometric.data.Data` input data
84+
"""

pina/model/block/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"RBFBlock",
2626
"GNOBlock",
2727
"PirateNetBlock",
28+
"KANBlock",
2829
]
2930

3031
from pina._src.model.block.convolution_2d import ContinuousConvBlock

pina/model/kolmogorov_arnold_network/kan_layer.py renamed to pina/model/block/kan_block.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from pina.model.spline import Spline
66

77

8-
class KAN_layer(torch.nn.Module):
8+
class KANBlock(torch.nn.Module):
99
"""define a KAN layer using splines"""
1010
def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_nodes: int, num=3, grid_eps=0.1, grid_range=[-1, 1], grid_extension=True, noise_scale=0.1, base_function=torch.nn.SiLU(), scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, sparse_init=True, sp_trainable=True, sb_trainable=True) -> None:
1111
"""
1212
Initialize the KAN layer.
13+
14+
num è il numero di intervalli nella griglia iniziale (esclusi gli eventuali nodi di estensione)
1315
"""
1416
super().__init__()
1517
self.k = k
@@ -27,24 +29,46 @@ def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_
2729
self.mask = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions)).requires_grad_(False)
2830

2931
grid = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1)[None,:].expand(self.input_dimensions, self.num+1)
32+
knots = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1)
3033

3134
if grid_extension:
3235
h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)
3336
for i in range(self.k):
3437
grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
3538
grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)
3639

37-
n_coef = grid.shape[1] - (self.k + 1)
40+
n_control_points = len(knots) - (self.k )
3841

39-
control_points = torch.nn.Parameter(
40-
torch.randn(self.input_dimensions, self.output_dimensions, n_coef) * noise_scale
41-
)
42+
# control_points = torch.nn.Parameter(
43+
# torch.randn(self.input_dimensions, self.output_dimensions, n_control_points) * noise_scale
44+
# )
45+
# print(control_points.shape)
46+
spline_q = []
47+
for q in range(self.output_dimensions):
48+
spline_p = []
49+
for p in range(self.input_dimensions):
50+
spline_ = Spline(
51+
order=self.k,
52+
knots=knots,
53+
control_points=torch.randn(n_control_points)
54+
)
55+
spline_p.append(spline_)
56+
spline_p = torch.nn.ModuleList(spline_p)
57+
spline_q.append(spline_p)
58+
self.spline_q = torch.nn.ModuleList(spline_q)
59+
60+
61+
# control_points = torch.nn.Parameter(
62+
# torch.randn(n_control_points, self.output_dimensions) * noise_scale)
63+
# print(control_points)
64+
# print('uuu')
4265

43-
self.spline = Spline(order=self.k+1, knots=grid, control_points=control_points, grid_extension=grid_extension)
66+
# self.spline = Spline(
67+
# order=self.k, knots=knots, control_points=control_points)
4468

45-
self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(input_dimensions) + \
46-
scale_base_sigma * (torch.rand(input_dimensions, output_dimensions)*2-1) * 1/np.sqrt(input_dimensions), requires_grad=sb_trainable)
47-
self.scale_spline = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions) * scale_sp * 1 / np.sqrt(input_dimensions) * self.mask, requires_grad=sp_trainable)
69+
# self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(input_dimensions) + \
70+
# scale_base_sigma * (torch.rand(input_dimensions, output_dimensions)*2-1) * 1/np.sqrt(input_dimensions), requires_grad=sb_trainable)
71+
# self.scale_spline = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions) * scale_sp * 1 / np.sqrt(input_dimensions) * self.mask, requires_grad=sp_trainable)
4872
self.base_function = base_function
4973

5074
@staticmethod
@@ -76,19 +100,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
76100
else:
77101
x_tensor = x
78102

79-
base = self.base_function(x_tensor) # (batch, input_dimensions)
80-
81-
basis = self.spline.basis(x_tensor, self.spline.k, self.spline.knots)
82-
spline_out_per_input = torch.einsum("bil,iol->bio", basis, self.spline.control_points)
83-
84-
base_term = self.scale_base[None, :, :] * base[:, :, None]
85-
spline_term = self.scale_spline[None, :, :] * spline_out_per_input
86-
combined = base_term + spline_term
87-
combined = self.mask[None,:,:] * combined
88-
89-
output = torch.sum(combined, dim=1) # (batch, output_dimensions)
90-
91-
return output
103+
y = []
104+
for q in range(self.output_dimensions):
105+
y_q = []
106+
for p in range(self.input_dimensions):
107+
spline_out = self.spline_q[q][p].forward(x_tensor[:, p]) # (batch, input_dimensions, output_dimensions)
108+
base_out = self.base_function(x_tensor[:, p]) # (batch, input_dimensions)
109+
y_q.append(spline_out + base_out)
110+
y.append(torch.stack(y_q, dim=1).sum(dim=1))
111+
y = torch.stack(y, dim=1)
112+
113+
return y
92114

93115
def update_grid_from_samples(self, x: torch.Tensor, mode: str = 'sample'):
94116
"""

pina/model/kolmogorov_arnold_network/kan_network.py renamed to pina/model/kolmogorov_arnold_network.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,20 @@
33
import torch.nn as nn
44
from typing import List
55

6-
try:
7-
from .kan_layer import KAN_layer
8-
except ImportError:
9-
from kan_layer import KAN_layer
6+
from pina.model.block import KANBlock
107

11-
class KAN_Network(torch.nn.Module):
8+
class KolmogorovArnoldNetwork(torch.nn.Module):
129
"""
13-
Kolmogorov Arnold Network - A neural network using KAN layers instead of traditional MLP layers.
14-
Each layer uses learnable univariate functions (B-splines + base functions) on edges.
10+
Kolmogorov Arnold Network, a neural network using KAN layers instead of
11+
traditional MLP layers. Each layer uses learnable univariate functions
12+
(B-splines + base functions) on edges.
13+
14+
.. references::
15+
16+
Liu, Z., Wang, Y., Vaidya, S., Ruehle, F., Halverson, J., Soljačić, M.,
17+
... & Tegmark, M. (2024). Kan: Kolmogorov-arnold networks. arXiv
18+
preprint arXiv:2404.19756.
19+
1520
"""
1621

1722
def __init__(
@@ -35,19 +40,25 @@ def __init__(
3540
):
3641
"""
3742
Initialize the KAN network.
38-
39-
Args:
40-
layer_sizes: List of integers defining the size of each layer [input_dim, hidden1, hidden2, ..., output_dim]
41-
k: Order of the B-spline
42-
num: Number of grid points for B-splines
43-
grid_eps: Epsilon for grid spacing
44-
grid_range: Range for the grid [min, max]
45-
grid_extension: Whether to extend the grid
46-
noise_scale: Scale for initialization noise
47-
base_function: Base activation function (e.g., SiLU)
48-
scale_base_mu: Mean for base function scaling
49-
scale_base_sigma: Std for base function scaling
50-
scale_sp: Scale for spline functions
43+
44+
:param iterable layer_sizes: List of layer sizes including input and
45+
output dimensions.
46+
:param int k: Order of the B-spline.
47+
:param int num: Number of grid points for B-splines.
48+
:param float grid_eps: Epsilon for grid spacing.
49+
:param list grid_range: Range for the grid [min, max].
50+
:param bool grid_extension: Whether to extend the grid.
51+
:param float noise_scale: Scale for initialization noise.
52+
:param base_function: Base activation function (e.g., SiLU).
53+
:param float scale_base_mu: Mean for base function scaling.
54+
:param float scale_base_sigma: Std for base function scaling.
55+
:param float scale_sp: Scale for spline functions.
56+
:param int inner_nodes: Number of inner nodes for KAN layers.
57+
:param bool sparse_init: Whether to use sparse initialization.
58+
:param bool sp_trainable: Whether spline parameters are trainable.
59+
:param bool sb_trainable: Whether base function parameters are
60+
trainable.
61+
:param bool save_act: Whether to save activations after each layer.
5162
"""
5263
super().__init__()
5364

@@ -62,7 +73,7 @@ def __init__(
6273
self.kan_layers = nn.ModuleList()
6374

6475
for i in range(self.num_layers):
65-
layer = KAN_layer(
76+
layer = KANBlock(
6677
k=k,
6778
input_dimensions=layer_sizes[i],
6879
output_dimensions=layer_sizes[i+1],
@@ -97,6 +108,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
97108

98109
for i, layer in enumerate(self.kan_layers):
99110
current = layer(current)
111+
current = torch.nn.functional.sigmoid(current)
100112

101113
if self.save_act:
102114
self.acts.append(current.detach())

0 commit comments

Comments
 (0)