1+ import numpy as np
2+ import pytest
3+ import torch
4+ import torch .nn as nn
5+ from unittest import TestCase
6+ from unittest .mock import Mock
7+
8+ from ezyrb import Database , Parameter , Snapshot
9+ from ezyrb .plugin .automatic_shift import AutomaticShiftSnapshots
10+
11+ class DummyModel (nn .Module ):
12+ def __init__ (self ):
13+ super ().__init__ ()
14+ self .dummy_param = nn .Parameter (torch .zeros (1 ))
15+
16+ def forward (self , x ):
17+ val = x .view (x .shape [0 ], - 1 ).sum (dim = 1 , keepdim = True ) * 0.0
18+ return val + self .dummy_param
19+
20+
21+ class SimpleANN :
22+ def __init__ (self , stop_training = None ):
23+ self .model = DummyModel ()
24+ self .lr = 0.01
25+ self .l2_regularization = 0.0
26+ self .loss_trend = []
27+ self .stop_training = stop_training if stop_training else [1 ]
28+ self .frequency_print = 100
29+
30+ def _build_model (self , x , y ):
31+ pass
32+
33+ def fit (self , x , y ):
34+ pass
35+
36+ def optimizer (self , params , lr , weight_decay ):
37+ return torch .optim .SGD (params , lr = lr , weight_decay = weight_decay )
38+
39+ def predict (self , x ):
40+ return np .zeros ((x .shape [0 ], 1 ))
41+
42+
43+ class SimpleInterpolator :
44+ def fit (self , x , y ):
45+ self .x_fit = np .asarray (x )
46+ self .y_fit = np .asarray (y )
47+
48+ def predict (self , x ):
49+ if hasattr (self , 'y_fit' ):
50+ return np .full ((x .shape [0 ],), self .y_fit .mean ())
51+ return np .zeros ((x .shape [0 ],))
52+
53+
54+ class MockROM :
55+ def __init__ (self , db ):
56+ self .database = db
57+ self .predict_full_database = db
58+ self ._full_database = None
59+
60+
61+ class TestAutomaticShiftSnapshots (TestCase ):
62+
63+ def setUp (self ):
64+ self .space = np .array ([0.0 , 1.0 , 2.0 ])
65+ self .db = Database ()
66+
67+ snap1 = Snapshot (values = np .array ([1.0 , 2.0 , 3.0 ]), space = self .space .copy ())
68+ snap2 = Snapshot (values = np .array ([2.0 , 3.0 , 4.0 ]), space = self .space .copy ())
69+ snap3 = Snapshot (values = np .array ([3.0 , 4.0 , 5.0 ]), space = self .space .copy ())
70+
71+ self .db .add (Parameter ([1.0 ]), snap1 )
72+ self .db .add (Parameter ([2.0 ]), snap2 )
73+ self .db .add (Parameter ([3.0 ]), snap3 )
74+
75+ self .rom = MockROM (self .db )
76+
77+ def test_constructor_stores_parameters (self ):
78+ shift_net = SimpleANN ()
79+ interp_net = SimpleANN ()
80+ interpolator = SimpleInterpolator ()
81+
82+ plugin = AutomaticShiftSnapshots (
83+ shift_network = shift_net ,
84+ interp_network = interp_net ,
85+ interpolator = interpolator ,
86+ parameter_index = 1 ,
87+ reference_index = 2 ,
88+ barycenter_loss = 5.0 ,
89+ )
90+
91+ self .assertIs (plugin .shift_network , shift_net )
92+ self .assertIs (plugin .interp_network , interp_net )
93+ self .assertIs (plugin .interpolator , interpolator )
94+ self .assertEqual (plugin .parameter_index , 1 )
95+ self .assertEqual (plugin .reference_index , 2 )
96+ self .assertEqual (plugin .barycenter_loss , 5.0 )
97+
98+ def test_fit_preprocessing_sets_reference_snapshot (self ):
99+ plugin = AutomaticShiftSnapshots (
100+ shift_network = SimpleANN (),
101+ interp_network = SimpleANN (),
102+ interpolator = SimpleInterpolator (),
103+ reference_index = 1 ,
104+ )
105+
106+ plugin .fit_preprocessing (self .rom )
107+
108+ expected_snap = self .db ._pairs [1 ][1 ]
109+ np .testing .assert_array_equal (plugin .reference_snapshot .values , expected_snap .values )
110+ np .testing .assert_array_equal (plugin .reference_snapshot .space , expected_snap .space )
111+
112+ def test_fit_preprocessing_calls_train_interp_network (self ):
113+ shift_net = SimpleANN ()
114+ interp_net = SimpleANN ()
115+ interp_net .fit = Mock ()
116+
117+ plugin = AutomaticShiftSnapshots (
118+ shift_network = shift_net ,
119+ interp_network = interp_net ,
120+ interpolator = SimpleInterpolator (),
121+ reference_index = 0 ,
122+ )
123+
124+ plugin .fit_preprocessing (self .rom )
125+
126+ interp_net .fit .assert_called_once ()
127+ args , _ = interp_net .fit .call_args
128+ np .testing .assert_array_equal (args [0 ], self .space .reshape (- 1 , 1 ))
129+
130+ def test_fit_preprocessing_calls_train_shift_network (self ):
131+ shift_net = SimpleANN ()
132+ shift_net ._build_model = Mock ()
133+
134+ plugin = AutomaticShiftSnapshots (
135+ shift_network = shift_net ,
136+ interp_network = SimpleANN (),
137+ interpolator = SimpleInterpolator (),
138+ )
139+
140+ plugin .fit_preprocessing (self .rom )
141+ shift_net ._build_model .assert_called_once ()
142+
143+ def test_fit_preprocessing_modifies_snapshots (self ):
144+ plugin = AutomaticShiftSnapshots (
145+ shift_network = SimpleANN (),
146+ interp_network = SimpleANN (),
147+ interpolator = SimpleInterpolator (),
148+ )
149+ plugin .fit_preprocessing (self .rom )
150+ self .assertIsNotNone (self .db ._pairs [0 ][1 ].values )
151+
152+ def test_fit_preprocessing_with_barycenter_loss_zero (self ):
153+ plugin = AutomaticShiftSnapshots (
154+ shift_network = SimpleANN (),
155+ interp_network = SimpleANN (),
156+ interpolator = SimpleInterpolator (),
157+ barycenter_loss = 0.0 ,
158+ )
159+ plugin .fit_preprocessing (self .rom )
160+ self .assertIsNotNone (plugin .reference_snapshot )
161+
162+ def test_fit_preprocessing_with_barycenter_loss_nonzero (self ):
163+ plugin = AutomaticShiftSnapshots (
164+ shift_network = SimpleANN (),
165+ interp_network = SimpleANN (),
166+ interpolator = SimpleInterpolator (),
167+ barycenter_loss = 10.0 ,
168+ )
169+ plugin .fit_preprocessing (self .rom )
170+ self .assertIsNotNone (plugin .reference_snapshot )
171+
172+ def test_predict_postprocessing_creates_full_database (self ):
173+ plugin = AutomaticShiftSnapshots (
174+ shift_network = SimpleANN (),
175+ interp_network = SimpleANN (),
176+ interpolator = SimpleInterpolator (),
177+ )
178+ plugin .fit_preprocessing (self .rom )
179+ plugin .predict_postprocessing (self .rom )
180+
181+ self .assertIsInstance (self .rom ._full_database , Database )
182+
183+ def test_predict_postprocessing_preserves_snapshot_count (self ):
184+ plugin = AutomaticShiftSnapshots (
185+ shift_network = SimpleANN (),
186+ interp_network = SimpleANN (),
187+ interpolator = SimpleInterpolator (),
188+ )
189+ plugin .fit_preprocessing (self .rom )
190+ original_count = len (self .rom .predict_full_database )
191+ plugin .predict_postprocessing (self .rom )
192+
193+ self .assertEqual (len (self .rom ._full_database ), original_count )
194+
195+ def test_predict_postprocessing_modifies_space (self ):
196+ plugin = AutomaticShiftSnapshots (
197+ shift_network = SimpleANN (),
198+ interp_network = SimpleANN (),
199+ interpolator = SimpleInterpolator (),
200+ )
201+ plugin .fit_preprocessing (self .rom )
202+ plugin .predict_postprocessing (self .rom )
203+
204+ new_spaces = [snap .space .copy () for _ , snap in self .rom ._full_database ._pairs ]
205+ for new_space in new_spaces :
206+ self .assertEqual (len (new_space ), len (self .space ))
207+
208+ def test_stop_training_integer_criterion (self ):
209+ shift_net = SimpleANN (stop_training = [2 ])
210+ plugin = AutomaticShiftSnapshots (
211+ shift_network = shift_net ,
212+ interp_network = SimpleANN (),
213+ interpolator = SimpleInterpolator (),
214+ )
215+ plugin .fit_preprocessing (self .rom )
216+ self .assertEqual (len (shift_net .loss_trend ), 2 )
217+
218+ def test_stop_training_float_criterion (self ):
219+ shift_net = SimpleANN (stop_training = [100.0 ])
220+ plugin = AutomaticShiftSnapshots (
221+ shift_network = shift_net ,
222+ interp_network = SimpleANN (),
223+ interpolator = SimpleInterpolator (),
224+ )
225+ plugin .fit_preprocessing (self .rom )
226+ self .assertGreaterEqual (len (shift_net .loss_trend ), 1 )
227+
228+ def test_single_snapshot_database (self ):
229+ db = Database ()
230+ snap = Snapshot (values = np .array ([1.0 , 2.0 , 3.0 ]), space = self .space )
231+ db .add (Parameter ([1.0 ]), snap )
232+ rom = MockROM (db )
233+
234+ plugin = AutomaticShiftSnapshots (
235+ shift_network = SimpleANN (),
236+ interp_network = SimpleANN (),
237+ interpolator = SimpleInterpolator (),
238+ )
239+ plugin .fit_preprocessing (rom )
240+ plugin .predict_postprocessing (rom )
241+ self .assertEqual (len (rom ._full_database ), 1 )
242+
243+ def test_reference_index_boundary (self ):
244+ db = Database ()
245+ for i in range (5 ):
246+ snap = Snapshot (values = np .array ([float (i )]), space = np .array ([0.5 ]))
247+ db .add (Parameter ([float (i )]), snap )
248+
249+ rom = MockROM (db )
250+ plugin = AutomaticShiftSnapshots (
251+ shift_network = SimpleANN (),
252+ interp_network = SimpleANN (),
253+ interpolator = SimpleInterpolator (),
254+ reference_index = 4 ,
255+ )
256+ plugin .fit_preprocessing (rom )
257+ self .assertEqual (plugin .reference_snapshot .values [0 ], 4.0 )
258+
259+ def test_multidimensional_parameters_raise_valueerror (self ):
260+ db = Database ()
261+ snap1 = Snapshot (values = np .array ([1.0 , 2.0 , 3.0 ]), space = self .space )
262+ db .add (Parameter ([1.0 , 10.0 ]), snap1 )
263+ rom = MockROM (db )
264+
265+ plugin = AutomaticShiftSnapshots (
266+ shift_network = SimpleANN (),
267+ interp_network = SimpleANN (),
268+ interpolator = SimpleInterpolator (),
269+ parameter_index = 1 ,
270+ )
271+ with self .assertRaises (ValueError ):
272+ plugin .fit_preprocessing (rom )
273+
274+
275+ if __name__ == '__main__' :
276+ pytest .main ([__file__ , '-v' ])
0 commit comments