3737 a * 0 == 0 multiplication by zero
3838"""
3939
40+ from __future__ import annotations
41+
4042import numpy as np
4143import pandas as pd
4244import pytest
4345import xarray as xr
4446
4547from linopy import Model
4648from linopy .expressions import LinearExpression
49+ from linopy .variables import Variable
4750
4851
4952@pytest .fixture
50- def m ():
53+ def m () -> Model :
5154 return Model ()
5255
5356
5457@pytest .fixture
55- def time ():
58+ def time () -> pd . RangeIndex :
5659 return pd .RangeIndex (3 , name = "time" )
5760
5861
5962@pytest .fixture
60- def tech ():
63+ def tech () -> pd . Index :
6164 return pd .Index (["solar" , "wind" ], name = "tech" )
6265
6366
6467@pytest .fixture
65- def x (m , time ) :
68+ def x (m : Model , time : pd . RangeIndex ) -> Variable :
6669 """Variable with dims [time]."""
6770 return m .add_variables (lower = 0 , coords = [time ], name = "x" )
6871
6972
7073@pytest .fixture
71- def y (m , time ) :
74+ def y (m : Model , time : pd . RangeIndex ) -> Variable :
7275 """Variable with dims [time]."""
7376 return m .add_variables (lower = 0 , coords = [time ], name = "y" )
7477
7578
7679@pytest .fixture
77- def z (m , time ) :
80+ def z (m : Model , time : pd . RangeIndex ) -> Variable :
7881 """Variable with dims [time]."""
7982 return m .add_variables (lower = 0 , coords = [time ], name = "z" )
8083
8184
8285@pytest .fixture
83- def g (m , time , tech ) :
86+ def g (m : Model , time : pd . RangeIndex , tech : pd . Index ) -> Variable :
8487 """Variable with dims [time, tech]."""
8588 return m .add_variables (lower = 0 , coords = [time , tech ], name = "g" )
8689
8790
8891@pytest .fixture
89- def c (tech ) :
92+ def c (tech : pd . Index ) -> xr . DataArray :
9093 """Constant (DataArray) with dims [tech]."""
9194 return xr .DataArray ([2.0 , 3.0 ], dims = ["tech" ], coords = {"tech" : tech })
9295
@@ -95,7 +98,7 @@ def assert_linequal(a: LinearExpression, b: LinearExpression) -> None:
9598 """Assert two linear expressions are algebraically equivalent."""
9699 assert set (a .dims ) == set (b .dims ), f"dims differ: { a .dims } vs { b .dims } "
97100 for dim in a .dims :
98- if dim .startswith ("_" ):
101+ if isinstance ( dim , str ) and dim .startswith ("_" ):
99102 continue
100103 np .testing .assert_array_equal (
101104 sorted (a .coords [dim ].values ), sorted (b .coords [dim ].values )
@@ -109,15 +112,15 @@ def assert_linequal(a: LinearExpression, b: LinearExpression) -> None:
109112
110113
111114class TestCommutativity :
112- def test_add_expr_expr (self , x , y ) :
115+ def test_add_expr_expr (self , x : Variable , y : Variable ) -> None :
113116 """X + y == y + x"""
114117 assert_linequal (x + y , y + x )
115118
116- def test_mul_expr_constant (self , g , c ) :
119+ def test_mul_expr_constant (self , g : Variable , c : xr . DataArray ) -> None :
117120 """G * c == c * g"""
118121 assert_linequal (g * c , c * g )
119122
120- def test_add_expr_constant (self , g , c ) :
123+ def test_add_expr_constant (self , g : Variable , c : xr . DataArray ) -> None :
121124 """G + c == c + g"""
122125 assert_linequal (g + c , c + g )
123126
@@ -128,11 +131,11 @@ def test_add_expr_constant(self, g, c):
128131
129132
130133class TestAssociativity :
131- def test_add_same_dims (self , x , y , z ) :
134+ def test_add_same_dims (self , x : Variable , y : Variable , z : Variable ) -> None :
132135 """(x + y) + z == x + (y + z)"""
133136 assert_linequal ((x + y ) + z , x + (y + z ))
134137
135- def test_add_with_constant (self , x , g , c ) :
138+ def test_add_with_constant (self , x : Variable , g : Variable , c : xr . DataArray ) -> None :
136139 """(x[A] + c[B]) + g[A,B] == x[A] + (c[B] + g[A,B])"""
137140 assert_linequal ((x + c ) + g , x + (c + g ))
138141
@@ -143,15 +146,17 @@ def test_add_with_constant(self, x, g, c):
143146
144147
145148class TestDistributivity :
146- def test_scalar (self , x , y ) :
149+ def test_scalar (self , x : Variable , y : Variable ) -> None :
147150 """S * (x + y) == s*x + s*y"""
148151 assert_linequal (3 * (x + y ), 3 * x + 3 * y )
149152
150- def test_constant_subset_dims (self , g , c ) :
153+ def test_constant_subset_dims (self , g : Variable , c : xr . DataArray ) -> None :
151154 """c[B] * (g[A,B] + g[A,B]) == c*g + c*g"""
152155 assert_linequal (c * (g + g ), c * g + c * g )
153156
154- def test_constant_mixed_dims (self , x , g , c ):
157+ def test_constant_mixed_dims (
158+ self , x : Variable , g : Variable , c : xr .DataArray
159+ ) -> None :
155160 """c[B] * (x[A] + g[A,B]) == c*x + c*g"""
156161 assert_linequal (c * (x + g ), c * x + c * g )
157162
@@ -162,14 +167,14 @@ def test_constant_mixed_dims(self, x, g, c):
162167
163168
164169class TestIdentity :
165- def test_additive (self , x ) :
170+ def test_additive (self , x : Variable ) -> None :
166171 """X + 0 == x"""
167172 result = x + 0
168173 assert isinstance (result , LinearExpression )
169174 assert (result .const == 0 ).all ()
170175 np .testing .assert_array_equal (result .coeffs .squeeze ().values , [1 , 1 , 1 ])
171176
172- def test_multiplicative (self , x ) :
177+ def test_multiplicative (self , x : Variable ) -> None :
173178 """X * 1 == x"""
174179 result = x * 1
175180 assert isinstance (result , LinearExpression )
@@ -182,15 +187,15 @@ def test_multiplicative(self, x):
182187
183188
184189class TestNegation :
185- def test_subtraction_is_add_negation (self , x , y ) :
190+ def test_subtraction_is_add_negation (self , x : Variable , y : Variable ) -> None :
186191 """X - y == x + (-y)"""
187192 assert_linequal (x - y , x + (- y ))
188193
189- def test_subtraction_definition (self , x , y ) :
194+ def test_subtraction_definition (self , x : Variable , y : Variable ) -> None :
190195 """X - y == x + (-1) * y"""
191196 assert_linequal (x - y , x + (- 1 ) * y )
192197
193- def test_double_negation (self , x ) :
198+ def test_double_negation (self , x : Variable ) -> None :
194199 """-(-x) has same coefficients as x"""
195200 result = - (- x )
196201 np .testing .assert_array_equal (
@@ -205,7 +210,7 @@ def test_double_negation(self, x):
205210
206211
207212class TestZero :
208- def test_multiplication_by_zero (self , x ) :
213+ def test_multiplication_by_zero (self , x : Variable ) -> None :
209214 """X * 0 has zero coefficients"""
210215 result = x * 0
211216 assert (result .coeffs == 0 ).all ()
0 commit comments