99import skrub
1010from joblib import Parallel
1111from numpy .typing import ArrayLike
12- from sklearn .base import BaseEstimator , clone , is_classifier
12+ from sklearn .base import clone , is_classifier
1313from sklearn .model_selection import check_cv
1414from sklearn .pipeline import Pipeline
1515
@@ -66,7 +66,7 @@ def _check_estimator_and_data(
6666 X : ArrayLike | None ,
6767 y : ArrayLike | None ,
6868 data : dict | None ,
69- ) -> tuple [bool , EstimatorLike , dict ]:
69+ ) -> tuple [bool , skrub . SkrubLearner , dict ]:
7070 if is_skrub_learner (estimator ):
7171 initialized_with_data_op = True
7272 if X is not None or y is not None :
@@ -153,7 +153,13 @@ class CrossValidationReport(_BaseReport, DirNamesMixin):
153153 Attributes
154154 ----------
155155 estimator_ : estimator object
156- The cloned or copied estimator.
156+ The fitted estimator.
157+
158+ estimator : estimator object
159+ The estimator that was given as input.
160+
161+ learner_ : skrub.SkrubLearner
162+ The estimator wrapped in a skrub Learner.
157163
158164 estimator_name_ : str
159165 The name of the estimator.
@@ -204,7 +210,7 @@ def __init__(
204210 n_jobs : int | None = None ,
205211 ) -> None :
206212 super ().__init__ ()
207- self ._raw_estimator = estimator
213+ self ._original_estimator = estimator
208214 if isinstance (estimator , skrub .DataOp ):
209215 if data is None :
210216 data = estimator .skb .get_data ()
@@ -214,7 +220,7 @@ def __init__(
214220 "Clustering models are not supported yet. Please use a"
215221 " classification or regression model instead."
216222 )
217- self ._initialized_with_data_op , self ._estimator , self ._data = (
223+ self ._initialized_with_data_op , self .learner_ , self ._data = (
218224 _check_estimator_and_data (clone (estimator ), X , y , data )
219225 )
220226 self ._pos_label = pos_label
@@ -244,12 +250,12 @@ def _fit_estimator_reports(self) -> list[EstimatorReport]:
244250 track (
245251 parallel (
246252 delayed (EstimatorReport )(
247- clone (self ._estimator ),
253+ clone (self .learner_ ),
248254 train_data = split ["train" ],
249255 test_data = split ["test" ],
250256 pos_label = self ._pos_label ,
251257 )
252- for split in self ._estimator .data_op .skb .iter_cv_splits (
258+ for split in self .learner_ .data_op .skb .iter_cv_splits (
253259 environment = self ._data , cv = self .split_indices
254260 )
255261 ),
@@ -264,7 +270,7 @@ def _fit_estimator_reports(self) -> list[EstimatorReport]:
264270 track (
265271 parallel (
266272 delayed (_generate_estimator_report )(
267- clone (self ._raw_estimator ),
273+ clone (self ._original_estimator ),
268274 self .X ,
269275 self .y ,
270276 self .pos_label ,
@@ -296,10 +302,10 @@ def get_state(self) -> dict[str, Any]:
296302 "version" : _STATE_VERSION ,
297303 "metadata" : self ._metadata ,
298304 "initialized_with_data_op" : self ._initialized_with_data_op ,
299- "raw_estimator " : self ._raw_estimator ,
305+ "original_estimator " : self ._original_estimator ,
300306 "ml_task" : self .ml_task ,
301307 "pos_label" : self .pos_label ,
302- "estimator" : self ._estimator ,
308+ "estimator" : self .learner_ ,
303309 "data" : self ._data ,
304310 "split_indices" : self ._split_indices ,
305311 "estimator_reports" : sub_states ,
@@ -323,8 +329,8 @@ def from_state(cls, state: dict[str, Any]) -> CrossValidationReport:
323329 report ._initialized_with_data_op = state ["initialized_with_data_op" ]
324330 report ._ml_task = state ["ml_task" ]
325331 report ._pos_label = state ["pos_label" ]
326- report ._estimator = state ["estimator" ]
327- report ._raw_estimator = state ["raw_estimator " ]
332+ report .learner_ = state ["estimator" ]
333+ report ._original_estimator = state ["original_estimator " ]
328334 report ._data = state ["data" ]
329335 report ._split_indices = state ["split_indices" ]
330336 # TODO? Include splitter in state?
@@ -333,7 +339,7 @@ def from_state(cls, state: dict[str, Any]) -> CrossValidationReport:
333339
334340 report .estimator_reports_ = []
335341 if report ._initialized_with_data_op :
336- split_data_iterator = report ._estimator .data_op .skb .iter_cv_splits (
342+ split_data_iterator = report .learner_ .data_op .skb .iter_cv_splits (
337343 environment = report ._data ,
338344 cv = report ._split_indices ,
339345 )
@@ -535,14 +541,14 @@ def create_estimator_report(
535541 """
536542 if self ._initialized_with_data_op :
537543 report = EstimatorReport (
538- self ._estimator ,
544+ self .learner_ ,
539545 train_data = self ._data ,
540546 test_data = test_data ,
541547 pos_label = self ._pos_label ,
542548 )
543549 else :
544550 report = EstimatorReport (
545- self ._raw_estimator ,
551+ self ._original_estimator ,
546552 X_train = self .X ,
547553 y_train = self .y ,
548554 X_test = X_test ,
@@ -586,21 +592,23 @@ def ml_task(self) -> str:
586592 return self ._ml_task
587593
588594 @property
589- def estimator (self ) -> BaseEstimator :
590- return self .estimator_
595+ def original_estimator (self ) -> EstimatorLike :
596+ """The estimator that was given as input."""
597+ return self ._original_estimator
591598
592599 @property
593- def estimator_ (self ) -> BaseEstimator :
600+ def estimator_ (self ) -> EstimatorLike :
601+ """The report's fitted estimator."""
594602 if self ._initialized_with_data_op :
595- return self ._estimator
596- return to_estimator (self ._estimator )
603+ return self .learner_
604+ return to_estimator (self .learner_ )
597605
598606 @property
599607 def estimator_name_ (self ) -> str :
600- if isinstance (self ._raw_estimator , Pipeline ):
601- name = self ._raw_estimator [- 1 ].__class__ .__name__
608+ if isinstance (self ._original_estimator , Pipeline ):
609+ name = self ._original_estimator [- 1 ].__class__ .__name__
602610 else :
603- name = self ._raw_estimator .__class__ .__name__
611+ name = self ._original_estimator .__class__ .__name__
604612 return name
605613
606614 @property
0 commit comments