Skip to content

Commit f13bf03

Browse files
authored
Merge pull request #1109 from JuliaRobotics/21Q1/maint/cf_derelative
update multihypo for CalcFactor
2 parents 1d2c14f + d60e139 commit f13bf03

2 files changed

Lines changed: 113 additions & 84 deletions

File tree

test/testMultihypoFMD.jl

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,37 @@ import IncrementalInference: getSample
99

1010
##
1111

12-
@testset "test FactorMetadata is properly populated" begin
12+
1313

1414
struct MyFactor{T <: SamplableBelief} <: IIF.AbstractRelativeRoots
1515
Z::T
1616
# specialSampler approach will be deprecated
17-
specialSampler::Function
17+
# specialSampler::Function
1818
end
1919

20-
##
2120

22-
function getSample(mf::MyFactor, N::Int=1, fmd_...)
23-
@warn "specialSampler (mf::MyFactor,N) does not get hypo sub-selected FMD data"
24-
@show DFG.getLabel.(fmd_[1].fullvariables)
21+
function getSample(cf::CalcFactor{<:MyFactor}, N::Int=1)
22+
@warn "getSample(cf::CalcFactor{<:MyFactor},::Int) does not get hypo sub-selected FMD data"
23+
@show DFG.getLabel.(cf.metadata.fullvariables)
2524
# @assert DFG.getLabel.(fmd_[1].fullvariables) |> length < 3 "this factor is only between two variables"
26-
return (reshape(rand(mf.Z, N),1,N),)
25+
return (reshape(rand(cf.factor.Z, N),1,N),)
2726
end
2827

2928

30-
function (mf_::MyFactor)(res,fmd,i,meas, X1,X2)
31-
@assert DFG.getLabel.(fmd.fullvariables) |> length < 3 "this factor is only between two variables"
29+
function (cf::CalcFactor{<:MyFactor})(res, z, X1, X2)
30+
@assert DFG.getLabel.(cf.metadata.fullvariables) |> length < 3 "this factor is only between two variables"
3231

3332
# just a linear difference to complete the test
34-
res .= X2[:,i] - (X1[:,i] + meas[1][:,1])
33+
res .= X2 - (X1 + z)
34+
nothing
3535
end
3636

37+
##
38+
39+
40+
41+
@testset "test FactorMetadata is properly populated" begin
42+
3743

3844
##
3945

@@ -47,11 +53,19 @@ addVariable!(fg, :x1_b, ContinuousScalar)
4753
addFactor!(fg, [:x0], Prior(Normal()))
4854

4955
# create the object and add it to the graph
50-
mf = MyFactor(Normal(10,1),getSample)
56+
mf = MyFactor( Normal(10,1) )
57+
58+
##
5159

5260
# this sampling might error
5361
addFactor!(fg, [:x0;:x1_a;:x1_b], mf, multihypo=[1;1/2;1/2])
5462

63+
##
64+
65+
meas = freshSamples(fg, :x0x1_ax1_bf1, 10)
66+
67+
##
68+
5569
solveTree!(fg);
5670

5771
##

test/testmultihypothesisapi.jl

Lines changed: 88 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -11,66 +11,63 @@ import IncrementalInference: getSample
1111

1212
##
1313

14-
mutable struct DevelopPrior <: AbstractPrior
15-
x::Distribution
14+
mutable struct DevelopPrior{T <: SamplableBelief} <: AbstractPrior
15+
x::T
1616
end
17-
getSample(dpl::DevelopPrior, N::Int=1) = (reshape(rand(dpl.x, N),1,N), )
17+
getSample(cf::CalcFactor{<:DevelopPrior}, N::Int=1) = (reshape(rand(cf.factor.x, N),1,N), )
1818

19-
mutable struct DevelopLikelihood <: AbstractRelativeRoots
20-
x::Distribution
19+
mutable struct DevelopLikelihood{T <: SamplableBelief} <: AbstractRelativeRoots
20+
x::T
2121
end
22-
getSample(dpl::DevelopLikelihood, N::Int=1) = (reshape(rand(dpl.x, N),1,N), )
23-
function (vv::DevelopLikelihood)(res::AbstractArray{<:Real},
24-
userdata::FactorMetadata,
25-
idx::Int,
26-
meas::Tuple,
27-
wXi::AbstractArray{<:Real,2},
28-
wXj::AbstractArray{<:Real,2} )::Nothing
22+
23+
getSample(cf::CalcFactor{<:DevelopLikelihood}, N::Int=1) = (reshape(rand(cf.factor.x, N),1,N), )
24+
function (cf::CalcFactor{<:DevelopLikelihood})( res::AbstractVector{<:Real},
25+
meas,
26+
wXi,
27+
wXj )
2928
#
30-
res[1] = meas[1][idx] - (wXj[1,idx] - wXi[1,idx])
29+
res .= meas - (wXj - wXi)
3130
nothing
3231
end
3332

3433

35-
global N = 100
36-
global fg = initfg()
34+
N = 100
35+
fg = initfg()
3736

37+
##
3838

3939
@testset "test populate factor graph with a multi-hypothesis factor..." begin
4040

41-
global v1 = addVariable!(fg, :x1, ContinuousScalar, N=N)
41+
##
4242

43-
global pr = DevelopPrior(Normal(10.0,1.0))
44-
global f1 = addFactor!(fg,[:x1],pr)
43+
v1 = addVariable!(fg, :x1, ContinuousScalar, N=N)
4544

45+
pr = DevelopPrior(Normal(10.0,1.0))
46+
f1 = addFactor!(fg,[:x1],pr)
4647

47-
ensureAllInitialized!(fg)
4848

49-
# Juno.breakpoint("/home/dehann/.julia/v0.5/IncrementalInference/src/ApproxConv.jl",121)
49+
ensureAllInitialized!(fg)
5050

51-
global pts = evalFactor(fg, f1, v1.label, N=N)
51+
pts = evalFactor(fg, f1, v1.label, N=N)
5252

5353
@test sum(abs.(pts .- 1.0) .< 5) < 30
5454
@test sum(abs.(pts .- 10.0) .< 5) > 30
5555

5656

57-
58-
global v2 = addVariable!(fg, :x2, ContinuousScalar, N=N)
59-
global pp = DevelopLikelihood(Normal(100.0,1.0))
60-
global f2 = addFactor!(fg, [:x1;:x2], pp)
57+
v2 = addVariable!(fg, :x2, ContinuousScalar, N=N)
58+
pp = DevelopLikelihood(Normal(100.0,1.0))
59+
f2 = addFactor!(fg, [:x1;:x2], pp)
6160

6261
ensureAllInitialized!(fg)
6362

6463
@test abs(Statistics.mean(getVal(fg, :x2))-110.0) < 10.0
6564

6665

66+
v3 = addVariable!(fg, :x3, ContinuousScalar, N=N)
67+
v4 = addVariable!(fg, :x4, ContinuousScalar, N=N)
6768

68-
69-
global v3 = addVariable!(fg, :x3, ContinuousScalar, N=N)
70-
global v4 = addVariable!(fg, :x4, ContinuousScalar, N=N)
71-
72-
global ppMH = DevelopLikelihood(Normal(90.0,1.0))
73-
global f3 = addFactor!(fg, [:x2;:x3;:x4], ppMH, multihypo=[1.0;0.5;0.5])
69+
ppMH = DevelopLikelihood(Normal(90.0,1.0))
70+
f3 = addFactor!(fg, [:x2;:x3;:x4], ppMH, multihypo=[1.0;0.5;0.5])
7471

7572

7673
# @test IIIF._getCCW(f3).hypoverts == [:x3, :x4]
@@ -82,27 +79,33 @@ initManual!(fg, :x2, 1*ones(1,100))
8279
initManual!(fg, :x3, 2*ones(1,100))
8380
initManual!(fg, :x4, 3*ones(1,100))
8481

82+
##
83+
8584
end
8685

8786

8887

8988
@testset "Test multi-hypothesis factor convolution exploration" begin
9089

91-
global pts = approxConv(fg, :x2x3x4f1, :x2, N=N)
90+
##
91+
92+
pts = approxConv(fg, :x2x3x4f1, :x2, N=N)
9293

9394
@test 99 < sum(pts .<= -70.0)
9495

95-
global pts = approxConv(fg, :x2x3x4f1, :x3, N=N)
96+
pts = approxConv(fg, :x2x3x4f1, :x3, N=N)
9697

9798
@test 15 < sum(pts .== 3.0) < 75
9899

99-
global pts = approxConv(fg, :x2x3x4f1, :x4, N=N)
100+
pts = approxConv(fg, :x2x3x4f1, :x4, N=N)
100101

101102
@test 15 < sum(pts .== 2.0) < 75
102103

103-
end
104+
##
104105

106+
end
105107

108+
##
106109

107110
println("Packing converters")
108111

@@ -113,10 +116,10 @@ mutable struct PackedDevelopPrior <: PackedInferenceType
113116
PackedDevelopPrior(x) = new(x)
114117
end
115118
function convert(::Type{PackedDevelopPrior}, d::DevelopPrior)
116-
PackedDevelopPrior(string(d.x))
119+
PackedDevelopPrior(convert(PackedSamplableBelief, d.x))
117120
end
118121
function convert(::Type{DevelopPrior}, d::PackedDevelopPrior)
119-
DevelopPrior(IncrementalInference.normalfromstring(d.x))
122+
DevelopPrior(convert(SamplableBelief, d.x))
120123
end
121124

122125
mutable struct PackedDevelopLikelihood <: PackedInferenceType
@@ -125,80 +128,89 @@ mutable struct PackedDevelopLikelihood <: PackedInferenceType
125128
PackedDevelopLikelihood(x) = new(x)
126129
end
127130
function convert(::Type{PackedDevelopLikelihood}, d::DevelopLikelihood)
128-
PackedDevelopLikelihood(string(d.x))
131+
PackedDevelopLikelihood(convert(PackedSamplableBelief, d.x))
129132
end
130133
function convert(::Type{DevelopLikelihood}, d::PackedDevelopLikelihood)
131-
DevelopLikelihood(IncrementalInference.extractdistribution(d.x))
134+
DevelopLikelihood(convert(SamplableBelief, d.x))
132135
end
133136

134137

138+
##
139+
135140
@testset "test packing and unpacking the data structure" begin
136141

137-
global topack = getSolverData(f1)
138-
global dd = convert(PackedFunctionNodeData{PackedDevelopPrior},topack)
139-
global unpacked = convert(FunctionNodeData{CommonConvWrapper{DevelopPrior}},dd)
142+
##
140143

141-
@test abs(IIF._getCCW(unpacked).usrfnc!.x.μ - 10.0) < 1e-10
142-
@test abs(IIF._getCCW(unpacked).usrfnc!.x.σ - 1.0) < 1e-10
144+
topack = getSolverData(getFactor(fg,:x1f1))
145+
dd = convert(PackedFunctionNodeData{PackedDevelopPrior},topack)
146+
unpacked = convert(FunctionNodeData{CommonConvWrapper{DevelopPrior}},dd)
143147

148+
@test abs(IIF._getCCW(unpacked).usrfnc!.x.μ - 10.0) < 1e-10
149+
@test abs(IIF._getCCW(unpacked).usrfnc!.x.σ - 1.0) < 1e-10
144150

145151

146-
global topack = getSolverData(f3)
147-
global dd = convert(PackedFunctionNodeData{PackedDevelopLikelihood},topack)
148-
global unpacked = convert(FunctionNodeData{CommonConvWrapper{DevelopLikelihood}},dd)
152+
fct = getFactor(fg, :x2x3x4f1)
153+
@show typeof(fct)
154+
topack = getSolverData(fct) # f3
155+
dd = convert(PackedFunctionNodeData{PackedDevelopLikelihood},topack)
156+
unpacked = convert(FunctionNodeData{CommonConvWrapper{DevelopLikelihood}},dd)
149157

150-
# @test IIF._getCCW(unpacked).hypoverts == Symbol[:x3; :x4]
151-
@test sum(abs.(IIF._getCCW(unpacked).hypotheses.p[1] .- 0.0)) < 0.1
152-
@test sum(abs.(IIF._getCCW(unpacked).hypotheses.p[2:3] .- 0.5)) < 0.1
153-
# str = "Symbol[:x3, :x4];[0.5, 0.5]"
154-
# IncrementalInference.parsemultihypostr(str)
158+
# @test IIF._getCCW(unpacked).hypoverts == Symbol[:x3; :x4]
159+
@test sum(abs.(IIF._getCCW(unpacked).hypotheses.p[1] .- 0.0)) < 0.1
160+
@test sum(abs.(IIF._getCCW(unpacked).hypotheses.p[2:3] .- 0.5)) < 0.1
155161

156-
end
157162

163+
##
158164

165+
end
166+
167+
##
159168

160169
# start a new factor graph
161-
global N = 200
162-
global fg = initfg()
170+
N = 200
171+
fg = initfg()
172+
173+
##
163174

164175
@testset "test tri-modal factor..." begin
165176

177+
##
166178

167-
global v1 = addVariable!(fg, :x1, ContinuousScalar, N=N)
179+
v1 = addVariable!(fg, :x1, ContinuousScalar, N=N)
168180

169-
global pr = DevelopPrior(Normal(10.0,1.0))
170-
global f1 = addFactor!(fg,[:x1],pr)
181+
pr = DevelopPrior(Normal(10.0,1.0))
182+
f1 = addFactor!(fg,[:x1],pr)
171183

172184

173185
ensureAllInitialized!(fg)
174186

175187
# Juno.breakpoint("/home/dehann/.julia/v0.5/IncrementalInference/src/ApproxConv.jl",121)
176188

177-
global pts = approxConv(fg, Symbol(f1.label), :x1, N=N)
189+
pts = approxConv(fg, Symbol(f1.label), :x1, N=N)
178190

179191

180192
@test sum(abs.(pts .- 1.0) .< 5) < 30
181193
@test sum(abs.(pts .- 10.0) .< 5) > 30
182194

183195

184196

185-
global v2 = addVariable!(fg, :x2, ContinuousScalar, N=N)
186-
global pp = DevelopLikelihood(Normal(100.0,1.0))
187-
global f2 = addFactor!(fg, [:x1;:x2], pp)
197+
v2 = addVariable!(fg, :x2, ContinuousScalar, N=N)
198+
pp = DevelopLikelihood(Normal(100.0,1.0))
199+
f2 = addFactor!(fg, [:x1;:x2], pp)
188200

189201
ensureAllInitialized!(fg)
190202

191203
@test abs(Statistics.mean(getVal(fg, :x2))-110.0) < 10.0
192204

193205

194206

195-
global v3 = addVariable!(fg, :x3, ContinuousScalar, N=N)
196-
global v4 = addVariable!(fg, :x4, ContinuousScalar, N=N)
197-
global v5 = addVariable!(fg, :x5, ContinuousScalar, N=N)
207+
v3 = addVariable!(fg, :x3, ContinuousScalar, N=N)
208+
v4 = addVariable!(fg, :x4, ContinuousScalar, N=N)
209+
v5 = addVariable!(fg, :x5, ContinuousScalar, N=N)
198210

199211

200-
global ppMH = DevelopLikelihood(Normal(90.0,1.0))
201-
global f3 = addFactor!(fg, [:x2;:x3;:x4;:x5], ppMH, multihypo=[1.0,0.333,0.333,0.334])
212+
ppMH = DevelopLikelihood(Normal(90.0,1.0))
213+
f3 = addFactor!(fg, [:x2;:x3;:x4;:x5], ppMH, multihypo=[1.0,0.333,0.333,0.334])
202214

203215

204216

@@ -216,13 +228,13 @@ initManual!(fg, :x5 ,4*ones(1,100))
216228

217229

218230
# solve for certain idx
219-
global pts = approxConv(fg, :x2x3x4x5f1, :x2, N=N)
231+
pts = approxConv(fg, :x2x3x4x5f1, :x2, N=N)
220232

221233
@test 0.95*N < sum(pts .<= -70.0)
222234

223235

224236
# solve for one of uncertain variables
225-
global pts = approxConv(fg, :x2x3x4x5f1, :x3, N=N)
237+
pts = approxConv(fg, :x2x3x4x5f1, :x3, N=N)
226238

227239
@test 0.1*N < sum(80 .< pts .< 100.0) < 0.5*N
228240
@test 0.1*N < sum(pts .== 3.0) < 0.5*N
@@ -233,7 +245,7 @@ global pts = approxConv(fg, :x2x3x4x5f1, :x3, N=N)
233245

234246

235247
# solve for one of uncertain variables
236-
global pts = approxConv(fg, :x2x3x4x5f1, :x4, N=N)
248+
pts = approxConv(fg, :x2x3x4x5f1, :x4, N=N)
237249

238250
@test 0.1*N < sum(80 .< pts .< 100.0) < 0.5*N
239251
@test 0.1*N < sum(pts .== 2.0) < 0.5*N
@@ -243,22 +255,23 @@ global pts = approxConv(fg, :x2x3x4x5f1, :x4, N=N)
243255

244256

245257
# solve for one of uncertain variables
246-
global pts = approxConv(fg, :x2x3x4x5f1, :x5, N=N)
258+
pts = approxConv(fg, :x2x3x4x5f1, :x5, N=N)
247259

248260
@test 0.1*N < sum(80 .< pts .< 100.0) < 0.5*N
249261
@test 0.1*N < sum(pts .== 2.0) < 0.5*N
250262
@test 0.1*N < sum(pts .== 3.0) < 0.5*N
251263

252264
@test 0.5*N <= sum(80 .< pts .< 100.0) + sum(pts .== 2.0) + sum(pts .== 3.0)
253265

254-
266+
##
255267

256268
end
257269

258-
##
259270

260271
@testset "test multihypo api numerical tolerance, #1086" begin
261272

273+
##
274+
262275
fg = initfg()
263276

264277
addVariable!(fg, :x0, ContinuousEuclid{1})
@@ -268,6 +281,8 @@ addFactor!(fg, [:x0;:x1a;:x1b], LinearRelative(Normal()), multihypo=[1; 0.5;0.49
268281
addFactor!(fg, [:x0;:x1a;:x1b], LinearRelative(Normal()), multihypo=[1; 0.5;0.5000000000001])
269282

270283

284+
##
285+
271286
end
272287

273288

0 commit comments

Comments
 (0)