Skip to content

Commit a293b64

Browse files
committed
perf: Add maybe_group_terms_polars() helper in common.py that checks for duplicate (labels, vars) pairs before calling group_terms_polars. Use it in both Constraint.to_polars() and LinearExpression.to_polars() to avoid expensive group_by when terms already reference distinct variables
1 parent 68f1adc commit a293b64

3 files changed

Lines changed: 23 additions & 9 deletions

File tree

linopy/common.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,25 @@ def group_terms_polars(df: pl.DataFrame) -> pl.DataFrame:
449449
return df
450450

451451

452+
def maybe_group_terms_polars(df: pl.DataFrame) -> pl.DataFrame:
453+
"""
454+
Group terms only if there are duplicate (labels, vars) pairs.
455+
456+
This avoids the expensive group_by operation when terms already
457+
reference distinct variables (e.g. ``x - y`` has ``_term=2`` but
458+
no duplicates). When skipping, columns are reordered to match the
459+
output of ``group_terms_polars``.
460+
"""
461+
varcols = [c for c in df.columns if c.startswith("vars")]
462+
keys = [c for c in ["labels"] + varcols if c in df.columns]
463+
key_count = df.select(pl.struct(keys).n_unique()).item()
464+
if key_count < df.height:
465+
return group_terms_polars(df)
466+
# Match column order of group_terms (group-by keys, coeffs, rest)
467+
rest = [c for c in df.columns if c not in keys and c != "coeffs"]
468+
return df.select(keys + ["coeffs"] + rest)
469+
470+
452471
def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset:
453472
"""
454473
Join multiple xarray Dataarray's to a Dataset and warn if coordinates are not equal.

linopy/constraints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@
4040
generate_indices_for_printout,
4141
get_dims_with_index_levels,
4242
get_label_position,
43-
group_terms_polars,
4443
has_optimized_model,
4544
iterate_slices,
45+
maybe_group_terms_polars,
4646
maybe_replace_signs,
4747
print_coord,
4848
print_single_constraint,
@@ -622,7 +622,7 @@ def to_polars(self) -> pl.DataFrame:
622622

623623
long = filter_nulls_polars(long)
624624
if ds.sizes.get("_term", 1) > 1:
625-
long = group_terms_polars(long)
625+
long = maybe_group_terms_polars(long)
626626
check_has_nulls_polars(long, name=f"{self.type} {self.name}")
627627

628628
# Build short DataFrame (labels, rhs, sign) without xarray broadcast.

linopy/expressions.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
has_optimized_model,
6161
is_constant,
6262
iterate_slices,
63+
maybe_group_terms_polars,
6364
print_coord,
6465
print_single_expression,
6566
to_dataframe,
@@ -1463,13 +1464,7 @@ def to_polars(self) -> pl.DataFrame:
14631464

14641465
df = to_polars(self.data)
14651466
df = filter_nulls_polars(df)
1466-
if df["vars"].n_unique() < df.height:
1467-
df = group_terms_polars(df)
1468-
else:
1469-
# Match column order of group_terms (group-by keys, coeffs, rest)
1470-
varcols = [c for c in df.columns if c.startswith("vars")]
1471-
rest = [c for c in df.columns if c not in varcols and c != "coeffs"]
1472-
df = df.select(varcols + ["coeffs"] + rest)
1467+
df = maybe_group_terms_polars(df)
14731468
check_has_nulls_polars(df, name=self.type)
14741469
return df
14751470

0 commit comments

Comments
 (0)