@@ -15,10 +15,6 @@ class Database:
1515
1616 :param array_like parameters: the input parameters
1717 :param array_like snapshots: the input snapshots
18- :param Scale scaler_parameters: the scaler for the parameters. Default
19- is None meaning no scaling.
20- :param Scale scaler_snapshots: the scaler for the snapshots. Default is
21- None meaning no scaling.
2218 :param array_like space: the input spatial data
2319
2420 :Example:
@@ -46,6 +42,9 @@ def __init__(self, parameters=None, snapshots=None, space=None):
4642 )
4743 self ._pairs = []
4844
45+ self .scaler_parameters = None
46+ self .scaler_snapshots = None
47+
4948 if parameters is None and snapshots is None :
5049 logger .debug ("Empty database created" )
5150 return
@@ -149,18 +148,66 @@ def add(self, parameter, snapshot):
149148 """
150149 if not isinstance (parameter , Parameter ):
151150 logger .error ("Invalid parameter type: %s" , type (parameter ))
152- raise ValueError
151+ raise TypeError ( f"Expected a Parameter object, got { type ( parameter ) } " )
153152
154153 if not isinstance (snapshot , Snapshot ):
155154 logger .error ("Invalid snapshot type: %s" , type (snapshot ))
156- raise ValueError
155+ raise TypeError ( f"Expected a Snapshot object, got { type ( snapshot ) } " )
157156
158157 self ._pairs .append ((parameter , snapshot ))
159158 logger .debug (
160159 "Added parameter-snapshot pair. Total pairs: %d" , len (self ._pairs )
161160 )
162161
163162 return self
163+
164+ def normalize_parameters (self , scaler = None ):
165+ """
166+ Normalize the parameters in the database.
167+
168+ :param scaler: A scaling object (e.g., from sklearn.preprocessing).
169+ If None, it defaults to a MinMaxScaler.
170+ """
171+ if len (self ._pairs ) == 0 :
172+ return self
173+
174+ from sklearn .preprocessing import MinMaxScaler
175+ if scaler is None :
176+ scaler = MinMaxScaler ()
177+
178+ params = self .parameters_matrix
179+ normalized_params = scaler .fit_transform (params )
180+
181+ for i , pair in enumerate (self ._pairs ):
182+ pair [0 ].values = normalized_params [i ]
183+
184+ self .scaler_parameters = scaler
185+ return self
186+
187+
188+ def normalize_snapshots (self , scaler = None ):
189+ """
190+ Normalize the snapshots in the database.
191+
192+ :param scaler: A scaling object (e.g., from sklearn.preprocessing).
193+ If None, it defaults to a MinMaxScaler.
194+ """
195+ if len (self ._pairs ) == 0 :
196+ return self
197+
198+ from sklearn .preprocessing import MinMaxScaler
199+ if scaler is None :
200+ scaler = MinMaxScaler ()
201+
202+ snaps = self .snapshots_matrix
203+ normalized_snaps = scaler .fit_transform (snaps )
204+
205+ for i , pair in enumerate (self ._pairs ):
206+ # reshape the flat array back to its original multidimensional shape
207+ pair [1 ].values = normalized_snaps [i ].reshape (pair [1 ].values .shape )
208+
209+ self .scaler_snapshots = scaler
210+ return self
164211
165212 def split (self , chunks , seed = None ):
166213 """
@@ -209,7 +256,7 @@ def split(self, chunks, seed=None):
209256
210257 else :
211258 logger .error ("Invalid chunk type" )
212- ValueError
259+ raise TypeError ( f"Invalid chunk type. Expected a list of integers or floats, but got { type ( chunks ) } ." )
213260
214261 new_database = [Database () for _ in range (len (chunks ))]
215262 for i , chunk in enumerate (chunks ):
@@ -235,4 +282,4 @@ def get_snapshot_space(self, index):
235282 """
236283 if index < 0 or index >= len (self ._pairs ):
237284 raise IndexError ("Snapshot index out of range." )
238- return self ._pairs [index ][1 ].space
285+ return self ._pairs [index ][1 ].space
0 commit comments