-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathtest_utils.h
More file actions
628 lines (514 loc) · 24.5 KB
/
test_utils.h
File metadata and controls
628 lines (514 loc) · 24.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
#pragma once
#include <memory>
#include <string>
#include <type_traits>
#include <unordered_set>
#ifdef _WIN32
#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <windows.h>
#endif
#include <catch2/benchmark/catch_optimizer.hpp> // for keep_memory()
#include <catch2/catch_test_macros.hpp>
#include <catch2/generators/catch_generators.hpp>
#include <celerity.h>
#include "affinity.h"
#include "async_event.h"
#include "backend/sycl_backend.h"
#include "command_graph.h"
#include "command_graph_generator.h"
#include "named_threads.h"
#include "print_graph.h"
#include "print_utils.h"
#include "print_utils_internal.h"
#include "range_mapper.h"
#include "region_map.h"
#include "scheduler.h"
#include "system_info.h"
#include "task_manager.h"
#include "testspy/runtime_testspy.h"
#include "testspy/scheduler_testspy.h"
#include "types.h"
// To avoid having to come up with tons of unique kernel names, we simply use the CPP counter.
// This is non-standard but widely supported.
#define _UKN_CONCAT2(x, y) x##_##y
#define _UKN_CONCAT(x, y) _UKN_CONCAT2(x, y)
#define UKN(name) _UKN_CONCAT(name, __COUNTER__)
/**
* REQUIRE_LOOP is a utility macro for performing Catch2 REQUIRE assertions inside of loops.
* The advantage over using a regular REQUIRE is that the number of reported assertions is much lower,
* as only the first iteration is actually passed on to Catch2 (useful when showing successful assertions with `-s`).
* If an expression result is false, it will also be forwarded to Catch2.
*
* NOTE: Since the checked expression will be evaluated twice, it must be idempotent!
*/
#define REQUIRE_LOOP(...) CELERITY_DETAIL_REQUIRE_LOOP(__VA_ARGS__)
namespace celerity {
namespace detail {
const std::unordered_map<std::string, std::string> print_graphs_env_setting{{"CELERITY_PRINT_GRAPHS", "1"}};
struct graph_testspy {
template <GraphNode Node, typename Predicate>
static size_t count_nodes_if(const graph<Node>& dag, const Predicate& p) {
size_t count = 0;
for(const auto& epoch : dag.m_epochs) {
for(const auto& node : epoch.nodes) {
if(p(*node)) count++;
}
}
return count;
}
template <GraphNode Node, typename Predicate>
static const Node* find_node_if(const graph<Node>& dag, const Predicate& p) {
for(const auto& epoch : dag.m_epochs) {
for(const auto& node : epoch.nodes) {
if(p(*node)) { return node.get(); }
}
}
return nullptr;
}
template <GraphNode Node>
static size_t get_live_node_count(const graph<Node>& dag) {
size_t count = 0;
for(const auto& epoch : dag.m_epochs) {
count += epoch.nodes.size();
}
return count;
}
};
struct task_manager_testspy {
inline static constexpr task_id initial_epoch_task = task_manager::initial_epoch_task;
static const task* get_epoch_for_new_tasks(const task_manager& tm) { return tm.m_epoch_for_new_tasks; }
static const task* get_current_horizon(const task_manager& tm) { return tm.m_current_horizon; }
static const region_map<task*>& get_last_writer(const task_manager& tm, const buffer_id bid) { return tm.m_buffers.at(bid).last_writers; }
static int get_max_pseudo_critical_path_length(const task_manager& tm) { return tm.m_max_pseudo_critical_path_length; }
static const std::unordered_set<task*>& get_execution_front(const task_manager& tm) { return tm.m_execution_front; }
};
struct range_mapper_testspy {
template <int Dims>
static bool neighborhood_equals(const celerity::access::neighborhood<Dims>& lhs, const celerity::access::neighborhood<Dims>& rhs) {
return lhs.m_extent == rhs.m_extent && lhs.m_shape == rhs.m_shape;
}
};
} // namespace detail
namespace test_utils {
// Pin the benchmark threads (even in absence of a runtime) for more consistent results
struct benchmark_thread_pinner {
benchmark_thread_pinner() {
const detail::thread_pinning::runtime_configuration cfg{
.enabled = true,
.use_backend_device_submission_threads = false,
};
m_thread_pinner.emplace(cfg);
detail::thread_pinning::pin_this_thread(detail::named_threads::thread_type::application);
}
std::optional<detail::thread_pinning::thread_pinner> m_thread_pinner;
};
inline const detail::task* find_task(const detail::task_graph& tdag, const detail::task_id tid) {
return detail::graph_testspy::find_node_if(tdag, [tid](const detail::task& tsk) { return tsk.get_id() == tid; });
}
inline bool has_task(const detail::task_graph& tdag, const detail::task_id tid) { return find_task(tdag, tid) != nullptr; }
inline const detail::task* get_task(const detail::task_graph& tdag, const detail::task_id tid) {
const auto tsk = find_task(tdag, tid);
REQUIRE(tsk != nullptr);
return tsk;
}
inline size_t get_num_live_horizons(const detail::task_graph& tdag) {
return detail::graph_testspy::count_nodes_if(tdag, [](const detail::task& tsk) { return tsk.get_type() == detail::task_type::horizon; });
}
inline bool has_dependency(const detail::task_graph& tdag, detail::task_id dependent, detail::task_id dependency,
detail::dependency_kind kind = detail::dependency_kind::true_dep) {
for(auto dep : get_task(tdag, dependent)->get_dependencies()) {
if(dep.node->get_id() == dependency && dep.kind == kind) return true;
}
return false;
}
inline bool has_any_dependency(const detail::task_graph& tdag, detail::task_id dependent, detail::task_id dependency) {
for(auto dep : get_task(tdag, dependent)->get_dependencies()) {
if(dep.node->get_id() == dependency) return true;
}
return false;
}
class require_loop_assertion_registry {
public:
static require_loop_assertion_registry& get_instance() {
if(instance == nullptr) { instance = std::make_unique<require_loop_assertion_registry>(); }
return *instance;
}
void reset() { m_logged_lines.clear(); }
bool should_log(std::string line_info) {
auto [_, is_new] = m_logged_lines.emplace(std::move(line_info));
return is_new;
}
private:
inline static std::unique_ptr<require_loop_assertion_registry> instance;
std::unordered_set<std::string> m_logged_lines{};
};
#define CELERITY_DETAIL_REQUIRE_LOOP(...) \
if(celerity::test_utils::require_loop_assertion_registry::get_instance().should_log(std::string(__FILE__) + std::to_string(__LINE__))) { \
REQUIRE(__VA_ARGS__); \
} else if(!(__VA_ARGS__)) { \
REQUIRE(__VA_ARGS__); \
}
/// By default, tests fail if their log contains a warning, error or critical message. This function allows tests to pass when higher-severity messages
/// are expected. The property is re-set at the beginning of each test-case run (even when it is re-entered due to a generator or section).
void allow_max_log_level(detail::log_level level);
/// Like allow_max_log_level(), but only applies to messages that match a regex. This is used in test fixtures to allow common system-dependent messages.
void allow_higher_level_log_messages(detail::log_level level, const std::string& text_regex);
/// Returns whether the log of the current test so far contains a message that exactly equals the given log level and message. Time stamps and the log level
/// are not part of the text, but any active log_context is.
bool log_contains_exact(detail::log_level level, const std::string& text);
/// Returns whether the log of the current test so far contains a message with exactly the given log level and a message that contains `substring`.
bool log_contains_substring(detail::log_level level, const std::string& substring);
/// Returns whether the log of the current test so far contains a message with exactly the given log level and a message that matches the given regex.
bool log_matches(const detail::log_level level, const std::string& regex);
template <int Dims, typename F>
void for_each_in_range(range<Dims> range, id<Dims> offset, F&& f) {
const auto range3 = detail::range_cast<3>(range);
id<3> index;
for(index[0] = 0; index[0] < range3[0]; ++index[0]) {
for(index[1] = 0; index[1] < range3[1]; ++index[1]) {
for(index[2] = 0; index[2] < range3[2]; ++index[2]) {
f(offset + detail::id_cast<Dims>(index));
}
}
}
}
template <int Dims, typename F>
void for_each_in_range(range<Dims> range, F&& f) {
for_each_in_range(range, {}, f);
}
template <typename Predicate>
void wait_until(const Predicate& pred, const std::chrono::milliseconds timeout = std::chrono::seconds(5)) {
const auto start = std::chrono::steady_clock::now();
while(!pred()) {
if(std::chrono::steady_clock::now() - start > timeout) { FAIL("Timeout reached"); }
}
}
class mock_buffer_factory;
class mock_host_object_factory;
class cdag_test_context;
class idag_test_context;
template <int Dims>
class mock_buffer {
public:
template <access_mode Mode, typename Functor>
void get_access(handler& cgh, Functor rmfn) {
(void)detail::add_requirement(cgh, m_id, Mode, std::make_unique<detail::range_mapper<Dims, Functor>>(rmfn, m_size));
}
detail::buffer_id get_id() const { return m_id; }
range<Dims> get_range() const { return m_size; }
private:
friend class mock_buffer_factory;
friend class cdag_test_context;
friend class idag_test_context;
friend class scheduler_test_context;
detail::buffer_id m_id;
range<Dims> m_size;
mock_buffer(detail::buffer_id id, range<Dims> size) : m_id(id), m_size(size) {}
};
class mock_host_object {
public:
void add_side_effect(handler& cgh, const experimental::side_effect_order order) { (void)detail::add_requirement(cgh, m_id, order, true); }
detail::host_object_id get_id() const { return m_id; }
private:
friend class mock_host_object_factory;
friend class cdag_test_context;
friend class idag_test_context;
friend class scheduler_test_context;
detail::host_object_id m_id;
public:
explicit mock_host_object(detail::host_object_id id) : m_id(id) {}
};
class mock_buffer_factory {
public:
explicit mock_buffer_factory() = default;
explicit mock_buffer_factory(detail::task_manager& tm) : m_task_mngr(&tm) {}
explicit mock_buffer_factory(detail::task_manager& tm, detail::command_graph_generator& cggen) : m_task_mngr(&tm), m_cggen(&cggen) {}
explicit mock_buffer_factory(detail::task_manager& tm, detail::command_graph_generator& cggen, detail::instruction_graph_generator& iggen)
: m_task_mngr(&tm), m_cggen(&cggen), m_iggen(&iggen) {}
explicit mock_buffer_factory(detail::task_manager& tm, detail::scheduler& schdlr) : m_task_mngr(&tm), m_schdlr(&schdlr) {}
template <int Dims>
mock_buffer<Dims> create_buffer(range<Dims> size, bool mark_as_host_initialized = false) {
const detail::buffer_id bid = m_next_buffer_id++;
const auto buf = mock_buffer<Dims>(bid, size);
const auto user_allocation_id =
mark_as_host_initialized ? detail::allocation_id(detail::user_memory_id, m_next_user_allocation_id++) : detail::null_allocation_id;
if(m_task_mngr != nullptr) { m_task_mngr->notify_buffer_created(bid, detail::range_cast<3>(size), mark_as_host_initialized); }
if(m_schdlr != nullptr) { m_schdlr->notify_buffer_created(bid, detail::range_cast<3>(size), sizeof(int), alignof(int), user_allocation_id); }
if(m_cggen != nullptr) { m_cggen->notify_buffer_created(bid, detail::range_cast<3>(size), mark_as_host_initialized); }
if(m_iggen != nullptr) { m_iggen->notify_buffer_created(bid, detail::range_cast<3>(size), sizeof(int), alignof(int), user_allocation_id); }
return buf;
}
private:
detail::task_manager* m_task_mngr = nullptr;
detail::scheduler* m_schdlr = nullptr;
detail::command_graph_generator* m_cggen = nullptr;
detail::instruction_graph_generator* m_iggen = nullptr;
detail::buffer_id m_next_buffer_id = 0;
detail::raw_allocation_id m_next_user_allocation_id = 1;
};
class mock_host_object_factory {
public:
explicit mock_host_object_factory() = default;
explicit mock_host_object_factory(detail::task_manager& tm) : m_task_mngr(&tm) {}
explicit mock_host_object_factory(detail::task_manager& tm, detail::scheduler& schdlr) : m_task_mngr(&tm), m_schdlr(&schdlr) {}
mock_host_object create_host_object(bool owns_instance = true) {
const detail::host_object_id hoid = m_next_id++;
if(m_task_mngr != nullptr) { m_task_mngr->notify_host_object_created(hoid); }
if(m_schdlr != nullptr) { m_schdlr->notify_host_object_created(hoid, owns_instance); }
return mock_host_object(hoid);
}
private:
detail::task_manager* m_task_mngr = nullptr;
detail::scheduler* m_schdlr = nullptr;
detail::host_object_id m_next_id = 0;
};
template <typename KernelName = detail::unnamed_kernel, typename CGF, int KernelDims = 2>
detail::task_id add_compute_task(detail::task_manager& tm, CGF cgf, range<KernelDims> global_size = {1, 1}, id<KernelDims> global_offset = {}) {
return tm.generate_command_group_task(detail::invoke_command_group_function([&, gs = global_size, go = global_offset](handler& cgh) {
cgf(cgh);
cgh.parallel_for<KernelName>(gs, go, [](id<KernelDims>) {});
}));
}
template <typename KernelName = detail::unnamed_kernel, typename CGF, int KernelDims = 2>
detail::task_id add_nd_range_compute_task(detail::task_manager& tm, CGF cgf, celerity::nd_range<KernelDims> execution_range = {{1, 1}, {1, 1}}) {
return tm.generate_command_group_task(detail::invoke_command_group_function([&, er = execution_range](handler& cgh) {
cgf(cgh);
cgh.parallel_for<KernelName>(er, [](nd_item<KernelDims>) {});
}));
}
template <typename Spec, typename CGF>
detail::task_id add_host_task(detail::task_manager& tm, Spec spec, CGF cgf) {
return tm.generate_command_group_task(detail::invoke_command_group_function([&](handler& cgh) {
cgf(cgh);
cgh.host_task(spec, [](auto...) {});
}));
}
inline detail::task_id add_fence_task(detail::task_manager& tm, mock_host_object ho) {
const detail::host_object_effect effect{ho.get_id(), experimental::side_effect_order::sequential};
return tm.generate_fence_task(effect, nullptr);
}
template <int Dims>
inline detail::task_id add_fence_task(detail::task_manager& tm, mock_buffer<Dims> buf, subrange<Dims> sr) {
detail::buffer_access access{buf.get_id(), access_mode::read,
std::make_unique<detail::range_mapper<Dims, celerity::access::fixed<Dims>>>(celerity::access::fixed<Dims>(sr), buf.get_range())};
return tm.generate_fence_task(std::move(access), nullptr);
}
template <int Dims>
inline detail::task_id add_fence_task(detail::task_manager& tm, mock_buffer<Dims> buf) {
return add_fence_task(tm, buf, {{}, buf.get_range()});
}
class mock_reduction_factory {
public:
detail::reduction_info create_reduction(const detail::buffer_id bid, const bool include_current_buffer_value) {
return detail::reduction_info{m_next_id++, bid, include_current_buffer_value};
}
private:
detail::reduction_id m_next_id = 1;
};
template <int Dims>
void add_reduction(handler& cgh, mock_reduction_factory& mrf, const mock_buffer<Dims>& vars, bool include_current_buffer_value) {
detail::add_reduction(cgh, mrf.create_reduction(vars.get_id(), include_current_buffer_value));
}
detail::system_info make_system_info(const size_t num_devices, const bool supports_d2d_copies);
// This fixture (or a subclass) must be used by all tests that transitively use MPI.
class mpi_fixture {
public:
mpi_fixture() { detail::runtime_testspy::test_require_mpi(); }
mpi_fixture(const mpi_fixture&) = delete;
mpi_fixture(mpi_fixture&&) = delete;
mpi_fixture& operator=(const mpi_fixture&) = delete;
mpi_fixture& operator=(mpi_fixture&&) = delete;
~mpi_fixture() = default;
};
// Allow "falling back to generic backend" warnings to appear in log
void allow_backend_fallback_warnings();
// Allow "fence in dry run" warning to appear in log
void allow_dry_run_executor_warnings();
// This fixture (or a subclass) must be used by all tests that transitively instantiate the runtime.
class runtime_fixture : public mpi_fixture {
public:
runtime_fixture();
runtime_fixture(const runtime_fixture&) = delete;
runtime_fixture(runtime_fixture&&) = delete;
runtime_fixture& operator=(const runtime_fixture&) = delete;
runtime_fixture& operator=(runtime_fixture&&) = delete;
~runtime_fixture();
/// Destroys the runtime immediately, instead of waiting for the destructor to do it.
/// Only required when testing shutdown behavior.
void destroy_runtime_now();
private:
bool m_runtime_manually_destroyed = false;
};
template <int>
struct runtime_fixture_dims : test_utils::runtime_fixture {};
class sycl_queue_fixture {
public:
sycl_queue_fixture() {
try {
m_queue = sycl::queue(sycl::gpu_selector_v, sycl::property::queue::in_order{});
} catch(sycl::exception&) { SKIP("no GPUs available"); }
}
sycl::queue& get_sycl_queue() { return m_queue; }
// Convenience function for submitting parallel_for with global offset without having to create a CGF
template <int Dims, typename KernelFn>
void parallel_for(const range<Dims>& global_range, const id<Dims>& global_offset, KernelFn fn) {
m_queue.submit([=](sycl::handler& cgh) {
cgh.parallel_for(sycl::range<Dims>{global_range}, detail::bind_simple_kernel(fn, global_range, global_offset, global_offset));
});
m_queue.wait_and_throw();
}
private:
sycl::queue m_queue;
};
// Printing of graphs can be enabled using the "--print-graphs" command line flag
extern bool g_print_graphs;
std::string make_test_graph_title(const std::string& type);
std::string make_test_graph_title(const std::string& type, size_t num_nodes, detail::node_id local_nid);
std::string make_test_graph_title(const std::string& type, size_t num_nodes, detail::node_id local_nid, size_t num_devices_per_node);
struct task_test_context {
detail::task_graph tdag;
detail::task_recorder trec;
detail::task_manager tm;
mock_buffer_factory mbf;
mock_host_object_factory mhof;
mock_reduction_factory mrf;
detail::task_id initial_epoch_task;
explicit task_test_context(const detail::task_manager::policy_set& policy = {})
: tm(1, tdag, &trec, nullptr /* delegate */, policy), mbf(tm), mhof(tm), initial_epoch_task(tm.generate_epoch_task(detail::epoch_action::init)) {}
task_test_context(const task_test_context&) = delete;
task_test_context(task_test_context&&) = delete;
task_test_context& operator=(const task_test_context&) = delete;
task_test_context& operator=(task_test_context&&) = delete;
~task_test_context();
};
// explicitly invoke a copy constructor without repeating the type
template <typename T>
T copy(const T& v) {
return v;
}
template <typename T>
void black_hole(T&& v) {
Catch::Benchmark::keep_memory(&v);
}
// truncate_*(): unchecked versions of *_cast() with signatures friendly to parameter type inference
template <int Dims>
constexpr range<Dims> truncate_range(const range<3>& r3) {
static_assert(Dims <= 3);
range<Dims> r = detail::zeros;
for(int d = 0; d < Dims; ++d) {
r[d] = r3[d];
}
return r;
}
template <int Dims>
constexpr id<Dims> truncate_id(const id<3>& i3) {
static_assert(Dims <= 3);
id<Dims> i;
for(int d = 0; d < Dims; ++d) {
i[d] = i3[d];
}
return i;
}
template <int Dims>
subrange<Dims> truncate_subrange(const subrange<3>& sr3) {
return subrange<Dims>(truncate_id<Dims>(sr3.offset), truncate_range<Dims>(sr3.range));
}
template <int Dims>
chunk<Dims> truncate_chunk(const chunk<3>& ck3) {
return chunk<Dims>(truncate_id<Dims>(ck3.offset), truncate_range<Dims>(ck3.range), truncate_range<Dims>(ck3.global_size));
}
template <int Dims>
detail::box<Dims> truncate_box(const detail::box<3>& b3) {
return detail::box<Dims>(truncate_id<Dims>(b3.get_min()), truncate_id<Dims>(b3.get_max()));
}
template <typename T>
class vector_generator final : public Catch::Generators::IGenerator<T> {
public:
explicit vector_generator(std::vector<T>&& values) : m_values(std::move(values)) {}
const T& get() const override { return m_values[m_idx]; }
bool next() override { return ++m_idx < m_values.size(); }
private:
std::vector<T> m_values;
size_t m_idx = 0;
};
template <typename T>
Catch::Generators::GeneratorWrapper<T> from_vector(std::vector<T> values) {
return Catch::Generators::GeneratorWrapper<T>(Catch::Detail::make_unique<vector_generator<T>>(std::move(values)));
}
inline void* await(const celerity::detail::async_event& evt) {
while(!evt.is_complete()) {}
return evt.get_result();
}
} // namespace test_utils
} // namespace celerity
namespace celerity::test_utils::access {
struct reverse_one_to_one {
template <int Dims>
subrange<Dims> operator()(chunk<Dims> ck) const {
subrange<Dims> sr;
for(int d = 0; d < Dims; ++d) {
sr.offset[d] = ck.global_size[d] - ck.range[d] - ck.offset[d];
sr.range[d] = ck.range[d];
}
return sr;
}
};
} // namespace celerity::test_utils::access
namespace Catch {
template <typename A, typename B>
struct StringMaker<std::pair<A, B>> {
static std::string convert(const std::pair<A, B>& v) {
return fmt::format("({}, {})", Catch::Detail::stringify(v.first), Catch::Detail::stringify(v.second));
}
};
template <typename T>
struct StringMaker<std::optional<T>> {
static std::string convert(const std::optional<T>& v) { return v.has_value() ? Catch::Detail::stringify(*v) : "null"; }
};
#define CELERITY_TEST_UTILS_IMPLEMENT_CATCH_STRING_MAKER(Type) \
template <> \
struct StringMaker<Type> { \
static std::string convert(const Type& v) { return fmt::format("{}", v); } \
};
CELERITY_TEST_UTILS_IMPLEMENT_CATCH_STRING_MAKER(celerity::detail::allocation_id)
CELERITY_TEST_UTILS_IMPLEMENT_CATCH_STRING_MAKER(celerity::detail::transfer_id)
CELERITY_TEST_UTILS_IMPLEMENT_CATCH_STRING_MAKER(celerity::detail::sycl_backend_type)
#define CELERITY_TEST_UTILS_IMPLEMENT_CATCH_STRING_MAKER_FOR_DIMS(Type) \
template <int Dims> \
struct StringMaker<Type<Dims>> { \
static std::string convert(const Type<Dims>& v) { return fmt::format("{}", v); } \
};
CELERITY_TEST_UTILS_IMPLEMENT_CATCH_STRING_MAKER_FOR_DIMS(celerity::id)
CELERITY_TEST_UTILS_IMPLEMENT_CATCH_STRING_MAKER_FOR_DIMS(celerity::range)
CELERITY_TEST_UTILS_IMPLEMENT_CATCH_STRING_MAKER_FOR_DIMS(celerity::subrange)
CELERITY_TEST_UTILS_IMPLEMENT_CATCH_STRING_MAKER_FOR_DIMS(celerity::chunk)
CELERITY_TEST_UTILS_IMPLEMENT_CATCH_STRING_MAKER_FOR_DIMS(celerity::detail::box)
CELERITY_TEST_UTILS_IMPLEMENT_CATCH_STRING_MAKER_FOR_DIMS(celerity::detail::region)
template <>
struct StringMaker<sycl::device> {
static std::string convert(const sycl::device& d) {
return fmt::format("sycl::device(vendor_id={}, name=\"{}\")", d.get_info<sycl::info::device::vendor_id>(), d.get_info<sycl::info::device::name>());
}
};
template <>
struct StringMaker<sycl::platform> {
static std::string convert(const sycl::platform& d) {
return fmt::format("sycl::platform(vendor=\"{}\", name=\"{}\")", d.get_info<sycl::info::platform::vendor>(), d.get_info<sycl::info::platform::name>());
}
};
template <>
struct StringMaker<celerity::detail::linearized_layout> {
static std::string convert(const celerity::detail::linearized_layout& v) { return fmt::format("linearized_layout({})", v.offset_bytes); }
};
template <>
struct StringMaker<celerity::detail::strided_layout> {
static std::string convert(const celerity::detail::strided_layout& v) { return fmt::format("strided_layout({})", v.allocation); }
};
template <>
struct StringMaker<celerity::detail::region_layout> {
static std::string convert(const celerity::detail::region_layout& v) {
return matchbox::match(v, [](const auto& a) { return StringMaker<std::decay_t<decltype(a)>>::convert(a); });
}
};
} // namespace Catch