Skip to content

Commit 3d56a94

Browse files
committed
support float and boolean targets in KhiopsClassifier
1 parent 5a561ed commit 3d56a94

File tree

4 files changed

+65
-7
lines changed

4 files changed

+65
-7
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
- Example: 10.2.1.4 is the 5th version that supports khiops 10.2.1.
77
- Internals: Changes in *Internals* sections are unlikely to be of interest for data scientists.
88

9+
## 10.3.0.1
10+
11+
### Added
12+
13+
- (`sklearn`) Supports boolean and float targets in `KhiopsClassifier`.
14+
915
## 10.2.4.0 - 2024-12-19
1016

1117
### Added

khiops/sklearn/dataset.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,8 +738,15 @@ def _init_target_column(self, y):
738738
if isinstance(y, str):
739739
y_checked = y
740740
else:
741-
y_checked = column_or_1d(y, warn=True)
742-
741+
if hasattr(y, "dtype"):
742+
if isinstance(y.dtype, pd.CategoricalDtype):
743+
y_checked = column_or_1d(
744+
y, warn=True, dtype=y.dtype.categories.dtype
745+
)
746+
else:
747+
y_checked = column_or_1d(y, warn=True, dtype=y.dtype)
748+
else:
749+
y_checked = column_or_1d(y, warn=True)
743750
# Check the target type coherence with those of X's tables
744751
if isinstance(
745752
self.main_table, (PandasTable, SparseTable, NumpyTable)

khiops/sklearn/estimators.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def _check_categorical_target_type(ds):
154154
or pd.api.types.is_string_dtype(ds.target_column.dtype)
155155
or pd.api.types.is_integer_dtype(ds.target_column.dtype)
156156
or pd.api.types.is_float_dtype(ds.target_column.dtype)
157+
or pd.api.types.is_bool_dtype(ds.target_column.dtype)
157158
):
158159
raise ValueError(
159160
f"'y' has invalid type '{ds.target_column_type}'. "
@@ -2123,6 +2124,24 @@ def _is_real_target_dtype_integer(self):
21232124
)
21242125
)
21252126

2127+
def _is_real_target_dtype_float(self):
2128+
return self._original_target_dtype is not None and (
2129+
pd.api.types.is_float_dtype(self._original_target_dtype)
2130+
or (
2131+
isinstance(self._original_target_dtype, pd.CategoricalDtype)
2132+
and pd.api.types.is_float_dtype(self._original_target_dtype.categories)
2133+
)
2134+
)
2135+
2136+
def _is_real_target_dtype_bool(self):
2137+
return self._original_target_dtype is not None and (
2138+
pd.api.types.is_bool_dtype(self._original_target_dtype)
2139+
or (
2140+
isinstance(self._original_target_dtype, pd.CategoricalDtype)
2141+
and pd.api.types.is_bool_dtype(self._original_target_dtype.categories)
2142+
)
2143+
)
2144+
21262145
def _sorted_prob_variable_names(self):
21272146
"""Returns the model probability variable names in the order of self.classes_"""
21282147
self._assert_is_fitted()
@@ -2227,8 +2246,13 @@ def _fit_training_post_process(self, ds):
22272246
for key in variable.meta_data.keys:
22282247
if key.startswith("TargetProb"):
22292248
self.classes_.append(variable.meta_data.get_value(key))
2230-
if ds.is_in_memory and self._is_real_target_dtype_integer():
2231-
self.classes_ = [int(class_value) for class_value in self.classes_]
2249+
if ds.is_in_memory:
2250+
if self._is_real_target_dtype_integer():
2251+
self.classes_ = [int(class_value) for class_value in self.classes_]
2252+
elif self._is_real_target_dtype_float():
2253+
self.classes_ = [float(class_value) for class_value in self.classes_]
2254+
elif self._is_real_target_dtype_bool():
2255+
self.classes_ = [class_value == "True" for class_value in self.classes_]
22322256
self.classes_.sort()
22332257
self.classes_ = column_or_1d(self.classes_)
22342258

tests/test_sklearn_output_types.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ def test_classifier_output_types(self):
7171
khc = KhiopsClassifier(n_trees=0)
7272
khc.fit(X, y)
7373
y_pred = khc.predict(X)
74+
khc.fit(X_mt, y)
75+
y_mt_pred = khc.predict(X_mt)
76+
7477
y_bin = y.replace({0: 0, 1: 0, 2: 1})
7578
khc.fit(X, y_bin)
7679
y_bin_pred = khc.predict(X)
77-
khc.fit(X_mt, y)
78-
khc.export_report_file("report.khj")
79-
y_mt_pred = khc.predict(X_mt)
8080
khc.fit(X_mt, y_bin)
8181
y_mt_bin_pred = khc.predict(X_mt)
8282

@@ -85,6 +85,8 @@ def test_classifier_output_types(self):
8585
"ys": {
8686
"int": y,
8787
"int binary": y_bin,
88+
"float": y.astype(float),
89+
"bool": y.replace({0: True, 1: True, 2: False}),
8890
"string": self._replace(y, {0: "se", 1: "vi", 2: "ve"}),
8991
"string binary": self._replace(y_bin, {0: "vi_or_se", 1: "ve"}),
9092
"int as string": self._replace(y, {0: "8", 1: "9", 2: "10"}),
@@ -93,30 +95,42 @@ def test_classifier_output_types(self):
9395
"cat string": pd.Series(
9496
self._replace(y, {0: "se", 1: "vi", 2: "ve"})
9597
).astype("category"),
98+
"cat float": y.astype(float).astype("category"),
99+
"cat bool": y.replace({0: True, 1: True, 2: False}).astype("category"),
96100
},
97101
"y_type_check": {
98102
"int": pd.api.types.is_integer_dtype,
99103
"int binary": pd.api.types.is_integer_dtype,
104+
"float": pd.api.types.is_float_dtype,
105+
"bool": pd.api.types.is_bool_dtype,
100106
"string": pd.api.types.is_string_dtype,
101107
"string binary": pd.api.types.is_string_dtype,
102108
"int as string": pd.api.types.is_string_dtype,
103109
"int as string binary": pd.api.types.is_string_dtype,
104110
"cat int": pd.api.types.is_integer_dtype,
105111
"cat string": pd.api.types.is_string_dtype,
112+
"cat float": pd.api.types.is_float_dtype,
113+
"cat bool": pd.api.types.is_bool_dtype,
106114
},
107115
"expected_classes": {
108116
"int": column_or_1d([0, 1, 2]),
109117
"int binary": column_or_1d([0, 1]),
118+
"float": column_or_1d([0.0, 1.0, 2.0]),
119+
"bool": column_or_1d([False, True]),
110120
"string": column_or_1d(["se", "ve", "vi"]),
111121
"string binary": column_or_1d(["ve", "vi_or_se"]),
112122
"int as string": column_or_1d(["10", "8", "9"]),
113123
"int as string binary": column_or_1d(["10", "89"]),
114124
"cat int": column_or_1d([0, 1, 2]),
115125
"cat string": column_or_1d(["se", "ve", "vi"]),
126+
"cat float": column_or_1d([0.0, 1.0, 2.0]),
127+
"cat bool": column_or_1d([False, True]),
116128
},
117129
"expected_y_preds": {
118130
"mono": {
119131
"int": y_pred,
132+
"float": y_pred.astype(float),
133+
"bool": self._replace(y_bin_pred, {0: True, 1: False}),
120134
"int binary": y_bin_pred,
121135
"string": self._replace(y_pred, {0: "se", 1: "vi", 2: "ve"}),
122136
"string binary": self._replace(
@@ -128,9 +142,13 @@ def test_classifier_output_types(self):
128142
),
129143
"cat int": y_pred,
130144
"cat string": self._replace(y_pred, {0: "se", 1: "vi", 2: "ve"}),
145+
"cat float": self._replace(y_pred, {0: 0.0, 1: 1.0, 2: 2.0}),
146+
"cat bool": self._replace(y_bin_pred, {0: True, 1: False}),
131147
},
132148
"multi": {
133149
"int": y_mt_pred,
150+
"float": y_mt_pred.astype(float),
151+
"bool": self._replace(y_mt_bin_pred, {0: True, 1: False}),
134152
"int binary": y_mt_bin_pred,
135153
"string": self._replace(y_mt_pred, {0: "se", 1: "vi", 2: "ve"}),
136154
"string binary": self._replace(
@@ -144,6 +162,8 @@ def test_classifier_output_types(self):
144162
),
145163
"cat int": y_mt_pred,
146164
"cat string": self._replace(y_mt_pred, {0: "se", 1: "vi", 2: "ve"}),
165+
"cat float": self._replace(y_mt_pred, {0: 0.0, 1: 1.0, 2: 2.0}),
166+
"cat bool": self._replace(y_mt_bin_pred, {0: True, 1: False}),
147167
},
148168
},
149169
"Xs": {
@@ -171,6 +191,7 @@ def test_classifier_output_types(self):
171191

172192
# Check the return type of predict
173193
y_pred = khc.predict(X)
194+
174195
self.assertTrue(
175196
y_type_check(y_pred),
176197
f"'{y_type_check.__name__}' was False for "

0 commit comments

Comments
 (0)