-
Notifications
You must be signed in to change notification settings - Fork 1k
Imp/tsmixer basic #2555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Imp/tsmixer basic #2555
Changes from all commits
e12a751
1bd17da
34e431f
bb72922
1934928
c5de724
b8a2e88
c6c7e97
79612f0
6cdc9db
171dc34
f9797a3
e528f85
133547e
5afff01
3a392a3
ebe02d1
0a90f24
2674e1c
2bf09ce
245ae09
f9c0d15
51e2b11
c86b559
e6647c0
1483096
97c83b2
bf912d9
46cf888
c374724
67a89f1
ef1a322
0dd0c30
a4891e2
a46645a
c5a5fe4
af74b5b
8550314
df935cc
2c156f5
694ca4a
94c6d71
f31af84
784e9ea
7ca964b
176dd5c
7ad36bd
552c92b
83264d4
1e48108
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -264,6 +264,7 @@ def __init__( | |||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| mixing_input = input_dim | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if static_cov_dim != 0: | ||||||||||||||||||||||||||||||||
| self.feature_mixing_static = _FeatureMixing( | ||||||||||||||||||||||||||||||||
| sequence_length=sequence_length, | ||||||||||||||||||||||||||||||||
|
|
@@ -312,28 +313,34 @@ def __init__( | |||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||
| input_dim: int, | ||||||||||||||||||||||||||||||||
| output_dim: int, | ||||||||||||||||||||||||||||||||
| num_encoder_blocks: int, | ||||||||||||||||||||||||||||||||
| num_decoder_blocks: int, | ||||||||||||||||||||||||||||||||
| past_cov_dim: int, | ||||||||||||||||||||||||||||||||
| future_cov_dim: int, | ||||||||||||||||||||||||||||||||
| static_cov_dim: int, | ||||||||||||||||||||||||||||||||
| nr_params: int, | ||||||||||||||||||||||||||||||||
| hidden_size: int, | ||||||||||||||||||||||||||||||||
| ff_size: int, | ||||||||||||||||||||||||||||||||
| num_blocks: int, | ||||||||||||||||||||||||||||||||
| activation: str, | ||||||||||||||||||||||||||||||||
| dropout: float, | ||||||||||||||||||||||||||||||||
| norm_type: str | nn.Module, | ||||||||||||||||||||||||||||||||
| normalize_before: bool, | ||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| Initializes the TSMixer module for use within a Darts forecasting model. | ||||||||||||||||||||||||||||||||
| Initializes the TSMixer module. | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||||||
| input_dim | ||||||||||||||||||||||||||||||||
| Number of input target features. | ||||||||||||||||||||||||||||||||
| output_dim | ||||||||||||||||||||||||||||||||
| Number of output target features. | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| num_encoder_blocks | ||||||||||||||||||||||||||||||||
| Number of encoder blocks (that operate on input_chunk_length). | ||||||||||||||||||||||||||||||||
| num_decoder_blocks | ||||||||||||||||||||||||||||||||
| Number of decoder blocks (that operate on output_chunk_length). | ||||||||||||||||||||||||||||||||
| past_cov_dim | ||||||||||||||||||||||||||||||||
| Number of past covariate features. | ||||||||||||||||||||||||||||||||
| future_cov_dim | ||||||||||||||||||||||||||||||||
|
|
@@ -347,8 +354,6 @@ def __init__( | |||||||||||||||||||||||||||||||
| Hidden state size of the TSMixer. | ||||||||||||||||||||||||||||||||
| ff_size | ||||||||||||||||||||||||||||||||
| Dimension of the feedforward network internal to the module. | ||||||||||||||||||||||||||||||||
| num_blocks | ||||||||||||||||||||||||||||||||
| Number of mixer blocks. | ||||||||||||||||||||||||||||||||
| activation | ||||||||||||||||||||||||||||||||
| Activation function to use. | ||||||||||||||||||||||||||||||||
| dropout | ||||||||||||||||||||||||||||||||
|
|
@@ -368,7 +373,7 @@ def __init__( | |||||||||||||||||||||||||||||||
| if activation not in ACTIVATIONS: | ||||||||||||||||||||||||||||||||
| raise_log( | ||||||||||||||||||||||||||||||||
| ValueError( | ||||||||||||||||||||||||||||||||
| f"Invalid `activation={activation}`. Must be on of {ACTIVATIONS}." | ||||||||||||||||||||||||||||||||
| f"Invalid `activation={activation}`. Must be one of {ACTIVATIONS}." | ||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||
| logger=logger, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
@@ -378,7 +383,7 @@ def __init__( | |||||||||||||||||||||||||||||||
| if norm_type not in NORMS: | ||||||||||||||||||||||||||||||||
| raise_log( | ||||||||||||||||||||||||||||||||
| ValueError( | ||||||||||||||||||||||||||||||||
| f"Invalid `norm_type={norm_type}`. Must be on of {NORMS}." | ||||||||||||||||||||||||||||||||
| f"Invalid `norm_type={norm_type}`. Must be one of {NORMS}." | ||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||
| logger=logger, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
@@ -397,13 +402,17 @@ def __init__( | |||||||||||||||||||||||||||||||
| "normalize_before": normalize_before, | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Projects from the input time dimension to the output time dimension | ||||||||||||||||||||||||||||||||
| self.fc_hist = nn.Linear(self.input_chunk_length, self.output_chunk_length) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| self.feature_mixing_hist = _FeatureMixing( | ||||||||||||||||||||||||||||||||
| sequence_length=self.output_chunk_length, | ||||||||||||||||||||||||||||||||
| sequence_length=self.input_chunk_length, | ||||||||||||||||||||||||||||||||
| input_dim=input_dim + past_cov_dim + future_cov_dim, | ||||||||||||||||||||||||||||||||
| output_dim=hidden_size, | ||||||||||||||||||||||||||||||||
| **mixer_params, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Process future covariates in decoder (if exists) | ||||||||||||||||||||||||||||||||
| if future_cov_dim: | ||||||||||||||||||||||||||||||||
| self.feature_mixing_future = _FeatureMixing( | ||||||||||||||||||||||||||||||||
| sequence_length=self.output_chunk_length, | ||||||||||||||||||||||||||||||||
|
|
@@ -413,19 +422,40 @@ def __init__( | |||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| self.feature_mixing_future = None | ||||||||||||||||||||||||||||||||
| self.conditional_mixer = self._build_mixer( | ||||||||||||||||||||||||||||||||
| prediction_length=self.output_chunk_length, | ||||||||||||||||||||||||||||||||
| num_blocks=num_blocks, | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Remove previous fc_hist and fc_future. | ||||||||||||||||||||||||||||||||
| # New projection from encoder (input_chunk_length) to decoder (output_chunk_length) | ||||||||||||||||||||||||||||||||
| self.encoder_to_decoder = nn.Linear( | ||||||||||||||||||||||||||||||||
| self.input_chunk_length, self.output_chunk_length | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Build encoder mixer (operating on input_chunk_length) | ||||||||||||||||||||||||||||||||
| self.encoder_mixer = self._build_mixer( | ||||||||||||||||||||||||||||||||
| sequence_length=self.input_chunk_length, | ||||||||||||||||||||||||||||||||
| num_blocks=num_encoder_blocks, | ||||||||||||||||||||||||||||||||
| hidden_size=hidden_size, | ||||||||||||||||||||||||||||||||
| future_cov_dim=0, # encoder mixing uses only historical features | ||||||||||||||||||||||||||||||||
| static_cov_dim=static_cov_dim, | ||||||||||||||||||||||||||||||||
| **mixer_params, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| # Build decoder mixer (operating on output_chunk_length) | ||||||||||||||||||||||||||||||||
| self.decoder_mixer = self._build_mixer( | ||||||||||||||||||||||||||||||||
| sequence_length=self.output_chunk_length, | ||||||||||||||||||||||||||||||||
| num_blocks=num_decoder_blocks, | ||||||||||||||||||||||||||||||||
| hidden_size=hidden_size, | ||||||||||||||||||||||||||||||||
| future_cov_dim=future_cov_dim, | ||||||||||||||||||||||||||||||||
| static_cov_dim=static_cov_dim, | ||||||||||||||||||||||||||||||||
| **mixer_params, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| self.fc_out = nn.Linear(hidden_size, output_dim * nr_params) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| self.fc_out = nn.Linear( | ||||||||||||||||||||||||||||||||
| hidden_size * (1 + int((num_decoder_blocks == 0) and (future_cov_dim > 0))), | ||||||||||||||||||||||||||||||||
| output_dim * nr_params, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||
| def _build_mixer( | ||||||||||||||||||||||||||||||||
| prediction_length: int, | ||||||||||||||||||||||||||||||||
| sequence_length: int, | ||||||||||||||||||||||||||||||||
| num_blocks: int, | ||||||||||||||||||||||||||||||||
| hidden_size: int, | ||||||||||||||||||||||||||||||||
| future_cov_dim: int, | ||||||||||||||||||||||||||||||||
|
|
@@ -436,14 +466,15 @@ def _build_mixer( | |||||||||||||||||||||||||||||||
| # the first block takes `x` consisting of concatenated features with size `hidden_size`: | ||||||||||||||||||||||||||||||||
| # - historic features | ||||||||||||||||||||||||||||||||
| # - optional future features | ||||||||||||||||||||||||||||||||
| input_dim_block = hidden_size * (1 + int(future_cov_dim > 0)) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| input_dim_block = hidden_size * ( | ||||||||||||||||||||||||||||||||
| 1 + int(future_cov_dim > 0) | ||||||||||||||||||||||||||||||||
| ) # starting dimension for mixer layers | ||||||||||||||||||||||||||||||||
| mixer_layers = nn.ModuleList() | ||||||||||||||||||||||||||||||||
| for _ in range(num_blocks): | ||||||||||||||||||||||||||||||||
| layer = _ConditionalMixerLayer( | ||||||||||||||||||||||||||||||||
| input_dim=input_dim_block, | ||||||||||||||||||||||||||||||||
| output_dim=hidden_size, | ||||||||||||||||||||||||||||||||
| sequence_length=prediction_length, | ||||||||||||||||||||||||||||||||
| sequence_length=sequence_length, | ||||||||||||||||||||||||||||||||
| static_cov_dim=static_cov_dim, | ||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
@@ -478,39 +509,53 @@ def forward(self, x_in: PLModuleInput) -> torch.Tensor: | |||||||||||||||||||||||||||||||
| # S: static cov features | ||||||||||||||||||||||||||||||||
| # H = C + P + F: historic features | ||||||||||||||||||||||||||||||||
| # H_S: hidden Size | ||||||||||||||||||||||||||||||||
| # N_P: likelihood parameters | ||||||||||||||||||||||||||||||||
| # N_P: number of parameters to predict per target feature | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # `x`: (B, L, H), `x_future`: (B, T, F), `x_static`: (B, C or 1, S) | ||||||||||||||||||||||||||||||||
| x, x_future, x_static = x_in | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # swap feature and time dimensions (B, L, H) -> (B, H, L) | ||||||||||||||||||||||||||||||||
| x = _time_to_feature(x) | ||||||||||||||||||||||||||||||||
| # linear transformations to horizon (B, H, L) -> (B, H, T) | ||||||||||||||||||||||||||||||||
| x = self.fc_hist(x) | ||||||||||||||||||||||||||||||||
| # (B, H, T) -> (B, T, H) | ||||||||||||||||||||||||||||||||
| x = _time_to_feature(x) | ||||||||||||||||||||||||||||||||
| if self.static_cov_dim: | ||||||||||||||||||||||||||||||||
| # (B, C, S) -> (B, 1, C * S) | ||||||||||||||||||||||||||||||||
| x_static_hist = x_static.reshape(x_static.shape[0], 1, -1) | ||||||||||||||||||||||||||||||||
| # repeat to match lookback time dim: (B, 1, C * S) -> (B, L, C * S) | ||||||||||||||||||||||||||||||||
| x_static_hist = x_static_hist.repeat(1, self.input_chunk_length, 1) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # (B, C, S) -> (B, 1, C * S) | ||||||||||||||||||||||||||||||||
| x_static_future = x_static.reshape(x_static.shape[0], 1, -1) | ||||||||||||||||||||||||||||||||
| # repeat to match horizon time dim: (B, 1, C * S) -> (B, T, C * S) | ||||||||||||||||||||||||||||||||
| x_static_future = x_static_future.repeat(1, self.output_chunk_length, 1) | ||||||||||||||||||||||||||||||||
|
Comment on lines
+518
to
+526
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be simplified. Also, it would be nice to only create x_static_hist / x_static_future if they are really required (e.g. future only if the encoder is not used). Any type of operation that can be avoided has positive impact on the model throughput :)
Suggested change
|
||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| x_static_hist = None | ||||||||||||||||||||||||||||||||
| x_static_future = None | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # feature mixing for historical features (B, T, H) -> (B, T, H_S) | ||||||||||||||||||||||||||||||||
| # Process historical data (B, L, H) -> (B, L, H_S) | ||||||||||||||||||||||||||||||||
| x = self.feature_mixing_hist(x) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Process future data (B, T, F) -> (B, T, H_S) | ||||||||||||||||||||||||||||||||
| if self.future_cov_dim: | ||||||||||||||||||||||||||||||||
| # feature mixing for future features (B, T, F) -> (B, T, H_S) | ||||||||||||||||||||||||||||||||
| x_future = self.feature_mixing_future(x_future) | ||||||||||||||||||||||||||||||||
| # (B, T, H_S) + (B, T, H_S) -> (B, T, 2*H_S) | ||||||||||||||||||||||||||||||||
| x = torch.cat([x, x_future], dim=-1) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if self.static_cov_dim: | ||||||||||||||||||||||||||||||||
| # (B, C, S) -> (B, 1, C * S) | ||||||||||||||||||||||||||||||||
| x_static = x_static.reshape(x_static.shape[0], 1, -1) | ||||||||||||||||||||||||||||||||
| # repeat to match horizon (B, 1, C * S) -> (B, T, C * S) | ||||||||||||||||||||||||||||||||
| x_static = x_static.repeat(1, self.output_chunk_length, 1) | ||||||||||||||||||||||||||||||||
| # Apply encoder mixer layers | ||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add the shape transformation (e.g. |
||||||||||||||||||||||||||||||||
| for layer in self.encoder_mixer: | ||||||||||||||||||||||||||||||||
| x = layer(x, x_static_hist) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Project time dimension (B, L, H_S) -> (B, T, H_S) | ||||||||||||||||||||||||||||||||
| x = x.transpose(1, 2) | ||||||||||||||||||||||||||||||||
| x = self.encoder_to_decoder(x) # Linear map | ||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm.. The This means Meaning, users will probably not be able to reproduce results from earlier Darts versions. |
||||||||||||||||||||||||||||||||
| x = x.transpose(1, 2) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # If future covariates are provided, mix and concatenate them with the encoder output | ||||||||||||||||||||||||||||||||
| if self.future_cov_dim: | ||||||||||||||||||||||||||||||||
| # (B, T, H_S) -> (B, T, 2 * H_S) | ||||||||||||||||||||||||||||||||
| x = torch.cat([x, x_future], dim=-1) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| for mixing_layer in self.conditional_mixer: | ||||||||||||||||||||||||||||||||
| # conditional mixer layers with static covariates (B, T, 2 * H_S), (B, T, C * S) -> (B, T, H_S) | ||||||||||||||||||||||||||||||||
| x = mixing_layer(x, x_static=x_static) | ||||||||||||||||||||||||||||||||
| # Apply decoder mixer layers | ||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's add the shape transformation |
||||||||||||||||||||||||||||||||
| for layer in self.decoder_mixer: | ||||||||||||||||||||||||||||||||
| x = layer(x, x_static_future) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # linear transformation to generate the forecast (B, T, H_S) -> (B, T, C * N_P) | ||||||||||||||||||||||||||||||||
| # Forecast generation | ||||||||||||||||||||||||||||||||
| # (B, T, H_S) -> (B, T, C * N_P) | ||||||||||||||||||||||||||||||||
| x = self.fc_out(x) | ||||||||||||||||||||||||||||||||
| # (B, T, C * N_P) -> (B, T, C, N_P) | ||||||||||||||||||||||||||||||||
| x = x.view(-1, self.output_chunk_length, self.output_dim, self.nr_params) | ||||||||||||||||||||||||||||||||
| return x | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
@@ -529,6 +574,7 @@ def __init__( | |||||||||||||||||||||||||||||||
| norm_type: str | nn.Module = "LayerNorm", | ||||||||||||||||||||||||||||||||
| normalize_before: bool = False, | ||||||||||||||||||||||||||||||||
| use_static_covariates: bool = True, | ||||||||||||||||||||||||||||||||
| project_after_n_blocks: int = 0, | ||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||
| """Time-Series Mixer (TSMixer): An All-MLP Architecture for Time Series. | ||||||||||||||||||||||||||||||||
|
|
@@ -570,8 +616,10 @@ def __init__( | |||||||||||||||||||||||||||||||
| The hidden state size / size of the second feed-forward layer in the feature mixing MLP. | ||||||||||||||||||||||||||||||||
| ff_size | ||||||||||||||||||||||||||||||||
| The size of the first feed-forward layer in the feature mixing MLP. | ||||||||||||||||||||||||||||||||
| num_blocks | ||||||||||||||||||||||||||||||||
| The number of mixer blocks in the model. The number includes the first block and all subsequent blocks. | ||||||||||||||||||||||||||||||||
| num_encoder_blocks | ||||||||||||||||||||||||||||||||
| The number of mixer blocks in the encoder. | ||||||||||||||||||||||||||||||||
| num_decoder_blocks | ||||||||||||||||||||||||||||||||
| The number of mixer blocks in the decoder. | ||||||||||||||||||||||||||||||||
|
Comment on lines
+619
to
+622
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are not available, and description of
Suggested change
|
||||||||||||||||||||||||||||||||
| activation | ||||||||||||||||||||||||||||||||
| The activation function to use in the mixer layers (default='ReLU'). | ||||||||||||||||||||||||||||||||
| Supported activations: ['ReLU', 'RReLU', 'PReLU', 'ELU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'Sigmoid', | ||||||||||||||||||||||||||||||||
|
|
@@ -762,11 +810,12 @@ def encode_year(idx): | |||||||||||||||||||||||||||||||
| # Model specific parameters | ||||||||||||||||||||||||||||||||
| self.ff_size = ff_size | ||||||||||||||||||||||||||||||||
| self.dropout = dropout | ||||||||||||||||||||||||||||||||
| self.num_blocks = num_blocks | ||||||||||||||||||||||||||||||||
| self.activation = activation | ||||||||||||||||||||||||||||||||
| self.normalize_before = normalize_before | ||||||||||||||||||||||||||||||||
| self.norm_type = norm_type | ||||||||||||||||||||||||||||||||
| self.hidden_size = hidden_size | ||||||||||||||||||||||||||||||||
| self.num_blocks = num_blocks | ||||||||||||||||||||||||||||||||
| self.project_after_n_blocks = project_after_n_blocks | ||||||||||||||||||||||||||||||||
| self._considers_static_covariates = use_static_covariates | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _create_model(self, train_sample: TorchTrainingSample) -> nn.Module: | ||||||||||||||||||||||||||||||||
|
|
@@ -793,6 +842,22 @@ def _create_model(self, train_sample: TorchTrainingSample) -> nn.Module: | |||||||||||||||||||||||||||||||
| input_dim = past_target.shape[1] | ||||||||||||||||||||||||||||||||
| output_dim = future_target.shape[1] | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| num_encoder_blocks = self.project_after_n_blocks | ||||||||||||||||||||||||||||||||
| num_decoder_blocks = self.num_blocks - num_encoder_blocks | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Raise exception for nonsensical number of encoder and decoder blocks | ||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||
| num_encoder_blocks < 0 | ||||||||||||||||||||||||||||||||
| or num_decoder_blocks < 0 | ||||||||||||||||||||||||||||||||
| or (num_encoder_blocks + num_decoder_blocks != self.num_blocks) | ||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||
| raise_log( | ||||||||||||||||||||||||||||||||
| ValueError( | ||||||||||||||||||||||||||||||||
| f"Invalid number of encoder and decoder blocks. " | ||||||||||||||||||||||||||||||||
| f"project_after_n_blocks must be between 0 and {self.num_blocks} inclusive." | ||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
Comment on lines
+848
to
+859
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should move this sanity check to |
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| static_cov_dim = ( | ||||||||||||||||||||||||||||||||
| static_covariates.shape[0] * static_covariates.shape[1] | ||||||||||||||||||||||||||||||||
| if static_covariates is not None | ||||||||||||||||||||||||||||||||
|
|
@@ -813,7 +878,8 @@ def _create_model(self, train_sample: TorchTrainingSample) -> nn.Module: | |||||||||||||||||||||||||||||||
| nr_params=nr_params, | ||||||||||||||||||||||||||||||||
| hidden_size=self.hidden_size, | ||||||||||||||||||||||||||||||||
| ff_size=self.ff_size, | ||||||||||||||||||||||||||||||||
| num_blocks=self.num_blocks, | ||||||||||||||||||||||||||||||||
| num_encoder_blocks=num_encoder_blocks, | ||||||||||||||||||||||||||||||||
| num_decoder_blocks=num_decoder_blocks, | ||||||||||||||||||||||||||||||||
| activation=self.activation, | ||||||||||||||||||||||||||||||||
| dropout=self.dropout, | ||||||||||||||||||||||||||||||||
| norm_type=self.norm_type, | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.fc_histis not used anymore (I guess replaced by the newencoder_to_decoder?). We should remove one of the two