33
44#include " GPUCommonDefAPI.h"
55
6- #define EXPAND (...) __VA_ARGS__
7- #define MEMLAYOUT_MEMBERFUNCTIONS (STRUCT_NAME , ...) \
8- template <template <class > class F_out > \
9- constexpr operator STRUCT_NAME <F_out>() { return { EXPAND (__VA_ARGS__) }; } \
10- template <template <class > class F_out > \
11- constexpr operator STRUCT_NAME <F_out>() const { return { EXPAND (__VA_ARGS__) }; } \
12- template <template <class > class F_out , class FunctionObject > \
13- constexpr STRUCT_NAME <F_out> invoke_on_members (FunctionObject f) { return {f (EXPAND (__VA_ARGS__))}; } \
14- template <template <class > class F_out , class FunctionObject > \
15- constexpr STRUCT_NAME <F_out> invoke_on_members (FunctionObject f) const { return {f (EXPAND (__VA_ARGS__))}; } \
16-
176namespace MemLayout {
187
198using size_t = decltype (sizeof 0 );
@@ -62,36 +51,21 @@ using AoS = wrapper<S, F, Flag::aos>;
6251// The types S<value>, S<reference>, and S<const_reference> need to be aggregate constructible
6352template <template <template <class > class > class S , template <class > class F >
6453struct wrapper <S, F, Flag::soa> : public S<F> {
54+ using Base = S<F>;
55+
6556 template <template <class > class F_out >
6657 constexpr operator wrapper<S, F_out, Flag::soa>() { return {*this }; };
6758
6859 template <template <class > class F_out >
6960 constexpr operator wrapper<S, F_out, Flag::soa>() const { return {*this }; };
7061
7162 constexpr S<reference> operator [](size_t i) {
72- return this -> template invoke_on_members <reference>(memberwise<reference, evaluate_at>{{i}} );
63+ return static_cast <Base*>( this )-> operator [](i );
7364 }
7465
7566 constexpr S<const_reference> operator [](size_t i) const {
76- return this -> template invoke_on_members <reference>(memberwise<const_reference, evaluate_at>{{i}} );
67+ return static_cast < const Base*>( this )-> operator [](i );
7768 }
78-
79- private:
80-
81- struct evaluate_at {
82- size_t i;
83- template <template <class > class F_in , class T >
84- constexpr reference<T> operator ()(F_in<T> & t) const { return t[i]; }
85- template <template <class > class F_in , class T >
86- constexpr const_reference<T> operator ()(const F_in<T> & t) const { return t[i]; }
87- };
88-
89- template <template <class > class F_out , class FunctionObject >
90- struct memberwise {
91- FunctionObject f;
92- template <class ... Args> // HACK: NVCC cannot deduce template parameters of f.operator() like so: { f(args)... }
93- constexpr S<F_out> operator ()(Args&... args) const { return {f.template operator ()<F>(args)...}; }
94- };
9569};
9670
9771template <template <template <class > class > class S , template <class > class F >
@@ -132,6 +106,14 @@ using enable_if_equal = type_traits::enable_if_t<type_traits::is_same<T_left, T_
132106template <class T_left , class T_right >
133107using disable_if_equal = type_traits::enable_if_t <!type_traits::is_same<T_left, T_right>::value>;
134108
109+ template <class T >
110+ using disable_if_scalar = type_traits::enable_if_t <
111+ !type_traits::is_same<T, value<int >>::value &&
112+ !type_traits::is_same<T, reference<int >>::value &&
113+ !type_traits::is_same<T, reference_restrict<int >>::value &&
114+ !type_traits::is_same<T, const_reference_restrict<int >>::value
115+ >;
116+
135117#if __cplusplus >= 202002L
136118template <template <class > class F_left , template <class > class F_right >
137119concept is_same = type_traits::is_same<F_left<int >, F_right<int >>::value;
@@ -143,6 +125,34 @@ template<template <class> class F>
143125concept is_const_reference = is_same<F, const_reference>;
144126#endif
145127
128+ template <
129+ template <template <class > class > class S ,
130+ template <class > class F_out ,
131+ class ... Args
132+ >
133+ constexpr S<F_out> eval_at (size_t i, Args& ...args) { return {(args[i])...}; }
134+
135+ template <
136+ template <template <class > class > class S ,
137+ template <class > class F_out ,
138+ class ... Args
139+ >
140+ constexpr S<F_out> eval_at (size_t i, const Args& ...args) { return {(args[i])...}; }
141+
146142} // namespace MemLayout
147143
144+ #define MEMLAYOUT_MEMBERFUNCTIONS (STRUCT , CONTAINER , ...) \
145+ template <template <class > class F_out > \
146+ constexpr operator STRUCT <F_out>() { return { __VA_ARGS__ }; } \
147+ template <template <class > class F_out > \
148+ constexpr operator STRUCT <F_out>() const { return { __VA_ARGS__ }; } \
149+ template <class T = int , class R = MemLayout::disable_if_scalar<CONTAINER <T>>> \
150+ constexpr STRUCT <MemLayout::reference> operator [] (MemLayout::size_t i) { \
151+ return MemLayout::eval_at<STRUCT , MemLayout::reference>(i, __VA_ARGS__); \
152+ } \
153+ template <class T = int , class R = MemLayout::disable_if_scalar<CONTAINER <T>>> \
154+ constexpr STRUCT <MemLayout::const_reference> operator [] (MemLayout::size_t i) const { \
155+ return MemLayout::eval_at<STRUCT , MemLayout::const_reference>(i, __VA_ARGS__); \
156+ } \
157+
148158#endif // MEMLAYOUT_H
0 commit comments