|
1 | 1 | import torch |
2 | 2 | import pytest |
3 | | - |
4 | 3 | from pina.model import KolmogorovArnoldNetwork |
5 | 4 |
|
6 | | -data = torch.rand((20, 3)) |
7 | | -input_vars = 3 |
8 | | -output_vars = 1 |
| 5 | +# Data |
| 6 | +input_dim = 3 |
| 7 | +data = torch.rand((10, input_dim)) |
9 | 8 |
|
10 | 9 |
|
11 | | -def test_constructor(): |
12 | | - KolmogorovArnoldNetwork([input_vars, output_vars]) |
13 | | - KolmogorovArnoldNetwork([input_vars, 10, 20, output_vars]) |
14 | | - KolmogorovArnoldNetwork( |
15 | | - [input_vars, 10, 20, output_vars], |
16 | | - k=3, |
17 | | - num=5 |
18 | | - ) |
19 | | - KolmogorovArnoldNetwork( |
20 | | - [input_vars, 10, 20, output_vars], |
21 | | - k=3, |
22 | | - num=5, |
23 | | - grid_eps=0.05, |
24 | | - grid_range=[-2, 2] |
25 | | - ) |
| 10 | +@pytest.mark.parametrize("use_base_linear", [True, False]) |
| 11 | +@pytest.mark.parametrize("use_bias", [True, False]) |
| 12 | +@pytest.mark.parametrize("grid_range", [[-1, 1], [0, 2]]) |
| 13 | +@pytest.mark.parametrize("layers", [[input_dim, 5, 1], [input_dim, 2]]) |
| 14 | +def test_constructor(use_base_linear, use_bias, grid_range, layers): |
| 15 | + |
| 16 | + # Constructor |
26 | 17 | KolmogorovArnoldNetwork( |
27 | | - [input_vars, 10, output_vars], |
28 | | - base_function=torch.nn.Tanh(), |
29 | | - scale_sp=0.5, |
30 | | - sparse_init=True |
| 18 | + layers=layers, |
| 19 | + spline_order=3, |
| 20 | + n_knots=10, |
| 21 | + grid_range=grid_range, |
| 22 | + base_function=torch.nn.SiLU, |
| 23 | + use_base_linear=use_base_linear, |
| 24 | + use_bias=use_bias, |
| 25 | + init_scale_spline=1e-2, |
| 26 | + init_scale_base=1.0, |
31 | 27 | ) |
32 | 28 |
|
33 | | - |
34 | | -def test_constructor_wrong(): |
| 29 | + # Should fail if grid_range is not of length 2 |
35 | 30 | with pytest.raises(ValueError): |
36 | | - KolmogorovArnoldNetwork([input_vars]) |
37 | | - with pytest.raises(ValueError): |
38 | | - KolmogorovArnoldNetwork([]) |
39 | | - |
40 | | - |
41 | | -def test_forward(): |
42 | | - dim_in, dim_out = 3, 2 |
43 | | - kan = KolmogorovArnoldNetwork([dim_in, dim_out]) |
44 | | - output_ = kan(data) |
45 | | - assert output_.shape == (data.shape[0], dim_out) |
| 31 | + KolmogorovArnoldNetwork(layers=layers, grid_range=[-1, 0, 1]) |
46 | 32 |
|
| 33 | + # Should fail if layers has less than 2 elements |
| 34 | + with pytest.raises(ValueError): |
| 35 | + KolmogorovArnoldNetwork(layers=[input_dim]) |
| 36 | + |
| 37 | + |
| 38 | +@pytest.mark.parametrize("use_base_linear", [True, False]) |
| 39 | +@pytest.mark.parametrize("use_bias", [True, False]) |
| 40 | +@pytest.mark.parametrize("grid_range", [[-1, 1], [0, 2]]) |
| 41 | +@pytest.mark.parametrize("layers", [[input_dim, 5, 1], [input_dim, 2]]) |
| 42 | +def test_forward(use_base_linear, use_bias, grid_range, layers): |
| 43 | + |
| 44 | + model = KolmogorovArnoldNetwork( |
| 45 | + layers=layers, |
| 46 | + spline_order=3, |
| 47 | + n_knots=10, |
| 48 | + grid_range=grid_range, |
| 49 | + base_function=torch.nn.SiLU, |
| 50 | + use_base_linear=use_base_linear, |
| 51 | + use_bias=use_bias, |
| 52 | + init_scale_spline=1e-2, |
| 53 | + init_scale_base=1.0, |
| 54 | + ) |
47 | 55 |
|
48 | | -def test_forward_multilayer(): |
49 | | - dim_in, dim_out = 3, 2 |
50 | | - kan = KolmogorovArnoldNetwork([dim_in, 10, 5, dim_out]) |
51 | | - output_ = kan(data) |
52 | | - assert output_.shape == (data.shape[0], dim_out) |
| 56 | + output_ = model(data) |
| 57 | + assert output_.shape == (data.shape[0], layers[-1]) |
| 58 | + |
| 59 | + |
| 60 | +@pytest.mark.parametrize("use_base_linear", [True, False]) |
| 61 | +@pytest.mark.parametrize("use_bias", [True, False]) |
| 62 | +@pytest.mark.parametrize("grid_range", [[-1, 1], [0, 2]]) |
| 63 | +@pytest.mark.parametrize("layers", [[input_dim, 5, 1], [input_dim, 2]]) |
| 64 | +def test_backward(use_base_linear, use_bias, grid_range, layers): |
| 65 | + |
| 66 | + model = KolmogorovArnoldNetwork( |
| 67 | + layers=layers, |
| 68 | + spline_order=3, |
| 69 | + n_knots=10, |
| 70 | + grid_range=grid_range, |
| 71 | + base_function=torch.nn.SiLU, |
| 72 | + use_base_linear=use_base_linear, |
| 73 | + use_bias=use_bias, |
| 74 | + init_scale_spline=1e-2, |
| 75 | + init_scale_base=1.0, |
| 76 | + ) |
53 | 77 |
|
| 78 | + data.requires_grad_() |
| 79 | + output_ = model(data) |
54 | 80 |
|
55 | | -def test_backward(): |
56 | | - dim_in, dim_out = 3, 2 |
57 | | - kan = KolmogorovArnoldNetwork([dim_in, dim_out]) |
58 | | - data.requires_grad = True |
59 | | - output_ = kan(data) |
60 | 81 | loss = torch.mean(output_) |
61 | 82 | loss.backward() |
62 | | - assert data._grad.shape == torch.Size([20, 3]) |
63 | | - |
64 | | - |
65 | | -def test_get_num_parameters(): |
66 | | - kan = KolmogorovArnoldNetwork([3, 5, 2]) |
67 | | - num_params = kan.get_num_parameters() |
68 | | - assert num_params > 0 |
69 | | - assert isinstance(num_params, int) |
70 | | - |
71 | | -from pina.problem.zoo import Poisson2DSquareProblem |
72 | | -from pina.solver import PINN |
73 | | -from pina.trainer import Trainer |
74 | | - |
75 | | -def test_train_poisson(): |
76 | | - problem = Poisson2DSquareProblem() |
77 | | - problem.discretise_domain(n=10, mode="random", domains="all") |
78 | | - |
79 | | - model = KolmogorovArnoldNetwork([2, 3, 1], k=3, num=5) |
80 | | - solver = PINN(model=model, problem=problem) |
81 | | - trainer = Trainer( |
82 | | - solver=solver, |
83 | | - max_epochs=10, |
84 | | - accelerator="cpu", |
85 | | - batch_size=100, |
86 | | - train_size=1.0, |
87 | | - val_size=0.0, |
88 | | - test_size=0.0, |
89 | | - ) |
90 | | - trainer.train() |
91 | | - |
92 | | - |
93 | | - |
94 | | -# def test_update_grid_from_samples(): |
95 | | -# kan = KolmogorovArnoldNetwork([3, 5, 2]) |
96 | | -# samples = torch.randn(50, 3) |
97 | | -# kan.update_grid_from_samples(samples, mode='sample') |
98 | | -# # Check that the network still works after grid update |
99 | | -# output = kan(data) |
100 | | -# assert output.shape == (data.shape[0], 2) |
101 | | - |
102 | | - |
103 | | -# def test_update_grid_resolution(): |
104 | | -# kan = KolmogorovArnoldNetwork([3, 5, 2], num=3) |
105 | | -# kan.update_grid_resolution(5) |
106 | | -# # Check that the network still works after resolution update |
107 | | -# output = kan(data) |
108 | | -# assert output.shape == (data.shape[0], 2) |
109 | | - |
110 | | - |
111 | | -# def test_enable_sparsification(): |
112 | | -# kan = KolmogorovArnoldNetwork([3, 5, 2]) |
113 | | -# kan.enable_sparsification(threshold=1e-4) |
114 | | -# # Check that the network still works after sparsification |
115 | | -# output = kan(data) |
116 | | -# assert output.shape == (data.shape[0], 2) |
117 | | - |
118 | | - |
119 | | -# def test_get_activation_statistics(): |
120 | | -# kan = KolmogorovArnoldNetwork([3, 5, 2]) |
121 | | -# stats = kan.get_activation_statistics(data) |
122 | | -# assert isinstance(stats, dict) |
123 | | -# assert 'layer_0' in stats |
124 | | -# assert 'layer_1' in stats |
125 | | -# assert 'mean' in stats['layer_0'] |
126 | | -# assert 'std' in stats['layer_0'] |
127 | | -# assert 'min' in stats['layer_0'] |
128 | | -# assert 'max' in stats['layer_0'] |
129 | | - |
130 | | - |
131 | | -# def test_get_network_grid_statistics(): |
132 | | -# kan = KolmogorovArnoldNetwork([3, 5, 2]) |
133 | | -# stats = kan.get_network_grid_statistics() |
134 | | -# assert isinstance(stats, dict) |
135 | | -# assert 'layer_0' in stats |
136 | | -# assert 'layer_1' in stats |
137 | | - |
138 | | - |
139 | | -# def test_save_act(): |
140 | | -# kan = KolmogorovArnoldNetwork([3, 5, 2], save_act=True) |
141 | | -# output = kan(data) |
142 | | -# assert hasattr(kan, 'acts') |
143 | | -# assert len(kan.acts) == 3 # input + 2 layers |
144 | | -# assert kan.acts[0].shape == data.shape |
145 | | -# assert kan.acts[-1].shape == output.shape |
146 | | - |
147 | | - |
148 | | -# def test_save_act_disabled(): |
149 | | -# kan = KolmogorovArnoldNetwork([3, 5, 2], save_act=False) |
150 | | -# _ = kan(data) |
151 | | -# assert hasattr(kan, 'acts') |
152 | | -# # Only the first activation (input) is saved |
153 | | -# assert len(kan.acts) == 1 |
| 83 | + assert data.grad.shape == data.shape |
0 commit comments