@@ -284,6 +284,19 @@ def _get_main_dictionary(self):
284284 self ._assert_is_fitted ()
285285 return self .model_ .get_dictionary (self .model_main_dictionary_name_ )
286286
287+ def _read_model_from_dictionary_file (self , model_dictionary_file_path ):
288+ """Removes dictionaries that do not have the model prefix in their name
289+
290+ This function is necessary for the regression case because Khiops generates a
291+ baseline model which has to be removed for sklearn predictor.
292+ """
293+ model = kh .read_dictionary_file (model_dictionary_file_path )
294+ assert self ._khiops_model_prefix is not None
295+ for dictionary_name in [kdic .name for kdic in model .dictionaries ]:
296+ if not dictionary_name .startswith (self ._khiops_model_prefix ):
297+ model .remove_dictionary (dictionary_name )
298+ return model
299+
287300 def export_report_file (self , report_file_path ):
288301 """Exports the model report to a JSON file
289302
@@ -309,7 +322,7 @@ def export_dictionary_file(self, dictionary_file_path):
309322
310323 def _import_model (self , kdic_path ):
311324 """Sets model instance attribute by importing model from ``.kdic``"""
312- self .model_ = kh . read_dictionary_file (kdic_path )
325+ self .model_ = self . _read_model_from_dictionary_file (kdic_path )
313326
314327 def _get_output_dir (self , fallback_dir ):
315328 if self .output_dir :
@@ -806,7 +819,7 @@ def _fit_train_model(self, ds, computation_dir, **kwargs):
806819
807820 # Update the `model_` attribute of the coclustering estimator to the
808821 # new coclustering model
809- self .model_ = kh . read_dictionary_file (
822+ self .model_ = self . _read_model_from_dictionary_file (
810823 fs .get_child_path (
811824 output_dir , f"{ self .model_main_dictionary_name_ } _deployed.kdic"
812825 )
@@ -1019,7 +1032,7 @@ def _simplify(
10191032
10201033 # Set the `model_` attribute of the new coclustering estimator to
10211034 # the new coclustering model
1022- simplified_cc .model_ = kh . read_dictionary_file (
1035+ simplified_cc .model_ = self . _read_model_from_dictionary_file (
10231036 fs .get_child_path (
10241037 output_dir , f"{ self .model_main_dictionary_name_ } _deployed.kdic"
10251038 )
@@ -1204,6 +1217,7 @@ def __init__(
12041217 self .construction_rules = construction_rules
12051218 self ._original_target_dtype = None
12061219 self ._predicted_target_meta_data_tag = None
1220+ self ._khiops_baseline_model_prefix = None
12071221
12081222 def __sklearn_tags__ (self ):
12091223 # If we don't implement this trivial method it's not found by the sklearn. This
@@ -1294,7 +1308,7 @@ def _fit_train_model(self, ds, computation_dir, **kwargs):
12941308 return
12951309
12961310 # Save the model domain object and report
1297- self .model_ = kh . read_dictionary_file (model_kdic_file_path )
1311+ self .model_ = self . _read_model_from_dictionary_file (model_kdic_file_path )
12981312 self .model_report_ = kh .read_analysis_results_file (report_file_path )
12991313
13001314 @abstractmethod
@@ -1383,15 +1397,24 @@ def _fit_training_post_process(self, ds):
13831397 self .model_main_dictionary_name_ = self .model_ .dictionaries [0 ].name
13841398 else :
13851399 for dictionary in self .model_ .dictionaries :
1386- assert dictionary .name .startswith (self ._khiops_model_prefix ), (
1400+
1401+ # The baseline model is mandatory for regression;
1402+ # absent for classification and encoding
1403+ assert dictionary .name .startswith (
1404+ self ._khiops_model_prefix
1405+ ) or dictionary .name .startswith (self ._khiops_baseline_model_prefix ), (
13871406 f"Dictionary '{ dictionary .name } ' "
1388- f"does not have prefix '{ self ._khiops_model_prefix } '"
1389- )
1390- initial_dictionary_name = dictionary .name .replace (
1391- self ._khiops_model_prefix , "" , 1
1407+ f"does not have prefix '{ self ._khiops_model_prefix } ' "
1408+ f"or '{ self ._khiops_baseline_model_prefix } '."
13921409 )
1393- if initial_dictionary_name == ds .main_table .name :
1394- self .model_main_dictionary_name_ = dictionary .name
1410+
1411+ # Skip baseline model
1412+ if dictionary .name .startswith (self ._khiops_model_prefix ):
1413+ initial_dictionary_name = dictionary .name .replace (
1414+ self ._khiops_model_prefix , "" , 1
1415+ )
1416+ if initial_dictionary_name == ds .main_table .name :
1417+ self .model_main_dictionary_name_ = dictionary .name
13951418 if self .model_main_dictionary_name_ is None :
13961419 raise ValueError ("No model dictionary after Khiops call" )
13971420
@@ -2183,6 +2206,7 @@ def __init__(
21832206 auto_sort = auto_sort ,
21842207 )
21852208 self ._khiops_model_prefix = "SNB_"
2209+ self ._khiops_baseline_model_prefix = "B_"
21862210 self ._predicted_target_meta_data_tag = "Mean"
21872211 self ._predicted_target_name_prefix = "M"
21882212 self ._original_target_dtype = np .float64
@@ -2284,6 +2308,9 @@ def predict(self, X):
22842308 - str (a path for the file containing the array) if X is a dataset spec
22852309 containing file-path tables.
22862310 """
2311+ assert (
2312+ self ._khiops_baseline_model_prefix is not None
2313+ ), "Baseline model prefix is not set (mandatory for regression)"
22872314 # Call the parent's method
22882315 y_pred = super ().predict (X )
22892316
0 commit comments