2525#include < optional>
2626#include < queue>
2727
28+ #include " arrow/acero/backpressure_handler.h"
2829#include " arrow/util/async_generator_fwd.h"
2930#include " arrow/util/async_util.h"
3031#include " arrow/util/functional.h"
@@ -854,12 +855,108 @@ class PushGenerator {
854855 struct State {
855856 State () {}
856857
857- util::Mutex mutex;
858+ virtual bool Push (Result<T> result) {
859+ return PushUnlocked (std::move (result), mutex.Lock ());
860+ }
861+
862+ bool PushUnlocked (Result<T> result, util::Mutex::Guard lock) {
863+ if (finished) {
864+ // Closed early
865+ return false ;
866+ }
867+ if (consumer_fut.has_value ()) {
868+ auto fut = std::move (consumer_fut.value ());
869+ consumer_fut.reset ();
870+ lock.Unlock (); // unlock before potentially invoking a callback
871+ fut.MarkFinished (std::move (result));
872+ } else {
873+ result_q.push_back (std::move (result));
874+ }
875+ return true ;
876+ }
877+
878+ bool Close () {
879+ auto lock = mutex.Lock ();
880+ if (finished) {
881+ // Already closed
882+ return false ;
883+ }
884+ finished = true ;
885+ if (consumer_fut.has_value ()) {
886+ auto fut = std::move (consumer_fut.value ());
887+ consumer_fut.reset ();
888+ lock.Unlock (); // unlock before potentially invoking a callback
889+ fut.MarkFinished (IterationTraits<T>::End ());
890+ }
891+ return true ;
892+ }
893+
894+ bool is_closed () const {
895+ auto lock = mutex.Lock ();
896+ return finished;
897+ }
898+
899+ // / Read an item from the queue
900+ virtual Future<T> Pop () {
901+ auto lock = mutex.Lock ();
902+ return PopUnlocked ();
903+ }
904+
905+ Future<T> PopUnlocked () {
906+ assert (!consumer_fut.has_value ()); // Non-reentrant
907+ if (!result_q.empty ()) {
908+ auto fut = Future<T>::MakeFinished (std::move (result_q.front ()));
909+ result_q.pop_front ();
910+ return fut;
911+ }
912+ if (finished) {
913+ return AsyncGeneratorEnd<T>();
914+ }
915+ auto fut = Future<T>::Make ();
916+ consumer_fut = fut;
917+ return fut;
918+ }
919+
920+ mutable util::Mutex mutex;
858921 std::deque<Result<T>> result_q;
859922 std::optional<Future<T>> consumer_fut;
860923 bool finished = false ;
861924 };
862925
926+ struct StateWithBackpressure : public State {
927+ explicit StateWithBackpressure (acero::BackpressureHandler handler)
928+ : handler_(handler) {}
929+
930+ struct DoHandle {
931+ explicit DoHandle (StateWithBackpressure& state)
932+ : state_(state), start_size_(state_.result_q.size()) {}
933+
934+ ~DoHandle () {
935+ // unsynced access is safe since DoHandle is internally only used when the
936+ // lock is held
937+ size_t end_size = state_.result_q .size ();
938+ state_.handler_ .Handle (start_size_, end_size);
939+ }
940+
941+ StateWithBackpressure& state_;
942+ size_t start_size_;
943+ };
944+
945+ bool Push (Result<T> result) override {
946+ auto lock = State::mutex.Lock ();
947+ DoHandle (*this );
948+ return PushUnlocked (std::move (result), std::move (lock));
949+ }
950+
951+ Future<T> Pop () override {
952+ auto lock = State::mutex.Lock ();
953+ DoHandle (*this );
954+ return State::PopUnlocked ();
955+ }
956+
957+ acero::BackpressureHandler handler_;
958+ };
959+
863960 public:
864961 // / Producer API for PushGenerator
865962 class Producer {
@@ -877,20 +974,7 @@ class PushGenerator {
877974 // Generator was destroyed
878975 return false ;
879976 }
880- auto lock = state->mutex .Lock ();
881- if (state->finished ) {
882- // Closed early
883- return false ;
884- }
885- if (state->consumer_fut .has_value ()) {
886- auto fut = std::move (state->consumer_fut .value ());
887- state->consumer_fut .reset ();
888- lock.Unlock (); // unlock before potentially invoking a callback
889- fut.MarkFinished (std::move (result));
890- } else {
891- state->result_q .push_back (std::move (result));
892- }
893- return true ;
977+ return state->Push (std::move (result));
894978 }
895979
896980 // / \brief Tell the consumer we have finished producing
@@ -907,19 +991,7 @@ class PushGenerator {
907991 // Generator was destroyed
908992 return false ;
909993 }
910- auto lock = state->mutex .Lock ();
911- if (state->finished ) {
912- // Already closed
913- return false ;
914- }
915- state->finished = true ;
916- if (state->consumer_fut .has_value ()) {
917- auto fut = std::move (state->consumer_fut .value ());
918- state->consumer_fut .reset ();
919- lock.Unlock (); // unlock before potentially invoking a callback
920- fut.MarkFinished (IterationTraits<T>::End ());
921- }
922- return true ;
994+ return state->Close ();
923995 }
924996
925997 // / Return whether the generator was closed or destroyed.
@@ -929,32 +1001,19 @@ class PushGenerator {
9291001 // Generator was destroyed
9301002 return true ;
9311003 }
932- auto lock = state->mutex .Lock ();
933- return state->finished ;
1004+ return state->is_closed ();
9341005 }
9351006
9361007 private:
9371008 const std::weak_ptr<State> weak_state_;
9381009 };
9391010
9401011 PushGenerator () : state_(std::make_shared<State>()) {}
1012+ explicit PushGenerator (acero::BackpressureHandler handler)
1013+ : state_(std::make_shared<StateWithBackpressure>(std::move(handler))) {}
9411014
9421015 // / Read an item from the queue
943- Future<T> operator ()() const {
944- auto lock = state_->mutex .Lock ();
945- assert (!state_->consumer_fut .has_value ()); // Non-reentrant
946- if (!state_->result_q .empty ()) {
947- auto fut = Future<T>::MakeFinished (std::move (state_->result_q .front ()));
948- state_->result_q .pop_front ();
949- return fut;
950- }
951- if (state_->finished ) {
952- return AsyncGeneratorEnd<T>();
953- }
954- auto fut = Future<T>::Make ();
955- state_->consumer_fut = fut;
956- return fut;
957- }
1016+ Future<T> operator ()() const { return state_->Pop (); }
9581017
9591018 // / \brief Return producer-side interface
9601019 // /
0 commit comments