Skip to content

Commit bc3b49e

Browse files
committed
Update benchmark: add --include-solve flag for PR #619 solution-unpack benchmark
1 parent 8290812 commit bc3b49e

File tree

1 file changed

+109
-11
lines changed

1 file changed

+109
-11
lines changed

benchmark/scripts/benchmark_matrix_gen.py

Lines changed: 109 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
"""
3-
Benchmark script for linopy matrix generation performance.
3+
Benchmark script for linopy matrix generation and solution-unpacking performance.
44
55
Covers the code paths optimised by PRs #616–#619:
66
- #616 cached_property on MatrixAccessor (flat_vars / flat_cons)
@@ -11,13 +11,16 @@
1111
Usage
1212
-----
1313
# Quick run (24 snapshots only):
14-
python dev-scripts/benchmark_matrix_gen.py --quick
14+
python benchmark/scripts/benchmark_matrix_gen.py --quick
1515
16-
# Full sweep with JSON output:
17-
python dev-scripts/benchmark_matrix_gen.py -o results.json --label "after-PR-616"
16+
# Full matrix-generation sweep with JSON output:
17+
python benchmark/scripts/benchmark_matrix_gen.py -o results.json --label "after-PR-616"
18+
19+
# Include solution-unpacking benchmark (requires HiGHS solver, #619):
20+
python benchmark/scripts/benchmark_matrix_gen.py --include-solve -o results.json
1821
1922
# Compare two runs:
20-
python dev-scripts/benchmark_matrix_gen.py --compare before.json after.json
23+
python benchmark/scripts/benchmark_matrix_gen.py --compare before.json after.json
2124
"""
2225

2326
from __future__ import annotations
@@ -38,19 +41,17 @@
3841
# Model builders
3942
# ---------------------------------------------------------------------------
4043

41-
def build_scigrid(n_snapshots: int):
42-
"""Return a linopy Model from PyPSA SciGrid-DE with extended snapshots."""
44+
def build_scigrid_network(n_snapshots: int):
45+
"""Return a PyPSA Network (SciGrid-DE) with extended snapshots, without building the model."""
4346
import pypsa
4447

4548
n = pypsa.examples.scigrid_de()
4649
orig_snapshots = n.snapshots
4750
orig_len = len(orig_snapshots)
4851

49-
# Create unique snapshots by extending from the first timestamp
5052
new_snapshots = pd.date_range(orig_snapshots[0], periods=n_snapshots, freq="h")
5153
n.set_snapshots(new_snapshots)
5254

53-
# Tile time-varying data to fill extended snapshots
5455
for component_t in (n.generators_t, n.loads_t, n.storage_units_t):
5556
for attr in list(component_t):
5657
df = getattr(component_t, attr)
@@ -60,7 +61,12 @@ def build_scigrid(n_snapshots: int):
6061
setattr(component_t, attr, pd.DataFrame(
6162
tiled, index=new_snapshots, columns=df.columns,
6263
))
64+
return n
6365

66+
67+
def build_scigrid(n_snapshots: int):
68+
"""Return a linopy Model from PyPSA SciGrid-DE with extended snapshots."""
69+
n = build_scigrid_network(n_snapshots)
6470
n.optimize.create_model(include_objective_constant=False)
6571
return n.model
6672

@@ -152,6 +158,82 @@ def do_full():
152158
return results
153159

154160

161+
# ---------------------------------------------------------------------------
162+
# Solution-unpacking benchmark (#619)
163+
# ---------------------------------------------------------------------------
164+
165+
def benchmark_solution_unpack(n_snapshots: int, repeats: int = 3) -> list[dict]:
166+
"""
167+
Benchmark the solution-assignment loop in Model.solve (PR #619).
168+
169+
Strategy: solve once with HiGHS to get a real solution vector, then
170+
re-run only the assignment loop (sol[idx] → var.solution) repeatedly
171+
without re-solving, isolating the unpacking cost from solver time.
172+
"""
173+
import xarray as xr
174+
175+
n = build_scigrid_network(n_snapshots)
176+
n.optimize.create_model(include_objective_constant=False)
177+
model = n.model
178+
179+
# Solve once to populate the raw solution
180+
status, _ = model.solve(solver_name="highs", io_api="direct")
181+
if status != "ok":
182+
print(f" WARNING: solve failed ({status}), skipping solution-unpack benchmark")
183+
return []
184+
185+
# Reconstruct the raw solution Series (as returned by the solver):
186+
# a float-indexed Series mapping variable label → solution value.
187+
nan = float("nan")
188+
parts = []
189+
for name, var in model.variables.items():
190+
if var.solution is None:
191+
continue
192+
labels = np.ravel(var.labels)
193+
values = np.ravel(var.solution.values)
194+
parts.append(pd.Series(values, index=labels.astype(float)))
195+
if not parts:
196+
print(" WARNING: no solution found on variables, was solve successful?")
197+
return []
198+
sol_series = pd.concat(parts).drop_duplicates()
199+
sol_series.loc[-1] = nan
200+
201+
n_vars = sum(np.ravel(model.variables[name].labels).size for name in model.variables)
202+
results = []
203+
204+
# ----- Old path (pandas label-based, pre-#619) -----
205+
def unpack_pandas():
206+
for name, var in model.variables.items():
207+
idx = np.ravel(var.labels).astype(float)
208+
try:
209+
vals = sol_series[idx].values.reshape(var.labels.shape)
210+
except KeyError:
211+
vals = sol_series.reindex(idx).values.reshape(var.labels.shape)
212+
var.solution = xr.DataArray(vals, var.coords)
213+
214+
results.append(time_phase(unpack_pandas, "unpack_pandas (before)", repeats))
215+
216+
# ----- New path (numpy dense array, #619) -----
217+
def unpack_numpy():
218+
sol_max_idx = int(max(sol_series.index.max(), 0))
219+
sol_arr = np.full(sol_max_idx + 1, nan)
220+
mask = sol_series.index >= 0
221+
valid = sol_series.index[mask].astype(int)
222+
sol_arr[valid] = sol_series.values[mask]
223+
for name, var in model.variables.items():
224+
idx = np.ravel(var.labels)
225+
safe_idx = np.clip(idx, 0, sol_max_idx)
226+
vals = sol_arr[safe_idx]
227+
vals[idx < 0] = nan
228+
var.solution = xr.DataArray(vals.reshape(var.labels.shape), var.coords)
229+
230+
results.append(time_phase(unpack_numpy, "unpack_numpy (after)", repeats))
231+
232+
for r in results:
233+
r.update(model_type="scigrid_solve", size=n_snapshots, n_vars=n_vars, n_cons=0)
234+
return results
235+
236+
155237
# ---------------------------------------------------------------------------
156238
# Runners
157239
# ---------------------------------------------------------------------------
@@ -161,7 +243,7 @@ def do_full():
161243
SYNTHETIC_SIZES = [20, 50, 100, 200]
162244

163245

164-
def run_benchmarks(model_type: str, quick: bool, repeats: int) -> list[dict]:
246+
def run_benchmarks(model_type: str, quick: bool, repeats: int, include_solve: bool = False) -> list[dict]:
165247
"""Run benchmarks across problem sizes, return flat list of results."""
166248
all_results = []
167249

@@ -185,6 +267,18 @@ def run_benchmarks(model_type: str, quick: bool, repeats: int) -> list[dict]:
185267
del model
186268
gc.collect()
187269

270+
if include_solve:
271+
# Solution-unpacking benchmark for PR #619 (SciGrid-DE only, small sizes)
272+
solve_sizes = QUICK_SNAPSHOTS if quick else [24, 100]
273+
for n_snap in solve_sizes:
274+
print(f"\n{'='*60}")
275+
print(f"SciGrid-DE solve + unpack {n_snap} snapshots (#619)")
276+
print(f"{'='*60}")
277+
for r in benchmark_solution_unpack(n_snap, repeats):
278+
all_results.append(r)
279+
print(f" {r['phase']:30s} {r['best_s']:.4f}s (median {r['median_s']:.4f}s)")
280+
gc.collect()
281+
188282
if model_type in ("synthetic", "all"):
189283
sizes = [20, 50] if quick else SYNTHETIC_SIZES
190284
for n in sizes:
@@ -246,6 +340,10 @@ def main():
246340
"--compare", nargs=2, metavar=("BEFORE", "AFTER"),
247341
help="Compare two JSON result files instead of running benchmarks",
248342
)
343+
parser.add_argument(
344+
"--include-solve", action="store_true",
345+
help="Also benchmark solution unpacking (PR #619); requires HiGHS solver",
346+
)
249347
args = parser.parse_args()
250348

251349
if args.compare:
@@ -258,7 +356,7 @@ def main():
258356
print(f"Python {sys.version.split()[0]}, numpy {np.__version__}, "
259357
f"{platform.machine()}, {platform.system()}")
260358

261-
results = run_benchmarks(args.model, args.quick, args.repeats)
359+
results = run_benchmarks(args.model, args.quick, args.repeats, args.include_solve)
262360

263361
if args.output:
264362
out = {

0 commit comments

Comments
 (0)