Skip to content

Commit 5e26346

Browse files
FBumannclaudeFabianHofmann
authored
perf: direct CSR-to-LP writer for frozen constraints (#631)
* perf: direct CSR-to-LP writer for frozen constraints Override Constraint.to_polars() to expand CSR data directly into a polars DataFrame, bypassing the expensive mutable() → xarray Dataset reconstruction. Also override iterate_slices() to yield CSR row-batches instead of relying on xarray's isel(). Move eliminate_zeros() to freeze time (from_mutable) so the cleanup happens once rather than on every to_polars() call. LP write is now 20-40% faster than master across all benchmark models. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: handle mixed per-row signs in CSR-to-LP writer When _sign is a numpy array (per-row signs from from_rule with mixed <=/>=/= constraints), expand it per-nonzero via _sign[rows] instead of using pl.lit() which only works for scalar strings. Also slice _sign in iterate_slices when it's an array. Add test for frozen mixed-sign constraint LP output equivalence. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Fabian Hofmann <fab.hof@gmx.de>
1 parent 2ab5ed4 commit 5e26346

2 files changed

Lines changed: 128 additions & 2 deletions

File tree

linopy/constraints.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -938,8 +938,73 @@ def mutable(self) -> Constraint:
938938
return Constraint(self.data, self._model, self._name)
939939

940940
def to_polars(self) -> pl.DataFrame:
941-
"""Convert to polars DataFrame — delegates to mutable()."""
942-
return self.mutable().to_polars()
941+
"""Convert frozen constraint to polars DataFrame directly from CSR."""
942+
csr = self._csr
943+
sign_dtype = pl.Enum(["=", "<=", ">="])
944+
if csr.nnz == 0:
945+
return pl.DataFrame(
946+
schema={
947+
"labels": pl.Int64,
948+
"coeffs": pl.Float64,
949+
"vars": pl.Int64,
950+
"sign": sign_dtype,
951+
"rhs": pl.Float64,
952+
}
953+
)
954+
955+
rows = np.repeat(np.arange(csr.shape[0]), np.diff(csr.indptr))
956+
vlabels = self._model.variables.label_index.vlabels
957+
958+
data: dict[str, Any] = {
959+
"labels": self._con_labels[rows],
960+
"coeffs": csr.data,
961+
"vars": vlabels[csr.indices],
962+
"rhs": self._rhs[rows],
963+
}
964+
if isinstance(self._sign, str):
965+
data["sign"] = pl.Series(
966+
"sign", [self._sign], dtype=sign_dtype
967+
).new_from_index(0, len(rows))
968+
else:
969+
data["sign"] = pl.Series("sign", self._sign[rows], dtype=sign_dtype)
970+
return pl.DataFrame(data)[["labels", "coeffs", "vars", "sign", "rhs"]]
971+
972+
def iterate_slices(
973+
self,
974+
slice_size: int | None = 2_000_000,
975+
slice_dims: list | None = None,
976+
) -> Iterator[CSRConstraint]:
977+
"""Yield row-batched sub-Constraints without Dataset reconstruction."""
978+
nnz = self._csr.nnz
979+
if slice_size is None or nnz <= slice_size:
980+
yield self
981+
return
982+
983+
n = self._csr.shape[0]
984+
cumulative = np.cumsum(np.diff(self._csr.indptr))
985+
batch_start = 0
986+
for batch_end_nnz in range(slice_size, nnz + slice_size, slice_size):
987+
batch_end = int(np.searchsorted(cumulative, batch_end_nnz, side="right"))
988+
batch_end = max(batch_end, batch_start + 1)
989+
if batch_end >= n:
990+
batch_end = n
991+
sign = (
992+
self._sign
993+
if isinstance(self._sign, str)
994+
else self._sign[batch_start:batch_end]
995+
)
996+
yield CSRConstraint(
997+
csr=self._csr[batch_start:batch_end],
998+
con_labels=self._con_labels[batch_start:batch_end],
999+
rhs=self._rhs[batch_start:batch_end],
1000+
sign=sign,
1001+
coords=self._coords,
1002+
model=self._model,
1003+
name=self._name,
1004+
)
1005+
batch_start = batch_end
1006+
if batch_start >= n:
1007+
break
9431008

9441009
@classmethod
9451010
def from_mutable(
@@ -958,6 +1023,7 @@ def from_mutable(
9581023
"""
9591024
label_index = con.model.variables.label_index
9601025
csr, con_labels = con.to_matrix(label_index)
1026+
csr.eliminate_zeros()
9611027
coords = [con.indexes[d] for d in con.coord_dims]
9621028
# Build active_mask aligned with con_labels (rows in csr)
9631029
# Use same filter as to_matrix: label != -1 AND at least one var != -1

test/test_io.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,3 +447,63 @@ def test_to_file_lp_mixed_sign_constraints(tmp_path: Path) -> None:
447447
assert "<=" in content
448448
assert ">=" in content
449449
assert "=" in content
450+
451+
452+
def test_to_file_lp_frozen_vs_mutable(tmp_path: Path) -> None:
453+
"""Test that frozen and mutable constraints produce identical LP output."""
454+
m_frozen = Model()
455+
N = np.arange(5)
456+
x = m_frozen.add_variables(coords=[N], name="x")
457+
y = m_frozen.add_variables(coords=[N], name="y")
458+
m_frozen.add_constraints(x + y <= 10, name="upper")
459+
m_frozen.add_constraints(x >= 1, name="lower")
460+
m_frozen.add_constraints(2 * x + y == 8, name="eq")
461+
m_frozen.add_objective(x.sum() + 2 * y.sum())
462+
463+
m_mutable = Model()
464+
x2 = m_mutable.add_variables(coords=[N], name="x")
465+
y2 = m_mutable.add_variables(coords=[N], name="y")
466+
m_mutable.add_constraints(x2 + y2 <= 10, name="upper", freeze=False)
467+
m_mutable.add_constraints(x2 >= 1, name="lower", freeze=False)
468+
m_mutable.add_constraints(2 * x2 + y2 == 8, name="eq", freeze=False)
469+
m_mutable.add_objective(x2.sum() + 2 * y2.sum())
470+
471+
fn_frozen = tmp_path / "frozen.lp"
472+
fn_mutable = tmp_path / "mutable.lp"
473+
m_frozen.to_file(fn_frozen)
474+
m_mutable.to_file(fn_mutable)
475+
476+
assert fn_frozen.read_text() == fn_mutable.read_text()
477+
478+
479+
def test_to_file_lp_frozen_mixed_sign(tmp_path: Path) -> None:
480+
"""Test LP writing for frozen constraint with per-row signs."""
481+
m_frozen = Model()
482+
N = pd.RangeIndex(4, name="i")
483+
x = m_frozen.add_variables(coords=[N], name="x")
484+
485+
def bound(m: Model, i: int) -> object:
486+
if i % 2:
487+
return x.at[i] >= i
488+
return x.at[i] <= 10
489+
490+
m_frozen.add_constraints(bound, coords=[N], name="mixed", freeze=True)
491+
m_frozen.add_objective(x.sum())
492+
493+
m_mutable = Model()
494+
x2 = m_mutable.add_variables(coords=[N], name="x")
495+
496+
def bound2(m: Model, i: int) -> object:
497+
if i % 2:
498+
return x2.at[i] >= i
499+
return x2.at[i] <= 10
500+
501+
m_mutable.add_constraints(bound2, coords=[N], name="mixed", freeze=False)
502+
m_mutable.add_objective(x2.sum())
503+
504+
fn_frozen = tmp_path / "frozen_mixed.lp"
505+
fn_mutable = tmp_path / "mutable_mixed.lp"
506+
m_frozen.to_file(fn_frozen)
507+
m_mutable.to_file(fn_mutable)
508+
509+
assert fn_frozen.read_text() == fn_mutable.read_text()

0 commit comments

Comments
 (0)