@@ -126,9 +126,6 @@ def __call__(self, t):
126126 return ys , yps
127127
128128
129- # TODO: Compare this with
130- # - ddassl.f by Petzold
131- # - epsode.f by Bryne and Hindmarsh
132129def select_initial_step (t0 , y0 , yp0 , t_bound , rtol , atol , max_step ):
133130 """Empirically select a good initial step.
134131
@@ -160,18 +157,28 @@ def select_initial_step(t0, y0, yp0, t_bound, rtol, atol, max_step):
160157
161158 References
162159 ----------
163- .. [1] TODO: Find a reference .
160+ .. [1] L. F. Shampine, "Starting an ODE solver", November 1977 .
164161 """
165- min_step = 0.0
162+ safety = 0.8
163+ min_step = 16 * EPS * abs (t0 )
166164 threshold = atol / rtol
167165 hspan = abs (t_bound - t0 )
168166
169- # compute an initial step size h using yp = y'(t0)
167+ # compute scaling
170168 wt = np .maximum (np .abs (y0 ), threshold )
171- rh = 1.25 * np .linalg .norm (yp0 / wt , np .inf ) / np .sqrt (rtol )
169+
170+ # error
171+ e = np .linalg .norm (yp0 / wt , np .inf )
172+
173+ # reciprocal step size
174+ rh = e / np .sqrt (rtol ) / safety
175+
176+ # compute an initial step size
172177 h_abs = min (max_step , hspan )
173178 if h_abs * rh > 1 :
174179 h_abs = 1 / rh
180+
181+ # ensure h_abs >= min_step
175182 h_abs = max (h_abs , min_step )
176183 return h_abs
177184
@@ -184,7 +191,7 @@ def consistent_initial_conditions(fun, t0, y0, yp0, jac=None, fixed_y0=None,
184191
185192 References
186193 ----------
187- .. [1] L. F. Shampine, "Solving 0 = F(t, y(t), y′ (t)) in Matlab", Journal
194+ .. [1] L. F. Shampine, "Solving 0 = F(t, y(t), y' (t)) in Matlab", Journal
188195 of Numerical Mathematics, vol. 10, no. 4, 2002, pp. 291-310.
189196 """
190197 n = len (y0 )
@@ -220,8 +227,8 @@ def fun_composite(t, z):
220227 if not (isinstance (rtol , float ) and rtol > 0 ):
221228 raise ValueError ("Relative tolerance must be a positive scalar." )
222229
223- if rtol < 100 * np . finfo ( float ). eps :
224- rtol = 100 * np . finfo ( float ). eps
230+ if rtol < 100 * EPS :
231+ rtol = 100 * EPS
225232 print (f"Relative tolerance increased to { rtol } " )
226233
227234 if np .any (np .array (atol ) <= 0 ):
@@ -235,29 +242,13 @@ def fun_composite(t, z):
235242 Jy , Jyp = jac (t0 , y0 , yp0 )
236243
237244 scale_f = atol + np .abs (f ) * rtol
238- # z0 = np.concatenate([y0, yp0])
239- # scale_z = atol + np.abs(z0) * rtol
240- # dz_norm_old = None
241- # rate_z = None
242- # tol = max(10 * EPS / rtol, min(0.03, rtol ** 0.5))
243245
244246 for _ in range (newton_maxiter ):
245247 for _ in range (chord_iter ):
246248 dy , dyp = solve_underdetermined_system (f , Jy , Jyp , free_y , free_yp )
247249 y0 += dy
248250 yp0 += dyp
249251
250- # dz = np.concatenate([dy, dyp])
251- # with np.errstate(divide='ignore'):
252- # dz_norm = norm(dz / scale_z)
253- # if dz_norm_old is not None:
254- # rate_z = dz_norm / dz_norm_old
255-
256- # if (dz_norm == 0 or (rate_z is not None and rate_z / (1 - rate_z) * dz_norm < safety * tol)):
257- # return y0, yp0, f
258-
259- # dz_norm_old = dz_norm
260-
261252 f = fun (t0 , y0 , yp0 , * args )
262253 error = norm (f / scale_f )
263254 if error < safety :
@@ -271,7 +262,7 @@ def fun_composite(t, z):
271262def qrank (A ):
272263 """Compute QR-decomposition with column pivoting of A and estimate the rank."""
273264 Q , R , p = qr (A , pivoting = True )
274- tol = max (A .shape ) * np . finfo ( float ). eps * abs (R [0 , 0 ])
265+ tol = max (A .shape ) * EPS * abs (R [0 , 0 ])
275266 rank = np .sum (abs (np .diag (R )) > tol )
276267 return rank , Q , R , p
277268
@@ -353,10 +344,17 @@ def solve_underdetermined_system(f, Jy, Jyp, free_y, free_yp):
353344 # [S21, S22] [w1] = d2
354345 # [w2]
355346 # using column pivoting QR-decomposition
356- w_ = np .zeros (RS .shape [1 ])
357- w_ [:rankS ] = solve_triangular (RS [:rankS , :rankS ], (QS .T @ d2 [:rankS ]))
358- w = np .zeros_like (w_ )
359- w [pS ] = w_
347+ # [RS11, RS12] [v1] = [c1]
348+ # [ 0, 0] [v2] [c2]
349+ # with v2 = 0 this gives
350+ # RS11 @ v1 = c1
351+ c = QS .T @ d2
352+ v = np .zeros (RS .shape [1 ])
353+ v [:rankS ] = solve_triangular (RS [:rankS , :rankS ], c [:rankS ])
354+
355+ # apply permutation
356+ w = np .zeros_like (v )
357+ w [pS ] = v
360358
361359 # set w2' = 0 and solve the remaining system
362360 # [R11] w1' = d1 - [S11, S12] [w1]
0 commit comments