Skip to content

Commit 1d2c14f

Browse files
authored
Merge pull request #1107 from JuliaRobotics/21Q1/maint/cf_derelative
update DERelative on new CalcFactor API
2 parents fd42d31 + 527f4a0 commit 1d2c14f

10 files changed

Lines changed: 104 additions & 242 deletions

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name = "IncrementalInference"
22
uuid = "904591bb-b899-562f-9e6f-b8df64c7d480"
33
keywords = ["MM-iSAM", "MM-iSAMv2", "Bayes tree", "junction tree", "Bayes network", "variable elimination", "graphical models", "SLAM", "inference", "sum-product", "belief-propagation"]
44
desc = "Implements the Multimodal iSAM algorithm."
5-
version = "0.20.0"
5+
version = "0.20.1"
66

77
[deps]
88
ApproxManifoldProducts = "9bbbb610-88a1-53cd-9763-118ce10c1f89"

src/CalcFactor.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,6 @@ end
6060

6161

6262

63-
function getSample( cf::CalcFactor,
64-
N::Int=1)
65-
#
66-
if !hasfield(typeof(cf.factor), :specialSampler)
67-
getSample(cf.factor, N)
68-
else
69-
cf.factor.specialSampler(cf.factor, N, cf.metadata, cf.metadata.fullvariables...)
70-
end
71-
end
72-
7363
"""
7464
$SIGNATURES
7565

src/Deprecated.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,19 @@ end
208208
# see DFG #590
209209
@deprecate extractdistribution(x) convert(SamplableBelief, x)
210210

211+
212+
function getSample( cf::CalcFactor,
213+
N::Int=1)
214+
#
215+
if !hasfield(typeof(cf.factor), :specialSampler)
216+
@warn "`getSample(::MyFactor, ::Int)` API is being deprecated, use `getSample(cf::CalcFactor{<:MyFactor}, ::Int=1) = cf.factor...` instead"
217+
getSample(cf.factor, N)
218+
else
219+
@warn "`myfactor.specialSampler` API is being deprecated, use `getSample(cf::CalcFactor{<:MyFactor}, ::Int=1) = cf.metadata.___` instead"
220+
cf.factor.specialSampler(cf.factor, N, cf.metadata, cf.metadata.fullvariables...)
221+
end
222+
end
223+
211224
function freshSamples(usrfnc::T,
212225
N::Int,
213226
fmd::FactorMetadata,

src/FactorGraph.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,8 @@ function prepgenericconvolution(Xi::Vector{<:DFGVariable},
628628
fldnms = fieldnames(T) # typeof(usrfnc)
629629

630630
# standard factor metadata
631-
fmd = FactorMetadata(Xi, getLabel.(Xi), ARR, :null, nothing)
631+
sflbl = 0==length(Xi) ? :null : getLabel(Xi[end])
632+
fmd = FactorMetadata(Xi, getLabel.(Xi), ARR, sflbl, nothing)
632633
cf = CalcFactor( usrfnc, fmd, 0, 1, (Matrix{Float64}(undef,0,0),), ARR)
633634

634635
zdim = calcZDim(cf)

src/ODE/DERelative.jl

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
using .DifferentialEquations
55

6+
import .DifferentialEquations: solve
7+
68
import IncrementalInference: getSample
79

810
export DERelative
@@ -56,7 +58,7 @@ end
5658
# - Can change numerical data return type using an additional first argument, `_calcTimespan(Float32, Xi)`.
5759
# _calcTimespan(Xi::AbstractVector{<:DFGVariable}) = _calcTimespan(Float64, Xi)
5860

59-
# performance helper function
61+
# performance helper function, FIXME not compatible with all multihypo cases
6062
_maketuplebeyond2args = (w1=nothing,w2=nothing,w3_...) -> (w3_...,)
6163

6264

@@ -85,7 +87,7 @@ function DERelative( Xi::AbstractVector{<:DFGVariable},
8587
end
8688

8789

88-
DERelative(dfg::AbstractDFG,
90+
DERelative( dfg::AbstractDFG,
8991
labels::AbstractVector{Symbol},
9092
domain::Type{<:InferenceVariable},
9193
f::Function,
@@ -101,88 +103,95 @@ DERelative(dfg::AbstractDFG,
101103

102104

103105
# Xtra splat are variable points (X3::Matrix, X4::Matrix,...)
104-
function _solveFactorODE!(measArr, prob, u0pts, i, Xtra...)
106+
function _solveFactorODE!(measArr, prob, u0pts, Xtra...)
105107
# should more variables be included in calculation
106108
for (xid, xtra) in enumerate(Xtra)
107109
# update the data register before ODE solver calls the function
108-
prob.p[xid+1][:] = Xtra[xid][:,i]
110+
prob.p[xid+1][:] = Xtra[xid][:]
109111
end
110112

111113
# set the initial condition
112-
prob.u0[:] = u0pts[:,i]
114+
prob.u0[:] = u0pts[:]
113115
sol = DifferentialEquations.solve(prob)
114116

115117
# extract solution from solved ode
116-
measArr[:,i] = sol.u[end]
118+
measArr[:] = sol.u[end]
117119
sol
118120
end
119121

120122
# FIXME see #1025, `multihypo=` will not work properly yet
121-
function getSample( oder::DERelative,
122-
N::Int=1,
123-
fmd_...)
123+
function getSample( cf::CalcFactor{<:DERelative},
124+
N::Int=1 )
124125
#
125126

127+
oder = cf.factor
128+
fmd_ = cf.metadata
129+
126130
# how many trajectories to propagate?
127-
meas = zeros(getDimension(fmd_[1].fullvariables[2]), N)
131+
# @show getLabel(fmd_.fullvariables[2]), getDimension(fmd_.fullvariables[2])
132+
meas = zeros(getDimension(fmd_.fullvariables[2]), N)
128133

129134
# pick forward or backward direction
130135
prob = oder.forwardProblem
131136
# buffer manifold operations for use during factor evaluation
132-
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks( getManifolds(fmd_[1].fullvariables[2]) )
137+
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks( getManifolds(fmd_.fullvariables[2]) )
133138
# set boundary condition
134-
u0pts = if fmd_[1].solvefor == DFG.getLabel(fmd_[1].fullvariables[1])
139+
u0pts = if fmd_.solvefor == DFG.getLabel(fmd_.fullvariables[1])
135140
# backward direction
136141
prob = oder.backwardProblem
137-
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks( getManifolds(fmd_[1].fullvariables[1]) )
138-
getBelief( fmd_[1].fullvariables[2] ) |> getPoints
142+
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks( getManifolds(fmd_.fullvariables[1]) )
143+
getBelief( fmd_.fullvariables[2] ) |> getPoints
139144
else
140145
# forward backward
141-
getBelief( fmd_[1].fullvariables[1] ) |> getPoints
146+
getBelief( fmd_.fullvariables[1] ) |> getPoints
142147
end
143148

144149
# solve likely elements
145150
for i in 1:N
146-
_solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(fmd_[1].arrRef...)...)
151+
idxArr = view.(fmd_.arrRef,:,i)
152+
_solveFactorODE!(view(meas, :,i), prob, view(u0pts, :,i), _maketuplebeyond2args(idxArr...)...)
153+
# _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(fmd_.arrRef...)...)
147154
end
148155

149156
return (meas, diffOp)
150157
end
151158
# getDimension(oderel.domain)
152159

153160

154-
# FIXME see #1025, `multihypo=` will not work properly yet
155-
function (oderel::DERelative)( res::AbstractVector{<:Real},
156-
fmd::FactorMetadata,
157-
idx::Int,
158-
meas::Tuple,
159-
X...)
161+
# NOTE see #1025, CalcFactor should fix `multihypo=` in `cf.metadata` fields
162+
function (cf::CalcFactor{<:DERelative})(res::AbstractVector{<:Real},
163+
meas1,
164+
diffOp,
165+
X...)
160166
#
161167

168+
oderel = cf.factor
169+
162170
# work on-manifold
163-
diffOp = meas[2]
171+
# diffOp = meas[2]
164172
# if backwardSolve else forward
165173

166174
# check direction
167175
# TODO put solveforIdx in FMD?
168176
solveforIdx = 1
169-
if fmd.solvefor == DFG.getLabel(fmd.fullvariables[2])
177+
if cf.metadata.solvefor == DFG.getLabel(cf.metadata.fullvariables[2])
170178
solveforIdx = 2
171-
elseif fmd.solvefor in _maketuplebeyond2args(fmd.variablelist...)
179+
elseif cf.metadata.solvefor in _maketuplebeyond2args(cf.metadata.variablelist...)
172180
# need to recalculate new ODE (forward) for change in parameters (solving for 3rd or higher variable)
173181
solveforIdx = 2
174182
# use forward solve for all solvefor not in [1;2]
175-
u0pts = getBelief(fmd.fullvariables[1]) |> getPoints
183+
u0pts = getBelief(cf.metadata.fullvariables[1]) |> getPoints
176184
# update parameters for additional variables
177-
_solveFactorODE!(meas[1], oderel.forwardProblem, u0pts, idx, _maketuplebeyond2args(X...)...)
185+
_solveFactorODE!(meas1, oderel.forwardProblem, u0pts[:,cf._sampleIdx], _maketuplebeyond2args(X...)...)
178186
end
179187

180188
# find the difference between measured and predicted.
181-
## assuming the ODE integrated from current X1 through to predicted X2 (ie `meas[1][:,idx]`)
189+
## assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`)
182190
## FIXME, obviously this is not going to work for more compilcated groups/manifolds -- must fix this soon!
191+
# @show cf._sampleIdx, solveforIdx, meas1
183192
for i in 1:size(X[2],1)
184193
# diffop( test, reference ) <===> ΔX = test \ reference
185-
res[i] = diffOp[i]( X[solveforIdx][i,idx], meas[1][i,idx] )
194+
res[i] = diffOp[i]( X[solveforIdx][i], meas1[i] )
186195
end
187196
res
188197
end

test/testCSMMonitor.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,18 @@ struct BrokenFactor{T<: SamplableBelief} <: AbstractRelativeRoots
66
Z::T
77
end
88

9-
IncrementalInference.getSample(s::BrokenFactor, N::Int=1) = (reshape(rand(s.Z, N),:,N), )
10-
11-
function (s::BrokenFactor{<: SamplableBelief})(res::AbstractVector{<:Real},
12-
userdata::FactorMetadata,
13-
idx::Int,
14-
meas::Tuple,
15-
wxi::AbstractArray{<:Real,2},
16-
wxj::AbstractArray{<:Real,2} )
9+
IncrementalInference.getSample(cf::CalcFactor{<:BrokenFactor}, N::Int=1) = (reshape(rand(cf.factor.Z, N),:,N), )
10+
11+
function (s::CalcFactor{<:BrokenFactor})(res::AbstractVector{<:Real},
12+
z,
13+
wxi,
14+
wxj )
1715
#
1816
error("User factor has a bug.")
1917
nothing
2018
end
2119

20+
# FIXME consolidate with CalcFactor according to #467
2221
function (s::BrokenFactor{<:IIF.ParametricTypes})(X1::AbstractArray{<:Real},
2322
X2::AbstractArray{<:Real};
2423
userdata::Union{Nothing,FactorMetadata}=nothing )

0 commit comments

Comments
 (0)