Skip to content

Commit a1ade1e

Browse files
committed
SIMD: Use AMReX versions
1 parent daef9a8 commit a1ade1e

1 file changed

Lines changed: 46 additions & 164 deletions

File tree

src/elements/mixin/beamoptic.H

Lines changed: 46 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -21,125 +21,10 @@
2121

2222
#include <type_traits>
2323

24-
25-
namespace impactx
26-
{
27-
/** A generalized ParallelFor that dispatches to amrex::ParallelFor or amrex::ParallelForSIMD
28-
* depending on whether the element type T_Element is vectorized.
29-
*
30-
* @tparam T_Element the element type
31-
* @tparam F the functor type
32-
* @param n the number of items to iterate over
33-
* @param f the functor to execute
34-
*/
35-
template <typename T_Element, typename F>
36-
void ParallelFor (int n, F&& f) {
37-
#ifdef AMREX_USE_SIMD
38-
if constexpr (amrex::simd::is_vectorized<T_Element>) {
39-
amrex::ParallelForSIMD<T_Element::simd_width>(n, std::forward<F>(f));
40-
} else
41-
#endif
42-
{
43-
amrex::ParallelFor(n, std::forward<F>(f));
44-
}
45-
}
46-
} // namespace impactx
47-
4824
namespace impactx::elements::mixin
4925
{
5026
namespace detail
5127
{
52-
/** Load particle data from array pointers
53-
*
54-
* On GPU and CPU w/o SIMD, this dereferences a particle property at the
55-
* index position i.
56-
* On CPU with SIMD, this loads a SIMD register at the IndexType::width
57-
* SIMD-wide index position i.
58-
*
59-
* @tparam T data type (amrex::ParticleReal or uint64_t)
60-
* @tparam IndexType int or amrex::SIMDindex<SIMD_WIDTH, int>
61-
* @param ptr pointer to the array data
62-
* @param i index or SIMD index
63-
* @return a reference to the data (scalar) or a SIMD register (SIMD)
64-
*/
65-
template <typename T, typename IndexType>
66-
AMREX_GPU_DEVICE AMREX_FORCE_INLINE
67-
decltype(auto) load_pdata (T* ptr, IndexType const i)
68-
{
69-
if constexpr (std::is_integral_v<IndexType>) {
70-
return ptr[i];
71-
} else {
72-
#ifdef AMREX_USE_SIMD
73-
using namespace amrex::simd;
74-
using RealType = SIMDParticleReal<IndexType::width>;
75-
using IdCpuType = SIMDIdCpu<RealType>;
76-
using DataType = std::conditional_t<
77-
std::is_same_v<T, amrex::ParticleReal>,
78-
RealType,
79-
IdCpuType
80-
>;
81-
82-
// initialize vector register
83-
// TODO stdx::vector_aligned needs alignment guarantees
84-
// https://github.com/AMReX-Codes/amrex/issues/4592
85-
// https://en.cppreference.com/w/cpp/experimental/simd/simd/copy_from
86-
DataType val;
87-
val.copy_from(&ptr[i.index], stdx::element_aligned);
88-
return val;
89-
90-
#else
91-
// error handling: we should never get here
92-
amrex::ignore_unused(ptr, i);
93-
amrex::Abort("SIMD index used but SIMD is not enabled");
94-
return ptr[0];
95-
#endif
96-
}
97-
}
98-
99-
/** Store particle data back to array pointers
100-
*
101-
* On GPU and CPU without SIMD, this does nothing because we already
102-
* modified the (global) RAM directly via pointer.
103-
*
104-
* On CPU with SIMD, this performs a conditional writeback of a SIMD register
105-
* to RAM (index in pointer array), but only if the argument was not passed
106-
* as const and thus was likely changed.
107-
*
108-
* Good optimizing compilers can eliminate writebacks of unchanged values
109-
* themselves, but we better help a little for robustness. Background:
110-
* https://github.com/AMReX-Codes/amrex/pull/4520#issuecomment-3064064215
111-
*
112-
* @tparam P_Method pointer to the push method (for is_nth_arg_non_const)
113-
* @tparam N the argument index (for is_nth_arg_non_const)
114-
* @tparam T data type
115-
* @tparam IndexType int or SIMD index
116-
* @tparam ValType the type of the value to store
117-
* @param val the value to store
118-
* @param ptr pointer to the SoA data
119-
* @param i index or SIMD index
120-
*/
121-
template <auto P_Method, int N, typename T, typename IndexType, typename ValType>
122-
AMREX_GPU_DEVICE AMREX_FORCE_INLINE
123-
void store_pdata (
124-
ValType const & AMREX_RESTRICT val,
125-
T * const AMREX_RESTRICT ptr,
126-
IndexType const i
127-
)
128-
{
129-
#ifdef AMREX_USE_SIMD
130-
if constexpr (!std::is_integral_v<IndexType>) {
131-
if constexpr (amrex::simd::is_nth_arg_non_const(P_Method, N)) {
132-
// write back to memory
133-
// TODO stdx::vector_aligned needs alignment guarantees
134-
// https://github.com/AMReX-Codes/amrex/issues/4592
135-
// https://en.cppreference.com/w/cpp/experimental/simd/simd/copy_from
136-
val.copy_to(&ptr[i.index], amrex::simd::stdx::element_aligned);
137-
}
138-
}
139-
#endif
140-
amrex::ignore_unused(val, ptr, i);
141-
}
142-
14328
/** Push a single particle through an element
14429
*
14530
* Note: we usually would just write a C++ lambda below in ParallelFor. But, due to restrictions
@@ -216,40 +101,37 @@ namespace detail
216101

217102
// access SoA data
218103
// note: an optimizing compiler will eliminate loads of unused parameters
219-
decltype(auto) x = load_pdata(m_part_x, i);
220-
decltype(auto) y = load_pdata(m_part_y, i);
221-
decltype(auto) t = load_pdata(m_part_t, i);
222-
decltype(auto) px = load_pdata(m_part_px, i);
223-
decltype(auto) py = load_pdata(m_part_py, i);
224-
decltype(auto) pt = load_pdata(m_part_pt, i);
225-
decltype(auto) sx = load_pdata(m_part_sx, i);
226-
decltype(auto) sy = load_pdata(m_part_sy, i);
227-
decltype(auto) sz = load_pdata(m_part_sz, i);
228-
decltype(auto) idcpu = load_pdata(m_part_idcpu, i);
104+
decltype(auto) x = load_1d(m_part_x, i);
105+
decltype(auto) y = load_1d(m_part_y, i);
106+
decltype(auto) t = load_1d(m_part_t, i);
107+
decltype(auto) px = load_1d(m_part_px, i);
108+
decltype(auto) py = load_1d(m_part_py, i);
109+
decltype(auto) pt = load_1d(m_part_pt, i);
110+
decltype(auto) sx = load_1d(m_part_sx, i);
111+
decltype(auto) sy = load_1d(m_part_sy, i);
112+
decltype(auto) sz = load_1d(m_part_sz, i);
113+
decltype(auto) idcpu = load_1d(m_part_idcpu, i);
229114

230115
// push spin & phase space through element
231116
m_element.spin_and_phasespace_push(x, y, t, px, py, pt, sx, sy, sz, idcpu, m_ref_part);
232117

233118
// SIMD: write back to memory
234-
#ifdef AMREX_USE_SIMD
235-
if constexpr (amrex::simd::is_vectorized<T_Element>)
236-
{
237-
using RealType = std::decay_t<decltype(x)>;
238-
using IdCpuType = std::decay_t<decltype(idcpu)>;
239-
constexpr auto P_Method = &T_Element::template spin_and_phasespace_push<RealType, IdCpuType>;
240-
241-
store_pdata<P_Method, 0>(x, m_part_x, i);
242-
store_pdata<P_Method, 1>(y, m_part_y, i);
243-
store_pdata<P_Method, 2>(t, m_part_t, i);
244-
store_pdata<P_Method, 3>(px, m_part_px, i);
245-
store_pdata<P_Method, 4>(py, m_part_py, i);
246-
store_pdata<P_Method, 5>(pt, m_part_pt, i);
247-
store_pdata<P_Method, 6>(sx, m_part_sx, i);
248-
store_pdata<P_Method, 7>(sy, m_part_sy, i);
249-
store_pdata<P_Method, 8>(sz, m_part_sz, i);
250-
store_pdata<P_Method, 9>(idcpu, m_part_idcpu, i);
251-
}
252-
#endif
119+
using RealType = std::decay_t<decltype(x)>;
120+
using IdCpuType = std::decay_t<decltype(idcpu)>;
121+
//static_assert(std::is_same_v<RealType, amrex::ParticleReal>, "SIMD push requires ParticleReal particle data!");
122+
//static_assert(std::is_same_v<IdCpuType, uint64_t>, "SIMD push requires uint64_t particle data!");
123+
constexpr decltype(auto) P_Method = &T_Element::template spin_and_phasespace_push<RealType, IdCpuType>;
124+
125+
store_1d<P_Method, 0>(x, m_part_x, i);
126+
store_1d<P_Method, 1>(y, m_part_y, i);
127+
store_1d<P_Method, 2>(t, m_part_t, i);
128+
store_1d<P_Method, 3>(px, m_part_px, i);
129+
store_1d<P_Method, 4>(py, m_part_py, i);
130+
store_1d<P_Method, 5>(pt, m_part_pt, i);
131+
store_1d<P_Method, 6>(sx, m_part_sx, i);
132+
store_1d<P_Method, 7>(sy, m_part_sy, i);
133+
store_1d<P_Method, 8>(sz, m_part_sz, i);
134+
store_1d<P_Method, 9>(idcpu, m_part_idcpu, i);
253135
}
254136

255137
private:
@@ -336,34 +218,34 @@ namespace detail
336218

337219
// access SoA data
338220
// note: an optimizing compiler will eliminate loads of unused parameters
339-
decltype(auto) x = load_pdata(m_part_x, i);
340-
decltype(auto) y = load_pdata(m_part_y, i);
341-
decltype(auto) t = load_pdata(m_part_t, i);
342-
decltype(auto) px = load_pdata(m_part_px, i);
343-
decltype(auto) py = load_pdata(m_part_py, i);
344-
decltype(auto) pt = load_pdata(m_part_pt, i);
345-
decltype(auto) idcpu = load_pdata(m_part_idcpu, i);
221+
decltype(auto) x = load_1d(m_part_x, i);
222+
decltype(auto) y = load_1d(m_part_y, i);
223+
decltype(auto) t = load_1d(m_part_t, i);
224+
decltype(auto) px = load_1d(m_part_px, i);
225+
decltype(auto) py = load_1d(m_part_py, i);
226+
decltype(auto) pt = load_1d(m_part_pt, i);
227+
decltype(auto) idcpu = load_1d(m_part_idcpu, i);
346228

347229
// push through element
348230
m_element(x, y, t, px, py, pt, idcpu, m_ref_part);
349231

350232
// write back to memory
351-
#ifdef AMREX_USE_SIMD
352233
if constexpr (amrex::simd::is_vectorized<T_Element>)
353234
{
354235
using RealType = std::decay_t<decltype(x)>;
355236
using IdCpuType = std::decay_t<decltype(idcpu)>;
356-
constexpr auto P_Method = &T_Element::template operator()<RealType, IdCpuType>;
357-
358-
store_pdata<P_Method, 0>(x, m_part_x, i);
359-
store_pdata<P_Method, 1>(y, m_part_y, i);
360-
store_pdata<P_Method, 2>(t, m_part_t, i);
361-
store_pdata<P_Method, 3>(px, m_part_px, i);
362-
store_pdata<P_Method, 4>(py, m_part_py, i);
363-
store_pdata<P_Method, 5>(pt, m_part_pt, i);
364-
store_pdata<P_Method, 6>(idcpu, m_part_idcpu, i);
237+
//static_assert(std::is_same_v<RealType, amrex::ParticleReal>, "SIMD push requires ParticleReal particle data!");
238+
//static_assert(std::is_same_v<IdCpuType, uint64_t>, "SIMD push requires uint64_t particle data!");
239+
constexpr decltype(auto) P_Method = &T_Element::template operator()<RealType, IdCpuType>;
240+
241+
store_1d<P_Method, 0>(x, m_part_x, i);
242+
store_1d<P_Method, 1>(y, m_part_y, i);
243+
store_1d<P_Method, 2>(t, m_part_t, i);
244+
store_1d<P_Method, 3>(px, m_part_px, i);
245+
store_1d<P_Method, 4>(py, m_part_py, i);
246+
store_1d<P_Method, 5>(pt, m_part_pt, i);
247+
store_1d<P_Method, 6>(idcpu, m_part_idcpu, i);
365248
}
366-
#endif
367249
}
368250

369251
private:
@@ -411,7 +293,7 @@ namespace detail
411293
element, part_x, part_y, part_t, part_px, part_py, part_pt, part_sx, part_sy, part_sz, part_idcpu, ref_part);
412294

413295
// loop over beam particles in the box
414-
impactx::ParallelFor<T_Element>(np, pushSingleParticle);
296+
amrex::ParallelForSIMD<T_Element>(np, pushSingleParticle);
415297
} else {
416298
throw std::runtime_error("Spin transport requested but element does not implement the `SpinTransport` interface class!");
417299
}
@@ -421,7 +303,7 @@ namespace detail
421303
element, part_x, part_y, part_t, part_px, part_py, part_pt, part_idcpu, ref_part);
422304

423305
// loop over beam particles in the box
424-
impactx::ParallelFor<T_Element>(np, pushSingleParticle);
306+
amrex::ParallelForSIMD<T_Element>(np, pushSingleParticle);
425307
}
426308
}
427309
} // namespace detail

0 commit comments

Comments
 (0)