99import mlflow .xgboost
1010import optuna
1111import pandas as pd
12-
1312from mlforecast import MLForecast
1413from mlforecast .lag_transforms import RollingMean , RollingStd
1514from sklearn .metrics import (mean_absolute_error ,
@@ -180,7 +179,8 @@ def _write_metrics(self,
180179
181180 def forecast_xgb (self ,
182181 input_data : pd .DataFrame ,
183- forecast_horizon : int ) -> pd .DataFrame :
182+ forecast_horizon : int ,
183+ full_data : pd .DataFrame = None ) -> pd .DataFrame :
184184 """
185185 Train an XGBoost forecaster with Optuna hyperparameter search,
186186 persist model + metrics + forecast to S3, and return the forecast.
@@ -194,17 +194,20 @@ def forecast_xgb(self,
194194 ----------
195195 input_data: pd.DataFrame
196196 Formatted DataFrame with columns 'ds', 'y', 'unique_id'.
197+ Used for Optuna cross-validation — typically the train split.
197198 forecast_horizon: int
198199 Number of future weekly steps to forecast.
200+ full_data: pd.DataFrame, optional
201+ Full dataset including all observations. When provided, the final
202+ model is fit on this before predicting so that forecast dates
203+ extend beyond the end of the training period into the future.
204+ If None, input_data is used for both CV and the final fit.
199205
200206 Returns
201207 -------
202208 pd.DataFrame
203- Forecast DataFrame.
209+ Forecast DataFrame with future dates .
204210 """
205- import boto3
206- import traceback
207-
208211 run_date = date .today ().strftime ('%Y-%m-%d' )
209212
210213 def objective (trial ):
@@ -215,69 +218,66 @@ def objective(trial):
215218 'subsample' : trial .suggest_float ('subsample' , 0.5 , 1.0 ),
216219 'colsample_bytree' : trial .suggest_float ('colsample_bytree' , 0.5 , 1.0 ),
217220 }
221+
218222 mlf = self ._get_mlforecast (XGBRegressor (** params , verbosity = 0 ))
219223 cv = mlf .cross_validation (input_data , n_windows = 3 , h = forecast_horizon )
224+
220225 mae = mean_absolute_error (cv ['y' ], cv ['XGBRegressor' ])
221- # log to mlflow if available — never fatal
222- try :
223- with mlflow . start_run ( nested = True ):
224- mlflow .log_params ( params )
225- mlflow .log_metric ( "mae" , mae )
226- mlflow .log_metric ("rmse " , root_mean_squared_error ( cv [ 'y' ], cv [ 'XGBRegressor' ]) )
227- mlflow .log_metric ("mape " , mean_absolute_percentage_error ( cv [ 'y' ], cv [ 'XGBRegressor' ]) )
228- except Exception :
229- pass
226+ rmse = root_mean_squared_error ( cv [ 'y' ], cv [ 'XGBRegressor' ])
227+ mape = mean_absolute_percentage_error ( cv [ 'y' ], cv [ 'XGBRegressor' ])
228+
229+ with mlflow .start_run ( nested = True ):
230+ mlflow .log_params ( params )
231+ mlflow .log_metric ("mae " , mae )
232+ mlflow .log_metric ("rmse " , rmse )
233+ mlflow . log_metric ( "mape" , mape )
234+
230235 return mae
231236
232- # ── 1. Optuna study ───────────────────────────────────────────────────
233- print ("Starting Optuna hyperparameter search..." )
237+ # ── 1. Optuna study ───────────────────────────────────────────────────────
234238 study = optuna .create_study (direction = 'minimize' )
235239 try :
236240 with mlflow .start_run (run_name = f"{ self .aoi_name } _{ run_date } " ):
237241 study .optimize (objective , n_trials = 100 , show_progress_bar = True )
238242 mlflow .log_params ({f"best_{ k } " : v for k , v in study .best_params .items ()})
239243 mlflow .log_metric ("best_mae" , study .best_value )
240244 except Exception as e :
241- print (f"MLflow parent run warning (non-fatal): { e } " )
242- # run without mlflow if it failed before optimizing
243- if study .best_params is None :
245+ print (f"MLflow logging warning (non-fatal): { e } " )
246+ if not study .trials :
244247 study .optimize (objective , n_trials = 100 , show_progress_bar = True )
245248
246249 print (f"Best MAE: { study .best_value } " )
247250 print (f"Best params: { study .best_params } " )
248-
249- # ── 2. Refit on full dataset ──────────────────────────────────────────
250- print ("Refitting best model on full dataset..." )
251251 best = study .best_params
252- mlf_best = self ._get_mlforecast (XGBRegressor (** best , verbosity = 0 ))
253- mlf_best .fit (input_data )
254- print ("Refit complete." )
255252
256- # ── 3 . CV with best params for metrics ───────────────────────────────
253+ # ── 2 . CV on train split for metrics ─── ───────────────────────────────
257254 print ("Running CV for metrics..." )
258- cv_best = mlf_best .cross_validation (input_data , n_windows = 3 , h = forecast_horizon )
255+ mlf_cv = self ._get_mlforecast (XGBRegressor (** best , verbosity = 0 ))
256+ cv_best = mlf_cv .cross_validation (input_data , n_windows = 3 , h = forecast_horizon )
259257 print ("CV complete." )
260258
261- # refit after cross_validation — cv resets internal model state in MLForecast
262- print ("Refitting after CV..." )
263- mlf_best .fit (input_data )
264- print ("Refit complete." )
259+ # ── 3. Final fit on full data so forecast extends into the future ─────
260+ fit_data = full_data if full_data is not None else input_data
261+ print (f"Fitting final model on { 'full' if full_data is not None else 'train' } "
262+ f"dataset (last obs: { fit_data ['ds' ].max ().date ()} )..." )
263+ mlf_best = self ._get_mlforecast (XGBRegressor (** best , verbosity = 0 ))
264+ mlf_best .fit (fit_data )
265+ print ("Fit complete." )
265266
266- # ── 4. Persist model pickle to S3 ────────────────────────────────────
267- print ("Saving model pickle..." )
267+ # ── 4. Persist model pickle to S3 ─────────────────────────────────────
268+ import boto3
269+ s3 = boto3 .client ("s3" )
268270 local_pkl = os .path .join (
269271 self .forecast_models_dir ,
270272 f"model_{ self .aoi_name } _{ run_date } .pkl"
271273 )
272274 with open (local_pkl , "wb" ) as f :
273275 pickle .dump (mlf_best , f )
274276
275- s3 = boto3 .client ("s3" )
276277 model_key = f"{ self .country } /{ self .aoi_name } /ml/model_{ self .aoi_name } _{ run_date } .pkl"
277278 s3 .upload_file (local_pkl , BUCKET_NAME , model_key )
278279 print (f"Model written to: s3://{ BUCKET_NAME } /{ model_key } " )
279280
280- # log artifact to mlflow — never fatal
281281 try :
282282 with mlflow .start_run (run_name = f"{ self .aoi_name } _best_model" ):
283283 mlflow .log_artifact (local_pkl )
@@ -286,7 +286,7 @@ def objective(trial):
286286 except Exception as e :
287287 print (f"MLflow artifact logging warning (non-fatal): { e } " )
288288
289- # ── 5. Generate forecast ─────────────── ───────────────────────────────
289+ # ── 5. Generate forecast (future dates) ───────────────────────────────
290290 print ("Generating forecast..." )
291291 try :
292292 forecast = mlf_best .predict (h = forecast_horizon )
@@ -297,22 +297,23 @@ def objective(trial):
297297 forecast .to_parquet (forecast_s3 , index = False )
298298 print (f"Forecast written to: { forecast_s3 } " )
299299 except Exception as e :
300+ import traceback
300301 print (f"ERROR writing forecast: { e } " )
301302 traceback .print_exc ()
302303 raise
303304
304- # ── 6. Persist metrics JSON to S3 ─────────────────────────────────────
305+ # ── 6. Persist metrics JSON ─────── ─────────────────────────────────────
305306 print ("Writing metrics..." )
306307 try :
307308 self ._write_metrics (run_date , best , study .best_value , cv_best )
308309 except Exception as e :
310+ import traceback
309311 print (f"ERROR writing metrics: { e } " )
310312 traceback .print_exc ()
311313 raise
312314
313315 return forecast
314316
315-
316317 def predict_xgb (self ,
317318 forecast_horizon : int ) -> pd .DataFrame :
318319 """
0 commit comments