-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
538 lines (485 loc) · 35.8 KB
/
Copy pathtrain.py
File metadata and controls
538 lines (485 loc) · 35.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
from stable_baselines3 import PPO
from stable_baselines3.common.utils import set_random_seed
from torchinfo import summary
import os
from src.environment import ComputeClusterEnv, Weights, PlottingComplete
from src.plot_config import PlotConfig
from src.callbacks import ComputeClusterCallback
from src.plotter import plot_dashboard, plot_cumulative_savings, plot_episode_summary
import re
import glob
import argparse
import sys
import pandas as pd
from src.arrival_scale import validate_job_arrival_scale
from src.analysis_naming import build_model_weight_dir_name
from src.evaluation_summary import build_episode_summary_line, mean_occupancy_pct
from src.workloadgen import WorkloadGenerator
from src.workloadgen_cli import add_workloadgen_args, build_workloadgen_config
from src.config import MAX_NODES, CORES_PER_NODE, EPISODE_HOURS
from src.session_log import log_invocation
import time
# Train.py passes strings; the env treats "" as falsy in some places and truthy in others.
# To be safe: normalize "" -> None here.
def norm_path(x):
return None if (x is None or str(x).strip() == "") else x
def safe_ratio(numerator: float, denominator: float) -> float | None:
"""Return numerator/denominator, or None when denominator is not positive."""
return (numerator / denominator) if denominator > 0 else None
def fmt_optional(value: float | None, precision: int = 2, thousands: bool = False) -> str:
"""Format float values for logs, using 'n/a' when value is undefined."""
if value is None:
return "n/a"
return f"{value:,.{precision}f}" if thousands else f"{value:.{precision}f}"
STEPS_PER_ITERATION = 100000
def main():
parser = argparse.ArgumentParser(description="Run the Compute Cluster Environment with optional rendering.")
parser.add_argument('--render', type=str, default='none', choices=['human', 'none'], help='Render mode for the environment (default: none).')
parser.add_argument('--quick-plot', action='store_true', help='In "human" render mode, skip quickly to the plot (default: False).')
parser.add_argument('--plot-once', action='store_true', help='In "human" render mode, exit after the first plot.')
parser.add_argument('--prices', type=str, nargs='?', const="", default="", help='Path to the CSV file containing electricity prices (Date,Price)')
parser.add_argument('--job-durations', type=str, nargs='?', const="", default="", help='Path to a file containing job duration samples (for use with durations_sampler)')
parser.add_argument('--jobs', type=str, nargs='?', const="", default="", help='Path to a file containing job samples (for use with jobs_sampler)')
parser.add_argument('--hourly-jobs', type=str, nargs='?', const="", default="", help='Path to Slurm log file for hourly statistical sampling (for use with hourly_sampler)')
parser.add_argument('--job-arrival-scale', type=float, default=1.0, help='Scale sampled arrivals per step (1.0 = unchanged).')
parser.add_argument('--jobs-exact-replay', action='store_true', help='For --jobs mode, replay raw jobs in timeline order (no template aggregation).')
parser.add_argument('--jobs-exact-replay-aggregate', action='store_true', help='With --jobs-exact-replay, aggregate each sampled raw time-bin before enqueueing.')
parser.add_argument('--plot-rewards', action='store_true', help='Per step, plot rewards for all possible num_idle_nodes & num_used_nodes (default: False).')
parser.add_argument('--plot-eff-reward', action=argparse.BooleanOptionalAction, default=True, help='Include efficiency reward in the plot (dashed line).')
parser.add_argument('--plot-price-reward', action=argparse.BooleanOptionalAction, default=True, help='Include price reward in the plot (dashed line).')
parser.add_argument('--plot-idle-penalty', action=argparse.BooleanOptionalAction, default=True, help='Include idle penalty in the plot (dashed line).')
parser.add_argument('--plot-job-age-penalty', action=argparse.BooleanOptionalAction, default=True, help='Include job age penalty in the plot (dashed line).')
parser.add_argument('--plot-total-reward', action=argparse.BooleanOptionalAction, default=True, help='Include total reward per step in the dashboard (raw values).')
parser.add_argument('--plot-price', action=argparse.BooleanOptionalAction, default=True, help='Plot electricity price.')
parser.add_argument('--plot-online-nodes', action=argparse.BooleanOptionalAction, default=True, help='Plot online nodes.')
parser.add_argument('--plot-used-nodes', action=argparse.BooleanOptionalAction, default=True, help='Plot used nodes.')
parser.add_argument('--plot-job-queue', action=argparse.BooleanOptionalAction, default=True, help='Plot job queue.')
parser.add_argument('--ent-coef', type=float, default=0.0, help='Entropy coefficient for the loss calculation (default: 0.0) (Passed to PPO).')
parser.add_argument("--efficiency-weight", type=float, default=0.7, help="Weight for efficiency reward")
parser.add_argument("--price-weight", type=float, default=0.2, help="Weight for price reward")
parser.add_argument("--idle-weight", type=float, default=0.1, help="Weight for idle penalty")
parser.add_argument("--job-age-weight", type=float, default=0.0, help="Weight for job age penalty")
parser.add_argument("--drop-weight", type=float, default=0.0, help="Weight for lost jobs penalty (age expiry or queue-full rejection) (WIP - default 0.0)")
parser.add_argument("--iter-limit", type=int, default=0, help=f"Max number of training iterations (1 iteration = {STEPS_PER_ITERATION} steps)")
parser.add_argument("--session", default="default", help="Session ID")
parser.add_argument(
"--output-dir",
dest="output_dir",
default="sessions",
help="Base directory for all output (models, logs, plots). Defaults to 'sessions'.",
)
parser.add_argument("--evaluate-savings", action='store_true', help="Load latest model and evaluate long-term savings (no training)")
parser.add_argument("--oracle", action='store_true', help="Enable both liquid and contiguous oracles alongside simulation to compute theoretical minimum cost lower bounds.")
parser.add_argument("--eval-months", type=int, default=12, help="Months to evaluate for savings analysis (default: 12, only used with --evaluate-savings)")
add_workloadgen_args(parser)
parser.add_argument("--plot-dashboard", action="store_true", help="Generate dashboard plot (per-hour panels + cumulative savings).")
parser.add_argument("--dashboard-hours", type=int, default=24*14, help="Hours to show in dashboard time-series panels (default: 336).")
parser.add_argument("--dashboard-interval", type=int, default=10000, help="Hours between dashboard plots (default: 10000).")
parser.add_argument("--model", type=int, default=None, help="Load a specific model by timestep number (e.g. 5000000 loads 5000000.zip).")
parser.add_argument("--net-arch", type=str, default="64,64", help="Hidden layer sizes for policy and value networks (comma-separated, e.g., '256,128' or '512,256,128')")
parser.add_argument("--device", type=str, default="auto", help="Device for training: 'auto' (default, uses CUDA if available), 'cuda', 'cpu'")
parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility (seeds environment, numpy, torch, and PPO)")
parser.add_argument("--seed-sweep", action="store_true", help="Treat this run as part of a --seeds sweep and isolate outputs under a seed-specific session subdirectory.")
parser.add_argument("--print-policy", action="store_true", help="Print structure of the policy network.")
parser.add_argument("--seed-path", default="", help="Path if models are saved by seed (forwarded to train.py) - only used by analyze_seed_occupancy.py, ignored otherwise.")
parser.add_argument(
"--flush-after-drop-streak",
type=int,
default=0,
help="Immediately flush and terminate the episode after this many consecutive dropped-job steps (0 disables).",
)
args = parser.parse_args()
try:
args.job_arrival_scale = validate_job_arrival_scale(args.job_arrival_scale)
except ValueError as exc:
parser.error(str(exc))
if args.jobs_exact_replay and not norm_path(args.jobs):
parser.error("--jobs-exact-replay requires --jobs")
if args.jobs_exact_replay_aggregate and not args.jobs_exact_replay:
parser.error("--jobs-exact-replay-aggregate requires --jobs-exact-replay")
if args.workload_gen and args.job_arrival_scale != 1.0:
print(
"Warning: --job-arrival-scale is not allowed with --workload-gen; "
"resetting it to 1.0. Use workload generator arrival settings instead.",
file=sys.stderr,
)
args.job_arrival_scale = 1.0
prices_file_path = args.prices
job_durations_file_path = args.job_durations
jobs_file_path = args.jobs
hourly_jobs_file_path = args.hourly_jobs
# Set random seed for reproducibility
if args.seed is not None:
set_random_seed(args.seed)
print(f"Random seed set to: {args.seed}")
if norm_path(prices_file_path):
df = pd.read_csv(prices_file_path, parse_dates=['Date'])
prices = df['Price'].values.tolist()
print(f"Loaded {len(prices)} prices from CSV: {prices_file_path}")
# print("First few prices:", prices[:30])
else:
prices = None
print("No CSV file provided. Using default price generation.")
weights = Weights(
efficiency_weight=args.efficiency_weight,
price_weight=args.price_weight,
idle_weight=args.idle_weight,
job_age_weight=args.job_age_weight,
drop_weight=args.drop_weight
)
weights_prefix = f"e{weights.efficiency_weight}_p{weights.price_weight}_i{weights.idle_weight}_a{weights.job_age_weight}_d{weights.drop_weight}"
session_root = os.path.join(args.output_dir, args.session)
if args.seed_sweep and args.seed is not None:
session_root = f"{session_root}/seed_{args.seed}"
models_dir = f"{session_root}/models/{weights_prefix}/"
log_dir = f"{session_root}/logs/{weights_prefix}/"
plots_dir = f"{session_root}/plots/"
os.makedirs(models_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
os.makedirs(plots_dir, exist_ok=True)
log_invocation(session_root)
# Load Workload Generator:
wg_cfg = build_workloadgen_config(args)
workload_gen = WorkloadGenerator(wg_cfg) if wg_cfg is not None else None
plot_config = PlotConfig(
quick_plot=args.quick_plot,
plot_once=args.plot_once,
plot_eff_reward=args.plot_eff_reward,
plot_price_reward=args.plot_price_reward,
plot_idle_penalty=args.plot_idle_penalty,
plot_job_age_penalty=args.plot_job_age_penalty,
plot_total_reward=args.plot_total_reward,
plot_price=args.plot_price,
plot_online_nodes=args.plot_online_nodes,
plot_used_nodes=args.plot_used_nodes,
plot_job_queue=args.plot_job_queue,
)
env = ComputeClusterEnv(weights=weights,
session=args.session,
render_mode=args.render,
external_prices=prices,
external_durations=norm_path(job_durations_file_path),
external_jobs=norm_path(jobs_file_path),
external_hourly_jobs=norm_path(hourly_jobs_file_path),
plot_config=plot_config,
steps_per_iteration=STEPS_PER_ITERATION,
evaluation_mode=args.evaluate_savings,
workload_gen=workload_gen,
job_arrival_scale=args.job_arrival_scale,
jobs_exact_replay=args.jobs_exact_replay,
output_dir=args.output_dir,
jobs_exact_replay_aggregate=args.jobs_exact_replay_aggregate,
flush_after_drop_streak=args.flush_after_drop_streak,
enable_oracle=args.oracle)
env.session_dir = session_root
env.plots_dir = plots_dir
env.reset(seed=args.seed)
# Check if there are any saved models in models_dir
model_files = glob.glob(models_dir + "*.zip")
latest_model_file = None
evaluation_plots_dir = None
if model_files:
# Sort the files by extracting the timestep number from the filename and converting it to an integer
model_files.sort(key=lambda filename: int(re.match(r"(\d+)", os.path.basename(filename)).group()))
if args.model is not None:
selected = os.path.join(models_dir, f"{args.model}.zip")
if os.path.exists(selected):
latest_model_file = selected
else:
print(f"Requested model not found: {selected}. Falling back to latest model.")
latest_model_file = model_files[-1]
else:
latest_model_file = model_files[-1] # Get the last file after sorting, which should be the one with the most timesteps
print(f"Found a saved model: {latest_model_file}")
selected_model_id = int(os.path.basename(latest_model_file).split(".")[0])
seed_suffix = ""
if args.seed_path != "":
seed_suffix = "_train" + args.seed_path + "_evalseed_" + str(args.seed)
evaluation_plots_dir = os.path.join(
session_root,
"plots-eval",
build_model_weight_dir_name(
model=selected_model_id,
efficiency_weight=weights.efficiency_weight,
price_weight=weights.price_weight,
idle_weight=weights.idle_weight,
job_age_weight=weights.job_age_weight,
) + seed_suffix,
)
model = PPO.load(latest_model_file, env=env, tensorboard_log=log_dir, n_steps=64, batch_size=64, device=args.device)
else:
print(f"Starting a new model training...")
# Parse network architecture from comma-separated string (e.g., "256,128" -> [256, 128])
net_arch_layers = [int(x) for x in args.net_arch.split(',')]
policy_kwargs = dict(
# pi = policy (actor) network, vf = value function (critic) network
net_arch=dict(pi=net_arch_layers, vf=net_arch_layers)
)
print(f"Network architecture: {net_arch_layers}")
model = PPO('MultiInputPolicy', env, policy_kwargs=policy_kwargs, tensorboard_log=log_dir, ent_coef=args.ent_coef, n_steps=64, batch_size=64, device=args.device, verbose=1)
print(f"Device: {model.device}")
if args.print_policy:
print(model.policy)
summary(model.policy, depth=4)
iters = 0
# If we're continuing from a saved model, adjust iters so that filenames continue sequentially
if latest_model_file:
try:
# Assumes the filename format is "{models_dir}/{STEPS_PER_ITERATION * iters}.zip"
iters = int(os.path.basename(latest_model_file).split('.')[0]) // STEPS_PER_ITERATION
except ValueError:
# If the filename doesn't follow expected format, default to 0
iters = 0
env.set_progress(iters)
if args.evaluate_savings:
if not latest_model_file:
print("Error: No trained model found for evaluation!")
print(f"Expected model files in: {models_dir}")
print("Train a model first, then run evaluation mode.")
return
print(f"=== EVALUATION MODE ===")
print(f"Evaluation period: {args.eval_months} months ({args.eval_months * 2} episodes, Each episode = 2 weeks)")
if evaluation_plots_dir is None:
raise RuntimeError("Evaluation plots directory could not be determined for the selected model.")
os.makedirs(evaluation_plots_dir, exist_ok=True)
env.plots_dir = f"{evaluation_plots_dir}/"
print(f"Evaluation plots directory: {evaluation_plots_dir}")
num_episodes = args.eval_months * 2 # 2 episodes per month
for episode in range(num_episodes):
obs, _ = env.reset()
episode_reward = 0
done = False
step_count = 0
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, _ = env.step(action)
episode_reward += reward
step_count += 1
if step_count%1000==0:
print(f"Episode {episode + 1}, Step {step_count}, Action: {action}, Reward: {reward:.2f}, Total Reward: {episode_reward:.2f}, Total Cost: €{env.metrics.total_cost:.2f}")
done = terminated or truncated
if not env.metrics.episode_costs:
raise RuntimeError("Episode metrics were not recorded before evaluation summary output.")
episode_data = env.metrics.episode_costs[-1]
agent_occupancy_cores_pct = mean_occupancy_pct(env.metrics.episode_used_cores, CORES_PER_NODE * MAX_NODES)
baseline_occupancy_cores_pct = mean_occupancy_pct(env.metrics.episode_baseline_used_cores, CORES_PER_NODE * MAX_NODES)
agent_occupancy_nodes_pct = mean_occupancy_pct(env.metrics.episode_used_nodes, MAX_NODES)
baseline_occupancy_nodes_pct = mean_occupancy_pct(env.metrics.episode_baseline_used_nodes, MAX_NODES)
summary_line = build_episode_summary_line(
episode_number=episode + 1,
episode_data=episode_data,
timeline_max_queue=env.metrics.max_queue_size_reached,
agent_occupancy_cores_pct=agent_occupancy_cores_pct,
baseline_occupancy_cores_pct=baseline_occupancy_cores_pct,
agent_occupancy_nodes_pct=agent_occupancy_nodes_pct,
baseline_occupancy_nodes_pct=baseline_occupancy_nodes_pct,
)
if args.oracle:
liq = float(episode_data.get('oracle_cost', 0.0))
con = float(episode_data.get('oracle_contiguous_cost', 0.0))
if liq != 0.0:
summary_line += f", OracleLiq=€{liq:.0f}"
if con != 0.0:
above = float(episode_data['agent_cost']) - con
summary_line += f", OracleCon=€{con:.0f}, AboveOracleCon=€{above:.0f}"
print(summary_line)
print(f"\nEvaluation complete! Generated {num_episodes} episodes of cost data.")
# Generate cumulative savings plot
session_dir = evaluation_plots_dir
try:
results = plot_cumulative_savings(env, env.metrics.episode_costs, session_dir, save=True, show=args.render == 'human')
plot_episode_summary(env, env.metrics.episode_costs, session_dir, save=True, show=args.render == 'human', suffix=f"eval_{args.eval_months}m")
if results:
print(f"\n=== CUMULATIVE SAVINGS ANALYSIS ===")
print(f"\nVs Baseline (with idle nodes):")
print(f" Total Savings: €{results['total_savings']:,.0f}")
print(f" Average Monthly Reduction (avg monthly flat-cost %): {results['avg_monthly_savings_pct']:.1f}%")
print(f" Annual Savings Rate: €{results['total_savings'] * 12 / args.eval_months:,.0f}/year")
print(f"\nVs Baseline_off (no idle nodes):")
print(f" Total Savings: €{results['total_savings_off']:,.0f}")
print(f" Average Monthly Reduction (avg monthly flat-cost %): {results['avg_monthly_savings_pct_off']:.1f}%")
print(f" Annual Savings Rate: €{results['total_savings_off'] * 12 / args.eval_months:,.0f}/year")
# Calculate job metrics across all episodes
total_jobs_submitted = sum(ep['jobs_submitted'] for ep in env.metrics.episode_costs)
total_jobs_launched = sum(int(ep.get('jobs_launched', ep['jobs_completed'])) for ep in env.metrics.episode_costs)
total_jobs_completed = sum(ep['jobs_completed'] for ep in env.metrics.episode_costs)
total_baseline_submitted = sum(ep['baseline_jobs_submitted'] for ep in env.metrics.episode_costs)
total_baseline_launched = sum(int(ep.get('baseline_jobs_launched', ep['baseline_jobs_completed'])) for ep in env.metrics.episode_costs)
total_baseline_completed = sum(ep['baseline_jobs_completed'] for ep in env.metrics.episode_costs)
avg_wait_time = sum(ep['avg_wait_time'] * int(ep.get('jobs_launched', ep['jobs_completed'])) for ep in env.metrics.episode_costs) / total_jobs_launched if total_jobs_launched > 0 else 0
avg_baseline_wait_time = sum(ep['baseline_avg_wait_time'] * int(ep.get('baseline_jobs_launched', ep['baseline_jobs_completed'])) for ep in env.metrics.episode_costs) / total_baseline_launched if total_baseline_launched > 0 else 0
avg_max_queue = sum(ep['max_queue_size'] for ep in env.metrics.episode_costs) / len(env.metrics.episode_costs)
avg_baseline_max_queue = sum(ep['baseline_max_queue_size'] for ep in env.metrics.episode_costs) / len(env.metrics.episode_costs)
avg_pending_jobs_end = sum(int(ep.get('pending_jobs_end', 0)) for ep in env.metrics.episode_costs) / len(env.metrics.episode_costs)
avg_overdue_jobs_end = sum(int(ep.get('overdue_jobs_end', 0)) for ep in env.metrics.episode_costs) / len(env.metrics.episode_costs)
total_agent_cost = sum(float(ep['agent_cost']) for ep in env.metrics.episode_costs)
total_baseline_cost = sum(float(ep['baseline_cost']) for ep in env.metrics.episode_costs)
total_baseline_off_cost = sum(float(ep['baseline_cost_off']) for ep in env.metrics.episode_costs)
total_jobs_dropped = sum(int(ep.get('jobs_lost_total', ep.get('jobs_dropped', 0))) for ep in env.metrics.episode_costs)
total_baseline_jobs_dropped = sum(int(ep.get('baseline_jobs_lost_total', ep.get('baseline_jobs_dropped', 0))) for ep in env.metrics.episode_costs)
total_agent_power_mwh = sum(float(ep.get('agent_power_consumption_mwh', 0.0)) for ep in env.metrics.episode_costs)
total_baseline_power_mwh = sum(float(ep.get('baseline_power_consumption_mwh', 0.0)) for ep in env.metrics.episode_costs)
total_baseline_off_power_mwh = sum(float(ep.get('baseline_power_consumption_off_mwh', 0.0)) for ep in env.metrics.episode_costs)
total_agent_prop_power_mwh = sum(float(ep.get('agent_prop_power_mwh', 0.0)) for ep in env.metrics.episode_costs)
total_baseline_prop_power_mwh = sum(float(ep.get('baseline_prop_power_mwh', 0.0)) for ep in env.metrics.episode_costs)
total_baseline_off_prop_power_mwh = sum(float(ep.get('baseline_off_prop_power_mwh', 0.0)) for ep in env.metrics.episode_costs)
total_agent_prop_cost = sum(float(ep.get('agent_prop_cost', 0.0)) for ep in env.metrics.episode_costs)
total_baseline_prop_cost = sum(float(ep.get('baseline_prop_cost', 0.0)) for ep in env.metrics.episode_costs)
total_baseline_off_prop_cost = sum(float(ep.get('baseline_off_prop_cost', 0.0)) for ep in env.metrics.episode_costs)
total_savings_prop_cost_vs_baseline = total_baseline_prop_cost - total_agent_prop_cost
total_savings_prop_cost_vs_baseline_off = total_baseline_off_prop_cost - total_agent_prop_cost
total_agent_mean_price = (total_agent_cost / total_agent_power_mwh) if total_agent_power_mwh > 0 else 0.0
total_baseline_mean_price = (total_baseline_cost / total_baseline_power_mwh) if total_baseline_power_mwh > 0 else 0.0
total_baseline_off_mean_price = (total_baseline_off_cost / total_baseline_off_power_mwh) if total_baseline_off_power_mwh > 0 else 0.0
total_agent_prop_mean_price = (total_agent_prop_cost / total_agent_prop_power_mwh) if total_agent_prop_power_mwh > 0 else 0.0
total_baseline_prop_mean_price = (total_baseline_prop_cost / total_baseline_prop_power_mwh) if total_baseline_prop_power_mwh > 0 else 0.0
total_baseline_off_prop_mean_price = (total_baseline_off_prop_cost / total_baseline_off_prop_power_mwh) if total_baseline_off_prop_power_mwh > 0 else 0.0
prop_savings_pct_vs_baseline = safe_ratio(total_savings_prop_cost_vs_baseline * 100.0, total_baseline_prop_cost)
prop_savings_pct_vs_baseline_off = safe_ratio(total_savings_prop_cost_vs_baseline_off * 100.0, total_baseline_off_prop_cost)
total_agent_completion_rate = (total_jobs_completed / total_jobs_submitted * 100) if total_jobs_submitted > 0 else 0.0
total_baseline_completion_rate = (total_baseline_completed / total_baseline_submitted * 100) if total_baseline_submitted > 0 else 0.0
total_savings_vs_baseline = total_baseline_cost - total_agent_cost
total_savings_vs_baseline_off = total_baseline_off_cost - total_agent_cost
total_agent_cost_per_1000_completed = safe_ratio(total_agent_cost * 1000.0, total_jobs_completed)
total_baseline_cost_per_1000_completed = safe_ratio(total_baseline_cost * 1000.0, total_baseline_completed)
# baseline_off is a cost variant of baseline scheduling, so it uses the same completed-job count.
total_baseline_off_cost_per_1000_completed = safe_ratio(total_baseline_off_cost * 1000.0, total_baseline_completed)
total_dropped_jobs_per_saved_euro = safe_ratio(total_jobs_dropped, total_savings_vs_baseline) if total_savings_vs_baseline > 0 else None
total_dropped_jobs_per_saved_euro_off = safe_ratio(total_jobs_dropped, total_savings_vs_baseline_off) if total_savings_vs_baseline_off > 0 else None
arrivals_per_hour_by_episode = [float(ep['jobs_submitted']) / float(EPISODE_HOURS) for ep in env.metrics.episode_costs]
mean_arrivals_per_hour = (sum(arrivals_per_hour_by_episode) / len(arrivals_per_hour_by_episode)) if arrivals_per_hour_by_episode else 0.0
arrivals_variance = (
sum((x - mean_arrivals_per_hour) ** 2 for x in arrivals_per_hour_by_episode) / len(arrivals_per_hour_by_episode)
) if arrivals_per_hour_by_episode else 0.0
std_arrivals_per_hour = arrivals_variance ** 0.5
print(f"\n=== JOB PROCESSING METRICS ===")
print(f"\nAgent:")
print(f" Jobs Launched: {total_jobs_launched:,} / {total_jobs_submitted:,}")
print(f" Jobs Completed: {total_jobs_completed:,} / {total_jobs_submitted:,} ({total_agent_completion_rate:.1f}%)")
print(f" Average Wait Time: {avg_wait_time:.1f} hours")
print(f" Average Max Queue Size: {avg_max_queue:.0f}")
print(f" Average Pending Jobs At Episode End: {avg_pending_jobs_end:.1f}")
print(f" Average Overdue Jobs At Episode End: {avg_overdue_jobs_end:.1f}")
print(f" Total Cost: €{total_agent_cost:,.0f}")
print(f" Job Arrivals/Hour (mean ± std): {mean_arrivals_per_hour:.2f} ± {std_arrivals_per_hour:.2f}")
print(f"\nBaseline:")
print(f" Jobs Launched: {total_baseline_launched:,} / {total_baseline_submitted:,}")
print(f" Jobs Completed: {total_baseline_completed:,} / {total_baseline_submitted:,} ({total_baseline_completion_rate:.1f}%)")
print(f" Average Wait Time: {avg_baseline_wait_time:.1f} hours")
print(f" Average Max Queue Size: {avg_baseline_max_queue:.0f}")
print(f" Baseline Total Cost: €{total_baseline_cost:,.0f}")
print(f" Baseline_off Total Cost: €{total_baseline_off_cost:,.0f}")
print(f"\n=== COST PER 1,000 COMPLETED JOBS ===")
print(f" Agent: {fmt_optional(total_agent_cost_per_1000_completed, 2, thousands=True)} €/1k jobs")
print(f" Baseline: {fmt_optional(total_baseline_cost_per_1000_completed, 2, thousands=True)} €/1k jobs")
print(f" Baseline_off: {fmt_optional(total_baseline_off_cost_per_1000_completed, 2, thousands=True)} €/1k jobs")
print(f"\n=== AGENT LOST JOBS PER SAVED EURO ===")
print(f" Total Lost Jobs (Agent): {total_jobs_dropped:,}")
print(f" Total Lost Jobs (Baseline): {total_baseline_jobs_dropped:,}")
print(f" Vs Baseline: {fmt_optional(total_dropped_jobs_per_saved_euro, 6)} jobs/€")
print(f" Vs Baseline_off: {fmt_optional(total_dropped_jobs_per_saved_euro_off, 6)} jobs/€")
print(f"\n=== POWER & PRICE METRICS (TOTAL OVER EVALUATION) ===")
print(f" Agent: Power={total_agent_prop_power_mwh:,.1f} MWh, Mean Price={total_agent_prop_mean_price:.2f} €/MWh")
print(f" Baseline: Power={total_baseline_prop_power_mwh:,.1f} MWh, Mean Price={total_baseline_prop_mean_price:.2f} €/MWh")
print(f" Baseline_off: Power={total_baseline_off_prop_power_mwh:,.1f} MWh, Mean Price={total_baseline_off_prop_mean_price:.2f} €/MWh")
print(" Note: the monthly reduction above is an average of monthly flat-cost percentages;")
print(" the proportional savings below are total-over-evaluation ratios on proportional costs.")
print(f"\n=== PROPORTIONAL COST SAVINGS (TOTAL OVER EVALUATION) ===")
print(f" Vs Baseline: €{total_savings_prop_cost_vs_baseline:,.0f}, {fmt_optional(prop_savings_pct_vs_baseline, 1)}%")
print(f" Vs Baseline_off: €{total_savings_prop_cost_vs_baseline_off:,.0f}, {fmt_optional(prop_savings_pct_vs_baseline_off, 1)}%")
if args.oracle:
total_oracle_liquid_cost = env.metrics.oracle_cost
total_oracle_contiguous_cost = env.metrics.oracle_contiguous_cost
print("\n=== ORACLE BENCHMARKS (THEORETICAL LOWER BOUNDS) ===")
if total_oracle_liquid_cost != 0.0:
print("\n Liquid Oracle (optimistic — allows job splitting):")
print(f" Total Cost: €{total_oracle_liquid_cost:,.0f}")
liq_window = total_baseline_off_cost - total_oracle_liquid_cost
liq_capture = safe_ratio((total_baseline_off_cost - total_agent_cost) * 100.0, liq_window)
print(f" Max Shifting Window (fluid): €{liq_window:,.0f} (baseline_off - oracle_liq)")
print(f" Agent Capture Rate: {fmt_optional(liq_capture, 1)}%")
if total_oracle_contiguous_cost != 0.0:
total_oracle_contiguous_unscheduled = env.metrics.oracle_contiguous_unscheduled
print("\n Contiguous Oracle (realistic — honors job continuity):")
print(f" Total Cost: €{total_oracle_contiguous_cost:,.0f}")
con_window = total_baseline_off_cost - total_oracle_contiguous_cost
con_capture = safe_ratio((total_baseline_off_cost - total_agent_cost) * 100.0, con_window)
agent_above = total_agent_cost - total_oracle_contiguous_cost
print(f" Max Shifting Window (real jobs): €{con_window:,.0f} (baseline_off - oracle_jcg)")
print(f" Agent Gap to Oracle: €{agent_above:,.0f} (agent - oracle_jcg)")
print(f" Agent Capture Rate: {fmt_optional(con_capture, 1)}%")
agent_above_rel = safe_ratio((total_agent_cost - total_oracle_contiguous_cost) * 100.0, total_oracle_contiguous_cost)
print(f" Agent above oracle (rel): {fmt_optional(agent_above_rel, 1)}%")
oracle_below_rel = safe_ratio((total_agent_cost - total_oracle_contiguous_cost) * 100.0, total_agent_cost)
print(f" Oracle below agent (rel): {fmt_optional(oracle_below_rel, 1)}%")
oracle_below_bl_rel = safe_ratio((total_baseline_off_cost - total_oracle_contiguous_cost) * 100.0, total_baseline_off_cost)
print(f" Oracle below baseline (rel): {fmt_optional(oracle_below_bl_rel, 1)}%")
oba_vs_obb = (
safe_ratio(oracle_below_rel * 100.0, oracle_below_bl_rel)
if (oracle_below_rel is not None and oracle_below_bl_rel is not None)
else None
)
print(f" oba vs obb: {fmt_optional(oba_vs_obb, 1)}%")
total_oracle_contiguous_spillover = env.metrics.oracle_contiguous_spillover
if total_oracle_contiguous_unscheduled > 0:
print(f" Unscheduled Jobs (oracle): {total_oracle_contiguous_unscheduled} (capacity-blocked)")
if total_oracle_contiguous_spillover > 0:
print(f" Cross-Episode Spillover (oracle): {total_oracle_contiguous_spillover} (carried to next episode)")
if total_oracle_liquid_cost != 0.0 and total_oracle_contiguous_cost != 0.0:
continuity_cost = total_oracle_contiguous_cost - total_oracle_liquid_cost
print(f"\n Continuity Constraint Cost: €{continuity_cost:,.0f} (oracle_jcg - oracle_liq)")
except Exception as e:
print(f"Could not generate cumulative savings plot: {e}")
# Optional: single dashboard plot combining the per-hour traces from the LAST episode
# and cumulative savings across all evaluated episodes.
if args.plot_dashboard:
try:
plot_dashboard(
env,
num_hours=args.dashboard_hours,
max_nodes=335,
save=True,
show=(args.render == "human"),
suffix=f"eval_{args.eval_months}m",
)
except Exception as e:
print(f"Could not generate dashboard plot: {e}")
print("\nEvaluation complete!")
env.close()
return
try:
while True:
print(f"Training iteration {iters + 1} ({STEPS_PER_ITERATION * (iters + 1)} steps)...")
iters += 1
t0 = time.time()
if (iters+1)%10==0:
print(f"Running... at {iters + 1} of {STEPS_PER_ITERATION * (iters + 1)} steps")
if args.iter_limit > 0 and iters > args.iter_limit:
print(f"iterations limit ({args.iter_limit}) reached: {iters}.")
break
try:
if args.plot_dashboard and iters > 1 and (STEPS_PER_ITERATION * (iters - 1)) % args.dashboard_interval == 0: # Only plot after the first iteration to avoid empty data
try:
plot_dashboard(
env,
num_hours=args.dashboard_hours,
max_nodes=335,
save=True,
show=False,
suffix=STEPS_PER_ITERATION * iters,
)
except Exception as e:
print(f"Dashboard plot failed (non-fatal): {e}")
model.learn(total_timesteps=STEPS_PER_ITERATION, reset_num_timesteps=False, tb_log_name=f"PPO", callback=ComputeClusterCallback())
print(f"Iteration {iters} finished in {time.time()-t0:.2f}s")
model.save(f"{models_dir}/{STEPS_PER_ITERATION * iters}.zip")
except PlottingComplete:
print("Plotting complete, terminating training...")
break
except KeyboardInterrupt:
print("Training interrupted by user.")
finally:
print("Exiting training...")
env.close()
if __name__ == "__main__":
main()