@@ -1682,6 +1682,35 @@ impl<T: RootDecoderType> RecordBatchReader for BatchDecodeIterator<T> {
16821682 }
16831683}
16841684
1685+ /// Estimate the number of bytes per row for a given Arrow data type.
1686+ ///
1687+ /// For fixed-width types this is exact. For variable-width types (strings,
1688+ /// binary, lists) a rough default is used. The estimate is used as a
1689+ /// starting point when `batch_size_bytes` is set; a post-decode feedback
1690+ /// loop corrects it after the first batch.
1691+ fn estimate_bytes_per_row ( data_type : & DataType ) -> f64 {
1692+ if let Some ( w) = data_type. byte_width_opt ( ) {
1693+ return w as f64 ;
1694+ }
1695+ match data_type {
1696+ DataType :: Boolean => 1.0 ,
1697+ DataType :: Utf8 | DataType :: Binary | DataType :: LargeUtf8 | DataType :: LargeBinary => 64.0 ,
1698+ DataType :: Struct ( fields) => fields
1699+ . iter ( )
1700+ . map ( |f| estimate_bytes_per_row ( f. data_type ( ) ) )
1701+ . sum ( ) ,
1702+ DataType :: List ( child) | DataType :: LargeList ( child) => {
1703+ 5.0 * estimate_bytes_per_row ( child. data_type ( ) )
1704+ }
1705+ DataType :: FixedSizeList ( child, dim) => {
1706+ * dim as f64 * estimate_bytes_per_row ( child. data_type ( ) )
1707+ }
1708+ DataType :: Dictionary ( _, value_type) => estimate_bytes_per_row ( value_type) ,
1709+ DataType :: Map ( entries, _) => 5.0 * estimate_bytes_per_row ( entries. data_type ( ) ) ,
1710+ _ => 64.0 ,
1711+ }
1712+ }
1713+
16851714/// A stream that takes scheduled jobs and generates decode tasks from them.
16861715pub struct StructuralBatchDecodeStream {
16871716 context : DecoderContext ,
@@ -1702,6 +1731,9 @@ pub struct StructuralBatchDecodeStream {
17021731 spawn_batch_decode_tasks : bool ,
17031732 /// If set, target this many bytes per batch instead of `rows_per_batch` rows.
17041733 batch_size_bytes : Option < u64 > ,
1734+ /// Schema-based estimate of bytes per row, computed once at construction.
1735+ /// Only meaningful when `batch_size_bytes` is `Some`.
1736+ schema_bytes_per_row : f64 ,
17051737}
17061738
17071739impl StructuralBatchDecodeStream {
@@ -1722,6 +1754,11 @@ impl StructuralBatchDecodeStream {
17221754 spawn_batch_decode_tasks : bool ,
17231755 batch_size_bytes : Option < u64 > ,
17241756 ) -> Self {
1757+ let schema_bytes_per_row = if batch_size_bytes. is_some ( ) {
1758+ estimate_bytes_per_row ( root_decoder. data_type ( ) ) . max ( 1.0 )
1759+ } else {
1760+ 0.0
1761+ } ;
17251762 Self {
17261763 context : DecoderContext :: new ( scheduled) ,
17271764 root_decoder,
@@ -1733,6 +1770,7 @@ impl StructuralBatchDecodeStream {
17331770 emitted_batch_size_warning : Arc :: new ( Once :: new ( ) ) ,
17341771 spawn_batch_decode_tasks,
17351772 batch_size_bytes,
1773+ schema_bytes_per_row,
17361774 }
17371775 }
17381776
@@ -1775,7 +1813,12 @@ impl StructuralBatchDecodeStream {
17751813 return Ok ( None ) ;
17761814 }
17771815
1778- let mut to_take = self . rows_remaining . min ( self . rows_per_batch as u64 ) ;
1816+ let mut to_take = if let Some ( batch_size_bytes) = self . batch_size_bytes {
1817+ let rows = ( batch_size_bytes as f64 / self . schema_bytes_per_row ) as u64 ;
1818+ self . rows_remaining . min ( rows. max ( 1 ) )
1819+ } else {
1820+ self . rows_remaining . min ( self . rows_per_batch as u64 )
1821+ } ;
17791822 self . rows_remaining -= to_take;
17801823
17811824 let scheduled_need = ( self . rows_drained + to_take) . saturating_sub ( self . rows_scheduled ) ;
@@ -2774,4 +2817,178 @@ mod tests {
27742817 let ranges = DecodeBatchScheduler :: indices_to_ranges ( & indices) ;
27752818 assert_eq ! ( ranges, vec![ 1 ..4 , 5 ..8 , 9 ..10 ] ) ;
27762819 }
2820+
2821+ #[ test]
2822+ fn test_estimate_bytes_per_row ( ) {
2823+ assert_eq ! ( estimate_bytes_per_row( & DataType :: Int32 ) , 4.0 ) ;
2824+ assert_eq ! ( estimate_bytes_per_row( & DataType :: Int64 ) , 8.0 ) ;
2825+ assert_eq ! ( estimate_bytes_per_row( & DataType :: Float32 ) , 4.0 ) ;
2826+ assert_eq ! ( estimate_bytes_per_row( & DataType :: Boolean ) , 1.0 ) ;
2827+ assert_eq ! ( estimate_bytes_per_row( & DataType :: Utf8 ) , 64.0 ) ;
2828+ assert_eq ! ( estimate_bytes_per_row( & DataType :: Binary ) , 64.0 ) ;
2829+ // Struct of 4 x Int32 = 16 bytes
2830+ let struct_type = DataType :: Struct ( Fields :: from ( vec ! [
2831+ ArrowField :: new( "a" , DataType :: Int32 , false ) ,
2832+ ArrowField :: new( "b" , DataType :: Int32 , false ) ,
2833+ ArrowField :: new( "c" , DataType :: Int32 , false ) ,
2834+ ArrowField :: new( "d" , DataType :: Int32 , false ) ,
2835+ ] ) ) ;
2836+ assert_eq ! ( estimate_bytes_per_row( & struct_type) , 16.0 ) ;
2837+ }
2838+
2839+ /// Helper: encode a batch, then decode it as a stream with optional
2840+ /// `batch_size_bytes`, collecting all output batches.
2841+ async fn decode_batches_with_byte_limit (
2842+ batch : & RecordBatch ,
2843+ batch_size : u32 ,
2844+ batch_size_bytes : Option < u64 > ,
2845+ ) -> Vec < RecordBatch > {
2846+ use crate :: encoder:: { default_encoding_strategy, encode_batch, EncodingOptions } ;
2847+ use crate :: version:: LanceFileVersion ;
2848+
2849+ let version = LanceFileVersion :: V2_1 ;
2850+ let options = EncodingOptions {
2851+ version,
2852+ ..Default :: default ( )
2853+ } ;
2854+ let strategy = default_encoding_strategy ( version) ;
2855+ let schema = Schema :: try_from ( batch. schema ( ) . as_ref ( ) ) . unwrap ( ) ;
2856+ let encoded = encode_batch ( batch, Arc :: new ( schema. clone ( ) ) , strategy. as_ref ( ) , & options)
2857+ . await
2858+ . unwrap ( ) ;
2859+
2860+ let io_scheduler =
2861+ Arc :: new ( BufferScheduler :: new ( encoded. data . clone ( ) ) ) as Arc < dyn EncodingsIo > ;
2862+ let cache = Arc :: new ( lance_core:: cache:: LanceCache :: with_capacity ( 128 * 1024 * 1024 ) ) ;
2863+ let decoder_plugins = Arc :: new ( DecoderPlugins :: default ( ) ) ;
2864+
2865+ let mut decode_scheduler = DecodeBatchScheduler :: try_new (
2866+ encoded. schema . as_ref ( ) ,
2867+ & encoded. top_level_columns ,
2868+ & encoded. page_table ,
2869+ & vec ! [ ] ,
2870+ encoded. num_rows ,
2871+ decoder_plugins,
2872+ io_scheduler. clone ( ) ,
2873+ cache,
2874+ & FilterExpression :: no_filter ( ) ,
2875+ & DecoderConfig :: default ( ) ,
2876+ )
2877+ . await
2878+ . unwrap ( ) ;
2879+
2880+ let ( tx, rx) = unbounded_channel ( ) ;
2881+ decode_scheduler. schedule_range (
2882+ 0 ..encoded. num_rows ,
2883+ & FilterExpression :: no_filter ( ) ,
2884+ tx,
2885+ io_scheduler,
2886+ ) ;
2887+
2888+ let mut decode_stream = create_decode_stream (
2889+ & encoded. schema ,
2890+ encoded. num_rows ,
2891+ batch_size,
2892+ /*is_structural=*/ true ,
2893+ /*should_validate=*/ true ,
2894+ /*spawn_structural_batch_decode_tasks=*/ true ,
2895+ rx,
2896+ batch_size_bytes,
2897+ )
2898+ . unwrap ( ) ;
2899+
2900+ let mut batches = Vec :: new ( ) ;
2901+ while let Some ( task) = decode_stream. next ( ) . await {
2902+ batches. push ( task. task . await . unwrap ( ) ) ;
2903+ }
2904+ batches
2905+ }
2906+
2907+ #[ tokio:: test]
2908+ async fn test_byte_sized_batches_fixed_width ( ) {
2909+ use arrow_array:: Int32Array ;
2910+
2911+ // 1000 rows x 4 Int32 columns = 16 bytes/row
2912+ let num_rows = 1000 ;
2913+ let arrays: Vec < Arc < dyn arrow_array:: Array > > = ( 0 ..4 )
2914+ . map ( |col| {
2915+ Arc :: new ( Int32Array :: from_iter_values (
2916+ ( 0 ..num_rows) . map ( |row| ( row * 10 + col) as i32 ) ,
2917+ ) ) as _
2918+ } )
2919+ . collect ( ) ;
2920+
2921+ let schema = Arc :: new ( ArrowSchema :: new ( vec ! [
2922+ ArrowField :: new( "a" , DataType :: Int32 , false ) ,
2923+ ArrowField :: new( "b" , DataType :: Int32 , false ) ,
2924+ ArrowField :: new( "c" , DataType :: Int32 , false ) ,
2925+ ArrowField :: new( "d" , DataType :: Int32 , false ) ,
2926+ ] ) ) ;
2927+ let input_batch = RecordBatch :: try_new ( schema, arrays) . unwrap ( ) ;
2928+
2929+ // 16 bytes/row, batch_size_bytes=1600 => 100 rows/batch
2930+ let batches =
2931+ decode_batches_with_byte_limit ( & input_batch, /*batch_size=*/ 1024 , Some ( 1600 ) ) . await ;
2932+
2933+ // Should produce 10 batches of 100 rows each
2934+ assert_eq ! ( batches. len( ) , 10 ) ;
2935+ for ( i, batch) in batches. iter ( ) . enumerate ( ) {
2936+ assert_eq ! (
2937+ batch. num_rows( ) ,
2938+ 100 ,
2939+ "batch {i} should have 100 rows, got {}" ,
2940+ batch. num_rows( )
2941+ ) ;
2942+ }
2943+
2944+ // Verify roundtrip: concatenate and compare
2945+ let all_batches: Vec < & RecordBatch > = batches. iter ( ) . collect ( ) ;
2946+ let concatenated = arrow_select:: concat:: concat_batches (
2947+ & batches[ 0 ] . schema ( ) ,
2948+ all_batches. iter ( ) . copied ( ) ,
2949+ )
2950+ . unwrap ( ) ;
2951+ assert_eq ! ( concatenated. num_rows( ) , num_rows as usize ) ;
2952+ for col in 0 ..4 {
2953+ assert_eq ! (
2954+ concatenated. column( col) . as_ref( ) ,
2955+ input_batch. column( col) . as_ref( ) ,
2956+ "column {col} roundtrip mismatch"
2957+ ) ;
2958+ }
2959+ }
2960+
2961+ #[ tokio:: test]
2962+ async fn test_byte_sized_batches_none_unchanged ( ) {
2963+ use arrow_array:: Int32Array ;
2964+
2965+ // Without batch_size_bytes, rows_per_batch controls batching
2966+ let num_rows = 1000 ;
2967+ let arrays: Vec < Arc < dyn arrow_array:: Array > > = ( 0 ..2 )
2968+ . map ( |col| {
2969+ Arc :: new ( Int32Array :: from_iter_values (
2970+ ( 0 ..num_rows) . map ( |row| ( row * 10 + col) as i32 ) ,
2971+ ) ) as _
2972+ } )
2973+ . collect ( ) ;
2974+
2975+ let schema = Arc :: new ( ArrowSchema :: new ( vec ! [
2976+ ArrowField :: new( "x" , DataType :: Int32 , false ) ,
2977+ ArrowField :: new( "y" , DataType :: Int32 , false ) ,
2978+ ] ) ) ;
2979+ let input_batch = RecordBatch :: try_new ( schema, arrays) . unwrap ( ) ;
2980+
2981+ // batch_size=250, batch_size_bytes=None => 4 batches of 250 rows
2982+ let batches =
2983+ decode_batches_with_byte_limit ( & input_batch, /*batch_size=*/ 250 , None ) . await ;
2984+ assert_eq ! ( batches. len( ) , 4 ) ;
2985+ for ( i, batch) in batches. iter ( ) . enumerate ( ) {
2986+ assert_eq ! (
2987+ batch. num_rows( ) ,
2988+ 250 ,
2989+ "batch {i} should have 250 rows, got {}" ,
2990+ batch. num_rows( )
2991+ ) ;
2992+ }
2993+ }
27772994}
0 commit comments