Skip to content

Commit df60b9b

Browse files
psalzGagaLP
authored andcommitted
Set up graph_query testing infrastructure for task graph
1 parent db71b5c commit df60b9b

8 files changed

Lines changed: 215 additions & 83 deletions

File tree

include/recorders.h

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -50,58 +50,51 @@ using reduction_list = std::vector<reduction_record>;
5050

5151
template <typename IdType>
5252
struct dependency_record {
53-
const IdType node;
54-
const dependency_kind kind;
55-
const dependency_origin origin;
53+
IdType predecessor;
54+
IdType successor;
55+
dependency_kind kind;
56+
dependency_origin origin;
57+
58+
dependency_record(const IdType predecessor, const IdType successor, const dependency_kind kind, const dependency_origin origin)
59+
: predecessor(predecessor), successor(successor), kind(kind), origin(origin) {}
5660
};
5761

5862
// Task recording
5963

60-
using task_dependency_list = std::vector<dependency_record<task_id>>;
64+
using task_dependency_record = dependency_record<task_id>;
6165

66+
// TODO: Switch to hierarchy like for CDAG/IDAG
6267
struct task_record {
6368
task_record(const task& tsk, const buffer_name_map& get_buffer_debug_name);
6469

65-
task_id tid;
70+
task_id id;
6671
std::string debug_name;
6772
collective_group_id cgid;
6873
task_type type;
6974
task_geometry geometry;
7075
reduction_list reductions;
7176
access_list accesses;
7277
detail::side_effect_map side_effect_map;
73-
task_dependency_list dependencies;
7478
};
7579

7680
class task_recorder {
7781
public:
78-
void record(task_record&& record) { m_recorded_tasks.push_back(std::move(record)); }
82+
void record(std::unique_ptr<task_record> record) { m_recorded_tasks.push_back(std::move(record)); }
7983

80-
const std::vector<task_record>& get_tasks() const { return m_recorded_tasks; }
84+
void record_dependency(const task_dependency_record& dependency) { m_recorded_dependencies.push_back(dependency); }
8185

82-
const task_record& get_task(const task_id tid) const {
83-
const auto it = std::find_if(m_recorded_tasks.begin(), m_recorded_tasks.end(), [tid](const task_record& rec) { return rec.tid == tid; });
84-
assert(it != m_recorded_tasks.end());
85-
return *it;
86-
}
86+
const std::vector<std::unique_ptr<task_record>>& get_graph_nodes() const { return m_recorded_tasks; }
87+
88+
const std::vector<task_dependency_record>& get_dependencies() const { return m_recorded_dependencies; }
8789

8890
private:
89-
std::vector<task_record> m_recorded_tasks;
91+
std::vector<std::unique_ptr<task_record>> m_recorded_tasks;
92+
std::vector<task_dependency_record> m_recorded_dependencies;
9093
};
9194

9295
// Command recording
9396

94-
using command_dependency_list = std::vector<dependency_record<command_id>>;
95-
96-
struct command_dependency_record {
97-
command_id predecessor;
98-
command_id successor;
99-
dependency_kind kind;
100-
dependency_origin origin;
101-
102-
command_dependency_record(const command_id predecessor, const command_id successor, const dependency_kind kind, const dependency_origin origin)
103-
: predecessor(predecessor), successor(successor), kind(kind), origin(origin) {}
104-
};
97+
using command_dependency_record = dependency_record<command_id>;
10598

10699
struct command_record : matchbox::acceptor<struct push_command_record, struct await_push_command_record, struct reduction_command_record,
107100
struct epoch_command_record, struct horizon_command_record, struct execution_command_record, struct fence_command_record> {

src/print_graph.cc

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,42 @@ void format_requirements(std::string& label, const reduction_list& reductions, c
8585
}
8686
}
8787

88+
template <typename IdType>
89+
void print_dependencies(
90+
const std::vector<dependency_record<IdType>>& dependencies, std::string& dot,
91+
const std::function<std::string(IdType)> id_transform = [](IdType id) { return std::to_string(id); }) {
92+
// Sort and deduplicate edges
93+
struct dependency_edge {
94+
IdType predecessor;
95+
IdType successor;
96+
};
97+
struct dependency_edge_order {
98+
bool operator()(const dependency_edge& lhs, const dependency_edge& rhs) const {
99+
if(lhs.predecessor < rhs.predecessor) return true;
100+
if(lhs.predecessor > rhs.predecessor) return false;
101+
return lhs.successor < rhs.successor;
102+
}
103+
};
104+
struct dependency_kind_order {
105+
bool operator()(const std::pair<dependency_kind, dependency_origin>& lhs, const std::pair<dependency_kind, dependency_origin>& rhs) const {
106+
return (lhs.first == dependency_kind::true_dep && rhs.first != dependency_kind::true_dep);
107+
}
108+
};
109+
std::map<dependency_edge, std::set<std::pair<dependency_kind, dependency_origin>, dependency_kind_order>, dependency_edge_order>
110+
dependencies_by_edge; // ordered and unique
111+
for(const auto& dep : dependencies) {
112+
dependencies_by_edge[{dep.predecessor, dep.successor}].insert(std::pair{dep.kind, dep.origin});
113+
}
114+
for(const auto& [edge, meta] : dependencies_by_edge) {
115+
// If there's at most two edges, take the first one (likely a true dependency followed by an anti-dependency). If there's more, bail (don't style).
116+
const auto style = meta.size() <= 2 ? dependency_style(meta.begin()->first, meta.begin()->second) : std::string{};
117+
fmt::format_to(std::back_inserter(dot), "{}->{}[{}];", id_transform(edge.predecessor), id_transform(edge.successor), style);
118+
}
119+
}
120+
88121
std::string get_task_label(const task_record& tsk) {
89122
std::string label;
90-
fmt::format_to(std::back_inserter(label), "T{}", tsk.tid);
123+
fmt::format_to(std::back_inserter(label), "T{}", tsk.id);
91124
if(!tsk.debug_name.empty()) { fmt::format_to(std::back_inserter(label), " \"{}\"", utils::escape_for_dot_label(tsk.debug_name)); }
92125

93126
fmt::format_to(std::back_inserter(label), "<br/><b>{}</b>", task_type_string(tsk.type));
@@ -107,16 +140,15 @@ std::string make_graph_preamble(const std::string& title) { return fmt::format("
107140
std::string print_task_graph(const task_recorder& recorder, const std::string& title) {
108141
std::string dot = make_graph_preamble(title);
109142

110-
CELERITY_DEBUG("print_task_graph, {} entries", recorder.get_tasks().size());
143+
CELERITY_DEBUG("print_task_graph, {} entries", recorder.get_graph_nodes().size());
111144

112-
for(const auto& tsk : recorder.get_tasks()) {
113-
const char* shape = tsk.type == task_type::epoch || tsk.type == task_type::horizon ? "ellipse" : "box style=rounded";
114-
fmt::format_to(std::back_inserter(dot), "{}[shape={} label=<{}>];", tsk.tid, shape, get_task_label(tsk));
115-
for(auto d : tsk.dependencies) {
116-
fmt::format_to(std::back_inserter(dot), "{}->{}[{}];", d.node, tsk.tid, dependency_style(d.kind, d.origin));
117-
}
145+
for(const auto& tsk : recorder.get_graph_nodes()) {
146+
const char* shape = tsk->type == task_type::epoch || tsk->type == task_type::horizon ? "ellipse" : "box style=rounded";
147+
fmt::format_to(std::back_inserter(dot), "{}[shape={} label=<{}>];", tsk->id, shape, get_task_label(*tsk));
118148
}
119149

150+
print_dependencies(recorder.get_dependencies(), dot);
151+
120152
dot += "}";
121153
return dot;
122154
}
@@ -135,7 +167,7 @@ std::string print_command_graph(const node_id local_nid, const command_recorder&
135167
std::string main_dot;
136168
std::map<task_id, std::string> task_subgraph_dot; // this map must be ordered!
137169

138-
const auto local_to_global_id = [local_nid](uint64_t id) {
170+
const auto local_to_global_id = [local_nid](auto id) -> std::string {
139171
// IDs in the DOT language may not start with a digit (unless the whole thing is a numeral)
140172
return fmt::format("id_{}_{}", local_nid, id);
141173
};
@@ -241,33 +273,7 @@ std::string print_command_graph(const node_id local_nid, const command_recorder&
241273
});
242274
};
243275

244-
// Sort and deduplicate edges
245-
struct dependency_edge {
246-
command_id predecessor;
247-
command_id successor;
248-
};
249-
struct dependency_edge_order {
250-
bool operator()(const dependency_edge& lhs, const dependency_edge& rhs) const {
251-
if(lhs.predecessor < rhs.predecessor) return true;
252-
if(lhs.predecessor > rhs.predecessor) return false;
253-
return lhs.successor < rhs.successor;
254-
}
255-
};
256-
struct dependency_kind_order {
257-
bool operator()(const std::pair<dependency_kind, dependency_origin>& lhs, const std::pair<dependency_kind, dependency_origin>& rhs) const {
258-
return (lhs.first == dependency_kind::true_dep && rhs.first != dependency_kind::true_dep);
259-
}
260-
};
261-
std::map<dependency_edge, std::set<std::pair<dependency_kind, dependency_origin>, dependency_kind_order>, dependency_edge_order>
262-
dependencies_by_edge; // ordered and unique
263-
for(const auto& dep : recorder.get_dependencies()) {
264-
dependencies_by_edge[{dep.predecessor, dep.successor}].insert(std::pair{dep.kind, dep.origin});
265-
}
266-
for(const auto& [edge, meta] : dependencies_by_edge) {
267-
// If there's at most two edges, take the first one (likely a true dependency followed by an anti-dependency). If there's more, bail (don't style).
268-
const auto style = meta.size() <= 2 ? dependency_style(meta.begin()->first, meta.begin()->second) : std::string{};
269-
fmt::format_to(std::back_inserter(main_dot), "{}->{}[{}];", local_to_global_id(edge.predecessor), local_to_global_id(edge.successor), style);
270-
}
276+
print_dependencies<command_id>(recorder.get_dependencies(), main_dot, local_to_global_id);
271277

272278
std::string result_dot = make_graph_preamble(title);
273279
for(auto& [_, sg_dot] : task_subgraph_dot) {

src/recorders.cc

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,10 @@ reduction_list build_reduction_list(const task& tsk, const buffer_name_map& get_
3939
return ret;
4040
}
4141

42-
task_dependency_list build_task_dependency_list(const task& tsk) {
43-
task_dependency_list ret;
44-
for(const auto& dep : tsk.get_dependencies()) {
45-
ret.push_back({dep.node->get_id(), dep.kind, dep.origin});
46-
}
47-
return ret;
48-
}
49-
5042
task_record::task_record(const task& tsk, const buffer_name_map& get_buffer_debug_name)
51-
: tid(tsk.get_id()), debug_name(tsk.get_debug_name()), cgid(tsk.get_collective_group_id()), type(tsk.get_type()), geometry(tsk.get_geometry()),
43+
: id(tsk.get_id()), debug_name(tsk.get_debug_name()), cgid(tsk.get_collective_group_id()), type(tsk.get_type()), geometry(tsk.get_geometry()),
5244
reductions(build_reduction_list(tsk, get_buffer_debug_name)), accesses(build_access_list(tsk, get_buffer_debug_name)),
53-
side_effect_map(tsk.get_side_effect_map()), dependencies(build_task_dependency_list(tsk)) {}
45+
side_effect_map(tsk.get_side_effect_map()) {}
5446

5547
// Commands
5648

src/task_manager.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ namespace detail {
189189
void task_manager::invoke_callbacks(const task* tsk) const {
190190
if(m_delegate != nullptr) { m_delegate->task_created(tsk); }
191191
if(m_task_recorder != nullptr) {
192-
m_task_recorder->record(task_record(*tsk, [this](const buffer_id bid) { return m_buffers.at(bid).debug_name; }));
192+
m_task_recorder->record(std::make_unique<task_record>(*tsk, [this](const buffer_id bid) { return m_buffers.at(bid).debug_name; }));
193193
}
194194
}
195195

@@ -198,6 +198,7 @@ namespace detail {
198198
depender.add_dependency({&dependee, kind, origin});
199199
m_execution_front.erase(&dependee);
200200
m_max_pseudo_critical_path_length = std::max(m_max_pseudo_critical_path_length, depender.get_pseudo_critical_path_length());
201+
if(m_task_recorder != nullptr) { m_task_recorder->record_dependency({dependee.get_id(), depender.get_id(), kind, origin}); }
201202
}
202203

203204
bool task_manager::need_new_horizon() const {

test/graph_test_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ using namespace celerity::detail;
1111

1212
namespace celerity::test_utils {
1313

14+
class tdag_test_context;
1415
class cdag_test_context;
1516
class idag_test_context;
1617
class scheduler_test_context;
1718

1819
template <typename TestContext>
1920
class task_builder {
21+
friend class tdag_test_context;
2022
friend class cdag_test_context;
2123
friend class idag_test_context;
2224
friend class scheduler_test_context;

test/print_graph_tests.cc

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,10 @@ TEST_CASE("task-graph printing is unchanged", "[print_graph][task-graph]") {
3737
// replace the `expected` value with the new dot graph.
3838
const std::string expected =
3939
"digraph G{label=<Task Graph>;pad=0.2;0[shape=ellipse label=<T0<br/><b>epoch</b>>];1[shape=box style=rounded label=<T1<br/><b>device-compute</b> "
40-
"[0,0,0] + [64,1,1]<br/><i>discard_write</i> B1 {[0,0,0] - [1,1,1]}>];0->1[color=orchid];2[shape=box style=rounded label=<T2<br/><b>device-compute</b> "
41-
"[0,0,0] + [64,1,1]<br/><i>discard_write</i> B0 {[0,0,0] - [64,1,1]}>];0->2[color=orchid];3[shape=box style=rounded "
42-
"label=<T3<br/><b>device-compute</b> [0,0,0] + [64,1,1]<br/>(R1) <i>read_write</i> B1 {[0,0,0] - [1,1,1]}<br/><i>read</i> B0 {[0,0,0] - "
43-
"[64,1,1]}>];1->3[];2->3[];4[shape=box style=rounded label=<T4<br/><b>device-compute</b> [0,0,0] + [64,1,1]<br/><i>read</i> B1 {[0,0,0] - "
44-
"[1,1,1]}>];3->4[];}";
40+
"[0,0,0] + [64,1,1]<br/><i>discard_write</i> B1 {[0,0,0] - [1,1,1]}>];2[shape=box style=rounded label=<T2<br/><b>device-compute</b> [0,0,0] + "
41+
"[64,1,1]<br/><i>discard_write</i> B0 {[0,0,0] - [64,1,1]}>];3[shape=box style=rounded label=<T3<br/><b>device-compute</b> [0,0,0] + [64,1,1]<br/>(R1) "
42+
"<i>read_write</i> B1 {[0,0,0] - [1,1,1]}<br/><i>read</i> B0 {[0,0,0] - [64,1,1]}>];4[shape=box style=rounded label=<T4<br/><b>device-compute</b> "
43+
"[0,0,0] + [64,1,1]<br/><i>read</i> B1 {[0,0,0] - [1,1,1]}>];0->1[color=orchid];0->2[color=orchid];1->3[];2->3[];3->4[];}";
4544

4645
const auto dot = print_task_graph(tt.trec);
4746
CHECK(dot == expected);
@@ -316,12 +315,12 @@ TEST_CASE_METHOD(test_utils::runtime_fixture, "full graph is printed if CELERITY
316315
SECTION("task graph") {
317316
const auto* expected =
318317
"digraph G{label=<Task Graph>;pad=0.2;0[shape=ellipse label=<T0<br/><b>epoch</b>>];1[shape=box style=rounded label=<T1<br/><b>host-compute</b> "
319-
"[0,0,0] + [16,1,1]<br/><i>read_write</i> B0 {[0,0,0] - [16,1,1]}>];0->1[];2[shape=ellipse "
320-
"label=<T2<br/><b>horizon</b>>];1->2[color=orange];3[shape=box style=rounded label=<T3<br/><b>host-compute</b> [0,0,0] + "
321-
"[16,1,1]<br/><i>read_write</i> B0 {[0,0,0] - [16,1,1]}>];1->3[];4[shape=ellipse "
322-
"label=<T4<br/><b>horizon</b>>];3->4[color=orange];2->4[color=orange];5[shape=box style=rounded label=<T5<br/><b>host-compute</b> [0,0,0] + "
323-
"[16,1,1]<br/><i>read_write</i> B0 {[0,0,0] - [16,1,1]}>];3->5[];6[shape=ellipse "
324-
"label=<T6<br/><b>horizon</b>>];5->6[color=orange];4->6[color=orange];7[shape=ellipse label=<T7<br/><b>epoch</b>>];6->7[color=orange];}";
318+
"[0,0,0] + [16,1,1]<br/><i>read_write</i> B0 {[0,0,0] - [16,1,1]}>];2[shape=ellipse label=<T2<br/><b>horizon</b>>];3[shape=box style=rounded "
319+
"label=<T3<br/><b>host-compute</b> [0,0,0] + [16,1,1]<br/><i>read_write</i> B0 {[0,0,0] - [16,1,1]}>];4[shape=ellipse "
320+
"label=<T4<br/><b>horizon</b>>];5[shape=box style=rounded label=<T5<br/><b>host-compute</b> [0,0,0] + [16,1,1]<br/><i>read_write</i> B0 {[0,0,0] - "
321+
"[16,1,1]}>];6[shape=ellipse label=<T6<br/><b>horizon</b>>];7[shape=ellipse "
322+
"label=<T7<br/><b>epoch</"
323+
"b>>];0->1[];1->2[color=orange];1->3[];2->4[color=orange];3->4[color=orange];3->5[];4->6[color=orange];5->6[color=orange];6->7[color=orange];}";
325324

326325
const auto dot = runtime_testspy::print_task_graph(celerity::detail::runtime::get_instance());
327326
CHECK(dot == expected);

0 commit comments

Comments
 (0)