Skip to content

Commit 43141bf

Browse files
Switch SoA and AoS
1 parent 0486cf9 commit 43141bf

3 files changed

Lines changed: 35 additions & 56 deletions

File tree

GPU/Common/MemLayout.h

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -23,53 +23,15 @@ template <class T> using const_pointer_restrict = const T* GPUrestrict();
2323

2424
enum Flag { soa, aos };
2525

26+
// The types S<value>, S<reference>, and S<const_reference> need to be aggregate constructible
2627
template <template <template <class> class> class S, template <class> class F, Flag L>
2728
struct wrapper;
2829

2930
template <template <template <class> class> class S, template <class> class F>
30-
struct wrapper<S, F, Flag::aos> : public F<S<value>> {
31-
using Base = F<S<value>>;
32-
33-
template <template <class> class F_out>
34-
constexpr operator wrapper<S, F_out, Flag::aos>() { return {*static_cast<Base*>(this)}; };
35-
36-
template <template <class> class F_out>
37-
constexpr operator wrapper<S, F_out, Flag::aos>() const { return {*static_cast<const Base*>(this)}; };
31+
struct wrapper<S, F, Flag::aos> { using type = F<S<value>>; };
3832

39-
constexpr S<reference> operator[](size_t i) {
40-
return static_cast<Base*>(this)->operator[](i);
41-
}
42-
43-
constexpr S<const_reference> operator[](size_t i) const {
44-
return static_cast<const Base*>(this)->operator[](i);
45-
}
46-
};
47-
48-
template <template <template <class> class> class S, template <class> class F>
49-
using AoS = wrapper<S, F, Flag::aos>;
50-
51-
// The types S<value>, S<reference>, and S<const_reference> need to be aggregate constructible
5233
template <template <template <class> class> class S, template <class> class F>
53-
struct wrapper<S, F, Flag::soa> : public S<F> {
54-
using Base = S<F>;
55-
56-
template <template <class> class F_out>
57-
constexpr operator wrapper<S, F_out, Flag::soa>() { return {*this}; };
58-
59-
template <template <class> class F_out>
60-
constexpr operator wrapper<S, F_out, Flag::soa>() const { return {*this}; };
61-
62-
constexpr S<reference> operator[](size_t i) {
63-
return static_cast<Base*>(this)->operator[](i);
64-
}
65-
66-
constexpr S<const_reference> operator[](size_t i) const {
67-
return static_cast<const Base*>(this)->operator[](i);
68-
}
69-
};
70-
71-
template <template <template <class> class> class S, template <class> class F>
72-
using SoA = wrapper<S, F, Flag::soa>;
34+
struct wrapper<S, F, Flag::soa> { using type = S<F>; };
7335

7436
namespace type_traits {
7537

@@ -92,6 +54,9 @@ struct true_type {
9254
constexpr operator bool() const noexcept { return value; }
9355
};
9456

57+
template <class T>
58+
struct always_false : false_type {};
59+
9560
template<class T, class U>
9661
struct is_same : false_type {};
9762

@@ -143,11 +108,11 @@ constexpr S<F_out> eval_at(size_t i, const Args& ...args) { return {(args[i])...
143108

144109
#define MEMLAYOUT_MEMBERFUNCTIONS(STRUCT, CONTAINER, ...) \
145110
template <template <class> class F_out> \
146-
constexpr operator STRUCT<F_out>() { return { __VA_ARGS__ }; } \
111+
constexpr operator STRUCT<F_out>() { return { __VA_ARGS__ }; } \
147112
template <template <class> class F_out> \
148-
constexpr operator STRUCT<F_out>() const { return { __VA_ARGS__ }; } \
113+
constexpr operator STRUCT<F_out>() const { return { __VA_ARGS__ }; } \
149114
template<class T = int, class R = MemLayout::disable_if_scalar<CONTAINER<T>>> \
150-
constexpr STRUCT<MemLayout::reference> operator[] (MemLayout::size_t i) { \
115+
constexpr STRUCT<MemLayout::reference> operator[] (MemLayout::size_t i) { \
151116
return MemLayout::eval_at<STRUCT, MemLayout::reference>(i, __VA_ARGS__); \
152117
} \
153118
template<class T = int, class R = MemLayout::disable_if_scalar<CONTAINER<T>>> \

GPU/GPUTracking/SectorTracker/GPUTPCTracker.cxx

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,26 @@ void GPUTPCTracker::RegisterMemoryAllocation()
108108
mMemoryResOutput = mRec->RegisterMemoryAllocation(this, &GPUTPCTracker::SetPointersOutput, type, "TPCTrackerTracks");
109109
}
110110

111-
GPUhd() void* GPUTPCTracker::SetPointersTracklets(void* mem)
112-
{
113-
computePointerWithAlignment(mem, mTracklets.mFirstRow, mNMaxTracklets);
114-
computePointerWithAlignment(mem, mTracklets.mLastRow, mNMaxTracklets);
111+
GPUhd() void GPUTPCTracker::SetPointersTrackletsHelper(void* & mem, GPUTPCTracker::TrackletArrayType<MemLayout::Flag::aos>& tracklets) {
112+
computePointerWithAlignment(mem, tracklets, mNMaxTracklets);
113+
}
114+
115+
GPUhd() void GPUTPCTracker::SetPointersTrackletsHelper(void* & mem, GPUTPCTracker::TrackletArrayType<MemLayout::Flag::soa>& tracklets) {
116+
computePointerWithAlignment(mem, tracklets.mFirstRow, mNMaxTracklets);
117+
computePointerWithAlignment(mem, tracklets.mLastRow, mNMaxTracklets);
115118

116-
computePointerWithAlignment(mem, mTracklets.mParam.mX, mNMaxTracklets);
117-
computePointerWithAlignment(mem, mTracklets.mParam.mC, mNMaxTracklets);
118-
computePointerWithAlignment(mem, mTracklets.mParam.mZOffset, mNMaxTracklets);
119-
computePointerWithAlignment(mem, mTracklets.mParam.mP, mNMaxTracklets);
119+
computePointerWithAlignment(mem, tracklets.mParam.mX, mNMaxTracklets);
120+
computePointerWithAlignment(mem, tracklets.mParam.mC, mNMaxTracklets);
121+
computePointerWithAlignment(mem, tracklets.mParam.mZOffset, mNMaxTracklets);
122+
computePointerWithAlignment(mem, tracklets.mParam.mP, mNMaxTracklets);
120123

121-
computePointerWithAlignment(mem, mTracklets.mHitWeight, mNMaxTracklets);
122-
computePointerWithAlignment(mem, mTracklets.mFirstHit, mNMaxTracklets);
124+
computePointerWithAlignment(mem, tracklets.mHitWeight, mNMaxTracklets);
125+
computePointerWithAlignment(mem, tracklets.mFirstHit, mNMaxTracklets);
126+
}
127+
128+
GPUhd() void* GPUTPCTracker::SetPointersTracklets(void* mem)
129+
{
130+
SetPointersTrackletsHelper(mem, mTracklets);
123131
computePointerWithAlignment(mem, mTrackletRowHits, mNMaxRowHits);
124132
return mem;
125133
}

GPU/GPUTracking/SectorTracker/GPUTPCTracker.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ class GPUTPCTracker : public GPUProcessor
6161
void DumpTrackletHits(std::ostream& out); // Same for Track Hits
6262
#endif
6363

64+
template <MemLayout::Flag layout>
65+
using TrackletArrayType = MemLayout::wrapper<GPUTPCTrackletSkeleton, MemLayout::pointer, layout>::type;
66+
6467
struct commonMemoryStruct {
6568
GPUAtomic(uint32_t) nStartHits = 0; // number of start hits
6669
GPUAtomic(uint32_t) nTracklets = 0; // number of tracklets
@@ -106,6 +109,8 @@ class GPUTPCTracker : public GPUProcessor
106109
void* SetPointersScratch(void* mem);
107110
void* SetPointersScratchHost(void* mem);
108111
void* SetPointersCommon(void* mem);
112+
void SetPointersTrackletsHelper(void* & mem, TrackletArrayType<MemLayout::Flag::aos>& tracklets);
113+
void SetPointersTrackletsHelper(void* & mem, TrackletArrayType<MemLayout::Flag::soa>& tracklets);
109114
void* SetPointersTracklets(void* mem);
110115
void* SetPointersOutput(void* mem);
111116
void RegisterMemoryAllocation();
@@ -189,7 +194,8 @@ class GPUTPCTracker : public GPUProcessor
189194

190195
GPUhd() GPUglobalref() GPUTPCTrackletSkeleton<MemLayout::reference_restrict> Tracklet(int32_t i) { return mTracklets[i]; }
191196

192-
GPUhd() GPUglobalref() GPUTPCTrackletSkeleton<MemLayout::pointer> Tracklets() const { return mTracklets; }
197+
template <MemLayout::Flag layout>
198+
GPUhd() GPUglobalref() TrackletArrayType<layout> Tracklets() const { return mTracklets; }
193199
GPUhd() GPUglobalref() calink* TrackletRowHits() const { return mTrackletRowHits; }
194200

195201
GPUhd() GPUglobalref() GPUAtomic(uint32_t) * NTracks() const { return &mCommonMem->nTracks; }
@@ -243,7 +249,7 @@ class GPUTPCTracker : public GPUProcessor
243249
// event
244250
GPUglobalref() commonMemoryStruct* mCommonMem = nullptr; // common event memory
245251
GPUglobalref() GPUTPCHitId* mTrackletStartHits = nullptr; // start hits for the tracklets
246-
GPUglobalref() GPUTPCTrackletSkeleton<MemLayout::pointer> mTracklets; // tracklets
252+
GPUglobalref() TrackletArrayType<MemLayout::Flag::soa> mTracklets; // tracklets
247253
GPUglobalref() calink* mTrackletRowHits = nullptr; // Hits for each Tracklet in each row
248254
GPUglobalref() GPUTPCTrackSkeleton<MemLayout::value>* mTracks = nullptr; // reconstructed tracks
249255
GPUglobalref() GPUTPCHitId* mTrackHits = nullptr; // array of track hit numbers

0 commit comments

Comments
 (0)