@@ -27,7 +27,7 @@ use std::sync::Arc;
2727use ahash:: RandomState ;
2828use arrow:: array:: UInt32Array ;
2929use arrow:: compute:: take;
30- use arrow:: datatypes:: SchemaRef ;
30+ use arrow:: datatypes:: { DataType , Schema , SchemaRef } ;
3131use arrow:: record_batch:: RecordBatch ;
3232use datafusion:: common:: hash_utils:: create_hashes;
3333use datafusion:: common:: Result as DFResult ;
@@ -242,8 +242,8 @@ pub(super) async fn buffer_build_optimistic(
242242 reservation : & mut MemoryReservation ,
243243 metrics : & GraceHashJoinMetrics ,
244244) -> DFResult < BuildBufferResult > {
245+ let schema = input. schema ( ) ;
245246 let mut batches = Vec :: new ( ) ;
246- let mut total_bytes = 0usize ;
247247
248248 while let Some ( batch) = input. next ( ) . await {
249249 let batch = batch?;
@@ -254,6 +254,10 @@ pub(super) async fn buffer_build_optimistic(
254254 metrics. build_input_batches . add ( 1 ) ;
255255 metrics. build_input_rows . add ( batch. num_rows ( ) ) ;
256256
257+ // Per-batch `get_array_memory_size` is safe to use for `try_grow`
258+ // because overestimating just makes us more conservative with memory
259+ // pressure — it can only force us into the fallback path, never into
260+ // a spurious OOM.
257261 let batch_size = batch. get_array_memory_size ( ) ;
258262
259263 if reservation. try_grow ( batch_size) . is_err ( ) {
@@ -263,11 +267,94 @@ pub(super) async fn buffer_build_optimistic(
263267 return Ok ( BuildBufferResult :: NeedPartition ( batches, input) ) ;
264268 }
265269
266- total_bytes += batch_size;
267270 batches. push ( batch) ;
268271 }
269272
270- Ok ( BuildBufferResult :: Complete ( batches, total_bytes) )
273+ // Compute a size estimate for the fast-path threshold check from schema +
274+ // row count instead of `get_array_memory_size`. The latter reports the
275+ // full underlying buffer for every zero-copy slice (common after shuffle),
276+ // so a 49 MB build can look like 97 MB and spuriously fail the threshold.
277+ let actual_bytes = approximate_memory_size ( & batches, & schema) ;
278+ Ok ( BuildBufferResult :: Complete ( batches, actual_bytes) )
279+ }
280+
281+ /// Approximate in-memory size of a collection of record batches using the
282+ /// schema's per-column byte widths and a row count.
283+ ///
284+ /// Used instead of `batch.get_array_memory_size()` for the fast-path threshold
285+ /// decision because the Arrow helper reports the full underlying buffer size
286+ /// for every zero-copy slice, inflating the number by the number of slices
287+ /// when batches come out of a shuffle read. A row-count × row-width estimate
288+ /// has no such cross-slice double-counting. It is approximate for
289+ /// variable-width columns (strings, binary) — we pick a conservative 32 bytes
290+ /// per row — but good enough to gate the coarse threshold check.
291+ fn approximate_memory_size ( batches : & [ RecordBatch ] , schema : & Schema ) -> usize {
292+ let row_size = approximate_row_size ( schema) ;
293+ let total_rows: usize = batches. iter ( ) . map ( |b| b. num_rows ( ) ) . sum ( ) ;
294+ total_rows * row_size
295+ }
296+
297+ fn approximate_row_size ( schema : & Schema ) -> usize {
298+ schema
299+ . fields ( )
300+ . iter ( )
301+ . map ( |f| approximate_type_size ( f. data_type ( ) ) )
302+ . sum ( )
303+ }
304+
305+ fn approximate_type_size ( dt : & DataType ) -> usize {
306+ match dt {
307+ DataType :: Null => 0 ,
308+ DataType :: Boolean => 1 ,
309+ DataType :: Int8 | DataType :: UInt8 => 1 ,
310+ DataType :: Int16 | DataType :: UInt16 | DataType :: Float16 => 2 ,
311+ DataType :: Int32
312+ | DataType :: UInt32
313+ | DataType :: Float32
314+ | DataType :: Date32
315+ | DataType :: Time32 ( _) => 4 ,
316+ DataType :: Int64
317+ | DataType :: UInt64
318+ | DataType :: Float64
319+ | DataType :: Date64
320+ | DataType :: Time64 ( _)
321+ | DataType :: Timestamp ( _, _)
322+ | DataType :: Duration ( _)
323+ | DataType :: Interval ( _) => 8 ,
324+ DataType :: Decimal32 ( _, _) => 4 ,
325+ DataType :: Decimal64 ( _, _) => 8 ,
326+ DataType :: Decimal128 ( _, _) => 16 ,
327+ DataType :: Decimal256 ( _, _) => 32 ,
328+ DataType :: FixedSizeBinary ( n) => * n as usize ,
329+ // Variable-width: pick a conservative average. Exact strings would
330+ // need a scan over the offset buffer; good enough for a threshold
331+ // gate that is itself a heuristic.
332+ DataType :: Binary
333+ | DataType :: LargeBinary
334+ | DataType :: BinaryView
335+ | DataType :: Utf8
336+ | DataType :: LargeUtf8
337+ | DataType :: Utf8View => 32 ,
338+ DataType :: List ( f)
339+ | DataType :: LargeList ( f)
340+ | DataType :: ListView ( f)
341+ | DataType :: LargeListView ( f)
342+ | DataType :: FixedSizeList ( f, _) => 4 + approximate_type_size ( f. data_type ( ) ) ,
343+ DataType :: Struct ( fields) => fields
344+ . iter ( )
345+ . map ( |f| approximate_type_size ( f. data_type ( ) ) )
346+ . sum ( ) ,
347+ DataType :: Map ( f, _) => 8 + approximate_type_size ( f. data_type ( ) ) ,
348+ DataType :: Dictionary ( key, value) => {
349+ approximate_type_size ( key) + approximate_type_size ( value)
350+ }
351+ DataType :: Union ( fields, _) => fields
352+ . iter ( )
353+ . map ( |( _, f) | approximate_type_size ( f. data_type ( ) ) )
354+ . max ( )
355+ . unwrap_or ( 8 ) ,
356+ DataType :: RunEndEncoded ( _, values) => approximate_type_size ( values. data_type ( ) ) ,
357+ }
271358}
272359
273360/// Partition already-buffered build batches into the partition structure.
@@ -815,3 +902,70 @@ pub(super) fn sub_partition_batches(
815902 }
816903 Ok ( result)
817904}
905+
906+ #[ cfg( test) ]
907+ mod tests {
908+ use super :: * ;
909+ use arrow:: array:: { Int32Array , StringArray } ;
910+ use arrow:: datatypes:: { DataType , Field , Schema } ;
911+
912+ /// approximate_memory_size must be insensitive to zero-copy slicing -
913+ /// a batch sliced into N pieces should report the same total as the
914+ /// unsliced parent. A naive sum of get_array_memory_size would
915+ /// inflate the number by N because each slice reports the full buffer.
916+ #[ test]
917+ fn approximate_memory_size_is_slice_invariant ( ) {
918+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "id" , DataType :: Int32 , false ) ] ) ) ;
919+ let values: Vec < i32 > = ( 0 ..1000 ) . collect ( ) ;
920+ let parent = RecordBatch :: try_new (
921+ Arc :: clone ( & schema) ,
922+ vec ! [ Arc :: new( Int32Array :: from( values) ) ] ,
923+ )
924+ . unwrap ( ) ;
925+
926+ // 1000 rows * 4 bytes/row = 4000
927+ let parent_est = approximate_memory_size ( std:: slice:: from_ref ( & parent) , & schema) ;
928+ assert_eq ! ( parent_est, 4000 ) ;
929+
930+ let slices = vec ! [
931+ parent. slice( 0 , 250 ) ,
932+ parent. slice( 250 , 250 ) ,
933+ parent. slice( 500 , 250 ) ,
934+ parent. slice( 750 , 250 ) ,
935+ ] ;
936+ let sliced_est = approximate_memory_size ( & slices, & schema) ;
937+ assert_eq ! ( sliced_est, parent_est) ;
938+
939+ // Show the contrast with the naive per-batch get_array_memory_size sum.
940+ let naive: usize = slices
941+ . iter ( )
942+ . flat_map ( |b| b. columns ( ) . iter ( ) )
943+ . map ( |c| c. to_data ( ) . get_array_memory_size ( ) )
944+ . sum ( ) ;
945+ assert ! (
946+ naive > parent_est * 2 ,
947+ "naive sum inflates with slices (got {naive}, parent {parent_est})"
948+ ) ;
949+ }
950+
951+ #[ test]
952+ fn approximate_memory_size_sums_independent_batches ( ) {
953+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "id" , DataType :: Int32 , false ) ] ) ) ;
954+ let mk = |start : i32 | {
955+ let arr = Int32Array :: from ( ( start..start + 100 ) . collect :: < Vec < _ > > ( ) ) ;
956+ RecordBatch :: try_new ( Arc :: clone ( & schema) , vec ! [ Arc :: new( arr) ] ) . unwrap ( )
957+ } ;
958+ let batches = vec ! [ mk( 0 ) , mk( 100 ) , mk( 200 ) ] ;
959+ // 3 * 100 rows * 4 bytes = 1200
960+ assert_eq ! ( approximate_memory_size( & batches, & schema) , 1200 ) ;
961+ }
962+
963+ #[ test]
964+ fn approximate_memory_size_handles_strings ( ) {
965+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "s" , DataType :: Utf8 , false ) ] ) ) ;
966+ let arr = StringArray :: from ( vec ! [ "a" ; 100 ] ) ;
967+ let batch = RecordBatch :: try_new ( Arc :: clone ( & schema) , vec ! [ Arc :: new( arr) ] ) . unwrap ( ) ;
968+ // 100 rows * 32 bytes/row (heuristic) = 3200
969+ assert_eq ! ( approximate_memory_size( & [ batch] , & schema) , 3200 ) ;
970+ }
971+ }
0 commit comments