diff --git a/CHANGELOG.md b/CHANGELOG.md index dfdd5158d5..51c7dbab6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** +- Added `project_after_n_blocks` hyperparameter to `TSMixerModel`, allowing some or all of the backbone to operate in the lookback rather than forecasted time space [#2555](https://github.com/unit8co/darts/pull/2555) by [Eric Schibli](https://github.com/eschibli) + **Fixed** - Updated the restrictive type hint for the timezone parameter `tz` to `Any`. This allows the use of more timezone definitions supported by Pandas [tz_convert](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DatetimeIndex.tz_convert.html). [#3015](https://github.com/unit8co/darts/pull/3015) by [Moritz Waldleben](https://github.com/mwaldleben). diff --git a/darts/models/forecasting/tsmixer_model.py b/darts/models/forecasting/tsmixer_model.py index 680152cbe5..fd1ba57d5e 100644 --- a/darts/models/forecasting/tsmixer_model.py +++ b/darts/models/forecasting/tsmixer_model.py @@ -264,6 +264,7 @@ def __init__( super().__init__() mixing_input = input_dim + if static_cov_dim != 0: self.feature_mixing_static = _FeatureMixing( sequence_length=sequence_length, @@ -312,13 +313,14 @@ def __init__( self, input_dim: int, output_dim: int, + num_encoder_blocks: int, + num_decoder_blocks: int, past_cov_dim: int, future_cov_dim: int, static_cov_dim: int, nr_params: int, hidden_size: int, ff_size: int, - num_blocks: int, activation: str, dropout: float, norm_type: str | nn.Module, @@ -326,7 +328,7 @@ def __init__( **kwargs, ) -> None: """ - Initializes the TSMixer module for use within a Darts forecasting model. + Initializes the TSMixer module. Parameters ---------- @@ -334,6 +336,11 @@ def __init__( Number of input target features. output_dim Number of output target features. + + num_encoder_blocks + Number of encoder blocks (that operate on input_chunk_length). + num_decoder_blocks + Number of decoder blocks (that operate on output_chunk_length). past_cov_dim Number of past covariate features. future_cov_dim @@ -347,8 +354,6 @@ def __init__( Hidden state size of the TSMixer. ff_size Dimension of the feedforward network internal to the module. - num_blocks - Number of mixer blocks. activation Activation function to use. dropout @@ -368,7 +373,7 @@ def __init__( if activation not in ACTIVATIONS: raise_log( ValueError( - f"Invalid `activation={activation}`. Must be on of {ACTIVATIONS}." + f"Invalid `activation={activation}`. Must be one of {ACTIVATIONS}." ), logger=logger, ) @@ -378,7 +383,7 @@ def __init__( if norm_type not in NORMS: raise_log( ValueError( - f"Invalid `norm_type={norm_type}`. Must be on of {NORMS}." + f"Invalid `norm_type={norm_type}`. Must be one of {NORMS}." ), logger=logger, ) @@ -397,13 +402,17 @@ def __init__( "normalize_before": normalize_before, } + # Projects from the input time dimension to the output time dimension self.fc_hist = nn.Linear(self.input_chunk_length, self.output_chunk_length) + self.feature_mixing_hist = _FeatureMixing( - sequence_length=self.output_chunk_length, + sequence_length=self.input_chunk_length, input_dim=input_dim + past_cov_dim + future_cov_dim, output_dim=hidden_size, **mixer_params, ) + + # Process future covariates in decoder (if exists) if future_cov_dim: self.feature_mixing_future = _FeatureMixing( sequence_length=self.output_chunk_length, @@ -413,19 +422,40 @@ def __init__( ) else: self.feature_mixing_future = None - self.conditional_mixer = self._build_mixer( - prediction_length=self.output_chunk_length, - num_blocks=num_blocks, + + # Remove previous fc_hist and fc_future. + # New projection from encoder (input_chunk_length) to decoder (output_chunk_length) + self.encoder_to_decoder = nn.Linear( + self.input_chunk_length, self.output_chunk_length + ) + + # Build encoder mixer (operating on input_chunk_length) + self.encoder_mixer = self._build_mixer( + sequence_length=self.input_chunk_length, + num_blocks=num_encoder_blocks, + hidden_size=hidden_size, + future_cov_dim=0, # encoder mixing uses only historical features + static_cov_dim=static_cov_dim, + **mixer_params, + ) + # Build decoder mixer (operating on output_chunk_length) + self.decoder_mixer = self._build_mixer( + sequence_length=self.output_chunk_length, + num_blocks=num_decoder_blocks, hidden_size=hidden_size, future_cov_dim=future_cov_dim, static_cov_dim=static_cov_dim, **mixer_params, ) - self.fc_out = nn.Linear(hidden_size, output_dim * nr_params) + + self.fc_out = nn.Linear( + hidden_size * (1 + int((num_decoder_blocks == 0) and (future_cov_dim > 0))), + output_dim * nr_params, + ) @staticmethod def _build_mixer( - prediction_length: int, + sequence_length: int, num_blocks: int, hidden_size: int, future_cov_dim: int, @@ -436,14 +466,15 @@ def _build_mixer( # the first block takes `x` consisting of concatenated features with size `hidden_size`: # - historic features # - optional future features - input_dim_block = hidden_size * (1 + int(future_cov_dim > 0)) - + input_dim_block = hidden_size * ( + 1 + int(future_cov_dim > 0) + ) # starting dimension for mixer layers mixer_layers = nn.ModuleList() for _ in range(num_blocks): layer = _ConditionalMixerLayer( input_dim=input_dim_block, output_dim=hidden_size, - sequence_length=prediction_length, + sequence_length=sequence_length, static_cov_dim=static_cov_dim, **kwargs, ) @@ -478,39 +509,53 @@ def forward(self, x_in: PLModuleInput) -> torch.Tensor: # S: static cov features # H = C + P + F: historic features # H_S: hidden Size - # N_P: likelihood parameters + # N_P: number of parameters to predict per target feature # `x`: (B, L, H), `x_future`: (B, T, F), `x_static`: (B, C or 1, S) x, x_future, x_static = x_in - # swap feature and time dimensions (B, L, H) -> (B, H, L) - x = _time_to_feature(x) - # linear transformations to horizon (B, H, L) -> (B, H, T) - x = self.fc_hist(x) - # (B, H, T) -> (B, T, H) - x = _time_to_feature(x) + if self.static_cov_dim: + # (B, C, S) -> (B, 1, C * S) + x_static_hist = x_static.reshape(x_static.shape[0], 1, -1) + # repeat to match lookback time dim: (B, 1, C * S) -> (B, L, C * S) + x_static_hist = x_static_hist.repeat(1, self.input_chunk_length, 1) + + # (B, C, S) -> (B, 1, C * S) + x_static_future = x_static.reshape(x_static.shape[0], 1, -1) + # repeat to match horizon time dim: (B, 1, C * S) -> (B, T, C * S) + x_static_future = x_static_future.repeat(1, self.output_chunk_length, 1) + else: + x_static_hist = None + x_static_future = None - # feature mixing for historical features (B, T, H) -> (B, T, H_S) + # Process historical data (B, L, H) -> (B, L, H_S) x = self.feature_mixing_hist(x) + + # Process future data (B, T, F) -> (B, T, H_S) if self.future_cov_dim: - # feature mixing for future features (B, T, F) -> (B, T, H_S) x_future = self.feature_mixing_future(x_future) - # (B, T, H_S) + (B, T, H_S) -> (B, T, 2*H_S) - x = torch.cat([x, x_future], dim=-1) - if self.static_cov_dim: - # (B, C, S) -> (B, 1, C * S) - x_static = x_static.reshape(x_static.shape[0], 1, -1) - # repeat to match horizon (B, 1, C * S) -> (B, T, C * S) - x_static = x_static.repeat(1, self.output_chunk_length, 1) + # Apply encoder mixer layers + for layer in self.encoder_mixer: + x = layer(x, x_static_hist) + + # Project time dimension (B, L, H_S) -> (B, T, H_S) + x = x.transpose(1, 2) + x = self.encoder_to_decoder(x) # Linear map + x = x.transpose(1, 2) + + # If future covariates are provided, mix and concatenate them with the encoder output + if self.future_cov_dim: + # (B, T, H_S) -> (B, T, 2 * H_S) + x = torch.cat([x, x_future], dim=-1) - for mixing_layer in self.conditional_mixer: - # conditional mixer layers with static covariates (B, T, 2 * H_S), (B, T, C * S) -> (B, T, H_S) - x = mixing_layer(x, x_static=x_static) + # Apply decoder mixer layers + for layer in self.decoder_mixer: + x = layer(x, x_static_future) - # linear transformation to generate the forecast (B, T, H_S) -> (B, T, C * N_P) + # Forecast generation + # (B, T, H_S) -> (B, T, C * N_P) x = self.fc_out(x) - # (B, T, C * N_P) -> (B, T, C, N_P) x = x.view(-1, self.output_chunk_length, self.output_dim, self.nr_params) return x @@ -529,6 +574,7 @@ def __init__( norm_type: str | nn.Module = "LayerNorm", normalize_before: bool = False, use_static_covariates: bool = True, + project_after_n_blocks: int = 0, **kwargs, ) -> None: """Time-Series Mixer (TSMixer): An All-MLP Architecture for Time Series. @@ -570,8 +616,10 @@ def __init__( The hidden state size / size of the second feed-forward layer in the feature mixing MLP. ff_size The size of the first feed-forward layer in the feature mixing MLP. - num_blocks - The number of mixer blocks in the model. The number includes the first block and all subsequent blocks. + num_encoder_blocks + The number of mixer blocks in the encoder. + num_decoder_blocks + The number of mixer blocks in the decoder. activation The activation function to use in the mixer layers (default='ReLU'). Supported activations: ['ReLU', 'RReLU', 'PReLU', 'ELU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'Sigmoid', @@ -762,11 +810,12 @@ def encode_year(idx): # Model specific parameters self.ff_size = ff_size self.dropout = dropout - self.num_blocks = num_blocks self.activation = activation self.normalize_before = normalize_before self.norm_type = norm_type self.hidden_size = hidden_size + self.num_blocks = num_blocks + self.project_after_n_blocks = project_after_n_blocks self._considers_static_covariates = use_static_covariates def _create_model(self, train_sample: TorchTrainingSample) -> nn.Module: @@ -793,6 +842,22 @@ def _create_model(self, train_sample: TorchTrainingSample) -> nn.Module: input_dim = past_target.shape[1] output_dim = future_target.shape[1] + num_encoder_blocks = self.project_after_n_blocks + num_decoder_blocks = self.num_blocks - num_encoder_blocks + + # Raise exception for nonsensical number of encoder and decoder blocks + if ( + num_encoder_blocks < 0 + or num_decoder_blocks < 0 + or (num_encoder_blocks + num_decoder_blocks != self.num_blocks) + ): + raise_log( + ValueError( + f"Invalid number of encoder and decoder blocks. " + f"project_after_n_blocks must be between 0 and {self.num_blocks} inclusive." + ), + ) + static_cov_dim = ( static_covariates.shape[0] * static_covariates.shape[1] if static_covariates is not None @@ -813,7 +878,8 @@ def _create_model(self, train_sample: TorchTrainingSample) -> nn.Module: nr_params=nr_params, hidden_size=self.hidden_size, ff_size=self.ff_size, - num_blocks=self.num_blocks, + num_encoder_blocks=num_encoder_blocks, + num_decoder_blocks=num_decoder_blocks, activation=self.activation, dropout=self.dropout, norm_type=self.norm_type, diff --git a/darts/tests/models/forecasting/test_global_forecasting_models.py b/darts/tests/models/forecasting/test_global_forecasting_models.py index 37b9aaa22c..2428552333 100644 --- a/darts/tests/models/forecasting/test_global_forecasting_models.py +++ b/darts/tests/models/forecasting/test_global_forecasting_models.py @@ -170,7 +170,7 @@ "n_epochs": 10, "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"], }, - 60.0, + 75.0, ), ( GlobalNaiveAggregate, diff --git a/darts/tests/models/forecasting/test_tsmixer.py b/darts/tests/models/forecasting/test_tsmixer.py index 2932ee5b90..1e671bb999 100644 --- a/darts/tests/models/forecasting/test_tsmixer.py +++ b/darts/tests/models/forecasting/test_tsmixer.py @@ -362,3 +362,39 @@ def test_time_batch_norm_2d_gradients(self): output.mean().backward() assert input_tensor.grad is not None + + @pytest.mark.parametrize("project_after_n_blocks", [-1, 0, 1, 2, 3]) + def test_project_after_n_blocks(self, project_after_n_blocks): + ts = tg.sine_timeseries(length=36, freq="h") + input_len = 12 + output_len = 6 + + expect_exception = project_after_n_blocks == -1 or project_after_n_blocks == 3 + + if expect_exception: + with pytest.raises(ValueError): + model = TSMixerModel( + input_chunk_length=input_len, + output_chunk_length=output_len, + n_epochs=1, + project_after_n_blocks=project_after_n_blocks, + **tfm_kwargs, + ) + model.fit(ts) + + else: + model = TSMixerModel( + input_chunk_length=input_len, + output_chunk_length=output_len, + n_epochs=1, + project_after_n_blocks=project_after_n_blocks, + # Cover case of projecting future covs back to input dims + add_encoders={"cyclic": {"future": "hour"}}, + **tfm_kwargs, + ) + model.fit(ts) + model.predict(n=output_len, series=ts) + + # Assert that the encoder and decoder mixers have the expected number of blocks + assert len(model.model.decoder_mixer) == 2 - project_after_n_blocks + assert len(model.model.encoder_mixer) == project_after_n_blocks diff --git a/examples/21-TSMixer-examples.ipynb b/examples/21-TSMixer-examples.ipynb index 8abe8addb1..1346eb7700 100644 --- a/examples/21-TSMixer-examples.ipynb +++ b/examples/21-TSMixer-examples.ipynb @@ -103,6 +103,558 @@ "outputs": [ { "data": { + "application/vnd.microsoft.datawrangler.viewer.v0+json": { + "columns": [ + { + "name": "date", + "rawType": "datetime64[ns]", + "type": "datetime" + }, + { + "name": "HUFL", + "rawType": "float32", + "type": "float" + }, + { + "name": "HULL", + "rawType": "float32", + "type": "float" + }, + { + "name": "MUFL", + "rawType": "float32", + "type": "float" + }, + { + "name": "MULL", + "rawType": "float32", + "type": "float" + }, + { + "name": "LUFL", + "rawType": "float32", + "type": "float" + }, + { + "name": "LULL", + "rawType": "float32", + "type": "float" + }, + { + "name": "OT", + "rawType": "float32", + "type": "float" + } + ], + "conversionMethod": "pd.DataFrame", + "ref": "bc97ec41-d673-4779-b6b1-e7e39822f6fa", + "rows": [ + [ + "2016-07-01 00:00:00", + "5.827", + "2.009", + "1.599", + "0.462", + "4.203", + "1.34", + "30.531" + ], + [ + "2016-07-01 01:00:00", + "5.693", + "2.076", + "1.492", + "0.426", + "4.142", + "1.371", + "27.787" + ], + [ + "2016-07-01 02:00:00", + "5.157", + "1.741", + "1.279", + "0.355", + "3.777", + "1.218", + "27.787" + ], + [ + "2016-07-01 03:00:00", + "5.09", + "1.942", + "1.279", + "0.391", + "3.807", + "1.279", + "25.044" + ], + [ + "2016-07-01 04:00:00", + "5.358", + "1.942", + "1.492", + "0.462", + "3.868", + "1.279", + "21.948" + ], + [ + "2016-07-01 05:00:00", + "5.626", + "2.143", + "1.528", + "0.533", + "4.051", + "1.371", + "21.174" + ], + [ + "2016-07-01 06:00:00", + "7.167", + "2.947", + "2.132", + "0.782", + "5.026", + "1.858", + "22.792" + ], + [ + "2016-07-01 07:00:00", + "7.435", + "3.282", + "2.31", + "1.031", + "5.087", + "2.224", + "23.144" + ], + [ + "2016-07-01 08:00:00", + "5.559", + "3.014", + "2.452", + "1.173", + "2.955", + "1.432", + "21.667" + ], + [ + "2016-07-01 09:00:00", + "4.555", + "2.545", + "1.919", + "0.817", + "2.68", + "1.371", + "17.446" + ], + [ + "2016-07-01 10:00:00", + "4.957", + "2.545", + "1.99", + "0.853", + "2.955", + "1.492", + "19.979" + ], + [ + "2016-07-01 11:00:00", + "5.76", + "2.545", + "2.203", + "0.853", + "3.442", + "1.492", + "20.119" + ], + [ + "2016-07-01 12:00:00", + "4.689", + "2.545", + "1.812", + "0.853", + "2.833", + "1.523", + "19.205" + ], + [ + "2016-07-01 13:00:00", + "4.689", + "2.679", + "1.777", + "1.244", + "3.107", + "1.614", + "18.572" + ], + [ + "2016-07-01 14:00:00", + "5.09", + "2.947", + "2.452", + "1.35", + "2.559", + "1.432", + "19.556" + ], + [ + "2016-07-01 15:00:00", + "5.09", + "3.148", + "2.487", + "1.35", + "2.589", + "1.523", + "17.305" + ], + [ + "2016-07-01 16:00:00", + "4.22", + "2.411", + "1.706", + "0.782", + "2.619", + "1.492", + "19.486" + ], + [ + "2016-07-01 17:00:00", + "4.756", + "2.344", + "1.635", + "0.711", + "3.076", + "1.492", + "19.134" + ], + [ + "2016-07-01 18:00:00", + "5.626", + "2.88", + "2.523", + "1.208", + "3.076", + "1.492", + "20.682" + ], + [ + "2016-07-01 19:00:00", + "5.492", + "3.014", + "2.452", + "1.208", + "3.015", + "1.553", + "18.712" + ], + [ + "2016-07-01 20:00:00", + "5.358", + "3.014", + "2.452", + "1.208", + "2.863", + "1.523", + "17.868" + ], + [ + "2016-07-01 21:00:00", + "5.09", + "2.947", + "2.381", + "1.208", + "2.68", + "1.523", + "18.009" + ], + [ + "2016-07-01 22:00:00", + "4.823", + "2.947", + "2.203", + "1.173", + "2.619", + "1.523", + "18.009" + ], + [ + "2016-07-01 23:00:00", + "4.622", + "2.88", + "2.132", + "1.137", + "2.467", + "1.492", + "19.768" + ], + [ + "2016-07-02 00:00:00", + "5.224", + "3.081", + "2.701", + "1.315", + "2.437", + "1.523", + "21.104" + ], + [ + "2016-07-02 01:00:00", + "5.157", + "3.014", + "2.878", + "1.35", + "2.345", + "1.432", + "19.697" + ], + [ + "2016-07-02 02:00:00", + "5.157", + "3.148", + "2.878", + "1.492", + "2.284", + "1.432", + "20.049" + ], + [ + "2016-07-02 03:00:00", + "5.157", + "3.081", + "2.914", + "1.492", + "2.193", + "1.401", + "20.752" + ], + [ + "2016-07-02 04:00:00", + "4.555", + "3.081", + "2.452", + "1.492", + "2.193", + "1.401", + "21.385" + ], + [ + "2016-07-02 05:00:00", + "5.425", + "3.282", + "3.092", + "1.706", + "2.437", + "1.462", + "22.23" + ], + [ + "2016-07-02 06:00:00", + "5.492", + "3.282", + "2.523", + "1.492", + "2.985", + "1.462", + "20.26" + ], + [ + "2016-07-02 07:00:00", + "5.626", + "3.215", + "2.487", + "1.492", + "3.076", + "1.523", + "21.104" + ], + [ + "2016-07-02 08:00:00", + "5.559", + "3.282", + "2.594", + "1.67", + "2.924", + "1.523", + "20.612" + ], + [ + "2016-07-02 09:00:00", + "5.224", + "3.215", + "2.559", + "1.564", + "2.68", + "1.462", + "18.361" + ], + [ + "2016-07-02 10:00:00", + "9.913", + "4.957", + "6.645", + "3.305", + "3.046", + "1.553", + "20.963" + ], + [ + "2016-07-02 11:00:00", + "11.788", + "5.425", + "8.173", + "2.523", + "3.686", + "1.675", + "19.416" + ], + [ + "2016-07-02 12:00:00", + "9.645", + "4.957", + "6.752", + "2.132", + "3.107", + "1.828", + "20.823" + ], + [ + "2016-07-02 13:00:00", + "10.382", + "5.76", + "7.462", + "2.559", + "2.985", + "1.767", + "20.19" + ], + [ + "2016-07-02 14:00:00", + "8.774", + "4.689", + "6.112", + "2.025", + "2.894", + "1.919", + "21.315" + ], + [ + "2016-07-02 15:00:00", + "10.449", + "5.157", + "6.965", + "2.452", + "2.772", + "1.736", + "22.019" + ], + [ + "2016-07-02 16:00:00", + "9.846", + "4.823", + "7.036", + "2.665", + "2.894", + "1.767", + "20.682" + ], + [ + "2016-07-02 17:00:00", + "9.913", + "4.823", + "6.894", + "2.416", + "3.229", + "1.736", + "25.466" + ], + [ + "2016-07-02 18:00:00", + "10.65", + "4.689", + "6.929", + "2.452", + "3.381", + "1.797", + "25.888" + ], + [ + "2016-07-02 19:00:00", + "10.114", + "4.354", + "6.645", + "1.812", + "3.107", + "1.736", + "27.857" + ], + [ + "2016-07-02 20:00:00", + "9.98", + "4.153", + "6.574", + "1.954", + "3.411", + "1.767", + "27.295" + ], + [ + "2016-07-02 21:00:00", + "9.31", + "4.22", + "6.005", + "2.132", + "3.229", + "1.858", + "22.23" + ], + [ + "2016-07-02 22:00:00", + "9.444", + "4.622", + "6.965", + "2.168", + "2.955", + "1.858", + "21.948" + ], + [ + "2016-07-02 23:00:00", + "9.444", + "4.287", + "6.823", + "2.559", + "2.589", + "1.736", + "27.295" + ], + [ + "2016-07-03 00:00:00", + "10.382", + "5.425", + "7.604", + "2.31", + "2.955", + "1.675", + "29.335" + ], + [ + "2016-07-03 01:00:00", + "9.779", + "5.224", + "6.716", + "2.843", + "2.65", + "1.675", + "26.028" + ] + ], + "shape": { + "columns": 7, + "rows": 17420 + } + }, "text/html": [ "
\n", "