@@ -27,7 +27,7 @@ namespace math {
2727/* *
2828 * Options for the laplace sampler
2929 */
30- struct laplace_options {
30+ struct laplace_options_base {
3131 /* Size of the blocks in block diagonal hessian*/
3232 int hessian_block_size{1 };
3333 /* *
@@ -45,11 +45,23 @@ struct laplace_options {
4545 double tolerance{1e-6 };
4646 /* Maximum number of steps*/
4747 int max_num_steps{100 };
48+ };
49+
50+ template <bool HasInitTheta>
51+ struct laplace_options ;
4852
49- /* Initial value for theta. Defaults to 0s of the correct size if nullopt */
50- std::optional<Eigen::VectorXd> theta_0{std::nullopt };
53+ template <>
54+ struct laplace_options <false > : public laplace_options_base {};
55+
56+ template <>
57+ struct laplace_options <true > : public laplace_options_base {
58+ /* Value for user supplied initial theta */
59+ Eigen::VectorXd theta_0{0 };
5160};
5261
62+
63+ using laplace_options_default = laplace_options<false >;
64+ using laplace_options_user_supplied = laplace_options<true >;
5365namespace internal {
5466
5567template <typename Covar, typename ThetaVec, typename WR, typename L_t,
@@ -452,21 +464,20 @@ inline STAN_COLD_PATH void throw_nan(NameStr&& name_str, ParamStr&& param_str,
452464 *
453465 */
454466template <typename LLFun, typename LLTupleArgs, typename CovarFun,
455- typename CovarArgs,
467+ typename CovarArgs, bool InitTheta,
456468 require_t <is_all_arithmetic_scalar<CovarArgs>>* = nullptr >
457469inline auto laplace_marginal_density_est (LLFun&& ll_fun, LLTupleArgs&& ll_args,
458470 CovarFun&& covariance_function,
459471 CovarArgs&& covar_args,
460- const laplace_options& options,
472+ const laplace_options<InitTheta> & options,
461473 std::ostream* msgs) {
462474 using Eigen::MatrixXd;
463475 using Eigen::SparseMatrix;
464476 using Eigen::VectorXd;
465- if (options. theta_0 . has_value () ) {
466- check_nonzero_size (" laplace_marginal" , " initial guess" , * options.theta_0 );
467- check_finite (" laplace_marginal" , " initial guess" , * options.theta_0 );
477+ if constexpr (InitTheta ) {
478+ check_nonzero_size (" laplace_marginal" , " initial guess" , options.theta_0 );
479+ check_finite (" laplace_marginal" , " initial guess" , options.theta_0 );
468480 }
469-
470481 check_nonnegative (" laplace_marginal" , " tolerance" , options.tolerance );
471482 check_positive (" laplace_marginal" , " max_num_steps" , options.max_num_steps );
472483 check_positive (" laplace_marginal" , " hessian_block_size" ,
@@ -510,9 +521,13 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
510521 + std::to_string (max_num_steps) + " exceeded." );
511522 };
512523 auto ll_args_vals = value_of (ll_args);
513- Eigen::VectorXd theta = options.theta_0 .has_value ()
514- ? *options.theta_0
515- : Eigen::VectorXd::Zero (theta_size);
524+ Eigen::VectorXd theta = [theta_size, &options]() {
525+ if constexpr (InitTheta) {
526+ return options.theta_0 ;
527+ } else {
528+ return Eigen::VectorXd::Zero (theta_size);
529+ }
530+ }();
516531 double objective_old = std::numeric_limits<double >::lowest ();
517532 double objective_new = std::numeric_limits<double >::lowest () + 1 ;
518533 Eigen::VectorXd a_prev = Eigen::VectorXd::Zero (theta_size);
@@ -584,7 +599,7 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
584599 }
585600 }
586601 } else {
587- Eigen::SparseMatrix<double > W_r (theta. rows (), theta. rows () );
602+ Eigen::SparseMatrix<double > W_r (theta_size, theta_size );
588603 Eigen::Index block_size = options.hessian_block_size ;
589604 W_r.reserve (Eigen::VectorXi::Constant (W_r.cols (), block_size));
590605 const Eigen::Index n_block = W_r.cols () / block_size;
@@ -781,12 +796,12 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
781796 * @return the log maginal density, p(y | phi)
782797 */
783798template <
784- typename LLFun, typename LLTupleArgs, typename CovarFun, typename CovarArgs,
799+ typename LLFun, typename LLTupleArgs, typename CovarFun, typename CovarArgs, bool InitTheta,
785800 require_t <is_all_arithmetic_scalar<CovarArgs, LLTupleArgs>>* = nullptr >
786801inline double laplace_marginal_density (LLFun&& ll_fun, LLTupleArgs&& ll_args,
787802 CovarFun&& covariance_function,
788803 CovarArgs&& covar_args,
789- const laplace_options& options,
804+ const laplace_options<InitTheta> & options,
790805 std::ostream* msgs) {
791806 return internal::laplace_marginal_density_est (
792807 std::forward<LLFun>(ll_fun), std::forward<LLTupleArgs>(ll_args),
@@ -1023,12 +1038,12 @@ inline void reverse_pass_collect_adjoints(var ret, Output&& output,
10231038 * @return the log maginal density, p(y | phi)
10241039 */
10251040template <typename LLFun, typename LLTupleArgs, typename CovarFun,
1026- typename CovarArgs,
1041+ typename CovarArgs, bool InitTheta,
10271042 require_t <is_any_var_scalar<LLTupleArgs, CovarArgs>>* = nullptr >
10281043inline auto laplace_marginal_density (const LLFun& ll_fun, LLTupleArgs&& ll_args,
10291044 CovarFun&& covariance_function,
10301045 CovarArgs&& covar_args,
1031- const laplace_options& options,
1046+ const laplace_options<InitTheta> & options,
10321047 std::ostream* msgs) {
10331048 auto covar_args_refs = to_ref (std::forward<CovarArgs>(covar_args));
10341049 auto ll_args_refs = to_ref (std::forward<LLTupleArgs>(ll_args));
0 commit comments