From f650f8a077509dd75bcb6c3f62ddac224be57c6f Mon Sep 17 00:00:00 2001 From: madtoinou Date: Mon, 3 Mar 2025 15:05:52 +0100 Subject: [PATCH 1/2] feat: allow stride=forecast_horizon for anomaly forecasting model --- darts/ad/anomaly_model/anomaly_model.py | 2 +- darts/ad/anomaly_model/forecasting_am.py | 28 ++++++++++++++--- darts/tests/ad/test_anomaly_model.py | 40 ++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 6 deletions(-) diff --git a/darts/ad/anomaly_model/anomaly_model.py b/darts/ad/anomaly_model/anomaly_model.py index 63655db40c..03773fbd48 100644 --- a/darts/ad/anomaly_model/anomaly_model.py +++ b/darts/ad/anomaly_model/anomaly_model.py @@ -50,7 +50,7 @@ def fit( if not allow_model_training and not self.scorers_are_trainable: return self - # check input series and covert to sequences + # check input series and convert to sequences series, kwargs = self._process_input_series(series, **kwargs) self._fit_core( series=series, allow_model_training=allow_model_training, **kwargs diff --git a/darts/ad/anomaly_model/forecasting_am.py b/darts/ad/anomaly_model/forecasting_am.py index 8b4339cd9c..31bf1560c7 100644 --- a/darts/ad/anomaly_model/forecasting_am.py +++ b/darts/ad/anomaly_model/forecasting_am.py @@ -23,7 +23,7 @@ from darts.ad.scorers.scorers import AnomalyScorer from darts.logging import get_logger, raise_log from darts.models.forecasting.forecasting_model import GlobalForecastingModel -from darts.timeseries import TimeSeries +from darts.timeseries import TimeSeries, concatenate logger = get_logger(__name__) @@ -248,6 +248,7 @@ def predict_series( start: Union[pd.Timestamp, float, int] = None, start_format: Literal["position", "value"] = "value", num_samples: int = 1, + stride: int = 1, verbose: bool = False, show_warnings: bool = True, enable_optimization: bool = True, @@ -288,6 +289,9 @@ def predict_series( an error if the value is not in `series`' index. Default: `'value'` num_samples Number of times a prediction is sampled from a probabilistic model. Must be `1` for deterministic models. + stride + The number of time steps between two consecutive predictions. Must be either `1` or + `forecast_horizon` (caution, the prediction will be faster but less accurate). verbose Whether to print the progress. show_warnings @@ -300,7 +304,7 @@ def predict_series( Returns ------- Sequence[TimeSeries] - A sequence of `TimeSeries` with the historical forecasts for each series (with `last_points_only=True`). + A sequence of `TimeSeries`, with one historical forecasts for each series. """ if not self.model._fit_called: raise_log( @@ -309,14 +313,21 @@ def predict_series( ), logger=logger, ) - return self.model.historical_forecasts( + if not (stride == 1 or stride == forecast_horizon): + raise_log( + ValueError( + f"`stride` must be equal to either `1` or `forecast_horizon`, received {stride}." + ), + logger=logger, + ) + forecasts = self.model.historical_forecasts( series, past_covariates=past_covariates, future_covariates=future_covariates, forecast_horizon=forecast_horizon, - stride=1, + stride=stride, retrain=False, - last_points_only=True, + last_points_only=(stride == 1), start=start, start_format=start_format, num_samples=num_samples, @@ -324,6 +335,13 @@ def predict_series( show_warnings=show_warnings, enable_optimization=enable_optimization, ) + if stride == 1: + return forecasts + # concatenate the strided historical forecasts blocks (last_point_only=False) + if isinstance(series, Sequence): + return [concatenate(hist_fc) for hist_fc in forecasts] + else: + return concatenate(forecasts) def eval_metric( self, diff --git a/darts/tests/ad/test_anomaly_model.py b/darts/tests/ad/test_anomaly_model.py index c6da7555e8..5e8446b8c1 100644 --- a/darts/tests/ad/test_anomaly_model.py +++ b/darts/tests/ad/test_anomaly_model.py @@ -1480,3 +1480,43 @@ def show_anomalies_function(self, visualization_function): pred_scores=[self.test, self.test], names_of_scorers=["scorer1", "scorer2", "scorer3"], ) + + def test_pred_series(self): + """Basics tests for the `pred_series` method for ForecastingAnomalyModel""" + model = ForecastingAnomalyModel( + model=RegressionModel(lags=10, output_chunk_length=2), scorer=Norm() + ) + # cannot predict without training + with pytest.raises(ValueError) as err: + model.predict_series(self.test) + assert str(err.value).endswith("has not been trained yet. Call `fit()` before.") + + # must set `allow_model_training=True` to fit the underlying model "in place" + model.fit(self.train, allow_model_training=True) + + # stride must be equal to 1 or forecast_horizon + err_msg = "`stride` must be equal to either `1` or `forecast_horizon`, received" + with pytest.raises(ValueError) as err: + model.predict_series(self.test, stride=2, forecast_horizon=1) + assert str(err.value).startswith(err_msg) + with pytest.raises(ValueError) as err: + model.predict_series(self.test, stride=2, forecast_horizon=3) + assert str(err.value).startswith(err_msg) + + # predict single series + pred_no_stride = model.predict_series(self.test, stride=1) + pred_no_stride_list = model.predict_series([self.test], stride=1) + assert pred_no_stride.time_index.equals(pred_no_stride_list[0].time_index) + np.testing.assert_almost_equal( + pred_no_stride.values(), pred_no_stride_list[0].values() + ) + + # stride == horizon + pred_strided = model.predict_series(self.test, forecast_horizon=3, stride=3) + pred_strided_list = model.predict_series( + [self.test], forecast_horizon=3, stride=3 + ) + assert pred_strided.time_index.equals(pred_strided_list[0].time_index) + np.testing.assert_almost_equal( + pred_strided.values(), pred_strided_list[0].values() + ) From fdfb3005f20d3ccdc3949acda9d5e8916bdf694a Mon Sep 17 00:00:00 2001 From: madtoinou Date: Mon, 3 Mar 2025 15:10:42 +0100 Subject: [PATCH 2/2] updated changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a592e63e5d..e40ff0a454 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Made method `ForecastingModel.untrained_model()` public. Use this method to get a new (untrained) model instance created with the same parameters. [#2684](https://github.com/unit8co/darts/pull/2684) by [Timon Erhart](https://github.com/turbotimon) - `TimeSeries.plot()` now supports setting the color for each component in the series. Simply pass a list / sequence of colors with length matching the number of components as parameters "c" or "colors". [#2680](https://github.com/unit8co/darts/pull/2680) by [Jules Authier](https://github.com/authierj) - Made it possible to run the quickstart notebook `00-quickstart.ipynb` locally. [#2691](https://github.com/unit8co/darts/pull/2691) by [Jules Authier](https://github.com/authierj) +- Exposed the `stride` argument of the `ForecastingAnomalyModel.predict_series` method, allowing for faster inference when `stride=forecast_horizon` at the cost of accuracy. [#2709](https://github.com/unit8co/darts/pull/2709) by [Antoine Madrona](https://github.com/madtoinou) **Fixed**