Skip to content

Commit c6039f5

Browse files
use templates again
1 parent 4321723 commit c6039f5

5 files changed

Lines changed: 266 additions & 73 deletions

File tree

GPU/Common/MemLayout.h

Lines changed: 201 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
namespace MemLayout {
77

8-
using size_t = decltype(sizeof 0);
9-
108
template <class T> using value = T;
119

1210
template <class T> using reference = T&;
@@ -21,6 +19,9 @@ template <class T> using pointer_restrtict = T* GPUrestrict();
2119
template <class T> using const_pointer = const T*;
2220
template <class T> using const_pointer_restrict = const T* GPUrestrict();
2321

22+
using size_t = decltype(sizeof 0);
23+
using ptrdiff_t = decltype(static_cast<int*>(nullptr) - static_cast<int*>(nullptr));
24+
2425
enum Flag { soa, aos };
2526

2627
// The types S<value>, S<reference>, and S<const_reference> need to be aggregate constructible
@@ -35,89 +36,222 @@ struct wrapper<S, F, Flag::soa> { using type = S<F>; };
3536

3637
namespace type_traits {
3738

38-
template<bool B, class T = void>
39-
struct enable_if {};
40-
41-
template<class T>
42-
struct enable_if<true, T> { typedef T type; };
39+
template<class T> struct remove_reference { using type = T; };
40+
template<class T> struct remove_reference<T&> { using type = T; };
41+
template<class T> struct remove_reference<T&&> { using type = T; };
4342

44-
template< bool B, class T = void >
45-
using enable_if_t = typename enable_if<B, T>::type;
43+
} // namespace type_traits
4644

47-
struct false_type {
48-
static constexpr bool value = false;
49-
constexpr operator bool() const noexcept { return value; }
50-
};
45+
template< class T >
46+
constexpr type_traits::remove_reference<T>::type&& move(T&& t) noexcept {
47+
return static_cast<typename type_traits::remove_reference<T>::type&&>(t);
48+
}
5149

52-
struct true_type {
53-
static constexpr bool value = true;
54-
constexpr operator bool() const noexcept { return value; }
50+
template <class SF>
51+
struct RandomAccessAt {
52+
size_t i;
53+
template <class... Args>
54+
constexpr SF operator()(Args& ...args) const { return {{}, args[i]...}; }
5555
};
5656

57-
template <class T>
58-
struct always_false : false_type {};
59-
60-
template<class T, class U>
61-
struct is_same : false_type {};
57+
template <class SF>
58+
struct GetPointer {
59+
template <class... Args>
60+
constexpr SF operator()(Args& ...args) const { return {{}, &args...}; }
61+
};
6262

63-
template<class T>
64-
struct is_same<T, T> : true_type {};
63+
template <class SF>
64+
struct AggregateConstructor {
65+
template <class... Args>
66+
constexpr SF operator()(Args& ...args) const { return {{}, args...}; }
67+
};
6568

66-
} // namespace type_traits
69+
template <
70+
template <class> class F_left,
71+
template <class> class F_right
72+
>
73+
struct CopyAssignment {
74+
template <class T>
75+
constexpr void operator()(F_left<T>& left, F_right<T>& right) const { left = right; }
76+
};
6777

68-
template<class T_left, class T_right>
69-
using enable_if_equal = type_traits::enable_if_t<type_traits::is_same<T_left, T_right>::value>;
70-
71-
template<class T_left, class T_right>
72-
using disable_if_equal = type_traits::enable_if_t<!type_traits::is_same<T_left, T_right>::value>;
73-
74-
template<class T>
75-
using disable_if_scalar = type_traits::enable_if_t<
76-
!type_traits::is_same<T, value<int>>::value &&
77-
!type_traits::is_same<T, reference<int>>::value &&
78-
!type_traits::is_same<T, reference_restrict<int>>::value &&
79-
!type_traits::is_same<T, const_reference_restrict<int>>::value
80-
>;
81-
82-
#if __cplusplus >= 202002L
83-
template<template <class> class F_left, template <class> class F_right>
84-
concept is_same = type_traits::is_same<F_left<int>, F_right<int>>::value;
85-
template<template <class> class F>
86-
concept is_value = is_same<F, value>;
87-
template<template <class> class F>
88-
concept is_reference = is_same<F, reference>;
89-
template<template <class> class F>
90-
concept is_const_reference = is_same<F, const_reference>;
91-
#endif
78+
template <
79+
template <class> class F_left,
80+
template <class> class F_right
81+
>
82+
struct MoveAssignment {
83+
template <class T>
84+
constexpr void operator()(F_left<T>& left, F_right<T>& right) const { left = move(right); }
85+
};
9286

9387
template <
9488
template <template <class> class> class S,
95-
template <class> class F_out,
96-
class... Args
89+
template <class> class F
9790
>
98-
constexpr S<F_out> eval_at(size_t i, Args& ...args) { return {(args[i])...}; }
91+
struct CRTP {
92+
using Derived = S<F>;
93+
template <template <class> class F_out>
94+
constexpr operator S<F_out>() {
95+
return static_cast<Derived*>(this)->apply(AggregateConstructor<S<F_out>>{});
96+
}
97+
template <template <class> class F_out>
98+
constexpr operator S<F_out>() const {
99+
return static_cast<const Derived*>(this)->apply(AggregateConstructor<S<F_out>>{});
100+
}
101+
constexpr S<reference> operator[] (size_t i) {
102+
return static_cast<Derived*>(this)->apply(RandomAccessAt<S<reference>>{i});
103+
}
104+
constexpr S<const_reference> operator[] (size_t i) const {
105+
return static_cast<const Derived*>(this)->apply(RandomAccessAt<S<const_reference>>{i});
106+
}
107+
constexpr S<pointer> operator& () {
108+
return static_cast<Derived*>(this)->apply(GetPointer<S<pointer>>{});
109+
}
110+
constexpr S<const_pointer> operator& () const {
111+
return static_cast<const Derived*>(this)->apply(GetPointer<S<const_pointer>>{});
112+
}
113+
constexpr S<reference> operator*() {
114+
return static_cast<Derived*>(this)->operator[](0);
115+
}
116+
constexpr S<const_reference> operator*() const {
117+
return static_cast<const Derived*>(this)->operator[](0);
118+
}
119+
};
120+
121+
template <template <template <class> class> class S>
122+
struct CRTP<S, value> {
123+
using Derived = S<value>;
124+
template <template <class> class F_out>
125+
constexpr operator S<F_out>() {
126+
return static_cast<Derived*>(this)->apply(AggregateConstructor<S<F_out>>{});
127+
}
128+
template <template <class> class F_out>
129+
constexpr operator S<F_out>() const {
130+
return static_cast<const Derived*>(this)->apply(AggregateConstructor<S<F_out>>{});
131+
}
132+
};
133+
134+
template <template <template <class> class> class S>
135+
struct CRTP<S, reference> {
136+
using Derived = S<reference>;
137+
template <template <class> class F_out>
138+
constexpr operator S<F_out>() {
139+
return static_cast<Derived*>(this)->apply(AggregateConstructor<S<F_out>>{});
140+
}
141+
template <template <class> class F_out>
142+
constexpr operator S<F_out>() const {
143+
return static_cast<const Derived*>(this)->apply(AggregateConstructor<S<F_out>>{});
144+
}
145+
template <template <class> class F_other>
146+
constexpr Derived& operator=(S<F_other>& other) {
147+
memberwise(*static_cast<Derived*>(this), other, CopyAssignment<reference, F_other>{});
148+
return *static_cast<Derived*>(this);
149+
}
150+
template <template <class> class F_other>
151+
constexpr Derived& operator=(S<F_other>&& other) {
152+
memberwise(*static_cast<Derived*>(this), other, MoveAssignment<reference, F_other>{});
153+
return *static_cast<Derived*>(this);
154+
}
155+
constexpr S<pointer> operator& () {
156+
return static_cast<Derived*>(this)->apply(GetPointer<S<pointer>>{});
157+
}
158+
constexpr S<const_pointer> operator& () const {
159+
return static_cast<const Derived*>(this)->apply(GetPointer<S<const_pointer>>{});
160+
}
161+
};
162+
163+
template <template <template <class> class> class S>
164+
struct CRTP<S, const_reference> {
165+
using Derived = S<const_reference>;
166+
template <template <class> class F_out>
167+
constexpr operator S<F_out>() const {
168+
return static_cast<const Derived*>(this)->apply(AggregateConstructor<S<F_out>>{});
169+
}
170+
constexpr S<const_pointer> operator& () const {
171+
return static_cast<const Derived*>(this)->apply(GetPointer<S<const_pointer>>{});
172+
}
173+
};
174+
175+
template <template <template <class> class> class S>
176+
struct CRTP<S, reference_restrict> {
177+
using Derived = S<reference_restrict>;
178+
template <template <class> class F_out>
179+
constexpr operator S<F_out>() {
180+
return static_cast<Derived*>(this)->apply(AggregateConstructor<S<F_out>>{});
181+
}
182+
template <template <class> class F_out>
183+
constexpr operator S<F_out>() const {
184+
return static_cast<const Derived*>(this)->apply(AggregateConstructor<S<F_out>>{});
185+
}
186+
template <template <class> class F_other>
187+
constexpr Derived& operator=(S<F_other>& other) {
188+
memberwise(*static_cast<Derived*>(this), other, CopyAssignment<reference_restrict, F_other>{});
189+
return *static_cast<Derived*>(this);
190+
}
191+
template <template <class> class F_other>
192+
constexpr Derived& operator=(S<F_other>&& other) {
193+
memberwise(*static_cast<Derived*>(this), other, MoveAssignment<reference_restrict, F_other>{});
194+
return *static_cast<Derived*>(this);
195+
}
196+
constexpr S<pointer_restrtict> operator& () {
197+
return static_cast<Derived*>(this)->apply(GetPointer<S<pointer_restrtict>>{});
198+
}
199+
constexpr S<const_pointer_restrict> operator& () const {
200+
return static_cast<const Derived*>(this)->apply(GetPointer<S<const_pointer_restrict>>{});
201+
}
202+
};
203+
204+
template <template <template <class> class> class S>
205+
struct CRTP<S, const_reference_restrict> {
206+
using Derived = S<const_reference_restrict>;
207+
template <template <class> class F_out>
208+
constexpr operator S<F_out>() const {
209+
return static_cast<const Derived*>(this)->apply(AggregateConstructor<S<F_out>>{});
210+
}
211+
constexpr S<const_pointer_restrict> operator& () const {
212+
return static_cast<const Derived*>(this)->apply(GetPointer<S<const_pointer_restrict>>{});
213+
}
214+
};
99215

100216
template <
101217
template <template <class> class> class S,
102-
template <class> class F_out,
103-
class... Args
218+
template <class> class F
104219
>
105-
constexpr S<F_out> eval_at(size_t i, const Args& ...args) { return {(args[i])...}; }
220+
struct iterator {
221+
//using iterator_category = std::random_access_iterator_tag;
222+
using difference_type = ptrdiff_t;
223+
using value_type = S<value>;
224+
using pointer = S<F>;
225+
using reference = S<reference>;
226+
227+
difference_type index;
228+
pointer handle;
229+
230+
constexpr bool operator==(iterator const& other) const { return index == other.index; }
231+
constexpr bool operator!=(iterator const& other) const { return index != other.index; }
232+
constexpr bool operator<(iterator const& other) const { return index < other.index; }
233+
234+
constexpr iterator operator+(difference_type i) const { return {index + i, handle}; }
235+
constexpr iterator operator-(difference_type i) const { return {index - i, handle}; }
236+
237+
constexpr difference_type operator-(iterator const& other) const {
238+
return difference_type(index) - difference_type(other.index);
239+
}
240+
241+
constexpr iterator& operator++() { ++index; return *this; }
242+
constexpr iterator& operator--() { --index; return *this; }
243+
244+
constexpr reference operator*() { return handle[index]; }
245+
};
106246

107247
} // namespace MemLayout
108248

109-
#define MEMLAYOUT_MEMBERFUNCTIONS(STRUCT, CONTAINER, ...) \
110-
template <template <class> class F_out> \
111-
constexpr operator STRUCT<F_out>() { return { __VA_ARGS__ }; } \
112-
template <template <class> class F_out> \
113-
constexpr operator STRUCT<F_out>() const { return { __VA_ARGS__ }; } \
114-
template<class T = int, class R = MemLayout::disable_if_scalar<CONTAINER<T>>> \
115-
constexpr STRUCT<MemLayout::reference> operator[] (MemLayout::size_t i) { \
116-
return MemLayout::eval_at<STRUCT, MemLayout::reference>(i, __VA_ARGS__); \
117-
} \
118-
template<class T = int, class R = MemLayout::disable_if_scalar<CONTAINER<T>>> \
119-
constexpr STRUCT<MemLayout::const_reference> operator[] (MemLayout::size_t i) const { \
120-
return MemLayout::eval_at<STRUCT, MemLayout::const_reference>(i, __VA_ARGS__); \
121-
} \
249+
#define MEMLAYOUT_MEMBERFUNCTIONS(STRUCT, CONTAINER, ...)\
250+
template <class Function>\
251+
constexpr auto apply(Function&& f) { return f(__VA_ARGS__); }\
252+
template <class Function>\
253+
constexpr auto apply(Function&& f) const { return f(__VA_ARGS__); }\
254+
template <template <class> class F_left, template <class> class F_right, class FunctionObject>\
255+
constexpr friend void memberwise(STRUCT<F_left>& left, STRUCT<F_right>& right, FunctionObject&& f);\
122256

123257
#endif // MEMLAYOUT_H

GPU/GPUTracking/SectorTracker/GPUTPCBaseTrackParam.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ struct Parameters {
4545
* This class is used for transfer between tracker and merger and does not contain the covariance matrice
4646
*/
4747
template <template <class> class F>
48-
struct GPUTPCBaseTrackParamSkeleton {
48+
struct GPUTPCBaseTrackParamSkeleton : public MemLayout::CRTP<GPUTPCBaseTrackParamSkeleton, F>
49+
{
50+
using Base = MemLayout::CRTP<GPUTPCBaseTrackParamSkeleton, F>;
51+
using Base::operator=;
4952
MEMLAYOUT_MEMBERFUNCTIONS(GPUTPCBaseTrackParamSkeleton, F, mX, mC, mZOffset, mP)
5053

5154
GPUd() float X() const { return mX; }
@@ -90,14 +93,14 @@ struct GPUTPCBaseTrackParamSkeleton {
9093
GPUd() void SetZOffset(float v) { mZOffset = v; }
9194

9295
// Needed for iterators and std::sort
93-
template <template <class> class F_in>
96+
/*template <template <class> class F_in>
9497
constexpr GPUTPCBaseTrackParamSkeleton& operator=(GPUTPCBaseTrackParamSkeleton<F_in> other) {
9598
mX = other.mX;
9699
mC = other.mC;
97100
mZOffset = other.mZOffset;
98101
mP = other.mP;
99102
return *this;
100-
}
103+
}*/
101104
//GPUTPCBaseTrackParamSkeleton() = default;
102105
//GPUTPCBaseTrackParamSkeleton(F<float> X, F<Covariance> C, F<float> ZOffset, F<Parameters> P) : mX(X), mC(C), mZOffset(ZOffset), mP(P) { }
103106
/*GPUTPCBaseTrackParamSkeleton(const GPUTPCBaseTrackParamSkeleton& other) = default;
@@ -124,6 +127,18 @@ struct GPUTPCBaseTrackParamSkeleton {
124127
F<Parameters> mP; // 'active' track parameters: Y, Z, SinPhi, DzDs, q/Pt
125128
};
126129

130+
template <
131+
template <class> class F_left,
132+
template <class> class F_right,
133+
class FunctionObject
134+
>
135+
constexpr void memberwise(GPUTPCBaseTrackParamSkeleton<F_left>& left, GPUTPCBaseTrackParamSkeleton<F_right>& right, FunctionObject&& f) {
136+
f(left.mX, right.mX);
137+
f(left.mC, right.mC);
138+
f(left.mZOffset, right.mZOffset);
139+
f(left.mP, right.mP);
140+
}
141+
127142
using GPUTPCBaseTrackParam = GPUTPCBaseTrackParamSkeleton<MemLayout::value>;
128143

129144
} // namespace o2::gpu

GPU/GPUTracking/SectorTracker/GPUTPCTrack.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,13 @@ namespace o2::gpu
2929
* The track parameters at both ends are stored separately in the GPUTPCEndPoint class
3030
*/
3131
template <template <class> class F>
32-
class GPUTPCTrackSkeleton
32+
class GPUTPCTrackSkeleton : public MemLayout::CRTP<GPUTPCTrackSkeleton, F>
3333
{
3434
public:
35+
using Base = MemLayout::CRTP<GPUTPCTrackSkeleton, F>;
36+
using Base::operator=;
3537
MEMLAYOUT_MEMBERFUNCTIONS(GPUTPCTrackSkeleton, F, mFirstHitID, mNHits, mLocalTrackId, mParam)
38+
3639
#if !defined(GPUCA_GPUCODE)
3740
constexpr GPUTPCTrackSkeleton() : mFirstHitID(0), mNHits(0), mLocalTrackId(-1), mParam() { }
3841
constexpr GPUTPCTrackSkeleton(
@@ -61,6 +64,18 @@ class GPUTPCTrackSkeleton
6164
GPUTPCBaseTrackParamSkeleton<F> mParam; // track parameters
6265
};
6366

67+
template <
68+
template <class> class F_left,
69+
template <class> class F_right,
70+
class FunctionObject
71+
>
72+
constexpr void memberwise(GPUTPCTrackSkeleton<F_left>& left, GPUTPCTrackSkeleton<F_right>& right, FunctionObject&& f) {
73+
f(left.mFirstHitID, right.mFirstHitID);
74+
f(left.mNHits, right.mNHits);
75+
f(left.mLocalTrackId, right.mLocalTrackId);
76+
f(left.mParam, right.mParam);
77+
}
78+
6479
using GPUTPCTrack = GPUTPCTrackSkeleton<MemLayout::value>;
6580

6681
} // namespace o2::gpu

0 commit comments

Comments
 (0)