77import io
88import json
99import logging
10+ import warnings
11+ from copy import copy
1012
1113import httpx
14+ import numpy as np
1215
1316from faim_client import AuthenticatedClient , Client
1417from faim_client .api .forecast import forecast_v1_ts_forecast_model_name_model_version_post
@@ -61,6 +64,148 @@ def _parse_error_response(response) -> ErrorResponse | None:
6164 return None
6265
6366
67+ def _needs_univariate_transformation (request : ForecastRequest ) -> bool :
68+ """Check if request requires univariate transformation.
69+
70+ FlowState and TiRex models only support univariate forecasting.
71+ When they receive multivariate input (features > 1), the input
72+ must be transformed to forecast each feature independently.
73+
74+ Args:
75+ request: Forecast request to check
76+
77+ Returns:
78+ True if transformation is needed, False otherwise
79+ """
80+ # Only FlowState and TiRex require transformation
81+ if request .model_name not in (ModelName .FLOWSTATE , ModelName .TIREX ):
82+ return False
83+
84+ # Check if input is multivariate (features > 1)
85+ num_features = request .x .shape [2 ] # Shape is (batch, seq_len, features)
86+ return num_features > 1
87+
88+
89+ def _prepare_univariate_request (request : ForecastRequest ) -> tuple [ForecastRequest , tuple [int , int ]]:
90+ """Prepare request for univariate-only models with multivariate input.
91+
92+ Transforms input from (batch, seq_len, features) to (batch*features, seq_len, 1)
93+ and issues a warning to the user that features will be forecast independently.
94+
95+ Args:
96+ request: Original forecast request with multivariate input
97+
98+ Returns:
99+ Tuple of (modified request, (original_batch_size, num_features))
100+
101+ Example:
102+ Input shape: (batch=2, seq_len=100, features=3)
103+ Output shape: (batch=6, seq_len=100, features=1)
104+ Mapping:
105+ - Feature 0 of series 0 → new batch index 0
106+ - Feature 1 of series 0 → new batch index 1
107+ - Feature 2 of series 0 → new batch index 2
108+ - Feature 0 of series 1 → new batch index 3
109+ - Feature 1 of series 1 → new batch index 4
110+ - Feature 2 of series 1 → new batch index 5
111+ """
112+ original_batch_size , seq_len , num_features = request .x .shape
113+
114+ # Issue user warning
115+ warnings .warn (
116+ f"{ request .model_name .value .title ()} model only supports univariate forecasting. "
117+ f"Input with { num_features } features will be forecast independently. "
118+ f"Each feature will be treated as a separate time series." ,
119+ UserWarning ,
120+ stacklevel = 3 , # Point to the user's code, not this internal function
121+ )
122+
123+ logger .info (
124+ f"Transforming multivariate input for { request .model_name .value } : "
125+ f"shape { request .x .shape } → ({ original_batch_size * num_features } , { seq_len } , 1)"
126+ )
127+
128+ # Reshape: (batch, seq_len, features) → (batch, features, seq_len) → (batch*features, seq_len) → (batch*features, seq_len, 1)
129+ # We want to interleave features across the batch dimension
130+ x_transposed = request .x .transpose (0 , 2 , 1 ) # (batch, features, seq_len)
131+ x_flattened = x_transposed .reshape (original_batch_size * num_features , seq_len ) # (batch*features, seq_len)
132+ x_univariate = x_flattened [:, :, np .newaxis ] # (batch*features, seq_len, 1)
133+
134+ # Create modified request with reshaped x
135+ # Use copy to avoid modifying the original request
136+ modified_request = copy (request )
137+ modified_request .x = x_univariate
138+
139+ return modified_request , (original_batch_size , num_features )
140+
141+
142+ def _reshape_univariate_response (
143+ response : ForecastResponse ,
144+ original_batch_size : int ,
145+ num_features : int ,
146+ ) -> ForecastResponse :
147+ """Reshape response from univariate transformation back to multivariate format.
148+
149+ Reverses the transformation done by _prepare_univariate_request() to restore
150+ the original batch and feature dimensions.
151+
152+ Args:
153+ response: Response from server with univariate format
154+ original_batch_size: Original batch size before transformation
155+ num_features: Number of features in original input
156+
157+ Returns:
158+ Response with proper multivariate shape
159+
160+ Example:
161+ Point forecast:
162+ Input shape: (batch*features=6, horizon=24, features=1)
163+ Output shape: (batch=2, horizon=24, features=3)
164+
165+ Quantile forecast:
166+ Input shape: (batch*features=6, horizon=24, quantiles=5, features=1)
167+ Output shape: (batch=2, horizon=24, quantiles=5, features=3)
168+ """
169+ modified_response = ForecastResponse (metadata = response .metadata )
170+
171+ # Reshape point predictions if present
172+ if response .point is not None :
173+ # Input: (batch*features, horizon, 1)
174+ # Output: (batch, horizon, features)
175+ batch_times_features , horizon , _ = response .point .shape
176+
177+ # Reshape to (batch, features, horizon, 1)
178+ reshaped = response .point .reshape (original_batch_size , num_features , horizon , 1 )
179+ # Transpose to (batch, horizon, features, 1)
180+ transposed = reshaped .transpose (0 , 2 , 1 , 3 ) # (batch, horizon, features, 1)
181+ # Squeeze last dimension to get (batch, horizon, features)
182+ modified_response .point = transposed .squeeze (- 1 )
183+
184+ # Reshape quantile predictions if present
185+ if response .quantiles is not None :
186+ # Input: (batch*features, horizon, quantiles, 1)
187+ # Output: (batch, horizon, quantiles, features)
188+ batch_times_features , horizon , num_quantiles , _ = response .quantiles .shape
189+
190+ # Reshape to (batch, features, horizon, quantiles, 1)
191+ reshaped = response .quantiles .reshape (original_batch_size , num_features , horizon , num_quantiles , 1 )
192+ # Transpose to (batch, horizon, quantiles, features, 1)
193+ transposed = reshaped .transpose (0 , 2 , 3 , 1 , 4 ) # (batch, horizon, quantiles, features, 1)
194+ # Squeeze last dimension to get (batch, horizon, quantiles, features)
195+ modified_response .quantiles = transposed .squeeze (- 1 )
196+
197+ # Samples - keep as is for now (not common for FlowState/TiRex)
198+ if response .samples is not None :
199+ modified_response .samples = response .samples
200+
201+ logger .debug (
202+ f"Reshaped univariate response: "
203+ f"original_batch={ original_batch_size } , features={ num_features } "
204+ )
205+
206+ return modified_response
207+
208+
64209class ForecastClient :
65210 """High-level client for FAIM time-series forecasting.
66211
@@ -86,8 +231,8 @@ class ForecastClient:
86231
87232 def __init__ (
88233 self ,
89- base_url : str ,
90- timeout : float = 120 .0 ,
234+ base_url : str = "https://api.faim.it.com" ,
235+ timeout : float = 60 .0 ,
91236 verify_ssl : bool = True ,
92237 api_key : str | None = None ,
93238 ** httpx_kwargs ,
@@ -96,7 +241,7 @@ def __init__(
96241
97242 Args:
98243 base_url: Base URL of FAIM inference API
99- timeout: Request timeout in seconds. Default: 120s
244+ timeout: Request timeout in seconds. Default: 60s
100245 verify_ssl: Whether to verify SSL certificates. Default: True
101246 api_key: Optional API key for authentication. If provided, all requests
102247 will include "Authorization: Bearer <api_key>" header. Default: None
@@ -174,6 +319,11 @@ def forecast(self, request: ForecastRequest) -> ForecastResponse:
174319 f"x.shape={ request .x .shape } , horizon={ request .horizon } "
175320 )
176321
322+ # Check if univariate transformation is needed
323+ transform_shape_info = None
324+ if _needs_univariate_transformation (request ):
325+ request , transform_shape_info = _prepare_univariate_request (request )
326+
177327 # Serialize request to Arrow format
178328 try :
179329 arrays , metadata = request .to_arrays_and_metadata ()
@@ -313,6 +463,13 @@ def forecast(self, request: ForecastRequest) -> ForecastResponse:
313463 arrays , metadata = deserialize_from_arrow (response_bytes )
314464 forecast_response = ForecastResponse .from_arrays_and_metadata (arrays , metadata )
315465
466+ # If univariate transformation was applied, reshape response back
467+ if transform_shape_info is not None :
468+ original_batch_size , num_features = transform_shape_info
469+ forecast_response = _reshape_univariate_response (
470+ forecast_response , original_batch_size , num_features
471+ )
472+
316473 logger .info (f"Forecast successful: { forecast_response } " )
317474 return forecast_response
318475
@@ -348,6 +505,11 @@ async def forecast_async(self, request: ForecastRequest) -> ForecastResponse:
348505 model = request .model_name
349506 logger .debug (f"Starting async forecast: model={ model } , version={ request .model_version } " )
350507
508+ # Check if univariate transformation is needed
509+ transform_shape_info = None
510+ if _needs_univariate_transformation (request ):
511+ request , transform_shape_info = _prepare_univariate_request (request )
512+
351513 # Serialize request
352514 try :
353515 arrays , metadata = request .to_arrays_and_metadata ()
@@ -487,6 +649,13 @@ async def forecast_async(self, request: ForecastRequest) -> ForecastResponse:
487649 arrays , metadata = deserialize_from_arrow (response_bytes )
488650 forecast_response = ForecastResponse .from_arrays_and_metadata (arrays , metadata )
489651
652+ # If univariate transformation was applied, reshape response back
653+ if transform_shape_info is not None :
654+ original_batch_size , num_features = transform_shape_info
655+ forecast_response = _reshape_univariate_response (
656+ forecast_response , original_batch_size , num_features
657+ )
658+
490659 logger .info (f"Async forecast successful: { forecast_response } " )
491660 return forecast_response
492661
0 commit comments