Skip to content

Commit e250fbe

Browse files
author
Rafał Hibner
committed
Merge branch 'PushGeneratorWithBackpressure' into combined2
2 parents 5271903 + 518c6f6 commit e250fbe

8 files changed

Lines changed: 129 additions & 77 deletions

cpp/src/arrow/acero/asof_join_node.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -514,9 +514,9 @@ class InputState : public util::SerialSequencingQueue::Processor {
514514
std::unique_ptr<BackpressureControl> backpressure_control =
515515
std::make_unique<BackpressureController>(
516516
/*node=*/asof_input, /*output=*/asof_node, backpressure_counter);
517-
ARROW_ASSIGN_OR_RAISE(
518-
auto handler, BackpressureHandler::Make(asof_input, low_threshold, high_threshold,
519-
std::move(backpressure_control)));
517+
ARROW_ASSIGN_OR_RAISE(auto handler,
518+
BackpressureHandler::Make(low_threshold, high_threshold,
519+
std::move(backpressure_control)));
520520
return std::make_unique<InputState>(index, tolerance, must_hash, may_rehash,
521521
key_hasher, asof_node, std::move(handler), schema,
522522
time_col_index, key_col_index);
@@ -763,10 +763,10 @@ class InputState : public util::SerialSequencingQueue::Processor {
763763
total_batches_ = n;
764764
}
765765

766-
Status ForceShutdown() {
766+
void ForceShutdown() {
767767
// Force the upstream input node to unpause. Necessary to avoid deadlock when we
768768
// terminate the process thread
769-
return queue_.ForceShutdown();
769+
queue_.ForceShutdown();
770770
}
771771

772772
private:
@@ -1048,7 +1048,7 @@ class AsofJoinNode : public ExecNode {
10481048
st = output_->InputFinished(this, batches_produced_);
10491049
}
10501050
for (const auto& s : state_) {
1051-
st &= s->ForceShutdown();
1051+
s->ForceShutdown();
10521052
}
10531053
}));
10541054
}

cpp/src/arrow/acero/backpressure_handler.h

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,15 @@ namespace arrow::acero {
2525

2626
class BackpressureHandler {
2727
private:
28-
BackpressureHandler(ExecNode* input, size_t low_threshold, size_t high_threshold,
28+
BackpressureHandler(size_t low_threshold, size_t high_threshold,
2929
std::unique_ptr<BackpressureControl> backpressure_control)
30-
: input_(input),
31-
low_threshold_(low_threshold),
30+
: low_threshold_(low_threshold),
3231
high_threshold_(high_threshold),
3332
backpressure_control_(std::move(backpressure_control)) {}
3433

3534
public:
3635
static Result<BackpressureHandler> Make(
37-
ExecNode* input, size_t low_threshold, size_t high_threshold,
36+
size_t low_threshold, size_t high_threshold,
3837
std::unique_ptr<BackpressureControl> backpressure_control) {
3938
if (low_threshold >= high_threshold) {
4039
return Status::Invalid("low threshold (", low_threshold,
@@ -43,7 +42,7 @@ class BackpressureHandler {
4342
if (backpressure_control == NULLPTR) {
4443
return Status::Invalid("null backpressure control parameter");
4544
}
46-
BackpressureHandler backpressure_handler(input, low_threshold, high_threshold,
45+
BackpressureHandler backpressure_handler(low_threshold, high_threshold,
4746
std::move(backpressure_control));
4847
return backpressure_handler;
4948
}
@@ -56,16 +55,7 @@ class BackpressureHandler {
5655
}
5756
}
5857

59-
Status ForceShutdown() {
60-
// It may be unintuitive to call Resume() here, but this is to avoid a deadlock.
61-
// Since acero's executor won't terminate if any one node is paused, we need to
62-
// force resume the node before stopping production.
63-
backpressure_control_->Resume();
64-
return input_->StopProducing();
65-
}
66-
6758
private:
68-
ExecNode* input_;
6959
size_t low_threshold_;
7060
size_t high_threshold_;
7161
std::unique_ptr<BackpressureControl> backpressure_control_;

cpp/src/arrow/acero/concurrent_queue.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class ConcurrentQueue {
113113
};
114114

115115
template <typename T>
116-
class BackpressureConcurrentQueue : public ConcurrentQueue<T> {
116+
class BackpressureConcurrentQueue : private ConcurrentQueue<T> {
117117
private:
118118
struct DoHandle {
119119
explicit DoHandle(BackpressureConcurrentQueue& queue)
@@ -134,6 +134,9 @@ class BackpressureConcurrentQueue : public ConcurrentQueue<T> {
134134
explicit BackpressureConcurrentQueue(BackpressureHandler handler)
135135
: handler_(std::move(handler)) {}
136136

137+
using ConcurrentQueue<T>::Empty;
138+
using ConcurrentQueue<T>::Front;
139+
137140
// Pops the last item from the queue but waits if the queue is empty until new items are
138141
// pushed.
139142
T WaitAndPop() {
@@ -152,6 +155,7 @@ class BackpressureConcurrentQueue : public ConcurrentQueue<T> {
152155

153156
// Pushes an item to the queue
154157
void Push(const T& item) {
158+
if (shutdown_) return;
155159
std::unique_lock<std::mutex> lock(ConcurrentQueue<T>::GetMutex());
156160
DoHandle do_handle(*this);
157161
ConcurrentQueue<T>::PushUnlocked(item);
@@ -164,10 +168,14 @@ class BackpressureConcurrentQueue : public ConcurrentQueue<T> {
164168
ConcurrentQueue<T>::ClearUnlocked();
165169
}
166170

167-
Status ForceShutdown() { return handler_.ForceShutdown(); }
171+
void ForceShutdown() {
172+
shutdown_ = true;
173+
Clear();
174+
}
168175

169176
private:
170177
BackpressureHandler handler_;
178+
bool shutdown_{false};
171179
};
172180

173181
} // namespace arrow::acero

cpp/src/arrow/acero/groupby_aggregate_node.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Status GroupByNode::Init() {
7070
std::unique_ptr<arrow::acero::BackpressureControl> backpressure_control =
7171
std::make_unique<BackpressureController>(inputs_[0], this);
7272
ARROW_ASSIGN_OR_RAISE(auto handler,
73-
BackpressureHandler::Make(this, low_threshold, high_threshold,
73+
BackpressureHandler::Make(low_threshold, high_threshold,
7474
std::move(backpressure_control)));
7575

7676
processor_ = acero::util::SerialSequencingQueue::Processor::MakeBackpressureWrapper(

cpp/src/arrow/acero/scalar_aggregate_node.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ Status ScalarAggregateNode::Init() {
235235
std::unique_ptr<arrow::acero::BackpressureControl> backpressure_control =
236236
std::make_unique<BackpressureController>(inputs_[0], this);
237237
ARROW_ASSIGN_OR_RAISE(auto handler,
238-
BackpressureHandler::Make(this, low_threshold, high_threshold,
238+
BackpressureHandler::Make(low_threshold, high_threshold,
239239
std::move(backpressure_control)));
240240

241241
processor_ = acero::util::SerialSequencingQueue::Processor::MakeBackpressureWrapper(

cpp/src/arrow/acero/sorted_merge_node.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class InputState : public util::SerialSequencingQueue::Processor {
125125
std::unique_ptr<arrow::acero::BackpressureControl> backpressure_control =
126126
std::make_unique<BackpressureController>(input, output, backpressure_counter);
127127
ARROW_ASSIGN_OR_RAISE(auto handler,
128-
BackpressureHandler::Make(input, low_threshold, high_threshold,
128+
BackpressureHandler::Make(low_threshold, high_threshold,
129129
std::move(backpressure_control)));
130130
return PtrType(new InputState(index, std::move(handler), schema, time_col_index));
131131
}

cpp/src/arrow/acero/util_test.cc

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,7 @@ class TestBackpressureControl : public BackpressureControl {
263263
TEST(BackpressureConcurrentQueue, BasicTest) {
264264
BackpressureTestExecNode dummy_node;
265265
auto ctrl = std::make_unique<TestBackpressureControl>(&dummy_node);
266-
ASSERT_OK_AND_ASSIGN(auto handler,
267-
BackpressureHandler::Make(&dummy_node, 2, 4, std::move(ctrl)));
266+
ASSERT_OK_AND_ASSIGN(auto handler, BackpressureHandler::Make(2, 4, std::move(ctrl)));
268267
BackpressureConcurrentQueue<int> queue(std::move(handler));
269268

270269
ConcurrentQueueBasicTest(queue);
@@ -275,8 +274,7 @@ TEST(BackpressureConcurrentQueue, BasicTest) {
275274
TEST(BackpressureConcurrentQueue, BackpressureTest) {
276275
BackpressureTestExecNode dummy_node;
277276
auto ctrl = std::make_unique<TestBackpressureControl>(&dummy_node);
278-
ASSERT_OK_AND_ASSIGN(auto handler,
279-
BackpressureHandler::Make(&dummy_node, 2, 4, std::move(ctrl)));
277+
ASSERT_OK_AND_ASSIGN(auto handler, BackpressureHandler::Make(2, 4, std::move(ctrl)));
280278
BackpressureConcurrentQueue<int> queue(std::move(handler));
281279

282280
queue.Push(6);
@@ -299,9 +297,6 @@ TEST(BackpressureConcurrentQueue, BackpressureTest) {
299297
queue.Push(11);
300298
ASSERT_TRUE(dummy_node.paused);
301299
ASSERT_FALSE(dummy_node.stopped);
302-
ASSERT_OK(queue.ForceShutdown());
303-
ASSERT_FALSE(dummy_node.paused);
304-
ASSERT_TRUE(dummy_node.stopped);
305300
}
306301

307302
} // namespace acero

cpp/src/arrow/util/async_generator.h

Lines changed: 104 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <optional>
2626
#include <queue>
2727

28+
#include "arrow/acero/backpressure_handler.h"
2829
#include "arrow/util/async_generator_fwd.h"
2930
#include "arrow/util/async_util.h"
3031
#include "arrow/util/functional.h"
@@ -854,12 +855,108 @@ class PushGenerator {
854855
struct State {
855856
State() {}
856857

857-
util::Mutex mutex;
858+
virtual bool Push(Result<T> result) {
859+
return PushUnlocked(std::move(result), mutex.Lock());
860+
}
861+
862+
bool PushUnlocked(Result<T> result, util::Mutex::Guard lock) {
863+
if (finished) {
864+
// Closed early
865+
return false;
866+
}
867+
if (consumer_fut.has_value()) {
868+
auto fut = std::move(consumer_fut.value());
869+
consumer_fut.reset();
870+
lock.Unlock(); // unlock before potentially invoking a callback
871+
fut.MarkFinished(std::move(result));
872+
} else {
873+
result_q.push_back(std::move(result));
874+
}
875+
return true;
876+
}
877+
878+
bool Close() {
879+
auto lock = mutex.Lock();
880+
if (finished) {
881+
// Already closed
882+
return false;
883+
}
884+
finished = true;
885+
if (consumer_fut.has_value()) {
886+
auto fut = std::move(consumer_fut.value());
887+
consumer_fut.reset();
888+
lock.Unlock(); // unlock before potentially invoking a callback
889+
fut.MarkFinished(IterationTraits<T>::End());
890+
}
891+
return true;
892+
}
893+
894+
bool is_closed() const {
895+
auto lock = mutex.Lock();
896+
return finished;
897+
}
898+
899+
/// Read an item from the queue
900+
virtual Future<T> Pop() {
901+
auto lock = mutex.Lock();
902+
return PopUnlocked();
903+
}
904+
905+
Future<T> PopUnlocked() {
906+
assert(!consumer_fut.has_value()); // Non-reentrant
907+
if (!result_q.empty()) {
908+
auto fut = Future<T>::MakeFinished(std::move(result_q.front()));
909+
result_q.pop_front();
910+
return fut;
911+
}
912+
if (finished) {
913+
return AsyncGeneratorEnd<T>();
914+
}
915+
auto fut = Future<T>::Make();
916+
consumer_fut = fut;
917+
return fut;
918+
}
919+
920+
mutable util::Mutex mutex;
858921
std::deque<Result<T>> result_q;
859922
std::optional<Future<T>> consumer_fut;
860923
bool finished = false;
861924
};
862925

926+
struct StateWithBackpressure : public State {
927+
explicit StateWithBackpressure(acero::BackpressureHandler handler)
928+
: handler_(handler) {}
929+
930+
struct DoHandle {
931+
explicit DoHandle(StateWithBackpressure& state)
932+
: state_(state), start_size_(state_.result_q.size()) {}
933+
934+
~DoHandle() {
935+
// unsynced access is safe since DoHandle is internally only used when the
936+
// lock is held
937+
size_t end_size = state_.result_q.size();
938+
state_.handler_.Handle(start_size_, end_size);
939+
}
940+
941+
StateWithBackpressure& state_;
942+
size_t start_size_;
943+
};
944+
945+
bool Push(Result<T> result) override {
946+
auto lock = State::mutex.Lock();
947+
DoHandle(*this);
948+
return PushUnlocked(std::move(result), std::move(lock));
949+
}
950+
951+
Future<T> Pop() override {
952+
auto lock = State::mutex.Lock();
953+
DoHandle(*this);
954+
return State::PopUnlocked();
955+
}
956+
957+
acero::BackpressureHandler handler_;
958+
};
959+
863960
public:
864961
/// Producer API for PushGenerator
865962
class Producer {
@@ -877,20 +974,7 @@ class PushGenerator {
877974
// Generator was destroyed
878975
return false;
879976
}
880-
auto lock = state->mutex.Lock();
881-
if (state->finished) {
882-
// Closed early
883-
return false;
884-
}
885-
if (state->consumer_fut.has_value()) {
886-
auto fut = std::move(state->consumer_fut.value());
887-
state->consumer_fut.reset();
888-
lock.Unlock(); // unlock before potentially invoking a callback
889-
fut.MarkFinished(std::move(result));
890-
} else {
891-
state->result_q.push_back(std::move(result));
892-
}
893-
return true;
977+
return state->Push(std::move(result));
894978
}
895979

896980
/// \brief Tell the consumer we have finished producing
@@ -907,19 +991,7 @@ class PushGenerator {
907991
// Generator was destroyed
908992
return false;
909993
}
910-
auto lock = state->mutex.Lock();
911-
if (state->finished) {
912-
// Already closed
913-
return false;
914-
}
915-
state->finished = true;
916-
if (state->consumer_fut.has_value()) {
917-
auto fut = std::move(state->consumer_fut.value());
918-
state->consumer_fut.reset();
919-
lock.Unlock(); // unlock before potentially invoking a callback
920-
fut.MarkFinished(IterationTraits<T>::End());
921-
}
922-
return true;
994+
return state->Close();
923995
}
924996

925997
/// Return whether the generator was closed or destroyed.
@@ -929,32 +1001,19 @@ class PushGenerator {
9291001
// Generator was destroyed
9301002
return true;
9311003
}
932-
auto lock = state->mutex.Lock();
933-
return state->finished;
1004+
return state->is_closed();
9341005
}
9351006

9361007
private:
9371008
const std::weak_ptr<State> weak_state_;
9381009
};
9391010

9401011
PushGenerator() : state_(std::make_shared<State>()) {}
1012+
explicit PushGenerator(acero::BackpressureHandler handler)
1013+
: state_(std::make_shared<StateWithBackpressure>(std::move(handler))) {}
9411014

9421015
/// Read an item from the queue
943-
Future<T> operator()() const {
944-
auto lock = state_->mutex.Lock();
945-
assert(!state_->consumer_fut.has_value()); // Non-reentrant
946-
if (!state_->result_q.empty()) {
947-
auto fut = Future<T>::MakeFinished(std::move(state_->result_q.front()));
948-
state_->result_q.pop_front();
949-
return fut;
950-
}
951-
if (state_->finished) {
952-
return AsyncGeneratorEnd<T>();
953-
}
954-
auto fut = Future<T>::Make();
955-
state_->consumer_fut = fut;
956-
return fut;
957-
}
1016+
Future<T> operator()() const { return state_->Pop(); }
9581017

9591018
/// \brief Return producer-side interface
9601019
///

0 commit comments

Comments
 (0)