Skip to content

Commit 8ce6aa1

Browse files
committed
Improve readability
1 parent 3ac87fb commit 8ce6aa1

1 file changed

Lines changed: 132 additions & 29 deletions

File tree

src/nhs_grid.jl

Lines changed: 132 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -395,79 +395,184 @@ end
395395
return nothing
396396
end
397397

398+
@inline function copy_to_localmem!(local_points, local_neighbor_coords,
399+
neighbor_cell, neighbor_system_coords,
400+
neighborhood_search, particleidx)
401+
points_view = points_in_cell(neighbor_cell, neighborhood_search)
402+
n_particles_in_neighbor_cell = length(points_view)
403+
404+
# First use all threads to load the neighbors into local memory in parallel
405+
if particleidx <= n_particles_in_neighbor_cell
406+
@inbounds p = local_points[particleidx] = points_view[particleidx]
407+
for d in 1:ndims(neighborhood_search)
408+
@inbounds local_neighbor_coords[d, particleidx] = neighbor_system_coords[d, p]
409+
end
410+
end
411+
return n_particles_in_neighbor_cell
412+
end
413+
414+
# @parallel(block) for cell in cells
415+
# for neighbor_cell in neighboring_cells
416+
# @parallel(thread) for neighbor in neighbor_cell
417+
# copy_coordinates_to_localmem(neighbor)
418+
#
419+
# # Make sure all threads finished the copying
420+
# @synchronize
421+
#
422+
# @parallel(thread) for particle in cell
423+
# for neighbor in neighbor_cell
424+
# # This uses the neighbor coordinates from the local memory
425+
# compute(point, neighbor)
426+
#
427+
# # Make sure all threads finished computing before we continue with copying
428+
# @synchronize
398429
@kernel cpu=false function foreach_neighbor_localmem(f::F, system_coords, neighbor_system_coords,
399430
neighborhood_search, cells, ::Val{MAX}, search_radius) where {F, MAX}
400431
cell_ = @index(Group)
401432
cell = @inbounds Tuple(cells[cell_])
402433
particleidx = @index(Local)
403434
@assert 1 <= particleidx <= MAX
404435

436+
# Coordinate buffer in local memory
405437
local_points = @localmem Int32 MAX
406438
local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)
407439

408-
next_local_points = @localmem Int32 MAX
409-
next_local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)
440+
points = points_in_cell(cell, neighborhood_search)
441+
n_particles_in_current_cell = length(points)
410442

411-
pv = points_in_cell(cell, neighborhood_search)
412-
n_particles_in_current_cell = length(pv)
443+
# Extract point coordinates if a point lies on this thread
413444
if particleidx <= n_particles_in_current_cell
414-
point = @inbounds pv[particleidx]
445+
point = @inbounds points[particleidx]
415446
point_coords = @inbounds extract_svector(system_coords, Val(ndims(neighborhood_search)),
416447
point)
417448
else
418449
point = zero(Int32)
419450
point_coords = zero(SVector{ndims(neighborhood_search), eltype(system_coords)})
420451
end
421452

422-
@inline function stage!(local_points, local_neighbor_coords, neighbor_cell)
423-
points_view = points_in_cell(neighbor_cell, neighborhood_search)
424-
n_particles_in_neighbor_cell_ = length(points_view)
453+
for neighbor_cell_ in neighboring_cells(cell, neighborhood_search)
454+
neighbor_cell = Tuple(neighbor_cell_)
455+
456+
n_particles_in_neighbor_cell = copy_to_localmem!(local_points, local_neighbor_coords,
457+
neighbor_cell, neighbor_system_coords,
458+
neighborhood_search, particleidx)
459+
460+
# Make sure all threads finished the copying
461+
@synchronize
425462

426-
# First use all threads to load the neighbors into local memory in parallel
427-
if particleidx <= n_particles_in_neighbor_cell_
428-
@inbounds p = local_points[particleidx] = points_view[particleidx]
429-
for d in 1:ndims(neighborhood_search)
430-
@inbounds local_neighbor_coords[d, particleidx] = neighbor_system_coords[d, p]
463+
# Now each thread works on one point again
464+
if particleidx <= n_particles_in_current_cell
465+
for local_neighbor in 1:n_particles_in_neighbor_cell
466+
@inbounds neighbor = local_points[local_neighbor]
467+
@inbounds neighbor_coords = extract_svector(local_neighbor_coords,
468+
Val(ndims(neighborhood_search)),
469+
local_neighbor)
470+
471+
pos_diff = point_coords - neighbor_coords
472+
distance2 = dot(pos_diff, pos_diff)
473+
474+
# TODO periodic
475+
476+
if distance2 <= search_radius^2
477+
distance = sqrt(distance2) # TODO: eventuell fastmath
478+
479+
# Inline to avoid loss of performance
480+
# compared to not using `foreach_point_neighbor`.
481+
@inline f(point, neighbor, pos_diff, distance)
482+
end
431483
end
432484
end
433-
return n_particles_in_neighbor_cell_
485+
486+
# Make sure all threads finished computing before we continue with copying
487+
@synchronize()
488+
end
489+
end
490+
491+
# @parallel(block) for cell in cells
492+
# @parallel(thread) for neighbor in first_neighbor_cell
493+
# copy_coordinates_to_localmem!(local_coords, neighbor)
494+
#
495+
# for neighbor_cell in neighboring_cells
496+
# @parallel(thread) for neighbor in neighbor_cell + 1
497+
# copy_coordinates_to_localmem!(next_local_coords, neighbor)
498+
#
499+
# # No synchronize needed. The following loop works on `local_coords`.
500+
#
501+
# @parallel(thread) for particle in cell
502+
# for neighbor in neighbor_cell
503+
# # This uses the neighbor coordinates from the local memory
504+
# compute(point, neighbor)
505+
#
506+
# # Make sure all threads finished computing before we switch variables
507+
# @synchronize
508+
# local_coords, next_local_coords = next_local_coords, local_coords
509+
@kernel cpu=false function foreach_neighbor_double_buffer(f::F, system_coords, neighbor_system_coords,
510+
neighborhood_search, cells, ::Val{MAX}, search_radius) where {F, MAX}
511+
cell_ = @index(Group)
512+
cell = @inbounds Tuple(cells[cell_])
513+
particleidx = @index(Local)
514+
@assert 1 <= particleidx <= MAX
515+
516+
# Coordinate buffer in local memory
517+
local_points = @localmem Int32 MAX
518+
local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)
519+
520+
# Next coordinate buffer in local memory
521+
next_local_points = @localmem Int32 MAX
522+
next_local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)
523+
524+
points = points_in_cell(cell, neighborhood_search)
525+
n_particles_in_current_cell = length(points)
526+
527+
# Extract point coordinates if a point lies on this thread
528+
if particleidx <= n_particles_in_current_cell
529+
point = @inbounds points[particleidx]
530+
point_coords = @inbounds extract_svector(system_coords, Val(ndims(neighborhood_search)),
531+
point)
532+
else
533+
point = zero(Int32)
534+
point_coords = zero(SVector{ndims(neighborhood_search), eltype(system_coords)})
434535
end
435536

436537
neighborhood = neighboring_cells(cell, neighborhood_search)
437538
# (neighbor_cell, state) = iterate(neighborhood)
438-
neighbor_cell = first(neighborhood)
539+
neighbor_cell = Tuple(first(neighborhood))
540+
541+
n_particles_in_neighbor_cell = copy_to_localmem!(local_points, local_neighbor_coords,
542+
neighbor_cell, neighbor_system_coords,
543+
neighborhood_search, particleidx)
439544

440-
n_particles_in_neighbor_cell = stage!(local_points, local_neighbor_coords, Tuple(neighbor_cell))
441545
@synchronize()
442546

443547
for neighbor_ in 1:length(neighborhood)
444-
neighbor_cell = @inbounds neighborhood[neighbor_]
548+
neighbor_cell = @inbounds Tuple(neighborhood[neighbor_])
445549

446550
# while true
447551
# next = iterate(neighborhood, state)
448552
# if next !== nothing
449-
# n_particles_in_neighbor_cell = stage!(local_points, local_neighbor_coords, Tuple(neighbor_cell))
450-
# @synchronize
553+
451554
if neighbor_ < length(neighborhood)
452-
next_neighbor_cell = neighborhood[neighbor_ + 1]
555+
next_neighbor_cell = @inbounds Tuple(neighborhood[neighbor_ + 1])
453556
# (next_neighbor_cell, state) = next
454-
next_n_particles_in_neighbor_cell = stage!(next_local_points, next_local_neighbor_coords, Tuple(next_neighbor_cell))
557+
next_n_particles_in_neighbor_cell = copy_to_localmem!(next_local_points, next_local_neighbor_coords,
558+
next_neighbor_cell, neighbor_system_coords,
559+
neighborhood_search, particleidx)
455560
end
456561

457562
# Now each thread works on one point again
458563
if particleidx <= n_particles_in_current_cell
459564
for local_neighbor in 1:n_particles_in_neighbor_cell
460565
@inbounds neighbor = local_points[local_neighbor]
461566
@inbounds neighbor_coords = extract_svector(local_neighbor_coords,
462-
Val(ndims(neighborhood_search)), local_neighbor)
567+
Val(ndims(neighborhood_search)),
568+
local_neighbor)
463569

464570
pos_diff = point_coords - neighbor_coords
465571
distance2 = dot(pos_diff, pos_diff)
466572

467573
# TODO periodic
468574

469575
if distance2 <= search_radius^2
470-
# KernelAbstractions.@print("Point $point, neighbor $neighbor with distance2 $distance2\n")
471576
distance = sqrt(distance2) # TODO: eventuell fastmath
472577

473578
# Inline to avoid loss of performance
@@ -476,17 +581,15 @@ end
476581
end
477582
end
478583
end
584+
479585
# next === nothing && break
480586
neighbor_ >= length(neighborhood) && break
481587
@synchronize()
588+
482589
# swap variables
483590
n_particles_in_neighbor_cell = next_n_particles_in_neighbor_cell
484-
temp = local_points
485-
local_points = next_local_points
486-
next_local_points = temp
487-
temp = local_neighbor_coords
488-
local_neighbor_coords = next_local_neighbor_coords
489-
next_local_neighbor_coords = temp
591+
local_points, next_local_points = next_local_points, local_points
592+
local_neighbor_coords, next_local_neighbor_coords = next_local_neighbor_coords, local_neighbor_coords
490593
end
491594
end
492595

0 commit comments

Comments
 (0)