Skip to content

Commit 11e6ce7

Browse files
FabianHofmannclaude
andcommitted
Fix mypy type checking issues
- Add type assertion in expressions.py to clarify pandas Series type - Keep necessary type: ignore[assignment] comments in test files where intentional type mismatches are tested - Remove Python 3.9 from classifiers in pyproject.toml (was part of staged changes) - Update model.py and expressions.py to use union types with | operator 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 8a4821c commit 11e6ce7

8 files changed

Lines changed: 51 additions & 56 deletions

linopy/expressions.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _exprwrap(
129129
def _expr_unwrap(
130130
maybe_expr: Any | LinearExpression | QuadraticExpression,
131131
) -> Any:
132-
if isinstance(maybe_expr, (LinearExpression, QuadraticExpression)):
132+
if isinstance(maybe_expr, LinearExpression | QuadraticExpression):
133133
return maybe_expr.data
134134

135135
return maybe_expr
@@ -249,6 +249,8 @@ def sum(self, use_fallback: bool = False, **kwargs: Any) -> LinearExpression:
249249
orig_group = group
250250
group = group.apply(tuple, axis=1).map(int_map)
251251

252+
# At this point, group is always a pandas Series
253+
assert isinstance(group, pd.Series)
252254
group_dim = group.index.name
253255

254256
arrays = [group, group.groupby(group).cumcount()]
@@ -529,13 +531,11 @@ def __div__(self: GenericExpression, other: SideLike) -> GenericExpression:
529531
try:
530532
if isinstance(
531533
other,
532-
(
533-
variables.Variable,
534-
variables.ScalarVariable,
535-
LinearExpression,
536-
ScalarLinearExpression,
537-
QuadraticExpression,
538-
),
534+
variables.Variable
535+
| variables.ScalarVariable
536+
| LinearExpression
537+
| ScalarLinearExpression
538+
| QuadraticExpression,
539539
):
540540
raise TypeError(
541541
"unsupported operand type(s) for /: "
@@ -930,7 +930,7 @@ def where(
930930
_other = FILL_VALUE
931931
else:
932932
_other = None
933-
elif isinstance(other, (int, float, DataArray)):
933+
elif isinstance(other, int | float | DataArray):
934934
_other = {**self._fill_value, "const": other}
935935
else:
936936
_other = _expr_unwrap(other)
@@ -970,7 +970,7 @@ def fillna(
970970
A new object with missing values filled with the given value.
971971
"""
972972
value = _expr_unwrap(value)
973-
if isinstance(value, (DataArray, np.floating, np.integer, int, float)):
973+
if isinstance(value, DataArray | np.floating | np.integer | int | float):
974974
value = {"const": value}
975975
return self.__class__(self.data.fillna(value), self.model)
976976

@@ -1362,10 +1362,10 @@ def __mul__(
13621362
return other.__rmul__(self)
13631363

13641364
try:
1365-
if isinstance(other, (variables.Variable, variables.ScalarVariable)):
1365+
if isinstance(other, variables.Variable | variables.ScalarVariable):
13661366
other = other.to_linexpr()
13671367

1368-
if isinstance(other, (LinearExpression, ScalarLinearExpression)):
1368+
if isinstance(other, LinearExpression | ScalarLinearExpression):
13691369
return self._multiply_by_linear_expression(other)
13701370
else:
13711371
return self._multiply_by_constant(other)
@@ -1403,7 +1403,7 @@ def __matmul__(
14031403
"""
14041404
Matrix multiplication with other, similar to xarray dot.
14051405
"""
1406-
if not isinstance(other, (LinearExpression, variables.Variable)):
1406+
if not isinstance(other, LinearExpression | variables.Variable):
14071407
other = as_dataarray(other, coords=self.coords, dims=self.coord_dims)
14081408

14091409
common_dims = list(set(self.coord_dims).intersection(other.dims))
@@ -1620,7 +1620,7 @@ def process_one(
16201620
# assume first element is coefficient and second is variable
16211621
c, v = t
16221622
if isinstance(v, variables.ScalarVariable):
1623-
if not isinstance(c, (int, float)):
1623+
if not isinstance(c, int | float):
16241624
raise TypeError(
16251625
"Expected int or float as coefficient of scalar variable (first element of tuple)."
16261626
)
@@ -1703,12 +1703,10 @@ def __mul__(self, other: SideLike) -> QuadraticExpression:
17031703
"""
17041704
if isinstance(
17051705
other,
1706-
(
1707-
BaseExpression,
1708-
ScalarLinearExpression,
1709-
variables.Variable,
1710-
variables.ScalarVariable,
1711-
),
1706+
BaseExpression
1707+
| ScalarLinearExpression
1708+
| variables.Variable
1709+
| variables.ScalarVariable,
17121710
):
17131711
raise TypeError(
17141712
"unsupported operand type(s) for *: "
@@ -1787,12 +1785,10 @@ def __matmul__(
17871785
"""
17881786
if isinstance(
17891787
other,
1790-
(
1791-
BaseExpression,
1792-
ScalarLinearExpression,
1793-
variables.Variable,
1794-
variables.ScalarVariable,
1795-
),
1788+
BaseExpression
1789+
| ScalarLinearExpression
1790+
| variables.Variable
1791+
| variables.ScalarVariable,
17961792
):
17971793
raise TypeError(
17981794
"Higher order non-linear expressions are not yet supported."
@@ -1915,9 +1911,9 @@ def as_expression(
19151911
ValueError
19161912
If object cannot be converted to LinearExpression.
19171913
"""
1918-
if isinstance(obj, (LinearExpression, QuadraticExpression)):
1914+
if isinstance(obj, LinearExpression | QuadraticExpression):
19191915
return obj
1920-
elif isinstance(obj, (variables.Variable, variables.ScalarVariable)):
1916+
elif isinstance(obj, variables.Variable | variables.ScalarVariable):
19211917
return obj.to_linexpr()
19221918
else:
19231919
try:
@@ -2134,7 +2130,7 @@ def __neg__(self) -> ScalarLinearExpression:
21342130
)
21352131

21362132
def __mul__(self, other: float | int) -> ScalarLinearExpression:
2137-
if not isinstance(other, (int, float, np.number)):
2133+
if not isinstance(other, int | float | np.number):
21382134
raise TypeError(
21392135
f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
21402136
)
@@ -2147,7 +2143,7 @@ def __rmul__(self, other: int) -> ScalarLinearExpression:
21472143
return self.__mul__(other)
21482144

21492145
def __div__(self, other: float | int) -> ScalarLinearExpression:
2150-
if not isinstance(other, (int, float, np.number)):
2146+
if not isinstance(other, int | float | np.number):
21512147
raise TypeError(
21522148
f"unsupported operand type(s) for /: {type(self)} and {type(other)}"
21532149
)
@@ -2157,23 +2153,23 @@ def __truediv__(self, other: float | int) -> ScalarLinearExpression:
21572153
return self.__div__(other)
21582154

21592155
def __le__(self, other: int | float) -> AnonymousScalarConstraint:
2160-
if not isinstance(other, (int, float, np.number)):
2156+
if not isinstance(other, int | float | np.number):
21612157
raise TypeError(
21622158
f"unsupported operand type(s) for <=: {type(self)} and {type(other)}"
21632159
)
21642160

21652161
return constraints.AnonymousScalarConstraint(self, LESS_EQUAL, other)
21662162

21672163
def __ge__(self, other: int | float) -> AnonymousScalarConstraint:
2168-
if not isinstance(other, (int, float, np.number)):
2164+
if not isinstance(other, int | float | np.number):
21692165
raise TypeError(
21702166
f"unsupported operand type(s) for >=: {type(self)} and {type(other)}"
21712167
)
21722168

21732169
return constraints.AnonymousScalarConstraint(self, GREATER_EQUAL, other)
21742170

21752171
def __eq__(self, other: int | float) -> AnonymousScalarConstraint: # type: ignore
2176-
if not isinstance(other, (int, float, np.number)):
2172+
if not isinstance(other, int | float | np.number):
21772173
raise TypeError(
21782174
f"unsupported operand type(s) for ==: {type(self)} and {type(other)}"
21792175
)

linopy/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def solver_dir(self) -> Path:
317317

318318
@solver_dir.setter
319319
def solver_dir(self, value: str | Path) -> None:
320-
if not isinstance(value, (str, Path)):
320+
if not isinstance(value, str | Path):
321321
raise TypeError("'solver_dir' must path-like.")
322322
self._solver_dir = Path(value)
323323

@@ -614,7 +614,7 @@ def add_constraints(
614614
if sign is None or rhs is None:
615615
raise ValueError(msg_sign_rhs_not_none)
616616
data = lhs.to_constraint(sign, rhs).data
617-
elif isinstance(lhs, (list, tuple)):
617+
elif isinstance(lhs, list | tuple):
618618
if sign is None or rhs is None:
619619
raise ValueError(msg_sign_rhs_none)
620620
data = self.linexpr(*lhs).to_constraint(sign, rhs).data
@@ -633,7 +633,7 @@ def add_constraints(
633633
if sign is not None or rhs is not None:
634634
raise ValueError(msg_sign_rhs_none)
635635
data = lhs.data
636-
elif isinstance(lhs, (Variable, ScalarVariable, ScalarLinearExpression)):
636+
elif isinstance(lhs, Variable | ScalarVariable | ScalarLinearExpression):
637637
if sign is None or rhs is None:
638638
raise ValueError(msg_sign_rhs_not_none)
639639
data = lhs.to_linexpr().to_constraint(sign, rhs).data

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ readme = "README.md"
1010
authors = [{ name = "Fabian Hofmann", email = "fabianmarikhofmann@gmail.com" }]
1111
license = { file = "LICENSE" }
1212
classifiers = [
13-
"Programming Language :: Python :: 3.9",
1413
"Programming Language :: Python :: 3.10",
1514
"Programming Language :: Python :: 3.11",
1615
"Programming Language :: Python :: 3.12",

test/test_constraint.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def test_constraint_rhs_getter(c: linopy.constraints.Constraint) -> None:
310310
def test_constraint_vars_setter(
311311
c: linopy.constraints.Constraint, x: linopy.Variable
312312
) -> None:
313-
c.vars = x # type: ignore
313+
c.vars = x # type: ignore[assignment]
314314
assert_equal(c.vars, x.labels)
315315

316316

@@ -329,7 +329,7 @@ def test_constraint_vars_setter_invalid(
329329

330330

331331
def test_constraint_coeffs_setter(c: linopy.constraints.Constraint) -> None:
332-
c.coeffs = 3 # type: ignore
332+
c.coeffs = 3 # type: ignore[assignment]
333333
assert (c.coeffs == 3).all()
334334

335335

@@ -345,32 +345,32 @@ def test_constraint_lhs_setter(
345345
def test_constraint_lhs_setter_with_variable(
346346
c: linopy.constraints.Constraint, x: linopy.Variable
347347
) -> None:
348-
c.lhs = x # type: ignore
348+
c.lhs = x # type: ignore[assignment]
349349
assert c.lhs.nterm == 1
350350

351351

352352
def test_constraint_lhs_setter_with_constant(c: linopy.constraints.Constraint) -> None:
353353
sizes = c.sizes
354-
c.lhs = 10 # type: ignore
354+
c.lhs = 10 # type: ignore[assignment]
355355
assert (c.rhs == -10).all()
356356
assert c.lhs.nterm == 0
357357
assert c.sizes["first"] == sizes["first"]
358358

359359

360360
def test_constraint_sign_setter(c: linopy.constraints.Constraint) -> None:
361-
c.sign = EQUAL # type: ignore
361+
c.sign = EQUAL # type: ignore[assignment]
362362
assert (c.sign == EQUAL).all()
363363

364364

365365
def test_constraint_sign_setter_alternative(c: linopy.constraints.Constraint) -> None:
366-
c.sign = long_EQUAL # type: ignore
366+
c.sign = long_EQUAL # type: ignore[assignment]
367367
assert (c.sign == EQUAL).all()
368368

369369

370370
def test_constraint_sign_setter_invalid(c: linopy.constraints.Constraint) -> None:
371371
# Test that assigning lhs with other type that LinearExpression raises TypeError
372372
with pytest.raises(ValueError):
373-
c.sign = "asd" # type: ignore
373+
c.sign = "asd" # type: ignore[assignment]
374374

375375

376376
def test_constraint_rhs_setter(c: linopy.constraints.Constraint) -> None:
@@ -392,7 +392,7 @@ def test_constraint_rhs_setter_with_variable(
392392
def test_constraint_rhs_setter_with_expression(
393393
c: linopy.constraints.Constraint, x: linopy.Variable, y: linopy.Variable
394394
) -> None:
395-
c.rhs = x + y # type: ignore
395+
c.rhs = x + y # type: ignore[assignment]
396396
assert (c.rhs == 0).all()
397397
assert (c.coeffs.isel({c.term_dim: -1}) == -1).all()
398398
assert c.lhs.nterm == 3
@@ -401,7 +401,7 @@ def test_constraint_rhs_setter_with_expression(
401401
def test_constraint_rhs_setter_with_expression_and_constant(
402402
c: linopy.constraints.Constraint, x: linopy.Variable
403403
) -> None:
404-
c.rhs = x + 1 # type: ignore
404+
c.rhs = x + 1 # type: ignore[assignment]
405405
assert (c.rhs == 1).all()
406406
assert (c.coeffs.sum(c.term_dim) == 0).all()
407407
assert c.lhs.nterm == 2

test/test_linear_expression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,7 @@ def test_merge(x: Variable, y: Variable, z: Variable) -> None:
11511151
assert res.sel(dim_1=0).vars[2].item() == -1
11521152

11531153
with pytest.warns(DeprecationWarning):
1154-
merge(expr1, expr2) # type: ignore
1154+
merge(expr1, expr2) # type: ignore[arg-type]
11551155

11561156

11571157
def test_linear_expression_outer_sum(x: Variable, y: Variable) -> None:

test/test_optimization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def model_with_inf() -> Model:
129129
m.add_constraints(x + y, GREATER_EQUAL, 10)
130130
m.add_constraints(1 * x, "<=", np.inf)
131131

132-
m.objective = 2 * x + y # type: ignore
132+
m.objective = 2 * x + y # type: ignore[assignment]
133133

134134
return m
135135

@@ -142,7 +142,7 @@ def model_with_duplicated_variables() -> Model:
142142
x = m.add_variables(coords=[lower.index], name="x")
143143

144144
m.add_constraints(x + x, GREATER_EQUAL, 10)
145-
m.objective = 1 * x # type: ignore
145+
m.objective = 1 * x # type: ignore[assignment]
146146

147147
return m
148148

@@ -157,7 +157,7 @@ def model_with_non_aligned_variables() -> Model:
157157
y = m.add_variables(lower=lower, coords=[lower.index], name="y")
158158

159159
m.add_constraints(x + y, GREATER_EQUAL, 10.5)
160-
m.objective = 1 * x + 0.5 * y # type: ignore
160+
m.objective = 1 * x + 0.5 * y # type: ignore[assignment]
161161

162162
return m
163163

@@ -270,9 +270,9 @@ def modified_model() -> Model:
270270

271271
c = m.add_constraints(x + y, GREATER_EQUAL, 10)
272272

273-
y.lower = 9 # type: ignore
273+
y.lower = 9 # type: ignore[assignment]
274274
c.lhs = 2 * x + y
275-
m.objective = 2 * x + y # type: ignore
275+
m.objective = 2 * x + y # type: ignore[assignment]
276276

277277
return m
278278

test/test_quadratic_expression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def test_quadratic_expression_to_matrix(model: Model, x: Variable, y: Variable)
303303

304304
def test_matrices_matrix(model: Model, x: Variable, y: Variable) -> None:
305305
expr = 10 * x * y
306-
model.objective = expr # type: ignore
306+
model.objective = expr # type: ignore[assignment]
307307

308308
Q = model.matrices.Q
309309
assert isinstance(Q, csc_matrix)
@@ -314,7 +314,7 @@ def test_matrices_matrix_mixed_linear_and_quadratic(
314314
model: Model, x: Variable, y: Variable
315315
) -> None:
316316
quad_expr = x * y + x
317-
model.objective = quad_expr + x # type: ignore
317+
model.objective = quad_expr + x # type: ignore[assignment]
318318

319319
Q = model.matrices.Q
320320
assert isinstance(Q, csc_matrix)

test/test_variable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,12 @@ def test_variable_lower_getter(z: linopy.Variable) -> None:
128128

129129

130130
def test_variable_upper_setter(z: linopy.Variable) -> None:
131-
z.upper = 20 # type: ignore
131+
z.upper = 20 # type: ignore[assignment]
132132
assert z.upper.item() == 20
133133

134134

135135
def test_variable_lower_setter(z: linopy.Variable) -> None:
136-
z.lower = 8 # type: ignore
136+
z.lower = 8 # type: ignore[assignment]
137137
assert z.lower == 8
138138

139139

0 commit comments

Comments
 (0)