Skip to content

Commit 0be5982

Browse files
authored
perf: sort-merge join (SMJ) batch deferred filtering and move mark joins to bitwise stream. Near-unique LEFT and FULL SMJ 20-50x faster (#21184)
## Which issue does this PR close? Partially addresses #20910. Fixes #21197. ## Rationale for this change Sort-merge join with a filter on outer joins (LEFT/RIGHT/FULL) runs `process_filtered_batches()` on every key transition in the Init state. With near-unique keys (1:1 cardinality), this means running the full deferred filtering pipeline (concat + `get_corrected_filter_mask` + `filter_record_batch_by_join_type`) once per row — making filtered LEFT/RIGHT/FULL **55x slower** than INNER for 10M unique keys. Additionally, mark join logic in `MaterializingSortMergeJoinStream` materializes full `(streamed, buffered)` pairs only to discard most of them via `get_corrected_filter_mask()`. Mark joins are structurally identical to semi joins (one output row per outer row with a boolean result) and belong in `BitwiseSortMergeJoinStream`, which avoids pair materialization entirely using a per-outer-batch bitset. ## What changes are included in this PR? Three areas of improvement, building on the specialized semi/anti stream from #20806: **1. Move mark joins to `BitwiseSortMergeJoinStream`** - Match on join type; `emit_outer_batch()` emits all rows with the match bitset as a boolean column (vs semi's filter / anti's invert-and-filter) - Route `LeftMark`/`RightMark` from `SortMergeJoinExec::execute()` to the bitwise stream - Remove all mark-specific logic from `MaterializingSortMergeJoinStream` (`mark_row_as_match`, `is_not_null` column generation, mark arms in filter correction) **2. Batch filter evaluation in `freeze_streamed()`** - Split `freeze_streamed()` into null-joined classification + `freeze_streamed_matched()` for batched materialization - Collect indices across chunks, materialize left/right columns once using tiered Arrow kernels (`slice` → `take` → `interleave`) - Single `RecordBatch` construction and single `expression.evaluate()` per freeze instead of per chunk - Vectorize `append_filter_metadata()` using builder `extend()` instead of per-element loop **3. Batch deferred filtering in Init state** (this is the big win for Q22 and Q23) - Gate `process_filtered_batches()` on accumulated rows >= `batch_size` instead of running on every Init entry - Accumulated data bounded to ~2×batch_size (one from `freeze_dequeuing_buffered`, one accumulating toward next freeze) — does not reintroduce unbounded buffering fixed by PR #20482 - `Exhausted` state flushes any remainder **Cleanup:** - Rename `SortMergeJoinStream` → `MaterializingSortMergeJoinStream` (materializes explicit row pairs for join output) and `SemiAntiMarkSortMergeJoinStream` → `BitwiseSortMergeJoinStream` (tracks matches via boolean bitset) - Consolidate `semi_anti_mark_sort_merge_join/` into `sort_merge_join/` as `bitwise_stream.rs` / `bitwise_tests.rs`; rename `stream.rs` → `materializing_stream.rs` and `tests.rs` → `materializing_tests.rs` - Consolidate `SpillManager` construction into `SortMergeJoinExec::execute()` (shared across both streams); move `peak_mem_used` gauge into `BitwiseSortMergeJoinStream::try_new` - `MaterializingSortMergeJoinStream` now handles only Inner/Left/Right/Full — all semi/anti/mark branching removed - `get_corrected_filter_mask()`: merge identical Left/Right/Full branches; add null-metadata passthrough for already-null-joined rows - `filter_record_batch_by_join_type()`: rewrite from `filter(true) + filter(false) + concat` to `zip()` for in-place null-joining — preserves row ordering and removes `create_null_joined_batch()` entirely; add early return for empty batches - `filter_record_batch_by_join_type()`: use `compute::filter()` directly on `BooleanArray` instead of wrapping in temporary `RecordBatch` ## Benchmarks `cargo run --release --bin dfbench -- smj` | Query | Join Type | Rows | Keys | Filter | Main (ms) | PR (ms) | Speedup | |-------|-----------|------|------|--------|-----------|---------|---------| | Q1 | INNER | 1M×1M | 1:1 | — | 16.3 | 14.4 | 1.1x | | Q2 | INNER | 1M×10M | 1:10 | — | 117.4 | 120.1 | 1.0x | | Q3 | INNER | 1M×1M | 1:100 | — | 74.2 | 66.6 | 1.1x | | Q4 | INNER | 1M×10M | 1:10 | 1% | 17.1 | 15.1 | 1.1x | | Q5 | INNER | 1M×1M | 1:100 | 10% | 18.4 | 14.4 | 1.3x | | Q6 | LEFT | 1M×10M | 1:10 | — | 129.3 | 122.7 | 1.1x | | Q7 | LEFT | 1M×10M | 1:10 | 50% | 150.2 | 142.2 | 1.1x | | Q8 | FULL | 1M×1M | 1:10 | — | 16.6 | 16.7 | 1.0x | | Q9 | FULL | 1M×10M | 1:10 | 10% | 153.5 | 136.2 | 1.1x | | Q10 | LEFT SEMI | 1M×10M | 1:10 | — | 53.1 | 53.1 | 1.0x | | Q11 | LEFT SEMI | 1M×10M | 1:10 | 1% | 15.5 | 14.7 | 1.1x | | Q12 | LEFT SEMI | 1M×10M | 1:10 | 50% | 65.0 | 67.3 | 1.0x | | Q13 | LEFT SEMI | 1M×10M | 1:10 | 90% | 105.7 | 109.8 | 1.0x | | Q14 | LEFT ANTI | 1M×10M | 1:10 | — | 54.3 | 53.9 | 1.0x | | Q15 | LEFT ANTI | 1M×10M | 1:10 | partial | 51.5 | 50.5 | 1.0x | | Q16 | LEFT ANTI | 1M×1M | 1:1 | — | 10.3 | 11.3 | 0.9x | | Q17 | INNER | 1M×50M | 1:50 | 5% | 75.9 | 79.0 | 1.0x | | Q18 | LEFT SEMI | 1M×50M | 1:50 | 2% | 50.2 | 49.0 | 1.0x | | Q19 | LEFT ANTI | 1M×50M | 1:50 | partial | 336.4 | 344.2 | 1.0x | | Q20 | INNER | 1M×10M | 1:100 | GROUP BY | 763.7 | 803.9 | 1.0x | | Q21 | INNER | 10M×10M | 1:1 | 50% | 186.1 | 187.8 | 1.0x | | Q22 | LEFT | 10M×10M | 1:1 | 50% | 10,193.8 | 185.8 | **54.9x** | | Q23 | FULL | 10M×10M | 1:1 | 50% | 10,194.7 | 233.6 | **43.6x** | | Q24 | LEFT MARK | 1M×10M | 1:10 | 1% | FAILS | 15.1 | — | | Q25 | LEFT MARK | 1M×10M | 1:10 | 50% | FAILS | 67.3 | — | | Q26 | LEFT MARK | 1M×10M | 1:10 | 90% | FAILS | 110.0 | — | General workload (Q1-Q20, various join types/cardinalities/selectivities): no regressions. ## Are these changes tested? In addition to existing unit and sqllogictests: - I ran 50 iterations of the fuzz tests (modified to only test against hash join as the baseline because nested loop join takes too long) `cargo test -p datafusion --features extended_tests --test fuzz -- join_fuzz` - One new sqllogictest for #21197 that fails on main - Four new unit tests: three for full join with filter that spills - One new fuzz test to exercise full join with filter that spills - New benchmark queries Q21-Q23: 10M×10M unique keys with 50% join filter for INNER/LEFT/FULL — exercises the degenerate case this PR fixes - New benchmark queries Q24-Q26 duplicated Q11-Q13 but for Mark joins, showing that they have the same performance as other joins (`LeftSemi`) that use this stream ## Are there any user-facing changes? No.
1 parent 010e5ee commit 0be5982

File tree

13 files changed

+2006
-1872
lines changed

13 files changed

+2006
-1872
lines changed

benchmarks/src/smj.rs

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ use futures::StreamExt;
3939
#[derive(Debug, Args, Clone)]
4040
#[command(verbatim_doc_comment)]
4141
pub struct RunOpt {
42-
/// Query number (between 1 and 23). If not specified, runs all queries
42+
/// Query number (between 1 and 26). If not specified, runs all queries
4343
#[arg(short, long)]
4444
query: Option<usize>,
4545

@@ -456,6 +456,72 @@ const SMJ_QUERIES: &[&str] = &[
456456
ON t1_sorted.key = t2_sorted.key
457457
AND t1_sorted.data + t2_sorted.data < 10000000
458458
"#,
459+
// Q24: LEFT MARK 1M x 10M | 1:10 | 1%
460+
r#"
461+
WITH t1_sorted AS (
462+
SELECT value % 100000 as key, value as data
463+
FROM range(1000000)
464+
ORDER BY key, data
465+
),
466+
t2_sorted AS (
467+
SELECT value % 100000 as key, value as data
468+
FROM range(10000000)
469+
ORDER BY key, data
470+
)
471+
SELECT t1_sorted.key, t1_sorted.data
472+
FROM t1_sorted
473+
WHERE t1_sorted.data < 0
474+
OR EXISTS (
475+
SELECT 1 FROM t2_sorted
476+
WHERE t2_sorted.key = t1_sorted.key
477+
AND t2_sorted.data <> t1_sorted.data
478+
AND t2_sorted.data % 100 = 0
479+
)
480+
"#,
481+
// Q25: LEFT MARK 1M x 10M | 1:10 | 50%
482+
r#"
483+
WITH t1_sorted AS (
484+
SELECT value % 100000 as key, value as data
485+
FROM range(1000000)
486+
ORDER BY key, data
487+
),
488+
t2_sorted AS (
489+
SELECT value % 100000 as key, value as data
490+
FROM range(10000000)
491+
ORDER BY key, data
492+
)
493+
SELECT t1_sorted.key, t1_sorted.data
494+
FROM t1_sorted
495+
WHERE t1_sorted.data < 0
496+
OR EXISTS (
497+
SELECT 1 FROM t2_sorted
498+
WHERE t2_sorted.key = t1_sorted.key
499+
AND t2_sorted.data <> t1_sorted.data
500+
AND t2_sorted.data % 2 = 0
501+
)
502+
"#,
503+
// Q26: LEFT MARK 1M x 10M | 1:10 | 90%
504+
r#"
505+
WITH t1_sorted AS (
506+
SELECT value % 100000 as key, value as data
507+
FROM range(1000000)
508+
ORDER BY key, data
509+
),
510+
t2_sorted AS (
511+
SELECT value % 100000 as key, value as data
512+
FROM range(10000000)
513+
ORDER BY key, data
514+
)
515+
SELECT t1_sorted.key, t1_sorted.data
516+
FROM t1_sorted
517+
WHERE t1_sorted.data < 0
518+
OR EXISTS (
519+
SELECT 1 FROM t2_sorted
520+
WHERE t2_sorted.key = t1_sorted.key
521+
AND t2_sorted.data <> t1_sorted.data
522+
AND t2_sorted.data % 10 <> 0
523+
)
524+
"#,
459525
];
460526

461527
impl RunOpt {
@@ -489,7 +555,10 @@ impl RunOpt {
489555

490556
let sql = SMJ_QUERIES[query_index];
491557
benchmark_run.start_new_case(&format!("Query {query_id}"));
492-
let query_run = self.benchmark_query(sql, &query_id.to_string(), &ctx).await;
558+
let expect_mark = query_id >= 24;
559+
let query_run = self
560+
.benchmark_query(sql, &query_id.to_string(), expect_mark, &ctx)
561+
.await;
493562
match query_run {
494563
Ok(query_results) => {
495564
for iter in query_results {
@@ -513,6 +582,7 @@ impl RunOpt {
513582
&self,
514583
sql: &str,
515584
query_name: &str,
585+
expect_mark: bool,
516586
ctx: &SessionContext,
517587
) -> Result<Vec<QueryResult>> {
518588
let mut query_results = vec![];
@@ -528,6 +598,12 @@ impl RunOpt {
528598
));
529599
}
530600

601+
if expect_mark && !plan_string.contains("LeftMark") {
602+
return Err(exec_datafusion_err!(
603+
"Query {query_name} expected LeftMark join. Physical plan: {plan_string}"
604+
));
605+
}
606+
531607
for i in 0..self.common.iterations {
532608
let start = Instant::now();
533609

datafusion/core/tests/fuzz_cases/join_fuzz.rs

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ use datafusion::physical_plan::joins::{
3838
};
3939
use datafusion::prelude::{SessionConfig, SessionContext};
4040
use datafusion_common::{NullEquality, ScalarValue};
41+
use datafusion_execution::TaskContext;
42+
use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
43+
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
4144
use datafusion_physical_expr::PhysicalExprRef;
4245
use datafusion_physical_expr::expressions::Literal;
4346

@@ -1125,6 +1128,138 @@ impl JoinFuzzTestCase {
11251128
}
11261129
}
11271130

1131+
/// Fuzz test: compare SMJ (with spilling) against HJ (no spill) for filtered
1132+
/// outer joins under memory pressure. This exercises the deferred filtering +
1133+
/// spill read-back path that unit tests can't easily cover with random data.
1134+
#[tokio::test]
1135+
async fn test_filtered_join_spill_fuzz() {
1136+
let join_types = [JoinType::Left, JoinType::Right, JoinType::Full];
1137+
1138+
let runtime_spill = RuntimeEnvBuilder::new()
1139+
.with_memory_limit(4096, 1.0)
1140+
.with_disk_manager_builder(
1141+
DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
1142+
)
1143+
.build_arc()
1144+
.unwrap();
1145+
1146+
for join_type in &join_types {
1147+
for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] {
1148+
let input1 = make_staggered_batches_i32(1000, left_extra);
1149+
let input2 = make_staggered_batches_i32(1000, right_extra);
1150+
1151+
let schema1 = input1[0].schema();
1152+
let schema2 = input2[0].schema();
1153+
let filter = col_lt_col_filter(schema1.clone(), schema2.clone());
1154+
1155+
let on = vec![
1156+
(
1157+
Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _,
1158+
Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _,
1159+
),
1160+
(
1161+
Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _,
1162+
Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _,
1163+
),
1164+
];
1165+
1166+
for batch_size in [2, 49, 100] {
1167+
let session_config = SessionConfig::new().with_batch_size(batch_size);
1168+
1169+
// HJ baseline (no memory limit)
1170+
let left_hj = MemorySourceConfig::try_new_exec(
1171+
std::slice::from_ref(&input1),
1172+
schema1.clone(),
1173+
None,
1174+
)
1175+
.unwrap();
1176+
let right_hj = MemorySourceConfig::try_new_exec(
1177+
std::slice::from_ref(&input2),
1178+
schema2.clone(),
1179+
None,
1180+
)
1181+
.unwrap();
1182+
let hj = Arc::new(
1183+
HashJoinExec::try_new(
1184+
left_hj,
1185+
right_hj,
1186+
on.clone(),
1187+
Some(filter.clone()),
1188+
join_type,
1189+
None,
1190+
PartitionMode::Partitioned,
1191+
NullEquality::NullEqualsNothing,
1192+
false,
1193+
)
1194+
.unwrap(),
1195+
);
1196+
let ctx_hj = SessionContext::new_with_config(session_config.clone());
1197+
let hj_collected = collect(hj, ctx_hj.task_ctx()).await.unwrap();
1198+
1199+
// SMJ with spilling
1200+
let left_smj = MemorySourceConfig::try_new_exec(
1201+
std::slice::from_ref(&input1),
1202+
schema1.clone(),
1203+
None,
1204+
)
1205+
.unwrap();
1206+
let right_smj = MemorySourceConfig::try_new_exec(
1207+
std::slice::from_ref(&input2),
1208+
schema2.clone(),
1209+
None,
1210+
)
1211+
.unwrap();
1212+
let smj = Arc::new(
1213+
SortMergeJoinExec::try_new(
1214+
left_smj,
1215+
right_smj,
1216+
on.clone(),
1217+
Some(filter.clone()),
1218+
*join_type,
1219+
vec![SortOptions::default(); on.len()],
1220+
NullEquality::NullEqualsNothing,
1221+
)
1222+
.unwrap(),
1223+
);
1224+
let task_ctx_spill = Arc::new(
1225+
TaskContext::default()
1226+
.with_session_config(session_config)
1227+
.with_runtime(Arc::clone(&runtime_spill)),
1228+
);
1229+
let smj_collected = collect(smj, task_ctx_spill).await.unwrap();
1230+
1231+
let hj_rows: usize = hj_collected.iter().map(|b| b.num_rows()).sum();
1232+
let smj_rows: usize = smj_collected.iter().map(|b| b.num_rows()).sum();
1233+
1234+
assert_eq!(
1235+
hj_rows, smj_rows,
1236+
"Row count mismatch for {join_type:?} batch_size={batch_size} \
1237+
left_extra={left_extra} right_extra={right_extra}: \
1238+
HJ={hj_rows} SMJ={smj_rows}"
1239+
);
1240+
1241+
if hj_rows > 0 {
1242+
let hj_fmt =
1243+
pretty_format_batches(&hj_collected).unwrap().to_string();
1244+
let smj_fmt =
1245+
pretty_format_batches(&smj_collected).unwrap().to_string();
1246+
1247+
let mut hj_sorted: Vec<&str> = hj_fmt.trim().lines().collect();
1248+
hj_sorted.sort_unstable();
1249+
let mut smj_sorted: Vec<&str> = smj_fmt.trim().lines().collect();
1250+
smj_sorted.sort_unstable();
1251+
1252+
assert_eq!(
1253+
hj_sorted, smj_sorted,
1254+
"Content mismatch for {join_type:?} batch_size={batch_size} \
1255+
left_extra={left_extra} right_extra={right_extra}"
1256+
);
1257+
}
1258+
}
1259+
}
1260+
}
1261+
}
1262+
11281263
/// Return randomly sized record batches with:
11291264
/// two sorted int32 columns 'a', 'b' ranged from 0..99 as join columns
11301265
/// two random int32 columns 'x', 'y' as other columns

datafusion/physical-plan/src/joins/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ mod cross_join;
3434
mod hash_join;
3535
mod nested_loop_join;
3636
mod piecewise_merge_join;
37-
pub(crate) mod semi_anti_sort_merge_join;
3837
mod sort_merge_join;
3938
mod stream_join_utils;
4039
mod symmetric_hash_join;

datafusion/physical-plan/src/joins/semi_anti_sort_merge_join/mod.rs

Lines changed: 0 additions & 25 deletions
This file was deleted.

0 commit comments

Comments
 (0)