Skip to content

Commit 2490610

Browse files
committed
added new tests
1 parent 06d77a6 commit 2490610

7 files changed

Lines changed: 813 additions & 1411 deletions

File tree

examples/flowstate_simple_example.ipynb

Lines changed: 0 additions & 716 deletions
This file was deleted.

examples/model_comparison_example.ipynb

Lines changed: 0 additions & 687 deletions
This file was deleted.

faim_sdk/client.py

Lines changed: 172 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
import io
88
import json
99
import logging
10+
import warnings
11+
from copy import copy
1012

1113
import httpx
14+
import numpy as np
1215

1316
from faim_client import AuthenticatedClient, Client
1417
from faim_client.api.forecast import forecast_v1_ts_forecast_model_name_model_version_post
@@ -61,6 +64,148 @@ def _parse_error_response(response) -> ErrorResponse | None:
6164
return None
6265

6366

67+
def _needs_univariate_transformation(request: ForecastRequest) -> bool:
68+
"""Check if request requires univariate transformation.
69+
70+
FlowState and TiRex models only support univariate forecasting.
71+
When they receive multivariate input (features > 1), the input
72+
must be transformed to forecast each feature independently.
73+
74+
Args:
75+
request: Forecast request to check
76+
77+
Returns:
78+
True if transformation is needed, False otherwise
79+
"""
80+
# Only FlowState and TiRex require transformation
81+
if request.model_name not in (ModelName.FLOWSTATE, ModelName.TIREX):
82+
return False
83+
84+
# Check if input is multivariate (features > 1)
85+
num_features = request.x.shape[2] # Shape is (batch, seq_len, features)
86+
return num_features > 1
87+
88+
89+
def _prepare_univariate_request(request: ForecastRequest) -> tuple[ForecastRequest, tuple[int, int]]:
90+
"""Prepare request for univariate-only models with multivariate input.
91+
92+
Transforms input from (batch, seq_len, features) to (batch*features, seq_len, 1)
93+
and issues a warning to the user that features will be forecast independently.
94+
95+
Args:
96+
request: Original forecast request with multivariate input
97+
98+
Returns:
99+
Tuple of (modified request, (original_batch_size, num_features))
100+
101+
Example:
102+
Input shape: (batch=2, seq_len=100, features=3)
103+
Output shape: (batch=6, seq_len=100, features=1)
104+
Mapping:
105+
- Feature 0 of series 0 → new batch index 0
106+
- Feature 1 of series 0 → new batch index 1
107+
- Feature 2 of series 0 → new batch index 2
108+
- Feature 0 of series 1 → new batch index 3
109+
- Feature 1 of series 1 → new batch index 4
110+
- Feature 2 of series 1 → new batch index 5
111+
"""
112+
original_batch_size, seq_len, num_features = request.x.shape
113+
114+
# Issue user warning
115+
warnings.warn(
116+
f"{request.model_name.value.title()} model only supports univariate forecasting. "
117+
f"Input with {num_features} features will be forecast independently. "
118+
f"Each feature will be treated as a separate time series.",
119+
UserWarning,
120+
stacklevel=3, # Point to the user's code, not this internal function
121+
)
122+
123+
logger.info(
124+
f"Transforming multivariate input for {request.model_name.value}: "
125+
f"shape {request.x.shape} → ({original_batch_size * num_features}, {seq_len}, 1)"
126+
)
127+
128+
# Reshape: (batch, seq_len, features) → (batch, features, seq_len) → (batch*features, seq_len) → (batch*features, seq_len, 1)
129+
# We want to interleave features across the batch dimension
130+
x_transposed = request.x.transpose(0, 2, 1) # (batch, features, seq_len)
131+
x_flattened = x_transposed.reshape(original_batch_size * num_features, seq_len) # (batch*features, seq_len)
132+
x_univariate = x_flattened[:, :, np.newaxis] # (batch*features, seq_len, 1)
133+
134+
# Create modified request with reshaped x
135+
# Use copy to avoid modifying the original request
136+
modified_request = copy(request)
137+
modified_request.x = x_univariate
138+
139+
return modified_request, (original_batch_size, num_features)
140+
141+
142+
def _reshape_univariate_response(
143+
response: ForecastResponse,
144+
original_batch_size: int,
145+
num_features: int,
146+
) -> ForecastResponse:
147+
"""Reshape response from univariate transformation back to multivariate format.
148+
149+
Reverses the transformation done by _prepare_univariate_request() to restore
150+
the original batch and feature dimensions.
151+
152+
Args:
153+
response: Response from server with univariate format
154+
original_batch_size: Original batch size before transformation
155+
num_features: Number of features in original input
156+
157+
Returns:
158+
Response with proper multivariate shape
159+
160+
Example:
161+
Point forecast:
162+
Input shape: (batch*features=6, horizon=24, features=1)
163+
Output shape: (batch=2, horizon=24, features=3)
164+
165+
Quantile forecast:
166+
Input shape: (batch*features=6, horizon=24, quantiles=5, features=1)
167+
Output shape: (batch=2, horizon=24, quantiles=5, features=3)
168+
"""
169+
modified_response = ForecastResponse(metadata=response.metadata)
170+
171+
# Reshape point predictions if present
172+
if response.point is not None:
173+
# Input: (batch*features, horizon, 1)
174+
# Output: (batch, horizon, features)
175+
batch_times_features, horizon, _ = response.point.shape
176+
177+
# Reshape to (batch, features, horizon, 1)
178+
reshaped = response.point.reshape(original_batch_size, num_features, horizon, 1)
179+
# Transpose to (batch, horizon, features, 1)
180+
transposed = reshaped.transpose(0, 2, 1, 3) # (batch, horizon, features, 1)
181+
# Squeeze last dimension to get (batch, horizon, features)
182+
modified_response.point = transposed.squeeze(-1)
183+
184+
# Reshape quantile predictions if present
185+
if response.quantiles is not None:
186+
# Input: (batch*features, horizon, quantiles, 1)
187+
# Output: (batch, horizon, quantiles, features)
188+
batch_times_features, horizon, num_quantiles, _ = response.quantiles.shape
189+
190+
# Reshape to (batch, features, horizon, quantiles, 1)
191+
reshaped = response.quantiles.reshape(original_batch_size, num_features, horizon, num_quantiles, 1)
192+
# Transpose to (batch, horizon, quantiles, features, 1)
193+
transposed = reshaped.transpose(0, 2, 3, 1, 4) # (batch, horizon, quantiles, features, 1)
194+
# Squeeze last dimension to get (batch, horizon, quantiles, features)
195+
modified_response.quantiles = transposed.squeeze(-1)
196+
197+
# Samples - keep as is for now (not common for FlowState/TiRex)
198+
if response.samples is not None:
199+
modified_response.samples = response.samples
200+
201+
logger.debug(
202+
f"Reshaped univariate response: "
203+
f"original_batch={original_batch_size}, features={num_features}"
204+
)
205+
206+
return modified_response
207+
208+
64209
class ForecastClient:
65210
"""High-level client for FAIM time-series forecasting.
66211
@@ -86,8 +231,8 @@ class ForecastClient:
86231

87232
def __init__(
88233
self,
89-
base_url: str,
90-
timeout: float = 120.0,
234+
base_url: str = "https://api.faim.it.com",
235+
timeout: float = 60.0,
91236
verify_ssl: bool = True,
92237
api_key: str | None = None,
93238
**httpx_kwargs,
@@ -96,7 +241,7 @@ def __init__(
96241
97242
Args:
98243
base_url: Base URL of FAIM inference API
99-
timeout: Request timeout in seconds. Default: 120s
244+
timeout: Request timeout in seconds. Default: 60s
100245
verify_ssl: Whether to verify SSL certificates. Default: True
101246
api_key: Optional API key for authentication. If provided, all requests
102247
will include "Authorization: Bearer <api_key>" header. Default: None
@@ -174,6 +319,11 @@ def forecast(self, request: ForecastRequest) -> ForecastResponse:
174319
f"x.shape={request.x.shape}, horizon={request.horizon}"
175320
)
176321

322+
# Check if univariate transformation is needed
323+
transform_shape_info = None
324+
if _needs_univariate_transformation(request):
325+
request, transform_shape_info = _prepare_univariate_request(request)
326+
177327
# Serialize request to Arrow format
178328
try:
179329
arrays, metadata = request.to_arrays_and_metadata()
@@ -313,6 +463,13 @@ def forecast(self, request: ForecastRequest) -> ForecastResponse:
313463
arrays, metadata = deserialize_from_arrow(response_bytes)
314464
forecast_response = ForecastResponse.from_arrays_and_metadata(arrays, metadata)
315465

466+
# If univariate transformation was applied, reshape response back
467+
if transform_shape_info is not None:
468+
original_batch_size, num_features = transform_shape_info
469+
forecast_response = _reshape_univariate_response(
470+
forecast_response, original_batch_size, num_features
471+
)
472+
316473
logger.info(f"Forecast successful: {forecast_response}")
317474
return forecast_response
318475

@@ -348,6 +505,11 @@ async def forecast_async(self, request: ForecastRequest) -> ForecastResponse:
348505
model = request.model_name
349506
logger.debug(f"Starting async forecast: model={model}, version={request.model_version}")
350507

508+
# Check if univariate transformation is needed
509+
transform_shape_info = None
510+
if _needs_univariate_transformation(request):
511+
request, transform_shape_info = _prepare_univariate_request(request)
512+
351513
# Serialize request
352514
try:
353515
arrays, metadata = request.to_arrays_and_metadata()
@@ -487,6 +649,13 @@ async def forecast_async(self, request: ForecastRequest) -> ForecastResponse:
487649
arrays, metadata = deserialize_from_arrow(response_bytes)
488650
forecast_response = ForecastResponse.from_arrays_and_metadata(arrays, metadata)
489651

652+
# If univariate transformation was applied, reshape response back
653+
if transform_shape_info is not None:
654+
original_batch_size, num_features = transform_shape_info
655+
forecast_response = _reshape_univariate_response(
656+
forecast_response, original_batch_size, num_features
657+
)
658+
490659
logger.info(f"Async forecast successful: {forecast_response}")
491660
return forecast_response
492661

faim_sdk/models.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class ForecastRequest:
2828
_model_name: ClassVar[ModelName]
2929

3030
x: np.ndarray
31-
"""Time series data. Shape: (batch_size, sequence_length, features) or (sequence_length, features)"""
31+
"""Time series data. Shape: (batch_size, sequence_length, features)"""
3232

3333
horizon: int
3434
"""Forecast horizon length (number of time steps to predict)"""
@@ -60,14 +60,21 @@ def __post_init__(self) -> None:
6060
6161
Raises:
6262
TypeError: If x is not a numpy ndarray
63-
ValueError: If x is empty or horizon is non-positive
63+
ValueError: If x is empty, not 3D, or horizon is non-positive
6464
"""
6565
if not isinstance(self.x, np.ndarray):
6666
raise TypeError(f"x must be numpy.ndarray, got {type(self.x).__name__}")
6767

6868
if self.x.size == 0:
6969
raise ValueError("x cannot be empty")
7070

71+
# Ensure x is 3D: (batch_size, sequence_length, features)
72+
if self.x.ndim != 3:
73+
raise ValueError(
74+
f"x must be a 3D array with shape (batch_size, sequence_length, features), "
75+
f"got shape {self.x.shape} with {self.x.ndim} dimensions"
76+
)
77+
7178
if self.horizon <= 0:
7279
raise ValueError(f"horizon must be positive, got {self.horizon}")
7380

@@ -165,7 +172,6 @@ def to_arrays_and_metadata(self) -> tuple[dict[str, np.ndarray], dict[str, Any]]
165172
"""
166173
arrays, metadata = super().to_arrays_and_metadata()
167174

168-
# Add TiRex-specific metadata
169175
metadata["output_type"] = self.output_type
170176

171177
return arrays, metadata
@@ -269,7 +275,7 @@ class ForecastResponse:
269275
"""Point predictions. Shape: (batch_size, horizon, features)"""
270276

271277
quantiles: np.ndarray | None = None
272-
"""Quantile predictions. Shape: (batch_size, horizon, num_quantiles)"""
278+
"""Quantile predictions. Shape: (batch_size, horizon, num_quantiles, features)"""
273279

274280
samples: np.ndarray | None = None
275281
"""Sample predictions. Shape: (batch_size, horizon, num_samples)"""

0 commit comments

Comments
 (0)