Skip to content

Commit 7a1a367

Browse files
kshitij-mathsndem0
authored andcommitted
test: database_splitter
1 parent 9c52d1d commit 7a1a367

1 file changed

Lines changed: 225 additions & 0 deletions

File tree

tests/test_database_splitter.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import numpy as np
2+
from unittest import TestCase
3+
from ezyrb import Database
4+
from ezyrb.plugin.database_splitter import DatabaseSplitter, DatabaseDictionarySplitter
5+
6+
class DummyROM:
7+
train_full_database = None
8+
test_full_database = None
9+
validation_full_database = None
10+
predict_full_database = None
11+
12+
def __init__(self, db):
13+
self._database = db
14+
15+
16+
class TestDatabaseSplitter(TestCase):
17+
18+
def test_split_integers_train_size(self):
19+
db = Database(np.random.uniform(size=(100, 2)),
20+
np.random.uniform(size=(100, 5)))
21+
rom = DummyROM(db)
22+
splitter = DatabaseSplitter(train=80, test=20, validation=0, predict=0)
23+
splitter.fit_preprocessing(rom)
24+
self.assertEqual(len(rom.train_full_database), 80)
25+
26+
def test_split_integers_test_size(self):
27+
db = Database(np.random.uniform(size=(100, 2)),
28+
np.random.uniform(size=(100, 5)))
29+
rom = DummyROM(db)
30+
splitter = DatabaseSplitter(train=80, test=20, validation=0, predict=0)
31+
splitter.fit_preprocessing(rom)
32+
self.assertEqual(len(rom.test_full_database), 20)
33+
34+
def test_split_integers_validation_predict_empty(self):
35+
db = Database(np.random.uniform(size=(100, 2)),
36+
np.random.uniform(size=(100, 5)))
37+
rom = DummyROM(db)
38+
splitter = DatabaseSplitter(train=80, test=20, validation=0, predict=0)
39+
splitter.fit_preprocessing(rom)
40+
self.assertEqual(len(rom.validation_full_database), 0)
41+
self.assertEqual(len(rom.predict_full_database), 0)
42+
43+
def test_split_integers_total_conserved(self):
44+
db = Database(np.random.uniform(size=(100, 2)),
45+
np.random.uniform(size=(100, 5)))
46+
rom = DummyROM(db)
47+
splitter = DatabaseSplitter(train=70, test=20, validation=5, predict=5)
48+
splitter.fit_preprocessing(rom)
49+
total = (len(rom.train_full_database) +
50+
len(rom.test_full_database) +
51+
len(rom.validation_full_database) +
52+
len(rom.predict_full_database))
53+
self.assertEqual(total, 100)
54+
55+
def test_split_integers_returns_database_instances(self):
56+
db = Database(np.random.uniform(size=(100, 2)),
57+
np.random.uniform(size=(100, 5)))
58+
rom = DummyROM(db)
59+
splitter = DatabaseSplitter(train=80, test=20, validation=0, predict=0)
60+
splitter.fit_preprocessing(rom)
61+
self.assertIsInstance(rom.train_full_database, Database)
62+
self.assertIsInstance(rom.test_full_database, Database)
63+
self.assertIsInstance(rom.validation_full_database, Database)
64+
self.assertIsInstance(rom.predict_full_database, Database)
65+
66+
def test_split_integers_inconsistent_chunks_raises(self):
67+
db = Database(np.random.uniform(size=(100, 2)),
68+
np.random.uniform(size=(100, 5)))
69+
rom = DummyROM(db)
70+
splitter = DatabaseSplitter(train=70, test=20, validation=0, predict=0)
71+
with self.assertRaises(ValueError):
72+
splitter.fit_preprocessing(rom)
73+
74+
75+
def test_split_floats_total_conserved(self):
76+
db = Database(np.random.uniform(size=(100, 2)),
77+
np.random.uniform(size=(100, 5)))
78+
rom = DummyROM(db)
79+
splitter = DatabaseSplitter(train=0.7, test=0.2, validation=0.05,
80+
predict=0.05, seed=0)
81+
splitter.fit_preprocessing(rom)
82+
total = (len(rom.train_full_database) +
83+
len(rom.test_full_database) +
84+
len(rom.validation_full_database) +
85+
len(rom.predict_full_database))
86+
self.assertEqual(total, 100)
87+
88+
def test_split_floats_returns_database_instances(self):
89+
db = Database(np.random.uniform(size=(100, 2)),
90+
np.random.uniform(size=(100, 5)))
91+
rom = DummyROM(db)
92+
splitter = DatabaseSplitter(train=0.8, test=0.2, seed=0)
93+
splitter.fit_preprocessing(rom)
94+
self.assertIsInstance(rom.train_full_database, Database)
95+
self.assertIsInstance(rom.test_full_database, Database)
96+
97+
def test_split_floats_inconsistent_ratios_raises(self):
98+
db = Database(np.random.uniform(size=(100, 2)),
99+
np.random.uniform(size=(100, 5)))
100+
rom = DummyROM(db)
101+
splitter = DatabaseSplitter(train=0.7, test=0.2, validation=0.0,
102+
predict=0.0)
103+
with self.assertRaises(ValueError):
104+
splitter.fit_preprocessing(rom)
105+
106+
107+
def test_split_seed_reproducibility(self):
108+
db = Database(np.random.uniform(size=(100, 2)),
109+
np.random.uniform(size=(100, 5)))
110+
rom1 = DummyROM(db)
111+
DatabaseSplitter(train=0.8, test=0.2, seed=42).fit_preprocessing(rom1)
112+
113+
rom2 = DummyROM(db)
114+
DatabaseSplitter(train=0.8, test=0.2, seed=42).fit_preprocessing(rom2)
115+
116+
np.testing.assert_array_equal(
117+
rom1.train_full_database.parameters_matrix,
118+
rom2.train_full_database.parameters_matrix,
119+
)
120+
121+
def test_split_different_seeds_differ(self):
122+
db = Database(np.random.uniform(size=(100, 2)),
123+
np.random.uniform(size=(100, 5)))
124+
rom1 = DummyROM(db)
125+
DatabaseSplitter(train=0.8, test=0.2, seed=0).fit_preprocessing(rom1)
126+
127+
rom2 = DummyROM(db)
128+
DatabaseSplitter(train=0.8, test=0.2, seed=99).fit_preprocessing(rom2)
129+
130+
with self.assertRaises(AssertionError):
131+
np.testing.assert_array_equal(
132+
rom1.train_full_database.parameters_matrix,
133+
rom2.train_full_database.parameters_matrix,
134+
)
135+
136+
def test_split_dict_database_explicit_flattening(self):
137+
db_a = Database(np.random.uniform(size=(100, 2)),
138+
np.random.uniform(size=(100, 5)))
139+
db_b = Database(np.random.uniform(size=(50, 2)),
140+
np.random.uniform(size=(50, 5)))
141+
142+
rom = DummyROM({'a': db_a, 'b': db_b})
143+
144+
splitter = DatabaseSplitter(train=80, test=20, validation=0, predict=0)
145+
splitter.fit_preprocessing(rom)
146+
147+
self.assertIsInstance(rom.train_full_database, Database)
148+
149+
self.assertEqual(len(rom.train_full_database), 80)
150+
151+
152+
class TestDatabaseDictionarySplitter(TestCase):
153+
154+
def _make_dict_rom(self):
155+
db_train = Database(np.random.uniform(size=(60, 2)),
156+
np.random.uniform(size=(60, 5)))
157+
db_test = Database(np.random.uniform(size=(20, 2)),
158+
np.random.uniform(size=(20, 5)))
159+
db_val = Database(np.random.uniform(size=(10, 2)),
160+
np.random.uniform(size=(10, 5)))
161+
db_pred = Database(np.random.uniform(size=(10, 2)),
162+
np.random.uniform(size=(10, 5)))
163+
db_dict = {
164+
'train': db_train,
165+
'test': db_test,
166+
'val': db_val,
167+
'pred': db_pred,
168+
}
169+
return DummyROM(db_dict), db_dict
170+
171+
def test_train_key_assigned(self):
172+
rom, db_dict = self._make_dict_rom()
173+
DatabaseDictionarySplitter(train_key='train').fit_preprocessing(rom)
174+
self.assertEqual(len(rom.train_full_database), 60)
175+
176+
def test_test_key_assigned(self):
177+
rom, db_dict = self._make_dict_rom()
178+
DatabaseDictionarySplitter(test_key='test').fit_preprocessing(rom)
179+
self.assertEqual(len(rom.test_full_database), 20)
180+
181+
def test_validation_key_assigned(self):
182+
rom, db_dict = self._make_dict_rom()
183+
DatabaseDictionarySplitter(validation_key='val').fit_preprocessing(rom)
184+
self.assertEqual(len(rom.validation_full_database), 10)
185+
186+
def test_predict_key_assigned(self):
187+
rom, db_dict = self._make_dict_rom()
188+
DatabaseDictionarySplitter(predict_key='pred').fit_preprocessing(rom)
189+
self.assertEqual(len(rom.predict_full_database), 10)
190+
191+
def test_all_keys_assigned(self):
192+
rom, db_dict = self._make_dict_rom()
193+
splitter = DatabaseDictionarySplitter(
194+
train_key='train', test_key='test',
195+
validation_key='val', predict_key='pred',
196+
)
197+
splitter.fit_preprocessing(rom)
198+
self.assertEqual(len(rom.train_full_database), 60)
199+
self.assertEqual(len(rom.test_full_database), 20)
200+
self.assertEqual(len(rom.validation_full_database), 10)
201+
self.assertEqual(len(rom.predict_full_database), 10)
202+
203+
def test_assigned_database_is_same_object(self):
204+
rom, db_dict = self._make_dict_rom()
205+
splitter = DatabaseDictionarySplitter(
206+
train_key='train', test_key='test',
207+
)
208+
splitter.fit_preprocessing(rom)
209+
self.assertIs(rom.train_full_database, db_dict['train'])
210+
self.assertIs(rom.test_full_database, db_dict['test'])
211+
212+
def test_unset_key_leaves_attribute_none(self):
213+
rom, _ = self._make_dict_rom()
214+
DatabaseDictionarySplitter(train_key='train').fit_preprocessing(rom)
215+
self.assertIsNone(rom.test_full_database)
216+
self.assertIsNone(rom.validation_full_database)
217+
self.assertIsNone(rom.predict_full_database)
218+
219+
def test_non_dict_database_raises(self):
220+
db = Database(np.random.uniform(size=(10, 2)),
221+
np.random.uniform(size=(10, 5)))
222+
rom = DummyROM(db)
223+
splitter = DatabaseDictionarySplitter(train_key='train')
224+
with self.assertRaises(ValueError):
225+
splitter.fit_preprocessing(rom)

0 commit comments

Comments
 (0)