Skip to content

Commit 966aab4

Browse files
committed
Remove need to static_cast to call apply
1 parent 131a3df commit 966aab4

1 file changed

Lines changed: 96 additions & 41 deletions

File tree

GPU/Common/MemLayout.h

Lines changed: 96 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <meta>
88
#endif
99

10+
#include <type_traits>
11+
1012
namespace MemLayout {
1113

1214
using size_t = decltype(sizeof(0));
@@ -25,7 +27,7 @@ template <class T> using const_pointer_restrict = const T* __restrict__;
2527

2628
//////////////// Reflection utilities
2729

28-
template <class S>
30+
template <typename S>
2931
constexpr std::size_t count_members() {
3032
return nonstatic_data_members_of(^^S, std::meta::access_context::current()).size();
3133
}
@@ -92,24 +94,77 @@ struct CopyAssignment {
9294
constexpr Left& operator()(Left& left, const Right& right) const { return left = right; }
9395
};
9496

95-
template <class FunctionObject, class Self>
96-
constexpr auto apply(Self &self, FunctionObject&& f) {
97+
//////////////// apply to members methods
98+
99+
template <class Self, class FunctionObject>
100+
constexpr auto apply_unary(Self &self, FunctionObject&& f) {
97101
auto construct_output = [&]<size_t... Is>(std::index_sequence<Is...>) {
98102
return f(self.[:nsdms(^^Self)[Is]:]...);
99103
};
100104
constexpr auto indices = std::make_index_sequence<count_members<Self>()>{};
101105
return construct_output(indices);
102106
}
103107

104-
template <class FunctionObject, class Self, class Other>
105-
constexpr auto apply(Self &self, Other &other, FunctionObject&& f) {
108+
// apply on skeleton struct S<F>
109+
template <class FunctionObject, template <template <class> class> class S, template <class> class F>
110+
constexpr auto apply(S<F> &self, FunctionObject&& f) {
111+
return apply_unary(self, std::forward<FunctionObject&&>(f));
112+
}
113+
114+
template <class FunctionObject, template <template <class> class> class S, template <class> class F>
115+
constexpr auto apply(const S<F> &self, FunctionObject&& f) {
116+
return apply_unary(self, std::forward<FunctionObject&&>(f));
117+
}
118+
119+
// apply on wrappers, forwarding to the base type
120+
template <class FunctionObject, class Self>
121+
requires requires { typename Self::Base; }
122+
constexpr auto apply(Self &self, FunctionObject&& f) {
123+
return apply_unary<typename Self::Base>(self, std::forward<FunctionObject&&>(f));
124+
}
125+
126+
template <class FunctionObject, class Self>
127+
requires requires { typename Self::Base; }
128+
constexpr auto apply(const Self &self, FunctionObject&& f) {
129+
return apply_unary<const typename Self::Base>(self, std::forward<FunctionObject&&>(f));
130+
}
131+
132+
133+
// template <class FunctionObject, class Self, class Other>
134+
// constexpr auto apply(Self &self, Other &other, FunctionObject&& f) {
135+
template <class Self, class Other, class FunctionObject>
136+
constexpr auto apply_binary(Self &self, Other &other, FunctionObject&& f) {
106137
auto construct_output = [&]<size_t... Is>(std::index_sequence<Is...>) -> Self {
107138
return {f(
108139
self.[:nsdms(^^Self)[Is]:], other.[:nsdms(^^Other)[Is]:])...};
109140
};
110141
constexpr auto indices = std::make_index_sequence<count_members<Self>()>{};
111142
return construct_output(indices);
112143
}
144+
145+
template <class FunctionObject, template <template <class> class> class S, template <class> class F_self, template <class> class F_other>
146+
constexpr auto apply(S<F_self> &self, S<F_other> &other, FunctionObject&& f) {
147+
return apply_binary(self, other, std::forward<FunctionObject&&>(f));
148+
}
149+
150+
template <class FunctionObject, template <template <class> class> class S, template <class> class F_self, template <class> class F_other>
151+
constexpr auto apply(S<F_self> &self, const S<F_other> &other, FunctionObject&& f) {
152+
return apply_binary(self, other, std::forward<FunctionObject&&>(f));
153+
}
154+
155+
template <class Self, class Other, class FunctionObject>
156+
requires requires { typename Self::Base; typename Other::Base; }
157+
constexpr auto apply(Self &self, Other &other, FunctionObject&& f) {
158+
return apply_binary<typename Self::Base, typename Other::Base>(self, other, std::forward<FunctionObject&&>(f));
159+
}
160+
161+
template <class Self, class Other, class FunctionObject>
162+
requires requires { typename Self::Base; typename Other::Base; }
163+
constexpr auto apply(Self &self, const Other &other, FunctionObject&& f) {
164+
static_assert(count_members<typename Self::Base>() == 4);
165+
return apply_binary<typename Self::Base, const typename Other::Base>(self, other, std::forward<FunctionObject&&>(f));
166+
}
167+
113168
//////////////// wrapper
114169

115170
template <
@@ -127,9 +182,9 @@ struct wrapper : public S<F> {
127182
constexpr wrapper(const S<F_other>& other) : Base{apply(other, AggregateConstructor<Base>{})} {}
128183

129184
constexpr wrapper<S, reference> operator[] (size_t i) {
130-
return apply(static_cast<Base &>(*this), RandomAccessAt<S<reference>>{i}); }
185+
return apply(*this, RandomAccessAt<S<reference>>{i}); }
131186
constexpr wrapper<S, const_reference> operator[] (size_t i) const {
132-
return apply(static_cast<const Base &>(*this), RandomAccessAt<S<const_reference>>{i}); }
187+
return apply(*this, RandomAccessAt<S<const_reference>>{i}); }
133188

134189
constexpr wrapper<S, reference> operator*() { return operator[](0); }
135190
constexpr wrapper<S, const_reference> operator*(ptrdiff_t) const { return operator[](0); }
@@ -160,32 +215,32 @@ struct wrapper<S, reference> : public S<reference> {
160215
constexpr wrapper(const wrapper& other) = default;
161216

162217
constexpr wrapper& operator=(const wrapper<S, value>& other) {
163-
apply(static_cast<Base &>(*this), static_cast<const S<value>&>(other), CopyAssignment{});
218+
apply(*this, other, CopyAssignment{});
164219
return *this;
165220
}
166221
constexpr wrapper& operator=(const wrapper& other) {
167-
apply(static_cast<Base &>(*this), static_cast<const Base&>(other), CopyAssignment{});
222+
apply(*this, other, CopyAssignment{});
168223
return *this;
169224
}
170225
constexpr wrapper& operator=(const wrapper<S, const_reference>& other) {
171-
apply(static_cast<Base &>(*this), static_cast<const S<const_reference>&>(other), CopyAssignment{});
226+
apply(*this, other, CopyAssignment{});
172227
return *this;
173228
}
174229
constexpr wrapper& operator=(const wrapper<S, reference_restrict>& other) {
175-
apply(static_cast<Base &>(*this), static_cast<const S<reference_restrict>&>(other), CopyAssignment{});
230+
apply(*this, other, CopyAssignment{});
176231
return *this;
177232
}
178233
constexpr wrapper& operator=(const wrapper<S, const_reference_restrict>& other) {
179-
apply(static_cast<Base &>(*this), static_cast<const S<const_reference_restrict>&>(other), CopyAssignment{});
234+
apply(*this, other, CopyAssignment{});
180235
return *this;
181236
}
182237

183238
constexpr wrapper(wrapper&& other) = default;
184239

185240
constexpr wrapper& operator=(wrapper&& other) { return operator=(other); }
186241

187-
constexpr wrapper<S, pointer> operator&() { return apply(static_cast<Base&>(*this), GetPointer<S<pointer>>{}); }
188-
//constexpr wrapper<S, const_pointer> operator&() const { return apply(static_cast<const Base&>(*this), GetPointer<S<const_pointer>>{}); }
242+
constexpr wrapper<S, pointer> operator&() { return apply(*this, GetPointer<S<pointer>>{}); }
243+
//constexpr wrapper<S, const_pointer> operator&() const { return apply(*this, GetPointer<S<const_pointer>>{}); }
189244
constexpr pointer<wrapper<S, reference>> operator->() { return this; }
190245
};
191246

@@ -201,32 +256,32 @@ struct wrapper<S, reference_restrict> : public S<reference_restrict> {
201256
constexpr wrapper(const wrapper& other) = default;
202257

203258
constexpr wrapper& operator=(const wrapper<S, value>& other) {
204-
apply(static_cast<Base &>(*this), static_cast<const S<value>&>(other), CopyAssignment{});
259+
apply(*this, other, CopyAssignment{});
205260
return *this;
206261
}
207262
constexpr wrapper& operator=(const wrapper& other) {
208-
apply(static_cast<Base &>(*this), static_cast<const Base&>(other), CopyAssignment{});
263+
apply(*this, other, CopyAssignment{});
209264
return *this;
210265
}
211266
constexpr wrapper& operator=(const wrapper<S, reference>& other) {
212-
apply(static_cast<Base &>(*this), static_cast<const S<reference>&>(other), CopyAssignment{});
267+
apply(*this, other, CopyAssignment{});
213268
return *this;
214269
}
215270
constexpr wrapper& operator=(const wrapper<S, const_reference>& other) {
216-
apply(static_cast<Base &>(*this), static_cast<const S<const_reference>&>(other), CopyAssignment{});
271+
apply(*this, other, CopyAssignment{});
217272
return *this;
218273
}
219274
constexpr wrapper& operator=(const wrapper<S, const_reference_restrict>& other) {
220-
apply(static_cast<Base &>(*this), static_cast<const S<const_reference_restrict>&>(other), CopyAssignment{});
275+
apply(*this, other, CopyAssignment{});
221276
return *this;
222277
}
223278

224279
constexpr wrapper(wrapper&& other) = default;
225280

226281
constexpr wrapper& operator=(wrapper&& other) { return operator=(other); }
227282

228-
constexpr wrapper<S, pointer> operator&() { return apply(static_cast<Base&>(*this), GetPointer<S<pointer>>{}); }
229-
//constexpr wrapper<S, const_pointer> operator&() const { return apply(static_cast<const Base&>(*this), GetPointer<S<const_pointer>>{}); }
283+
constexpr wrapper<S, pointer> operator&() { return apply(*this, GetPointer<S<pointer>>{}); }
284+
//constexpr wrapper<S, const_pointer> operator&() const { return apply(*this, GetPointer<S<const_pointer>>{}); }
230285
constexpr pointer<wrapper<S, reference>> operator->() { return this; }
231286
};
232287

@@ -241,7 +296,7 @@ struct wrapper<S, const_reference> : public S<const_reference> {
241296
constexpr wrapper(const S<reference_restrict>& other) : Base(apply(other, AggregateConstructor<Base>{})) {}
242297
constexpr wrapper(const S<const_reference_restrict>& other) : Base(apply(other, AggregateConstructor<Base>{})) {}
243298

244-
constexpr wrapper<S, const_pointer> operator&() const { return apply(static_cast<const Base&>(*this), GetPointer<S<const_pointer>>{}); }
299+
constexpr wrapper<S, const_pointer> operator&() const { return apply(*this, GetPointer<S<const_pointer>>{}); }
245300
constexpr const_pointer<wrapper<S, const_reference>> operator->() const { return this; }
246301
};
247302

@@ -256,7 +311,7 @@ struct wrapper<S, const_reference_restrict> : public S<const_reference_restrict>
256311
constexpr wrapper(const S<reference_restrict>& other) : Base(apply(other, AggregateConstructor<Base>{})) {}
257312
constexpr wrapper(const S<const_reference>& other) : Base(apply(other, AggregateConstructor<Base>{})) {}
258313

259-
constexpr wrapper<S, const_pointer> operator&() const { return apply(static_cast<const Base&>(*this), GetPointer<S<const_pointer>>{}); }
314+
constexpr wrapper<S, const_pointer> operator&() const { return apply(*this, GetPointer<S<const_pointer>>{}); }
260315
constexpr const_pointer<wrapper<S, const_reference>> operator->() const { return this; }
261316
};
262317

@@ -268,36 +323,36 @@ struct wrapper<S, pointer> : public S<pointer> {
268323
constexpr wrapper(Base b) : Base{static_cast<Base&&>(b)} {}
269324

270325
constexpr wrapper<S, reference> operator[] (size_t i) {
271-
return apply(static_cast<Base&>(*this), RandomAccessAt<S<reference>>{i}); }
326+
return apply(*this, RandomAccessAt<S<reference>>{i}); }
272327
constexpr const wrapper<S, const_reference> operator[] (size_t i) const {
273-
return apply(static_cast<const Base&>(*this), RandomAccessAt<S<const_reference>>{i}); }
328+
return apply(*this, RandomAccessAt<S<const_reference>>{i}); }
274329

275330
constexpr wrapper<S, reference> operator*() { return operator[](0); }
276331
constexpr wrapper<S, const_reference> operator*() const { return operator[](0); }
277332
constexpr wrapper<S, reference> operator->() { return operator[](0); }
278333
constexpr wrapper<S, const_reference> operator->() const { return operator[](0); }
279334

280335
constexpr bool operator==(const wrapper& other) const {
281-
return apply(static_cast<const Base&>(*this), FirstMember{}) == apply(static_cast<const Base&>(other), FirstMember{}); }
336+
return apply(*this, FirstMember{}) == apply(other, FirstMember{}); }
282337
constexpr bool operator!=(const wrapper& other) const {
283338
return !this->operator==(other); }
284339
constexpr bool operator<(const wrapper& other) const {
285-
return apply(static_cast<const Base&>(*this), FirstMember{}) < apply(static_cast<const Base&>(other), FirstMember{}); }
340+
return apply(*this, FirstMember{}) < apply(other, FirstMember{}); }
286341
constexpr bool operator<=(const wrapper& other) const {
287-
return apply(static_cast<const Base&>(*this), FirstMember{}) <= apply(static_cast<const Base&>(other), FirstMember{}); }
342+
return apply(*this, FirstMember{}) <= apply(other, FirstMember{}); }
288343
constexpr bool operator>(const wrapper& other) const {
289-
return apply(static_cast<const Base&>(*this), FirstMember{}) > apply(static_cast<const Base&>(other), FirstMember{}); }
344+
return apply(*this, FirstMember{}) > apply(other, FirstMember{}); }
290345
constexpr bool operator>=(const wrapper& other) const {
291-
return apply(static_cast<const Base&>(*this), FirstMember{}) >= apply(static_cast<const Base&>(other), FirstMember{}); }
346+
return apply(*this, FirstMember{}) >= apply(other, FirstMember{}); }
292347

293-
constexpr wrapper operator+(ptrdiff_t i) const { return apply(static_cast<const Base&>(*this), Advance<Base>{i}); }
348+
constexpr wrapper operator+(ptrdiff_t i) const { return apply(*this, Advance<Base>{i}); }
294349
constexpr wrapper operator-(ptrdiff_t i) const { return operator+(-i); }
295350
constexpr ptrdiff_t operator-(const wrapper& other) const {
296-
return apply(static_cast<const Base&>(*this), FirstMember{}) - apply(static_cast<const Base&>(other), FirstMember{}); }
351+
return apply(*this, FirstMember{}) - apply(other, FirstMember{}); }
297352

298-
constexpr wrapper& operator++() { apply(static_cast<Base&>(*this), PreIncrement<Base>{}); return *this; }
353+
constexpr wrapper& operator++() { apply(*this, PreIncrement<Base>{}); return *this; }
299354
constexpr wrapper& operator+=(ptrdiff_t i) { return *this = *this + i; }
300-
constexpr wrapper& operator--() { apply(static_cast<Base&>(*this), PreDecrement<Base>{}); return *this; }
355+
constexpr wrapper& operator--() { apply(*this, PreDecrement<Base>{}); return *this; }
301356
constexpr wrapper& operator-=(ptrdiff_t i) { return *this = *this - i; }
302357
};
303358

@@ -310,26 +365,26 @@ struct wrapper<S, const_pointer> : public S<const_pointer> {
310365
constexpr wrapper(const S<pointer>& other) : Base(apply(other, AggregateConstructor<Base>{})) {}
311366

312367
constexpr wrapper<S, const_reference> operator[] (size_t i) const {
313-
return apply(static_cast<const Base&>(*this), RandomAccessAt<S<const_reference>>{i}); }
368+
return apply(*this, RandomAccessAt<S<const_reference>>{i}); }
314369
constexpr wrapper<S, const_reference> operator*() const { return operator[](0); }
315370
constexpr wrapper<S, const_reference> operator->() const { return operator[](0); }
316371

317372
constexpr bool operator==(const wrapper& other) const {
318-
return apply(static_cast<const Base&>(*this), FirstMember{}) == apply(static_cast<const Base&>(other), FirstMember{}); }
373+
return apply(*this, FirstMember{}) == apply(other, FirstMember{}); }
319374
constexpr bool operator!=(const wrapper& other) const {
320375
return !this->operator==(other); }
321376
constexpr bool operator<(const wrapper& other) const {
322-
return apply(static_cast<const Base&>(*this), FirstMember{}) < apply(static_cast<const Base&>(other), FirstMember{}); }
377+
return apply(*this, FirstMember{}) < apply(other, FirstMember{}); }
323378

324379
constexpr wrapper operator+(ptrdiff_t i) const {
325-
return apply(static_cast<const Base&>(*this), Advance<Base>{i}); }
380+
return apply(*this, Advance<Base>{i}); }
326381
constexpr wrapper operator-(ptrdiff_t i) const { return operator+(-i); }
327382
constexpr ptrdiff_t operator-(const wrapper& other) const {
328-
return apply(static_cast<const Base&>(*this), FirstMember{}) - apply(static_cast<const Base&>(other), FirstMember{}); }
383+
return apply(*this, FirstMember{}) - apply(other, FirstMember{}); }
329384

330-
constexpr wrapper& operator++() { apply(static_cast<Base&>(*this), PreIncrement<Base>{}); return *this; }
385+
constexpr wrapper& operator++() { apply(*this, PreIncrement<Base>{}); return *this; }
331386
constexpr wrapper& operator+=(ptrdiff_t i) { return *this = *this + i; }
332-
constexpr wrapper& operator--() { apply(static_cast<Base&>(*this), PreDecrement<Base>{}); return *this; }
387+
constexpr wrapper& operator--() { apply(*this, PreDecrement<Base>{}); return *this; }
333388
constexpr wrapper& operator-=(ptrdiff_t i) { return *this = *this - i; }
334389
};
335390

0 commit comments

Comments
 (0)