|
1 | 1 | import numpy as np |
| 2 | +import pandas as pd |
2 | 3 | from .batch_transformer import BatchTransformer |
3 | 4 |
|
4 | 5 |
|
@@ -315,16 +316,25 @@ def transform(self, batch): |
315 | 316 | raise KeyError(f'Error: The data passed to {type(self).__name__} is not forked, while fork parameter ' |
316 | 317 | f'is specified. Please add multiindex level to columns of your data or use DataFork ' |
317 | 318 | f'batch transform before.') |
318 | | - subset = batch[self._fork][self._cols] |
| 319 | + if self._fork not in batch.columns.get_level_values(0): |
| 320 | + raise KeyError(f"Error: fork {self._fork} specified as a parameter 'data_fork' was not found in data. " |
| 321 | + f"The following forks were found: {set(batch.columns.get_level_values(0))}. Please " |
| 322 | + f"make sure you are using DataFork that is configured to provide this a fork with the" |
| 323 | + f"name specified.") |
| 324 | + # the top level of multiinedex is dropped here to avoid a hassle of handling it in methods _make_mask and |
| 325 | + # _make_augmented_version. This dropped level will be added later when merged back with the batch |
| 326 | + subset = batch[self._fork][self._cols].copy() |
319 | 327 | else: |
320 | | - subset = batch[self._cols] |
| 328 | + subset = batch[self._cols].copy() |
321 | 329 | mask = self._make_mask(subset) |
322 | 330 | augmented_batch = self._make_augmented_version(subset) |
323 | | - transformed = subset.mask(mask.astype(bool), augmented_batch.values) |
| 331 | + transformed = subset.mask(mask.astype(bool), augmented_batch) |
324 | 332 | if self._fork: |
325 | | - batch.loc(axis=1)[self._fork, self._cols] = transformed.values |
| 333 | + # in order for loc to work, the top level index must be restored |
| 334 | + transformed.columns = pd.MultiIndex.from_product([[self._fork], transformed.columns]) |
| 335 | + batch.loc[:, (self._fork, self._cols)] = transformed |
326 | 336 | else: |
327 | | - batch[self._cols] = transformed |
| 337 | + batch.loc[:, self._cols] = transformed |
328 | 338 | return batch |
329 | 339 |
|
330 | 340 | def inverse_transform(self, batch): |
|
0 commit comments