From 066df1bbe64985543e8b8e6dd7039ae64b755665 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Sun, 18 Feb 2024 22:02:15 -0600 Subject: [PATCH 1/3] Add finite basis approximation --- Project.toml | 2 ++ src/ApproximateGPs.jl | 4 +++ src/FiniteBasisModule.jl | 60 +++++++++++++++++++++++++++++++++++++++ test/FiniteBasisModule.jl | 35 +++++++++++++++++++++++ test/runtests.jl | 4 +++ 5 files changed, 105 insertions(+) create mode 100644 src/FiniteBasisModule.jl create mode 100644 test/FiniteBasisModule.jl diff --git a/Project.toml b/Project.toml index 02a02f68..bd707eec 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.4.5" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" +ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" @@ -12,6 +13,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" +KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" diff --git a/src/ApproximateGPs.jl b/src/ApproximateGPs.jl index 029be317..a3bc4de9 100644 --- a/src/ApproximateGPs.jl +++ b/src/ApproximateGPs.jl @@ -19,6 +19,10 @@ include("LaplaceApproximationModule.jl") @reexport using .LaplaceApproximationModule: build_laplace_objective, build_laplace_objective! +include("FiniteBasisModule.jl") +@reexport using .FiniteBasisModule: + RandomFourierFeature, FiniteBasis, DegeneratePosterior + include("deprecations.jl") include("TestUtils.jl") diff --git a/src/FiniteBasisModule.jl b/src/FiniteBasisModule.jl new file mode 100644 index 00000000..8062b296 --- /dev/null +++ b/src/FiniteBasisModule.jl @@ -0,0 +1,60 @@ +module FiniteBasisModule + +using KernelFunctions,LinearAlgebra, AbstractGPs, ArraysOfArrays +import AbstractGPs: AbstractGP, FiniteGP +import Statistics + +struct FiniteBasis{T} <: Kernel + ϕ::T +end + +(k::FiniteBasis)(x, y) = dot(k.ϕ(x), k.ϕ(y)) + +struct DegeneratePosterior{P,T,C} <: AbstractGP + prior::P + w_mean::T + w_prec::C +end + +weight_form(ϕ, x) = flatview(ArrayOfSimilarArrays(ϕ.(x)))' + +function AbstractGPs.posterior(fx::FiniteGP{GP{M, B}}, y::AbstractVector{<:Real}) where {M, B <: FiniteBasis} + kern = fx.f.kernel + δ = y - mean(fx) + X = weight_form(kern.ϕ, fx.x) + X_prec = X' * inv(fx.Σy) + Λμ = X_prec * y + prec = cholesky(I + Symmetric(X_prec * X)) + w = prec \ Λμ + DegeneratePosterior(fx.f, w, prec) +end + +function Statistics.mean(f::DegeneratePosterior, x::AbstractVector) + w = f.w_mean + X = weight_form(f.prior.kernel.ϕ, x) + X * w +end + +function Statistics.cov(f::DegeneratePosterior, x::AbstractVector) + X = weight_form(f.prior.kernel.ϕ, x) + AbstractGPs.Xt_invA_X(f.w_prec, X') +end + +function Statistics.var(f::DegeneratePosterior, x::AbstractVector) + X = weight_form(f.prior.kernel.ϕ, x) + AbstractGPs.diag_Xt_invA_X(f.w_prec, X') +end + +struct RandomFourierFeature + ws::Vector{Float64} +end + +RandomFourierFeature(kern::SqExponentialKernel, k::Int) = RandomFourierFeature(randn(k)) +RandomFourierFeature(rng, kern::SqExponentialKernel, k::Int) = RandomFourierFeature(randn(rng, k)) + + +function (f::RandomFourierFeature)(x) + Float64[cos.(f.ws .* x); sin.(f.ws .* x)] .* sqrt(2/length(f.ws)) +end + +end \ No newline at end of file diff --git a/test/FiniteBasisModule.jl b/test/FiniteBasisModule.jl new file mode 100644 index 00000000..1c6b47f7 --- /dev/null +++ b/test/FiniteBasisModule.jl @@ -0,0 +1,35 @@ +@testset "finite_basis" begin + rng = MersenneTwister(123456) + N = 50 + x = rand(rng, 2, N); + y = sin.(norm.(eachcol(x))) + + @testset "Verify equivalence of weight space and function space posteriors" begin + kern = FiniteBasis(identity) + x2 = ColVecs(rand(2, N)) + + # Predict mean and covariance using weight space view + f = GP(kern) + fx = f(x, 0.001) + opt_pred = mean_and_cov(posterior(fx, y)(x2)) + + # Predict mean and covariance as normal + fx2 = GP(kern + ZeroKernel())(x, 0.001) + pred = mean_and_cov(posterior(fx2, y)(x2)) + + # The two approaches should be the same + @test all(opt_pred .≈ pred) + end + + @testset "Verify that the RFF approximation matches the exact posterior" begin + rng = MersenneTwister(12345) + rbf = SqExponentialKernel() + flat_x = rand(rng, N) + flat_x2 = rand(rng, N) + ffkern = FiniteBasis(RandomFourierFeature(rng, rbf, 200)) + + opt_pred = mean_and_cov(posterior(GP(ffkern)(flat_x, 0.001), y)(flat_x2)) + pred = mean_and_cov(posterior(GP(rbf)(flat_x, 0.001), y)(flat_x2)) + @test all(isapprox.(opt_pred, pred; atol=1e-2)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index fa26a951..46009ddd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,6 +61,10 @@ include("test_utils.jl") include("LaplaceApproximationModule.jl") println(" ") @info "Ran laplace tests" + + include("FiniteBasisModule.jl") + println(" ") + @info "Ran finite basis tests" end if GROUP == "All" || GROUP == "CUDA" From 136d938325f9d6aa60fcf471fcdf0b96bc842292 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 27 Feb 2024 13:19:30 -0600 Subject: [PATCH 2/3] Add rand method --- src/FiniteBasisModule.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/FiniteBasisModule.jl b/src/FiniteBasisModule.jl index 8062b296..97903c18 100644 --- a/src/FiniteBasisModule.jl +++ b/src/FiniteBasisModule.jl @@ -1,6 +1,6 @@ module FiniteBasisModule -using KernelFunctions,LinearAlgebra, AbstractGPs, ArraysOfArrays +using KernelFunctions,LinearAlgebra, AbstractGPs, ArraysOfArrays, Random import AbstractGPs: AbstractGP, FiniteGP import Statistics @@ -40,17 +40,29 @@ function Statistics.cov(f::DegeneratePosterior, x::AbstractVector) AbstractGPs.Xt_invA_X(f.w_prec, X') end +function Statistics.cov(f::DegeneratePosterior, x::AbstractVector, y::AbstractVector) + X = weight_form(f.prior.kernel.ϕ, x) + Y = weight_form(f.prior.kernel.ϕ, y) + AbstractGPs.Xt_invA_Y(X', f.w_prec, Y') +end + function Statistics.var(f::DegeneratePosterior, x::AbstractVector) X = weight_form(f.prior.kernel.ϕ, x) AbstractGPs.diag_Xt_invA_X(f.w_prec, X') end +function Statistics.rand(rng::AbstractRNG, f::DegeneratePosterior, x::AbstractVector) + w = f.w_mean + X = weight_form(f.prior.kernel.ϕ, x) + X * (f.w_prec.U \ randn(rng, length(x))) +end + struct RandomFourierFeature ws::Vector{Float64} end RandomFourierFeature(kern::SqExponentialKernel, k::Int) = RandomFourierFeature(randn(k)) -RandomFourierFeature(rng, kern::SqExponentialKernel, k::Int) = RandomFourierFeature(randn(rng, k)) +RandomFourierFeature(rng::AbstractRNG, kern::SqExponentialKernel, k::Int) = RandomFourierFeature(randn(rng, k)) function (f::RandomFourierFeature)(x) From 68eb6128e57ca8f56235184eba1aa1de37f638aa Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 27 Feb 2024 15:42:01 -0600 Subject: [PATCH 3/3] Use SimpleKernel in FiniteBasis --- Project.toml | 1 - src/ApproximateGPs.jl | 2 +- src/FiniteBasisModule.jl | 30 +++++++++++++++++------------- test/FiniteBasisModule.jl | 4 ++-- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index bd707eec..6ce50610 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.4.5" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" -ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" diff --git a/src/ApproximateGPs.jl b/src/ApproximateGPs.jl index a3bc4de9..41b7c116 100644 --- a/src/ApproximateGPs.jl +++ b/src/ApproximateGPs.jl @@ -21,7 +21,7 @@ include("LaplaceApproximationModule.jl") include("FiniteBasisModule.jl") @reexport using .FiniteBasisModule: - RandomFourierFeature, FiniteBasis, DegeneratePosterior + FFApprox, FiniteBasis, DegeneratePosterior include("deprecations.jl") diff --git a/src/FiniteBasisModule.jl b/src/FiniteBasisModule.jl index 97903c18..a1d777b5 100644 --- a/src/FiniteBasisModule.jl +++ b/src/FiniteBasisModule.jl @@ -1,14 +1,14 @@ module FiniteBasisModule -using KernelFunctions,LinearAlgebra, AbstractGPs, ArraysOfArrays, Random +using KernelFunctions,LinearAlgebra, AbstractGPs, Random import AbstractGPs: AbstractGP, FiniteGP import Statistics +import ChainRulesCore -struct FiniteBasis{T} <: Kernel - ϕ::T -end +struct FiniteBasis <: KernelFunctions.SimpleKernel end -(k::FiniteBasis)(x, y) = dot(k.ϕ(x), k.ϕ(y)) +KernelFunctions.kappa(::FiniteBasis, d::Real) = d +KernelFunctions.metric(::FiniteBasis) = KernelFunctions.DotProduct() struct DegeneratePosterior{P,T,C} <: AbstractGP prior::P @@ -16,12 +16,13 @@ struct DegeneratePosterior{P,T,C} <: AbstractGP w_prec::C end -weight_form(ϕ, x) = flatview(ArrayOfSimilarArrays(ϕ.(x)))' +weight_form(A::KernelFunctions.ColVecs) = A.X' +weight_form(A::KernelFunctions.RowVecs) = A.X function AbstractGPs.posterior(fx::FiniteGP{GP{M, B}}, y::AbstractVector{<:Real}) where {M, B <: FiniteBasis} kern = fx.f.kernel δ = y - mean(fx) - X = weight_form(kern.ϕ, fx.x) + X = weight_form(fx.x) X_prec = X' * inv(fx.Σy) Λμ = X_prec * y prec = cholesky(I + Symmetric(X_prec * X)) @@ -31,29 +32,29 @@ end function Statistics.mean(f::DegeneratePosterior, x::AbstractVector) w = f.w_mean - X = weight_form(f.prior.kernel.ϕ, x) + X = weight_form(x) X * w end function Statistics.cov(f::DegeneratePosterior, x::AbstractVector) - X = weight_form(f.prior.kernel.ϕ, x) + X = weight_form(x) AbstractGPs.Xt_invA_X(f.w_prec, X') end function Statistics.cov(f::DegeneratePosterior, x::AbstractVector, y::AbstractVector) - X = weight_form(f.prior.kernel.ϕ, x) - Y = weight_form(f.prior.kernel.ϕ, y) + X = weight_form(x) + Y = weight_form(y) AbstractGPs.Xt_invA_Y(X', f.w_prec, Y') end function Statistics.var(f::DegeneratePosterior, x::AbstractVector) - X = weight_form(f.prior.kernel.ϕ, x) + X = weight_form(x) AbstractGPs.diag_Xt_invA_X(f.w_prec, X') end function Statistics.rand(rng::AbstractRNG, f::DegeneratePosterior, x::AbstractVector) w = f.w_mean - X = weight_form(f.prior.kernel.ϕ, x) + X = weight_form(x) X * (f.w_prec.U \ randn(rng, length(x))) end @@ -64,6 +65,9 @@ end RandomFourierFeature(kern::SqExponentialKernel, k::Int) = RandomFourierFeature(randn(k)) RandomFourierFeature(rng::AbstractRNG, kern::SqExponentialKernel, k::Int) = RandomFourierFeature(randn(rng, k)) +FFApprox(kern::Kernel, k::Int) = FiniteBasis() ∘ FunctionTransform(RandomFourierFeature(kern, k)) +FFApprox(rng::AbstractRNG, kern::Kernel, k::Int) = FiniteBasis() ∘ FunctionTransform(RandomFourierFeature(rng, kern, k)) + function (f::RandomFourierFeature)(x) Float64[cos.(f.ws .* x); sin.(f.ws .* x)] .* sqrt(2/length(f.ws)) diff --git a/test/FiniteBasisModule.jl b/test/FiniteBasisModule.jl index 1c6b47f7..2dc3e867 100644 --- a/test/FiniteBasisModule.jl +++ b/test/FiniteBasisModule.jl @@ -5,7 +5,7 @@ y = sin.(norm.(eachcol(x))) @testset "Verify equivalence of weight space and function space posteriors" begin - kern = FiniteBasis(identity) + kern = FiniteBasis() x2 = ColVecs(rand(2, N)) # Predict mean and covariance using weight space view @@ -26,7 +26,7 @@ rbf = SqExponentialKernel() flat_x = rand(rng, N) flat_x2 = rand(rng, N) - ffkern = FiniteBasis(RandomFourierFeature(rng, rbf, 200)) + ffkern = FFApprox(rng, rbf, 200) opt_pred = mean_and_cov(posterior(GP(ffkern)(flat_x, 0.001), y)(flat_x2)) pred = mean_and_cov(posterior(GP(rbf)(flat_x, 0.001), y)(flat_x2))