Skip to content

Commit eab35ec

Browse files
committed
update with last review
1 parent 6789f32 commit eab35ec

2 files changed

Lines changed: 79 additions & 78 deletions

File tree

stan/math/mix/functor/laplace_marginal_density_estimator.hpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <stan/math/prim/functor/iter_tuple_nested.hpp>
1414
#include <unsupported/Eigen/MatrixFunctions>
1515
#include <cmath>
16+
#include <sstream>
1617

1718
/**
1819
* @file
@@ -443,8 +444,12 @@ inline void llt_with_jitter(LLT& llt_B, B_t& B, double min_jitter = 1e-10,
443444
}
444445
}
445446
if (llt_B.info() != Eigen::Success) {
446-
throw std::domain_error(
447-
"laplace_marginal_density: Cholesky (Diag) failed");
447+
std::stringstream msg;
448+
msg << "laplace_marginal_density: Cholesky decomposition failed on "
449+
<< "Hessian matrix B after attempting jitter values from "
450+
<< min_jitter << " to " << max_jitter
451+
<< ". Matrix may not be positive definite.";
452+
throw std::domain_error(msg.str());
448453
}
449454
}
450455
}
@@ -942,16 +947,13 @@ inline auto run_newton_loop(SolverPolicy& solver, NewtonStateT& state,
942947
scratch.alpha() = 1.0;
943948
update_fun(scratch, state.curr(), state.prev(), scratch.eval_,
944949
state.wolfe_info.p_);
945-
bool run_convergence_check = true;
950+
bool force_finish = false;
946951
if (scratch.alpha() <= options.line_search.min_alpha) {
947952
state.wolfe_status.accept_ = false;
948-
finish_update = true;
949-
run_convergence_check = false;
953+
force_finish = true;
950954
} else if (options.line_search.max_iterations == 0) {
951955
state.curr().update(scratch);
952956
state.wolfe_status.accept_ = true;
953-
finish_update = false;
954-
run_convergence_check = false;
955957
} else {
956958
Eigen::VectorXd s = scratch.a() - state.prev().a();
957959
auto full_step_grad
@@ -964,16 +966,15 @@ inline auto run_newton_loop(SolverPolicy& solver, NewtonStateT& state,
964966
state.wolfe_status = internal::wolfe_line_search(
965967
state.wolfe_info, update_fun, options.line_search, msgs);
966968
}
967-
if (run_convergence_check) {
968-
/**
969-
* Stop when objective change is small, or when a rejected Wolfe step
970-
* fails to improve; finish_update then exits the Newton loop.
971-
*/
972-
finish_update = std::abs(state.curr().obj() - state.prev().obj())
973-
< options.tolerance
974-
|| (!state.wolfe_status.accept_
975-
&& state.curr().obj() <= state.prev().obj());
976-
}
969+
/**
970+
* Stop when objective change is small, or when a rejected Wolfe step
971+
* fails to improve; finish_update then exits the Newton loop.
972+
*/
973+
const bool obj_below_tol = std::abs(state.curr().obj() - state.prev().obj()) <
974+
options.tolerance;
975+
const bool wolfe_failed = !state.wolfe_status.accept_
976+
&& state.curr().obj() <= state.prev().obj();
977+
finish_update = force_finish || obj_below_tol || wolfe_failed;
977978
}
978979
if (finish_update) {
979980
if (!state.final_loop && state.wolfe_status.accept_) {

stan/math/mix/functor/wolfe_line_search.hpp

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ namespace internal {
156156
* (x_left + x_right) / 2 is returned instead.
157157
*/
158158
template <typename Scalar>
159-
[[nodiscard]] inline Scalar cubic_or_bisect_max(Scalar x_left, Scalar f_left,
159+
[[nodiscard]] inline Scalar cubic_interpolation(Scalar x_left, Scalar f_left,
160160
Scalar df_left, Scalar x_right,
161161
Scalar f_right,
162162
Scalar df_right) noexcept {
@@ -283,8 +283,8 @@ template <typename Scalar>
283283
}
284284

285285
template <typename Eval, typename Options>
286-
inline auto cubic_or_bisect_max(Eval&& low, Eval&& high, Options&& opt) {
287-
auto alpha = cubic_or_bisect_max(low.alpha(), low.obj(), low.dir(),
286+
inline auto cubic_interpolation(Eval&& low, Eval&& high, Options&& opt) {
287+
auto alpha = cubic_interpolation(low.alpha(), low.obj(), low.dir(),
288288
high.alpha(), high.obj(), high.dir());
289289
const double width = high.alpha() - low.alpha();
290290
const double guard = 1e-3 * width; // or make this an option
@@ -714,7 +714,7 @@ inline auto retry_evaluate(Update&& update, Proposal&& proposal, Curr&& curr,
714714
*
715715
* - If `low.dir()` and `high.dir()` have opposite signs and the right
716716
* endpoint `high` satisfies Armijo, a cubic interpolation of the endpoints
717-
* is used (`cubic_or_bisect_max(low, high, opt)`).
717+
* is used (`cubic_interpolation(low, high, opt)`).
718718
* - Otherwise the trial is the simple bisection midpoint
719719
* \f$\tfrac{1}{2}(\alpha_\text{low} + \alpha_\text{high})\f$.
720720
*
@@ -864,10 +864,6 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
864864
Eval low{0.0, prev.obj(), dir_deriv_init};
865865
prev.dir() = dir_deriv_init;
866866
int total_updates = 0;
867-
auto eval_finite = [](const Eval& e, const WolfeData& state) {
868-
return std::isfinite(e.obj()) && std::isfinite(e.dir())
869-
&& state.theta().allFinite() && state.theta_grad().allFinite();
870-
};
871867
Eval best = low; // keep the best Armijo-OK in case strong-Wolfe fails
872868
auto update_with_tick = [&total_updates, &opt, &best, &update_fun](
873869
auto&& proposal, auto&& curr, auto&& prev,
@@ -895,7 +891,6 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
895891
= std::clamp(curr.alpha() * opt.scale_up, opt.min_alpha, opt.max_alpha);
896892
Eval high{alpha_start, curr.obj(), dir_deriv_init};
897893
WolfeStatus wolfe_check{WolfeReturn::Continue, 0, 0, false};
898-
bool high_has_eval = true;
899894
// Initial check for numerical trouble
900895
{
901896
wolfe_check = update_with_tick(scratch, curr, prev, high, p);
@@ -920,7 +915,6 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
920915
if (wolfe_check.stop_ != WolfeReturn::Continue) {
921916
return wolfe_check;
922917
}
923-
high_has_eval = true;
924918
}
925919
wolfe_check = update_with_tick(scratch, curr, prev, best, p);
926920
if (wolfe_check.stop_ != WolfeReturn::Continue) {
@@ -935,55 +929,50 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
935929
}
936930
}
937931
}
938-
bool found_right = false;
939932
int num_backtracks = 0;
940933
/**
941-
* For each case
934+
* From Nocedal–Wright (2006), Algorithm 3.5:
935+
* https://www.math.uci.edu/~qnie/Publications/NumericalOptimization.pdf
942936
* | armijo | wolfe | sign(g) | Action
943937
* -------+-------+---------+--------------------------------
944938
* | [1] T | T | | Accept alpha
945939
* | [2] T | F | > 0 | set low=high, expand high
946-
* | [3] T | F | < 0 | Set alpha_high <- alpha, stop
947-
* | [4] F | T | | Set alpha_high <- alpha, stop
948-
* | [5] F | F | | Set alpha_high <- alpha, stop
940+
* | [3] T | F | < 0 | Bracket found: stop
941+
* | [4] F | T | | Bracket found: stop
942+
* | [5] F | F | | Bracket found: stop
943+
* NOTE: In an ideal case we would end up with a positive low directional gradient and
944+
* negative high directional gradient. Cubic interpolation requires a bracket with directional
945+
* shape like /\. This scheme does not gurantee a bracket with that shape will be found.
946+
* So in the zoom we will have to check if we can do cubic or have to fallback to bisection.
949947
**/
950-
while (!found_right && high.alpha() < opt.max_alpha) {
948+
while (high.alpha() < opt.max_alpha) {
951949
num_backtracks++;
952-
// 1. Evaluate f(alpha) and g(alpha)
953950
wolfe_check = update_with_tick(scratch, curr, prev, high, p);
954951
if (wolfe_check.stop_ != WolfeReturn::Continue) {
955952
return wolfe_check;
956953
}
957-
high_has_eval = true;
958-
const bool finite_ok = eval_finite(high, scratch);
959-
// 2. Handle numerical trouble first
960-
if (!finite_ok) { // f or g is NaN/Inf → shrink
961-
high.alpha() *= 0.5;
962-
high_has_eval = false;
963-
if (high.alpha() < opt.min_alpha) {
964-
break;
965-
}
966-
continue;
967-
}
968954
const bool armijo = check_armijo(high, prev, opt);
969955
const bool wolfe = check_wolfe(high, prev, opt);
970-
if (armijo && wolfe) { // [1]
956+
// [1]
957+
if (armijo && wolfe) {
971958
curr.update(scratch, high);
972959
return WolfeStatus{WolfeReturn::Wolfe, total_updates, num_backtracks,
973960
true};
961+
} else if (armijo) {
962+
if (best.obj() < high.obj()) {
963+
best = high;
964+
}
965+
// [2]
966+
if (high.dir() > 0) {
967+
low = high;
968+
high.alpha() *= opt.scale_up;
969+
continue;
970+
}
971+
// [3]
972+
break;
974973
}
975-
if (armijo && best.obj() < high.obj()) {
976-
best = high;
977-
}
978-
const bool dir_pos = high.dir() > 0;
979-
if (armijo && !wolfe && dir_pos) { // [2]
980-
low = high;
981-
high.alpha() *= opt.scale_up;
982-
high_has_eval = false;
983-
continue;
984-
}
985-
// [3,4,5]
986-
found_right = true;
974+
// [3, 4, 5]
975+
break;
987976
}
988977
const double grad_tol
989978
= std::max(opt.abs_grad_threshold,
@@ -1018,13 +1007,6 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10181007
return WolfeStatus{WolfeReturn::Continue, total_updates, num_backtracks,
10191008
false};
10201009
};
1021-
if (!high_has_eval) {
1022-
wolfe_check = update_with_tick(scratch, curr, prev, high, p);
1023-
if (wolfe_check.stop_ != WolfeReturn::Continue) {
1024-
return wolfe_check;
1025-
}
1026-
high_has_eval = true;
1027-
}
10281010
auto check_b = check_bounds(high);
10291011
if (check_b.stop_ != WolfeReturn::Continue) {
10301012
if (check_b.accept_) {
@@ -1036,7 +1018,19 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10361018
if (wolfe_check.stop_ != WolfeReturn::Continue) {
10371019
return wolfe_check;
10381020
}
1039-
// Zoom phase
1021+
/**
1022+
* Zoom Step: (Alg 3.6, adapted to maximization via phi=-obj)
1023+
*
1024+
* Exit/update table (evaluated at `mid`, with `low` = best Armijo endpoint):
1025+
* | Armijo? | obj(mid) >= obj(low)? | Wolfe? | dir(mid) >= 0? | Action
1026+
* |---------|-----------------------|--------|----------------|--------------------------|
1027+
* | T | F | T | * | accept mid [1] |
1028+
* | T | T | * | * | high = mid [2] |
1029+
* | T | F | F | T | high = low; low = mid [3]|
1030+
* | T | F | F | F | low = mid [4] |
1031+
* | F | * | * | * | high = mid [5] |
1032+
* ----------------------------------------------------------------------------------------
1033+
**/
10401034
while ((high.alpha() - low.alpha() > opt.min_alpha)
10411035
&& high.alpha() > opt.min_alpha) {
10421036
num_backtracks++;
@@ -1046,9 +1040,12 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10461040
const bool use_cubic = have_sign_change && high_armijo_ok;
10471041

10481042
// Choose trial alpha: cubic when bracket is good, else bisection.
1049-
double alpha_mid = use_cubic ? cubic_or_bisect_max(low, high, opt)
1050-
: 0.5 * (low.alpha() + high.alpha());
1051-
1043+
double alpha_mid{0};
1044+
if (use_cubic) {
1045+
alpha_mid = cubic_interpolation(low, high, opt);
1046+
} else {
1047+
alpha_mid = 0.5 * (low.alpha() + high.alpha());
1048+
}
10521049
if (alpha_mid <= opt.min_alpha) {
10531050
break;
10541051
}
@@ -1063,6 +1060,7 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10631060
}
10641061
if (check_armijo(mid, prev, opt)) {
10651062
if (check_wolfe(mid, prev, opt)) {
1063+
// [1]
10661064
curr.update(scratch, mid);
10671065
return WolfeStatus{WolfeReturn::Wolfe, total_updates, num_backtracks,
10681066
true};
@@ -1071,17 +1069,17 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10711069
if (mid.obj() > best.obj()) {
10721070
best = mid;
10731071
}
1074-
}
1075-
1076-
// Update bracket based on derivative sign
1077-
if (mid.dir() * low.dir() < 0) {
1078-
// sign change between low and mid -> [low, mid]
1079-
high = mid;
1080-
} else {
1081-
// otherwise shift left endpoint -> [mid, high]
1072+
if (mid.obj() >= low.obj()) {
1073+
// [2]
1074+
high = mid;
1075+
} else if (mid.dir() >= 0) {
1076+
// [3]
1077+
high = low;
1078+
low = mid;
1079+
}
1080+
// [4]
10821081
low = mid;
10831082
}
1084-
10851083
// Convergence/guard-rail checks (uses prev/grad_tol/obj_tol etc.)
10861084
auto bounds_check = check_bounds(mid);
10871085
if (bounds_check.stop_ != WolfeReturn::Continue) {
@@ -1090,6 +1088,8 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10901088
}
10911089
return bounds_check;
10921090
}
1091+
// [5]
1092+
high = mid;
10931093
}
10941094
// On failure, use the best point we have found so far that at least satisfies
10951095
// armijo

0 commit comments

Comments
 (0)