Skip to content

Commit c7fbcee

Browse files
authored
support float and boolean targets in KhiopsClassifier (#375)
1 parent a9a2c5f commit c7fbcee

File tree

4 files changed

+91
-9
lines changed

4 files changed

+91
-9
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
- Example: 10.2.1.4 is the 5th version that supports khiops 10.2.1.
77
- Internals: Changes in *Internals* sections are unlikely to be of interest for data scientists.
88

9+
## Unreleased
10+
11+
### Added
12+
- (`sklearn`) Support for boolean and float targets in `KhiopsClassifier`.
13+
914
## 10.2.4.0 - 2024-12-19
1015

1116
### Added

khiops/sklearn/dataset.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
"""Classes for handling diverse data tables"""
88
import csv
99
import io
10+
import warnings
1011
from abc import ABC, abstractmethod
1112
from collections.abc import Iterable, Mapping, Sequence
1213

1314
import numpy as np
1415
import pandas as pd
16+
import sklearn
1517
from scipy import sparse as sp
1618
from sklearn.utils import check_array
1719
from sklearn.utils.validation import column_or_1d
@@ -422,6 +424,19 @@ def write_internal_data_table(dataframe, file_path_or_stream):
422424
)
423425

424426

427+
def _column_or_1d_with_dtype(y, dtype=None):
428+
# 'dtype' has been introduced on `column_or_1d' since Scikit-learn 1.2;
429+
if sklearn.__version__ < "1.2":
430+
if pd.api.types.is_string_dtype(dtype) and y.isin(["True", "False"]).all():
431+
warnings.warn(
432+
"'y' stores strings restricted to 'True'/'False' values: "
433+
"The predict method may return a bool vector."
434+
)
435+
return column_or_1d(y, warn=True)
436+
else:
437+
return column_or_1d(y, warn=True, dtype=dtype)
438+
439+
425440
class Dataset:
426441
"""A representation of a dataset
427442
@@ -618,8 +633,22 @@ def _init_target_column(self, y):
618633
y_checked = y
619634
# pandas.Series, pandas.DataFrame or numpy.ndarray
620635
else:
621-
y_checked = column_or_1d(y, warn=True)
622-
636+
if hasattr(y, "dtype"):
637+
if isinstance(y.dtype, pd.CategoricalDtype):
638+
y_checked = _column_or_1d_with_dtype(
639+
y, dtype=y.dtype.categories.dtype
640+
)
641+
else:
642+
y_checked = _column_or_1d_with_dtype(y, dtype=y.dtype)
643+
elif hasattr(y, "dtypes"):
644+
if isinstance(y.dtypes[0], pd.CategoricalDtype):
645+
y_checked = _column_or_1d_with_dtype(
646+
y, dtype=y.dtypes[0].categories.dtype
647+
)
648+
else:
649+
y_checked = _column_or_1d_with_dtype(y)
650+
else:
651+
y_checked = _column_or_1d_with_dtype(y)
623652
# Check the target type coherence with those of X's tables
624653
if isinstance(
625654
self.main_table, (PandasTable, SparseTable, NumpyTable)

khiops/sklearn/estimators.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def _check_categorical_target_type(ds):
148148
or pd.api.types.is_string_dtype(ds.target_column.dtype)
149149
or pd.api.types.is_integer_dtype(ds.target_column.dtype)
150150
or pd.api.types.is_float_dtype(ds.target_column.dtype)
151+
or pd.api.types.is_bool_dtype(ds.target_column.dtype)
151152
):
152153
raise ValueError(
153154
f"'y' has invalid type '{ds.target_column_type}'. "
@@ -1856,6 +1857,24 @@ def _is_real_target_dtype_integer(self):
18561857
)
18571858
)
18581859

1860+
def _is_real_target_dtype_float(self):
1861+
return self._original_target_dtype is not None and (
1862+
pd.api.types.is_float_dtype(self._original_target_dtype)
1863+
or (
1864+
isinstance(self._original_target_dtype, pd.CategoricalDtype)
1865+
and pd.api.types.is_float_dtype(self._original_target_dtype.categories)
1866+
)
1867+
)
1868+
1869+
def _is_real_target_dtype_bool(self):
1870+
return self._original_target_dtype is not None and (
1871+
pd.api.types.is_bool_dtype(self._original_target_dtype)
1872+
or (
1873+
isinstance(self._original_target_dtype, pd.CategoricalDtype)
1874+
and pd.api.types.is_bool_dtype(self._original_target_dtype.categories)
1875+
)
1876+
)
1877+
18591878
def _sorted_prob_variable_names(self):
18601879
"""Returns the model probability variable names in the order of self.classes_"""
18611880
self._assert_is_fitted()
@@ -1949,7 +1968,11 @@ def _fit_training_post_process(self, ds):
19491968
self.classes_.append(variable.meta_data.get_value(key))
19501969
if self._is_real_target_dtype_integer():
19511970
self.classes_ = [int(class_value) for class_value in self.classes_]
1952-
self.classes_.sort()
1971+
elif self._is_real_target_dtype_float():
1972+
self.classes_ = [float(class_value) for class_value in self.classes_]
1973+
elif self._is_real_target_dtype_bool():
1974+
self.classes_ = [class_value == "True" for class_value in self.classes_]
1975+
self.classes_.sort()
19531976
self.classes_ = column_or_1d(self.classes_)
19541977

19551978
# Count number of classes
@@ -1996,9 +2019,10 @@ def predict(self, X):
19962019
-------
19972020
`ndarray <numpy.ndarray>`
19982021
An array containing the encoded columns. A first column containing key
1999-
column ids is added in multi-table mode. The `numpy.dtype` of the array is
2000-
integer if the classifier was learned with an integer ``y``. Otherwise it
2001-
will be ``str``.
2022+
column ids is added in multi-table mode. The `numpy.dtype` of the array
2023+
matches the type of ``y`` used during training. It will be integer, float,
2024+
or boolean if the classifier was trained with a ``y`` of the corresponding
2025+
type. Otherwise it will be ``str``.
20022026
20032027
The key columns are added for multi-table tasks.
20042028
"""

tests/test_sklearn_output_types.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ def test_classifier_output_types(self):
7171
khc = KhiopsClassifier(n_trees=0)
7272
khc.fit(X, y)
7373
y_pred = khc.predict(X)
74+
khc.fit(X_mt, y)
75+
y_mt_pred = khc.predict(X_mt)
76+
7477
y_bin = y.replace({0: 0, 1: 0, 2: 1})
7578
khc.fit(X, y_bin)
7679
y_bin_pred = khc.predict(X)
77-
khc.fit(X_mt, y)
78-
khc.export_report_file("report.khj")
79-
y_mt_pred = khc.predict(X_mt)
8080
khc.fit(X_mt, y_bin)
8181
y_mt_bin_pred = khc.predict(X_mt)
8282

@@ -85,6 +85,8 @@ def test_classifier_output_types(self):
8585
"ys": {
8686
"int": y,
8787
"int binary": y_bin,
88+
"float": y.astype(float),
89+
"bool": y.replace({0: True, 1: True, 2: False}),
8890
"string": self._replace(y, {0: "se", 1: "vi", 2: "ve"}),
8991
"string binary": self._replace(y_bin, {0: "vi_or_se", 1: "ve"}),
9092
"int as string": self._replace(y, {0: "8", 1: "9", 2: "10"}),
@@ -93,30 +95,42 @@ def test_classifier_output_types(self):
9395
"cat string": pd.Series(
9496
self._replace(y, {0: "se", 1: "vi", 2: "ve"})
9597
).astype("category"),
98+
"cat float": y.astype(float).astype("category"),
99+
"cat bool": y.replace({0: True, 1: True, 2: False}).astype("category"),
96100
},
97101
"y_type_check": {
98102
"int": pd.api.types.is_integer_dtype,
99103
"int binary": pd.api.types.is_integer_dtype,
104+
"float": pd.api.types.is_float_dtype,
105+
"bool": pd.api.types.is_bool_dtype,
100106
"string": pd.api.types.is_string_dtype,
101107
"string binary": pd.api.types.is_string_dtype,
102108
"int as string": pd.api.types.is_string_dtype,
103109
"int as string binary": pd.api.types.is_string_dtype,
104110
"cat int": pd.api.types.is_integer_dtype,
105111
"cat string": pd.api.types.is_string_dtype,
112+
"cat float": pd.api.types.is_float_dtype,
113+
"cat bool": pd.api.types.is_bool_dtype,
106114
},
107115
"expected_classes": {
108116
"int": column_or_1d([0, 1, 2]),
109117
"int binary": column_or_1d([0, 1]),
118+
"float": column_or_1d([0.0, 1.0, 2.0]),
119+
"bool": column_or_1d([False, True]),
110120
"string": column_or_1d(["se", "ve", "vi"]),
111121
"string binary": column_or_1d(["ve", "vi_or_se"]),
112122
"int as string": column_or_1d(["10", "8", "9"]),
113123
"int as string binary": column_or_1d(["10", "89"]),
114124
"cat int": column_or_1d([0, 1, 2]),
115125
"cat string": column_or_1d(["se", "ve", "vi"]),
126+
"cat float": column_or_1d([0.0, 1.0, 2.0]),
127+
"cat bool": column_or_1d([False, True]),
116128
},
117129
"expected_y_preds": {
118130
"mono": {
119131
"int": y_pred,
132+
"float": y_pred.astype(float),
133+
"bool": self._replace(y_bin_pred, {0: True, 1: False}),
120134
"int binary": y_bin_pred,
121135
"string": self._replace(y_pred, {0: "se", 1: "vi", 2: "ve"}),
122136
"string binary": self._replace(
@@ -128,9 +142,15 @@ def test_classifier_output_types(self):
128142
),
129143
"cat int": y_pred,
130144
"cat string": self._replace(y_pred, {0: "se", 1: "vi", 2: "ve"}),
145+
"cat float": self._replace(
146+
y_pred, {target: float(target) for target in (0, 1, 2)}
147+
),
148+
"cat bool": self._replace(y_bin_pred, {0: True, 1: False}),
131149
},
132150
"multi": {
133151
"int": y_mt_pred,
152+
"float": y_mt_pred.astype(float),
153+
"bool": self._replace(y_mt_bin_pred, {0: True, 1: False}),
134154
"int binary": y_mt_bin_pred,
135155
"string": self._replace(y_mt_pred, {0: "se", 1: "vi", 2: "ve"}),
136156
"string binary": self._replace(
@@ -144,6 +164,10 @@ def test_classifier_output_types(self):
144164
),
145165
"cat int": y_mt_pred,
146166
"cat string": self._replace(y_mt_pred, {0: "se", 1: "vi", 2: "ve"}),
167+
"cat float": self._replace(
168+
y_mt_pred, {target: float(target) for target in (0, 1, 2)}
169+
),
170+
"cat bool": self._replace(y_mt_bin_pred, {0: True, 1: False}),
147171
},
148172
},
149173
"Xs": {

0 commit comments

Comments
 (0)