@@ -156,7 +156,7 @@ namespace internal {
156156 * (x_left + x_right) / 2 is returned instead.
157157 */
158158template <typename Scalar>
159- [[nodiscard]] inline Scalar cubic_or_bisect_max (Scalar x_left, Scalar f_left,
159+ [[nodiscard]] inline Scalar cubic_interpolation (Scalar x_left, Scalar f_left,
160160 Scalar df_left, Scalar x_right,
161161 Scalar f_right,
162162 Scalar df_right) noexcept {
@@ -283,8 +283,8 @@ template <typename Scalar>
283283}
284284
285285template <typename Eval, typename Options>
286- inline auto cubic_or_bisect_max (Eval&& low, Eval&& high, Options&& opt) {
287- auto alpha = cubic_or_bisect_max (low.alpha (), low.obj (), low.dir (),
286+ inline auto cubic_interpolation (Eval&& low, Eval&& high, Options&& opt) {
287+ auto alpha = cubic_interpolation (low.alpha (), low.obj (), low.dir (),
288288 high.alpha (), high.obj (), high.dir ());
289289 const double width = high.alpha () - low.alpha ();
290290 const double guard = 1e-3 * width; // or make this an option
@@ -714,7 +714,7 @@ inline auto retry_evaluate(Update&& update, Proposal&& proposal, Curr&& curr,
714714 *
715715 * - If `low.dir()` and `high.dir()` have opposite signs and the right
716716 * endpoint `high` satisfies Armijo, a cubic interpolation of the endpoints
717- * is used (`cubic_or_bisect_max (low, high, opt)`).
717+ * is used (`cubic_interpolation (low, high, opt)`).
718718 * - Otherwise the trial is the simple bisection midpoint
719719 * \f$\tfrac{1}{2}(\alpha_\text{low} + \alpha_\text{high})\f$.
720720 *
@@ -864,10 +864,6 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
864864 Eval low{0.0 , prev.obj (), dir_deriv_init};
865865 prev.dir () = dir_deriv_init;
866866 int total_updates = 0 ;
867- auto eval_finite = [](const Eval& e, const WolfeData& state) {
868- return std::isfinite (e.obj ()) && std::isfinite (e.dir ())
869- && state.theta ().allFinite () && state.theta_grad ().allFinite ();
870- };
871867 Eval best = low; // keep the best Armijo-OK in case strong-Wolfe fails
872868 auto update_with_tick = [&total_updates, &opt, &best, &update_fun](
873869 auto && proposal, auto && curr, auto && prev,
@@ -895,7 +891,6 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
895891 = std::clamp (curr.alpha () * opt.scale_up , opt.min_alpha , opt.max_alpha );
896892 Eval high{alpha_start, curr.obj (), dir_deriv_init};
897893 WolfeStatus wolfe_check{WolfeReturn::Continue, 0 , 0 , false };
898- bool high_has_eval = true ;
899894 // Initial check for numerical trouble
900895 {
901896 wolfe_check = update_with_tick (scratch, curr, prev, high, p);
@@ -920,7 +915,6 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
920915 if (wolfe_check.stop_ != WolfeReturn::Continue) {
921916 return wolfe_check;
922917 }
923- high_has_eval = true ;
924918 }
925919 wolfe_check = update_with_tick (scratch, curr, prev, best, p);
926920 if (wolfe_check.stop_ != WolfeReturn::Continue) {
@@ -935,55 +929,50 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
935929 }
936930 }
937931 }
938- bool found_right = false ;
939932 int num_backtracks = 0 ;
940933 /* *
941- * For each case
934+ * From Nocedal–Wright (2006), Algorithm 3.5:
935+ * https://www.math.uci.edu/~qnie/Publications/NumericalOptimization.pdf
942936 * | armijo | wolfe | sign(g) | Action
943937 * -------+-------+---------+--------------------------------
944938 * | [1] T | T | | Accept alpha
945939 * | [2] T | F | > 0 | set low=high, expand high
946- * | [3] T | F | < 0 | Set alpha_high <- alpha, stop
947- * | [4] F | T | | Set alpha_high <- alpha, stop
948- * | [5] F | F | | Set alpha_high <- alpha, stop
940+ * | [3] T | F | < 0 | Bracket found: stop
941+ * | [4] F | T | | Bracket found: stop
942+ * | [5] F | F | | Bracket found: stop
943+ * NOTE: In an ideal case we would end up with a positive low directional gradient and
944+ * negative high directional gradient. Cubic interpolation requires a bracket with directional
945+ * shape like /\. This scheme does not gurantee a bracket with that shape will be found.
946+ * So in the zoom we will have to check if we can do cubic or have to fallback to bisection.
949947 **/
950- while (!found_right && high.alpha () < opt.max_alpha ) {
948+ while (high.alpha () < opt.max_alpha ) {
951949 num_backtracks++;
952- // 1. Evaluate f(alpha) and g(alpha)
953950 wolfe_check = update_with_tick (scratch, curr, prev, high, p);
954951 if (wolfe_check.stop_ != WolfeReturn::Continue) {
955952 return wolfe_check;
956953 }
957- high_has_eval = true ;
958- const bool finite_ok = eval_finite (high, scratch);
959- // 2. Handle numerical trouble first
960- if (!finite_ok) { // f or g is NaN/Inf → shrink
961- high.alpha () *= 0.5 ;
962- high_has_eval = false ;
963- if (high.alpha () < opt.min_alpha ) {
964- break ;
965- }
966- continue ;
967- }
968954 const bool armijo = check_armijo (high, prev, opt);
969955 const bool wolfe = check_wolfe (high, prev, opt);
970- if (armijo && wolfe) { // [1]
956+ // [1]
957+ if (armijo && wolfe) {
971958 curr.update (scratch, high);
972959 return WolfeStatus{WolfeReturn::Wolfe, total_updates, num_backtracks,
973960 true };
961+ } else if (armijo) {
962+ if (best.obj () < high.obj ()) {
963+ best = high;
964+ }
965+ // [2]
966+ if (high.dir () > 0 ) {
967+ low = high;
968+ high.alpha () *= opt.scale_up ;
969+ continue ;
970+ }
971+ // [3]
972+ break ;
974973 }
975- if (armijo && best.obj () < high.obj ()) {
976- best = high;
977- }
978- const bool dir_pos = high.dir () > 0 ;
979- if (armijo && !wolfe && dir_pos) { // [2]
980- low = high;
981- high.alpha () *= opt.scale_up ;
982- high_has_eval = false ;
983- continue ;
984- }
985- // [3,4,5]
986- found_right = true ;
974+ // [3, 4, 5]
975+ break ;
987976 }
988977 const double grad_tol
989978 = std::max (opt.abs_grad_threshold ,
@@ -1018,13 +1007,6 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10181007 return WolfeStatus{WolfeReturn::Continue, total_updates, num_backtracks,
10191008 false };
10201009 };
1021- if (!high_has_eval) {
1022- wolfe_check = update_with_tick (scratch, curr, prev, high, p);
1023- if (wolfe_check.stop_ != WolfeReturn::Continue) {
1024- return wolfe_check;
1025- }
1026- high_has_eval = true ;
1027- }
10281010 auto check_b = check_bounds (high);
10291011 if (check_b.stop_ != WolfeReturn::Continue) {
10301012 if (check_b.accept_ ) {
@@ -1036,7 +1018,19 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10361018 if (wolfe_check.stop_ != WolfeReturn::Continue) {
10371019 return wolfe_check;
10381020 }
1039- // Zoom phase
1021+ /* *
1022+ * Zoom Step: (Alg 3.6, adapted to maximization via phi=-obj)
1023+ *
1024+ * Exit/update table (evaluated at `mid`, with `low` = best Armijo endpoint):
1025+ * | Armijo? | obj(mid) >= obj(low)? | Wolfe? | dir(mid) >= 0? | Action
1026+ * |---------|-----------------------|--------|----------------|--------------------------|
1027+ * | T | F | T | * | accept mid [1] |
1028+ * | T | T | * | * | high = mid [2] |
1029+ * | T | F | F | T | high = low; low = mid [3]|
1030+ * | T | F | F | F | low = mid [4] |
1031+ * | F | * | * | * | high = mid [5] |
1032+ * ----------------------------------------------------------------------------------------
1033+ **/
10401034 while ((high.alpha () - low.alpha () > opt.min_alpha )
10411035 && high.alpha () > opt.min_alpha ) {
10421036 num_backtracks++;
@@ -1046,9 +1040,12 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10461040 const bool use_cubic = have_sign_change && high_armijo_ok;
10471041
10481042 // Choose trial alpha: cubic when bracket is good, else bisection.
1049- double alpha_mid = use_cubic ? cubic_or_bisect_max (low, high, opt)
1050- : 0.5 * (low.alpha () + high.alpha ());
1051-
1043+ double alpha_mid{0 };
1044+ if (use_cubic) {
1045+ alpha_mid = cubic_interpolation (low, high, opt);
1046+ } else {
1047+ alpha_mid = 0.5 * (low.alpha () + high.alpha ());
1048+ }
10521049 if (alpha_mid <= opt.min_alpha ) {
10531050 break ;
10541051 }
@@ -1063,6 +1060,7 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10631060 }
10641061 if (check_armijo (mid, prev, opt)) {
10651062 if (check_wolfe (mid, prev, opt)) {
1063+ // [1]
10661064 curr.update (scratch, mid);
10671065 return WolfeStatus{WolfeReturn::Wolfe, total_updates, num_backtracks,
10681066 true };
@@ -1071,17 +1069,17 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10711069 if (mid.obj () > best.obj ()) {
10721070 best = mid;
10731071 }
1074- }
1075-
1076- // Update bracket based on derivative sign
1077- if (mid.dir () * low.dir () < 0 ) {
1078- // sign change between low and mid -> [low, mid]
1079- high = mid;
1080- } else {
1081- // otherwise shift left endpoint -> [mid, high]
1072+ if (mid.obj () >= low.obj ()) {
1073+ // [2]
1074+ high = mid;
1075+ } else if (mid.dir () >= 0 ) {
1076+ // [3]
1077+ high = low;
1078+ low = mid;
1079+ }
1080+ // [4]
10821081 low = mid;
10831082 }
1084-
10851083 // Convergence/guard-rail checks (uses prev/grad_tol/obj_tol etc.)
10861084 auto bounds_check = check_bounds (mid);
10871085 if (bounds_check.stop_ != WolfeReturn::Continue) {
@@ -1090,6 +1088,8 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10901088 }
10911089 return bounds_check;
10921090 }
1091+ // [5]
1092+ high = mid;
10931093 }
10941094 // On failure, use the best point we have found so far that at least satisfies
10951095 // armijo
0 commit comments