Skip to content

Commit 4feba22

Browse files
Dandandanclaude
andcommitted
Add batch pass-through optimization to SortPreservingMergeExec
When the loser tree winner's entire remaining batch is strictly less than every other stream's current value, skip per-row loser-tree comparisons and emit the batch directly. Two fast paths: - Zero-copy: when in_progress buffer is empty and the full batch qualifies, slice and return the RecordBatch without interleave - Bulk-push: otherwise append all qualifying rows at once, avoiding O(remaining × log K) loser-tree work The runner-up is found by walking the winner's loser-tree path (O(log K)), and the check is only performed at the start of each new batch to amortise cost. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1e93a67 commit 4feba22

4 files changed

Lines changed: 332 additions & 0 deletions

File tree

datafusion/physical-plan/src/sorts/builder.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,34 @@ impl BatchBuilder {
116116
self.indices.push((cursor.batch_idx, row_idx));
117117
}
118118

119+
/// Append `count` consecutive rows from `stream_idx`
120+
pub fn push_rows(&mut self, stream_idx: usize, count: usize) {
121+
let cursor = &mut self.cursors[stream_idx];
122+
let batch_idx = cursor.batch_idx;
123+
let start_row = cursor.row_idx;
124+
self.indices
125+
.extend((0..count).map(|i| (batch_idx, start_row + i)));
126+
cursor.row_idx += count;
127+
}
128+
129+
/// Slice the current batch for `stream_idx` starting at its cursor
130+
/// position, returning `num_rows` rows as a zero-copy [`RecordBatch`].
131+
///
132+
/// Advances the builder's cursor but does **not** touch `self.indices`,
133+
/// so the caller must not also call `push_row`/`push_rows` for these
134+
/// rows.
135+
pub fn take_batch_slice(
136+
&mut self,
137+
stream_idx: usize,
138+
num_rows: usize,
139+
) -> RecordBatch {
140+
let cursor = &mut self.cursors[stream_idx];
141+
let (_, batch) = &self.batches[cursor.batch_idx];
142+
let sliced = batch.slice(cursor.row_idx, num_rows);
143+
cursor.row_idx += num_rows;
144+
sliced
145+
}
146+
119147
/// Returns the number of in-progress rows in this [`BatchBuilder`]
120148
pub fn len(&self) -> usize {
121149
self.indices.len()

datafusion/physical-plan/src/sorts/cursor.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,41 @@ impl<T: CursorValues> Cursor<T> {
9393
self.offset == self.values.len()
9494
}
9595

96+
/// Returns true if the cursor is at the start (offset 0)
97+
pub fn is_at_start(&self) -> bool {
98+
self.offset == 0
99+
}
100+
101+
/// Returns the number of remaining rows (including current position)
102+
pub fn remaining(&self) -> usize {
103+
self.values.len() - self.offset
104+
}
105+
96106
/// Advance the cursor, returning the previous row index
97107
pub fn advance(&mut self) -> usize {
98108
let t = self.offset;
99109
self.offset += 1;
100110
t
101111
}
102112

113+
/// Advance the cursor by `n` positions
114+
pub fn advance_by(&mut self, n: usize) {
115+
self.offset += n;
116+
}
117+
118+
/// Compare the last value in this cursor with the current value of `other`.
119+
///
120+
/// Returns [`Ordering::Less`] if the last value of this cursor comes
121+
/// before `other`'s current value in sort order.
122+
pub fn last_cmp(&self, other: &Self) -> Ordering {
123+
T::compare(
124+
&self.values,
125+
self.values.len() - 1,
126+
&other.values,
127+
other.offset,
128+
)
129+
}
130+
103131
pub fn is_eq_to_prev_one(&self, prev_cursor: Option<&Cursor<T>>) -> bool {
104132
if self.offset > 0 {
105133
self.is_eq_to_prev_row()

datafusion/physical-plan/src/sorts/merge.rs

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,72 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
302302
}
303303

304304
let stream_idx = self.loser_tree[0];
305+
306+
// Batch pass-through: when the winner's entire remaining
307+
// batch is strictly less than every other stream's current
308+
// value we can skip per-row loser-tree comparisons.
309+
// Only check at the start of a new batch to amortise the
310+
// O(log K) runner-up lookup.
311+
if self.cursors[stream_idx]
312+
.as_ref()
313+
.is_some_and(|c| c.is_at_start())
314+
&& self.can_batch_pass_through(stream_idx)
315+
{
316+
let remaining = self.cursors[stream_idx].as_ref().unwrap().remaining();
317+
let space_in_batch =
318+
self.batch_size.saturating_sub(self.in_progress.len());
319+
let fetch_remaining = self
320+
.fetch
321+
.map(|f| f.saturating_sub(self.produced + self.in_progress.len()))
322+
.unwrap_or(usize::MAX);
323+
let rows_to_add = remaining.min(space_in_batch).min(fetch_remaining);
324+
325+
if rows_to_add > 0 {
326+
// Zero-copy fast path: emit a batch slice directly when
327+
// the in-progress buffer is empty and we can take the
328+
// entire remaining batch.
329+
if self.in_progress.is_empty() && rows_to_add == remaining {
330+
let batch =
331+
self.in_progress.take_batch_slice(stream_idx, rows_to_add);
332+
self.produced += rows_to_add;
333+
334+
let cursor = self.cursors[stream_idx].as_mut().unwrap();
335+
cursor.advance_by(rows_to_add);
336+
if cursor.is_finished() {
337+
self.prev_cursors[stream_idx] =
338+
self.cursors[stream_idx].take();
339+
}
340+
self.loser_tree_adjusted = false;
341+
342+
if self.fetch_reached() {
343+
self.done = true;
344+
}
345+
return Poll::Ready(Some(Ok(batch)));
346+
}
347+
348+
// Bulk-push path: append all qualifying rows at once,
349+
// avoiding per-row loser-tree work.
350+
self.in_progress.push_rows(stream_idx, rows_to_add);
351+
352+
let cursor = self.cursors[stream_idx].as_mut().unwrap();
353+
cursor.advance_by(rows_to_add);
354+
if cursor.is_finished() {
355+
self.prev_cursors[stream_idx] = self.cursors[stream_idx].take();
356+
}
357+
self.loser_tree_adjusted = false;
358+
359+
if self.fetch_reached() {
360+
self.done = true;
361+
self.drain_in_progress_on_done = true;
362+
} else if self.in_progress.len() < self.batch_size {
363+
continue;
364+
}
365+
366+
return Poll::Ready(self.emit_in_progress_batch().transpose());
367+
}
368+
}
369+
370+
// Normal row-by-row path
305371
if self.advance_cursors(stream_idx) {
306372
self.loser_tree_adjusted = false;
307373
self.in_progress.push_row(stream_idx);
@@ -341,6 +407,48 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
341407
}
342408
}
343409

410+
/// Walk the loser tree to find the runner-up (second-smallest current
411+
/// value). This is the minimum of the losers along the winner's path
412+
/// from leaf to root. Cost: O(log K).
413+
fn find_runner_up(&self) -> Option<usize> {
414+
let winner = self.loser_tree[0];
415+
let num_streams = self.cursors.len();
416+
let mut runner_up: Option<usize> = None;
417+
418+
let mut node = self.lt_leaf_node_index(winner);
419+
while node != 0 {
420+
let loser = self.loser_tree[node];
421+
if loser < num_streams && self.cursors[loser].is_some() {
422+
runner_up = Some(match runner_up {
423+
None => loser,
424+
Some(current) if self.is_gt(current, loser) => loser,
425+
Some(current) => current,
426+
});
427+
}
428+
node = self.lt_parent_node_index(node);
429+
}
430+
runner_up
431+
}
432+
433+
/// Returns `true` when the winner's entire remaining batch is strictly
434+
/// less than every other stream's current value, meaning those rows can
435+
/// be emitted without per-row loser-tree comparisons.
436+
fn can_batch_pass_through(&self, winner: usize) -> bool {
437+
let winner_cursor = match &self.cursors[winner] {
438+
Some(c) if c.remaining() > 1 => c,
439+
_ => return false,
440+
};
441+
442+
match self.find_runner_up() {
443+
// All other streams exhausted — pass through is safe
444+
None => true,
445+
Some(runner_up) => {
446+
let runner_up_cursor = self.cursors[runner_up].as_ref().unwrap();
447+
winner_cursor.last_cmp(runner_up_cursor).is_lt()
448+
}
449+
}
450+
}
451+
344452
fn fetch_reached(&mut self) -> bool {
345453
self.fetch
346454
.map(|fetch| self.produced + self.in_progress.len() >= fetch)

datafusion/physical-plan/src/sorts/sort_preserving_merge.rs

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,4 +1595,172 @@ mod tests {
15951595

15961596
Ok(())
15971597
}
1598+
1599+
async fn _test_merge_sort_by_b(
1600+
partitions: &[Vec<RecordBatch>],
1601+
exp: &[&str],
1602+
context: Arc<TaskContext>,
1603+
) {
1604+
let schema = partitions[0][0].schema();
1605+
let sort: LexOrdering = [PhysicalSortExpr {
1606+
expr: col("b", &schema).unwrap(),
1607+
options: Default::default(),
1608+
}]
1609+
.into();
1610+
let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap();
1611+
let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1612+
let collected = collect(merge, context).await.unwrap();
1613+
assert_batches_eq!(exp, collected.as_slice());
1614+
}
1615+
1616+
/// Test batch pass-through with multiple non-overlapping batches per
1617+
/// partition, ensuring cursor advancement and batch cleanup work.
1618+
#[tokio::test]
1619+
async fn test_batch_pass_through_multi_batch() {
1620+
let task_ctx = Arc::new(TaskContext::default());
1621+
1622+
// Partition 0: two batches [a, b] then [c, d]
1623+
let b0a = RecordBatch::try_from_iter(vec![
1624+
("a", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef),
1625+
("b", Arc::new(StringArray::from(vec!["a", "b"])) as ArrayRef),
1626+
])
1627+
.unwrap();
1628+
let b0b = RecordBatch::try_from_iter(vec![
1629+
("a", Arc::new(Int32Array::from(vec![3, 4])) as ArrayRef),
1630+
("b", Arc::new(StringArray::from(vec!["c", "d"])) as ArrayRef),
1631+
])
1632+
.unwrap();
1633+
1634+
// Partition 1: two batches [e, f] then [g, h]
1635+
let b1a = RecordBatch::try_from_iter(vec![
1636+
("a", Arc::new(Int32Array::from(vec![5, 6])) as ArrayRef),
1637+
("b", Arc::new(StringArray::from(vec!["e", "f"])) as ArrayRef),
1638+
])
1639+
.unwrap();
1640+
let b1b = RecordBatch::try_from_iter(vec![
1641+
("a", Arc::new(Int32Array::from(vec![7, 8])) as ArrayRef),
1642+
("b", Arc::new(StringArray::from(vec!["g", "h"])) as ArrayRef),
1643+
])
1644+
.unwrap();
1645+
1646+
_test_merge_sort_by_b(
1647+
&[vec![b0a, b0b], vec![b1a, b1b]],
1648+
&[
1649+
"+---+---+",
1650+
"| a | b |",
1651+
"+---+---+",
1652+
"| 1 | a |",
1653+
"| 2 | b |",
1654+
"| 3 | c |",
1655+
"| 4 | d |",
1656+
"| 5 | e |",
1657+
"| 6 | f |",
1658+
"| 7 | g |",
1659+
"| 8 | h |",
1660+
"+---+---+",
1661+
],
1662+
task_ctx,
1663+
)
1664+
.await;
1665+
}
1666+
1667+
/// Test batch pass-through with a fetch limit that cuts through a
1668+
/// pass-through batch.
1669+
#[tokio::test]
1670+
async fn test_batch_pass_through_with_fetch() -> Result<()> {
1671+
let schema = Arc::new(Schema::new(vec![
1672+
Field::new("a", DataType::Int32, false),
1673+
Field::new("b", DataType::Utf8, false),
1674+
]));
1675+
1676+
// Partition 0: [a, b, c]
1677+
let b0 = RecordBatch::try_new(
1678+
Arc::clone(&schema),
1679+
vec![
1680+
Arc::new(Int32Array::from(vec![1, 2, 3])),
1681+
Arc::new(StringArray::from(vec!["a", "b", "c"])),
1682+
],
1683+
)?;
1684+
1685+
// Partition 1: [x, y, z] — completely non-overlapping
1686+
let b1 = RecordBatch::try_new(
1687+
Arc::clone(&schema),
1688+
vec![
1689+
Arc::new(Int32Array::from(vec![4, 5, 6])),
1690+
Arc::new(StringArray::from(vec!["x", "y", "z"])),
1691+
],
1692+
)?;
1693+
1694+
let task_ctx = Arc::new(TaskContext::default());
1695+
let sort: LexOrdering = [PhysicalSortExpr {
1696+
expr: col("b", &schema)?,
1697+
options: Default::default(),
1698+
}]
1699+
.into();
1700+
let exec = TestMemoryExec::try_new_exec(&[vec![b0], vec![b1]], schema, None)?;
1701+
let merge =
1702+
Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(4)));
1703+
1704+
let collected = collect(merge, task_ctx).await?;
1705+
assert_batches_eq!(
1706+
[
1707+
"+---+---+",
1708+
"| a | b |",
1709+
"+---+---+",
1710+
"| 1 | a |",
1711+
"| 2 | b |",
1712+
"| 3 | c |",
1713+
"| 4 | x |",
1714+
"+---+---+",
1715+
],
1716+
collected.as_slice()
1717+
);
1718+
Ok(())
1719+
}
1720+
1721+
/// Test that the merge is still correct when batches partially overlap
1722+
/// (only some partitions qualify for pass-through).
1723+
#[tokio::test]
1724+
async fn test_batch_pass_through_partial_overlap() {
1725+
let task_ctx = Arc::new(TaskContext::default());
1726+
1727+
// Partition 0: [a, b] — non-overlapping with partition 2
1728+
let b0 = RecordBatch::try_from_iter(vec![
1729+
("a", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef),
1730+
("b", Arc::new(StringArray::from(vec!["a", "b"])) as ArrayRef),
1731+
])
1732+
.unwrap();
1733+
1734+
// Partition 1: [b, d] — overlaps with partition 0 at "b"
1735+
let b1 = RecordBatch::try_from_iter(vec![
1736+
("a", Arc::new(Int32Array::from(vec![3, 4])) as ArrayRef),
1737+
("b", Arc::new(StringArray::from(vec!["b", "d"])) as ArrayRef),
1738+
])
1739+
.unwrap();
1740+
1741+
// Partition 2: [f, g] — non-overlapping with everything
1742+
let b2 = RecordBatch::try_from_iter(vec![
1743+
("a", Arc::new(Int32Array::from(vec![5, 6])) as ArrayRef),
1744+
("b", Arc::new(StringArray::from(vec!["f", "g"])) as ArrayRef),
1745+
])
1746+
.unwrap();
1747+
1748+
_test_merge_sort_by_b(
1749+
&[vec![b0], vec![b1], vec![b2]],
1750+
&[
1751+
"+---+---+",
1752+
"| a | b |",
1753+
"+---+---+",
1754+
"| 1 | a |",
1755+
"| 2 | b |",
1756+
"| 3 | b |",
1757+
"| 4 | d |",
1758+
"| 5 | f |",
1759+
"| 6 | g |",
1760+
"+---+---+",
1761+
],
1762+
task_ctx,
1763+
)
1764+
.await;
1765+
}
15981766
}

0 commit comments

Comments
 (0)