Skip to content

Commit 3ecb319

Browse files
Use Macro instead of CRTP
1 parent 0c104dc commit 3ecb319

21 files changed

Lines changed: 173 additions & 375 deletions

GPU/Common/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ set(HDRS_INSTALL
2727
GPUROOTSMatrixFwd.h
2828
GPUROOTCartesianFwd.h
2929
GPUDebugStreamer.h
30-
helper.h
31-
wrapper.h)
30+
MemLayout.h)
3231

3332
if(ALIGPU_BUILD_TYPE STREQUAL "O2")
3433
o2_add_library(${MODULE}

GPU/Common/MemLayout.h

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#ifndef MEMLAYOUT_H
2+
#define MEMLAYOUT_H
3+
4+
#include "GPUCommonDefAPI.h"
5+
6+
#define EXPAND(...) __VA_ARGS__
7+
#define MEMLAYOUT_MEMBERFUNCTIONS(STRUCT_NAME, ...) \
8+
template <template <class> class F_out> \
9+
constexpr operator STRUCT_NAME<F_out>() { return { EXPAND(__VA_ARGS__) }; } \
10+
template <template <class> class F_out> \
11+
constexpr operator STRUCT_NAME<F_out>() const { return { EXPAND(__VA_ARGS__) }; } \
12+
template <template <class> class F_out, class FunctionObject> \
13+
constexpr STRUCT_NAME<F_out> invoke_on_members(FunctionObject f) { return {f(EXPAND(__VA_ARGS__))}; } \
14+
template <template <class> class F_out, class FunctionObject> \
15+
constexpr STRUCT_NAME<F_out> invoke_on_members(FunctionObject f) const { return {f(EXPAND(__VA_ARGS__))}; } \
16+
17+
namespace MemLayout {
18+
19+
using size_t = decltype(sizeof 0);
20+
21+
template <class T> using value = T;
22+
23+
template <class T> using reference = T&;
24+
template <class T> using reference_restrict = T& GPUrestrict();
25+
26+
template <class T> using const_reference = const T&;
27+
template <class T> using const_reference_restrict = const T& GPUrestrict();
28+
29+
template <class T> using pointer = T*;
30+
template <class T> using pointer_restrtict = T* GPUrestrict();
31+
32+
template <class T> using const_pointer = const T*;
33+
template <class T> using const_pointer_restrict = const T* GPUrestrict();
34+
35+
enum Flag { soa, aos };
36+
37+
template <template <template <class> class> class S, template <class> class F, Flag L>
38+
struct wrapper;
39+
40+
template <template <template <class> class> class S, template <class> class F>
41+
struct wrapper<S, F, Flag::aos> : public F<S<value>> {
42+
using Base = F<S<value>>;
43+
44+
template <template <class> class F_out>
45+
constexpr operator wrapper<S, F_out, Flag::aos>() { return {*static_cast<Base*>(this)}; };
46+
47+
template <template <class> class F_out>
48+
constexpr operator wrapper<S, F_out, Flag::aos>() const { return {*static_cast<const Base*>(this)}; };
49+
50+
constexpr S<reference> operator[](size_t i) {
51+
return static_cast<Base*>(this)->operator[](i);
52+
}
53+
54+
constexpr S<const_reference> operator[](size_t i) const {
55+
return static_cast<const Base*>(this)->operator[](i);
56+
}
57+
};
58+
59+
template <template <template <class> class> class S, template <class> class F>
60+
using AoS = wrapper<S, F, Flag::aos>;
61+
62+
// The types S<value>, S<reference>, and S<const_reference> need to be aggregate constructible
63+
template <template <template <class> class> class S, template <class> class F>
64+
struct wrapper<S, F, Flag::soa> : public S<F> {
65+
template <template <class> class F_out>
66+
constexpr operator wrapper<S, F_out, Flag::soa>() { return {*this}; };
67+
68+
template <template <class> class F_out>
69+
constexpr operator wrapper<S, F_out, Flag::soa>() const { return {*this}; };
70+
71+
constexpr S<reference> operator[](size_t i) {
72+
return this->template invoke_on_members<reference>(memberwise<reference, evaluate_at>{{i}});
73+
}
74+
75+
constexpr S<const_reference> operator[](size_t i) const {
76+
return this->template invoke_on_members<reference>(memberwise<const_reference, evaluate_at>{{i}});
77+
}
78+
79+
private:
80+
81+
struct evaluate_at {
82+
size_t i;
83+
template <template <class> class F_in, class T>
84+
constexpr reference<T> operator()(F_in<T> & t) const { return t[i]; }
85+
template <template <class> class F_in, class T>
86+
constexpr const_reference<T> operator()(const F_in<T> & t) const { return t[i]; }
87+
};
88+
89+
template <template <class> class F_out, class FunctionObject>
90+
struct memberwise {
91+
FunctionObject f;
92+
template <class... Args> // HACK: NVCC cannot deduce template parameters of f.operator() like so: { f(args)... }
93+
constexpr S<F_out> operator()(Args&... args) const { return {f.template operator()<F>(args)...}; }
94+
};
95+
};
96+
97+
template <template <template <class> class> class S, template <class> class F>
98+
using SoA = wrapper<S, F, Flag::soa>;
99+
100+
} // namespace MemLayout
101+
102+
#endif // MEMLAYOUT_H

GPU/Common/helper.h

Lines changed: 0 additions & 132 deletions
This file was deleted.

GPU/Common/wrapper.h

Lines changed: 0 additions & 72 deletions
This file was deleted.

GPU/GPUTracking/DataTypes/GPUDataTypes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#endif
2626
#include "GPUTRDDef.h"
2727
#include "GPUTPCDef.h"
28+
#include "MemLayout.h"
2829

2930
struct AliHLTTPCClusterMCLabel;
3031
struct AliHLTTPCRawCluster;
@@ -226,7 +227,7 @@ struct GPUTrackingInOutPointers {
226227
const AliHLTTPCRawCluster* rawClusters[NSECTORS] = {nullptr};
227228
uint32_t nRawClusters[NSECTORS] = {0};
228229
const o2::tpc::ClusterNativeAccess* clustersNative = nullptr;
229-
const GPUTPCTrackSkeleton<wrapper::value>* sectorTracks[NSECTORS] = {nullptr}; // GPUTPCTrack
230+
const GPUTPCTrackSkeleton<MemLayout::value>* sectorTracks[NSECTORS] = {nullptr}; // GPUTPCTrack
230231
uint32_t nSectorTracks[NSECTORS] = {0};
231232
const GPUTPCHitId* sectorClusters[NSECTORS] = {nullptr};
232233
uint32_t nSectorClusters[NSECTORS] = {0};

GPU/GPUTracking/Global/GPUChainTracking.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "GPUChain.h"
1919
#include "GPUDataTypes.h"
20+
#include "MemLayout.h"
2021
#include <atomic>
2122
#include <mutex>
2223
#include <functional>
@@ -107,7 +108,7 @@ class GPUChainTracking : public GPUChain
107108
std::unique_ptr<AliHLTTPCRawCluster[]> rawClusters[NSECTORS];
108109
std::unique_ptr<o2::tpc::ClusterNative[]> clustersNative;
109110
std::unique_ptr<o2::tpc::ClusterNativeAccess> clusterNativeAccess;
110-
std::unique_ptr<GPUTPCTrackSkeleton<wrapper::value>[]> sectorTracks[NSECTORS];
111+
std::unique_ptr<GPUTPCTrackSkeleton<MemLayout::value>[]> sectorTracks[NSECTORS];
111112
std::unique_ptr<GPUTPCHitId[]> sectorClusters[NSECTORS];
112113
std::unique_ptr<AliHLTTPCClusterMCLabel[]> mcLabelsTPC;
113114
std::unique_ptr<GPUTPCMCInfo[]> mcInfosTPC;

GPU/GPUTracking/Merger/GPUTPCGMSectorTrack.cxx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
#include "GPUO2DataTypes.h"
1919
#include "GPUTPCGMMerger.h"
2020
#include "GPUTPCConvertImpl.h"
21+
#include "MemLayout.h"
2122
#include "GPUParam.inc"
2223

2324
using namespace o2::gpu;
2425
using namespace o2::tpc;
2526

2627
GPUd() void GPUTPCGMSectorTrack::Set(const GPUTPCGMMerger* merger, const GPUTPCTrack* sectorTr, float alpha, int32_t sector)
2728
{
28-
GPUTPCBaseTrackParamSkeleton<wrapper::const_reference> t = sectorTr->Param();
29+
GPUTPCBaseTrackParamSkeleton<MemLayout::const_reference> t = sectorTr->Param();
2930
mOrigTrack = sectorTr;
3031
mParam.mX = t.GetX();
3132
mParam.mY = t.GetY();

0 commit comments

Comments
 (0)