Skip to content

Commit 127de4f

Browse files
Eliminate is_fitted_ sklearn attribute
1 parent 4e9aba0 commit 127de4f

File tree

2 files changed

+25
-26
lines changed

2 files changed

+25
-26
lines changed

khiops/sklearn/estimators.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
RegressorMixin,
3838
TransformerMixin,
3939
)
40+
from sklearn.exceptions import NotFittedError
4041
from sklearn.utils.multiclass import type_of_target
4142
from sklearn.utils.validation import assert_all_finite, check_is_fitted, column_or_1d
4243

@@ -302,7 +303,7 @@ def _undefine_estimator_attributes(self):
302303

303304
def _get_main_dictionary(self):
304305
"""Returns the model's main Khiops dictionary"""
305-
assert self.model_ is not None, "Model dictionary domain not available."
306+
self._assert_is_fitted()
306307
return self.model_.get_dictionary(self.model_main_dictionary_name_)
307308

308309
def export_report_file(self, report_file_path):
@@ -318,16 +319,14 @@ def export_report_file(self, report_file_path):
318319
`ValueError`
319320
When the instance is not fitted.
320321
"""
321-
if not self.is_fitted_:
322-
raise ValueError(f"{self.__class__.__name__} not fitted yet.")
322+
check_is_fitted(self)
323323
if self.model_report_ is None:
324324
raise ValueError("Report not available (imported model?).")
325325
self.model_report_.write_khiops_json_file(report_file_path)
326326

327327
def export_dictionary_file(self, dictionary_file_path):
328328
"""Export the model's Khiops dictionary file (.kdic)"""
329-
if not self.is_fitted_:
330-
raise ValueError(f"{self.__class__.__name__} not fitted yet.")
329+
check_is_fitted(self)
331330
self.model_.export_khiops_dictionary_file(dictionary_file_path)
332331

333332
def _import_model(self, kdic_path):
@@ -384,13 +383,19 @@ def fit(self, X, y=None, **kwargs):
384383
# If on "fitted" state then:
385384
# - self.model_ must be a DictionaryDomain
386385
# - self.model_report_ must be a KhiopsJSONObject
387-
if hasattr(self, "is_fitted_") and self.is_fitted_:
388-
assert hasattr(self, "model_") and isinstance(
389-
self.model_, kh.DictionaryDomain
390-
)
391-
assert hasattr(self, "model_report_") and isinstance(
392-
self.model_report_, kh.KhiopsJSONObject
393-
)
386+
try:
387+
check_is_fitted(self)
388+
assert isinstance(self.model_, kh.DictionaryDomain)
389+
assert isinstance(self.model_report_, kh.KhiopsJSONObject)
390+
assert isinstance(self.model_, kh.DictionaryDomain)
391+
# Note:
392+
# We ignore any raised NotFittedError by check_is_fitted because we are using
393+
# the try/catch as an if/else. The code intended is
394+
# if check_is_fitted(self):
395+
# # asserts
396+
# But check_is_fitted has a do-nothing or raise pattern.
397+
except NotFittedError:
398+
pass
394399

395400
return self
396401

@@ -424,7 +429,6 @@ def _fit(self, ds, computation_dir, **kwargs):
424429
and isinstance(self.model_report_, kh.KhiopsJSONObject)
425430
):
426431
self._fit_training_post_process(ds)
427-
self.is_fitted_ = True
428432
self.is_multitable_model_ = ds.is_multitable
429433
self.n_features_in_ = ds.main_table.n_features()
430434

@@ -649,6 +653,12 @@ def _create_computation_dir(self, method_name):
649653
prefix=f"{self.__class__.__name__}_{method_name}_"
650654
)
651655

656+
def _assert_is_fitted(self):
657+
try:
658+
check_is_fitted(self)
659+
except NotFittedError:
660+
raise AssertionError("Model not fitted")
661+
652662

653663
# Note: scikit-learn **requires** inherit first the mixins and then other classes
654664
class KhiopsCoclustering(ClusterMixin, KhiopsEstimator):
@@ -704,8 +714,6 @@ class KhiopsCoclustering(ClusterMixin, KhiopsEstimator):
704714
705715
Attributes
706716
----------
707-
is_fitted_ : bool
708-
``True`` if the estimator is fitted.
709717
is_multitable_model_ : bool
710718
``True`` if the model was fitted on a multi-table dataset.
711719
model_ : `.DictionaryDomain`
@@ -1152,7 +1160,6 @@ def _simplify(
11521160
# Copy relevant attributes
11531161
# Note: do not copy `model_*` attributes, that get rebuilt anyway
11541162
for attribute_name in (
1155-
"is_fitted_",
11561163
"is_multitable_model_",
11571164
"model_main_dictionary_name_",
11581165
"model_id_column",
@@ -1215,8 +1222,7 @@ def simplify(
12151222
A *new*, simplified `.KhiopsCoclustering` estimator instance.
12161223
"""
12171224
# Check that the estimator is fitted:
1218-
if not self.is_fitted_:
1219-
raise ValueError("Only fitted coclustering estimators can be simplified")
1225+
check_is_fitted(self)
12201226

12211227
return self._simplify(
12221228
max_preserved_information=max_preserved_information,
@@ -2015,8 +2021,6 @@ class KhiopsClassifier(ClassifierMixin, KhiopsPredictor):
20152021
20162022
- Importance: The geometric mean between the Level and the Weight.
20172023
2018-
is_fitted_ : bool
2019-
``True`` if the estimator is fitted.
20202024
is_multitable_model_ : bool
20212025
``True`` if the model was fitted on a multi-table dataset.
20222026
model_ : `.DictionaryDomain`
@@ -2097,7 +2101,7 @@ def _is_real_target_dtype_integer(self):
20972101

20982102
def _sorted_prob_variable_names(self):
20992103
"""Returns the model probability variable names in the order of self.classes_"""
2100-
assert self.is_fitted_, "Model not fit yet"
2104+
self._assert_is_fitted()
21012105

21022106
# Collect the probability variables from the model main dictionary
21032107
prob_variables = []
@@ -2483,8 +2487,6 @@ class KhiopsRegressor(RegressorMixin, KhiopsPredictor):
24832487
24842488
- Importance: The geometric mean between the Level and the Weight.
24852489
2486-
is_fitted_ : bool
2487-
``True`` if the estimator is fitted.
24882490
is_multitable_model_ : bool
24892491
``True`` if the model was fitted on a multi-table dataset.
24902492
model_ : `.DictionaryDomain`
@@ -2774,8 +2776,6 @@ class KhiopsEncoder(TransformerMixin, KhiopsSupervisedEstimator):
27742776
Level of the features evaluated by the classifier. The Level is measure of the
27752777
predictive importance of the feature taken individually. It ranges between 0 (no
27762778
predictive interest) and 1 (optimal predictive importance).
2777-
is_fitted_ : bool
2778-
``True`` if the estimator is fitted.
27792779
is_multitable_model_ : bool
27802780
``True`` if the model was fitted on a multi-table dataset.
27812781
model_ : `.DictionaryDomain`

tests/test_estimator_attributes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ def assert_attribute_values_ok(self, model, X, y):
172172
self.assertEqual(
173173
model.n_features_used_, len(feature_used_importances_report)
174174
)
175-
self.assertTrue(model.is_fitted_)
176175

177176
def test_classifier_attributes_monotable(self):
178177
"""Test consistency of KhiopsClassifier's attributes with the output reports

0 commit comments

Comments
 (0)