|
1 | 1 | @testitem "nodes:MixtureNode" begin |
2 | 2 | using ReactiveMP, BayesBase, ExponentialFamily, Test |
3 | 3 | import ReactiveMP: |
| 4 | + Mixture, |
4 | 5 | MixtureNode, |
5 | 6 | MixtureNodeFactorisation, |
6 | 7 | MixtureNodeFunctionalDependencies, |
7 | 8 | RequireMarginalFunctionalDependencies, |
8 | 9 | NodeInterface, |
9 | 10 | IndexedNodeInterface, |
10 | 11 | interfaceindex, |
| 12 | + interfaceindices, |
| 13 | + collect_factorisation, |
11 | 14 | collect_functional_dependencies, |
12 | 15 | collect_latest_marginals, |
13 | 16 | collect_latest_messages, |
14 | | - factornode, |
15 | 17 | functional_dependencies, |
| 18 | + alias_interface, |
| 19 | + is_predefined_node, |
| 20 | + PredefinedNodeFunctionalForm, |
| 21 | + interfaces, |
| 22 | + sdtype, |
| 23 | + factornode, |
16 | 24 | FactorNodeActivationOptions, |
17 | | - activate!, |
18 | | - Mixture, |
19 | | - MixtureNode |
| 25 | + activate! |
20 | 26 |
|
21 | | - interfaces = [(:out, datavar()), (:switch, randomvar()), (:inputs, randomvar()), (:inputs, randomvar())] |
| 27 | + # Common interfaces and factorizations used by both test groups |
| 28 | + interfaces_list = [(:out, datavar()), (:switch, randomvar()), (:inputs, randomvar()), (:inputs, randomvar())] |
22 | 29 | factorizations = [[:out], [:switch], [:inputs1], [:inputs2]] |
23 | 30 |
|
| 31 | + @testset "Type-level definitions" begin |
| 32 | + @test ReactiveMP.as_node_symbol(Mixture{2}) === :Mixture |
| 33 | + @test interfaces(Mixture{2}) == Val((:out, :switch, :inputs)) |
| 34 | + @test alias_interface(Mixture{2}, 1, :foo) === :foo |
| 35 | + @test is_predefined_node(Mixture{2}) isa PredefinedNodeFunctionalForm |
| 36 | + @test sdtype(Mixture{2}) === Stochastic() |
| 37 | + @test collect_factorisation(Mixture{2}, nothing) isa MixtureNodeFactorisation |
| 38 | + end |
| 39 | + |
24 | 40 | @testset "Construction and interface structure" begin |
25 | | - node = factornode(Mixture, interfaces, factorizations) |
| 41 | + node = factornode(Mixture, interfaces_list, factorizations) |
26 | 42 |
|
27 | 43 | @test node isa MixtureNode{2} |
28 | 44 | @test sdtype(node) == Stochastic() |
29 | 45 | @test functionalform(node) == Mixture{2} |
30 | 46 | @test getinterfaces(node) isa Tuple |
31 | | - @test length(getinterfaces(node)) == 4 # out, switch, inputs... |
| 47 | + @test length(getinterfaces(node)) == 4 |
32 | 48 |
|
33 | 49 | @test node.out isa NodeInterface |
34 | 50 | @test node.switch isa NodeInterface |
35 | 51 | @test all(i -> i isa IndexedNodeInterface, node.inputs) |
36 | 52 |
|
37 | | - # index mapping |
38 | 53 | @test interfaceindex(node, :out) == 1 |
39 | 54 | @test interfaceindex(node, :switch) == 2 |
40 | 55 | @test interfaceindex(node, :inputs) == 3 |
41 | 56 | @test_throws ErrorException interfaceindex(node, :invalid) |
42 | 57 | end |
43 | 58 |
|
| 59 | + @testset "Interface indices" begin |
| 60 | + node = factornode(Mixture, interfaces_list, factorizations) |
| 61 | + |
| 62 | + @test interfaceindices(node, :out) == (1,) |
| 63 | + @test interfaceindices(node, (:out, :switch)) == (1, 2) |
| 64 | + end |
| 65 | + |
44 | 66 | @testset "Collect functional dependencies" begin |
45 | 67 | node = MixtureNode( |
46 | | - NodeInterface(interfaces[1]...), |
47 | | - NodeInterface(interfaces[2]...), |
48 | | - (IndexedNodeInterface(1, NodeInterface(interfaces[3]...)), IndexedNodeInterface(2, NodeInterface(interfaces[4]...))) |
| 68 | + NodeInterface(interfaces_list[1]...), |
| 69 | + NodeInterface(interfaces_list[2]...), |
| 70 | + (IndexedNodeInterface(1, NodeInterface(interfaces_list[3]...)), IndexedNodeInterface(2, NodeInterface(interfaces_list[4]...))) |
49 | 71 | ) |
50 | 72 |
|
51 | 73 | @test collect_functional_dependencies(node, nothing) isa MixtureNodeFunctionalDependencies |
|
55 | 77 | end |
56 | 78 |
|
57 | 79 | @testset "Functional dependencies" begin |
58 | | - node = factornode(Mixture, interfaces, factorizations) |
| 80 | + node = factornode(Mixture, interfaces_list, factorizations) |
59 | 81 | deps = MixtureNodeFunctionalDependencies() |
60 | 82 |
|
61 | | - # out |
62 | 83 | msg, marg = functional_dependencies(deps, node, node.out, 1) |
63 | 84 | @test msg == (node.switch, node.inputs) |
64 | 85 | @test marg == () |
65 | 86 |
|
66 | | - # switch |
67 | 87 | msg, marg = functional_dependencies(deps, node, node.switch, 2) |
68 | 88 | @test msg == (node.out, node.inputs) |
69 | 89 | @test marg == () |
70 | 90 |
|
71 | | - # inputs |
72 | 91 | msg, marg = functional_dependencies(deps, node, node.inputs[1], 3) |
73 | 92 | @test msg == (node.out, node.switch) |
74 | 93 | @test marg == () |
75 | 94 |
|
76 | | - # invalid |
77 | 95 | @test_throws ErrorException functional_dependencies(deps, node, node.out, 99) |
78 | 96 | end |
79 | 97 |
|
80 | 98 | @testset "RequireMarginalFunctionalDependencies variant" begin |
81 | | - node = factornode(Mixture, interfaces, factorizations) |
| 99 | + node = factornode(Mixture, interfaces_list, factorizations) |
82 | 100 | deps = RequireMarginalFunctionalDependencies() |
83 | 101 |
|
84 | | - # out depends on inputs + marginal on switch |
85 | 102 | msg, marg = functional_dependencies(deps, node, 1) |
86 | 103 | @test length(msg) == 1 |
87 | 104 | @test length(marg) == 1 |
88 | 105 |
|
89 | | - # switch depends on out, inputs + no marginal |
90 | 106 | msg, marg = functional_dependencies(deps, node, 2) |
91 | 107 | @test length(msg) == 2 |
92 | 108 | @test isempty(marg) |
93 | 109 |
|
94 | | - # input depends on out + marginal on switch |
95 | 110 | msg, marg = functional_dependencies(deps, node, 3) |
96 | 111 | @test length(msg) == 1 |
97 | 112 | @test length(marg) == 1 |
98 | 113 | end |
99 | | -end |
100 | | - |
101 | | -@testitem "nodes:MixtureNode:Extended" begin |
102 | | - using ReactiveMP, BayesBase, ExponentialFamily, Test |
103 | | - import ReactiveMP: |
104 | | - Mixture, |
105 | | - MixtureNode, |
106 | | - MixtureNodeFactorisation, |
107 | | - MixtureNodeFunctionalDependencies, |
108 | | - RequireMarginalFunctionalDependencies, |
109 | | - NodeInterface, |
110 | | - IndexedNodeInterface, |
111 | | - interfaceindex, |
112 | | - interfaceindices, |
113 | | - collect_factorisation, |
114 | | - collect_latest_marginals, |
115 | | - collect_latest_messages, |
116 | | - interfaces, |
117 | | - alias_interface, |
118 | | - is_predefined_node, |
119 | | - PredefinedNodeFunctionalForm |
120 | | - |
121 | | - @testset "Type-level definitions" begin |
122 | | - @test ReactiveMP.as_node_symbol(Mixture{2}) === :Mixture |
123 | | - @test interfaces(Mixture{2}) == Val((:out, :switch, :inputs)) |
124 | | - @test alias_interface(Mixture{2}, 1, :foo) === :foo |
125 | | - @test is_predefined_node(Mixture{2}) isa PredefinedNodeFunctionalForm |
126 | | - @test sdtype(Mixture{2}) === Stochastic() |
127 | | - @test collect_factorisation(Mixture{2}, nothing) isa MixtureNodeFactorisation |
128 | | - end |
129 | | - |
130 | | - # Construct a simple MixtureNode for reuse |
131 | | - vinterfaces = [(:out, datavar()), (:switch, randomvar()), (:inputs, randomvar()), (:inputs, randomvar())] |
132 | | - factorizations = [[:out], [:switch], [:inputs1], [:inputs2]] |
133 | | - node = factornode(Mixture, vinterfaces, factorizations) |
134 | | - |
135 | | - @testset "Interface indices" begin |
136 | | - # Single symbol |
137 | | - @test interfaceindices(node, :out) == (1,) |
138 | | - # Multiple symbols |
139 | | - res = interfaceindices(node, (:out, :switch)) |
140 | | - @test res == (1, 2) |
141 | | - end |
142 | | - |
143 | | - # TODO: collect_latest_messages |
144 | 114 |
|
145 | 115 | @testset "Collect latest marginals" begin |
| 116 | + node = factornode(Mixture, interfaces_list, factorizations) |
146 | 117 | deps1 = MixtureNodeFunctionalDependencies() |
147 | 118 | deps2 = RequireMarginalFunctionalDependencies() |
148 | 119 |
|
149 | | - # Variant 1: no marginals |
150 | 120 | val1, obs1 = collect_latest_marginals(deps1, node, ()) |
151 | 121 | @test val1 === nothing |
152 | 122 | @test obs1 !== nothing |
153 | 123 |
|
154 | | - # Variant 2: with switch marginal |
155 | 124 | switchiface = NodeInterface(:switch, randomvar()) |
156 | 125 | val2, obs2 = collect_latest_marginals(deps2, node, (switchiface,)) |
157 | 126 | @test val2 isa Val |
|
0 commit comments