55from solve_dae .integrate import solve_dae
66
77
8- # initial conditions
98y0 = [1 ]
109yp0 = [0.5 * y0 [0 ]]
11- t_span = (0 , 1 )
10+ t_span = (0 , 1.2 )
1211
1312
1413def F (t , y , yp ):
1514 return yp - 0.5 * y
1615
1716def jac_dense (t , y , yp ):
18- return np .eye (len (y )), - 0.5 * np .eye (len (yp ))
17+ return - 0.5 * np .eye (len (y )), np .eye (len (yp ))
1918
2019def jac_sparse (t , y , yp ):
21- return csc_matrix (np .eye (len (y ))), csc_matrix (- 0.5 * np .eye (len (yp )))
20+ return csc_matrix (- 0.5 * np .eye (len (y ))), csc_matrix (np .eye (len (yp )))
2221
23- def jac_wrong_shape_dense (t , y , yp ):
24- return np .eye (len (y ) + 1 ), - 0.5 * np .eye (len (yp ) + 1 )
22+ def jac_wrong_shape_Jyp_dense (t , y , yp ):
23+ return - 0.5 * np .eye (len (y )), np .eye (len (yp ) + 1 ),
2524
26- def jac_wrong_shape_sparse (t , y , yp ):
27- return csc_matrix (np .eye (len (y ) + 1 )), csc_matrix (- 0.5 * np .eye (len (yp ) + 1 ))
25+ def jac_wrong_shape_Jy_dense (t , y , yp ):
26+ return - 0.5 * np .eye (len (y ) + 1 ), np .eye (len (yp ))
27+
28+ def jac_wrong_shape_Jyp_sparse (t , y , yp ):
29+ return csc_matrix (- 0.5 * np .eye (len (y ))), csc_matrix (np .eye (len (yp ) + 1 ))
30+
31+ def jac_wrong_shape_Jy_sparse (t , y , yp ):
32+ return csc_matrix (- 0.5 * np .eye (len (y ) + 1 )), csc_matrix (np .eye (len (yp )))
33+
34+ jac_wrong_shape_Jyp_dense_constant = (
35+ np .eye (2 ),
36+ - 0.5 * np .eye (1 ),
37+ )
38+
39+ jac_wrong_shape_Jy_dense_constant = (
40+ np .eye (1 ),
41+ - 0.5 * np .eye (2 ),
42+ )
43+
44+ jac_wrong_shape_Jyp_sparse_constant = (
45+ csc_matrix (np .eye (2 )),
46+ csc_matrix (- 0.5 * np .eye (1 )),
47+ )
48+
49+ jac_wrong_shape_Jy_sparse_constant = (
50+ csc_matrix (np .eye (1 )),
51+ csc_matrix (- 0.5 * np .eye (2 )),
52+ )
2853
2954
3055parameters_method = ["BDF" , "Radau" ]
31- parameters_jac = [jac_dense , jac_sparse , jac_wrong_shape_dense , jac_wrong_shape_sparse ]
56+ parameters_jac_correct_shape = [jac_dense , jac_sparse ]
57+ parameters_jac_wrong_shape = [
58+ jac_wrong_shape_Jyp_dense ,
59+ jac_wrong_shape_Jy_dense ,
60+ jac_wrong_shape_Jyp_sparse ,
61+ jac_wrong_shape_Jy_sparse ,
62+ jac_wrong_shape_Jyp_dense_constant ,
63+ jac_wrong_shape_Jy_dense_constant ,
64+ jac_wrong_shape_Jyp_sparse_constant ,
65+ jac_wrong_shape_Jy_sparse_constant ,
66+ ]
67+ parameters_jac = parameters_jac_correct_shape + parameters_jac_wrong_shape
3268
3369
3470parameters = product (
@@ -39,7 +75,7 @@ def jac_wrong_shape_sparse(t, y, yp):
3975
4076@pytest .mark .parametrize ("method, jac" , parameters )
4177def test_jacobian_shape (method , jac ):
42- if jac in [ jac_wrong_shape_dense , jac_wrong_shape_sparse ] :
78+ if not callable ( jac ) or jac in parameters_jac_wrong_shape :
4379 with pytest .raises (ValueError ) as excinfo :
4480 solve_dae (F , t_span , y0 , yp0 , method = method , jac = jac )
4581
@@ -61,12 +97,34 @@ def test_jacobian_shape(method, jac):
6197 solve_dae (F , t_span , y0 , yp0 , method = method , jac = jac )
6298
6399
64- # @pytest.mark.parametrize("method,", parameters_method)
65- # def test_step(method):
66- # sol = solve_dae(F, t_span, y0, yp0, method=method)
67- # pass
100+ @pytest .mark .parametrize ("method" , parameters_method )
101+ def test_small_max_step (method ):
102+ solve_dae (F , t_span , y0 , yp0 , method = method , max_step = 1e-2 )
103+
104+
105+ def F (t , y , yp ):
106+ return yp - 1 / (1 - t )
107+
108+ y0 = [0 ]
109+ yp0 = [1 ]
110+ t_span = (0 , 1.2 )
111+
112+
113+ @pytest .mark .filterwarnings ("ignore:divide by zero encountered in scalar divide" )
114+ @pytest .mark .parametrize ("method" , parameters_method )
115+ def test_overflow (method ):
116+ sol = solve_dae (F , t_span , y0 , yp0 , method = method )
117+ assert sol .status == - 1
118+ assert sol .success == False
119+ assert sol .message == "Required step size is less than spacing between numbers."
68120
69121
70122# if __name__ == "__main__":
71123# for params in parameters:
72124# test_jacobian_shape(*params)
125+
126+ # for params in parameters_method:
127+ # test_small_max_step(params)
128+
129+ # for params in parameters_method:
130+ # test_overflow(params)
0 commit comments