Skip to content

Commit cfa513e

Browse files
committed
#539 Incorporate feedback
1 parent 881dd41 commit cfa513e

4 files changed

Lines changed: 40 additions & 78 deletions

File tree

test/nodes/predefined/autoregressive_tests.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,22 +84,16 @@ end
8484
@test ar_unit(Univariate, 1) == 1.0
8585
@test ar_unit(Multivariate, 3)[1] 1.0
8686
@test eltype(ar_unit(Float32, Univariate, 1)) === Float32
87-
@test ar_unit(Float64, Multivariate, 2) isa ReactiveMP.StandardBasisVector
8887

8988
# --- ar_precision and ARPrecisionMatrix ---
9089
γ = 2.5
9190
order = 3
9291
pm = ar_precision(Multivariate, order, γ)
93-
@test pm isa ARPrecisionMatrix
9492
@test size(pm) == (3, 3)
9593
@test pm[1, 1] == γ
9694
@test pm[2, 2] convert(Float64, huge)
9795
@test pm[1, 2] == 0.0
9896

99-
# convert(::Type{AbstractArray{T}})
100-
pm32 = convert(AbstractArray{Float32}, pm)
101-
@test eltype(pm32.γ) == Float32
102-
10397
# add_precision & add_precision!
10498
A = zeros(3, 3)
10599
B = copy(A)

test/nodes/predefined/gamma_mixture_tests.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,8 @@ end
232232

233233
# support should be (0.0, Inf)
234234
s = support(d)
235-
@test s isa Distributions.RealInterval
236-
@test s.lb == 0.0
237-
@test s.ub == Inf
235+
@test minimum(s) == 0.0
236+
@test maximum(s) == Inf
238237

239238
# default_prod_rule dispatch
240239
rule = BayesBase.default_prod_rule(GammaShapeLikelihood, GammaShapeLikelihood)

test/nodes/predefined/mixture_tests.jl

Lines changed: 36 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,73 @@
11
@testitem "nodes:MixtureNode" begin
22
using ReactiveMP, BayesBase, ExponentialFamily, Test
33
import ReactiveMP:
4+
Mixture,
45
MixtureNode,
56
MixtureNodeFactorisation,
67
MixtureNodeFunctionalDependencies,
78
RequireMarginalFunctionalDependencies,
89
NodeInterface,
910
IndexedNodeInterface,
1011
interfaceindex,
12+
interfaceindices,
13+
collect_factorisation,
1114
collect_functional_dependencies,
1215
collect_latest_marginals,
1316
collect_latest_messages,
14-
factornode,
1517
functional_dependencies,
18+
alias_interface,
19+
is_predefined_node,
20+
PredefinedNodeFunctionalForm,
21+
interfaces,
22+
sdtype,
23+
factornode,
1624
FactorNodeActivationOptions,
17-
activate!,
18-
Mixture,
19-
MixtureNode
25+
activate!
2026

21-
interfaces = [(:out, datavar()), (:switch, randomvar()), (:inputs, randomvar()), (:inputs, randomvar())]
27+
# Common interfaces and factorizations used by both test groups
28+
interfaces_list = [(:out, datavar()), (:switch, randomvar()), (:inputs, randomvar()), (:inputs, randomvar())]
2229
factorizations = [[:out], [:switch], [:inputs1], [:inputs2]]
2330

31+
@testset "Type-level definitions" begin
32+
@test ReactiveMP.as_node_symbol(Mixture{2}) === :Mixture
33+
@test interfaces(Mixture{2}) == Val((:out, :switch, :inputs))
34+
@test alias_interface(Mixture{2}, 1, :foo) === :foo
35+
@test is_predefined_node(Mixture{2}) isa PredefinedNodeFunctionalForm
36+
@test sdtype(Mixture{2}) === Stochastic()
37+
@test collect_factorisation(Mixture{2}, nothing) isa MixtureNodeFactorisation
38+
end
39+
2440
@testset "Construction and interface structure" begin
25-
node = factornode(Mixture, interfaces, factorizations)
41+
node = factornode(Mixture, interfaces_list, factorizations)
2642

2743
@test node isa MixtureNode{2}
2844
@test sdtype(node) == Stochastic()
2945
@test functionalform(node) == Mixture{2}
3046
@test getinterfaces(node) isa Tuple
31-
@test length(getinterfaces(node)) == 4 # out, switch, inputs...
47+
@test length(getinterfaces(node)) == 4
3248

3349
@test node.out isa NodeInterface
3450
@test node.switch isa NodeInterface
3551
@test all(i -> i isa IndexedNodeInterface, node.inputs)
3652

37-
# index mapping
3853
@test interfaceindex(node, :out) == 1
3954
@test interfaceindex(node, :switch) == 2
4055
@test interfaceindex(node, :inputs) == 3
4156
@test_throws ErrorException interfaceindex(node, :invalid)
4257
end
4358

59+
@testset "Interface indices" begin
60+
node = factornode(Mixture, interfaces_list, factorizations)
61+
62+
@test interfaceindices(node, :out) == (1,)
63+
@test interfaceindices(node, (:out, :switch)) == (1, 2)
64+
end
65+
4466
@testset "Collect functional dependencies" begin
4567
node = MixtureNode(
46-
NodeInterface(interfaces[1]...),
47-
NodeInterface(interfaces[2]...),
48-
(IndexedNodeInterface(1, NodeInterface(interfaces[3]...)), IndexedNodeInterface(2, NodeInterface(interfaces[4]...)))
68+
NodeInterface(interfaces_list[1]...),
69+
NodeInterface(interfaces_list[2]...),
70+
(IndexedNodeInterface(1, NodeInterface(interfaces_list[3]...)), IndexedNodeInterface(2, NodeInterface(interfaces_list[4]...)))
4971
)
5072

5173
@test collect_functional_dependencies(node, nothing) isa MixtureNodeFunctionalDependencies
@@ -55,103 +77,50 @@
5577
end
5678

5779
@testset "Functional dependencies" begin
58-
node = factornode(Mixture, interfaces, factorizations)
80+
node = factornode(Mixture, interfaces_list, factorizations)
5981
deps = MixtureNodeFunctionalDependencies()
6082

61-
# out
6283
msg, marg = functional_dependencies(deps, node, node.out, 1)
6384
@test msg == (node.switch, node.inputs)
6485
@test marg == ()
6586

66-
# switch
6787
msg, marg = functional_dependencies(deps, node, node.switch, 2)
6888
@test msg == (node.out, node.inputs)
6989
@test marg == ()
7090

71-
# inputs
7291
msg, marg = functional_dependencies(deps, node, node.inputs[1], 3)
7392
@test msg == (node.out, node.switch)
7493
@test marg == ()
7594

76-
# invalid
7795
@test_throws ErrorException functional_dependencies(deps, node, node.out, 99)
7896
end
7997

8098
@testset "RequireMarginalFunctionalDependencies variant" begin
81-
node = factornode(Mixture, interfaces, factorizations)
99+
node = factornode(Mixture, interfaces_list, factorizations)
82100
deps = RequireMarginalFunctionalDependencies()
83101

84-
# out depends on inputs + marginal on switch
85102
msg, marg = functional_dependencies(deps, node, 1)
86103
@test length(msg) == 1
87104
@test length(marg) == 1
88105

89-
# switch depends on out, inputs + no marginal
90106
msg, marg = functional_dependencies(deps, node, 2)
91107
@test length(msg) == 2
92108
@test isempty(marg)
93109

94-
# input depends on out + marginal on switch
95110
msg, marg = functional_dependencies(deps, node, 3)
96111
@test length(msg) == 1
97112
@test length(marg) == 1
98113
end
99-
end
100-
101-
@testitem "nodes:MixtureNode:Extended" begin
102-
using ReactiveMP, BayesBase, ExponentialFamily, Test
103-
import ReactiveMP:
104-
Mixture,
105-
MixtureNode,
106-
MixtureNodeFactorisation,
107-
MixtureNodeFunctionalDependencies,
108-
RequireMarginalFunctionalDependencies,
109-
NodeInterface,
110-
IndexedNodeInterface,
111-
interfaceindex,
112-
interfaceindices,
113-
collect_factorisation,
114-
collect_latest_marginals,
115-
collect_latest_messages,
116-
interfaces,
117-
alias_interface,
118-
is_predefined_node,
119-
PredefinedNodeFunctionalForm
120-
121-
@testset "Type-level definitions" begin
122-
@test ReactiveMP.as_node_symbol(Mixture{2}) === :Mixture
123-
@test interfaces(Mixture{2}) == Val((:out, :switch, :inputs))
124-
@test alias_interface(Mixture{2}, 1, :foo) === :foo
125-
@test is_predefined_node(Mixture{2}) isa PredefinedNodeFunctionalForm
126-
@test sdtype(Mixture{2}) === Stochastic()
127-
@test collect_factorisation(Mixture{2}, nothing) isa MixtureNodeFactorisation
128-
end
129-
130-
# Construct a simple MixtureNode for reuse
131-
vinterfaces = [(:out, datavar()), (:switch, randomvar()), (:inputs, randomvar()), (:inputs, randomvar())]
132-
factorizations = [[:out], [:switch], [:inputs1], [:inputs2]]
133-
node = factornode(Mixture, vinterfaces, factorizations)
134-
135-
@testset "Interface indices" begin
136-
# Single symbol
137-
@test interfaceindices(node, :out) == (1,)
138-
# Multiple symbols
139-
res = interfaceindices(node, (:out, :switch))
140-
@test res == (1, 2)
141-
end
142-
143-
# TODO: collect_latest_messages
144114

145115
@testset "Collect latest marginals" begin
116+
node = factornode(Mixture, interfaces_list, factorizations)
146117
deps1 = MixtureNodeFunctionalDependencies()
147118
deps2 = RequireMarginalFunctionalDependencies()
148119

149-
# Variant 1: no marginals
150120
val1, obs1 = collect_latest_marginals(deps1, node, ())
151121
@test val1 === nothing
152122
@test obs1 !== nothing
153123

154-
# Variant 2: with switch marginal
155124
switchiface = NodeInterface(:switch, randomvar())
156125
val2, obs2 = collect_latest_marginals(deps2, node, (switchiface,))
157126
@test val2 isa Val

test/nodes/predefined/wishart_inverse_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
q_S = PointMass([2.0 0.0; 0.0 2.0])
1212

1313
marginals = (Marginal(q_out, false, false, nothing), Marginal(q_ν, false, false, nothing), Marginal(q_S, false, false, nothing))
14-
@test score(AverageEnergy(), InverseWishart, Val{(:out, :ν, :S)}(), marginals, nothing) 9.496544113156787
14+
@test score(AverageEnergy(), InverseWishart, Val{(:out, :ν, :S)}(), marginals, nothing) 9.496544113156787 rtol=1e-8
1515
end
1616

1717
begin
@@ -22,7 +22,7 @@
2222
q_S = PointMass(S)
2323

2424
marginals = (Marginal(q_out, false, false, nothing), Marginal(q_ν, false, false, nothing), Marginal(q_S, false, false, nothing))
25-
@test score(AverageEnergy(), InverseWishart, Val{(:out, :ν, :S)}(), marginals, nothing) 1.1299587008097587
25+
@test score(AverageEnergy(), InverseWishart, Val{(:out, :ν, :S)}(), marginals, nothing) 1.1299587008097587 rtol=1e-8
2626
end
2727
end
2828

0 commit comments

Comments
 (0)