Skip to content

Commit 472ecc9

Browse files
MaykThewessenclaudepre-commit-ci[bot]FBumannFabianHofmann
authored
perf: use numpy array lookup for solution unpacking (#619)
* perf: use numpy array lookup for solution unpacking Convert the primal/dual pandas Series to a dense numpy lookup array before the per-variable/per-constraint unpacking loop. This replaces pandas indexing (sol[idx].values) with direct numpy array indexing (sol_arr[idx]), avoiding pandas overhead per variable type. The loop over variable/constraint types still exists (needed to set each variable's .solution xr.DataArray), but the inner indexing operation is now pure numpy instead of pandas Series.__getitem__. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Add reproducible benchmark script for PRs #616#619 Adds benchmark/scripts/benchmark_matrix_gen.py covering all four performance code paths: - #616 cached_property on MatrixAccessor (flat_vars / flat_cons) - #617 np.char.add label string concatenation - #618 single-step sparse matrix slicing - #619 numpy dense-array solution unpacking Reproduce with: python benchmark/scripts/benchmark_matrix_gen.py -o results.json python benchmark/scripts/benchmark_matrix_gen.py --include-solve # PR #619 path python benchmark/scripts/benchmark_matrix_gen.py --compare before.json after.json Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Delete benchmark/scripts/benchmark_matrix_gen.py * Replace pandas-based solution unpacking with numpy dense array lookup (2-6x faster) Extract series_to_lookup_array/lookup_vals helpers to linopy/common.py. Fix critical bug where out-of-range labels silently mapped to wrong values. --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: FBumann <117816358+FBumann@users.noreply.github.com> Co-authored-by: Fabian Hofmann <fab.hof@gmx.de>
1 parent 43af227 commit 472ecc9

3 files changed

Lines changed: 136 additions & 15 deletions

File tree

linopy/common.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919
import pandas as pd
2020
import polars as pl
21-
from numpy import arange, signedinteger
21+
from numpy import arange, nan, signedinteger
2222
from xarray import DataArray, Dataset, apply_ufunc, broadcast
2323
from xarray import align as xr_align
2424
from xarray.core import dtypes, indexing
@@ -1393,3 +1393,51 @@ def is_constant(x: SideLike) -> bool:
13931393
"Expected a constant, variable, or expression on the constraint side, "
13941394
f"got {type(x)}."
13951395
)
1396+
1397+
1398+
def series_to_lookup_array(s: pd.Series) -> np.ndarray:
1399+
"""
1400+
Convert an integer-indexed Series to a dense numpy lookup array.
1401+
1402+
Non-negative indices are placed at their corresponding positions;
1403+
negative indices are ignored. Gaps are filled with NaN.
1404+
1405+
Parameters
1406+
----------
1407+
s : pd.Series
1408+
Series with an integer index.
1409+
1410+
Returns
1411+
-------
1412+
np.ndarray
1413+
Dense array of length ``max(index) + 1``.
1414+
"""
1415+
max_idx = max(int(s.index.max()), 0)
1416+
arr = np.full(max_idx + 1, nan)
1417+
mask = s.index >= 0
1418+
arr[s.index[mask]] = s.values[mask]
1419+
return arr
1420+
1421+
1422+
def lookup_vals(arr: np.ndarray, idx: np.ndarray) -> np.ndarray:
1423+
"""
1424+
Look up values from a dense array by integer labels.
1425+
1426+
Negative labels and labels beyond the array length map to NaN.
1427+
1428+
Parameters
1429+
----------
1430+
arr : np.ndarray
1431+
Dense lookup array (e.g. from :func:`series_to_lookup_array`).
1432+
idx : np.ndarray
1433+
Integer label indices.
1434+
1435+
Returns
1436+
-------
1437+
np.ndarray
1438+
Array of looked-up values with the same shape as *idx*.
1439+
"""
1440+
valid = (idx >= 0) & (idx < len(arr))
1441+
vals = np.full(idx.shape, nan)
1442+
vals[valid] = arr[idx[valid]]
1443+
return vals

linopy/model.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
assign_multiindex_safe,
3232
best_int,
3333
broadcast_mask,
34+
lookup_vals,
3435
maybe_replace_signs,
3536
replace_by_map,
37+
series_to_lookup_array,
3638
set_int_index,
3739
to_path,
3840
)
@@ -1591,26 +1593,24 @@ def solve(
15911593
sol = set_int_index(sol)
15921594
sol.loc[-1] = nan
15931595

1594-
for name, var in self.variables.items():
1595-
idx = np.ravel(var.labels)
1596-
try:
1597-
vals = sol[idx].values.reshape(var.labels.shape)
1598-
except KeyError:
1599-
vals = sol.reindex(idx).values.reshape(var.labels.shape)
1600-
var.solution = xr.DataArray(vals, var.coords)
1596+
sol_arr = series_to_lookup_array(sol)
1597+
1598+
for _, var in self.variables.items():
1599+
vals = lookup_vals(sol_arr, np.ravel(var.labels))
1600+
var.solution = xr.DataArray(vals.reshape(var.labels.shape), var.coords)
16011601

16021602
if not result.solution.dual.empty:
16031603
dual = result.solution.dual.copy()
16041604
dual = set_int_index(dual)
16051605
dual.loc[-1] = nan
16061606

1607-
for name, con in self.constraints.items():
1608-
idx = np.ravel(con.labels)
1609-
try:
1610-
vals = dual[idx].values.reshape(con.labels.shape)
1611-
except KeyError:
1612-
vals = dual.reindex(idx).values.reshape(con.labels.shape)
1613-
con.dual = xr.DataArray(vals, con.labels.coords)
1607+
dual_arr = series_to_lookup_array(dual)
1608+
1609+
for _, con in self.constraints.items():
1610+
vals = lookup_vals(dual_arr, np.ravel(con.labels))
1611+
con.dual = xr.DataArray(
1612+
vals.reshape(con.labels.shape), con.labels.coords
1613+
)
16141614

16151615
return result.status.status.value, result.status.termination_condition.value
16161616
finally:

test/test_solution_lookup.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import numpy as np
2+
import pandas as pd
3+
from numpy import nan
4+
5+
from linopy.common import lookup_vals, series_to_lookup_array
6+
7+
8+
class TestSeriesToLookupArray:
9+
def test_basic(self) -> None:
10+
s = pd.Series([10.0, 20.0, 30.0], index=pd.Index([0, 1, 2]))
11+
arr = series_to_lookup_array(s)
12+
np.testing.assert_array_equal(arr, [10.0, 20.0, 30.0])
13+
14+
def test_with_negative_index(self) -> None:
15+
s = pd.Series([nan, 10.0, 20.0], index=pd.Index([-1, 0, 2]))
16+
arr = series_to_lookup_array(s)
17+
assert arr[0] == 10.0
18+
assert np.isnan(arr[1])
19+
assert arr[2] == 20.0
20+
21+
def test_sparse_index(self) -> None:
22+
s = pd.Series([5.0, 7.0], index=pd.Index([0, 100]))
23+
arr = series_to_lookup_array(s)
24+
assert len(arr) == 101
25+
assert arr[0] == 5.0
26+
assert arr[100] == 7.0
27+
assert np.isnan(arr[50])
28+
29+
def test_only_negative_index(self) -> None:
30+
s = pd.Series([nan], index=pd.Index([-1]))
31+
arr = series_to_lookup_array(s)
32+
assert len(arr) == 1
33+
assert np.isnan(arr[0])
34+
35+
36+
class TestLookupVals:
37+
def test_basic(self) -> None:
38+
arr = np.array([10.0, 20.0, 30.0])
39+
idx = np.array([0, 1, 2])
40+
result = lookup_vals(arr, idx)
41+
np.testing.assert_array_equal(result, [10.0, 20.0, 30.0])
42+
43+
def test_negative_labels_become_nan(self) -> None:
44+
arr = np.array([10.0, 20.0])
45+
idx = np.array([0, -1, 1, -1])
46+
result = lookup_vals(arr, idx)
47+
assert result[0] == 10.0
48+
assert np.isnan(result[1])
49+
assert result[2] == 20.0
50+
assert np.isnan(result[3])
51+
52+
def test_out_of_range_labels_become_nan(self) -> None:
53+
arr = np.array([10.0, 20.0])
54+
idx = np.array([0, 1, 999])
55+
result = lookup_vals(arr, idx)
56+
assert result[0] == 10.0
57+
assert result[1] == 20.0
58+
assert np.isnan(result[2])
59+
60+
def test_all_negative(self) -> None:
61+
arr = np.array([10.0])
62+
idx = np.array([-1, -1, -1])
63+
result = lookup_vals(arr, idx)
64+
assert all(np.isnan(result))
65+
66+
def test_no_mutation_of_source(self) -> None:
67+
arr = np.array([10.0, 20.0, 30.0])
68+
idx1 = np.array([-1, 1])
69+
idx2 = np.array([0, 2])
70+
lookup_vals(arr, idx1)
71+
result2 = lookup_vals(arr, idx2)
72+
np.testing.assert_array_equal(result2, [10.0, 30.0])
73+
np.testing.assert_array_equal(arr, [10.0, 20.0, 30.0])

0 commit comments

Comments
 (0)