Skip to content

Commit d6ea7a2

Browse files
committed
Improved tests and minor cleanup of Radau.
1 parent 58bde78 commit d6ea7a2

2 files changed

Lines changed: 38 additions & 136 deletions

File tree

pydaes/integrate/_dae/radau.py

Lines changed: 8 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
from warnings import warn
32
from scipy.integrate._ivp.common import norm, EPS, warn_extraneous
43
from scipy.integrate._ivp.base import DenseOutput
54
from .dae import DaeSolver
@@ -37,11 +36,6 @@
3736
gamma = 3 + 3 ** (2 / 3) - 3 ** (1 / 3)
3837
alpha = 3 + 0.5 * (3 ** (1 / 3) - 3 ** (2 / 3))
3938
beta = 0.5 * (3 ** (5 / 6) + 3 ** (7 / 6))
40-
# Lambda = np.array([
41-
# [gamma, 0, 0],
42-
# [0, alpha, -beta],
43-
# [0, beta, alpha],
44-
# ])
4539
Lambda = np.array([
4640
[gamma, 0, 0],
4741
[0, alpha, beta],
@@ -60,14 +54,6 @@
6054
b = A[-1, :]
6155
b_hat = b + (E * gamma) @ A
6256

63-
# print(f"gamma, alpha, beta: {[gamma, alpha, beta]}")
64-
# print(f"A:\n{A}")
65-
# print(f"np.linalg.inv(A):\n{np.linalg.inv(A)}")
66-
# print(f"A_inv:\n{A_inv}")
67-
# print(f"b:\n{b}")
68-
# print(f"b_hat:\n{b_hat}")
69-
# exit()
70-
7157
# Interpolator coefficients.
7258
P = np.array([
7359
[13/3 + 7*S6/3, -23/3 - 22*S6/3, 10/3 + 5 * S6],
@@ -119,8 +105,6 @@ def solve_collocation_system(fun, t, y, h, Z0, scale, tol,
119105
The rate of convergence.
120106
"""
121107
n = y.shape[0]
122-
M_real = MU_REAL / h
123-
M_complex = MU_COMPLEX / h
124108

125109
# W = V of Fabien
126110
# A_inv = W of Fabien
@@ -145,42 +129,21 @@ def solve_collocation_system(fun, t, y, h, Z0, scale, tol,
145129
if not np.all(np.isfinite(F)):
146130
break
147131

148-
# f_real = F.T.dot(TI_REAL) - M_real * mass_matrix.dot(W[0])
149-
# f_complex = F.T.dot(TI_COMPLEX) - M_complex * mass_matrix.dot(W[1] + 1j * W[2])
150-
151-
# f_real = -h / MU_REAL * F.T.dot(TI_REAL)
152-
# f_complex = -h / MU_COMPLEX * F.T.dot(TI_COMPLEX)
153132
U = TI @ F
154-
# f_real = -h / MU_REAL * U[0]
155-
# f_complex = -h / MU_COMPLEX * (U[1] + 1j * U[2])
156133
f_real = -U[0]
157134
f_complex = -(U[1] + 1j * U[2])
158135

159-
# dW_real = solve_lu(LU_real, f_real)
160-
# dW_complex = solve_lu(LU_complex, f_complex)
161-
162-
# dW[0] = dW_real
163-
# dW[1] = dW_complex.real
164-
# dW[2] = dW_complex.imag
136+
dW_real = solve_lu(LU_real, f_real)
137+
dW_complex = solve_lu(LU_complex, f_complex)
165138

166-
dV_real = solve_lu(LU_real, f_real)
167-
dV_complex = solve_lu(LU_complex, f_complex)
168-
169-
dW[0] = dV_real
170-
dW[1] = dV_complex.real
171-
dW[2] = dV_complex.imag
172-
173-
# dW = TI @ dW
139+
dW[0] = dW_real
140+
dW[1] = dW_complex.real
141+
dW[2] = dW_complex.imag
174142

175143
dW_norm = norm(dW / scale)
176144
if dW_norm_old is not None:
177145
rate = dW_norm / dW_norm_old
178146

179-
# print(F"dW_norm: {dW_norm}")
180-
# print(F"rate: {rate}")
181-
# if rate is not None:
182-
# print(F"rate ** (NEWTON_MAXITER - k) / (1 - rate) * dW_norm: {rate ** (NEWTON_MAXITER - k) / (1 - rate) * dW_norm}")
183-
# print(F"tol: {tol}")
184147
if (rate is not None and (rate >= 1 or rate ** (NEWTON_MAXITER - k) / (1 - rate) * dW_norm > tol)):
185148
break
186149

@@ -363,7 +326,6 @@ class RadauDAE(DaeSolver):
363326
def __init__(self, fun, t0, y0, yp0, t_bound, max_step=np.inf,
364327
rtol=1e-3, atol=1e-6, jac=None, jac_sparsity=None,
365328
vectorized=False, first_step=None, **extraneous):
366-
warn("RadauDAE is currently under development and not finished. The error estimate is still flawed.")
367329
warn_extraneous(extraneous)
368330
super().__init__(fun, t0, y0, yp0, t_bound, rtol, atol, first_step, max_step, vectorized, jac, jac_sparsity)
369331
self.y_old = None
@@ -411,7 +373,6 @@ def _step_impl(self):
411373
current_jac = self.current_jac
412374
jac = self.jac
413375

414-
rejected = False
415376
step_accepted = False
416377
message = None
417378
while not step_accepted:
@@ -437,18 +398,13 @@ def _step_impl(self):
437398
converged = False
438399
while not converged:
439400
if LU_real is None or LU_complex is None:
440-
# LU_real = self.lu(h / MU_REAL * Jyp + Jy)
441-
# LU_complex = self.lu(h / MU_COMPLEX * Jyp + Jy)
442-
# LU_real = self.lu(Jyp + h / MU_REAL * Jy)
443-
# LU_complex = self.lu(Jyp + h / MU_COMPLEX * Jy)
444401
# Fabien (5.59) and (5.60)
445402
LU_real = self.lu(MU_REAL / h * Jyp + Jy)
446403
LU_complex = self.lu(MU_COMPLEX / h * Jyp + Jy)
447404

448405
converged, n_iter, Y, Yp, Z, rate = solve_collocation_system(
449406
self.fun, t, y, h, Z0, scale, self.newton_tol,
450407
LU_real, LU_complex, self.solve_lu)
451-
# print(f"converged: {converged}")
452408

453409
if not converged:
454410
if current_jac:
@@ -461,20 +417,16 @@ def _step_impl(self):
461417

462418
if not converged:
463419
h_abs *= 0.5
464-
# print(f"not converged")
465-
# print(f"h_abs: {h_abs}")
466420
LU_real = None
467421
LU_complex = None
468422
continue
469423

470424
# Hairer1996 (8.2b)
471-
y_new = y + Z[-1]
472-
# y_new = Y[-1]
425+
# y_new = y + Z[-1]
426+
y_new = Y[-1]
473427
yp_new = Yp[-1]
474428

475429
scale = atol + np.maximum(np.abs(y), np.abs(y_new)) * rtol
476-
# scale = atol + np.maximum(np.abs(yp), np.abs(yp_new)) * rtol
477-
# scale = atol + h * np.maximum(np.abs(yp), np.abs(yp_new)) * rtol
478430

479431
if True:
480432
# # ######################################################
@@ -615,70 +567,17 @@ def _step_impl(self):
615567
# print(f"error: {error}")
616568
# # error = -self.solve_lu(LU_real, self.fun(t_new, y_hat_new, yp_hat_new))
617569

618-
619-
620-
621-
622-
623-
# # TODO: These are the correct estimates for ODE Radau
624-
# err = h * MU_REAL * f + Z.T.dot(E * MU_REAL)
625-
# # error = err
626-
# error = self.solve_lu(LU_real, err) / (MU_REAL * h)
627-
628-
# # improve error estimate for stiff components
629-
# if unknown_z:
630-
# error = self.solve_lu(LU_real, err) / (h / MU_REAL)
631-
# # error = self.solve_lu(LU_real, err) * (1 / MU_REAL * h) # TODO: Why is this good?
632-
# # error = self.solve_lu(LU_real, yp + Jyp @ Z.T.dot(E) / h)# * (1 / MU_REAL * h)
633-
# # error = self.solve_lu(LU_real, err) #/ (MU_REAL * h)
634-
# # # error = self.solve_lu(LU_real, err)
635-
# # error = self.solve_lu(LU_real, err) / gamma0 * h
636-
# # # error = self.solve_lu(LU_real, h * gamma0 * yp + Z.T.dot(e))
637-
# # # error = self.solve_lu(LU_real, h * gamma0 * yp + Z.T.dot(e))
638-
# # # error = self.solve_lu(LU_real, (h * gamma0 * yp + Jyp @ Z.T.dot(e)))
639-
640-
# # print(f"err: {err}")
641-
# # print(f"error: {error}")
642-
# # print(f"h: {h}")
643-
# # error = self.solve_lu(LU_real, err / (h / gamma0))
644-
# # error = self.solve_lu(LU_real, (h * gamma0 * yp + Jyp @ Z.T.dot(e)))
645-
# # error = self.solve_lu(LU_real, err / (h * gamma0))
646-
# # error = self.solve_lu(LU_real, (yp + Z.T.dot(e) / (h * gamma0)))
647-
648-
# # D = np.eye(self.n) / (h * gamma0) + Jy
649-
# # error = np.linalg.solve(D, err / (h * gamma0))
650-
# pass
651-
# else:
652-
# # error = self.solve_lu(LU_real, err / (h * gamma0))
653-
# pass
654-
# # error = self.solve_lu(LU_real, err / gamma0h)
655-
# # error = self.solve_lu(LU_real, err * gamma0h)
656-
# # error = self.solve_lu(LU_real, (gamma0h * yp + Z.T.dot(e)) / gamma0h)
657-
# # error = self.solve_lu(LU_real, yp + Z.T.dot(e) / gamma0h)
658-
# # # error = self.solve_lu(LU_real, yp + Z.T.dot(e) / gamma0h)
659-
# # # error = self.solve_lu(LU_real, yp + Z.T.dot(E) / h)
660-
# # error = self.solve_lu(LU_real, yp + Z.T.dot(E) / h)
661-
# # error = self.solve_lu(LU_real, yp + Jyp @ Z.T.dot(E) / h)
662570

663571
error_norm = norm(error / scale)
664572

665573
safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter)
666574

667-
if rejected and error_norm > 1: # try with stabilised error estimate
668-
print(f"rejected")
669-
# # # error = self.solve_lu(LU_real, self.fun(t, y + error) + self.mass_matrix.dot(ZE))
670-
# # err = h * gamma0 * (yp + error) + Z.T.dot(e)
671-
# # error = self.solve_lu(LU_real, err)
672-
# error = self.solve_lu(LU_real, error)
673-
# error_norm = norm(error / scale)
674-
675575
if error_norm > 1:
676576
factor = predict_factor(h_abs, h_abs_old, error_norm, error_norm_old)
677577
h_abs *= max(MIN_FACTOR, safety * factor)
678578

679579
LU_real = None
680580
LU_complex = None
681-
rejected = True
682581
else:
683582
step_accepted = True
684583
else:
@@ -691,7 +590,6 @@ def _step_impl(self):
691590

692591
factor = predict_factor(h_abs, h_abs_old, error_norm, error_norm_old)
693592
factor = min(MAX_FACTOR, safety * factor)
694-
# print(f"factor: {factor}")
695593

696594
if not recompute_jac and factor < 1.2:
697595
factor = 1
@@ -709,14 +607,12 @@ def _step_impl(self):
709607
self.h_abs_old = self.h_abs
710608
self.error_norm_old = error_norm
711609

712-
# print(f"h_abs: {h_abs}")
713610
self.h_abs = h_abs * factor
714-
# print(f"self.h_abs: {self.h_abs}")
715611

716612
f_new = self.fun(t_new, y_new, yp_new)
717613

718614
self.y_old = y
719-
# self.yp_old = yp
615+
self.yp_old = yp
720616

721617
self.t = t_new
722618
self.y = y_new

pydaes/integrate/_dae/tests/test_dae.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ def test_integration_complex(method, jac):
139139

140140
parameters_rational = product(
141141
[False], # vectorized
142-
["BDF"], # method
143-
# ["BDF", "Radau"], # method
142+
["BDF", "Radau"], # method
144143
[[5, 9], [5, 1]], # t_span
145144
[None, J_rational, J_rational_sparse] # jac
146145
)
@@ -166,7 +165,10 @@ def test_integration_rational(vectorized, method, t_span, jac):
166165
assert_(res.success)
167166
assert_equal(res.status, 0)
168167

169-
assert_(0 < res.njev < 3)
168+
if method == "BDF":
169+
assert_(0 < res.njev < 3)
170+
else: # Radau
171+
assert_(0 < res.njev < 4)
170172
assert_(0 < res.nlu < 10)
171173

172174
y_true = sol_rational(res.t)
@@ -190,14 +192,10 @@ def test_integration_rational(vectorized, method, t_span, jac):
190192
assert_allclose(res.sol(res.t), res.y, rtol=1e-15, atol=1e-15)
191193

192194

193-
parameters_stiff = product(
194-
# ["BDF", "Radau"], # method
195-
["BDF"], # method
196-
["stability", "efficiency", None], # NDF_strategy
197-
[1, 2, 3, 4, 5, 6], # max_order
198-
)
199-
@pytest.mark.parametrize("method, NDF_strategy, max_order", parameters_stiff)
200-
def test_integration_stiff(method, NDF_strategy, max_order):
195+
parameters_stiff = ["BDF", "Radau"]
196+
@pytest.mark.slow
197+
@pytest.mark.parametrize("method", parameters_stiff)
198+
def test_integration_stiff(method):
201199
def fun_robertson(t, state):
202200
x, y, z = state
203201
return [
@@ -215,31 +213,39 @@ def F_robertson(t, state, statep):
215213
yp0 = fun_robertson(0, y0)
216214
tspan = [0, 1e8]
217215

218-
with suppress_warnings() as sup:
219-
sup.filter(UserWarning,
220-
"Choosing `max_order = 6` is not recomended due to its "
221-
"poor stability properties.")
216+
if method == "BDF":
217+
for NDF_strategy, max_order in product(
218+
["stability", "efficiency", None], # NDF_strategy
219+
[1, 2, 3, 4, 5, 6], # max_order
220+
):
221+
with suppress_warnings() as sup:
222+
sup.filter(UserWarning,
223+
"Choosing `max_order = 6` is not recomended due to its "
224+
"poor stability properties.")
225+
res = solve_dae(F_robertson, tspan, y0, yp0, rtol=rtol,
226+
atol=atol, method=method, max_order=max_order,
227+
NDF_strategy=NDF_strategy)
228+
else: # Radau
222229
res = solve_dae(F_robertson, tspan, y0, yp0, rtol=rtol,
223-
atol=atol, method=method, max_order=max_order,
224-
NDF_strategy=NDF_strategy)
230+
atol=atol, method=method)
225231

226232
# If the stiff mode is not activated correctly, these numbers will be much
227233
# bigger (see max_order=1 case)
228-
if max_order == 1:
234+
if method == "BDF" and max_order == 1:
229235
assert res.nfev < 21000
230236
else:
231237
assert res.nfev < 5000
232238
assert res.njev < 200
233239

234240
if __name__ == "__main__":
235-
# for params in parameters_linear:
236-
# test_integration_const_jac(*params)
241+
for params in parameters_linear:
242+
test_integration_const_jac(*params)
237243

238-
# for params in parameters_complex:
239-
# test_integration_complex(*params)
244+
for params in parameters_complex:
245+
test_integration_complex(*params)
240246

241247
for params in parameters_rational:
242248
test_integration_rational(*params)
243249

244-
# for params in parameters_stiff:
245-
# test_integration_stiff(*params)
250+
for params in parameters_stiff:
251+
test_integration_stiff(*params)

0 commit comments

Comments
 (0)