Skip to content

Commit cca956b

Browse files
refactor: Rename EstimatorReport._estimator to EstimatorReport.learner_
Previously we had `estimator`, `estimator_` and `_estimator`.
1 parent c6a52a3 commit cca956b

16 files changed

Lines changed: 98 additions & 107 deletions

File tree

skore-mlflow-project/src/skore_mlflow_project/protocol.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class EstimatorReport(Protocol):
1717
"""Protocol equivalent to ``skore.EstimatorReport``."""
1818

1919
ml_task: str
20-
estimator: BaseEstimator
20+
original_estimator: BaseEstimator
2121
estimator_: BaseEstimator
2222
estimator_name_: str
2323
X_train: DatasetLike | None
@@ -33,7 +33,7 @@ class CrossValidationReport(Protocol):
3333
"""Protocol equivalent to ``skore.CrossValidationReport``."""
3434

3535
ml_task: str
36-
estimator: BaseEstimator
36+
original_estimator: BaseEstimator
3737
estimator_: BaseEstimator
3838
estimator_name_: str
3939
estimator_reports_: list[EstimatorReport]

skore-mlflow-project/src/skore_mlflow_project/reports.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def iter_cv(report: CrossValidationReport) -> Generator[NestedLogItem, None, Non
199199
"""Yield loggable objects for a cross-validation report."""
200200
yield from iter_cv_metrics(report)
201201

202-
estimator = clone(report.estimator).fit(report.X, report.y)
202+
estimator = clone(report.original_estimator).fit(report.X, report.y)
203203
yield Params(estimator.get_params())
204204
yield Model(estimator, _sample_input_example(report.X))
205205

skore/src/skore/_plugins/hub/artifact/media/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ class EstimatorHtmlRepr(Media[Report]): # noqa: D101
1313
def content_to_upload(self) -> str: # noqa: D102
1414
import sklearn.utils
1515

16+
# FIXME: Unclear if we want to repr of estimator_
17+
# or the original_estimator
1618
estimator_html_repr: str = sklearn.utils.estimator_html_repr(
17-
self.report.estimator
19+
self.report.estimator_
1820
)
1921

2022
return estimator_html_repr

skore/src/skore/_sklearn/_checks/model_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def check_function(self, report: _BaseReport) -> str | None:
290290
isinstance(report, EstimatorReport)
291291
and report.X_train is not None
292292
and report.X_test is not None
293-
and hasattr(report.estimator, "coef_")
293+
and hasattr(report.estimator_, "coef_")
294294
):
295295
raise CheckNotApplicable()
296296

skore/src/skore/_sklearn/_comparison/metrics_accessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,10 +1182,10 @@ def confusion_matrix(
11821182
"""
11831183
do_thresholds = True
11841184
if not all(
1185-
hasattr(report._estimator, "predict_proba")
1185+
hasattr(report.learner_, "predict_proba")
11861186
for report in self._parent.reports_.values()
11871187
) and not all(
1188-
hasattr(report._estimator, "decision_function")
1188+
hasattr(report.learner_, "decision_function")
11891189
for report in self._parent.reports_.values()
11901190
):
11911191
warnings.warn(

skore/src/skore/_sklearn/_comparison/report.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ def create_estimator_report(
523523
X_train, y_train = estimator_report.X_train, estimator_report.y_train
524524

525525
return EstimatorReport(
526-
estimator_report._raw_estimator,
526+
estimator_report._original_estimator,
527527
fit=True,
528528
X_train=X_train,
529529
y_train=y_train,

skore/src/skore/_sklearn/_cross_validation/report.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import skrub
1010
from joblib import Parallel
1111
from numpy.typing import ArrayLike
12-
from sklearn.base import BaseEstimator, clone, is_classifier
12+
from sklearn.base import clone, is_classifier
1313
from sklearn.model_selection import check_cv
1414
from sklearn.pipeline import Pipeline
1515

@@ -66,7 +66,7 @@ def _check_estimator_and_data(
6666
X: ArrayLike | None,
6767
y: ArrayLike | None,
6868
data: dict | None,
69-
) -> tuple[bool, EstimatorLike, dict]:
69+
) -> tuple[bool, skrub.SkrubLearner, dict]:
7070
if is_skrub_learner(estimator):
7171
initialized_with_data_op = True
7272
if X is not None or y is not None:
@@ -153,7 +153,13 @@ class CrossValidationReport(_BaseReport, DirNamesMixin):
153153
Attributes
154154
----------
155155
estimator_ : estimator object
156-
The cloned or copied estimator.
156+
The fitted estimator.
157+
158+
estimator : estimator object
159+
The estimator that was given as input.
160+
161+
learner_ : skrub.SkrubLearner
162+
The estimator wrapped in a skrub Learner.
157163
158164
estimator_name_ : str
159165
The name of the estimator.
@@ -204,7 +210,7 @@ def __init__(
204210
n_jobs: int | None = None,
205211
) -> None:
206212
super().__init__()
207-
self._raw_estimator = estimator
213+
self._original_estimator = estimator
208214
if isinstance(estimator, skrub.DataOp):
209215
if data is None:
210216
data = estimator.skb.get_data()
@@ -214,7 +220,7 @@ def __init__(
214220
"Clustering models are not supported yet. Please use a"
215221
" classification or regression model instead."
216222
)
217-
self._initialized_with_data_op, self._estimator, self._data = (
223+
self._initialized_with_data_op, self.learner_, self._data = (
218224
_check_estimator_and_data(clone(estimator), X, y, data)
219225
)
220226
self._pos_label = pos_label
@@ -244,12 +250,12 @@ def _fit_estimator_reports(self) -> list[EstimatorReport]:
244250
track(
245251
parallel(
246252
delayed(EstimatorReport)(
247-
clone(self._estimator),
253+
clone(self.learner_),
248254
train_data=split["train"],
249255
test_data=split["test"],
250256
pos_label=self._pos_label,
251257
)
252-
for split in self._estimator.data_op.skb.iter_cv_splits(
258+
for split in self.learner_.data_op.skb.iter_cv_splits(
253259
environment=self._data, cv=self.split_indices
254260
)
255261
),
@@ -264,7 +270,7 @@ def _fit_estimator_reports(self) -> list[EstimatorReport]:
264270
track(
265271
parallel(
266272
delayed(_generate_estimator_report)(
267-
clone(self._raw_estimator),
273+
clone(self._original_estimator),
268274
self.X,
269275
self.y,
270276
self.pos_label,
@@ -296,10 +302,10 @@ def get_state(self) -> dict[str, Any]:
296302
"version": _STATE_VERSION,
297303
"metadata": self._metadata,
298304
"initialized_with_data_op": self._initialized_with_data_op,
299-
"raw_estimator": self._raw_estimator,
305+
"original_estimator": self._original_estimator,
300306
"ml_task": self.ml_task,
301307
"pos_label": self.pos_label,
302-
"estimator": self._estimator,
308+
"estimator": self.learner_,
303309
"data": self._data,
304310
"split_indices": self._split_indices,
305311
"estimator_reports": sub_states,
@@ -323,8 +329,8 @@ def from_state(cls, state: dict[str, Any]) -> CrossValidationReport:
323329
report._initialized_with_data_op = state["initialized_with_data_op"]
324330
report._ml_task = state["ml_task"]
325331
report._pos_label = state["pos_label"]
326-
report._estimator = state["estimator"]
327-
report._raw_estimator = state["raw_estimator"]
332+
report.learner_ = state["estimator"]
333+
report._original_estimator = state["original_estimator"]
328334
report._data = state["data"]
329335
report._split_indices = state["split_indices"]
330336
# TODO? Include splitter in state?
@@ -333,7 +339,7 @@ def from_state(cls, state: dict[str, Any]) -> CrossValidationReport:
333339

334340
report.estimator_reports_ = []
335341
if report._initialized_with_data_op:
336-
split_data_iterator = report._estimator.data_op.skb.iter_cv_splits(
342+
split_data_iterator = report.learner_.data_op.skb.iter_cv_splits(
337343
environment=report._data,
338344
cv=report._split_indices,
339345
)
@@ -535,14 +541,14 @@ def create_estimator_report(
535541
"""
536542
if self._initialized_with_data_op:
537543
report = EstimatorReport(
538-
self._estimator,
544+
self.learner_,
539545
train_data=self._data,
540546
test_data=test_data,
541547
pos_label=self._pos_label,
542548
)
543549
else:
544550
report = EstimatorReport(
545-
self._raw_estimator,
551+
self._original_estimator,
546552
X_train=self.X,
547553
y_train=self.y,
548554
X_test=X_test,
@@ -586,21 +592,23 @@ def ml_task(self) -> str:
586592
return self._ml_task
587593

588594
@property
589-
def estimator(self) -> BaseEstimator:
590-
return self.estimator_
595+
def original_estimator(self) -> EstimatorLike:
596+
"""The estimator that was given as input."""
597+
return self._original_estimator
591598

592599
@property
593-
def estimator_(self) -> BaseEstimator:
600+
def estimator_(self) -> EstimatorLike:
601+
"""The report's fitted estimator."""
594602
if self._initialized_with_data_op:
595-
return self._estimator
596-
return to_estimator(self._estimator)
603+
return self.learner_
604+
return to_estimator(self.learner_)
597605

598606
@property
599607
def estimator_name_(self) -> str:
600-
if isinstance(self._raw_estimator, Pipeline):
601-
name = self._raw_estimator[-1].__class__.__name__
608+
if isinstance(self._original_estimator, Pipeline):
609+
name = self._original_estimator[-1].__class__.__name__
602610
else:
603-
name = self._raw_estimator.__class__.__name__
611+
name = self._original_estimator.__class__.__name__
604612
return name
605613

606614
@property

skore/src/skore/_sklearn/_estimator/metrics_accessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,8 +1163,8 @@ def confusion_matrix(
11631163
"data_source='both' is not supported for confusion_matrix."
11641164
)
11651165

1166-
if hasattr(self._parent._estimator, "predict_proba") or hasattr(
1167-
self._parent._estimator, "decision_function"
1166+
if hasattr(self._parent.learner_, "predict_proba") or hasattr(
1167+
self._parent.learner_, "decision_function"
11681168
):
11691169
y_scores = self._parent._get_predictions(
11701170
data_source=data_source,

0 commit comments

Comments
 (0)