diff --git a/chemotools/model_selection/__init__.py b/chemotools/model_selection/__init__.py new file mode 100644 index 00000000..0288e7f7 --- /dev/null +++ b/chemotools/model_selection/__init__.py @@ -0,0 +1,4 @@ +from ._candidate_selector import CandidateSelector +from ._fitted_model import BaseFittedModel + +__all__ = ["CandidateSelector", "BaseFittedModel"] diff --git a/chemotools/model_selection/_candidate_selector.py b/chemotools/model_selection/_candidate_selector.py new file mode 100644 index 00000000..1a1d7af8 --- /dev/null +++ b/chemotools/model_selection/_candidate_selector.py @@ -0,0 +1,426 @@ +""" +The :mod:`chemotools.model_selection._candidate_selector` module implements +model selection with enhanced candidate evaluation and RMSE metrics. +""" + +# Authors: Nusret Emirhan Salli +# License: MIT + +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +from sklearn.base import BaseEstimator +from sklearn.model_selection import GridSearchCV +from sklearn.utils.validation import check_is_fitted +import operator + +from ._fitted_model import BaseFittedModel + +__all__ = ["CandidateSelector"] + + +class CandidateSelector(BaseEstimator): + """Model selection wrapper that produces ranked candidates with RMSE metrics. + + Wraps :class:`~sklearn.model_selection.GridSearchCV` and converts every + evaluated parameter set into a :class:`BaseFittedModel` instance with + RMSE-based metrics (RMSECV, RMSE ratio, etc.). + + Parameters + ---------- + estimator : estimator object + A scikit-learn compatible estimator. + param_grid : dict or list of dict + Parameter grid to search. + scoring : str or callable, default=None + Scoring strategy (passed to ``GridSearchCV``). + cv : int, default=5 + Number of cross-validation folds. + n_jobs : int, default=None + Parallelism for grid search. + verbose : int, default=0 + Verbosity level. + return_train_score : bool, default=True + Whether to compute training scores (required for RMSE ratio). + + Attributes + ---------- + cv_results_ : dict + Raw results from the underlying ``GridSearchCV``. + best_estimator_ : estimator + Estimator refitted on the full training set with best parameters. + best_params_ : dict + Parameters that achieved the best score. + best_score_ : float + Best cross-validation score. + candidates_ : list of BaseFittedModel + All evaluated candidates sorted by rank. + """ + + def __init__( + self, + estimator: BaseEstimator, + param_grid: Union[Dict, List[Dict]], + *, + scoring: Optional[Union[str, Callable[..., float]]] = None, + cv: int = 5, + n_jobs: Optional[int] = None, + verbose: int = 0, + return_train_score: bool = True, + ) -> None: + self.estimator = estimator + self.param_grid = param_grid + self.scoring = scoring + self.cv = cv + self.n_jobs = n_jobs + self.verbose = verbose + self.return_train_score = return_train_score + + def fit( + self, X: np.ndarray, y: Optional[np.ndarray] = None, **fit_params + ) -> "CandidateSelector": + """Run grid search and build candidate list.""" + grid = GridSearchCV( + estimator=self.estimator, + param_grid=self.param_grid, + scoring=self.scoring, + cv=self.cv, + n_jobs=self.n_jobs, + verbose=self.verbose, + return_train_score=self.return_train_score, + refit=True, + ) + grid.fit(X, y, **fit_params) + + self.cv_results_ = grid.cv_results_ + self.best_estimator_ = grid.best_estimator_ + self.best_params_ = grid.best_params_ + self.best_score_ = grid.best_score_ + self.best_index_ = grid.best_index_ + + self.candidates_ = self._build_candidates() + return self + + def _build_candidates(self) -> List[BaseFittedModel]: + """Create sorted list of ``BaseFittedModel`` from cv_results_.""" + candidates = [] + for idx in range(len(self.cv_results_["params"])): + candidate = BaseFittedModel.from_cv_results( + estimator=self.estimator, + cv_results=self.cv_results_, + index=idx, + scoring=self.scoring, + ) + candidates.append(candidate) + + candidates.sort(key=lambda c: (c.rank or float("inf"), c.cv_results_index or 0)) + return candidates + + def get_candidates(self, n: Optional[int] = None) -> List[BaseFittedModel]: + """Return top *n* candidates (all if n is None).""" + check_is_fitted(self, ["candidates_"]) + if n is None: + return self.candidates_ + return self.candidates_[:n] + + def get_candidate(self, rank: int = 1) -> BaseFittedModel: + """Return candidate by rank (1 = best).""" + check_is_fitted(self, ["candidates_"]) + for c in self.candidates_: + if c.rank == rank: + return c + raise ValueError(f"No candidate with rank {rank}.") + + def filter_candidates( + self, + metric: str = "rmse_ratio", + threshold: float = 1.1, + mode: str = "<=", + ) -> List[BaseFittedModel]: + """Filter candidates based on a metric threshold. + + Parameters + ---------- + metric : str, default='rmse_ratio' + The metric to filter by. Options: 'rmsecv', 'rmse_train', 'rmse_ratio', + 'mean_test_score', 'std_test_score', 'variance'. + threshold : float, default=1.1 + The threshold value for filtering. + mode : str, default='<=' + Comparison mode. Options: '<=', '>=', '<', '>', '=='. + + Returns + ------- + list of BaseFittedModel + Candidates that match the filter criteria. + + Examples + -------- + >>> # Get candidates with RMSE ratio <= 1.2 (good generalization) + >>> robust = selector.filter_candidates(metric='rmse_ratio', threshold=1.2) + >>> # Get candidates with low variance (stable across folds) + >>> stable = selector.filter_candidates(metric='variance', threshold=0.01) + """ + + check_is_fitted(self, ["candidates_"]) + + ops = { + "<=": operator.le, + ">=": operator.ge, + "<": operator.lt, + ">": operator.gt, + "==": operator.eq, + } + cmp = ops.get(mode) + if cmp is None: + raise ValueError(f"mode must be one of {list(ops.keys())}") + + result = [] + for c in self.candidates_: + val = getattr(c, metric, None) + if val is None: + val = c.to_dict().get(metric) + if val is not None and cmp(val, threshold): + result.append(c) + return result + + def predict(self, X: np.ndarray) -> np.ndarray: + """Predict using best estimator.""" + check_is_fitted(self, ["best_estimator_"]) + return self.best_estimator_.predict(X) + + def score(self, X: np.ndarray, y: np.ndarray) -> float: + """Score using best estimator.""" + check_is_fitted(self, ["best_estimator_"]) + return self.best_estimator_.score(X, y) + + def __len__(self) -> int: + """Return the number of candidates.""" + check_is_fitted(self, ["candidates_"]) + return len(self.candidates_) + + def __iter__(self) -> Iterator[BaseFittedModel]: + """Iterate over candidates.""" + check_is_fitted(self, ["candidates_"]) + return iter(self.candidates_) + + def summary(self, n: int = 10) -> str: + """Return a formatted summary of the top candidates. + + Parameters + ---------- + n : int, default=10 + Number of top candidates to include in summary. + + Returns + ------- + str + Formatted summary string. + """ + check_is_fitted(self, ["candidates_"]) + + lines = [ + "CandidateSelector Summary", + f"{'=' * 60}", + f"Total candidates: {len(self.candidates_)}", + f"Best score: {self.best_score_:.6f}", + f"Best params: {self.best_params_}", + "", + f"Top {min(n, len(self.candidates_))} Candidates:", + f"{'-' * 60}", + ] + + # Header + lines.append( + f"{'Rank':>4} {'RMSECV':>10} {'RMSE_train':>10} {'Ratio':>8} {'Variance':>12}" + ) + lines.append(f"{'-' * 4} {'-' * 10} {'-' * 10} {'-' * 8} {'-' * 12}") + + for c in self.candidates_[:n]: + rmsecv = f"{c.rmsecv:.4f}" if c.rmsecv is not None else "N/A" + rmse_train = f"{c.rmse_train:.4f}" if c.rmse_train is not None else "N/A" + ratio = f"{c.rmse_ratio:.3f}" if c.rmse_ratio is not None else "N/A" + var = f"{c.variance:.2e}" if c.variance is not None else "N/A" + lines.append( + f"{c.rank:>4} {rmsecv:>10} {rmse_train:>10} {ratio:>8} {var:>12}" + ) + + return "\n".join(lines) + + def to_dataframe(self): + """Convert all candidates to a pandas DataFrame. + + Returns + ------- + pandas.DataFrame + DataFrame with all candidate metrics and parameters. + + Raises + ------ + ImportError + If pandas is not installed. + """ + check_is_fitted(self, ["candidates_"]) + + try: + import pandas as pd + except ImportError: + raise ImportError( + "pandas is required for to_dataframe(). Install with: pip install pandas" + ) + + records = [] + for c in self.candidates_: + record = c.to_dict() + # Flatten params into separate columns + for key, value in c.params.items(): + record[f"param_{key}"] = value + records.append(record) + + return pd.DataFrame(records) + + def _create_scatter_plot( + self, + x_metric: str, + y_metric: str, + color_by: Optional[str], + ax, + figsize: Tuple[int, int], + title: str, + xlabel: str, + ylabel: str, + hline: Optional[float] = None, + ): + """Internal helper to create scatter plots with consistent styling.""" + check_is_fitted(self, ["candidates_"]) + + # Auto-detect color_by parameter + if color_by is None and self.candidates_: + color_by = next(iter(self.candidates_[0].params), None) + + # Group data by color_by parameter + groups: Dict[Any, List[Tuple[float, float]]] = {} + for c in self.candidates_: + x_val = getattr(c, x_metric, None) or c.to_dict().get(x_metric) + y_val = getattr(c, y_metric, None) or c.to_dict().get(y_metric) + if x_val is None or y_val is None: + continue + key = c.params.get(color_by) if color_by else c.rank + groups.setdefault(key, []).append((x_val, y_val)) + + if not groups: + raise ValueError( + f"No valid data found for metrics '{x_metric}' and '{y_metric}'." + ) + + if ax is None: + _, ax = plt.subplots(figsize=figsize) + + markers = ["o", "s", "^", "D", "v", "*", "p", "h"] + cmap = plt.colormaps.get_cmap("tab10") + + for idx, key in enumerate(sorted(groups.keys())): + data = groups[key] + ax.scatter( + [d[0] for d in data], + [d[1] for d in data], + marker=markers[idx % len(markers)], + c=[cmap(idx % 10)], + s=80, + label=str(key), + edgecolors="black", + linewidths=0.5, + alpha=0.8, + ) + + if hline is not None: + ax.axhline(y=hline, linestyle="-", color="green", linewidth=2, alpha=0.8) + + ax.set_xlabel(xlabel, fontsize=11) + ax.set_ylabel(ylabel, fontsize=11) + ax.set_title(title, fontsize=12) + + param_label = ( + color_by.split("__")[-1] if color_by and "__" in color_by else color_by + ) + ax.legend(title=param_label or "Group", loc="best", fontsize=9) + ax.grid(True, alpha=0.3) + + return ax + + def plot_cv_metrics( + self, + color_by: Optional[str] = None, + *, + ax=None, + figsize: Tuple[int, int] = (10, 6), + show_ratio_threshold: Optional[float] = 1.0, + title: Optional[str] = None, + ): + """Plot RMSECV vs RMSE ratio for model selection. + + Parameters + ---------- + color_by : str, optional + Parameter name to use for coloring points. If None, auto-detects. + ax : matplotlib.axes.Axes, optional + Axes to plot on. If None, creates a new figure. + figsize : tuple, default=(10, 6) + Figure size if creating a new figure. + show_ratio_threshold : float, default=1.0 + Draws a horizontal line at this RMSE ratio value. + title : str, optional + Custom title for the plot. + + Returns + ------- + ax : matplotlib.axes.Axes + """ + return self._create_scatter_plot( + x_metric="rmsecv", + y_metric="rmse_ratio", + color_by=color_by, + ax=ax, + figsize=figsize, + title=title or "Cross-validation Error vs Overfitting", + xlabel="RMSECV", + ylabel="RMSECV / RMSEC", + hline=show_ratio_threshold, + ) + + def plot_score_vs_variance( + self, + color_by: Optional[str] = None, + *, + ax=None, + figsize: Tuple[int, int] = (10, 6), + title: Optional[str] = None, + ): + """Plot test score vs variance for model selection. + + Parameters + ---------- + color_by : str, optional + Parameter name to use for coloring points. If None, auto-detects. + ax : matplotlib.axes.Axes, optional + Axes to plot on. If None, creates a new figure. + figsize : tuple, default=(10, 6) + Figure size if creating a new figure. + title : str, optional + Custom title for the plot. + + Returns + ------- + ax : matplotlib.axes.Axes + """ + return self._create_scatter_plot( + x_metric="variance", + y_metric="mean_test_score", + color_by=color_by, + ax=ax, + figsize=figsize, + title=title or "Model Stability vs Performance", + xlabel="Variance", + ylabel="Mean Test Score", + ) diff --git a/chemotools/model_selection/_fitted_model.py b/chemotools/model_selection/_fitted_model.py new file mode 100644 index 00000000..80f7d1f8 --- /dev/null +++ b/chemotools/model_selection/_fitted_model.py @@ -0,0 +1,163 @@ +""" +The :mod:`chemotools.model_selection._fitted_model` module implements +a container for storing candidate model information during model selection. +""" + +# Authors: Nusret Emirhan Salli +# License: MIT + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Sequence, Union + +from sklearn.base import BaseEstimator, clone + +__all__ = ["BaseFittedModel"] + +ScoringType = Optional[Union[str, Callable[..., float]]] + + +@dataclass(slots=True) +class BaseFittedModel: + """Container for a single candidate model's results. + + Parameters + ---------- + estimator : BaseEstimator + The estimator instance (scikit-learn compatible). + params : dict + Parameters used when fitting the estimator. + rank : int, optional + Rank assigned by model selection (1 = best). + mean_test_score : float, optional + Mean cross-validation score. + std_test_score : float, optional + Standard deviation of cross-validation score. + mean_train_score : float, optional + Mean training score. + scoring : str or callable, optional + Scoring function used. + cv_results_index : int, optional + Index in ``cv_results_``. + rmsecv : float, optional + Root-mean-square error from cross-validation. + rmse_train : float, optional + Training RMSE. + rmse_ratio : float, optional + Ratio of RMSECV to training RMSE (overfitting indicator). + """ + + estimator: BaseEstimator + params: Dict[str, Any] + rank: Optional[int] = None + mean_test_score: Optional[float] = None + std_test_score: Optional[float] = None + mean_train_score: Optional[float] = None + scoring: ScoringType = None + cv_results_index: Optional[int] = None + rmsecv: Optional[float] = None + rmse_train: Optional[float] = None + rmse_ratio: Optional[float] = None + + def __post_init__(self) -> None: + if self.rank is not None and self.rank < 1: + raise ValueError("rank must be a positive integer.") + if not isinstance(self.params, dict): + raise TypeError("params must be a dictionary.") + + @staticmethod + def _to_native(value: Any) -> Any: + """Convert numpy scalars to native Python types.""" + if hasattr(value, "item"): + try: + return value.item() + except Exception: + return value + return value + + @classmethod + def from_cv_results( + cls, + estimator: BaseEstimator, + cv_results: Dict[str, Sequence[Any]], + index: int, + *, + scoring: ScoringType = None, + ) -> BaseFittedModel: + """Create a ``BaseFittedModel`` from a ``GridSearchCV`` result entry.""" + params = cv_results["params"][index] + estimator_clone = clone(estimator) + estimator_clone.set_params(**params) + + rank_values = cv_results.get("rank_test_score") + rank = ( + int(cls._to_native(rank_values[index])) if rank_values is not None else None + ) + + mean_test = cls._to_native(cv_results["mean_test_score"][index]) + std_test = cls._to_native(cv_results["std_test_score"][index]) + mean_train = ( + cls._to_native(cv_results["mean_train_score"][index]) + if "mean_train_score" in cv_results + else None + ) + + # Calculate RMSE metrics if using neg_root_mean_squared_error + rmsecv, rmse_train_val, rmse_ratio = None, None, None + if isinstance(scoring, str) and "neg_root_mean_squared_error" in scoring: + rmsecv = -mean_test + if mean_train is not None: + rmse_train_val = -mean_train + if rmse_train_val != 0: + rmse_ratio = rmsecv / rmse_train_val + + return cls( + estimator=estimator_clone, + params=params, + rank=rank, + mean_test_score=mean_test, + std_test_score=std_test, + mean_train_score=mean_train, + scoring=scoring, + cv_results_index=index, + rmsecv=rmsecv, + rmse_train=rmse_train_val, + rmse_ratio=rmse_ratio, + ) + + @property + def variance(self) -> Optional[float]: + """Variance of the test score (std_test_score²).""" + return self.std_test_score**2 if self.std_test_score is not None else None + + def to_dict(self) -> Dict[str, Any]: + """Return a dictionary representation of the candidate.""" + return { + "rank": self.rank, + "params": self.params, + "mean_test_score": self.mean_test_score, + "std_test_score": self.std_test_score, + "variance": self.variance, + "mean_train_score": self.mean_train_score, + "rmsecv": self.rmsecv, + "rmse_train": self.rmse_train, + "rmse_ratio": self.rmse_ratio, + } + + def clone_estimator(self) -> BaseEstimator: + """Return a fresh estimator with the stored parameters.""" + cloned = clone(self.estimator) + cloned.set_params(**self.params) + return cloned + + def __repr__(self) -> str: + """Return a concise string representation.""" + parts = [f"Rank {self.rank}"] + if self.rmsecv is not None: + parts.append(f"RMSECV={self.rmsecv:.4f}") + if self.rmse_ratio is not None: + parts.append(f"ratio={self.rmse_ratio:.3f}") + if self.variance is not None: + parts.append(f"var={self.variance:.2e}") + return f"BaseFittedModel({', '.join(parts)})" diff --git a/tests/model_selection/test_candidate_selector.py b/tests/model_selection/test_candidate_selector.py new file mode 100644 index 00000000..1a7adb5e --- /dev/null +++ b/tests/model_selection/test_candidate_selector.py @@ -0,0 +1,273 @@ +import numpy as np +import pytest +from sklearn.linear_model import Ridge + +from chemotools.model_selection import BaseFittedModel, CandidateSelector + + +# -- Fixtures ------------------------------------------------------------------ + + +@pytest.fixture +def fitted_selector(dummy_data_loader): + """Return a fitted CandidateSelector for testing.""" + X, y = dummy_data_loader + selector = CandidateSelector( + estimator=Ridge(random_state=0), + param_grid={"alpha": [0.1, 1.0, 10.0]}, + cv=3, + scoring="neg_root_mean_squared_error", + return_train_score=True, + n_jobs=1, + ) + selector.fit(X, y) + return selector + + +@pytest.fixture +def unfitted_selector(): + """Return an unfitted CandidateSelector for testing.""" + return CandidateSelector( + estimator=Ridge(random_state=0), + param_grid={"alpha": [0.1, 1.0, 10.0]}, + cv=3, + scoring="neg_root_mean_squared_error", + return_train_score=True, + n_jobs=1, + ) + + +# -- Test instantiation and fitting -------------------------------------------- + + +def test_instantiation_and_fit(dummy_data_loader): + """Test that CandidateSelector can be instantiated with valid params and fitted.""" + # Arrange + X, y = dummy_data_loader + selector = CandidateSelector( + estimator=Ridge(random_state=0), + param_grid={"alpha": [0.1, 1.0, 10.0]}, + cv=3, + scoring="neg_root_mean_squared_error", + return_train_score=True, + n_jobs=1, + ) + + # Act + result = selector.fit(X, y) + + # Assert + assert result is selector + assert hasattr(selector, "cv_results_") + assert hasattr(selector, "best_estimator_") + assert hasattr(selector, "best_params_") + assert hasattr(selector, "best_score_") + assert hasattr(selector, "candidates_") + assert len(selector.candidates_) == 3 + + +# -- Test get_candidates and get_candidate ------------------------------------- + + +def test_get_candidates(fitted_selector): + """Test retrieving all candidates and specific candidates by rank.""" + # Act + all_candidates = fitted_selector.get_candidates() + top_2 = fitted_selector.get_candidates(n=2) + best = fitted_selector.get_candidate(rank=1) + + # Assert + assert len(all_candidates) == 3 + assert all(isinstance(c, BaseFittedModel) for c in all_candidates) + assert len(top_2) == 2 + assert isinstance(best, BaseFittedModel) + assert best.rank == 1 + # Verify candidates are sorted by rank + ranks = [c.rank for c in all_candidates] + assert ranks == sorted(ranks) + + +def test_get_candidate_invalid_rank_raises_error(fitted_selector): + """Test that requesting an invalid rank raises an error.""" + with pytest.raises(ValueError, match="No candidate with rank"): + fitted_selector.get_candidate(rank=999) + + +# -- Test filter_candidates ---------------------------------------------------- + + +def test_filter_candidates(fitted_selector): + """Test filtering candidates by different metrics and modes.""" + # Test filter by rmse_ratio with <= mode + filtered_le = fitted_selector.filter_candidates( + metric="rmse_ratio", threshold=2.0, mode="<=" + ) + assert isinstance(filtered_le, list) + assert all(isinstance(c, BaseFittedModel) for c in filtered_le) + + # Test filter with >= mode + filtered_ge = fitted_selector.filter_candidates( + metric="rmse_ratio", threshold=0.5, mode=">=" + ) + assert all(c.rmse_ratio >= 0.5 for c in filtered_ge if c.rmse_ratio is not None) + + +def test_filter_candidates_invalid_mode_raises_error(fitted_selector): + """Test that an invalid filter mode raises an error.""" + with pytest.raises(ValueError, match="mode must be one of"): + fitted_selector.filter_candidates(metric="rmse_ratio", threshold=1.0, mode="!=") + + +# -- Test predict and score ---------------------------------------------------- + + +def test_predict_and_score(fitted_selector, dummy_data_loader): + """Test prediction and scoring methods.""" + X, y = dummy_data_loader + + # Act + predictions = fitted_selector.predict(X) + score = fitted_selector.score(X, y) + + # Assert + assert isinstance(predictions, np.ndarray) + assert predictions.shape[0] == X.shape[0] + assert isinstance(score, float) + + +# -- Test __len__ and __iter__ ------------------------------------------------- + + +def test_len_and_iter(fitted_selector): + """Test length and iteration over candidates.""" + # Act + length = len(fitted_selector) + candidates = list(fitted_selector) + + # Assert + assert length == 3 + assert len(candidates) == 3 + assert all(isinstance(c, BaseFittedModel) for c in candidates) + + +# -- Test summary and to_dataframe --------------------------------------------- + + +def test_summary(fitted_selector): + """Test summary method returns expected format.""" + # Act + summary = fitted_selector.summary() + summary_n = fitted_selector.summary(n=2) + + # Assert + assert isinstance(summary, str) + assert "CandidateSelector Summary" in summary + assert "Total candidates:" in summary + assert "Best score:" in summary + assert "Top 2 Candidates:" in summary_n + + +# -- Test RMSE metrics and candidate properties -------------------------------- + + +def test_candidates_have_rmse_metrics(fitted_selector): + """Test that candidates have valid RMSE metrics.""" + candidates = fitted_selector.get_candidates() + + for c in candidates: + assert c.rmsecv is not None and c.rmsecv > 0 + assert c.rmse_train is not None and c.rmse_train > 0 + assert c.rmse_ratio is not None and c.rmse_ratio > 0 + + +def test_candidate_clone_estimator(fitted_selector, dummy_data_loader): + """Test cloning an estimator from a candidate.""" + X, y = dummy_data_loader + best_candidate = fitted_selector.get_candidate(rank=1) + + # Act + cloned = best_candidate.clone_estimator() + cloned.fit(X, y) + predictions = cloned.predict(X[:3]) + + # Assert + assert predictions.shape == (3,) + + +# -- Test plot methods --------------------------------------------------------- + + +def test_plot_cv_metrics(fitted_selector): + """Test plot_cv_metrics returns valid axes.""" + ax = fitted_selector.plot_cv_metrics() + + assert ax is not None + assert hasattr(ax, "get_xlabel") + assert ax.get_xlabel() == "RMSECV" + assert ax.get_ylabel() == "RMSECV / RMSEC" + + +def test_plot_score_vs_variance(fitted_selector): + """Test plot_score_vs_variance returns valid axes.""" + ax = fitted_selector.plot_score_vs_variance() + + assert ax is not None + assert ax.get_xlabel() == "Variance" + assert ax.get_ylabel() == "Mean Test Score" + + +# -- Test unfitted selector raises errors -------------------------------------- + + +def test_unfitted_selector_raises_errors(unfitted_selector, dummy_data_loader): + """Test that unfitted selector raises errors for methods requiring fit.""" + X, y = dummy_data_loader + + with pytest.raises(Exception): + unfitted_selector.get_candidates() + + with pytest.raises(Exception): + unfitted_selector.get_candidate(rank=1) + + with pytest.raises(Exception): + unfitted_selector.filter_candidates() + + with pytest.raises(Exception): + unfitted_selector.predict(X) + + with pytest.raises(Exception): + unfitted_selector.score(X, y) + + with pytest.raises(Exception): + len(unfitted_selector) + + with pytest.raises(Exception): + list(unfitted_selector) + + with pytest.raises(Exception): + unfitted_selector.summary() + + with pytest.raises(Exception): + unfitted_selector.to_dataframe() + + with pytest.raises(Exception): + unfitted_selector.plot_cv_metrics() + + with pytest.raises(Exception): + unfitted_selector.plot_score_vs_variance() + + +# -- Test best_estimator_ usage ------------------------------------------------ + + +def test_best_estimator_and_params(fitted_selector, dummy_data_loader): + """Test that best_estimator_ is fitted and best_params_ matches best candidate.""" + X, _ = dummy_data_loader + best_candidate = fitted_selector.get_candidate(rank=1) + + # Act + predictions = fitted_selector.best_estimator_.predict(X) + + # Assert + assert predictions.shape[0] == X.shape[0] + assert fitted_selector.best_params_ == best_candidate.params