1+ #ifndef MEMLAYOUT_H
2+ #define MEMLAYOUT_H
3+
4+ #include " GPUCommonDefAPI.h"
5+
6+ #define EXPAND (...) __VA_ARGS__
7+ #define MEMLAYOUT_MEMBERFUNCTIONS (STRUCT_NAME , ...) \
8+ template <template <class > class F_out > \
9+ constexpr operator STRUCT_NAME <F_out>() { return { EXPAND (__VA_ARGS__) }; } \
10+ template <template <class > class F_out > \
11+ constexpr operator STRUCT_NAME <F_out>() const { return { EXPAND (__VA_ARGS__) }; } \
12+ template <template <class > class F_out , class FunctionObject > \
13+ constexpr STRUCT_NAME <F_out> invoke_on_members (FunctionObject f) { return {f (EXPAND (__VA_ARGS__))}; } \
14+ template <template <class > class F_out , class FunctionObject > \
15+ constexpr STRUCT_NAME <F_out> invoke_on_members (FunctionObject f) const { return {f (EXPAND (__VA_ARGS__))}; } \
16+
17+ namespace MemLayout {
18+
19+ using size_t = decltype (sizeof 0 );
20+
21+ template <class T > using value = T;
22+
23+ template <class T > using reference = T&;
24+ template <class T > using reference_restrict = T& GPUrestrict ();
25+
26+ template <class T > using const_reference = const T&;
27+ template <class T > using const_reference_restrict = const T& GPUrestrict ();
28+
29+ template <class T > using pointer = T*;
30+ template <class T > using pointer_restrtict = T* GPUrestrict ();
31+
32+ template <class T > using const_pointer = const T*;
33+ template <class T > using const_pointer_restrict = const T* GPUrestrict ();
34+
35+ enum Flag { soa, aos };
36+
37+ template <template <template <class > class > class S , template <class > class F , Flag L>
38+ struct wrapper ;
39+
40+ template <template <template <class > class > class S , template <class > class F >
41+ struct wrapper <S, F, Flag::aos> : public F<S<value>> {
42+ using Base = F<S<value>>;
43+
44+ template <template <class > class F_out >
45+ constexpr operator wrapper<S, F_out, Flag::aos>() { return {*static_cast <Base*>(this )}; };
46+
47+ template <template <class > class F_out >
48+ constexpr operator wrapper<S, F_out, Flag::aos>() const { return {*static_cast <const Base*>(this )}; };
49+
50+ constexpr S<reference> operator [](size_t i) {
51+ return static_cast <Base*>(this )->operator [](i);
52+ }
53+
54+ constexpr S<const_reference> operator [](size_t i) const {
55+ return static_cast <const Base*>(this )->operator [](i);
56+ }
57+ };
58+
59+ template <template <template <class > class > class S , template <class > class F >
60+ using AoS = wrapper<S, F, Flag::aos>;
61+
62+ // The types S<value>, S<reference>, and S<const_reference> need to be aggregate constructible
63+ template <template <template <class > class > class S , template <class > class F >
64+ struct wrapper <S, F, Flag::soa> : public S<F> {
65+ template <template <class > class F_out >
66+ constexpr operator wrapper<S, F_out, Flag::soa>() { return {*this }; };
67+
68+ template <template <class > class F_out >
69+ constexpr operator wrapper<S, F_out, Flag::soa>() const { return {*this }; };
70+
71+ constexpr S<reference> operator [](size_t i) {
72+ return this ->template invoke_on_members <reference>(memberwise<reference, evaluate_at>{{i}});
73+ }
74+
75+ constexpr S<const_reference> operator [](size_t i) const {
76+ return this ->template invoke_on_members <reference>(memberwise<const_reference, evaluate_at>{{i}});
77+ }
78+
79+ private:
80+
81+ struct evaluate_at {
82+ size_t i;
83+ template <template <class > class F_in , class T >
84+ constexpr reference<T> operator ()(F_in<T> & t) const { return t[i]; }
85+ template <template <class > class F_in , class T >
86+ constexpr const_reference<T> operator ()(const F_in<T> & t) const { return t[i]; }
87+ };
88+
89+ template <template <class > class F_out , class FunctionObject >
90+ struct memberwise {
91+ FunctionObject f;
92+ template <class ... Args> // HACK: NVCC cannot deduce template parameters of f.operator() like so: { f(args)... }
93+ constexpr S<F_out> operator ()(Args&... args) const { return {f.template operator ()<F>(args)...}; }
94+ };
95+ };
96+
97+ template <template <template <class > class > class S , template <class > class F >
98+ using SoA = wrapper<S, F, Flag::soa>;
99+
100+ } // namespace MemLayout
101+
102+ #endif // MEMLAYOUT_H
0 commit comments