Skip to content

Commit 5a05269

Browse files
author
kshitij-maths
committed
feat: add normalization and improve error handling in Database
1 parent 615bd66 commit 5a05269

1 file changed

Lines changed: 55 additions & 8 deletions

File tree

ezyrb/database.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ class Database:
1515
1616
:param array_like parameters: the input parameters
1717
:param array_like snapshots: the input snapshots
18-
:param Scale scaler_parameters: the scaler for the parameters. Default
19-
is None meaning no scaling.
20-
:param Scale scaler_snapshots: the scaler for the snapshots. Default is
21-
None meaning no scaling.
2218
:param array_like space: the input spatial data
2319
2420
:Example:
@@ -46,6 +42,9 @@ def __init__(self, parameters=None, snapshots=None, space=None):
4642
)
4743
self._pairs = []
4844

45+
self.scaler_parameters = None
46+
self.scaler_snapshots = None
47+
4948
if parameters is None and snapshots is None:
5049
logger.debug("Empty database created")
5150
return
@@ -149,18 +148,66 @@ def add(self, parameter, snapshot):
149148
"""
150149
if not isinstance(parameter, Parameter):
151150
logger.error("Invalid parameter type: %s", type(parameter))
152-
raise ValueError
151+
raise TypeError(f"Expected a Parameter object, got {type(parameter)}")
153152

154153
if not isinstance(snapshot, Snapshot):
155154
logger.error("Invalid snapshot type: %s", type(snapshot))
156-
raise ValueError
155+
raise TypeError(f"Expected a Snapshot object, got {type(snapshot)}")
157156

158157
self._pairs.append((parameter, snapshot))
159158
logger.debug(
160159
"Added parameter-snapshot pair. Total pairs: %d", len(self._pairs)
161160
)
162161

163162
return self
163+
164+
def normalize_parameters(self, scaler=None):
165+
"""
166+
Normalize the parameters in the database.
167+
168+
:param scaler: A scaling object (e.g., from sklearn.preprocessing).
169+
If None, it defaults to a MinMaxScaler.
170+
"""
171+
if len(self._pairs) == 0:
172+
return self
173+
174+
from sklearn.preprocessing import MinMaxScaler
175+
if scaler is None:
176+
scaler = MinMaxScaler()
177+
178+
params = self.parameters_matrix
179+
normalized_params = scaler.fit_transform(params)
180+
181+
for i, pair in enumerate(self._pairs):
182+
pair[0].values = normalized_params[i]
183+
184+
self.scaler_parameters = scaler
185+
return self
186+
187+
188+
def normalize_snapshots(self, scaler=None):
189+
"""
190+
Normalize the snapshots in the database.
191+
192+
:param scaler: A scaling object (e.g., from sklearn.preprocessing).
193+
If None, it defaults to a MinMaxScaler.
194+
"""
195+
if len(self._pairs) == 0:
196+
return self
197+
198+
from sklearn.preprocessing import MinMaxScaler
199+
if scaler is None:
200+
scaler = MinMaxScaler()
201+
202+
snaps = self.snapshots_matrix
203+
normalized_snaps = scaler.fit_transform(snaps)
204+
205+
for i, pair in enumerate(self._pairs):
206+
# reshape the flat array back to its original multidimensional shape
207+
pair[1].values = normalized_snaps[i].reshape(pair[1].values.shape)
208+
209+
self.scaler_snapshots = scaler
210+
return self
164211

165212
def split(self, chunks, seed=None):
166213
"""
@@ -209,7 +256,7 @@ def split(self, chunks, seed=None):
209256

210257
else:
211258
logger.error("Invalid chunk type")
212-
ValueError
259+
raise TypeError(f"Invalid chunk type. Expected a list of integers or floats, but got {type(chunks)}.")
213260

214261
new_database = [Database() for _ in range(len(chunks))]
215262
for i, chunk in enumerate(chunks):
@@ -235,4 +282,4 @@ def get_snapshot_space(self, index):
235282
"""
236283
if index < 0 or index >= len(self._pairs):
237284
raise IndexError("Snapshot index out of range.")
238-
return self._pairs[index][1].space
285+
return self._pairs[index][1].space

0 commit comments

Comments
 (0)