Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 88 additions & 51 deletions src/transformers/models/timesfm/modeling_timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# limitations under the License.

import math
import warnings
from collections.abc import Callable, Sequence
from dataclasses import dataclass

Expand Down Expand Up @@ -381,8 +382,10 @@ def forward(
Past values of the time series that serves as input to the model.
past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The padding indicator of the time series.
freq (`torch.LongTensor` of shape `(batch_size,)`):
Frequency indices for the time series data.
freq (`torch.LongTensor` of shape `(batch_size,)` or `Sequence[int]`, *optional*):
Frequency indices for the time series data. Defaults to a zero tensor (high-frequency). A
sequence of ints is also accepted and converted to a tensor internally. Tensor inputs are
preferred and required for export.
"""
# Reshape into patches (using view for efficiency)
bsize = past_values.shape[0]
Expand Down Expand Up @@ -590,13 +593,20 @@ def __init__(self, config: TimesFmConfig):
self.post_init()

def _preprocess(
self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None
self,
inputs: torch.Tensor,
observed_mask: torch.Tensor,
freq: torch.Tensor | None = None,
context_len: int | None = None,
) -> tuple[torch.Tensor, ...]:
"""Pad/truncate input time series to `context_len` and build a padding mask.

Args:
inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task.
freq: Optional list of frequencies (returned as a tensor when provided).
inputs: A 2D `torch.Tensor` of shape `(batch_size, sequence_length)`.
observed_mask: A 2D `torch.Tensor` of the same shape as `inputs` where `1` indicates an observed
value and `0` indicates a missing value. Missing positions are marked as padded in the
returned padding mask.
freq: Optional 1D `torch.Tensor` of frequency indices.
context_len: Optional context length override (defaults to `self.context_len`).

Returns:
Expand All @@ -605,25 +615,20 @@ def _preprocess(
if context_len is None:
context_len = self.context_len

input_ts, input_padding = [], []

for ts in inputs:
input_len = ts.shape[0]
padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
if input_len < context_len:
num_front_pad = context_len - input_len
ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
elif input_len > context_len:
ts = ts[-context_len:]
padding = padding[-(context_len + self.horizon_len) :]
x = inputs[:, -context_len:]
num_front_pad = context_len - x.shape[1]
x = F.pad(x, (num_front_pad, 0))

input_ts.append(ts)
input_padding.append(padding)
obs = observed_mask[:, -context_len:].to(dtype=x.dtype)
body_padding = 1 - obs
front_padding = torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device)
horizon_padding = torch.zeros(x.shape[0], self.horizon_len, dtype=x.dtype, device=x.device)
padding = torch.cat([front_padding, body_padding, horizon_padding], dim=1)
result = (x, padding)

result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0))
if freq is not None:
result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),)
freq_tensor = freq[: x.shape[0]].to(dtype=torch.int32)
result = result + (freq_tensor.reshape(-1, 1),)
return result

def _postprocess_output(
Expand Down Expand Up @@ -653,8 +658,9 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to
@auto_docstring
def forward(
self,
past_values: Sequence[torch.Tensor],
freq: Sequence[torch.Tensor | int] | None = None,
past_values: Sequence[torch.Tensor] | torch.Tensor,
past_observed_mask: torch.Tensor | None = None,
freq: Sequence[int] | torch.Tensor | None = None,
window_size: int | None = None,
future_values: torch.Tensor | None = None,
forecast_context_len: int | None = None,
Expand All @@ -664,9 +670,23 @@ def forward(
) -> TimesFmOutputForPrediction:
r"""
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Past values of the time series that serves as input to the model.
freq (`torch.LongTensor` of shape `(batch_size,)`):
Frequency indices for the time series data.
Past values of the time series that serves as input to the model. A list of 1D tensors with
possibly differing lengths is also accepted (deprecated): each tensor is front-padded with zeros
and stacked into a 2D tensor. Tensor inputs are preferred and required for export.
past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Boolean mask indicating which `past_values` were observed and which are padding/missing. Mask
values selected in `[0, 1]`:

- `1` for values that are **observed**,
- `0` for values that are **missing** (i.e. padded or NaNs that were replaced by zeros).

Defaults to a tensor of ones (everything observed). When `past_values` is passed as a list of
variable-length tensors, you should provide a matching `past_observed_mask` so the front-padding
zeros are not treated as observed values.
freq (`torch.LongTensor` of shape `(batch_size,)` or `Sequence[int]`, *optional*):
Frequency indices for the time series data. Defaults to a zero tensor (high-frequency). A
sequence of ints is also accepted and converted to a tensor internally. Tensor inputs are
preferred and required for export.
window_size (`int`, *optional*):
Window size of trend + residual decomposition. If None then we do not do decomposition.
future_values (`torch.Tensor`, *optional*):
Expand All @@ -686,7 +706,7 @@ def forward(

>>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch")

>>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()]
>>> forecast_input = torch.stack([torch.linspace(0, 20, 400).sin() for _ in range(3)])
>>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long)

>>> # Generate
Expand All @@ -701,27 +721,36 @@ def forward(
else:
fcontext_len = forecast_context_len

device = past_values[0].device

inputs = [ts[-fcontext_len:] for ts in past_values]
inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))

if window_size is not None:
new_inputs = []
new_freqs = []
for i, ts in enumerate(inputs):
new_inputs.extend(self._timesfm_moving_average(ts, window_size))
if freq is not None:
new_freqs.extend([freq[i]] * 2)
inputs = new_inputs
if freq is not None:
freq = new_freqs
if isinstance(past_values, list):
warnings.warn(
"Passing `past_values` as a list of 1D tensors is deprecated and will be removed in a future "
"version. Please pass a 2D `torch.Tensor` of shape `(batch_size, sequence_length)` and, when "
"needed, a `past_observed_mask` of the same shape (1 = observed, 0 = padded/missing).",
FutureWarning,
)
past_values = self._past_values_to_tensor(past_values)

device = past_values.device
if past_observed_mask is None:
past_observed_mask = torch.ones_like(past_values)
if freq is None:
logger.info("No frequency provided via `freq`. Default to high (0).")
freq = [0] * len(inputs)
freq = torch.zeros(past_values.shape[0], dtype=torch.int32, device=device)
else:
freq = torch.as_tensor(freq, dtype=torch.int32, device=device)

inputs = past_values[:, -fcontext_len:]
observed_mask = past_observed_mask[:, -fcontext_len:].to(device=device)
sentinel = torch.full_like(inputs, torch.finfo(inputs.dtype).max)
inp_min = torch.where(observed_mask.bool(), inputs, sentinel).min()

if window_size is not None:
trend, residual = self._timesfm_moving_average(inputs, window_size)
inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1)
observed_mask = torch.repeat_interleave(observed_mask, 2, dim=0)
freq = torch.repeat_interleave(freq, 2, dim=0)

input_ts, input_padding, inp_freq = self._preprocess(inputs, freq)
input_ts, input_padding, inp_freq = self._preprocess(inputs, observed_mask, freq=freq)
input_ts = input_ts.to(device)
input_padding = input_padding.to(device)
inp_freq = inp_freq.to(device)
Expand Down Expand Up @@ -794,15 +823,23 @@ def forward(
)

@staticmethod
def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]:
"""Calculates the moving average using PyTorch's convolution function."""
# Pad with zeros to handle initial window positions
def _past_values_to_tensor(past_values: Sequence[torch.Tensor]) -> torch.Tensor:
"""Convert a list of variable-length 1D tensors into a 2D tensor of shape `(batch_size, max_len)`
by left-padding each entry with zeros. Equivalent to `torch.nn.utils.rnn.pad_sequence(past_values,
batch_first=True, padding_side="left")`, re-implemented here because `padding_side` requires
`torch>=2.5`.
"""
max_len = max(ts.shape[0] for ts in past_values)
return torch.stack([F.pad(ts, (max_len - ts.shape[0], 0)) for ts in past_values], dim=0)

@staticmethod
def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculates the moving average using PyTorch's convolution function. `arr` shape: `(B, T)`."""
arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0)
# Create a convolution kernel
kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size
# Apply convolution to calculate the moving average
smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze()
return [smoothed_arr, arr - smoothed_arr]
kernel = kernel.view(1, 1, -1)
smoothed_arr = F.conv1d(arr_padded.unsqueeze(1), kernel).squeeze(1)
return smoothed_arr, arr - smoothed_arr


__all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"]
Loading
Loading