-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathcommand_graph_generator_test_utils.h
More file actions
410 lines (332 loc) · 15.8 KB
/
command_graph_generator_test_utils.h
File metadata and controls
410 lines (332 loc) · 15.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
#pragma once
#include <memory>
#include <type_traits>
#include <vector>
#include <catch2/catch_message.hpp>
#include <catch2/internal/catch_context.hpp>
#include <fmt/format.h>
#include <fmt/ranges.h>
#include "command_graph.h"
#include "command_graph_generator.h"
#include "handler.h"
#include "print_graph.h"
#include "recorders.h"
#include "task_manager.h"
#include "types.h"
#include "graph_test_utils.h"
#include "test_utils.h"
using namespace celerity;
using namespace celerity::detail;
namespace celerity::test_utils {
template <int Dims>
std::vector<std::pair<node_id, region<3>>> push_regions(const std::vector<std::pair<node_id, region<Dims>>>& regions) {
std::vector<std::pair<node_id, region<3>>> result;
std::transform(regions.begin(), regions.end(), std::back_inserter(result), [](const auto& p) { return std::make_pair(p.first, region_cast<3>(p.second)); });
return result;
}
template <typename Record>
struct command_matcher {
static bool matches(const Record& cmd, const std::string& debug_name) {
return matchbox::match(
cmd, //
[&](const execution_command_record& ecmd) { return ecmd.debug_name == debug_name; }, //
[](const auto& /* other */) { return false; });
}
static bool matches(const Record& cmd, const task_id tid) {
if(const auto* tcmd = dynamic_cast<const task_command_record*>(&cmd); tcmd != nullptr) { return tcmd->tid == tid; }
return false;
}
template <typename R = Record, std::enable_if_t<std::is_same_v<R, push_command_record> || std::is_same_v<R, await_push_command_record>, int> = 0>
static bool matches(const R& cmd, const buffer_id bid) {
return matchbox::match(
cmd, //
[&](const push_command_record& c) { return c.trid.bid == bid; }, //
[](const auto& /* other */) { return false; });
}
static std::string print_filter(const std::string& debug_name) { return fmt::format("\"{}\"", debug_name); }
static std::string print_filter(const task_id tid) { return fmt::format("\"T{}\"", tid); }
static std::string print_filter(const buffer_id bid) { return fmt::format("\"B{}\"", bid); }
};
using command_query = graph_query<command_record, command_record, command_recorder, command_matcher>;
/// Wrapper type around command_query that adds semantics for command graphs on multiple nodes.
template <typename Record = command_record>
class distributed_command_query {
public:
template <typename R>
using command_query = graph_query<R, command_record, command_recorder, command_matcher>;
explicit distributed_command_query(std::vector<command_query<Record>>&& queries) : m_queries(std::move(queries)) {}
// allow upcast
template <typename SpecificRecord, std::enable_if_t<std::is_base_of_v<Record, SpecificRecord> && !std::is_same_v<Record, SpecificRecord>, int> = 0>
distributed_command_query(const distributed_command_query<SpecificRecord>& other)
: distributed_command_query(std::vector<command_query<Record>>(other.m_queries.begin(), other.m_queries.end())) {}
// ------------------------- distributed graph interface -------------------------
command_query<Record> on(const node_id nid) const {
REQUIRE(nid < m_queries.size());
return m_queries[nid];
}
size_t total_count() const {
size_t sum = 0;
for(const auto& q : m_queries) {
sum += q.count();
}
return sum;
}
size_t count_per_node() const {
const size_t expected = m_queries.at(0).count();
for(node_id nid = 1; nid < m_queries.size(); ++nid) {
REQUIRE(m_queries[nid].count() == expected);
}
return expected;
}
const distributed_command_query& assert_total_count(const size_t expected) const {
REQUIRE(total_count() == expected);
return *this;
}
const distributed_command_query& assert_count_per_node(const size_t expected) const {
for(node_id nid = 0; nid < m_queries.size(); ++nid) {
REQUIRE(m_queries[nid].count() == expected);
}
return *this;
}
const std::vector<command_query<Record>>& iterate_nodes() const& { return m_queries; }
std::vector<command_query<Record>> iterate_nodes() && { return std::move(m_queries); }
// ---------------------------- graph_query interface ----------------------------
template <typename SpecificRecord = Record, typename... Filters>
distributed_command_query<SpecificRecord> select_all(const Filters&... filters) const {
return apply<SpecificRecord>([&filters...](auto& q) { return q.template select_all<SpecificRecord>(filters...); });
}
template <typename SpecificRecord = Record, typename... Filters>
distributed_command_query<SpecificRecord> select_unique(const Filters&... filters) const {
return apply<SpecificRecord>([&filters...](auto& q) { return q.template select_unique<SpecificRecord>(filters...); });
}
distributed_command_query<command_record> predecessors() const {
return apply<command_record>([](auto& q) { return q.predecessors(); });
}
distributed_command_query<command_record> transitive_predecessors() const {
return apply<command_record>([](auto& q) { return q.transitive_predecessors(); });
}
template <typename SpecificRecord = Record, typename... Filters>
distributed_command_query<SpecificRecord> transitive_predecessors_across(const Filters&... filters) const {
return apply<SpecificRecord>([&filters...](auto& q) { return q.template transitive_predecessors_across<SpecificRecord>(filters...); });
}
distributed_command_query<command_record> successors() const {
return apply<command_record>([](auto& q) { return q.successors(); });
}
distributed_command_query<command_record> transitive_successors() const {
return apply<command_record>([](auto& q) { return q.transitive_successors(); });
}
template <typename SpecificRecord = Record, typename... Filters>
distributed_command_query<SpecificRecord> transitive_successors_across(const Filters&... filters) const {
return apply<SpecificRecord>([&filters...](auto& q) { return q.template transitive_successors_across<SpecificRecord>(filters...); });
}
bool is_concurrent_with(const distributed_command_query& other) const {
for(node_id nid = 0; nid < m_queries.size(); ++nid) {
if(!m_queries[nid].is_concurrent_with(other.on(nid))) return false;
}
return true;
}
template <typename SpecificRecord = Record, typename... Filters>
bool all_match(const Filters&... filters) const {
return std::all_of(m_queries.begin(), m_queries.end(), [&filters...](const auto& q) { return q.template all_match<SpecificRecord>(filters...); });
}
template <typename SpecificRecord = Record, typename... Filters>
distributed_command_query<SpecificRecord> assert_all(const Filters&... filters) const {
return apply<SpecificRecord>([&filters...](auto& q) { return q.template assert_all<SpecificRecord>(filters...); });
}
bool contains(const distributed_command_query& subset) const {
for(node_id nid = 0; nid < m_queries.size(); ++nid) {
if(!m_queries[nid].contains(subset.on(nid))) return false;
}
return true;
}
bool operator==(const distributed_command_query& other) const {
if(m_queries.size() != other.m_queries.size()) return false;
for(node_id nid = 0; nid < m_queries.size(); ++nid) {
if(m_queries[nid] != other.m_queries[nid]) return false;
}
return true;
}
bool operator!=(const distributed_command_query& other) const { return !(*this == other); }
template <typename... DistributedCommandQueries>
friend distributed_command_query union_of(const distributed_command_query& head, const DistributedCommandQueries&... tail) {
REQUIRE(((head.m_queries.size() == tail.m_queries.size()) && ...));
std::vector<command_query<Record>> result;
for(node_id nid = 0; nid < head.m_queries.size(); ++nid) {
result.push_back(union_of(head.m_queries[nid], tail.m_queries[nid]...));
}
return distributed_command_query{std::move(result)};
}
template <typename... DistributedCommandQueries>
friend distributed_command_query intersection_of(const distributed_command_query& head, const DistributedCommandQueries&... tail) {
REQUIRE(((head.m_queries.size() == tail.m_queries.size()) && ...));
std::vector<command_query<Record>> result;
for(node_id nid = 0; nid < head.m_queries.size(); ++nid) {
result.push_back(intersection_of(head.m_queries[nid], tail.m_queries[nid]...));
}
return distributed_command_query{std::move(result)};
}
friend distributed_command_query difference_of(const distributed_command_query& lhs, const distributed_command_query& rhs) {
REQUIRE(lhs.m_queries.size() == rhs.m_queries.size());
std::vector<command_query<Record>> result;
for(node_id nid = 0; nid < lhs.m_queries.size(); ++nid) {
result.push_back(difference_of(lhs.m_queries[nid], rhs.m_queries[nid]));
}
return distributed_command_query{std::move(result)};
}
private:
template <typename>
friend class distributed_command_query;
template <typename, typename, typename>
friend struct fmt::formatter;
std::vector<command_query<Record>> m_queries;
template <typename SpecificRecord, typename Function>
distributed_command_query<SpecificRecord> apply(const Function& fn) const {
std::vector<command_query<SpecificRecord>> result;
for(auto& q : m_queries) {
result.push_back(fn(q));
}
return distributed_command_query<SpecificRecord>{std::move(result)};
}
};
class cdag_test_context final : private task_manager::delegate {
friend class task_builder<cdag_test_context>;
public:
struct policy_set {
task_manager::policy_set tm;
command_graph_generator::policy_set cggen;
};
cdag_test_context(const size_t num_nodes, const policy_set& policy = {})
: m_num_nodes(num_nodes), m_tm(num_nodes, m_tdag, &m_task_recorder, this, policy.tm) {
for(node_id nid = 0; nid < num_nodes; ++nid) {
m_cdags.emplace_back(std::make_unique<command_graph>());
m_cmd_recorders.emplace_back(std::make_unique<command_recorder>());
m_cggens.emplace_back(std::make_unique<command_graph_generator>(num_nodes, nid, *m_cdags[nid], m_cmd_recorders[nid].get(), policy.cggen));
}
m_initial_epoch_tid = m_tm.generate_epoch_task(epoch_action::init);
}
~cdag_test_context() { maybe_print_graphs(); }
cdag_test_context(const cdag_test_context&) = delete;
cdag_test_context(cdag_test_context&&) = delete;
cdag_test_context& operator=(const cdag_test_context&) = delete;
cdag_test_context& operator=(cdag_test_context&&) = delete;
void task_created(const task* tsk) override {
for(auto& cggen : m_cggens) {
cggen->build_task(*tsk);
}
}
template <int Dims>
test_utils::mock_buffer<Dims> create_buffer(range<Dims> size, bool mark_as_host_initialized = false) {
const buffer_id bid = m_next_buffer_id++;
const auto buf = test_utils::mock_buffer<Dims>(bid, size);
m_tm.notify_buffer_created(bid, range_cast<3>(size), mark_as_host_initialized);
for(auto& cggen : m_cggens) {
cggen->notify_buffer_created(bid, range_cast<3>(size), mark_as_host_initialized);
}
return buf;
}
test_utils::mock_host_object create_host_object(const bool owns_instance = true) {
const host_object_id hoid = m_next_host_object_id++;
m_tm.notify_host_object_created(hoid);
for(auto& cggen : m_cggens) {
cggen->notify_host_object_created(hoid);
}
return test_utils::mock_host_object(hoid);
}
template <typename Name = unnamed_kernel, int Dims>
auto device_compute(const range<Dims>& global_size, const id<Dims>& global_offset = {}) {
return task_builder(*this).template device_compute<Name>(global_size, global_offset);
}
template <typename Name = unnamed_kernel, int Dims>
auto device_compute(const nd_range<Dims>& execution_range) {
return task_builder(*this).template device_compute<Name>(execution_range);
}
template <int Dims>
auto host_task(const range<Dims>& global_size) {
return task_builder(*this).host_task(global_size);
}
auto master_node_host_task() { return task_builder(*this).master_node_host_task(); }
auto collective_host_task(experimental::collective_group group = detail::default_collective_group) {
return task_builder(*this).collective_host_task(group);
}
task_id fence(test_utils::mock_host_object ho) {
host_object_effect effect{ho.get_id(), experimental::side_effect_order::sequential};
return m_tm.generate_fence_task(effect, nullptr);
}
template <int Dims>
task_id fence(test_utils::mock_buffer<Dims> buf, subrange<Dims> sr) {
buffer_access access{buf.get_id(), access_mode::read,
std::make_unique<range_mapper<Dims, celerity::access::fixed<Dims>>>(celerity::access::fixed<Dims>(sr), buf.get_range())};
return m_tm.generate_fence_task(std::move(access), nullptr);
}
template <int Dims>
task_id fence(test_utils::mock_buffer<Dims> buf) {
return fence(buf, {{}, buf.get_range()});
}
task_id epoch(epoch_action action) { return m_tm.generate_epoch_task(action); }
template <typename SpecificRecord = command_record, typename... Filters>
distributed_command_query<SpecificRecord> query(Filters... filters) {
std::vector<typename distributed_command_query<>::command_query<command_record>> queries;
for(auto& recorder : m_cmd_recorders) {
queries.push_back(typename distributed_command_query<>::command_query<command_record>(*recorder));
}
return distributed_command_query(std::move(queries)).template select_all<SpecificRecord>(std::forward<Filters>(filters)...);
}
void set_horizon_step(const int step) { m_tm.set_horizon_step(step); }
void set_test_chunk_multiplier(const size_t multiplier) {
for(auto& cggen : m_cggens) {
cggen->test_set_chunk_multiplier(multiplier);
}
}
task_graph& get_task_graph() { return m_tdag; }
task_manager& get_task_manager() { return m_tm; }
command_graph& get_command_graph(node_id nid) { return *m_cdags.at(nid); }
command_graph_generator& get_graph_generator(node_id nid) { return *m_cggens.at(nid); }
task_id get_initial_epoch_task() const { return m_initial_epoch_tid; }
[[nodiscard]] std::string print_task_graph() { return detail::print_task_graph(m_task_recorder, make_test_graph_title("Task Graph")); }
[[nodiscard]] std::string print_command_graph(node_id nid) {
// Don't include node id in title: All CDAG printouts must have identical preambles for combine_command_graphs to work
return detail::print_command_graph(nid, *m_cmd_recorders[nid], make_test_graph_title("Command Graph"));
}
private:
size_t m_num_nodes;
buffer_id m_next_buffer_id = 0;
host_object_id m_next_host_object_id = 0;
reduction_id m_next_reduction_id = 1; // Start from 1 as rid 0 designates "no reduction" in push commands
task_graph m_tdag;
task_manager m_tm;
task_recorder m_task_recorder;
task_id m_initial_epoch_tid = 0;
std::vector<std::unique_ptr<command_graph>> m_cdags;
std::vector<std::unique_ptr<command_graph_generator>> m_cggens;
std::vector<std::unique_ptr<command_recorder>> m_cmd_recorders;
reduction_info create_reduction(const buffer_id bid, const bool include_current_buffer_value) {
return reduction_info{m_next_reduction_id++, bid, include_current_buffer_value};
}
template <typename CGF>
task_id submit_command_group(CGF cgf) {
return m_tm.generate_command_group_task(invoke_command_group_function(cgf));
}
void maybe_print_graphs() {
if(test_utils::g_print_graphs) {
fmt::print("\n{}\n", print_task_graph());
std::vector<std::string> graphs;
for(node_id nid = 0; nid < m_num_nodes; ++nid) {
graphs.push_back(print_command_graph(nid));
}
fmt::print("\n{}\n", combine_command_graphs(graphs, make_test_graph_title("Command Graph")));
}
}
};
} // namespace celerity::test_utils
template <typename Record>
struct fmt::formatter<celerity::test_utils::distributed_command_query<Record>> : fmt::formatter<size_t> {
format_context::iterator format(const celerity::test_utils::distributed_command_query<Record>& dcq, format_context& ctx) const {
auto out = ctx.out();
fmt::format_to(out, "[{}]", fmt::join(dcq.m_queries, ", "));
return out;
}
};
template <typename Record>
struct Catch::StringMaker<celerity::test_utils::distributed_command_query<Record>> {
static std::string convert(const celerity::test_utils::distributed_command_query<Record>& dcq) { return fmt::format("{}", dcq); }
};