Skip to content

Commit 3f52fef

Browse files
committed
test: add coverage for streaming fallback and maybe_group_terms_polars
1 parent 04c4bea commit 3f52fef

2 files changed

Lines changed: 65 additions & 1 deletion

File tree

test/test_common.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_dims_with_index_levels,
2424
is_constant,
2525
iterate_slices,
26+
maybe_group_terms_polars,
2627
)
2728
from linopy.testing import assert_linequal, assert_varequal
2829

@@ -737,3 +738,20 @@ def test_is_constant() -> None:
737738
]
738739
for cv in constant_values:
739740
assert is_constant(cv)
741+
742+
743+
def test_maybe_group_terms_polars_no_duplicates():
744+
"""Fast path: distinct (labels, vars) pairs skip group_by."""
745+
df = pl.DataFrame({"labels": [0, 0], "vars": [1, 2], "coeffs": [3.0, 4.0]})
746+
result = maybe_group_terms_polars(df)
747+
assert result.shape == (2, 3)
748+
assert result.columns == ["labels", "vars", "coeffs"]
749+
assert result["coeffs"].to_list() == [3.0, 4.0]
750+
751+
752+
def test_maybe_group_terms_polars_with_duplicates():
753+
"""Slow path: duplicate (labels, vars) pairs trigger group_by."""
754+
df = pl.DataFrame({"labels": [0, 0], "vars": [1, 1], "coeffs": [3.0, 4.0]})
755+
result = maybe_group_terms_polars(df)
756+
assert result.shape == (1, 3)
757+
assert result["coeffs"].to_list() == [7.0]

test/test_io.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pickle
99
from pathlib import Path
10+
from unittest.mock import patch
1011

1112
import numpy as np
1213
import pandas as pd
@@ -15,7 +16,7 @@
1516
import xarray as xr
1617

1718
from linopy import LESS_EQUAL, Model, available_solvers, read_netcdf
18-
from linopy.io import signed_number
19+
from linopy.io import _format_and_write, signed_number
1920
from linopy.testing import assert_model_equal
2021

2122

@@ -336,3 +337,48 @@ def test_to_file_lp_with_negative_zero_coefficients(tmp_path: Path) -> None:
336337

337338
# Verify Gurobi can read it without errors
338339
gurobipy.read(str(fn))
340+
341+
342+
def test_format_and_write_streaming_fallback(tmp_path):
343+
"""Test that _format_and_write falls back to eager when streaming fails."""
344+
df = pl.DataFrame({"a": ["x", "y"], "b": ["1", "2"]})
345+
columns = [pl.col("a"), pl.lit(" "), pl.col("b")]
346+
347+
# Normal path
348+
fn1 = tmp_path / "normal.lp"
349+
with open(fn1, "wb") as f:
350+
_format_and_write(df, columns, f)
351+
content_normal = fn1.read_text()
352+
353+
# Force streaming to fail
354+
original_collect = pl.LazyFrame.collect
355+
356+
def failing_collect(self, *args, **kwargs):
357+
if kwargs.get("engine") == "streaming":
358+
raise RuntimeError("simulated streaming failure")
359+
return original_collect(self, *args, **kwargs)
360+
361+
fn2 = tmp_path / "fallback.lp"
362+
with patch.object(pl.LazyFrame, "collect", failing_collect):
363+
with open(fn2, "wb") as f:
364+
_format_and_write(df, columns, f)
365+
content_fallback = fn2.read_text()
366+
367+
assert content_normal == content_fallback
368+
369+
370+
def test_to_file_lp_same_sign_constraints(tmp_path):
371+
"""Test LP writing when all constraints have the same sign operator."""
372+
m = Model()
373+
N = np.arange(5)
374+
x = m.add_variables(coords=[N], name="x")
375+
# All constraints use <=
376+
m.add_constraints(x <= 10, name="upper")
377+
m.add_constraints(x <= 20, name="upper2")
378+
m.add_objective(x.sum())
379+
380+
fn = tmp_path / "same_sign.lp"
381+
m.to_file(fn)
382+
content = fn.read_text()
383+
assert "s.t." in content
384+
assert "<=" in content

0 commit comments

Comments
 (0)