Skip to content

Commit 2737be7

Browse files
committed
feat(viz): add population evolution visualization and history interval support
1 parent 2832128 commit 2737be7

5 files changed

Lines changed: 178 additions & 5 deletions

File tree

scripts/experiment_evolution.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import os
2+
import sys
3+
import random
4+
import numpy as np
5+
from datetime import datetime
6+
7+
# Add project root to sys.path
8+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9+
10+
from src.config import (
11+
ASSET_NAMES,
12+
DATA_FOLDER,
13+
OUTPUT_FOLDER,
14+
TRAINING_ITERATIONS,
15+
NUM_MIXTURES,
16+
)
17+
from src.data_handler import load_data
18+
from src.predictors import ExactGPPredictor
19+
from src.portfolio import (
20+
calculate_expected_returns_and_cov,
21+
MOEADOptimizer,
22+
)
23+
from src.visualization import plot_population_evolution
24+
25+
26+
def seed_everything(seed=42):
27+
random.seed(seed)
28+
np.random.seed(seed)
29+
try:
30+
import torch
31+
32+
torch.manual_seed(seed)
33+
except ImportError:
34+
pass
35+
36+
37+
def run_evolution_demo():
38+
seed_everything(42)
39+
40+
# Setup folders
41+
run_id = datetime.now().strftime("%Y%m%d_%H%M%S_moead_evolution")
42+
run_folder = os.path.join(OUTPUT_FOLDER, run_id)
43+
os.makedirs(run_folder, exist_ok=True)
44+
45+
print(f"Starting MOEA/D Evolution Demo. Result will be saved to {run_folder}")
46+
47+
# 1. Load data and calculate expected returns/cov
48+
# (Simplified data hash for speed in this demo)
49+
raw_data_dict = load_data(DATA_FOLDER, ASSET_NAMES, verbose=False)
50+
predictor = ExactGPPredictor(
51+
training_iterations=TRAINING_ITERATIONS,
52+
num_mixtures=NUM_MIXTURES,
53+
verbose=False,
54+
)
55+
56+
sample_timestamps = raw_data_dict[ASSET_NAMES[0]]["timestamps"]
57+
target_timestamp = sample_timestamps[-1] + 100.0
58+
59+
expected_returns, cov_matrix = calculate_expected_returns_and_cov(
60+
raw_data_dict, predictor, target_timestamp, ASSET_NAMES, verbose=False
61+
)
62+
63+
data_kwargs = {"expected_returns": expected_returns, "cov_matrix": cov_matrix}
64+
65+
# 2. Run MOEA/D with history recording
66+
print("Running MOEA/D with history recording (interval=10)...")
67+
moead = MOEADOptimizer()
68+
history_interval = 10
69+
generations = 200
70+
71+
# generate_pareto_front returns: metrics, population, weight_vectors, history (if record_history=True)
72+
results = moead.generate_pareto_front(
73+
num_points=100,
74+
generations=generations,
75+
verbose=True,
76+
record_history=True,
77+
history_interval=history_interval,
78+
**data_kwargs,
79+
)
80+
81+
# unpack results
82+
moead_metrics, _, _, history = results
83+
84+
# 3. Plot Population Evolution
85+
print("Generating Population Evolution Plot...")
86+
obj_names = ["Expected Return", "Expected Risk"]
87+
save_path = os.path.join(run_folder, "population_evolution.png")
88+
89+
plot_population_evolution(
90+
history, obj_names, history_interval=history_interval, save_path=save_path
91+
)
92+
93+
print(f"Success! Plot saved to {save_path}")
94+
95+
96+
if __name__ == "__main__":
97+
run_evolution_demo()

src/portfolio/moead.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def generate_pareto_front(
2525
nr=2,
2626
verbose=False,
2727
crossover_operator="sbx",
28+
history_interval=10,
2829
**kwargs,
2930
):
3031
"""
@@ -204,7 +205,8 @@ def normalise(f):
204205
z_ideal[k] = off_f[k]
205206

206207
if kwargs.get("record_history", False):
207-
history.append(f_phys.copy())
208+
if _gen % history_interval == 0 or _gen == generations - 1:
209+
history.append(f_phys.copy())
208210

209211
# ── 6. Return physical-space metrics ─────────────────────────────────
210212
pareto_metrics = {

src/portfolio/moead_awa.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@ def _get_non_dominated(self, solutions, objectives):
2323
return is_efficient
2424

2525
def generate_pareto_front(
26-
self, num_points=100, generations=100, T=10, nr=2, verbose=False, **kwargs
26+
self,
27+
num_points=100,
28+
generations=100,
29+
T=10,
30+
nr=2,
31+
verbose=False,
32+
history_interval=10,
33+
**kwargs,
2734
):
2835
"""
2936
MOEA/D-AWA (Adaptive Weight vector Adjustment).
@@ -222,7 +229,8 @@ def normalise(f):
222229
neighbors = self._get_neighbors(weight_vectors, T)
223230

224231
if kwargs.get("record_history", False):
225-
history.append(f_phys.copy())
232+
if gen % history_interval == 0 or gen == generations - 1:
233+
history.append(f_phys.copy())
226234

227235
pareto_metrics = {
228236
obj.name: f_phys[:, i] for i, obj in enumerate(self.problem.objectives)

src/portfolio/moead_dra.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@ def __init__(self, problem=None):
88
super().__init__(problem)
99

1010
def generate_pareto_front(
11-
self, num_points=100, generations=100, T=10, nr=2, verbose=False, **kwargs
11+
self,
12+
num_points=100,
13+
generations=100,
14+
T=10,
15+
nr=2,
16+
verbose=False,
17+
history_interval=10,
18+
**kwargs,
1219
):
1320
"""
1421
MOEA/D-DRA (Dynamic Resource Allocation).
@@ -163,7 +170,8 @@ def normalise(f):
163170
z_ideal[k] = off_f[k]
164171

165172
if kwargs.get("record_history", False):
166-
history.append(f_phys.copy())
173+
if gen % history_interval == 0 or gen == generations - 1:
174+
history.append(f_phys.copy())
167175

168176
pareto_metrics = {
169177
obj.name: f_phys[:, i] for i, obj in enumerate(self.problem.objectives)

src/visualization.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,3 +395,61 @@ def plot_variants_comparison(ref_metrics, variants_results, save_path=None):
395395
plt.grid(True, alpha=0.3)
396396
if save_path:
397397
plt.savefig(save_path)
398+
399+
400+
def plot_population_evolution(history, obj_names, history_interval=10, save_path=None):
401+
"""
402+
Plots the evolution of the population through different generations.
403+
404+
Args:
405+
history: Array of shape (n_recorded, n_points, n_objectives)
406+
obj_names: List of objective names (e.g., ["Return", "Risk"])
407+
history_interval: Interval at which history was recorded
408+
save_path: Path to save the plot
409+
"""
410+
n_recorded = len(history)
411+
if n_recorded == 0:
412+
print("Empty history provided.")
413+
return
414+
415+
plt.figure(figsize=(12, 8))
416+
cmap = plt.get_cmap("viridis")
417+
418+
for i, pop in enumerate(history):
419+
# Calculate approximate generation number
420+
gen = i * history_interval
421+
color = cmap(i / max(1, n_recorded - 1))
422+
423+
plt.scatter(
424+
pop[:, 0],
425+
pop[:, 1],
426+
color=color,
427+
alpha=0.6,
428+
s=20,
429+
label=f"Gen {gen}" if i % 2 == 0 or i == n_recorded - 1 else "",
430+
)
431+
432+
plt.title("Evolution of Pareto Front (Population Dynamics)")
433+
plt.xlabel(obj_names[0])
434+
plt.ylabel(obj_names[1])
435+
plt.grid(True, alpha=0.3)
436+
437+
# Add colorbar
438+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=gen))
439+
plt.colorbar(sm, label="Generation", ax=plt.gca())
440+
441+
# Limit legend items if too many
442+
handles, labels = plt.gca().get_legend_handles_labels()
443+
if len(labels) > 10:
444+
indices = np.linspace(0, len(labels) - 1, 6, dtype=int)
445+
plt.legend(
446+
[handles[i] for i in indices],
447+
[labels[i] for i in indices],
448+
title="Generations Sample",
449+
)
450+
else:
451+
plt.legend(title="Generations Sample")
452+
453+
plt.tight_layout()
454+
if save_path:
455+
plt.savefig(save_path)

0 commit comments

Comments
 (0)