From 711bbbd1e359bd902f731c1f4864dedb1e48f8f4 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 26 Apr 2026 21:28:38 +0200 Subject: [PATCH] chore(skore,skore-*-project)!: Remove clear_cache and cache_predictions from public API --- CHANGELOG.rst | 7 ++ .../model_evaluation/plot_estimator_report.py | 62 +++-------- .../technical_details/plot_cache_mechanism.py | 100 +++++++----------- examples/use_cases/plot_employee_salaries.py | 28 ++--- .../artifact/pickle/pickle.py | 2 +- .../src/skore_hub_project/protocol.py | 8 +- .../src/skore_hub_project/report/report.py | 2 +- .../unit/artifact/media/test_inspection.py | 2 +- .../report/test_cross_validation_report.py | 4 +- .../unit/report/test_estimator_report.py | 2 +- .../tests/unit/test_project.py | 8 +- .../_sklearn/_comparison/metrics_accessor.py | 2 +- .../src/skore/_sklearn/_comparison/report.py | 46 ++------ .../_cross_validation/metrics_accessor.py | 2 +- .../_sklearn/_cross_validation/report.py | 39 ++----- .../_sklearn/_estimator/metrics_accessor.py | 2 +- skore/src/skore/_sklearn/_estimator/report.py | 44 ++------ .../cross_validation/metrics/test_numeric.py | 2 +- .../cross_validation/metrics/test_numeric.py | 2 +- .../reports/cross_validation/test_report.py | 14 +-- .../unit/reports/estimator/test_report.py | 10 +- sphinx/reference/report/comparison_report.rst | 2 - .../report/cross_validation_report.rst | 2 - sphinx/reference/report/estimator_report.rst | 2 - sphinx/user_guide/reporters.rst | 16 ++- 25 files changed, 121 insertions(+), 289 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 44b1d1107c..97c646a062 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -48,6 +48,13 @@ Added Removed ------- +- **Breaking change:** The public methods ``cache_predictions`` and ``clear_cache`` on + :class:`~skore.EstimatorReport`, :class:`~skore.CrossValidationReport`, and + :class:`~skore.ComparisonReport` are removed. The library still uses + ``_cache_predictions`` and ``_clear_cache`` internally; in application code, rely on + :meth:`~skore.EstimatorReport.get_predictions` and the metrics and inspection APIs, + which populate and reuse the in-memory store automatically. + Fixed ----- diff --git a/examples/model_evaluation/plot_estimator_report.py b/examples/model_evaluation/plot_estimator_report.py index 2c5c8f6335..5160ffb9a9 100644 --- a/examples/model_evaluation/plot_estimator_report.py +++ b/examples/model_evaluation/plot_estimator_report.py @@ -96,7 +96,7 @@ # %% # -# Metrics computation with aggressive caching +# Metrics computation and repeated evaluation # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # # At this point, we might be interested to have a first look at the statistical @@ -116,13 +116,10 @@ # %% # -# An interesting feature provided by the :class:`skore.EstimatorReport` is the -# the caching mechanism. Indeed, when we have a large enough dataset, computing the -# predictions for a model is not cheap anymore. For instance, on our smallish dataset, -# it took a couple of seconds to compute the metrics. The report will cache the -# predictions and if we are interested in computing a metric again or an alternative -# metric that requires the same predictions, it will be faster. Let's check by -# requesting the same metrics report again. +# On large enough data, getting predictions is often the expensive step. The report +# keeps intermediate results in memory for the same session, so when we ask for the +# same :meth:`~skore.EstimatorReport.metrics.summarize` again, it can complete much +# faster. Let's request the same summary a second time. start = time.time() metric_report = report.metrics.summarize().frame() @@ -147,22 +144,8 @@ # %% # -# Whenever computing a metric, we check if the predictions are available in the cache -# and reload them if available. So for instance, let's compute the log loss. - -start = time.time() -log_loss = report.metrics.log_loss() -end = time.time() -log_loss - -# %% -print(f"Time taken to compute the log loss: {end - start:.2f} seconds") - -# %% -# -# We can show that without initial cache, it would have taken more time to compute -# the log loss. -report.clear_cache() +# Another metric on the test set, such as log loss, can reuse the same underlying +# predictions if they were already required for a previous call. start = time.time() log_loss = report.metrics.log_loss() @@ -181,10 +164,9 @@ # %% # -# Be aware that we can also benefit from the caching mechanism with our own custom -# metrics. Skore only expects that we define our own metric function to take `y_true` -# and `y_pred` as the first two positional arguments. It can take any other arguments. -# Let's see an example. +# Custom metrics also go through the same path: they receive `y_true` and `y_pred` +# as the first two arguments, and the report supplies predictions consistently with +# built-in metrics. The callable can take any other arguments. Let's see an example. def operational_decision_cost(y_true, y_pred, amount): @@ -259,10 +241,8 @@ def operational_decision_cost(y_true, y_pred, amount): # %% # -# Similarly to the metrics, we aggressively use the caching to avoid recomputing the -# predictions of the model. We also cache the plot display object by detection if the -# input parameters are the same as the previous call. Let's demonstrate the kind of -# performance gain we can get. +# Similarly to the metrics, repeated calls for the same ROC display can be much +# faster in the same session once the underlying values have been computed. start = time.time() # we already trigger the computation of the predictions in a previous call display = report.metrics.roc() @@ -273,24 +253,6 @@ def operational_decision_cost(y_true, y_pred, amount): # %% print(f"Time taken to compute the ROC curve: {end - start:.2f} seconds") -# %% -# -# Now, let's clean the cache and check if we get a slowdown. -report.clear_cache() - -# %% -start = time.time() -display = report.metrics.roc() -fig = display.plot() -end = time.time() -fig - -# %% -print(f"Time taken to compute the ROC curve: {end - start:.2f} seconds") - -# %% -# As expected, since we need to recompute the predictions, it takes more time. - # %% # Visualizing the confusion matrix # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/examples/technical_details/plot_cache_mechanism.py b/examples/technical_details/plot_cache_mechanism.py index fd4d632772..fbd0970a48 100644 --- a/examples/technical_details/plot_cache_mechanism.py +++ b/examples/technical_details/plot_cache_mechanism.py @@ -1,12 +1,13 @@ """ .. _example_cache_mechanism: -=============== -Cache mechanism -=============== +==================================== +Fast repeated metrics and evaluation +==================================== -This example shows how :class:`~skore.EstimatorReport` and -:class:`~skore.CrossValidationReport` use caching to speed up computations. +This example shows that :class:`~skore.EstimatorReport` and +:class:`~skore.CrossValidationReport` avoid redundant work when you compute metrics +or displays several times, so the second call is often much faster than the first. """ # %% @@ -38,8 +39,8 @@ # Some categories are not well defined. # %% -# Caching with :class:`~skore.EstimatorReport` and :class:`~skore.CrossValidationReport` -# ====================================================================================== +# :class:`~skore.EstimatorReport` and repeated evaluation +# ======================================================= # # We use `skrub` to create a simple predictive model that handles our dataset's # challenges. @@ -62,14 +63,11 @@ ) # %% -# Caching the predictions for fast metric computation -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# First and second calls to a metric +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # -# First, we focus on :class:`~skore.EstimatorReport`, as the same philosophy will -# apply to :class:`~skore.CrossValidationReport`. -# -# Let's explore how :class:`~skore.EstimatorReport` uses caching to speed up -# predictions. We start by training the model: +# We build an :class:`~skore.EstimatorReport` and time how long successive metric +# calls take. from skore import EstimatorReport report = EstimatorReport( @@ -112,8 +110,7 @@ # # Both approaches take similar time. # -# Now, watch what happens when we compute the accuracy again with our skore estimator -# report: +# Now, we compute the accuracy again through the same report: start = time.time() result = report.metrics.accuracy() end = time.time() @@ -124,13 +121,13 @@ # %% # -# The second calculation is instant! This happens because the report saves previous -# calculations in its cache. Let's look inside the cache: -report._cache +# The second calculation is much faster, because the report does not repeat the +# expensive ``predict`` work when the same information is still available for this +# session. # %% -# The cache stores predictions by type and data source. This means that computing -# metrics that use the same type of predictions will be faster. +# A different metric that needs the same predictions +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # Let's try the precision metric: start = time.time() result = report.metrics.precision() @@ -141,23 +138,16 @@ print(f"Time taken: {end - start:.2f} seconds") # %% -# We observe that it takes only a few milliseconds to compute the precision because we -# don't need to re-compute the predictions and only have to compute the precision -# metric itself. -# Since the predictions are the bottleneck in terms of computation time, we observe -# an interesting speedup. - -# %% -# Caching all the possible predictions at once -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # -# We can pre-compute all predictions at once: -report.cache_predictions() +# It typically stays fast, because the same type of test-set predictions is reused +# where possible. # %% +# Another data source +# ^^^^^^^^^^^^^^^^^^^ # -# Now, all possible predictions are stored. Any metric calculation will be much faster, -# even on different data (like the training set): +# The first time we ask for a training-set metric, the model must be run on the +# training set as well. Later calls on that data source also benefit from reuse. start = time.time() result = report.metrics.log_loss(data_source="train") end = time.time() @@ -167,10 +157,11 @@ print(f"Time taken: {end - start:.2f} seconds") # %% -# Caching for plotting -# ^^^^^^^^^^^^^^^^^^^^ +# Plots +# ^^^^^ # -# The cache also speeds up plots. Let's create a ROC curve: +# Displays (for example a ROC curve) also benefit: the first request builds the +# underlying arrays; a second request for the same display is quick. start = time.time() display = report.metrics.roc() @@ -182,7 +173,6 @@ # %% # -# The second plot is instant because it uses cached data: start = time.time() display = report.metrics.roc() display.plot() @@ -193,26 +183,18 @@ # %% # -# We only use the cache to retrieve the `display` object and not directly the matplotlib -# figure. It means that we can still customize the cached plot before displaying it: +# We can still customize the display (for example style) and plot again; the +# evaluation work behind the same metric does not need to be redone in full. display.set_style(relplot_kwargs={"color": "tab:orange"}) _ = display.plot() # %% +# Cross-validation: :class:`~skore.CrossValidationReport` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # -# Be aware that we can clear the cache if we want to: -report.clear_cache() -report._cache - -# %% -# -# It means that nothing is stored anymore in the cache. -# -# Caching with :class:`~skore.CrossValidationReport` -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# :class:`~skore.CrossValidationReport` uses the same caching system for each split -# in cross-validation by leveraging the previous :class:`~skore.EstimatorReport`: +# A :class:`~skore.CrossValidationReport` uses one +# :class:`~skore.EstimatorReport` per split, so the same idea applies: the first +# heavy summary of metrics walks every fold; a second run reuses work where possible. from skore import CrossValidationReport report = CrossValidationReport(model, X=df, y=y, splitter=5, n_jobs=4) @@ -220,10 +202,8 @@ # %% # -# Since a :class:`~skore.CrossValidationReport` uses many -# :class:`~skore.EstimatorReport`, we will observe the same behaviour as we previously -# exposed. -# The first call will be slow because it computes the predictions for each split. +# The first call to a full summary of metrics can take a while because each fold +# is evaluated. start = time.time() result = report.metrics.summarize().frame() end = time.time() @@ -234,7 +214,7 @@ # %% # -# But the subsequent calls are fast because the predictions are cached. +# The second call is typically much faster. start = time.time() result = report.metrics.summarize().frame() end = time.time() @@ -242,7 +222,3 @@ # %% print(f"Time taken: {end - start:.2f} seconds") - -# %% -# -# Hence, we observe the same type of behaviour as we previously exposed. diff --git a/examples/use_cases/plot_employee_salaries.py b/examples/use_cases/plot_employee_salaries.py index 6d9e035e48..2a39e46606 100644 --- a/examples/use_cases/plot_employee_salaries.py +++ b/examples/use_cases/plot_employee_salaries.py @@ -110,17 +110,11 @@ # %% # A report provides a collection of useful information. For instance, it allows to # compute on demand the predictions of the model and some performance metrics. -# -# Let's cache the predictions of the cross-validated models once and for all. - -# %% -hgbt_model_report.cache_predictions() +# The first time you call a summary of metrics, the report performs the per-fold +# work it needs; later calls in the same session can reuse a lot of that work. # %% -# Now that the predictions are cached, any request to compute a metric will be -# performed using the cached predictions and will thus be fast. -# -# We can now have a look at the performance of the model with some standard metrics. +# We can have a look at the performance of the model with some standard metrics. # %% hgbt_model_report.metrics.summarize().frame() @@ -254,17 +248,9 @@ def periodic_spline_transformer(period, n_splines=None, degree=3): # We observe that the cross-validation report has detected that we have a regression # task at hand and thus provides us with some metrics and plots that make sense with # regards to our specific problem at hand. -# -# To accelerate any future computation (e.g. of a metric), we cache the predictions of -# our model once and for all. -# Note that we do not necessarily need to cache the predictions as the report will -# compute them on the fly (if not cached) and cache them for us. - -# %% -linear_model_report.cache_predictions() # %% -# We can now have a look at the performance of the model with some standard metrics. +# We can have a look at the performance of the model with some standard metrics. # %% linear_model_report.metrics.summarize().frame(favorability=True) @@ -285,9 +271,9 @@ def periodic_spline_transformer(period, n_splines=None, degree=3): # %% # In addition, if we forgot to compute a specific metric # (e.g. :func:`~sklearn.metrics.mean_absolute_error`), -# we can easily add it to the report, without re-training the model and even -# without re-computing the predictions since they are cached internally in the report. -# This allows us to save some potentially huge computation time. +# we can easily add it to the report, without re-training the model. The +# comparison reuses the underlying reports' stored evaluation where possible, so +# you can avoid redundant prediction work in the same session. # %% comparator.metrics.add(metric="neg_mean_absolute_error", name="MAE") diff --git a/skore-hub-project/src/skore_hub_project/artifact/pickle/pickle.py b/skore-hub-project/src/skore_hub_project/artifact/pickle/pickle.py index 2bf655619e..66fc8ef837 100644 --- a/skore-hub-project/src/skore_hub_project/artifact/pickle/pickle.py +++ b/skore-hub-project/src/skore_hub_project/artifact/pickle/pickle.py @@ -51,7 +51,7 @@ def content_to_upload(self) -> Generator[bytes, None, None]: reports_with_cache = [ (report, report._cache) for report in reports if hasattr(report, "_cache") ] - self.report.clear_cache() + self.report._clear_cache() try: with BytesIO() as stream: diff --git a/skore-hub-project/src/skore_hub_project/protocol.py b/skore-hub-project/src/skore_hub_project/protocol.py index 82607f66e7..805071dcd9 100644 --- a/skore-hub-project/src/skore_hub_project/protocol.py +++ b/skore-hub-project/src/skore_hub_project/protocol.py @@ -26,8 +26,8 @@ class EstimatorReport(Protocol): """Protocol equivalent to ``skore.EstimatorReport``.""" _hash: int - cache_predictions: Any - clear_cache: Any + _cache_predictions: Any + _clear_cache: Any _cache: Any metrics: Any data: Any @@ -49,8 +49,8 @@ class CrossValidationReport(Protocol): """Protocol equivalent to ``skore.CrossValidationReport``.""" _hash: int - cache_predictions: Any - clear_cache: Any + _cache_predictions: Any + _clear_cache: Any metrics: Any data: Any estimator_reports_: Any diff --git a/skore-hub-project/src/skore_hub_project/report/report.py b/skore-hub-project/src/skore_hub_project/report/report.py index 13a6887b02..520ffcd96e 100644 --- a/skore-hub-project/src/skore_hub_project/report/report.py +++ b/skore-hub-project/src/skore_hub_project/report/report.py @@ -106,7 +106,7 @@ def metrics(self) -> list[Metric[Report]]: - int [0, inf[, to be displayed at the position, - None, not to be displayed. """ - self.report.cache_predictions() + self.report._cache_predictions() metrics = [metric_cls(report=self.report) for metric_cls in self.METRICS] diff --git a/skore-hub-project/tests/unit/artifact/media/test_inspection.py b/skore-hub-project/tests/unit/artifact/media/test_inspection.py index 8bb69e3c22..e451f5febc 100644 --- a/skore-hub-project/tests/unit/artifact/media/test_inspection.py +++ b/skore-hub-project/tests/unit/artifact/media/test_inspection.py @@ -156,7 +156,7 @@ def test_inspection( } # unavailable accessor - report.clear_cache() + report._clear_cache() monkeypatch.delattr(report.inspection.__class__, accessor) upload_mock.reset_mock() diff --git a/skore-hub-project/tests/unit/report/test_cross_validation_report.py b/skore-hub-project/tests/unit/report/test_cross_validation_report.py index fc214413ef..a4450d6641 100644 --- a/skore-hub-project/tests/unit/report/test_cross_validation_report.py +++ b/skore-hub-project/tests/unit/report/test_cross_validation_report.py @@ -79,7 +79,7 @@ def serialize(object: EstimatorReport | CrossValidationReport) -> tuple[bytes, s reports_with_cache = [ (report, report._cache) for report in reports if hasattr(report, "_cache") ] - object.clear_cache() + object._clear_cache() try: with BytesIO() as stream: @@ -363,7 +363,7 @@ def test_target_range_classification(self, payload): ) @mark.respx() def test_estimators(self, project, payload, upload_mock): - payload.report.cache_predictions() + payload.report._cache_predictions() assert len(payload.estimators) == len(payload.report.estimator_reports_) for i, estimator in enumerate(payload.estimators): diff --git a/skore-hub-project/tests/unit/report/test_estimator_report.py b/skore-hub-project/tests/unit/report/test_estimator_report.py index 7bdb988449..bf2f117b6d 100644 --- a/skore-hub-project/tests/unit/report/test_estimator_report.py +++ b/skore-hub-project/tests/unit/report/test_estimator_report.py @@ -53,7 +53,7 @@ def serialize(object: EstimatorReport | CrossValidationReport) -> tuple[bytes, s reports_with_cache = [ (report, report._cache) for report in reports if hasattr(report, "_cache") ] - object.clear_cache() + object._clear_cache() try: with io.BytesIO() as stream: diff --git a/skore-local-project/tests/unit/test_project.py b/skore-local-project/tests/unit/test_project.py index f30680efc4..4b0bd9d2ef 100644 --- a/skore-local-project/tests/unit/test_project.py +++ b/skore-local-project/tests/unit/test_project.py @@ -126,7 +126,7 @@ def test_put_estimator_report_reuses_artifact_id(self, tmp_path, regression): project = Project("", workspace=tmp_path) project.put("", regression) - regression.cache_predictions() + _ = regression.get_predictions(data_source="test") project.put("", regression) # Ensure only one artifact was persisted: @@ -136,7 +136,7 @@ def test_put_estimator_report_reuses_artifact_id(self, tmp_path, regression): # Make sure the pickle is not broken: report = project.get(str(regression.id)) - report.cache_predictions() + _ = report.get_predictions(data_source="test") def test_put_cross_validation_report_reuses_artifact_id( self, tmp_path, cv_regression @@ -144,7 +144,7 @@ def test_put_cross_validation_report_reuses_artifact_id( project = Project("", workspace=tmp_path) project.put("", cv_regression) - cv_regression.cache_predictions() + _ = cv_regression.get_predictions(data_source="test") project.put("", cv_regression) # Ensure only one artifact was persisted: @@ -154,7 +154,7 @@ def test_put_cross_validation_report_reuses_artifact_id( # Make sure the pickle is not broken: report = project.get(str(cv_regression.id)) - report.cache_predictions() + _ = report.get_predictions(data_source="test") def test_init_with_envar(self, monkeypatch, tmp_path): monkeypatch.setenv("SKORE_WORKSPACE", str(tmp_path)) diff --git a/skore/src/skore/_sklearn/_comparison/metrics_accessor.py b/skore/src/skore/_sklearn/_comparison/metrics_accessor.py index d32c891231..f58bfb1f48 100644 --- a/skore/src/skore/_sklearn/_comparison/metrics_accessor.py +++ b/skore/src/skore/_sklearn/_comparison/metrics_accessor.py @@ -220,7 +220,7 @@ def timings( >>> report.metrics.timings() LogisticRegression_1 LogisticRegression_2 Fit time (s) ... ... - >>> report.cache_predictions() + >>> _ = report.get_predictions(data_source="test") >>> report.metrics.timings() LogisticRegression_1 LogisticRegression_2 Fit time (s) ... ... diff --git a/skore/src/skore/_sklearn/_comparison/report.py b/skore/src/skore/_sklearn/_comparison/report.py index ac50a33392..ba630e7fc9 100644 --- a/skore/src/skore/_sklearn/_comparison/report.py +++ b/skore/src/skore/_sklearn/_comparison/report.py @@ -258,54 +258,21 @@ def __init__( self.n_jobs = n_jobs self._ml_task = next(iter(self.reports_.values()))._ml_task # type: ignore - def clear_cache(self) -> None: - """Clear the cache. - - Examples - -------- - >>> from sklearn.datasets import make_classification - >>> from sklearn.linear_model import LogisticRegression - >>> from skore import train_test_split - >>> from skore import ComparisonReport, EstimatorReport - >>> X, y = make_classification(random_state=42) - >>> split_data = train_test_split(X=X, y=y, random_state=42, as_dict=True) - >>> estimator_1 = LogisticRegression() - >>> estimator_report_1 = EstimatorReport(estimator_1, **split_data) - >>> estimator_2 = LogisticRegression(C=2) # Different regularization - >>> estimator_report_2 = EstimatorReport(estimator_2, **split_data) - >>> report = ComparisonReport([estimator_report_1, estimator_report_2]) - >>> report.cache_predictions() - >>> report.clear_cache() - """ + def _clear_cache(self) -> None: + """Reset in-memory caches on each underlying report.""" for report in self.reports_.values(): - report.clear_cache() + report._clear_cache() - def cache_predictions( + def _cache_predictions( self, ) -> None: - """Cache the predictions for sub-estimators reports. - - Examples - -------- - >>> from sklearn.datasets import make_classification - >>> from sklearn.linear_model import LogisticRegression - >>> from skore import train_test_split - >>> from skore import ComparisonReport, EstimatorReport - >>> X, y = make_classification(random_state=42) - >>> split_data = train_test_split(X=X, y=y, random_state=42, as_dict=True) - >>> estimator_1 = LogisticRegression() - >>> estimator_report_1 = EstimatorReport(estimator_1, **split_data) - >>> estimator_2 = LogisticRegression(C=2) # Different regularization - >>> estimator_report_2 = EstimatorReport(estimator_2, **split_data) - >>> report = ComparisonReport([estimator_report_1, estimator_report_2]) - >>> report.cache_predictions() - """ + """Precompute predictions for each model in the comparison.""" for report in track( self.reports_.values(), description="Estimator predictions", total=len(self.reports_), ): - report.cache_predictions() + report._cache_predictions() def get_predictions( self, @@ -357,7 +324,6 @@ def get_predictions( >>> estimator_2 = LogisticRegression(C=2) # Different regularization >>> estimator_report_2 = EstimatorReport(estimator_2, **split_data) >>> report = ComparisonReport([estimator_report_1, estimator_report_2]) - >>> report.cache_predictions() >>> predictions = report.get_predictions(data_source="test") >>> print([split_predictions.shape for split_predictions in predictions]) [(25,), (25,)] diff --git a/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py b/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py index a1dbf7f085..89d77f7453 100644 --- a/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py +++ b/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py @@ -199,7 +199,7 @@ def timings( >>> report.metrics.timings() mean std Fit time (s) ... ... - >>> report.cache_predictions() + >>> _ = report.get_predictions(data_source="test") >>> report.metrics.timings() mean std Fit time (s) ... ... diff --git a/skore/src/skore/_sklearn/_cross_validation/report.py b/skore/src/skore/_sklearn/_cross_validation/report.py index fb22da67f7..9afa5bb256 100644 --- a/skore/src/skore/_sklearn/_cross_validation/report.py +++ b/skore/src/skore/_sklearn/_cross_validation/report.py @@ -357,47 +357,20 @@ def from_state(cls, state: dict[str, Any]) -> CrossValidationReport: return report - def clear_cache(self) -> None: - """Clear the cache. - - Examples - -------- - >>> from sklearn.datasets import load_breast_cancer - >>> from sklearn.linear_model import LogisticRegression - >>> from skore import CrossValidationReport - >>> X, y = load_breast_cancer(return_X_y=True) - >>> classifier = LogisticRegression(max_iter=10_000) - >>> report = CrossValidationReport(classifier, X=X, y=y, splitter=2) - >>> report.cache_predictions() - >>> report.clear_cache() - >>> report.estimator_reports_[0]._cache - {} - """ + def _clear_cache(self) -> None: + """Reset in-memory caches on each fold's estimator report.""" for report in self.estimator_reports_: - report.clear_cache() + report._clear_cache() - def cache_predictions( + def _cache_predictions( self, ) -> None: - """Cache the predictions for sub-estimators reports. - - Examples - -------- - >>> from sklearn.datasets import load_breast_cancer - >>> from sklearn.linear_model import LogisticRegression - >>> from skore import CrossValidationReport - >>> X, y = load_breast_cancer(return_X_y=True) - >>> classifier = LogisticRegression(max_iter=10_000) - >>> report = CrossValidationReport(classifier, X=X, y=y, splitter=2) - >>> report.cache_predictions() - >>> report.estimator_reports_[0]._cache - {...} - """ + """Precompute predictions for each cross-validation split report.""" for estimator_report in track( self.estimator_reports_, description="Cross-validation predictions for split", ): - estimator_report.cache_predictions() + estimator_report._cache_predictions() def get_predictions( self, diff --git a/skore/src/skore/_sklearn/_estimator/metrics_accessor.py b/skore/src/skore/_sklearn/_estimator/metrics_accessor.py index 08e47c9ee0..e00d38d5b9 100644 --- a/skore/src/skore/_sklearn/_estimator/metrics_accessor.py +++ b/skore/src/skore/_sklearn/_estimator/metrics_accessor.py @@ -295,7 +295,7 @@ def timings(self) -> dict: >>> report = evaluate(estimator, X, y, splitter=0.2) >>> report.metrics.timings() {'fit_time': ...} - >>> report.cache_predictions() + >>> _ = report.get_predictions(data_source="test") >>> report.metrics.timings() {'fit_time': ..., 'predict_time_test': ...} """ diff --git a/skore/src/skore/_sklearn/_estimator/report.py b/skore/src/skore/_sklearn/_estimator/report.py index 7ae6911471..9c6106d065 100644 --- a/skore/src/skore/_sklearn/_estimator/report.py +++ b/skore/src/skore/_sklearn/_estimator/report.py @@ -346,31 +346,15 @@ def from_state(cls, state: dict[str, Any]) -> EstimatorReport: return report - def clear_cache(self) -> None: - """Clear the cache. - - Examples - -------- - >>> from sklearn.datasets import load_breast_cancer - >>> from sklearn.linear_model import LogisticRegression - >>> from skore import train_test_split - >>> from skore import EstimatorReport - >>> X, y = load_breast_cancer(return_X_y=True) - >>> split_data = train_test_split(X=X, y=y, random_state=0, as_dict=True) - >>> classifier = LogisticRegression(max_iter=10_000) - >>> report = EstimatorReport(classifier, **split_data) - >>> report.cache_predictions() - >>> report.clear_cache() - >>> report._cache - {} - """ + def _clear_cache(self) -> None: + """Reset the in-memory cache used for predictions and timings.""" self._cache = Cache() - def cache_predictions( + def _cache_predictions( self, data_source: DataSource | Literal["both"] = "both", ) -> None: - """Cache estimator's predictions. + """Precompute and store predictions and related outputs for a data source. Parameters ---------- @@ -380,25 +364,11 @@ def cache_predictions( - "test" : cache predictions for the test set only. - "train" : cache predictions for the train set only. - "both" : cache predictions for both train and test sets when available. - - Examples - -------- - >>> from sklearn.datasets import load_breast_cancer - >>> from sklearn.linear_model import LogisticRegression - >>> from skore import train_test_split - >>> from skore import EstimatorReport - >>> X, y = load_breast_cancer(return_X_y=True) - >>> split_data = train_test_split(X=X, y=y, random_state=0, as_dict=True) - >>> classifier = LogisticRegression(max_iter=10_000) - >>> report = EstimatorReport(classifier, **split_data) - >>> report.cache_predictions() - >>> report._cache - {...} """ if data_source == "both": - self.cache_predictions(data_source="test") + self._cache_predictions(data_source="test") if self.X_train is not None: - self.cache_predictions(data_source="train") + self._cache_predictions(data_source="train") return data = self._test_data if data_source == "test" else self._train_data @@ -655,7 +625,7 @@ def _get_predictions( raise ValueError(f"Cannot specify a `pos_label` for task {self.ml_task}") method_name = _check_response_method(self.estimator_, response_method).__name__ - self.cache_predictions(data_source=data_source) + self._cache_predictions(data_source=data_source) cache_key = make_cache_key(data_source, method_name) predictions = self._cache[cache_key] diff --git a/skore/tests/unit/reports/comparison/cross_validation/metrics/test_numeric.py b/skore/tests/unit/reports/comparison/cross_validation/metrics/test_numeric.py index 60ed999c7d..3e3d281670 100644 --- a/skore/tests/unit/reports/comparison/cross_validation/metrics/test_numeric.py +++ b/skore/tests/unit/reports/comparison/cross_validation/metrics/test_numeric.py @@ -43,7 +43,7 @@ def case_timings_with_predictions( ) report = comparison_cross_validation_reports_binary_classification - report.cache_predictions() + report._cache_predictions() return ( report, "timings", diff --git a/skore/tests/unit/reports/cross_validation/metrics/test_numeric.py b/skore/tests/unit/reports/cross_validation/metrics/test_numeric.py index 7224475d75..19c10e0cea 100644 --- a/skore/tests/unit/reports/cross_validation/metrics/test_numeric.py +++ b/skore/tests/unit/reports/cross_validation/metrics/test_numeric.py @@ -49,7 +49,7 @@ def _check_results_single_metric(report, metric, expected_n_splits, expected_nb_ # check that something was written to the children's cache assert all(report._cache != {} for report in report.estimator_reports_) - report.clear_cache() + report._clear_cache() _check_metrics_names(result, [metric], expected_nb_stats) diff --git a/skore/tests/unit/reports/cross_validation/test_report.py b/skore/tests/unit/reports/cross_validation/test_report.py index 57936dea75..9551f26a87 100644 --- a/skore/tests/unit/reports/cross_validation/test_report.py +++ b/skore/tests/unit/reports/cross_validation/test_report.py @@ -112,18 +112,18 @@ def test_attributes(fixture_name, request, cv, n_jobs): ) @pytest.mark.parametrize("n_jobs", [None, 1, 2]) def test_cache_predictions(request, fixture_name, expected_n_keys, n_jobs): - """Check that calling cache_predictions fills the cache.""" + """Check that :meth:`CrossValidationReport._cache_predictions` fills the cache.""" estimator, X, y = request.getfixturevalue(fixture_name) report = CrossValidationReport(estimator, X, y, splitter=2, n_jobs=n_jobs) for estimator_report in report.estimator_reports_: assert estimator_report._cache == {} - report.cache_predictions() + report._cache_predictions() for estimator_report in report.estimator_reports_: assert len(estimator_report._cache) == expected_n_keys - report.clear_cache() + report._clear_cache() for estimator_report in report.estimator_reports_: assert estimator_report._cache == {} @@ -182,7 +182,7 @@ def test_pickle(tmp_path, logistic_binary_classification_data): """ estimator, X, y = logistic_binary_classification_data report = CrossValidationReport(estimator, X, y, splitter=2) - report.cache_predictions() + report._cache_predictions() joblib.dump(report, tmp_path / "report.joblib") @@ -317,7 +317,7 @@ def test_from_state_bypasses_init_and_restores_state( estimator, X, y = logistic_binary_classification_data report = CrossValidationReport(estimator, X, y, splitter=2) expected_accuracy = report.metrics.accuracy() - report.cache_predictions() + report._cache_predictions() state = report.get_state() def _unexpected_init(self, *args, **kwargs): @@ -342,7 +342,7 @@ def _unexpected_init(self, *args, **kwargs): # check new metrics/predictions can be computed: restored.metrics.roc_auc() _ = report.get_predictions(data_source="test") - report.cache_predictions() + report._cache_predictions() def test_get_from_state_with_complex_data_op(): @@ -379,7 +379,7 @@ def _join_features(left, right): restored = CrossValidationReport.from_state(state) # check fresh computations still work after restoring from state: - restored.clear_cache() + restored._clear_cache() assert restored.metrics.accuracy().equals(expected_accuracy) preds = restored.get_predictions(data_source="test") for pred, expected_pred in zip(preds, expected_preds, strict=True): diff --git a/skore/tests/unit/reports/estimator/test_report.py b/skore/tests/unit/reports/estimator/test_report.py index bff5979a4d..250442e4f4 100644 --- a/skore/tests/unit/reports/estimator/test_report.py +++ b/skore/tests/unit/reports/estimator/test_report.py @@ -172,7 +172,7 @@ def test_check_support_plot( ], ) def test_cache_predictions(request, fixture_name, pass_train_data, expected_n_keys): - """Check that calling cache_predictions fills the cache.""" + """Check that :meth:`EstimatorReport._cache_predictions` fills the cache.""" estimator, X_test, y_test = request.getfixturevalue(fixture_name) if pass_train_data: report = EstimatorReport( @@ -182,11 +182,11 @@ def test_cache_predictions(request, fixture_name, pass_train_data, expected_n_ke report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) assert report._cache == {} - report.cache_predictions() + report._cache_predictions() assert len(report._cache) == expected_n_keys assert report._cache != {} stored_cache = deepcopy(report._cache) - report.cache_predictions() + report._cache_predictions() # check that the keys are exactly the same assert report._cache.keys() == stored_cache.keys() @@ -222,7 +222,7 @@ def test_pickle(forest_binary_classification_with_test): """ estimator, X_test, y_test = forest_binary_classification_with_test report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) - report.cache_predictions() + report._cache_predictions() with BytesIO() as stream: joblib.dump(report, stream) @@ -539,7 +539,7 @@ def test_from_state_bypasses_init_and_restores_state( pos_label=1, ) expected_accuracy = report.metrics.accuracy() - report.cache_predictions() + report._cache_predictions() report.metrics.add("f1", name="F1") state = report.get_state() assert state["metadata"]["report_type"] == report._report_type diff --git a/sphinx/reference/report/comparison_report.rst b/sphinx/reference/report/comparison_report.rst index 97258a1648..7ca4adddcf 100644 --- a/sphinx/reference/report/comparison_report.rst +++ b/sphinx/reference/report/comparison_report.rst @@ -22,8 +22,6 @@ way. The functionalities of the report are accessible through accessors. ComparisonReport.help ComparisonReport.diagnose ComparisonReport.add_checks - ComparisonReport.cache_predictions - ComparisonReport.clear_cache ComparisonReport.create_estimator_report ComparisonReport.get_predictions diff --git a/sphinx/reference/report/cross_validation_report.rst b/sphinx/reference/report/cross_validation_report.rst index 77be3484d6..4fa805c753 100644 --- a/sphinx/reference/report/cross_validation_report.rst +++ b/sphinx/reference/report/cross_validation_report.rst @@ -22,8 +22,6 @@ functionalities of the report are exposed through accessors. CrossValidationReport.help CrossValidationReport.diagnose CrossValidationReport.add_checks - CrossValidationReport.cache_predictions - CrossValidationReport.clear_cache CrossValidationReport.create_estimator_report CrossValidationReport.get_predictions diff --git a/sphinx/reference/report/estimator_report.rst b/sphinx/reference/report/estimator_report.rst index f3110a3085..e64020a4d9 100644 --- a/sphinx/reference/report/estimator_report.rst +++ b/sphinx/reference/report/estimator_report.rst @@ -22,8 +22,6 @@ report are accessible through accessors. EstimatorReport.help EstimatorReport.diagnose EstimatorReport.add_checks - EstimatorReport.cache_predictions - EstimatorReport.clear_cache EstimatorReport.get_predictions .. rubric:: Accessors diff --git a/sphinx/user_guide/reporters.rst b/sphinx/user_guide/reporters.rst index 19aee11dbf..adf6d0387f 100644 --- a/sphinx/user_guide/reporters.rst +++ b/sphinx/user_guide/reporters.rst @@ -113,13 +113,11 @@ intermediate information that is expensive to compute, such as predictions. It efficiently re-uses this information when recomputing the same metric or a metric requiring the same intermediate information. -We expose three methods to interact with the cache: - -- :meth:`EstimatorReport.cache_predictions` to cache the predictions of the estimator - without awaiting the computation when calling the evaluation metrics. -- :meth:`EstimatorReport.clear_cache` to clear the cache. -- :meth:`EstimatorReport.get_predictions` to get the predictions from the cache or - compute them if they are not in the cache. +Predictions and related outputs are stored in memory the first time they are needed +(for example when you compute a metric or call :meth:`EstimatorReport.get_predictions`). +Later calls that need the same values reuse them, so you typically do not need to +think about warming or invalidating a cache. Use :meth:`EstimatorReport.get_predictions` +when you need the raw model outputs; metrics and displays use the same path internally. .. note:: The current implementation of the caching mechanism happens in-memory. It is @@ -128,8 +126,8 @@ We expose three methods to interact with the cache: section :ref:`project` for more details. Refer to the example entitled -:ref:`sphx_glr_auto_examples_technical_details_plot_cache_mechanism.py` to get a -detailed view of the caching mechanism. +:ref:`sphx_glr_auto_examples_technical_details_plot_cache_mechanism.py` for a +concrete walkthrough of first versus repeated evaluation time. .. _cross_validation_report: