@@ -2,30 +2,144 @@ using SciMLBase
22
33export IndependentlyLinearizedSolution
44
5+
6+ """
7+ CachePool(T, alloc; thread_safe = true)
8+
9+ Simple memory-reusing cache that allows us to grow a cache and keep
10+ re-using those pieces of memory (in our case, typically `u` vectors)
11+ until the solve is finished. By default, this datastructure is made
12+ to be thread-safe by locking on every acquire and release, but it
13+ can be made thread-unsafe (and correspondingly faster) by passing
14+ `thread_safe = false` to the constructor.
15+
16+ While manual usage with `acquire!()` and `release!()` is possible,
17+ most users will want to use `@with_cache`, which provides lexically-
18+ scoped `acquire!()` and `release!()` usage automatically. Example:
19+
20+ ```julia
21+ us = CachePool(Vector{S}, () -> Vector{S}(undef, num_us); thread_safe=false)
22+ @with_cache us u_prev begin
23+ @with_cache us u_next begin
24+ # perform tasks with these two `u` vectors
25+ end
26+ end
27+ ```
28+
29+ !!! warning "Escaping values"
30+ You must not use an acquired value after you have released it;
31+ the memory may be immediately re-used by some other consumer of
32+ your cache pool. Do not allow the acquired value to escape
33+ outside of the `@with_cache` block, or past a `release!()`.
34+ """
35+ mutable struct CachePool{T, THREAD_SAFE}
36+ const pool:: Vector{T}
37+ const alloc:: Function
38+ lock:: ReentrantLock
39+ num_allocated:: Int
40+ num_acquired:: Int
41+
42+ function CachePool (T, alloc:: F ; thread_safe:: Bool = true ) where {F}
43+ return new {T,Val{thread_safe}} (T[], alloc, ReentrantLock (), 0 , 0 )
44+ end
45+ end
46+ const ThreadSafeCachePool{T} = CachePool{T,Val{true }}
47+ const ThreadUnsafeCachePool{T} = CachePool{T,Val{false }}
48+
49+ """
50+ acquire!(cache::CachePool)
51+
52+ Returns a cached element of the cache pool, calling `cache.alloc()` if none
53+ are available.
54+ """
55+ Base. @inline function acquire! (cache:: CachePool{T} , _dummy = nothing ) where {T}
56+ cache. num_acquired += 1
57+ if isempty (cache. pool)
58+ cache. num_allocated += 1
59+ return cache. alloc ():: T
60+ end
61+ return pop! (cache. pool)
62+ end
63+
64+ """
65+ release!(cache::CachePool, val)
66+
67+ Returns the value `val` to the cache pool.
68+ """
69+ Base. @inline function release! (cache:: CachePool , val, _dummy = nothing )
70+ push! (cache. pool, val)
71+ cache. num_acquired -= 1
72+ end
73+
74+ function is_fully_released (cache:: CachePool , _dummy = nothing )
75+ return cache. num_acquired == 0
76+ end
77+
78+ # Thread-safe versions just sub out to the other methods, using `_dummy` to force correct dispatch
79+ acquire! (cache:: ThreadSafeCachePool ) = @lock cache. lock acquire! (cache, nothing )
80+ release! (cache:: ThreadSafeCachePool , val) = @lock cache. lock release! (cache, val, nothing )
81+ is_fully_released (cache:: ThreadSafeCachePool ) = @lock cache. lock is_fully_released (cache, nothing )
82+
83+ macro with_cache (cache, name, body)
84+ return quote
85+ $ (esc (name)) = acquire! ($ (esc (cache)))
86+ try
87+ $ (esc (body))
88+ finally
89+ release! ($ (esc (cache)), $ (esc (name)))
90+ end
91+ end
92+ end
93+
94+
95+ struct IndependentlyLinearizedSolutionChunksCache{T,S}
96+ t_chunks:: ThreadUnsafeCachePool{Vector{T}}
97+ u_chunks:: ThreadUnsafeCachePool{Matrix{S}}
98+ time_masks:: ThreadUnsafeCachePool{BitMatrix}
99+
100+ function IndependentlyLinearizedSolutionChunksCache {T,S} (num_us:: Int , num_derivatives:: Int , chunk_size:: Int ) where {T,S}
101+ t_chunks_alloc = () -> Vector {T} (undef, chunk_size)
102+ u_chunks_alloc = () -> Matrix {S} (undef, num_derivatives+ 1 , chunk_size)
103+ time_masks_alloc = () -> BitMatrix (undef, num_us, chunk_size)
104+ return new (
105+ CachePool (Vector{T}, t_chunks_alloc; thread_safe= false ),
106+ CachePool (Matrix{S}, u_chunks_alloc; thread_safe= false ),
107+ CachePool (BitMatrix, time_masks_alloc; thread_safe= false ),
108+ )
109+ end
110+ end
111+
5112"""
6113 IndependentlyLinearizedSolutionChunks
7114
8115When constructing an `IndependentlyLinearizedSolution` via the `IndependentlyLinearizingCallback`,
9116we use this indermediate structure to reduce allocations and collect the unknown number of timesteps
10117that the solve will generate.
11118"""
12- mutable struct IndependentlyLinearizedSolutionChunks{T, S}
119+ mutable struct IndependentlyLinearizedSolutionChunks{T, S, N }
13120 t_chunks:: Vector{Vector{T}}
14121 u_chunks:: Vector{Vector{Matrix{S}}}
15122 time_masks:: Vector{BitMatrix}
16123
124+ # Temporary array that gets used by `get_chunks`
125+ last_chunks:: Vector{Matrix{S}}
126+
17127 # Index of next write into the last chunk
18128 u_offsets:: Vector{Int}
19129 t_offset:: Int
20130
131+ cache:: IndependentlyLinearizedSolutionChunksCache
132+
21133 function IndependentlyLinearizedSolutionChunks {T, S} (num_us:: Int , num_derivatives:: Int = 0 ,
22- chunk_size:: Int = 100 ) where {T, S}
23- return new ([Vector {T} (undef, chunk_size)],
24- [[Matrix {S} (undef, num_derivatives+ 1 , chunk_size)] for _ in 1 : num_us],
25- [BitMatrix (undef, num_us, chunk_size)],
26- [1 for _ in 1 : num_us],
27- 1 ,
28- )
134+ chunk_size:: Int = 512 ,
135+ cache:: IndependentlyLinearizedSolutionChunksCache = IndependentlyLinearizedSolutionChunksCache {T,S} (num_us, num_derivatives, chunk_size)) where {T, S}
136+ t_chunks = [acquire! (cache. t_chunks)]
137+ u_chunks = [[acquire! (cache. u_chunks)] for _ in 1 : num_us]
138+ time_masks = [acquire! (cache. time_masks)]
139+ last_chunks = [u_chunks[u_idx][1 ] for u_idx in 1 : num_us]
140+ u_offsets = [1 for _ in 1 : num_us]
141+ t_offset = 1
142+ return new {T,S,num_derivatives} (t_chunks, u_chunks, time_masks, last_chunks, u_offsets, t_offset, cache)
29143 end
30144end
31145
@@ -44,14 +158,8 @@ function num_us(ilsc::IndependentlyLinearizedSolutionChunks)
44158 end
45159 return length (ilsc. u_chunks)
46160end
161+ num_derivatives (ilsc:: IndependentlyLinearizedSolutionChunks{T,S,N} ) where {T,S,N} = N
47162
48- function num_derivatives (ilsc:: IndependentlyLinearizedSolutionChunks )
49- # If we've been finalized, just return `0` (which means only the primal)
50- if isempty (ilsc. t_chunks)
51- return 0
52- end
53- return size (first (first (ilsc. u_chunks)), 1 ) - 1
54- end
55163
56164function Base. isempty (ilsc:: IndependentlyLinearizedSolutionChunks )
57165 return length (ilsc. t_chunks) == 1 && ilsc. t_offset == 1
@@ -61,24 +169,25 @@ function get_chunks(ilsc::IndependentlyLinearizedSolutionChunks{T, S}) where {T,
61169 # Check if we need to allocate new `t` chunk
62170 chunksize = chunk_size (ilsc)
63171 if ilsc. t_offset > chunksize
64- push! (ilsc. t_chunks, Vector {T} (undef, chunksize ))
65- push! (ilsc. time_masks, BitMatrix (undef, length ( ilsc. u_offsets), chunksize ))
172+ push! (ilsc. t_chunks, acquire! (ilsc . cache . t_chunks ))
173+ push! (ilsc. time_masks, acquire! ( ilsc. cache . time_masks ))
66174 ilsc. t_offset = 1
67175 end
68176
69177 # Check if we need to allocate any new `u` chunks (but only for those with `u_mask`)
70178 for (u_idx, u_chunks) in enumerate (ilsc. u_chunks)
71179 if ilsc. u_offsets[u_idx] > chunksize
72- push! (u_chunks, Matrix {S} (undef, num_derivatives ( ilsc) + 1 , chunksize ))
180+ push! (u_chunks, acquire! ( ilsc. cache . u_chunks ))
73181 ilsc. u_offsets[u_idx] = 1
74182 end
183+ ilsc. last_chunks[u_idx] = u_chunks[end ]
75184 end
76185
77186 # return the last chunk for each
78187 return (
79188 ilsc. t_chunks[end ],
80189 ilsc. time_masks[end ],
81- [u_chunks[ end ] for u_chunks in ilsc. u_chunks] ,
190+ ilsc. last_chunks ,
82191 )
83192end
84193
@@ -135,16 +244,18 @@ function store!(ilsc::IndependentlyLinearizedSolutionChunks{T, S},
135244 ts, time_mask, us = get_chunks (ilsc)
136245
137246 # Store into the chunks, gated by `u_mask`
138- for u_idx in 1 : size (u, 2 )
247+ @inbounds for u_idx in 1 : size (u, 2 )
139248 if u_mask[u_idx]
140249 for deriv_idx in 1 : size (u, 1 )
141250 us[u_idx][deriv_idx, ilsc. u_offsets[u_idx]] = u[deriv_idx, u_idx]
142251 end
143252 ilsc. u_offsets[u_idx] += 1
144253 end
254+
255+ # Update our `time_mask` while we're at it
256+ time_mask[u_idx, ilsc. t_offset] = u_mask[u_idx]
145257 end
146258 ts[ilsc. t_offset] = t
147- time_mask[:, ilsc. t_offset] .= u_mask
148259 ilsc. t_offset += 1
149260end
150261
@@ -161,7 +272,7 @@ efficient `iterate()` method that can be used to reconstruct coherent views
161272of the state variables at all timepoints, as well as an efficient `sample!()`
162273method that can sample at arbitrary timesteps.
163274"""
164- mutable struct IndependentlyLinearizedSolution{T, S}
275+ mutable struct IndependentlyLinearizedSolution{T, S, N }
165276 # All timepoints, shared by all `us`
166277 ts:: Vector{T}
167278
@@ -173,28 +284,37 @@ mutable struct IndependentlyLinearizedSolution{T, S}
173284 time_mask:: BitMatrix
174285
175286 # Temporary object used during construction, will be set to `nothing` at the end.
176- ilsc:: Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S}}
287+ ilsc:: Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S,N}}
288+ ilsc_cache_pool:: Union{Nothing,ThreadSafeCachePool{IndependentlyLinearizedSolutionChunksCache{T,S}}}
177289end
178290# Helper function to create an ILS wrapped around an in-progress ILSC
179- function IndependentlyLinearizedSolution (ilsc:: IndependentlyLinearizedSolutionChunks{T,S} ) where {T,S}
180- ils = IndependentlyLinearizedSolution (
291+ function IndependentlyLinearizedSolution (ilsc:: IndependentlyLinearizedSolutionChunks{T,S,N} , cache_pool = nothing ) where {T,S,N }
292+ return IndependentlyLinearizedSolution {T,S,N} (
181293 T[],
182294 Matrix{S}[],
183295 BitMatrix (undef, 0 ,0 ),
184296 ilsc,
297+ cache_pool,
185298 )
186- return ils
187299end
188300# Automatically create an ILS wrapped around an ILSC from a `prob`
189- function IndependentlyLinearizedSolution (prob:: SciMLBase.AbstractDEProblem , num_derivatives = 0 )
301+ function IndependentlyLinearizedSolution (prob:: SciMLBase.AbstractDEProblem , num_derivatives = 0 ;
302+ cache_pool = nothing ,
303+ chunk_size:: Int = 512 )
190304 T = eltype (prob. tspan)
305+ S = eltype (prob. u0)
191306 U = isnothing (prob. u0) ? Float64 : eltype (prob. u0)
192- N = isnothing (prob. u0) ? 0 : length (prob. u0)
193- chunks = IndependentlyLinearizedSolutionChunks {T,U} (N, num_derivatives)
194- return IndependentlyLinearizedSolution (chunks)
307+ num_us = isnothing (prob. u0) ? 0 : length (prob. u0)
308+ if cache_pool === nothing
309+ cache = IndependentlyLinearizedSolutionChunksCache {T,S} (num_us, num_derivatives, chunk_size)
310+ else
311+ cache = acquire! (cache_pool)
312+ end
313+ chunks = IndependentlyLinearizedSolutionChunks {T,U} (num_us, num_derivatives, chunk_size, cache)
314+ return IndependentlyLinearizedSolution (chunks, cache_pool)
195315end
196316
197- num_derivatives (ils :: IndependentlyLinearizedSolution ) = ! isempty (ils . us) ? size ( first (ils . us), 1 ) : 0
317+ num_derivatives (:: IndependentlyLinearizedSolution{T,S,N} ) where {T,S,N} = N
198318num_us (ils:: IndependentlyLinearizedSolution ) = length (ils. us)
199319Base. size (ils:: IndependentlyLinearizedSolution ) = size (ils. time_mask)
200320Base. length (ils:: IndependentlyLinearizedSolution ) = length (ils. ts)
@@ -222,10 +342,51 @@ function finish!(ils::IndependentlyLinearizedSolution{T, S}, return_code) where
222342 us = Vector {Matrix{S}} ()
223343 time_mask = BitMatrix (undef, 0 , 0 )
224344 else
225- ts = vcat (trim_chunk (ilsc. t_chunks, ilsc. t_offset)... )
226- time_mask = hcat (trim_chunk (ilsc. time_masks, ilsc. t_offset)... )
227- us = [hcat (trim_chunk (ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx])... )
228- for u_idx in 1 : length (ilsc. u_chunks)]
345+ chunk_len (chunk) = size (chunk, ndims (chunk))
346+ function chunks_len (chunks:: Vector , offset)
347+ len = 0
348+ for chunk_idx in 1 : length (chunks)- 1
349+ len += chunk_len (chunks[chunk_idx])
350+ end
351+ return len + offset - 1
352+ end
353+
354+ function copy_chunk! (out:: Vector , in:: Vector , out_offset:: Int , len= chunk_len (in))
355+ for idx in 1 : len
356+ out[idx+ out_offset] = in[idx]
357+ end
358+ end
359+ function copy_chunk! (out:: AbstractMatrix , in:: AbstractMatrix , out_offset:: Int , len= chunk_len (in))
360+ for zdx in 1 : size (in, 1 )
361+ for idx in 1 : len
362+ out[zdx, idx+ out_offset] = in[zdx, idx]
363+ end
364+ end
365+ end
366+
367+ function collapse_chunks! (out, chunks, offset:: Int )
368+ write_offset = 0
369+ for chunk_idx in 1 : (length (chunks)- 1 )
370+ chunk = chunks[chunk_idx]
371+ copy_chunk! (out, chunk, write_offset)
372+ write_offset += chunk_len (chunk)
373+ end
374+ copy_chunk! (out, chunks[end ], write_offset, offset- 1 )
375+ end
376+
377+ # Collapse t_chunks
378+ ts = Vector {T} (undef, chunks_len (ilsc. t_chunks, ilsc. t_offset))
379+ collapse_chunks! (ts, ilsc. t_chunks, ilsc. t_offset)
380+
381+ # Collapse u_chunks
382+ us = Vector {Matrix{S}} (undef, length (ilsc. u_chunks))
383+ for u_idx in 1 : length (ilsc. u_chunks)
384+ us[u_idx] = Matrix {S} (undef, size (ilsc. u_chunks[u_idx][1 ],1 ), chunks_len (ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx]))
385+ collapse_chunks! (us[u_idx], ilsc. u_chunks[u_idx], ilsc. u_offsets[u_idx])
386+ end
387+
388+ time_mask = BitMatrix (undef, size (ilsc. time_masks[1 ], 1 ), chunks_len (ilsc. time_masks, ilsc. t_offset))
389+ collapse_chunks! (time_mask, ilsc. time_masks, ilsc. t_offset)
229390 end
230391
231392 # Sanity-check lengths
@@ -245,7 +406,24 @@ function finish!(ils::IndependentlyLinearizedSolution{T, S}, return_code) where
245406 throw (ArgumentError (" Time mask must indicate same length as `us` ($(time_mask_lens) != $(us_lens) )" ))
246407 end
247408
248- # Update our struct, release the `ilsc`
409+ # Update our struct, release the `ilsc` and its caches
410+ for t_chunk in ilsc. t_chunks
411+ release! (ilsc. cache. t_chunks, t_chunk)
412+ end
413+ @assert is_fully_released (ilsc. cache. t_chunks)
414+ for u_idx in 1 : length (ilsc. u_chunks)
415+ for u_chunk in ilsc. u_chunks[u_idx]
416+ release! (ilsc. cache. u_chunks, u_chunk)
417+ end
418+ end
419+ @assert is_fully_released (ilsc. cache. u_chunks)
420+ for time_mask in ilsc. time_masks
421+ release! (ilsc. cache. time_masks, time_mask)
422+ end
423+ @assert is_fully_released (ilsc. cache. time_masks)
424+ if ils. ilsc_cache_pool != = nothing
425+ release! (ils. ilsc_cache_pool, ilsc. cache)
426+ end
249427 ils. ilsc = nothing
250428 ils. ts = ts
251429 ils. us = us
0 commit comments