@@ -2,30 +2,144 @@ using SciMLBase
22
33export IndependentlyLinearizedSolution
44
5+
6+ """
7+ CachePool(T, alloc; thread_safe = true)
8+
9+ Simple memory-reusing cache that allows us to grow a cache and keep
10+ re-using those pieces of memory (in our case, typically `u` vectors)
11+ until the solve is finished. By default, this datastructure is made
12+ to be thread-safe by locking on every acquire and release, but it
13+ can be made thread-unsafe (and correspondingly faster) by passing
14+ `thread_safe = false` to the constructor.
15+
16+ While manual usage with `acquire!()` and `release!()` is possible,
17+ most users will want to use `@with_cache`, which provides lexically-
18+ scoped `acquire!()` and `release!()` usage automatically. Example:
19+
20+ ```julia
21+ us = CachePool(Vector{S}, () -> Vector{S}(undef, num_us); thread_safe=false)
22+ @with_cache us u_prev begin
23+ @with_cache us u_next begin
24+ # perform tasks with these two `u` vectors
25+ end
26+ end
27+ ```
28+
29+ !!! warning "Escaping values"
30+ You must not use an acquired value after you have released it;
31+ the memory may be immediately re-used by some other consumer of
32+ your cache pool. Do not allow the acquired value to escape
33+ outside of the `@with_cache` block, or past a `release!()`.
34+ """
35+ mutable struct CachePool{T, THREAD_SAFE}
36+ const pool:: Vector{T}
37+ const alloc:: Function
38+ lock:: ReentrantLock
39+ num_alloced:: Int
40+ num_acquired:: Int
41+
42+ function CachePool (T, alloc:: F ; thread_safe:: Bool = true ) where {F}
43+ return new {T,Val{thread_safe}} (T[], alloc, ReentrantLock (), 0 , 0 )
44+ end
45+ end
46+ const ThreadSafeCachePool{T} = CachePool{T,Val{true }}
47+ const ThreadUnsafeCachePool{T} = CachePool{T,Val{false }}
48+
49+ """
50+ acquire!(cache::CachePool)
51+
52+ Returns a cached element of the cache pool, calling `cache.alloc()` if none
53+ are available.
54+ """
55+ Base. @inline function acquire! (cache:: CachePool{T} , _dummy = nothing ) where {T}
56+ cache. num_acquired += 1
57+ if isempty (cache. pool)
58+ cache. num_alloced += 1
59+ return cache. alloc ():: T
60+ end
61+ return pop! (cache. pool)
62+ end
63+
64+ """
65+ release!(cache::CachePool, val)
66+
67+ Returns the value `val` to the cache pool.
68+ """
69+ Base. @inline function release! (cache:: CachePool , val, _dummy = nothing )
70+ push! (cache. pool, val)
71+ cache. num_acquired -= 1
72+ end
73+
74+ function is_fully_released (cache:: CachePool , _dummy = nothing )
75+ return cache. num_acquired == 0
76+ end
77+
78+ # Thread-safe versions just sub out to the other methods, using `_dummy` to force correct dispatch
79+ acquire! (cache:: ThreadSafeCachePool ) = @lock cache. lock acquire! (cache, nothing )
80+ release! (cache:: ThreadSafeCachePool , val) = @lock cache. lock release! (cache, val, nothing )
81+ is_fully_released (cache:: ThreadSafeCachePool ) = @lock cache. lock is_fully_released (cache, nothing )
82+
83+ macro with_cache (cache, name, body)
84+ return quote
85+ $ (esc (name)) = acquire! ($ (esc (cache)))
86+ try
87+ $ (esc (body))
88+ finally
89+ release! ($ (esc (cache)), $ (esc (name)))
90+ end
91+ end
92+ end
93+
94+
95+ struct IndependentlyLinearizedSolutionChunksCache{T,S}
96+ t_chunks:: ThreadUnsafeCachePool{Vector{T}}
97+ u_chunks:: ThreadUnsafeCachePool{Matrix{S}}
98+ time_masks:: ThreadUnsafeCachePool{BitMatrix}
99+
100+ function IndependentlyLinearizedSolutionChunksCache {T,S} (num_us:: Int , num_derivatives:: Int , chunk_size:: Int ) where {T,S}
101+ t_chunks_alloc = () -> Vector {T} (undef, chunk_size)
102+ u_chunks_alloc = () -> Matrix {S} (undef, num_derivatives+ 1 , chunk_size)
103+ time_masks_alloc = () -> BitMatrix (undef, num_us, chunk_size)
104+ return new (
105+ CachePool (Vector{T}, t_chunks_alloc; thread_safe= false ),
106+ CachePool (Matrix{S}, u_chunks_alloc; thread_safe= false ),
107+ CachePool (BitMatrix, time_masks_alloc; thread_safe= false ),
108+ )
109+ end
110+ end
111+
5112"""
6113 IndependentlyLinearizedSolutionChunks
7114
8115When constructing an `IndependentlyLinearizedSolution` via the `IndependentlyLinearizingCallback`,
9116we use this indermediate structure to reduce allocations and collect the unknown number of timesteps
10117that the solve will generate.
11118"""
12- mutable struct IndependentlyLinearizedSolutionChunks{T, S}
119+ mutable struct IndependentlyLinearizedSolutionChunks{T, S, N }
13120 t_chunks:: Vector{Vector{T}}
14121 u_chunks:: Vector{Vector{Matrix{S}}}
15122 time_masks:: Vector{BitMatrix}
16123
124+ # Temporary array that gets used by `get_chunks`
125+ last_chunks:: Vector{Matrix{S}}
126+
17127 # Index of next write into the last chunk
18128 u_offsets:: Vector{Int}
19129 t_offset:: Int
20130
131+ cache:: IndependentlyLinearizedSolutionChunksCache
132+
21133 function IndependentlyLinearizedSolutionChunks {T, S} (num_us:: Int , num_derivatives:: Int = 0 ,
22- chunk_size:: Int = 100 ) where {T, S}
23- return new ([Vector {T} (undef, chunk_size)],
24- [[Matrix {S} (undef, num_derivatives+ 1 , chunk_size)] for _ in 1 : num_us],
25- [BitMatrix (undef, num_us, chunk_size)],
26- [1 for _ in 1 : num_us],
27- 1 ,
28- )
134+ chunk_size:: Int = 512 ,
135+ cache:: IndependentlyLinearizedSolutionChunksCache = IndependentlyLinearizedSolutionChunksCache {T,S} (num_us, num_derivatives, chunk_size)) where {T, S}
136+ t_chunks = [acquire! (cache. t_chunks)]
137+ u_chunks = [[acquire! (cache. u_chunks)] for _ in 1 : num_us]
138+ time_masks = [acquire! (cache. time_masks)]
139+ last_chunks = [u_chunks[u_idx][1 ] for u_idx in 1 : num_us]
140+ u_offsets = [1 for _ in 1 : num_us]
141+ t_offset = 1
142+ return new {T,S,num_derivatives} (t_chunks, u_chunks, time_masks, last_chunks, u_offsets, t_offset, cache)
29143 end
30144end
31145
@@ -44,14 +158,8 @@ function num_us(ilsc::IndependentlyLinearizedSolutionChunks)
44158 end
45159 return length (ilsc. u_chunks)
46160end
161+ num_derivatives (ilsc:: IndependentlyLinearizedSolutionChunks{T,S,N} ) where {T,S,N} = N
47162
48- function num_derivatives (ilsc:: IndependentlyLinearizedSolutionChunks )
49- # If we've been finalized, just return `0` (which means only the primal)
50- if isempty (ilsc. t_chunks)
51- return 0
52- end
53- return size (first (first (ilsc. u_chunks)), 1 ) - 1
54- end
55163
56164function Base. isempty (ilsc:: IndependentlyLinearizedSolutionChunks )
57165 return length (ilsc. t_chunks) == 1 && ilsc. t_offset == 1
@@ -61,24 +169,25 @@ function get_chunks(ilsc::IndependentlyLinearizedSolutionChunks{T, S}) where {T,
61169 # Check if we need to allocate new `t` chunk
62170 chunksize = chunk_size (ilsc)
63171 if ilsc. t_offset > chunksize
64- push! (ilsc. t_chunks, Vector {T} (undef, chunksize ))
65- push! (ilsc. time_masks, BitMatrix (undef, length ( ilsc. u_offsets), chunksize ))
172+ push! (ilsc. t_chunks, acquire! (ilsc . cache . t_chunks ))
173+ push! (ilsc. time_masks, acquire! ( ilsc. cache . time_masks ))
66174 ilsc. t_offset = 1
67175 end
68176
69177 # Check if we need to allocate any new `u` chunks (but only for those with `u_mask`)
70178 for (u_idx, u_chunks) in enumerate (ilsc. u_chunks)
71179 if ilsc. u_offsets[u_idx] > chunksize
72- push! (u_chunks, Matrix {S} (undef, num_derivatives ( ilsc) + 1 , chunksize ))
180+ push! (u_chunks, acquire! ( ilsc. cache . u_chunks ))
73181 ilsc. u_offsets[u_idx] = 1
74182 end
183+ ilsc. last_chunks[u_idx] = u_chunks[end ]
75184 end
76185
77186 # return the last chunk for each
78187 return (
79188 ilsc. t_chunks[end ],
80189 ilsc. time_masks[end ],
81- [u_chunks[ end ] for u_chunks in ilsc. u_chunks] ,
190+ ilsc. last_chunks ,
82191 )
83192end
84193
@@ -135,16 +244,18 @@ function store!(ilsc::IndependentlyLinearizedSolutionChunks{T, S},
135244 ts, time_mask, us = get_chunks (ilsc)
136245
137246 # Store into the chunks, gated by `u_mask`
138- for u_idx in 1 : size (u, 2 )
247+ @inbounds for u_idx in 1 : size (u, 2 )
139248 if u_mask[u_idx]
140249 for deriv_idx in 1 : size (u, 1 )
141250 us[u_idx][deriv_idx, ilsc. u_offsets[u_idx]] = u[deriv_idx, u_idx]
142251 end
143252 ilsc. u_offsets[u_idx] += 1
144253 end
254+
255+ # Update our `time_mask` while we're at it
256+ time_mask[u_idx, ilsc. t_offset] = u_mask[u_idx]
145257 end
146258 ts[ilsc. t_offset] = t
147- time_mask[:, ilsc. t_offset] .= u_mask
148259 ilsc. t_offset += 1
149260end
150261
@@ -161,7 +272,7 @@ efficient `iterate()` method that can be used to reconstruct coherent views
161272of the state variables at all timepoints, as well as an efficient `sample!()`
162273method that can sample at arbitrary timesteps.
163274"""
164- mutable struct IndependentlyLinearizedSolution{T, S}
275+ mutable struct IndependentlyLinearizedSolution{T, S, N }
165276 # All timepoints, shared by all `us`
166277 ts:: Vector{T}
167278
@@ -173,33 +284,42 @@ mutable struct IndependentlyLinearizedSolution{T, S}
173284 time_mask:: BitMatrix
174285
175286 # Temporary object used during construction, will be set to `nothing` at the end.
176- ilsc:: Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S}}
287+ ilsc:: Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S,N}}
288+ ilsc_cache_pool:: Union{Nothing,ThreadSafeCachePool{IndependentlyLinearizedSolutionChunksCache{T,S}}}
177289end
178290# Helper function to create an ILS wrapped around an in-progress ILSC
179- function IndependentlyLinearizedSolution (ilsc:: IndependentlyLinearizedSolutionChunks{T,S} ) where {T,S}
180- ils = IndependentlyLinearizedSolution (
291+ function IndependentlyLinearizedSolution (ilsc:: IndependentlyLinearizedSolutionChunks{T,S,N} , cache_pool = nothing ) where {T,S,N }
292+ return IndependentlyLinearizedSolution {T,S,N} (
181293 T[],
182294 Matrix{S}[],
183295 BitMatrix (undef, 0 ,0 ),
184296 ilsc,
297+ cache_pool,
185298 )
186- return ils
187299end
188300# Automatically create an ILS wrapped around an ILSC from a `prob`
189- function IndependentlyLinearizedSolution (prob:: SciMLBase.AbstractDEProblem , num_derivatives = 0 )
301+ function IndependentlyLinearizedSolution (prob:: SciMLBase.AbstractDEProblem , num_derivatives = 0 ;
302+ cache_pool = nothing ,
303+ chunk_size:: Int = 512 )
190304 T = eltype (prob. tspan)
305+ S = eltype (prob. u0)
191306 U = isnothing (prob. u0) ? Float64 : eltype (prob. u0)
192- N = isnothing (prob. u0) ? 0 : length (prob. u0)
193- chunks = IndependentlyLinearizedSolutionChunks {T,U} (N, num_derivatives)
194- return IndependentlyLinearizedSolution (chunks)
307+ num_us = isnothing (prob. u0) ? 0 : length (prob. u0)
308+ if cache_pool === nothing
309+ cache = IndependentlyLinearizedSolutionChunksCache {T,S} (num_us, num_derivatives, chunk_size)
310+ else
311+ cache = acquire! (cache_pool)
312+ end
313+ chunks = IndependentlyLinearizedSolutionChunks {T,U} (num_us, num_derivatives, chunk_size, cache)
314+ return IndependentlyLinearizedSolution (chunks, nothing )
195315end
196316
197- num_derivatives (ils :: IndependentlyLinearizedSolution ) = ! isempty (ils . us) ? size ( first (ils . us), 1 ) : 0
317+ num_derivatives (:: IndependentlyLinearizedSolution{T,S,N} ) where {T,S,N} = N
198318num_us (ils:: IndependentlyLinearizedSolution ) = length (ils. us)
199319Base. size (ils:: IndependentlyLinearizedSolution ) = size (ils. time_mask)
200320Base. length (ils:: IndependentlyLinearizedSolution ) = length (ils. ts)
201321
202- function finish! (ils:: IndependentlyLinearizedSolution )
322+ function finish! (ils:: IndependentlyLinearizedSolution{T,S} ) where {T,S}
203323 function trim_chunk (chunks:: Vector , offset)
204324 chunks = [chunk for chunk in chunks]
205325 if eltype (chunks) <: AbstractVector
@@ -216,10 +336,52 @@ function finish!(ils::IndependentlyLinearizedSolution)
216336 end
217337
218338 ilsc = ils. ilsc:: IndependentlyLinearizedSolutionChunks
219- ts = vcat (trim_chunk (ilsc. t_chunks, ilsc. t_offset)... )
220- time_mask = hcat (trim_chunk (ilsc. time_masks, ilsc. t_offset)... )
221- us = [hcat (trim_chunk (ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx])... )
222- for u_idx in 1 : length (ilsc. u_chunks)]
339+
340+ chunk_len (chunk) = size (chunk, ndims (chunk))
341+ function chunks_len (chunks:: Vector , offset)
342+ len = 0
343+ for chunk_idx in 1 : length (chunks)- 1
344+ len += chunk_len (chunks[chunk_idx])
345+ end
346+ return len + offset - 1
347+ end
348+
349+ function copy_chunk! (out:: Vector , in:: Vector , out_offset:: Int , len= chunk_len (in))
350+ for idx in 1 : len
351+ out[idx+ out_offset] = in[idx]
352+ end
353+ end
354+ function copy_chunk! (out:: AbstractMatrix , in:: AbstractMatrix , out_offset:: Int , len= chunk_len (in))
355+ for zdx in 1 : size (in, 1 )
356+ for idx in 1 : len
357+ out[zdx, idx+ out_offset] = in[zdx, idx]
358+ end
359+ end
360+ end
361+
362+ function collapse_chunks! (out, chunks, offset:: Int )
363+ write_offset = 0
364+ for chunk_idx in 1 : (length (chunks)- 1 )
365+ chunk = chunks[chunk_idx]
366+ copy_chunk! (out, chunk, write_offset)
367+ write_offset += chunk_len (chunk)
368+ end
369+ copy_chunk! (out, chunks[end ], write_offset, offset- 1 )
370+ end
371+
372+ # Collapse t_chunks
373+ ts = Vector {T} (undef, chunks_len (ilsc. t_chunks, ilsc. t_offset))
374+ collapse_chunks! (ts, ilsc. t_chunks, ilsc. t_offset)
375+
376+ # Collapse u_chunks
377+ us = Vector {Matrix{S}} (undef, length (ilsc. u_chunks))
378+ for u_idx in 1 : length (ilsc. u_chunks)
379+ us[u_idx] = Matrix {S} (undef, size (ilsc. u_chunks[u_idx][1 ],1 ), chunks_len (ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx]))
380+ collapse_chunks! (us[u_idx], ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx])
381+ end
382+
383+ time_mask = BitMatrix (undef, size (ilsc. time_masks[1 ], 1 ), chunks_len (ilsc. time_masks, ilsc. t_offset))
384+ collapse_chunks! (time_mask, ilsc. time_masks, ilsc. t_offset)
223385
224386 # Sanity-check lengths
225387 if length (ts) != size (time_mask, 2 )
@@ -238,7 +400,24 @@ function finish!(ils::IndependentlyLinearizedSolution)
238400 throw (ArgumentError (" Time mask must indicate same length as `us` ($(time_mask_lens) != $(us_lens) )" ))
239401 end
240402
241- # Update our struct, release the `ilsc`
403+ # Update our struct, release the `ilsc` and its caches
404+ for t_chunk in ilsc. t_chunks
405+ release! (ilsc. cache. t_chunks, t_chunk)
406+ end
407+ @assert is_fully_released (ilsc. cache. t_chunks)
408+ for u_idx in 1 : length (ilsc. u_chunks)
409+ for u_chunk in ilsc. u_chunks[u_idx]
410+ release! (ilsc. cache. u_chunks, u_chunk)
411+ end
412+ end
413+ @assert is_fully_released (ilsc. cache. u_chunks)
414+ for time_mask in ilsc. time_masks
415+ release! (ilsc. cache. time_masks, time_mask)
416+ end
417+ @assert is_fully_released (ilsc. cache. time_masks)
418+ if ils. ilsc_cache_pool != = nothing
419+ release! (ils. ilsc_cache_pool, ilsc. cache)
420+ end
242421 ils. ilsc = nothing
243422 ils. ts = ts
244423 ils. us = us
0 commit comments