diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 0b4eb3786f555..41fdd109de782 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -205,6 +205,19 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug { /// On error the `allocation` will not be increased in size fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()>; + /// Attempt to reclaim `target_bytes` from existing spillable consumers already registered + /// with this pool. + /// + /// `exclude_consumer_id`, when provided, identifies the current requester and should not be + /// reclaimed from to avoid re-entering the same operator while it is mid-allocation. + fn reclaim( + &self, + _target_bytes: usize, + _exclude_consumer_id: Option, + ) -> Result { + Ok(0) + } + /// Return the total amount of memory reserved fn reserved(&self) -> usize; @@ -240,11 +253,22 @@ pub enum MemoryLimit { /// For help with allocation accounting, see the [`proxy`] module. /// /// [proxy]: datafusion_common::utils::proxy -#[derive(Debug)] pub struct MemoryConsumer { name: String, can_spill: bool, id: usize, + reclaimer: Option>, +} + +impl std::fmt::Debug for MemoryConsumer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MemoryConsumer") + .field("name", &self.name) + .field("can_spill", &self.can_spill) + .field("id", &self.id) + .field("has_reclaimer", &self.reclaimer.is_some()) + .finish() + } } impl PartialEq for MemoryConsumer { @@ -283,6 +307,7 @@ impl MemoryConsumer { name: name.into(), can_spill: false, id: Self::new_unique_id(), + reclaimer: None, } } @@ -294,6 +319,7 @@ impl MemoryConsumer { name: self.name.clone(), can_spill: self.can_spill, id: Self::new_unique_id(), + reclaimer: self.reclaimer.clone(), } } @@ -307,6 +333,15 @@ impl MemoryConsumer { Self { can_spill, ..self } } + /// Configure a callback that can reclaim memory from this consumer when another consumer in + /// the same pool is under pressure. + pub fn with_reclaimer(self, reclaimer: Arc) -> Self { + Self { + reclaimer: Some(reclaimer), + ..self + } + } + /// Returns true if this allocation can spill to disk pub fn can_spill(&self) -> bool { self.can_spill @@ -317,6 +352,11 @@ impl MemoryConsumer { &self.name } + /// Returns the reclaim callback registered for this consumer, if any. + pub fn reclaimer(&self) -> Option> { + self.reclaimer.clone() + } + /// Registers this [`MemoryConsumer`] with the provided [`MemoryPool`] returning /// a [`MemoryReservation`] that can be used to grow or shrink the memory reservation pub fn register(self, pool: &Arc) -> MemoryReservation { @@ -331,6 +371,12 @@ impl MemoryConsumer { } } +/// Callback implemented by spillable operators that can synchronously reclaim existing +/// reservations when another consumer in the same pool is under pressure. +pub trait MemoryReclaimer: Send + Sync { + fn reclaim(&self, target_bytes: usize) -> Result; +} + /// A registration of a [`MemoryConsumer`] with a [`MemoryPool`]. /// /// Calls [`MemoryPool::unregister`] on drop to return any memory to diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index b10270851cc06..10299a4d52b7f 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -16,7 +16,8 @@ // under the License. use crate::memory_pool::{ - MemoryConsumer, MemoryLimit, MemoryPool, MemoryReservation, human_readable_size, + MemoryConsumer, MemoryLimit, MemoryPool, MemoryReclaimer, MemoryReservation, + human_readable_size, }; use datafusion_common::HashMap; use datafusion_common::{DataFusionError, Result, resources_datafusion_err}; @@ -24,6 +25,7 @@ use log::debug; use parking_lot::Mutex; use std::{ num::NonZeroUsize, + sync::Arc, sync::atomic::{AtomicUsize, Ordering}, }; @@ -269,12 +271,24 @@ fn insufficient_capacity_err( ) } -#[derive(Debug)] struct TrackedConsumer { name: String, can_spill: bool, reserved: AtomicUsize, peak: AtomicUsize, + reclaimer: Option>, +} + +impl std::fmt::Debug for TrackedConsumer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TrackedConsumer") + .field("name", &self.name) + .field("can_spill", &self.can_spill) + .field("reserved", &self.reserved()) + .field("peak", &self.peak()) + .field("has_reclaimer", &self.reclaimer.is_some()) + .finish() + } } impl TrackedConsumer { @@ -428,6 +442,7 @@ impl MemoryPool for TrackConsumersPool { can_spill: consumer.can_spill(), reserved: Default::default(), peak: Default::default(), + reclaimer: consumer.reclaimer(), }, ); @@ -488,6 +503,50 @@ impl MemoryPool for TrackConsumersPool { Ok(()) } + fn reclaim( + &self, + target_bytes: usize, + exclude_consumer_id: Option, + ) -> Result { + if target_bytes == 0 { + return Ok(0); + } + + let mut candidates = self + .tracked_consumers + .lock() + .iter() + .filter_map(|(consumer_id, tracked_consumer)| { + let reserved = tracked_consumer.reserved(); + let reclaimer = tracked_consumer.reclaimer.as_ref()?; + if exclude_consumer_id == Some(*consumer_id) + || !tracked_consumer.can_spill + || reserved == 0 + { + return None; + } + + Some((*consumer_id, reserved, Arc::clone(reclaimer))) + }) + .collect::>(); + candidates.sort_by( + |(left_id, left_reserved, _), (right_id, right_reserved, _)| { + right_reserved + .cmp(left_reserved) + .then_with(|| left_id.cmp(right_id)) + }, + ); + + let mut reclaimed = 0; + for (_, _, reclaimer) in candidates { + if reclaimed >= target_bytes { + break; + } + reclaimed += reclaimer.reclaim(target_bytes - reclaimed)?; + } + Ok(reclaimed) + } + fn reserved(&self) -> usize { self.inner.reserved() } @@ -513,6 +572,24 @@ mod tests { use insta::{Settings, allow_duplicates, assert_snapshot}; use std::sync::Arc; + #[derive(Debug)] + struct TestReclaimer { + reservation: Arc>>>, + } + + impl MemoryReclaimer for TestReclaimer { + fn reclaim(&self, target_bytes: usize) -> Result { + let Some(reservation) = self.reservation.lock().clone() else { + return Ok(0); + }; + let reclaimed = reservation.size().min(target_bytes); + if reclaimed > 0 { + reservation.shrink(reclaimed); + } + Ok(reclaimed) + } + } + fn make_settings() -> Settings { let mut settings = Settings::clone_current(); settings.add_filter( @@ -811,4 +888,80 @@ mod tests { r1#[ID](can spill: false) consumed 20.0 B, peak 20.0 B. "); } + + #[test] + fn test_tracked_consumers_pool_reclaim_prefers_largest_consumer() { + let pool = Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(200), + NonZeroUsize::new(3).unwrap(), + )) as Arc; + + let first_reservation_handle = Arc::new(Mutex::new(None)); + let first = Arc::new( + MemoryConsumer::new("spillable-1") + .with_can_spill(true) + .with_reclaimer(Arc::new(TestReclaimer { + reservation: Arc::clone(&first_reservation_handle), + })) + .register(&pool), + ); + *first_reservation_handle.lock() = Some(Arc::clone(&first)); + first.grow(100); + + let second_reservation_handle = Arc::new(Mutex::new(None)); + let second = Arc::new( + MemoryConsumer::new("spillable-2") + .with_can_spill(true) + .with_reclaimer(Arc::new(TestReclaimer { + reservation: Arc::clone(&second_reservation_handle), + })) + .register(&pool), + ); + *second_reservation_handle.lock() = Some(Arc::clone(&second)); + second.grow(60); + + let reclaimed = pool.reclaim(80, None).unwrap(); + + assert_eq!(reclaimed, 80); + assert_eq!(first.size(), 20); + assert_eq!(second.size(), 60); + } + + #[test] + fn test_tracked_consumers_pool_reclaim_excludes_requester() { + let pool = Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(200), + NonZeroUsize::new(3).unwrap(), + )) as Arc; + + let first_reservation_handle = Arc::new(Mutex::new(None)); + let first = Arc::new( + MemoryConsumer::new("spillable-1") + .with_can_spill(true) + .with_reclaimer(Arc::new(TestReclaimer { + reservation: Arc::clone(&first_reservation_handle), + })) + .register(&pool), + ); + *first_reservation_handle.lock() = Some(Arc::clone(&first)); + first.grow(100); + + let second_reservation_handle = Arc::new(Mutex::new(None)); + let second = Arc::new( + MemoryConsumer::new("spillable-2") + .with_can_spill(true) + .with_reclaimer(Arc::new(TestReclaimer { + reservation: Arc::clone(&second_reservation_handle), + })) + .register(&pool), + ); + *second_reservation_handle.lock() = Some(Arc::clone(&second)); + second.grow(60); + + let reclaimed = pool.reclaim(80, Some(first.consumer().id())).unwrap(); + + assert_eq!(reclaimed, 60); + assert_eq!(first.size(), 100); + assert_eq!(second.size(), 0); + } }