Skip to content

Commit 2d36fad

Browse files
authored
Merge pull request #115 from maxsch3/bug/114
#110 fixed and tests updated. Indices were fixed and numpy conversion…
2 parents 0276211 + 0fd3361 commit 2d36fad

3 files changed

Lines changed: 44 additions & 8 deletions

File tree

keras_batchflow/base/batch_transformers/base_random_cell.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pandas as pd
23
from .batch_transformer import BatchTransformer
34

45

@@ -315,16 +316,25 @@ def transform(self, batch):
315316
raise KeyError(f'Error: The data passed to {type(self).__name__} is not forked, while fork parameter '
316317
f'is specified. Please add multiindex level to columns of your data or use DataFork '
317318
f'batch transform before.')
318-
subset = batch[self._fork][self._cols]
319+
if self._fork not in batch.columns.get_level_values(0):
320+
raise KeyError(f"Error: fork {self._fork} specified as a parameter 'data_fork' was not found in data. "
321+
f"The following forks were found: {set(batch.columns.get_level_values(0))}. Please "
322+
f"make sure you are using DataFork that is configured to provide this a fork with the"
323+
f"name specified.")
324+
# the top level of multiinedex is dropped here to avoid a hassle of handling it in methods _make_mask and
325+
# _make_augmented_version. This dropped level will be added later when merged back with the batch
326+
subset = batch[self._fork][self._cols].copy()
319327
else:
320-
subset = batch[self._cols]
328+
subset = batch[self._cols].copy()
321329
mask = self._make_mask(subset)
322330
augmented_batch = self._make_augmented_version(subset)
323-
transformed = subset.mask(mask.astype(bool), augmented_batch.values)
331+
transformed = subset.mask(mask.astype(bool), augmented_batch)
324332
if self._fork:
325-
batch.loc(axis=1)[self._fork, self._cols] = transformed.values
333+
# in order for loc to work, the top level index must be restored
334+
transformed.columns = pd.MultiIndex.from_product([[self._fork], transformed.columns])
335+
batch.loc[:, (self._fork, self._cols)] = transformed
326336
else:
327-
batch[self._cols] = transformed
337+
batch.loc[:, self._cols] = transformed
328338
return batch
329339

330340
def inverse_transform(self, batch):

keras_batchflow/base/batch_transformers/shuffle_noise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,5 @@ class ShuffleNoise(BaseRandomCellTransform):
3333
"""
3434

3535
def _make_augmented_version(self, batch):
36-
batch1 = batch.apply(lambda x: x.sample(frac=1).values)
37-
return batch1
36+
augmented_batch = batch.apply(lambda x: x.sample(frac=1).values)
37+
return augmented_batch

test/test_base_random_cell_transform.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pandas as pd
33
import numpy as np
44
from scipy.stats import binom_test, chisquare
5-
from keras_batchflow.base.batch_transformers import BaseRandomCellTransform
5+
from keras_batchflow.base.batch_transformers import BaseRandomCellTransform, BatchFork
66

77

88
class LocalVersionTransform(BaseRandomCellTransform):
@@ -17,6 +17,19 @@ def _make_augmented_version(self, batch):
1717
return batch
1818

1919

20+
class TestTransformInt(BaseRandomCellTransform):
21+
"""
22+
BaseRandomCellTransform raise NotImplemented error in below function which does not allow to test
23+
transform functionality. I'm re-defining this method here to be able to test it.
24+
"""
25+
26+
def _make_augmented_version(self, batch):
27+
batch = batch.copy()
28+
for c in self._cols:
29+
batch[c] = 0
30+
return batch
31+
32+
2033
class TestFeatureDropout:
2134

2235
df = None
@@ -188,6 +201,19 @@ def test_transform_fork_many_cols(self):
188201
assert (batch1['x']['var2'] == '').all()
189202
assert (batch1['x']['label'] != '').all()
190203

204+
def test_non_numpy_dtype(self):
205+
"""
206+
This test is to make sure the transform does not convert data to numpy behind the scenes, causing
207+
unpredictable dtype changes
208+
:return:
209+
"""
210+
data = pd.DataFrame({'var1': np.random.randint(low=0, high=10, size=100)}).astype('Int64')
211+
data.iloc[0, 0] = None
212+
data_forked = BatchFork().transform(data.copy())
213+
ct = TestTransformInt([0., 1.], cols=['var1'], data_fork='x')
214+
data_transformed = ct.transform(data_forked)
215+
assert all(dt.name == 'Int64' for dt in data_transformed.dtypes)
216+
191217

192218
# def test_row_dist(self):
193219
# fd = FeatureDropout(.6, 'var1', '')

0 commit comments

Comments
 (0)