55from pina .model .spline import Spline
66
77
8- class KAN_layer (torch .nn .Module ):
8+ class KANBlock (torch .nn .Module ):
99 """define a KAN layer using splines"""
1010 def __init__ (self , k : int , input_dimensions : int , output_dimensions : int , inner_nodes : int , num = 3 , grid_eps = 0.1 , grid_range = [- 1 , 1 ], grid_extension = True , noise_scale = 0.1 , base_function = torch .nn .SiLU (), scale_base_mu = 0.0 , scale_base_sigma = 1.0 , scale_sp = 1.0 , sparse_init = True , sp_trainable = True , sb_trainable = True ) -> None :
1111 """
1212 Initialize the KAN layer.
13+
14+ num è il numero di intervalli nella griglia iniziale (esclusi gli eventuali nodi di estensione)
1315 """
1416 super ().__init__ ()
1517 self .k = k
@@ -27,24 +29,46 @@ def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_
2729 self .mask = torch .nn .Parameter (torch .ones (input_dimensions , output_dimensions )).requires_grad_ (False )
2830
2931 grid = torch .linspace (grid_range [0 ], grid_range [1 ], steps = self .num + 1 )[None ,:].expand (self .input_dimensions , self .num + 1 )
32+ knots = torch .linspace (grid_range [0 ], grid_range [1 ], steps = self .num + 1 )
3033
3134 if grid_extension :
3235 h = (grid [:, [- 1 ]] - grid [:, [0 ]]) / (grid .shape [1 ] - 1 )
3336 for i in range (self .k ):
3437 grid = torch .cat ([grid [:, [0 ]] - h , grid ], dim = 1 )
3538 grid = torch .cat ([grid , grid [:, [- 1 ]] + h ], dim = 1 )
3639
37- n_coef = grid . shape [ 1 ] - (self .k + 1 )
40+ n_control_points = len ( knots ) - (self .k )
3841
39- control_points = torch .nn .Parameter (
40- torch .randn (self .input_dimensions , self .output_dimensions , n_coef ) * noise_scale
41- )
42+ # control_points = torch.nn.Parameter(
43+ # torch.randn(self.input_dimensions, self.output_dimensions, n_control_points) * noise_scale
44+ # )
45+ # print(control_points.shape)
46+ spline_q = []
47+ for q in range (self .output_dimensions ):
48+ spline_p = []
49+ for p in range (self .input_dimensions ):
50+ spline_ = Spline (
51+ order = self .k ,
52+ knots = knots ,
53+ control_points = torch .randn (n_control_points )
54+ )
55+ spline_p .append (spline_ )
56+ spline_p = torch .nn .ModuleList (spline_p )
57+ spline_q .append (spline_p )
58+ self .spline_q = torch .nn .ModuleList (spline_q )
59+
60+
61+ # control_points = torch.nn.Parameter(
62+ # torch.randn(n_control_points, self.output_dimensions) * noise_scale)
63+ # print(control_points)
64+ # print('uuu')
4265
43- self .spline = Spline (order = self .k + 1 , knots = grid , control_points = control_points , grid_extension = grid_extension )
66+ # self.spline = Spline(
67+ # order=self.k, knots=knots, control_points=control_points)
4468
45- self .scale_base = torch .nn .Parameter (scale_base_mu * 1 / np .sqrt (input_dimensions ) + \
46- scale_base_sigma * (torch .rand (input_dimensions , output_dimensions )* 2 - 1 ) * 1 / np .sqrt (input_dimensions ), requires_grad = sb_trainable )
47- self .scale_spline = torch .nn .Parameter (torch .ones (input_dimensions , output_dimensions ) * scale_sp * 1 / np .sqrt (input_dimensions ) * self .mask , requires_grad = sp_trainable )
69+ # self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(input_dimensions) + \
70+ # scale_base_sigma * (torch.rand(input_dimensions, output_dimensions)*2-1) * 1/np.sqrt(input_dimensions), requires_grad=sb_trainable)
71+ # self.scale_spline = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions) * scale_sp * 1 / np.sqrt(input_dimensions) * self.mask, requires_grad=sp_trainable)
4872 self .base_function = base_function
4973
5074 @staticmethod
@@ -76,19 +100,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
76100 else :
77101 x_tensor = x
78102
79- base = self .base_function (x_tensor ) # (batch, input_dimensions)
80-
81- basis = self .spline .basis (x_tensor , self .spline .k , self .spline .knots )
82- spline_out_per_input = torch .einsum ("bil,iol->bio" , basis , self .spline .control_points )
83-
84- base_term = self .scale_base [None , :, :] * base [:, :, None ]
85- spline_term = self .scale_spline [None , :, :] * spline_out_per_input
86- combined = base_term + spline_term
87- combined = self .mask [None ,:,:] * combined
88-
89- output = torch .sum (combined , dim = 1 ) # (batch, output_dimensions)
90-
91- return output
103+ y = []
104+ for q in range (self .output_dimensions ):
105+ y_q = []
106+ for p in range (self .input_dimensions ):
107+ spline_out = self .spline_q [q ][p ].forward (x_tensor [:, p ]) # (batch, input_dimensions, output_dimensions)
108+ base_out = self .base_function (x_tensor [:, p ]) # (batch, input_dimensions)
109+ y_q .append (spline_out + base_out )
110+ y .append (torch .stack (y_q , dim = 1 ).sum (dim = 1 ))
111+ y = torch .stack (y , dim = 1 )
112+
113+ return y
92114
93115 def update_grid_from_samples (self , x : torch .Tensor , mode : str = 'sample' ):
94116 """
0 commit comments