@@ -387,33 +387,40 @@ void command_graph_generator::generate_pushes(batch& current_batch, const task&
387387 }
388388}
389389
390- // TODO: We currently generate an await push command for each local chunk, whereas we only generate a single push command for all remote chunks
391390void command_graph_generator::generate_await_pushes (batch& current_batch, const task& tsk, const assigned_chunks_with_requirements& chunks_with_requirements) {
391+ std::unordered_map<buffer_id, region_builder<3 >> per_buffer_required_boxes;
392+
392393 for (auto & [_, requirements] : chunks_with_requirements.local_chunks ) {
393394 for (auto & [bid, consumed, _] : requirements) {
394395 if (consumed.empty ()) continue ;
395396 auto & buffer = m_buffers.at (bid);
396397
397398 const auto local_sources = buffer.local_last_writer .get_region_values (consumed);
398- region_builder <3 > missing_part_boxes ;
399+ box_vector <3 > missing_parts_boxes ;
399400 for (const auto & [box, wcs] : local_sources) {
400401 // Note that we initialize all buffers as fresh, so this doesn't trigger for uninitialized reads
401- if (!box.empty () && !wcs.is_fresh ()) { missing_part_boxes. add (box); }
402+ if (!box.empty () && !wcs.is_fresh ()) { missing_parts_boxes. push_back (box); }
402403 }
403404
404- // There is data we don't yet have locally. Generate an await push command for it.
405- if (!missing_part_boxes.empty ()) {
406- const auto missing_parts = std::move (missing_part_boxes).into_region ();
405+ if (!missing_parts_boxes.empty ()) {
407406 assert (m_num_nodes > 1 );
408- auto * const ap_cmd = create_command<await_push_command>(current_batch, transfer_id (tsk.get_id (), bid, no_reduction_id), missing_parts,
409- [&](const auto & record_debug_info) { record_debug_info (buffer.debug_name ); });
410- generate_anti_dependencies (tsk, bid, buffer.local_last_writer , missing_parts, ap_cmd);
411- generate_epoch_dependencies (ap_cmd);
412- // Remember that we have this data now
413- buffer.local_last_writer .update_region (missing_parts, {ap_cmd, true /* is_replicated */ });
407+ auto & required_boxes = per_buffer_required_boxes[bid]; // allow default-insert
408+ required_boxes.add (missing_parts_boxes);
414409 }
415410 }
416411 }
412+
413+ // There is data we don't yet have locally. Generate an await push command for it.
414+ for (auto & [bid, boxes] : per_buffer_required_boxes) {
415+ auto & buffer = m_buffers.at (bid);
416+ auto region = std::move (boxes).into_region (); // moved-from after next line!
417+ auto * const ap_cmd = create_command<await_push_command>(current_batch, transfer_id (tsk.get_id (), bid, no_reduction_id), std::move (region),
418+ [&](const auto & record_debug_info) { record_debug_info (buffer.debug_name ); });
419+ generate_anti_dependencies (tsk, bid, buffer.local_last_writer , ap_cmd->get_region (), ap_cmd);
420+ generate_epoch_dependencies (ap_cmd);
421+ // Remember that we have this data now
422+ buffer.local_last_writer .update_region (ap_cmd->get_region (), {ap_cmd, true /* is_replicated */ });
423+ }
417424}
418425
419426void command_graph_generator::update_local_buffer_fresh_regions (const task& tsk, const std::unordered_map<buffer_id, region<3 >>& per_buffer_local_writes) {
0 commit comments