@@ -395,79 +395,184 @@ end
395395 return nothing
396396end
397397
398+ @inline function copy_to_localmem! (local_points, local_neighbor_coords,
399+ neighbor_cell, neighbor_system_coords,
400+ neighborhood_search, particleidx)
401+ points_view = points_in_cell (neighbor_cell, neighborhood_search)
402+ n_particles_in_neighbor_cell = length (points_view)
403+
404+ # First use all threads to load the neighbors into local memory in parallel
405+ if particleidx <= n_particles_in_neighbor_cell
406+ @inbounds p = local_points[particleidx] = points_view[particleidx]
407+ for d in 1 : ndims (neighborhood_search)
408+ @inbounds local_neighbor_coords[d, particleidx] = neighbor_system_coords[d, p]
409+ end
410+ end
411+ return n_particles_in_neighbor_cell
412+ end
413+
414+ # @parallel(block) for cell in cells
415+ # for neighbor_cell in neighboring_cells
416+ # @parallel(thread) for neighbor in neighbor_cell
417+ # copy_coordinates_to_localmem(neighbor)
418+ #
419+ # # Make sure all threads finished the copying
420+ # @synchronize
421+ #
422+ # @parallel(thread) for particle in cell
423+ # for neighbor in neighbor_cell
424+ # # This uses the neighbor coordinates from the local memory
425+ # compute(point, neighbor)
426+ #
427+ # # Make sure all threads finished computing before we continue with copying
428+ # @synchronize
398429@kernel cpu= false function foreach_neighbor_localmem (f:: F , system_coords, neighbor_system_coords,
399430 neighborhood_search, cells, :: Val{MAX} , search_radius) where {F, MAX}
400431 cell_ = @index (Group)
401432 cell = @inbounds Tuple (cells[cell_])
402433 particleidx = @index (Local)
403434 @assert 1 <= particleidx <= MAX
404435
436+ # Coordinate buffer in local memory
405437 local_points = @localmem Int32 MAX
406438 local_neighbor_coords = @localmem eltype (system_coords) (ndims (neighborhood_search), MAX)
407439
408- next_local_points = @localmem Int32 MAX
409- next_local_neighbor_coords = @localmem eltype (system_coords) ( ndims (neighborhood_search), MAX )
440+ points = points_in_cell (cell, neighborhood_search)
441+ n_particles_in_current_cell = length (points )
410442
411- pv = points_in_cell (cell, neighborhood_search)
412- n_particles_in_current_cell = length (pv)
443+ # Extract point coordinates if a point lies on this thread
413444 if particleidx <= n_particles_in_current_cell
414- point = @inbounds pv [particleidx]
445+ point = @inbounds points [particleidx]
415446 point_coords = @inbounds extract_svector (system_coords, Val (ndims (neighborhood_search)),
416447 point)
417448 else
418449 point = zero (Int32)
419450 point_coords = zero (SVector{ndims (neighborhood_search), eltype (system_coords)})
420451 end
421452
422- @inline function stage! (local_points, local_neighbor_coords, neighbor_cell)
423- points_view = points_in_cell (neighbor_cell, neighborhood_search)
424- n_particles_in_neighbor_cell_ = length (points_view)
453+ for neighbor_cell_ in neighboring_cells (cell, neighborhood_search)
454+ neighbor_cell = Tuple (neighbor_cell_)
455+
456+ n_particles_in_neighbor_cell = copy_to_localmem! (local_points, local_neighbor_coords,
457+ neighbor_cell, neighbor_system_coords,
458+ neighborhood_search, particleidx)
459+
460+ # Make sure all threads finished the copying
461+ @synchronize
425462
426- # First use all threads to load the neighbors into local memory in parallel
427- if particleidx <= n_particles_in_neighbor_cell_
428- @inbounds p = local_points[particleidx] = points_view[particleidx]
429- for d in 1 : ndims (neighborhood_search)
430- @inbounds local_neighbor_coords[d, particleidx] = neighbor_system_coords[d, p]
463+ # Now each thread works on one point again
464+ if particleidx <= n_particles_in_current_cell
465+ for local_neighbor in 1 : n_particles_in_neighbor_cell
466+ @inbounds neighbor = local_points[local_neighbor]
467+ @inbounds neighbor_coords = extract_svector (local_neighbor_coords,
468+ Val (ndims (neighborhood_search)),
469+ local_neighbor)
470+
471+ pos_diff = point_coords - neighbor_coords
472+ distance2 = dot (pos_diff, pos_diff)
473+
474+ # TODO periodic
475+
476+ if distance2 <= search_radius^ 2
477+ distance = sqrt (distance2) # TODO : eventuell fastmath
478+
479+ # Inline to avoid loss of performance
480+ # compared to not using `foreach_point_neighbor`.
481+ @inline f (point, neighbor, pos_diff, distance)
482+ end
431483 end
432484 end
433- return n_particles_in_neighbor_cell_
485+
486+ # Make sure all threads finished computing before we continue with copying
487+ @synchronize ()
488+ end
489+ end
490+
491+ # @parallel(block) for cell in cells
492+ # @parallel(thread) for neighbor in first_neighbor_cell
493+ # copy_coordinates_to_localmem!(local_coords, neighbor)
494+ #
495+ # for neighbor_cell in neighboring_cells
496+ # @parallel(thread) for neighbor in neighbor_cell + 1
497+ # copy_coordinates_to_localmem!(next_local_coords, neighbor)
498+ #
499+ # # No synchronize needed. The following loop works on `local_coords`.
500+ #
501+ # @parallel(thread) for particle in cell
502+ # for neighbor in neighbor_cell
503+ # # This uses the neighbor coordinates from the local memory
504+ # compute(point, neighbor)
505+ #
506+ # # Make sure all threads finished computing before we switch variables
507+ # @synchronize
508+ # local_coords, next_local_coords = next_local_coords, local_coords
509+ @kernel cpu= false function foreach_neighbor_double_buffer (f:: F , system_coords, neighbor_system_coords,
510+ neighborhood_search, cells, :: Val{MAX} , search_radius) where {F, MAX}
511+ cell_ = @index (Group)
512+ cell = @inbounds Tuple (cells[cell_])
513+ particleidx = @index (Local)
514+ @assert 1 <= particleidx <= MAX
515+
516+ # Coordinate buffer in local memory
517+ local_points = @localmem Int32 MAX
518+ local_neighbor_coords = @localmem eltype (system_coords) (ndims (neighborhood_search), MAX)
519+
520+ # Next coordinate buffer in local memory
521+ next_local_points = @localmem Int32 MAX
522+ next_local_neighbor_coords = @localmem eltype (system_coords) (ndims (neighborhood_search), MAX)
523+
524+ points = points_in_cell (cell, neighborhood_search)
525+ n_particles_in_current_cell = length (points)
526+
527+ # Extract point coordinates if a point lies on this thread
528+ if particleidx <= n_particles_in_current_cell
529+ point = @inbounds points[particleidx]
530+ point_coords = @inbounds extract_svector (system_coords, Val (ndims (neighborhood_search)),
531+ point)
532+ else
533+ point = zero (Int32)
534+ point_coords = zero (SVector{ndims (neighborhood_search), eltype (system_coords)})
434535 end
435536
436537 neighborhood = neighboring_cells (cell, neighborhood_search)
437538 # (neighbor_cell, state) = iterate(neighborhood)
438- neighbor_cell = first (neighborhood)
539+ neighbor_cell = Tuple (first (neighborhood))
540+
541+ n_particles_in_neighbor_cell = copy_to_localmem! (local_points, local_neighbor_coords,
542+ neighbor_cell, neighbor_system_coords,
543+ neighborhood_search, particleidx)
439544
440- n_particles_in_neighbor_cell = stage! (local_points, local_neighbor_coords, Tuple (neighbor_cell))
441545 @synchronize ()
442546
443547 for neighbor_ in 1 : length (neighborhood)
444- neighbor_cell = @inbounds neighborhood[neighbor_]
548+ neighbor_cell = @inbounds Tuple ( neighborhood[neighbor_])
445549
446550 # while true
447551 # next = iterate(neighborhood, state)
448552 # if next !== nothing
449- # n_particles_in_neighbor_cell = stage!(local_points, local_neighbor_coords, Tuple(neighbor_cell))
450- # @synchronize
553+
451554 if neighbor_ < length (neighborhood)
452- next_neighbor_cell = neighborhood[neighbor_ + 1 ]
555+ next_neighbor_cell = @inbounds Tuple ( neighborhood[neighbor_ + 1 ])
453556 # (next_neighbor_cell, state) = next
454- next_n_particles_in_neighbor_cell = stage! (next_local_points, next_local_neighbor_coords, Tuple (next_neighbor_cell))
557+ next_n_particles_in_neighbor_cell = copy_to_localmem! (next_local_points, next_local_neighbor_coords,
558+ next_neighbor_cell, neighbor_system_coords,
559+ neighborhood_search, particleidx)
455560 end
456561
457562 # Now each thread works on one point again
458563 if particleidx <= n_particles_in_current_cell
459564 for local_neighbor in 1 : n_particles_in_neighbor_cell
460565 @inbounds neighbor = local_points[local_neighbor]
461566 @inbounds neighbor_coords = extract_svector (local_neighbor_coords,
462- Val (ndims (neighborhood_search)), local_neighbor)
567+ Val (ndims (neighborhood_search)),
568+ local_neighbor)
463569
464570 pos_diff = point_coords - neighbor_coords
465571 distance2 = dot (pos_diff, pos_diff)
466572
467573 # TODO periodic
468574
469575 if distance2 <= search_radius^ 2
470- # KernelAbstractions.@print("Point $point, neighbor $neighbor with distance2 $distance2\n")
471576 distance = sqrt (distance2) # TODO : eventuell fastmath
472577
473578 # Inline to avoid loss of performance
@@ -476,17 +581,15 @@ end
476581 end
477582 end
478583 end
584+
479585 # next === nothing && break
480586 neighbor_ >= length (neighborhood) && break
481587 @synchronize ()
588+
482589 # swap variables
483590 n_particles_in_neighbor_cell = next_n_particles_in_neighbor_cell
484- temp = local_points
485- local_points = next_local_points
486- next_local_points = temp
487- temp = local_neighbor_coords
488- local_neighbor_coords = next_local_neighbor_coords
489- next_local_neighbor_coords = temp
591+ local_points, next_local_points = next_local_points, local_points
592+ local_neighbor_coords, next_local_neighbor_coords = next_local_neighbor_coords, local_neighbor_coords
490593 end
491594end
492595
0 commit comments