|
25 | 25 | logger = logging.getLogger(__name__) |
26 | 26 |
|
27 | 27 |
|
| 28 | +def _safe_mean(values: list[float] | np.ndarray, default: float = 0.0) -> float: |
| 29 | + """Compute mean safely, handling empty lists and NaN values. |
| 30 | +
|
| 31 | + Args: |
| 32 | + values: List or array of values |
| 33 | + default: Default value if list is empty or all NaN |
| 34 | +
|
| 35 | + Returns: |
| 36 | + Mean value or default |
| 37 | + """ |
| 38 | + if not values or len(values) == 0: |
| 39 | + return float(default) |
| 40 | + |
| 41 | + # Convert to numpy array for easier filtering |
| 42 | + vals = np.array(values, dtype=float) |
| 43 | + |
| 44 | + # Filter out NaN and infinite values |
| 45 | + valid_vals = vals[np.isfinite(vals)] |
| 46 | + |
| 47 | + if len(valid_vals) == 0: |
| 48 | + return float(default) |
| 49 | + |
| 50 | + return float(np.mean(valid_vals)) |
| 51 | + |
| 52 | + |
28 | 53 | class SelfPlayTrainer: |
29 | 54 | """Orchestrates propose→solve→verify self-play training loop.""" |
30 | 55 |
|
@@ -186,16 +211,25 @@ def train_episode( |
186 | 211 |
|
187 | 212 | # Episode summary statistics |
188 | 213 | metrics["episode_time"] = time.time() - episode_start_time |
189 | | - metrics["avg_solver_loss"] = np.mean(metrics["solver_losses"]) |
190 | | - metrics["avg_verification_reward"] = np.mean(metrics["verification_rewards"]) |
191 | | - metrics["avg_proposer_reward"] = np.mean(metrics["proposer_rewards"]) |
192 | | - metrics["scenario_diversity"] = len(set(metrics["scenarios"])) / len( |
193 | | - metrics["scenarios"] |
| 214 | + metrics["avg_solver_loss"] = _safe_mean(metrics["solver_losses"], default=0.0) |
| 215 | + metrics["avg_verification_reward"] = _safe_mean( |
| 216 | + metrics["verification_rewards"], default=0.0 |
| 217 | + ) |
| 218 | + metrics["avg_proposer_reward"] = _safe_mean( |
| 219 | + metrics["proposer_rewards"], default=0.0 |
194 | 220 | ) |
195 | 221 |
|
| 222 | + # Handle scenario diversity safely |
| 223 | + if metrics["scenarios"]: |
| 224 | + metrics["scenario_diversity"] = len(set(metrics["scenarios"])) / len( |
| 225 | + metrics["scenarios"] |
| 226 | + ) |
| 227 | + else: |
| 228 | + metrics["scenario_diversity"] = 0.0 |
| 229 | + |
196 | 230 | for error_type in ["mae", "mape", "smape"]: |
197 | | - metrics[f"avg_{error_type}"] = np.mean( |
198 | | - metrics["forecast_errors"][error_type] |
| 231 | + metrics[f"avg_{error_type}"] = _safe_mean( |
| 232 | + metrics["forecast_errors"][error_type], default=0.0 |
199 | 233 | ) |
200 | 234 |
|
201 | 235 | self.episode_count += 1 |
@@ -351,13 +385,13 @@ def validate( |
351 | 385 | all_violations.append(has_violation) |
352 | 386 |
|
353 | 387 | return { |
354 | | - "avg_loss": np.mean(all_losses), |
355 | | - "mae": np.mean(all_mae), |
356 | | - "mape": np.mean(all_mape), |
357 | | - "smape": np.mean(all_smape), |
358 | | - "violation_rate": np.mean(all_violations), |
359 | | - "mae_std": np.std(all_mae), |
360 | | - "mape_std": np.std(all_mape), |
| 388 | + "avg_loss": _safe_mean(all_losses, default=0.0), |
| 389 | + "mae": _safe_mean(all_mae, default=0.0), |
| 390 | + "mape": _safe_mean(all_mape, default=0.0), |
| 391 | + "smape": _safe_mean(all_smape, default=0.0), |
| 392 | + "violation_rate": _safe_mean(all_violations, default=0.0), |
| 393 | + "mae_std": np.std(all_mae) if all_mae else 0.0, |
| 394 | + "mape_std": np.std(all_mape) if all_mape else 0.0, |
361 | 395 | } |
362 | 396 |
|
363 | 397 | def _prepare_data_windows( |
@@ -457,9 +491,13 @@ def _print_progress(self, episode: int, metrics: dict[str, Any]) -> None: |
457 | 491 | window = 10 |
458 | 492 | recent_metrics = self.metrics_history[-window:] |
459 | 493 |
|
460 | | - avg_loss = np.mean([m["avg_solver_loss"] for m in recent_metrics]) |
461 | | - avg_reward = np.mean([m["avg_verification_reward"] for m in recent_metrics]) |
462 | | - avg_mae = np.mean([m["avg_mae"] for m in recent_metrics]) |
| 494 | + avg_loss = _safe_mean( |
| 495 | + [m["avg_solver_loss"] for m in recent_metrics], default=0.0 |
| 496 | + ) |
| 497 | + avg_reward = _safe_mean( |
| 498 | + [m["avg_verification_reward"] for m in recent_metrics], default=0.0 |
| 499 | + ) |
| 500 | + avg_mae = _safe_mean([m["avg_mae"] for m in recent_metrics], default=0.0) |
463 | 501 |
|
464 | 502 | # Scenario distribution |
465 | 503 | scenario_counts = {} |
@@ -526,23 +564,29 @@ def _save_training_summary(self) -> None: |
526 | 564 | for scenario_type, success_rates in self.scenario_success_rates.items(): |
527 | 565 | scenario_performance[scenario_type] = { |
528 | 566 | "total_count": len(success_rates), |
529 | | - "avg_success_rate": np.mean(success_rates), |
530 | | - "std_success_rate": np.std(success_rates), |
| 567 | + "avg_success_rate": _safe_mean(success_rates, default=0.0), |
| 568 | + "std_success_rate": np.std(success_rates) if success_rates else 0.0, |
531 | 569 | } |
532 | 570 |
|
533 | 571 | summary = { |
534 | 572 | "total_episodes": self.episode_count, |
535 | 573 | "best_val_loss": self.best_val_loss, |
536 | 574 | "final_metrics": { |
537 | | - "avg_loss": np.mean(all_losses[-100:]), |
538 | | - "avg_reward": np.mean(all_rewards[-100:]), |
539 | | - "avg_mae": np.mean(all_mae[-100:]), |
| 575 | + "avg_loss": _safe_mean(all_losses[-100:], default=0.0), |
| 576 | + "avg_reward": _safe_mean(all_rewards[-100:], default=0.0), |
| 577 | + "avg_mae": _safe_mean(all_mae[-100:], default=0.0), |
540 | 578 | }, |
541 | 579 | "improvement": { |
542 | | - "loss_reduction": (all_losses[0] - all_losses[-1]) |
543 | | - / all_losses[0] |
544 | | - * 100, |
545 | | - "mae_reduction": (all_mae[0] - all_mae[-1]) / all_mae[0] * 100, |
| 580 | + "loss_reduction": ( |
| 581 | + ((all_losses[0] - all_losses[-1]) / all_losses[0] * 100) |
| 582 | + if all_losses and all_losses[0] != 0 |
| 583 | + else 0.0 |
| 584 | + ), |
| 585 | + "mae_reduction": ( |
| 586 | + ((all_mae[0] - all_mae[-1]) / all_mae[0] * 100) |
| 587 | + if all_mae and all_mae[0] != 0 |
| 588 | + else 0.0 |
| 589 | + ), |
546 | 590 | }, |
547 | 591 | "scenario_performance": scenario_performance, |
548 | 592 | "proposer_final_state": self.proposer.get_scenario_statistics(), |
|
0 commit comments