-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathFiniteBasisModule.jl
More file actions
76 lines (59 loc) · 2.1 KB
/
FiniteBasisModule.jl
File metadata and controls
76 lines (59 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
module FiniteBasisModule
using KernelFunctions,LinearAlgebra, AbstractGPs, Random
import AbstractGPs: AbstractGP, FiniteGP
import Statistics
import ChainRulesCore
struct FiniteBasis <: KernelFunctions.SimpleKernel end
KernelFunctions.kappa(::FiniteBasis, d::Real) = d
KernelFunctions.metric(::FiniteBasis) = KernelFunctions.DotProduct()
struct DegeneratePosterior{P,T,C} <: AbstractGP
prior::P
w_mean::T
w_prec::C
end
weight_form(A::KernelFunctions.ColVecs) = A.X'
weight_form(A::KernelFunctions.RowVecs) = A.X
function AbstractGPs.posterior(fx::FiniteGP{GP{M, B}}, y::AbstractVector{<:Real}) where {M, B <: FiniteBasis}
kern = fx.f.kernel
δ = y - mean(fx)
X = weight_form(fx.x)
X_prec = X' * inv(fx.Σy)
Λμ = X_prec * y
prec = cholesky(I + Symmetric(X_prec * X))
w = prec \ Λμ
DegeneratePosterior(fx.f, w, prec)
end
function Statistics.mean(f::DegeneratePosterior, x::AbstractVector)
w = f.w_mean
X = weight_form(x)
X * w
end
function Statistics.cov(f::DegeneratePosterior, x::AbstractVector)
X = weight_form(x)
AbstractGPs.Xt_invA_X(f.w_prec, X')
end
function Statistics.cov(f::DegeneratePosterior, x::AbstractVector, y::AbstractVector)
X = weight_form(x)
Y = weight_form(y)
AbstractGPs.Xt_invA_Y(X', f.w_prec, Y')
end
function Statistics.var(f::DegeneratePosterior, x::AbstractVector)
X = weight_form(x)
AbstractGPs.diag_Xt_invA_X(f.w_prec, X')
end
function Statistics.rand(rng::AbstractRNG, f::DegeneratePosterior, x::AbstractVector)
w = f.w_mean
X = weight_form(x)
X * (f.w_prec.U \ randn(rng, length(x)))
end
struct RandomFourierFeature
ws::Vector{Float64}
end
RandomFourierFeature(kern::SqExponentialKernel, k::Int) = RandomFourierFeature(randn(k))
RandomFourierFeature(rng::AbstractRNG, kern::SqExponentialKernel, k::Int) = RandomFourierFeature(randn(rng, k))
FFApprox(kern::Kernel, k::Int) = FiniteBasis() ∘ FunctionTransform(RandomFourierFeature(kern, k))
FFApprox(rng::AbstractRNG, kern::Kernel, k::Int) = FiniteBasis() ∘ FunctionTransform(RandomFourierFeature(rng, kern, k))
function (f::RandomFourierFeature)(x)
Float64[cos.(f.ws .* x); sin.(f.ws .* x)] .* sqrt(2/length(f.ws))
end
end