Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

- (General) `visualize_report` helper function to open reports with the Khiops Visualization and
Khiops Co-Visualization app.
- (`sklearn`) Support for `float` and `boolean` targets in `KhiopsClassifier`.

## 10.2.3.1 - 2024-11-27

Expand Down
35 changes: 30 additions & 5 deletions khiops/sklearn/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def _check_categorical_target_type(ds):
or pd.api.types.is_string_dtype(ds.target_column.dtype)
or pd.api.types.is_integer_dtype(ds.target_column.dtype)
or pd.api.types.is_float_dtype(ds.target_column.dtype)
or pd.api.types.is_bool_dtype(ds.target_column.dtype)
):
raise ValueError(
f"'y' has invalid type '{ds.target_column_type}'. "
Expand Down Expand Up @@ -2093,6 +2094,24 @@ def _is_real_target_dtype_integer(self):
)
)

def _is_real_target_dtype_float(self):
return self._original_target_dtype is not None and (
pd.api.types.is_float_dtype(self._original_target_dtype)
or (
isinstance(self._original_target_dtype, pd.CategoricalDtype)
and pd.api.types.is_float_dtype(self._original_target_dtype.categories)
)
)

def _is_real_target_dtype_bool(self):
return self._original_target_dtype is not None and (
pd.api.types.is_bool_dtype(self._original_target_dtype)
or (
isinstance(self._original_target_dtype, pd.CategoricalDtype)
and pd.api.types.is_bool_dtype(self._original_target_dtype.categories)
)
)

def _sorted_prob_variable_names(self):
"""Returns the model probability variable names in the order of self.classes_"""
assert self.is_fitted_, "Model not fit yet"
Expand Down Expand Up @@ -2195,11 +2214,15 @@ def _fit_training_post_process(self, ds):
for key in variable.meta_data.keys:
if key.startswith("TargetProb"):
self.classes_.append(variable.meta_data.get_value(key))
if ds.is_in_memory and self._is_real_target_dtype_integer():
self.classes_ = [int(class_value) for class_value in self.classes_]
if ds.is_in_memory:
if self._is_real_target_dtype_integer():
self.classes_ = [int(class_value) for class_value in self.classes_]
elif self._is_real_target_dtype_float():
self.classes_ = [float(class_value) for class_value in self.classes_]
elif self._is_real_target_dtype_bool():
self.classes_ = [class_value == "True" for class_value in self.classes_]
self.classes_.sort()
self.classes_ = column_or_1d(self.classes_)

# Count number of classes
self.n_classes_ = len(self.classes_)

Expand Down Expand Up @@ -2259,13 +2282,11 @@ def predict(self, X):
"""
# Call the parent's method
y_pred = super().predict(X)

# Adjust the data type according to the original target type
# Note: String is coerced explictly because astype does not work as expected
if isinstance(y_pred, pd.DataFrame):
# Transform to numpy.ndarray
y_pred = y_pred.to_numpy(copy=False).ravel()

# If integer and string just transform
if pd.api.types.is_integer_dtype(self._original_target_dtype):
y_pred = y_pred.astype(self._original_target_dtype)
Expand All @@ -2275,6 +2296,10 @@ def predict(self, X):
self._original_target_dtype
):
y_pred = y_pred.astype(str, copy=False)
elif pd.api.types.is_float_dtype(self._original_target_dtype):
y_pred = y_pred.astype(float, copy=False)
elif pd.api.types.is_bool_dtype(self._original_target_dtype):
y_pred = y_pred.astype(bool, copy=False)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work. A fix

y_pred = y_pred.map(lambda val: val == "True")

# If category first coerce the type to the categories' type
else:
assert isinstance(self._original_target_dtype, pd.CategoricalDtype), (
Expand Down
7 changes: 6 additions & 1 deletion tests/test_sklearn_output_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ def test_classifier_output_types(self):
"""Test the KhiopsClassifier output types and classes of predict* methods"""
X, y = create_iris()
X_mt, X_sec_mt, _ = create_iris_mt()

fixtures = {
"ys": {
"int": y,
"int binary": y.replace({0: 0, 1: 0, 2: 1}),
"float": y.astype(float),
"bool": y.replace({0: True, 1: True, 2: False}),
"string": y.replace({0: "se", 1: "vi", 2: "ve"}),
"string binary": y.replace({0: "vi_or_se", 1: "vi_or_se", 2: "ve"}),
"int as string": y.replace({0: "8", 1: "9", 2: "10"}),
Expand All @@ -69,6 +70,8 @@ def test_classifier_output_types(self):
"y_type_check": {
"int": pd.api.types.is_integer_dtype,
"int binary": pd.api.types.is_integer_dtype,
"float": pd.api.types.is_float_dtype,
"bool": pd.api.types.is_bool_dtype,
"string": pd.api.types.is_string_dtype,
"string binary": pd.api.types.is_string_dtype,
"int as string": pd.api.types.is_string_dtype,
Expand All @@ -79,6 +82,8 @@ def test_classifier_output_types(self):
"expected_classes": {
"int": column_or_1d([0, 1, 2]),
"int binary": column_or_1d([0, 1]),
"float": column_or_1d([0.0, 1.0, 2.0]),
"bool": column_or_1d([False, True]),
"string": column_or_1d(["se", "ve", "vi"]),
"string binary": column_or_1d(["ve", "vi_or_se"]),
"int as string": column_or_1d(["10", "8", "9"]),
Expand Down