Skip to content

Commit 4ffa6c4

Browse files
authored
refactor: neighbours are now exposed by the neighbour lists themselves (#60)
**Summary** This PR moves all neighbour-list related logic to the neighour subpackage, so that all other modules are agnostic of the underlying implementation. This is accomplished by making each neighbour list type callable (with `system` and `i` are arguments). The call to the instances return generators, that when iterated upon return the indices of all the particles adjacent to `i`. The methods to compute energies leverage these newly defined methods instead of encoding the neighbour list logic themselves. **Implementation discussion** For linked lists, I wasn't able to return a simple generator in the dedicated calling function. It looks like Julia doesn't have some of the features that would be required to create custom generators (something like the `yield` keyword in Python). Instead, I created a new struct that is used to iterate over. This adds some complexity and makes the iteration process less clear, which is opposite to the goal of this PR. I believe this complexity is inherent to the linked list algorithm, and is now fully contained in the implementation of this list (rather than spread to other parts), so I'd argue that this PR still improves overall clarity. **Test plan** CI, which contains tests that rely on these lists. Locally, tests on a polymer chain. Currently, unit tests on molecules fail - need to be fixed.
1 parent 57d0ebe commit 4ffa6c4

3 files changed

Lines changed: 103 additions & 117 deletions

File tree

src/atoms.jl

Lines changed: 3 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -76,64 +76,13 @@ end
7676

7777

7878
"""
79-
Compute the energy of particle `i` by brute force (no neighbour list).
80-
81-
`compute_energy_particle(system, i, ::EmptyList)` sums interactions of particle `i`
82-
with all other particles using `compute_energy_ij`.
79+
Compute the energy of particle `i` using the provided neighbour list.
8380
"""
84-
function compute_energy_particle(system::Atoms, i, ::EmptyList)
81+
function compute_energy_particle(system::Atoms, i, neighbour_list::NeighbourList)
8582
energy_i = zero(typeof(system.density))
8683
position_i = get_position(system, i)
87-
for (j, _) in enumerate(system)
84+
for j in neighbour_list(system, i)
8885
energy_i += compute_energy_ij(system, i, j, position_i)
8986
end
9087
return energy_i
9188
end
92-
93-
94-
"""
95-
Compute the energy of particle `i` using a `CellList` neighbour list.
96-
97-
This restricts pair evaluations to particles in neighbouring cells of `i`.
98-
"""
99-
function compute_energy_particle(system::Atoms, i, neighbour_list::CellList)
100-
energy_i = zero(typeof(system.density))
101-
position_i = get_position(system, i)
102-
c = get_cell_index(position_i, neighbour_list)
103-
neighbour_cells = neighbour_list.neighbour_cells[c]
104-
105-
# Scan the neighbourhood of cell mc (including itself)
106-
@inbounds for c2 in neighbour_cells
107-
# Scan atoms in cell c2
108-
neighbours = neighbour_list.cells[c2]
109-
@inbounds for j in neighbours
110-
energy_i += compute_energy_ij(system, i, j, position_i)
111-
end
112-
end
113-
return energy_i
114-
end
115-
116-
"""
117-
Compute the energy of particle `i` using a `LinkedList` neighbour list.
118-
119-
This variant iterates linked list heads for neighbouring cells and accumulates
120-
pair energies computed with `compute_energy_ij`.
121-
"""
122-
function compute_energy_particle(system::Atoms, i, neighbour_list::LinkedList)
123-
energy_i = zero(typeof(system.density))
124-
# Get cell of particle i
125-
position_i = get_position(system, i)
126-
c = get_cell_index(position_i, neighbour_list)
127-
neighbour_cells = neighbour_list.neighbour_cells[c]
128-
# Scan the neighbourhood of cell mc (including itself)
129-
@inbounds for c2 in neighbour_cells
130-
# Scan atoms in cell c2
131-
j = neighbour_list.head[c2]
132-
while (j != -1)
133-
energy_ij = compute_energy_ij(system, i, j, position_i)
134-
energy_i += energy_ij
135-
j = neighbour_list.list[j]
136-
end
137-
end
138-
return energy_i
139-
end

src/molecules.jl

Lines changed: 12 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,15 @@ Return arrays of first indices and counts for consecutive blocks in `vec`.
112112
function get_first_and_counts(vec::Vector{Int})
113113
firsts = Int[]
114114
counts = Int[]
115-
115+
116116
# Handle empty vector case
117117
isempty(vec) && return firsts, counts
118-
118+
119119
# Initialize with first element
120120
current = vec[1]
121121
push!(firsts, 1)
122122
count = 1
123-
123+
124124
# Scan through vector
125125
@inbounds for i in 2:length(vec)
126126
if vec[i] != current
@@ -134,7 +134,7 @@ function get_first_and_counts(vec::Vector{Int})
134134
end
135135
# Add last count
136136
push!(counts, count)
137-
137+
138138
return firsts, counts
139139
end
140140

@@ -198,69 +198,18 @@ function compute_energy_ij(system::Molecules, position_i, position_j, model_ij::
198198
end
199199

200200
"""
201-
Compute particle energy by brute force (no neighbour list).
202-
203-
`compute_energy_particle(system, i, ::EmptyList)` sums interactions of particle `i`
204-
with all particles (including bonded and non-bonded contributions via helper
205-
functions). Used when no neighbour list is available.
206-
"""
207-
function compute_energy_particle(system::Molecules, i, ::EmptyList)
208-
energy = zero(typeof(system.density))
209-
position_i = system.position[i]
210-
bonds_i = system.bonds[i]
211-
@inbounds for j in eachindex(system)
212-
energy += check_compute_energy_ij(system, i, j, position_i, bonds_i)
213-
end
214-
return energy
215-
end
216-
217-
# With linked list
218-
"""
219-
Compute particle energy using a `LinkedList` neighbour list.
201+
Compute particle energy using a the provided neighbour list.
220202
221-
This variant restricts non-bonded pair evaluation to particles in neighbouring
222-
cells defined by the linked list; bonded contributions are added explicitly.
203+
Non-bonded pair energy evaluations are restricted to particles in the neighbour list;
204+
bonded contributions are added explicitly.
223205
"""
224-
function compute_energy_particle(system::Molecules, i, neighbour_list::LinkedList)
225-
energy_i = zero(typeof(system.density))
226-
# Get cell of particle i
206+
function compute_energy_particle(system::Molecules, i, neighbour_list::NeighbourList)
227207
position_i = system.position[i]
228-
c = get_cell_index(i, neighbour_list)
229-
cells = neighbour_list.neighbour_cells[c]
230-
# Scan the neighbourhood of cell mc (including itself)
231208
bonds_i = system.bonds[i]
232-
energy_i += compute_energy_bonded_i(system, i, position_i, bonds_i)
233-
@inbounds for c2 in cells
234-
# Calculate the scalar cell index of the neighbour cell (with PBC)
235-
j = neighbour_list.head[c2]
236-
while (j != -1)
237-
energy_i += check_nonbonded_compute_energy_ij(system, i, j, position_i, bonds_i)
238-
j = neighbour_list.list[j]
239-
end
240-
end
241-
return energy_i
242-
end
243209

244-
"""
245-
Compute particle energy using a `CellList` neighbour list.
246-
247-
This variant restricts non-bonded pair evaluation to particles in neighbouring
248-
cells defined by the cell list; bonded contributions are added explicitly.
249-
"""
250-
function compute_energy_particle(system::Molecules, i, neighbour_list::CellList)
251-
energy_i = zero(typeof(system.density))
252-
position_i = get_position(system, i)
253-
c = get_cell_index(i, neighbour_list)
254-
neighbour_cells = neighbour_list.neighbour_cells[c]
255-
# Scan the neighbourhood of cell mc (including itself)
256-
bonds_i = system.bonds[i]
257-
energy_i += compute_energy_bonded_i(system, i, position_i, bonds_i)
258-
@inbounds for c2 in neighbour_cells
259-
# Scan atoms in cell c2
260-
neighbours = neighbour_list.cells[c2]
261-
@inbounds for j in neighbours
262-
energy_i += check_nonbonded_compute_energy_ij(system, i, j, position_i, bonds_i)
263-
end
210+
energy_i = compute_energy_bonded_i(system, i, position_i, bonds_i)
211+
for j in neighbour_list(system, i)
212+
energy_i += check_nonbonded_compute_energy_ij(system, i, j, position_i, bonds_i)
264213
end
265214
return energy_i
266215
end
@@ -290,4 +239,4 @@ function compute_chain_correlation(system::Molecules)
290239
end
291240
end
292241
return sum(correlation_array.^2)
293-
end
242+
end

src/neighbours.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ function old_new_cell(::Particles, i, ::EmptyList)
4242
return 1, 1
4343
end
4444

45+
"""Calling an EmptyList objects return an object which can be iterated upon.
46+
47+
This iteration will return the indices of the neighbours (which for this list is all the other particles in the system).
48+
"""
49+
function (empty_list::EmptyList)(system::Particles, ::Int)
50+
return (j for j in 1:length(system))
51+
end
52+
53+
4554
"""Return the scalar cell index of particle `i` stored in `neighbour_list`.
4655
"""
4756
function get_cell_index(i::Int, neighbour_list::NeighbourList)
@@ -195,6 +204,20 @@ function old_new_cell(system::Particles, i, neighbour_list::CellList)
195204
return c, c2
196205
end
197206

207+
"""Calling a CellList objects return an object which can be iterated upon.
208+
209+
This iteration will return the indices of the neighbours of particle i.
210+
"""
211+
function (cell_list::CellList)(system::Particles, i::Int)
212+
position_i = get_position(system, i)
213+
c = get_cell_index(position_i, cell_list)
214+
neighbour_cells = cell_list.neighbour_cells[c]
215+
# Scan the neighbourhood of cell mc (including itself)
216+
# and from there scan atoms in cell c2
217+
return (j for c2 in neighbour_cells for j in @inbounds cell_list.cells[c2])
218+
end
219+
220+
198221
"""Linked-list neighbour list implementation.
199222
200223
Uses arrays `head` and `list` to store per-cell linked lists of particle indices.
@@ -281,3 +304,68 @@ function old_new_cell(system::Particles, i, neighbour_list::LinkedList)
281304
c2 = cell_index(neighbour_list, mc2)
282305
return c, c2
283306
end
307+
308+
""" This struct is used to iterate over neighbours of a Linked list
309+
"""
310+
struct LinkedIterator
311+
neighbour_cells::Vector{Int}
312+
head::Vector{Int}
313+
list::Vector{Int}
314+
end
315+
316+
# To iterate over the neighbours of a linked list, one could write the following loops
317+
#@inbounds for c in neighbour_list.neighbour_cells
318+
# j = neighbour_list.head[c]
319+
# while (j != -1)
320+
# do stuff
321+
# j = neighbour_list.list[j]
322+
# end
323+
#end
324+
# This is however impossible to rewrite as a simple generator
325+
# So we implement the following function, which uses a state to carry over the needed information
326+
function Base.iterate(neighbour_list::LinkedIterator, state=-1)
327+
# First time in
328+
if state == -1
329+
j = -1
330+
c_state = 1
331+
# The while loop is necessary, in case the first head is -1
332+
while j == -1
333+
next = iterate(neighbour_list.neighbour_cells, c_state)
334+
if next == nothing
335+
return nothing
336+
end
337+
c, c_state = next
338+
@inbounds j = neighbour_list.head[c]
339+
end
340+
else
341+
c_state, j = state
342+
@inbounds j = neighbour_list.list[j]
343+
# The while loop is necessary, in case a head is -1
344+
while j == -1
345+
next = iterate(neighbour_list.neighbour_cells, c_state)
346+
if next == nothing
347+
return nothing
348+
end
349+
c, c_state = next
350+
@inbounds j = neighbour_list.head[c]
351+
end
352+
end
353+
354+
if j == -1
355+
return nothing
356+
end
357+
state = (c_state, j)
358+
return j, state
359+
end
360+
361+
"""Calling a LinkedList objects return an object which can be iterated upon.
362+
363+
This iteration will return the indices of the neighbours of particle i.
364+
"""
365+
function (linked_list::LinkedList)(system::Particles, i::Int)
366+
position_i = get_position(system, i)
367+
c = get_cell_index(position_i, linked_list)
368+
neighbour_cells = linked_list.neighbour_cells[c]
369+
370+
return LinkedIterator(neighbour_cells, linked_list.head, linked_list.list)
371+
end

0 commit comments

Comments
 (0)