Skip to content

Commit 7d7d3b4

Browse files
committed
WIP
1 parent 5a561ed commit 7d7d3b4

File tree

4 files changed

+120
-32
lines changed

4 files changed

+120
-32
lines changed

khiops/samples/samples_sklearn.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,45 @@ def khiops_classifier():
8888
# khc.export_report_file("report.khj")
8989
# kh.visualize_report("report.khj")
9090

91+
def khiops_classifier_float_target():
92+
"""Trains a `.KhiopsClassifier` on a monotable dataframe"""
93+
# Imports
94+
import os
95+
import pandas as pd
96+
from khiops import core as kh
97+
from khiops.sklearn import KhiopsClassifier
98+
from sklearn import metrics
99+
from sklearn.model_selection import train_test_split
100+
101+
# Load the dataset into a pandas dataframe
102+
adult_path = os.path.join(kh.get_samples_dir(), "Adult", "Adult.txt")
103+
adult_df = pd.read_csv(adult_path, sep="\t")
104+
105+
# Split the whole dataframe into train and test (70%-30%)
106+
adult_train_df, adult_test_df = train_test_split(
107+
adult_df, test_size=0.3, random_state=1
108+
)
109+
110+
X_train = adult_train_df.drop("class", axis=1)
111+
X_test = adult_test_df.drop("class", axis=1)
112+
#y_train = adult_train_df["class"].replace({"less": 0.0, "more": 1.0})
113+
y_train = adult_train_df["class"].replace({"less": True, "more": False})
114+
115+
print(y_train.dtype)
116+
#y_train.replace()
117+
118+
# Create the classifier object
119+
khc = KhiopsClassifier()
120+
121+
# Train the classifier
122+
khc.fit(X_train, y_train)
123+
124+
# Predict the classes on the test dataset
125+
y_test_pred = khc.predict(X_test)
126+
print("Predicted classes (first 10):")
127+
print(y_test_pred[0:10])
128+
print("---")
129+
91130

92131
def khiops_classifier_multiclass():
93132
"""Trains a multiclass `.KhiopsClassifier` on a monotable dataframe"""
@@ -1024,23 +1063,7 @@ def khiops_classifier_multitable_star_file():
10241063

10251064

10261065
exported_samples = [
1027-
khiops_classifier,
1028-
khiops_classifier_multiclass,
1029-
khiops_classifier_multitable_star,
1030-
khiops_classifier_multitable_snowflake,
1031-
khiops_classifier_sparse,
1032-
khiops_classifier_pickle,
1033-
khiops_classifier_with_hyperparameters,
1034-
khiops_regressor,
1035-
khiops_encoder,
1036-
khiops_encoder_multitable_star,
1037-
khiops_encoder_multitable_snowflake,
1038-
khiops_encoder_pipeline_with_hgbc,
1039-
khiops_encoder_with_hyperparameters,
1040-
khiops_coclustering,
1041-
khiops_coclustering_simplify,
1042-
khiops_classifier_multitable_list,
1043-
khiops_classifier_multitable_star_file,
1066+
khiops_classifier_float_target
10441067
]
10451068

10461069

khiops/sklearn/dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import pandas as pd
17+
from Demos.win32cred_demo import target
1718
from scipy import sparse as sp
1819
from sklearn.utils import check_array
1920
from sklearn.utils.validation import column_or_1d
@@ -1258,6 +1259,8 @@ def create_table_file_for_khiops(
12581259
get_khiops_variable_name(column_id) for column_id in self.column_ids
12591260
]
12601261
if target_column is not None:
1262+
print(target_column)
1263+
print(target_column.dtype)
12611264
output_dataframe[get_khiops_variable_name(target_column_id)] = (
12621265
target_column.copy()
12631266
)

khiops/sklearn/estimators.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def _check_categorical_target_type(ds):
154154
or pd.api.types.is_string_dtype(ds.target_column.dtype)
155155
or pd.api.types.is_integer_dtype(ds.target_column.dtype)
156156
or pd.api.types.is_float_dtype(ds.target_column.dtype)
157+
or pd.api.types.is_bool_dtype(ds.target_column.dtype)
157158
):
158159
raise ValueError(
159160
f"'y' has invalid type '{ds.target_column_type}'. "
@@ -2123,6 +2124,24 @@ def _is_real_target_dtype_integer(self):
21232124
)
21242125
)
21252126

2127+
def _is_real_target_dtype_float(self):
2128+
return self._original_target_dtype is not None and (
2129+
pd.api.types.is_float_dtype(self._original_target_dtype)
2130+
or (
2131+
isinstance(self._original_target_dtype, pd.CategoricalDtype)
2132+
and pd.api.types.is_float_dtype(self._original_target_dtype.categories)
2133+
)
2134+
)
2135+
2136+
def _is_real_target_dtype_bool(self):
2137+
return self._original_target_dtype is not None and (
2138+
pd.api.types.is_bool_dtype(self._original_target_dtype)
2139+
or (
2140+
isinstance(self._original_target_dtype, pd.CategoricalDtype)
2141+
and pd.api.types.is_bool_dtype(self._original_target_dtype.categories)
2142+
)
2143+
)
2144+
21262145
def _sorted_prob_variable_names(self):
21272146
"""Returns the model probability variable names in the order of self.classes_"""
21282147
self._assert_is_fitted()
@@ -2227,8 +2246,15 @@ def _fit_training_post_process(self, ds):
22272246
for key in variable.meta_data.keys:
22282247
if key.startswith("TargetProb"):
22292248
self.classes_.append(variable.meta_data.get_value(key))
2230-
if ds.is_in_memory and self._is_real_target_dtype_integer():
2231-
self.classes_ = [int(class_value) for class_value in self.classes_]
2249+
print(self._get_main_dictionary())
2250+
if ds.is_in_memory:
2251+
if self._is_real_target_dtype_integer():
2252+
self.classes_ = [int(class_value) for class_value in self.classes_]
2253+
elif self._is_real_target_dtype_float():
2254+
print(self.classes_)
2255+
self.classes_ = [float(class_value) for class_value in self.classes_]
2256+
elif self._is_real_target_dtype_bool():
2257+
self.classes_ = [class_value == "True" for class_value in self.classes_]
22322258
self.classes_.sort()
22332259
self.classes_ = column_or_1d(self.classes_)
22342260

tests/test_sklearn_output_types.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,52 +71,66 @@ def test_classifier_output_types(self):
7171
khc = KhiopsClassifier(n_trees=0)
7272
khc.fit(X, y)
7373
y_pred = khc.predict(X)
74+
khc.fit(X_mt, y)
75+
y_mt_pred = khc.predict(X_mt)
76+
7477
y_bin = y.replace({0: 0, 1: 0, 2: 1})
7578
khc.fit(X, y_bin)
7679
y_bin_pred = khc.predict(X)
77-
khc.fit(X_mt, y)
78-
khc.export_report_file("report.khj")
79-
y_mt_pred = khc.predict(X_mt)
8080
khc.fit(X_mt, y_bin)
8181
y_mt_bin_pred = khc.predict(X_mt)
8282

8383
# Create the fixtures
8484
fixtures = {
8585
"ys": {
86-
"int": y,
87-
"int binary": y_bin,
88-
"string": self._replace(y, {0: "se", 1: "vi", 2: "ve"}),
89-
"string binary": self._replace(y_bin, {0: "vi_or_se", 1: "ve"}),
90-
"int as string": self._replace(y, {0: "8", 1: "9", 2: "10"}),
91-
"int as string binary": self._replace(y_bin, {0: "89", 1: "10"}),
92-
"cat int": pd.Series(y).astype("category"),
93-
"cat string": pd.Series(
94-
self._replace(y, {0: "se", 1: "vi", 2: "ve"})
95-
).astype("category"),
86+
#"int": y,
87+
#"int binary": y_bin,
88+
#"float": y.astype(float),
89+
#"bool": y.replace({0: True, 1: True, 2: False}),
90+
#"string": self._replace(y, {0: "se", 1: "vi", 2: "ve"}),
91+
#"string binary": self._replace(y_bin, {0: "vi_or_se", 1: "ve"}),
92+
#"int as string": self._replace(y, {0: "8", 1: "9", 2: "10"}),
93+
#"int as string binary": self._replace(y_bin, {0: "89", 1: "10"}),
94+
#"cat int": pd.Series(y).astype("category"),
95+
#"cat string": pd.Series(
96+
#self._replace(y, {0: "se", 1: "vi", 2: "ve"})
97+
#).astype("category"),
98+
#"cat float": y.astype(float).astype("category"),
99+
"cat bool": y.replace({0: True, 1: True, 2: False}).astype("category"),
96100
},
97101
"y_type_check": {
98102
"int": pd.api.types.is_integer_dtype,
99103
"int binary": pd.api.types.is_integer_dtype,
104+
"float": pd.api.types.is_float_dtype,
105+
"bool": pd.api.types.is_bool_dtype,
100106
"string": pd.api.types.is_string_dtype,
101107
"string binary": pd.api.types.is_string_dtype,
102108
"int as string": pd.api.types.is_string_dtype,
103109
"int as string binary": pd.api.types.is_string_dtype,
104110
"cat int": pd.api.types.is_integer_dtype,
105111
"cat string": pd.api.types.is_string_dtype,
112+
"cat float": pd.api.types.is_float_dtype,
113+
"cat bool": pd.api.types.is_bool_dtype,
106114
},
107115
"expected_classes": {
108116
"int": column_or_1d([0, 1, 2]),
109117
"int binary": column_or_1d([0, 1]),
118+
"float": column_or_1d([0.0, 1.0, 2.0]),
119+
"bool": column_or_1d([False, True]),
110120
"string": column_or_1d(["se", "ve", "vi"]),
111121
"string binary": column_or_1d(["ve", "vi_or_se"]),
112122
"int as string": column_or_1d(["10", "8", "9"]),
113123
"int as string binary": column_or_1d(["10", "89"]),
114124
"cat int": column_or_1d([0, 1, 2]),
115125
"cat string": column_or_1d(["se", "ve", "vi"]),
126+
"cat float": column_or_1d([0.0, 1.0, 2.0]),
127+
"cat bool": column_or_1d([False, True]),
116128
},
117129
"expected_y_preds": {
118130
"mono": {
119131
"int": y_pred,
132+
"float": y_pred.astype(float),
133+
"bool": self._replace(y_bin_pred, {0: True, 1: False}),
120134
"int binary": y_bin_pred,
121135
"string": self._replace(y_pred, {0: "se", 1: "vi", 2: "ve"}),
122136
"string binary": self._replace(
@@ -128,9 +142,14 @@ def test_classifier_output_types(self):
128142
),
129143
"cat int": y_pred,
130144
"cat string": self._replace(y_pred, {0: "se", 1: "vi", 2: "ve"}),
145+
"cat float": self._replace(y_pred, {0: 0.0, 1: 1.0, 2: 2.0}),
146+
"cat bool": self._replace(y_bin_pred, {0: True, 1: False}),
147+
131148
},
132149
"multi": {
133150
"int": y_mt_pred,
151+
"float": y_mt_pred.astype(float),
152+
"bool": self._replace(y_mt_bin_pred, {0: True, 1: False}),
134153
"int binary": y_mt_bin_pred,
135154
"string": self._replace(y_mt_pred, {0: "se", 1: "vi", 2: "ve"}),
136155
"string binary": self._replace(
@@ -144,6 +163,9 @@ def test_classifier_output_types(self):
144163
),
145164
"cat int": y_mt_pred,
146165
"cat string": self._replace(y_mt_pred, {0: "se", 1: "vi", 2: "ve"}),
166+
"cat float": self._replace(y_mt_pred, {0: 0.0, 1: 1.0, 2: 2.0}),
167+
"cat bool": self._replace(y_mt_bin_pred, {0: True, 1: False}),
168+
147169
},
148170
},
149171
"Xs": {
@@ -154,6 +176,8 @@ def test_classifier_output_types(self):
154176

155177
# Test for each fixture configuration
156178
for y_type, y in fixtures["ys"].items():
179+
print()
180+
print(y_type)
157181
y_type_check = fixtures["y_type_check"][y_type]
158182
expected_classes = fixtures["expected_classes"][y_type]
159183
for dataset_type, X in fixtures["Xs"].items():
@@ -165,20 +189,32 @@ def test_classifier_output_types(self):
165189
# Train the classifier
166190
khc = KhiopsClassifier(n_trees=0)
167191
khc.fit(X, y)
192+
print('unique y values:', np.unique(y))
168193

169194
# Check the expected classes
195+
print()
196+
print('model classes:', khc.classes_)
197+
print('expected_classes:', expected_classes)
170198
assert_array_equal(khc.classes_, expected_classes)
171199

172200
# Check the return type of predict
173201
y_pred = khc.predict(X)
202+
174203
self.assertTrue(
175204
y_type_check(y_pred),
176205
f"'{y_type_check.__name__}' was False for "
177206
f"dtype '{y_pred.dtype}'.",
178207
)
179208

180209
# Check the predictions match
210+
print(dataset_type, y_type)
181211
expected_y_pred = fixtures["expected_y_preds"][dataset_type][y_type]
212+
diff_indices = np.where(y_pred != expected_y_pred)
213+
print("Indices où les valeurs diffèrent :", diff_indices)
214+
215+
for idx in zip(*diff_indices):
216+
print(f"À l'indice {idx}, y_pred={y_pred[idx]}, expected_y_pred={expected_y_pred[idx]}")
217+
182218
assert_array_equal(y_pred, expected_y_pred)
183219

184220
# Check the dimensions of predict_proba

0 commit comments

Comments
 (0)