|
18 | 18 | Σ_empirical = (Y .- mean(Y; dims=2)) * (Y .- mean(Y; dims=2))' ./ samples |
19 | 19 | @test mean(f(X, Σy)) ≈ m_empirical atol = 1e-2 rtol = 1e-2 |
20 | 20 | @test cov(f(X, Σy)) ≈ Σ_empirical atol = 1e-2 rtol = 1e-2 |
21 | | - |
22 | | - @testset "Zygote (everything dense)" begin |
23 | | - function rand_blr(X, A_Σy, mw, A_Λw) |
24 | | - Σy, Λw = Symmetric(A_Σy * A_Σy' + I), Symmetric(A_Λw * A_Λw' + I) |
25 | | - f = BayesianLinearRegressor(mw, Λw) |
26 | | - return rand(MersenneTwister(123456), f(X, Σy), 3) |
27 | | - end |
28 | | - mw, A_Σy, A_Λw = f.mw, 0.1 .* randn(rng, N, N), 0.1 .* randn(rng, D, D) |
29 | | - |
30 | | - # Run the model forwards and check that output agrees with non-Zygote output |
31 | | - z, back = Zygote.pullback(rand_blr, X, A_Σy, mw, A_Λw) |
32 | | - @test z == rand_blr(X, A_Σy, mw, A_Λw) |
33 | | - |
34 | | - # Compute adjoints using Zygote. |
35 | | - z̄ = randn(rng, size(z)) |
36 | | - dX, dA_Σy, dmw, dA_Λw = back(z̄) |
37 | | - |
38 | | - # Verify adjoints via finite differencing. |
39 | | - fdm = central_fdm(5, 1) |
40 | | - @test dX ≈ first(j′vp(fdm, X -> rand_blr(X, A_Σy, mw, A_Λw), z̄, X)) |
41 | | - @test dA_Σy ≈ |
42 | | - first(j′vp(fdm, A_Σy -> rand_blr(X, A_Σy, mw, A_Λw), z̄, A_Σy)) |
43 | | - @test dmw ≈ first(j′vp(fdm, mw -> rand_blr(X, A_Σy, mw, A_Λw), z̄, mw)) |
44 | | - @test dA_Λw ≈ |
45 | | - first(j′vp(fdm, A_Λw -> rand_blr(X, A_Σy, mw, A_Λw), z̄, A_Λw)) |
46 | | - end |
47 | 21 | end |
48 | 22 | @testset "logpdf" begin |
49 | 23 | rng, N, D = MersenneTwister(123456), 13, 7 |
50 | 24 | X, f, Σy = generate_toy_problem(rng, N, D, Tx) |
51 | 25 | y = rand(rng, f(X, Σy)) |
52 | 26 |
|
53 | | - # Construct MvNormal using a naive but simple computation for the mean / cov. |
| 27 | + # Compute logpdf using a naive but simple computation for the mean / cov. |
54 | 28 | function naive_normal_stats(X::Matrix) |
55 | 29 | return (X' * f.mw, Symmetric(X' * (cholesky(f.Λw) \ X) + Σy)) |
56 | 30 | end |
|
59 | 33 | m, Σ = naive_normal_stats(X) |
60 | 34 |
|
61 | 35 | # Check that logpdf agrees between distributions and BLR. |
62 | | - @test logpdf(f(X, Σy), y) ≈ logpdf(MvNormal(m, Σ), y) |
63 | | - |
64 | | - @testset "Zygote (everything dense)" begin |
65 | | - function logpdf_blr(X, A_Σy, y, mw, A_Λw) |
66 | | - Σy, Λw = Symmetric(A_Σy * A_Σy' + I), Symmetric(A_Λw * A_Λw' + I) |
67 | | - f = BayesianLinearRegressor(mw, Λw) |
68 | | - return logpdf(f(X, Σy), y) |
69 | | - end |
70 | | - mw, A_Σy, A_Λw = f.mw, 0.1 .* randn(rng, N, N), 0.1 .* randn(rng, D, D) |
71 | | - |
72 | | - z, back = Zygote.pullback(logpdf_blr, X, A_Σy, y, mw, A_Λw) |
73 | | - @test z == logpdf_blr(X, A_Σy, y, mw, A_Λw) |
74 | | - |
75 | | - # Compute gradients using Zygote. |
76 | | - z̄ = randn(rng) |
77 | | - dX, dA_Σy, dy, dmw, dA_Λw = back(z̄) |
78 | | - |
79 | | - # Check correctness via finite differencing. |
80 | | - fdm = central_fdm(5, 1) |
81 | | - @test dX ≈ first(j′vp(fdm, X -> logpdf_blr(X, A_Σy, y, mw, A_Λw), z̄, X)) |
82 | | - @test dA_Σy ≈ |
83 | | - first(j′vp(fdm, A_Σy -> logpdf_blr(X, A_Σy, y, mw, A_Λw), z̄, A_Σy)) |
84 | | - @test dy ≈ first(j′vp(fdm, y -> logpdf_blr(X, A_Σy, y, mw, A_Λw), z̄, y)) |
85 | | - @test dmw ≈ first(j′vp(fdm, mw -> logpdf_blr(X, A_Σy, y, mw, A_Λw), z̄, mw)) |
86 | | - @test dA_Λw ≈ |
87 | | - first(j′vp(fdm, A_Λw -> logpdf_blr(X, A_Σy, y, mw, A_Λw), z̄, A_Λw)) |
88 | | - end |
| 36 | + δ = y - m |
| 37 | + @test logpdf(f(X, Σy), y) ≈ -(N * log(2π) + logdet(Σ) + δ' * (Σ \ δ)) / 2 |
89 | 38 | end |
90 | 39 | @testset "posterior" begin |
91 | 40 | @testset "low noise" begin |
|
0 commit comments