1+ import numpy as np
2+ from unittest import TestCase
3+ from ezyrb import Database
4+ from ezyrb .plugin .database_splitter import DatabaseSplitter , DatabaseDictionarySplitter
5+
6+ class DummyROM :
7+ train_full_database = None
8+ test_full_database = None
9+ validation_full_database = None
10+ predict_full_database = None
11+
12+ def __init__ (self , db ):
13+ self ._database = db
14+
15+
16+ class TestDatabaseSplitter (TestCase ):
17+
18+ def test_split_integers_train_size (self ):
19+ db = Database (np .random .uniform (size = (100 , 2 )),
20+ np .random .uniform (size = (100 , 5 )))
21+ rom = DummyROM (db )
22+ splitter = DatabaseSplitter (train = 80 , test = 20 , validation = 0 , predict = 0 )
23+ splitter .fit_preprocessing (rom )
24+ self .assertEqual (len (rom .train_full_database ), 80 )
25+
26+ def test_split_integers_test_size (self ):
27+ db = Database (np .random .uniform (size = (100 , 2 )),
28+ np .random .uniform (size = (100 , 5 )))
29+ rom = DummyROM (db )
30+ splitter = DatabaseSplitter (train = 80 , test = 20 , validation = 0 , predict = 0 )
31+ splitter .fit_preprocessing (rom )
32+ self .assertEqual (len (rom .test_full_database ), 20 )
33+
34+ def test_split_integers_validation_predict_empty (self ):
35+ db = Database (np .random .uniform (size = (100 , 2 )),
36+ np .random .uniform (size = (100 , 5 )))
37+ rom = DummyROM (db )
38+ splitter = DatabaseSplitter (train = 80 , test = 20 , validation = 0 , predict = 0 )
39+ splitter .fit_preprocessing (rom )
40+ self .assertEqual (len (rom .validation_full_database ), 0 )
41+ self .assertEqual (len (rom .predict_full_database ), 0 )
42+
43+ def test_split_integers_total_conserved (self ):
44+ db = Database (np .random .uniform (size = (100 , 2 )),
45+ np .random .uniform (size = (100 , 5 )))
46+ rom = DummyROM (db )
47+ splitter = DatabaseSplitter (train = 70 , test = 20 , validation = 5 , predict = 5 )
48+ splitter .fit_preprocessing (rom )
49+ total = (len (rom .train_full_database ) +
50+ len (rom .test_full_database ) +
51+ len (rom .validation_full_database ) +
52+ len (rom .predict_full_database ))
53+ self .assertEqual (total , 100 )
54+
55+ def test_split_integers_returns_database_instances (self ):
56+ db = Database (np .random .uniform (size = (100 , 2 )),
57+ np .random .uniform (size = (100 , 5 )))
58+ rom = DummyROM (db )
59+ splitter = DatabaseSplitter (train = 80 , test = 20 , validation = 0 , predict = 0 )
60+ splitter .fit_preprocessing (rom )
61+ self .assertIsInstance (rom .train_full_database , Database )
62+ self .assertIsInstance (rom .test_full_database , Database )
63+ self .assertIsInstance (rom .validation_full_database , Database )
64+ self .assertIsInstance (rom .predict_full_database , Database )
65+
66+ def test_split_integers_inconsistent_chunks_raises (self ):
67+ db = Database (np .random .uniform (size = (100 , 2 )),
68+ np .random .uniform (size = (100 , 5 )))
69+ rom = DummyROM (db )
70+ splitter = DatabaseSplitter (train = 70 , test = 20 , validation = 0 , predict = 0 )
71+ with self .assertRaises (ValueError ):
72+ splitter .fit_preprocessing (rom )
73+
74+
75+ def test_split_floats_total_conserved (self ):
76+ db = Database (np .random .uniform (size = (100 , 2 )),
77+ np .random .uniform (size = (100 , 5 )))
78+ rom = DummyROM (db )
79+ splitter = DatabaseSplitter (train = 0.7 , test = 0.2 , validation = 0.05 ,
80+ predict = 0.05 , seed = 0 )
81+ splitter .fit_preprocessing (rom )
82+ total = (len (rom .train_full_database ) +
83+ len (rom .test_full_database ) +
84+ len (rom .validation_full_database ) +
85+ len (rom .predict_full_database ))
86+ self .assertEqual (total , 100 )
87+
88+ def test_split_floats_returns_database_instances (self ):
89+ db = Database (np .random .uniform (size = (100 , 2 )),
90+ np .random .uniform (size = (100 , 5 )))
91+ rom = DummyROM (db )
92+ splitter = DatabaseSplitter (train = 0.8 , test = 0.2 , seed = 0 )
93+ splitter .fit_preprocessing (rom )
94+ self .assertIsInstance (rom .train_full_database , Database )
95+ self .assertIsInstance (rom .test_full_database , Database )
96+
97+ def test_split_floats_inconsistent_ratios_raises (self ):
98+ db = Database (np .random .uniform (size = (100 , 2 )),
99+ np .random .uniform (size = (100 , 5 )))
100+ rom = DummyROM (db )
101+ splitter = DatabaseSplitter (train = 0.7 , test = 0.2 , validation = 0.0 ,
102+ predict = 0.0 )
103+ with self .assertRaises (ValueError ):
104+ splitter .fit_preprocessing (rom )
105+
106+
107+ def test_split_seed_reproducibility (self ):
108+ db = Database (np .random .uniform (size = (100 , 2 )),
109+ np .random .uniform (size = (100 , 5 )))
110+ rom1 = DummyROM (db )
111+ DatabaseSplitter (train = 0.8 , test = 0.2 , seed = 42 ).fit_preprocessing (rom1 )
112+
113+ rom2 = DummyROM (db )
114+ DatabaseSplitter (train = 0.8 , test = 0.2 , seed = 42 ).fit_preprocessing (rom2 )
115+
116+ np .testing .assert_array_equal (
117+ rom1 .train_full_database .parameters_matrix ,
118+ rom2 .train_full_database .parameters_matrix ,
119+ )
120+
121+ def test_split_different_seeds_differ (self ):
122+ db = Database (np .random .uniform (size = (100 , 2 )),
123+ np .random .uniform (size = (100 , 5 )))
124+ rom1 = DummyROM (db )
125+ DatabaseSplitter (train = 0.8 , test = 0.2 , seed = 0 ).fit_preprocessing (rom1 )
126+
127+ rom2 = DummyROM (db )
128+ DatabaseSplitter (train = 0.8 , test = 0.2 , seed = 99 ).fit_preprocessing (rom2 )
129+
130+ with self .assertRaises (AssertionError ):
131+ np .testing .assert_array_equal (
132+ rom1 .train_full_database .parameters_matrix ,
133+ rom2 .train_full_database .parameters_matrix ,
134+ )
135+
136+ def test_split_dict_database_explicit_flattening (self ):
137+ db_a = Database (np .random .uniform (size = (100 , 2 )),
138+ np .random .uniform (size = (100 , 5 )))
139+ db_b = Database (np .random .uniform (size = (50 , 2 )),
140+ np .random .uniform (size = (50 , 5 )))
141+
142+ rom = DummyROM ({'a' : db_a , 'b' : db_b })
143+
144+ splitter = DatabaseSplitter (train = 80 , test = 20 , validation = 0 , predict = 0 )
145+ splitter .fit_preprocessing (rom )
146+
147+ self .assertIsInstance (rom .train_full_database , Database )
148+
149+ self .assertEqual (len (rom .train_full_database ), 80 )
150+
151+
152+ class TestDatabaseDictionarySplitter (TestCase ):
153+
154+ def _make_dict_rom (self ):
155+ db_train = Database (np .random .uniform (size = (60 , 2 )),
156+ np .random .uniform (size = (60 , 5 )))
157+ db_test = Database (np .random .uniform (size = (20 , 2 )),
158+ np .random .uniform (size = (20 , 5 )))
159+ db_val = Database (np .random .uniform (size = (10 , 2 )),
160+ np .random .uniform (size = (10 , 5 )))
161+ db_pred = Database (np .random .uniform (size = (10 , 2 )),
162+ np .random .uniform (size = (10 , 5 )))
163+ db_dict = {
164+ 'train' : db_train ,
165+ 'test' : db_test ,
166+ 'val' : db_val ,
167+ 'pred' : db_pred ,
168+ }
169+ return DummyROM (db_dict ), db_dict
170+
171+ def test_train_key_assigned (self ):
172+ rom , db_dict = self ._make_dict_rom ()
173+ DatabaseDictionarySplitter (train_key = 'train' ).fit_preprocessing (rom )
174+ self .assertEqual (len (rom .train_full_database ), 60 )
175+
176+ def test_test_key_assigned (self ):
177+ rom , db_dict = self ._make_dict_rom ()
178+ DatabaseDictionarySplitter (test_key = 'test' ).fit_preprocessing (rom )
179+ self .assertEqual (len (rom .test_full_database ), 20 )
180+
181+ def test_validation_key_assigned (self ):
182+ rom , db_dict = self ._make_dict_rom ()
183+ DatabaseDictionarySplitter (validation_key = 'val' ).fit_preprocessing (rom )
184+ self .assertEqual (len (rom .validation_full_database ), 10 )
185+
186+ def test_predict_key_assigned (self ):
187+ rom , db_dict = self ._make_dict_rom ()
188+ DatabaseDictionarySplitter (predict_key = 'pred' ).fit_preprocessing (rom )
189+ self .assertEqual (len (rom .predict_full_database ), 10 )
190+
191+ def test_all_keys_assigned (self ):
192+ rom , db_dict = self ._make_dict_rom ()
193+ splitter = DatabaseDictionarySplitter (
194+ train_key = 'train' , test_key = 'test' ,
195+ validation_key = 'val' , predict_key = 'pred' ,
196+ )
197+ splitter .fit_preprocessing (rom )
198+ self .assertEqual (len (rom .train_full_database ), 60 )
199+ self .assertEqual (len (rom .test_full_database ), 20 )
200+ self .assertEqual (len (rom .validation_full_database ), 10 )
201+ self .assertEqual (len (rom .predict_full_database ), 10 )
202+
203+ def test_assigned_database_is_same_object (self ):
204+ rom , db_dict = self ._make_dict_rom ()
205+ splitter = DatabaseDictionarySplitter (
206+ train_key = 'train' , test_key = 'test' ,
207+ )
208+ splitter .fit_preprocessing (rom )
209+ self .assertIs (rom .train_full_database , db_dict ['train' ])
210+ self .assertIs (rom .test_full_database , db_dict ['test' ])
211+
212+ def test_unset_key_leaves_attribute_none (self ):
213+ rom , _ = self ._make_dict_rom ()
214+ DatabaseDictionarySplitter (train_key = 'train' ).fit_preprocessing (rom )
215+ self .assertIsNone (rom .test_full_database )
216+ self .assertIsNone (rom .validation_full_database )
217+ self .assertIsNone (rom .predict_full_database )
218+
219+ def test_non_dict_database_raises (self ):
220+ db = Database (np .random .uniform (size = (10 , 2 )),
221+ np .random .uniform (size = (10 , 5 )))
222+ rom = DummyROM (db )
223+ splitter = DatabaseDictionarySplitter (train_key = 'train' )
224+ with self .assertRaises (ValueError ):
225+ splitter .fit_preprocessing (rom )
0 commit comments