Skip to content

Commit c5a5644

Browse files
committed
support float and boolean targets in KhiopsClassifier
1 parent 7d7d3b4 commit c5a5644

File tree

4 files changed

+49
-85
lines changed

4 files changed

+49
-85
lines changed

khiops/samples/samples_sklearn.py

Lines changed: 17 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -88,45 +88,6 @@ def khiops_classifier():
8888
# khc.export_report_file("report.khj")
8989
# kh.visualize_report("report.khj")
9090

91-
def khiops_classifier_float_target():
92-
"""Trains a `.KhiopsClassifier` on a monotable dataframe"""
93-
# Imports
94-
import os
95-
import pandas as pd
96-
from khiops import core as kh
97-
from khiops.sklearn import KhiopsClassifier
98-
from sklearn import metrics
99-
from sklearn.model_selection import train_test_split
100-
101-
# Load the dataset into a pandas dataframe
102-
adult_path = os.path.join(kh.get_samples_dir(), "Adult", "Adult.txt")
103-
adult_df = pd.read_csv(adult_path, sep="\t")
104-
105-
# Split the whole dataframe into train and test (70%-30%)
106-
adult_train_df, adult_test_df = train_test_split(
107-
adult_df, test_size=0.3, random_state=1
108-
)
109-
110-
X_train = adult_train_df.drop("class", axis=1)
111-
X_test = adult_test_df.drop("class", axis=1)
112-
#y_train = adult_train_df["class"].replace({"less": 0.0, "more": 1.0})
113-
y_train = adult_train_df["class"].replace({"less": True, "more": False})
114-
115-
print(y_train.dtype)
116-
#y_train.replace()
117-
118-
# Create the classifier object
119-
khc = KhiopsClassifier()
120-
121-
# Train the classifier
122-
khc.fit(X_train, y_train)
123-
124-
# Predict the classes on the test dataset
125-
y_test_pred = khc.predict(X_test)
126-
print("Predicted classes (first 10):")
127-
print(y_test_pred[0:10])
128-
print("---")
129-
13091

13192
def khiops_classifier_multiclass():
13293
"""Trains a multiclass `.KhiopsClassifier` on a monotable dataframe"""
@@ -1063,7 +1024,23 @@ def khiops_classifier_multitable_star_file():
10631024

10641025

10651026
exported_samples = [
1066-
khiops_classifier_float_target
1027+
khiops_classifier,
1028+
khiops_classifier_multiclass,
1029+
khiops_classifier_multitable_star,
1030+
khiops_classifier_multitable_snowflake,
1031+
khiops_classifier_sparse,
1032+
khiops_classifier_pickle,
1033+
khiops_classifier_with_hyperparameters,
1034+
khiops_regressor,
1035+
khiops_encoder,
1036+
khiops_encoder_multitable_star,
1037+
khiops_encoder_multitable_snowflake,
1038+
khiops_encoder_pipeline_with_hgbc,
1039+
khiops_encoder_with_hyperparameters,
1040+
khiops_coclustering,
1041+
khiops_coclustering_simplify,
1042+
khiops_classifier_multitable_list,
1043+
khiops_classifier_multitable_star_file,
10671044
]
10681045

10691046

khiops/sklearn/dataset.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import numpy as np
1616
import pandas as pd
17-
from Demos.win32cred_demo import target
1817
from scipy import sparse as sp
1918
from sklearn.utils import check_array
2019
from sklearn.utils.validation import column_or_1d
@@ -739,8 +738,15 @@ def _init_target_column(self, y):
739738
if isinstance(y, str):
740739
y_checked = y
741740
else:
742-
y_checked = column_or_1d(y, warn=True)
743-
741+
if hasattr(y, "dtype"):
742+
if isinstance(y.dtype, pd.CategoricalDtype):
743+
y_checked = column_or_1d(
744+
y, warn=True, dtype=y.dtype.categories.dtype
745+
)
746+
else:
747+
y_checked = column_or_1d(y, warn=True, dtype=y.dtype)
748+
else:
749+
y_checked = column_or_1d(y, warn=True)
744750
# Check the target type coherence with those of X's tables
745751
if isinstance(
746752
self.main_table, (PandasTable, SparseTable, NumpyTable)
@@ -1259,8 +1265,6 @@ def create_table_file_for_khiops(
12591265
get_khiops_variable_name(column_id) for column_id in self.column_ids
12601266
]
12611267
if target_column is not None:
1262-
print(target_column)
1263-
print(target_column.dtype)
12641268
output_dataframe[get_khiops_variable_name(target_column_id)] = (
12651269
target_column.copy()
12661270
)

khiops/sklearn/estimators.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2126,20 +2126,20 @@ def _is_real_target_dtype_integer(self):
21262126

21272127
def _is_real_target_dtype_float(self):
21282128
return self._original_target_dtype is not None and (
2129-
pd.api.types.is_float_dtype(self._original_target_dtype)
2130-
or (
2131-
isinstance(self._original_target_dtype, pd.CategoricalDtype)
2132-
and pd.api.types.is_float_dtype(self._original_target_dtype.categories)
2133-
)
2129+
pd.api.types.is_float_dtype(self._original_target_dtype)
2130+
or (
2131+
isinstance(self._original_target_dtype, pd.CategoricalDtype)
2132+
and pd.api.types.is_float_dtype(self._original_target_dtype.categories)
2133+
)
21342134
)
21352135

21362136
def _is_real_target_dtype_bool(self):
21372137
return self._original_target_dtype is not None and (
2138-
pd.api.types.is_bool_dtype(self._original_target_dtype)
2139-
or (
2140-
isinstance(self._original_target_dtype, pd.CategoricalDtype)
2141-
and pd.api.types.is_bool_dtype(self._original_target_dtype.categories)
2142-
)
2138+
pd.api.types.is_bool_dtype(self._original_target_dtype)
2139+
or (
2140+
isinstance(self._original_target_dtype, pd.CategoricalDtype)
2141+
and pd.api.types.is_bool_dtype(self._original_target_dtype.categories)
2142+
)
21432143
)
21442144

21452145
def _sorted_prob_variable_names(self):
@@ -2246,12 +2246,10 @@ def _fit_training_post_process(self, ds):
22462246
for key in variable.meta_data.keys:
22472247
if key.startswith("TargetProb"):
22482248
self.classes_.append(variable.meta_data.get_value(key))
2249-
print(self._get_main_dictionary())
22502249
if ds.is_in_memory:
22512250
if self._is_real_target_dtype_integer():
22522251
self.classes_ = [int(class_value) for class_value in self.classes_]
22532252
elif self._is_real_target_dtype_float():
2254-
print(self.classes_)
22552253
self.classes_ = [float(class_value) for class_value in self.classes_]
22562254
elif self._is_real_target_dtype_bool():
22572255
self.classes_ = [class_value == "True" for class_value in self.classes_]

tests/test_sklearn_output_types.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,19 @@ def test_classifier_output_types(self):
8383
# Create the fixtures
8484
fixtures = {
8585
"ys": {
86-
#"int": y,
87-
#"int binary": y_bin,
88-
#"float": y.astype(float),
89-
#"bool": y.replace({0: True, 1: True, 2: False}),
90-
#"string": self._replace(y, {0: "se", 1: "vi", 2: "ve"}),
91-
#"string binary": self._replace(y_bin, {0: "vi_or_se", 1: "ve"}),
92-
#"int as string": self._replace(y, {0: "8", 1: "9", 2: "10"}),
93-
#"int as string binary": self._replace(y_bin, {0: "89", 1: "10"}),
94-
#"cat int": pd.Series(y).astype("category"),
95-
#"cat string": pd.Series(
96-
#self._replace(y, {0: "se", 1: "vi", 2: "ve"})
97-
#).astype("category"),
98-
#"cat float": y.astype(float).astype("category"),
86+
"int": y,
87+
"int binary": y_bin,
88+
"float": y.astype(float),
89+
"bool": y.replace({0: True, 1: True, 2: False}),
90+
"string": self._replace(y, {0: "se", 1: "vi", 2: "ve"}),
91+
"string binary": self._replace(y_bin, {0: "vi_or_se", 1: "ve"}),
92+
"int as string": self._replace(y, {0: "8", 1: "9", 2: "10"}),
93+
"int as string binary": self._replace(y_bin, {0: "89", 1: "10"}),
94+
"cat int": pd.Series(y).astype("category"),
95+
"cat string": pd.Series(
96+
self._replace(y, {0: "se", 1: "vi", 2: "ve"})
97+
).astype("category"),
98+
"cat float": y.astype(float).astype("category"),
9999
"cat bool": y.replace({0: True, 1: True, 2: False}).astype("category"),
100100
},
101101
"y_type_check": {
@@ -144,7 +144,6 @@ def test_classifier_output_types(self):
144144
"cat string": self._replace(y_pred, {0: "se", 1: "vi", 2: "ve"}),
145145
"cat float": self._replace(y_pred, {0: 0.0, 1: 1.0, 2: 2.0}),
146146
"cat bool": self._replace(y_bin_pred, {0: True, 1: False}),
147-
148147
},
149148
"multi": {
150149
"int": y_mt_pred,
@@ -165,7 +164,6 @@ def test_classifier_output_types(self):
165164
"cat string": self._replace(y_mt_pred, {0: "se", 1: "vi", 2: "ve"}),
166165
"cat float": self._replace(y_mt_pred, {0: 0.0, 1: 1.0, 2: 2.0}),
167166
"cat bool": self._replace(y_mt_bin_pred, {0: True, 1: False}),
168-
169167
},
170168
},
171169
"Xs": {
@@ -176,8 +174,6 @@ def test_classifier_output_types(self):
176174

177175
# Test for each fixture configuration
178176
for y_type, y in fixtures["ys"].items():
179-
print()
180-
print(y_type)
181177
y_type_check = fixtures["y_type_check"][y_type]
182178
expected_classes = fixtures["expected_classes"][y_type]
183179
for dataset_type, X in fixtures["Xs"].items():
@@ -189,12 +185,8 @@ def test_classifier_output_types(self):
189185
# Train the classifier
190186
khc = KhiopsClassifier(n_trees=0)
191187
khc.fit(X, y)
192-
print('unique y values:', np.unique(y))
193188

194189
# Check the expected classes
195-
print()
196-
print('model classes:', khc.classes_)
197-
print('expected_classes:', expected_classes)
198190
assert_array_equal(khc.classes_, expected_classes)
199191

200192
# Check the return type of predict
@@ -207,14 +199,7 @@ def test_classifier_output_types(self):
207199
)
208200

209201
# Check the predictions match
210-
print(dataset_type, y_type)
211202
expected_y_pred = fixtures["expected_y_preds"][dataset_type][y_type]
212-
diff_indices = np.where(y_pred != expected_y_pred)
213-
print("Indices où les valeurs diffèrent :", diff_indices)
214-
215-
for idx in zip(*diff_indices):
216-
print(f"À l'indice {idx}, y_pred={y_pred[idx]}, expected_y_pred={expected_y_pred[idx]}")
217-
218203
assert_array_equal(y_pred, expected_y_pred)
219204

220205
# Check the dimensions of predict_proba

0 commit comments

Comments
 (0)