Skip to content

Cache Cholesky of K_XX + noise I in DefaultPredictionStrategy#2746

Open
saitcakmak wants to merge 2 commits into
cornellius-gp:mainfrom
saitcakmak:cache-chol-exact-prediction
Open

Cache Cholesky of K_XX + noise I in DefaultPredictionStrategy#2746
saitcakmak wants to merge 2 commits into
cornellius-gp:mainfrom
saitcakmak:cache-chol-exact-prediction

Conversation

@saitcakmak
Copy link
Copy Markdown
Collaborator

@saitcakmak saitcakmak commented Apr 17, 2026

Summary

  • Cache the Cholesky factor of the train-train covariance (K_XX + σ²I) on DefaultPredictionStrategy and share it across mean_cache and exact_predictive_covar, removing a per-posterior-call factorization.
  • Switch the covariance correction in exact_predictive_covar to a single triangular solve (K(t,T) K_XX^{-1} K(T,t) = MᵀM with M = L⁻¹ K(T,t)), resolving the existing # TODO about avoiding two triangular solves for dense inference.
  • Add a parity test (TestExactPredictiveCovar) that checks posterior mean, covariance, and gradients w.r.t. test_x against a from-scratch math reference in float64.
  • Override _solve_lik_train_train in SGPRPredictionStrategy so SGPR's Woodbury-based .solve() is preserved (see "Interaction with structured-solve operators" below).

Why

The old code constructed a fresh AddedDiagLinearOperator inside exact_predictive_covar (from self.likelihood(dist, self.train_inputs)) and called .solve(train_test_covar) on it. LinearOperator.solve dispatches through Solve.apply, which reconstructs the lazy operator via representation_tree(*args) — so the @cached(name="cholesky") memoization on AddedDiagLinearOperator never survives. Every posterior call ate a fresh psd_safe_cholesky(K_XX + σ²I). On batched models this is a batched factorization; on moderately-sized non-batched models (n≥500) it dominates the forward cost.

The fix materializes L once as a TriangularLinearOperator on the prediction strategy. Subsequent solves bypass Solve.apply (the triangular operator's .solve goes straight to torch.linalg.solve_triangular), so the Cholesky is truly reused.

Correctness

  • Parity against a from-scratch K(t,t) − K(t,T)(K(T,T)+σ²I)⁻¹K(T,t) reference: max|Δ| ≤ 1e-10 on values, ≤ 1e-9 on gradients (float64), on both non-batched (10, 1) and batched (4, 10, 1) test inputs. The new TestExactPredictiveCovar.test_posterior_matches_math_reference encodes this.
  • Teeth check: flipping the sign of the correction term (alpha=-1alpha=1) breaks the new test with cov diff ≈ 2.4 (vs atol=1e-10).
  • Training path unchanged — train_final_loss matches bit-for-bit across every benchmarked config.

Interaction with LinearOperators that override .solve()

LinearOperator.solve() dispatches through Solve.apply, which reconstructs the operator from its representation_tree and always routes through .cholesky()._cholesky_solve(rhs) for small matrices. Subclasses that want structural efficiency therefore take one of two approaches:

  1. Override _solve only. This path is never actually invoked through .solve() in today's default settings — Solve.apply swallows it. Example: KroneckerProductAddedDiagLinearOperator has sophisticated eigendecomposition and Kronecker-Woodbury branches in its _solve, but baseline code already materializes a dense Cholesky via .cholesky() instead. Confirmed empirically for the multitask tests: zero torch.linalg.eigh calls during posterior; dense Choleskys only. Our patch is unaffected — we do the same dense Cholesky, just one of them instead of two.

  2. Override .solve() directly, bypassing Solve.apply. This is genuinely used and preserves the structural efficiency. The only class in linear_operator that does this for a PSD operator is LowRankRootAddedDiagLinearOperator, which uses Woodbury with an inducing-size k × k cap matrix. The only gpytorch code path that produces such an operator as lik_train_train_covar is SGPR via InducingPointKernel.

The initial commit broke category (2) for SGPR: _mean_cache started calling self.lik_train_train_chol, which routes through the default LinearOperator._cholesky and materializes a dense n × n factor, defeating Woodbury. The existing test test_sgpr_mean_abs_error asserts every Cholesky during SGPR inference is of inducing size; the initial commit made n = 100 matrices sneak in and the test failed in CI.

Fix

A follow-up commit extracts the "ignore" nan_policy solve in _mean_cache into a one-line _solve_lik_train_train helper on DefaultPredictionStrategy, then overrides it in SGPRPredictionStrategy to dispatch to the operator's own .solve(). exact_predictive_covar is untouched because SGPRPredictionStrategy overrides that method entirely and never reaches the patched default path.

Why subclass override rather than auto-detection

I considered auto-detecting whether lik_train_train_covar's class overrides .solve() via type(x).solve is not LinearOperator.solve and branching inside DefaultPredictionStrategy. I decided against it:

  • The set of "real cases today" is exactly {SGPR}. I audited every .solve() override in linear_operator (zero, triangular, low_rank_root_added_diag, chol, diag, identity, kronecker_product for triangular). Of those, only LowRankRootAddedDiagLinearOperator is reachable as lik_train_train_covar through DefaultPredictionStrategy in current gpytorch, and that's the SGPR path. Everything else is either non-PSD (triangular, chol) or never appears as train-train covar (diag, identity, etc.).
  • Auto-detection has wrapper-composition holes. BatchRepeatLinearOperator(LowRankRootAddedDiagLinearOperator), ConstantMulLinearOperator(LowRankRoot), etc. would have a structurally-special inner but a wrapper class that inherits LinearOperator.solve. The shallow type check says "no structured solve" and forces a dense Cholesky. None of these compositions occur in current gpytorch — the one that could (SGPR fantasy updates producing BatchRepeat(LowRankRoot)) is blocked because SGPRPredictionStrategy.get_fantasy_strategy raises NotImplementedError — but a shallow auto-detector pretends to generalize while actually missing these cases.
  • Kronecker doesn't need defending. KroneckerProductAddedDiagLinearOperator doesn't override .solve(), only _solve, and _solve is dead through the dispatch. Our patch and baseline produce identical Cholesky sizes for multitask models; no regression to work around.
  • Explicit is safer than clever. A reader looking at SGPRPredictionStrategy sees the override and knows what it does. Auto-detection + a cached _evaluated_lik_train_train_covar property inflates DefaultPredictionStrategy to defend against a case that doesn't exist.

If a future operator overrides .solve() for Woodbury-like reasons and a user wires it into a DefaultPredictionStrategy (i.e., they don't subclass the prediction strategy themselves), they'd fall back to the dense Cholesky path with no correctness issue — just the same type of perf regression SGPR would have hit without a dedicated override. The _solve_lik_train_train helper is a deliberate one-line extension point for that situation, and the SGPR override is a worked example.

BO loop impact (aggregate)

Step Dominant op touched? Verified impact
Fit model No 0%
First acquisition build (model.posterior(X)) Yes −13% to −30%
Analytic acqf opt (EI/UCB/PI) — warm L-BFGS-B Yes −10% to −24% at n≤100; neutral at n≥500
MC acqf opt — cold posterior(X).rsample().sum().backward() Yes −10% to −27% across all configs
MC acqf opt — warm inner loop Partially −12% to −24% at n≤100; neutral at large n
rsample internals (posterior-covar Cholesky, L @ noise) No unchanged

No significant regressions at 10-seed precision. Biggest wins on the cold posterior path (first acquisition build, fresh refit) and on batched models where the old per-call batched Cholesky was expensive.

Out of scope

  • CPU-only benchmarks. GPU behavior may differ but the identity is dtype/device-agnostic.
  • SGPRPredictionStrategy, LinearPredictionStrategy, and InterpolatedPredictionStrategy override exact_predictive_covar and don't inherit the change; their existing tests all pass.

Test plan

  • python -m pytest test/models/test_exact_gp.py -q59 passed
  • python -m pytest test/mlls/test_exact_marginal_log_likelihood.py test/distributions/ -q48 passed
  • python -m pytest test/examples/ -q101 passed (covers SGPR, all Kronecker/Hadamard/LCM/batched multitask variants)
  • TestExactPredictiveCovar.test_posterior_matches_math_reference — passes (atol=1e-10 on values, 1e-9 on gradients)
  • Mutation test: flipping sign of the correction term correctly fails the new parity test (cov diff ≈ 2.4)
  • test_sgpr_mean_abs_error — passes after the follow-up commit; asserts every Cholesky during SGPR inference is of inducing (k) size
  • flake8 clean on touched files

Benchmark setup

  • Hardware/dtype: CPU, float64.
  • Model: BoTorch SingleTaskGP with Normalize(d) + Standardize(m=1) transforms. botorch is imported for side-effect, which disables linear-operator fast paths (_fast_solves=False, _fast_covar_root_decomposition=False, _fast_log_prob=False, max_cholesky_size=4096). Matches the setting BoTorch users actually see.
  • Training: 30 Adam steps at lr=5e-3 (only to produce non-init hyperparameters — train_final_loss is bit-identical between baseline and prototype).
  • Configs: non-batched n ∈ {20, 100, 500, 1000}, d ∈ {2, 5}; batched B ∈ {16, 32, 64}, n ∈ {50, 100}, d ∈ {4, 5}. Test points m = 128.
  • Test shapes: for each config, a non-batched test input (m, d) and an MC-batch variant. Batched models also include a test input matched to the model batch ((B, m, d)).
  • Measurements (averaged over 5 seeds at 15 iters each, with 3 warmup iters; * = |Δ| > 3% AND exceeds combined stdev):
    • first_posterior_ms — invalidate prediction_strategy each call, then one model.posterior(X) under no_grad.
    • repeat_posterior_ms — reuse prediction_strategy, fresh X each call, under no_grad.
    • warm_rsample_bwd_msprediction_strategy kept warm; p = model.posterior(X); s = p.rsample(Size([16])); s.sum().backward() with X.requires_grad=True.
    • cold_rsample_bwd_ms — invalidate prediction_strategy each call, then the same rsample+backward. Primary BoTorch MC acquisition pattern.
  • Additional 10-seed verification on nb_n1000_d5 and b64_n100_d5 to tighten confidence on the two regimes with the largest measurement noise.

Detailed results — cold_rsample_bwd_ms (primary BoTorch pattern, 5 seeds)

config test shape base (ms) proto (ms) Δ
nb_n20_d2 128×2 2.00 ± 0.19 1.63 ± 0.15 −18.7%*
nb_n20_d2 5×128×2 4.76 ± 0.47 3.82 ± 0.29 −19.6%*
nb_n20_d2 32×128×2 16.74 ± 1.17 14.45 ± 0.55 −13.6%*
nb_n100_d5 128×5 2.74 ± 0.23 2.10 ± 0.11 −23.3%*
nb_n100_d5 5×128×5 6.86 ± 0.40 5.46 ± 0.46 −20.4%*
nb_n100_d5 32×128×5 20.79 ± 1.13 17.82 ± 0.40 −14.3%*
nb_n500_d5 128×5 6.58 ± 0.88 5.34 ± 0.45 −18.8%
nb_n500_d5 5×128×5 11.59 ± 1.36 10.66 ± 0.63 −8.0%
nb_n500_d5 32×128×5 32.73 ± 1.50 31.77 ± 2.09 −2.9%
nb_n1000_d5 128×5 19.85 ± 1.88 14.53 ± 0.37 −26.8%*
nb_n1000_d5 5×128×5 28.38 ± 1.48 23.80 ± 1.14 −16.1%*
nb_n1000_d5 32×128×5 86.46 ± 2.29 77.65 ± 1.94 −10.2%*
b16_n50_d4 128×4 14.81 ± 1.28 12.05 ± 1.63 −18.6%
b16_n50_d4 16×128×4 14.69 ± 1.16 11.85 ± 1.20 −19.4%*
b32_n50_d4 128×4 21.60 ± 1.00 20.00 ± 0.65 −7.4%
b32_n50_d4 32×128×4 21.99 ± 1.40 20.52 ± 0.57 −6.7%
b64_n50_d4 128×4 37.84 ± 1.81 34.15 ± 1.54 −9.8%*
b64_n50_d4 64×128×4 38.02 ± 2.14 34.90 ± 0.85 −8.2%*
b16_n100_d5 128×5 17.79 ± 1.29 15.58 ± 1.47 −12.4%
b16_n100_d5 16×128×5 18.06 ± 0.18 15.32 ± 0.84 −15.2%*
b32_n100_d5 128×5 25.89 ± 2.41 24.92 ± 0.73 −3.7%
b32_n100_d5 32×128×5 25.09 ± 2.12 25.32 ± 0.97 +0.9%
b64_n100_d5 128×5 47.85 ± 2.13 45.94 ± 1.62 −4.0%
b64_n100_d5 64×128×5 43.74 ± 3.44 45.50 ± 1.01 +4.0%
Detailed results — warm_rsample_bwd_ms (steady-state L-BFGS-B iteration, 5 seeds)
config test shape base (ms) proto (ms) Δ
nb_n20_d2 128×2 1.64 ± 0.32 1.30 ± 0.12 −20.5%
nb_n20_d2 5×128×2 3.87 ± 0.44 3.47 ± 0.39 −10.3%
nb_n20_d2 32×128×2 15.76 ± 1.59 13.99 ± 0.73 −11.2%
nb_n100_d5 128×5 1.87 ± 0.19 1.43 ± 0.12 −23.7%*
nb_n100_d5 5×128×5 5.58 ± 0.42 4.44 ± 0.11 −20.4%*
nb_n100_d5 32×128×5 19.90 ± 1.22 16.80 ± 0.37 −15.6%*
nb_n500_d5 128×5 2.33 ± 0.38 2.29 ± 0.25 −1.7%
nb_n500_d5 5×128×5 7.32 ± 0.72 7.63 ± 0.76 +4.2%
nb_n500_d5 32×128×5 27.78 ± 2.08 27.73 ± 2.31 −0.2%
nb_n1000_d5 128×5 3.96 ± 0.71 3.95 ± 0.49 −0.2%
nb_n1000_d5 5×128×5 14.17 ± 1.67 14.07 ± 0.83 −0.7%
nb_n1000_d5 32×128×5 69.84 ± 2.98 66.49 ± 2.32 −4.8%
b16_n50_d4 128×4 12.14 ± 1.31 10.39 ± 1.36 −14.4%
b16_n50_d4 16×128×4 12.48 ± 1.20 10.15 ± 1.09 −18.6%*
b32_n50_d4 128×4 18.39 ± 0.92 18.01 ± 0.64 −2.1%
b32_n50_d4 32×128×4 19.26 ± 1.25 18.07 ± 0.66 −6.1%
b64_n50_d4 128×4 32.61 ± 1.06 31.16 ± 0.72 −4.5%
b64_n50_d4 64×128×4 32.91 ± 1.65 31.46 ± 1.49 −4.4%
b16_n100_d5 128×5 14.10 ± 1.87 12.28 ± 0.69 −12.9%
b16_n100_d5 16×128×5 14.98 ± 1.88 12.83 ± 0.75 −14.3%
b32_n100_d5 128×5 20.48 ± 1.74 19.95 ± 0.80 −2.6%
b32_n100_d5 32×128×5 20.07 ± 2.21 20.91 ± 1.01 +4.2%
b64_n100_d5 128×5 38.01 ± 0.95 38.90 ± 1.53 +2.3%
b64_n100_d5 64×128×5 34.75 ± 2.59 38.31 ± 0.94 +10.3%* (5-seed; not reproduced at 10 seeds, see below)
Detailed results — first_posterior_ms and repeat_posterior_ms (5 seeds)
config test shape first: base first: proto first: Δ repeat: base repeat: proto repeat: Δ
nb_n20_d2 128×2 1.03 ± 0.15 0.74 ± 0.05 −27.7%* 0.47 ± 0.07 0.39 ± 0.02 −17.4%
nb_n20_d2 5×128×2 1.35 ± 0.06 1.13 ± 0.04 −16.0%* 0.85 ± 0.05 0.78 ± 0.04 −8.6%
nb_n20_d2 32×128×2 3.95 ± 0.16 3.38 ± 0.15 −14.4%* 3.33 ± 0.04 2.93 ± 0.14 −12.2%*
nb_n100_d5 128×5 1.57 ± 0.30 0.95 ± 0.08 −39.5%* 0.55 ± 0.08 0.41 ± 0.02 −24.5%*
nb_n100_d5 5×128×5 2.25 ± 0.20 1.58 ± 0.06 −30.0%* 1.08 ± 0.04 0.96 ± 0.01 −11.3%*
nb_n100_d5 32×128×5 5.78 ± 0.38 4.69 ± 0.23 −18.7%* 4.56 ± 0.28 3.78 ± 0.11 −17.1%*
nb_n500_d5 128×5 4.55 ± 0.54 3.40 ± 0.34 −25.3%* 0.68 ± 0.08 0.65 ± 0.07 −3.9%
nb_n500_d5 5×128×5 6.12 ± 0.51 4.79 ± 0.32 −21.8%* 1.98 ± 0.11 1.90 ± 0.08 −3.8%
nb_n500_d5 32×128×5 12.56 ± 0.62 11.72 ± 0.77 −6.6% 7.75 ± 0.96 8.51 ± 1.49 +9.8%
nb_n1000_d5 128×5 16.42 ± 1.28 10.92 ± 0.48 −33.5%* 1.11 ± 0.20 1.09 ± 0.11 −2.4%
nb_n1000_d5 5×128×5 19.58 ± 1.52 14.11 ± 0.42 −27.9%* 4.46 ± 0.41 4.31 ± 0.35 −3.3%
nb_n1000_d5 32×128×5 37.86 ± 1.17 32.30 ± 1.41 −14.7%* 22.50 ± 0.73 21.84 ± 0.96 −2.9%
b16_n50_d4 128×4 4.24 ± 0.51 3.21 ± 0.18 −24.3%* 2.57 ± 0.42 2.16 ± 0.19 −16.0%
b16_n50_d4 16×128×4 4.66 ± 0.56 3.47 ± 0.62 −25.5%* 2.82 ± 0.24 2.23 ± 0.32 −20.9%*
b32_n50_d4 128×4 6.78 ± 0.50 5.45 ± 0.35 −19.7%* 3.78 ± 0.37 3.70 ± 0.25 −2.0%
b32_n50_d4 32×128×4 6.51 ± 0.70 5.59 ± 0.30 −14.1% 4.12 ± 0.34 3.86 ± 0.16 −6.4%
b64_n50_d4 128×4 10.78 ± 0.63 9.22 ± 0.21 −14.5%* 7.38 ± 0.66 6.85 ± 0.28 −7.2%
b64_n50_d4 64×128×4 11.35 ± 0.37 9.60 ± 1.07 −15.5%* 7.50 ± 0.26 7.20 ± 0.40 −3.9%
b16_n100_d5 128×5 6.31 ± 0.66 4.40 ± 0.39 −30.2%* 2.97 ± 0.40 2.70 ± 0.20 −9.0%
b16_n100_d5 16×128×5 6.33 ± 0.51 4.78 ± 0.37 −24.5%* 3.06 ± 0.25 2.73 ± 0.27 −11.0%
b32_n100_d5 128×5 10.07 ± 1.24 8.74 ± 0.58 −13.3% 4.76 ± 0.75 4.53 ± 0.20 −4.7%
b32_n100_d5 32×128×5 10.12 ± 0.98 8.63 ± 0.66 −14.7% 4.65 ± 0.42 4.54 ± 0.25 −2.4%
b64_n100_d5 128×5 17.77 ± 1.31 15.99 ± 0.65 −10.1% 8.79 ± 0.40 8.97 ± 0.63 +2.1%
b64_n100_d5 64×128×5 19.57 ± 0.54 16.48 ± 1.12 −15.8%* 9.16 ± 0.45 9.20 ± 0.54 +0.4%
10-seed reverification (tightening confidence on large-n and B=64 rsample paths)

The 5-seed run flagged two marginal concerns: a +10.3%* warm regression at b64_n100_d5 64×128×5 and a consistent-direction ~+5% drift in warm_rsample at nb_n1000_d5. Rerunning these configs at 10 seeds:

config test shape metric base (ms) proto (ms) Δ
nb_n1000_d5 128×5 first_posterior 15.15 ± 0.84 11.12 ± 0.55 −26.6%*
nb_n1000_d5 128×5 warm_rsample_bwd 3.75 ± 0.34 3.70 ± 0.35 −1.2%
nb_n1000_d5 128×5 cold_rsample_bwd 17.87 ± 0.70 14.36 ± 0.69 −19.6%*
nb_n1000_d5 5×128×5 first_posterior 18.12 ± 0.59 13.93 ± 0.63 −23.1%*
nb_n1000_d5 5×128×5 warm_rsample_bwd 13.38 ± 0.83 13.40 ± 0.52 +0.2%
nb_n1000_d5 5×128×5 cold_rsample_bwd 27.38 ± 0.78 23.43 ± 0.80 −14.4%*
nb_n1000_d5 32×128×5 first_posterior 34.75 ± 0.81 31.97 ± 1.08 −8.0%*
nb_n1000_d5 32×128×5 warm_rsample_bwd 61.53 ± 2.12 63.29 ± 2.36 +2.9%
nb_n1000_d5 32×128×5 cold_rsample_bwd 75.27 ± 2.28 72.79 ± 2.51 −3.3%
b64_n100_d5 128×5 warm_rsample_bwd 36.93 ± 1.44 37.30 ± 1.37 +1.0%
b64_n100_d5 128×5 cold_rsample_bwd 47.12 ± 2.34 44.78 ± 1.63 −5.0%
b64_n100_d5 64×128×5 warm_rsample_bwd 37.13 ± 1.59 37.77 ± 0.81 +1.7% (was +10.3%* at 5 seeds — noise)
b64_n100_d5 64×128×5 cold_rsample_bwd 45.51 ± 2.51 44.02 ± 1.32 −3.3%

Both flagged concerns resolve: the 5-seed +10.3%* warm regression at b64_n100 64×128 collapses to +1.7% (not significant), and n=1000 warm paths are neutral across all shapes while cold paths show large wins.

The Cholesky factor of the train-train covariance was being recomputed
on every posterior call. LinearOperator.solve dispatches through
Solve.apply, which reconstructs the lazy operator via representation_tree
and drops the @cached(cholesky) state on AddedDiagLinearOperator.

Materialize it once as a TriangularLinearOperator cached on the prediction
strategy, share across mean_cache and exact_predictive_covar, and switch
the covariance correction to a single triangular solve
(K(t,T) K_XX^-1 K(T,t) = M^T M with M = L^-1 K(T,t)), resolving the
existing TODO about using one triangular solve for dense inference.

Adds TestExactPredictiveCovar checking posterior mean, covariance, and
gradients w.r.t. test_x against a from-scratch math reference
(float64, atol=1e-10 / 1e-9).
SGPR's train-train covariance is a LowRankRootAddedDiagLinearOperator that
overrides .solve() to use a Woodbury identity with an inducing-size Cholesky
(k x k instead of n x n). The Cholesky-sharing optimization in the initial
commit bypassed this: _mean_cache called self.lik_train_train_chol, which
routes through the default LinearOperator._cholesky and materializes a
dense n x n factor — breaking test_sgpr_mean_abs_error's assertion that
every Cholesky during SGPR inference is of inducing size.

Extract the "ignore" nan_policy solve into a _solve_lik_train_train helper
on DefaultPredictionStrategy, then override it in SGPRPredictionStrategy to
dispatch to .evaluate_kernel().solve(), which uses the Woodbury path.
exact_predictive_covar is unchanged because SGPRPredictionStrategy overrides
that method entirely and never reaches the patched code path.
@saitcakmak saitcakmak requested review from Balandat and gpleiss April 20, 2026 18:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant