Skip to content

Commit b21e13c

Browse files
Control the types on sklearn internal read table
We do this only for KhiopsClassifier and KhiopsRegressor: It is critical for KhiopsClassifier as it accepts many target types and it is trivial in the case of KhiopsRegressor. For KhiopsEncoder and KhiopsCoclustering is less critical and for the first one it is very complex. We left them as TODO's. Additionaly, we now also check in the "output type" tests that the result of predict is correct. Before we only checked only that the classes_ attribute was ok. This is to further ensure correctness.
1 parent 8b52433 commit b21e13c

File tree

3 files changed

+213
-122
lines changed

3 files changed

+213
-122
lines changed

khiops/sklearn/estimators.py

Lines changed: 109 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ class KhiopsEstimator(ABC, BaseEstimator):
231231
The name of the column to be used as key.
232232
**Deprecated** will be removed in Khiops 11.
233233
internal_sort : bool, optional
234-
*Advanced.*: See concrete estimator classes for information about this
234+
*Advanced*: See concrete estimator classes for information about this
235235
parameter.
236236
**Deprecated** will be removed in Khiops 11. Use the ``auto_sort``
237237
estimator parameter instead.
@@ -466,7 +466,7 @@ def _transform(
466466
self,
467467
ds,
468468
computation_dir,
469-
_transform_create_deployment_model_fun,
469+
_transform_prepare_deployment_fun,
470470
drop_key,
471471
transformed_file_name,
472472
):
@@ -478,11 +478,13 @@ def _transform(
478478
self._transform_check_dataset(ds)
479479

480480
# Create a deployment dataset
481-
# Note: The input dataset is not necessarily ready to be deployed
481+
# Note: The input dataset isn't ready for deployment in the case of coclustering
482482
deployment_ds = self._transform_create_deployment_dataset(ds, computation_dir)
483483

484-
# Create a deployment dictionary
485-
deployment_dictionary_domain = _transform_create_deployment_model_fun(ds)
484+
# Create a deployment dictionary and the internal table column dtypes
485+
deployment_dictionary_domain, internal_table_column_dtypes = (
486+
_transform_prepare_deployment_fun(ds)
487+
)
486488

487489
# Deploy the model
488490
output_table_path = self._transform_deploy_model(
@@ -493,10 +495,36 @@ def _transform(
493495
transformed_file_name,
494496
)
495497

496-
# Post-process to return the correct output type
497-
return self._transform_deployment_post_process(
498-
deployment_ds, output_table_path, drop_key
499-
)
498+
# Post-process to return the correct output type and order
499+
if deployment_ds.is_in_memory:
500+
# Load the table as a dataframe
501+
with io.BytesIO(fs.read(output_table_path)) as output_table_stream:
502+
output_table_df = read_internal_data_table(
503+
output_table_stream, column_dtypes=internal_table_column_dtypes
504+
)
505+
506+
# On multi-table:
507+
# - Reorder the table to the original table order
508+
# - Because transformed data table file is sorted by key
509+
# - Drop the key columns if specified
510+
if deployment_ds.is_multitable:
511+
key_df = deployment_ds.main_table.data_source[
512+
deployment_ds.main_table.key
513+
]
514+
output_table_df_or_path = key_df.merge(
515+
output_table_df, on=deployment_ds.main_table.key
516+
)
517+
if drop_key:
518+
output_table_df_or_path.drop(
519+
deployment_ds.main_table.key, axis=1, inplace=True
520+
)
521+
# On mono-table: Return the read dataframe as-is
522+
else:
523+
output_table_df_or_path = output_table_df
524+
else:
525+
output_table_df_or_path = output_table_path
526+
527+
return output_table_df_or_path
500528

501529
def _transform_create_deployment_dataset(self, ds, _):
502530
"""Creates if necessary a new dataset to execute the model deployment
@@ -605,44 +633,6 @@ def _transform_check_dataset(self, ds):
605633
if ds.table_type == FileTable and self.output_dir is None:
606634
raise ValueError("'output_dir' is not set but dataset is file-based")
607635

608-
def _transform_deployment_post_process(
609-
self, deployment_ds, output_table_path, drop_key
610-
):
611-
# Return a dataframe for dataframe based datasets
612-
if deployment_ds.is_in_memory:
613-
# Read the transformed table with the internal table settings
614-
with io.BytesIO(fs.read(output_table_path)) as output_table_stream:
615-
output_table_df = read_internal_data_table(output_table_stream)
616-
617-
# On multi-table:
618-
# - Reorder the table to the original table order
619-
# - Because transformed data table file is sorted by key
620-
# - Drop the key columns if specified
621-
if deployment_ds.is_multitable:
622-
key_df = deployment_ds.main_table.data_source[
623-
deployment_ds.main_table.key
624-
]
625-
output_table_df_or_path = key_df.merge(
626-
output_table_df, on=deployment_ds.main_table.key
627-
)
628-
if drop_key:
629-
output_table_df_or_path.drop(
630-
deployment_ds.main_table.key, axis=1, inplace=True
631-
)
632-
# On mono-table: Return the read dataframe as-is
633-
else:
634-
output_table_df_or_path = output_table_df
635-
# Return a file path for file based datasets
636-
else:
637-
output_table_df_or_path = output_table_path
638-
639-
assert isinstance(
640-
output_table_df_or_path, (str, pd.DataFrame)
641-
), type_error_message(
642-
"output_table_df_or_path", output_table_df_or_path, str, pd.DataFrame
643-
)
644-
return output_table_df_or_path
645-
646636
def _create_computation_dir(self, method_name):
647637
"""Creates a temporary computation directory"""
648638
return kh.get_runner().create_temp_dir(
@@ -1260,7 +1250,7 @@ def predict(self, X):
12601250
y_pred = super()._transform(
12611251
ds,
12621252
computation_dir,
1263-
self._transform_prepare_deployment_model_for_predict,
1253+
self._transform_prepare_deployment_for_predict,
12641254
False,
12651255
"predict.txt",
12661256
)
@@ -1366,16 +1356,12 @@ def _transform_create_deployment_dataset(self, ds, computation_dir):
13661356

13671357
return Dataset(deploy_dataset_spec)
13681358

1369-
def _transform_prepare_deployment_model_for_predict(self, _):
1370-
return self.model_.copy()
1371-
1372-
def _transform_deployment_post_process(
1373-
self, deployment_ds, output_table_path, drop_key
1374-
):
1375-
assert deployment_ds.is_multitable
1376-
return super()._transform_deployment_post_process(
1377-
deployment_ds, output_table_path, drop_key
1378-
)
1359+
def _transform_prepare_deployment_for_predict(self, _):
1360+
# TODO
1361+
# Replace the second return value (the output columns' dtypes) with a proper
1362+
# value instead of `None`. In the current state, it will use pandas type
1363+
# auto-detection to load the internal table into memory.
1364+
return self.model_.copy(), None
13791365

13801366
def fit_predict(self, X, y=None, **kwargs):
13811367
"""Performs clustering on X and returns result (instead of labels)"""
@@ -1412,6 +1398,7 @@ def __init__(
14121398
self.specific_pairs = specific_pairs
14131399
self.all_possible_pairs = all_possible_pairs
14141400
self.construction_rules = construction_rules
1401+
self._original_target_dtype = None
14151402
self._predicted_target_meta_data_tag = None
14161403

14171404
# Deprecation message for 'key' constructor parameter
@@ -1619,6 +1606,22 @@ def _fit_training_post_process(self, ds):
16191606
# Call parent method
16201607
super()._fit_training_post_process(ds)
16211608

1609+
# Save the target and key column dtype's
1610+
if ds.is_in_memory:
1611+
if self._original_target_dtype is None:
1612+
self._original_target_dtype = ds.target_column.dtype
1613+
if ds.main_table.key is not None:
1614+
self._original_key_dtypes = {}
1615+
for column_id in ds.main_table.key:
1616+
self._original_key_dtypes[column_id] = (
1617+
ds.main_table.get_column_dtype(column_id)
1618+
)
1619+
else:
1620+
self._original_key_dtypes = None
1621+
else:
1622+
self._original_target_dtype = None
1623+
self._original_key_dtypes = None
1624+
16221625
# Set the target variable name
16231626
self.model_target_variable_name_ = get_khiops_variable_name(ds.target_column_id)
16241627

@@ -1794,6 +1797,7 @@ def __init__(
17941797
)
17951798
# Data to be specified by inherited classes
17961799
self._predicted_target_meta_data_tag = None
1800+
self._predicted_target_name_prefix = None
17971801
self.n_evaluated_features = n_evaluated_features
17981802
self.n_selected_features = n_selected_features
17991803

@@ -1821,7 +1825,7 @@ def predict(self, X):
18211825
y_pred = super()._transform(
18221826
ds,
18231827
computation_dir,
1824-
self._transform_prepare_deployment_model_for_predict,
1828+
self._transform_prepare_deployment_for_predict,
18251829
True,
18261830
"predict.txt",
18271831
)
@@ -1849,7 +1853,7 @@ def _fit_prepare_training_function_inputs(self, ds, computation_dir):
18491853

18501854
return args, kwargs
18511855

1852-
def _transform_prepare_deployment_model_for_predict(self, ds):
1856+
def _transform_prepare_deployment_for_predict(self, ds):
18531857
assert (
18541858
self._predicted_target_meta_data_tag is not None
18551859
), "Predicted target metadata tag is not set"
@@ -1874,7 +1878,20 @@ def _transform_prepare_deployment_model_for_predict(self, ds):
18741878
if self.model_target_variable_name_ not in list(ds.main_table.column_ids):
18751879
model_dictionary.remove_variable(self.model_target_variable_name_)
18761880

1877-
return model_copy
1881+
# Create the output column dtype dict
1882+
if ds.is_in_memory:
1883+
predicted_target_column_name = (
1884+
self._predicted_target_name_prefix + self.model_target_variable_name_
1885+
)
1886+
output_columns_dtype = {
1887+
predicted_target_column_name: self._original_target_dtype
1888+
}
1889+
if self.is_multitable_model_:
1890+
output_columns_dtype.update(self._original_key_dtypes)
1891+
else:
1892+
output_columns_dtype = None
1893+
1894+
return model_copy, output_columns_dtype
18781895

18791896
def get_feature_used_statistics(self, modeling_report):
18801897
# Extract, from the modeling report, names, levels, weights and importances
@@ -1889,7 +1906,7 @@ def get_feature_used_statistics(self, modeling_report):
18891906
for var in modeling_report.selected_variables
18901907
]
18911908
)
1892-
# Return empty arrays if not selected_variables is available
1909+
# Return empty arrays if no selected variables are available
18931910
else:
18941911
feature_used_names_ = np.array([], dtype=np.dtype("<U1"))
18951912
feature_used_importances_ = np.array([])
@@ -2075,6 +2092,7 @@ def __init__(
20752092
self.group_target_value = group_target_value
20762093
self._khiops_model_prefix = "SNB_"
20772094
self._predicted_target_meta_data_tag = "Prediction"
2095+
self._predicted_target_name_prefix = "Predicted"
20782096

20792097
def __sklearn_tags__(self):
20802098
# If we don't implement this trivial method it's not found by the sklearn. This
@@ -2183,12 +2201,6 @@ def _fit_training_post_process(self, ds):
21832201
# Call the parent's method
21842202
super()._fit_training_post_process(ds)
21852203

2186-
# Save the target datatype
2187-
if ds.is_in_memory:
2188-
self._original_target_dtype = ds.target_column.dtype
2189-
else:
2190-
self._original_target_dtype = None
2191-
21922204
# Save class values in the order of deployment
21932205
self.classes_ = []
21942206
for variable in self._get_main_dictionary().variables:
@@ -2260,37 +2272,15 @@ def predict(self, X):
22602272
# Call the parent's method
22612273
y_pred = super().predict(X)
22622274

2263-
# Adjust the data type according to the original target type
2264-
# Note: String is coerced explictly because astype does not work as expected
2275+
# Convert to numpy if it is in memory
22652276
if isinstance(y_pred, pd.DataFrame):
2266-
# Transform to numpy.ndarray
2267-
y_pred = y_pred.to_numpy(copy=False).ravel()
2268-
2269-
# If integer and string just transform
2270-
if pd.api.types.is_integer_dtype(self._original_target_dtype):
2271-
y_pred = y_pred.astype(self._original_target_dtype)
2272-
# If str transform to str
2273-
# Note: If the original type is None then it was learned with a file dataset
2274-
elif self._original_target_dtype is None or pd.api.types.is_string_dtype(
2275-
self._original_target_dtype
2276-
):
2277-
y_pred = y_pred.astype(str, copy=False)
2278-
# If category first coerce the type to the categories' type
2279-
else:
2280-
assert isinstance(self._original_target_dtype, pd.CategoricalDtype), (
2281-
"_original_target_dtype is not categorical"
2282-
f", it is '{self._original_target_dtype}'"
2283-
)
2284-
if pd.api.types.is_integer_dtype(
2285-
self._original_target_dtype.categories.dtype
2286-
):
2287-
y_pred = y_pred.astype(
2288-
self._original_target_dtype.categories.dtype, copy=False
2289-
)
2290-
else:
2291-
y_pred = y_pred.astype(str, copy=False)
2277+
y_pred = y_pred[
2278+
self._predicted_target_name_prefix + self.model_target_variable_name_
2279+
].to_numpy(copy=False)
22922280

2293-
assert isinstance(y_pred, (str, np.ndarray)), "Expected str or np.array"
2281+
assert isinstance(y_pred, (np.ndarray, str)), type_error_message(
2282+
"y_pred", y_pred, np.ndarray, str
2283+
)
22942284
return y_pred
22952285

22962286
def predict_proba(self, X):
@@ -2336,7 +2326,7 @@ def predict_proba(self, X):
23362326
y_probas = self._transform(
23372327
ds,
23382328
computation_dir,
2339-
self._transform_prepare_deployment_model_for_predict_proba,
2329+
self._transform_prepare_deployment_for_predict_proba,
23402330
True,
23412331
"predict_proba.txt",
23422332
)
@@ -2359,7 +2349,7 @@ def predict_proba(self, X):
23592349
assert isinstance(y_probas, (str, np.ndarray)), "Expected str or np.ndarray"
23602350
return y_probas
23612351

2362-
def _transform_prepare_deployment_model_for_predict_proba(self, ds):
2352+
def _transform_prepare_deployment_for_predict_proba(self, ds):
23632353
assert hasattr(
23642354
self, "model_target_variable_name_"
23652355
), "Target variable name has not been set"
@@ -2382,7 +2372,17 @@ def _transform_prepare_deployment_model_for_predict_proba(self, ds):
23822372
if self.model_target_variable_name_ not in list(ds.main_table.column_ids):
23832373
model_dictionary.remove_variable(self.model_target_variable_name_)
23842374

2385-
return model_copy
2375+
if ds.is_in_memory:
2376+
output_columns_dtype = {}
2377+
if self.is_multitable_model_:
2378+
output_columns_dtype.update(self._original_key_dtypes)
2379+
for variable in model_dictionary.variables:
2380+
if variable.used and variable.name not in model_dictionary.key:
2381+
output_columns_dtype[variable.name] = np.float64
2382+
else:
2383+
output_columns_dtype = None
2384+
2385+
return model_copy, output_columns_dtype
23862386

23872387

23882388
# Note: scikit-learn **requires** inherit first the mixins and then other classes
@@ -2534,6 +2534,8 @@ def __init__(
25342534
)
25352535
self._khiops_model_prefix = "SNB_"
25362536
self._predicted_target_meta_data_tag = "Mean"
2537+
self._predicted_target_name_prefix = "M"
2538+
self._original_target_dtype = np.float64
25372539

25382540
def fit(self, X, y=None, **kwargs):
25392541
"""Fits a Selective Naive Bayes regressor according to X, y
@@ -3090,7 +3092,7 @@ def transform(self, X):
30903092
X_transformed = super()._transform(
30913093
ds,
30923094
computation_dir,
3093-
self._transform_prepare_deployment_model,
3095+
self._transform_prepare_deployment_for_transform,
30943096
True,
30953097
"transform.txt",
30963098
)
@@ -3102,7 +3104,7 @@ def transform(self, X):
31023104
return X_transformed.to_numpy(copy=False)
31033105
return X_transformed
31043106

3105-
def _transform_prepare_deployment_model(self, ds):
3107+
def _transform_prepare_deployment_for_transform(self, ds):
31063108
assert hasattr(
31073109
self, "model_target_variable_name_"
31083110
), "Target variable name has not been set"
@@ -3115,7 +3117,11 @@ def _transform_prepare_deployment_model(self, ds):
31153117
if self.model_target_variable_name_ not in list(ds.main_table.column_ids):
31163118
model_dictionary.remove_variable(self.model_target_variable_name_)
31173119

3118-
return model_copy
3120+
# TODO
3121+
# Replace the second return value (the output columns' dtypes) with a proper
3122+
# value instead of `None`. In the current state, it will use pandas type
3123+
# auto-detection to load the internal table into memory.
3124+
return model_copy, None
31193125

31203126
def fit_transform(self, X, y=None, **kwargs):
31213127
"""Fit and transforms its inputs

0 commit comments

Comments
 (0)