diff --git a/CHANGELOG.md b/CHANGELOG.md index eab8505bb0..c3221981c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - `TimeSeries.plot()` now supports setting the color for each component in the series. Simply pass a list / sequence of colors with length matching the number of components as parameters "c" or "colors". [#2680](https://github.com/unit8co/darts/pull/2680) by [Jules Authier](https://github.com/authierj) - Made it possible to run the quickstart notebook `00-quickstart.ipynb` locally. [#2691](https://github.com/unit8co/darts/pull/2691) by [Jules Authier](https://github.com/authierj) - Added `quantile` parameter to `RegressionModel.get_estimator()` to get the specific quantile estimator for probabilistic regression models using the `quantile` likelihood. [#2716](https://github.com/unit8co/darts/pull/2716) by [Antoine Madrona](https://github.com/madtoinou) +- Exposed the `stride` argument of the `ForecastingAnomalyModel.predict_series` method, allowing for faster inference when `stride=forecast_horizon` at the cost of accuracy. [#2709](https://github.com/unit8co/darts/pull/2709) by [Antoine Madrona](https://github.com/madtoinou) **Removed** diff --git a/darts/ad/anomaly_model/anomaly_model.py b/darts/ad/anomaly_model/anomaly_model.py index 63655db40c..03773fbd48 100644 --- a/darts/ad/anomaly_model/anomaly_model.py +++ b/darts/ad/anomaly_model/anomaly_model.py @@ -50,7 +50,7 @@ def fit( if not allow_model_training and not self.scorers_are_trainable: return self - # check input series and covert to sequences + # check input series and convert to sequences series, kwargs = self._process_input_series(series, **kwargs) self._fit_core( series=series, allow_model_training=allow_model_training, **kwargs diff --git a/darts/ad/anomaly_model/forecasting_am.py b/darts/ad/anomaly_model/forecasting_am.py index 8b4339cd9c..31bf1560c7 100644 --- a/darts/ad/anomaly_model/forecasting_am.py +++ b/darts/ad/anomaly_model/forecasting_am.py @@ -23,7 +23,7 @@ from darts.ad.scorers.scorers import AnomalyScorer from darts.logging import get_logger, raise_log from darts.models.forecasting.forecasting_model import GlobalForecastingModel -from darts.timeseries import TimeSeries +from darts.timeseries import TimeSeries, concatenate logger = get_logger(__name__) @@ -248,6 +248,7 @@ def predict_series( start: Union[pd.Timestamp, float, int] = None, start_format: Literal["position", "value"] = "value", num_samples: int = 1, + stride: int = 1, verbose: bool = False, show_warnings: bool = True, enable_optimization: bool = True, @@ -288,6 +289,9 @@ def predict_series( an error if the value is not in `series`' index. Default: `'value'` num_samples Number of times a prediction is sampled from a probabilistic model. Must be `1` for deterministic models. + stride + The number of time steps between two consecutive predictions. Must be either `1` or + `forecast_horizon` (caution, the prediction will be faster but less accurate). verbose Whether to print the progress. show_warnings @@ -300,7 +304,7 @@ def predict_series( Returns ------- Sequence[TimeSeries] - A sequence of `TimeSeries` with the historical forecasts for each series (with `last_points_only=True`). + A sequence of `TimeSeries`, with one historical forecasts for each series. """ if not self.model._fit_called: raise_log( @@ -309,14 +313,21 @@ def predict_series( ), logger=logger, ) - return self.model.historical_forecasts( + if not (stride == 1 or stride == forecast_horizon): + raise_log( + ValueError( + f"`stride` must be equal to either `1` or `forecast_horizon`, received {stride}." + ), + logger=logger, + ) + forecasts = self.model.historical_forecasts( series, past_covariates=past_covariates, future_covariates=future_covariates, forecast_horizon=forecast_horizon, - stride=1, + stride=stride, retrain=False, - last_points_only=True, + last_points_only=(stride == 1), start=start, start_format=start_format, num_samples=num_samples, @@ -324,6 +335,13 @@ def predict_series( show_warnings=show_warnings, enable_optimization=enable_optimization, ) + if stride == 1: + return forecasts + # concatenate the strided historical forecasts blocks (last_point_only=False) + if isinstance(series, Sequence): + return [concatenate(hist_fc) for hist_fc in forecasts] + else: + return concatenate(forecasts) def eval_metric( self, diff --git a/darts/tests/ad/test_anomaly_model.py b/darts/tests/ad/test_anomaly_model.py index c6da7555e8..5e8446b8c1 100644 --- a/darts/tests/ad/test_anomaly_model.py +++ b/darts/tests/ad/test_anomaly_model.py @@ -1480,3 +1480,43 @@ def show_anomalies_function(self, visualization_function): pred_scores=[self.test, self.test], names_of_scorers=["scorer1", "scorer2", "scorer3"], ) + + def test_pred_series(self): + """Basics tests for the `pred_series` method for ForecastingAnomalyModel""" + model = ForecastingAnomalyModel( + model=RegressionModel(lags=10, output_chunk_length=2), scorer=Norm() + ) + # cannot predict without training + with pytest.raises(ValueError) as err: + model.predict_series(self.test) + assert str(err.value).endswith("has not been trained yet. Call `fit()` before.") + + # must set `allow_model_training=True` to fit the underlying model "in place" + model.fit(self.train, allow_model_training=True) + + # stride must be equal to 1 or forecast_horizon + err_msg = "`stride` must be equal to either `1` or `forecast_horizon`, received" + with pytest.raises(ValueError) as err: + model.predict_series(self.test, stride=2, forecast_horizon=1) + assert str(err.value).startswith(err_msg) + with pytest.raises(ValueError) as err: + model.predict_series(self.test, stride=2, forecast_horizon=3) + assert str(err.value).startswith(err_msg) + + # predict single series + pred_no_stride = model.predict_series(self.test, stride=1) + pred_no_stride_list = model.predict_series([self.test], stride=1) + assert pred_no_stride.time_index.equals(pred_no_stride_list[0].time_index) + np.testing.assert_almost_equal( + pred_no_stride.values(), pred_no_stride_list[0].values() + ) + + # stride == horizon + pred_strided = model.predict_series(self.test, forecast_horizon=3, stride=3) + pred_strided_list = model.predict_series( + [self.test], forecast_horizon=3, stride=3 + ) + assert pred_strided.time_index.equals(pred_strided_list[0].time_index) + np.testing.assert_almost_equal( + pred_strided.values(), pred_strided_list[0].values() + )