Skip to content

Commit aee0eee

Browse files
committed
Improve robustness of metrics calculation and scenario handling
1 parent 073024d commit aee0eee

2 files changed

Lines changed: 80 additions & 28 deletions

File tree

src/fyp/selfplay/solver.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,12 @@ def predict(
214214

215215
forecast = np.full(self.forecast_horizon, baseline)
216216

217+
# Apply scenario transformation to the baseline if provided
218+
if scenario is not None:
219+
forecast = scenario.apply_to_timeseries(forecast.copy())
220+
217221
if return_quantiles:
218-
return {"0.1": forecast * 0.8, "0.5": forecast, "0.9": forecast * 1.2}
222+
return {"0.1": forecast * 0.9, "0.5": forecast, "0.9": forecast * 1.1}
219223
else:
220224
return {"point": forecast}
221225

@@ -258,8 +262,12 @@ def predict(
258262

259263
forecast = np.full(self.forecast_horizon, baseline)
260264

265+
# Apply scenario transformation to the baseline if provided
266+
if scenario is not None:
267+
forecast = scenario.apply_to_timeseries(forecast.copy())
268+
261269
if return_quantiles:
262-
return {"0.1": forecast * 0.8, "0.5": forecast, "0.9": forecast * 1.2}
270+
return {"0.1": forecast * 0.9, "0.5": forecast, "0.9": forecast * 1.1}
263271
else:
264272
return {"point": forecast}
265273

src/fyp/selfplay/trainer.py

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,31 @@
2525
logger = logging.getLogger(__name__)
2626

2727

28+
def _safe_mean(values: list[float] | np.ndarray, default: float = 0.0) -> float:
29+
"""Compute mean safely, handling empty lists and NaN values.
30+
31+
Args:
32+
values: List or array of values
33+
default: Default value if list is empty or all NaN
34+
35+
Returns:
36+
Mean value or default
37+
"""
38+
if not values or len(values) == 0:
39+
return float(default)
40+
41+
# Convert to numpy array for easier filtering
42+
vals = np.array(values, dtype=float)
43+
44+
# Filter out NaN and infinite values
45+
valid_vals = vals[np.isfinite(vals)]
46+
47+
if len(valid_vals) == 0:
48+
return float(default)
49+
50+
return float(np.mean(valid_vals))
51+
52+
2853
class SelfPlayTrainer:
2954
"""Orchestrates propose→solve→verify self-play training loop."""
3055

@@ -186,16 +211,25 @@ def train_episode(
186211

187212
# Episode summary statistics
188213
metrics["episode_time"] = time.time() - episode_start_time
189-
metrics["avg_solver_loss"] = np.mean(metrics["solver_losses"])
190-
metrics["avg_verification_reward"] = np.mean(metrics["verification_rewards"])
191-
metrics["avg_proposer_reward"] = np.mean(metrics["proposer_rewards"])
192-
metrics["scenario_diversity"] = len(set(metrics["scenarios"])) / len(
193-
metrics["scenarios"]
214+
metrics["avg_solver_loss"] = _safe_mean(metrics["solver_losses"], default=0.0)
215+
metrics["avg_verification_reward"] = _safe_mean(
216+
metrics["verification_rewards"], default=0.0
217+
)
218+
metrics["avg_proposer_reward"] = _safe_mean(
219+
metrics["proposer_rewards"], default=0.0
194220
)
195221

222+
# Handle scenario diversity safely
223+
if metrics["scenarios"]:
224+
metrics["scenario_diversity"] = len(set(metrics["scenarios"])) / len(
225+
metrics["scenarios"]
226+
)
227+
else:
228+
metrics["scenario_diversity"] = 0.0
229+
196230
for error_type in ["mae", "mape", "smape"]:
197-
metrics[f"avg_{error_type}"] = np.mean(
198-
metrics["forecast_errors"][error_type]
231+
metrics[f"avg_{error_type}"] = _safe_mean(
232+
metrics["forecast_errors"][error_type], default=0.0
199233
)
200234

201235
self.episode_count += 1
@@ -351,13 +385,13 @@ def validate(
351385
all_violations.append(has_violation)
352386

353387
return {
354-
"avg_loss": np.mean(all_losses),
355-
"mae": np.mean(all_mae),
356-
"mape": np.mean(all_mape),
357-
"smape": np.mean(all_smape),
358-
"violation_rate": np.mean(all_violations),
359-
"mae_std": np.std(all_mae),
360-
"mape_std": np.std(all_mape),
388+
"avg_loss": _safe_mean(all_losses, default=0.0),
389+
"mae": _safe_mean(all_mae, default=0.0),
390+
"mape": _safe_mean(all_mape, default=0.0),
391+
"smape": _safe_mean(all_smape, default=0.0),
392+
"violation_rate": _safe_mean(all_violations, default=0.0),
393+
"mae_std": np.std(all_mae) if all_mae else 0.0,
394+
"mape_std": np.std(all_mape) if all_mape else 0.0,
361395
}
362396

363397
def _prepare_data_windows(
@@ -457,9 +491,13 @@ def _print_progress(self, episode: int, metrics: dict[str, Any]) -> None:
457491
window = 10
458492
recent_metrics = self.metrics_history[-window:]
459493

460-
avg_loss = np.mean([m["avg_solver_loss"] for m in recent_metrics])
461-
avg_reward = np.mean([m["avg_verification_reward"] for m in recent_metrics])
462-
avg_mae = np.mean([m["avg_mae"] for m in recent_metrics])
494+
avg_loss = _safe_mean(
495+
[m["avg_solver_loss"] for m in recent_metrics], default=0.0
496+
)
497+
avg_reward = _safe_mean(
498+
[m["avg_verification_reward"] for m in recent_metrics], default=0.0
499+
)
500+
avg_mae = _safe_mean([m["avg_mae"] for m in recent_metrics], default=0.0)
463501

464502
# Scenario distribution
465503
scenario_counts = {}
@@ -526,23 +564,29 @@ def _save_training_summary(self) -> None:
526564
for scenario_type, success_rates in self.scenario_success_rates.items():
527565
scenario_performance[scenario_type] = {
528566
"total_count": len(success_rates),
529-
"avg_success_rate": np.mean(success_rates),
530-
"std_success_rate": np.std(success_rates),
567+
"avg_success_rate": _safe_mean(success_rates, default=0.0),
568+
"std_success_rate": np.std(success_rates) if success_rates else 0.0,
531569
}
532570

533571
summary = {
534572
"total_episodes": self.episode_count,
535573
"best_val_loss": self.best_val_loss,
536574
"final_metrics": {
537-
"avg_loss": np.mean(all_losses[-100:]),
538-
"avg_reward": np.mean(all_rewards[-100:]),
539-
"avg_mae": np.mean(all_mae[-100:]),
575+
"avg_loss": _safe_mean(all_losses[-100:], default=0.0),
576+
"avg_reward": _safe_mean(all_rewards[-100:], default=0.0),
577+
"avg_mae": _safe_mean(all_mae[-100:], default=0.0),
540578
},
541579
"improvement": {
542-
"loss_reduction": (all_losses[0] - all_losses[-1])
543-
/ all_losses[0]
544-
* 100,
545-
"mae_reduction": (all_mae[0] - all_mae[-1]) / all_mae[0] * 100,
580+
"loss_reduction": (
581+
((all_losses[0] - all_losses[-1]) / all_losses[0] * 100)
582+
if all_losses and all_losses[0] != 0
583+
else 0.0
584+
),
585+
"mae_reduction": (
586+
((all_mae[0] - all_mae[-1]) / all_mae[0] * 100)
587+
if all_mae and all_mae[0] != 0
588+
else 0.0
589+
),
546590
},
547591
"scenario_performance": scenario_performance,
548592
"proposer_final_state": self.proposer.get_scenario_statistics(),

0 commit comments

Comments
 (0)