Skip to content

Commit a1f7f78

Browse files
committed
Improve baseline fallback logic and scenario transformation
1 parent 930c25c commit a1f7f78

2 files changed

Lines changed: 28 additions & 8 deletions

File tree

src/fyp/selfplay/solver.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,17 @@ def predict(
201201
if self.model is None or self.model.model is None:
202202
# Model not available or not trained yet, return naive forecast
203203
logger.warning("Model not available/trained, returning naive forecast")
204+
205+
# Calculate baseline with validation
204206
baseline = np.median(context_window)
207+
if np.isnan(baseline) or baseline == 0:
208+
# Fallback to mean if median is invalid
209+
baseline = np.nanmean(context_window)
210+
if np.isnan(baseline) or baseline == 0:
211+
# Ultimate fallback to a reasonable default (1 kWh)
212+
baseline = 1.0
213+
logger.warning("Context window invalid, using default baseline")
214+
205215
forecast = np.full(self.forecast_horizon, baseline)
206216

207217
if return_quantiles:
@@ -238,8 +248,14 @@ def predict(
238248

239249
except Exception as e:
240250
logger.error(f"Prediction failed: {e}")
241-
# Return fallback forecast
251+
# Return fallback forecast with validation
242252
baseline = np.median(context_window)
253+
if np.isnan(baseline) or baseline == 0:
254+
baseline = np.nanmean(context_window)
255+
if np.isnan(baseline) or baseline == 0:
256+
baseline = 1.0
257+
logger.warning("Context invalid in fallback, using default")
258+
243259
forecast = np.full(self.forecast_horizon, baseline)
244260

245261
if return_quantiles:
@@ -601,7 +617,9 @@ def load_checkpoint(self, checkpoint_path: str) -> None:
601617
else:
602618
if torch is None:
603619
raise ImportError("PyTorch required to load .pth checkpoint")
604-
checkpoint = torch.load(checkpoint_path, map_location=self.device)
620+
checkpoint = torch.load(
621+
checkpoint_path, map_location=self.device, weights_only=False
622+
)
605623

606624
# Recreate model with saved config
607625
self.model_config = checkpoint["model_config"]

src/fyp/selfplay/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,17 @@ def apply_scenario_transformation(
184184
if shift_intervals > 0:
185185
# Shift forward
186186
peak_values = transformed[peak_mask]
187-
transformed[peak_mask] *= 0.7 # Reduce original peak
188-
shift_mask = np.roll(peak_mask, shift_intervals)
189-
transformed[shift_mask] += peak_values.mean() * 0.3
187+
if len(peak_values) > 0:
188+
transformed[peak_mask] *= 0.7 # Reduce original peak
189+
shift_mask = np.roll(peak_mask, shift_intervals)
190+
transformed[shift_mask] += peak_values.mean() * 0.3
190191
else:
191192
# Shift backward
192193
peak_values = transformed[peak_mask]
193-
transformed[peak_mask] *= 0.7
194-
shift_mask = np.roll(peak_mask, shift_intervals)
195-
transformed[shift_mask] += peak_values.mean() * 0.3
194+
if len(peak_values) > 0:
195+
transformed[peak_mask] *= 0.7
196+
shift_mask = np.roll(peak_mask, shift_intervals)
197+
transformed[shift_mask] += peak_values.mean() * 0.3
196198

197199
elif scenario_type == "OUTAGE":
198200
# Zero consumption during outage

0 commit comments

Comments
 (0)