@@ -284,6 +284,19 @@ def _get_main_dictionary(self):
284284 self ._assert_is_fitted ()
285285 return self .model_ .get_dictionary (self .model_main_dictionary_name_ )
286286
287+ def _read_model_from_dictionary_file (self , model_dictionary_file_path ):
288+ """Removes dictionaries that do not have the model prefix in their name
289+
290+ This function is necessary for the regression case because Khiops generates a
291+ baseline model which has to be removed for sklearn predictor.
292+ """
293+ model = kh .read_dictionary_file (model_dictionary_file_path )
294+ assert self ._khiops_model_prefix is not None
295+ for dictionary_name in [kdic .name for kdic in model .dictionaries ]:
296+ if not dictionary_name .startswith (self ._khiops_model_prefix ):
297+ model .remove_dictionary (dictionary_name )
298+ return model
299+
287300 def export_report_file (self , report_file_path ):
288301 """Exports the model report to a JSON file
289302
@@ -309,7 +322,7 @@ def export_dictionary_file(self, dictionary_file_path):
309322
310323 def _import_model (self , kdic_path ):
311324 """Sets model instance attribute by importing model from ``.kdic``"""
312- self .model_ = kh . read_dictionary_file (kdic_path )
325+ self .model_ = self . _read_model_from_dictionary_file (kdic_path )
313326
314327 def _get_output_dir (self , fallback_dir ):
315328 if self .output_dir :
@@ -808,7 +821,7 @@ def _fit_train_model(self, ds, computation_dir, **kwargs):
808821
809822 # Update the `model_` attribute of the coclustering estimator to the
810823 # new coclustering model
811- self .model_ = kh . read_dictionary_file (
824+ self .model_ = self . _read_model_from_dictionary_file (
812825 fs .get_child_path (
813826 output_dir , f"{ self .model_main_dictionary_name_ } _deployed.kdic"
814827 )
@@ -1021,7 +1034,7 @@ def _simplify(
10211034
10221035 # Set the `model_` attribute of the new coclustering estimator to
10231036 # the new coclustering model
1024- simplified_cc .model_ = kh . read_dictionary_file (
1037+ simplified_cc .model_ = self . _read_model_from_dictionary_file (
10251038 fs .get_child_path (
10261039 output_dir , f"{ self .model_main_dictionary_name_ } _deployed.kdic"
10271040 )
@@ -1205,6 +1218,7 @@ def __init__(
12051218 self .construction_rules = construction_rules
12061219 self ._original_target_dtype = None
12071220 self ._predicted_target_meta_data_tag = None
1221+ self ._khiops_baseline_model_prefix = None
12081222
12091223 def __sklearn_tags__ (self ):
12101224 # If we don't implement this trivial method it's not found by the sklearn. This
@@ -1295,7 +1309,7 @@ def _fit_train_model(self, ds, computation_dir, **kwargs):
12951309 return
12961310
12971311 # Save the model domain object and report
1298- self .model_ = kh . read_dictionary_file (model_kdic_file_path )
1312+ self .model_ = self . _read_model_from_dictionary_file (model_kdic_file_path )
12991313 self .model_report_ = kh .read_analysis_results_file (report_file_path )
13001314
13011315 @abstractmethod
@@ -1384,15 +1398,24 @@ def _fit_training_post_process(self, ds):
13841398 self .model_main_dictionary_name_ = self .model_ .dictionaries [0 ].name
13851399 else :
13861400 for dictionary in self .model_ .dictionaries :
1387- assert dictionary .name .startswith (self ._khiops_model_prefix ), (
1401+
1402+ # The baseline model is mandatory for regression;
1403+ # absent for classification and encoding
1404+ assert dictionary .name .startswith (
1405+ self ._khiops_model_prefix
1406+ ) or dictionary .name .startswith (self ._khiops_baseline_model_prefix ), (
13881407 f"Dictionary '{ dictionary .name } ' "
1389- f"does not have prefix '{ self ._khiops_model_prefix } '"
1390- )
1391- initial_dictionary_name = dictionary .name .replace (
1392- self ._khiops_model_prefix , "" , 1
1408+ f"does not have prefix '{ self ._khiops_model_prefix } ' "
1409+ f"or '{ self ._khiops_baseline_model_prefix } '."
13931410 )
1394- if initial_dictionary_name == ds .main_table .name :
1395- self .model_main_dictionary_name_ = dictionary .name
1411+
1412+ # Skip baseline model
1413+ if dictionary .name .startswith (self ._khiops_model_prefix ):
1414+ initial_dictionary_name = dictionary .name .replace (
1415+ self ._khiops_model_prefix , "" , 1
1416+ )
1417+ if initial_dictionary_name == ds .main_table .name :
1418+ self .model_main_dictionary_name_ = dictionary .name
13961419 if self .model_main_dictionary_name_ is None :
13971420 raise ValueError ("No model dictionary after Khiops call" )
13981421
@@ -2185,6 +2208,7 @@ def __init__(
21852208 auto_sort = auto_sort ,
21862209 )
21872210 self ._khiops_model_prefix = "SNB_"
2211+ self ._khiops_baseline_model_prefix = "B_"
21882212 self ._predicted_target_meta_data_tag = "Mean"
21892213 self ._predicted_target_name_prefix = "M"
21902214 self ._original_target_dtype = np .float64
@@ -2286,6 +2310,9 @@ def predict(self, X):
22862310 - str (a path for the file containing the array) if X is a dataset spec
22872311 containing file-path tables.
22882312 """
2313+ assert (
2314+ self ._khiops_baseline_model_prefix is not None
2315+ ), "Baseline model prefix is not set (mandatory for regression)"
22892316 # Call the parent's method
22902317 y_pred = super ().predict (X )
22912318
0 commit comments