Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
- Example: 10.2.1.4 is the 5th version that supports khiops 10.2.1.
- Internals: Changes in *Internals* sections are unlikely to be of interest for data scientists.

## Unreleased

### Added
- (`sklearn`) Support for boolean and float targets in `KhiopsClassifier`.

## 10.3.0.0 - 2025-02-10

### Fixed
Expand Down
32 changes: 30 additions & 2 deletions khiops/sklearn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import pandas as pd
import sklearn
from scipy import sparse as sp
from sklearn.utils import check_array
from sklearn.utils.validation import column_or_1d
Expand Down Expand Up @@ -430,6 +431,19 @@ def write_internal_data_table(dataframe, file_path_or_stream):
)


def _column_or_1d_with_dtype(y, dtype=None):
# 'dtype' has been introduced on `column_or_1d' since Scikit-learn 1.2;
if sklearn.__version__ < "1.2":
if pd.api.types.is_string_dtype(dtype) and y.isin(["True", "False"]).all():
warnings.warn(
"'y' stores strings restricted to 'True'/'False' values: "
"The predict method may return a bool vector."
)
return column_or_1d(y, warn=True)
else:
return column_or_1d(y, warn=True, dtype=dtype)


class Dataset:
"""A representation of a dataset

Expand Down Expand Up @@ -738,8 +752,22 @@ def _init_target_column(self, y):
if isinstance(y, str):
y_checked = y
else:
y_checked = column_or_1d(y, warn=True)

if hasattr(y, "dtype"):
Comment thread
folmos-at-orange marked this conversation as resolved.
if isinstance(y.dtype, pd.CategoricalDtype):
y_checked = _column_or_1d_with_dtype(
y, dtype=y.dtype.categories.dtype
)
else:
y_checked = _column_or_1d_with_dtype(y, dtype=y.dtype)
elif hasattr(y, "dtypes"):
if isinstance(y.dtypes[0], pd.CategoricalDtype):
y_checked = _column_or_1d_with_dtype(
y, dtype=y.dtypes[0].categories.dtype
)
else:
y_checked = _column_or_1d_with_dtype(y)
else:
y_checked = _column_or_1d_with_dtype(y)
# Check the target type coherence with those of X's tables
if isinstance(
self.main_table, (PandasTable, SparseTable, NumpyTable)
Expand Down
35 changes: 30 additions & 5 deletions khiops/sklearn/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def _check_categorical_target_type(ds):
or pd.api.types.is_string_dtype(ds.target_column.dtype)
or pd.api.types.is_integer_dtype(ds.target_column.dtype)
or pd.api.types.is_float_dtype(ds.target_column.dtype)
or pd.api.types.is_bool_dtype(ds.target_column.dtype)
):
raise ValueError(
f"'y' has invalid type '{ds.target_column_type}'. "
Expand Down Expand Up @@ -2123,6 +2124,24 @@ def _is_real_target_dtype_integer(self):
)
)

def _is_real_target_dtype_float(self):
return self._original_target_dtype is not None and (
pd.api.types.is_float_dtype(self._original_target_dtype)
or (
isinstance(self._original_target_dtype, pd.CategoricalDtype)
and pd.api.types.is_float_dtype(self._original_target_dtype.categories)
)
)

def _is_real_target_dtype_bool(self):
return self._original_target_dtype is not None and (
pd.api.types.is_bool_dtype(self._original_target_dtype)
or (
isinstance(self._original_target_dtype, pd.CategoricalDtype)
and pd.api.types.is_bool_dtype(self._original_target_dtype.categories)
)
)

def _sorted_prob_variable_names(self):
"""Returns the model probability variable names in the order of self.classes_"""
self._assert_is_fitted()
Expand Down Expand Up @@ -2227,8 +2246,13 @@ def _fit_training_post_process(self, ds):
for key in variable.meta_data.keys:
if key.startswith("TargetProb"):
self.classes_.append(variable.meta_data.get_value(key))
if ds.is_in_memory and self._is_real_target_dtype_integer():
self.classes_ = [int(class_value) for class_value in self.classes_]
if ds.is_in_memory:
if self._is_real_target_dtype_integer():
self.classes_ = [int(class_value) for class_value in self.classes_]
elif self._is_real_target_dtype_float():
self.classes_ = [float(class_value) for class_value in self.classes_]
elif self._is_real_target_dtype_bool():
self.classes_ = [class_value == "True" for class_value in self.classes_]
self.classes_.sort()
self.classes_ = column_or_1d(self.classes_)

Expand Down Expand Up @@ -2283,9 +2307,10 @@ def predict(self, X):
-------
`ndarray <numpy.ndarray>`
An array containing the encoded columns. A first column containing key
column ids is added in multi-table mode. The `numpy.dtype` of the array is
integer if the classifier was learned with an integer ``y``. Otherwise it
will be ``str``.
column ids is added in multi-table mode. The `numpy.dtype` of the array
matches the type of ``y`` used during training. It will be integer, float,
or boolean if the classifier was trained with a ``y`` of the corresponding
type. Otherwise it will be ``str``.

The key columns are added for multi-table tasks.
"""
Expand Down
30 changes: 27 additions & 3 deletions tests/test_sklearn_output_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ def test_classifier_output_types(self):
khc = KhiopsClassifier(n_trees=0)
khc.fit(X, y)
y_pred = khc.predict(X)
khc.fit(X_mt, y)
y_mt_pred = khc.predict(X_mt)

y_bin = y.replace({0: 0, 1: 0, 2: 1})
khc.fit(X, y_bin)
y_bin_pred = khc.predict(X)
khc.fit(X_mt, y)
khc.export_report_file("report.khj")
y_mt_pred = khc.predict(X_mt)
Comment thread
popescu-v marked this conversation as resolved.
khc.fit(X_mt, y_bin)
y_mt_bin_pred = khc.predict(X_mt)

Expand All @@ -85,6 +85,8 @@ def test_classifier_output_types(self):
"ys": {
"int": y,
"int binary": y_bin,
"float": y.astype(float),
"bool": y.replace({0: True, 1: True, 2: False}),
"string": self._replace(y, {0: "se", 1: "vi", 2: "ve"}),
"string binary": self._replace(y_bin, {0: "vi_or_se", 1: "ve"}),
"int as string": self._replace(y, {0: "8", 1: "9", 2: "10"}),
Expand All @@ -93,30 +95,42 @@ def test_classifier_output_types(self):
"cat string": pd.Series(
self._replace(y, {0: "se", 1: "vi", 2: "ve"})
).astype("category"),
"cat float": y.astype(float).astype("category"),
"cat bool": y.replace({0: True, 1: True, 2: False}).astype("category"),
},
"y_type_check": {
"int": pd.api.types.is_integer_dtype,
"int binary": pd.api.types.is_integer_dtype,
"float": pd.api.types.is_float_dtype,
"bool": pd.api.types.is_bool_dtype,
"string": pd.api.types.is_string_dtype,
"string binary": pd.api.types.is_string_dtype,
"int as string": pd.api.types.is_string_dtype,
"int as string binary": pd.api.types.is_string_dtype,
"cat int": pd.api.types.is_integer_dtype,
"cat string": pd.api.types.is_string_dtype,
"cat float": pd.api.types.is_float_dtype,
"cat bool": pd.api.types.is_bool_dtype,
},
"expected_classes": {
"int": column_or_1d([0, 1, 2]),
"int binary": column_or_1d([0, 1]),
"float": column_or_1d([0.0, 1.0, 2.0]),
"bool": column_or_1d([False, True]),
"string": column_or_1d(["se", "ve", "vi"]),
"string binary": column_or_1d(["ve", "vi_or_se"]),
"int as string": column_or_1d(["10", "8", "9"]),
"int as string binary": column_or_1d(["10", "89"]),
"cat int": column_or_1d([0, 1, 2]),
"cat string": column_or_1d(["se", "ve", "vi"]),
"cat float": column_or_1d([0.0, 1.0, 2.0]),
"cat bool": column_or_1d([False, True]),
},
"expected_y_preds": {
"mono": {
"int": y_pred,
"float": y_pred.astype(float),
"bool": self._replace(y_bin_pred, {0: True, 1: False}),
"int binary": y_bin_pred,
"string": self._replace(y_pred, {0: "se", 1: "vi", 2: "ve"}),
"string binary": self._replace(
Expand All @@ -128,9 +142,15 @@ def test_classifier_output_types(self):
),
"cat int": y_pred,
"cat string": self._replace(y_pred, {0: "se", 1: "vi", 2: "ve"}),
"cat float": self._replace(
y_pred, {target: float(target) for target in (0, 1, 2)}
),
"cat bool": self._replace(y_bin_pred, {0: True, 1: False}),
},
"multi": {
"int": y_mt_pred,
"float": y_mt_pred.astype(float),
"bool": self._replace(y_mt_bin_pred, {0: True, 1: False}),
"int binary": y_mt_bin_pred,
"string": self._replace(y_mt_pred, {0: "se", 1: "vi", 2: "ve"}),
"string binary": self._replace(
Expand All @@ -144,6 +164,10 @@ def test_classifier_output_types(self):
),
"cat int": y_mt_pred,
"cat string": self._replace(y_mt_pred, {0: "se", 1: "vi", 2: "ve"}),
"cat float": self._replace(
y_mt_pred, {target: float(target) for target in (0, 1, 2)}
),
"cat bool": self._replace(y_mt_bin_pred, {0: True, 1: False}),
},
},
"Xs": {
Expand Down