1111import numpy as np
1212import skrub
1313from numpy .typing import ArrayLike
14- from sklearn .base import BaseEstimator , clone
14+ from sklearn .base import clone
1515from sklearn .exceptions import NotFittedError
1616from sklearn .pipeline import Pipeline
1717from sklearn .utils ._response import (
@@ -52,7 +52,7 @@ def _check_estimator_and_data(
5252 y_test : ArrayLike | None ,
5353 train_data : dict | None ,
5454 test_data : dict | None ,
55- ) -> tuple [bool , EstimatorLike , dict | None , dict | None ]:
55+ ) -> tuple [bool , skrub . SkrubLearner , dict | None , dict | None ]:
5656 """Check and validate the estimator and data."""
5757 if is_skrub_learner (estimator ):
5858 initialized_with_data_op = True
@@ -140,7 +140,13 @@ class EstimatorReport(_BaseReport, DirNamesMixin):
140140 Attributes
141141 ----------
142142 estimator_ : estimator object
143- The cloned or copied estimator.
143+ The fitted estimator.
144+
145+ original_estimator : estimator object
146+ The estimator that was given as input.
147+
148+ learner_ : skrub.SkrubLearner
149+ The estimator wrapped in a skrub Learner.
144150
145151 estimator_name_ : str
146152 The name of the estimator.
@@ -230,7 +236,7 @@ def __init__(
230236 ) -> None :
231237 super ().__init__ ()
232238 estimator = self ._copy_estimator (estimator )
233- self ._raw_estimator = estimator
239+ self ._original_estimator = estimator
234240 self ._fit = fit
235241
236242 if isinstance (estimator , skrub .DataOp ):
@@ -254,17 +260,17 @@ def __init__(
254260 if fit == "auto" :
255261 try :
256262 check_is_fitted (estimator )
257- self ._estimator = estimator
263+ self .learner_ = estimator
258264 except NotFittedError :
259- self ._estimator , self ._fit_time = self ._fit_estimator (
265+ self .learner_ , self ._fit_time = self ._fit_estimator (
260266 estimator , self ._train_data
261267 )
262268 elif fit is True :
263- self ._estimator , self ._fit_time = self ._fit_estimator (
269+ self .learner_ , self ._fit_time = self ._fit_estimator (
264270 estimator , self ._train_data
265271 )
266272 else : # fit is False
267- self ._estimator = estimator
273+ self .learner_ = estimator
268274
269275 self ._pos_label = pos_label
270276 self .fit_time_ = self ._fit_time
@@ -320,12 +326,12 @@ def get_state(self) -> dict[str, Any]:
320326 # -------- CORE STATE ---------
321327 "metadata" : self ._metadata ,
322328 "initialized_with_data_op" : self ._initialized_with_data_op ,
323- "raw_estimator " : self ._raw_estimator ,
329+ "original_estimator " : self ._original_estimator ,
324330 "ml_task" : self ._ml_task ,
325331 "fit" : self ._fit ,
326332 "fit_time" : self .fit_time_ ,
327333 "pos_label" : self ._pos_label ,
328- "estimator" : self ._estimator ,
334+ "estimator" : self .learner_ ,
329335 "data" : {
330336 "train_data" : self ._train_data ,
331337 "test_data" : self ._test_data ,
@@ -360,8 +366,8 @@ def from_state(cls, state: dict[str, Any]) -> EstimatorReport:
360366 report ._fit = state ["fit" ]
361367 report .fit_time_ = state ["fit_time" ]
362368 report ._pos_label = state ["pos_label" ]
363- report ._estimator = state ["estimator" ]
364- report ._raw_estimator = state ["raw_estimator " ]
369+ report .learner_ = state ["estimator" ]
370+ report ._original_estimator = state ["original_estimator " ]
365371 data = state ["data" ]
366372 report ._train_data = data ["train_data" ]
367373 report ._test_data = data ["test_data" ]
@@ -451,11 +457,11 @@ def cache_predictions(
451457 # from decision_function/predict_proba:
452458 if not self ._can_skip_predict :
453459 with MeasureTime () as pred_time :
454- self ._cache [pred_key ] = self ._estimator .predict (data )
460+ self ._cache [pred_key ] = self .learner_ .predict (data )
455461 self ._cache [time_key ] = pred_time ()
456462
457- has_proba = hasattr (self ._estimator , "predict_proba" )
458- has_decision = hasattr (self ._estimator , "decision_function" )
463+ has_proba = hasattr (self .learner_ , "predict_proba" )
464+ has_decision = hasattr (self .learner_ , "decision_function" )
459465
460466 if not (has_proba or has_decision ):
461467 return
@@ -517,8 +523,8 @@ def _get_response_and_derived_predictions(
517523 Time spent computing ``response_method(data)`` in seconds.
518524 """
519525 with MeasureTime () as pred_time :
520- response = getattr (self ._estimator , response_method )(data )
521- classes = self ._estimator .classes_
526+ response = getattr (self .learner_ , response_method )(data )
527+ classes = self .learner_ .classes_
522528 if response_method == "decision_function" :
523529 if self .ml_task == "binary-classification" :
524530 response = np .vstack ((- response , response )).T
@@ -538,7 +544,7 @@ def _can_skip_predict(self) -> bool:
538544 """
539545 response_methods = ["decision_function" , "predict_proba" ]
540546 try :
541- method = _check_response_method (self ._estimator , response_methods )
547+ method = _check_response_method (self .learner_ , response_methods )
542548 except AttributeError :
543549 return False
544550 data = self .train_data if self .test_data is None else self .test_data
@@ -557,7 +563,7 @@ def _can_skip_predict(self) -> bool:
557563 sampled_data = data | {"_skrub_X" : X_sample }
558564
559565 # probe:
560- predictions = self ._estimator .predict (sampled_data )
566+ predictions = self .learner_ .predict (sampled_data )
561567 _ , deduced_predictions , _ = self ._get_response_and_derived_predictions (
562568 sampled_data ,
563569 response_method = method .__name__ ,
@@ -727,14 +733,16 @@ def ml_task(self):
727733 return self ._ml_task
728734
729735 @property
730- def estimator (self ) -> BaseEstimator :
731- return self .estimator_
736+ def original_estimator (self ) -> EstimatorLike :
737+ """The estimator that was given as input."""
738+ return self ._original_estimator
732739
733740 @property
734- def estimator_ (self ) -> BaseEstimator :
741+ def estimator_ (self ) -> EstimatorLike :
742+ """The report's fitted estimator."""
735743 if self ._initialized_with_data_op :
736- return self ._estimator
737- return to_estimator (self ._estimator )
744+ return self .learner_
745+ return to_estimator (self .learner_ )
738746
739747 @property
740748 def X_train (self ) -> ArrayLike | None :
@@ -770,10 +778,10 @@ def fit(self) -> str | bool:
770778
771779 @property
772780 def estimator_name_ (self ) -> str :
773- if isinstance (self ._raw_estimator , Pipeline ):
774- name = self ._raw_estimator [- 1 ].__class__ .__name__
781+ if isinstance (self ._original_estimator , Pipeline ):
782+ name = self ._original_estimator [- 1 ].__class__ .__name__
775783 else :
776- name = self ._raw_estimator .__class__ .__name__
784+ name = self ._original_estimator .__class__ .__name__
777785 return name
778786
779787 ####################################################################################
0 commit comments