2121
2222#include < type_traits>
2323
24-
25- namespace impactx
26- {
27- /* * A generalized ParallelFor that dispatches to amrex::ParallelFor or amrex::ParallelForSIMD
28- * depending on whether the element type T_Element is vectorized.
29- *
30- * @tparam T_Element the element type
31- * @tparam F the functor type
32- * @param n the number of items to iterate over
33- * @param f the functor to execute
34- */
35- template <typename T_Element, typename F>
36- void ParallelFor (int n, F&& f) {
37- #ifdef AMREX_USE_SIMD
38- if constexpr (amrex::simd::is_vectorized<T_Element>) {
39- amrex::ParallelForSIMD<T_Element::simd_width>(n, std::forward<F>(f));
40- } else
41- #endif
42- {
43- amrex::ParallelFor (n, std::forward<F>(f));
44- }
45- }
46- } // namespace impactx
47-
4824namespace impactx ::elements::mixin
4925{
5026namespace detail
5127{
52- /* * Load particle data from array pointers
53- *
54- * On GPU and CPU w/o SIMD, this dereferences a particle property at the
55- * index position i.
56- * On CPU with SIMD, this loads a SIMD register at the IndexType::width
57- * SIMD-wide index position i.
58- *
59- * @tparam T data type (amrex::ParticleReal or uint64_t)
60- * @tparam IndexType int or amrex::SIMDindex<SIMD_WIDTH, int>
61- * @param ptr pointer to the array data
62- * @param i index or SIMD index
63- * @return a reference to the data (scalar) or a SIMD register (SIMD)
64- */
65- template <typename T, typename IndexType>
66- AMREX_GPU_DEVICE AMREX_FORCE_INLINE
67- decltype (auto ) load_pdata (T* ptr, IndexType const i)
68- {
69- if constexpr (std::is_integral_v<IndexType>) {
70- return ptr[i];
71- } else {
72- #ifdef AMREX_USE_SIMD
73- using namespace amrex ::simd;
74- using RealType = SIMDParticleReal<IndexType::width>;
75- using IdCpuType = SIMDIdCpu<RealType>;
76- using DataType = std::conditional_t <
77- std::is_same_v<T, amrex::ParticleReal>,
78- RealType,
79- IdCpuType
80- >;
81-
82- // initialize vector register
83- // TODO stdx::vector_aligned needs alignment guarantees
84- // https://github.com/AMReX-Codes/amrex/issues/4592
85- // https://en.cppreference.com/w/cpp/experimental/simd/simd/copy_from
86- DataType val;
87- val.copy_from (&ptr[i.index ], stdx::element_aligned);
88- return val;
89-
90- #else
91- // error handling: we should never get here
92- amrex::ignore_unused (ptr, i);
93- amrex::Abort (" SIMD index used but SIMD is not enabled" );
94- return ptr[0 ];
95- #endif
96- }
97- }
98-
99- /* * Store particle data back to array pointers
100- *
101- * On GPU and CPU without SIMD, this does nothing because we already
102- * modified the (global) RAM directly via pointer.
103- *
104- * On CPU with SIMD, this performs a conditional writeback of a SIMD register
105- * to RAM (index in pointer array), but only if the argument was not passed
106- * as const and thus was likely changed.
107- *
108- * Good optimizing compilers can eliminate writebacks of unchanged values
109- * themselves, but we better help a little for robustness. Background:
110- * https://github.com/AMReX-Codes/amrex/pull/4520#issuecomment-3064064215
111- *
112- * @tparam P_Method pointer to the push method (for is_nth_arg_non_const)
113- * @tparam N the argument index (for is_nth_arg_non_const)
114- * @tparam T data type
115- * @tparam IndexType int or SIMD index
116- * @tparam ValType the type of the value to store
117- * @param val the value to store
118- * @param ptr pointer to the SoA data
119- * @param i index or SIMD index
120- */
121- template <auto P_Method, int N, typename T, typename IndexType, typename ValType>
122- AMREX_GPU_DEVICE AMREX_FORCE_INLINE
123- void store_pdata (
124- ValType const & AMREX_RESTRICT val,
125- T * const AMREX_RESTRICT ptr,
126- IndexType const i
127- )
128- {
129- #ifdef AMREX_USE_SIMD
130- if constexpr (!std::is_integral_v<IndexType>) {
131- if constexpr (amrex::simd::is_nth_arg_non_const (P_Method, N)) {
132- // write back to memory
133- // TODO stdx::vector_aligned needs alignment guarantees
134- // https://github.com/AMReX-Codes/amrex/issues/4592
135- // https://en.cppreference.com/w/cpp/experimental/simd/simd/copy_from
136- val.copy_to (&ptr[i.index ], amrex::simd::stdx::element_aligned);
137- }
138- }
139- #endif
140- amrex::ignore_unused (val, ptr, i);
141- }
142-
14328 /* * Push a single particle through an element
14429 *
14530 * Note: we usually would just write a C++ lambda below in ParallelFor. But, due to restrictions
@@ -216,40 +101,37 @@ namespace detail
216101
217102 // access SoA data
218103 // note: an optimizing compiler will eliminate loads of unused parameters
219- decltype (auto ) x = load_pdata (m_part_x, i);
220- decltype (auto ) y = load_pdata (m_part_y, i);
221- decltype (auto ) t = load_pdata (m_part_t , i);
222- decltype (auto ) px = load_pdata (m_part_px, i);
223- decltype (auto ) py = load_pdata (m_part_py, i);
224- decltype (auto ) pt = load_pdata (m_part_pt, i);
225- decltype (auto ) sx = load_pdata (m_part_sx, i);
226- decltype (auto ) sy = load_pdata (m_part_sy, i);
227- decltype (auto ) sz = load_pdata (m_part_sz, i);
228- decltype (auto ) idcpu = load_pdata (m_part_idcpu, i);
104+ decltype (auto ) x = load_1d (m_part_x, i);
105+ decltype (auto ) y = load_1d (m_part_y, i);
106+ decltype (auto ) t = load_1d (m_part_t , i);
107+ decltype (auto ) px = load_1d (m_part_px, i);
108+ decltype (auto ) py = load_1d (m_part_py, i);
109+ decltype (auto ) pt = load_1d (m_part_pt, i);
110+ decltype (auto ) sx = load_1d (m_part_sx, i);
111+ decltype (auto ) sy = load_1d (m_part_sy, i);
112+ decltype (auto ) sz = load_1d (m_part_sz, i);
113+ decltype (auto ) idcpu = load_1d (m_part_idcpu, i);
229114
230115 // push spin & phase space through element
231116 m_element.spin_and_phasespace_push (x, y, t, px, py, pt, sx, sy, sz, idcpu, m_ref_part);
232117
233118 // SIMD: write back to memory
234- #ifdef AMREX_USE_SIMD
235- if constexpr (amrex::simd::is_vectorized<T_Element>)
236- {
237- using RealType = std::decay_t <decltype (x)>;
238- using IdCpuType = std::decay_t <decltype (idcpu)>;
239- constexpr auto P_Method = &T_Element::template spin_and_phasespace_push<RealType, IdCpuType>;
240-
241- store_pdata<P_Method, 0 >(x, m_part_x, i);
242- store_pdata<P_Method, 1 >(y, m_part_y, i);
243- store_pdata<P_Method, 2 >(t, m_part_t , i);
244- store_pdata<P_Method, 3 >(px, m_part_px, i);
245- store_pdata<P_Method, 4 >(py, m_part_py, i);
246- store_pdata<P_Method, 5 >(pt, m_part_pt, i);
247- store_pdata<P_Method, 6 >(sx, m_part_sx, i);
248- store_pdata<P_Method, 7 >(sy, m_part_sy, i);
249- store_pdata<P_Method, 8 >(sz, m_part_sz, i);
250- store_pdata<P_Method, 9 >(idcpu, m_part_idcpu, i);
251- }
252- #endif
119+ using RealType = std::decay_t <decltype (x)>;
120+ using IdCpuType = std::decay_t <decltype (idcpu)>;
121+ // static_assert(std::is_same_v<RealType, amrex::ParticleReal>, "SIMD push requires ParticleReal particle data!");
122+ // static_assert(std::is_same_v<IdCpuType, uint64_t>, "SIMD push requires uint64_t particle data!");
123+ constexpr decltype (auto ) P_Method = &T_Element::template spin_and_phasespace_push<RealType, IdCpuType>;
124+
125+ store_1d<P_Method, 0 >(x, m_part_x, i);
126+ store_1d<P_Method, 1 >(y, m_part_y, i);
127+ store_1d<P_Method, 2 >(t, m_part_t , i);
128+ store_1d<P_Method, 3 >(px, m_part_px, i);
129+ store_1d<P_Method, 4 >(py, m_part_py, i);
130+ store_1d<P_Method, 5 >(pt, m_part_pt, i);
131+ store_1d<P_Method, 6 >(sx, m_part_sx, i);
132+ store_1d<P_Method, 7 >(sy, m_part_sy, i);
133+ store_1d<P_Method, 8 >(sz, m_part_sz, i);
134+ store_1d<P_Method, 9 >(idcpu, m_part_idcpu, i);
253135 }
254136
255137 private:
@@ -336,34 +218,34 @@ namespace detail
336218
337219 // access SoA data
338220 // note: an optimizing compiler will eliminate loads of unused parameters
339- decltype (auto ) x = load_pdata (m_part_x, i);
340- decltype (auto ) y = load_pdata (m_part_y, i);
341- decltype (auto ) t = load_pdata (m_part_t , i);
342- decltype (auto ) px = load_pdata (m_part_px, i);
343- decltype (auto ) py = load_pdata (m_part_py, i);
344- decltype (auto ) pt = load_pdata (m_part_pt, i);
345- decltype (auto ) idcpu = load_pdata (m_part_idcpu, i);
221+ decltype (auto ) x = load_1d (m_part_x, i);
222+ decltype (auto ) y = load_1d (m_part_y, i);
223+ decltype (auto ) t = load_1d (m_part_t , i);
224+ decltype (auto ) px = load_1d (m_part_px, i);
225+ decltype (auto ) py = load_1d (m_part_py, i);
226+ decltype (auto ) pt = load_1d (m_part_pt, i);
227+ decltype (auto ) idcpu = load_1d (m_part_idcpu, i);
346228
347229 // push through element
348230 m_element (x, y, t, px, py, pt, idcpu, m_ref_part);
349231
350232 // write back to memory
351- #ifdef AMREX_USE_SIMD
352233 if constexpr (amrex::simd::is_vectorized<T_Element>)
353234 {
354235 using RealType = std::decay_t <decltype (x)>;
355236 using IdCpuType = std::decay_t <decltype (idcpu)>;
356- constexpr auto P_Method = &T_Element::template operator ()<RealType, IdCpuType>;
357-
358- store_pdata<P_Method, 0 >(x, m_part_x, i);
359- store_pdata<P_Method, 1 >(y, m_part_y, i);
360- store_pdata<P_Method, 2 >(t, m_part_t , i);
361- store_pdata<P_Method, 3 >(px, m_part_px, i);
362- store_pdata<P_Method, 4 >(py, m_part_py, i);
363- store_pdata<P_Method, 5 >(pt, m_part_pt, i);
364- store_pdata<P_Method, 6 >(idcpu, m_part_idcpu, i);
237+ // static_assert(std::is_same_v<RealType, amrex::ParticleReal>, "SIMD push requires ParticleReal particle data!");
238+ // static_assert(std::is_same_v<IdCpuType, uint64_t>, "SIMD push requires uint64_t particle data!");
239+ constexpr decltype (auto ) P_Method = &T_Element::template operator ()<RealType, IdCpuType>;
240+
241+ store_1d<P_Method, 0 >(x, m_part_x, i);
242+ store_1d<P_Method, 1 >(y, m_part_y, i);
243+ store_1d<P_Method, 2 >(t, m_part_t , i);
244+ store_1d<P_Method, 3 >(px, m_part_px, i);
245+ store_1d<P_Method, 4 >(py, m_part_py, i);
246+ store_1d<P_Method, 5 >(pt, m_part_pt, i);
247+ store_1d<P_Method, 6 >(idcpu, m_part_idcpu, i);
365248 }
366- #endif
367249 }
368250
369251 private:
@@ -411,7 +293,7 @@ namespace detail
411293 element, part_x, part_y, part_t , part_px, part_py, part_pt, part_sx, part_sy, part_sz, part_idcpu, ref_part);
412294
413295 // loop over beam particles in the box
414- impactx::ParallelFor <T_Element>(np, pushSingleParticle);
296+ amrex::ParallelForSIMD <T_Element>(np, pushSingleParticle);
415297 } else {
416298 throw std::runtime_error (" Spin transport requested but element does not implement the `SpinTransport` interface class!" );
417299 }
@@ -421,7 +303,7 @@ namespace detail
421303 element, part_x, part_y, part_t , part_px, part_py, part_pt, part_idcpu, ref_part);
422304
423305 // loop over beam particles in the box
424- impactx::ParallelFor <T_Element>(np, pushSingleParticle);
306+ amrex::ParallelForSIMD <T_Element>(np, pushSingleParticle);
425307 }
426308 }
427309} // namespace detail
0 commit comments