Skip to content

Commit 6e962e3

Browse files
committed
Use the MemoryReservation as limit memory gate
1 parent 49776be commit 6e962e3

1 file changed

Lines changed: 35 additions & 3 deletions

File tree

src/worker/worker_connection_pool.rs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
3030
use std::sync::{Arc, OnceLock};
3131
use std::task::{Context, Poll};
3232
use std::time::{Duration, SystemTime, UNIX_EPOCH};
33+
use tokio::sync::Notify;
3334
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
3435
use tokio_stream::StreamExt;
3536
use tokio_stream::wrappers::UnboundedReceiverStream;
@@ -98,6 +99,11 @@ impl WorkerConnectionPool {
9899

99100
type WorkerMsg = Result<(FlightData, FlightAppMetadata), Status>;
100101

102+
/// Soft byte budget the demux task will buffer in memory before pausing the gRPC
103+
/// pull. Per-partition channels are unbounded (to avoid head-of-line blocking
104+
/// between sibling partitions), so backpressure is enforced globally here instead.
105+
const PER_CONNECTION_BUFFER_BUDGET_BYTES: usize = 64 * 1024 * 1024;
106+
101107
/// Represents a connection to one [Worker]. Network boundaries will use this for streaming
102108
/// data from single partitions while the actual network communication is handling all the partitions
103109
/// under the hood.
@@ -115,6 +121,9 @@ pub(crate) struct WorkerConnection {
115121
cancel_token: CancellationToken,
116122
per_partition_rx: DashMap<usize, UnboundedReceiver<WorkerMsg>>,
117123

124+
// Signals the demux task that buffered memory has been freed by a consumer.
125+
mem_available_notify: Arc<Notify>,
126+
118127
// Metrics collection stuff.
119128
memory_reservation: Arc<MemoryReservation>,
120129
elapsed_compute: Time,
@@ -182,9 +191,11 @@ impl WorkerConnection {
182191
};
183192

184193
// The senders and receivers are unbounded queues used for multiplexing the record
185-
// batches sent through the single gRPC stream into one stream per partition.
186-
// The received record batches contain information of the partition to which they belong,
187-
// so we use that for determining where to put them.
194+
// batches sent through the single gRPC stream into one stream per partition. They
195+
// are unbounded to avoid head-of-line blocking: a single bounded queue could block
196+
// the demux task and starve all sibling partitions even though they have capacity,
197+
// which deadlocks queries with cross-partition dependencies.
198+
// Total memory is bounded globally below via `mem_available_notify`.
188199
let mut per_partition_tx = Vec::with_capacity(target_partition_range.len());
189200
let per_partition_rx = DashMap::with_capacity(target_partition_range.len());
190201
for partition in target_partition_range.clone() {
@@ -193,6 +204,9 @@ impl WorkerConnection {
193204
per_partition_rx.insert(partition, rx);
194205
}
195206

207+
let mem_available_notify = Arc::new(Notify::new());
208+
let mem_available_notify_for_task = Arc::clone(&mem_available_notify);
209+
196210
// Cancellation token allows us to stop the background task promptly when all partition
197211
// streams are dropped (e.g., when the query is cancelled).
198212
let cancel_token = CancellationToken::new();
@@ -215,6 +229,20 @@ impl WorkerConnection {
215229
};
216230

217231
loop {
232+
// Backpressure gate. Per-partition channels are unbounded, so we cap
233+
// total in-flight buffered bytes here by pausing the gRPC pull when
234+
// consumers haven't drained enough. This propagates flow control all
235+
// the way back to the worker without coupling sibling partitions.
236+
// We always allow a message through when reservation == 0 to avoid
237+
// livelock if a single message is larger than the budget.
238+
while memory_reservation.size() >= PER_CONNECTION_BUFFER_BUDGET_BYTES {
239+
tokio::select! {
240+
biased;
241+
_ = cancel.cancelled() => return,
242+
_ = mem_available_notify_for_task.notified() => {}
243+
}
244+
}
245+
218246
// Check for cancellation while waiting for the next message.
219247
let flight_data = tokio::select! {
220248
biased;
@@ -291,6 +319,7 @@ impl WorkerConnection {
291319
cancel_token,
292320
not_consumed_streams: Arc::new(AtomicUsize::new(per_partition_rx.len())),
293321
per_partition_rx,
322+
mem_available_notify,
294323

295324
// metrics stuff
296325
memory_reservation: memory_reservation_clone,
@@ -324,8 +353,11 @@ impl WorkerConnection {
324353
let stream = UnboundedReceiverStream::new(partition_receiver);
325354
let stream = stream.map_err(|err| FlightError::Tonic(Box::new(err)));
326355
let reservation = Arc::clone(&self.memory_reservation);
356+
let mem_available_notify = Arc::clone(&self.mem_available_notify);
327357
let stream = stream.map_ok(move |(data, meta)| {
328358
reservation.shrink(data.encoded_len());
359+
// Wake the demux task in case it is blocked on the byte budget.
360+
mem_available_notify.notify_one();
329361
let _ = &task; // <- keep the task that polls data from the network alive.
330362
on_metadata(meta);
331363
data

0 commit comments

Comments
 (0)