|
| 1 | + |
| 2 | +# Source: https://github.com/scikit-learn/scikit-learn |
| 3 | + |
| 4 | +"""Utilities for input validation""" |
| 5 | + |
| 6 | +# Authors: Olivier Grisel |
| 7 | +# Gael Varoquaux |
| 8 | +# Andreas Mueller |
| 9 | +# Lars Buitinck |
| 10 | +# Alexandre Gramfort |
| 11 | +# Nicolas Tresegnie |
| 12 | +# License: BSD 3 clause |
| 13 | + |
| 14 | + |
| 15 | +class NotFittedError(ValueError, AttributeError): |
| 16 | + """Exception class to raise if estimator is used before fitting. |
| 17 | + This class inherits from both ValueError and AttributeError to help with |
| 18 | + exception handling and backward compatibility. |
| 19 | + Examples |
| 20 | + -------- |
| 21 | + >>> from sklearn.svm import LinearSVC |
| 22 | + >>> from sklearn.exceptions import NotFittedError |
| 23 | + >>> try: |
| 24 | + ... LinearSVC().predict([[1, 2], [2, 3], [3, 4]]) |
| 25 | + ... except NotFittedError as e: |
| 26 | + ... print(repr(e)) |
| 27 | + ... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS |
| 28 | + NotFittedError('This LinearSVC instance is not fitted yet',) |
| 29 | + .. versionchanged:: 0.18 |
| 30 | + Moved from sklearn.utils.validation. |
| 31 | + """ |
| 32 | + |
| 33 | + |
| 34 | +def check_is_fitted(estimator, attributes, msg=None, all_or_any=all): |
| 35 | + """Perform is_fitted validation for estimator. |
| 36 | + Checks if the estimator is fitted by verifying the presence of |
| 37 | + "all_or_any" of the passed attributes and raises a NotFittedError with the |
| 38 | + given message. |
| 39 | + Parameters |
| 40 | + ---------- |
| 41 | + estimator : estimator instance. |
| 42 | + estimator instance for which the check is performed. |
| 43 | + attributes : attribute name(s) given as string or a list/tuple of strings |
| 44 | + Eg.: |
| 45 | + ``["coef_", "estimator_", ...], "coef_"`` |
| 46 | + msg : string |
| 47 | + The default error message is, "This %(name)s instance is not fitted |
| 48 | + yet. Call 'fit' with appropriate arguments before using this method." |
| 49 | + For custom messages if "%(name)s" is present in the message string, |
| 50 | + it is substituted for the estimator name. |
| 51 | + Eg. : "Estimator, %(name)s, must be fitted before sparsifying". |
| 52 | + all_or_any : callable, {all, any}, default all |
| 53 | + Specify whether all or any of the given attributes must exist. |
| 54 | + Returns |
| 55 | + ------- |
| 56 | + None |
| 57 | + Raises |
| 58 | + ------ |
| 59 | + NotFittedError |
| 60 | + If the attributes are not found. |
| 61 | + """ |
| 62 | + if msg is None: |
| 63 | + msg = ("This %(name)s instance is not fitted yet. Call 'fit' with " |
| 64 | + "appropriate arguments before using this method.") |
| 65 | + |
| 66 | + if not hasattr(estimator, 'fit'): |
| 67 | + raise TypeError("%s is not an estimator instance." % (estimator)) |
| 68 | + |
| 69 | + if not isinstance(attributes, (list, tuple)): |
| 70 | + attributes = [attributes] |
| 71 | + |
| 72 | + if not all_or_any([hasattr(estimator, attr) for attr in attributes]): |
| 73 | + raise NotFittedError(msg % {'name': type(estimator).__name__}) |
0 commit comments