Skip to content

Commit c52b4cc

Browse files
westonpaceclaude
andcommitted
feat: byte-based row selection in StructuralBatchDecodeStream
When `batch_size_bytes` is `Some`, compute the number of rows to drain per batch from an estimated bytes-per-row instead of using `rows_per_batch`. The estimate is computed once from the schema using `estimate_bytes_per_row()`, which is exact for fixed-width types and uses rough defaults for variable-width types. Part of #6387 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4d0bc34 commit c52b4cc

1 file changed

Lines changed: 218 additions & 1 deletion

File tree

rust/lance-encoding/src/decoder.rs

Lines changed: 218 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,35 @@ impl<T: RootDecoderType> RecordBatchReader for BatchDecodeIterator<T> {
16821682
}
16831683
}
16841684

1685+
/// Estimate the number of bytes per row for a given Arrow data type.
1686+
///
1687+
/// For fixed-width types this is exact. For variable-width types (strings,
1688+
/// binary, lists) a rough default is used. The estimate is used as a
1689+
/// starting point when `batch_size_bytes` is set; a post-decode feedback
1690+
/// loop corrects it after the first batch.
1691+
fn estimate_bytes_per_row(data_type: &DataType) -> f64 {
1692+
if let Some(w) = data_type.byte_width_opt() {
1693+
return w as f64;
1694+
}
1695+
match data_type {
1696+
DataType::Boolean => 1.0,
1697+
DataType::Utf8 | DataType::Binary | DataType::LargeUtf8 | DataType::LargeBinary => 64.0,
1698+
DataType::Struct(fields) => fields
1699+
.iter()
1700+
.map(|f| estimate_bytes_per_row(f.data_type()))
1701+
.sum(),
1702+
DataType::List(child) | DataType::LargeList(child) => {
1703+
5.0 * estimate_bytes_per_row(child.data_type())
1704+
}
1705+
DataType::FixedSizeList(child, dim) => {
1706+
*dim as f64 * estimate_bytes_per_row(child.data_type())
1707+
}
1708+
DataType::Dictionary(_, value_type) => estimate_bytes_per_row(value_type),
1709+
DataType::Map(entries, _) => 5.0 * estimate_bytes_per_row(entries.data_type()),
1710+
_ => 64.0,
1711+
}
1712+
}
1713+
16851714
/// A stream that takes scheduled jobs and generates decode tasks from them.
16861715
pub struct StructuralBatchDecodeStream {
16871716
context: DecoderContext,
@@ -1702,6 +1731,9 @@ pub struct StructuralBatchDecodeStream {
17021731
spawn_batch_decode_tasks: bool,
17031732
/// If set, target this many bytes per batch instead of `rows_per_batch` rows.
17041733
batch_size_bytes: Option<u64>,
1734+
/// Schema-based estimate of bytes per row, computed once at construction.
1735+
/// Only meaningful when `batch_size_bytes` is `Some`.
1736+
schema_bytes_per_row: f64,
17051737
}
17061738

17071739
impl StructuralBatchDecodeStream {
@@ -1722,6 +1754,11 @@ impl StructuralBatchDecodeStream {
17221754
spawn_batch_decode_tasks: bool,
17231755
batch_size_bytes: Option<u64>,
17241756
) -> Self {
1757+
let schema_bytes_per_row = if batch_size_bytes.is_some() {
1758+
estimate_bytes_per_row(root_decoder.data_type()).max(1.0)
1759+
} else {
1760+
0.0
1761+
};
17251762
Self {
17261763
context: DecoderContext::new(scheduled),
17271764
root_decoder,
@@ -1733,6 +1770,7 @@ impl StructuralBatchDecodeStream {
17331770
emitted_batch_size_warning: Arc::new(Once::new()),
17341771
spawn_batch_decode_tasks,
17351772
batch_size_bytes,
1773+
schema_bytes_per_row,
17361774
}
17371775
}
17381776

@@ -1775,7 +1813,12 @@ impl StructuralBatchDecodeStream {
17751813
return Ok(None);
17761814
}
17771815

1778-
let mut to_take = self.rows_remaining.min(self.rows_per_batch as u64);
1816+
let mut to_take = if let Some(batch_size_bytes) = self.batch_size_bytes {
1817+
let rows = (batch_size_bytes as f64 / self.schema_bytes_per_row) as u64;
1818+
self.rows_remaining.min(rows.max(1))
1819+
} else {
1820+
self.rows_remaining.min(self.rows_per_batch as u64)
1821+
};
17791822
self.rows_remaining -= to_take;
17801823

17811824
let scheduled_need = (self.rows_drained + to_take).saturating_sub(self.rows_scheduled);
@@ -2774,4 +2817,178 @@ mod tests {
27742817
let ranges = DecodeBatchScheduler::indices_to_ranges(&indices);
27752818
assert_eq!(ranges, vec![1..4, 5..8, 9..10]);
27762819
}
2820+
2821+
#[test]
2822+
fn test_estimate_bytes_per_row() {
2823+
assert_eq!(estimate_bytes_per_row(&DataType::Int32), 4.0);
2824+
assert_eq!(estimate_bytes_per_row(&DataType::Int64), 8.0);
2825+
assert_eq!(estimate_bytes_per_row(&DataType::Float32), 4.0);
2826+
assert_eq!(estimate_bytes_per_row(&DataType::Boolean), 1.0);
2827+
assert_eq!(estimate_bytes_per_row(&DataType::Utf8), 64.0);
2828+
assert_eq!(estimate_bytes_per_row(&DataType::Binary), 64.0);
2829+
// Struct of 4 x Int32 = 16 bytes
2830+
let struct_type = DataType::Struct(Fields::from(vec![
2831+
ArrowField::new("a", DataType::Int32, false),
2832+
ArrowField::new("b", DataType::Int32, false),
2833+
ArrowField::new("c", DataType::Int32, false),
2834+
ArrowField::new("d", DataType::Int32, false),
2835+
]));
2836+
assert_eq!(estimate_bytes_per_row(&struct_type), 16.0);
2837+
}
2838+
2839+
/// Helper: encode a batch, then decode it as a stream with optional
2840+
/// `batch_size_bytes`, collecting all output batches.
2841+
async fn decode_batches_with_byte_limit(
2842+
batch: &RecordBatch,
2843+
batch_size: u32,
2844+
batch_size_bytes: Option<u64>,
2845+
) -> Vec<RecordBatch> {
2846+
use crate::encoder::{default_encoding_strategy, encode_batch, EncodingOptions};
2847+
use crate::version::LanceFileVersion;
2848+
2849+
let version = LanceFileVersion::V2_1;
2850+
let options = EncodingOptions {
2851+
version,
2852+
..Default::default()
2853+
};
2854+
let strategy = default_encoding_strategy(version);
2855+
let schema = Schema::try_from(batch.schema().as_ref()).unwrap();
2856+
let encoded = encode_batch(batch, Arc::new(schema.clone()), strategy.as_ref(), &options)
2857+
.await
2858+
.unwrap();
2859+
2860+
let io_scheduler =
2861+
Arc::new(BufferScheduler::new(encoded.data.clone())) as Arc<dyn EncodingsIo>;
2862+
let cache = Arc::new(lance_core::cache::LanceCache::with_capacity(128 * 1024 * 1024));
2863+
let decoder_plugins = Arc::new(DecoderPlugins::default());
2864+
2865+
let mut decode_scheduler = DecodeBatchScheduler::try_new(
2866+
encoded.schema.as_ref(),
2867+
&encoded.top_level_columns,
2868+
&encoded.page_table,
2869+
&vec![],
2870+
encoded.num_rows,
2871+
decoder_plugins,
2872+
io_scheduler.clone(),
2873+
cache,
2874+
&FilterExpression::no_filter(),
2875+
&DecoderConfig::default(),
2876+
)
2877+
.await
2878+
.unwrap();
2879+
2880+
let (tx, rx) = unbounded_channel();
2881+
decode_scheduler.schedule_range(
2882+
0..encoded.num_rows,
2883+
&FilterExpression::no_filter(),
2884+
tx,
2885+
io_scheduler,
2886+
);
2887+
2888+
let mut decode_stream = create_decode_stream(
2889+
&encoded.schema,
2890+
encoded.num_rows,
2891+
batch_size,
2892+
/*is_structural=*/ true,
2893+
/*should_validate=*/ true,
2894+
/*spawn_structural_batch_decode_tasks=*/ true,
2895+
rx,
2896+
batch_size_bytes,
2897+
)
2898+
.unwrap();
2899+
2900+
let mut batches = Vec::new();
2901+
while let Some(task) = decode_stream.next().await {
2902+
batches.push(task.task.await.unwrap());
2903+
}
2904+
batches
2905+
}
2906+
2907+
#[tokio::test]
2908+
async fn test_byte_sized_batches_fixed_width() {
2909+
use arrow_array::Int32Array;
2910+
2911+
// 1000 rows x 4 Int32 columns = 16 bytes/row
2912+
let num_rows = 1000;
2913+
let arrays: Vec<Arc<dyn arrow_array::Array>> = (0..4)
2914+
.map(|col| {
2915+
Arc::new(Int32Array::from_iter_values(
2916+
(0..num_rows).map(|row| (row * 10 + col) as i32),
2917+
)) as _
2918+
})
2919+
.collect();
2920+
2921+
let schema = Arc::new(ArrowSchema::new(vec![
2922+
ArrowField::new("a", DataType::Int32, false),
2923+
ArrowField::new("b", DataType::Int32, false),
2924+
ArrowField::new("c", DataType::Int32, false),
2925+
ArrowField::new("d", DataType::Int32, false),
2926+
]));
2927+
let input_batch = RecordBatch::try_new(schema, arrays).unwrap();
2928+
2929+
// 16 bytes/row, batch_size_bytes=1600 => 100 rows/batch
2930+
let batches =
2931+
decode_batches_with_byte_limit(&input_batch, /*batch_size=*/ 1024, Some(1600)).await;
2932+
2933+
// Should produce 10 batches of 100 rows each
2934+
assert_eq!(batches.len(), 10);
2935+
for (i, batch) in batches.iter().enumerate() {
2936+
assert_eq!(
2937+
batch.num_rows(),
2938+
100,
2939+
"batch {i} should have 100 rows, got {}",
2940+
batch.num_rows()
2941+
);
2942+
}
2943+
2944+
// Verify roundtrip: concatenate and compare
2945+
let all_batches: Vec<&RecordBatch> = batches.iter().collect();
2946+
let concatenated = arrow_select::concat::concat_batches(
2947+
&batches[0].schema(),
2948+
all_batches.iter().copied(),
2949+
)
2950+
.unwrap();
2951+
assert_eq!(concatenated.num_rows(), num_rows as usize);
2952+
for col in 0..4 {
2953+
assert_eq!(
2954+
concatenated.column(col).as_ref(),
2955+
input_batch.column(col).as_ref(),
2956+
"column {col} roundtrip mismatch"
2957+
);
2958+
}
2959+
}
2960+
2961+
#[tokio::test]
2962+
async fn test_byte_sized_batches_none_unchanged() {
2963+
use arrow_array::Int32Array;
2964+
2965+
// Without batch_size_bytes, rows_per_batch controls batching
2966+
let num_rows = 1000;
2967+
let arrays: Vec<Arc<dyn arrow_array::Array>> = (0..2)
2968+
.map(|col| {
2969+
Arc::new(Int32Array::from_iter_values(
2970+
(0..num_rows).map(|row| (row * 10 + col) as i32),
2971+
)) as _
2972+
})
2973+
.collect();
2974+
2975+
let schema = Arc::new(ArrowSchema::new(vec![
2976+
ArrowField::new("x", DataType::Int32, false),
2977+
ArrowField::new("y", DataType::Int32, false),
2978+
]));
2979+
let input_batch = RecordBatch::try_new(schema, arrays).unwrap();
2980+
2981+
// batch_size=250, batch_size_bytes=None => 4 batches of 250 rows
2982+
let batches =
2983+
decode_batches_with_byte_limit(&input_batch, /*batch_size=*/ 250, None).await;
2984+
assert_eq!(batches.len(), 4);
2985+
for (i, batch) in batches.iter().enumerate() {
2986+
assert_eq!(
2987+
batch.num_rows(),
2988+
250,
2989+
"batch {i} should have 250 rows, got {}",
2990+
batch.num_rows()
2991+
);
2992+
}
2993+
}
27772994
}

0 commit comments

Comments
 (0)