Skip to content

Commit b32e9fa

Browse files
refactor: Rename EstimatorReport._estimator to EstimatorReport.learner_
Previously we had `estimator`, `estimator_` and `_estimator`.
1 parent 58155b7 commit b32e9fa

13 files changed

Lines changed: 94 additions & 103 deletions

File tree

skore/src/skore/_plugins/hub/artifact/media/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ class EstimatorHtmlRepr(Media[Report]): # noqa: D101
1313
def content_to_upload(self) -> str: # noqa: D102
1414
import sklearn.utils
1515

16+
# FIXME: Unclear if we want to repr of estimator_
17+
# or the original_estimator
1618
estimator_html_repr: str = sklearn.utils.estimator_html_repr(
17-
self.report.estimator
19+
self.report.estimator_
1820
)
1921

2022
return estimator_html_repr

skore/src/skore/_sklearn/_checks/model_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def check_function(self, report: _BaseReport) -> str | None:
290290
isinstance(report, EstimatorReport)
291291
and report.X_train is not None
292292
and report.X_test is not None
293-
and hasattr(report.estimator, "coef_")
293+
and hasattr(report.estimator_, "coef_")
294294
):
295295
raise CheckNotApplicable()
296296

skore/src/skore/_sklearn/_comparison/metrics_accessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,10 +1182,10 @@ def confusion_matrix(
11821182
"""
11831183
do_thresholds = True
11841184
if not all(
1185-
hasattr(report._estimator, "predict_proba")
1185+
hasattr(report.learner_, "predict_proba")
11861186
for report in self._parent.reports_.values()
11871187
) and not all(
1188-
hasattr(report._estimator, "decision_function")
1188+
hasattr(report.learner_, "decision_function")
11891189
for report in self._parent.reports_.values()
11901190
):
11911191
warnings.warn(

skore/src/skore/_sklearn/_comparison/report.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ def create_estimator_report(
523523
X_train, y_train = estimator_report.X_train, estimator_report.y_train
524524

525525
return EstimatorReport(
526-
estimator_report._raw_estimator,
526+
estimator_report._original_estimator,
527527
fit=True,
528528
X_train=X_train,
529529
y_train=y_train,

skore/src/skore/_sklearn/_cross_validation/report.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import skrub
1010
from joblib import Parallel
1111
from numpy.typing import ArrayLike
12-
from sklearn.base import BaseEstimator, clone, is_classifier
12+
from sklearn.base import clone, is_classifier
1313
from sklearn.model_selection import check_cv
1414
from sklearn.pipeline import Pipeline
1515

@@ -66,7 +66,7 @@ def _check_estimator_and_data(
6666
X: ArrayLike | None,
6767
y: ArrayLike | None,
6868
data: dict | None,
69-
) -> tuple[bool, EstimatorLike, dict]:
69+
) -> tuple[bool, skrub.SkrubLearner, dict]:
7070
if is_skrub_learner(estimator):
7171
initialized_with_data_op = True
7272
if X is not None or y is not None:
@@ -153,7 +153,13 @@ class CrossValidationReport(_BaseReport, DirNamesMixin):
153153
Attributes
154154
----------
155155
estimator_ : estimator object
156-
The cloned or copied estimator.
156+
The fitted estimator.
157+
158+
estimator : estimator object
159+
The estimator that was given as input.
160+
161+
learner_ : skrub.SkrubLearner
162+
The estimator wrapped in a skrub Learner.
157163
158164
estimator_name_ : str
159165
The name of the estimator.
@@ -204,7 +210,7 @@ def __init__(
204210
n_jobs: int | None = None,
205211
) -> None:
206212
super().__init__()
207-
self._raw_estimator = estimator
213+
self._original_estimator = estimator
208214
if isinstance(estimator, skrub.DataOp):
209215
if data is None:
210216
data = estimator.skb.get_data()
@@ -214,7 +220,7 @@ def __init__(
214220
"Clustering models are not supported yet. Please use a"
215221
" classification or regression model instead."
216222
)
217-
self._initialized_with_data_op, self._estimator, self._data = (
223+
self._initialized_with_data_op, self.learner_, self._data = (
218224
_check_estimator_and_data(clone(estimator), X, y, data)
219225
)
220226
self._pos_label = pos_label
@@ -244,12 +250,12 @@ def _fit_estimator_reports(self) -> list[EstimatorReport]:
244250
track(
245251
parallel(
246252
delayed(EstimatorReport)(
247-
clone(self._estimator),
253+
clone(self.learner_),
248254
train_data=split["train"],
249255
test_data=split["test"],
250256
pos_label=self._pos_label,
251257
)
252-
for split in self._estimator.data_op.skb.iter_cv_splits(
258+
for split in self.learner_.data_op.skb.iter_cv_splits(
253259
environment=self._data, cv=self.split_indices
254260
)
255261
),
@@ -264,7 +270,7 @@ def _fit_estimator_reports(self) -> list[EstimatorReport]:
264270
track(
265271
parallel(
266272
delayed(_generate_estimator_report)(
267-
clone(self._raw_estimator),
273+
clone(self._original_estimator),
268274
self.X,
269275
self.y,
270276
self.pos_label,
@@ -296,10 +302,10 @@ def get_state(self) -> dict[str, Any]:
296302
"version": _STATE_VERSION,
297303
"metadata": self._metadata,
298304
"initialized_with_data_op": self._initialized_with_data_op,
299-
"raw_estimator": self._raw_estimator,
305+
"original_estimator": self._original_estimator,
300306
"ml_task": self.ml_task,
301307
"pos_label": self.pos_label,
302-
"estimator": self._estimator,
308+
"estimator": self.learner_,
303309
"data": self._data,
304310
"split_indices": self._split_indices,
305311
"estimator_reports": sub_states,
@@ -323,8 +329,8 @@ def from_state(cls, state: dict[str, Any]) -> CrossValidationReport:
323329
report._initialized_with_data_op = state["initialized_with_data_op"]
324330
report._ml_task = state["ml_task"]
325331
report._pos_label = state["pos_label"]
326-
report._estimator = state["estimator"]
327-
report._raw_estimator = state["raw_estimator"]
332+
report.learner_ = state["estimator"]
333+
report._original_estimator = state["original_estimator"]
328334
report._data = state["data"]
329335
report._split_indices = state["split_indices"]
330336
# TODO? Include splitter in state?
@@ -333,7 +339,7 @@ def from_state(cls, state: dict[str, Any]) -> CrossValidationReport:
333339

334340
report.estimator_reports_ = []
335341
if report._initialized_with_data_op:
336-
split_data_iterator = report._estimator.data_op.skb.iter_cv_splits(
342+
split_data_iterator = report.learner_.data_op.skb.iter_cv_splits(
337343
environment=report._data,
338344
cv=report._split_indices,
339345
)
@@ -535,14 +541,14 @@ def create_estimator_report(
535541
"""
536542
if self._initialized_with_data_op:
537543
report = EstimatorReport(
538-
self._estimator,
544+
self.learner_,
539545
train_data=self._data,
540546
test_data=test_data,
541547
pos_label=self._pos_label,
542548
)
543549
else:
544550
report = EstimatorReport(
545-
self._raw_estimator,
551+
self._original_estimator,
546552
X_train=self.X,
547553
y_train=self.y,
548554
X_test=X_test,
@@ -586,21 +592,23 @@ def ml_task(self) -> str:
586592
return self._ml_task
587593

588594
@property
589-
def estimator(self) -> BaseEstimator:
590-
return self.estimator_
595+
def original_estimator(self) -> EstimatorLike:
596+
"""The estimator that was given as input."""
597+
return self._original_estimator
591598

592599
@property
593-
def estimator_(self) -> BaseEstimator:
600+
def estimator_(self) -> EstimatorLike:
601+
"""The report's fitted estimator."""
594602
if self._initialized_with_data_op:
595-
return self._estimator
596-
return to_estimator(self._estimator)
603+
return self.learner_
604+
return to_estimator(self.learner_)
597605

598606
@property
599607
def estimator_name_(self) -> str:
600-
if isinstance(self._raw_estimator, Pipeline):
601-
name = self._raw_estimator[-1].__class__.__name__
608+
if isinstance(self._original_estimator, Pipeline):
609+
name = self._original_estimator[-1].__class__.__name__
602610
else:
603-
name = self._raw_estimator.__class__.__name__
611+
name = self._original_estimator.__class__.__name__
604612
return name
605613

606614
@property

skore/src/skore/_sklearn/_estimator/metrics_accessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,8 +1163,8 @@ def confusion_matrix(
11631163
"data_source='both' is not supported for confusion_matrix."
11641164
)
11651165

1166-
if hasattr(self._parent._estimator, "predict_proba") or hasattr(
1167-
self._parent._estimator, "decision_function"
1166+
if hasattr(self._parent.learner_, "predict_proba") or hasattr(
1167+
self._parent.learner_, "decision_function"
11681168
):
11691169
y_scores = self._parent._get_predictions(
11701170
data_source=data_source,

skore/src/skore/_sklearn/_estimator/report.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
import skrub
1313
from numpy.typing import ArrayLike
14-
from sklearn.base import BaseEstimator, clone
14+
from sklearn.base import clone
1515
from sklearn.exceptions import NotFittedError
1616
from sklearn.pipeline import Pipeline
1717
from sklearn.utils._response import (
@@ -52,7 +52,7 @@ def _check_estimator_and_data(
5252
y_test: ArrayLike | None,
5353
train_data: dict | None,
5454
test_data: dict | None,
55-
) -> tuple[bool, EstimatorLike, dict | None, dict | None]:
55+
) -> tuple[bool, skrub.SkrubLearner, dict | None, dict | None]:
5656
"""Check and validate the estimator and data."""
5757
if is_skrub_learner(estimator):
5858
initialized_with_data_op = True
@@ -140,7 +140,13 @@ class EstimatorReport(_BaseReport, DirNamesMixin):
140140
Attributes
141141
----------
142142
estimator_ : estimator object
143-
The cloned or copied estimator.
143+
The fitted estimator.
144+
145+
original_estimator : estimator object
146+
The estimator that was given as input.
147+
148+
learner_ : skrub.SkrubLearner
149+
The estimator wrapped in a skrub Learner.
144150
145151
estimator_name_ : str
146152
The name of the estimator.
@@ -230,7 +236,7 @@ def __init__(
230236
) -> None:
231237
super().__init__()
232238
estimator = self._copy_estimator(estimator)
233-
self._raw_estimator = estimator
239+
self._original_estimator = estimator
234240
self._fit = fit
235241

236242
if isinstance(estimator, skrub.DataOp):
@@ -254,17 +260,17 @@ def __init__(
254260
if fit == "auto":
255261
try:
256262
check_is_fitted(estimator)
257-
self._estimator = estimator
263+
self.learner_ = estimator
258264
except NotFittedError:
259-
self._estimator, self._fit_time = self._fit_estimator(
265+
self.learner_, self._fit_time = self._fit_estimator(
260266
estimator, self._train_data
261267
)
262268
elif fit is True:
263-
self._estimator, self._fit_time = self._fit_estimator(
269+
self.learner_, self._fit_time = self._fit_estimator(
264270
estimator, self._train_data
265271
)
266272
else: # fit is False
267-
self._estimator = estimator
273+
self.learner_ = estimator
268274

269275
self._pos_label = pos_label
270276
self.fit_time_ = self._fit_time
@@ -320,12 +326,12 @@ def get_state(self) -> dict[str, Any]:
320326
# -------- CORE STATE ---------
321327
"metadata": self._metadata,
322328
"initialized_with_data_op": self._initialized_with_data_op,
323-
"raw_estimator": self._raw_estimator,
329+
"original_estimator": self._original_estimator,
324330
"ml_task": self._ml_task,
325331
"fit": self._fit,
326332
"fit_time": self.fit_time_,
327333
"pos_label": self._pos_label,
328-
"estimator": self._estimator,
334+
"estimator": self.learner_,
329335
"data": {
330336
"train_data": self._train_data,
331337
"test_data": self._test_data,
@@ -360,8 +366,8 @@ def from_state(cls, state: dict[str, Any]) -> EstimatorReport:
360366
report._fit = state["fit"]
361367
report.fit_time_ = state["fit_time"]
362368
report._pos_label = state["pos_label"]
363-
report._estimator = state["estimator"]
364-
report._raw_estimator = state["raw_estimator"]
369+
report.learner_ = state["estimator"]
370+
report._original_estimator = state["original_estimator"]
365371
data = state["data"]
366372
report._train_data = data["train_data"]
367373
report._test_data = data["test_data"]
@@ -451,11 +457,11 @@ def cache_predictions(
451457
# from decision_function/predict_proba:
452458
if not self._can_skip_predict:
453459
with MeasureTime() as pred_time:
454-
self._cache[pred_key] = self._estimator.predict(data)
460+
self._cache[pred_key] = self.learner_.predict(data)
455461
self._cache[time_key] = pred_time()
456462

457-
has_proba = hasattr(self._estimator, "predict_proba")
458-
has_decision = hasattr(self._estimator, "decision_function")
463+
has_proba = hasattr(self.learner_, "predict_proba")
464+
has_decision = hasattr(self.learner_, "decision_function")
459465

460466
if not (has_proba or has_decision):
461467
return
@@ -517,8 +523,8 @@ def _get_response_and_derived_predictions(
517523
Time spent computing ``response_method(data)`` in seconds.
518524
"""
519525
with MeasureTime() as pred_time:
520-
response = getattr(self._estimator, response_method)(data)
521-
classes = self._estimator.classes_
526+
response = getattr(self.learner_, response_method)(data)
527+
classes = self.learner_.classes_
522528
if response_method == "decision_function":
523529
if self.ml_task == "binary-classification":
524530
response = np.vstack((-response, response)).T
@@ -538,7 +544,7 @@ def _can_skip_predict(self) -> bool:
538544
"""
539545
response_methods = ["decision_function", "predict_proba"]
540546
try:
541-
method = _check_response_method(self._estimator, response_methods)
547+
method = _check_response_method(self.learner_, response_methods)
542548
except AttributeError:
543549
return False
544550
data = self.train_data if self.test_data is None else self.test_data
@@ -557,7 +563,7 @@ def _can_skip_predict(self) -> bool:
557563
sampled_data = data | {"_skrub_X": X_sample}
558564

559565
# probe:
560-
predictions = self._estimator.predict(sampled_data)
566+
predictions = self.learner_.predict(sampled_data)
561567
_, deduced_predictions, _ = self._get_response_and_derived_predictions(
562568
sampled_data,
563569
response_method=method.__name__,
@@ -727,14 +733,16 @@ def ml_task(self):
727733
return self._ml_task
728734

729735
@property
730-
def estimator(self) -> BaseEstimator:
731-
return self.estimator_
736+
def original_estimator(self) -> EstimatorLike:
737+
"""The estimator that was given as input."""
738+
return self._original_estimator
732739

733740
@property
734-
def estimator_(self) -> BaseEstimator:
741+
def estimator_(self) -> EstimatorLike:
742+
"""The report's fitted estimator."""
735743
if self._initialized_with_data_op:
736-
return self._estimator
737-
return to_estimator(self._estimator)
744+
return self.learner_
745+
return to_estimator(self.learner_)
738746

739747
@property
740748
def X_train(self) -> ArrayLike | None:
@@ -770,10 +778,10 @@ def fit(self) -> str | bool:
770778

771779
@property
772780
def estimator_name_(self) -> str:
773-
if isinstance(self._raw_estimator, Pipeline):
774-
name = self._raw_estimator[-1].__class__.__name__
781+
if isinstance(self._original_estimator, Pipeline):
782+
name = self._original_estimator[-1].__class__.__name__
775783
else:
776-
name = self._raw_estimator.__class__.__name__
784+
name = self._original_estimator.__class__.__name__
777785
return name
778786

779787
####################################################################################

0 commit comments

Comments
 (0)