Skip to content

Commit 18a7b66

Browse files
westonpaceclaude
andcommitted
feat: post-decode feedback loop for byte-sized batches
After each batch is decoded, measure the actual data bytes per row and feed it back so that the next `next_batch_task()` call uses the measured value instead of the schema-based estimate. This corrects for inaccurate initial estimates on variable-width data (strings, binary) where the schema default of 64 bytes may be far off. The measurement uses `batch_data_size()`, a new helper that computes the actual data contribution of a batch by walking column types and reading offsets for variable-width arrays. This avoids the over-counting from `get_array_memory_size()` which reports full shared page-buffer capacity rather than per-batch data. Part of #6387 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c52b4cc commit 18a7b66

1 file changed

Lines changed: 158 additions & 6 deletions

File tree

rust/lance-encoding/src/decoder.rs

Lines changed: 158 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@
213213
//! relation to the way the data is stored.
214214
215215
use std::collections::VecDeque;
216+
use std::sync::atomic::{AtomicU64, Ordering};
216217
use std::sync::{LazyLock, Once, OnceLock};
217218
use std::{ops::Range, sync::Arc};
218219

@@ -1682,6 +1683,73 @@ impl<T: RootDecoderType> RecordBatchReader for BatchDecodeIterator<T> {
16821683
}
16831684
}
16841685

1686+
/// Compute the actual data size (in bytes) of a record batch,
1687+
/// accounting only for the portion of buffers that belongs to the
1688+
/// batch's row range. Unlike `get_array_memory_size()`, this does
1689+
/// not over-count when arrays share a larger underlying page buffer.
1690+
fn batch_data_size(batch: &RecordBatch) -> u64 {
1691+
batch
1692+
.columns()
1693+
.iter()
1694+
.map(|c| array_data_size(c.as_ref()))
1695+
.sum()
1696+
}
1697+
1698+
fn array_data_size(array: &dyn arrow_array::Array) -> u64 {
1699+
let dt = array.data_type();
1700+
let n = array.len() as u64;
1701+
if let Some(w) = dt.byte_width_opt() {
1702+
return n * w as u64;
1703+
}
1704+
match dt {
1705+
DataType::Boolean => n.div_ceil(8),
1706+
DataType::Utf8 => {
1707+
let arr = array.as_string::<i32>();
1708+
let offsets = arr.value_offsets();
1709+
(offsets[n as usize] - offsets[0]) as u64
1710+
}
1711+
DataType::LargeUtf8 => {
1712+
let arr = array.as_string::<i64>();
1713+
let offsets = arr.value_offsets();
1714+
(offsets[n as usize] - offsets[0]) as u64
1715+
}
1716+
DataType::Binary => {
1717+
let arr = array.as_binary::<i32>();
1718+
let offsets = arr.value_offsets();
1719+
(offsets[n as usize] - offsets[0]) as u64
1720+
}
1721+
DataType::LargeBinary => {
1722+
let arr = array.as_binary::<i64>();
1723+
let offsets = arr.value_offsets();
1724+
(offsets[n as usize] - offsets[0]) as u64
1725+
}
1726+
DataType::Struct(fields) => {
1727+
let s = array.as_struct();
1728+
fields
1729+
.iter()
1730+
.enumerate()
1731+
.map(|(i, _)| array_data_size(s.column(i).as_ref()))
1732+
.sum()
1733+
}
1734+
DataType::List(_) => {
1735+
let list = array.as_list::<i32>();
1736+
array_data_size(list.values().as_ref())
1737+
}
1738+
DataType::LargeList(_) => {
1739+
let list = array.as_list::<i64>();
1740+
array_data_size(list.values().as_ref())
1741+
}
1742+
DataType::FixedSizeList(_, _) => {
1743+
let list = array
1744+
.as_any()
1745+
.downcast_ref::<arrow_array::FixedSizeListArray>()
1746+
.unwrap();
1747+
array_data_size(list.values().as_ref())
1748+
}
1749+
_ => n * 64, // fallback for uncommon types
1750+
}
1751+
}
1752+
16851753
/// Estimate the number of bytes per row for a given Arrow data type.
16861754
///
16871755
/// For fixed-width types this is exact. For variable-width types (strings,
@@ -1734,6 +1802,9 @@ pub struct StructuralBatchDecodeStream {
17341802
/// Schema-based estimate of bytes per row, computed once at construction.
17351803
/// Only meaningful when `batch_size_bytes` is `Some`.
17361804
schema_bytes_per_row: f64,
1805+
/// Post-decode feedback: actual bytes-per-row measured from the most
1806+
/// recently decoded batch. Zero means no feedback yet (use schema estimate).
1807+
bytes_per_row_feedback: Arc<AtomicU64>,
17371808
}
17381809

17391810
impl StructuralBatchDecodeStream {
@@ -1771,6 +1842,7 @@ impl StructuralBatchDecodeStream {
17711842
spawn_batch_decode_tasks,
17721843
batch_size_bytes,
17731844
schema_bytes_per_row,
1845+
bytes_per_row_feedback: Arc::new(AtomicU64::new(0)),
17741846
}
17751847
}
17761848

@@ -1814,7 +1886,13 @@ impl StructuralBatchDecodeStream {
18141886
}
18151887

18161888
let mut to_take = if let Some(batch_size_bytes) = self.batch_size_bytes {
1817-
let rows = (batch_size_bytes as f64 / self.schema_bytes_per_row) as u64;
1889+
let feedback = self.bytes_per_row_feedback.load(Ordering::Relaxed);
1890+
let bpr = if feedback > 0 {
1891+
feedback as f64
1892+
} else {
1893+
self.schema_bytes_per_row
1894+
};
1895+
let rows = (batch_size_bytes as f64 / bpr) as u64;
18181896
self.rows_remaining.min(rows.max(1))
18191897
} else {
18201898
self.rows_remaining.min(self.rows_per_batch as u64)
@@ -1854,20 +1932,30 @@ impl StructuralBatchDecodeStream {
18541932
let next_task = next_task.transpose().map(|next_task| {
18551933
let num_rows = next_task.as_ref().map(|t| t.num_rows).unwrap_or(0);
18561934
let emitted_batch_size_warning = slf.emitted_batch_size_warning.clone();
1935+
let bytes_per_row_feedback = slf.bytes_per_row_feedback.clone();
18571936
// Capture the per-stream policy once so every emitted batch task follows the
18581937
// same throughput-vs-overhead choice made by the scheduler.
18591938
let spawn_batch_decode_tasks = slf.spawn_batch_decode_tasks;
18601939
let task = async move {
18611940
let next_task = next_task?;
1862-
if spawn_batch_decode_tasks {
1941+
let batch = if spawn_batch_decode_tasks {
18631942
tokio::spawn(
18641943
async move { next_task.into_batch(emitted_batch_size_warning) },
18651944
)
18661945
.await
18671946
.map_err(|err| Error::wrapped(err.into()))?
18681947
} else {
18691948
next_task.into_batch(emitted_batch_size_warning)
1949+
};
1950+
if let Ok(ref b) = batch {
1951+
let num_rows = b.num_rows() as u64;
1952+
if num_rows > 0 {
1953+
let bpr = batch_data_size(b) / num_rows;
1954+
bytes_per_row_feedback
1955+
.store(bpr.max(1), Ordering::Relaxed);
1956+
}
18701957
}
1958+
batch
18711959
};
18721960
(task, num_rows)
18731961
});
@@ -1978,6 +2066,7 @@ fn check_scheduler_on_drop(
19782066
.boxed()
19792067
}
19802068

2069+
#[allow(clippy::too_many_arguments)]
19812070
pub fn create_decode_stream(
19822071
schema: &Schema,
19832072
num_rows: u64,
@@ -2909,11 +2998,11 @@ mod tests {
29092998
use arrow_array::Int32Array;
29102999

29113000
// 1000 rows x 4 Int32 columns = 16 bytes/row
2912-
let num_rows = 1000;
3001+
let num_rows: i32 = 1000;
29133002
let arrays: Vec<Arc<dyn arrow_array::Array>> = (0..4)
29143003
.map(|col| {
29153004
Arc::new(Int32Array::from_iter_values(
2916-
(0..num_rows).map(|row| (row * 10 + col) as i32),
3005+
(0..num_rows).map(move |row| row * 10 + col),
29173006
)) as _
29183007
})
29193008
.collect();
@@ -2963,11 +3052,11 @@ mod tests {
29633052
use arrow_array::Int32Array;
29643053

29653054
// Without batch_size_bytes, rows_per_batch controls batching
2966-
let num_rows = 1000;
3055+
let num_rows: i32 = 1000;
29673056
let arrays: Vec<Arc<dyn arrow_array::Array>> = (0..2)
29683057
.map(|col| {
29693058
Arc::new(Int32Array::from_iter_values(
2970-
(0..num_rows).map(|row| (row * 10 + col) as i32),
3059+
(0..num_rows).map(move |row| row * 10 + col),
29713060
)) as _
29723061
})
29733062
.collect();
@@ -2991,4 +3080,67 @@ mod tests {
29913080
);
29923081
}
29933082
}
3083+
3084+
#[tokio::test]
3085+
async fn test_byte_sized_batches_feedback_convergence() {
3086+
use arrow_array::StringArray;
3087+
3088+
// Each row has a 100-byte string. Schema estimate = 64 bytes (default
3089+
// for Utf8), so the first batch will overshoot. The feedback loop
3090+
// should correct subsequent batches toward the target.
3091+
let num_rows = 500;
3092+
let value: String = "x".repeat(100);
3093+
let arrays: Vec<Arc<dyn arrow_array::Array>> = vec![Arc::new(StringArray::from(
3094+
(0..num_rows).map(|_| value.as_str()).collect::<Vec<_>>(),
3095+
))];
3096+
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
3097+
"s",
3098+
DataType::Utf8,
3099+
false,
3100+
)]));
3101+
let input_batch = RecordBatch::try_new(schema, arrays).unwrap();
3102+
3103+
// Target 5000 bytes/batch. At 100 bytes/row the ideal is 50 rows/batch.
3104+
// Schema estimate is 64 bytes/row → first batch ~78 rows (overshoot).
3105+
// After feedback kicks in, batches should converge to ~50 rows.
3106+
let target_bytes: u64 = 5000;
3107+
let batches =
3108+
decode_batches_with_byte_limit(&input_batch, /*batch_size=*/ 1024, Some(target_bytes))
3109+
.await;
3110+
3111+
// Verify all data round-trips correctly
3112+
let all_batches: Vec<&RecordBatch> = batches.iter().collect();
3113+
let concatenated = arrow_select::concat::concat_batches(
3114+
&batches[0].schema(),
3115+
all_batches.iter().copied(),
3116+
)
3117+
.unwrap();
3118+
assert_eq!(concatenated.num_rows(), num_rows as usize);
3119+
assert_eq!(
3120+
concatenated.column(0).as_ref(),
3121+
input_batch.column(0).as_ref()
3122+
);
3123+
3124+
// After the first batch, subsequent batches should be closer to the
3125+
// target. The ideal is 50 rows/batch.
3126+
assert!(
3127+
batches.len() >= 2,
3128+
"need at least 2 batches to test convergence"
3129+
);
3130+
// The first batch uses the schema estimate (64 bytes/row) →
3131+
// ~78 rows. After feedback the rows should settle near 50.
3132+
if batches.len() >= 3 {
3133+
let second_batch_rows = batches[1].num_rows();
3134+
let third_batch_rows = batches[2].num_rows();
3135+
// Both should be within 20% of the ideal (50 rows)
3136+
assert!(
3137+
(40..=60).contains(&second_batch_rows),
3138+
"second batch should be near 50 rows, got {second_batch_rows}"
3139+
);
3140+
assert!(
3141+
(40..=60).contains(&third_batch_rows),
3142+
"third batch should be near 50 rows, got {third_batch_rows}"
3143+
);
3144+
}
3145+
}
29943146
}

0 commit comments

Comments
 (0)