Skip to content

Commit 1b7c135

Browse files
authored
Clarify API for GP approximations (#361)
* API docstrings for the base forms of `posterior` and `approx_log_evidence` * DTC as separate type * ExactInference for fallback forms * deprecations * bump version to 0.5.17
1 parent 3e5f0a5 commit 1b7c135

7 files changed

Lines changed: 214 additions & 153 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AbstractGPs"
22
uuid = "99985d1d-32ba-4be9-9821-2ec096f28918"
33
authors = ["JuliaGaussianProcesses Team"]
4-
version = "0.5.16"
4+
version = "0.5.17"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/AbstractGPs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ export rand!,
2727
mean_vector,
2828
marginals,
2929
logpdf,
30+
approx_log_evidence,
3031
elbo,
3132
dtc,
3233
posterior,

src/abstract_gp.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,30 @@ for (m, f) in [
8585
)
8686
end
8787
end
88+
89+
"""
90+
approx_log_evidence(approx::<Approximation>, lfx::LatentFiniteGP, ys)
91+
92+
Compute an approximation to the log of the marginal likelihood (also known as
93+
"evidence") under the given `approx`imation to the posterior. The return value
94+
of `approx_log_evidence` can be used to optimise the hyperparameters of `lfx`.
95+
"""
96+
function approx_log_evidence end
97+
98+
"""
99+
posterior(fx::FiniteGP, y::AbstractVector{<:Real})
100+
posterior(approx::<Approximation>, fx::FiniteGP, y::AbstractVector{<:Real})
101+
posterior(approx::<Approximation>, lfx::LatentFiniteGP, y::AbstractVector)
102+
103+
Construct the posterior distribution over the latent Gaussian process (`fx.f`
104+
or `lfx.fx.f`), given the observations `y` corresponding to the process's
105+
finite projection (`fx` or `lfx`).
106+
107+
In the two-argument form, this describes exact GP regression with `y` observed
108+
under a Gaussian likelihood, and returns a `PosteriorGP`.
109+
110+
In the three-argument form, the first argument specifies the approximation to
111+
be used (e.g. `VFE` or defined in other packages such as ApproximateGPs.jl),
112+
and returns an `ApproxPosteriorGP`.
113+
"""
114+
function posterior end

src/deprecations.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@
55
@deprecate sampleplot!(plt::RecipesBase.AbstractPlot, gp::FiniteGP, n::Int; kwargs...) sampleplot!(
66
plt, gp; samples=n, kwargs...
77
)
8+
9+
@deprecate elbo(dtc::DTC, fx, y) approx_log_evidence(dtc, fx, y)
10+
@deprecate dtc(vfe::Union{VFE,DTC}, fx, y) approx_log_evidence(vfe, fx, y)

src/exact_gpr_posterior.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,18 @@ struct PosteriorGP{Tprior,Tdata} <: AbstractGP
33
data::Tdata
44
end
55

6+
struct ExactInference end
7+
8+
posterior(::ExactInference, fx::FiniteGP, y::AbstractVector{<:Real}) = posterior(fx, y)
9+
10+
function approx_log_evidence(::ExactInference, fx::FiniteGP, y::AbstractVector{<:Real})
11+
return logpdf(fx, y)
12+
end
13+
614
"""
715
posterior(fx::FiniteGP, y::AbstractVector{<:Real})
816
9-
Construct the posterior distribution over `fx.f` given observations `y` at `x` made under
17+
Construct the posterior distribution over `fx.f` given observations `y` at `fx.x` made under
1018
noise `fx.Σy`. This is another `AbstractGP` object. See chapter 2 of [1] for a recap on
1119
exact inference in GPs. This posterior process has mean function
1220
```julia

src/sparse_approximations.jl

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@ struct VFE{Tfz<:FiniteGP}
1313
fz::Tfz
1414
end
1515

16-
const DTC = VFE
16+
"""
17+
DTC(fz::FiniteGP)
18+
19+
Similar to `VFE`, but uses a different objective for `approx_log_evidence`.
20+
"""
21+
struct DTC{Tfz<:FiniteGP}
22+
fz::Tfz
23+
end
1724

1825
struct ApproxPosteriorGP{Tapprox,Tprior,Tdata} <: AbstractGP
1926
approx::Tapprox
@@ -48,7 +55,7 @@ true
4855
processes". In: Proceedings of the Twelfth International Conference on Artificial
4956
Intelligence and Statistics. 2009.
5057
"""
51-
function posterior(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
58+
function posterior(vfe::Union{VFE,DTC}, fx::FiniteGP, y::AbstractVector{<:Real})
5259
@assert vfe.fz.f === fx.f
5360

5461
U_y = _cholesky(_symmetric(fx.Σy)).U
@@ -69,7 +76,7 @@ end
6976

7077
"""
7178
function update_posterior(
72-
f_post_approx::ApproxPosteriorGP{<:VFE},
79+
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}},
7380
fx::FiniteGP,
7481
y::AbstractVector{<:Real}
7582
)
@@ -78,7 +85,9 @@ Update the `ApproxPosteriorGP` given a new set of observations. Here, we retain
7885
set of pseudo-points.
7986
"""
8087
function update_posterior(
81-
f_post_approx::ApproxPosteriorGP{<:VFE}, fx::FiniteGP, y::AbstractVector{<:Real}
88+
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}},
89+
fx::FiniteGP,
90+
y::AbstractVector{<:Real},
8291
)
8392
@assert f_post_approx.prior === fx.f
8493

@@ -111,14 +120,14 @@ end
111120

112121
"""
113122
function update_posterior(
114-
f_post_approx::ApproxPosteriorGP{<:VFE},
123+
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}},
115124
z::FiniteGP,
116125
)
117126
118127
Update the `ApproxPosteriorGP` given a new set of pseudo-points to append to the existing
119128
set of pseudo-points.
120129
"""
121-
function update_posterior(f_post_approx::ApproxPosteriorGP{<:VFE}, fz::FiniteGP)
130+
function update_posterior(f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}}, fz::FiniteGP)
122131
@assert f_post_approx.prior === fz.f
123132

124133
z_old = inducing_points(f_post_approx)
@@ -161,48 +170,56 @@ function update_posterior(f_post_approx::ApproxPosteriorGP{<:VFE}, fz::FiniteGP)
161170
x=f_post_approx.data.x,
162171
Σy=f_post_approx.data.Σy,
163172
)
164-
return ApproxPosteriorGP(VFE(fz_new), f_post_approx.prior, cache)
173+
return ApproxPosteriorGP(
174+
_update_approx(f_post_approx.approx, fz_new), f_post_approx.prior, cache
175+
)
165176
end
166177

178+
_update_approx(vfe::VFE, fz_new::FiniteGP) = VFE(fz_new)
179+
_update_approx(dtc::DTC, fz_new::FiniteGP) = DTC(fz_new)
180+
167181
# AbstractGP interface implementation.
168182

169-
function Statistics.mean(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
183+
function Statistics.mean(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
170184
return mean(f.prior, x) + cov(f.prior, x, inducing_points(f)) * f.data.α
171185
end
172186

173-
function Statistics.cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
187+
function Statistics.cov(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
174188
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
175189
return cov(f.prior, x) - At_A(A) + Xt_invA_X(f.data.Λ_ε, A)
176190
end
177191

178-
function Statistics.var(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
192+
function Statistics.var(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
179193
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
180194
return var(f.prior, x) - diag_At_A(A) + diag_Xt_invA_X(f.data.Λ_ε, A)
181195
end
182196

183-
function Statistics.cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector, y::AbstractVector)
197+
function Statistics.cov(
198+
f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector, y::AbstractVector
199+
)
184200
A_zx = f.data.U' \ cov(f.prior, inducing_points(f), x)
185201
A_zy = f.data.U' \ cov(f.prior, inducing_points(f), y)
186202
return cov(f.prior, x, y) - A_zx'A_zy + Xt_invA_Y(A_zx, f.data.Λ_ε, A_zy)
187203
end
188204

189-
function StatsBase.mean_and_cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
205+
function StatsBase.mean_and_cov(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
190206
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
191207
m_post = mean(f.prior, x) + A' * f.data.m_ε
192208
C_post = cov(f.prior, x) - At_A(A) + Xt_invA_X(f.data.Λ_ε, A)
193209
return m_post, C_post
194210
end
195211

196-
function StatsBase.mean_and_var(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
212+
function StatsBase.mean_and_var(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
197213
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
198214
m_post = mean(f.prior, x) + A' * f.data.m_ε
199215
c_post = var(f.prior, x) - diag_At_A(A) + diag_Xt_invA_X(f.data.Λ_ε, A)
200216
return m_post, c_post
201217
end
202218

203-
inducing_points(f::ApproxPosteriorGP{<:VFE}) = f.approx.fz.x
219+
inducing_points(f::ApproxPosteriorGP{<:Union{VFE,DTC}}) = f.approx.fz.x
204220

205221
"""
222+
approx_log_evidence(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
206223
elbo(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
207224
208225
The Titsias Evidence Lower BOund (ELBO) [1]. `y` are observations of `fx`, and `v.z`
@@ -228,14 +245,16 @@ true
228245
processes". In: Proceedings of the Twelfth International Conference on Artificial
229246
Intelligence and Statistics. 2009.
230247
"""
231-
function elbo(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
248+
function approx_log_evidence(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
232249
@assert vfe.fz.f === fx.f
233-
_dtc, A = _compute_intermediates(fx, y, vfe.fz)
234-
return _dtc - (tr_Cf_invΣy(fx, fx.Σy) - sum(abs2, A)) / 2
250+
dtc_objective, A = _compute_intermediates(fx, y, vfe.fz)
251+
return dtc_objective - (tr_Cf_invΣy(fx, fx.Σy) - sum(abs2, A)) / 2
235252
end
236253

254+
elbo(vfe::VFE, fx, y) = approx_log_evidence(vfe, fx, y)
255+
237256
"""
238-
dtc(v::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
257+
approx_log_evidence(dtc::DTC, fx::FiniteGP, y::AbstractVector{<:Real})
239258
240259
The Deterministic Training Conditional (DTC) [1]. `y` are observations of `fx`, and `v.z`
241260
are inducing points.
@@ -248,25 +267,25 @@ julia> x = randn(1000);
248267
249268
julia> z = range(-5.0, 5.0; length=256);
250269
251-
julia> v = VFE(f(z));
270+
julia> d = DTC(f(z));
252271
253272
julia> y = rand(f(x, 0.1));
254273
255-
julia> isapprox(dtc(v, f(x, 0.1), y), logpdf(f(x, 0.1), y); atol=1e-6, rtol=1e-6)
274+
julia> isapprox(approx_log_evidence(d, f(x, 0.1), y), logpdf(f(x, 0.1), y); atol=1e-6, rtol=1e-6)
256275
true
257276
```
258277
259278
[1] - M. Seeger, C. K. I. Williams and N. D. Lawrence. "Fast Forward Selection to Speed Up
260279
Sparse Gaussian Process Regression". In: Proceedings of the Ninth International Workshop on
261280
Artificial Intelligence and Statistics. 2003
262281
"""
263-
function dtc(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
264-
@assert vfe.fz.f === fx.f
265-
_dtc, _ = _compute_intermediates(fx, y, vfe.fz)
266-
return _dtc
282+
function approx_log_evidence(dtc::DTC, fx::FiniteGP, y::AbstractVector{<:Real})
283+
@assert dtc.fz.f === fx.f
284+
dtc_objective, _ = _compute_intermediates(fx, y, dtc.fz)
285+
return dtc_objective
267286
end
268287

269-
# Factor out computations common to the `elbo` and `dtc`.
288+
# Factor out computations of `approx_log_evidence` common to `VFE` and `DTC`
270289
function _compute_intermediates(fx::FiniteGP, y::AbstractVector{<:Real}, fz::FiniteGP)
271290
length(fx) == length(y) || throw(
272291
DimensionMismatch(

0 commit comments

Comments
 (0)