Skip to content

Commit f353ed4

Browse files
Recursive operator[]
1 parent ea5379b commit f353ed4

22 files changed

Lines changed: 265 additions & 465 deletions

GPU/Common/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ set(HDRS_INSTALL
2727
GPUROOTSMatrixFwd.h
2828
GPUROOTCartesianFwd.h
2929
GPUDebugStreamer.h
30-
helper.h
31-
wrapper.h)
30+
MemLayout.h)
3231

3332
if(ALIGPU_BUILD_TYPE STREQUAL "O2")
3433
o2_add_library(${MODULE}

GPU/Common/MemLayout.h

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#ifndef MEMLAYOUT_H
2+
#define MEMLAYOUT_H
3+
4+
#include "GPUCommonDefAPI.h"
5+
6+
namespace MemLayout {
7+
8+
using size_t = decltype(sizeof 0);
9+
10+
template <class T> using value = T;
11+
12+
template <class T> using reference = T&;
13+
template <class T> using reference_restrict = T& GPUrestrict();
14+
15+
template <class T> using const_reference = const T&;
16+
template <class T> using const_reference_restrict = const T& GPUrestrict();
17+
18+
template <class T> using pointer = T*;
19+
template <class T> using pointer_restrtict = T* GPUrestrict();
20+
21+
template <class T> using const_pointer = const T*;
22+
template <class T> using const_pointer_restrict = const T* GPUrestrict();
23+
24+
enum Flag { soa, aos };
25+
26+
// The types S<value>, S<reference>, and S<const_reference> need to be aggregate constructible
27+
template <template <template <class> class> class S, template <class> class F, Flag L>
28+
struct wrapper;
29+
30+
template <template <template <class> class> class S, template <class> class F>
31+
struct wrapper<S, F, Flag::aos> { using type = F<S<value>>; };
32+
33+
template <template <template <class> class> class S, template <class> class F>
34+
struct wrapper<S, F, Flag::soa> { using type = S<F>; };
35+
36+
namespace type_traits {
37+
38+
template<bool B, class T = void>
39+
struct enable_if {};
40+
41+
template<class T>
42+
struct enable_if<true, T> { typedef T type; };
43+
44+
template< bool B, class T = void >
45+
using enable_if_t = typename enable_if<B, T>::type;
46+
47+
struct false_type {
48+
static constexpr bool value = false;
49+
constexpr operator bool() const noexcept { return value; }
50+
};
51+
52+
struct true_type {
53+
static constexpr bool value = true;
54+
constexpr operator bool() const noexcept { return value; }
55+
};
56+
57+
template <class T>
58+
struct always_false : false_type {};
59+
60+
template<class T, class U>
61+
struct is_same : false_type {};
62+
63+
template<class T>
64+
struct is_same<T, T> : true_type {};
65+
66+
} // namespace type_traits
67+
68+
template<class T_left, class T_right>
69+
using enable_if_equal = type_traits::enable_if_t<type_traits::is_same<T_left, T_right>::value>;
70+
71+
template<class T_left, class T_right>
72+
using disable_if_equal = type_traits::enable_if_t<!type_traits::is_same<T_left, T_right>::value>;
73+
74+
template<class T>
75+
using disable_if_scalar = type_traits::enable_if_t<
76+
!type_traits::is_same<T, value<int>>::value &&
77+
!type_traits::is_same<T, reference<int>>::value &&
78+
!type_traits::is_same<T, reference_restrict<int>>::value &&
79+
!type_traits::is_same<T, const_reference_restrict<int>>::value
80+
>;
81+
82+
#if __cplusplus >= 202002L
83+
template<template <class> class F_left, template <class> class F_right>
84+
concept is_same = type_traits::is_same<F_left<int>, F_right<int>>::value;
85+
template<template <class> class F>
86+
concept is_value = is_same<F, value>;
87+
template<template <class> class F>
88+
concept is_reference = is_same<F, reference>;
89+
template<template <class> class F>
90+
concept is_const_reference = is_same<F, const_reference>;
91+
#endif
92+
93+
template <
94+
template <template <class> class> class S,
95+
template <class> class F_out,
96+
class... Args
97+
>
98+
constexpr S<F_out> eval_at(size_t i, Args& ...args) { return {(args[i])...}; }
99+
100+
template <
101+
template <template <class> class> class S,
102+
template <class> class F_out,
103+
class... Args
104+
>
105+
constexpr S<F_out> eval_at(size_t i, const Args& ...args) { return {(args[i])...}; }
106+
107+
} // namespace MemLayout
108+
109+
#define MEMLAYOUT_MEMBERFUNCTIONS(STRUCT, CONTAINER, ...) \
110+
template <template <class> class F_out> \
111+
constexpr operator STRUCT<F_out>() { return { __VA_ARGS__ }; } \
112+
template <template <class> class F_out> \
113+
constexpr operator STRUCT<F_out>() const { return { __VA_ARGS__ }; } \
114+
template<class T = int, class R = MemLayout::disable_if_scalar<CONTAINER<T>>> \
115+
constexpr STRUCT<MemLayout::reference> operator[] (MemLayout::size_t i) { \
116+
return MemLayout::eval_at<STRUCT, MemLayout::reference>(i, __VA_ARGS__); \
117+
} \
118+
template<class T = int, class R = MemLayout::disable_if_scalar<CONTAINER<T>>> \
119+
constexpr STRUCT<MemLayout::const_reference> operator[] (MemLayout::size_t i) const { \
120+
return MemLayout::eval_at<STRUCT, MemLayout::const_reference>(i, __VA_ARGS__); \
121+
} \
122+
123+
#endif // MEMLAYOUT_H

GPU/Common/helper.h

Lines changed: 0 additions & 132 deletions
This file was deleted.

GPU/Common/wrapper.h

Lines changed: 0 additions & 72 deletions
This file was deleted.

GPU/GPUTracking/DataTypes/GPUDataTypes.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#include <cstddef>
2525
#endif
2626
#include "GPUTRDDef.h"
27-
#include "GPUTPCDef.h"
27+
#include "MemLayout.h"
2828

2929
struct AliHLTTPCClusterMCLabel;
3030
struct AliHLTTPCRawCluster;
@@ -226,7 +226,7 @@ struct GPUTrackingInOutPointers {
226226
const AliHLTTPCRawCluster* rawClusters[NSECTORS] = {nullptr};
227227
uint32_t nRawClusters[NSECTORS] = {0};
228228
const o2::tpc::ClusterNativeAccess* clustersNative = nullptr;
229-
const GPUTPCTrackSkeleton<wrapper::value>* sectorTracks[NSECTORS] = {nullptr}; // GPUTPCTrack
229+
const GPUTPCTrackSkeleton<MemLayout::value>* sectorTracks[NSECTORS] = {nullptr};
230230
uint32_t nSectorTracks[NSECTORS] = {0};
231231
const GPUTPCHitId* sectorClusters[NSECTORS] = {nullptr};
232232
uint32_t nSectorClusters[NSECTORS] = {0};

GPU/GPUTracking/Global/GPUChainTracking.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "GPUChain.h"
1919
#include "GPUDataTypes.h"
20+
#include "MemLayout.h"
2021
#include <atomic>
2122
#include <mutex>
2223
#include <functional>
@@ -107,7 +108,7 @@ class GPUChainTracking : public GPUChain
107108
std::unique_ptr<AliHLTTPCRawCluster[]> rawClusters[NSECTORS];
108109
std::unique_ptr<o2::tpc::ClusterNative[]> clustersNative;
109110
std::unique_ptr<o2::tpc::ClusterNativeAccess> clusterNativeAccess;
110-
std::unique_ptr<GPUTPCTrackSkeleton<wrapper::value>[]> sectorTracks[NSECTORS];
111+
std::unique_ptr<GPUTPCTrackSkeleton<MemLayout::value>[]> sectorTracks[NSECTORS];
111112
std::unique_ptr<GPUTPCHitId[]> sectorClusters[NSECTORS];
112113
std::unique_ptr<AliHLTTPCClusterMCLabel[]> mcLabelsTPC;
113114
std::unique_ptr<GPUTPCMCInfo[]> mcInfosTPC;

GPU/GPUTracking/Merger/GPUTPCGMSectorTrack.cxx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
#include "GPUO2DataTypes.h"
1919
#include "GPUTPCGMMerger.h"
2020
#include "GPUTPCConvertImpl.h"
21+
#include "MemLayout.h"
2122
#include "GPUParam.inc"
2223

2324
using namespace o2::gpu;
2425
using namespace o2::tpc;
2526

2627
GPUd() void GPUTPCGMSectorTrack::Set(const GPUTPCGMMerger* merger, const GPUTPCTrack* sectorTr, float alpha, int32_t sector)
2728
{
28-
GPUTPCBaseTrackParamSkeleton<wrapper::const_reference> t = sectorTr->Param();
29+
GPUTPCBaseTrackParamSkeleton<MemLayout::const_reference> t = sectorTr->Param();
2930
mOrigTrack = sectorTr;
3031
mParam.mX = t.GetX();
3132
mParam.mY = t.GetY();

0 commit comments

Comments
 (0)