Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
KernelSpectralDensities = "027d52a2-76e5-4228-9bfe-bc7e0f5a8348"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this into a package extension?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, it has been a minute.

The short answer is that I am not sure. On the one hand, it would make sense to keep everything clean, but

  • KernelSpectralDensities is a tiny package
  • The main reason it exists is to enable decoupled sampling with Fourier basis functions, which (AFAIK) still the state of the art for sampling GPs
  • (the main uncertainty for me) I am not sure where to put some of the objects, since they only make sense when using the combination of both packages.

LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -24,6 +25,7 @@ Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
FillArrays = "0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
IrrationalConstants = "0.1, 0.2"
KernelFunctions = "0.9, 0.10"
KernelSpectralDensities = "0.2.0"
LinearAlgebra = "1"
PDMats = "0.11"
Random = "1"
Expand Down
5 changes: 5 additions & 0 deletions src/AbstractGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using RecipesBase
using IrrationalConstants: log2π

using KernelFunctions: ColVecs, RowVecs
using KernelSpectralDensities

using ChainRulesCore: ChainRulesCore

Expand All @@ -33,6 +34,7 @@ export rand!,
posterior,
update_posterior
export ColVecs, RowVecs
export GPSampler, CholeskySampling, Conditional, Independent, RFFSampling, PathwiseSampling

# Various bits of utility functionality.
include("util/common_covmat_ops.jl")
Expand All @@ -56,6 +58,9 @@ include("sparse_approximations.jl")
# LatentGP and LatentFiniteGP objects to accommodate GPs with non-Gaussian likelihoods.
include("latent_gp.jl")

# Different sampling methods
include("sampling.jl")

# Plotting utilities.
include("util/plotting.jl")

Expand Down
316 changes: 316 additions & 0 deletions src/sampling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
abstract type AbstractGPSamplingMethod end

SeedableRNG = Union{Xoshiro,MersenneTwister}

_rand(rng, d) = Random.rand(rng, d)
function _rand(rng::AbstractRNG, ::Type{T}) where {T<:SeedableRNG}
return T(Random.rand(rng, 1:typemax(Int)))
end

# ## Interface

struct GPSample{F,S}
fun::F
sample::S
end

(gs::GPSample)(x) = eval_at(gs.fun, gs.sample, x)

# This may become more challenging once we extend to multi-input GPS
(gs::GPSample)(x::Number) = only(gs([x]))
(gs::GPSample)(x::Tuple{T,Int}) where {T} = only(eval_at(gs.fun, gs.sample, x))

"""
GPSampler(gp::AbstractGPs.AbstractGP, method::AbstractGPSamplingMethod)
Creates a sampler for the given `gp` using the specified `method`.

```jldoctest
julia> f = GP(Matern32Kernel());

julia> gps = GPSampler(f, CholeskySampling());

julia> rand(gps);
```
"""
struct GPSampler{F,S} <: Random.Sampler{GPSample}
fun::F
sampler::S

# Specify input types here, since it is a "public" interface
function GPSampler(
gp::AbstractGPs.AbstractGP, method::AbstractGPSamplingMethod; dims=Val(:auto)
)
fun, sampler = method(gp, dims)
return new{typeof(fun),typeof(sampler)}(fun, sampler)
end
end

# Don't love the deepcopy here
# issue is "pass by sharing" and the mutable struct in CholeskySampling
function Random.rand(rng::AbstractRNG, gs::GPSampler)
return GPSample(deepcopy(gs.fun), _rand(rng, gs.sampler))
end

# ## Utils

_get_prior(gp::AbstractGPs.GP) = gp
_get_prior(pgp::AbstractGPs.PosteriorGP) = pgp.prior

function get_obs_variance(pgp::AbstractGPs.PosteriorGP)
x = pgp.data.x[1]
σk = pgp.prior.kernel(x, x)
v = diag(pgp.data.C.L * pgp.data.C.U) .- σk
return max.(v, default_σ²)
end

function get_target_prior(pgp::AbstractGPs.PosteriorGP)
m = pgp.data.δ
σ2 = get_obs_variance(pgp)
return MvNormal(m, sqrt.(σ2))
end

#########################
# Function Space/ Cholesky

"""
CholeskySampling(s=Conditional, generator=Xoshiro)
Sampling by using the standard way, via Cholesky decomposition.
Arguments:
- `s`: Sampling type, either `Conditional` or `Independent`. Default is `Conditional`.
- `generator`: Random number generator to use in each sample. Default is `Xoshiro`.
"""
struct CholeskySampling{M,G} <: AbstractGPSamplingMethod
function CholeskySampling(s=Conditional, generator=Xoshiro)
return new{s,generator}()
end
end

function (cs::CholeskySampling{M,G})(gp, dims) where {M,G}
return M(gp), G
end

"""
Conditional
Generates a GP sample that conditions function samples on all previous samples.
"""
mutable struct Conditional{GPT<:AbstractGPs.AbstractGP}
gp::GPT
end

function Conditional(gp::AbstractGPs.GP)
data = (
α=Vector{Float64}(undef, 0),
C=Cholesky(UpperTriangular(Matrix{Float64}(undef, 0, 0))),
x=Vector{Float64}(undef, 0),
δ=Vector{Float64}(undef, 0),
)
pgp = AbstractGPs.PosteriorGP(gp, data)
return Conditional(pgp)
end

function eval_at(s::Conditional, rng, x::AbstractArray)
if isassigned(s.gp.data.x, 1)
pgp = s.gp
else
pgp = s.gp.prior
end
fgp = pgp(x)
y = rand(rng, fgp)
s.gp = posterior(fgp, y)
return y
end

"""
Independent
Generates a GP sample that samples function samples independent from previous samples.
"""
struct Independent{GPT<:AbstractGPs.AbstractGP}
gp::GPT
function Independent(gp)
return new{typeof(gp)}(gp)
end
end

function eval_at(s::Independent, rng, x::AbstractArray)
gp = s.gp
fgp = gp(x)
y = rand(rng, fgp)
return y
end

# ## WeightSpace

# ### Utils

_rand(rng, t::Tuple{Normal,Int}) = rand(rng, t[1], t[2])
get_weight_distribution(::AbstractGPs.GP, rff) = (Normal(), length(rff))

function get_weight_distribution(pgp::AbstractGPs.PosteriorGP, rff)
d = get_target_prior(pgp)

P = rff.(pgp.data.x)
Pt = reduce(hcat, P)
Cp = Symmetric(Pt * (d.Σ \ Pt') + I)
C = cholesky(Cp)

μ = C \ (Pt * (d.Σ \ d.μ))
Σ = C \ I
return MvNormal(μ, Symmetric(Σ))
end

# ### Main

"""
RFFSampling(l::Int, rff_type::Type{<:KernelSpectralDensities.AbstractRFF}=DoubleRFF)
Sampling by using Random Fourier Features.
Arguments:
- `l`: Number of random Fourier features to use.
- `rff_type`: Type of random Fourier features to use. Default is `DoubleRFF`.
"""
struct RFFSampling{RFF,RNG} <: AbstractGPSamplingMethod
l::Int
rng::RNG
function RFFSampling(
rng, l; rff_type::Type{<:KernelSpectralDensities.AbstractRFF}=DoubleRFF
)
return new{rff_type,typeof(rng)}(l, rng)
end
end

function RFFSampling(l; rff_type::Type{<:KernelSpectralDensities.AbstractRFF}=DoubleRFF)
return RFFSampling(Random.default_rng(), l; rff_type)
end

_extract_dims(x::AbstractVector{<:Tuple}) = (length(x[1][1]), x.out_dim)
_extract_dims(x::AbstractVector) = (length(x[1]), 1)

function determine_dims(pgp::AbstractGPs.PosteriorGP, ::Val{:auto})
d, p = _extract_dims(pgp.data.x)
return (d, p)
end

function determine_dims(pgp::AbstractGPs.PosteriorGP, dims)
det_dims = determine_dims(pgp, Val(:auto))
if det_sims == dims
return det_dims
else
throw(
ArgumentError(
"Specified dims $dims do not match dimensions inferred from data $(det_dims).",
),
)
end
end

function determine_dims(::AbstractGPs.AbstractGP, ::Val{:auto})
throw(
ArgumentError(
"Cannot determine input/output dimensions for a non-posterior GP. Please specify dims explicitly.",
),
)
end
determine_dims(::AbstractGPs.AbstractGP, dims) = dims

# currently no way to infer the input domain of a prior GP
# maybe additional optional arguments for the GPSampler?
function (rffs::RFFSampling{RFF})(gp::AbstractGPs.AbstractGP, dims) where {RFF}
prior = _get_prior(gp)
dims = determine_dims(gp, dims)

rff = sample_rff(rffs.rng, prior.kernel, rffs.l, dims...; rff_type=RFF)

ws = get_weight_distribution(gp, rff)

return rff, ws
end

function eval_at(rff::KernelSpectralDensities.AbstractRFF, w, x)
# return dot.(rff.(x), Ref(w))
return dot(rff(x), w)
end

function eval_at(rff::KernelSpectralDensities.AbstractMORFF, w, x)
# return dot.(rff.(x), Ref(w))
return rff(x) * w
end

# ## PathwiseSampler

# ### utils
struct KernelBasis{K}
ker::K
x::AbstractArray
end

(kb::KernelBasis)(x) = kb.ker.(Ref(x), kb.x)

function update_basis(pgp, cs::CholeskySampling)
ker = pgp.prior.kernel
x = pgp.data.x
return KernelBasis(ker, x)
end

function update_basis(pgp, rffs::RFFSampling)
rff, _ = rffs(pgp)

return rff
end

# ### Main

"""
PathwiseSampling(l::Int)
Sampling by using pathwise sampling, which uses RFF sampling for the prior and an update rule
based on the kernel. Takes as an input the number of random Fourier features `l` to use.
"""
struct PathwiseSampling{P,U} <: AbstractGPSamplingMethod
prior::P
update::U
end

function PathwiseSampling(l::Int)
return PathwiseSampling(RFFSampling(l), CholeskySampling())
end

struct PathwiseSampler{PS,TS,D}
prior_sampler::PS
target_sampler::TS
data::D
end

function (ps::PathwiseSampling)(pgp::AbstractGPs.PosteriorGP, dims)
upd_fun = update_basis(pgp, ps.update)

dims = determine_dims(pgp, dims)
prior = pgp.prior
prior_sampler = GPSampler(prior, ps.prior; dims)

target_dist = get_target_prior(pgp)

data = (C=pgp.data.C, x=pgp.data.x)
return upd_fun, PathwiseSampler(prior_sampler, target_dist, data)
end

function _rand(rng::AbstractRNG, ps::PathwiseSampler)
prior = rand(rng, ps.prior_sampler)
f = prior.(ps.data.x) # here
# display(f)

ts = rand(rng, ps.target_sampler)
# display(ts)

v = ps.data.C \ (ts - f)

return (prior=prior, v=v)
end

function eval_at(basis::KernelBasis, s, x)
# return s.prior(x) .+ dot.(basis.(x), Ref(s.v))
return s.prior(x) + dot(basis(x), s.v)
end

# ToDo: only to get results asap
function (gs::GPSample{<:KernelBasis{<:MOKernel}})(x::AbstractVector)
dims = gs.sample.prior.fun.pr
return eval_at.(Ref(gs.fun), Ref(gs.sample), MOInput([x], dims[2]))
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ include("test_util.jl")
println(" ")
@info "Ran latent_gp tests"

include("sampling.jl")
println(" ")
@info "Ran sampling tests"

include("deprecations.jl")
println(" ")
@info "Ran deprecation tests"
Expand Down
Loading
Loading