Skip to content

Commit 368dbbc

Browse files
committed
minor fix
1 parent 249fcb5 commit 368dbbc

3 files changed

Lines changed: 13 additions & 11 deletions

File tree

pina/_src/model/block/kan_block.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
import numpy as np
44

55
from pina._src.model.spline import Spline
6+
from pina._src.model.vectorized_spline import VectorizedSpline
67

78

89
class KANBlock(torch.nn.Module):
910
"""define a KAN layer using splines"""
10-
def __init__(self, k, input_dimensions, output_dimensions, inner_nodes, num=3, grid_eps=0.1, grid_range=[-1, 1], grid_extension=True, noise_scale=0.1, base_function=torch.nn.SiLU(), scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, sparse_init=True, sp_trainable=True, sb_trainable=True):
11+
def __init__(self, k, input_dimensions, output_dimensions, inner_nodes,
12+
num=3, grid_eps=0.1, grid_range=[-1, 1], grid_extension=True,
13+
noise_scale=0.1, base_function=torch.nn.SiLU(), scale_base_mu=0.0,
14+
scale_base_sigma=1.0, scale_sp=1.0, sparse_init=True, sp_trainable=True,
15+
sb_trainable=True):
1116
"""
1217
Initialize the KAN layer.
1318
@@ -46,7 +51,6 @@ def __init__(self, k, input_dimensions, output_dimensions, inner_nodes, num=3, g
4651
# )
4752
# print(control_points.shape)
4853
if self.vec:
49-
from pina.model.spline import SplineVectorized as VectorizedSpline
5054
control_points = torch.randn(self.input_dimensions * self.output_dimensions, n_control_points)
5155
print('control points', control_points.shape)
5256
control_points = torch.stack([

tests/test_model/test_kolmogorov_arnold_network.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import pytest
33

4-
from pina.model.block import KANBlock
54
from pina.model import KolmogorovArnoldNetwork
65

76
data = torch.rand((20, 3))
@@ -81,16 +80,14 @@ def test_train_poisson():
8180
solver = PINN(model=model, problem=problem)
8281
trainer = Trainer(
8382
solver=solver,
84-
max_epochs=1000,
83+
max_epochs=10,
8584
accelerator="cpu",
8685
batch_size=100,
8786
train_size=1.0,
8887
val_size=0.0,
8988
test_size=0.0,
9089
)
9190
trainer.train()
92-
assert False
93-
9491

9592

9693

tests/test_model/test_spline.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,22 +193,22 @@ def test_derivative(args, pts):
193193
assert torch.allclose(first_der, first_der_auto, atol=1e-4, rtol=1e-4)
194194

195195

196-
@pytest.mark.parametrize("out_dim", [1, 3, 5])
197-
def test_vectorized(out_dim):
196+
#@pytest.mark.parametrize("args", valid_args) # TODO
197+
def test_vectorized():
198198

199199
N = 7
200200
cps = []
201201
splines = []
202202
for i in range(N):
203-
cp = torch.rand(n_ctrl_pts, 3)
203+
cp = torch.rand(n_ctrl_pts)
204204
cps.append(cp)
205205
spline = Spline(
206206
order=order,
207207
control_points=cp
208208
)
209209
splines.append(spline)
210210

211-
from pina.model.spline import SplineVectorized as VectorizedSpline
211+
from pina.model import VectorizedSpline
212212
unique_cps = torch.stack(cps, dim=0)
213213
print(unique_cps.shape)
214214
print(cps[0].shape)
@@ -224,7 +224,8 @@ def test_vectorized(out_dim):
224224
result_single = torch.stack([
225225
splines[i](x) for i in range(N)
226226
])
227-
result_single = result_single.squeeze(2).permute(1, 0, 2)
227+
print(result_single.shape)
228+
result_single = result_single.permute(1, 2, 0)
228229
out_vectorized = vectorized_spline(x)
229230
print(out_vectorized.shape)
230231
print(result_single.shape)

0 commit comments

Comments
 (0)