Skip to content

Commit 44c8816

Browse files
committed
Once cv is done, fit all the data and forecast for future dates
1 parent 4c3813d commit 44c8816

2 files changed

Lines changed: 55 additions & 57 deletions

File tree

scripts/forecast_ts.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import mlflow.xgboost
1010
import optuna
1111
import pandas as pd
12-
1312
from mlforecast import MLForecast
1413
from mlforecast.lag_transforms import RollingMean, RollingStd
1514
from sklearn.metrics import (mean_absolute_error,
@@ -180,7 +179,8 @@ def _write_metrics(self,
180179

181180
def forecast_xgb(self,
182181
input_data: pd.DataFrame,
183-
forecast_horizon: int) -> pd.DataFrame:
182+
forecast_horizon: int,
183+
full_data: pd.DataFrame = None) -> pd.DataFrame:
184184
"""
185185
Train an XGBoost forecaster with Optuna hyperparameter search,
186186
persist model + metrics + forecast to S3, and return the forecast.
@@ -194,17 +194,20 @@ def forecast_xgb(self,
194194
----------
195195
input_data: pd.DataFrame
196196
Formatted DataFrame with columns 'ds', 'y', 'unique_id'.
197+
Used for Optuna cross-validation — typically the train split.
197198
forecast_horizon: int
198199
Number of future weekly steps to forecast.
200+
full_data: pd.DataFrame, optional
201+
Full dataset including all observations. When provided, the final
202+
model is fit on this before predicting so that forecast dates
203+
extend beyond the end of the training period into the future.
204+
If None, input_data is used for both CV and the final fit.
199205
200206
Returns
201207
-------
202208
pd.DataFrame
203-
Forecast DataFrame.
209+
Forecast DataFrame with future dates.
204210
"""
205-
import boto3
206-
import traceback
207-
208211
run_date = date.today().strftime('%Y-%m-%d')
209212

210213
def objective(trial):
@@ -215,69 +218,66 @@ def objective(trial):
215218
'subsample': trial.suggest_float('subsample', 0.5, 1.0),
216219
'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
217220
}
221+
218222
mlf = self._get_mlforecast(XGBRegressor(**params, verbosity=0))
219223
cv = mlf.cross_validation(input_data, n_windows=3, h=forecast_horizon)
224+
220225
mae = mean_absolute_error(cv['y'], cv['XGBRegressor'])
221-
# log to mlflow if available — never fatal
222-
try:
223-
with mlflow.start_run(nested=True):
224-
mlflow.log_params(params)
225-
mlflow.log_metric("mae", mae)
226-
mlflow.log_metric("rmse", root_mean_squared_error(cv['y'], cv['XGBRegressor']))
227-
mlflow.log_metric("mape", mean_absolute_percentage_error(cv['y'], cv['XGBRegressor']))
228-
except Exception:
229-
pass
226+
rmse = root_mean_squared_error(cv['y'], cv['XGBRegressor'])
227+
mape = mean_absolute_percentage_error(cv['y'], cv['XGBRegressor'])
228+
229+
with mlflow.start_run(nested=True):
230+
mlflow.log_params(params)
231+
mlflow.log_metric("mae", mae)
232+
mlflow.log_metric("rmse", rmse)
233+
mlflow.log_metric("mape", mape)
234+
230235
return mae
231236

232-
# ── 1. Optuna study ───────────────────────────────────────────────────
233-
print("Starting Optuna hyperparameter search...")
237+
# ── 1. Optuna study ───────────────────────────────────────────────────────
234238
study = optuna.create_study(direction='minimize')
235239
try:
236240
with mlflow.start_run(run_name=f"{self.aoi_name}_{run_date}"):
237241
study.optimize(objective, n_trials=100, show_progress_bar=True)
238242
mlflow.log_params({f"best_{k}": v for k, v in study.best_params.items()})
239243
mlflow.log_metric("best_mae", study.best_value)
240244
except Exception as e:
241-
print(f"MLflow parent run warning (non-fatal): {e}")
242-
# run without mlflow if it failed before optimizing
243-
if study.best_params is None:
245+
print(f"MLflow logging warning (non-fatal): {e}")
246+
if not study.trials:
244247
study.optimize(objective, n_trials=100, show_progress_bar=True)
245248

246249
print(f"Best MAE: {study.best_value}")
247250
print(f"Best params: {study.best_params}")
248-
249-
# ── 2. Refit on full dataset ──────────────────────────────────────────
250-
print("Refitting best model on full dataset...")
251251
best = study.best_params
252-
mlf_best = self._get_mlforecast(XGBRegressor(**best, verbosity=0))
253-
mlf_best.fit(input_data)
254-
print("Refit complete.")
255252

256-
# ── 3. CV with best params for metrics ───────────────────────────────
253+
# ── 2. CV on train split for metrics ──────────────────────────────────
257254
print("Running CV for metrics...")
258-
cv_best = mlf_best.cross_validation(input_data, n_windows=3, h=forecast_horizon)
255+
mlf_cv = self._get_mlforecast(XGBRegressor(**best, verbosity=0))
256+
cv_best = mlf_cv.cross_validation(input_data, n_windows=3, h=forecast_horizon)
259257
print("CV complete.")
260258

261-
# refit after cross_validation — cv resets internal model state in MLForecast
262-
print("Refitting after CV...")
263-
mlf_best.fit(input_data)
264-
print("Refit complete.")
259+
# ── 3. Final fit on full data so forecast extends into the future ─────
260+
fit_data = full_data if full_data is not None else input_data
261+
print(f"Fitting final model on {'full' if full_data is not None else 'train'} "
262+
f"dataset (last obs: {fit_data['ds'].max().date()})...")
263+
mlf_best = self._get_mlforecast(XGBRegressor(**best, verbosity=0))
264+
mlf_best.fit(fit_data)
265+
print("Fit complete.")
265266

266-
# ── 4. Persist model pickle to S3 ────────────────────────────────────
267-
print("Saving model pickle...")
267+
# ── 4. Persist model pickle to S3 ─────────────────────────────────────
268+
import boto3
269+
s3 = boto3.client("s3")
268270
local_pkl = os.path.join(
269271
self.forecast_models_dir,
270272
f"model_{self.aoi_name}_{run_date}.pkl"
271273
)
272274
with open(local_pkl, "wb") as f:
273275
pickle.dump(mlf_best, f)
274276

275-
s3 = boto3.client("s3")
276277
model_key = f"{self.country}/{self.aoi_name}/ml/model_{self.aoi_name}_{run_date}.pkl"
277278
s3.upload_file(local_pkl, BUCKET_NAME, model_key)
278279
print(f"Model written to: s3://{BUCKET_NAME}/{model_key}")
279280

280-
# log artifact to mlflow — never fatal
281281
try:
282282
with mlflow.start_run(run_name=f"{self.aoi_name}_best_model"):
283283
mlflow.log_artifact(local_pkl)
@@ -286,7 +286,7 @@ def objective(trial):
286286
except Exception as e:
287287
print(f"MLflow artifact logging warning (non-fatal): {e}")
288288

289-
# ── 5. Generate forecast ──────────────────────────────────────────────
289+
# ── 5. Generate forecast (future dates) ───────────────────────────────
290290
print("Generating forecast...")
291291
try:
292292
forecast = mlf_best.predict(h=forecast_horizon)
@@ -297,22 +297,23 @@ def objective(trial):
297297
forecast.to_parquet(forecast_s3, index=False)
298298
print(f"Forecast written to: {forecast_s3}")
299299
except Exception as e:
300+
import traceback
300301
print(f"ERROR writing forecast: {e}")
301302
traceback.print_exc()
302303
raise
303304

304-
# ── 6. Persist metrics JSON to S3 ─────────────────────────────────────
305+
# ── 6. Persist metrics JSON ────────────────────────────────────────────
305306
print("Writing metrics...")
306307
try:
307308
self._write_metrics(run_date, best, study.best_value, cv_best)
308309
except Exception as e:
310+
import traceback
309311
print(f"ERROR writing metrics: {e}")
310312
traceback.print_exc()
311313
raise
312314

313315
return forecast
314316

315-
316317
def predict_xgb(self,
317318
forecast_horizon: int) -> pd.DataFrame:
318319
"""

scripts/pipeline.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22

33
import json
44
import datetime
5-
import os
65

76
import boto3
8-
import dotenv
97
import duckdb
108
import pandas as pd
119

@@ -14,25 +12,20 @@
1412
from scripts.process_ts import DataAnalysis
1513
from scripts.read_bucket import DataReader
1614

17-
from dotenv import load_dotenv
18-
19-
load_dotenv()
20-
21-
BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
22-
23-
print(f"Using S3 bucket: {BUCKET_NAME}")
15+
import os
16+
BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "environment-monitor")
2417

2518

2619
class Pipeline:
2720
"""
2821
End-to-end ML pipeline for a single AOI.
2922
3023
S3 structure assumed:
31-
s3://{BUCKET_NAME}/aois.json
32-
s3://{BUCKET_NAME}/{country}/{aoi_name}/ts/*.parquet
33-
s3://{BUCKET_NAME}/{country}/{aoi_name}/ml/model_{aoi_name}_{date}.pkl
34-
s3://{BUCKET_NAME}/{country}/{aoi_name}/ml/metrics_{aoi_name}_{date}.json
35-
s3://{BUCKET_NAME}/{country}/{aoi_name}/ml/forecast_{aoi_name}_{date}.parquet
24+
s3://env_monitor/aois.json
25+
s3://env_monitor/{country}/{aoi_name}/ts/*.parquet
26+
s3://env_monitor/{country}/{aoi_name}/ml/model_{aoi_name}_{date}.pkl
27+
s3://env_monitor/{country}/{aoi_name}/ml/metrics_{aoi_name}_{date}.json
28+
s3://env_monitor/{country}/{aoi_name}/ml/forecast_{aoi_name}_{date}.parquet
3629
"""
3730

3831
def __init__(self,
@@ -52,9 +45,11 @@ def __init__(self,
5245
self.conn = duckdb.connect()
5346
self.conn.execute("INSTALL spatial;")
5447
self.conn.execute("LOAD spatial;")
55-
self.conn.execute("""CREATE SECRET (
48+
self.conn.execute(f"""CREATE SECRET (
5649
TYPE s3,
57-
PROVIDER credential_chain
50+
KEY_ID '{os.getenv("AWS_ACCESS_KEY_ID")}',
51+
SECRET '{os.getenv("AWS_SECRET_ACCESS_KEY")}',
52+
REGION '{os.getenv("AWS_DEFAULT_REGION", "us-east-1")}'
5853
);
5954
""")
6055

@@ -72,7 +67,7 @@ def register_aoi(self, lat: float, lon: float, rad: float) -> None:
7267
"""
7368
Add or update this AOI's entry in the top-level aois.json registry.
7469
75-
Reads s3://{BUCKET_NAME}/aois.json, upserts this AOI under its
70+
Reads s3://env_monitor/aois.json, upserts this AOI under its
7671
country key, and writes the file back.
7772
7873
Parameters
@@ -256,7 +251,9 @@ def train_model(self, mae_threshold: float = 0.05) -> None:
256251
print(f"Total train shape: {train_df.shape}")
257252
print(f"Total test shape: {test_df.shape}")
258253

259-
forecast = forecast_ts.forecast_xgb(train_df, h)
254+
# Pass full dataset so the final model fit anchors forecasts
255+
# to the end of all available data, not just the train split
256+
forecast = forecast_ts.forecast_xgb(train_df, h, full_data=input_df)
260257

261258
# Load the metrics just written to check MAE
262259
try:

0 commit comments

Comments
 (0)