Skip to content

Commit 3ac87fb

Browse files
committed
Fix double buffered kernel
1 parent e6d87a8 commit 3ac87fb

1 file changed

Lines changed: 20 additions & 11 deletions

File tree

src/nhs_grid.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ end
404404

405405
local_points = @localmem Int32 MAX
406406
local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)
407-
407+
408408
next_local_points = @localmem Int32 MAX
409409
next_local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)
410410

@@ -421,28 +421,36 @@ end
421421

422422
@inline function stage!(local_points, local_neighbor_coords, neighbor_cell)
423423
points_view = points_in_cell(neighbor_cell, neighborhood_search)
424-
n_particles_in_neighbor_cell = length(points_view)
424+
n_particles_in_neighbor_cell_ = length(points_view)
425425

426426
# First use all threads to load the neighbors into local memory in parallel
427-
if particleidx <= n_particles_in_neighbor_cell
427+
if particleidx <= n_particles_in_neighbor_cell_
428428
@inbounds p = local_points[particleidx] = points_view[particleidx]
429429
for d in 1:ndims(neighborhood_search)
430430
@inbounds local_neighbor_coords[d, particleidx] = neighbor_system_coords[d, p]
431431
end
432432
end
433-
return n_particles_in_neighbor_cell
433+
return n_particles_in_neighbor_cell_
434434
end
435435

436-
neighborhood = neighboring_cells(cell, neighbor_search)
437-
(neighbor_cell, state) = iterate(neighborhood)
436+
neighborhood = neighboring_cells(cell, neighborhood_search)
437+
# (neighbor_cell, state) = iterate(neighborhood)
438+
neighbor_cell = first(neighborhood)
438439

439440
n_particles_in_neighbor_cell = stage!(local_points, local_neighbor_coords, Tuple(neighbor_cell))
440441
@synchronize()
441442

442-
while true
443-
next = iterate(neighborhood, state)
444-
if next !== nothing
445-
(next_neighbor_cell, state) = next
443+
for neighbor_ in 1:length(neighborhood)
444+
neighbor_cell = @inbounds neighborhood[neighbor_]
445+
446+
# while true
447+
# next = iterate(neighborhood, state)
448+
# if next !== nothing
449+
# n_particles_in_neighbor_cell = stage!(local_points, local_neighbor_coords, Tuple(neighbor_cell))
450+
# @synchronize
451+
if neighbor_ < length(neighborhood)
452+
next_neighbor_cell = neighborhood[neighbor_ + 1]
453+
# (next_neighbor_cell, state) = next
446454
next_n_particles_in_neighbor_cell = stage!(next_local_points, next_local_neighbor_coords, Tuple(next_neighbor_cell))
447455
end
448456

@@ -468,7 +476,8 @@ end
468476
end
469477
end
470478
end
471-
next === nothing && break
479+
# next === nothing && break
480+
neighbor_ >= length(neighborhood) && break
472481
@synchronize()
473482
# swap variables
474483
n_particles_in_neighbor_cell = next_n_particles_in_neighbor_cell

0 commit comments

Comments
 (0)