Skip to content

Commit d4f963b

Browse files
committed
2prev
1 parent 5516faa commit d4f963b

3 files changed

Lines changed: 135 additions & 150 deletions

File tree

test/helpers/algebra/standard_basis_vector_tests.jl

Lines changed: 98 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -6,126 +6,119 @@
66

77
rng = MersenneTwister(1234)
88

9-
@testset begin
10-
@test_throws AssertionError StandardBasisVector(0, 1)
11-
@test_throws AssertionError StandardBasisVector(-10, 1)
12-
@test_throws AssertionError StandardBasisVector(10, 11)
13-
@test_throws AssertionError StandardBasisVector(10, -2)
14-
15-
for T in (Int, Float64, Float32)
16-
r = rand(rng, T)
17-
e = StandardBasisVector(2, 1, r)
18-
@test eltype(e) === T
19-
@test e[1] === r
20-
@test e[2] === zero(r)
21-
end
9+
@test_throws AssertionError StandardBasisVector(0, 1)
10+
@test_throws AssertionError StandardBasisVector(-10, 1)
11+
@test_throws AssertionError StandardBasisVector(10, 11)
12+
@test_throws AssertionError StandardBasisVector(10, -2)
13+
14+
for T in (Int, Float64, Float32)
15+
r = rand(rng, T)
16+
e = StandardBasisVector(2, 1, r)
17+
@test eltype(e) === T
18+
@test e[1] === r
19+
@test e[2] === zero(r)
2220
end
2321

2422
# Same sizes
25-
@testset begin
26-
for N in 1:8
27-
for I in 1:N
28-
for T in (Int, Float64, Float32)
29-
scale = rand(rng, T)
30-
e = StandardBasisVector(N, I, scale)
31-
e_c = zeros(T, N)
32-
e_c[I] = scale
33-
34-
m = rand(rng, T)
35-
v = rand(rng, T, N)
36-
A = rand(rng, T, N, N)
37-
a = rand(rng, T, N, 1)
38-
39-
@test m * e == m * e_c
40-
@test m * e' == m * e_c'
41-
@test e * m == e_c * m
42-
@test e' * m == e_c' * m
43-
44-
for A in (A, Diagonal(diag(A)), A')
45-
@test (A * e) == (A * e_c)
46-
@test (A' * e) == (A' * e_c)
47-
@test (e * e') == (e_c * e_c')
48-
@test (e' * e) == (e_c' * e_c)
49-
@test (v' * e) == (v' * e_c)
50-
@test (e' * v) == (e_c' * v)
51-
@test (e' * v) == (e_c' * v)
52-
@test (a' * e) == (a' * e_c)
53-
@test (a * e') == (a * e_c')
54-
55-
t = rand(rng, T)
56-
57-
@test ReactiveMP.v_a_vT(e, t) ReactiveMP.v_a_vT(e_c, t)
58-
@test ReactiveMP.v_a_vT(e, t, e) ReactiveMP.v_a_vT(e_c, t, e_c)
59-
60-
@test dot(e, A, e) === dot(e_c, A, e_c)
61-
@test dot(e, e) === dot(e_c, e_c)
62-
@test dot(e, e_c) === dot(e_c, e_c)
63-
@test dot(e_c, e) === dot(e_c, e_c)
64-
@test dot(v, e) === dot(v, e_c)
65-
@test dot(e, v) === dot(e_c, v)
66-
@test dot(v, e') === dot(v, e_c')
67-
@test dot(e', v) === dot(e_c', v)
68-
@test dot(v', e) === dot(v', e_c)
69-
@test dot(e, v') === dot(e_c, v')
70-
@test dot(v', e') === dot(v', e_c')
71-
@test dot(e', v') === dot(e_c', v')
72-
end
23+
for N in 1:8
24+
for I in 1:N
25+
for T in (Int, Float64, Float32)
26+
scale = rand(rng, T)
27+
e = StandardBasisVector(N, I, scale)
28+
e_c = zeros(T, N)
29+
e_c[I] = scale
30+
31+
m = rand(rng, T)
32+
v = rand(rng, T, N)
33+
A = rand(rng, T, N, N)
34+
a = rand(rng, T, N, 1)
35+
36+
@test m * e == m * e_c
37+
@test m * e' == m * e_c'
38+
@test e * m == e_c * m
39+
@test e' * m == e_c' * m
40+
41+
for A in (A, Diagonal(diag(A)), A')
42+
@test (A * e) == (A * e_c)
43+
@test (A' * e) == (A' * e_c)
44+
@test (e * e') == (e_c * e_c')
45+
@test (e' * e) == (e_c' * e_c)
46+
@test (v' * e) == (v' * e_c)
47+
@test (e' * v) == (e_c' * v)
48+
@test (e' * v) == (e_c' * v)
49+
@test (a' * e) == (a' * e_c)
50+
@test (a * e') == (a * e_c')
51+
52+
t = rand(rng, T)
53+
54+
@test ReactiveMP.v_a_vT(e, t) ReactiveMP.v_a_vT(e_c, t)
55+
@test ReactiveMP.v_a_vT(e, t, e) ReactiveMP.v_a_vT(e_c, t, e_c)
56+
57+
@test dot(e, A, e) === dot(e_c, A, e_c)
58+
@test dot(e, e) === dot(e_c, e_c)
59+
@test dot(e, e_c) === dot(e_c, e_c)
60+
@test dot(e_c, e) === dot(e_c, e_c)
61+
@test dot(v, e) === dot(v, e_c)
62+
@test dot(e, v) === dot(e_c, v)
63+
@test dot(v, e') === dot(v, e_c')
64+
@test dot(e', v) === dot(e_c', v)
65+
@test dot(v', e) === dot(v', e_c)
66+
@test dot(e, v') === dot(e_c, v')
67+
@test dot(v', e') === dot(v', e_c')
68+
@test dot(e', v') === dot(e_c', v')
7369
end
7470
end
7571
end
7672
end
73+
7774
# Different sizes
78-
@testset begin
79-
for N1 in 1:4, N2 in 1:4
80-
if N1 !== N2
81-
for I1 in 1:N1, I2 in 1:N2
82-
for T in (Int, Float64, Float32)
83-
scale1 = rand(rng, T)
84-
scale2 = rand(rng, T)
85-
e1 = StandardBasisVector(N1, I1, scale1)
86-
e2 = StandardBasisVector(N2, I2, scale2)
87-
e_c1 = zeros(T, N1)
88-
e_c1[I1] = scale1
89-
e_c2 = zeros(T, N2)
90-
e_c2[I2] = scale2
91-
92-
@test_throws AssertionError dot(e1, e2)
93-
@test_throws AssertionError dot(e_c1, e2)
94-
@test_throws AssertionError dot(e1, e_c2)
95-
@test_throws AssertionError dot(e2, e1)
96-
@test_throws AssertionError dot(e_c2, e1)
97-
@test_throws AssertionError dot(e2, e_c1)
98-
99-
@test e1 * e2' == e_c1 * e_c2'
100-
@test e2 * e1' == e_c2 * e_c1'
101-
@test e_c1 * e2' == e_c1 * e_c2'
102-
@test e_c2 * e1' == e_c2 * e_c1'
103-
@test e1 * e_c2' == e_c1 * e_c2'
104-
@test e2 * e_c1' == e_c2 * e_c1'
105-
@test_throws AssertionError e1' * e2
106-
@test_throws AssertionError e1' * e_c2
107-
@test_throws AssertionError e2' * e1
108-
@test_throws AssertionError e2' * e_c1
109-
end
75+
for N1 in 1:4, N2 in 1:4
76+
if N1 !== N2
77+
for I1 in 1:N1, I2 in 1:N2
78+
for T in (Int, Float64, Float32)
79+
scale1 = rand(rng, T)
80+
scale2 = rand(rng, T)
81+
e1 = StandardBasisVector(N1, I1, scale1)
82+
e2 = StandardBasisVector(N2, I2, scale2)
83+
e_c1 = zeros(T, N1)
84+
e_c1[I1] = scale1
85+
e_c2 = zeros(T, N2)
86+
e_c2[I2] = scale2
87+
88+
@test_throws AssertionError dot(e1, e2)
89+
@test_throws AssertionError dot(e_c1, e2)
90+
@test_throws AssertionError dot(e1, e_c2)
91+
@test_throws AssertionError dot(e2, e1)
92+
@test_throws AssertionError dot(e_c2, e1)
93+
@test_throws AssertionError dot(e2, e_c1)
94+
95+
@test e1 * e2' == e_c1 * e_c2'
96+
@test e2 * e1' == e_c2 * e_c1'
97+
@test e_c1 * e2' == e_c1 * e_c2'
98+
@test e_c2 * e1' == e_c2 * e_c1'
99+
@test e1 * e_c2' == e_c1 * e_c2'
100+
@test e2 * e_c1' == e_c2 * e_c1'
101+
@test_throws AssertionError e1' * e2
102+
@test_throws AssertionError e1' * e_c2
103+
@test_throws AssertionError e2' * e1
104+
@test_throws AssertionError e2' * e_c1
110105
end
111106
end
112107
end
113108
end
114109

115-
@testset begin
116-
import ReactiveMP: v_a_vT
110+
import ReactiveMP: v_a_vT
117111

118-
for i in 2:5, j in 1:i, scale in (1.0, 2.0), a in rand(5)
119-
e1 = StandardBasisVector(i, j, scale)
120-
v1 = collect(e1)
121-
@test v_a_vT(e1, a) v_a_vT(v1, a)
112+
for i in 2:5, j in 1:i, scale in (1.0, 2.0), a in rand(5)
113+
e1 = StandardBasisVector(i, j, scale)
114+
v1 = collect(e1)
115+
@test v_a_vT(e1, a) v_a_vT(v1, a)
122116

123-
e2 = StandardBasisVector(i, i - j + 1, scale)
124-
v2 = collect(e2)
117+
e2 = StandardBasisVector(i, i - j + 1, scale)
118+
v2 = collect(e2)
125119

126-
@test v_a_vT(e1, a, e2) v_a_vT(v1, a, v2)
127-
@test v_a_vT(e1, a, v2) v_a_vT(v1, a, v2)
128-
@test v_a_vT(v1, a, e2) v_a_vT(v1, a, v2)
129-
end
120+
@test v_a_vT(e1, a, e2) v_a_vT(v1, a, v2)
121+
@test v_a_vT(e1, a, v2) v_a_vT(v1, a, v2)
122+
@test v_a_vT(v1, a, e2) v_a_vT(v1, a, v2)
130123
end
131124
end

test/variables/constant_tests.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,16 @@ end
3535

3636
include("../testutilities.jl")
3737

38-
@testset begin
39-
# Test marginal computation
40-
for d in 1:5:100, constant in rand(10)
41-
let var = constvar(constant)
42-
marginal_expected = mgl(PointMass(constant))
43-
marginal_result = check_stream_updated_once(getmarginal(var)) do
44-
nothing
45-
end
46-
47-
@test getdata(marginal_result) === getdata(marginal_expected)
48-
@test getdata(marginal_result) === PointMass(constant)
38+
# Test marginal computation
39+
for d in 1:5:100, constant in rand(10)
40+
let var = constvar(constant)
41+
marginal_expected = mgl(PointMass(constant))
42+
marginal_result = check_stream_updated_once(getmarginal(var)) do
43+
nothing
4944
end
45+
46+
@test getdata(marginal_result) === getdata(marginal_expected)
47+
@test getdata(marginal_result) === PointMass(constant)
5048
end
5149
end
5250
end

test/variables/data_tests.jl

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -35,37 +35,34 @@ end
3535

3636
include("../testutilities.jl")
3737

38-
@testset begin
39-
# Test marginal computation
40-
for d in 1:5:100
41-
let var = datavar()
42-
messageins = map(1:d) do _
43-
s = Subject(AbstractMessage)
44-
m, i = create_messagein!(var)
45-
connect!(m, s)
46-
return s
47-
end
38+
for d in 1:5:100
39+
let var = datavar()
40+
messageins = map(1:d) do _
41+
s = Subject(AbstractMessage)
42+
m, i = create_messagein!(var)
43+
connect!(m, s)
44+
return s
45+
end
4846

49-
activate!(var, DataVariableActivationOptions(false, false, nothing, nothing))
47+
activate!(var, DataVariableActivationOptions(false, false, nothing, nothing))
5048

51-
messages = map(msg, rand(d))
49+
messages = map(msg, rand(d))
5250

53-
@test check_stream_not_updated(getmarginal(var)) do
54-
foreach(zip(messageins, messages)) do (messagein, message)
55-
next!(messagein, message)
56-
end
51+
@test check_stream_not_updated(getmarginal(var)) do
52+
foreach(zip(messageins, messages)) do (messagein, message)
53+
next!(messagein, message)
5754
end
55+
end
5856

59-
data_point = rand()
60-
61-
marginal_expected = mgl(PointMass(data_point))
62-
marginal_result = check_stream_updated_once(getmarginal(var)) do
63-
update!(var, data_point)
64-
end
57+
data_point = rand()
6558

66-
@test getdata(marginal_result) === getdata(marginal_expected)
67-
@test getdata(marginal_result) === PointMass(data_point)
59+
marginal_expected = mgl(PointMass(data_point))
60+
marginal_result = check_stream_updated_once(getmarginal(var)) do
61+
update!(var, data_point)
6862
end
63+
64+
@test getdata(marginal_result) === getdata(marginal_expected)
65+
@test getdata(marginal_result) === PointMass(data_point)
6966
end
7067
end
7168
end
@@ -77,8 +74,7 @@ end
7774
include("../testutilities.jl")
7875

7976
for fn in (+, *), val1 in 1:3, val2 in 1:3
80-
@testset begin
81-
var = datavar()
77+
let var = datavar()
8278
options = DataVariableActivationOptions(true, true, fn, (val1, val2))
8379
activate!(var, options)
8480
marginal = check_stream_updated_once(getmarginal(var))
@@ -88,24 +84,22 @@ end
8884
end
8985

9086
# Just marginal
91-
@testset begin
92-
var = datavar()
87+
let var = datavar()
9388
options = DataVariableActivationOptions(true, true, fn, (val1, val2))
9489
activate!(var, options)
9590
marginal = check_stream_updated_once(getmarginal(var))
9691
@test getdata(marginal) === PointMass(fn(val1, val2))
9792
end
9893

9994
# Just message
100-
@testset begin
101-
var = datavar()
95+
let var = datavar()
10296
options = DataVariableActivationOptions(true, true, fn, (val1, val2))
10397
activate!(var, options)
10498
message = check_stream_updated_once(messageout(var, 1))
10599
@test getdata(message) === PointMass(fn(val1, val2))
106100
end
107101

108-
@testset begin
102+
let
109103
var1 = datavar()
110104
activate!(var1, DataVariableActivationOptions(true, false, nothing, nothing))
111105

@@ -122,7 +116,7 @@ end
122116
@test getdata(message) === PointMass(fn(val1, val2))
123117
end
124118

125-
@testset begin
119+
let
126120
var2 = datavar()
127121
activate!(var2, DataVariableActivationOptions(true, false, nothing, nothing))
128122

@@ -140,7 +134,7 @@ end
140134
@test getdata(message) === PointMass(fn(val1, val2))
141135
end
142136

143-
@testset begin
137+
let
144138
var1 = datavar()
145139
var2 = datavar()
146140
activate!(var1, DataVariableActivationOptions(true, false, nothing, nothing))
@@ -161,7 +155,7 @@ end
161155
@test getdata(message) === PointMass(fn(val1, val2))
162156
end
163157

164-
@testset begin
158+
let
165159
var1 = datavar()
166160
var2 = datavar()
167161
activate!(var1, DataVariableActivationOptions(true, false, nothing, nothing))

0 commit comments

Comments
 (0)