Skip to content

Commit 3373a1b

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 8290812 commit 3373a1b

1 file changed

Lines changed: 55 additions & 27 deletions

File tree

benchmark/scripts/benchmark_matrix_gen.py

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@
3333
import numpy as np
3434
import pandas as pd
3535

36-
3736
# ---------------------------------------------------------------------------
3837
# Model builders
3938
# ---------------------------------------------------------------------------
4039

40+
4141
def build_scigrid(n_snapshots: int):
4242
"""Return a linopy Model from PyPSA SciGrid-DE with extended snapshots."""
4343
import pypsa
@@ -57,9 +57,15 @@ def build_scigrid(n_snapshots: int):
5757
if df is not None and not df.empty:
5858
tiles = int(np.ceil(n_snapshots / orig_len)) + 1
5959
tiled = np.tile(df.values, (tiles, 1))[:n_snapshots]
60-
setattr(component_t, attr, pd.DataFrame(
61-
tiled, index=new_snapshots, columns=df.columns,
62-
))
60+
setattr(
61+
component_t,
62+
attr,
63+
pd.DataFrame(
64+
tiled,
65+
index=new_snapshots,
66+
columns=df.columns,
67+
),
68+
)
6369

6470
n.optimize.create_model(include_objective_constant=False)
6571
return n.model
@@ -83,6 +89,7 @@ def build_synthetic(n: int):
8389
# Benchmark phases
8490
# ---------------------------------------------------------------------------
8591

92+
8693
def time_phase(func, label: str, repeats: int = 3) -> dict:
8794
"""Time a callable, return best-of-N result."""
8895
times = []
@@ -95,8 +102,12 @@ def time_phase(func, label: str, repeats: int = 3) -> dict:
95102
gc.enable()
96103
times.append(elapsed)
97104
del result
98-
return {"phase": label, "best_s": min(times), "median_s": sorted(times)[len(times) // 2],
99-
"times": times}
105+
return {
106+
"phase": label,
107+
"best_s": min(times),
108+
"median_s": sorted(times)[len(times) // 2],
109+
"times": times,
110+
}
100111

101112

102113
def benchmark_model(model, repeats: int = 3) -> list[dict]:
@@ -168,38 +179,42 @@ def run_benchmarks(model_type: str, quick: bool, repeats: int) -> list[dict]:
168179
if model_type in ("scigrid", "all"):
169180
sizes = QUICK_SNAPSHOTS if quick else FULL_SNAPSHOTS
170181
for n_snap in sizes:
171-
print(f"\n{'='*60}")
182+
print(f"\n{'=' * 60}")
172183
print(f"SciGrid-DE {n_snap} snapshots")
173-
print(f"{'='*60}")
184+
print(f"{'=' * 60}")
174185
model = build_scigrid(n_snap)
175186
n_vars = len(model.variables.flat)
176187
n_cons = len(model.constraints.flat)
177188
print(f" {n_vars:,} variables, {n_cons:,} constraints")
178189

179190
for r in benchmark_model(model, repeats):
180-
r.update(model_type="scigrid", size=n_snap,
181-
n_vars=n_vars, n_cons=n_cons)
191+
r.update(
192+
model_type="scigrid", size=n_snap, n_vars=n_vars, n_cons=n_cons
193+
)
182194
all_results.append(r)
183-
print(f" {r['phase']:20s} {r['best_s']:.4f}s (median {r['median_s']:.4f}s)")
195+
print(
196+
f" {r['phase']:20s} {r['best_s']:.4f}s (median {r['median_s']:.4f}s)"
197+
)
184198

185199
del model
186200
gc.collect()
187201

188202
if model_type in ("synthetic", "all"):
189203
sizes = [20, 50] if quick else SYNTHETIC_SIZES
190204
for n in sizes:
191-
print(f"\n{'='*60}")
192-
print(f"Synthetic N={n} ({2*n*n} vars, {2*n*n} cons)")
193-
print(f"{'='*60}")
205+
print(f"\n{'=' * 60}")
206+
print(f"Synthetic N={n} ({2 * n * n} vars, {2 * n * n} cons)")
207+
print(f"{'=' * 60}")
194208
model = build_synthetic(n)
195209
n_vars = 2 * n * n
196210
n_cons = 2 * n * n
197211

198212
for r in benchmark_model(model, repeats):
199-
r.update(model_type="synthetic", size=n,
200-
n_vars=n_vars, n_cons=n_cons)
213+
r.update(model_type="synthetic", size=n, n_vars=n_vars, n_cons=n_cons)
201214
all_results.append(r)
202-
print(f" {r['phase']:20s} {r['best_s']:.4f}s (median {r['median_s']:.4f}s)")
215+
print(
216+
f" {r['phase']:20s} {r['best_s']:.4f}s (median {r['median_s']:.4f}s)"
217+
)
203218

204219
del model
205220
gc.collect()
@@ -211,7 +226,9 @@ def format_comparison(before: list[dict], after: list[dict]) -> str:
211226
"""Format a before/after comparison table."""
212227
df_b = pd.DataFrame(before).set_index(["model_type", "size", "phase"])
213228
df_a = pd.DataFrame(after).set_index(["model_type", "size", "phase"])
214-
merged = df_b[["best_s"]].join(df_a[["best_s"]], lsuffix="_before", rsuffix="_after")
229+
merged = df_b[["best_s"]].join(
230+
df_a[["best_s"]], lsuffix="_before", rsuffix="_after"
231+
)
215232
merged["speedup"] = merged["best_s_before"] / merged["best_s_after"]
216233
lines = [
217234
f"{'Model':>10s} {'Size':>6s} {'Phase':>20s} {'Before':>8s} {'After':>8s} {'Speedup':>8s}",
@@ -230,20 +247,29 @@ def format_comparison(before: list[dict], after: list[dict]) -> str:
230247
# CLI
231248
# ---------------------------------------------------------------------------
232249

250+
233251
def main():
234252
parser = argparse.ArgumentParser(
235253
description="Benchmark linopy matrix generation (PRs #616–#619)"
236254
)
237255
parser.add_argument(
238-
"--model", choices=["scigrid", "synthetic", "all"], default="all",
256+
"--model",
257+
choices=["scigrid", "synthetic", "all"],
258+
default="all",
239259
help="Model type to benchmark (default: all)",
240260
)
241-
parser.add_argument("--quick", action="store_true", help="Quick mode (smallest sizes only)")
242-
parser.add_argument("--repeats", type=int, default=3, help="Timing repeats per phase (default: 3)")
261+
parser.add_argument(
262+
"--quick", action="store_true", help="Quick mode (smallest sizes only)"
263+
)
264+
parser.add_argument(
265+
"--repeats", type=int, default=3, help="Timing repeats per phase (default: 3)"
266+
)
243267
parser.add_argument("-o", "--output", type=str, help="Save results to JSON file")
244268
parser.add_argument("--label", type=str, default="", help="Label for this run")
245269
parser.add_argument(
246-
"--compare", nargs=2, metavar=("BEFORE", "AFTER"),
270+
"--compare",
271+
nargs=2,
272+
metavar=("BEFORE", "AFTER"),
247273
help="Compare two JSON result files instead of running benchmarks",
248274
)
249275
args = parser.parse_args()
@@ -254,9 +280,11 @@ def main():
254280
print(format_comparison(before, after))
255281
return
256282

257-
print(f"linopy matrix generation benchmark")
258-
print(f"Python {sys.version.split()[0]}, numpy {np.__version__}, "
259-
f"{platform.machine()}, {platform.system()}")
283+
print("linopy matrix generation benchmark")
284+
print(
285+
f"Python {sys.version.split()[0]}, numpy {np.__version__}, "
286+
f"{platform.machine()}, {platform.system()}"
287+
)
260288

261289
results = run_benchmarks(args.model, args.quick, args.repeats)
262290

@@ -272,9 +300,9 @@ def main():
272300
print(f"\nResults saved to {args.output}")
273301

274302
# Summary table
275-
print(f"\n{'='*60}")
303+
print(f"\n{'=' * 60}")
276304
print("Summary (best times)")
277-
print(f"{'='*60}")
305+
print(f"{'=' * 60}")
278306
df = pd.DataFrame(results)
279307
for (mtype, size), group in df.groupby(["model_type", "size"]):
280308
n_vars = group.iloc[0]["n_vars"]

0 commit comments

Comments
 (0)