Skip to content

Commit 0dbe488

Browse files
committed
fix: mypy
1 parent 3d4a815 commit 0dbe488

2 files changed

Lines changed: 34 additions & 0 deletions

File tree

test/test_constraint.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,20 @@ def test_constraint_to_polars(c: linopy.constraints.Constraint) -> None:
437437
assert isinstance(c.to_polars(), pl.DataFrame)
438438

439439

440+
def test_constraint_to_polars_mixed_signs(m: Model, x: linopy.Variable) -> None:
441+
"""Test to_polars when a constraint has mixed sign values across dims."""
442+
# Create a constraint, then manually patch the sign to have mixed values
443+
m.add_constraints(x >= 0, name="mixed")
444+
con = m.constraints["mixed"]
445+
# Replace sign data with mixed signs across the first dimension
446+
n = con.data.sizes["first"]
447+
signs = np.array(["<=" if i % 2 == 0 else ">=" for i in range(n)])
448+
con.data["sign"] = xr.DataArray(signs, dims=con.data["sign"].dims)
449+
df = con.to_polars()
450+
assert isinstance(df, pl.DataFrame)
451+
assert set(df["sign"].to_list()) == {"<=", ">="}
452+
453+
440454
def test_constraint_assignment_with_anonymous_constraints(
441455
m: Model, x: linopy.Variable, y: linopy.Variable
442456
) -> None:

test/test_io.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,3 +383,23 @@ def test_to_file_lp_same_sign_constraints(tmp_path: Path) -> None:
383383
content = fn.read_text()
384384
assert "s.t." in content
385385
assert "<=" in content
386+
387+
388+
def test_to_file_lp_mixed_sign_constraints(tmp_path: Path) -> None:
389+
"""Test LP writing when constraints have different sign operators."""
390+
m = Model()
391+
N = np.arange(5)
392+
x = m.add_variables(coords=[N], name="x")
393+
# Mix of <= and >= constraints in the same container
394+
m.add_constraints(x <= 10, name="upper")
395+
m.add_constraints(x >= 1, name="lower")
396+
m.add_constraints(2 * x == 8, name="eq")
397+
m.add_objective(x.sum())
398+
399+
fn = tmp_path / "mixed_sign.lp"
400+
m.to_file(fn)
401+
content = fn.read_text()
402+
assert "s.t." in content
403+
assert "<=" in content
404+
assert ">=" in content
405+
assert "=" in content

0 commit comments

Comments
 (0)