@@ -51,24 +51,25 @@ since not sorting makes our implementation a lot faster (although less paralleli
5151 In: Computer Graphics Forum 30.1 (2011), pages 99–112.
5252 [doi: 10.1111/J.1467-8659.2010.01832.X](https://doi.org/10.1111/J.1467-8659.2010.01832.X)
5353"""
54- struct GridNeighborhoodSearch{NDIMS, ELTYPE, CL, CB, PB} <: AbstractNeighborhoodSearch
55- cell_list :: CL
56- search_radius :: ELTYPE
57- periodic_box :: PB
58- n_cells :: NTuple{NDIMS, Int} # Required to calculate periodic cell index
59- cell_size :: NTuple{NDIMS, ELTYPE} # Required to calculate cell index
60- cell_buffer :: CB # Multithreaded buffer for `update!`
61- cell_buffer_indices :: Vector{Int} # Store which entries of `cell_buffer` are initialized
62- threaded_update :: Bool
54+ struct GridNeighborhoodSearch{NDIMS, ELTYPE, CL, PB, UB} <: AbstractNeighborhoodSearch
55+ cell_list :: CL
56+ search_radius :: ELTYPE
57+ periodic_box :: PB
58+ n_cells :: NTuple{NDIMS, Int} # Required to calculate periodic cell index
59+ cell_size :: NTuple{NDIMS, ELTYPE} # Required to calculate cell index
60+ update_buffer :: UB # Multithreaded buffer for `update!`
61+ threaded_update :: Bool
6362
6463 function GridNeighborhoodSearch {NDIMS} (; search_radius = 0.0 , n_points = 0 ,
6564 periodic_box = nothing ,
6665 cell_list = DictionaryCellList {NDIMS} (),
6766 threaded_update = true ) where {NDIMS}
6867 ELTYPE = typeof (search_radius)
6968
70- cell_buffer = Array {index_type(cell_list), 2} (undef, n_points, Threads. nthreads ())
71- cell_buffer_indices = zeros (Int, Threads. nthreads ())
69+ # Create update buffer and initialize it with empty vectors
70+ update_buffer = DynamicVectorOfVectors {index_type(cell_list)} (max_outer_length = Threads. nthreads (),
71+ max_inner_length = n_points)
72+ push! (update_buffer, (NTuple{NDIMS, Int}[] for _ in 1 : Threads. nthreads ()). .. )
7273
7374 if search_radius < eps () || isnothing (periodic_box)
7475 # No periodicity
@@ -90,37 +91,28 @@ struct GridNeighborhoodSearch{NDIMS, ELTYPE, CL, CB, PB} <: AbstractNeighborhood
9091 end
9192 end
9293
93- new{NDIMS, ELTYPE, typeof (cell_list), typeof (cell_buffer),
94- typeof (periodic_box)}(cell_list, search_radius, periodic_box, n_cells,
95- cell_size, cell_buffer, cell_buffer_indices,
96- threaded_update)
94+ new{NDIMS, ELTYPE, typeof (cell_list), typeof (periodic_box),
95+ typeof (update_buffer)}(cell_list, search_radius, periodic_box, n_cells,
96+ cell_size, update_buffer, threaded_update)
9797 end
9898end
9999
100100@inline Base. ndims (:: GridNeighborhoodSearch{NDIMS} ) where {NDIMS} = NDIMS
101101
102- @inline function npoints (neighborhood_search:: GridNeighborhoodSearch )
103- return size (neighborhood_search. cell_buffer, 1 )
104- end
105-
106102function initialize! (neighborhood_search:: GridNeighborhoodSearch ,
107103 x:: AbstractMatrix , y:: AbstractMatrix )
108104 initialize_grid! (neighborhood_search, y)
109105end
110106
111- function initialize_grid! (neighborhood_search:: GridNeighborhoodSearch{NDIMS} ,
112- y:: AbstractMatrix ) where {NDIMS}
113- initialize_grid! (neighborhood_search, i -> extract_svector (y, Val (NDIMS), i))
114- end
115-
116- function initialize_grid! (neighborhood_search:: GridNeighborhoodSearch , coords_fun)
107+ function initialize_grid! (neighborhood_search:: GridNeighborhoodSearch , y:: AbstractMatrix )
117108 (; cell_list) = neighborhood_search
118109
119110 empty! (cell_list)
120111
121- for point in 1 : npoints (neighborhood_search )
112+ for point in axes (y, 2 )
122113 # Get cell index of the point's cell
123- cell = cell_coords (coords_fun (point), neighborhood_search)
114+ point_coords = extract_svector (y, Val (ndims (neighborhood_search)), point)
115+ cell = cell_coords (point_coords, neighborhood_search)
124116
125117 # Add point to corresponding cell
126118 push_cell! (cell_list, cell, point)
@@ -147,20 +139,19 @@ end
147139
148140# Modify the existing hash table by moving points into their new cells
149141function update_grid! (neighborhood_search:: GridNeighborhoodSearch , coords_fun)
150- (; cell_list, cell_buffer, cell_buffer_indices , threaded_update) = neighborhood_search
142+ (; cell_list, update_buffer , threaded_update) = neighborhood_search
151143
152- # Reset `cell_buffer` by moving all pointers to the beginning
153- cell_buffer_indices .= 0
144+ # Empty each thread's list
145+ for i in eachindex (update_buffer)
146+ emptyat! (update_buffer, i)
147+ end
154148
155149 # Find all cells containing points that now belong to another cell
156- mark_changed_cell! (neighborhood_search, cell_list, coords_fun,
157- Val (threaded_update))
158-
159- # Iterate over all marked cells and move the points into their new cells.
160- for thread in 1 : Threads. nthreads ()
161- # Only the entries `1:cell_buffer_indices[thread]` are initialized for `thread`.
162- for i in 1 : cell_buffer_indices[thread]
163- cell_index = cell_buffer[i, thread]
150+ mark_changed_cell! (neighborhood_search, cell_list, coords_fun, Val (threaded_update))
151+
152+ # Iterate over all marked cells and move the points into their new cells
153+ for j in eachindex (update_buffer)
154+ for cell_index in update_buffer[j]
164155 points = cell_list[cell_index]
165156
166157 # Find all points whose coordinates do not match this cell
207198# Otherwise, `@threaded` does not work here with Julia ARM on macOS.
208199# See https://github.com/JuliaSIMD/Polyester.jl/issues/88.
209200@inline function mark_changed_cell! (neighborhood_search, cell_index, coords_fun)
210- (; cell_list, cell_buffer, cell_buffer_indices ) = neighborhood_search
201+ (; cell_list, update_buffer ) = neighborhood_search
211202
212203 for point in cell_list[cell_index]
213204 cell = cell_coords (coords_fun (point), neighborhood_search)
216207 # cell list to store cells inside `cell`.
217208 # These can be identical (see `DictionaryCellList`).
218209 if ! is_correct_cell (cell_list, cell, cell_index)
219- # Mark this cell and continue with the next one.
220- #
221- # `cell_buffer` is preallocated,
222- # but only the entries 1:i are used for this thread.
223- i = cell_buffer_indices[Threads. threadid ()] += 1
224- cell_buffer[i, Threads. threadid ()] = cell_index
210+ # Mark this cell and continue with the next one
211+ pushat! (update_buffer, Threads. threadid (), cell_index)
225212 break
226213 end
227214 end
0 commit comments