Cache Cholesky of K_XX + noise I in DefaultPredictionStrategy#2746
Open
saitcakmak wants to merge 2 commits into
Open
Cache Cholesky of K_XX + noise I in DefaultPredictionStrategy#2746saitcakmak wants to merge 2 commits into
saitcakmak wants to merge 2 commits into
Conversation
The Cholesky factor of the train-train covariance was being recomputed on every posterior call. LinearOperator.solve dispatches through Solve.apply, which reconstructs the lazy operator via representation_tree and drops the @cached(cholesky) state on AddedDiagLinearOperator. Materialize it once as a TriangularLinearOperator cached on the prediction strategy, share across mean_cache and exact_predictive_covar, and switch the covariance correction to a single triangular solve (K(t,T) K_XX^-1 K(T,t) = M^T M with M = L^-1 K(T,t)), resolving the existing TODO about using one triangular solve for dense inference. Adds TestExactPredictiveCovar checking posterior mean, covariance, and gradients w.r.t. test_x against a from-scratch math reference (float64, atol=1e-10 / 1e-9).
SGPR's train-train covariance is a LowRankRootAddedDiagLinearOperator that overrides .solve() to use a Woodbury identity with an inducing-size Cholesky (k x k instead of n x n). The Cholesky-sharing optimization in the initial commit bypassed this: _mean_cache called self.lik_train_train_chol, which routes through the default LinearOperator._cholesky and materializes a dense n x n factor — breaking test_sgpr_mean_abs_error's assertion that every Cholesky during SGPR inference is of inducing size. Extract the "ignore" nan_policy solve into a _solve_lik_train_train helper on DefaultPredictionStrategy, then override it in SGPRPredictionStrategy to dispatch to .evaluate_kernel().solve(), which uses the Woodbury path. exact_predictive_covar is unchanged because SGPRPredictionStrategy overrides that method entirely and never reaches the patched code path.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
K_XX + σ²I) onDefaultPredictionStrategyand share it acrossmean_cacheandexact_predictive_covar, removing a per-posterior-call factorization.exact_predictive_covarto a single triangular solve (K(t,T) K_XX^{-1} K(T,t) = MᵀMwithM = L⁻¹ K(T,t)), resolving the existing# TODOabout avoiding two triangular solves for dense inference.TestExactPredictiveCovar) that checks posterior mean, covariance, and gradients w.r.t.test_xagainst a from-scratch math reference in float64._solve_lik_train_traininSGPRPredictionStrategyso SGPR's Woodbury-based.solve()is preserved (see "Interaction with structured-solve operators" below).Why
The old code constructed a fresh
AddedDiagLinearOperatorinsideexact_predictive_covar(fromself.likelihood(dist, self.train_inputs)) and called.solve(train_test_covar)on it.LinearOperator.solvedispatches throughSolve.apply, which reconstructs the lazy operator viarepresentation_tree(*args)— so the@cached(name="cholesky")memoization onAddedDiagLinearOperatornever survives. Every posterior call ate a freshpsd_safe_cholesky(K_XX + σ²I). On batched models this is a batched factorization; on moderately-sized non-batched models (n≥500) it dominates the forward cost.The fix materializes
Lonce as aTriangularLinearOperatoron the prediction strategy. Subsequent solves bypassSolve.apply(the triangular operator's.solvegoes straight totorch.linalg.solve_triangular), so the Cholesky is truly reused.Correctness
K(t,t) − K(t,T)(K(T,T)+σ²I)⁻¹K(T,t)reference:max|Δ| ≤ 1e-10on values,≤ 1e-9on gradients (float64), on both non-batched(10, 1)and batched(4, 10, 1)test inputs. The newTestExactPredictiveCovar.test_posterior_matches_math_referenceencodes this.alpha=-1→alpha=1) breaks the new test withcov diff ≈ 2.4(vsatol=1e-10).train_final_lossmatches bit-for-bit across every benchmarked config.Interaction with LinearOperators that override
.solve()LinearOperator.solve()dispatches throughSolve.apply, which reconstructs the operator from itsrepresentation_treeand always routes through.cholesky()._cholesky_solve(rhs)for small matrices. Subclasses that want structural efficiency therefore take one of two approaches:Override
_solveonly. This path is never actually invoked through.solve()in today's default settings —Solve.applyswallows it. Example:KroneckerProductAddedDiagLinearOperatorhas sophisticated eigendecomposition and Kronecker-Woodbury branches in its_solve, but baseline code already materializes a dense Cholesky via.cholesky()instead. Confirmed empirically for the multitask tests: zerotorch.linalg.eighcalls during posterior; dense Choleskys only. Our patch is unaffected — we do the same dense Cholesky, just one of them instead of two.Override
.solve()directly, bypassingSolve.apply. This is genuinely used and preserves the structural efficiency. The only class in linear_operator that does this for a PSD operator isLowRankRootAddedDiagLinearOperator, which uses Woodbury with an inducing-sizek × kcap matrix. The only gpytorch code path that produces such an operator aslik_train_train_covaris SGPR viaInducingPointKernel.The initial commit broke category (2) for SGPR:
_mean_cachestarted callingself.lik_train_train_chol, which routes through the defaultLinearOperator._choleskyand materializes a densen × nfactor, defeating Woodbury. The existing testtest_sgpr_mean_abs_errorasserts every Cholesky during SGPR inference is of inducing size; the initial commit maden = 100matrices sneak in and the test failed in CI.Fix
A follow-up commit extracts the "ignore"
nan_policysolve in_mean_cacheinto a one-line_solve_lik_train_trainhelper onDefaultPredictionStrategy, then overrides it inSGPRPredictionStrategyto dispatch to the operator's own.solve().exact_predictive_covaris untouched becauseSGPRPredictionStrategyoverrides that method entirely and never reaches the patched default path.Why subclass override rather than auto-detection
I considered auto-detecting whether
lik_train_train_covar's class overrides.solve()viatype(x).solve is not LinearOperator.solveand branching insideDefaultPredictionStrategy. I decided against it:{SGPR}. I audited every.solve()override in linear_operator (zero,triangular,low_rank_root_added_diag,chol,diag,identity,kronecker_productfor triangular). Of those, onlyLowRankRootAddedDiagLinearOperatoris reachable aslik_train_train_covarthroughDefaultPredictionStrategyin current gpytorch, and that's the SGPR path. Everything else is either non-PSD (triangular, chol) or never appears as train-train covar (diag, identity, etc.).BatchRepeatLinearOperator(LowRankRootAddedDiagLinearOperator),ConstantMulLinearOperator(LowRankRoot), etc. would have a structurally-special inner but a wrapper class that inheritsLinearOperator.solve. The shallow type check says "no structured solve" and forces a dense Cholesky. None of these compositions occur in current gpytorch — the one that could (SGPR fantasy updates producingBatchRepeat(LowRankRoot)) is blocked becauseSGPRPredictionStrategy.get_fantasy_strategyraisesNotImplementedError— but a shallow auto-detector pretends to generalize while actually missing these cases.KroneckerProductAddedDiagLinearOperatordoesn't override.solve(), only_solve, and_solveis dead through the dispatch. Our patch and baseline produce identical Cholesky sizes for multitask models; no regression to work around.SGPRPredictionStrategysees the override and knows what it does. Auto-detection + a cached_evaluated_lik_train_train_covarproperty inflatesDefaultPredictionStrategyto defend against a case that doesn't exist.If a future operator overrides
.solve()for Woodbury-like reasons and a user wires it into aDefaultPredictionStrategy(i.e., they don't subclass the prediction strategy themselves), they'd fall back to the dense Cholesky path with no correctness issue — just the same type of perf regression SGPR would have hit without a dedicated override. The_solve_lik_train_trainhelper is a deliberate one-line extension point for that situation, and the SGPR override is a worked example.BO loop impact (aggregate)
model.posterior(X))posterior(X).rsample().sum().backward()rsampleinternals (posterior-covar Cholesky,L @ noise)No significant regressions at 10-seed precision. Biggest wins on the cold posterior path (first acquisition build, fresh refit) and on batched models where the old per-call batched Cholesky was expensive.
Out of scope
SGPRPredictionStrategy,LinearPredictionStrategy, andInterpolatedPredictionStrategyoverrideexact_predictive_covarand don't inherit the change; their existing tests all pass.Test plan
python -m pytest test/models/test_exact_gp.py -q— 59 passedpython -m pytest test/mlls/test_exact_marginal_log_likelihood.py test/distributions/ -q— 48 passedpython -m pytest test/examples/ -q— 101 passed (covers SGPR, all Kronecker/Hadamard/LCM/batched multitask variants)TestExactPredictiveCovar.test_posterior_matches_math_reference— passes (atol=1e-10on values,1e-9on gradients)cov diff ≈ 2.4)test_sgpr_mean_abs_error— passes after the follow-up commit; asserts every Cholesky during SGPR inference is of inducing (k) sizeflake8clean on touched filesBenchmark setup
SingleTaskGPwithNormalize(d)+Standardize(m=1)transforms.botorchis imported for side-effect, which disables linear-operator fast paths (_fast_solves=False,_fast_covar_root_decomposition=False,_fast_log_prob=False,max_cholesky_size=4096). Matches the setting BoTorch users actually see.lr=5e-3(only to produce non-init hyperparameters —train_final_lossis bit-identical between baseline and prototype).n ∈ {20, 100, 500, 1000},d ∈ {2, 5}; batchedB ∈ {16, 32, 64},n ∈ {50, 100},d ∈ {4, 5}. Test pointsm = 128.(m, d)and an MC-batch variant. Batched models also include a test input matched to the model batch ((B, m, d)).*=|Δ| > 3%AND exceeds combined stdev):first_posterior_ms— invalidateprediction_strategyeach call, then onemodel.posterior(X)underno_grad.repeat_posterior_ms— reuseprediction_strategy, freshXeach call, underno_grad.warm_rsample_bwd_ms—prediction_strategykept warm;p = model.posterior(X); s = p.rsample(Size([16])); s.sum().backward()withX.requires_grad=True.cold_rsample_bwd_ms— invalidateprediction_strategyeach call, then the same rsample+backward. Primary BoTorch MC acquisition pattern.nb_n1000_d5andb64_n100_d5to tighten confidence on the two regimes with the largest measurement noise.Detailed results —
cold_rsample_bwd_ms(primary BoTorch pattern, 5 seeds)Detailed results —
warm_rsample_bwd_ms(steady-state L-BFGS-B iteration, 5 seeds)Detailed results —
first_posterior_msandrepeat_posterior_ms(5 seeds)10-seed reverification (tightening confidence on large-n and B=64 rsample paths)
The 5-seed run flagged two marginal concerns: a
+10.3%*warm regression atb64_n100_d5 64×128×5and a consistent-direction ~+5% drift inwarm_rsampleatnb_n1000_d5. Rerunning these configs at 10 seeds:Both flagged concerns resolve: the 5-seed
+10.3%*warm regression atb64_n100 64×128collapses to+1.7%(not significant), and n=1000 warm paths are neutral across all shapes while cold paths show large wins.