@@ -85,47 +85,47 @@ class cvodes_integrator_memory : public chainable_alloc {
8585 SUNLinearSolver LS_;
8686
8787 cvodes_integrator_memory (const F& f,
88- const Eigen::Matrix<T_y0, Eigen::Dynamic, 1 >& y0,
89- const T_t0& t0, const std::vector<T_ts>& ts,
90- const T_Args&... args) :
91- N_ (y0.size()),
92- f_ (f),
93- y0_ (y0),
94- t0_ (t0),
95- ts_ (ts),
96- args_tuple_ (std::make_tuple(args...)),
97- value_of_args_tuple_ (std::make_tuple(value_of(args)...)),
98- y_ (ts_.size()),
99- cvodes_mem_ (nullptr ),
100- state (value_of(y0)) {
101- if (N_ > 0 ) {
88+ const Eigen::Matrix<T_y0, Eigen::Dynamic, 1 >& y0,
89+ const T_t0& t0, const std::vector<T_ts>& ts,
90+ const T_Args&... args)
91+ : N_(y0.size()),
92+ f_ (f),
93+ y0_(y0),
94+ t0_(t0),
95+ ts_(ts),
96+ args_tuple_(std::make_tuple(args...)),
97+ value_of_args_tuple_(std::make_tuple(value_of(args)...)),
98+ y_(ts_.size()),
99+ cvodes_mem_(nullptr ),
100+ state(value_of(y0)) {
101+ if (N_ > 0 ) {
102102 nv_state_ = N_VMake_Serial (N_, state.data ());
103103 A_ = SUNDenseMatrix (N_, N_);
104104 LS_ = SUNDenseLinearSolver (nv_state_, A_);
105105
106106 cvodes_mem_ = CVodeCreate (Lmm);
107107 if (cvodes_mem_ == nullptr ) {
108- throw std::runtime_error (" CVodeCreate failed to allocate memory" );
108+ throw std::runtime_error (" CVodeCreate failed to allocate memory" );
109109 }
110110 }
111111 }
112112
113113 ~cvodes_integrator_memory () {
114- if (N_ > 0 ) {
114+ if (N_ > 0 ) {
115115 SUNLinSolFree (LS_);
116116 SUNMatDestroy (A_);
117117
118118 N_VDestroy_Serial (nv_state_);
119119
120- if (cvodes_mem_) {
121- CVodeFree (&cvodes_mem_);
120+ if (cvodes_mem_) {
121+ CVodeFree (&cvodes_mem_);
122122 }
123123 }
124124 }
125125
126126 friend class cvodes_integrator_vari <Lmm, F, T_y0, T_t0, T_ts, T_Args...>;
127127};
128-
128+
129129/* *
130130 * Integrator interface for CVODES' ODE solvers (Adams & BDF
131131 * methods).
@@ -233,9 +233,9 @@ class cvodes_integrator_vari : public vari {
233233 inline void rhs (double t, const double y[], double dy_dt[]) const {
234234 const Eigen::VectorXd y_vec = Eigen::Map<const Eigen::VectorXd>(y, N_);
235235
236- Eigen::VectorXd dy_dt_vec
237- = apply ( [&](auto &&... args) { return memory->f_ (t, y_vec, msgs_, args...); },
238- memory->value_of_args_tuple_ );
236+ Eigen::VectorXd dy_dt_vec = apply (
237+ [&](auto &&... args) { return memory->f_ (t, y_vec, msgs_, args...); },
238+ memory->value_of_args_tuple_ );
239239
240240 check_size_match (" cvodes_integrator::rhs" , " dy_dt" , dy_dt_vec.size (),
241241 " states" , N_);
@@ -266,9 +266,9 @@ class cvodes_integrator_vari : public vari {
266266 for (size_t i = 0 ; i < y_vars.size (); ++i)
267267 y_vars (i) = new vari (y_vec (i));
268268
269- Eigen::Matrix<var, Eigen::Dynamic, 1 > f_y_t_vars
270- = apply ( [&](auto &&... args) { return memory->f_ (t, y_vars, msgs_, args...); },
271- memory->value_of_args_tuple_ );
269+ Eigen::Matrix<var, Eigen::Dynamic, 1 > f_y_t_vars = apply (
270+ [&](auto &&... args) { return memory->f_ (t, y_vars, msgs_, args...); },
271+ memory->value_of_args_tuple_ );
272272
273273 check_size_match (" coupled_ode_system1" , " dy_dt" , f_y_t_vars.size (),
274274 " states" , N_);
@@ -303,15 +303,15 @@ class cvodes_integrator_vari : public vari {
303303 nested_rev_autodiff nested;
304304
305305 auto local_args_tuple = apply (
306- [&](auto &&... args) {
306+ [&](auto &&... args) {
307307 return std::tuple<decltype (deep_copy_vars (args))...>(
308308 deep_copy_vars (args)...);
309309 },
310310 memory->args_tuple_ );
311311
312- Eigen::Matrix<var, Eigen::Dynamic, 1 > f_y_t_vars
313- = apply ( [&](auto &&... args) { return memory->f_ (t, y_vec, msgs_, args...); },
314- local_args_tuple);
312+ Eigen::Matrix<var, Eigen::Dynamic, 1 > f_y_t_vars = apply (
313+ [&](auto &&... args) { return memory->f_ (t, y_vec, msgs_, args...); },
314+ local_args_tuple);
315315
316316 check_size_match (" coupled_ode_system2" , " dy_dt" , f_y_t_vars.size (),
317317 " states" , N_);
@@ -335,8 +335,9 @@ class cvodes_integrator_vari : public vari {
335335 Eigen::MatrixXd Jfy;
336336
337337 auto f_wrapped = [&](const Eigen::Matrix<var, Eigen::Dynamic, 1 >& y) {
338- return apply ([&](auto &&... args) { return memory->f_ (t, y, msgs_, args...); },
339- memory->value_of_args_tuple_ );
338+ return apply (
339+ [&](auto &&... args) { return memory->f_ (t, y, msgs_, args...); },
340+ memory->value_of_args_tuple_ );
340341 };
341342
342343 jacobian (f_wrapped, Eigen::Map<const Eigen::VectorXd>(NV_DATA_S (y), N_), fy,
@@ -362,8 +363,9 @@ class cvodes_integrator_vari : public vari {
362363 Eigen::MatrixXd Jfy;
363364
364365 auto f_wrapped = [&](const Eigen::Matrix<var, Eigen::Dynamic, 1 >& y) {
365- return apply ([&](auto &&... args) { return memory->f_ (t, y, msgs_, args...); },
366- memory->value_of_args_tuple_ );
366+ return apply (
367+ [&](auto &&... args) { return memory->f_ (t, y, msgs_, args...); },
368+ memory->value_of_args_tuple_ );
367369 };
368370
369371 jacobian (f_wrapped, Eigen::Map<const Eigen::VectorXd>(NV_DATA_S (y), N_), fy,
@@ -404,8 +406,8 @@ class cvodes_integrator_vari : public vari {
404406 const T_Args&... args)
405407 : vari(NOT_A_NUMBER),
406408 N_ (y0.size()),
407- returned_(false ),
408- memory(NULL ),
409+ returned_(false ),
410+ memory(NULL ),
409411 msgs_(msgs),
410412 relative_tolerance_(relative_tolerance),
411413 absolute_tolerance_(absolute_tolerance),
@@ -424,9 +426,9 @@ class cvodes_integrator_vari : public vari {
424426 args_vars_)) {
425427 const char * fun = " cvodes_integrator::integrate" ;
426428
427- memory = new cvodes_integrator_memory
428- <Lmm, F, T_y0, T_t0, T_ts, T_Args...>( f, y0, t0, ts, args...);
429-
429+ memory = new cvodes_integrator_memory<Lmm, F, T_y0, T_t0, T_ts, T_Args...>(
430+ f, y0, t0, ts, args...);
431+
430432 save_varis (t0_varis_, t0);
431433 save_varis (ts_varis_, ts);
432434 save_varis (y0_varis_, y0);
@@ -454,8 +456,7 @@ class cvodes_integrator_vari : public vari {
454456 check_positive (fun, " max_num_steps" , max_num_steps_);
455457 }
456458
457- ~cvodes_integrator_vari () {
458- }
459+ ~cvodes_integrator_vari () {}
459460
460461 /* *
461462 * Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of
@@ -469,51 +470,54 @@ class cvodes_integrator_vari : public vari {
469470 const double t0_dbl = value_of (memory->t0_ );
470471 const std::vector<double > ts_dbl = value_of (memory->ts_ );
471472
472- check_flag_sundials (CVodeInit (memory->cvodes_mem_ , &cvodes_integrator_vari::cv_rhs, t0_dbl,
473- memory->nv_state_ ),
474- " CVodeInit" );
473+ check_flag_sundials (
474+ CVodeInit (memory->cvodes_mem_ , &cvodes_integrator_vari::cv_rhs, t0_dbl,
475+ memory->nv_state_ ),
476+ " CVodeInit" );
475477
476478 // Assign pointer to this as user data
477479 check_flag_sundials (
478- CVodeSetUserData (memory->cvodes_mem_ , reinterpret_cast <void *>(this )),
479- " CVodeSetUserData" );
480+ CVodeSetUserData (memory->cvodes_mem_ , reinterpret_cast <void *>(this )),
481+ " CVodeSetUserData" );
480482
481- cvodes_set_options (memory->cvodes_mem_ , relative_tolerance_, absolute_tolerance_,
482- max_num_steps_);
483+ cvodes_set_options (memory->cvodes_mem_ , relative_tolerance_,
484+ absolute_tolerance_, max_num_steps_);
483485
484486 // for the stiff solvers we need to reserve additional memory
485487 // and provide a Jacobian function call. new API since 3.0.0:
486488 // create matrix object and linear solver object; resource
487489 // (de-)allocation is handled in the cvodes_ode_data
488- check_flag_sundials (CVodeSetLinearSolver (memory->cvodes_mem_ , memory->LS_ , memory->A_ ),
489- " CVodeSetLinearSolver" );
490+ check_flag_sundials (
491+ CVodeSetLinearSolver (memory->cvodes_mem_ , memory->LS_ , memory->A_ ),
492+ " CVodeSetLinearSolver" );
490493
491494 check_flag_sundials (
492- CVodeSetJacFn (memory->cvodes_mem_ ,
493- &cvodes_integrator_vari::cv_jacobian_states),
494- " CVodeSetJacFn" );
495+ CVodeSetJacFn (memory->cvodes_mem_ ,
496+ &cvodes_integrator_vari::cv_jacobian_states),
497+ " CVodeSetJacFn" );
495498
496499 // initialize forward sensitivity system of CVODES as needed
497500 if (t0_vars_ + ts_vars_ + y0_vars_ + args_vars_ > 0 ) {
498501 check_flag_sundials (CVodeAdjInit (memory->cvodes_mem_ , 25 , CV_HERMITE),
499- " CVodeAdjInit" );
502+ " CVodeAdjInit" );
500503 }
501504
502505 double t_init = t0_dbl;
503506 for (size_t n = 0 ; n < ts_dbl.size (); ++n) {
504507 double t_final = ts_dbl[n];
505508
506509 if (t_final != t_init) {
507- if (t0_vars_ + ts_vars_ + y0_vars_ + args_vars_ > 0 ) {
508- int ncheck;
509- check_flag_sundials (CVodeF (memory->cvodes_mem_ , t_final, memory->nv_state_ , &t_init,
510- CV_NORMAL, &ncheck),
511- " CVodeF" );
512- } else {
513- check_flag_sundials (
514- CVode (memory->cvodes_mem_ , t_final, memory->nv_state_ , &t_init, CV_NORMAL),
515- " CVode" );
516- }
510+ if (t0_vars_ + ts_vars_ + y0_vars_ + args_vars_ > 0 ) {
511+ int ncheck;
512+ check_flag_sundials (
513+ CVodeF (memory->cvodes_mem_ , t_final, memory->nv_state_ , &t_init,
514+ CV_NORMAL, &ncheck),
515+ " CVodeF" );
516+ } else {
517+ check_flag_sundials (CVode (memory->cvodes_mem_ , t_final,
518+ memory->nv_state_ , &t_init, CV_NORMAL),
519+ " CVode" );
520+ }
517521 }
518522
519523 memory->y_ [n] = memory->state ;
@@ -528,13 +532,13 @@ class cvodes_integrator_vari : public vari {
528532 virtual void chain () {
529533 // std::cout << "chain" << std::endl; <-- Good way to verify it's only
530534 // being called once
531- if (memory == NULL )
535+ if (memory == NULL )
532536 return ;
533537
534- if (memory->cvodes_mem_ == NULL )
538+ if (memory->cvodes_mem_ == NULL )
535539 return ;
536-
537- if (returned_ == false )
540+
541+ if (returned_ == false )
538542 return ;
539543
540544 if (t0_vars_ + ts_vars_ + y0_vars_ + args_vars_ == 0 ) {
@@ -559,25 +563,27 @@ class cvodes_integrator_vari : public vari {
559563 check_flag_sundials (CVodeCreateB (memory->cvodes_mem_ , Lmm, &indexB),
560564 " CVodeCreateB" );
561565
562- check_flag_sundials (
563- CVodeSetUserDataB (memory-> cvodes_mem_ , indexB, reinterpret_cast <void *>(this )),
564- " CVodeSetUserDataB" );
566+ check_flag_sundials (CVodeSetUserDataB (memory-> cvodes_mem_ , indexB,
567+ reinterpret_cast <void *>(this )),
568+ " CVodeSetUserDataB" );
565569
566570 // The ode_rhs_adj_sense functions passed in here cause problems with
567571 // the autodiff stack (they can cause reallocations of the internal
568572 // vectors and cause segfaults)
569- check_flag_sundials (CVodeInitB (memory->cvodes_mem_ , indexB,
570- &cvodes_integrator_vari::cv_rhs_adj_sens,
571- value_of (memory->ts_ .back ()), nv_state_sens),
572- " CVodeInitB" );
573+ check_flag_sundials (
574+ CVodeInitB (memory->cvodes_mem_ , indexB,
575+ &cvodes_integrator_vari::cv_rhs_adj_sens,
576+ value_of (memory->ts_ .back ()), nv_state_sens),
577+ " CVodeInitB" );
573578
574579 check_flag_sundials (
575580 CVodeSStolerancesB (memory->cvodes_mem_ , indexB, relative_tolerance_,
576581 absolute_tolerance_),
577582 " CVodeSStolerancesB" );
578583
579- check_flag_sundials (CVodeSetLinearSolverB (memory->cvodes_mem_ , indexB, LSB_, AB_),
580- " CVodeSetLinearSolverB" );
584+ check_flag_sundials (
585+ CVodeSetLinearSolverB (memory->cvodes_mem_ , indexB, LSB_, AB_),
586+ " CVodeSetLinearSolverB" );
581587
582588 // The same autodiff issue that applies to ode_rhs_adj_sense applies
583589 // here
@@ -594,12 +600,13 @@ class cvodes_integrator_vari : public vari {
594600 " CVodeQuadInitB" );
595601
596602 check_flag_sundials (
597- CVodeQuadSStolerancesB (memory->cvodes_mem_ , indexB, relative_tolerance_,
598- absolute_tolerance_),
603+ CVodeQuadSStolerancesB (memory->cvodes_mem_ , indexB,
604+ relative_tolerance_, absolute_tolerance_),
599605 " CVodeQuadSStolerancesB" );
600606
601- check_flag_sundials (CVodeSetQuadErrConB (memory->cvodes_mem_ , indexB, SUNTRUE),
602- " CVodeSetQuadErrConB" );
607+ check_flag_sundials (
608+ CVodeSetQuadErrConB (memory->cvodes_mem_ , indexB, SUNTRUE),
609+ " CVodeSetQuadErrConB" );
603610 }
604611
605612 // At every time step, collect the adjoints from the output
@@ -608,18 +615,19 @@ class cvodes_integrator_vari : public vari {
608615 for (int i = memory->ts_ .size () - 1 ; i >= 0 ; --i) {
609616 // Take in the adjoints from all the output variables at this point
610617 // in time
611- Eigen::VectorXd step_sens = Eigen::VectorXd::Zero (N_);
618+ Eigen::VectorXd step_sens = Eigen::VectorXd::Zero (N_);
612619 for (int j = 0 ; j < N_; j++) {
613- // std::cout << "i: " << i << ", j: " << j << std::endl;
620+ // std::cout << "i: " << i << ", j: " << j << std::endl;
614621 state_sens (j) += non_chaining_varis_[i * N_ + j]->adj_ ;
615622 step_sens (j) += non_chaining_varis_[i * N_ + j]->adj_ ;
616623 }
617624
618625 if (ts_vars_ > 0 && i >= 0 ) {
619626 ts_varis_[i]->adj_ += apply (
620627 [&](auto &&... args) {
621- double adj = step_sens.dot (memory->f_ (t_init, memory->y_ [i], msgs_, args...));
622- // std::cout << "adj: " << adj << ", i: " << i << std::endl;
628+ double adj = step_sens.dot (
629+ memory->f_ (t_init, memory->y_ [i], msgs_, args...));
630+ // std::cout << "adj: " << adj << ", i: " << i << std::endl;
623631 return adj;
624632 },
625633 memory->value_of_args_tuple_ );
@@ -631,27 +639,29 @@ class cvodes_integrator_vari : public vari {
631639 CVodeReInitB (memory->cvodes_mem_ , indexB, t_init, nv_state_sens),
632640 " CVodeReInitB" );
633641
634- if (args_vars_ > 0 ) {
635- check_flag_sundials (CVodeQuadReInitB (memory->cvodes_mem_ , indexB, nv_quad),
636- " CVodeQuadReInitB" );
637- }
642+ if (args_vars_ > 0 ) {
643+ check_flag_sundials (
644+ CVodeQuadReInitB (memory->cvodes_mem_ , indexB, nv_quad),
645+ " CVodeQuadReInitB" );
646+ }
638647
639648 check_flag_sundials (CVodeB (memory->cvodes_mem_ , t_final, CV_NORMAL),
640649 " CVodeB" );
641650
642- check_flag_sundials (
643- CVodeGetB (memory->cvodes_mem_ , indexB, &t_init, nv_state_sens),
644- " CVodeGetB" );
645-
646- if (args_vars_ > 0 ) {
647- check_flag_sundials (CVodeGetQuadB (memory->cvodes_mem_ , indexB, &t_init, nv_quad),
648- " CVodeGetQuadB" );
649- }
651+ check_flag_sundials (
652+ CVodeGetB (memory->cvodes_mem_ , indexB, &t_init, nv_state_sens),
653+ " CVodeGetB" );
654+
655+ if (args_vars_ > 0 ) {
656+ check_flag_sundials (
657+ CVodeGetQuadB (memory->cvodes_mem_ , indexB, &t_init, nv_quad),
658+ " CVodeGetQuadB" );
659+ }
650660 }
651661 }
652662
653663 if (t0_vars_ > 0 ) {
654- Eigen::VectorXd y0d = value_of (memory->y0_ );
664+ Eigen::VectorXd y0d = value_of (memory->y0_ );
655665 t0_varis_[0 ]->adj_ += apply (
656666 [&](auto &&... args) {
657667 return -state_sens.dot (memory->f_ (t_init, y0d, msgs_, args...));
0 commit comments