Skip to content

Commit f37590d

Browse files
authored
Merge pull request #106 from maxsch3/feature/105
Feature/105
2 parents 8e78061 + 509d526 commit f37590d

7 files changed

Lines changed: 52 additions & 8 deletions

File tree

build_requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
setuptools-scm==3.5.0
2-
mkdocs==1.0.4
3-
mkdocs-jupyter==0.10.2
2+
mkdocs==1.1.2
3+
mkdocs-jupyter==0.13.0
44
notebook==6.0.2
55
pymdown-extensions==6.3
6+
nbconvert==5.6.1
67
git+https://github.com/tomchristie/mkautodoc.git#egg=mkautodoc

keras_batchflow/base/batch_shapers/var_shaper.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,19 @@
55

66
class VarShaper:
77

8+
"""
9+
This class is a wrapper around encoder. It abstracts away the encoding of a single column into a
10+
model-ready data. This class is used in BatchShaper where all leaves (a tuple (column name, encoder))
11+
in a structure are replaced with equivalent VarShaper objects in a BatchShaper's constructor method
12+
After that, BatchShaper use this equivalent structure of VarShaper objects rather than original
13+
"human-readable" structure of tuples.
14+
15+
This class is a backend class and you normally would not need to use it manually.
16+
17+
It implements core interface functions like transform, inverse_transform,
18+
as well as additional metadata functions like, shape, n_classes, dtype, etc.
19+
"""
20+
821
_dummy_constant_counter = 0
922

1023
def __init__(self, var_name, encoder, data_sample=None):
@@ -109,7 +122,6 @@ def _get_dtypes(self, sample):
109122
else:
110123
RuntimeError(f"Error: the class type {self._encoder} is not supported in '_get_dtypes' method")
111124

112-
113125
def _get_n_classes(self, encoder):
114126
"""
115127
Calculates number of classes provided by the encoder. This is required for creating embedding layer for

keras_batchflow/base/batch_transformers/base_random_cell.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,6 @@ def transform(self, batch):
326326
else:
327327
batch[self._cols] = transformed
328328
return batch
329+
330+
def inverse_transform(self, batch):
331+
return batch

keras_batchflow/base/batch_transformers/batch_fork.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,6 @@ def _verify_levels(self, levels):
1717
def transform(self, batch):
1818
batch = pd.concat([batch]*len(self._levels), axis=1, keys=self._levels)
1919
return batch
20+
21+
def inverse_transform(self, batch):
22+
return batch
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
from abc import ABC, abstractmethod
12

23

3-
class BatchTransformer:
4+
class BatchTransformer(ABC):
45
"""
56
This is an abstract class that defines basic functionality and interfaces of all BatchTransformers
67
"""
78
def __init__(self):
89
pass
910

11+
@abstractmethod
1012
def transform(self, batch):
11-
return batch
13+
pass
14+
15+
@abstractmethod
16+
def inverse_transform(self, batch):
17+
pass

test/test_batch_generator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,18 @@ def transform(self, batch):
132132
batch[self.col_name] = 'Red'
133133
return batch
134134

135-
bt1 = BatchTransformer()
135+
def inverse_transform(self, batch):
136+
return batch
137+
138+
class TransparentTransform(BatchTransformer):
139+
140+
def transform(self, batch):
141+
return batch
142+
143+
def inverse_transform(self, batch):
144+
return batch
145+
146+
bt1 = TransparentTransform()
136147
bt2 = TestTransform('label')
137148
bg = BatchGenerator(
138149
self.df,

test/test_triplet_pk_generator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,17 @@ def test_basic(self):
5454
assert batch[0].shape == (6, 1)
5555
assert batch[1].shape == (6, 1)
5656

57-
5857
def test_kwargs_pass_to_parent(self):
59-
bt = BatchTransformer()
58+
59+
class TransparentTransform(BatchTransformer):
60+
61+
def transform(self, batch):
62+
return batch
63+
64+
def inverse_transform(self, batch):
65+
return batch
66+
67+
bt = TransparentTransform()
6068
tg = TripletPKGenerator(
6169
data=self.df,
6270
triplet_label='label',

0 commit comments

Comments
 (0)