Skip to content

Commit bb85c2b

Browse files
authored
Merge pull request #584 from ReactiveBayes/test-item-runner
use TestItemRunner instead of ReTestItems
2 parents f9b1c1f + d4f963b commit bb85c2b

18 files changed

Lines changed: 212 additions & 269 deletions

File tree

Project.toml

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ authors = ["Dmitry Bagaev <d.v.bagaev@tue.nl>", "Albert Podusenko <a.podusenko@t
55

66
[deps]
77
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
8-
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
98
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
109
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1110
DomainIntegrals = "cc6bae93-f070-4015-88fd-838f9505a86c"
@@ -26,13 +25,11 @@ PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
2625
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2726
Rocket = "df971d30-c9d6-4b37-b8ff-e965b2cb3a40"
2827
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
29-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3028
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3129
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
3230
TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04"
3331
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
3432
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
35-
Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
3633

3734
[weakdeps]
3835
ExponentialFamilyProjection = "17f509fa-9a96-44ba-99b2-1c5f01f0931b"
@@ -44,7 +41,6 @@ ReactiveMPProjectionExt = "ExponentialFamilyProjection"
4441

4542
[compat]
4643
BayesBase = "1.5"
47-
DataStructures = "0.17, 0.18"
4844
DiffResults = "1.1.0"
4945
Distributions = "0.24, 0.25"
5046
DomainIntegrals = "0.3.2, 0.4, 0.5"
@@ -67,13 +63,11 @@ PositiveFactorizations = "0.2"
6763
Random = "1.9"
6864
Rocket = "1.8.0"
6965
SpecialFunctions = "2.3"
70-
StaticArrays = "1.9.0"
7166
StatsBase = "0.34.0"
7267
StatsFuns = "1.3.0"
7368
TinyHugeNumbers = "1.0.0"
7469
Tullio = "0.3"
7570
TupleTools = "1.2.0"
76-
Unrolled = "0.1.3"
7771
julia = "1.10"
7872

7973
[extras]
@@ -85,7 +79,6 @@ DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
8579
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
8680
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
8781
ExponentialFamilyProjection = "17f509fa-9a96-44ba-99b2-1c5f01f0931b"
88-
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
8982
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
9083
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
9184
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
@@ -94,10 +87,9 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
9487
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
9588
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
9689
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
97-
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
90+
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
9891
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
9992
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
100-
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
10193

10294
[targets]
103-
test = ["Aqua", "Hwloc", "ReTestItems", "Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "Dates", "Distributed", "Documenter", "BenchmarkTools", "JET", "PkgBenchmark", "StableRNGs", "Optimisers", "DiffResults", "ExponentialFamilyProjection", "REPL", "Manopt"]
95+
test = ["Aqua", "TestItemRunner", "Test", "Pkg", "Logging", "InteractiveUtils", "Coverage", "Dates", "Distributed", "Documenter", "BenchmarkTools", "JET", "PkgBenchmark", "StableRNGs", "Optimisers", "DiffResults", "ExponentialFamilyProjection", "REPL", "Manopt"]

src/nodes/predefined/mixture.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ function functional_dependencies(::RequireMarginalFunctionalDependencies, factor
136136
else
137137
error("Bad index in function_dependencies for MixtureNode")
138138
end
139-
# println(marginal_dependencies)
139+
140140
return message_dependencies, marginal_dependencies
141141
end
142142

src/rules/mixture/switch.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import StaticArrays: SVector
2-
31
@rule Mixture(:switch, Marginalisation) (m_out::Any, m_inputs::ManyOf{N, Any}) where {N} = begin
42

53
# compute logscales of different products

test/approximations/unscented_tests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@
8787
@test all(approximate(Unscented(), (x) -> [x^2, x], (2.0,), (0.0,)) .≈ ([4.0, 2.0], [0.0 0.0; 0.0 0.0]))
8888
@test all(x -> all(isnan, x), approximate(Unscented(), (x) -> x + 1.0, (NaN,), (1.0,)))
8989
@test all(x -> all(isnan, x), approximate(Unscented(), (x) -> [x^2, x], (NaN,), (1.0,)))
90-
@show approximate(Unscented(), (x) -> [x^2, x], (NaN,), (1.0,))
9190
@test all(x -> all(isnan, x), approximate(Unscented(), (x) -> x + 1.0, (1.0,), (NaN,)))
9291
end
9392
end

test/ext/ReactiveMPProjectionExt/rules/out_tests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ end
6161

6262
prj = ProjectedTo(ExponentialFamily.exponential_family_typetag(q_out), size(q_out)...)
6363

64-
q_out_projected = project_to(prj, (x) -> logpdf(msg, x) + logpdf(m_out_incoming, x); initial_point = q_out)
64+
q_out_projected = project_to(prj, (x) -> logpdf(msg, x) + logpdf(m_out_incoming, x); initialpoint = q_out)
6565
@test mean(q_out_projected) mean(q_out) rtol = 5e-1
6666
@test var(q_out_projected) var(q_out) rtol = 1.0
6767
@test mode(q_out_projected) mode(q_out) rtol = 5e-1
@@ -167,7 +167,6 @@ end
167167
@test isa(result_dist, LogNormal)
168168

169169
# For the standard normal, E[exp(x)] = exp(1/2)
170-
@show result_dist
171170
@test mean(result_dist) exp(1 / 2) rtol = 0.05
172171
end
173172

test/helpers/algebra/standard_basis_vector_tests.jl

Lines changed: 98 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -6,126 +6,119 @@
66

77
rng = MersenneTwister(1234)
88

9-
@testset begin
10-
@test_throws AssertionError StandardBasisVector(0, 1)
11-
@test_throws AssertionError StandardBasisVector(-10, 1)
12-
@test_throws AssertionError StandardBasisVector(10, 11)
13-
@test_throws AssertionError StandardBasisVector(10, -2)
14-
15-
for T in (Int, Float64, Float32)
16-
r = rand(rng, T)
17-
e = StandardBasisVector(2, 1, r)
18-
@test eltype(e) === T
19-
@test e[1] === r
20-
@test e[2] === zero(r)
21-
end
9+
@test_throws AssertionError StandardBasisVector(0, 1)
10+
@test_throws AssertionError StandardBasisVector(-10, 1)
11+
@test_throws AssertionError StandardBasisVector(10, 11)
12+
@test_throws AssertionError StandardBasisVector(10, -2)
13+
14+
for T in (Int, Float64, Float32)
15+
r = rand(rng, T)
16+
e = StandardBasisVector(2, 1, r)
17+
@test eltype(e) === T
18+
@test e[1] === r
19+
@test e[2] === zero(r)
2220
end
2321

2422
# Same sizes
25-
@testset begin
26-
for N in 1:8
27-
for I in 1:N
28-
for T in (Int, Float64, Float32)
29-
scale = rand(rng, T)
30-
e = StandardBasisVector(N, I, scale)
31-
e_c = zeros(T, N)
32-
e_c[I] = scale
33-
34-
m = rand(rng, T)
35-
v = rand(rng, T, N)
36-
A = rand(rng, T, N, N)
37-
a = rand(rng, T, N, 1)
38-
39-
@test m * e == m * e_c
40-
@test m * e' == m * e_c'
41-
@test e * m == e_c * m
42-
@test e' * m == e_c' * m
43-
44-
for A in (A, Diagonal(diag(A)), A')
45-
@test (A * e) == (A * e_c)
46-
@test (A' * e) == (A' * e_c)
47-
@test (e * e') == (e_c * e_c')
48-
@test (e' * e) == (e_c' * e_c)
49-
@test (v' * e) == (v' * e_c)
50-
@test (e' * v) == (e_c' * v)
51-
@test (e' * v) == (e_c' * v)
52-
@test (a' * e) == (a' * e_c)
53-
@test (a * e') == (a * e_c')
54-
55-
t = rand(rng, T)
56-
57-
@test ReactiveMP.v_a_vT(e, t) ReactiveMP.v_a_vT(e_c, t)
58-
@test ReactiveMP.v_a_vT(e, t, e) ReactiveMP.v_a_vT(e_c, t, e_c)
59-
60-
@test dot(e, A, e) === dot(e_c, A, e_c)
61-
@test dot(e, e) === dot(e_c, e_c)
62-
@test dot(e, e_c) === dot(e_c, e_c)
63-
@test dot(e_c, e) === dot(e_c, e_c)
64-
@test dot(v, e) === dot(v, e_c)
65-
@test dot(e, v) === dot(e_c, v)
66-
@test dot(v, e') === dot(v, e_c')
67-
@test dot(e', v) === dot(e_c', v)
68-
@test dot(v', e) === dot(v', e_c)
69-
@test dot(e, v') === dot(e_c, v')
70-
@test dot(v', e') === dot(v', e_c')
71-
@test dot(e', v') === dot(e_c', v')
72-
end
23+
for N in 1:8
24+
for I in 1:N
25+
for T in (Int, Float64, Float32)
26+
scale = rand(rng, T)
27+
e = StandardBasisVector(N, I, scale)
28+
e_c = zeros(T, N)
29+
e_c[I] = scale
30+
31+
m = rand(rng, T)
32+
v = rand(rng, T, N)
33+
A = rand(rng, T, N, N)
34+
a = rand(rng, T, N, 1)
35+
36+
@test m * e == m * e_c
37+
@test m * e' == m * e_c'
38+
@test e * m == e_c * m
39+
@test e' * m == e_c' * m
40+
41+
for A in (A, Diagonal(diag(A)), A')
42+
@test (A * e) == (A * e_c)
43+
@test (A' * e) == (A' * e_c)
44+
@test (e * e') == (e_c * e_c')
45+
@test (e' * e) == (e_c' * e_c)
46+
@test (v' * e) == (v' * e_c)
47+
@test (e' * v) == (e_c' * v)
48+
@test (e' * v) == (e_c' * v)
49+
@test (a' * e) == (a' * e_c)
50+
@test (a * e') == (a * e_c')
51+
52+
t = rand(rng, T)
53+
54+
@test ReactiveMP.v_a_vT(e, t) ReactiveMP.v_a_vT(e_c, t)
55+
@test ReactiveMP.v_a_vT(e, t, e) ReactiveMP.v_a_vT(e_c, t, e_c)
56+
57+
@test dot(e, A, e) === dot(e_c, A, e_c)
58+
@test dot(e, e) === dot(e_c, e_c)
59+
@test dot(e, e_c) === dot(e_c, e_c)
60+
@test dot(e_c, e) === dot(e_c, e_c)
61+
@test dot(v, e) === dot(v, e_c)
62+
@test dot(e, v) === dot(e_c, v)
63+
@test dot(v, e') === dot(v, e_c')
64+
@test dot(e', v) === dot(e_c', v)
65+
@test dot(v', e) === dot(v', e_c)
66+
@test dot(e, v') === dot(e_c, v')
67+
@test dot(v', e') === dot(v', e_c')
68+
@test dot(e', v') === dot(e_c', v')
7369
end
7470
end
7571
end
7672
end
73+
7774
# Different sizes
78-
@testset begin
79-
for N1 in 1:4, N2 in 1:4
80-
if N1 !== N2
81-
for I1 in 1:N1, I2 in 1:N2
82-
for T in (Int, Float64, Float32)
83-
scale1 = rand(rng, T)
84-
scale2 = rand(rng, T)
85-
e1 = StandardBasisVector(N1, I1, scale1)
86-
e2 = StandardBasisVector(N2, I2, scale2)
87-
e_c1 = zeros(T, N1)
88-
e_c1[I1] = scale1
89-
e_c2 = zeros(T, N2)
90-
e_c2[I2] = scale2
91-
92-
@test_throws AssertionError dot(e1, e2)
93-
@test_throws AssertionError dot(e_c1, e2)
94-
@test_throws AssertionError dot(e1, e_c2)
95-
@test_throws AssertionError dot(e2, e1)
96-
@test_throws AssertionError dot(e_c2, e1)
97-
@test_throws AssertionError dot(e2, e_c1)
98-
99-
@test e1 * e2' == e_c1 * e_c2'
100-
@test e2 * e1' == e_c2 * e_c1'
101-
@test e_c1 * e2' == e_c1 * e_c2'
102-
@test e_c2 * e1' == e_c2 * e_c1'
103-
@test e1 * e_c2' == e_c1 * e_c2'
104-
@test e2 * e_c1' == e_c2 * e_c1'
105-
@test_throws AssertionError e1' * e2
106-
@test_throws AssertionError e1' * e_c2
107-
@test_throws AssertionError e2' * e1
108-
@test_throws AssertionError e2' * e_c1
109-
end
75+
for N1 in 1:4, N2 in 1:4
76+
if N1 !== N2
77+
for I1 in 1:N1, I2 in 1:N2
78+
for T in (Int, Float64, Float32)
79+
scale1 = rand(rng, T)
80+
scale2 = rand(rng, T)
81+
e1 = StandardBasisVector(N1, I1, scale1)
82+
e2 = StandardBasisVector(N2, I2, scale2)
83+
e_c1 = zeros(T, N1)
84+
e_c1[I1] = scale1
85+
e_c2 = zeros(T, N2)
86+
e_c2[I2] = scale2
87+
88+
@test_throws AssertionError dot(e1, e2)
89+
@test_throws AssertionError dot(e_c1, e2)
90+
@test_throws AssertionError dot(e1, e_c2)
91+
@test_throws AssertionError dot(e2, e1)
92+
@test_throws AssertionError dot(e_c2, e1)
93+
@test_throws AssertionError dot(e2, e_c1)
94+
95+
@test e1 * e2' == e_c1 * e_c2'
96+
@test e2 * e1' == e_c2 * e_c1'
97+
@test e_c1 * e2' == e_c1 * e_c2'
98+
@test e_c2 * e1' == e_c2 * e_c1'
99+
@test e1 * e_c2' == e_c1 * e_c2'
100+
@test e2 * e_c1' == e_c2 * e_c1'
101+
@test_throws AssertionError e1' * e2
102+
@test_throws AssertionError e1' * e_c2
103+
@test_throws AssertionError e2' * e1
104+
@test_throws AssertionError e2' * e_c1
110105
end
111106
end
112107
end
113108
end
114109

115-
@testset begin
116-
import ReactiveMP: v_a_vT
110+
import ReactiveMP: v_a_vT
117111

118-
for i in 2:5, j in 1:i, scale in (1.0, 2.0), a in rand(5)
119-
e1 = StandardBasisVector(i, j, scale)
120-
v1 = collect(e1)
121-
@test v_a_vT(e1, a) v_a_vT(v1, a)
112+
for i in 2:5, j in 1:i, scale in (1.0, 2.0), a in rand(5)
113+
e1 = StandardBasisVector(i, j, scale)
114+
v1 = collect(e1)
115+
@test v_a_vT(e1, a) v_a_vT(v1, a)
122116

123-
e2 = StandardBasisVector(i, i - j + 1, scale)
124-
v2 = collect(e2)
117+
e2 = StandardBasisVector(i, i - j + 1, scale)
118+
v2 = collect(e2)
125119

126-
@test v_a_vT(e1, a, e2) v_a_vT(v1, a, v2)
127-
@test v_a_vT(e1, a, v2) v_a_vT(v1, a, v2)
128-
@test v_a_vT(v1, a, e2) v_a_vT(v1, a, v2)
129-
end
120+
@test v_a_vT(e1, a, e2) v_a_vT(v1, a, v2)
121+
@test v_a_vT(e1, a, v2) v_a_vT(v1, a, v2)
122+
@test v_a_vT(v1, a, e2) v_a_vT(v1, a, v2)
130123
end
131124
end

test/helpers/helpers_tests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import ReactiveMP: SkipIndexIterator, skipindex
55
import ReactiveMP: CountingReal
6-
import ReactiveMP: FunctionalIndex
76

87
@testset "SkipIndexIterator" begin
98
s = skipindex(1:3, 2)

test/nodes/dependencies_tests.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44

55
import ReactiveMP: NodeInterface, collect_latest_messages, getdata, getrecent, default_functional_dependencies
66

7-
struct ArbitraryNode end
7+
struct ArbitraryNodeForCollectLatestMessage end
88

9-
@node ArbitraryNode Stochastic [a, b, c]
9+
@node ArbitraryNodeForCollectLatestMessage Stochastic [a, b, c]
1010

1111
a_v = ConstVariable(1)
1212
b_v = ConstVariable(2)
1313
c_v = ConstVariable(3)
1414

15-
node = factornode(ArbitraryNode, [(:a, a_v), (:b, b_v), (:c, c_v)], ((1, 2, 3),))
16-
dependencies = default_functional_dependencies(ArbitraryNode)
15+
node = factornode(ArbitraryNodeForCollectLatestMessage, [(:a, a_v), (:b, b_v), (:c, c_v)], ((1, 2, 3),))
16+
dependencies = default_functional_dependencies(ArbitraryNodeForCollectLatestMessage)
1717

1818
a, b, c = getinterfaces(node)
1919

@@ -50,16 +50,16 @@ end
5050
import ReactiveMP:
5151
NodeInterface, FactorNodeLocalMarginal, getmarginal, collect_latest_marginals, getdata, getrecent, default_functional_dependencies, getlocalclusters, getmarginals
5252

53-
struct ArbitraryNode end
53+
struct ArbitraryNodeForCollectLatestMarginals end
5454

55-
@node ArbitraryNode Stochastic [a, b, c]
55+
@node ArbitraryNodeForCollectLatestMarginals Stochastic [a, b, c]
5656

5757
a_v = ConstVariable(1)
5858
b_v = ConstVariable(2)
5959
c_v = ConstVariable(3)
6060

61-
node = factornode(ArbitraryNode, [(:a, a_v), (:b, b_v), (:c, c_v)], ((1,), (2,), (3,)))
62-
dependencies = default_functional_dependencies(ArbitraryNode)
61+
node = factornode(ArbitraryNodeForCollectLatestMarginals, [(:a, a_v), (:b, b_v), (:c, c_v)], ((1,), (2,), (3,)))
62+
dependencies = default_functional_dependencies(ArbitraryNodeForCollectLatestMarginals)
6363

6464
a, b, c = getmarginals(getlocalclusters(node))
6565

0 commit comments

Comments
 (0)