Skip to content

Commit cb6d273

Browse files
committed
added eval
1 parent 4bf60e5 commit cb6d273

10 files changed

Lines changed: 5110 additions & 88 deletions

README.md

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Production-ready Python SDK for FAIM (Foundation AI Models) - a high-performance
1515
- **🔄 Async Support**: Built-in async/await support for concurrent requests
1616
- **📊 Rich Error Handling**: Machine-readable error codes with detailed diagnostics
1717
- **🧪 Battle-Tested**: Production-ready with comprehensive error handling
18+
- **📈 Evaluation Tools**: Built-in metrics (MSE, MASE, CRPS) and visualization utilities
1819

1920
## Installation
2021

@@ -140,6 +141,173 @@ print(response.metadata)
140141
# {'model_name': 'chronos2', 'model_version': '1.0', 'inference_time_ms': 123}
141142
```
142143

144+
## Evaluation & Metrics
145+
146+
The SDK includes a comprehensive evaluation toolkit (`faim_sdk.eval`) for measuring forecast quality with standard metrics and visualizations.
147+
148+
### Installation
149+
150+
For visualization support, install with the viz extra:
151+
152+
```bash
153+
pip install faim-sdk[viz]
154+
```
155+
156+
### Available Metrics
157+
158+
#### Mean Squared Error (MSE)
159+
160+
Measures average squared difference between predictions and ground truth.
161+
162+
```python
163+
from faim_sdk.eval import mse
164+
165+
# Evaluate point forecast
166+
mse_score = mse(test_data, response.point, reduction='mean')
167+
print(f"MSE: {mse_score:.4f}")
168+
169+
# Per-sample MSE
170+
mse_per_sample = mse(test_data, response.point, reduction='none')
171+
print(f"MSE per sample shape: {mse_per_sample.shape}") # (batch_size,)
172+
```
173+
174+
#### Mean Absolute Scaled Error (MASE)
175+
176+
Scale-independent metric comparing forecast to naive baseline (better than MAPE for series with zeros).
177+
178+
```python
179+
from faim_sdk.eval import mase
180+
181+
# MASE requires training data for baseline
182+
mase_score = mase(test_data, response.point, train_data, reduction='mean')
183+
print(f"MASE: {mase_score:.4f}")
184+
185+
# Interpretation:
186+
# MASE < 1: Better than naive baseline
187+
# MASE = 1: Equivalent to naive baseline
188+
# MASE > 1: Worse than naive baseline
189+
```
190+
191+
#### Continuous Ranked Probability Score (CRPS)
192+
193+
Proper scoring rule for probabilistic forecasts - generalizes MAE to distributions.
194+
195+
```python
196+
from faim_sdk.eval import crps_from_quantiles
197+
198+
# Evaluate probabilistic forecast with quantiles
199+
crps_score = crps_from_quantiles(
200+
test_data,
201+
response.quantiles,
202+
quantile_levels=[0.1, 0.5, 0.9],
203+
reduction='mean'
204+
)
205+
print(f"CRPS: {crps_score:.4f}")
206+
```
207+
208+
### Visualization
209+
210+
Plot forecasts with training context and ground truth:
211+
212+
```python
213+
from faim_sdk.eval import plot_forecast
214+
215+
# Plot single sample (remember to index batch dimension!)
216+
fig, ax = plot_forecast(
217+
train_data=train_data[0], # (seq_len, features) - 2D array
218+
forecast=response.point[0], # (horizon, features) - 2D array
219+
test_data=test_data[0], # (horizon, features) - optional
220+
title="Time Series Forecast"
221+
)
222+
223+
# Save to file
224+
fig.savefig("forecast.png", dpi=300, bbox_inches="tight")
225+
```
226+
227+
#### Multi-Feature Visualization
228+
229+
```python
230+
# Option 1: All features on same plot
231+
fig, ax = plot_forecast(
232+
train_data[0],
233+
response.point[0],
234+
test_data[0],
235+
features_on_same_plot=True,
236+
feature_names=["Temperature", "Humidity", "Pressure"]
237+
)
238+
239+
# Option 2: Separate subplots per feature
240+
fig, axes = plot_forecast(
241+
train_data[0],
242+
response.point[0],
243+
test_data[0],
244+
features_on_same_plot=False,
245+
feature_names=["Temperature", "Humidity", "Pressure"]
246+
)
247+
```
248+
249+
### Complete Evaluation Example
250+
251+
```python
252+
import numpy as np
253+
from faim_sdk import ForecastClient, Chronos2ForecastRequest
254+
from faim_sdk.eval import mse, mase, crps_from_quantiles, plot_forecast
255+
from faim_client.models import ModelName
256+
257+
# Initialize client
258+
client = ForecastClient(base_url="https://api.faim.example.com")
259+
260+
# Prepare data splits
261+
train_data = np.random.randn(32, 100, 1)
262+
test_data = np.random.randn(32, 24, 1)
263+
264+
# Generate forecast
265+
request = Chronos2ForecastRequest(
266+
x=train_data,
267+
horizon=24,
268+
output_type="quantiles",
269+
quantiles=[0.1, 0.5, 0.9]
270+
)
271+
response = client.forecast(ModelName.CHRONOS2, request)
272+
273+
# Evaluate point forecast (use median)
274+
point_pred = response.quantiles[:, :, 1:2] # Extract median, keep 3D shape
275+
mse_score = mse(test_data, point_pred)
276+
mase_score = mase(test_data, point_pred, train_data)
277+
278+
# Evaluate probabilistic forecast
279+
crps_score = crps_from_quantiles(
280+
test_data,
281+
response.quantiles,
282+
quantile_levels=[0.1, 0.5, 0.9]
283+
)
284+
285+
print(f"MSE: {mse_score:.4f}")
286+
print(f"MASE: {mase_score:.4f}")
287+
print(f"CRPS: {crps_score:.4f}")
288+
289+
# Visualize best and worst predictions
290+
mse_per_sample = mse(test_data, point_pred, reduction='none')
291+
best_idx = np.argmin(mse_per_sample)
292+
worst_idx = np.argmax(mse_per_sample)
293+
294+
fig1, ax1 = plot_forecast(
295+
train_data[best_idx],
296+
point_pred[best_idx],
297+
test_data[best_idx],
298+
title=f"Best Forecast (MSE: {mse_per_sample[best_idx]:.4f})"
299+
)
300+
fig1.savefig("best_forecast.png")
301+
302+
fig2, ax2 = plot_forecast(
303+
train_data[worst_idx],
304+
point_pred[worst_idx],
305+
test_data[worst_idx],
306+
title=f"Worst Forecast (MSE: {mse_per_sample[worst_idx]:.4f})"
307+
)
308+
fig2.savefig("worst_forecast.png")
309+
```
310+
143311
## Error Handling
144312

145313
The SDK provides **machine-readable error codes** for robust error handling:

0 commit comments

Comments
 (0)