diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 219908c1e47c..7f6a01bc3b16 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -19,6 +19,7 @@ # limitations under the License. import math +import warnings from collections.abc import Callable, Sequence from dataclasses import dataclass @@ -381,8 +382,10 @@ def forward( Past values of the time series that serves as input to the model. past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The padding indicator of the time series. - freq (`torch.LongTensor` of shape `(batch_size,)`): - Frequency indices for the time series data. + freq (`torch.LongTensor` of shape `(batch_size,)` or `Sequence[int]`, *optional*): + Frequency indices for the time series data. Defaults to a zero tensor (high-frequency). A + sequence of ints is also accepted and converted to a tensor internally. Tensor inputs are + preferred and required for export. """ # Reshape into patches (using view for efficiency) bsize = past_values.shape[0] @@ -590,13 +593,20 @@ def __init__(self, config: TimesFmConfig): self.post_init() def _preprocess( - self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None + self, + inputs: torch.Tensor, + observed_mask: torch.Tensor, + freq: torch.Tensor | None = None, + context_len: int | None = None, ) -> tuple[torch.Tensor, ...]: """Pad/truncate input time series to `context_len` and build a padding mask. Args: - inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task. - freq: Optional list of frequencies (returned as a tensor when provided). + inputs: A 2D `torch.Tensor` of shape `(batch_size, sequence_length)`. + observed_mask: A 2D `torch.Tensor` of the same shape as `inputs` where `1` indicates an observed + value and `0` indicates a missing value. Missing positions are marked as padded in the + returned padding mask. + freq: Optional 1D `torch.Tensor` of frequency indices. context_len: Optional context length override (defaults to `self.context_len`). Returns: @@ -605,25 +615,20 @@ def _preprocess( if context_len is None: context_len = self.context_len - input_ts, input_padding = [], [] - - for ts in inputs: - input_len = ts.shape[0] - padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) - if input_len < context_len: - num_front_pad = context_len - input_len - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) - padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) - elif input_len > context_len: - ts = ts[-context_len:] - padding = padding[-(context_len + self.horizon_len) :] + x = inputs[:, -context_len:] + num_front_pad = context_len - x.shape[1] + x = F.pad(x, (num_front_pad, 0)) - input_ts.append(ts) - input_padding.append(padding) + obs = observed_mask[:, -context_len:].to(dtype=x.dtype) + body_padding = 1 - obs + front_padding = torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device) + horizon_padding = torch.zeros(x.shape[0], self.horizon_len, dtype=x.dtype, device=x.device) + padding = torch.cat([front_padding, body_padding, horizon_padding], dim=1) + result = (x, padding) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) + freq_tensor = freq[: x.shape[0]].to(dtype=torch.int32) + result = result + (freq_tensor.reshape(-1, 1),) return result def _postprocess_output( @@ -653,8 +658,9 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to @auto_docstring def forward( self, - past_values: Sequence[torch.Tensor], - freq: Sequence[torch.Tensor | int] | None = None, + past_values: Sequence[torch.Tensor] | torch.Tensor, + past_observed_mask: torch.Tensor | None = None, + freq: Sequence[int] | torch.Tensor | None = None, window_size: int | None = None, future_values: torch.Tensor | None = None, forecast_context_len: int | None = None, @@ -664,9 +670,23 @@ def forward( ) -> TimesFmOutputForPrediction: r""" past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): - Past values of the time series that serves as input to the model. - freq (`torch.LongTensor` of shape `(batch_size,)`): - Frequency indices for the time series data. + Past values of the time series that serves as input to the model. A list of 1D tensors with + possibly differing lengths is also accepted (deprecated): each tensor is front-padded with zeros + and stacked into a 2D tensor. Tensor inputs are preferred and required for export. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Boolean mask indicating which `past_values` were observed and which are padding/missing. Mask + values selected in `[0, 1]`: + + - `1` for values that are **observed**, + - `0` for values that are **missing** (i.e. padded or NaNs that were replaced by zeros). + + Defaults to a tensor of ones (everything observed). When `past_values` is passed as a list of + variable-length tensors, you should provide a matching `past_observed_mask` so the front-padding + zeros are not treated as observed values. + freq (`torch.LongTensor` of shape `(batch_size,)` or `Sequence[int]`, *optional*): + Frequency indices for the time series data. Defaults to a zero tensor (high-frequency). A + sequence of ints is also accepted and converted to a tensor internally. Tensor inputs are + preferred and required for export. window_size (`int`, *optional*): Window size of trend + residual decomposition. If None then we do not do decomposition. future_values (`torch.Tensor`, *optional*): @@ -686,7 +706,7 @@ def forward( >>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch") - >>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()] + >>> forecast_input = torch.stack([torch.linspace(0, 20, 400).sin() for _ in range(3)]) >>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long) >>> # Generate @@ -701,27 +721,36 @@ def forward( else: fcontext_len = forecast_context_len - device = past_values[0].device - - inputs = [ts[-fcontext_len:] for ts in past_values] - inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) - - if window_size is not None: - new_inputs = [] - new_freqs = [] - for i, ts in enumerate(inputs): - new_inputs.extend(self._timesfm_moving_average(ts, window_size)) - if freq is not None: - new_freqs.extend([freq[i]] * 2) - inputs = new_inputs - if freq is not None: - freq = new_freqs + if isinstance(past_values, list): + warnings.warn( + "Passing `past_values` as a list of 1D tensors is deprecated and will be removed in a future " + "version. Please pass a 2D `torch.Tensor` of shape `(batch_size, sequence_length)` and, when " + "needed, a `past_observed_mask` of the same shape (1 = observed, 0 = padded/missing).", + FutureWarning, + ) + past_values = self._past_values_to_tensor(past_values) + device = past_values.device + if past_observed_mask is None: + past_observed_mask = torch.ones_like(past_values) if freq is None: logger.info("No frequency provided via `freq`. Default to high (0).") - freq = [0] * len(inputs) + freq = torch.zeros(past_values.shape[0], dtype=torch.int32, device=device) + else: + freq = torch.as_tensor(freq, dtype=torch.int32, device=device) + + inputs = past_values[:, -fcontext_len:] + observed_mask = past_observed_mask[:, -fcontext_len:].to(device=device) + sentinel = torch.full_like(inputs, torch.finfo(inputs.dtype).max) + inp_min = torch.where(observed_mask.bool(), inputs, sentinel).min() + + if window_size is not None: + trend, residual = self._timesfm_moving_average(inputs, window_size) + inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1) + observed_mask = torch.repeat_interleave(observed_mask, 2, dim=0) + freq = torch.repeat_interleave(freq, 2, dim=0) - input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) + input_ts, input_padding, inp_freq = self._preprocess(inputs, observed_mask, freq=freq) input_ts = input_ts.to(device) input_padding = input_padding.to(device) inp_freq = inp_freq.to(device) @@ -794,15 +823,23 @@ def forward( ) @staticmethod - def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: - """Calculates the moving average using PyTorch's convolution function.""" - # Pad with zeros to handle initial window positions + def _past_values_to_tensor(past_values: Sequence[torch.Tensor]) -> torch.Tensor: + """Convert a list of variable-length 1D tensors into a 2D tensor of shape `(batch_size, max_len)` + by left-padding each entry with zeros. Equivalent to `torch.nn.utils.rnn.pad_sequence(past_values, + batch_first=True, padding_side="left")`, re-implemented here because `padding_side` requires + `torch>=2.5`. + """ + max_len = max(ts.shape[0] for ts in past_values) + return torch.stack([F.pad(ts, (max_len - ts.shape[0], 0)) for ts in past_values], dim=0) + + @staticmethod + def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates the moving average using PyTorch's convolution function. `arr` shape: `(B, T)`.""" arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) - # Create a convolution kernel kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size - # Apply convolution to calculate the moving average - smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() - return [smoothed_arr, arr - smoothed_arr] + kernel = kernel.view(1, 1, -1) + smoothed_arr = F.conv1d(arr_padded.unsqueeze(1), kernel).squeeze(1) + return smoothed_arr, arr - smoothed_arr __all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"] diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index ca53ec7dd668..429f47428c8f 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -14,6 +14,7 @@ """PyTorch TimesFM model.""" import math +import warnings from collections.abc import Callable, Sequence from dataclasses import dataclass @@ -338,8 +339,10 @@ def forward( Past values of the time series that serves as input to the model. past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The padding indicator of the time series. - freq (`torch.LongTensor` of shape `(batch_size,)`): - Frequency indices for the time series data. + freq (`torch.LongTensor` of shape `(batch_size,)` or `Sequence[int]`, *optional*): + Frequency indices for the time series data. Defaults to a zero tensor (high-frequency). A + sequence of ints is also accepted and converted to a tensor internally. Tensor inputs are + preferred and required for export. """ # Reshape into patches (using view for efficiency) bsize = past_values.shape[0] @@ -547,13 +550,20 @@ def __init__(self, config: TimesFmConfig): self.post_init() def _preprocess( - self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None + self, + inputs: torch.Tensor, + observed_mask: torch.Tensor, + freq: torch.Tensor | None = None, + context_len: int | None = None, ) -> tuple[torch.Tensor, ...]: """Pad/truncate input time series to `context_len` and build a padding mask. Args: - inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task. - freq: Optional list of frequencies (returned as a tensor when provided). + inputs: A 2D `torch.Tensor` of shape `(batch_size, sequence_length)`. + observed_mask: A 2D `torch.Tensor` of the same shape as `inputs` where `1` indicates an observed + value and `0` indicates a missing value. Missing positions are marked as padded in the + returned padding mask. + freq: Optional 1D `torch.Tensor` of frequency indices. context_len: Optional context length override (defaults to `self.context_len`). Returns: @@ -562,25 +572,20 @@ def _preprocess( if context_len is None: context_len = self.context_len - input_ts, input_padding = [], [] - - for ts in inputs: - input_len = ts.shape[0] - padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) - if input_len < context_len: - num_front_pad = context_len - input_len - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) - padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) - elif input_len > context_len: - ts = ts[-context_len:] - padding = padding[-(context_len + self.horizon_len) :] + x = inputs[:, -context_len:] + num_front_pad = context_len - x.shape[1] + x = F.pad(x, (num_front_pad, 0)) - input_ts.append(ts) - input_padding.append(padding) + obs = observed_mask[:, -context_len:].to(dtype=x.dtype) + body_padding = 1 - obs + front_padding = torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device) + horizon_padding = torch.zeros(x.shape[0], self.horizon_len, dtype=x.dtype, device=x.device) + padding = torch.cat([front_padding, body_padding, horizon_padding], dim=1) + result = (x, padding) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) + freq_tensor = freq[: x.shape[0]].to(dtype=torch.int32) + result = result + (freq_tensor.reshape(-1, 1),) return result def _postprocess_output( @@ -610,8 +615,9 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to @auto_docstring def forward( self, - past_values: Sequence[torch.Tensor], - freq: Sequence[torch.Tensor | int] | None = None, + past_values: Sequence[torch.Tensor] | torch.Tensor, + past_observed_mask: torch.Tensor | None = None, + freq: Sequence[int] | torch.Tensor | None = None, window_size: int | None = None, future_values: torch.Tensor | None = None, forecast_context_len: int | None = None, @@ -621,9 +627,23 @@ def forward( ) -> TimesFmOutputForPrediction: r""" past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): - Past values of the time series that serves as input to the model. - freq (`torch.LongTensor` of shape `(batch_size,)`): - Frequency indices for the time series data. + Past values of the time series that serves as input to the model. A list of 1D tensors with + possibly differing lengths is also accepted (deprecated): each tensor is front-padded with zeros + and stacked into a 2D tensor. Tensor inputs are preferred and required for export. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Boolean mask indicating which `past_values` were observed and which are padding/missing. Mask + values selected in `[0, 1]`: + + - `1` for values that are **observed**, + - `0` for values that are **missing** (i.e. padded or NaNs that were replaced by zeros). + + Defaults to a tensor of ones (everything observed). When `past_values` is passed as a list of + variable-length tensors, you should provide a matching `past_observed_mask` so the front-padding + zeros are not treated as observed values. + freq (`torch.LongTensor` of shape `(batch_size,)` or `Sequence[int]`, *optional*): + Frequency indices for the time series data. Defaults to a zero tensor (high-frequency). A + sequence of ints is also accepted and converted to a tensor internally. Tensor inputs are + preferred and required for export. window_size (`int`, *optional*): Window size of trend + residual decomposition. If None then we do not do decomposition. future_values (`torch.Tensor`, *optional*): @@ -643,7 +663,7 @@ def forward( >>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch") - >>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()] + >>> forecast_input = torch.stack([torch.linspace(0, 20, 400).sin() for _ in range(3)]) >>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long) >>> # Generate @@ -658,27 +678,36 @@ def forward( else: fcontext_len = forecast_context_len - device = past_values[0].device - - inputs = [ts[-fcontext_len:] for ts in past_values] - inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) - - if window_size is not None: - new_inputs = [] - new_freqs = [] - for i, ts in enumerate(inputs): - new_inputs.extend(self._timesfm_moving_average(ts, window_size)) - if freq is not None: - new_freqs.extend([freq[i]] * 2) - inputs = new_inputs - if freq is not None: - freq = new_freqs + if isinstance(past_values, list): + warnings.warn( + "Passing `past_values` as a list of 1D tensors is deprecated and will be removed in a future " + "version. Please pass a 2D `torch.Tensor` of shape `(batch_size, sequence_length)` and, when " + "needed, a `past_observed_mask` of the same shape (1 = observed, 0 = padded/missing).", + FutureWarning, + ) + past_values = self._past_values_to_tensor(past_values) + device = past_values.device + if past_observed_mask is None: + past_observed_mask = torch.ones_like(past_values) if freq is None: logger.info("No frequency provided via `freq`. Default to high (0).") - freq = [0] * len(inputs) + freq = torch.zeros(past_values.shape[0], dtype=torch.int32, device=device) + else: + freq = torch.as_tensor(freq, dtype=torch.int32, device=device) + + inputs = past_values[:, -fcontext_len:] + observed_mask = past_observed_mask[:, -fcontext_len:].to(device=device) + sentinel = torch.full_like(inputs, torch.finfo(inputs.dtype).max) + inp_min = torch.where(observed_mask.bool(), inputs, sentinel).min() + + if window_size is not None: + trend, residual = self._timesfm_moving_average(inputs, window_size) + inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1) + observed_mask = torch.repeat_interleave(observed_mask, 2, dim=0) + freq = torch.repeat_interleave(freq, 2, dim=0) - input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) + input_ts, input_padding, inp_freq = self._preprocess(inputs, observed_mask, freq=freq) input_ts = input_ts.to(device) input_padding = input_padding.to(device) inp_freq = inp_freq.to(device) @@ -751,15 +780,23 @@ def forward( ) @staticmethod - def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: - """Calculates the moving average using PyTorch's convolution function.""" - # Pad with zeros to handle initial window positions + def _past_values_to_tensor(past_values: Sequence[torch.Tensor]) -> torch.Tensor: + """Convert a list of variable-length 1D tensors into a 2D tensor of shape `(batch_size, max_len)` + by left-padding each entry with zeros. Equivalent to `torch.nn.utils.rnn.pad_sequence(past_values, + batch_first=True, padding_side="left")`, re-implemented here because `padding_side` requires + `torch>=2.5`. + """ + max_len = max(ts.shape[0] for ts in past_values) + return torch.stack([F.pad(ts, (max_len - ts.shape[0], 0)) for ts in past_values], dim=0) + + @staticmethod + def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates the moving average using PyTorch's convolution function. `arr` shape: `(B, T)`.""" arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) - # Create a convolution kernel kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size - # Apply convolution to calculate the moving average - smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() - return [smoothed_arr, arr - smoothed_arr] + kernel = kernel.view(1, 1, -1) + smoothed_arr = F.conv1d(arr_padded.unsqueeze(1), kernel).squeeze(1) + return smoothed_arr, arr - smoothed_arr __all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"] diff --git a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py index e7b4e799d20b..280352a1e1df 100644 --- a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py @@ -19,6 +19,7 @@ # limitations under the License. import math +import warnings from collections.abc import Callable, Sequence from dataclasses import dataclass from typing import Optional @@ -584,12 +585,15 @@ def forward( """ batch_size, seq_len = past_values.shape patch_len = self.config.patch_length + torch._check(seq_len % patch_len == 0) if past_values_padding is None: past_values_padding = torch.zeros_like(past_values, dtype=torch.long) + else: + past_values_padding = past_values_padding.narrow(1, 0, seq_len) - patched_inputs = past_values.view(batch_size, -1, patch_len) - patched_masks = past_values_padding[:, :seq_len].view(batch_size, -1, patch_len) + patched_inputs = past_values.unflatten(-1, (-1, patch_len)) + patched_masks = past_values_padding.unflatten(-1, (-1, patch_len)) patched_masks_bool = patched_masks >= 0.5 count = past_values.new_zeros(batch_size) @@ -682,13 +686,20 @@ def __init__(self, config: TimesFm2_5Config): self.post_init() def _preprocess( - self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None + self, + inputs: torch.Tensor, + observed_mask: torch.Tensor, + freq: torch.Tensor | None = None, + context_len: int | None = None, ) -> tuple[torch.Tensor, ...]: """Pad/truncate input time series to `context_len` and build a padding mask. Args: - inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task. - freq: Optional list of frequencies (returned as a tensor when provided). + inputs: A 2D `torch.Tensor` of shape `(batch_size, sequence_length)`. + observed_mask: A 2D `torch.Tensor` of the same shape as `inputs` where `1` indicates an observed + value and `0` indicates a missing value. Missing positions are marked as padded in the + returned padding mask. + freq: Optional 1D `torch.Tensor` of frequency indices. context_len: Optional context length override (defaults to `self.context_len`). Returns: @@ -697,25 +708,20 @@ def _preprocess( if context_len is None: context_len = self.context_len - input_ts, input_padding = [], [] - - for ts in inputs: - input_len = ts.shape[0] - padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) - if input_len < context_len: - num_front_pad = context_len - input_len - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) - padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) - elif input_len > context_len: - ts = ts[-context_len:] - padding = padding[-(context_len + self.horizon_len) :] + x = inputs[:, -context_len:] + num_front_pad = context_len - x.shape[1] + x = F.pad(x, (num_front_pad, 0)) - input_ts.append(ts) - input_padding.append(padding) + obs = observed_mask[:, -context_len:].to(dtype=x.dtype) + body_padding = 1 - obs + front_padding = torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device) + horizon_padding = torch.zeros(x.shape[0], self.horizon_len, dtype=x.dtype, device=x.device) + padding = torch.cat([front_padding, body_padding, horizon_padding], dim=1) + result = (x, padding) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) + freq_tensor = freq[: x.shape[0]].to(dtype=torch.int32) + result = result + (freq_tensor.reshape(-1, 1),) return result def _postprocess_output( @@ -745,7 +751,8 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to @auto_docstring def forward( self, - past_values: Sequence[torch.Tensor], + past_values: Sequence[torch.Tensor] | torch.Tensor, + past_observed_mask: torch.Tensor | None = None, window_size: int | None = None, future_values: torch.Tensor | None = None, forecast_context_len: int | None = None, @@ -754,8 +761,20 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> TimesFm2_5OutputForPrediction: r""" - past_values (`Sequence[torch.Tensor]`): - Past values of the time series that serves as input to the model. Each tensor is a 1D time series. + past_values (`torch.Tensor` of shape `(batch_size, sequence_length)`): + Past values of the time series that serves as input to the model. A list of 1D tensors with + possibly differing lengths is also accepted (deprecated): each tensor is front-padded with zeros + and stacked into a 2D tensor. Tensor inputs are preferred and required for export. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Boolean mask indicating which `past_values` were observed and which are padding/missing. Mask + values selected in `[0, 1]`: + + - `1` for values that are **observed**, + - `0` for values that are **missing** (i.e. padded or NaNs that were replaced by zeros). + + Defaults to a tensor of ones (everything observed). When `past_values` is passed as a list of + variable-length tensors, you should provide a matching `past_observed_mask` so the front-padding + zeros are not treated as observed values. window_size (`int`, *optional*): Window size of trend + residual decomposition. If `None`, decomposition is not applied. future_values (`torch.Tensor`, *optional*): @@ -769,23 +788,36 @@ def forward( `config.force_flip_invariance`. """ forecast_context_len = forecast_context_len or self.context_len - device = past_values[0].device - inputs = [ts[-forecast_context_len:] for ts in past_values] - input_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + if isinstance(past_values, list): + warnings.warn( + "Passing `past_values` as a list of 1D tensors is deprecated and will be removed in a future " + "version. Please pass a 2D `torch.Tensor` of shape `(batch_size, sequence_length)` and, when " + "needed, a `past_observed_mask` of the same shape (1 = observed, 0 = padded/missing).", + FutureWarning, + ) + past_values = self._past_values_to_tensor(past_values) + + device = past_values.device + if past_observed_mask is None: + past_observed_mask = torch.ones_like(past_values) + + inputs = past_values[:, -forecast_context_len:] + observed_mask = past_observed_mask[:, -forecast_context_len:].to(device=device) + sentinel = torch.full_like(inputs, torch.finfo(inputs.dtype).max) + input_min = torch.where(observed_mask.bool(), inputs, sentinel).min() if window_size is not None: - new_inputs: list[torch.Tensor] = [] - for ts in inputs: - new_inputs.extend(self._timesfm_moving_average(ts, window_size)) - inputs = new_inputs + trend, residual = self._timesfm2_5_moving_average(inputs, window_size) + inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1) + observed_mask = torch.repeat_interleave(observed_mask, 2, dim=0) if truncate_negative is None: truncate_negative = self.config.infer_is_positive if force_flip_invariance is None: force_flip_invariance = self.config.force_flip_invariance - input_ts, input_padding = self._preprocess(inputs, context_len=forecast_context_len) + input_ts, input_padding = self._preprocess(inputs, observed_mask, context_len=forecast_context_len) input_ts = input_ts.to(device) input_padding = input_padding.to(device) @@ -864,15 +896,23 @@ def _flip_quantiles(x: torch.Tensor) -> torch.Tensor: ) @staticmethod - def _timesfm2_5_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: - """Calculates the moving average using PyTorch's convolution function.""" - # Pad with zeros to handle initial window positions + def _past_values_to_tensor(past_values: Sequence[torch.Tensor]) -> torch.Tensor: + """Convert a list of variable-length 1D tensors into a 2D tensor of shape `(batch_size, max_len)` + by left-padding each entry with zeros. Equivalent to `torch.nn.utils.rnn.pad_sequence(past_values, + batch_first=True, padding_side="left")`, re-implemented here because `padding_side` requires + `torch>=2.5`. + """ + max_len = max(ts.shape[0] for ts in past_values) + return torch.stack([F.pad(ts, (max_len - ts.shape[0], 0)) for ts in past_values], dim=0) + + @staticmethod + def _timesfm2_5_moving_average(arr: torch.Tensor, window_size: int) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates the moving average using PyTorch's convolution function. `arr` shape: `(B, T)`.""" arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) - # Create a convolution kernel kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size - # Apply convolution to calculate the moving average - smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() - return [smoothed_arr, arr - smoothed_arr] + kernel = kernel.view(1, 1, -1) + smoothed_arr = F.conv1d(arr_padded.unsqueeze(1), kernel).squeeze(1) + return smoothed_arr, arr - smoothed_arr def _decode_and_project( self, diff --git a/src/transformers/models/timesfm2_5/modular_timesfm2_5.py b/src/transformers/models/timesfm2_5/modular_timesfm2_5.py index 3a912d07946b..762366d291c5 100644 --- a/src/transformers/models/timesfm2_5/modular_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modular_timesfm2_5.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import warnings from collections.abc import Callable, Sequence from dataclasses import dataclass @@ -387,12 +388,15 @@ def forward( """ batch_size, seq_len = past_values.shape patch_len = self.config.patch_length + torch._check(seq_len % patch_len == 0) if past_values_padding is None: past_values_padding = torch.zeros_like(past_values, dtype=torch.long) + else: + past_values_padding = past_values_padding.narrow(1, 0, seq_len) - patched_inputs = past_values.view(batch_size, -1, patch_len) - patched_masks = past_values_padding[:, :seq_len].view(batch_size, -1, patch_len) + patched_inputs = past_values.unflatten(-1, (-1, patch_len)) + patched_masks = past_values_padding.unflatten(-1, (-1, patch_len)) patched_masks_bool = patched_masks >= 0.5 count = past_values.new_zeros(batch_size) @@ -532,7 +536,8 @@ def _decode_and_project( @auto_docstring def forward( self, - past_values: Sequence[torch.Tensor], + past_values: Sequence[torch.Tensor] | torch.Tensor, + past_observed_mask: torch.Tensor | None = None, window_size: int | None = None, future_values: torch.Tensor | None = None, forecast_context_len: int | None = None, @@ -541,8 +546,20 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> TimesFm2_5OutputForPrediction: r""" - past_values (`Sequence[torch.Tensor]`): - Past values of the time series that serves as input to the model. Each tensor is a 1D time series. + past_values (`torch.Tensor` of shape `(batch_size, sequence_length)`): + Past values of the time series that serves as input to the model. A list of 1D tensors with + possibly differing lengths is also accepted (deprecated): each tensor is front-padded with zeros + and stacked into a 2D tensor. Tensor inputs are preferred and required for export. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Boolean mask indicating which `past_values` were observed and which are padding/missing. Mask + values selected in `[0, 1]`: + + - `1` for values that are **observed**, + - `0` for values that are **missing** (i.e. padded or NaNs that were replaced by zeros). + + Defaults to a tensor of ones (everything observed). When `past_values` is passed as a list of + variable-length tensors, you should provide a matching `past_observed_mask` so the front-padding + zeros are not treated as observed values. window_size (`int`, *optional*): Window size of trend + residual decomposition. If `None`, decomposition is not applied. future_values (`torch.Tensor`, *optional*): @@ -556,23 +573,36 @@ def forward( `config.force_flip_invariance`. """ forecast_context_len = forecast_context_len or self.context_len - device = past_values[0].device - inputs = [ts[-forecast_context_len:] for ts in past_values] - input_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + if isinstance(past_values, list): + warnings.warn( + "Passing `past_values` as a list of 1D tensors is deprecated and will be removed in a future " + "version. Please pass a 2D `torch.Tensor` of shape `(batch_size, sequence_length)` and, when " + "needed, a `past_observed_mask` of the same shape (1 = observed, 0 = padded/missing).", + FutureWarning, + ) + past_values = self._past_values_to_tensor(past_values) + + device = past_values.device + if past_observed_mask is None: + past_observed_mask = torch.ones_like(past_values) + + inputs = past_values[:, -forecast_context_len:] + observed_mask = past_observed_mask[:, -forecast_context_len:].to(device=device) + sentinel = torch.full_like(inputs, torch.finfo(inputs.dtype).max) + input_min = torch.where(observed_mask.bool(), inputs, sentinel).min() if window_size is not None: - new_inputs: list[torch.Tensor] = [] - for ts in inputs: - new_inputs.extend(self._timesfm_moving_average(ts, window_size)) - inputs = new_inputs + trend, residual = self._timesfm2_5_moving_average(inputs, window_size) + inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1) + observed_mask = torch.repeat_interleave(observed_mask, 2, dim=0) if truncate_negative is None: truncate_negative = self.config.infer_is_positive if force_flip_invariance is None: force_flip_invariance = self.config.force_flip_invariance - input_ts, input_padding = self._preprocess(inputs, context_len=forecast_context_len) + input_ts, input_padding = self._preprocess(inputs, observed_mask, context_len=forecast_context_len) input_ts = input_ts.to(device) input_padding = input_padding.to(device) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 31ed60ce9d5c..e63fb7d86ed8 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -209,22 +209,49 @@ def test_model_main_input_name(self): observed_main_input_name = list(model_signature.parameters.keys())[1] self.assertEqual(TimesFmModelForPrediction.main_input_name, observed_main_input_name) + def test_past_values_to_tensor_left_pads_and_stacks(self): + past_values = [ + torch.tensor([1.0, 2.0, 3.0]), + torch.tensor([4.0]), + torch.tensor([5.0, 6.0]), + ] + expected = torch.tensor( + [ + [1.0, 2.0, 3.0], + [0.0, 0.0, 4.0], + [0.0, 5.0, 6.0], + ] + ) + + out = TimesFmModelForPrediction._past_values_to_tensor(past_values) + + self.assertEqual(out.shape, (3, 3)) + self.assertEqual(out.dtype, past_values[0].dtype) + self.assertTrue(torch.equal(out, expected)) + @require_torch @slow class TimesFmModelIntegrationTests(unittest.TestCase): def test_inference(self): model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch").to(torch_device) - forecast_input = [ - np.sin(np.linspace(0, 20, 100)), - np.sin(np.linspace(0, 20, 200)), - np.sin(np.linspace(0, 20, 400)), + sequences = [ + torch.sin(torch.linspace(0, 20, 100, dtype=torch.float32, device=torch_device)), + torch.sin(torch.linspace(0, 20, 200, dtype=torch.float32, device=torch_device)), + torch.sin(torch.linspace(0, 20, 400, dtype=torch.float32, device=torch_device)), ] - forecast_input_tensor = [torch.tensor(ts, dtype=torch.float32, device=torch_device) for ts in forecast_input] - frequency_input = [0, 1, 2] + past_values = TimesFmModelForPrediction._past_values_to_tensor(sequences) + past_observed_mask = torch.zeros_like(past_values, dtype=torch.long) + for i, ts in enumerate(sequences): + past_observed_mask[i, past_values.shape[1] - ts.shape[0] :] = 1 + frequency_input = torch.tensor([0, 1, 2], dtype=torch.long, device=torch_device) with torch.no_grad(): - output = model(past_values=forecast_input_tensor, freq=frequency_input) + output = model( + past_values=past_values, + past_observed_mask=past_observed_mask, + freq=frequency_input, + ) mean_predictions = output.mean_predictions self.assertEqual(mean_predictions.shape, torch.Size([3, model.config.horizon_length])) diff --git a/tests/models/timesfm2_5/test_modeling_timesfm2_5.py b/tests/models/timesfm2_5/test_modeling_timesfm2_5.py index 7a909da6d78c..ebe4c9f2adcb 100644 --- a/tests/models/timesfm2_5/test_modeling_timesfm2_5.py +++ b/tests/models/timesfm2_5/test_modeling_timesfm2_5.py @@ -290,15 +290,18 @@ def test_inference(self): model = TimesFm2_5ModelForPrediction.from_pretrained( "google/timesfm-2.5-200m-transformers", revision="refs/pr/3" ).to(torch_device) - forecast_input = [ - np.sin(np.linspace(0, 20, 100)), - np.sin(np.linspace(0, 20, 200)), - np.sin(np.linspace(0, 20, 400)), + sequences = [ + torch.sin(torch.linspace(0, 20, 100, dtype=torch.float32, device=torch_device)), + torch.sin(torch.linspace(0, 20, 200, dtype=torch.float32, device=torch_device)), + torch.sin(torch.linspace(0, 20, 400, dtype=torch.float32, device=torch_device)), ] - forecast_input_tensor = [torch.tensor(ts, dtype=torch.float32, device=torch_device) for ts in forecast_input] + past_values = TimesFm2_5ModelForPrediction._past_values_to_tensor(sequences) + past_observed_mask = torch.zeros_like(past_values, dtype=torch.long) + for i, ts in enumerate(sequences): + past_observed_mask[i, past_values.shape[1] - ts.shape[0] :] = 1 with torch.no_grad(): - output = model(past_values=forecast_input_tensor) + output = model(past_values=past_values, past_observed_mask=past_observed_mask) mean_predictions = output.mean_predictions self.assertEqual(mean_predictions.shape, torch.Size([3, model.config.horizon_length]))