33
44using . DifferentialEquations
55
6+ import . DifferentialEquations: solve
7+
68import IncrementalInference: getSample
79
810export DERelative
5658# - Can change numerical data return type using an additional first argument, `_calcTimespan(Float32, Xi)`.
5759# _calcTimespan(Xi::AbstractVector{<:DFGVariable}) = _calcTimespan(Float64, Xi)
5860
59- # performance helper function
61+ # performance helper function, FIXME not compatible with all multihypo cases
6062_maketuplebeyond2args = (w1= nothing ,w2= nothing ,w3_... ) -> (w3_... ,)
6163
6264
@@ -85,7 +87,7 @@ function DERelative( Xi::AbstractVector{<:DFGVariable},
8587end
8688
8789
88- DERelative (dfg:: AbstractDFG ,
90+ DERelative ( dfg:: AbstractDFG ,
8991 labels:: AbstractVector{Symbol} ,
9092 domain:: Type{<:InferenceVariable} ,
9193 f:: Function ,
@@ -101,88 +103,95 @@ DERelative(dfg::AbstractDFG,
101103
102104
103105# Xtra splat are variable points (X3::Matrix, X4::Matrix,...)
104- function _solveFactorODE! (measArr, prob, u0pts, i, Xtra... )
106+ function _solveFactorODE! (measArr, prob, u0pts, Xtra... )
105107 # should more variables be included in calculation
106108 for (xid, xtra) in enumerate (Xtra)
107109 # update the data register before ODE solver calls the function
108- prob. p[xid+ 1 ][:] = Xtra[xid][:,i ]
110+ prob. p[xid+ 1 ][:] = Xtra[xid][:]
109111 end
110112
111113 # set the initial condition
112- prob. u0[:] = u0pts[:,i ]
114+ prob. u0[:] = u0pts[:]
113115 sol = DifferentialEquations. solve (prob)
114116
115117 # extract solution from solved ode
116- measArr[:,i ] = sol. u[end ]
118+ measArr[:] = sol. u[end ]
117119 sol
118120end
119121
120122# FIXME see #1025, `multihypo=` will not work properly yet
121- function getSample ( oder:: DERelative ,
122- N:: Int = 1 ,
123- fmd_... )
123+ function getSample ( cf:: CalcFactor{<:DERelative} ,
124+ N:: Int = 1 )
124125 #
125126
127+ oder = cf. factor
128+ fmd_ = cf. metadata
129+
126130 # how many trajectories to propagate?
127- meas = zeros (getDimension (fmd_[1 ]. fullvariables[2 ]), N)
131+ # @show getLabel(fmd_.fullvariables[2]), getDimension(fmd_.fullvariables[2])
132+ meas = zeros (getDimension (fmd_. fullvariables[2 ]), N)
128133
129134 # pick forward or backward direction
130135 prob = oder. forwardProblem
131136 # buffer manifold operations for use during factor evaluation
132- addOp, diffOp, _, _ = AMP. buildHybridManifoldCallbacks ( getManifolds (fmd_[ 1 ] . fullvariables[2 ]) )
137+ addOp, diffOp, _, _ = AMP. buildHybridManifoldCallbacks ( getManifolds (fmd_. fullvariables[2 ]) )
133138 # set boundary condition
134- u0pts = if fmd_[ 1 ] . solvefor == DFG. getLabel (fmd_[ 1 ] . fullvariables[1 ])
139+ u0pts = if fmd_. solvefor == DFG. getLabel (fmd_. fullvariables[1 ])
135140 # backward direction
136141 prob = oder. backwardProblem
137- addOp, diffOp, _, _ = AMP. buildHybridManifoldCallbacks ( getManifolds (fmd_[ 1 ] . fullvariables[1 ]) )
138- getBelief ( fmd_[ 1 ] . fullvariables[2 ] ) |> getPoints
142+ addOp, diffOp, _, _ = AMP. buildHybridManifoldCallbacks ( getManifolds (fmd_. fullvariables[1 ]) )
143+ getBelief ( fmd_. fullvariables[2 ] ) |> getPoints
139144 else
140145 # forward backward
141- getBelief ( fmd_[ 1 ] . fullvariables[1 ] ) |> getPoints
146+ getBelief ( fmd_. fullvariables[1 ] ) |> getPoints
142147 end
143148
144149 # solve likely elements
145150 for i in 1 : N
146- _solveFactorODE! (meas, prob, u0pts, i, _maketuplebeyond2args (fmd_[1 ]. arrRef... )... )
151+ idxArr = view .(fmd_. arrRef,:,i)
152+ _solveFactorODE! (view (meas, :,i), prob, view (u0pts, :,i), _maketuplebeyond2args (idxArr... )... )
153+ # _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(fmd_.arrRef...)...)
147154 end
148155
149156 return (meas, diffOp)
150157end
151158# getDimension(oderel.domain)
152159
153160
154- # FIXME see #1025, `multihypo=` will not work properly yet
155- function (oderel:: DERelative )( res:: AbstractVector{<:Real} ,
156- fmd:: FactorMetadata ,
157- idx:: Int ,
158- meas:: Tuple ,
159- X... )
161+ # NOTE see #1025, CalcFactor should fix `multihypo=` in `cf.metadata` fields
162+ function (cf:: CalcFactor{<:DERelative} )(res:: AbstractVector{<:Real} ,
163+ meas1,
164+ diffOp,
165+ X... )
160166 #
161167
168+ oderel = cf. factor
169+
162170 # work on-manifold
163- diffOp = meas[2 ]
171+ # diffOp = meas[2]
164172 # if backwardSolve else forward
165173
166174 # check direction
167175 # TODO put solveforIdx in FMD?
168176 solveforIdx = 1
169- if fmd . solvefor == DFG. getLabel (fmd . fullvariables[2 ])
177+ if cf . metadata . solvefor == DFG. getLabel (cf . metadata . fullvariables[2 ])
170178 solveforIdx = 2
171- elseif fmd . solvefor in _maketuplebeyond2args (fmd . variablelist... )
179+ elseif cf . metadata . solvefor in _maketuplebeyond2args (cf . metadata . variablelist... )
172180 # need to recalculate new ODE (forward) for change in parameters (solving for 3rd or higher variable)
173181 solveforIdx = 2
174182 # use forward solve for all solvefor not in [1;2]
175- u0pts = getBelief (fmd . fullvariables[1 ]) |> getPoints
183+ u0pts = getBelief (cf . metadata . fullvariables[1 ]) |> getPoints
176184 # update parameters for additional variables
177- _solveFactorODE! (meas[ 1 ] , oderel. forwardProblem, u0pts, idx , _maketuplebeyond2args (X... )... )
185+ _solveFactorODE! (meas1 , oderel. forwardProblem, u0pts[:,cf . _sampleIdx] , _maketuplebeyond2args (X... )... )
178186 end
179187
180188 # find the difference between measured and predicted.
181- # # assuming the ODE integrated from current X1 through to predicted X2 (ie `meas[1] [:,idx]`)
189+ # # assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1 [:,idx]`)
182190 # # FIXME , obviously this is not going to work for more compilcated groups/manifolds -- must fix this soon!
191+ # @show cf._sampleIdx, solveforIdx, meas1
183192 for i in 1 : size (X[2 ],1 )
184193 # diffop( test, reference ) <===> ΔX = test \ reference
185- res[i] = diffOp[i]( X[solveforIdx][i,idx ], meas[ 1 ][i,idx ] )
194+ res[i] = diffOp[i]( X[solveforIdx][i], meas1[i ] )
186195 end
187196 res
188197end
0 commit comments