11#ifndef STAN_MATH_MIX_FUNCTOR_LAPLACE_MARGINAL_DENSITY_ESTIMATOR_HPP
22#define STAN_MATH_MIX_FUNCTOR_LAPLACE_MARGINAL_DENSITY_ESTIMATOR_HPP
33#include < stan/math/prim/fun/Eigen.hpp>
4+ #include < stan/math/prim/fun/generate_laplace_options.hpp>
45#include < stan/math/mix/functor/laplace_likelihood.hpp>
56#include < stan/math/mix/functor/wolfe_line_search.hpp>
67#include < stan/math/rev/meta.hpp>
1011#include < stan/math/mix/functor/barzilai_borwein_step_size.hpp>
1112#include < stan/math/prim/fun/to_ref.hpp>
1213#include < stan/math/prim/fun/quad_form_diag.hpp>
14+ #include < stan/math/prim/fun/value_of.hpp>
1315#include < stan/math/prim/functor/iter_tuple_nested.hpp>
1416#include < unsupported/Eigen/MatrixFunctions>
1517#include < cmath>
@@ -30,7 +32,7 @@ namespace math {
3032 */
3133struct laplace_options_base {
3234 /* Size of the blocks in block diagonal hessian*/
33- int hessian_block_size{1 }; // 0
35+ int hessian_block_size{internal::laplace_default_hessian_block_size }; // 0
3436 /* *
3537 * Which linear solver to use inside the Newton step.
3638 *
@@ -46,19 +48,20 @@ struct laplace_options_base {
4648 * `Sigma = K_root * K_root^T` and form `B = I + K_root^T * W * K_root`.
4749 * 3. General LU: form `B = I + Sigma * W` and factorize with LU.
4850 */
49- int solver{1 }; // 1
51+ int solver{internal::laplace_default_solver }; // 1
5052 /* *
5153 * Iterations end when the absolute change in the optimization objective
5254 * is less than this tolerance.
5355 *
5456 * Note: the objective used for convergence is the one optimized by the
5557 * Newton/Wolfe loop (not the final Laplace-corrected log marginal density).
5658 */
57- double tolerance{1.49012e-08 }; // 2
59+ double tolerance{internal::laplace_default_tolerance }; // 2
5860 /* Maximum number of steps*/
59- int max_num_steps{500 }; // 3
60- int allow_fallthrough{true }; // 4
61- laplace_line_search_options line_search; // 5
61+ int max_num_steps{internal::laplace_default_max_num_steps}; // 3
62+ int allow_fallthrough{internal::laplace_default_allow_fallthrough}; // 4
63+ laplace_line_search_options line_search{
64+ internal::laplace_default_max_steps_line_search}; // 5
6265 laplace_options_base () = default ;
6366 laplace_options_base (int hessian_block_size_, int solver_, double tolerance_,
6467 int max_num_steps_, bool allow_fallthrough_,
@@ -75,7 +78,13 @@ template <bool HasInitTheta>
7578struct laplace_options ;
7679
7780template <>
78- struct laplace_options <false > : public laplace_options_base {};
81+ struct laplace_options <false > : public laplace_options_base {
82+ laplace_options () = default ;
83+
84+ explicit laplace_options (int hessian_block_size_) {
85+ hessian_block_size = hessian_block_size_;
86+ }
87+ };
7988
8089template <>
8190struct laplace_options <true > : public laplace_options_base {
@@ -89,25 +98,12 @@ struct laplace_options<true> : public laplace_options_base {
8998 : laplace_options_base(hessian_block_size_, solver_, tolerance_,
9099 max_num_steps_, allow_fallthrough_,
91100 max_steps_line_search_),
92- theta_0 (std::forward<ThetaVec>(theta_0_)) {}
101+ theta_0 (value_of( std::forward<ThetaVec>(theta_0_) )) {}
93102};
94103
95104using laplace_options_default = laplace_options<false >;
96105using laplace_options_user_supplied = laplace_options<true >;
97106
98- /* *
99- * User function for generating laplace options tuple
100- * @param theta_0_size Size of user supplied initial theta
101- * @return tuple representing laplace options exposed to user.
102- */
103- inline auto generate_laplace_options (int theta_0_size) {
104- auto ops = laplace_options_default{};
105- return std::make_tuple (
106- Eigen::VectorXd::Zero (theta_0_size).eval (), // 0 -> 6
107- ops.tolerance , ops.max_num_steps , ops.hessian_block_size , ops.solver ,
108- ops.line_search .max_iterations , static_cast <int >(ops.allow_fallthrough ));
109- }
110-
111107namespace internal {
112108
113109template <typename Options>
@@ -135,41 +131,36 @@ inline constexpr auto tuple_to_laplace_options(Options&& ops) {
135131 " the laplace approximation." );
136132 }
137133 if constexpr (!stan::is_inner_tuple_type_v<3 , Ops, int >) {
138- static_assert (
139- sizeof (std::decay_t <Ops>*) == 0 ,
140- " ERROR:(laplace_marginal_lpdf) The fifth laplace argument is "
141- " expected to be an int representing the hessian block size." );
142- }
143- if constexpr (!stan::is_inner_tuple_type_v<4 , Ops, int >) {
144134 static_assert (
145135 sizeof (std::decay_t <Ops>*) == 0 ,
146136 " ERROR:(laplace_marginal_lpdf) The fourth laplace argument is "
147137 " expected to be an int representing the solver." );
148138 }
149- if constexpr (!stan::is_inner_tuple_type_v<5 , Ops, int >) {
139+ if constexpr (!stan::is_inner_tuple_type_v<4 , Ops, int >) {
150140 static_assert (
151141 sizeof (std::decay_t <Ops>*) == 0 ,
152- " ERROR:(laplace_marginal_lpdf) The sixth laplace argument is "
142+ " ERROR:(laplace_marginal_lpdf) The fifth laplace argument is "
153143 " expected to be an int representing the max steps for the laplace "
154144 " approximaton's wolfe line search." );
155145 }
156146 constexpr bool is_fallthrough
157147 = stan::is_inner_tuple_type_v<
158- 6 , Ops, int > || stan::is_inner_tuple_type_v<6 , Ops, bool >;
148+ 5 , Ops, int > || stan::is_inner_tuple_type_v<5 , Ops, bool >;
159149 if constexpr (!is_fallthrough) {
160150 static_assert (
161151 sizeof (std::decay_t <Ops>*) == 0 ,
162- " ERROR:(laplace_marginal_lpdf) The seventh laplace argument is "
152+ " ERROR:(laplace_marginal_lpdf) The sixth laplace argument is "
163153 " expected to be an int representing allow fallthrough (0/1)." );
164154 }
155+ auto defaults = laplace_options_default{};
165156 return laplace_options_user_supplied{
166157 value_of (std::get<0 >(std::forward<Ops>(ops))),
167158 std::get<1 >(ops),
168159 std::get<2 >(ops),
160+ defaults.hessian_block_size ,
169161 std::get<3 >(ops),
170162 std::get<4 >(ops),
171- std::get<5 >(ops),
172- (std::get<6 >(ops) > 0 ) ? true : false ,
163+ (std::get<5 >(ops) > 0 ) ? true : false ,
173164 };
174165 } else {
175166 return std::forward<Ops>(ops);
0 commit comments