@@ -47,19 +47,21 @@ struct laplace_options_base {
4747 int max_num_steps{100 };
4848};
4949
50- template <typename Theta, typename = void >
50+ template <bool HasInitTheta >
5151struct laplace_options ;
5252
53- template <typename Theta>
54- struct laplace_options <Theta, require_eigen_t <Theta>> : public laplace_options_base {
53+ template <>
54+ struct laplace_options <false > : public laplace_options_base {};
55+
56+ template <>
57+ struct laplace_options <true > : public laplace_options_base {
5558 /* Value for user supplied initial theta */
56- Theta theta_0{0 };
59+ Eigen::VectorXd theta_0{0 };
5760};
5861
59- template <typename Theta>
60- struct laplace_options <Theta, require_not_eigen_t <Theta>> : public laplace_options_base {};
6162
62- using laplace_options_default = laplace_options<void >;
63+ using laplace_options_default = laplace_options<false >;
64+ using laplace_options_user_supplied = laplace_options<true >;
6365namespace internal {
6466
6567template <typename Covar, typename ThetaVec, typename WR, typename L_t,
@@ -462,17 +464,17 @@ inline STAN_COLD_PATH void throw_nan(NameStr&& name_str, ParamStr&& param_str,
462464 *
463465 */
464466template <typename LLFun, typename LLTupleArgs, typename CovarFun,
465- typename CovarArgs, typename ThetaVec ,
467+ typename CovarArgs, bool InitTheta ,
466468 require_t <is_all_arithmetic_scalar<CovarArgs>>* = nullptr >
467469inline auto laplace_marginal_density_est (LLFun&& ll_fun, LLTupleArgs&& ll_args,
468470 CovarFun&& covariance_function,
469471 CovarArgs&& covar_args,
470- const laplace_options<ThetaVec >& options,
472+ const laplace_options<InitTheta >& options,
471473 std::ostream* msgs) {
472474 using Eigen::MatrixXd;
473475 using Eigen::SparseMatrix;
474476 using Eigen::VectorXd;
475- if constexpr (is_eigen_v<ThetaVec> ) {
477+ if constexpr (InitTheta ) {
476478 check_nonzero_size (" laplace_marginal" , " initial guess" , options.theta_0 );
477479 check_finite (" laplace_marginal" , " initial guess" , options.theta_0 );
478480 }
@@ -520,7 +522,7 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
520522 };
521523 auto ll_args_vals = value_of (ll_args);
522524 Eigen::VectorXd theta = [theta_size, &options]() {
523- if constexpr (is_eigen_v<ThetaVec> ) {
525+ if constexpr (InitTheta ) {
524526 return options.theta_0 ;
525527 } else {
526528 return Eigen::VectorXd::Zero (theta_size);
@@ -794,12 +796,12 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
794796 * @return the log maginal density, p(y | phi)
795797 */
796798template <
797- typename LLFun, typename LLTupleArgs, typename CovarFun, typename CovarArgs, typename ThetaVec ,
799+ typename LLFun, typename LLTupleArgs, typename CovarFun, typename CovarArgs, bool InitTheta ,
798800 require_t <is_all_arithmetic_scalar<CovarArgs, LLTupleArgs>>* = nullptr >
799801inline double laplace_marginal_density (LLFun&& ll_fun, LLTupleArgs&& ll_args,
800802 CovarFun&& covariance_function,
801803 CovarArgs&& covar_args,
802- const laplace_options<ThetaVec >& options,
804+ const laplace_options<InitTheta >& options,
803805 std::ostream* msgs) {
804806 return internal::laplace_marginal_density_est (
805807 std::forward<LLFun>(ll_fun), std::forward<LLTupleArgs>(ll_args),
@@ -1036,12 +1038,12 @@ inline void reverse_pass_collect_adjoints(var ret, Output&& output,
10361038 * @return the log maginal density, p(y | phi)
10371039 */
10381040template <typename LLFun, typename LLTupleArgs, typename CovarFun,
1039- typename CovarArgs, typename ThetaVec ,
1041+ typename CovarArgs, bool InitTheta ,
10401042 require_t <is_any_var_scalar<LLTupleArgs, CovarArgs>>* = nullptr >
10411043inline auto laplace_marginal_density (const LLFun& ll_fun, LLTupleArgs&& ll_args,
10421044 CovarFun&& covariance_function,
10431045 CovarArgs&& covar_args,
1044- const laplace_options<ThetaVec >& options,
1046+ const laplace_options<InitTheta >& options,
10451047 std::ostream* msgs) {
10461048 auto covar_args_refs = to_ref (std::forward<CovarArgs>(covar_args));
10471049 auto ll_args_refs = to_ref (std::forward<LLTupleArgs>(ll_args));
0 commit comments