Skip to content

Commit c89d5a0

Browse files
committed
Add access functions and fix default integrate
1 parent 8c9fa06 commit c89d5a0

1 file changed

Lines changed: 125 additions & 64 deletions

File tree

src/problem_types/ODERelaxProb.jl

Lines changed: 125 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,10 @@ mutable struct ODELocalIntegrator{N}
211211
integrator
212212
p::Vector{Float64}
213213
pduals::Vector{Dual{Nothing,Float64,N}}
214-
x::ElasticArray{Float64,2}
215-
dxdp::Vector{ElasticArray{Float64,2}}
216-
xduals::Vector{Dual{Nothing,Float64,N}}
214+
x0::Vector{Float64}
215+
x::Matrix{Float64}
216+
dxdp::Vector{Matrix{Float64}}
217+
x0duals::Vector{Dual{Nothing,Float64,N}}
217218
user_t::Vector{Float64}
218219
integrator_t::Vector{Float64}
219220
local_t_dict_flt::Dict{Float64,Int64}
@@ -232,19 +233,21 @@ mutable struct ODELocalIntegrator{N}
232233
zeros(Float64, prob.nx),
233234
prob.tspan,
234235
prob.p)
235-
d.x = zeros(Float64, prob.nx, length(prob.tsupports))
236-
dxdp = ElasticArray{Float64,2}[]
236+
d.x0 = zeros(Float64, prob.nx)
237+
dxdp = Matrix{Float64}[]
237238
for i = 1:prob.np
238-
push!(dxdp, ElasticArray(zeros(Float64, prob.nx, length(prob.tsupports))))
239+
push!(dxdp, zeros(Float64, prob.nx, length(prob.support_set.s)))
239240
end
241+
d.x = zeros(prob.nx, length(prob.support_set.s))
240242
d.dxdp = dxdp
241243
d.p = copy(prob.p)
242244
d.pduals = seed_duals(prob.p)
243-
d.xduals = fill(Dual{Nothing}(0.0,
245+
d.x0duals = fill(Dual{Nothing}(0.0,
244246
single_seed(Partials{N, Float64}, Val(1))),
245247
(prob.nx,))
246-
d.user_t = prob.support_set.s
247-
d.integrator_t = prob.support_set.s
248+
support_set = get(prob, SupportSet())
249+
d.user_t = copy(support_set.s)
250+
d.integrator_t = copy(support_set.s)
248251
d.local_t_dict_flt = Dict{Float64,Int64}()
249252
d.local_t_dict_indx = Dict{Int64,Int64}()
250253
for (i,s) in enumerate(d.user_t)
@@ -258,92 +261,116 @@ function ODELocalIntegrator(prob::ODERelaxProb, integrator)
258261
ODELocalIntegrator{prob.np}(prob, integrator)
259262
end
260263

261-
function integrate!(d::AbstractODERelaxIntegrator, p::ODERelaxProb)
264+
function integrate!(::Val{true}, d::AbstractODERelaxIntegrator, p::ODERelaxProb)
262265

263-
local_ode_storage = DBB.get(d, DBB.LocalIntegrator())
266+
local_prob_storage = get(d, LocalIntegrator())
267+
np = get(d, ParameterNumber())
268+
nx = get(d, StateNumber())
269+
if size(local_prob_storage.x0, 1) != nx + nx*np
270+
local_prob_storage.x0 = zeros(nx + nx*np)
271+
end
272+
for i = 1:nx
273+
local_prob_storage.x0[i] = local_prob_storage.x0duals[i].value
274+
for j = 1:np
275+
local_prob_storage.x0[(nx + j + (i-1)*np)] = local_prob_storage.x0duals[i].partials[j]
276+
end
277+
end
264278

265-
np = DBB.get(d, DBB.ParameterNumber())
266-
nx = DBB.get(d, DBB.StateNumber)
267-
DBB.getall!(local_prob_storage.p, d, DBB.ParameterValue())
268-
local_prob_storage.pduals .= seed_duals(local_prob_storage.p, 1:np)
279+
local_prob_storage.sensitivity_problem = remake(local_prob_storage.sensitivity_problem,
280+
u0 = local_prob_storage.x0,
281+
p = local_prob_storage.p)
269282

270-
initial_condition!(local_prob_storage.x0duals, d, d.local_problem_storage.pduals)
283+
solution = solve(local_prob_storage.sensitivity_problem,
284+
local_prob_storage.integrator,
285+
saveat = local_prob_storage.user_t,
286+
abstol = local_prob_storage.abs_tol,
287+
tstops = local_prob_storage.user_t,
288+
reltol = local_prob_storage.rel_tol)
271289

272-
if !DBB.get(t.integrator, DBB.LocalSensitivityOn())
273-
if length(local_prob_storage.x0local) != d.nx
274-
resize!(local_prob_storage.x0local, d.nx)
275-
end
290+
x, dxdp = extract_local_sensitivities(solution)
291+
292+
new_length = size(x, 2)
293+
prior_length = length(local_prob_storage.integrator_t)
294+
if new_length == prior_length
295+
local_prob_storage.integrator_t .= solution.t
276296
else
277-
if length(local_prob_storage.x0local) != nx*(np + 1)
278-
resize!(local_prob_storage.x0local, nx*(np + 1))
297+
local_prob_storage.x = zeros(nx, new_length)
298+
for i = 1:np
299+
local_prob_storage.dxdp[i] = zeros(nx, new_length)
279300
end
301+
local_prob_storage.integrator_t = solution.t
302+
end
303+
local_prob_storage.x .= x
304+
for i = 1:np
305+
local_prob_storage.dxdp[i] .= dxdp[i]
280306
end
281307

282-
for i = 1:nx
283-
local_prob_storage.x0local[i] = d.local_prob_storage.x0duals[i].value
284-
if DBB.get(t.integrator, DBB.LocalSensitivityOn())
285-
for j = 1:np
286-
local_prob_storage.x0local[(nx + j + (i-1)*np)] = local_prob_storage.x0duals[i].partials[j]
287-
end
288-
end
308+
return solution.t
309+
end
310+
311+
function integrate!(::Val{false}, d::AbstractODERelaxIntegrator, p::ODERelaxProb)
312+
313+
local_prob_storage = get(d, LocalIntegrator())
314+
np = get(d, ParameterNumber())
315+
nx = get(d, StateNumber())
316+
317+
if size(local_prob_storage.x0, 1) != nx
318+
local_prob_storage.x0 = zeros(nx)
289319
end
290-
local_prob_storage.problem = remake(local_problem_storage.problem,
291-
u0 = local_problem_storage.x0local,
292-
p = local_prob_storage.p)
293-
294-
if ~isempty(local_prob_storage.user_t)
295-
solution = solve(local_prob_storage.problem,
296-
local_prob_storage.integrator,
297-
saveat = local_prob_storage.user_t,
298-
abstol = local_prob_storage.abs_tol,
299-
tstops = local_prob_storage.user_t,
300-
adaptive = false,
301-
reltol = local_prob_storage.rel_tol)
302-
else
303-
solution = solve(local_prob_storage.problem,
304-
local_prob_storage.integrator,
305-
abstol = local_prob_storage.abs_tol,
306-
reltol = local_prob_storage.rel_tol)
320+
for i = 1:nx
321+
local_prob_storage.x0[i] = local_prob_storage.x0duals[i].value
307322
end
308323

309-
new_length = length(solution.t)
310-
if DBB.get(d, DBB.LocalSensitivityOn())
311-
x, dxdp = extract_local_sensitivities(solution)
312-
else
313-
x = solution.u
314-
end
324+
local_prob_storage.ode_problem = remake(local_prob_storage.ode_problem,
325+
u0 = local_prob_storage.x0,
326+
p = local_prob_storage.p)
315327

316-
resize!(local_prob_storage.pode_x, nx, new_length)
317-
resize!(local_prob_storage.integrator_t, new_length)
328+
solution = solve(local_prob_storage.ode_problem,
329+
local_prob_storage.integrator,
330+
saveat = local_prob_storage.user_t,
331+
abstol = local_prob_storage.abs_tol,
332+
tstops = local_prob_storage.user_t,
333+
reltol = local_prob_storage.rel_tol)
334+
335+
x = solution.u
336+
337+
new_length = length(x)
318338
prior_length = length(local_prob_storage.integrator_t)
319339
if new_length == prior_length
320340
local_prob_storage.integrator_t .= solution.t
321341
else
342+
local_prob_storage.x = zeros(nx, new_length)
322343
local_prob_storage.integrator_t = solution.t
323344
end
324-
325345
for i = 1:new_length
326-
local_prob_storage.pode_x[:,i] .= x[i]
327-
end
328-
if DBB.get(d, DBB.LocalSensitivityOn())
329-
for i = 1:np
330-
resize!(local_prob_storage.pode_dxdp[i], nx, new_length)
331-
local_prob_storage.pode_dxdp[i] .= dxdp[i]
332-
end
346+
local_prob_storage.x[:,i] .= x[i]
333347
end
334348

349+
return solution.t
350+
end
351+
352+
function integrate!(d::AbstractODERelaxIntegrator, p::ODERelaxProb)
353+
354+
local_prob_storage = get(d, LocalIntegrator())::ODELocalIntegrator
355+
356+
getall!(local_prob_storage.p, d, ParameterValue())
357+
local_prob_storage.pduals .= seed_duals(local_prob_storage.p)
358+
local_prob_storage.x0duals = p.x0(d.local_problem_storage.pduals)
359+
360+
solution_t = integrate!(Val(get(d, LocalSensitivityOn())), d, p)
361+
335362
empty!(local_prob_storage.local_t_dict_flt)
336363
empty!(local_prob_storage.local_t_dict_indx)
337364

338-
for (tindx, t) in enumerate(solution.t)
365+
for (tindx, t) in enumerate(solution_t)
339366
local_prob_storage.local_t_dict_flt[t] = tindx
340367
end
341368

342369
if !isempty(local_prob_storage.user_t)
343370
next_support_time = local_prob_storage.user_t[1]
344371
supports_left = length(local_prob_storage.user_t)
345372
loc_count = 1
346-
for (tindx, t) in enumerate(solution.t)
373+
for (tindx, t) in enumerate(solution_t)
347374
if t == next_support_time
348375
local_prob_storage.local_t_dict_indx[loc_count] = tindx
349376
loc_count += 1
@@ -357,3 +384,37 @@ function integrate!(d::AbstractODERelaxIntegrator, p::ODERelaxProb)
357384

358385
return nothing
359386
end
387+
388+
389+
function get_val_loc_local(t::AbstractODERelaxIntegrator, index::Int64, time::Float64)
390+
391+
local_prob_storage = get(t, LocalIntegrator())::ODELocalIntegrator
392+
393+
(index <= 0 && time == -Inf) && error("Must set either index or time.")
394+
if index > 0
395+
return local_prob_storage.local_t_dict_indx[index]
396+
end
397+
local_prob_storage.local_t_dict_flt[time]
398+
end
399+
400+
function get(out::Vector{Float64}, t::AbstractODERelaxIntegrator, v::Value)
401+
local_prob_storage = get(t, LocalIntegrator())::ODELocalIntegrator
402+
val_loc = get_val_loc_local(t, v.index, v.time)
403+
out .= local_prob_storage.x[:, val_loc]
404+
return
405+
end
406+
407+
408+
function getall!(out::Array{Float64,2}, t::AbstractODERelaxIntegrator, v::Value)
409+
local_prob_storage = get(t, LocalIntegrator())::ODELocalIntegrator
410+
copyto!(out, local_prob_storage.x)
411+
return
412+
end
413+
414+
function getall!(out::Vector{Array{Float64,2}}, t::AbstractODERelaxIntegrator, g::Gradient{Nominal})
415+
local_prob_storage = get(t, LocalIntegrator())::ODELocalIntegrator
416+
for i = 1:get(t, ParameterNumber())
417+
copyto!(out[i], local_prob_storage.dxdp[i])
418+
end
419+
return
420+
end

0 commit comments

Comments
 (0)