Skip to content

Commit f6764d1

Browse files
psalzPeterTh
authored andcommitted
Generate single await-push command per buffer for all local chunks
This brings await-pushes in line with pushes, where we already compute the union of all regions required by remote chunks executed on the same node.
1 parent db71b5c commit f6764d1

File tree

3 files changed

+52
-12
lines changed

3 files changed

+52
-12
lines changed

include/grid.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,15 @@ class region_builder {
273273
m_boxes.push_back(box);
274274
}
275275

276+
// Adds a set of boxes to the region builder, skipping empty boxes,
277+
// by calling `add` for each element instead of a single `insert(end)`.
278+
void add(const box_vector<Dims>& boxes) & {
279+
m_boxes.reserve(m_boxes.size() + boxes.size());
280+
for(const auto& b : boxes) {
281+
add(b);
282+
}
283+
}
284+
276285
void add(const region<Dims>& region) & {
277286
if(region.empty()) return;
278287
m_normalized = m_boxes.empty();

src/command_graph_generator.cc

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -387,33 +387,40 @@ void command_graph_generator::generate_pushes(batch& current_batch, const task&
387387
}
388388
}
389389

390-
// TODO: We currently generate an await push command for each local chunk, whereas we only generate a single push command for all remote chunks
391390
void command_graph_generator::generate_await_pushes(batch& current_batch, const task& tsk, const assigned_chunks_with_requirements& chunks_with_requirements) {
391+
std::unordered_map<buffer_id, region_builder<3>> per_buffer_required_boxes;
392+
392393
for(auto& [_, requirements] : chunks_with_requirements.local_chunks) {
393394
for(auto& [bid, consumed, _] : requirements) {
394395
if(consumed.empty()) continue;
395396
auto& buffer = m_buffers.at(bid);
396397

397398
const auto local_sources = buffer.local_last_writer.get_region_values(consumed);
398-
region_builder<3> missing_part_boxes;
399+
box_vector<3> missing_parts_boxes;
399400
for(const auto& [box, wcs] : local_sources) {
400401
// Note that we initialize all buffers as fresh, so this doesn't trigger for uninitialized reads
401-
if(!box.empty() && !wcs.is_fresh()) { missing_part_boxes.add(box); }
402+
if(!box.empty() && !wcs.is_fresh()) { missing_parts_boxes.push_back(box); }
402403
}
403404

404-
// There is data we don't yet have locally. Generate an await push command for it.
405-
if(!missing_part_boxes.empty()) {
406-
const auto missing_parts = std::move(missing_part_boxes).into_region();
405+
if(!missing_parts_boxes.empty()) {
407406
assert(m_num_nodes > 1);
408-
auto* const ap_cmd = create_command<await_push_command>(current_batch, transfer_id(tsk.get_id(), bid, no_reduction_id), missing_parts,
409-
[&](const auto& record_debug_info) { record_debug_info(buffer.debug_name); });
410-
generate_anti_dependencies(tsk, bid, buffer.local_last_writer, missing_parts, ap_cmd);
411-
generate_epoch_dependencies(ap_cmd);
412-
// Remember that we have this data now
413-
buffer.local_last_writer.update_region(missing_parts, {ap_cmd, true /* is_replicated */});
407+
auto& required_boxes = per_buffer_required_boxes[bid]; // allow default-insert
408+
required_boxes.add(missing_parts_boxes);
414409
}
415410
}
416411
}
412+
413+
// There is data we don't yet have locally. Generate an await push command for it.
414+
for(auto& [bid, boxes] : per_buffer_required_boxes) {
415+
auto& buffer = m_buffers.at(bid);
416+
auto region = std::move(boxes).into_region(); // moved-from after next line!
417+
auto* const ap_cmd = create_command<await_push_command>(current_batch, transfer_id(tsk.get_id(), bid, no_reduction_id), std::move(region),
418+
[&](const auto& record_debug_info) { record_debug_info(buffer.debug_name); });
419+
generate_anti_dependencies(tsk, bid, buffer.local_last_writer, ap_cmd->get_region(), ap_cmd);
420+
generate_epoch_dependencies(ap_cmd);
421+
// Remember that we have this data now
422+
buffer.local_last_writer.update_region(ap_cmd->get_region(), {ap_cmd, true /* is_replicated */});
423+
}
417424
}
418425

419426
void command_graph_generator::update_local_buffer_fresh_regions(const task& tsk, const std::unordered_map<buffer_id, region<3>>& per_buffer_local_writes) {

test/command_graph_transfer_tests.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,30 @@ TEST_CASE("command_graph_generator generates a single push command per buffer an
115115
CHECK(cctx.query<push_command_record>(buf1.get_id()).on(1)[1]->target_regions == push_regions<1>({{0, region<1>{{box<1>{96, 128}}}}}));
116116
}
117117

118+
TEST_CASE("command_graph_generator generates a single await_push command per buffer and task", "[command_graph_generator][command-graph]") { //
119+
cdag_test_context cctx(2);
120+
121+
const range<1> test_range = {128};
122+
auto buf0 = cctx.create_buffer(test_range);
123+
auto buf1 = cctx.create_buffer(test_range);
124+
125+
// Initialize buffers across both nodes
126+
cctx.device_compute(test_range).name("init").discard_write(buf0, acc::one_to_one{}).discard_write(buf1, acc::one_to_one{}).submit();
127+
128+
// Read in reverse order, but split task into 4 chunks each
129+
cctx.set_test_chunk_multiplier(4);
130+
cctx.device_compute(test_range).read(buf0, test_utils::access::reverse_one_to_one{}).read(buf1, test_utils::access::reverse_one_to_one{}).submit();
131+
132+
CHECK(cctx.query<push_command_record>().count_per_node() == 2);
133+
CHECK(cctx.query<await_push_command_record>().count_per_node() == 2);
134+
135+
// The union of the required regions is just the full other half
136+
CHECK(cctx.query<await_push_command_record>().on(0).iterate()[0]->await_region == region_cast<3>(region<1>{box<1>{64, 128}}));
137+
CHECK(cctx.query<await_push_command_record>().on(0).iterate()[1]->await_region == region_cast<3>(region<1>{box<1>{64, 128}}));
138+
CHECK(cctx.query<await_push_command_record>().on(1).iterate()[0]->await_region == region_cast<3>(region<1>{box<1>{0, 64}}));
139+
CHECK(cctx.query<await_push_command_record>().on(1).iterate()[1]->await_region == region_cast<3>(region<1>{box<1>{0, 64}}));
140+
}
141+
118142
TEST_CASE("command_graph_generator doesn't generate data transfer commands for the same buffer and range more than once",
119143
"[command_graph_generator][command-graph]") {
120144
cdag_test_context cctx(2);

0 commit comments

Comments
 (0)