Skip to content

Commit eeddace

Browse files
Modify CategoricalDataProcessor to avoid regrouping of dummy variables
1 parent 8d8d553 commit eeddace

1 file changed

Lines changed: 10 additions & 2 deletions

File tree

cobra/preprocessing/categorical_data_processor.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,13 @@ def _fit_column(self, data: pd.DataFrame, column_name: str,
191191

192192
unique_categories = list(X.unique())
193193

194+
# do not merge categories in case of dummies, i.e. 0 and 1
195+
# (and possibly "Missings")
196+
if (len(unique_categories) == 2
197+
or (len(unique_categories) == 3
198+
and "Missing" in unique_categories)):
199+
return set(unique_categories)
200+
194201
# get small categories and add them to the merged category list
195202
small_categories = (CategoricalDataProcessor
196203
._get_small_categories(
@@ -420,7 +427,8 @@ def _compute_p_value(X: pd.Series, y: pd.Series, category: str,
420427

421428
@staticmethod
422429
def _replace_categories(data: pd.Series, categories: set) -> pd.Series:
423-
"""replace categories in set with "Other"
430+
"""replace categories in set with "Other" and transform the remaining
431+
categories to strings to avoid type errors later on in the pipeline
424432
425433
Parameters
426434
----------
@@ -434,4 +442,4 @@ def _replace_categories(data: pd.Series, categories: set) -> pd.Series:
434442
pd.Series
435443
Description
436444
"""
437-
return data.apply(lambda x: x if x in categories else "Other")
445+
return data.apply(lambda x: str(x) if x in categories else "Other")

0 commit comments

Comments
 (0)