Skip to content

Commit 6831634

Browse files
KAN implementation (#611)
* Improve spline * Add KAN --------- Co-authored-by: Filippo Olivo <folivo@filippoolivo.com>
1 parent f12173e commit 6831634

File tree

3 files changed

+418
-1
lines changed

3 files changed

+418
-1
lines changed
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
"""Create the infrastructure for a KAN layer"""
2+
import torch
3+
import numpy as np
4+
5+
from pina.model.spline import Spline
6+
7+
8+
class KAN_layer(torch.nn.Module):
9+
"""define a KAN layer using splines"""
10+
def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_nodes: int, num=3, grid_eps=0.1, grid_range=[-1, 1], grid_extension=True, noise_scale=0.1, base_function=torch.nn.SiLU(), scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, sparse_init=True, sp_trainable=True, sb_trainable=True) -> None:
11+
"""
12+
Initialize the KAN layer.
13+
"""
14+
super().__init__()
15+
self.k = k
16+
self.input_dimensions = input_dimensions
17+
self.output_dimensions = output_dimensions
18+
self.inner_nodes = inner_nodes
19+
self.num = num
20+
self.grid_eps = grid_eps
21+
self.grid_range = grid_range
22+
self.grid_extension = grid_extension
23+
24+
if sparse_init:
25+
self.mask = torch.nn.Parameter(self.sparse_mask(input_dimensions, output_dimensions)).requires_grad_(False)
26+
else:
27+
self.mask = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions)).requires_grad_(False)
28+
29+
grid = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1)[None,:].expand(self.input_dimensions, self.num+1)
30+
31+
if grid_extension:
32+
h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)
33+
for i in range(self.k):
34+
grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
35+
grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)
36+
37+
n_coef = grid.shape[1] - (self.k + 1)
38+
39+
control_points = torch.nn.Parameter(
40+
torch.randn(self.input_dimensions, self.output_dimensions, n_coef) * noise_scale
41+
)
42+
43+
self.spline = Spline(order=self.k+1, knots=grid, control_points=control_points, grid_extension=grid_extension)
44+
45+
self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(input_dimensions) + \
46+
scale_base_sigma * (torch.rand(input_dimensions, output_dimensions)*2-1) * 1/np.sqrt(input_dimensions), requires_grad=sb_trainable)
47+
self.scale_spline = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions) * scale_sp * 1 / np.sqrt(input_dimensions) * self.mask, requires_grad=sp_trainable)
48+
self.base_function = base_function
49+
50+
@staticmethod
51+
def sparse_mask(in_dimensions: int, out_dimensions: int) -> torch.Tensor:
52+
'''
53+
get sparse mask
54+
'''
55+
in_coord = torch.arange(in_dimensions) * 1/in_dimensions + 1/(2*in_dimensions)
56+
out_coord = torch.arange(out_dimensions) * 1/out_dimensions + 1/(2*out_dimensions)
57+
58+
dist_mat = torch.abs(out_coord[:,None] - in_coord[None,:])
59+
in_nearest = torch.argmin(dist_mat, dim=0)
60+
in_connection = torch.stack([torch.arange(in_dimensions), in_nearest]).permute(1,0)
61+
out_nearest = torch.argmin(dist_mat, dim=1)
62+
out_connection = torch.stack([out_nearest, torch.arange(out_dimensions)]).permute(1,0)
63+
all_connection = torch.cat([in_connection, out_connection], dim=0)
64+
mask = torch.zeros(in_dimensions, out_dimensions)
65+
mask[all_connection[:,0], all_connection[:,1]] = 1.
66+
return mask
67+
68+
def forward(self, x: torch.Tensor) -> torch.Tensor:
69+
"""
70+
Forward pass through the KAN layer.
71+
Each input goes through: w_base*base(x) + w_spline*spline(x)
72+
Then sum across input dimensions for each output node.
73+
"""
74+
if hasattr(x, 'tensor'):
75+
x_tensor = x.tensor
76+
else:
77+
x_tensor = x
78+
79+
base = self.base_function(x_tensor) # (batch, input_dimensions)
80+
81+
basis = self.spline.basis(x_tensor, self.spline.k, self.spline.knots)
82+
spline_out_per_input = torch.einsum("bil,iol->bio", basis, self.spline.control_points)
83+
84+
base_term = self.scale_base[None, :, :] * base[:, :, None]
85+
spline_term = self.scale_spline[None, :, :] * spline_out_per_input
86+
combined = base_term + spline_term
87+
combined = self.mask[None,:,:] * combined
88+
89+
output = torch.sum(combined, dim=1) # (batch, output_dimensions)
90+
91+
return output
92+
93+
def update_grid_from_samples(self, x: torch.Tensor, mode: str = 'sample'):
94+
"""
95+
Update grid from input samples to better fit data distribution.
96+
Based on PyKAN implementation but with boundary preservation.
97+
"""
98+
# Convert LabelTensor to regular tensor for spline operations
99+
if hasattr(x, 'tensor'):
100+
# This is a LabelTensor, extract the tensor part
101+
x_tensor = x.tensor
102+
else:
103+
x_tensor = x
104+
105+
with torch.no_grad():
106+
batch_size = x_tensor.shape[0]
107+
x_sorted = torch.sort(x_tensor, dim=0)[0] # (batch_size, input_dimensions)
108+
109+
# Get current number of intervals (excluding extensions)
110+
if self.grid_extension:
111+
num_interval = self.spline.knots.shape[1] - 1 - 2*self.k
112+
else:
113+
num_interval = self.spline.knots.shape[1] - 1
114+
115+
def get_grid(num_intervals: int):
116+
"""PyKAN-style grid creation with boundary preservation"""
117+
ids = [int(batch_size * i / num_intervals) for i in range(num_intervals)] + [-1]
118+
grid_adaptive = x_sorted[ids, :].transpose(0, 1) # (input_dimensions, num_intervals+1)
119+
120+
original_min = self.grid_range[0]
121+
original_max = self.grid_range[1]
122+
123+
# Clamp adaptive grid to not shrink beyond original domain
124+
grid_adaptive[:, 0] = torch.min(grid_adaptive[:, 0],
125+
torch.full_like(grid_adaptive[:, 0], original_min))
126+
grid_adaptive[:, -1] = torch.max(grid_adaptive[:, -1],
127+
torch.full_like(grid_adaptive[:, -1], original_max))
128+
129+
margin = 0.0
130+
h = (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) / num_intervals
131+
grid_uniform = (grid_adaptive[:, [0]] - margin +
132+
h * torch.arange(num_intervals + 1, device=x_tensor.device, dtype=x_tensor.dtype)[None, :])
133+
134+
grid_blended = (self.grid_eps * grid_uniform +
135+
(1 - self.grid_eps) * grid_adaptive)
136+
137+
return grid_blended
138+
139+
# Create augmented evaluation points: samples + boundary points
140+
# This ensures we preserve boundary behavior while adapting to sample density
141+
boundary_points = torch.tensor([[self.grid_range[0]], [self.grid_range[1]]],
142+
device=x_tensor.device, dtype=x_tensor.dtype).expand(-1, self.input_dimensions)
143+
144+
# Combine samples with boundary points for evaluation
145+
x_augmented = torch.cat([x_sorted, boundary_points], dim=0)
146+
x_augmented = torch.sort(x_augmented, dim=0)[0] # Re-sort with boundaries included
147+
148+
# Evaluate current spline at augmented points (samples + boundaries)
149+
basis = self.spline.basis(x_augmented, self.spline.k, self.spline.knots)
150+
y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points)
151+
152+
# Create new grid
153+
new_grid = get_grid(num_interval)
154+
155+
if mode == 'grid':
156+
# For 'grid' mode, use denser sampling
157+
sample_grid = get_grid(2 * num_interval)
158+
x_augmented = sample_grid.transpose(0, 1) # (batch_size, input_dimensions)
159+
basis = self.spline.basis(x_augmented, self.spline.k, self.spline.knots)
160+
y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points)
161+
162+
# Add grid extensions if needed
163+
if self.grid_extension:
164+
h = (new_grid[:, [-1]] - new_grid[:, [0]]) / (new_grid.shape[1] - 1)
165+
for i in range(self.k):
166+
new_grid = torch.cat([new_grid[:, [0]] - h, new_grid], dim=1)
167+
new_grid = torch.cat([new_grid, new_grid[:, [-1]] + h], dim=1)
168+
169+
# Update grid and refit coefficients
170+
self.spline.knots = new_grid
171+
172+
try:
173+
# Refit coefficients using augmented points (preserves boundaries)
174+
self.spline.compute_control_points(x_augmented, y_eval)
175+
except Exception as e:
176+
print(f"Warning: Failed to update coefficients during grid refinement: {e}")
177+
178+
def update_grid_resolution(self, new_num: int):
179+
"""
180+
Update grid resolution to a new number of intervals.
181+
"""
182+
with torch.no_grad():
183+
# Sample the current spline function on a dense grid
184+
x_eval = torch.linspace(
185+
self.grid_range[0],
186+
self.grid_range[1],
187+
steps=2 * new_num,
188+
device=self.spline.knots.device
189+
)
190+
x_eval = x_eval.unsqueeze(1).expand(-1, self.input_dimensions)
191+
192+
basis = self.spline.basis(x_eval, self.spline.k, self.spline.knots)
193+
y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points)
194+
195+
# Update num and create a new grid
196+
self.num = new_num
197+
new_grid = torch.linspace(
198+
self.grid_range[0],
199+
self.grid_range[1],
200+
steps=self.num + 1,
201+
device=self.spline.knots.device
202+
)
203+
new_grid = new_grid[None, :].expand(self.input_dimensions, self.num + 1)
204+
205+
if self.grid_extension:
206+
h = (new_grid[:, [-1]] - new_grid[:, [0]]) / (new_grid.shape[1] - 1)
207+
for i in range(self.k):
208+
new_grid = torch.cat([new_grid[:, [0]] - h, new_grid], dim=1)
209+
new_grid = torch.cat([new_grid, new_grid[:, [-1]] + h], dim=1)
210+
211+
# Update spline with the new grid and re-compute control points
212+
self.spline.knots = new_grid
213+
self.spline.compute_control_points(x_eval, y_eval)
214+
215+
def get_grid_statistics(self):
216+
"""Get statistics about the current grid for debugging/analysis"""
217+
return {
218+
'grid_shape': self.spline.knots.shape,
219+
'grid_min': self.spline.knots.min().item(),
220+
'grid_max': self.spline.knots.max().item(),
221+
'grid_range': (self.spline.knots.max() - self.spline.knots.min()).mean().item(),
222+
'num_intervals': self.spline.knots.shape[1] - 1 - (2*self.k if self.spline.grid_extension else 0)
223+
}

0 commit comments

Comments
 (0)