Skip to content

Commit a17f175

Browse files
committed
Fix imputer string dtype encoding
1 parent 8cf69e2 commit a17f175

1 file changed

Lines changed: 13 additions & 4 deletions

File tree

toad/impute.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import numpy as np
22
import pandas as pd
3-
from pandas.api.types import is_numeric_dtype
43
from sklearn.experimental import enable_iterative_imputer
54
from sklearn.impute import IterativeImputer
65
from sklearn.ensemble import RandomForestRegressor
7-
from sklearn.preprocessing import LabelEncoder
86

97

108

@@ -62,12 +60,18 @@ def _fit_encode(self, X, mask):
6260
X (DataFrame)
6361
mask (Mask): empty mask for X
6462
"""
63+
X = X.copy()
6564
category_data = X.select_dtypes(exclude = np.number).columns
6665

6766
for col in category_data:
68-
unique, X[col].loc[~mask[col]] = np.unique(X[col][~mask[col]], return_inverse = True)
67+
valid_mask = ~mask[col]
68+
values = X.loc[valid_mask, col].to_numpy(dtype = object)
69+
unique, encoded = np.unique(values, return_inverse = True)
6970

7071
self.encoder_dict[col] = unique
72+
encoded_col = pd.Series(np.nan, index = X.index, dtype = float)
73+
encoded_col.loc[valid_mask] = encoded.astype(float)
74+
X[col] = encoded_col
7175

7276
return X
7377

@@ -78,9 +82,14 @@ def _encode(self, X, mask):
7882
X (DataFrame)
7983
mask (Mask): empty mask for X
8084
"""
85+
X = X.copy()
8186
for col, unique in self.encoder_dict.items():
87+
valid_mask = ~mask[col]
8288
table = dict(zip(unique, np.arange(len(unique))))
83-
X[col].loc[~mask[col]] = np.array([table[v] for v in X[col][~mask[col]]])
89+
encoded_col = pd.Series(np.nan, index = X.index, dtype = float)
90+
values = X.loc[valid_mask, col].to_numpy(dtype = object)
91+
encoded_col.loc[valid_mask] = np.array([table[v] for v in values], dtype = float)
92+
X[col] = encoded_col
8493

8594
return X
8695

0 commit comments

Comments
 (0)