Skip to content

Commit 8b5328e

Browse files
committed
Support float and boolean targets in KhiopsClassifier
1 parent df90b0e commit 8b5328e

File tree

3 files changed

+37
-6
lines changed

3 files changed

+37
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
- (General) `visualize_report` helper function to open reports with the Khiops Visualization and
1414
Khiops Co-Visualization app.
15+
- (`sklearn`) Support for `float` and `boolean` targets in `KhiopsClassifier`.
1516

1617
## 10.2.3.1 - 2024-11-27
1718

khiops/sklearn/estimators.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def _check_categorical_target_type(ds):
153153
or pd.api.types.is_string_dtype(ds.target_column.dtype)
154154
or pd.api.types.is_integer_dtype(ds.target_column.dtype)
155155
or pd.api.types.is_float_dtype(ds.target_column.dtype)
156+
or pd.api.types.is_bool_dtype(ds.target_column.dtype)
156157
):
157158
raise ValueError(
158159
f"'y' has invalid type '{ds.target_column_type}'. "
@@ -2093,6 +2094,24 @@ def _is_real_target_dtype_integer(self):
20932094
)
20942095
)
20952096

2097+
def _is_real_target_dtype_float(self):
2098+
return self._original_target_dtype is not None and (
2099+
pd.api.types.is_float_dtype(self._original_target_dtype)
2100+
or (
2101+
isinstance(self._original_target_dtype, pd.CategoricalDtype)
2102+
and pd.api.types.is_float_dtype(self._original_target_dtype.categories)
2103+
)
2104+
)
2105+
2106+
def _is_real_target_dtype_bool(self):
2107+
return self._original_target_dtype is not None and (
2108+
pd.api.types.is_bool_dtype(self._original_target_dtype)
2109+
or (
2110+
isinstance(self._original_target_dtype, pd.CategoricalDtype)
2111+
and pd.api.types.is_bool_dtype(self._original_target_dtype.categories)
2112+
)
2113+
)
2114+
20962115
def _sorted_prob_variable_names(self):
20972116
"""Returns the model probability variable names in the order of self.classes_"""
20982117
assert self.is_fitted_, "Model not fit yet"
@@ -2195,11 +2214,15 @@ def _fit_training_post_process(self, ds):
21952214
for key in variable.meta_data.keys:
21962215
if key.startswith("TargetProb"):
21972216
self.classes_.append(variable.meta_data.get_value(key))
2198-
if ds.is_in_memory and self._is_real_target_dtype_integer():
2199-
self.classes_ = [int(class_value) for class_value in self.classes_]
2217+
if ds.is_in_memory:
2218+
if self._is_real_target_dtype_integer():
2219+
self.classes_ = [int(class_value) for class_value in self.classes_]
2220+
elif self._is_real_target_dtype_float():
2221+
self.classes_ = [float(class_value) for class_value in self.classes_]
2222+
elif self._is_real_target_dtype_bool():
2223+
self.classes_ = [class_value == "True" for class_value in self.classes_]
22002224
self.classes_.sort()
22012225
self.classes_ = column_or_1d(self.classes_)
2202-
22032226
# Count number of classes
22042227
self.n_classes_ = len(self.classes_)
22052228

@@ -2259,13 +2282,11 @@ def predict(self, X):
22592282
"""
22602283
# Call the parent's method
22612284
y_pred = super().predict(X)
2262-
22632285
# Adjust the data type according to the original target type
22642286
# Note: String is coerced explictly because astype does not work as expected
22652287
if isinstance(y_pred, pd.DataFrame):
22662288
# Transform to numpy.ndarray
22672289
y_pred = y_pred.to_numpy(copy=False).ravel()
2268-
22692290
# If integer and string just transform
22702291
if pd.api.types.is_integer_dtype(self._original_target_dtype):
22712292
y_pred = y_pred.astype(self._original_target_dtype)
@@ -2275,6 +2296,10 @@ def predict(self, X):
22752296
self._original_target_dtype
22762297
):
22772298
y_pred = y_pred.astype(str, copy=False)
2299+
elif pd.api.types.is_float_dtype(self._original_target_dtype):
2300+
y_pred = y_pred.astype(float, copy=False)
2301+
elif pd.api.types.is_bool_dtype(self._original_target_dtype):
2302+
y_pred = y_pred.astype(bool, copy=False)
22782303
# If category first coerce the type to the categories' type
22792304
else:
22802305
assert isinstance(self._original_target_dtype, pd.CategoricalDtype), (

tests/test_sklearn_output_types.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@ def test_classifier_output_types(self):
5454
"""Test the KhiopsClassifier output types and classes of predict* methods"""
5555
X, y = create_iris()
5656
X_mt, X_sec_mt, _ = create_iris_mt()
57-
5857
fixtures = {
5958
"ys": {
6059
"int": y,
6160
"int binary": y.replace({0: 0, 1: 0, 2: 1}),
61+
"float": y.astype(float),
62+
"bool": y.replace({0: True, 1: True, 2: False}),
6263
"string": y.replace({0: "se", 1: "vi", 2: "ve"}),
6364
"string binary": y.replace({0: "vi_or_se", 1: "vi_or_se", 2: "ve"}),
6465
"int as string": y.replace({0: "8", 1: "9", 2: "10"}),
@@ -69,6 +70,8 @@ def test_classifier_output_types(self):
6970
"y_type_check": {
7071
"int": pd.api.types.is_integer_dtype,
7172
"int binary": pd.api.types.is_integer_dtype,
73+
"float": pd.api.types.is_float_dtype,
74+
"bool": pd.api.types.is_bool_dtype,
7275
"string": pd.api.types.is_string_dtype,
7376
"string binary": pd.api.types.is_string_dtype,
7477
"int as string": pd.api.types.is_string_dtype,
@@ -79,6 +82,8 @@ def test_classifier_output_types(self):
7982
"expected_classes": {
8083
"int": column_or_1d([0, 1, 2]),
8184
"int binary": column_or_1d([0, 1]),
85+
"float": column_or_1d([0.0, 1.0, 2.0]),
86+
"bool": column_or_1d([False, True]),
8287
"string": column_or_1d(["se", "ve", "vi"]),
8388
"string binary": column_or_1d(["ve", "vi_or_se"]),
8489
"int as string": column_or_1d(["10", "8", "9"]),

0 commit comments

Comments
 (0)