1+ """Create the infrastructure for a KAN layer"""
2+ import torch
3+ import numpy as np
4+
5+ from pina .model .spline import Spline
6+
7+
8+ class KAN_layer (torch .nn .Module ):
9+ """define a KAN layer using splines"""
10+ def __init__ (self , k : int , input_dimensions : int , output_dimensions : int , inner_nodes : int , num = 3 , grid_eps = 0.1 , grid_range = [- 1 , 1 ], grid_extension = True , noise_scale = 0.1 , base_function = torch .nn .SiLU (), scale_base_mu = 0.0 , scale_base_sigma = 1.0 , scale_sp = 1.0 , sparse_init = True , sp_trainable = True , sb_trainable = True ) -> None :
11+ """
12+ Initialize the KAN layer.
13+ """
14+ super ().__init__ ()
15+ self .k = k
16+ self .input_dimensions = input_dimensions
17+ self .output_dimensions = output_dimensions
18+ self .inner_nodes = inner_nodes
19+ self .num = num
20+ self .grid_eps = grid_eps
21+ self .grid_range = grid_range
22+ self .grid_extension = grid_extension
23+
24+ if sparse_init :
25+ self .mask = torch .nn .Parameter (self .sparse_mask (input_dimensions , output_dimensions )).requires_grad_ (False )
26+ else :
27+ self .mask = torch .nn .Parameter (torch .ones (input_dimensions , output_dimensions )).requires_grad_ (False )
28+
29+ grid = torch .linspace (grid_range [0 ], grid_range [1 ], steps = self .num + 1 )[None ,:].expand (self .input_dimensions , self .num + 1 )
30+
31+ if grid_extension :
32+ h = (grid [:, [- 1 ]] - grid [:, [0 ]]) / (grid .shape [1 ] - 1 )
33+ for i in range (self .k ):
34+ grid = torch .cat ([grid [:, [0 ]] - h , grid ], dim = 1 )
35+ grid = torch .cat ([grid , grid [:, [- 1 ]] + h ], dim = 1 )
36+
37+ n_coef = grid .shape [1 ] - (self .k + 1 )
38+
39+ control_points = torch .nn .Parameter (
40+ torch .randn (self .input_dimensions , self .output_dimensions , n_coef ) * noise_scale
41+ )
42+
43+ self .spline = Spline (order = self .k + 1 , knots = grid , control_points = control_points , grid_extension = grid_extension )
44+
45+ self .scale_base = torch .nn .Parameter (scale_base_mu * 1 / np .sqrt (input_dimensions ) + \
46+ scale_base_sigma * (torch .rand (input_dimensions , output_dimensions )* 2 - 1 ) * 1 / np .sqrt (input_dimensions ), requires_grad = sb_trainable )
47+ self .scale_spline = torch .nn .Parameter (torch .ones (input_dimensions , output_dimensions ) * scale_sp * 1 / np .sqrt (input_dimensions ) * self .mask , requires_grad = sp_trainable )
48+ self .base_function = base_function
49+
50+ @staticmethod
51+ def sparse_mask (in_dimensions : int , out_dimensions : int ) -> torch .Tensor :
52+ '''
53+ get sparse mask
54+ '''
55+ in_coord = torch .arange (in_dimensions ) * 1 / in_dimensions + 1 / (2 * in_dimensions )
56+ out_coord = torch .arange (out_dimensions ) * 1 / out_dimensions + 1 / (2 * out_dimensions )
57+
58+ dist_mat = torch .abs (out_coord [:,None ] - in_coord [None ,:])
59+ in_nearest = torch .argmin (dist_mat , dim = 0 )
60+ in_connection = torch .stack ([torch .arange (in_dimensions ), in_nearest ]).permute (1 ,0 )
61+ out_nearest = torch .argmin (dist_mat , dim = 1 )
62+ out_connection = torch .stack ([out_nearest , torch .arange (out_dimensions )]).permute (1 ,0 )
63+ all_connection = torch .cat ([in_connection , out_connection ], dim = 0 )
64+ mask = torch .zeros (in_dimensions , out_dimensions )
65+ mask [all_connection [:,0 ], all_connection [:,1 ]] = 1.
66+ return mask
67+
68+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
69+ """
70+ Forward pass through the KAN layer.
71+ Each input goes through: w_base*base(x) + w_spline*spline(x)
72+ Then sum across input dimensions for each output node.
73+ """
74+ if hasattr (x , 'tensor' ):
75+ x_tensor = x .tensor
76+ else :
77+ x_tensor = x
78+
79+ base = self .base_function (x_tensor ) # (batch, input_dimensions)
80+
81+ basis = self .spline .basis (x_tensor , self .spline .k , self .spline .knots )
82+ spline_out_per_input = torch .einsum ("bil,iol->bio" , basis , self .spline .control_points )
83+
84+ base_term = self .scale_base [None , :, :] * base [:, :, None ]
85+ spline_term = self .scale_spline [None , :, :] * spline_out_per_input
86+ combined = base_term + spline_term
87+ combined = self .mask [None ,:,:] * combined
88+
89+ output = torch .sum (combined , dim = 1 ) # (batch, output_dimensions)
90+
91+ return output
92+
93+ def update_grid_from_samples (self , x : torch .Tensor , mode : str = 'sample' ):
94+ """
95+ Update grid from input samples to better fit data distribution.
96+ Based on PyKAN implementation but with boundary preservation.
97+ """
98+ # Convert LabelTensor to regular tensor for spline operations
99+ if hasattr (x , 'tensor' ):
100+ # This is a LabelTensor, extract the tensor part
101+ x_tensor = x .tensor
102+ else :
103+ x_tensor = x
104+
105+ with torch .no_grad ():
106+ batch_size = x_tensor .shape [0 ]
107+ x_sorted = torch .sort (x_tensor , dim = 0 )[0 ] # (batch_size, input_dimensions)
108+
109+ # Get current number of intervals (excluding extensions)
110+ if self .grid_extension :
111+ num_interval = self .spline .knots .shape [1 ] - 1 - 2 * self .k
112+ else :
113+ num_interval = self .spline .knots .shape [1 ] - 1
114+
115+ def get_grid (num_intervals : int ):
116+ """PyKAN-style grid creation with boundary preservation"""
117+ ids = [int (batch_size * i / num_intervals ) for i in range (num_intervals )] + [- 1 ]
118+ grid_adaptive = x_sorted [ids , :].transpose (0 , 1 ) # (input_dimensions, num_intervals+1)
119+
120+ original_min = self .grid_range [0 ]
121+ original_max = self .grid_range [1 ]
122+
123+ # Clamp adaptive grid to not shrink beyond original domain
124+ grid_adaptive [:, 0 ] = torch .min (grid_adaptive [:, 0 ],
125+ torch .full_like (grid_adaptive [:, 0 ], original_min ))
126+ grid_adaptive [:, - 1 ] = torch .max (grid_adaptive [:, - 1 ],
127+ torch .full_like (grid_adaptive [:, - 1 ], original_max ))
128+
129+ margin = 0.0
130+ h = (grid_adaptive [:, [- 1 ]] - grid_adaptive [:, [0 ]] + 2 * margin ) / num_intervals
131+ grid_uniform = (grid_adaptive [:, [0 ]] - margin +
132+ h * torch .arange (num_intervals + 1 , device = x_tensor .device , dtype = x_tensor .dtype )[None , :])
133+
134+ grid_blended = (self .grid_eps * grid_uniform +
135+ (1 - self .grid_eps ) * grid_adaptive )
136+
137+ return grid_blended
138+
139+ # Create augmented evaluation points: samples + boundary points
140+ # This ensures we preserve boundary behavior while adapting to sample density
141+ boundary_points = torch .tensor ([[self .grid_range [0 ]], [self .grid_range [1 ]]],
142+ device = x_tensor .device , dtype = x_tensor .dtype ).expand (- 1 , self .input_dimensions )
143+
144+ # Combine samples with boundary points for evaluation
145+ x_augmented = torch .cat ([x_sorted , boundary_points ], dim = 0 )
146+ x_augmented = torch .sort (x_augmented , dim = 0 )[0 ] # Re-sort with boundaries included
147+
148+ # Evaluate current spline at augmented points (samples + boundaries)
149+ basis = self .spline .basis (x_augmented , self .spline .k , self .spline .knots )
150+ y_eval = torch .einsum ("bil,iol->bio" , basis , self .spline .control_points )
151+
152+ # Create new grid
153+ new_grid = get_grid (num_interval )
154+
155+ if mode == 'grid' :
156+ # For 'grid' mode, use denser sampling
157+ sample_grid = get_grid (2 * num_interval )
158+ x_augmented = sample_grid .transpose (0 , 1 ) # (batch_size, input_dimensions)
159+ basis = self .spline .basis (x_augmented , self .spline .k , self .spline .knots )
160+ y_eval = torch .einsum ("bil,iol->bio" , basis , self .spline .control_points )
161+
162+ # Add grid extensions if needed
163+ if self .grid_extension :
164+ h = (new_grid [:, [- 1 ]] - new_grid [:, [0 ]]) / (new_grid .shape [1 ] - 1 )
165+ for i in range (self .k ):
166+ new_grid = torch .cat ([new_grid [:, [0 ]] - h , new_grid ], dim = 1 )
167+ new_grid = torch .cat ([new_grid , new_grid [:, [- 1 ]] + h ], dim = 1 )
168+
169+ # Update grid and refit coefficients
170+ self .spline .knots = new_grid
171+
172+ try :
173+ # Refit coefficients using augmented points (preserves boundaries)
174+ self .spline .compute_control_points (x_augmented , y_eval )
175+ except Exception as e :
176+ print (f"Warning: Failed to update coefficients during grid refinement: { e } " )
177+
178+ def update_grid_resolution (self , new_num : int ):
179+ """
180+ Update grid resolution to a new number of intervals.
181+ """
182+ with torch .no_grad ():
183+ # Sample the current spline function on a dense grid
184+ x_eval = torch .linspace (
185+ self .grid_range [0 ],
186+ self .grid_range [1 ],
187+ steps = 2 * new_num ,
188+ device = self .spline .knots .device
189+ )
190+ x_eval = x_eval .unsqueeze (1 ).expand (- 1 , self .input_dimensions )
191+
192+ basis = self .spline .basis (x_eval , self .spline .k , self .spline .knots )
193+ y_eval = torch .einsum ("bil,iol->bio" , basis , self .spline .control_points )
194+
195+ # Update num and create a new grid
196+ self .num = new_num
197+ new_grid = torch .linspace (
198+ self .grid_range [0 ],
199+ self .grid_range [1 ],
200+ steps = self .num + 1 ,
201+ device = self .spline .knots .device
202+ )
203+ new_grid = new_grid [None , :].expand (self .input_dimensions , self .num + 1 )
204+
205+ if self .grid_extension :
206+ h = (new_grid [:, [- 1 ]] - new_grid [:, [0 ]]) / (new_grid .shape [1 ] - 1 )
207+ for i in range (self .k ):
208+ new_grid = torch .cat ([new_grid [:, [0 ]] - h , new_grid ], dim = 1 )
209+ new_grid = torch .cat ([new_grid , new_grid [:, [- 1 ]] + h ], dim = 1 )
210+
211+ # Update spline with the new grid and re-compute control points
212+ self .spline .knots = new_grid
213+ self .spline .compute_control_points (x_eval , y_eval )
214+
215+ def get_grid_statistics (self ):
216+ """Get statistics about the current grid for debugging/analysis"""
217+ return {
218+ 'grid_shape' : self .spline .knots .shape ,
219+ 'grid_min' : self .spline .knots .min ().item (),
220+ 'grid_max' : self .spline .knots .max ().item (),
221+ 'grid_range' : (self .spline .knots .max () - self .spline .knots .min ()).mean ().item (),
222+ 'num_intervals' : self .spline .knots .shape [1 ] - 1 - (2 * self .k if self .spline .grid_extension else 0 )
223+ }
0 commit comments