@@ -211,9 +211,10 @@ mutable struct ODELocalIntegrator{N}
211211 integrator
212212 p:: Vector{Float64}
213213 pduals:: Vector{Dual{Nothing,Float64,N}}
214- x:: ElasticArray{Float64,2}
215- dxdp:: Vector{ElasticArray{Float64,2}}
216- xduals:: Vector{Dual{Nothing,Float64,N}}
214+ x0:: Vector{Float64}
215+ x:: Matrix{Float64}
216+ dxdp:: Vector{Matrix{Float64}}
217+ x0duals:: Vector{Dual{Nothing,Float64,N}}
217218 user_t:: Vector{Float64}
218219 integrator_t:: Vector{Float64}
219220 local_t_dict_flt:: Dict{Float64,Int64}
@@ -232,19 +233,21 @@ mutable struct ODELocalIntegrator{N}
232233 zeros (Float64, prob. nx),
233234 prob. tspan,
234235 prob. p)
235- d. x = zeros (Float64, prob. nx, length (prob . tsupports) )
236- dxdp = ElasticArray {Float64, 2 }[]
236+ d. x0 = zeros (Float64, prob. nx)
237+ dxdp = Matrix {Float64}[]
237238 for i = 1 : prob. np
238- push! (dxdp, ElasticArray ( zeros (Float64, prob. nx, length (prob. tsupports) )))
239+ push! (dxdp, zeros (Float64, prob. nx, length (prob. support_set . s )))
239240 end
241+ d. x = zeros (prob. nx, length (prob. support_set. s))
240242 d. dxdp = dxdp
241243 d. p = copy (prob. p)
242244 d. pduals = seed_duals (prob. p)
243- d. xduals = fill (Dual {Nothing} (0.0 ,
245+ d. x0duals = fill (Dual {Nothing} (0.0 ,
244246 single_seed (Partials{N, Float64}, Val (1 ))),
245247 (prob. nx,))
246- d. user_t = prob. support_set. s
247- d. integrator_t = prob. support_set. s
248+ support_set = get (prob, SupportSet ())
249+ d. user_t = copy (support_set. s)
250+ d. integrator_t = copy (support_set. s)
248251 d. local_t_dict_flt = Dict {Float64,Int64} ()
249252 d. local_t_dict_indx = Dict {Int64,Int64} ()
250253 for (i,s) in enumerate (d. user_t)
@@ -258,92 +261,116 @@ function ODELocalIntegrator(prob::ODERelaxProb, integrator)
258261 ODELocalIntegrator {prob.np} (prob, integrator)
259262end
260263
261- function integrate! (d:: AbstractODERelaxIntegrator , p:: ODERelaxProb )
264+ function integrate! (:: Val{true} , d:: AbstractODERelaxIntegrator , p:: ODERelaxProb )
262265
263- local_ode_storage = DBB. get (d, DBB. LocalIntegrator ())
266+ local_prob_storage = get (d, LocalIntegrator ())
267+ np = get (d, ParameterNumber ())
268+ nx = get (d, StateNumber ())
269+ if size (local_prob_storage. x0, 1 ) != nx + nx* np
270+ local_prob_storage. x0 = zeros (nx + nx* np)
271+ end
272+ for i = 1 : nx
273+ local_prob_storage. x0[i] = local_prob_storage. x0duals[i]. value
274+ for j = 1 : np
275+ local_prob_storage. x0[(nx + j + (i- 1 )* np)] = local_prob_storage. x0duals[i]. partials[j]
276+ end
277+ end
264278
265- np = DBB. get (d, DBB. ParameterNumber ())
266- nx = DBB. get (d, DBB. StateNumber)
267- DBB. getall! (local_prob_storage. p, d, DBB. ParameterValue ())
268- local_prob_storage. pduals .= seed_duals (local_prob_storage. p, 1 : np)
279+ local_prob_storage. sensitivity_problem = remake (local_prob_storage. sensitivity_problem,
280+ u0 = local_prob_storage. x0,
281+ p = local_prob_storage. p)
269282
270- initial_condition! (local_prob_storage. x0duals, d, d. local_problem_storage. pduals)
283+ solution = solve (local_prob_storage. sensitivity_problem,
284+ local_prob_storage. integrator,
285+ saveat = local_prob_storage. user_t,
286+ abstol = local_prob_storage. abs_tol,
287+ tstops = local_prob_storage. user_t,
288+ reltol = local_prob_storage. rel_tol)
271289
272- if ! DBB. get (t. integrator, DBB. LocalSensitivityOn ())
273- if length (local_prob_storage. x0local) != d. nx
274- resize! (local_prob_storage. x0local, d. nx)
275- end
290+ x, dxdp = extract_local_sensitivities (solution)
291+
292+ new_length = size (x, 2 )
293+ prior_length = length (local_prob_storage. integrator_t)
294+ if new_length == prior_length
295+ local_prob_storage. integrator_t .= solution. t
276296 else
277- if length (local_prob_storage. x0local) != nx* (np + 1 )
278- resize! (local_prob_storage. x0local, nx* (np + 1 ))
297+ local_prob_storage. x = zeros (nx, new_length)
298+ for i = 1 : np
299+ local_prob_storage. dxdp[i] = zeros (nx, new_length)
279300 end
301+ local_prob_storage. integrator_t = solution. t
302+ end
303+ local_prob_storage. x .= x
304+ for i = 1 : np
305+ local_prob_storage. dxdp[i] .= dxdp[i]
280306 end
281307
282- for i = 1 : nx
283- local_prob_storage. x0local[i] = d. local_prob_storage. x0duals[i]. value
284- if DBB. get (t. integrator, DBB. LocalSensitivityOn ())
285- for j = 1 : np
286- local_prob_storage. x0local[(nx + j + (i- 1 )* np)] = local_prob_storage. x0duals[i]. partials[j]
287- end
288- end
308+ return solution. t
309+ end
310+
311+ function integrate! (:: Val{false} , d:: AbstractODERelaxIntegrator , p:: ODERelaxProb )
312+
313+ local_prob_storage = get (d, LocalIntegrator ())
314+ np = get (d, ParameterNumber ())
315+ nx = get (d, StateNumber ())
316+
317+ if size (local_prob_storage. x0, 1 ) != nx
318+ local_prob_storage. x0 = zeros (nx)
289319 end
290- local_prob_storage. problem = remake (local_problem_storage. problem,
291- u0 = local_problem_storage. x0local,
292- p = local_prob_storage. p)
293-
294- if ~ isempty (local_prob_storage. user_t)
295- solution = solve (local_prob_storage. problem,
296- local_prob_storage. integrator,
297- saveat = local_prob_storage. user_t,
298- abstol = local_prob_storage. abs_tol,
299- tstops = local_prob_storage. user_t,
300- adaptive = false ,
301- reltol = local_prob_storage. rel_tol)
302- else
303- solution = solve (local_prob_storage. problem,
304- local_prob_storage. integrator,
305- abstol = local_prob_storage. abs_tol,
306- reltol = local_prob_storage. rel_tol)
320+ for i = 1 : nx
321+ local_prob_storage. x0[i] = local_prob_storage. x0duals[i]. value
307322 end
308323
309- new_length = length (solution. t)
310- if DBB. get (d, DBB. LocalSensitivityOn ())
311- x, dxdp = extract_local_sensitivities (solution)
312- else
313- x = solution. u
314- end
324+ local_prob_storage. ode_problem = remake (local_prob_storage. ode_problem,
325+ u0 = local_prob_storage. x0,
326+ p = local_prob_storage. p)
315327
316- resize! (local_prob_storage. pode_x, nx, new_length)
317- resize! (local_prob_storage. integrator_t, new_length)
328+ solution = solve (local_prob_storage. ode_problem,
329+ local_prob_storage. integrator,
330+ saveat = local_prob_storage. user_t,
331+ abstol = local_prob_storage. abs_tol,
332+ tstops = local_prob_storage. user_t,
333+ reltol = local_prob_storage. rel_tol)
334+
335+ x = solution. u
336+
337+ new_length = length (x)
318338 prior_length = length (local_prob_storage. integrator_t)
319339 if new_length == prior_length
320340 local_prob_storage. integrator_t .= solution. t
321341 else
342+ local_prob_storage. x = zeros (nx, new_length)
322343 local_prob_storage. integrator_t = solution. t
323344 end
324-
325345 for i = 1 : new_length
326- local_prob_storage. pode_x[:,i] .= x[i]
327- end
328- if DBB. get (d, DBB. LocalSensitivityOn ())
329- for i = 1 : np
330- resize! (local_prob_storage. pode_dxdp[i], nx, new_length)
331- local_prob_storage. pode_dxdp[i] .= dxdp[i]
332- end
346+ local_prob_storage. x[:,i] .= x[i]
333347 end
334348
349+ return solution. t
350+ end
351+
352+ function integrate! (d:: AbstractODERelaxIntegrator , p:: ODERelaxProb )
353+
354+ local_prob_storage = get (d, LocalIntegrator ()):: ODELocalIntegrator
355+
356+ getall! (local_prob_storage. p, d, ParameterValue ())
357+ local_prob_storage. pduals .= seed_duals (local_prob_storage. p)
358+ local_prob_storage. x0duals = p. x0 (d. local_problem_storage. pduals)
359+
360+ solution_t = integrate! (Val (get (d, LocalSensitivityOn ())), d, p)
361+
335362 empty! (local_prob_storage. local_t_dict_flt)
336363 empty! (local_prob_storage. local_t_dict_indx)
337364
338- for (tindx, t) in enumerate (solution . t )
365+ for (tindx, t) in enumerate (solution_t )
339366 local_prob_storage. local_t_dict_flt[t] = tindx
340367 end
341368
342369 if ! isempty (local_prob_storage. user_t)
343370 next_support_time = local_prob_storage. user_t[1 ]
344371 supports_left = length (local_prob_storage. user_t)
345372 loc_count = 1
346- for (tindx, t) in enumerate (solution . t )
373+ for (tindx, t) in enumerate (solution_t )
347374 if t == next_support_time
348375 local_prob_storage. local_t_dict_indx[loc_count] = tindx
349376 loc_count += 1
@@ -357,3 +384,37 @@ function integrate!(d::AbstractODERelaxIntegrator, p::ODERelaxProb)
357384
358385 return nothing
359386end
387+
388+
389+ function get_val_loc_local (t:: AbstractODERelaxIntegrator , index:: Int64 , time:: Float64 )
390+
391+ local_prob_storage = get (t, LocalIntegrator ()):: ODELocalIntegrator
392+
393+ (index <= 0 && time == - Inf ) && error (" Must set either index or time." )
394+ if index > 0
395+ return local_prob_storage. local_t_dict_indx[index]
396+ end
397+ local_prob_storage. local_t_dict_flt[time]
398+ end
399+
400+ function get (out:: Vector{Float64} , t:: AbstractODERelaxIntegrator , v:: Value )
401+ local_prob_storage = get (t, LocalIntegrator ()):: ODELocalIntegrator
402+ val_loc = get_val_loc_local (t, v. index, v. time)
403+ out .= local_prob_storage. x[:, val_loc]
404+ return
405+ end
406+
407+
408+ function getall! (out:: Array{Float64,2} , t:: AbstractODERelaxIntegrator , v:: Value )
409+ local_prob_storage = get (t, LocalIntegrator ()):: ODELocalIntegrator
410+ copyto! (out, local_prob_storage. x)
411+ return
412+ end
413+
414+ function getall! (out:: Vector{Array{Float64,2}} , t:: AbstractODERelaxIntegrator , g:: Gradient{Nominal} )
415+ local_prob_storage = get (t, LocalIntegrator ()):: ODELocalIntegrator
416+ for i = 1 : get (t, ParameterNumber ())
417+ copyto! (out[i], local_prob_storage. dxdp[i])
418+ end
419+ return
420+ end
0 commit comments