Skip to content

Commit ee9ff7c

Browse files
authored
Merge pull request #3297 from stan-dev/fix-laplace_marginal
fix laplace_marginal optimization
2 parents 8ee3d96 + 520d494 commit ee9ff7c

File tree

4 files changed

+294
-47
lines changed

4 files changed

+294
-47
lines changed

stan/math/mix/functor/laplace_marginal_density_estimator.hpp

Lines changed: 83 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,9 @@ struct NewtonState {
342342
/** @brief Status of the most recent Wolfe line search */
343343
WolfeStatus wolfe_status;
344344

345+
/** @brief Cached proposal evaluated before the Wolfe line search. */
346+
WolfeData proposal;
347+
345348
/** @brief Workspace vector: b = W * theta + grad(log_lik) */
346349
Eigen::VectorXd b;
347350

@@ -360,20 +363,31 @@ struct NewtonState {
360363
bool final_loop = false;
361364

362365
/**
363-
* @brief Constructs Newton state with given dimensions and functors.
366+
* @brief Constructs Newton state with a consistent (a_init, theta_init) pair.
364367
*
365-
* @tparam ThetaInitializer Type of the initial theta provider
368+
* When the caller supplies a non-zero theta_init, a_init = Sigma^{-1} *
369+
* theta_init must be provided to maintain the invariant theta = Sigma * a.
370+
* @tparam ObjFun A callable type for the objective function
371+
* @tparam ThetaGradFun A callable type for the theta gradient function
372+
* @tparam CovarianceT A matrix type for the covariance (must support LLT
373+
* solve)
374+
* @tparam ThetaInitializer A type for the initial theta (e.g., Eigen vector)
366375
* @param theta_size Dimension of the latent space
367376
* @param obj_fun Objective function: (a, theta) -> double
368377
* @param theta_grad_f Gradient function: theta -> grad
369-
* @param theta_init Initial theta value or provider
378+
* @param covariance Covariance matrix for the latent variables
379+
* @param a_init Initial a value consistent with theta_init
380+
* @param theta_init Initial theta value
370381
*/
371-
template <typename ObjFun, typename ThetaGradFun, typename ThetaInitializer>
382+
template <typename ObjFun, typename ThetaGradFun, typename CovarianceT,
383+
typename ThetaInitializer>
372384
NewtonState(int theta_size, ObjFun&& obj_fun, ThetaGradFun&& theta_grad_f,
373-
ThetaInitializer&& theta_init)
374-
: wolfe_info(std::forward<ObjFun>(obj_fun), theta_size,
385+
CovarianceT&& covariance, ThetaInitializer&& theta_init)
386+
: wolfe_info(std::forward<ObjFun>(obj_fun),
387+
covariance.llt().solve(theta_init),
375388
std::forward<ThetaInitializer>(theta_init),
376389
std::forward<ThetaGradFun>(theta_grad_f)),
390+
proposal(theta_size),
377391
b(theta_size),
378392
B(theta_size, theta_size),
379393
prev_g(theta_size) {
@@ -404,9 +418,12 @@ struct NewtonState {
404418
*/
405419
const auto& prev() const& { return wolfe_info.prev_; }
406420
auto&& prev() && { return std::move(wolfe_info).prev(); }
421+
auto& proposal_step() & { return proposal; }
422+
const auto& proposal_step() const& { return proposal; }
423+
auto&& proposal_step() && { return std::move(proposal); }
407424
template <typename Options>
408425
inline void update_next_step(const Options& options) {
409-
this->prev().update(this->curr());
426+
this->prev().swap(this->curr());
410427
this->curr().alpha()
411428
= std::clamp(this->curr().alpha(), 0.0, options.line_search.max_alpha);
412429
}
@@ -426,9 +443,13 @@ inline void llt_with_jitter(LLT& llt_B, B_t& B, double min_jitter = 1e-10,
426443
double max_jitter = 1e-5) {
427444
llt_B.compute(B);
428445
if (llt_B.info() != Eigen::Success) {
446+
double prev_jitter = 0.0;
429447
double jitter_try = min_jitter;
430448
for (; jitter_try < max_jitter; jitter_try *= 10) {
431-
B.diagonal().array() += jitter_try;
449+
// Remove previously added jitter before adding the new (larger) amount,
450+
// so that the total diagonal perturbation is exactly jitter_try.
451+
B.diagonal().array() += (jitter_try - prev_jitter);
452+
prev_jitter = jitter_try;
432453
llt_B.compute(B);
433454
if (llt_B.info() == Eigen::Success) {
434455
break;
@@ -478,7 +499,8 @@ struct CholeskyWSolverDiag {
478499
* @tparam LLFun Type of the log-likelihood functor
479500
* @tparam LLTupleArgs Type of the likelihood arguments tuple
480501
* @tparam CovarMat Type of the covariance matrix
481-
* @param[in,out] state Shared Newton state (modified: B, b, curr().a())
502+
* @param[in,out] state Shared Newton state (modified: B, b,
503+
* proposal_step().a())
482504
* @param[in] ll_fun Log-likelihood functor
483505
* @param[in,out] ll_args Additional arguments for the likelihood
484506
* @param[in] covariance Prior covariance matrix Sigma
@@ -514,12 +536,12 @@ struct CholeskyWSolverDiag {
514536

515537
// 3. Factorize B with jittering fallback
516538
llt_with_jitter(llt_B, state.B);
517-
// 4. Solve for curr.a
539+
// 4. Solve for the raw Newton proposal in a-space.
518540
state.b.noalias() = (W_diag.array() * state.prev().theta().array()).matrix()
519541
+ state.prev().theta_grad();
520542
auto L = llt_B.matrixL();
521543
auto LT = llt_B.matrixU();
522-
state.curr().a().noalias()
544+
state.proposal_step().a().noalias()
523545
= state.b
524546
- W_r_diag.asDiagonal()
525547
* LT.solve(
@@ -608,7 +630,8 @@ struct CholeskyWSolverBlock {
608630
* @tparam LLFun Type of the log-likelihood functor
609631
* @tparam LLTupleArgs Type of the likelihood arguments tuple
610632
* @tparam CovarMat Type of the covariance matrix
611-
* @param[in,out] state Shared Newton state (modified: B, b, curr().a())
633+
* @param[in,out] state Shared Newton state (modified: B, b,
634+
* proposal_step().a())
612635
* @param[in] ll_fun Log-likelihood functor
613636
* @param[in,out] ll_args Additional arguments for the likelihood
614637
* @param[in] covariance Prior covariance matrix Sigma
@@ -646,12 +669,12 @@ struct CholeskyWSolverBlock {
646669
// 4. Factorize B with jittering fallback
647670
llt_with_jitter(llt_B, state.B);
648671

649-
// 5. Solve for curr.a
672+
// 5. Solve for the raw Newton proposal in a-space.
650673
state.b.noalias()
651674
= W_block * state.prev().theta() + state.prev().theta_grad();
652675
auto L = llt_B.matrixL();
653676
auto LT = llt_B.matrixU();
654-
state.curr().a().noalias()
677+
state.proposal_step().a().noalias()
655678
= state.b - W_r * LT.solve(L.solve(W_r * (covariance * state.b)));
656679
}
657680

@@ -729,7 +752,7 @@ struct CholeskyKSolver {
729752
* @tparam LLFun Type of the log-likelihood functor
730753
* @tparam LLTupleArgs Type of the likelihood arguments tuple
731754
* @tparam CovarMat Type of the covariance matrix
732-
* @param[in] state Shared Newton state (modified: B, b, curr().a())
755+
* @param[in] state Shared Newton state (modified: B, b, proposal_step().a())
733756
* @param[in] ll_fun Log-likelihood functor
734757
* @param[in] ll_args Additional arguments for the likelihood
735758
* @param[in] covariance Prior covariance matrix Sigma
@@ -756,12 +779,12 @@ struct CholeskyKSolver {
756779
// 3. Factorize B with jittering fallback
757780
llt_with_jitter(llt_B, state.B);
758781

759-
// 4. Solve for curr.a
782+
// 4. Solve for the raw Newton proposal in a-space.
760783
state.b.noalias()
761784
= W_full * state.prev().theta() + state.prev().theta_grad();
762785
auto L = llt_B.matrixL();
763786
auto LT = llt_B.matrixU();
764-
state.curr().a().noalias()
787+
state.proposal_step().a().noalias()
765788
= K_root.transpose().template triangularView<Eigen::Upper>().solve(
766789
LT.solve(L.solve(K_root.transpose() * state.b)));
767790
}
@@ -826,7 +849,7 @@ struct LUSolver {
826849
* @tparam LLFun Type of the log-likelihood functor
827850
* @tparam LLTupleArgs Type of the likelihood arguments tuple
828851
* @tparam CovarMat Type of the covariance matrix
829-
* @param[in,out] state Shared Newton state (modified: b, curr().a())
852+
* @param[in,out] state Shared Newton state (modified: b, proposal_step().a())
830853
* @param[in] ll_fun Log-likelihood functor
831854
* @param[in,out] ll_args Additional arguments for the likelihood
832855
* @param[in] covariance Prior covariance matrix Sigma
@@ -848,10 +871,10 @@ struct LUSolver {
848871
lu.compute(Eigen::MatrixXd::Identity(theta_size, theta_size)
849872
+ covariance * W_full);
850873

851-
// 3. Solve for curr.a
874+
// 3. Solve for the raw Newton proposal in a-space.
852875
state.b.noalias()
853876
= W_full * state.prev().theta() + state.prev().theta_grad();
854-
state.curr().a().noalias()
877+
state.proposal_step().a().noalias()
855878
= state.b - W_full * lu.solve(covariance * state.b);
856879
}
857880

@@ -925,26 +948,32 @@ inline auto run_newton_loop(SolverPolicy& solver, NewtonStateT& state,
925948
solver.solve_step(state, ll_fun, ll_args, covariance,
926949
options.hessian_block_size, msgs);
927950
if (!state.final_loop) {
928-
state.wolfe_info.p_ = state.curr().a() - state.prev().a();
951+
auto&& proposal = state.proposal_step();
952+
state.wolfe_info.p_ = proposal.a() - state.prev().a();
929953
state.prev_g.noalias() = -covariance * state.prev().a()
930954
+ covariance * state.prev().theta_grad();
931955
state.wolfe_info.init_dir_ = state.prev_g.dot(state.wolfe_info.p_);
932956
// Flip direction if not ascending
933957
state.wolfe_info.flip_direction();
934958
auto&& scratch = state.wolfe_info.scratch_;
935-
scratch.alpha() = 1.0;
936-
update_fun(scratch, state.curr(), state.prev(), scratch.eval_,
937-
state.wolfe_info.p_);
938-
if (scratch.alpha() <= options.line_search.min_alpha) {
939-
state.wolfe_status.accept_ = false;
940-
finish_update = true;
959+
proposal.eval_.alpha() = 1.0;
960+
const bool proposal_valid
961+
= update_fun(proposal, state.curr(), state.prev(), proposal.eval_,
962+
state.wolfe_info.p_);
963+
const bool cached_proposal_ok
964+
= proposal_valid && std::isfinite(proposal.obj())
965+
&& std::isfinite(proposal.dir())
966+
&& proposal.alpha() > options.line_search.min_alpha;
967+
if (!cached_proposal_ok) {
968+
state.wolfe_status
969+
= WolfeStatus{WolfeReturn::StepTooSmall, 1, 0, false};
941970
} else if (options.line_search.max_iterations == 0) {
942-
state.curr().update(scratch);
943-
state.wolfe_status.accept_ = true;
971+
state.curr().update(proposal);
972+
state.wolfe_status = WolfeStatus{WolfeReturn::Continue, 1, 0, true};
944973
} else {
945-
Eigen::VectorXd s = scratch.a() - state.prev().a();
974+
Eigen::VectorXd s = proposal.a() - state.prev().a();
946975
auto full_step_grad
947-
= (-covariance * scratch.a() + covariance * scratch.theta_grad())
976+
= (-covariance * proposal.a() + covariance * proposal.theta_grad())
948977
.eval();
949978
state.curr().alpha() = barzilai_borwein_step_size(
950979
s, full_step_grad, state.prev_g, state.prev().alpha(),
@@ -953,22 +982,29 @@ inline auto run_newton_loop(SolverPolicy& solver, NewtonStateT& state,
953982
state.wolfe_status = internal::wolfe_line_search(
954983
state.wolfe_info, update_fun, options.line_search, msgs);
955984
}
956-
/**
957-
* Stop when objective change is small, or when a rejected Wolfe step
958-
* fails to improve; finish_update then exits the Newton loop.
959-
*/
985+
bool search_failed = !state.wolfe_status.accept_;
986+
const bool proposal_armijo_ok
987+
= cached_proposal_ok
988+
&& internal::check_armijo(
989+
proposal.obj(), state.prev().obj(), proposal.alpha(),
990+
state.wolfe_info.init_dir_, options.line_search);
991+
if (search_failed && proposal_armijo_ok) {
992+
state.curr().update(proposal);
993+
state.wolfe_status
994+
= WolfeStatus{WolfeReturn::Armijo, state.wolfe_status.num_evals_,
995+
state.wolfe_status.num_backtracks_, true};
996+
search_failed = false;
997+
}
960998
bool objective_converged
961-
= std::abs(state.curr().obj() - state.prev().obj())
962-
< options.tolerance;
963-
bool search_failed = (!state.wolfe_status.accept_
964-
&& state.curr().obj() <= state.prev().obj());
999+
= state.wolfe_status.accept_
1000+
&& std::abs(state.curr().obj() - state.prev().obj())
1001+
< options.tolerance;
9651002
finish_update = objective_converged || search_failed;
9661003
}
9671004
if (finish_update) {
9681005
if (!state.final_loop && state.wolfe_status.accept_) {
9691006
// Do one final loop with exact wolfe conditions
9701007
state.final_loop = true;
971-
// NOTE: Swapping here so we need to swap prev and curr later
9721008
state.update_next_step(options);
9731009
continue;
9741010
}
@@ -1152,7 +1188,13 @@ inline auto laplace_marginal_density_est(
11521188
return laplace_likelihood::theta_grad(ll_fun, theta_val, ll_args, msgs);
11531189
};
11541190
decltype(auto) theta_init = theta_init_impl<InitTheta>(theta_size, options);
1155-
internal::NewtonState state(theta_size, obj_fun, theta_grad_f, theta_init);
1191+
// When the user supplies a non-zero theta_init, we must initialise a
1192+
// consistently so that the invariant theta = Sigma * a holds. Otherwise
1193+
// the prior term -0.5 * a'*theta vanishes (a=0 while theta!=0), inflating
1194+
// the initial objective and causing the Wolfe line search to reject the
1195+
// first Newton step.
1196+
auto state
1197+
= NewtonState(theta_size, obj_fun, theta_grad_f, covariance, theta_init);
11561198
// Start with safe step size
11571199
auto update_fun = create_update_fun(
11581200
std::move(obj_fun), std::move(theta_grad_f), covariance, options);

stan/math/mix/functor/wolfe_line_search.hpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,12 @@ struct WolfeData {
461461
a_.swap(other.a_);
462462
eval_ = other.eval_;
463463
}
464+
void swap(WolfeData& other) {
465+
theta_.swap(other.theta_);
466+
theta_grad_.swap(other.theta_grad_);
467+
a_.swap(other.a_);
468+
std::swap(eval_, other.eval_);
469+
}
464470
void update(WolfeData& other, const Eval& eval) {
465471
theta_.swap(other.theta_);
466472
a_.swap(other.a_);
@@ -499,13 +505,25 @@ struct WolfeInfo {
499505
Eigen::VectorXd p_;
500506
// Initial directional derivative
501507
double init_dir_;
502-
template <typename ObjFun, typename Theta0, typename ThetaGradF>
503-
WolfeInfo(ObjFun&& obj_fun, Eigen::Index n, Theta0&& theta0,
508+
509+
/**
510+
* Construct WolfeInfo with a consistent (a_init, theta_init) pair.
511+
*
512+
* When the caller supplies a non-zero theta_init, the corresponding
513+
* a_init = Sigma^{-1} * theta_init must be provided so that the
514+
* invariant theta = Sigma * a holds at initialization. This avoids
515+
* an inflated initial objective (the prior term -0.5 * a'*theta would
516+
* otherwise vanish when a is zero but theta is not).
517+
*/
518+
template <typename ObjFun, typename Theta0, typename AInit,
519+
typename ThetaGradF>
520+
WolfeInfo(ObjFun&& obj_fun, AInit&& a_init, Theta0&& theta0,
504521
ThetaGradF&& theta_grad_f)
505-
: curr_(std::forward<ObjFun>(obj_fun), n, std::forward<Theta0>(theta0),
522+
: curr_(std::forward<ObjFun>(obj_fun), std::forward<AInit>(a_init),
523+
std::forward<Theta0>(theta0),
506524
std::forward<ThetaGradF>(theta_grad_f)),
507525
prev_(curr_),
508-
scratch_(n) {
526+
scratch_(a_init.size()) {
509527
if (!std::isfinite(curr_.obj())) {
510528
throw std::domain_error(
511529
"laplace_marginal_density: log likelihood is not finite at initial "
@@ -902,9 +920,10 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
902920
} else { // [3]
903921
high = mid;
904922
}
923+
} else {
924+
// [4]
925+
high = mid;
905926
}
906-
// [4]
907-
high = mid;
908927
} else {
909928
// [5]
910929
high = mid;

0 commit comments

Comments
 (0)