Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 42 additions & 164 deletions src/elements/mixin/beamoptic.H
Original file line number Diff line number Diff line change
Expand Up @@ -21,125 +21,10 @@

#include <type_traits>


namespace impactx
{
/** A generalized ParallelFor that dispatches to amrex::ParallelFor or amrex::ParallelForSIMD
* depending on whether the element type T_Element is vectorized.
*
* @tparam T_Element the element type
* @tparam F the functor type
* @param n the number of items to iterate over
* @param f the functor to execute
*/
template <typename T_Element, typename F>
void ParallelFor (int n, F&& f) {
#ifdef AMREX_USE_SIMD
if constexpr (amrex::simd::is_vectorized<T_Element>) {
amrex::ParallelForSIMD<T_Element::simd_width>(n, std::forward<F>(f));
} else
#endif
{
amrex::ParallelFor(n, std::forward<F>(f));
}
}
} // namespace impactx

namespace impactx::elements::mixin
{
namespace detail
{
/** Load particle data from array pointers
*
* On GPU and CPU w/o SIMD, this dereferences a particle property at the
* index position i.
* On CPU with SIMD, this loads a SIMD register at the IndexType::width
* SIMD-wide index position i.
*
* @tparam T data type (amrex::ParticleReal or uint64_t)
* @tparam IndexType int or amrex::SIMDindex<SIMD_WIDTH, int>
* @param ptr pointer to the array data
* @param i index or SIMD index
* @return a reference to the data (scalar) or a SIMD register (SIMD)
*/
template <typename T, typename IndexType>
AMREX_GPU_DEVICE AMREX_FORCE_INLINE
decltype(auto) load_pdata (T* ptr, IndexType const i)
{
if constexpr (std::is_integral_v<IndexType>) {
return ptr[i];
} else {
#ifdef AMREX_USE_SIMD
using namespace amrex::simd;
using RealType = SIMDParticleReal<IndexType::width>;
using IdCpuType = SIMDIdCpu<RealType>;
using DataType = std::conditional_t<
std::is_same_v<T, amrex::ParticleReal>,
RealType,
IdCpuType
>;

// initialize vector register
// TODO stdx::vector_aligned needs alignment guarantees
// https://github.com/AMReX-Codes/amrex/issues/4592
// https://en.cppreference.com/w/cpp/experimental/simd/simd/copy_from
DataType val;
val.copy_from(&ptr[i.index], stdx::element_aligned);
return val;

#else
// error handling: we should never get here
amrex::ignore_unused(ptr, i);
amrex::Abort("SIMD index used but SIMD is not enabled");
return ptr[0];
#endif
}
}

/** Store particle data back to array pointers
*
* On GPU and CPU without SIMD, this does nothing because we already
* modified the (global) RAM directly via pointer.
*
* On CPU with SIMD, this performs a conditional writeback of a SIMD register
* to RAM (index in pointer array), but only if the argument was not passed
* as const and thus was likely changed.
*
* Good optimizing compilers can eliminate writebacks of unchanged values
* themselves, but we better help a little for robustness. Background:
* https://github.com/AMReX-Codes/amrex/pull/4520#issuecomment-3064064215
*
* @tparam P_Method pointer to the push method (for is_nth_arg_non_const)
* @tparam N the argument index (for is_nth_arg_non_const)
* @tparam T data type
* @tparam IndexType int or SIMD index
* @tparam ValType the type of the value to store
* @param val the value to store
* @param ptr pointer to the SoA data
* @param i index or SIMD index
*/
template <auto P_Method, int N, typename T, typename IndexType, typename ValType>
AMREX_GPU_DEVICE AMREX_FORCE_INLINE
void store_pdata (
ValType const & AMREX_RESTRICT val,
T * const AMREX_RESTRICT ptr,
IndexType const i
)
{
#ifdef AMREX_USE_SIMD
if constexpr (!std::is_integral_v<IndexType>) {
if constexpr (amrex::simd::is_nth_arg_non_const(P_Method, N)) {
// write back to memory
// TODO stdx::vector_aligned needs alignment guarantees
// https://github.com/AMReX-Codes/amrex/issues/4592
// https://en.cppreference.com/w/cpp/experimental/simd/simd/copy_from
val.copy_to(&ptr[i.index], amrex::simd::stdx::element_aligned);
}
}
#endif
amrex::ignore_unused(val, ptr, i);
}

/** Push a single particle through an element
*
* Note: we usually would just write a C++ lambda below in ParallelFor. But, due to restrictions
Expand Down Expand Up @@ -216,40 +101,35 @@ namespace detail

// access SoA data
// note: an optimizing compiler will eliminate loads of unused parameters
decltype(auto) x = load_pdata(m_part_x, i);
decltype(auto) y = load_pdata(m_part_y, i);
decltype(auto) t = load_pdata(m_part_t, i);
decltype(auto) px = load_pdata(m_part_px, i);
decltype(auto) py = load_pdata(m_part_py, i);
decltype(auto) pt = load_pdata(m_part_pt, i);
decltype(auto) sx = load_pdata(m_part_sx, i);
decltype(auto) sy = load_pdata(m_part_sy, i);
decltype(auto) sz = load_pdata(m_part_sz, i);
decltype(auto) idcpu = load_pdata(m_part_idcpu, i);
decltype(auto) x = load_1d(m_part_x, i);
decltype(auto) y = load_1d(m_part_y, i);
decltype(auto) t = load_1d(m_part_t, i);
decltype(auto) px = load_1d(m_part_px, i);
decltype(auto) py = load_1d(m_part_py, i);
decltype(auto) pt = load_1d(m_part_pt, i);
decltype(auto) sx = load_1d(m_part_sx, i);
decltype(auto) sy = load_1d(m_part_sy, i);
decltype(auto) sz = load_1d(m_part_sz, i);
decltype(auto) idcpu = load_1d(m_part_idcpu, i);

// push spin & phase space through element
m_element.spin_and_phasespace_push(x, y, t, px, py, pt, sx, sy, sz, idcpu, m_ref_part);

// SIMD: write back to memory
#ifdef AMREX_USE_SIMD
if constexpr (amrex::simd::is_vectorized<T_Element>)
{
using RealType = std::decay_t<decltype(x)>;
using IdCpuType = std::decay_t<decltype(idcpu)>;
constexpr auto P_Method = &T_Element::template spin_and_phasespace_push<RealType, IdCpuType>;

store_pdata<P_Method, 0>(x, m_part_x, i);
store_pdata<P_Method, 1>(y, m_part_y, i);
store_pdata<P_Method, 2>(t, m_part_t, i);
store_pdata<P_Method, 3>(px, m_part_px, i);
store_pdata<P_Method, 4>(py, m_part_py, i);
store_pdata<P_Method, 5>(pt, m_part_pt, i);
store_pdata<P_Method, 6>(sx, m_part_sx, i);
store_pdata<P_Method, 7>(sy, m_part_sy, i);
store_pdata<P_Method, 8>(sz, m_part_sz, i);
store_pdata<P_Method, 9>(idcpu, m_part_idcpu, i);
}
#endif
using RealType = std::decay_t<decltype(x)>;
using IdCpuType = std::decay_t<decltype(idcpu)>;
constexpr decltype(auto) P_Method = &T_Element::template spin_and_phasespace_push<RealType, IdCpuType>;

store_1d<P_Method, 0>(x, m_part_x, i);
store_1d<P_Method, 1>(y, m_part_y, i);
store_1d<P_Method, 2>(t, m_part_t, i);
store_1d<P_Method, 3>(px, m_part_px, i);
store_1d<P_Method, 4>(py, m_part_py, i);
store_1d<P_Method, 5>(pt, m_part_pt, i);
store_1d<P_Method, 6>(sx, m_part_sx, i);
store_1d<P_Method, 7>(sy, m_part_sy, i);
store_1d<P_Method, 8>(sz, m_part_sz, i);
store_1d<P_Method, 9>(idcpu, m_part_idcpu, i);
Comment on lines +123 to +132
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These store_1d calls are not guarded anymore by

if constexpr (amrex::simd::is_vectorized<T_Element>)

while the store_1d calls below (lines 241-247) kept that guard (line 233).

Is this intentional and, if so, what's the reason for the asymmetry?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a special constexpr check in store_1d in AMReX

        // SIMD uses special vector register types in ValType that need to be copied back to RAM array type T
        if constexpr (!std::is_same_v<ValType, T>) {

that serves the same purpose as the "is this a vectorized type" check we had before.
https://github.com/AMReX-Codes/amrex/pull/4924/changes

I should remove the guard below as well and just forgot (it does not hurt, but it is verbose).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah but here is the catch:
Scalar-only elements like ExactCFbend have a concrete non-template particle operator() that we need to guard the constexpr decltype(auto) P_Method = ... from.

The spin path does not hit that issue because it uses spin_and_phasespace_push, and currently all the spin-capable elements implement that as a templated method.

}

private:
Expand Down Expand Up @@ -336,34 +216,32 @@ namespace detail

// access SoA data
// note: an optimizing compiler will eliminate loads of unused parameters
decltype(auto) x = load_pdata(m_part_x, i);
decltype(auto) y = load_pdata(m_part_y, i);
decltype(auto) t = load_pdata(m_part_t, i);
decltype(auto) px = load_pdata(m_part_px, i);
decltype(auto) py = load_pdata(m_part_py, i);
decltype(auto) pt = load_pdata(m_part_pt, i);
decltype(auto) idcpu = load_pdata(m_part_idcpu, i);
decltype(auto) x = load_1d(m_part_x, i);
decltype(auto) y = load_1d(m_part_y, i);
decltype(auto) t = load_1d(m_part_t, i);
decltype(auto) px = load_1d(m_part_px, i);
decltype(auto) py = load_1d(m_part_py, i);
decltype(auto) pt = load_1d(m_part_pt, i);
decltype(auto) idcpu = load_1d(m_part_idcpu, i);

// push through element
m_element(x, y, t, px, py, pt, idcpu, m_ref_part);

// write back to memory
#ifdef AMREX_USE_SIMD
if constexpr (amrex::simd::is_vectorized<T_Element>)
{
using RealType = std::decay_t<decltype(x)>;
using IdCpuType = std::decay_t<decltype(idcpu)>;
constexpr auto P_Method = &T_Element::template operator()<RealType, IdCpuType>;

store_pdata<P_Method, 0>(x, m_part_x, i);
store_pdata<P_Method, 1>(y, m_part_y, i);
store_pdata<P_Method, 2>(t, m_part_t, i);
store_pdata<P_Method, 3>(px, m_part_px, i);
store_pdata<P_Method, 4>(py, m_part_py, i);
store_pdata<P_Method, 5>(pt, m_part_pt, i);
store_pdata<P_Method, 6>(idcpu, m_part_idcpu, i);
constexpr decltype(auto) P_Method = &T_Element::template operator()<RealType, IdCpuType>;

store_1d<P_Method, 0>(x, m_part_x, i);
store_1d<P_Method, 1>(y, m_part_y, i);
store_1d<P_Method, 2>(t, m_part_t, i);
store_1d<P_Method, 3>(px, m_part_px, i);
store_1d<P_Method, 4>(py, m_part_py, i);
store_1d<P_Method, 5>(pt, m_part_pt, i);
store_1d<P_Method, 6>(idcpu, m_part_idcpu, i);
}
#endif
}

private:
Expand Down Expand Up @@ -411,7 +289,7 @@ namespace detail
element, part_x, part_y, part_t, part_px, part_py, part_pt, part_sx, part_sy, part_sz, part_idcpu, ref_part);

// loop over beam particles in the box
impactx::ParallelFor<T_Element>(np, pushSingleParticle);
amrex::ParallelForSIMD<T_Element>(np, pushSingleParticle);
} else {
throw std::runtime_error("Spin transport requested but element does not implement the `SpinTransport` interface class!");
}
Expand All @@ -421,7 +299,7 @@ namespace detail
element, part_x, part_y, part_t, part_px, part_py, part_pt, part_idcpu, ref_part);

// loop over beam particles in the box
impactx::ParallelFor<T_Element>(np, pushSingleParticle);
amrex::ParallelForSIMD<T_Element>(np, pushSingleParticle);
}
}
} // namespace detail
Expand Down
Loading