Skip to content

Commit 0aaed66

Browse files
committed
cleanup and testing different line searches
1 parent 5159c28 commit 0aaed66

54 files changed

Lines changed: 100 additions & 455 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

stan/math/mix/functor/hessian_block_diag.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#ifndef STAN_MATH_MIX_FUNCTOR_HESSIAN_BLOCK_DIAG_HPP
22
#define STAN_MATH_MIX_FUNCTOR_HESSIAN_BLOCK_DIAG_HPP
33

4+
#include <stan/math/prim/fun/Eigen.hpp>
45
#include <stan/math/mix/functor/hessian_times_vector.hpp>
5-
#include <Eigen/Sparse>
66

77
namespace stan {
88
namespace math {

stan/math/mix/functor/laplace_base_rng.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#ifndef STAN_MATH_MIX_FUNCTOR_LAPLACE_BASE_RNG_HPP
22
#define STAN_MATH_MIX_FUNCTOR_LAPLACE_BASE_RNG_HPP
33

4+
#include <stan/math/prim/fun/Eigen.hpp>
45
#include <stan/math/mix/functor/laplace_marginal_density.hpp>
56
#include <stan/math/prim/prob/multi_normal_cholesky_rng.hpp>
67

7-
#include <Eigen/Sparse>
8+
89

910
namespace stan {
1011
namespace math {
@@ -73,7 +74,7 @@ inline Eigen::VectorXd laplace_base_rng(LLFunc&& ll_fun, LLArgs&& ll_args,
7374
- covariance_train
7475
* (md_est.W_r
7576
- md_est.W_r
76-
* md_est.LU.solve(md_est.covariance * md_est.W_r))
77+
* md_est.LU.solve(covariance_train * md_est.W_r))
7778
* covariance_train;
7879
return multi_normal_rng(std::move(mean_train), std::move(Sigma), rng);
7980
}

stan/math/mix/functor/laplace_likelihood.hpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88

99
namespace stan {
1010
namespace math {
11-
inline std::basic_ostream<char>* value_of(std::basic_ostream<char>*& pstream) {
12-
return pstream;
13-
}
1411

1512
/**
1613
* functions to compute the log density, first, second,
@@ -55,22 +52,26 @@ template <template <typename...> class Filter,
5552
inline auto conditional_copy_and_promote(Args&&... args) {
5653
return map_if<Filter>(
5754
[](auto&& arg) {
58-
if constexpr (is_tuple<std::decay_t<decltype(arg)>>::value) {
55+
if constexpr (is_tuple_v<decltype(arg)>) {
5956
return stan::math::apply(
60-
[](auto&&... args) {
57+
[](auto&&... inner_args) {
6158
return partially_forward_as_tuple(
6259
conditional_copy_and_promote<Filter, PromotedType,
6360
CopyType>(
64-
std::forward<decltype(args)>(args))...);
61+
std::forward<decltype(inner_args)>(inner_args))...);
6562
},
6663
std::forward<decltype(arg)>(arg));
6764
} else {
6865
if constexpr (CopyType == COPY_TYPE::DEEP) {
6966
return stan::math::eval(promote_scalar<PromotedType>(
7067
value_of_rec(std::forward<decltype(arg)>(arg))));
7168
} else if (CopyType == COPY_TYPE::SHALLOW) {
72-
return stan::math::eval(
69+
if constexpr (std::is_same_v<PromotedType, scalar_type_t<decltype(arg)>>) {
70+
return std::forward<decltype(arg)>(arg);
71+
} else {
72+
return stan::math::eval(
7373
promote_scalar<PromotedType>(std::forward<decltype(arg)>(arg)));
74+
}
7475
}
7576
}
7677
},
@@ -92,6 +93,7 @@ inline auto shallow_copy_vargs(Args&&... args) {
9293
}
9394

9495
/**
96+
* @note If `Args` contains \ref var types then their adjoints will be calculated as a side effect.
9597
* @tparam F A functor with `opertor()(Args&&...)` returning a scalar
9698
* @tparam Theta A class assignable to an Eigen vector type
9799
* @tparam Stream Type of stream for messages.
@@ -139,6 +141,7 @@ inline auto diff(F&& f, Theta&& theta, const Eigen::Index hessian_block_size,
139141
}
140142

141143
/**
144+
* @note If `Args` contains \ref var types then their adjoints will be calculated as a side effect.
142145
* @tparam F A functor with `opertor()(Args&&...)` returning a scalar
143146
* @tparam Theta A class assignable to an Eigen vector type
144147
* @tparam Stream Type of stream for messages.
@@ -165,6 +168,7 @@ inline Eigen::VectorXd third_diff(F&& f, Theta&& theta, Stream&& msgs,
165168
}
166169

167170
/**
171+
* @note If `Args` contains \ref var types then their adjoints will be calculated as a side effect.
168172
* @tparam F A functor with `opertor()(Args&&...)` returning a scalar
169173
* @tparam Theta An Eigen Matrix
170174
* @tparam AMat An Eigen Matrix
@@ -225,6 +229,7 @@ inline auto compute_s2(F&& f, Theta&& theta, AMat&& A,
225229
}
226230

227231
/**
232+
* @note If `Args` contains \ref var types then their adjoints will be calculated as a side effect.
228233
* @tparam F A functor with `opertor()(Args&&...)` returning a scalar
229234
* @tparam V_t A type assignable to an Eigen vector type
230235
* @tparam Theta A type assignable to an Eigen vector type
@@ -270,6 +275,7 @@ inline auto diff_eta_implicit(F&& f, V_t&& v, Theta&& theta, Stream* msgs,
270275
} // namespace internal
271276

272277
/**
278+
* A wrapper that accepts a tuple as arguments.
273279
* @tparam F A functor with `opertor()(Args&&...)` returning a scalar
274280
* @tparam Theta A class assignable to an Eigen vector type
275281
* @tparam TupleArgs Type of arguments for covariance function.
@@ -295,6 +301,7 @@ inline auto log_likelihood(F&& f, Theta&& theta, TupleArgs&& ll_tup,
295301
}
296302

297303
/**
304+
* A wrapper that accepts a tuple as arguments.
298305
* @tparam F A functor with `opertor()(Args&&...)` returning a scalar
299306
* @tparam Theta A class assignable to an Eigen vector type
300307
* @tparam TupleArgs Type of arguments for covariance function.
@@ -323,6 +330,7 @@ inline auto diff(F&& f, Theta&& theta, const Eigen::Index hessian_block_size,
323330
}
324331

325332
/**
333+
* A wrapper that accepts a tuple as arguments.
326334
* @tparam F Type of log likelhood function.
327335
* @tparam Theta A class assignable to an Eigen vector type
328336
* @tparam TupleArgs Type of arguments for covariance function.
@@ -348,6 +356,7 @@ inline Eigen::VectorXd third_diff(F&& f, Theta&& theta, TupleArgs&& ll_args,
348356
}
349357

350358
/**
359+
* A wrapper that accepts a tuple as arguments.
351360
* @tparam F Type of log likelhood function.
352361
* @tparam Theta Type of latent Gaussian ba
353362
* @tparam TupleArgs Type of arguments for covariance function.
@@ -380,6 +389,7 @@ inline auto compute_s2(F&& f, Theta&& theta, AMat&& A, int hessian_block_size,
380389
}
381390

382391
/**
392+
* A wrapper that accepts a tuple as arguments.
383393
* @tparam F A functor with `opertor()(Args&&...)` returning a scalar
384394
* @tparam V_t Type of initial tangent.
385395
* @tparam Theta A class assignable to an Eigen vector type

stan/math/mix/functor/laplace_marginal_density.hpp

Lines changed: 70 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#ifndef STAN_MATH_MIX_FUNCTOR_LAPLACE_MARGINAL_DENSITY_HPP
22
#define STAN_MATH_MIX_FUNCTOR_LAPLACE_MARGINAL_DENSITY_HPP
3+
#include <stan/math/prim/fun/Eigen.hpp>
34
#include <stan/math/mix/functor/laplace_likelihood.hpp>
45
#include <stan/math/rev/meta.hpp>
56
#include <stan/math/rev/core.hpp>
@@ -8,7 +9,6 @@
89
#include <stan/math/rev/functor.hpp>
910
#include <stan/math/prim/fun/to_ref.hpp>
1011
#include <stan/math/prim/fun/quad_form_diag.hpp>
11-
#include <Eigen/Sparse>
1212
#include <Eigen/LU>
1313
#include <unsupported/Eigen/MatrixFunctions>
1414

@@ -83,66 +83,6 @@ struct laplace_density_estimates {
8383
K_root(std::move(K_root_)) {}
8484
};
8585

86-
/**
87-
* Function to compute the pseudo target, $\tilde Z$,
88-
* with a custom derivative method
89-
* NOTE: we actually don't need to compute the pseudo-target, only its
90-
* derivative
91-
* @tparam Kmat Type inheriting from `Eigen::EigenBase` with dynamic rows and
92-
* columns
93-
* @tparam AVec Type of matrix of initial tangents
94-
* @tparam RMat Type of the stable R matrix
95-
* @tparam LGradVec Type of the gradient of the log likelihood
96-
* @tparam S2Vec Type of the s2 vector
97-
*/
98-
template <
99-
typename KMat, typename AVec, typename RMat, typename LGradVec,
100-
typename S2Vec,
101-
require_eigen_matrix_dynamic_vt<std::is_floating_point, KMat>* = nullptr>
102-
inline constexpr double laplace_pseudo_target(KMat&& /* K */, AVec&& /* a */,
103-
RMat&& /* R */,
104-
LGradVec&& /* l_grad */,
105-
S2Vec&& /* s2 */) {
106-
return static_cast<double>(0.0);
107-
}
108-
109-
/**
110-
* Overload function for case where K is passed as a matrix of var
111-
* @tparam Kmat Type inheriting from `Eigen::EigenBase` with dynamic rows and
112-
* columns
113-
* @tparam AVec Type inheriting from `Eigen::EigenBase` with dynamic columns and
114-
* a single row
115-
* @tparam RMat Type inheriting from `Eigen::EigenBase` with dynamic rows and
116-
* columns
117-
* @tparam LGradVec Type inheriting from `Eigen::EigenBase` with dynamic rows
118-
* and a single column
119-
* @tparam S2Vec Type of s2 vector
120-
* @param K Covariance matrix
121-
* @param a Saved a vector from Newton solver
122-
* @param R Stable R matrix
123-
* @param l_grad Saved gradient of log likelihood
124-
* @param s2 Gradient of log determinant w.r.t latent Gaussian variable
125-
*/
126-
template <typename KMat, typename AVec, typename RMat, typename LGradVec,
127-
typename S2Vec,
128-
require_eigen_matrix_dynamic_vt<is_var, KMat>* = nullptr>
129-
inline auto laplace_pseudo_target(KMat&& K, AVec&& a, RMat&& R,
130-
LGradVec&& l_grad, S2Vec&& s2) {
131-
const Eigen::Index dim_theta = K.rows();
132-
auto K_arena = to_arena(std::forward<KMat>(K));
133-
auto&& a_ref = to_ref(std::forward<AVec>(a));
134-
auto&& R_ref = to_ref(std::forward<RMat>(R));
135-
auto&& s2_ref = to_ref(std::forward<S2Vec>(s2));
136-
auto&& l_grad_ref = to_ref(std::forward<LGradVec>(l_grad));
137-
arena_matrix<Eigen::MatrixXd> K_adj_arena
138-
= 0.5 * a_ref * a_ref.transpose() - 0.5 * R_ref
139-
+ s2_ref * l_grad_ref.transpose()
140-
- (R_ref * (value_of(K_arena) * s2_ref)) * l_grad_ref.transpose();
141-
return make_callback_var(0.0, [K_arena, K_adj_arena](auto&& vi) mutable {
142-
K_arena.adj().array() += vi.adj() * K_adj_arena.array();
143-
});
144-
}
145-
14686
template <typename WRootMat>
14787
inline void block_matrix_sqrt(WRootMat& W_root,
14888
const Eigen::SparseMatrix<double>& W,
@@ -194,39 +134,67 @@ inline void block_matrix_sqrt(WRootMat& W_root,
194134
}
195135
}
196136
}
197-
template <typename AVec, typename APrev, typename ThetaVec, typename LLFun,
198-
typename LLArgs, typename Covar, typename Msgs>
199-
inline auto line_search(double& objective_new, AVec&& a, APrev& a_prev,
200-
ThetaVec&& theta, LLFun&& ll_fun, LLArgs&& ll_args,
201-
Covar&& covariance, const int max_steps_line_search,
202-
const double objective_old, Msgs* msgs) {
203-
Eigen::VectorXd a_tmp(a.size());
204-
double objective_new_tmp = 0.0;
205-
double objective_old_tmp = objective_old;
206-
Eigen::VectorXd theta_tmp(covariance.rows());
207-
int j = 0;
208-
for (; j < max_steps_line_search && (objective_new < objective_old_tmp);
209-
++j) {
210-
a_tmp.noalias() = a_prev + 0.5 * (a - a_prev);
211-
theta_tmp.noalias() = covariance * a_tmp;
212-
if (!theta_tmp.allFinite()) {
213-
break;
214-
} else {
215-
objective_new_tmp = -0.5 * a_tmp.dot(theta_tmp)
216-
+ laplace_likelihood::log_likelihood(
217-
ll_fun, theta_tmp, ll_args, msgs);
218-
if (objective_new_tmp < objective_new) {
219-
a_prev.swap(a);
220-
a.swap(a_tmp);
221-
theta.swap(theta_tmp);
222-
objective_old_tmp = objective_new;
223-
objective_new = objective_new_tmp;
224-
} else {
225-
break;
226-
}
227-
}
228-
}
229-
return std::make_tuple(objective_new, std::move(a), std::move(theta));
137+
138+
/**
139+
* @brief Performs a simple line search
140+
*
141+
* @tparam AVec Type of the parameter update vector (`a`), e.g. Eigen::VectorXd.
142+
* @tparam APrev Type of the previous parameter vector (`a_prev`), same shape as AVec.
143+
* @tparam ThetaVec Type of the transformed vector (`theta`), e.g. Σ·a.
144+
* @tparam LLFun Functor type for computing the log‐likelihood.
145+
* @tparam LLArgs Tuple or pack type forwarded to `ll_fun`.
146+
* @tparam Covar Matrix type for the covariance Σ, e.g. Eigen::MatrixXd.
147+
* @tparam Msgs Diagnostics container type for capturing warnings/errors.
148+
*
149+
* @param[in,out] objective_new On entry: objective at the full‐step `a` (must satisfy objective_new < objective_old). On exit: best objective found.
150+
* @param[in,out] a On entry: candidate parameter vector. On exit: updated to the step achieving the lowest objective.
151+
* @param[in,out] theta On entry: Σ·a for the initial candidate. On exit: Σ·a for the accepted best step.
152+
* @param[in,out] a_prev On entry: previous parameter vector, with objective `objective_old`. On exit: rolled forward to each newly accepted step.
153+
* @param[in] ll_fun Callable that computes the log‐likelihood given `(theta, ll_args, msgs)`.
154+
* @param[in] ll_args Arguments forwarded to `ll_fun` at each evaluation.
155+
* @param[in] covariance Covariance matrix Σ used to compute `theta = Σ·a`.
156+
* @param[in] max_steps_line_search Maximum number of iterations.
157+
* @param[in] objective_old Objective value at the initial `a_prev` (used as f₀ for the first pass).
158+
* @param[in,out] msgs Pointer to a diagnostics container; may be used by `ll_fun` to record warnings.
159+
*/
160+
template <typename AVec, typename APrev, typename ThetaVec,
161+
typename LLFun, typename LLArgs, typename Covar, typename Msgs>
162+
inline void line_search(double& objective_new,
163+
AVec& a,
164+
ThetaVec& theta,
165+
APrev& a_prev,
166+
LLFun&& ll_fun,
167+
LLArgs&& ll_args,
168+
Covar&& covariance,
169+
const int max_steps_line_search,
170+
const double objective_old,
171+
double tolerance,
172+
Msgs* msgs) {
173+
Eigen::VectorXd a_tmp(a.size());
174+
double objective_new_tmp = 0.0;
175+
double objective_old_tmp = objective_old;
176+
Eigen::VectorXd theta_tmp(covariance.rows());
177+
for (int j = 0; j < max_steps_line_search && (objective_new < objective_old_tmp);
178+
++j) {
179+
a_tmp.noalias() = a_prev + 0.5 * (a - a_prev);
180+
theta_tmp.noalias() = covariance * a_tmp;
181+
if (!theta_tmp.allFinite()) {
182+
break;
183+
} else {
184+
objective_new_tmp = -0.5 * a_tmp.dot(theta_tmp)
185+
+ laplace_likelihood::log_likelihood(
186+
ll_fun, theta_tmp, ll_args, msgs);
187+
if (objective_new_tmp < objective_new) {
188+
a_prev.swap(a);
189+
a.swap(a_tmp);
190+
theta.swap(theta_tmp);
191+
objective_old_tmp = objective_new;
192+
objective_new = objective_new_tmp;
193+
} else {
194+
break;
195+
}
196+
}
197+
}
230198
}
231199

232200
// iter_tuple_n
@@ -479,10 +447,9 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
479447
+ laplace_likelihood::log_likelihood(ll_fun, theta,
480448
ll_args_vals, msgs);
481449
if (options.max_steps_line_search) {
482-
std::tie(objective_new, a, theta)
483-
= line_search(objective_new, std::move(a), a_prev, std::move(theta),
450+
line_search(objective_new, a, theta, a_prev,
484451
ll_fun, ll_args_vals, covariance,
485-
options.max_steps_line_search, objective_old, msgs);
452+
options.max_steps_line_search, objective_old, options.tolerance, msgs);
486453
}
487454
// Check for convergence
488455
if (abs(objective_new - objective_old) < options.tolerance) {
@@ -547,10 +514,9 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
547514
+ laplace_likelihood::log_likelihood(
548515
ll_fun, value_of(theta), ll_args_vals, msgs);
549516
if (options.max_steps_line_search > 0) {
550-
std::tie(objective_new, a, theta)
551-
= line_search(objective_new, std::move(a), a_prev, std::move(theta),
517+
line_search(objective_new, a, theta, a_prev,
552518
ll_fun, ll_args_vals, covariance,
553-
options.max_steps_line_search, objective_old, msgs);
519+
options.max_steps_line_search, objective_old, options.tolerance, msgs);
554520
}
555521
// Check for convergence
556522
if (abs(objective_new - objective_old) < options.tolerance) {
@@ -600,10 +566,9 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
600566
ll_args_vals, msgs);
601567
// linesearch
602568
if (options.max_steps_line_search > 0) {
603-
std::tie(objective_new, a, theta)
604-
= line_search(objective_new, std::move(a), a_prev, std::move(theta),
569+
line_search(objective_new, a, theta, a_prev,
605570
ll_fun, ll_args_vals, covariance,
606-
options.max_steps_line_search, objective_old, msgs);
571+
options.max_steps_line_search, objective_old, options.tolerance, msgs);
607572
}
608573
// Check for convergence
609574
if (abs(objective_new - objective_old) < options.tolerance) {
@@ -633,7 +598,6 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
633598
MatrixXd::Identity(theta_size, theta_size) + covariance * W);
634599
// L on upper and U on lower triangular
635600
b.noalias() = W * theta + theta_grad;
636-
637601
a.noalias() = b - W * LU.solve(covariance * b);
638602
// Simple Newton step
639603
theta.noalias() = covariance * a;
@@ -647,13 +611,10 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
647611
ll_fun, value_of(theta), ll_args_vals, msgs);
648612

649613
// TODO(Charles): How do we handle NA values in theta?
650-
// linesearch
651-
// CHECK -- does linesearch work for options.solver 2?
652614
if (options.max_steps_line_search > 0) {
653-
std::tie(objective_new, a, theta)
654-
= line_search(objective_new, std::move(a), a_prev, std::move(theta),
615+
line_search(objective_new, a, theta, a_prev,
655616
ll_fun, ll_args_vals, covariance,
656-
options.max_steps_line_search, objective_old, msgs);
617+
options.max_steps_line_search, objective_old, options.tolerance, msgs);
657618
}
658619
if (abs(objective_new - objective_old) < options.tolerance) {
659620
// TODO(Charles): There has to be a simple trick for this
@@ -1046,8 +1007,6 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
10461007
return covariance_function(args..., msgs);
10471008
},
10481009
covar_args_copy));
1049-
// var Z = laplace_pseudo_target(K_var, md_est.a, R,
1050-
// md_est.theta_grad, s2);
10511010
arena_t<Eigen::MatrixXd> K_adj_arena
10521011
= 0.5 * md_est.a * md_est.a.transpose() - 0.5 * R
10531012
+ s2 * md_est.theta_grad.transpose()

0 commit comments

Comments
 (0)