Skip to content

Commit c4e7daa

Browse files
authored
Merge pull request #353 from rasbt/nonfitted
raise notfittederror in stacking estimators
2 parents bac2d26 + 8fed8ee commit c4e7daa

10 files changed

Lines changed: 123 additions & 62 deletions

docs/sources/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@ The fit method of the SequentialFeatureSelector now optionally accepts **fit_par
2727

2828

2929
- Replaces `plot_decision_regions` colors by a colorblind-friendly palette and adds contour lines for decision regions. ([#348](https://github.com/rasbt/mlxtend/issues/348))
30+
- All stacking estimators now raise `NonFittedErrors` if any method for inference is called prior to fitting the estimator. ([#353](https://github.com/rasbt/mlxtend/issues/353))
3031

3132

3233
##### Bug Fixes
3334

34-
- -
35+
36+
3537

3638

3739
### Version 0.11.0 (2018-03-14)

mlxtend/classifier/stacking_classification.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
#
99
# License: BSD 3 clause
1010

11+
from ..externals.estimator_checks import check_is_fitted
12+
from ..externals.name_estimators import _name_estimators
13+
from ..externals import six
1114
from sklearn.base import BaseEstimator
1215
from sklearn.base import ClassifierMixin
1316
from sklearn.base import TransformerMixin
1417
from sklearn.base import clone
15-
from sklearn.utils.validation import check_is_fitted
16-
from ..externals.name_estimators import _name_estimators
17-
from ..externals import six
1818
import numpy as np
1919

2020

@@ -188,6 +188,7 @@ def predict_meta_features(self, X):
188188
Returns the meta-features for test data.
189189
190190
"""
191+
check_is_fitted(self, 'clfs_')
191192
if self.use_probas:
192193
probas = np.asarray([clf.predict_proba(X)
193194
for clf in self.clfs_])

mlxtend/classifier/stacking_cv_classification.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
# License: BSD 3 clause
1111

1212
from ..externals.name_estimators import _name_estimators
13+
from ..externals.estimator_checks import check_is_fitted
1314
from sklearn.base import BaseEstimator
1415
from sklearn.base import ClassifierMixin
1516
from sklearn.base import TransformerMixin
1617
from sklearn.base import clone
17-
from sklearn.utils.validation import check_is_fitted
1818
from sklearn.externals import six
1919
from sklearn.model_selection._split import check_cv
2020
import numpy as np
@@ -321,6 +321,7 @@ def predict(self, X):
321321
Predicted class labels.
322322
323323
"""
324+
check_is_fitted(self, 'clfs_')
324325
all_model_predictions = self.predict_meta_features(X)
325326
if not self.use_features_in_secondary:
326327
return self.meta_clf_.predict(all_model_predictions)

mlxtend/classifier/tests/test_stacking_classifier.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
# License: BSD 3 clause
66

77
from mlxtend.classifier import StackingClassifier
8+
from mlxtend.externals.estimator_checks import NotFittedError
89
from sklearn.linear_model import LogisticRegression
910
from sklearn.naive_bayes import GaussianNB
1011
from sklearn.ensemble import RandomForestClassifier
1112
from sklearn.neighbors import KNeighborsClassifier
1213
from sklearn.model_selection import GridSearchCV
1314
from sklearn.model_selection import cross_val_score
14-
from sklearn.exceptions import NotFittedError
1515
import numpy as np
1616
from sklearn import datasets
1717
from mlxtend.utils import assert_raises
@@ -197,6 +197,13 @@ def test_not_fitted():
197197
sclf.predict_proba,
198198
iris.data)
199199

200+
assert_raises(NotFittedError,
201+
"This StackingClassifier instance is not fitted yet."
202+
" Call 'fit' with appropriate arguments"
203+
" before using this method.",
204+
sclf.predict_meta_features,
205+
iris.data)
206+
200207

201208
def test_verbose():
202209
np.random.seed(123)

mlxtend/classifier/tests/test_stacking_cv_classifier.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,18 @@
55
#
66
# License: BSD 3 clause
77

8-
from mlxtend.classifier import StackingCVClassifier
9-
108
import pandas as pd
9+
import numpy as np
10+
from mlxtend.classifier import StackingCVClassifier
11+
from mlxtend.externals.estimator_checks import NotFittedError
12+
from mlxtend.utils import assert_raises
1113
from sklearn.linear_model import LogisticRegression
1214
from sklearn.naive_bayes import GaussianNB
1315
from sklearn.ensemble import RandomForestClassifier
1416
from sklearn.neighbors import KNeighborsClassifier
15-
import numpy as np
1617
from sklearn import datasets
17-
from mlxtend.utils import assert_raises
1818
from sklearn.model_selection import GridSearchCV
1919
from sklearn.model_selection import KFold
20-
from sklearn.exceptions import NotFittedError
2120
from sklearn.model_selection import cross_val_score
2221
from sklearn.model_selection import train_test_split
2322

@@ -203,6 +202,13 @@ def test_not_fitted():
203202
sclf.predict_proba,
204203
iris.data)
205204

205+
assert_raises(NotFittedError,
206+
"This StackingCVClassifier instance is not fitted yet."
207+
" Call 'fit' with appropriate arguments"
208+
" before using this method.",
209+
sclf.predict_meta_features,
210+
iris.data)
211+
206212

207213
def test_verbose():
208214
np.random.seed(123)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
2+
# Source: https://github.com/scikit-learn/scikit-learn
3+
4+
"""Utilities for input validation"""
5+
6+
# Authors: Olivier Grisel
7+
# Gael Varoquaux
8+
# Andreas Mueller
9+
# Lars Buitinck
10+
# Alexandre Gramfort
11+
# Nicolas Tresegnie
12+
# License: BSD 3 clause
13+
14+
15+
class NotFittedError(ValueError, AttributeError):
16+
"""Exception class to raise if estimator is used before fitting.
17+
This class inherits from both ValueError and AttributeError to help with
18+
exception handling and backward compatibility.
19+
Examples
20+
--------
21+
>>> from sklearn.svm import LinearSVC
22+
>>> from sklearn.exceptions import NotFittedError
23+
>>> try:
24+
... LinearSVC().predict([[1, 2], [2, 3], [3, 4]])
25+
... except NotFittedError as e:
26+
... print(repr(e))
27+
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
28+
NotFittedError('This LinearSVC instance is not fitted yet',)
29+
.. versionchanged:: 0.18
30+
Moved from sklearn.utils.validation.
31+
"""
32+
33+
34+
def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
35+
"""Perform is_fitted validation for estimator.
36+
Checks if the estimator is fitted by verifying the presence of
37+
"all_or_any" of the passed attributes and raises a NotFittedError with the
38+
given message.
39+
Parameters
40+
----------
41+
estimator : estimator instance.
42+
estimator instance for which the check is performed.
43+
attributes : attribute name(s) given as string or a list/tuple of strings
44+
Eg.:
45+
``["coef_", "estimator_", ...], "coef_"``
46+
msg : string
47+
The default error message is, "This %(name)s instance is not fitted
48+
yet. Call 'fit' with appropriate arguments before using this method."
49+
For custom messages if "%(name)s" is present in the message string,
50+
it is substituted for the estimator name.
51+
Eg. : "Estimator, %(name)s, must be fitted before sparsifying".
52+
all_or_any : callable, {all, any}, default all
53+
Specify whether all or any of the given attributes must exist.
54+
Returns
55+
-------
56+
None
57+
Raises
58+
------
59+
NotFittedError
60+
If the attributes are not found.
61+
"""
62+
if msg is None:
63+
msg = ("This %(name)s instance is not fitted yet. Call 'fit' with "
64+
"appropriate arguments before using this method.")
65+
66+
if not hasattr(estimator, 'fit'):
67+
raise TypeError("%s is not an estimator instance." % (estimator))
68+
69+
if not isinstance(attributes, (list, tuple)):
70+
attributes = [attributes]
71+
72+
if not all_or_any([hasattr(estimator, attr) for attr in attributes]):
73+
raise NotFittedError(msg % {'name': type(estimator).__name__})

mlxtend/regressor/stacking_cv_regression.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
#
1414
# License: BSD 3 clause
1515

16+
from ..externals.estimator_checks import check_is_fitted
17+
from ..externals import six
18+
from ..externals.name_estimators import _name_estimators
1619
from sklearn.base import BaseEstimator
1720
from sklearn.base import RegressorMixin
1821
from sklearn.base import TransformerMixin
1922
from sklearn.base import clone
20-
from sklearn.exceptions import NotFittedError
2123
from sklearn.model_selection._split import check_cv
22-
from ..externals import six
23-
from ..externals.name_estimators import _name_estimators
24+
2425
import numpy as np
2526

2627

@@ -203,9 +204,7 @@ def predict(self, X):
203204
# the meta-model from that info.
204205
#
205206

206-
if not hasattr(self, 'regr_'):
207-
raise NotFittedError("Estimator not fitted, "
208-
"call `fit` before exploiting the model.")
207+
check_is_fitted(self, 'regr_')
209208

210209
meta_features = np.column_stack([
211210
regr.predict(X) for regr in self.regr_
@@ -233,9 +232,7 @@ def predict_meta_features(self, X):
233232
of regressors.
234233
235234
"""
236-
if not hasattr(self, 'regr_'):
237-
raise NotFittedError("Estimator not fitted, "
238-
"call `fit` before exploiting the model.")
235+
check_is_fitted(self, 'regr_')
239236
return np.column_stack([regr.predict(X) for regr in self.regr_])
240237

241238
def get_params(self, deep=True):

mlxtend/regressor/stacking_regression.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
#
99
# License: BSD 3 clause
1010

11+
from ..externals.estimator_checks import check_is_fitted
12+
from ..externals.name_estimators import _name_estimators
13+
from ..externals import six
1114
from sklearn.base import BaseEstimator
1215
from sklearn.base import RegressorMixin
1316
from sklearn.base import TransformerMixin
1417
from sklearn.base import clone
15-
from sklearn.exceptions import NotFittedError
16-
from ..externals.name_estimators import _name_estimators
17-
from ..externals import six
1818
import numpy as np
1919

2020

@@ -183,9 +183,7 @@ def predict_meta_features(self, X):
183183
of regressors.
184184
185185
"""
186-
if not hasattr(self, 'regr_'):
187-
raise NotFittedError("Estimator not fitted, "
188-
"call `fit` before exploiting the model.")
186+
check_is_fitted(self, 'regr_')
189187
return np.column_stack([r.predict(X) for r in self.regr_])
190188

191189
def predict(self, X):
@@ -202,5 +200,6 @@ def predict(self, X):
202200
y_target : array-like, shape = [n_samples] or [n_samples, n_targets]
203201
Predicted target values.
204202
"""
203+
check_is_fitted(self, 'regr_')
205204
meta_features = self.predict_meta_features(X)
206205
return self.meta_regr_.predict(meta_features)

mlxtend/regressor/tests/test_stacking_cv_regression.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
# License: BSD 3 clause
99

1010
import numpy as np
11+
from mlxtend.externals.estimator_checks import NotFittedError
1112
from mlxtend.regressor import StackingCVRegressor
13+
from mlxtend.utils import assert_raises
1214
from sklearn.linear_model import LinearRegression
1315
from sklearn.linear_model import Ridge
1416
from sklearn.svm import SVR
15-
from sklearn.exceptions import NotFittedError
1617
from sklearn.model_selection import GridSearchCV, train_test_split
17-
from mlxtend.utils import assert_raises
1818

1919

2020
# Some test data
@@ -180,27 +180,14 @@ def test_not_fitted_predict():
180180
store_train_meta_features=True)
181181
X_train, X_test, y_train, y_test = train_test_split(X2, y, test_size=0.3)
182182

183-
expect = ("Estimator not fitted, "
184-
"call `fit` before exploiting the model.")
183+
expect = ("This StackingCVRegressor instance is not fitted yet. Call "
184+
"'fit' with appropriate arguments before using this method.")
185185

186186
assert_raises(NotFittedError,
187187
expect,
188188
stregr.predict,
189189
X_train)
190190

191-
192-
def test_not_fitted_predict_meta_features():
193-
lr = LinearRegression()
194-
svr_rbf = SVR(kernel='rbf')
195-
ridge = Ridge(random_state=1)
196-
stregr = StackingCVRegressor(regressors=[lr, ridge],
197-
meta_regressor=svr_rbf,
198-
store_train_meta_features=True)
199-
X_train, X_test, y_train, y_test = train_test_split(X2, y, test_size=0.3)
200-
201-
expect = ("Estimator not fitted, "
202-
"call `fit` before exploiting the model.")
203-
204191
assert_raises(NotFittedError,
205192
expect,
206193
stregr.predict_meta_features,

mlxtend/regressor/tests/test_stacking_regression.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@
44
#
55
# License: BSD 3 clause
66

7+
from mlxtend.externals.estimator_checks import NotFittedError
8+
from mlxtend.utils import assert_raises
79
from mlxtend.regressor import StackingRegressor
810
from sklearn.linear_model import LinearRegression
911
from sklearn.linear_model import Ridge
1012
from sklearn.svm import SVR
13+
from sklearn.model_selection import GridSearchCV
14+
from sklearn.model_selection import train_test_split
1115
import numpy as np
1216
from numpy.testing import assert_almost_equal
1317
from nose.tools import raises
14-
from sklearn.model_selection import GridSearchCV
15-
from sklearn.model_selection import train_test_split
16-
from sklearn.exceptions import NotFittedError
17-
from mlxtend.utils import assert_raises
18+
1819

1920

2021
# Generating a sample dataset
@@ -220,27 +221,14 @@ def test_not_fitted_predict():
220221
store_train_meta_features=True)
221222
X_train, X_test, y_train, y_test = train_test_split(X2, y, test_size=0.3)
222223

223-
expect = ("Estimator not fitted, "
224-
"call `fit` before exploiting the model.")
224+
expect = ("This StackingRegressor instance is not fitted yet. Call "
225+
"'fit' with appropriate arguments before using this method.")
225226

226227
assert_raises(NotFittedError,
227228
expect,
228229
stregr.predict,
229230
X_train)
230231

231-
232-
def test_not_fitted_predict_meta_features():
233-
lr = LinearRegression()
234-
svr_rbf = SVR(kernel='rbf')
235-
ridge = Ridge(random_state=1)
236-
stregr = StackingRegressor(regressors=[lr, ridge],
237-
meta_regressor=svr_rbf,
238-
store_train_meta_features=True)
239-
X_train, X_test, y_train, y_test = train_test_split(X2, y, test_size=0.3)
240-
241-
expect = ("Estimator not fitted, "
242-
"call `fit` before exploiting the model.")
243-
244232
assert_raises(NotFittedError,
245233
expect,
246234
stregr.predict_meta_features,

0 commit comments

Comments
 (0)