diff --git a/CHANGELOG.md b/CHANGELOG.md index afc8aefb..74038f34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,11 @@ - Example: 10.2.1.4 is the 5th version that supports khiops 10.2.1. - Internals: Changes in *Internals* sections are unlikely to be of interest for data scientists. +## Unreleased + +### Added +- (`sklearn`) Support for boolean and float targets in `KhiopsClassifier`. + ## 10.3.0.0 - 2025-02-10 ### Fixed diff --git a/khiops/sklearn/dataset.py b/khiops/sklearn/dataset.py index cc50d66c..03cd3f10 100644 --- a/khiops/sklearn/dataset.py +++ b/khiops/sklearn/dataset.py @@ -14,6 +14,7 @@ import numpy as np import pandas as pd +import sklearn from scipy import sparse as sp from sklearn.utils import check_array from sklearn.utils.validation import column_or_1d @@ -430,6 +431,19 @@ def write_internal_data_table(dataframe, file_path_or_stream): ) +def _column_or_1d_with_dtype(y, dtype=None): + # 'dtype' has been introduced on `column_or_1d' since Scikit-learn 1.2; + if sklearn.__version__ < "1.2": + if pd.api.types.is_string_dtype(dtype) and y.isin(["True", "False"]).all(): + warnings.warn( + "'y' stores strings restricted to 'True'/'False' values: " + "The predict method may return a bool vector." + ) + return column_or_1d(y, warn=True) + else: + return column_or_1d(y, warn=True, dtype=dtype) + + class Dataset: """A representation of a dataset @@ -738,8 +752,22 @@ def _init_target_column(self, y): if isinstance(y, str): y_checked = y else: - y_checked = column_or_1d(y, warn=True) - + if hasattr(y, "dtype"): + if isinstance(y.dtype, pd.CategoricalDtype): + y_checked = _column_or_1d_with_dtype( + y, dtype=y.dtype.categories.dtype + ) + else: + y_checked = _column_or_1d_with_dtype(y, dtype=y.dtype) + elif hasattr(y, "dtypes"): + if isinstance(y.dtypes[0], pd.CategoricalDtype): + y_checked = _column_or_1d_with_dtype( + y, dtype=y.dtypes[0].categories.dtype + ) + else: + y_checked = _column_or_1d_with_dtype(y) + else: + y_checked = _column_or_1d_with_dtype(y) # Check the target type coherence with those of X's tables if isinstance( self.main_table, (PandasTable, SparseTable, NumpyTable) diff --git a/khiops/sklearn/estimators.py b/khiops/sklearn/estimators.py index 470afef4..214b3d9e 100644 --- a/khiops/sklearn/estimators.py +++ b/khiops/sklearn/estimators.py @@ -154,6 +154,7 @@ def _check_categorical_target_type(ds): or pd.api.types.is_string_dtype(ds.target_column.dtype) or pd.api.types.is_integer_dtype(ds.target_column.dtype) or pd.api.types.is_float_dtype(ds.target_column.dtype) + or pd.api.types.is_bool_dtype(ds.target_column.dtype) ): raise ValueError( f"'y' has invalid type '{ds.target_column_type}'. " @@ -2123,6 +2124,24 @@ def _is_real_target_dtype_integer(self): ) ) + def _is_real_target_dtype_float(self): + return self._original_target_dtype is not None and ( + pd.api.types.is_float_dtype(self._original_target_dtype) + or ( + isinstance(self._original_target_dtype, pd.CategoricalDtype) + and pd.api.types.is_float_dtype(self._original_target_dtype.categories) + ) + ) + + def _is_real_target_dtype_bool(self): + return self._original_target_dtype is not None and ( + pd.api.types.is_bool_dtype(self._original_target_dtype) + or ( + isinstance(self._original_target_dtype, pd.CategoricalDtype) + and pd.api.types.is_bool_dtype(self._original_target_dtype.categories) + ) + ) + def _sorted_prob_variable_names(self): """Returns the model probability variable names in the order of self.classes_""" self._assert_is_fitted() @@ -2227,8 +2246,13 @@ def _fit_training_post_process(self, ds): for key in variable.meta_data.keys: if key.startswith("TargetProb"): self.classes_.append(variable.meta_data.get_value(key)) - if ds.is_in_memory and self._is_real_target_dtype_integer(): - self.classes_ = [int(class_value) for class_value in self.classes_] + if ds.is_in_memory: + if self._is_real_target_dtype_integer(): + self.classes_ = [int(class_value) for class_value in self.classes_] + elif self._is_real_target_dtype_float(): + self.classes_ = [float(class_value) for class_value in self.classes_] + elif self._is_real_target_dtype_bool(): + self.classes_ = [class_value == "True" for class_value in self.classes_] self.classes_.sort() self.classes_ = column_or_1d(self.classes_) @@ -2283,9 +2307,10 @@ def predict(self, X): ------- `ndarray ` An array containing the encoded columns. A first column containing key - column ids is added in multi-table mode. The `numpy.dtype` of the array is - integer if the classifier was learned with an integer ``y``. Otherwise it - will be ``str``. + column ids is added in multi-table mode. The `numpy.dtype` of the array + matches the type of ``y`` used during training. It will be integer, float, + or boolean if the classifier was trained with a ``y`` of the corresponding + type. Otherwise it will be ``str``. The key columns are added for multi-table tasks. """ diff --git a/tests/test_sklearn_output_types.py b/tests/test_sklearn_output_types.py index 572f378e..5d75e3af 100644 --- a/tests/test_sklearn_output_types.py +++ b/tests/test_sklearn_output_types.py @@ -71,12 +71,12 @@ def test_classifier_output_types(self): khc = KhiopsClassifier(n_trees=0) khc.fit(X, y) y_pred = khc.predict(X) + khc.fit(X_mt, y) + y_mt_pred = khc.predict(X_mt) + y_bin = y.replace({0: 0, 1: 0, 2: 1}) khc.fit(X, y_bin) y_bin_pred = khc.predict(X) - khc.fit(X_mt, y) - khc.export_report_file("report.khj") - y_mt_pred = khc.predict(X_mt) khc.fit(X_mt, y_bin) y_mt_bin_pred = khc.predict(X_mt) @@ -85,6 +85,8 @@ def test_classifier_output_types(self): "ys": { "int": y, "int binary": y_bin, + "float": y.astype(float), + "bool": y.replace({0: True, 1: True, 2: False}), "string": self._replace(y, {0: "se", 1: "vi", 2: "ve"}), "string binary": self._replace(y_bin, {0: "vi_or_se", 1: "ve"}), "int as string": self._replace(y, {0: "8", 1: "9", 2: "10"}), @@ -93,30 +95,42 @@ def test_classifier_output_types(self): "cat string": pd.Series( self._replace(y, {0: "se", 1: "vi", 2: "ve"}) ).astype("category"), + "cat float": y.astype(float).astype("category"), + "cat bool": y.replace({0: True, 1: True, 2: False}).astype("category"), }, "y_type_check": { "int": pd.api.types.is_integer_dtype, "int binary": pd.api.types.is_integer_dtype, + "float": pd.api.types.is_float_dtype, + "bool": pd.api.types.is_bool_dtype, "string": pd.api.types.is_string_dtype, "string binary": pd.api.types.is_string_dtype, "int as string": pd.api.types.is_string_dtype, "int as string binary": pd.api.types.is_string_dtype, "cat int": pd.api.types.is_integer_dtype, "cat string": pd.api.types.is_string_dtype, + "cat float": pd.api.types.is_float_dtype, + "cat bool": pd.api.types.is_bool_dtype, }, "expected_classes": { "int": column_or_1d([0, 1, 2]), "int binary": column_or_1d([0, 1]), + "float": column_or_1d([0.0, 1.0, 2.0]), + "bool": column_or_1d([False, True]), "string": column_or_1d(["se", "ve", "vi"]), "string binary": column_or_1d(["ve", "vi_or_se"]), "int as string": column_or_1d(["10", "8", "9"]), "int as string binary": column_or_1d(["10", "89"]), "cat int": column_or_1d([0, 1, 2]), "cat string": column_or_1d(["se", "ve", "vi"]), + "cat float": column_or_1d([0.0, 1.0, 2.0]), + "cat bool": column_or_1d([False, True]), }, "expected_y_preds": { "mono": { "int": y_pred, + "float": y_pred.astype(float), + "bool": self._replace(y_bin_pred, {0: True, 1: False}), "int binary": y_bin_pred, "string": self._replace(y_pred, {0: "se", 1: "vi", 2: "ve"}), "string binary": self._replace( @@ -128,9 +142,15 @@ def test_classifier_output_types(self): ), "cat int": y_pred, "cat string": self._replace(y_pred, {0: "se", 1: "vi", 2: "ve"}), + "cat float": self._replace( + y_pred, {target: float(target) for target in (0, 1, 2)} + ), + "cat bool": self._replace(y_bin_pred, {0: True, 1: False}), }, "multi": { "int": y_mt_pred, + "float": y_mt_pred.astype(float), + "bool": self._replace(y_mt_bin_pred, {0: True, 1: False}), "int binary": y_mt_bin_pred, "string": self._replace(y_mt_pred, {0: "se", 1: "vi", 2: "ve"}), "string binary": self._replace( @@ -144,6 +164,10 @@ def test_classifier_output_types(self): ), "cat int": y_mt_pred, "cat string": self._replace(y_mt_pred, {0: "se", 1: "vi", 2: "ve"}), + "cat float": self._replace( + y_mt_pred, {target: float(target) for target in (0, 1, 2)} + ), + "cat bool": self._replace(y_mt_bin_pred, {0: True, 1: False}), }, }, "Xs": {