Skip to content

Commit ad8a27f

Browse files
add tests
1 parent d9b59ff commit ad8a27f

File tree

4 files changed

+473
-170
lines changed

4 files changed

+473
-170
lines changed

tests/test_block/test_kan_block.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import torch
2+
import pytest
3+
from pina.model.block import KANBlock
4+
5+
# Data
6+
input_dim = 3
7+
data = torch.rand((10, input_dim))
8+
9+
10+
@pytest.mark.parametrize("output_dimensions", [1, 5])
11+
@pytest.mark.parametrize("spline_order", [3, 4])
12+
@pytest.mark.parametrize("n_knots", [10, 20])
13+
@pytest.mark.parametrize("init_scale_spline", [1e-2, 1e-1])
14+
@pytest.mark.parametrize("init_scale_base", [1.0, 0.1])
15+
def test_constructor(
16+
output_dimensions, spline_order, n_knots, init_scale_spline, init_scale_base
17+
):
18+
19+
KANBlock(
20+
input_dimensions=data.shape[1],
21+
output_dimensions=output_dimensions,
22+
spline_order=spline_order,
23+
n_knots=n_knots,
24+
init_scale_spline=init_scale_spline,
25+
init_scale_base=init_scale_base,
26+
)
27+
28+
# Should fail if input_dimensions is not a positive integer
29+
with pytest.raises(AssertionError):
30+
KANBlock(input_dimensions=-1, output_dimensions=output_dimensions)
31+
32+
# Should fail if output_dimensions is not a positive integer
33+
with pytest.raises(AssertionError):
34+
KANBlock(input_dimensions=data.shape[1], output_dimensions=-1)
35+
36+
# Should fail if spline_order is not a positive integer
37+
with pytest.raises(AssertionError):
38+
KANBlock(
39+
input_dimensions=data.shape[1],
40+
output_dimensions=output_dimensions,
41+
spline_order=-1,
42+
)
43+
44+
# Should fail if n_knots is not a positive integer
45+
with pytest.raises(AssertionError):
46+
KANBlock(
47+
input_dimensions=data.shape[1],
48+
output_dimensions=output_dimensions,
49+
n_knots=-1,
50+
)
51+
52+
# Should fail if grid_range is not of length 2
53+
with pytest.raises(ValueError):
54+
KANBlock(
55+
input_dimensions=data.shape[1],
56+
output_dimensions=output_dimensions,
57+
grid_range=[-1, 0, 1],
58+
)
59+
60+
# Should fail if base_function is not a torch.nn.Module subclass
61+
with pytest.raises(ValueError):
62+
KANBlock(
63+
input_dimensions=data.shape[1],
64+
output_dimensions=output_dimensions,
65+
base_function="not_a_module",
66+
)
67+
68+
# Should fail if use_base_linear is not a boolean
69+
with pytest.raises(ValueError):
70+
KANBlock(
71+
input_dimensions=data.shape[1],
72+
output_dimensions=output_dimensions,
73+
use_base_linear="not_a_bool",
74+
)
75+
76+
# Should fail if use_bias is not a boolean
77+
with pytest.raises(ValueError):
78+
KANBlock(
79+
input_dimensions=data.shape[1],
80+
output_dimensions=output_dimensions,
81+
use_bias="not_a_bool",
82+
)
83+
84+
# Should fail if init_scale_spline is not a float or int
85+
with pytest.raises(ValueError):
86+
KANBlock(
87+
input_dimensions=data.shape[1],
88+
output_dimensions=output_dimensions,
89+
init_scale_spline="not_a_number",
90+
)
91+
92+
# Should fail if init_scale_base is not a float or int
93+
with pytest.raises(ValueError):
94+
KANBlock(
95+
input_dimensions=data.shape[1],
96+
output_dimensions=output_dimensions,
97+
init_scale_base="not_a_number",
98+
)
99+
100+
101+
@pytest.mark.parametrize("output_dimensions", [1, 5])
102+
@pytest.mark.parametrize("spline_order", [3, 4])
103+
@pytest.mark.parametrize("n_knots", [10, 20])
104+
@pytest.mark.parametrize("init_scale_spline", [1e-2, 1e-1])
105+
@pytest.mark.parametrize("init_scale_base", [1.0, 0.1])
106+
def test_forward(
107+
output_dimensions, spline_order, n_knots, init_scale_spline, init_scale_base
108+
):
109+
110+
model = KANBlock(
111+
input_dimensions=data.shape[1],
112+
output_dimensions=output_dimensions,
113+
spline_order=spline_order,
114+
n_knots=n_knots,
115+
init_scale_spline=init_scale_spline,
116+
init_scale_base=init_scale_base,
117+
)
118+
119+
output_ = model(data)
120+
assert output_.shape == (data.shape[0], output_dimensions)
121+
122+
123+
@pytest.mark.parametrize("output_dimensions", [1, 5])
124+
@pytest.mark.parametrize("spline_order", [3, 4])
125+
@pytest.mark.parametrize("n_knots", [10, 20])
126+
@pytest.mark.parametrize("init_scale_spline", [1e-2, 1e-1])
127+
@pytest.mark.parametrize("init_scale_base", [1.0, 0.1])
128+
def test_backward(
129+
output_dimensions, spline_order, n_knots, init_scale_spline, init_scale_base
130+
):
131+
132+
model = KANBlock(
133+
input_dimensions=data.shape[1],
134+
output_dimensions=output_dimensions,
135+
spline_order=spline_order,
136+
n_knots=n_knots,
137+
init_scale_spline=init_scale_spline,
138+
init_scale_base=init_scale_base,
139+
)
140+
141+
data.requires_grad_()
142+
output_ = model(data)
143+
144+
loss = torch.mean(output_)
145+
loss.backward()
146+
assert data.grad.shape == data.shape
Lines changed: 67 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,153 +1,83 @@
11
import torch
22
import pytest
3-
43
from pina.model import KolmogorovArnoldNetwork
54

6-
data = torch.rand((20, 3))
7-
input_vars = 3
8-
output_vars = 1
5+
# Data
6+
input_dim = 3
7+
data = torch.rand((10, input_dim))
98

109

11-
def test_constructor():
12-
KolmogorovArnoldNetwork([input_vars, output_vars])
13-
KolmogorovArnoldNetwork([input_vars, 10, 20, output_vars])
14-
KolmogorovArnoldNetwork(
15-
[input_vars, 10, 20, output_vars],
16-
k=3,
17-
num=5
18-
)
19-
KolmogorovArnoldNetwork(
20-
[input_vars, 10, 20, output_vars],
21-
k=3,
22-
num=5,
23-
grid_eps=0.05,
24-
grid_range=[-2, 2]
25-
)
10+
@pytest.mark.parametrize("use_base_linear", [True, False])
11+
@pytest.mark.parametrize("use_bias", [True, False])
12+
@pytest.mark.parametrize("grid_range", [[-1, 1], [0, 2]])
13+
@pytest.mark.parametrize("layers", [[input_dim, 5, 1], [input_dim, 2]])
14+
def test_constructor(use_base_linear, use_bias, grid_range, layers):
15+
16+
# Constructor
2617
KolmogorovArnoldNetwork(
27-
[input_vars, 10, output_vars],
28-
base_function=torch.nn.Tanh(),
29-
scale_sp=0.5,
30-
sparse_init=True
18+
layers=layers,
19+
spline_order=3,
20+
n_knots=10,
21+
grid_range=grid_range,
22+
base_function=torch.nn.SiLU,
23+
use_base_linear=use_base_linear,
24+
use_bias=use_bias,
25+
init_scale_spline=1e-2,
26+
init_scale_base=1.0,
3127
)
3228

33-
34-
def test_constructor_wrong():
29+
# Should fail if grid_range is not of length 2
3530
with pytest.raises(ValueError):
36-
KolmogorovArnoldNetwork([input_vars])
37-
with pytest.raises(ValueError):
38-
KolmogorovArnoldNetwork([])
39-
40-
41-
def test_forward():
42-
dim_in, dim_out = 3, 2
43-
kan = KolmogorovArnoldNetwork([dim_in, dim_out])
44-
output_ = kan(data)
45-
assert output_.shape == (data.shape[0], dim_out)
31+
KolmogorovArnoldNetwork(layers=layers, grid_range=[-1, 0, 1])
4632

33+
# Should fail if layers has less than 2 elements
34+
with pytest.raises(ValueError):
35+
KolmogorovArnoldNetwork(layers=[input_dim])
36+
37+
38+
@pytest.mark.parametrize("use_base_linear", [True, False])
39+
@pytest.mark.parametrize("use_bias", [True, False])
40+
@pytest.mark.parametrize("grid_range", [[-1, 1], [0, 2]])
41+
@pytest.mark.parametrize("layers", [[input_dim, 5, 1], [input_dim, 2]])
42+
def test_forward(use_base_linear, use_bias, grid_range, layers):
43+
44+
model = KolmogorovArnoldNetwork(
45+
layers=layers,
46+
spline_order=3,
47+
n_knots=10,
48+
grid_range=grid_range,
49+
base_function=torch.nn.SiLU,
50+
use_base_linear=use_base_linear,
51+
use_bias=use_bias,
52+
init_scale_spline=1e-2,
53+
init_scale_base=1.0,
54+
)
4755

48-
def test_forward_multilayer():
49-
dim_in, dim_out = 3, 2
50-
kan = KolmogorovArnoldNetwork([dim_in, 10, 5, dim_out])
51-
output_ = kan(data)
52-
assert output_.shape == (data.shape[0], dim_out)
56+
output_ = model(data)
57+
assert output_.shape == (data.shape[0], layers[-1])
58+
59+
60+
@pytest.mark.parametrize("use_base_linear", [True, False])
61+
@pytest.mark.parametrize("use_bias", [True, False])
62+
@pytest.mark.parametrize("grid_range", [[-1, 1], [0, 2]])
63+
@pytest.mark.parametrize("layers", [[input_dim, 5, 1], [input_dim, 2]])
64+
def test_backward(use_base_linear, use_bias, grid_range, layers):
65+
66+
model = KolmogorovArnoldNetwork(
67+
layers=layers,
68+
spline_order=3,
69+
n_knots=10,
70+
grid_range=grid_range,
71+
base_function=torch.nn.SiLU,
72+
use_base_linear=use_base_linear,
73+
use_bias=use_bias,
74+
init_scale_spline=1e-2,
75+
init_scale_base=1.0,
76+
)
5377

78+
data.requires_grad_()
79+
output_ = model(data)
5480

55-
def test_backward():
56-
dim_in, dim_out = 3, 2
57-
kan = KolmogorovArnoldNetwork([dim_in, dim_out])
58-
data.requires_grad = True
59-
output_ = kan(data)
6081
loss = torch.mean(output_)
6182
loss.backward()
62-
assert data._grad.shape == torch.Size([20, 3])
63-
64-
65-
def test_get_num_parameters():
66-
kan = KolmogorovArnoldNetwork([3, 5, 2])
67-
num_params = kan.get_num_parameters()
68-
assert num_params > 0
69-
assert isinstance(num_params, int)
70-
71-
from pina.problem.zoo import Poisson2DSquareProblem
72-
from pina.solver import PINN
73-
from pina.trainer import Trainer
74-
75-
def test_train_poisson():
76-
problem = Poisson2DSquareProblem()
77-
problem.discretise_domain(n=10, mode="random", domains="all")
78-
79-
model = KolmogorovArnoldNetwork([2, 3, 1], k=3, num=5)
80-
solver = PINN(model=model, problem=problem)
81-
trainer = Trainer(
82-
solver=solver,
83-
max_epochs=10,
84-
accelerator="cpu",
85-
batch_size=100,
86-
train_size=1.0,
87-
val_size=0.0,
88-
test_size=0.0,
89-
)
90-
trainer.train()
91-
92-
93-
94-
# def test_update_grid_from_samples():
95-
# kan = KolmogorovArnoldNetwork([3, 5, 2])
96-
# samples = torch.randn(50, 3)
97-
# kan.update_grid_from_samples(samples, mode='sample')
98-
# # Check that the network still works after grid update
99-
# output = kan(data)
100-
# assert output.shape == (data.shape[0], 2)
101-
102-
103-
# def test_update_grid_resolution():
104-
# kan = KolmogorovArnoldNetwork([3, 5, 2], num=3)
105-
# kan.update_grid_resolution(5)
106-
# # Check that the network still works after resolution update
107-
# output = kan(data)
108-
# assert output.shape == (data.shape[0], 2)
109-
110-
111-
# def test_enable_sparsification():
112-
# kan = KolmogorovArnoldNetwork([3, 5, 2])
113-
# kan.enable_sparsification(threshold=1e-4)
114-
# # Check that the network still works after sparsification
115-
# output = kan(data)
116-
# assert output.shape == (data.shape[0], 2)
117-
118-
119-
# def test_get_activation_statistics():
120-
# kan = KolmogorovArnoldNetwork([3, 5, 2])
121-
# stats = kan.get_activation_statistics(data)
122-
# assert isinstance(stats, dict)
123-
# assert 'layer_0' in stats
124-
# assert 'layer_1' in stats
125-
# assert 'mean' in stats['layer_0']
126-
# assert 'std' in stats['layer_0']
127-
# assert 'min' in stats['layer_0']
128-
# assert 'max' in stats['layer_0']
129-
130-
131-
# def test_get_network_grid_statistics():
132-
# kan = KolmogorovArnoldNetwork([3, 5, 2])
133-
# stats = kan.get_network_grid_statistics()
134-
# assert isinstance(stats, dict)
135-
# assert 'layer_0' in stats
136-
# assert 'layer_1' in stats
137-
138-
139-
# def test_save_act():
140-
# kan = KolmogorovArnoldNetwork([3, 5, 2], save_act=True)
141-
# output = kan(data)
142-
# assert hasattr(kan, 'acts')
143-
# assert len(kan.acts) == 3 # input + 2 layers
144-
# assert kan.acts[0].shape == data.shape
145-
# assert kan.acts[-1].shape == output.shape
146-
147-
148-
# def test_save_act_disabled():
149-
# kan = KolmogorovArnoldNetwork([3, 5, 2], save_act=False)
150-
# _ = kan(data)
151-
# assert hasattr(kan, 'acts')
152-
# # Only the first activation (input) is saved
153-
# assert len(kan.acts) == 1
83+
assert data.grad.shape == data.shape

0 commit comments

Comments
 (0)