Skip to content

Commit 3a080d7

Browse files
authored
Merge pull request #2620 from stan-dev/cvodes_err_msg_issue_1039
use CVODES error string instead of numeric flag in error messages
2 parents a5e40e7 + 74c37fd commit 3a080d7

11 files changed

Lines changed: 316 additions & 231 deletions

stan/math/prim/err/check_flag_sundials.hpp

Lines changed: 206 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,177 @@
22
#define STAN_MATH_PRIM_ERR_CHECK_FLAG_SUNDIALS_HPP
33

44
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/throw_domain_error.hpp>
5+
#include <stan/math/prim/err/domain_error.hpp>
6+
#include <kinsol/kinsol.h>
7+
#include <cvodes/cvodes.h>
68

79
namespace stan {
810
namespace math {
911

12+
#define CHECK_CVODES_CALL(call) cvodes_check(call, #call)
13+
#define CHECK_KINSOL_CALL(call) kinsol_check(call, #call)
14+
15+
/**
16+
* Map cvodes error flag to acutally error msg. The most frequent
17+
* errors are put at the top. An alternative would be to use std::map
18+
* but in our case the difference would be negligible. Note that we
19+
* don't use CVGetReturnFlagName function to retrieve the constant
20+
* because sanitizer indicates it contains mem leak.
21+
*
22+
* @param flag
23+
*
24+
* @return error msg string constant and actuall informative msg
25+
*/
26+
inline std::array<std::string, 2> cvodes_flag_msg(int flag) {
27+
std::array<std::string, 2> msg;
28+
switch (flag) {
29+
case -1:
30+
msg = {"CV_TOO_MUCH_WORK",
31+
"The solver took mxstep internal steps but could not reach tout"};
32+
break; // NOLINT
33+
case -2:
34+
msg = {"CV_TOO_MUCH_ACC",
35+
"The solver could not satisfy the accuracy demanded by the user "
36+
"for some internal step"};
37+
break; // NOLINT
38+
case -3:
39+
msg = {"CV_ERR_FAILURE",
40+
"Error test failures occurred too many times during one internal "
41+
"time step or minimum step size was reached"};
42+
break; // NOLINT
43+
case -4:
44+
msg = {"CV_CONV_FAILURE",
45+
"Convergence test failures occurred too many times during one "
46+
"internal time step or minimum step size was reached"};
47+
break; // NOLINT
48+
case -8:
49+
msg = {"CV_RHSFUNC_FAIL",
50+
"The right-hand side function failed in an unrecoverable manner"};
51+
break; // NOLINT
52+
case -9:
53+
msg = {"CV_FIRST_RHSFUNC_ERR",
54+
"The right-hand side function failed at the first call"};
55+
break; // NOLINT
56+
case -10:
57+
msg = {"CV_REPTD_RHSFUNC_ERR",
58+
"The right-hand side function had repetead recoverable errors"};
59+
break; // NOLINT
60+
case -11:
61+
msg = {"CV_UNREC_RHSFUNC_ERR",
62+
"The right-hand side function had a recoverable error, but no "
63+
"recovery is possible"};
64+
break; // NOLINT
65+
case -27:
66+
msg = {"CV_TOO_CLOSE",
67+
"The output and initial times are too close to each other"};
68+
break; // NOLINT
69+
default:
70+
switch (flag) {
71+
case -5:
72+
msg = {"CV_LINIT_FAIL",
73+
"The linear solver's initialization function failed"};
74+
break; // NOLINT
75+
case -6:
76+
msg = {"CV_LSETUP_FAIL",
77+
"The linear solver's setup function failed in an "
78+
"unrecoverable manner"};
79+
break; // NOLINT
80+
case -7:
81+
msg = {"CV_LSOLVE_FAIL",
82+
"The linear solver's solve function failed in an "
83+
"unrecoverable manner"};
84+
break; // NOLINT
85+
case -20:
86+
msg = {"CV_MEM_FAIL", "A memory allocation failed"};
87+
break; // NOLINT
88+
case -21:
89+
msg = {"CV_MEM_NULL", "The cvode_mem argument was NULL"};
90+
break; // NOLINT
91+
case -22:
92+
msg = {"CV_ILL_INPUT", "One of the function inputs is illegal"};
93+
break; // NOLINT
94+
case -23:
95+
msg = {"CV_NO_MALLOC",
96+
"The CVODE memory block was not allocated by a call to "
97+
"CVodeMalloc"};
98+
break; // NOLINT
99+
case -24:
100+
msg = {"CV_BAD_K",
101+
"The derivative order k is larger than the order used"};
102+
break; // NOLINT
103+
case -25:
104+
msg = {"CV_BAD_T", "The time t s outside the last step taken"};
105+
break; // NOLINT
106+
case -26:
107+
msg = {"CV_BAD_DKY", "The output derivative vector is NULL"};
108+
break; // NOLINT
109+
case -40:
110+
msg = {"CV_BAD_IS",
111+
"The sensitivity index is larger than the number of "
112+
"sensitivities computed"};
113+
break; // NOLINT
114+
case -41:
115+
msg = {"CV_NO_SENS",
116+
"Forward sensitivity integration was not activated"};
117+
break; // NOLINT
118+
case -42:
119+
msg = {"CV_SRHSFUNC_FAIL",
120+
"The sensitivity right-hand side function failed in an "
121+
"unrecoverable manner"};
122+
break; // NOLINT
123+
case -43:
124+
msg = {"CV_FIRST_SRHSFUNC_ER",
125+
"The sensitivity right-hand side function failed at the first "
126+
"call"};
127+
break; // NOLINT
128+
case -44:
129+
msg = {"CV_REPTD_SRHSFUNC_ER",
130+
"The sensitivity ight-hand side function had repetead "
131+
"recoverable errors"};
132+
break; // NOLINT
133+
case -45:
134+
msg = {"CV_UNREC_SRHSFUNC_ER",
135+
"The sensitivity right-hand side function had a recoverable "
136+
"error, but no recovery is possible"};
137+
break; // NOLINT
138+
case -101:
139+
msg = {"CV_ADJMEM_NULL", "The cvadj_mem argument was NULL"};
140+
break; // NOLINT
141+
case -103:
142+
msg = {"CV_BAD_TB0",
143+
"The final time for the adjoint problem is outside the "
144+
"interval over which the forward problem was solved"};
145+
break; // NOLINT
146+
case -104:
147+
msg = {"CV_BCKMEM_NULL",
148+
"The cvodes memory for the backward problem was not created"};
149+
break; // NOLINT
150+
case -105:
151+
msg = {"CV_REIFWD_FAIL",
152+
"Reinitialization of the forward problem failed at the first "
153+
"checkpoint"};
154+
break; // NOLINT
155+
case -106:
156+
msg = {
157+
"CV_FWD_FAIL",
158+
"An error occured during the integration of the forward problem"};
159+
break; // NOLINT
160+
case -107:
161+
msg = {"CV_BAD_ITASK", "Wrong task for backward integration"};
162+
break; // NOLINT
163+
case -108:
164+
msg = {"CV_BAD_TBOUT",
165+
"The desired output time is outside the interval over which "
166+
"the forward problem was solved"};
167+
break; // NOLINT
168+
case -109:
169+
msg = {"CV_GETY_BADT", "Wrong time in interpolation function"};
170+
break; // NOLINT
171+
}
172+
}
173+
return msg;
174+
}
175+
10176
/**
11177
* Throws a std::runtime_error exception when a Sundial function fails
12178
* (i.e. returns a negative flag)
@@ -15,34 +181,59 @@ namespace math {
15181
* @param func_name Name of the function that returned the flag
16182
* @throw <code>std::runtime_error</code> if the flag is negative
17183
*/
18-
inline void check_flag_sundials(int flag, const char* func_name) {
184+
inline void cvodes_check(int flag, const char* func_name) {
19185
if (flag < 0) {
20186
std::ostringstream ss;
21-
ss << func_name << " failed with error flag " << flag << ".";
187+
ss << func_name << " failed with error flag " << flag << ": \n"
188+
<< cvodes_flag_msg(flag).at(1) << ".";
189+
if (flag == -1 || flag == -4) {
190+
throw std::domain_error(ss.str());
191+
} else {
192+
throw std::runtime_error(ss.str());
193+
}
194+
}
195+
}
196+
197+
/**
198+
* Throws an exception message when the functions in KINSOL
199+
* fails. "KINGetReturnFlagName()" from SUNDIALS has a mem leak bug so
200+
* until it's fixed we cannot use it to extract flag error string.
201+
*
202+
* @param flag Error flag
203+
* @param func_name calling function name
204+
* @throw <code>std::runtime_error</code> if the flag is negative.
205+
*/
206+
inline void kinsol_check(int flag, const char* func_name) {
207+
if (flag < 0) {
208+
std::ostringstream ss;
209+
ss << "algebra_solver failed with error flag " << flag << ".";
22210
throw std::runtime_error(ss.str());
23211
}
24212
}
25213

26214
/**
27-
* Throws an exception message when the function KINSol()
28-
* (call to the solver) fails. When the exception is caused
215+
* Throws an exception message when the KINSol() call fails.
216+
* When the exception is caused
29217
* by a tuning parameter the user controls, gives a specific
30218
* error.
31219
*
32220
* @param flag Error flag
33-
* @param max_num_steps Maximum number of iterations the algebra solver
34-
* should take before throwing an error
35-
* @throw <code>std::domain_error</code> if flag means maximum number of
36-
* iterations exceeded in the algebra solver.
37-
* @throw <code>std::runtime_error</code> if the flag is negative for
38-
* any other reason.
221+
* @param func_name calling function name
222+
* @param max_num_steps max number of nonlinear iterations.
223+
* @throw <code>std::runtime_error</code> if the flag is negative.
224+
* @throw <code>std::domain_error</code> if the flag indicates max
225+
* number of steps is exceeded.
39226
*/
40-
inline void check_flag_kinsol(int flag,
41-
long int max_num_steps) { // NOLINT(runtime/int)
227+
inline void kinsol_check(int flag, const char* func_name,
228+
long int max_num_steps) { // NOLINT(runtime/int)
42229
std::ostringstream ss;
43230
if (flag == -6) {
44-
throw_domain_error("algebra_solver", "maximum number of iterations",
45-
max_num_steps, "(", ") was exceeded in the solve.");
231+
domain_error("algebra_solver", "maximum number of iterations",
232+
max_num_steps, "(", ") was exceeded in the solve.");
233+
} else if (flag == -11) {
234+
ss << "The linear solver’s setup function failed "
235+
<< "in an unrecoverable manner.";
236+
throw std::runtime_error(ss.str());
46237
} else if (flag < 0) {
47238
ss << "algebra_solver failed with error flag " << flag << ".";
48239
throw std::runtime_error(ss.str());

stan/math/rev/functor/algebra_solver_fp.hpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -239,18 +239,13 @@ struct FixedPointSolver<KinsolFixedPointEnv<F>, fp_jac_type> {
239239
const int default_anderson_depth = 4;
240240
int anderson_depth = std::min(N, default_anderson_depth);
241241

242-
check_flag_sundials(KINSetNumMaxIters(mem, max_num_steps),
243-
"KINSetNumMaxIters");
244-
check_flag_sundials(KINSetMAA(mem, anderson_depth), "KINSetMAA");
245-
check_flag_sundials(KINInit(mem, &env.kinsol_f_system, env.nv_x_),
246-
"KINInit");
247-
check_flag_sundials(KINSetFuncNormTol(mem, f_tol), "KINSetFuncNormTol");
248-
check_flag_sundials(KINSetUserData(mem, static_cast<void*>(&env)),
249-
"KINSetUserData");
250-
251-
check_flag_kinsol(
252-
KINSol(mem, env.nv_x_, KIN_FP, env.nv_u_scal_, env.nv_f_scal_),
253-
max_num_steps);
242+
CHECK_KINSOL_CALL(KINSetNumMaxIters(mem, max_num_steps));
243+
CHECK_KINSOL_CALL(KINSetMAA(mem, anderson_depth));
244+
CHECK_KINSOL_CALL(KINInit(mem, &env.kinsol_f_system, env.nv_x_));
245+
CHECK_KINSOL_CALL(KINSetFuncNormTol(mem, f_tol));
246+
CHECK_KINSOL_CALL(KINSetUserData(mem, static_cast<void*>(&env)));
247+
kinsol_check(KINSol(mem, env.nv_x_, KIN_FP, env.nv_u_scal_, env.nv_f_scal_),
248+
"KINSol", max_num_steps);
254249

255250
for (int i = 0; i < N; ++i) {
256251
x(i) = NV_Ith_S(env.nv_x_, i);

stan/math/rev/functor/cvodes_integrator.hpp

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -270,62 +270,44 @@ class cvodes_integrator {
270270
}
271271

272272
try {
273-
check_flag_sundials(CVodeInit(cvodes_mem, &cvodes_integrator::cv_rhs,
274-
value_of(t0_), nv_state_),
275-
"CVodeInit");
273+
CHECK_CVODES_CALL(CVodeInit(cvodes_mem, &cvodes_integrator::cv_rhs,
274+
value_of(t0_), nv_state_));
276275

277276
// Assign pointer to this as user data
278-
check_flag_sundials(
279-
CVodeSetUserData(cvodes_mem, reinterpret_cast<void*>(this)),
280-
"CVodeSetUserData");
277+
CHECK_CVODES_CALL(
278+
CVodeSetUserData(cvodes_mem, reinterpret_cast<void*>(this)));
281279

282280
cvodes_set_options(cvodes_mem, max_num_steps_);
283281

284-
check_flag_sundials(CVodeSStolerances(cvodes_mem, relative_tolerance_,
285-
absolute_tolerance_),
286-
"CVodeSStolerances");
282+
CHECK_CVODES_CALL(CVodeSStolerances(cvodes_mem, relative_tolerance_,
283+
absolute_tolerance_));
287284

288-
check_flag_sundials(CVodeSetLinearSolver(cvodes_mem, LS_, A_),
289-
"CVodeSetLinearSolver");
290-
check_flag_sundials(
291-
CVodeSetJacFn(cvodes_mem, &cvodes_integrator::cv_jacobian_states),
292-
"CVodeSetJacFn");
285+
CHECK_CVODES_CALL(CVodeSetLinearSolver(cvodes_mem, LS_, A_));
286+
CHECK_CVODES_CALL(
287+
CVodeSetJacFn(cvodes_mem, &cvodes_integrator::cv_jacobian_states));
293288

294289
// initialize forward sensitivity system of CVODES as needed
295290
if (num_y0_vars_ + num_args_vars_ > 0) {
296-
check_flag_sundials(
297-
CVodeSensInit(
298-
cvodes_mem, static_cast<int>(num_y0_vars_ + num_args_vars_),
299-
CV_STAGGERED, &cvodes_integrator::cv_rhs_sens, nv_state_sens_),
300-
"CVodeSensInit");
291+
CHECK_CVODES_CALL(CVodeSensInit(
292+
cvodes_mem, static_cast<int>(num_y0_vars_ + num_args_vars_),
293+
CV_STAGGERED, &cvodes_integrator::cv_rhs_sens, nv_state_sens_));
301294

302-
check_flag_sundials(CVodeSetSensErrCon(cvodes_mem, SUNTRUE),
303-
"CVodeSetSensErrCon");
295+
CHECK_CVODES_CALL(CVodeSetSensErrCon(cvodes_mem, SUNTRUE));
304296

305-
check_flag_sundials(CVodeSensEEtolerances(cvodes_mem),
306-
"CVodeSensEEtolerances");
297+
CHECK_CVODES_CALL(CVodeSensEEtolerances(cvodes_mem));
307298
}
308299

309300
double t_init = value_of(t0_);
310301
for (size_t n = 0; n < ts_.size(); ++n) {
311302
double t_final = value_of(ts_[n]);
312303

313304
if (t_final != t_init) {
314-
int error_code
315-
= CVode(cvodes_mem, t_final, nv_state_, &t_init, CV_NORMAL);
316-
317-
if (error_code == CV_TOO_MUCH_WORK) {
318-
throw_domain_error(function_name_, "", t_final,
319-
"Failed to integrate to next output time (",
320-
") in less than max_num_steps steps");
321-
} else {
322-
check_flag_sundials(error_code, "CVode");
323-
}
305+
CHECK_CVODES_CALL(
306+
CVode(cvodes_mem, t_final, nv_state_, &t_init, CV_NORMAL));
324307

325308
if (num_y0_vars_ + num_args_vars_ > 0) {
326-
check_flag_sundials(
327-
CVodeGetSens(cvodes_mem, &t_init, nv_state_sens_),
328-
"CVodeGetSens");
309+
CHECK_CVODES_CALL(
310+
CVodeGetSens(cvodes_mem, &t_init, nv_state_sens_));
329311
}
330312
}
331313

0 commit comments

Comments
 (0)