Skip to content

Commit 85944a0

Browse files
Add wrapper.h and helper.h
1 parent ae80673 commit 85944a0

4 files changed

Lines changed: 211 additions & 6 deletions

File tree

GPU/Common/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ set(HDRS_INSTALL
2626
GPUCommonTransform3D.h
2727
GPUROOTSMatrixFwd.h
2828
GPUROOTCartesianFwd.h
29-
GPUDebugStreamer.h)
29+
GPUDebugStreamer.h
30+
helper.h
31+
wrapper.h)
3032

3133
if(ALIGPU_BUILD_TYPE STREQUAL "O2")
3234
o2_add_library(${MODULE}

GPU/Common/helper.h

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#ifndef HELPER_H
2+
#define HELPER_H
3+
4+
namespace helper {
5+
6+
using size_t = decltype(sizeof 0);
7+
8+
namespace detail {
9+
10+
struct UniversalType {
11+
template<class T>
12+
operator T() const;
13+
};
14+
15+
template <typename T, typename Is, typename=void>
16+
struct is_aggregate_constructible_from_n_impl : std::false_type {};
17+
18+
template <typename T, size_t...Is>
19+
struct is_aggregate_constructible_from_n_impl<T, std::index_sequence<Is...>, std::void_t<decltype(T{(void(Is), UniversalType{})...})>> : std::true_type {};
20+
21+
template <typename T, size_t N>
22+
using is_aggregate_constructible_from_n_helper = is_aggregate_constructible_from_n_impl<T, std::make_index_sequence<N>>;
23+
24+
template <typename T, size_t N>
25+
struct is_aggregate_constructible_from_n {
26+
constexpr static bool value = is_aggregate_constructible_from_n_helper<T, N>::value && !is_aggregate_constructible_from_n_helper<T, N+1>::value;
27+
};
28+
29+
template <typename T>
30+
constexpr bool false_type = false;
31+
32+
} // namespace detail
33+
34+
template <class Argument>
35+
constexpr size_t CountMembers() {
36+
if constexpr (detail::is_aggregate_constructible_from_n<Argument, 0>::value) return 0;
37+
else if constexpr (detail::is_aggregate_constructible_from_n<Argument, 1>::value) return 1;
38+
else if constexpr (detail::is_aggregate_constructible_from_n<Argument, 2>::value) return 2;
39+
else if constexpr (detail::is_aggregate_constructible_from_n<Argument, 3>::value) return 3;
40+
else if constexpr (detail::is_aggregate_constructible_from_n<Argument, 4>::value) return 4;
41+
else if constexpr (detail::is_aggregate_constructible_from_n<Argument, 5>::value) return 5;
42+
else if constexpr (detail::is_aggregate_constructible_from_n<Argument, 6>::value) return 6;
43+
else if constexpr (detail::is_aggregate_constructible_from_n<Argument, 7>::value) return 7;
44+
else if constexpr (detail::is_aggregate_constructible_from_n<Argument, 8>::value) return 8;
45+
else if constexpr (detail::is_aggregate_constructible_from_n<Argument, 9>::value) return 9;
46+
else if constexpr (detail::is_aggregate_constructible_from_n<Argument, 10>::value) return 10;
47+
else {
48+
static_assert(detail::false_type<Argument>, "Unsupported number of members.");
49+
return 100; // Silence warnings about missing return value
50+
}
51+
}
52+
53+
template <
54+
class Argument,
55+
class FunctionObject
56+
>
57+
constexpr auto invoke(Argument & arg, FunctionObject&& f) {
58+
constexpr size_t M = helper::CountMembers<Argument>();
59+
if constexpr (M == 0) {
60+
return f();
61+
} else if constexpr (M == 1) {
62+
auto& [m00] = arg;
63+
return f(m00);
64+
} else if constexpr (M == 2) {
65+
auto& [m00, m01] = arg;
66+
return f(m00, m01);
67+
} else if constexpr (M == 3) {
68+
auto& [m00, m01, m02] = arg;
69+
return f(m00, m01, m02);
70+
} else if constexpr (M == 4) {
71+
auto& [m00, m01, m02, m03] = arg;
72+
return f(m00, m01, m02, m03);
73+
} else if constexpr (M == 5) {
74+
auto& [m00, m01, m02, m03, m04] = arg;
75+
return f(m00, m01, m02, m03, m04);
76+
} else if constexpr (M == 6) {
77+
auto& [m00, m01, m02, m03, m04, m05] = arg;
78+
return f(m00, m01, m02, m03, m04, m05);
79+
} else if constexpr (M == 7) {
80+
auto& [m00, m01, m02, m03, m04, m05, m06] = arg;
81+
return f(m00, m01, m02, m03, m04, m05, m06);
82+
} else if constexpr (M == 8) {
83+
auto& [m00, m01, m02, m03, m04, m05, m06, m07] = arg;
84+
return f(m00, m01, m02, m03, m04, m05, m06, m07);
85+
} else if constexpr (M == 9) {
86+
auto& [m00, m01, m02, m03, m04, m05, m06, m07, m08] = arg;
87+
return f(m00, m01, m02, m03, m04, m05, m06, m07, m08);
88+
} else if constexpr (M == 10) {
89+
auto& [m00, m01, m02, m03, m04, m05, m06, m07, m08, m09] = arg;
90+
return f(m00, m01, m02, m03, m04, m05, m06, m07, m08, m09);
91+
} else {
92+
static_assert(detail::false_type<Argument>, "Unsupported number of members.");
93+
return void(); // Silence warnings about missing return value
94+
}
95+
}
96+
97+
template <
98+
template <class> class F_out,
99+
template <class> class F_in,
100+
template <template <class> class> class S,
101+
class FunctionObject
102+
>
103+
struct memberwise {
104+
FunctionObject f;
105+
106+
template <class... Args> // HACK: NVCC cannot deduce template parameters of f.operator() like so: { f(args)... }
107+
constexpr S<F_out> operator()(Args&... args) const { return {f.template operator()<F_in>(args)...}; }
108+
};
109+
110+
template <
111+
template <class> class F_out,
112+
template <class> class F_in,
113+
template <template <class> class> class S,
114+
class FunctionObject
115+
>
116+
constexpr S<F_out> invoke_on_members(S<F_in> & s, FunctionObject&& f) {
117+
return invoke(s, memberwise<F_out, F_in, S, FunctionObject>{f});
118+
}
119+
120+
template <
121+
template <class> class F_out,
122+
template <class> class F_in,
123+
template <template <class> class> class S,
124+
class FunctionObject
125+
>
126+
constexpr S<F_out> invoke_on_members(const S<F_in> & s, FunctionObject&& f) {
127+
return invoke(s, memberwise<F_out, F_in, S, FunctionObject>{f});
128+
}
129+
130+
} // namespace helper
131+
132+
#endif // HELPER_H

GPU/Common/wrapper.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#ifndef WRAPPER_H
2+
#define WRAPPER_H
3+
4+
#include "helper.h"
5+
6+
namespace wrapper {
7+
8+
using size_t = decltype(sizeof 0);
9+
10+
enum class layout { aos = 0, soa = 1 };
11+
12+
template <class T>
13+
using value = T;
14+
15+
template <class T>
16+
using reference = T&;
17+
18+
template <class T>
19+
using const_reference = const T&;
20+
21+
template<
22+
template <template <class> class> class S,
23+
template <class> class F,
24+
layout L
25+
>
26+
struct wrapper;
27+
28+
template <template <template <class> class> class S, template <class> class F>
29+
struct wrapper<S, F, layout::aos> {
30+
F<S<value>> data;
31+
32+
template <template <class> class F_out>
33+
constexpr operator wrapper<S, F_out, layout::aos>() { return {data}; };
34+
35+
constexpr S<reference> operator[](size_t i) { return data[i]; }
36+
constexpr S<const_reference> operator[](size_t i) const { return data[i]; }
37+
};
38+
39+
template <template <template <class> class> class S, template <class> class F>
40+
struct wrapper<S, F, layout::soa> : S<F> {
41+
template <template <class> class F_out>
42+
constexpr operator wrapper<S, F_out, layout::soa>() { return {*this}; };
43+
44+
constexpr S<reference> operator[](size_t i) {
45+
return helper::invoke_on_members<reference, F>(*this, evaluate_at{i});
46+
}
47+
constexpr S<const_reference> operator[](size_t i) const {
48+
return helper::invoke_on_members<const_reference, F>(*this, evaluate_at{i});
49+
}
50+
51+
private:
52+
53+
struct evaluate_at {
54+
size_t i;
55+
56+
template <template <class> class F_in, class T>
57+
constexpr reference<T> operator()(F_in<T> & t) const { return t[i]; }
58+
59+
template <template <class> class F_in, class T>
60+
constexpr const_reference<T> operator()(const F_in<T> & t) const { return t[i]; }
61+
};
62+
};
63+
64+
} // namespace wrapper
65+
66+
#endif // WRAPPER_H

GPU/GPUTracking/SectorTracker/GPUTPCBaseTrackParam.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#define GPUTPCBASETRACKPARAM_H
1717

1818
#include "GPUTPCDef.h"
19+
#include "wrapper.h"
1920

2021
namespace o2::gpu
2122
{
@@ -28,8 +29,9 @@ class GPUTPCTrackParam;
2829
* used in output of the GPUTPCTracker sector tracker.
2930
* This class is used for transfer between tracker and merger and does not contain the covariance matrice
3031
*/
31-
struct GPUTPCBaseTrackParam {
32-
GPUd() float X() const { return mX; }
32+
template <template <class> class F>
33+
struct GPUTPCBaseTrackParamSkeleton {
34+
GPUd() F<float> X() const { return mX; }
3335
GPUd() float Y() const { return mP[0]; }
3436
GPUd() float Z() const { return mP[1]; }
3537
GPUd() float SinPhi() const { return mP[2]; }
@@ -46,7 +48,7 @@ struct GPUTPCBaseTrackParam {
4648
GPUd() float GetCov(int32_t i) const { return mC[i]; }
4749
GPUhd() void SetCov(int32_t i, float v) { mC[i] = v; }
4850

49-
GPUhd() float GetX() const { return mX; }
51+
GPUhd() F<float> GetX() const { return mX; }
5052
GPUhd() float GetY() const { return mP[0]; }
5153
GPUhd() float GetZ() const { return mP[1]; }
5254
GPUhd() float GetSinPhi() const { return mP[2]; }
@@ -62,7 +64,7 @@ struct GPUTPCBaseTrackParam {
6264

6365
GPUhd() void SetPar(int32_t i, float v) { mP[i] = v; }
6466

65-
GPUd() void SetX(float v) { mX = v; }
67+
GPUd() void SetX(F<float> v) { mX = v; }
6668
GPUd() void SetY(float v) { mP[0] = v; }
6769
GPUd() void SetZ(float v) { mP[1] = v; }
6870
GPUd() void SetSinPhi(float v) { mP[2] = v; }
@@ -73,11 +75,14 @@ struct GPUTPCBaseTrackParam {
7375
// WARNING, Track Param Data is copied in the GPU Tracklet Constructor element by element instead of using copy constructor!!!
7476
// This is neccessary for performance reasons!!!
7577
// Changes to Elements of this class therefore must also be applied to TrackletConstructor!!!
76-
float mX; // x position
78+
F<float> mX; // x position
7779
float mC[15]; // the covariance matrix for Y,Z,SinPhi,..
7880
float mZOffset; // z offset
7981
float mP[5]; // 'active' track parameters: Y, Z, SinPhi, DzDs, q/Pt
8082
};
83+
84+
using GPUTPCBaseTrackParam = GPUTPCBaseTrackParamSkeleton<wrapper::value>;
85+
8186
} // namespace o2::gpu
8287

8388
#endif

0 commit comments

Comments
 (0)