Skip to content

Commit 0486cf9

Browse files
Recursive operator[]
1 parent 9f4b216 commit 0486cf9

6 files changed

Lines changed: 54 additions & 49 deletions

File tree

GPU/Common/MemLayout.h

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,6 @@
33

44
#include "GPUCommonDefAPI.h"
55

6-
#define EXPAND(...) __VA_ARGS__
7-
#define MEMLAYOUT_MEMBERFUNCTIONS(STRUCT_NAME, ...) \
8-
template <template <class> class F_out> \
9-
constexpr operator STRUCT_NAME<F_out>() { return { EXPAND(__VA_ARGS__) }; } \
10-
template <template <class> class F_out> \
11-
constexpr operator STRUCT_NAME<F_out>() const { return { EXPAND(__VA_ARGS__) }; } \
12-
template <template <class> class F_out, class FunctionObject> \
13-
constexpr STRUCT_NAME<F_out> invoke_on_members(FunctionObject f) { return {f(EXPAND(__VA_ARGS__))}; } \
14-
template <template <class> class F_out, class FunctionObject> \
15-
constexpr STRUCT_NAME<F_out> invoke_on_members(FunctionObject f) const { return {f(EXPAND(__VA_ARGS__))}; } \
16-
176
namespace MemLayout {
187

198
using size_t = decltype(sizeof 0);
@@ -62,36 +51,21 @@ using AoS = wrapper<S, F, Flag::aos>;
6251
// The types S<value>, S<reference>, and S<const_reference> need to be aggregate constructible
6352
template <template <template <class> class> class S, template <class> class F>
6453
struct wrapper<S, F, Flag::soa> : public S<F> {
54+
using Base = S<F>;
55+
6556
template <template <class> class F_out>
6657
constexpr operator wrapper<S, F_out, Flag::soa>() { return {*this}; };
6758

6859
template <template <class> class F_out>
6960
constexpr operator wrapper<S, F_out, Flag::soa>() const { return {*this}; };
7061

7162
constexpr S<reference> operator[](size_t i) {
72-
return this->template invoke_on_members<reference>(memberwise<reference, evaluate_at>{{i}});
63+
return static_cast<Base*>(this)->operator[](i);
7364
}
7465

7566
constexpr S<const_reference> operator[](size_t i) const {
76-
return this->template invoke_on_members<reference>(memberwise<const_reference, evaluate_at>{{i}});
67+
return static_cast<const Base*>(this)->operator[](i);
7768
}
78-
79-
private:
80-
81-
struct evaluate_at {
82-
size_t i;
83-
template <template <class> class F_in, class T>
84-
constexpr reference<T> operator()(F_in<T> & t) const { return t[i]; }
85-
template <template <class> class F_in, class T>
86-
constexpr const_reference<T> operator()(const F_in<T> & t) const { return t[i]; }
87-
};
88-
89-
template <template <class> class F_out, class FunctionObject>
90-
struct memberwise {
91-
FunctionObject f;
92-
template <class... Args> // HACK: NVCC cannot deduce template parameters of f.operator() like so: { f(args)... }
93-
constexpr S<F_out> operator()(Args&... args) const { return {f.template operator()<F>(args)...}; }
94-
};
9569
};
9670

9771
template <template <template <class> class> class S, template <class> class F>
@@ -132,6 +106,14 @@ using enable_if_equal = type_traits::enable_if_t<type_traits::is_same<T_left, T_
132106
template<class T_left, class T_right>
133107
using disable_if_equal = type_traits::enable_if_t<!type_traits::is_same<T_left, T_right>::value>;
134108

109+
template<class T>
110+
using disable_if_scalar = type_traits::enable_if_t<
111+
!type_traits::is_same<T, value<int>>::value &&
112+
!type_traits::is_same<T, reference<int>>::value &&
113+
!type_traits::is_same<T, reference_restrict<int>>::value &&
114+
!type_traits::is_same<T, const_reference_restrict<int>>::value
115+
>;
116+
135117
#if __cplusplus >= 202002L
136118
template<template <class> class F_left, template <class> class F_right>
137119
concept is_same = type_traits::is_same<F_left<int>, F_right<int>>::value;
@@ -143,6 +125,34 @@ template<template <class> class F>
143125
concept is_const_reference = is_same<F, const_reference>;
144126
#endif
145127

128+
template <
129+
template <template <class> class> class S,
130+
template <class> class F_out,
131+
class... Args
132+
>
133+
constexpr S<F_out> eval_at(size_t i, Args& ...args) { return {(args[i])...}; }
134+
135+
template <
136+
template <template <class> class> class S,
137+
template <class> class F_out,
138+
class... Args
139+
>
140+
constexpr S<F_out> eval_at(size_t i, const Args& ...args) { return {(args[i])...}; }
141+
146142
} // namespace MemLayout
147143

144+
#define MEMLAYOUT_MEMBERFUNCTIONS(STRUCT, CONTAINER, ...) \
145+
template <template <class> class F_out> \
146+
constexpr operator STRUCT<F_out>() { return { __VA_ARGS__ }; } \
147+
template <template <class> class F_out> \
148+
constexpr operator STRUCT<F_out>() const { return { __VA_ARGS__ }; } \
149+
template<class T = int, class R = MemLayout::disable_if_scalar<CONTAINER<T>>> \
150+
constexpr STRUCT<MemLayout::reference> operator[] (MemLayout::size_t i) { \
151+
return MemLayout::eval_at<STRUCT, MemLayout::reference>(i, __VA_ARGS__); \
152+
} \
153+
template<class T = int, class R = MemLayout::disable_if_scalar<CONTAINER<T>>> \
154+
constexpr STRUCT<MemLayout::const_reference> operator[] (MemLayout::size_t i) const { \
155+
return MemLayout::eval_at<STRUCT, MemLayout::const_reference>(i, __VA_ARGS__); \
156+
} \
157+
148158
#endif // MEMLAYOUT_H

GPU/GPUTracking/SectorTracker/GPUTPCBaseTrackParam.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace o2::gpu
3030
*/
3131
template <template <class> class F>
3232
struct GPUTPCBaseTrackParamSkeleton {
33-
MEMLAYOUT_MEMBERFUNCTIONS(GPUTPCBaseTrackParamSkeleton, mX, mC, mZOffset, mP)
33+
MEMLAYOUT_MEMBERFUNCTIONS(GPUTPCBaseTrackParamSkeleton, F, mX, mC, mZOffset, mP)
3434

3535
GPUhd() void ElementwiseAssignment(GPUTPCBaseTrackParamSkeleton<MemLayout::const_reference> v) {
3636
mX = v.mX;

GPU/GPUTracking/SectorTracker/GPUTPCTrack.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,15 @@ template <template <class> class F>
3232
class GPUTPCTrackSkeleton
3333
{
3434
public:
35-
MEMLAYOUT_MEMBERFUNCTIONS(GPUTPCTrackSkeleton, mFirstHitID, mNHits, mLocalTrackId, mParam)
35+
MEMLAYOUT_MEMBERFUNCTIONS(GPUTPCTrackSkeleton, F, mFirstHitID, mNHits, mLocalTrackId, mParam)
3636
#if !defined(GPUCA_GPUCODE)
37-
GPUTPCTrackSkeleton() : mFirstHitID(0), mNHits(0), mLocalTrackId(-1), mParam()
38-
{
39-
}
37+
constexpr GPUTPCTrackSkeleton() : mFirstHitID(0), mNHits(0), mLocalTrackId(-1), mParam() { }
38+
constexpr GPUTPCTrackSkeleton(
39+
F<int32_t> FirstHitID,
40+
F<int32_t> NHits,
41+
F<int32_t> LocalTrackId,
42+
GPUTPCBaseTrackParamSkeleton<F> Param
43+
) : mFirstHitID(FirstHitID), mNHits(NHits), mLocalTrackId(LocalTrackId), mParam(Param) { }
4044
~GPUTPCTrackSkeleton() = default;
4145
#endif //! GPUCA_GPUCODE
4246

@@ -50,7 +54,7 @@ class GPUTPCTrackSkeleton
5054
GPUhd() void SetFirstHitID(int32_t v) { mFirstHitID = v; }
5155
GPUhd() void SetParam(GPUTPCBaseTrackParamSkeleton<MemLayout::const_reference> v) { mParam.ElementwiseAssignment(v); }
5256

53-
private:
57+
//private:
5458
F<int32_t> mFirstHitID; // index of the first track cell in the track->cell pointer array
5559
F<int32_t> mNHits; // number of track cells
5660
F<int32_t> mLocalTrackId; // Id of local track this extrapolated track belongs to, index of this track itself if it is a local track

GPU/GPUTracking/SectorTracker/GPUTPCTrackParam.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ template <template <class> class F>
3636
class GPUTPCTrackParamSkeleton
3737
{
3838
public:
39-
MEMLAYOUT_MEMBERFUNCTIONS(GPUTPCTrackParamSkeleton, mParam, mSignCosPhi, mChi2, mNDF)
39+
MEMLAYOUT_MEMBERFUNCTIONS(GPUTPCTrackParamSkeleton, F, mParam, mSignCosPhi, mChi2, mNDF)
4040

4141
struct GPUTPCTrackFitParam {
4242
float bethe, e, theta2, EP2, sigmadE2, k22, k33, k43, k44; // parameters
@@ -188,7 +188,6 @@ GPUd() void GPUTPCTrackParamSkeleton<F>::InitParam()
188188
SetCov(14, 1000.f);
189189
SetZOffset(0);
190190
}
191-
192191
} // namespace o2::gpu
193192

194193
#endif // GPUTPCTRACKPARAM_H

GPU/GPUTracking/SectorTracker/GPUTPCTracker.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,7 @@ class GPUTPCTracker : public GPUProcessor
187187
GPUhd() GPUglobalref() GPUTPCHitId* TrackletStartHits() { return mTrackletStartHits; }
188188
GPUhd() GPUglobalref() GPUTPCHitId* TrackletTmpStartHits() const { return mTrackletTmpStartHits; }
189189

190-
GPUhd() GPUglobalref() GPUTPCTrackletSkeleton<MemLayout::reference_restrict> Tracklet(int32_t i) {
191-
return {
192-
mTracklets.mFirstRow[i],
193-
mTracklets.mLastRow[i],
194-
{mTracklets.mParam.mX[i], mTracklets.mParam.mC[i], mTracklets.mParam.mZOffset[i], mTracklets.mParam.mP[i]},
195-
mTracklets.mHitWeight[i],
196-
mTracklets.mFirstHit[i]
197-
};
198-
}
190+
GPUhd() GPUglobalref() GPUTPCTrackletSkeleton<MemLayout::reference_restrict> Tracklet(int32_t i) { return mTracklets[i]; }
199191

200192
GPUhd() GPUglobalref() GPUTPCTrackletSkeleton<MemLayout::pointer> Tracklets() const { return mTracklets; }
201193
GPUhd() GPUglobalref() calink* TrackletRowHits() const { return mTrackletRowHits; }
@@ -253,7 +245,7 @@ class GPUTPCTracker : public GPUProcessor
253245
GPUglobalref() GPUTPCHitId* mTrackletStartHits = nullptr; // start hits for the tracklets
254246
GPUglobalref() GPUTPCTrackletSkeleton<MemLayout::pointer> mTracklets; // tracklets
255247
GPUglobalref() calink* mTrackletRowHits = nullptr; // Hits for each Tracklet in each row
256-
GPUglobalref() GPUTPCTrackSkeleton<MemLayout::value>* mTracks = nullptr; // reconstructed tracks
248+
GPUglobalref() GPUTPCTrackSkeleton<MemLayout::value>* mTracks = nullptr; // reconstructed tracks
257249
GPUglobalref() GPUTPCHitId* mTrackHits = nullptr; // array of track hit numbers
258250

259251
static int32_t StarthitSortComparison(const void* a, const void* b);

GPU/GPUTracking/SectorTracker/GPUTPCTracklet.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ template <template <class> class F>
3131
class GPUTPCTrackletSkeleton
3232
{
3333
public:
34-
MEMLAYOUT_MEMBERFUNCTIONS(GPUTPCTrackletSkeleton, mFirstRow, mLastRow, mParam, mHitWeight, mFirstHit)
34+
MEMLAYOUT_MEMBERFUNCTIONS(GPUTPCTrackletSkeleton, F, mFirstRow, mLastRow, mParam, mHitWeight, mFirstHit)
3535
#if !defined(GPUCA_GPUCODE)
3636
//GPUTPCTrackletSkeleton() : mFirstRow(0), mLastRow(0), mParam(), mHitWeight(0), mFirstHit(0) {};
3737
#endif //! GPUCA_GPUCODE

0 commit comments

Comments
 (0)