Skip to content

Commit 78033fa

Browse files
authored
refactor: introduce ProbeEnd state in NestedLoopJoinExec (#22865)
## Which issue does this PR close? - Closes #22808. ## Rationale for this change Follow-up to #22791, as suggested in review by @2010YOUY01. That PR fixed a double-decrement bug where `EmitLeftUnmatched` did two jobs at once — deciding whether a partition emits unmatched-left rows (which decrements the shared `probe_threads_counter`) and performing the emit. Because the state is re-enterable (a ready batch can be flushed before the state advances to `Done`), the counter could be decremented twice, driving it to zero before all partitions finished probing and emitting spurious NULL-padded rows. #22791 patched this with a `probe_completed_reported` guard flag. This refactor makes "decrement exactly once per probe stream" a structural property of the state graph rather than a runtime guard, so the inner logic is easier to follow and the bug is harder to reintroduce. ## What changes are included in this PR? Restructures the state machine from `FetchingRight → EmitLeftUnmatched` to `FetchingRight → ProbeEnd → EmitLeftUnmatched`: - Adds a dedicated `ProbeEnd` state, entered exactly once per left chunk when the right side is exhausted. It owns the single `report_probe_completed()` call and records whether this stream is the unmatched-left emitter. - Replaces the `probe_completed_reported` guard flag with an `is_unmatched_left_emitter` field that `EmitLeftUnmatched` only reads. - Removes the per-chunk flag reset in the memory-limited path (the decision is recomputed in `ProbeEnd` for each chunk) and reverts the `Arc::clone` workaround #22791 needed in `process_left_unmatched`. - Updates the state-transition doc graph and arm comments. No behavior change is expected. ## Are these changes tested? Yes — covered by existing tests: - All 42 `nested_loop_join` unit tests and the full `datafusion-physical-plan` suite pass. - `joins.slt` sqllogictests pass (including the multi-partition LEFT JOIN regression test added in #22791). - 41 `join_fuzz` tests (`cargo test --features extended_tests`) comparing `NestedLoopJoinExec` against `HashJoinExec` across every join type, filtered and unfiltered, with a multi-partition probe side — the exact scenario class of the original bug — pass. - `cargo fmt` and `cargo clippy --all-targets --all-features -- -D warnings` are clean. ## Are there any user-facing changes? No.
1 parent a7280b8 commit 78033fa

1 file changed

Lines changed: 96 additions & 45 deletions

File tree

datafusion/physical-plan/src/joins/nested_loop_join.rs

Lines changed: 96 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,16 @@ enum NLJState {
874874
FetchingRight,
875875
ProbeRight,
876876
EmitRightUnmatched,
877+
/// Entered exactly once per left chunk, when the probe (right) side is
878+
/// exhausted and probing for the current chunk is finished. This state
879+
/// owns the single [`JoinLeftData::report_probe_completed`] call that
880+
/// decrements the shared probe-threads counter, and records in
881+
/// `is_unmatched_left_emitter` whether this stream is the one responsible
882+
/// for emitting unmatched-left rows. Splitting this decision out of
883+
/// `EmitLeftUnmatched` makes "decrement exactly once" a structural
884+
/// property of the state graph, so the (re-enterable) emit state no longer
885+
/// has to guard against decrementing twice.
886+
ProbeEnd,
877887
EmitLeftUnmatched,
878888
/// Emit unmatched right rows using the global bitmap accumulated across
879889
/// all left chunks. Only used in memory-limited mode for join types that
@@ -1065,16 +1075,17 @@ pub(crate) struct NestedLoopJoinStream {
10651075
/// Memory-limited spill fallback state. See [`SpillState`] for details.
10661076
spill_state: SpillState,
10671077

1068-
/// Whether this stream has already reported probe completion for the current
1069-
/// left chunk via [`JoinLeftData::report_probe_completed`]. The shared
1070-
/// probe-threads counter must be decremented exactly once per probe stream;
1071-
/// without this guard a stream that yields a ready batch while finishing the
1072-
/// `EmitLeftUnmatched` state (and is then re-polled with `left_emit_idx`
1073-
/// still 0) would decrement the counter twice, driving it to zero
1074-
/// prematurely and causing a sibling partition to emit unmatched-left rows
1075-
/// before all partitions finished probing (spurious NULL-padded rows).
1076-
/// Reset to `false` when starting a new left chunk in memory-limited mode.
1077-
probe_completed_reported: bool,
1078+
/// Whether this stream is the one responsible for emitting unmatched-left
1079+
/// rows for the current left chunk. Set in the [`NLJState::ProbeEnd`] state,
1080+
/// which is entered exactly once per chunk and owns the single
1081+
/// [`JoinLeftData::report_probe_completed`] call: the stream that drives the
1082+
/// shared probe-threads counter to zero (the last to finish probing) becomes
1083+
/// the emitter. Because the decrement happens once in `ProbeEnd` rather than
1084+
/// in the re-enterable `EmitLeftUnmatched` state, the counter can never be
1085+
/// decremented twice, so it cannot reach zero before all partitions finish
1086+
/// probing (which would otherwise let a partition emit spurious NULL-padded
1087+
/// unmatched-left rows early).
1088+
is_unmatched_left_emitter: bool,
10781089
}
10791090

10801091
pub(crate) struct NestedLoopJoinMetrics {
@@ -1118,14 +1129,17 @@ impl Stream for NestedLoopJoinStream {
11181129
/// BufferingLeft → FetchingRight
11191130
///
11201131
/// FetchingRight → ProbeRight (if right batch available)
1121-
/// FetchingRight → EmitLeftUnmatched (if right exhausted)
1132+
/// FetchingRight → ProbeEnd (if right exhausted)
11221133
///
11231134
/// ProbeRight → ProbeRight (next left row or after yielding output)
11241135
/// ProbeRight → EmitRightUnmatched (for special join types like right join)
11251136
/// ProbeRight → FetchingRight (done with the current right batch)
11261137
///
11271138
/// EmitRightUnmatched → FetchingRight
11281139
///
1140+
/// ProbeEnd → EmitLeftUnmatched (records whether this stream is the
1141+
/// unmatched-left emitter, then always continues to EmitLeftUnmatched)
1142+
///
11291143
/// EmitLeftUnmatched → EmitLeftUnmatched (only process 1 chunk for each
11301144
/// iteration)
11311145
/// EmitLeftUnmatched → Done (if finished)
@@ -1161,8 +1175,8 @@ impl Stream for NestedLoopJoinStream {
11611175
// 1. --> ProbeRight
11621176
// Start processing the join for the newly fetched right
11631177
// batch.
1164-
// 2. --> EmitLeftUnmatched: When the right side input is exhausted, (maybe) emit
1165-
// unmatched left side rows.
1178+
// 2. --> ProbeEnd: When the right side input is exhausted,
1179+
// probing for the current left chunk is finished.
11661180
//
11671181
// After fetching a new batch from the right side, it will
11681182
// process all rows from the buffered left data:
@@ -1176,9 +1190,10 @@ impl Stream for NestedLoopJoinStream {
11761190
// at once in memory.
11771191
//
11781192
// So after the right side input is exhausted, the join phase
1179-
// for the current buffered left data is finished. We can go to
1180-
// the next `EmitLeftUnmatched` phase to check if there is any
1181-
// special handling (e.g., in cases like left join).
1193+
// for the current buffered left data is finished. We go to the
1194+
// `ProbeEnd` state, which records probe completion before the
1195+
// `EmitLeftUnmatched` phase checks if there is any special
1196+
// handling (e.g., in cases like left join).
11821197
NLJState::FetchingRight => {
11831198
debug!("[NLJState] Entering: {:?}", self.state);
11841199
// stop on drop
@@ -1241,6 +1256,28 @@ impl Stream for NestedLoopJoinStream {
12411256
}
12421257
}
12431258

1259+
// NLJState transitions:
1260+
// 1. --> EmitLeftUnmatched
1261+
// Probing for the current left chunk is finished. Report
1262+
// probe completion exactly once (decrementing the shared
1263+
// probe-threads counter) and record whether this stream is
1264+
// the unmatched-left emitter, then always advance to
1265+
// `EmitLeftUnmatched`.
1266+
NLJState::ProbeEnd => {
1267+
debug!("[NLJState] Entering: {:?}", self.state);
1268+
1269+
// stop on drop
1270+
let join_metric = self.metrics.join_metrics.join_time.clone();
1271+
let _join_timer = join_metric.timer();
1272+
1273+
match self.handle_probe_end() {
1274+
ControlFlow::Continue(()) => continue,
1275+
ControlFlow::Break(poll) => {
1276+
return self.metrics.join_metrics.baseline.record_poll(poll);
1277+
}
1278+
}
1279+
}
1280+
12441281
// NLJState transitions:
12451282
// 1. --> EmitLeftUnmatched(1)
12461283
// If we have already buffered enough output to yield, it
@@ -1348,7 +1385,7 @@ impl NestedLoopJoinStream {
13481385
handled_empty_output: false,
13491386
should_track_unmatched_right: need_produce_right_in_final(join_type),
13501387
spill_state,
1351-
probe_completed_reported: false,
1388+
is_unmatched_left_emitter: false,
13521389
}
13531390
}
13541391

@@ -1724,7 +1761,10 @@ impl NestedLoopJoinStream {
17241761
}
17251762
Some(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
17261763
None => {
1727-
self.state = NLJState::EmitLeftUnmatched;
1764+
// Right side exhausted: probing for the current left chunk
1765+
// is finished. `ProbeEnd` reports probe completion before
1766+
// emitting unmatched-left rows.
1767+
self.state = NLJState::ProbeEnd;
17281768
ControlFlow::Continue(())
17291769
}
17301770
},
@@ -1837,6 +1877,34 @@ impl NestedLoopJoinStream {
18371877
}
18381878
}
18391879

1880+
/// Handle ProbeEnd state - record probe completion for the current chunk.
1881+
///
1882+
/// Entered exactly once per left chunk, when the right side is exhausted.
1883+
/// This is the single place that decrements the shared probe-threads counter
1884+
/// via [`JoinLeftData::report_probe_completed`]: the stream that drives the
1885+
/// counter to zero (the last to finish probing) is the one responsible for
1886+
/// emitting unmatched-left rows, recorded in `is_unmatched_left_emitter`.
1887+
///
1888+
/// Owning the decrement here — rather than in the re-enterable
1889+
/// `EmitLeftUnmatched` state — makes "decrement exactly once per stream" a
1890+
/// structural property of the state graph, so the counter cannot reach zero
1891+
/// before all partitions finish probing (which would let a partition emit
1892+
/// spurious NULL-padded unmatched-left rows early).
1893+
///
1894+
/// Always transitions to `EmitLeftUnmatched`.
1895+
fn handle_probe_end(&mut self) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1896+
// Decrement the shared counter exactly once for this stream/chunk. The
1897+
// last stream to finish probing (the one that drives the counter to
1898+
// zero) becomes the unmatched-left emitter.
1899+
let is_emitter = match self.get_left_data() {
1900+
Ok(left_data) => left_data.report_probe_completed(),
1901+
Err(e) => return ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1902+
};
1903+
self.is_unmatched_left_emitter = is_emitter;
1904+
self.state = NLJState::EmitLeftUnmatched;
1905+
ControlFlow::Continue(())
1906+
}
1907+
18401908
/// Handle EmitLeftUnmatched state - emit unmatched left rows.
18411909
///
18421910
/// In memory-limited mode, after processing all unmatched rows for the
@@ -1876,9 +1944,9 @@ impl NestedLoopJoinStream {
18761944
self.left_probe_idx = 0;
18771945
self.left_emit_idx = 0;
18781946
// Each memory-limited chunk gets a fresh per-chunk
1879-
// `JoinLeftData`/counter, so allow this stream to report
1880-
// completion again for the next chunk.
1881-
self.probe_completed_reported = false;
1947+
// `JoinLeftData`/counter; `is_unmatched_left_emitter` is
1948+
// recomputed when `ProbeEnd` is re-entered for the next
1949+
// chunk, so it does not need to be reset here.
18821950
self.state = NLJState::BufferingLeft;
18831951
} else if self.is_memory_limited()
18841952
&& self.should_track_unmatched_right
@@ -2357,9 +2425,7 @@ impl NestedLoopJoinStream {
23572425
/// true -> continue in the same EmitLeftUnmatched state
23582426
/// false -> next state (Done)
23592427
fn process_left_unmatched(&mut self) -> Result<bool> {
2360-
// Clone the shared `Arc<JoinLeftData>` so the immutable borrow of `self`
2361-
// ends here and we can update `self.probe_completed_reported` below.
2362-
let left_data = Arc::clone(self.get_left_data()?);
2428+
let left_data = self.get_left_data()?;
23632429
let left_batch = left_data.batch();
23642430

23652431
// ========
@@ -2368,29 +2434,14 @@ impl NestedLoopJoinStream {
23682434

23692435
// Early return if join type can't have unmatched rows
23702436
let join_type_no_produce_left = !need_produce_result_in_final(self.join_type);
2371-
// Early return if another thread is already processing unmatched rows.
2372-
//
2373-
// The shared probe-threads counter must be decremented exactly once per
2374-
// probe stream. This function can be re-entered with `left_emit_idx`
2375-
// still 0 (e.g. when a ready batch was flushed via an early return in
2376-
// `handle_emit_left_unmatched` before the state advanced), so guard the
2377-
// decrement with `probe_completed_reported` instead of relying solely on
2378-
// `left_emit_idx == 0`. Decrementing twice would drive the counter to
2379-
// zero prematurely and let a partition emit unmatched-left rows before
2380-
// all partitions finished probing, producing spurious NULL-padded rows.
2381-
let handled_by_other_partition = if self.probe_completed_reported {
2382-
// Already counted this stream's completion; if we're the designated
2383-
// emitter we have `left_emit_idx > 0` (or are mid-emit) and continue,
2384-
// otherwise another partition is handling emission.
2385-
self.left_emit_idx == 0
2386-
} else {
2387-
self.probe_completed_reported = true;
2388-
self.left_emit_idx == 0 && !left_data.report_probe_completed()
2389-
};
23902437
// Stop processing unmatched rows, the caller will go to the next state
23912438
let finished = self.left_emit_idx >= left_batch.num_rows();
23922439

2393-
if join_type_no_produce_left || handled_by_other_partition || finished {
2440+
// `ProbeEnd` already recorded whether this stream emits unmatched-left
2441+
// rows. Every probe partition passes through this state, but only the
2442+
// one that finished probing last is the emitter, so this flag is false
2443+
// for the others.
2444+
if join_type_no_produce_left || !self.is_unmatched_left_emitter || finished {
23942445
return Ok(false);
23952446
}
23962447

@@ -2402,7 +2453,7 @@ impl NestedLoopJoinStream {
24022453
let end_idx = std::cmp::min(start_idx + self.batch_size, left_batch.num_rows());
24032454

24042455
if let Some(batch) =
2405-
self.process_left_unmatched_range(&left_data, start_idx, end_idx)?
2456+
self.process_left_unmatched_range(left_data, start_idx, end_idx)?
24062457
{
24072458
self.output_buffer.push_batch(batch)?;
24082459
}

0 commit comments

Comments
 (0)