Skip to content

Commit 518c6f6

Browse files
author
Rafał Hibner
committed
Add optional backpressure to PushGenerator
1 parent a1c896f commit 518c6f6

1 file changed

Lines changed: 104 additions & 45 deletions

File tree

cpp/src/arrow/util/async_generator.h

Lines changed: 104 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <optional>
2626
#include <queue>
2727

28+
#include "arrow/acero/backpressure_handler.h"
2829
#include "arrow/util/async_generator_fwd.h"
2930
#include "arrow/util/async_util.h"
3031
#include "arrow/util/functional.h"
@@ -854,12 +855,108 @@ class PushGenerator {
854855
struct State {
855856
State() {}
856857

857-
util::Mutex mutex;
858+
virtual bool Push(Result<T> result) {
859+
return PushUnlocked(std::move(result), mutex.Lock());
860+
}
861+
862+
bool PushUnlocked(Result<T> result, util::Mutex::Guard lock) {
863+
if (finished) {
864+
// Closed early
865+
return false;
866+
}
867+
if (consumer_fut.has_value()) {
868+
auto fut = std::move(consumer_fut.value());
869+
consumer_fut.reset();
870+
lock.Unlock(); // unlock before potentially invoking a callback
871+
fut.MarkFinished(std::move(result));
872+
} else {
873+
result_q.push_back(std::move(result));
874+
}
875+
return true;
876+
}
877+
878+
bool Close() {
879+
auto lock = mutex.Lock();
880+
if (finished) {
881+
// Already closed
882+
return false;
883+
}
884+
finished = true;
885+
if (consumer_fut.has_value()) {
886+
auto fut = std::move(consumer_fut.value());
887+
consumer_fut.reset();
888+
lock.Unlock(); // unlock before potentially invoking a callback
889+
fut.MarkFinished(IterationTraits<T>::End());
890+
}
891+
return true;
892+
}
893+
894+
bool is_closed() const {
895+
auto lock = mutex.Lock();
896+
return finished;
897+
}
898+
899+
/// Read an item from the queue
900+
virtual Future<T> Pop() {
901+
auto lock = mutex.Lock();
902+
return PopUnlocked();
903+
}
904+
905+
Future<T> PopUnlocked() {
906+
assert(!consumer_fut.has_value()); // Non-reentrant
907+
if (!result_q.empty()) {
908+
auto fut = Future<T>::MakeFinished(std::move(result_q.front()));
909+
result_q.pop_front();
910+
return fut;
911+
}
912+
if (finished) {
913+
return AsyncGeneratorEnd<T>();
914+
}
915+
auto fut = Future<T>::Make();
916+
consumer_fut = fut;
917+
return fut;
918+
}
919+
920+
mutable util::Mutex mutex;
858921
std::deque<Result<T>> result_q;
859922
std::optional<Future<T>> consumer_fut;
860923
bool finished = false;
861924
};
862925

926+
struct StateWithBackpressure : public State {
927+
explicit StateWithBackpressure(acero::BackpressureHandler handler)
928+
: handler_(handler) {}
929+
930+
struct DoHandle {
931+
explicit DoHandle(StateWithBackpressure& state)
932+
: state_(state), start_size_(state_.result_q.size()) {}
933+
934+
~DoHandle() {
935+
// unsynced access is safe since DoHandle is internally only used when the
936+
// lock is held
937+
size_t end_size = state_.result_q.size();
938+
state_.handler_.Handle(start_size_, end_size);
939+
}
940+
941+
StateWithBackpressure& state_;
942+
size_t start_size_;
943+
};
944+
945+
bool Push(Result<T> result) override {
946+
auto lock = State::mutex.Lock();
947+
DoHandle(*this);
948+
return PushUnlocked(std::move(result), std::move(lock));
949+
}
950+
951+
Future<T> Pop() override {
952+
auto lock = State::mutex.Lock();
953+
DoHandle(*this);
954+
return State::PopUnlocked();
955+
}
956+
957+
acero::BackpressureHandler handler_;
958+
};
959+
863960
public:
864961
/// Producer API for PushGenerator
865962
class Producer {
@@ -877,20 +974,7 @@ class PushGenerator {
877974
// Generator was destroyed
878975
return false;
879976
}
880-
auto lock = state->mutex.Lock();
881-
if (state->finished) {
882-
// Closed early
883-
return false;
884-
}
885-
if (state->consumer_fut.has_value()) {
886-
auto fut = std::move(state->consumer_fut.value());
887-
state->consumer_fut.reset();
888-
lock.Unlock(); // unlock before potentially invoking a callback
889-
fut.MarkFinished(std::move(result));
890-
} else {
891-
state->result_q.push_back(std::move(result));
892-
}
893-
return true;
977+
return state->Push(std::move(result));
894978
}
895979

896980
/// \brief Tell the consumer we have finished producing
@@ -907,19 +991,7 @@ class PushGenerator {
907991
// Generator was destroyed
908992
return false;
909993
}
910-
auto lock = state->mutex.Lock();
911-
if (state->finished) {
912-
// Already closed
913-
return false;
914-
}
915-
state->finished = true;
916-
if (state->consumer_fut.has_value()) {
917-
auto fut = std::move(state->consumer_fut.value());
918-
state->consumer_fut.reset();
919-
lock.Unlock(); // unlock before potentially invoking a callback
920-
fut.MarkFinished(IterationTraits<T>::End());
921-
}
922-
return true;
994+
return state->Close();
923995
}
924996

925997
/// Return whether the generator was closed or destroyed.
@@ -929,32 +1001,19 @@ class PushGenerator {
9291001
// Generator was destroyed
9301002
return true;
9311003
}
932-
auto lock = state->mutex.Lock();
933-
return state->finished;
1004+
return state->is_closed();
9341005
}
9351006

9361007
private:
9371008
const std::weak_ptr<State> weak_state_;
9381009
};
9391010

9401011
PushGenerator() : state_(std::make_shared<State>()) {}
1012+
explicit PushGenerator(acero::BackpressureHandler handler)
1013+
: state_(std::make_shared<StateWithBackpressure>(std::move(handler))) {}
9411014

9421015
/// Read an item from the queue
943-
Future<T> operator()() const {
944-
auto lock = state_->mutex.Lock();
945-
assert(!state_->consumer_fut.has_value()); // Non-reentrant
946-
if (!state_->result_q.empty()) {
947-
auto fut = Future<T>::MakeFinished(std::move(state_->result_q.front()));
948-
state_->result_q.pop_front();
949-
return fut;
950-
}
951-
if (state_->finished) {
952-
return AsyncGeneratorEnd<T>();
953-
}
954-
auto fut = Future<T>::Make();
955-
state_->consumer_fut = fut;
956-
return fut;
957-
}
1016+
Future<T> operator()() const { return state_->Pop(); }
9581017

9591018
/// \brief Return producer-side interface
9601019
///

0 commit comments

Comments
 (0)