Skip to content

Commit 2bdb49b

Browse files
FBumannclaude
andcommitted
Fix mypy errors in test files: add type annotations and fix Hashable check
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ae760b1 commit 2bdb49b

2 files changed

Lines changed: 29 additions & 24 deletions

File tree

test/test_algebraic_properties.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,56 +37,59 @@
3737
a * 0 == 0 multiplication by zero
3838
"""
3939

40+
from __future__ import annotations
41+
4042
import numpy as np
4143
import pandas as pd
4244
import pytest
4345
import xarray as xr
4446

4547
from linopy import Model
4648
from linopy.expressions import LinearExpression
49+
from linopy.variables import Variable
4750

4851

4952
@pytest.fixture
50-
def m():
53+
def m() -> Model:
5154
return Model()
5255

5356

5457
@pytest.fixture
55-
def time():
58+
def time() -> pd.RangeIndex:
5659
return pd.RangeIndex(3, name="time")
5760

5861

5962
@pytest.fixture
60-
def tech():
63+
def tech() -> pd.Index:
6164
return pd.Index(["solar", "wind"], name="tech")
6265

6366

6467
@pytest.fixture
65-
def x(m, time):
68+
def x(m: Model, time: pd.RangeIndex) -> Variable:
6669
"""Variable with dims [time]."""
6770
return m.add_variables(lower=0, coords=[time], name="x")
6871

6972

7073
@pytest.fixture
71-
def y(m, time):
74+
def y(m: Model, time: pd.RangeIndex) -> Variable:
7275
"""Variable with dims [time]."""
7376
return m.add_variables(lower=0, coords=[time], name="y")
7477

7578

7679
@pytest.fixture
77-
def z(m, time):
80+
def z(m: Model, time: pd.RangeIndex) -> Variable:
7881
"""Variable with dims [time]."""
7982
return m.add_variables(lower=0, coords=[time], name="z")
8083

8184

8285
@pytest.fixture
83-
def g(m, time, tech):
86+
def g(m: Model, time: pd.RangeIndex, tech: pd.Index) -> Variable:
8487
"""Variable with dims [time, tech]."""
8588
return m.add_variables(lower=0, coords=[time, tech], name="g")
8689

8790

8891
@pytest.fixture
89-
def c(tech):
92+
def c(tech: pd.Index) -> xr.DataArray:
9093
"""Constant (DataArray) with dims [tech]."""
9194
return xr.DataArray([2.0, 3.0], dims=["tech"], coords={"tech": tech})
9295

@@ -95,7 +98,7 @@ def assert_linequal(a: LinearExpression, b: LinearExpression) -> None:
9598
"""Assert two linear expressions are algebraically equivalent."""
9699
assert set(a.dims) == set(b.dims), f"dims differ: {a.dims} vs {b.dims}"
97100
for dim in a.dims:
98-
if dim.startswith("_"):
101+
if isinstance(dim, str) and dim.startswith("_"):
99102
continue
100103
np.testing.assert_array_equal(
101104
sorted(a.coords[dim].values), sorted(b.coords[dim].values)
@@ -109,15 +112,15 @@ def assert_linequal(a: LinearExpression, b: LinearExpression) -> None:
109112

110113

111114
class TestCommutativity:
112-
def test_add_expr_expr(self, x, y):
115+
def test_add_expr_expr(self, x: Variable, y: Variable) -> None:
113116
"""X + y == y + x"""
114117
assert_linequal(x + y, y + x)
115118

116-
def test_mul_expr_constant(self, g, c):
119+
def test_mul_expr_constant(self, g: Variable, c: xr.DataArray) -> None:
117120
"""G * c == c * g"""
118121
assert_linequal(g * c, c * g)
119122

120-
def test_add_expr_constant(self, g, c):
123+
def test_add_expr_constant(self, g: Variable, c: xr.DataArray) -> None:
121124
"""G + c == c + g"""
122125
assert_linequal(g + c, c + g)
123126

@@ -128,11 +131,11 @@ def test_add_expr_constant(self, g, c):
128131

129132

130133
class TestAssociativity:
131-
def test_add_same_dims(self, x, y, z):
134+
def test_add_same_dims(self, x: Variable, y: Variable, z: Variable) -> None:
132135
"""(x + y) + z == x + (y + z)"""
133136
assert_linequal((x + y) + z, x + (y + z))
134137

135-
def test_add_with_constant(self, x, g, c):
138+
def test_add_with_constant(self, x: Variable, g: Variable, c: xr.DataArray) -> None:
136139
"""(x[A] + c[B]) + g[A,B] == x[A] + (c[B] + g[A,B])"""
137140
assert_linequal((x + c) + g, x + (c + g))
138141

@@ -143,15 +146,17 @@ def test_add_with_constant(self, x, g, c):
143146

144147

145148
class TestDistributivity:
146-
def test_scalar(self, x, y):
149+
def test_scalar(self, x: Variable, y: Variable) -> None:
147150
"""S * (x + y) == s*x + s*y"""
148151
assert_linequal(3 * (x + y), 3 * x + 3 * y)
149152

150-
def test_constant_subset_dims(self, g, c):
153+
def test_constant_subset_dims(self, g: Variable, c: xr.DataArray) -> None:
151154
"""c[B] * (g[A,B] + g[A,B]) == c*g + c*g"""
152155
assert_linequal(c * (g + g), c * g + c * g)
153156

154-
def test_constant_mixed_dims(self, x, g, c):
157+
def test_constant_mixed_dims(
158+
self, x: Variable, g: Variable, c: xr.DataArray
159+
) -> None:
155160
"""c[B] * (x[A] + g[A,B]) == c*x + c*g"""
156161
assert_linequal(c * (x + g), c * x + c * g)
157162

@@ -162,14 +167,14 @@ def test_constant_mixed_dims(self, x, g, c):
162167

163168

164169
class TestIdentity:
165-
def test_additive(self, x):
170+
def test_additive(self, x: Variable) -> None:
166171
"""X + 0 == x"""
167172
result = x + 0
168173
assert isinstance(result, LinearExpression)
169174
assert (result.const == 0).all()
170175
np.testing.assert_array_equal(result.coeffs.squeeze().values, [1, 1, 1])
171176

172-
def test_multiplicative(self, x):
177+
def test_multiplicative(self, x: Variable) -> None:
173178
"""X * 1 == x"""
174179
result = x * 1
175180
assert isinstance(result, LinearExpression)
@@ -182,15 +187,15 @@ def test_multiplicative(self, x):
182187

183188

184189
class TestNegation:
185-
def test_subtraction_is_add_negation(self, x, y):
190+
def test_subtraction_is_add_negation(self, x: Variable, y: Variable) -> None:
186191
"""X - y == x + (-y)"""
187192
assert_linequal(x - y, x + (-y))
188193

189-
def test_subtraction_definition(self, x, y):
194+
def test_subtraction_definition(self, x: Variable, y: Variable) -> None:
190195
"""X - y == x + (-1) * y"""
191196
assert_linequal(x - y, x + (-1) * y)
192197

193-
def test_double_negation(self, x):
198+
def test_double_negation(self, x: Variable) -> None:
194199
"""-(-x) has same coefficients as x"""
195200
result = -(-x)
196201
np.testing.assert_array_equal(
@@ -205,7 +210,7 @@ def test_double_negation(self, x):
205210

206211

207212
class TestZero:
208-
def test_multiplication_by_zero(self, x):
213+
def test_multiplication_by_zero(self, x: Variable) -> None:
209214
"""X * 0 has zero coefficients"""
210215
result = x * 0
211216
assert (result.coeffs == 0).all()

test/test_linear_expression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1920,7 +1920,7 @@ def test_add_constant_join_override(self, a: Variable, c: Variable) -> None:
19201920
def test_add_same_coords_all_joins(self, a: Variable, c: Variable) -> None:
19211921
expr_a = 1 * a + 5
19221922
const = xr.DataArray([1, 2, 3], dims=["i"], coords={"i": [0, 1, 2]})
1923-
for join in ["override", "outer", "inner"]:
1923+
for join in ("override", "outer", "inner"):
19241924
result = expr_a.add(const, join=join)
19251925
assert list(result.coords["i"].values) == [0, 1, 2]
19261926
np.testing.assert_array_equal(result.const.values, [6, 7, 8])

0 commit comments

Comments
 (0)