213213//! relation to the way the data is stored.
214214
215215use std:: collections:: VecDeque ;
216+ use std:: sync:: atomic:: { AtomicU64 , Ordering } ;
216217use std:: sync:: { LazyLock , Once , OnceLock } ;
217218use std:: { ops:: Range , sync:: Arc } ;
218219
@@ -1682,6 +1683,73 @@ impl<T: RootDecoderType> RecordBatchReader for BatchDecodeIterator<T> {
16821683 }
16831684}
16841685
1686+ /// Compute the actual data size (in bytes) of a record batch,
1687+ /// accounting only for the portion of buffers that belongs to the
1688+ /// batch's row range. Unlike `get_array_memory_size()`, this does
1689+ /// not over-count when arrays share a larger underlying page buffer.
1690+ fn batch_data_size ( batch : & RecordBatch ) -> u64 {
1691+ batch
1692+ . columns ( )
1693+ . iter ( )
1694+ . map ( |c| array_data_size ( c. as_ref ( ) ) )
1695+ . sum ( )
1696+ }
1697+
1698+ fn array_data_size ( array : & dyn arrow_array:: Array ) -> u64 {
1699+ let dt = array. data_type ( ) ;
1700+ let n = array. len ( ) as u64 ;
1701+ if let Some ( w) = dt. byte_width_opt ( ) {
1702+ return n * w as u64 ;
1703+ }
1704+ match dt {
1705+ DataType :: Boolean => n. div_ceil ( 8 ) ,
1706+ DataType :: Utf8 => {
1707+ let arr = array. as_string :: < i32 > ( ) ;
1708+ let offsets = arr. value_offsets ( ) ;
1709+ ( offsets[ n as usize ] - offsets[ 0 ] ) as u64
1710+ }
1711+ DataType :: LargeUtf8 => {
1712+ let arr = array. as_string :: < i64 > ( ) ;
1713+ let offsets = arr. value_offsets ( ) ;
1714+ ( offsets[ n as usize ] - offsets[ 0 ] ) as u64
1715+ }
1716+ DataType :: Binary => {
1717+ let arr = array. as_binary :: < i32 > ( ) ;
1718+ let offsets = arr. value_offsets ( ) ;
1719+ ( offsets[ n as usize ] - offsets[ 0 ] ) as u64
1720+ }
1721+ DataType :: LargeBinary => {
1722+ let arr = array. as_binary :: < i64 > ( ) ;
1723+ let offsets = arr. value_offsets ( ) ;
1724+ ( offsets[ n as usize ] - offsets[ 0 ] ) as u64
1725+ }
1726+ DataType :: Struct ( fields) => {
1727+ let s = array. as_struct ( ) ;
1728+ fields
1729+ . iter ( )
1730+ . enumerate ( )
1731+ . map ( |( i, _) | array_data_size ( s. column ( i) . as_ref ( ) ) )
1732+ . sum ( )
1733+ }
1734+ DataType :: List ( _) => {
1735+ let list = array. as_list :: < i32 > ( ) ;
1736+ array_data_size ( list. values ( ) . as_ref ( ) )
1737+ }
1738+ DataType :: LargeList ( _) => {
1739+ let list = array. as_list :: < i64 > ( ) ;
1740+ array_data_size ( list. values ( ) . as_ref ( ) )
1741+ }
1742+ DataType :: FixedSizeList ( _, _) => {
1743+ let list = array
1744+ . as_any ( )
1745+ . downcast_ref :: < arrow_array:: FixedSizeListArray > ( )
1746+ . unwrap ( ) ;
1747+ array_data_size ( list. values ( ) . as_ref ( ) )
1748+ }
1749+ _ => n * 64 , // fallback for uncommon types
1750+ }
1751+ }
1752+
16851753/// Estimate the number of bytes per row for a given Arrow data type.
16861754///
16871755/// For fixed-width types this is exact. For variable-width types (strings,
@@ -1734,6 +1802,9 @@ pub struct StructuralBatchDecodeStream {
17341802 /// Schema-based estimate of bytes per row, computed once at construction.
17351803 /// Only meaningful when `batch_size_bytes` is `Some`.
17361804 schema_bytes_per_row : f64 ,
1805+ /// Post-decode feedback: actual bytes-per-row measured from the most
1806+ /// recently decoded batch. Zero means no feedback yet (use schema estimate).
1807+ bytes_per_row_feedback : Arc < AtomicU64 > ,
17371808}
17381809
17391810impl StructuralBatchDecodeStream {
@@ -1771,6 +1842,7 @@ impl StructuralBatchDecodeStream {
17711842 spawn_batch_decode_tasks,
17721843 batch_size_bytes,
17731844 schema_bytes_per_row,
1845+ bytes_per_row_feedback : Arc :: new ( AtomicU64 :: new ( 0 ) ) ,
17741846 }
17751847 }
17761848
@@ -1814,7 +1886,13 @@ impl StructuralBatchDecodeStream {
18141886 }
18151887
18161888 let mut to_take = if let Some ( batch_size_bytes) = self . batch_size_bytes {
1817- let rows = ( batch_size_bytes as f64 / self . schema_bytes_per_row ) as u64 ;
1889+ let feedback = self . bytes_per_row_feedback . load ( Ordering :: Relaxed ) ;
1890+ let bpr = if feedback > 0 {
1891+ feedback as f64
1892+ } else {
1893+ self . schema_bytes_per_row
1894+ } ;
1895+ let rows = ( batch_size_bytes as f64 / bpr) as u64 ;
18181896 self . rows_remaining . min ( rows. max ( 1 ) )
18191897 } else {
18201898 self . rows_remaining . min ( self . rows_per_batch as u64 )
@@ -1854,20 +1932,30 @@ impl StructuralBatchDecodeStream {
18541932 let next_task = next_task. transpose ( ) . map ( |next_task| {
18551933 let num_rows = next_task. as_ref ( ) . map ( |t| t. num_rows ) . unwrap_or ( 0 ) ;
18561934 let emitted_batch_size_warning = slf. emitted_batch_size_warning . clone ( ) ;
1935+ let bytes_per_row_feedback = slf. bytes_per_row_feedback . clone ( ) ;
18571936 // Capture the per-stream policy once so every emitted batch task follows the
18581937 // same throughput-vs-overhead choice made by the scheduler.
18591938 let spawn_batch_decode_tasks = slf. spawn_batch_decode_tasks ;
18601939 let task = async move {
18611940 let next_task = next_task?;
1862- if spawn_batch_decode_tasks {
1941+ let batch = if spawn_batch_decode_tasks {
18631942 tokio:: spawn (
18641943 async move { next_task. into_batch ( emitted_batch_size_warning) } ,
18651944 )
18661945 . await
18671946 . map_err ( |err| Error :: wrapped ( err. into ( ) ) ) ?
18681947 } else {
18691948 next_task. into_batch ( emitted_batch_size_warning)
1949+ } ;
1950+ if let Ok ( ref b) = batch {
1951+ let num_rows = b. num_rows ( ) as u64 ;
1952+ if num_rows > 0 {
1953+ let bpr = batch_data_size ( b) / num_rows;
1954+ bytes_per_row_feedback
1955+ . store ( bpr. max ( 1 ) , Ordering :: Relaxed ) ;
1956+ }
18701957 }
1958+ batch
18711959 } ;
18721960 ( task, num_rows)
18731961 } ) ;
@@ -1978,6 +2066,7 @@ fn check_scheduler_on_drop(
19782066 . boxed ( )
19792067}
19802068
2069+ #[ allow( clippy:: too_many_arguments) ]
19812070pub fn create_decode_stream (
19822071 schema : & Schema ,
19832072 num_rows : u64 ,
@@ -2909,11 +2998,11 @@ mod tests {
29092998 use arrow_array:: Int32Array ;
29102999
29113000 // 1000 rows x 4 Int32 columns = 16 bytes/row
2912- let num_rows = 1000 ;
3001+ let num_rows: i32 = 1000 ;
29133002 let arrays: Vec < Arc < dyn arrow_array:: Array > > = ( 0 ..4 )
29143003 . map ( |col| {
29153004 Arc :: new ( Int32Array :: from_iter_values (
2916- ( 0 ..num_rows) . map ( |row| ( row * 10 + col) as i32 ) ,
3005+ ( 0 ..num_rows) . map ( move |row| row * 10 + col) ,
29173006 ) ) as _
29183007 } )
29193008 . collect ( ) ;
@@ -2963,11 +3052,11 @@ mod tests {
29633052 use arrow_array:: Int32Array ;
29643053
29653054 // Without batch_size_bytes, rows_per_batch controls batching
2966- let num_rows = 1000 ;
3055+ let num_rows: i32 = 1000 ;
29673056 let arrays: Vec < Arc < dyn arrow_array:: Array > > = ( 0 ..2 )
29683057 . map ( |col| {
29693058 Arc :: new ( Int32Array :: from_iter_values (
2970- ( 0 ..num_rows) . map ( |row| ( row * 10 + col) as i32 ) ,
3059+ ( 0 ..num_rows) . map ( move |row| row * 10 + col) ,
29713060 ) ) as _
29723061 } )
29733062 . collect ( ) ;
@@ -2991,4 +3080,67 @@ mod tests {
29913080 ) ;
29923081 }
29933082 }
3083+
3084+ #[ tokio:: test]
3085+ async fn test_byte_sized_batches_feedback_convergence ( ) {
3086+ use arrow_array:: StringArray ;
3087+
3088+ // Each row has a 100-byte string. Schema estimate = 64 bytes (default
3089+ // for Utf8), so the first batch will overshoot. The feedback loop
3090+ // should correct subsequent batches toward the target.
3091+ let num_rows = 500 ;
3092+ let value: String = "x" . repeat ( 100 ) ;
3093+ let arrays: Vec < Arc < dyn arrow_array:: Array > > = vec ! [ Arc :: new( StringArray :: from(
3094+ ( 0 ..num_rows) . map( |_| value. as_str( ) ) . collect:: <Vec <_>>( ) ,
3095+ ) ) ] ;
3096+ let schema = Arc :: new ( ArrowSchema :: new ( vec ! [ ArrowField :: new(
3097+ "s" ,
3098+ DataType :: Utf8 ,
3099+ false ,
3100+ ) ] ) ) ;
3101+ let input_batch = RecordBatch :: try_new ( schema, arrays) . unwrap ( ) ;
3102+
3103+ // Target 5000 bytes/batch. At 100 bytes/row the ideal is 50 rows/batch.
3104+ // Schema estimate is 64 bytes/row → first batch ~78 rows (overshoot).
3105+ // After feedback kicks in, batches should converge to ~50 rows.
3106+ let target_bytes: u64 = 5000 ;
3107+ let batches =
3108+ decode_batches_with_byte_limit ( & input_batch, /*batch_size=*/ 1024 , Some ( target_bytes) )
3109+ . await ;
3110+
3111+ // Verify all data round-trips correctly
3112+ let all_batches: Vec < & RecordBatch > = batches. iter ( ) . collect ( ) ;
3113+ let concatenated = arrow_select:: concat:: concat_batches (
3114+ & batches[ 0 ] . schema ( ) ,
3115+ all_batches. iter ( ) . copied ( ) ,
3116+ )
3117+ . unwrap ( ) ;
3118+ assert_eq ! ( concatenated. num_rows( ) , num_rows as usize ) ;
3119+ assert_eq ! (
3120+ concatenated. column( 0 ) . as_ref( ) ,
3121+ input_batch. column( 0 ) . as_ref( )
3122+ ) ;
3123+
3124+ // After the first batch, subsequent batches should be closer to the
3125+ // target. The ideal is 50 rows/batch.
3126+ assert ! (
3127+ batches. len( ) >= 2 ,
3128+ "need at least 2 batches to test convergence"
3129+ ) ;
3130+ // The first batch uses the schema estimate (64 bytes/row) →
3131+ // ~78 rows. After feedback the rows should settle near 50.
3132+ if batches. len ( ) >= 3 {
3133+ let second_batch_rows = batches[ 1 ] . num_rows ( ) ;
3134+ let third_batch_rows = batches[ 2 ] . num_rows ( ) ;
3135+ // Both should be within 20% of the ideal (50 rows)
3136+ assert ! (
3137+ ( 40 ..=60 ) . contains( & second_batch_rows) ,
3138+ "second batch should be near 50 rows, got {second_batch_rows}"
3139+ ) ;
3140+ assert ! (
3141+ ( 40 ..=60 ) . contains( & third_batch_rows) ,
3142+ "third batch should be near 50 rows, got {third_batch_rows}"
3143+ ) ;
3144+ }
3145+ }
29943146}
0 commit comments