@@ -201,7 +201,17 @@ def predict(
201201 if self .model is None or self .model .model is None :
202202 # Model not available or not trained yet, return naive forecast
203203 logger .warning ("Model not available/trained, returning naive forecast" )
204+
205+ # Calculate baseline with validation
204206 baseline = np .median (context_window )
207+ if np .isnan (baseline ) or baseline == 0 :
208+ # Fallback to mean if median is invalid
209+ baseline = np .nanmean (context_window )
210+ if np .isnan (baseline ) or baseline == 0 :
211+ # Ultimate fallback to a reasonable default (1 kWh)
212+ baseline = 1.0
213+ logger .warning ("Context window invalid, using default baseline" )
214+
205215 forecast = np .full (self .forecast_horizon , baseline )
206216
207217 if return_quantiles :
@@ -238,8 +248,14 @@ def predict(
238248
239249 except Exception as e :
240250 logger .error (f"Prediction failed: { e } " )
241- # Return fallback forecast
251+ # Return fallback forecast with validation
242252 baseline = np .median (context_window )
253+ if np .isnan (baseline ) or baseline == 0 :
254+ baseline = np .nanmean (context_window )
255+ if np .isnan (baseline ) or baseline == 0 :
256+ baseline = 1.0
257+ logger .warning ("Context invalid in fallback, using default" )
258+
243259 forecast = np .full (self .forecast_horizon , baseline )
244260
245261 if return_quantiles :
@@ -601,7 +617,9 @@ def load_checkpoint(self, checkpoint_path: str) -> None:
601617 else :
602618 if torch is None :
603619 raise ImportError ("PyTorch required to load .pth checkpoint" )
604- checkpoint = torch .load (checkpoint_path , map_location = self .device )
620+ checkpoint = torch .load (
621+ checkpoint_path , map_location = self .device , weights_only = False
622+ )
605623
606624 # Recreate model with saved config
607625 self .model_config = checkpoint ["model_config" ]
0 commit comments