Skip to content

Commit 8d1bb83

Browse files
kshitij-mathsndem0
authored andcommitted
test: aggregation
1 parent 7a1a367 commit 8d1bb83

1 file changed

Lines changed: 279 additions & 0 deletions

File tree

tests/test_aggregation.py

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
import copy
2+
import unittest
3+
import numpy as np
4+
from unittest import TestCase
5+
from ezyrb import Database, RBF
6+
from ezyrb.approximation.linear import Linear
7+
from ezyrb.reduction.pod import POD
8+
from ezyrb.reducedordermodel import ReducedOrderModel as ROM
9+
from ezyrb.reducedordermodel import MultiReducedOrderModel as MROM
10+
from ezyrb.plugin.aggregation import Aggregation
11+
from ezyrb.plugin.database_splitter import DatabaseSplitter
12+
13+
class MockROM:
14+
validation_full_database = None
15+
16+
def __init__(self, db):
17+
self.validation_full_database = db
18+
19+
def predict(self, db):
20+
return db
21+
22+
class MockMROM:
23+
train_full_database = None
24+
validation_full_database = None
25+
predict_full_database = None
26+
multi_predict_database = None
27+
weights_predict = None
28+
29+
def __init__(self, db, n_roms=2):
30+
self.roms = {f'rom{i}': MockROM(db) for i in range(n_roms)}
31+
self.train_full_database = db
32+
self.validation_full_database = db
33+
self.predict_full_database = db
34+
self.multi_predict_database = {f'rom{i}': db for i in range(n_roms)}
35+
self.weights_predict = {}
36+
37+
38+
def _make_unit_db():
39+
space = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]])
40+
params = np.array([[0.5], [1.5]])
41+
snaps = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]])
42+
return Database(params, snaps, space=space)
43+
44+
45+
def _make_integration_db(n_params=5, n_space=3):
46+
mu = np.linspace(0.5, 3.0, n_params)
47+
x = np.linspace(0, 2 * np.pi, n_space)
48+
snaps = np.array([np.sin(m * x) for m in mu])
49+
space = x.reshape(-1, 1)
50+
return Database(mu.reshape(-1, 1), snaps, space=space)
51+
52+
def _relative_error(predicted, actual):
53+
norms = np.linalg.norm(actual, axis=1)
54+
norms = np.where(norms < 1e-12, 1.0, norms)
55+
return np.mean(np.linalg.norm(predicted - actual, axis=1) / norms)
56+
57+
class TestAggregation(TestCase):
58+
59+
def setUp(self):
60+
self.db = _make_unit_db()
61+
62+
def test_constructor_default_fit_function_is_none(self):
63+
agg = Aggregation()
64+
self.assertIsNone(agg.fit_function)
65+
66+
def test_constructor_default_predict_function_is_linear(self):
67+
agg = Aggregation()
68+
self.assertIsInstance(agg.predict_function, Linear)
69+
70+
def test_constructor_custom_arguments(self):
71+
agg = Aggregation(fit_function=RBF(), predict_function=RBF())
72+
self.assertIsInstance(agg.fit_function, RBF)
73+
self.assertIsInstance(agg.predict_function, RBF)
74+
75+
76+
def test_check_sum_gaussians_partial_zeros(self):
77+
agg = Aggregation()
78+
mrom = MockMROM(self.db, n_roms=2)
79+
gaussians = np.array([[0.0, 0.8], [0.0, 0.2]])
80+
res = agg._check_sum_gaussians(mrom, gaussians.sum(axis=0), gaussians.copy())
81+
np.testing.assert_array_equal(res[:, 0], [0.5, 0.5])
82+
np.testing.assert_array_equal(res[:, 1], [0.8, 0.2])
83+
84+
def test_check_sum_gaussians_no_zeros_unchanged(self):
85+
agg = Aggregation()
86+
mrom = MockMROM(self.db, n_roms=2)
87+
gaussians = np.array([[0.3, 0.7], [0.6, 0.3]])
88+
original = gaussians.copy()
89+
res = agg._check_sum_gaussians(mrom, gaussians.sum(axis=0), gaussians.copy())
90+
np.testing.assert_array_equal(res, original)
91+
92+
def test_check_sum_gaussians_all_zeros(self):
93+
agg = Aggregation()
94+
mrom = MockMROM(self.db, n_roms=2)
95+
gaussians = np.zeros((2, 3))
96+
res = agg._check_sum_gaussians(mrom, gaussians.sum(axis=0), gaussians.copy())
97+
np.testing.assert_array_equal(res, np.full((2, 3), 0.5))
98+
99+
def test_check_sum_gaussians_equal_weight_matches_n_roms(self):
100+
n_roms = 4
101+
agg = Aggregation()
102+
mrom = MockMROM(self.db, n_roms=n_roms)
103+
gaussians = np.zeros((n_roms, 2))
104+
res = agg._check_sum_gaussians(mrom, gaussians.sum(axis=0), gaussians.copy())
105+
np.testing.assert_array_almost_equal(res, np.full((n_roms, 2), 1.0 / n_roms))
106+
107+
108+
def test_compute_validation_weights_perfect_prediction_values(self):
109+
mrom = MockMROM(self.db, n_roms=2)
110+
agg = Aggregation()
111+
g = agg._compute_validation_weights(mrom, sigma=1.0, normalized=False)
112+
np.testing.assert_array_almost_equal(g, np.ones_like(g))
113+
114+
def test_compute_validation_weights_normalized_sums_to_one(self):
115+
mrom = MockMROM(self.db, n_roms=2)
116+
agg = Aggregation()
117+
g = agg._compute_validation_weights(mrom, sigma=1.0, normalized=True)
118+
np.testing.assert_array_almost_equal(g.sum(axis=0), np.ones_like(g[0]))
119+
120+
def test_compute_validation_weights_shape(self):
121+
mrom = MockMROM(self.db, n_roms=3)
122+
agg = Aggregation()
123+
g = agg._compute_validation_weights(mrom, sigma=1.0)
124+
self.assertEqual(g.shape[0], 3)
125+
126+
def test_compute_validation_weights_sigma_effect(self):
127+
mrom = MockMROM(self.db, n_roms=2)
128+
agg = Aggregation()
129+
g_large = agg._compute_validation_weights(mrom, sigma=1e6, normalized=False)
130+
g_small = agg._compute_validation_weights(mrom, sigma=1e-6, normalized=False)
131+
np.testing.assert_array_almost_equal(g_large, np.ones_like(g_large))
132+
np.testing.assert_array_almost_equal(g_small, np.ones_like(g_small))
133+
134+
135+
def test_optimize_sigma_returns_finite_value(self):
136+
mrom = MockMROM(self.db, n_roms=2)
137+
agg = Aggregation()
138+
sigma = agg._optimize_sigma(mrom)
139+
self.assertTrue(np.isfinite(sigma).all())
140+
141+
def test_optimize_sigma_within_default_range(self):
142+
mrom = MockMROM(self.db, n_roms=2)
143+
agg = Aggregation()
144+
sigma = agg._optimize_sigma(mrom)
145+
self.assertGreaterEqual(float(sigma), 1e-5)
146+
self.assertLessEqual(float(sigma), 1e-2)
147+
148+
def test_aggregation_no_fit_function(self):
149+
mrom = MockMROM(self.db, n_roms=2)
150+
agg = Aggregation(fit_function=None, predict_function=RBF())
151+
agg.fit_postprocessing(mrom)
152+
agg.predict_postprocessing(mrom)
153+
self.assertIsNotNone(mrom.predict_full_database)
154+
self.assertEqual(len(agg.predict_functions), 2)
155+
156+
def test_aggregation_with_fit_function(self):
157+
mrom = MockMROM(self.db, n_roms=1)
158+
agg = Aggregation(fit_function=RBF(), predict_function=RBF())
159+
agg.fit_postprocessing(mrom)
160+
agg.predict_postprocessing(mrom)
161+
self.assertIsNotNone(mrom.predict_full_database)
162+
163+
def test_nan_handling_in_weights(self):
164+
mrom = MockMROM(self.db, n_roms=2)
165+
agg = Aggregation(fit_function=None, predict_function=RBF())
166+
agg._compute_validation_weights = (
167+
lambda mrom, sigma, normalized=False: np.full((2, 2, 3), np.nan)
168+
)
169+
agg._optimize_sigma = lambda mrom: 1e-3
170+
agg.fit_postprocessing(mrom)
171+
self.assertEqual(len(agg.predict_functions), 2)
172+
173+
174+
class TestAggregationIntegration(TestCase):
175+
176+
@classmethod
177+
def setUpClass(cls):
178+
cls.db = _make_integration_db(n_params=5, n_space=3)
179+
180+
def _make_splitter(self, seed=0):
181+
return DatabaseSplitter(
182+
train=2, test=0, validation=2, predict=1, seed=seed
183+
)
184+
185+
def _build_and_fit_mrom(self, agg, seed=0):
186+
splitter = self._make_splitter(seed=seed)
187+
rom1 = ROM(self.db, POD(rank=1), RBF())
188+
rom2 = ROM(self.db, POD(rank=1), Linear())
189+
agg._optimize_sigma = lambda mrom: 1e-3
190+
mrom = MROM(
191+
{'rbf': rom1, 'lin': rom2},
192+
plugins=[splitter, agg],
193+
rom_plugin=splitter,
194+
)
195+
mrom.fit()
196+
return mrom
197+
198+
def test_fit_does_not_raise(self):
199+
agg = Aggregation(fit_function=None, predict_function=RBF())
200+
self._build_and_fit_mrom(agg)
201+
202+
def test_fit_regression_path_does_not_raise(self):
203+
splitter = self._make_splitter()
204+
rom1 = ROM(self.db, POD(rank=1), RBF())
205+
agg = Aggregation(fit_function=RBF(), predict_function=RBF())
206+
mrom = MROM({'rbf': rom1}, plugins=[splitter, agg], rom_plugin=splitter)
207+
mrom.fit()
208+
209+
def test_predict_returns_database_instance(self):
210+
agg = Aggregation(fit_function=None, predict_function=RBF())
211+
mrom = self._build_and_fit_mrom(agg)
212+
mrom.predict(mrom.predict_full_database)
213+
self.assertIsInstance(mrom.predict_full_database, Database)
214+
215+
def test_predict_snapshot_shape(self):
216+
agg = Aggregation(fit_function=None, predict_function=RBF())
217+
mrom = self._build_and_fit_mrom(agg)
218+
mrom.predict(mrom.predict_full_database)
219+
self.assertEqual(mrom.predict_full_database.snapshots_matrix.shape[1], 3)
220+
221+
def test_predict_functions_count_matches_n_roms(self):
222+
agg = Aggregation(fit_function=None, predict_function=RBF())
223+
self._build_and_fit_mrom(agg)
224+
self.assertEqual(len(agg.predict_functions), 2)
225+
226+
def test_weights_are_finite(self):
227+
agg = Aggregation(fit_function=None, predict_function=RBF())
228+
mrom = self._build_and_fit_mrom(agg)
229+
mrom.predict(mrom.predict_full_database)
230+
for key, w in mrom.weights_predict.items():
231+
self.assertTrue(np.isfinite(w).all(),
232+
msg=f"Non-finite weight for ROM '{key}'")
233+
234+
def test_weights_sum_to_one(self):
235+
agg = Aggregation(fit_function=None, predict_function=RBF())
236+
mrom = self._build_and_fit_mrom(agg)
237+
mrom.predict(mrom.predict_full_database)
238+
weight_sum = np.sum(list(mrom.weights_predict.values()), axis=0)
239+
np.testing.assert_array_almost_equal(
240+
weight_sum, np.ones_like(weight_sum), decimal=5
241+
)
242+
243+
def test_fit_reproducible_with_same_seed(self):
244+
agg1 = Aggregation(fit_function=None, predict_function=RBF())
245+
agg2 = Aggregation(fit_function=None, predict_function=RBF())
246+
mrom1 = self._build_and_fit_mrom(agg1, seed=7)
247+
mrom2 = self._build_and_fit_mrom(agg2, seed=7)
248+
249+
pred_db1 = copy.deepcopy(mrom1.predict_full_database)
250+
pred_db2 = copy.deepcopy(mrom2.predict_full_database)
251+
mrom1.predict(pred_db1)
252+
mrom2.predict(pred_db2)
253+
254+
np.testing.assert_array_almost_equal(
255+
mrom1.predict_full_database.snapshots_matrix,
256+
mrom2.predict_full_database.snapshots_matrix,
257+
decimal=10,
258+
)
259+
260+
def test_fit_different_seeds_produce_different_predictions(self):
261+
agg1 = Aggregation(fit_function=None, predict_function=RBF())
262+
agg2 = Aggregation(fit_function=None, predict_function=RBF())
263+
mrom1 = self._build_and_fit_mrom(agg1, seed=0)
264+
mrom2 = self._build_and_fit_mrom(agg2, seed=99)
265+
266+
pred_db1 = copy.deepcopy(mrom1.predict_full_database)
267+
pred_db2 = copy.deepcopy(mrom2.predict_full_database)
268+
mrom1.predict(pred_db1)
269+
mrom2.predict(pred_db2)
270+
271+
with self.assertRaises(AssertionError):
272+
np.testing.assert_array_almost_equal(
273+
mrom1.predict_full_database.snapshots_matrix,
274+
mrom2.predict_full_database.snapshots_matrix,
275+
decimal=10,
276+
)
277+
278+
if __name__ == '__main__':
279+
unittest.main()

0 commit comments

Comments
 (0)