Skip to content

Commit 656d283

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d15e3f2 commit 656d283

File tree

1 file changed

+69
-33
lines changed

1 file changed

+69
-33
lines changed

benchmark/scripts/benchmark_matrix_gen.py

Lines changed: 69 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@
3636
import numpy as np
3737
import pandas as pd
3838

39-
4039
# ---------------------------------------------------------------------------
4140
# Model builders
4241
# ---------------------------------------------------------------------------
4342

43+
4444
def build_scigrid_network(n_snapshots: int):
4545
"""Return a PyPSA Network (SciGrid-DE) with extended snapshots, without building the model."""
4646
import pypsa
@@ -58,9 +58,15 @@ def build_scigrid_network(n_snapshots: int):
5858
if df is not None and not df.empty:
5959
tiles = int(np.ceil(n_snapshots / orig_len)) + 1
6060
tiled = np.tile(df.values, (tiles, 1))[:n_snapshots]
61-
setattr(component_t, attr, pd.DataFrame(
62-
tiled, index=new_snapshots, columns=df.columns,
63-
))
61+
setattr(
62+
component_t,
63+
attr,
64+
pd.DataFrame(
65+
tiled,
66+
index=new_snapshots,
67+
columns=df.columns,
68+
),
69+
)
6470
return n
6571

6672

@@ -89,6 +95,7 @@ def build_synthetic(n: int):
8995
# Benchmark phases
9096
# ---------------------------------------------------------------------------
9197

98+
9299
def time_phase(func, label: str, repeats: int = 3) -> dict:
93100
"""Time a callable, return best-of-N result."""
94101
times = []
@@ -101,8 +108,12 @@ def time_phase(func, label: str, repeats: int = 3) -> dict:
101108
gc.enable()
102109
times.append(elapsed)
103110
del result
104-
return {"phase": label, "best_s": min(times), "median_s": sorted(times)[len(times) // 2],
105-
"times": times}
111+
return {
112+
"phase": label,
113+
"best_s": min(times),
114+
"median_s": sorted(times)[len(times) // 2],
115+
"times": times,
116+
}
106117

107118

108119
def benchmark_model(model, repeats: int = 3) -> list[dict]:
@@ -162,6 +173,7 @@ def do_full():
162173
# Solution-unpacking benchmark (#619)
163174
# ---------------------------------------------------------------------------
164175

176+
165177
def benchmark_solution_unpack(n_snapshots: int, repeats: int = 3) -> list[dict]:
166178
"""
167179
Benchmark the solution-assignment loop in Model.solve (PR #619).
@@ -198,7 +210,9 @@ def benchmark_solution_unpack(n_snapshots: int, repeats: int = 3) -> list[dict]:
198210
sol_series = pd.concat(parts).drop_duplicates()
199211
sol_series.loc[-1] = nan
200212

201-
n_vars = sum(np.ravel(model.variables[name].labels).size for name in model.variables)
213+
n_vars = sum(
214+
np.ravel(model.variables[name].labels).size for name in model.variables
215+
)
202216
results = []
203217

204218
# ----- Old path (pandas label-based, pre-#619) -----
@@ -243,26 +257,31 @@ def unpack_numpy():
243257
SYNTHETIC_SIZES = [20, 50, 100, 200]
244258

245259

246-
def run_benchmarks(model_type: str, quick: bool, repeats: int, include_solve: bool = False) -> list[dict]:
260+
def run_benchmarks(
261+
model_type: str, quick: bool, repeats: int, include_solve: bool = False
262+
) -> list[dict]:
247263
"""Run benchmarks across problem sizes, return flat list of results."""
248264
all_results = []
249265

250266
if model_type in ("scigrid", "all"):
251267
sizes = QUICK_SNAPSHOTS if quick else FULL_SNAPSHOTS
252268
for n_snap in sizes:
253-
print(f"\n{'='*60}")
269+
print(f"\n{'=' * 60}")
254270
print(f"SciGrid-DE {n_snap} snapshots")
255-
print(f"{'='*60}")
271+
print(f"{'=' * 60}")
256272
model = build_scigrid(n_snap)
257273
n_vars = len(model.variables.flat)
258274
n_cons = len(model.constraints.flat)
259275
print(f" {n_vars:,} variables, {n_cons:,} constraints")
260276

261277
for r in benchmark_model(model, repeats):
262-
r.update(model_type="scigrid", size=n_snap,
263-
n_vars=n_vars, n_cons=n_cons)
278+
r.update(
279+
model_type="scigrid", size=n_snap, n_vars=n_vars, n_cons=n_cons
280+
)
264281
all_results.append(r)
265-
print(f" {r['phase']:20s} {r['best_s']:.4f}s (median {r['median_s']:.4f}s)")
282+
print(
283+
f" {r['phase']:20s} {r['best_s']:.4f}s (median {r['median_s']:.4f}s)"
284+
)
266285

267286
del model
268287
gc.collect()
@@ -271,29 +290,32 @@ def run_benchmarks(model_type: str, quick: bool, repeats: int, include_solve: bo
271290
# Solution-unpacking benchmark for PR #619 (SciGrid-DE only, small sizes)
272291
solve_sizes = QUICK_SNAPSHOTS if quick else [24, 100]
273292
for n_snap in solve_sizes:
274-
print(f"\n{'='*60}")
293+
print(f"\n{'=' * 60}")
275294
print(f"SciGrid-DE solve + unpack {n_snap} snapshots (#619)")
276-
print(f"{'='*60}")
295+
print(f"{'=' * 60}")
277296
for r in benchmark_solution_unpack(n_snap, repeats):
278297
all_results.append(r)
279-
print(f" {r['phase']:30s} {r['best_s']:.4f}s (median {r['median_s']:.4f}s)")
298+
print(
299+
f" {r['phase']:30s} {r['best_s']:.4f}s (median {r['median_s']:.4f}s)"
300+
)
280301
gc.collect()
281302

282303
if model_type in ("synthetic", "all"):
283304
sizes = [20, 50] if quick else SYNTHETIC_SIZES
284305
for n in sizes:
285-
print(f"\n{'='*60}")
286-
print(f"Synthetic N={n} ({2*n*n} vars, {2*n*n} cons)")
287-
print(f"{'='*60}")
306+
print(f"\n{'=' * 60}")
307+
print(f"Synthetic N={n} ({2 * n * n} vars, {2 * n * n} cons)")
308+
print(f"{'=' * 60}")
288309
model = build_synthetic(n)
289310
n_vars = 2 * n * n
290311
n_cons = 2 * n * n
291312

292313
for r in benchmark_model(model, repeats):
293-
r.update(model_type="synthetic", size=n,
294-
n_vars=n_vars, n_cons=n_cons)
314+
r.update(model_type="synthetic", size=n, n_vars=n_vars, n_cons=n_cons)
295315
all_results.append(r)
296-
print(f" {r['phase']:20s} {r['best_s']:.4f}s (median {r['median_s']:.4f}s)")
316+
print(
317+
f" {r['phase']:20s} {r['best_s']:.4f}s (median {r['median_s']:.4f}s)"
318+
)
297319

298320
del model
299321
gc.collect()
@@ -305,7 +327,9 @@ def format_comparison(before: list[dict], after: list[dict]) -> str:
305327
"""Format a before/after comparison table."""
306328
df_b = pd.DataFrame(before).set_index(["model_type", "size", "phase"])
307329
df_a = pd.DataFrame(after).set_index(["model_type", "size", "phase"])
308-
merged = df_b[["best_s"]].join(df_a[["best_s"]], lsuffix="_before", rsuffix="_after")
330+
merged = df_b[["best_s"]].join(
331+
df_a[["best_s"]], lsuffix="_before", rsuffix="_after"
332+
)
309333
merged["speedup"] = merged["best_s_before"] / merged["best_s_after"]
310334
lines = [
311335
f"{'Model':>10s} {'Size':>6s} {'Phase':>20s} {'Before':>8s} {'After':>8s} {'Speedup':>8s}",
@@ -324,24 +348,34 @@ def format_comparison(before: list[dict], after: list[dict]) -> str:
324348
# CLI
325349
# ---------------------------------------------------------------------------
326350

351+
327352
def main():
328353
parser = argparse.ArgumentParser(
329354
description="Benchmark linopy matrix generation (PRs #616–#619)"
330355
)
331356
parser.add_argument(
332-
"--model", choices=["scigrid", "synthetic", "all"], default="all",
357+
"--model",
358+
choices=["scigrid", "synthetic", "all"],
359+
default="all",
333360
help="Model type to benchmark (default: all)",
334361
)
335-
parser.add_argument("--quick", action="store_true", help="Quick mode (smallest sizes only)")
336-
parser.add_argument("--repeats", type=int, default=3, help="Timing repeats per phase (default: 3)")
362+
parser.add_argument(
363+
"--quick", action="store_true", help="Quick mode (smallest sizes only)"
364+
)
365+
parser.add_argument(
366+
"--repeats", type=int, default=3, help="Timing repeats per phase (default: 3)"
367+
)
337368
parser.add_argument("-o", "--output", type=str, help="Save results to JSON file")
338369
parser.add_argument("--label", type=str, default="", help="Label for this run")
339370
parser.add_argument(
340-
"--compare", nargs=2, metavar=("BEFORE", "AFTER"),
371+
"--compare",
372+
nargs=2,
373+
metavar=("BEFORE", "AFTER"),
341374
help="Compare two JSON result files instead of running benchmarks",
342375
)
343376
parser.add_argument(
344-
"--include-solve", action="store_true",
377+
"--include-solve",
378+
action="store_true",
345379
help="Also benchmark solution unpacking (PR #619); requires HiGHS solver",
346380
)
347381
args = parser.parse_args()
@@ -352,9 +386,11 @@ def main():
352386
print(format_comparison(before, after))
353387
return
354388

355-
print(f"linopy matrix generation benchmark")
356-
print(f"Python {sys.version.split()[0]}, numpy {np.__version__}, "
357-
f"{platform.machine()}, {platform.system()}")
389+
print("linopy matrix generation benchmark")
390+
print(
391+
f"Python {sys.version.split()[0]}, numpy {np.__version__}, "
392+
f"{platform.machine()}, {platform.system()}"
393+
)
358394

359395
results = run_benchmarks(args.model, args.quick, args.repeats, args.include_solve)
360396

@@ -370,9 +406,9 @@ def main():
370406
print(f"\nResults saved to {args.output}")
371407

372408
# Summary table
373-
print(f"\n{'='*60}")
409+
print(f"\n{'=' * 60}")
374410
print("Summary (best times)")
375-
print(f"{'='*60}")
411+
print(f"{'=' * 60}")
376412
df = pd.DataFrame(results)
377413
for (mtype, size), group in df.groupby(["model_type", "size"]):
378414
n_vars = group.iloc[0]["n_vars"]

0 commit comments

Comments
 (0)