1+ import copy
2+ import unittest
3+ import numpy as np
4+ from unittest import TestCase
5+ from ezyrb import Database , RBF
6+ from ezyrb .approximation .linear import Linear
7+ from ezyrb .reduction .pod import POD
8+ from ezyrb .reducedordermodel import ReducedOrderModel as ROM
9+ from ezyrb .reducedordermodel import MultiReducedOrderModel as MROM
10+ from ezyrb .plugin .aggregation import Aggregation
11+ from ezyrb .plugin .database_splitter import DatabaseSplitter
12+
13+ class MockROM :
14+ validation_full_database = None
15+
16+ def __init__ (self , db ):
17+ self .validation_full_database = db
18+
19+ def predict (self , db ):
20+ return db
21+
22+ class MockMROM :
23+ train_full_database = None
24+ validation_full_database = None
25+ predict_full_database = None
26+ multi_predict_database = None
27+ weights_predict = None
28+
29+ def __init__ (self , db , n_roms = 2 ):
30+ self .roms = {f'rom{ i } ' : MockROM (db ) for i in range (n_roms )}
31+ self .train_full_database = db
32+ self .validation_full_database = db
33+ self .predict_full_database = db
34+ self .multi_predict_database = {f'rom{ i } ' : db for i in range (n_roms )}
35+ self .weights_predict = {}
36+
37+
38+ def _make_unit_db ():
39+ space = np .array ([[0.0 , 0.0 ], [1.0 , 0.0 ], [0.0 , 1.0 ]])
40+ params = np .array ([[0.5 ], [1.5 ]])
41+ snaps = np .array ([[10.0 , 20.0 , 30.0 ], [40.0 , 50.0 , 60.0 ]])
42+ return Database (params , snaps , space = space )
43+
44+
45+ def _make_integration_db (n_params = 5 , n_space = 3 ):
46+ mu = np .linspace (0.5 , 3.0 , n_params )
47+ x = np .linspace (0 , 2 * np .pi , n_space )
48+ snaps = np .array ([np .sin (m * x ) for m in mu ])
49+ space = x .reshape (- 1 , 1 )
50+ return Database (mu .reshape (- 1 , 1 ), snaps , space = space )
51+
52+ def _relative_error (predicted , actual ):
53+ norms = np .linalg .norm (actual , axis = 1 )
54+ norms = np .where (norms < 1e-12 , 1.0 , norms )
55+ return np .mean (np .linalg .norm (predicted - actual , axis = 1 ) / norms )
56+
57+ class TestAggregation (TestCase ):
58+
59+ def setUp (self ):
60+ self .db = _make_unit_db ()
61+
62+ def test_constructor_default_fit_function_is_none (self ):
63+ agg = Aggregation ()
64+ self .assertIsNone (agg .fit_function )
65+
66+ def test_constructor_default_predict_function_is_linear (self ):
67+ agg = Aggregation ()
68+ self .assertIsInstance (agg .predict_function , Linear )
69+
70+ def test_constructor_custom_arguments (self ):
71+ agg = Aggregation (fit_function = RBF (), predict_function = RBF ())
72+ self .assertIsInstance (agg .fit_function , RBF )
73+ self .assertIsInstance (agg .predict_function , RBF )
74+
75+
76+ def test_check_sum_gaussians_partial_zeros (self ):
77+ agg = Aggregation ()
78+ mrom = MockMROM (self .db , n_roms = 2 )
79+ gaussians = np .array ([[0.0 , 0.8 ], [0.0 , 0.2 ]])
80+ res = agg ._check_sum_gaussians (mrom , gaussians .sum (axis = 0 ), gaussians .copy ())
81+ np .testing .assert_array_equal (res [:, 0 ], [0.5 , 0.5 ])
82+ np .testing .assert_array_equal (res [:, 1 ], [0.8 , 0.2 ])
83+
84+ def test_check_sum_gaussians_no_zeros_unchanged (self ):
85+ agg = Aggregation ()
86+ mrom = MockMROM (self .db , n_roms = 2 )
87+ gaussians = np .array ([[0.3 , 0.7 ], [0.6 , 0.3 ]])
88+ original = gaussians .copy ()
89+ res = agg ._check_sum_gaussians (mrom , gaussians .sum (axis = 0 ), gaussians .copy ())
90+ np .testing .assert_array_equal (res , original )
91+
92+ def test_check_sum_gaussians_all_zeros (self ):
93+ agg = Aggregation ()
94+ mrom = MockMROM (self .db , n_roms = 2 )
95+ gaussians = np .zeros ((2 , 3 ))
96+ res = agg ._check_sum_gaussians (mrom , gaussians .sum (axis = 0 ), gaussians .copy ())
97+ np .testing .assert_array_equal (res , np .full ((2 , 3 ), 0.5 ))
98+
99+ def test_check_sum_gaussians_equal_weight_matches_n_roms (self ):
100+ n_roms = 4
101+ agg = Aggregation ()
102+ mrom = MockMROM (self .db , n_roms = n_roms )
103+ gaussians = np .zeros ((n_roms , 2 ))
104+ res = agg ._check_sum_gaussians (mrom , gaussians .sum (axis = 0 ), gaussians .copy ())
105+ np .testing .assert_array_almost_equal (res , np .full ((n_roms , 2 ), 1.0 / n_roms ))
106+
107+
108+ def test_compute_validation_weights_perfect_prediction_values (self ):
109+ mrom = MockMROM (self .db , n_roms = 2 )
110+ agg = Aggregation ()
111+ g = agg ._compute_validation_weights (mrom , sigma = 1.0 , normalized = False )
112+ np .testing .assert_array_almost_equal (g , np .ones_like (g ))
113+
114+ def test_compute_validation_weights_normalized_sums_to_one (self ):
115+ mrom = MockMROM (self .db , n_roms = 2 )
116+ agg = Aggregation ()
117+ g = agg ._compute_validation_weights (mrom , sigma = 1.0 , normalized = True )
118+ np .testing .assert_array_almost_equal (g .sum (axis = 0 ), np .ones_like (g [0 ]))
119+
120+ def test_compute_validation_weights_shape (self ):
121+ mrom = MockMROM (self .db , n_roms = 3 )
122+ agg = Aggregation ()
123+ g = agg ._compute_validation_weights (mrom , sigma = 1.0 )
124+ self .assertEqual (g .shape [0 ], 3 )
125+
126+ def test_compute_validation_weights_sigma_effect (self ):
127+ mrom = MockMROM (self .db , n_roms = 2 )
128+ agg = Aggregation ()
129+ g_large = agg ._compute_validation_weights (mrom , sigma = 1e6 , normalized = False )
130+ g_small = agg ._compute_validation_weights (mrom , sigma = 1e-6 , normalized = False )
131+ np .testing .assert_array_almost_equal (g_large , np .ones_like (g_large ))
132+ np .testing .assert_array_almost_equal (g_small , np .ones_like (g_small ))
133+
134+
135+ def test_optimize_sigma_returns_finite_value (self ):
136+ mrom = MockMROM (self .db , n_roms = 2 )
137+ agg = Aggregation ()
138+ sigma = agg ._optimize_sigma (mrom )
139+ self .assertTrue (np .isfinite (sigma ).all ())
140+
141+ def test_optimize_sigma_within_default_range (self ):
142+ mrom = MockMROM (self .db , n_roms = 2 )
143+ agg = Aggregation ()
144+ sigma = agg ._optimize_sigma (mrom )
145+ self .assertGreaterEqual (float (sigma ), 1e-5 )
146+ self .assertLessEqual (float (sigma ), 1e-2 )
147+
148+ def test_aggregation_no_fit_function (self ):
149+ mrom = MockMROM (self .db , n_roms = 2 )
150+ agg = Aggregation (fit_function = None , predict_function = RBF ())
151+ agg .fit_postprocessing (mrom )
152+ agg .predict_postprocessing (mrom )
153+ self .assertIsNotNone (mrom .predict_full_database )
154+ self .assertEqual (len (agg .predict_functions ), 2 )
155+
156+ def test_aggregation_with_fit_function (self ):
157+ mrom = MockMROM (self .db , n_roms = 1 )
158+ agg = Aggregation (fit_function = RBF (), predict_function = RBF ())
159+ agg .fit_postprocessing (mrom )
160+ agg .predict_postprocessing (mrom )
161+ self .assertIsNotNone (mrom .predict_full_database )
162+
163+ def test_nan_handling_in_weights (self ):
164+ mrom = MockMROM (self .db , n_roms = 2 )
165+ agg = Aggregation (fit_function = None , predict_function = RBF ())
166+ agg ._compute_validation_weights = (
167+ lambda mrom , sigma , normalized = False : np .full ((2 , 2 , 3 ), np .nan )
168+ )
169+ agg ._optimize_sigma = lambda mrom : 1e-3
170+ agg .fit_postprocessing (mrom )
171+ self .assertEqual (len (agg .predict_functions ), 2 )
172+
173+
174+ class TestAggregationIntegration (TestCase ):
175+
176+ @classmethod
177+ def setUpClass (cls ):
178+ cls .db = _make_integration_db (n_params = 5 , n_space = 3 )
179+
180+ def _make_splitter (self , seed = 0 ):
181+ return DatabaseSplitter (
182+ train = 2 , test = 0 , validation = 2 , predict = 1 , seed = seed
183+ )
184+
185+ def _build_and_fit_mrom (self , agg , seed = 0 ):
186+ splitter = self ._make_splitter (seed = seed )
187+ rom1 = ROM (self .db , POD (rank = 1 ), RBF ())
188+ rom2 = ROM (self .db , POD (rank = 1 ), Linear ())
189+ agg ._optimize_sigma = lambda mrom : 1e-3
190+ mrom = MROM (
191+ {'rbf' : rom1 , 'lin' : rom2 },
192+ plugins = [splitter , agg ],
193+ rom_plugin = splitter ,
194+ )
195+ mrom .fit ()
196+ return mrom
197+
198+ def test_fit_does_not_raise (self ):
199+ agg = Aggregation (fit_function = None , predict_function = RBF ())
200+ self ._build_and_fit_mrom (agg )
201+
202+ def test_fit_regression_path_does_not_raise (self ):
203+ splitter = self ._make_splitter ()
204+ rom1 = ROM (self .db , POD (rank = 1 ), RBF ())
205+ agg = Aggregation (fit_function = RBF (), predict_function = RBF ())
206+ mrom = MROM ({'rbf' : rom1 }, plugins = [splitter , agg ], rom_plugin = splitter )
207+ mrom .fit ()
208+
209+ def test_predict_returns_database_instance (self ):
210+ agg = Aggregation (fit_function = None , predict_function = RBF ())
211+ mrom = self ._build_and_fit_mrom (agg )
212+ mrom .predict (mrom .predict_full_database )
213+ self .assertIsInstance (mrom .predict_full_database , Database )
214+
215+ def test_predict_snapshot_shape (self ):
216+ agg = Aggregation (fit_function = None , predict_function = RBF ())
217+ mrom = self ._build_and_fit_mrom (agg )
218+ mrom .predict (mrom .predict_full_database )
219+ self .assertEqual (mrom .predict_full_database .snapshots_matrix .shape [1 ], 3 )
220+
221+ def test_predict_functions_count_matches_n_roms (self ):
222+ agg = Aggregation (fit_function = None , predict_function = RBF ())
223+ self ._build_and_fit_mrom (agg )
224+ self .assertEqual (len (agg .predict_functions ), 2 )
225+
226+ def test_weights_are_finite (self ):
227+ agg = Aggregation (fit_function = None , predict_function = RBF ())
228+ mrom = self ._build_and_fit_mrom (agg )
229+ mrom .predict (mrom .predict_full_database )
230+ for key , w in mrom .weights_predict .items ():
231+ self .assertTrue (np .isfinite (w ).all (),
232+ msg = f"Non-finite weight for ROM '{ key } '" )
233+
234+ def test_weights_sum_to_one (self ):
235+ agg = Aggregation (fit_function = None , predict_function = RBF ())
236+ mrom = self ._build_and_fit_mrom (agg )
237+ mrom .predict (mrom .predict_full_database )
238+ weight_sum = np .sum (list (mrom .weights_predict .values ()), axis = 0 )
239+ np .testing .assert_array_almost_equal (
240+ weight_sum , np .ones_like (weight_sum ), decimal = 5
241+ )
242+
243+ def test_fit_reproducible_with_same_seed (self ):
244+ agg1 = Aggregation (fit_function = None , predict_function = RBF ())
245+ agg2 = Aggregation (fit_function = None , predict_function = RBF ())
246+ mrom1 = self ._build_and_fit_mrom (agg1 , seed = 7 )
247+ mrom2 = self ._build_and_fit_mrom (agg2 , seed = 7 )
248+
249+ pred_db1 = copy .deepcopy (mrom1 .predict_full_database )
250+ pred_db2 = copy .deepcopy (mrom2 .predict_full_database )
251+ mrom1 .predict (pred_db1 )
252+ mrom2 .predict (pred_db2 )
253+
254+ np .testing .assert_array_almost_equal (
255+ mrom1 .predict_full_database .snapshots_matrix ,
256+ mrom2 .predict_full_database .snapshots_matrix ,
257+ decimal = 10 ,
258+ )
259+
260+ def test_fit_different_seeds_produce_different_predictions (self ):
261+ agg1 = Aggregation (fit_function = None , predict_function = RBF ())
262+ agg2 = Aggregation (fit_function = None , predict_function = RBF ())
263+ mrom1 = self ._build_and_fit_mrom (agg1 , seed = 0 )
264+ mrom2 = self ._build_and_fit_mrom (agg2 , seed = 99 )
265+
266+ pred_db1 = copy .deepcopy (mrom1 .predict_full_database )
267+ pred_db2 = copy .deepcopy (mrom2 .predict_full_database )
268+ mrom1 .predict (pred_db1 )
269+ mrom2 .predict (pred_db2 )
270+
271+ with self .assertRaises (AssertionError ):
272+ np .testing .assert_array_almost_equal (
273+ mrom1 .predict_full_database .snapshots_matrix ,
274+ mrom2 .predict_full_database .snapshots_matrix ,
275+ decimal = 10 ,
276+ )
277+
278+ if __name__ == '__main__' :
279+ unittest .main ()
0 commit comments