Skip to content

Commit 84f46d6

Browse files
committed
update laplace options to just use a default flag
1 parent c5e1fde commit 84f46d6

12 files changed

+30
-28
lines changed

stan/math/mix/functor/laplace_base_rng.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ namespace math {
3232
* \msg_arg
3333
*/
3434
template <typename LLFunc, typename LLArgs, typename CovarFun,
35-
typename CovarArgs, typename ThetaVec, typename RNG,
35+
typename CovarArgs, bool InitTheta, typename RNG,
3636
require_t<is_all_arithmetic_scalar<CovarArgs, LLArgs>>* = nullptr>
3737
inline Eigen::VectorXd laplace_base_rng(LLFunc&& ll_fun, LLArgs&& ll_args,
3838
CovarFun&& covariance_function,
3939
CovarArgs&& covar_args,
40-
const laplace_options<ThetaVec>& options,
40+
const laplace_options<InitTheta>& options,
4141
RNG& rng, std::ostream* msgs) {
4242
auto md_est = internal::laplace_marginal_density_est(
4343
ll_fun, std::forward<LLArgs>(ll_args),

stan/math/mix/functor/laplace_marginal_density.hpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,21 @@ struct laplace_options_base {
4747
int max_num_steps{100};
4848
};
4949

50-
template <typename Theta, typename = void>
50+
template <bool HasInitTheta>
5151
struct laplace_options;
5252

53-
template <typename Theta>
54-
struct laplace_options<Theta, require_eigen_t<Theta>> : public laplace_options_base {
53+
template <>
54+
struct laplace_options<false> : public laplace_options_base {};
55+
56+
template <>
57+
struct laplace_options<true> : public laplace_options_base {
5558
/* Value for user supplied initial theta */
56-
Theta theta_0{0};
59+
Eigen::VectorXd theta_0{0};
5760
};
5861

59-
template <typename Theta>
60-
struct laplace_options<Theta, require_not_eigen_t<Theta>> : public laplace_options_base {};
6162

62-
using laplace_options_default = laplace_options<void>;
63+
using laplace_options_default = laplace_options<false>;
64+
using laplace_options_user_supplied = laplace_options<true>;
6365
namespace internal {
6466

6567
template <typename Covar, typename ThetaVec, typename WR, typename L_t,
@@ -462,17 +464,17 @@ inline STAN_COLD_PATH void throw_nan(NameStr&& name_str, ParamStr&& param_str,
462464
*
463465
*/
464466
template <typename LLFun, typename LLTupleArgs, typename CovarFun,
465-
typename CovarArgs, typename ThetaVec,
467+
typename CovarArgs, bool InitTheta,
466468
require_t<is_all_arithmetic_scalar<CovarArgs>>* = nullptr>
467469
inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
468470
CovarFun&& covariance_function,
469471
CovarArgs&& covar_args,
470-
const laplace_options<ThetaVec>& options,
472+
const laplace_options<InitTheta>& options,
471473
std::ostream* msgs) {
472474
using Eigen::MatrixXd;
473475
using Eigen::SparseMatrix;
474476
using Eigen::VectorXd;
475-
if constexpr (is_eigen_v<ThetaVec>) {
477+
if constexpr (InitTheta) {
476478
check_nonzero_size("laplace_marginal", "initial guess", options.theta_0);
477479
check_finite("laplace_marginal", "initial guess", options.theta_0);
478480
}
@@ -520,7 +522,7 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
520522
};
521523
auto ll_args_vals = value_of(ll_args);
522524
Eigen::VectorXd theta = [theta_size, &options]() {
523-
if constexpr (is_eigen_v<ThetaVec>) {
525+
if constexpr (InitTheta) {
524526
return options.theta_0;
525527
} else {
526528
return Eigen::VectorXd::Zero(theta_size);
@@ -794,12 +796,12 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
794796
* @return the log maginal density, p(y | phi)
795797
*/
796798
template <
797-
typename LLFun, typename LLTupleArgs, typename CovarFun, typename CovarArgs, typename ThetaVec,
799+
typename LLFun, typename LLTupleArgs, typename CovarFun, typename CovarArgs, bool InitTheta,
798800
require_t<is_all_arithmetic_scalar<CovarArgs, LLTupleArgs>>* = nullptr>
799801
inline double laplace_marginal_density(LLFun&& ll_fun, LLTupleArgs&& ll_args,
800802
CovarFun&& covariance_function,
801803
CovarArgs&& covar_args,
802-
const laplace_options<ThetaVec>& options,
804+
const laplace_options<InitTheta>& options,
803805
std::ostream* msgs) {
804806
return internal::laplace_marginal_density_est(
805807
std::forward<LLFun>(ll_fun), std::forward<LLTupleArgs>(ll_args),
@@ -1036,12 +1038,12 @@ inline void reverse_pass_collect_adjoints(var ret, Output&& output,
10361038
* @return the log maginal density, p(y | phi)
10371039
*/
10381040
template <typename LLFun, typename LLTupleArgs, typename CovarFun,
1039-
typename CovarArgs, typename ThetaVec,
1041+
typename CovarArgs, bool InitTheta,
10401042
require_t<is_any_var_scalar<LLTupleArgs, CovarArgs>>* = nullptr>
10411043
inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
10421044
CovarFun&& covariance_function,
10431045
CovarArgs&& covar_args,
1044-
const laplace_options<ThetaVec>& options,
1046+
const laplace_options<InitTheta>& options,
10451047
std::ostream* msgs) {
10461048
auto covar_args_refs = to_ref(std::forward<CovarArgs>(covar_args));
10471049
auto ll_args_refs = to_ref(std::forward<LLTupleArgs>(ll_args));

stan/math/mix/prob/laplace_latent_bernoulli_logit_rng.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ inline Eigen::VectorXd laplace_latent_tol_bernoulli_logit_rng(
3636
const double tolerance, const int max_num_steps,
3737
const int hessian_block_size, const int solver,
3838
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
39-
laplace_options<Eigen::VectorXd> ops{hessian_block_size, solver, max_steps_line_search,
39+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
4040
tolerance, max_num_steps, value_of(theta_0)};
4141
return laplace_base_rng(bernoulli_logit_likelihood{},
4242
std::forward_as_tuple(to_vector(y), n_samples),

stan/math/mix/prob/laplace_latent_neg_binomial_2_log_rng.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ inline Eigen::VectorXd laplace_latent_tol_neg_binomial_2_log_rng(
4242
const double tolerance, const int max_num_steps,
4343
const int hessian_block_size, const int solver,
4444
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
45-
laplace_options<Eigen::VectorXd> ops{hessian_block_size, solver, max_steps_line_search,
45+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
4646
tolerance, max_num_steps, value_of(theta_0)};
4747
return laplace_base_rng(
4848
neg_binomial_2_log_likelihood{},

stan/math/mix/prob/laplace_latent_poisson_log_2_rng.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ inline auto laplace_latent_tol_poisson_2_log_rng(
3838
const double tolerance, const int max_num_steps,
3939
const int hessian_block_size, const int solver,
4040
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
41-
laplace_options<Eigen::VectorXd> ops{hessian_block_size, solver, max_steps_line_search,
41+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
4242
tolerance, max_num_steps, value_of(theta_0)};
4343
return laplace_base_rng(poisson_log_2_likelihood{},
4444
std::forward_as_tuple(y, y_index, ye),

stan/math/mix/prob/laplace_latent_poisson_log_rng.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ inline Eigen::VectorXd laplace_latent_tol_poisson_log_rng(
3636
const double tolerance, const int max_num_steps,
3737
const int hessian_block_size, const int solver,
3838
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
39-
laplace_options<Eigen::VectorXd> ops{hessian_block_size, solver, max_steps_line_search,
39+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
4040
tolerance, max_num_steps, value_of(theta_0)};
4141
return laplace_base_rng(poisson_log_likelihood{},
4242
std::forward_as_tuple(y, y_index),

stan/math/mix/prob/laplace_latent_rng.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ inline auto laplace_latent_tol_rng(
3636
CovarArgs&& covar_args, ThetaVec&& theta_0, const double tolerance,
3737
const int max_num_steps, const int hessian_block_size, const int solver,
3838
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
39-
const laplace_options<Eigen::VectorXd> ops{hessian_block_size, solver,
39+
const laplace_options_user_supplied ops{hessian_block_size, solver,
4040
max_steps_line_search, tolerance,
4141
max_num_steps, value_of(theta_0)};
4242
return laplace_base_rng(std::forward<LLFunc>(L_f),

stan/math/mix/prob/laplace_marginal.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ inline auto laplace_marginal_tol(
3333
CovarArgs&& covar_args, const ThetaVec& theta_0, double tolerance,
3434
int max_num_steps, const int hessian_block_size, const int solver,
3535
const int max_steps_line_search, std::ostream* msgs) {
36-
laplace_options<Eigen::VectorXd> ops{hessian_block_size, solver, max_steps_line_search,
36+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
3737
tolerance, max_num_steps, value_of(theta_0)};
3838
return laplace_marginal_density(
3939
std::forward<LFun>(L_f), std::forward<LArgs>(l_args),

stan/math/mix/prob/laplace_marginal_bernoulli_logit_lpmf.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ inline auto laplace_marginal_tol_bernoulli_logit_lpmf(
5555
const ThetaVec& theta_0, double tolerance, int max_num_steps,
5656
const int hessian_block_size, const int solver,
5757
const int max_steps_line_search, std::ostream* msgs) {
58-
laplace_options<Eigen::VectorXd> ops{hessian_block_size, solver, max_steps_line_search,
58+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
5959
tolerance, max_num_steps, value_of(theta_0)};
6060
return laplace_marginal_density(
6161
bernoulli_logit_likelihood{},

stan/math/mix/prob/laplace_marginal_neg_binomial_2_log_lpmf.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ inline auto laplace_marginal_tol_neg_binomial_2_log_lpmf(
7373
const ThetaVec& theta_0, double tolerance, int max_num_steps,
7474
const int hessian_block_size, const int solver,
7575
const int max_steps_line_search, std::ostream* msgs) {
76-
laplace_options<Eigen::VectorXd> ops{hessian_block_size, solver, max_steps_line_search,
76+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
7777
tolerance, max_num_steps, value_of(theta_0)};
7878
return laplace_marginal_density(
7979
neg_binomial_2_log_likelihood{}, std::forward_as_tuple(eta, y, y_index),
@@ -158,7 +158,7 @@ inline auto laplace_marginal_tol_neg_binomial_2_log_summary_lpmf(
158158
const ThetaVec& theta_0, double tolerance, int max_num_steps,
159159
const int hessian_block_size, const int solver,
160160
const int max_steps_line_search, std::ostream* msgs) {
161-
laplace_options<Eigen::VectorXd> ops{hessian_block_size, solver, max_steps_line_search,
161+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
162162
tolerance, max_num_steps, value_of(theta_0)};
163163
return laplace_marginal_density(
164164
neg_binomial_2_log_likelihood_summary{},

0 commit comments

Comments
 (0)