3737 RegressorMixin ,
3838 TransformerMixin ,
3939)
40+ from sklearn .exceptions import NotFittedError
4041from sklearn .utils .multiclass import type_of_target
4142from sklearn .utils .validation import assert_all_finite , check_is_fitted , column_or_1d
4243
@@ -302,7 +303,7 @@ def _undefine_estimator_attributes(self):
302303
303304 def _get_main_dictionary (self ):
304305 """Returns the model's main Khiops dictionary"""
305- assert self .model_ is not None , "Model dictionary domain not available."
306+ self ._assert_is_fitted ()
306307 return self .model_ .get_dictionary (self .model_main_dictionary_name_ )
307308
308309 def export_report_file (self , report_file_path ):
@@ -318,16 +319,14 @@ def export_report_file(self, report_file_path):
318319 `ValueError`
319320 When the instance is not fitted.
320321 """
321- if not self .is_fitted_ :
322- raise ValueError (f"{ self .__class__ .__name__ } not fitted yet." )
322+ check_is_fitted (self )
323323 if self .model_report_ is None :
324324 raise ValueError ("Report not available (imported model?)." )
325325 self .model_report_ .write_khiops_json_file (report_file_path )
326326
327327 def export_dictionary_file (self , dictionary_file_path ):
328328 """Export the model's Khiops dictionary file (.kdic)"""
329- if not self .is_fitted_ :
330- raise ValueError (f"{ self .__class__ .__name__ } not fitted yet." )
329+ check_is_fitted (self )
331330 self .model_ .export_khiops_dictionary_file (dictionary_file_path )
332331
333332 def _import_model (self , kdic_path ):
@@ -384,13 +383,19 @@ def fit(self, X, y=None, **kwargs):
384383 # If on "fitted" state then:
385384 # - self.model_ must be a DictionaryDomain
386385 # - self.model_report_ must be a KhiopsJSONObject
387- if hasattr (self , "is_fitted_" ) and self .is_fitted_ :
388- assert hasattr (self , "model_" ) and isinstance (
389- self .model_ , kh .DictionaryDomain
390- )
391- assert hasattr (self , "model_report_" ) and isinstance (
392- self .model_report_ , kh .KhiopsJSONObject
393- )
386+ try :
387+ check_is_fitted (self )
388+ assert isinstance (self .model_ , kh .DictionaryDomain )
389+ assert isinstance (self .model_report_ , kh .KhiopsJSONObject )
390+ assert isinstance (self .model_ , kh .DictionaryDomain )
391+ # Note:
392+ # We ignore any raised NotFittedError by check_is_fitted because we are using
393+ # the try/catch as an if/else. The code intended is
394+ # if check_is_fitted(self):
395+ # # asserts
396+ # But check_is_fitted has a do-nothing or raise pattern.
397+ except NotFittedError :
398+ pass
394399
395400 return self
396401
@@ -424,7 +429,6 @@ def _fit(self, ds, computation_dir, **kwargs):
424429 and isinstance (self .model_report_ , kh .KhiopsJSONObject )
425430 ):
426431 self ._fit_training_post_process (ds )
427- self .is_fitted_ = True
428432 self .is_multitable_model_ = ds .is_multitable
429433 self .n_features_in_ = ds .main_table .n_features ()
430434
@@ -649,6 +653,12 @@ def _create_computation_dir(self, method_name):
649653 prefix = f"{ self .__class__ .__name__ } _{ method_name } _"
650654 )
651655
656+ def _assert_is_fitted (self ):
657+ try :
658+ check_is_fitted (self )
659+ except NotFittedError :
660+ raise AssertionError ("Model not fitted" )
661+
652662
653663# Note: scikit-learn **requires** inherit first the mixins and then other classes
654664class KhiopsCoclustering (ClusterMixin , KhiopsEstimator ):
@@ -704,8 +714,6 @@ class KhiopsCoclustering(ClusterMixin, KhiopsEstimator):
704714
705715 Attributes
706716 ----------
707- is_fitted_ : bool
708- ``True`` if the estimator is fitted.
709717 is_multitable_model_ : bool
710718 ``True`` if the model was fitted on a multi-table dataset.
711719 model_ : `.DictionaryDomain`
@@ -1152,7 +1160,6 @@ def _simplify(
11521160 # Copy relevant attributes
11531161 # Note: do not copy `model_*` attributes, that get rebuilt anyway
11541162 for attribute_name in (
1155- "is_fitted_" ,
11561163 "is_multitable_model_" ,
11571164 "model_main_dictionary_name_" ,
11581165 "model_id_column" ,
@@ -1215,8 +1222,7 @@ def simplify(
12151222 A *new*, simplified `.KhiopsCoclustering` estimator instance.
12161223 """
12171224 # Check that the estimator is fitted:
1218- if not self .is_fitted_ :
1219- raise ValueError ("Only fitted coclustering estimators can be simplified" )
1225+ check_is_fitted (self )
12201226
12211227 return self ._simplify (
12221228 max_preserved_information = max_preserved_information ,
@@ -2015,8 +2021,6 @@ class KhiopsClassifier(ClassifierMixin, KhiopsPredictor):
20152021
20162022 - Importance: The geometric mean between the Level and the Weight.
20172023
2018- is_fitted_ : bool
2019- ``True`` if the estimator is fitted.
20202024 is_multitable_model_ : bool
20212025 ``True`` if the model was fitted on a multi-table dataset.
20222026 model_ : `.DictionaryDomain`
@@ -2097,7 +2101,7 @@ def _is_real_target_dtype_integer(self):
20972101
20982102 def _sorted_prob_variable_names (self ):
20992103 """Returns the model probability variable names in the order of self.classes_"""
2100- assert self .is_fitted_ , "Model not fit yet"
2104+ self ._assert_is_fitted ()
21012105
21022106 # Collect the probability variables from the model main dictionary
21032107 prob_variables = []
@@ -2483,8 +2487,6 @@ class KhiopsRegressor(RegressorMixin, KhiopsPredictor):
24832487
24842488 - Importance: The geometric mean between the Level and the Weight.
24852489
2486- is_fitted_ : bool
2487- ``True`` if the estimator is fitted.
24882490 is_multitable_model_ : bool
24892491 ``True`` if the model was fitted on a multi-table dataset.
24902492 model_ : `.DictionaryDomain`
@@ -2774,8 +2776,6 @@ class KhiopsEncoder(TransformerMixin, KhiopsSupervisedEstimator):
27742776 Level of the features evaluated by the classifier. The Level is measure of the
27752777 predictive importance of the feature taken individually. It ranges between 0 (no
27762778 predictive interest) and 1 (optimal predictive importance).
2777- is_fitted_ : bool
2778- ``True`` if the estimator is fitted.
27792779 is_multitable_model_ : bool
27802780 ``True`` if the model was fitted on a multi-table dataset.
27812781 model_ : `.DictionaryDomain`
0 commit comments