Skip to content

Commit 747d9c4

Browse files
committed
WIP
1 parent df90b0e commit 747d9c4

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

khiops/samples/samples_sklearn.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,53 @@ def khiops_classifier():
8989
# kh.visualize_report("report.khj")
9090

9191

92+
def khiops_classifier_float_target():
93+
"""Trains a `.KhiopsClassifier` on a monotable dataframe
94+
with a float target"""
95+
# Imports
96+
import os
97+
import pandas as pd
98+
from khiops import core as kh
99+
from khiops.sklearn import KhiopsClassifier
100+
from sklearn.model_selection import train_test_split
101+
102+
# Load the dataset into a pandas dataframe
103+
adult_path = os.path.join(kh.get_samples_dir(), "Adult", "Adult.txt")
104+
adult_df = pd.read_csv(adult_path, sep="\t")
105+
adult_df["class"] = adult_df["class"].replace({"less": 0.0, "more": 1.0})
106+
107+
# Split the whole dataframe into train and test (70%-30%)
108+
adult_train_df, adult_test_df = train_test_split(
109+
adult_df, test_size=0.3, random_state=1
110+
)
111+
112+
# Split the dataset into:
113+
# - the X feature table
114+
# - the y target vector ("class" column)
115+
X_train = adult_train_df.drop("class", axis=1)
116+
X_test = adult_test_df.drop("class", axis=1)
117+
y_train = adult_train_df["class"]
118+
119+
# Create the classifier object
120+
khc = KhiopsClassifier()
121+
122+
# Train the classifier
123+
khc.fit(X_train, y_train)
124+
125+
# Predict the classes on the test dataset
126+
y_test_pred = khc.predict(X_test)
127+
print("Predicted classes (first 10):")
128+
print(y_test_pred[0:10])
129+
print("---")
130+
131+
# Predict the class probabilities on the test dataset
132+
y_test_probas = khc.predict_proba(X_test)
133+
print(f"Class order: {khc.classes_}")
134+
print("Predicted class probabilities (first 10):")
135+
print(y_test_probas[0:10])
136+
print("---")
137+
138+
92139
def khiops_classifier_multiclass():
93140
"""Trains a multiclass `.KhiopsClassifier` on a monotable dataframe"""
94141
# Imports
@@ -1061,6 +1108,8 @@ def khiops_classifier_multitable_star_file():
10611108
print(f"Test auc = {test_auc}")
10621109

10631110

1111+
exported_samples = [khiops_classifier_float_target]
1112+
"""
10641113
exported_samples = [
10651114
khiops_classifier,
10661115
khiops_classifier_multiclass,
@@ -1080,6 +1129,7 @@ def khiops_classifier_multitable_star_file():
10801129
khiops_classifier_multitable_list,
10811130
khiops_classifier_multitable_star_file,
10821131
]
1132+
"""
10831133

10841134

10851135
def execute_samples(args):

khiops/sklearn/estimators.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,13 +2259,11 @@ def predict(self, X):
22592259
"""
22602260
# Call the parent's method
22612261
y_pred = super().predict(X)
2262-
22632262
# Adjust the data type according to the original target type
22642263
# Note: String is coerced explictly because astype does not work as expected
22652264
if isinstance(y_pred, pd.DataFrame):
22662265
# Transform to numpy.ndarray
22672266
y_pred = y_pred.to_numpy(copy=False).ravel()
2268-
22692267
# If integer and string just transform
22702268
if pd.api.types.is_integer_dtype(self._original_target_dtype):
22712269
y_pred = y_pred.astype(self._original_target_dtype)
@@ -2275,6 +2273,10 @@ def predict(self, X):
22752273
self._original_target_dtype
22762274
):
22772275
y_pred = y_pred.astype(str, copy=False)
2276+
elif pd.api.types.is_float_dtype(self._original_target_type):
2277+
print(self._original_target_type)
2278+
y_pred = y_pred.astype(str, copy=False)
2279+
print(y_pred)
22782280
# If category first coerce the type to the categories' type
22792281
else:
22802282
assert isinstance(self._original_target_dtype, pd.CategoricalDtype), (

0 commit comments

Comments
 (0)