Skip to content

Commit 9c52d1d

Browse files
kshitij-mathsndem0
authored andcommitted
test: automatic_shift
1 parent 603dcc0 commit 9c52d1d

2 files changed

Lines changed: 277 additions & 1 deletion

File tree

ezyrb/plugin/automatic_shift.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Module for Scaler plugin"""
22

33
import numpy as np
4-
4+
import torch
55
from ezyrb import Database, Snapshot, Parameter
66
from .plugin import Plugin
77

tests/test_automatic_shift.py

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
import torch.nn as nn
5+
from unittest import TestCase
6+
from unittest.mock import Mock
7+
8+
from ezyrb import Database, Parameter, Snapshot
9+
from ezyrb.plugin.automatic_shift import AutomaticShiftSnapshots
10+
11+
class DummyModel(nn.Module):
12+
def __init__(self):
13+
super().__init__()
14+
self.dummy_param = nn.Parameter(torch.zeros(1))
15+
16+
def forward(self, x):
17+
val = x.view(x.shape[0], -1).sum(dim=1, keepdim=True) * 0.0
18+
return val + self.dummy_param
19+
20+
21+
class SimpleANN:
22+
def __init__(self, stop_training=None):
23+
self.model = DummyModel()
24+
self.lr = 0.01
25+
self.l2_regularization = 0.0
26+
self.loss_trend = []
27+
self.stop_training = stop_training if stop_training else [1]
28+
self.frequency_print = 100
29+
30+
def _build_model(self, x, y):
31+
pass
32+
33+
def fit(self, x, y):
34+
pass
35+
36+
def optimizer(self, params, lr, weight_decay):
37+
return torch.optim.SGD(params, lr=lr, weight_decay=weight_decay)
38+
39+
def predict(self, x):
40+
return np.zeros((x.shape[0], 1))
41+
42+
43+
class SimpleInterpolator:
44+
def fit(self, x, y):
45+
self.x_fit = np.asarray(x)
46+
self.y_fit = np.asarray(y)
47+
48+
def predict(self, x):
49+
if hasattr(self, 'y_fit'):
50+
return np.full((x.shape[0],), self.y_fit.mean())
51+
return np.zeros((x.shape[0],))
52+
53+
54+
class MockROM:
55+
def __init__(self, db):
56+
self.database = db
57+
self.predict_full_database = db
58+
self._full_database = None
59+
60+
61+
class TestAutomaticShiftSnapshots(TestCase):
62+
63+
def setUp(self):
64+
self.space = np.array([0.0, 1.0, 2.0])
65+
self.db = Database()
66+
67+
snap1 = Snapshot(values=np.array([1.0, 2.0, 3.0]), space=self.space.copy())
68+
snap2 = Snapshot(values=np.array([2.0, 3.0, 4.0]), space=self.space.copy())
69+
snap3 = Snapshot(values=np.array([3.0, 4.0, 5.0]), space=self.space.copy())
70+
71+
self.db.add(Parameter([1.0]), snap1)
72+
self.db.add(Parameter([2.0]), snap2)
73+
self.db.add(Parameter([3.0]), snap3)
74+
75+
self.rom = MockROM(self.db)
76+
77+
def test_constructor_stores_parameters(self):
78+
shift_net = SimpleANN()
79+
interp_net = SimpleANN()
80+
interpolator = SimpleInterpolator()
81+
82+
plugin = AutomaticShiftSnapshots(
83+
shift_network=shift_net,
84+
interp_network=interp_net,
85+
interpolator=interpolator,
86+
parameter_index=1,
87+
reference_index=2,
88+
barycenter_loss=5.0,
89+
)
90+
91+
self.assertIs(plugin.shift_network, shift_net)
92+
self.assertIs(plugin.interp_network, interp_net)
93+
self.assertIs(plugin.interpolator, interpolator)
94+
self.assertEqual(plugin.parameter_index, 1)
95+
self.assertEqual(plugin.reference_index, 2)
96+
self.assertEqual(plugin.barycenter_loss, 5.0)
97+
98+
def test_fit_preprocessing_sets_reference_snapshot(self):
99+
plugin = AutomaticShiftSnapshots(
100+
shift_network=SimpleANN(),
101+
interp_network=SimpleANN(),
102+
interpolator=SimpleInterpolator(),
103+
reference_index=1,
104+
)
105+
106+
plugin.fit_preprocessing(self.rom)
107+
108+
expected_snap = self.db._pairs[1][1]
109+
np.testing.assert_array_equal(plugin.reference_snapshot.values, expected_snap.values)
110+
np.testing.assert_array_equal(plugin.reference_snapshot.space, expected_snap.space)
111+
112+
def test_fit_preprocessing_calls_train_interp_network(self):
113+
shift_net = SimpleANN()
114+
interp_net = SimpleANN()
115+
interp_net.fit = Mock()
116+
117+
plugin = AutomaticShiftSnapshots(
118+
shift_network=shift_net,
119+
interp_network=interp_net,
120+
interpolator=SimpleInterpolator(),
121+
reference_index=0,
122+
)
123+
124+
plugin.fit_preprocessing(self.rom)
125+
126+
interp_net.fit.assert_called_once()
127+
args, _ = interp_net.fit.call_args
128+
np.testing.assert_array_equal(args[0], self.space.reshape(-1, 1))
129+
130+
def test_fit_preprocessing_calls_train_shift_network(self):
131+
shift_net = SimpleANN()
132+
shift_net._build_model = Mock()
133+
134+
plugin = AutomaticShiftSnapshots(
135+
shift_network=shift_net,
136+
interp_network=SimpleANN(),
137+
interpolator=SimpleInterpolator(),
138+
)
139+
140+
plugin.fit_preprocessing(self.rom)
141+
shift_net._build_model.assert_called_once()
142+
143+
def test_fit_preprocessing_modifies_snapshots(self):
144+
plugin = AutomaticShiftSnapshots(
145+
shift_network=SimpleANN(),
146+
interp_network=SimpleANN(),
147+
interpolator=SimpleInterpolator(),
148+
)
149+
plugin.fit_preprocessing(self.rom)
150+
self.assertIsNotNone(self.db._pairs[0][1].values)
151+
152+
def test_fit_preprocessing_with_barycenter_loss_zero(self):
153+
plugin = AutomaticShiftSnapshots(
154+
shift_network=SimpleANN(),
155+
interp_network=SimpleANN(),
156+
interpolator=SimpleInterpolator(),
157+
barycenter_loss=0.0,
158+
)
159+
plugin.fit_preprocessing(self.rom)
160+
self.assertIsNotNone(plugin.reference_snapshot)
161+
162+
def test_fit_preprocessing_with_barycenter_loss_nonzero(self):
163+
plugin = AutomaticShiftSnapshots(
164+
shift_network=SimpleANN(),
165+
interp_network=SimpleANN(),
166+
interpolator=SimpleInterpolator(),
167+
barycenter_loss=10.0,
168+
)
169+
plugin.fit_preprocessing(self.rom)
170+
self.assertIsNotNone(plugin.reference_snapshot)
171+
172+
def test_predict_postprocessing_creates_full_database(self):
173+
plugin = AutomaticShiftSnapshots(
174+
shift_network=SimpleANN(),
175+
interp_network=SimpleANN(),
176+
interpolator=SimpleInterpolator(),
177+
)
178+
plugin.fit_preprocessing(self.rom)
179+
plugin.predict_postprocessing(self.rom)
180+
181+
self.assertIsInstance(self.rom._full_database, Database)
182+
183+
def test_predict_postprocessing_preserves_snapshot_count(self):
184+
plugin = AutomaticShiftSnapshots(
185+
shift_network=SimpleANN(),
186+
interp_network=SimpleANN(),
187+
interpolator=SimpleInterpolator(),
188+
)
189+
plugin.fit_preprocessing(self.rom)
190+
original_count = len(self.rom.predict_full_database)
191+
plugin.predict_postprocessing(self.rom)
192+
193+
self.assertEqual(len(self.rom._full_database), original_count)
194+
195+
def test_predict_postprocessing_modifies_space(self):
196+
plugin = AutomaticShiftSnapshots(
197+
shift_network=SimpleANN(),
198+
interp_network=SimpleANN(),
199+
interpolator=SimpleInterpolator(),
200+
)
201+
plugin.fit_preprocessing(self.rom)
202+
plugin.predict_postprocessing(self.rom)
203+
204+
new_spaces = [snap.space.copy() for _, snap in self.rom._full_database._pairs]
205+
for new_space in new_spaces:
206+
self.assertEqual(len(new_space), len(self.space))
207+
208+
def test_stop_training_integer_criterion(self):
209+
shift_net = SimpleANN(stop_training=[2])
210+
plugin = AutomaticShiftSnapshots(
211+
shift_network=shift_net,
212+
interp_network=SimpleANN(),
213+
interpolator=SimpleInterpolator(),
214+
)
215+
plugin.fit_preprocessing(self.rom)
216+
self.assertEqual(len(shift_net.loss_trend), 2)
217+
218+
def test_stop_training_float_criterion(self):
219+
shift_net = SimpleANN(stop_training=[100.0])
220+
plugin = AutomaticShiftSnapshots(
221+
shift_network=shift_net,
222+
interp_network=SimpleANN(),
223+
interpolator=SimpleInterpolator(),
224+
)
225+
plugin.fit_preprocessing(self.rom)
226+
self.assertGreaterEqual(len(shift_net.loss_trend), 1)
227+
228+
def test_single_snapshot_database(self):
229+
db = Database()
230+
snap = Snapshot(values=np.array([1.0, 2.0, 3.0]), space=self.space)
231+
db.add(Parameter([1.0]), snap)
232+
rom = MockROM(db)
233+
234+
plugin = AutomaticShiftSnapshots(
235+
shift_network=SimpleANN(),
236+
interp_network=SimpleANN(),
237+
interpolator=SimpleInterpolator(),
238+
)
239+
plugin.fit_preprocessing(rom)
240+
plugin.predict_postprocessing(rom)
241+
self.assertEqual(len(rom._full_database), 1)
242+
243+
def test_reference_index_boundary(self):
244+
db = Database()
245+
for i in range(5):
246+
snap = Snapshot(values=np.array([float(i)]), space=np.array([0.5]))
247+
db.add(Parameter([float(i)]), snap)
248+
249+
rom = MockROM(db)
250+
plugin = AutomaticShiftSnapshots(
251+
shift_network=SimpleANN(),
252+
interp_network=SimpleANN(),
253+
interpolator=SimpleInterpolator(),
254+
reference_index=4,
255+
)
256+
plugin.fit_preprocessing(rom)
257+
self.assertEqual(plugin.reference_snapshot.values[0], 4.0)
258+
259+
def test_multidimensional_parameters_raise_valueerror(self):
260+
db = Database()
261+
snap1 = Snapshot(values=np.array([1.0, 2.0, 3.0]), space=self.space)
262+
db.add(Parameter([1.0, 10.0]), snap1)
263+
rom = MockROM(db)
264+
265+
plugin = AutomaticShiftSnapshots(
266+
shift_network=SimpleANN(),
267+
interp_network=SimpleANN(),
268+
interpolator=SimpleInterpolator(),
269+
parameter_index=1,
270+
)
271+
with self.assertRaises(ValueError):
272+
plugin.fit_preprocessing(rom)
273+
274+
275+
if __name__ == '__main__':
276+
pytest.main([__file__, '-v'])

0 commit comments

Comments
 (0)