From 30ed823f1cab8ad049156bfe64392cbb0ed15ba7 Mon Sep 17 00:00:00 2001 From: Rui Fan <1996fanrui@gmail.com> Date: Wed, 18 Feb 2026 21:26:40 +0100 Subject: [PATCH 1/3] [FLINK-38543][checkpoint] Fix Mailbox loop interrupted before recovery finished MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Return allOf future instead of thenRun future. thenRun() returns a NEW future that completes only after the callback finishes. CompletableFuture executes thenRun callbacks synchronously on the thread that calls complete(). When recoveredFutures contains bufferFilteringCompleteFuture (checkpointingDuringRecovery enabled), complete() is called on channelIOExecutor (in finishReadRecoveredState), so thenRun(suspend) also runs on channelIOExecutor. suspend() sends a poison mail, and the mailbox thread can pick it up and exit runMailboxLoop() before the thenRun future completes — causing checkState(isDone) to fail. With stateConsumedFuture (the default), complete() runs on the mailbox thread itself, so thenRun(suspend) blocks the loop from processing the poison mail until the future completes — no race. Returning allOf future avoids the issue entirely. --- .../streaming/runtime/tasks/StreamTask.java | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 9ec03137842dd..5ca9a5662e30e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -913,8 +913,21 @@ private CompletableFuture restoreStateAndGates( "Input gate request partitions")); } - return CompletableFuture.allOf(recoveredFutures.toArray(new CompletableFuture[0])) - .thenRun(mailboxProcessor::suspend); + // Return allOf future instead of thenRun future. thenRun() returns a NEW future that + // completes only after the callback finishes. CompletableFuture executes thenRun callbacks + // synchronously on the thread that calls complete(). When recoveredFutures contains + // bufferFilteringCompleteFuture (checkpointingDuringRecovery enabled), complete() is called + // on channelIOExecutor (in finishReadRecoveredState), so thenRun(suspend) also runs on + // channelIOExecutor. suspend() sends a poison mail, and the mailbox thread can pick it up + // and exit runMailboxLoop() before the thenRun future completes — causing + // checkState(isDone) to fail. With stateConsumedFuture (the default), complete() runs on + // the mailbox thread itself, so thenRun(suspend) blocks the loop from processing the poison + // mail until the future completes — no race. Returning allOf future avoids the issue + // entirely. + CompletableFuture allRecoveredFuture = + CompletableFuture.allOf(recoveredFutures.toArray(new CompletableFuture[0])); + allRecoveredFuture.thenRun(mailboxProcessor::suspend); + return allRecoveredFuture; } private void ensureNotCanceled() { From d832f84a3879b1b2e1303d80637b928efc245db6 Mon Sep 17 00:00:00 2001 From: Rui Fan <1996fanrui@gmail.com> Date: Wed, 18 Feb 2026 21:26:40 +0100 Subject: [PATCH 2/3] [FLINK-38543][checkpoint] Introduce bufferFilteringCompleteFuture for earlier RUNNING state transition --- .../network/partition/consumer/InputGate.java | 7 + .../consumer/RecoveredInputChannel.java | 41 ++++- .../partition/consumer/SingleInputGate.java | 15 ++ .../partition/consumer/UnionInputGate.java | 9 ++ .../taskmanager/InputGateWithMetrics.java | 5 + .../consumer/RecoveredInputChannelTest.java | 143 +++++++++++++++++- .../runtime/io/MockIndexedInputGate.java | 5 + .../streaming/runtime/io/MockInputGate.java | 5 + .../AlignedCheckpointsMassiveRandomTest.java | 5 + 9 files changed, 225 insertions(+), 10 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java index 11d22a8df4ae6..dd744bae330ff 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java @@ -192,5 +192,12 @@ public String toString() { public abstract CompletableFuture getStateConsumedFuture(); + /** + * Returns a future that completes when buffer filtering is complete for all channels. This + * future completes before {@link #getStateConsumedFuture()}, enabling earlier RUNNING state + * transition when unaligned checkpoint during recovery is enabled. + */ + public abstract CompletableFuture getBufferFilteringCompleteFuture(); + public abstract void finishReadRecoveredState() throws IOException; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java index d2a7a07137df5..d9b7885815bd1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java @@ -62,6 +62,13 @@ public abstract class RecoveredInputChannel extends InputChannel implements Chan private final CompletableFuture stateConsumedFuture = new CompletableFuture<>(); protected final BufferManager bufferManager; + /** + * Future that completes when recovered buffers have been filtered for this channel. This + * completes before stateConsumedFuture, enabling earlier RUNNING state transition when + * unaligned checkpoint during recovery is enabled. + */ + private final CompletableFuture bufferFilteringCompleteFuture = new CompletableFuture<>(); + @GuardedBy("receivedBuffers") private boolean isReleased; @@ -110,7 +117,11 @@ public void setChannelStateWriter(ChannelStateWriter channelStateWriter) { public final InputChannel toInputChannel() throws IOException { Preconditions.checkState( - stateConsumedFuture.isDone(), "recovered state is not fully consumed"); + bufferFilteringCompleteFuture.isDone(), "buffer filtering is not complete"); + if (!inputGate.isCheckpointingDuringRecoveryEnabled()) { + Preconditions.checkState( + stateConsumedFuture.isDone(), "recovered state is not fully consumed"); + } // Extract remaining buffers before conversion. // These buffers have been filtered but not yet consumed by the Task. @@ -140,6 +151,14 @@ public void checkpointStopped(long checkpointId) { protected abstract InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) throws IOException; + /** + * Returns the future that completes when buffer filtering is complete. This future completes + * before stateConsumedFuture, at the point when finishReadRecoveredState() is called. + */ + CompletableFuture getBufferFilteringCompleteFuture() { + return bufferFilteringCompleteFuture; + } + CompletableFuture getStateConsumedFuture() { return stateConsumedFuture; } @@ -176,8 +195,22 @@ public void onRecoveredStateBuffer(Buffer buffer) { } public void finishReadRecoveredState() throws IOException { - onRecoveredStateBuffer( - EventSerializer.toBuffer(EndOfInputChannelStateEvent.INSTANCE, false)); + // Adding the event and completing the future must be atomic under receivedBuffers lock. + // Without this, either ordering has a race: + // - event first: task thread consumes EndOfInputChannelStateEvent, which completes + // stateConsumedFuture. When checkpointing during recovery is disabled, + // stateConsumedFuture triggers requestPartitions -> toInputChannel(), which + // fails because bufferFilteringCompleteFuture is not yet done. + // - future first: toInputChannel() extracts buffers before the event is added, + // losing the EndOfInputChannelStateEvent. + // Both toInputChannel() and getNextRecoveredStateBuffer() synchronize on + // receivedBuffers, so holding the same lock here guarantees + // bufferFilteringCompleteFuture is always done before stateConsumedFuture. + synchronized (receivedBuffers) { + onRecoveredStateBuffer( + EventSerializer.toBuffer(EndOfInputChannelStateEvent.INSTANCE, false)); + bufferFilteringCompleteFuture.complete(null); + } bufferManager.releaseFloatingBuffers(); LOG.debug("{}/{} finished recovering input.", inputGate.getOwningTaskName(), channelInfo); } @@ -196,6 +229,8 @@ private BufferAndAvailability getNextRecoveredStateBuffer() throws IOException { if (next == null) { return null; } else if (isEndOfInputChannelStateEvent(next)) { + Preconditions.checkState( + bufferFilteringCompleteFuture.isDone(), "buffer filtering is not complete"); stateConsumedFuture.complete(null); return null; } else { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index 2847e36fcc2b8..438efa2f58bd5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -341,6 +341,21 @@ public boolean isCheckpointingDuringRecoveryEnabled() { return checkpointingDuringRecoveryEnabled; } + @Override + public CompletableFuture getBufferFilteringCompleteFuture() { + synchronized (requestLock) { + List> futures = new ArrayList<>(numberOfInputChannels); + for (InputChannel inputChannel : inputChannels()) { + if (inputChannel instanceof RecoveredInputChannel) { + futures.add( + ((RecoveredInputChannel) inputChannel) + .getBufferFilteringCompleteFuture()); + } + } + return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])); + } + } + @Override public void requestPartitions() { synchronized (requestLock) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java index 6c7c765938ba7..dda71c63be38f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java @@ -350,6 +350,15 @@ public CompletableFuture getStateConsumedFuture() { .toArray(new CompletableFuture[] {})); } + @Override + public CompletableFuture getBufferFilteringCompleteFuture() { + return CompletableFuture.allOf( + inputGatesByGateIndex.values().stream() + .map(InputGate::getBufferFilteringCompleteFuture) + .collect(Collectors.toList()) + .toArray(new CompletableFuture[] {})); + } + @Override public void requestPartitions() throws IOException { for (InputGate inputGate : inputGatesByGateIndex.values()) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java index 31775fb21deda..bff412f53b330 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java @@ -120,6 +120,11 @@ public CompletableFuture getStateConsumedFuture() { return inputGate.getStateConsumedFuture(); } + @Override + public CompletableFuture getBufferFilteringCompleteFuture() { + return inputGate.getBufferFilteringCompleteFuture(); + } + @Override public void requestPartitions() throws IOException { inputGate.requestPartitions(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java index 5985a81e8ca86..f40fd09702ede 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java @@ -28,24 +28,30 @@ import org.junit.jupiter.api.Test; +import java.io.IOException; import java.util.ArrayDeque; import static org.apache.flink.runtime.checkpoint.CheckpointOptions.unaligned; import static org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; /** Tests for {@link RecoveredInputChannel}. */ class RecoveredInputChannelTest { @Test - void testConversionOnlyPossibleAfterConsumed() { - assertThatThrownBy(() -> buildChannel().toInputChannel()) - .isInstanceOf(IllegalStateException.class); + void testConversionOnlyPossibleAfterBufferFilteringComplete() { + // toInputChannel() always checks bufferFilteringCompleteFuture regardless of config + for (boolean configEnabled : new boolean[] {true, false}) { + assertThatThrownBy(() -> buildChannel(configEnabled).toInputChannel()) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("buffer filtering is not complete"); + } } @Test void testRequestPartitionsImpossible() { - assertThatThrownBy(() -> buildChannel().requestSubpartitions()) + assertThatThrownBy(() -> buildChannel(false).requestSubpartitions()) .isInstanceOf(UnsupportedOperationException.class); } @@ -53,7 +59,7 @@ void testRequestPartitionsImpossible() { void testCheckpointStartImpossible() { assertThatThrownBy( () -> - buildChannel() + buildChannel(false) .checkpointStarted( new CheckpointBarrier( 0L, @@ -64,10 +70,96 @@ void testCheckpointStartImpossible() { .isInstanceOf(CheckpointException.class); } - private RecoveredInputChannel buildChannel() { + @Test + void testToInputChannelAllowedWhenBufferFilteringCompleteAndConfigEnabled() throws IOException { + // When config is enabled, conversion is allowed when bufferFilteringCompleteFuture is done + TestableRecoveredInputChannel channel = buildTestableChannel(true); + + // Initially, conversion should fail + assertThatThrownBy(() -> channel.toInputChannel()) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("buffer filtering is not complete"); + + // After finishReadRecoveredState(), bufferFilteringCompleteFuture should be done + channel.finishReadRecoveredState(); + assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); + assertThat(channel.getStateConsumedFuture()).isNotDone(); + + // Conversion should now succeed (no exception) + InputChannel converted = channel.toInputChannel(); + assertThat(converted).isNotNull(); + } + + @Test + void testToInputChannelAllowedWhenStateConsumedAndConfigDisabled() throws IOException { + // When config is disabled, conversion requires both bufferFilteringCompleteFuture + // and stateConsumedFuture to be done + TestableRecoveredInputChannel channel = buildTestableChannel(false); + + // Initially, conversion should fail (buffer filtering not complete) + assertThatThrownBy(() -> channel.toInputChannel()) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("buffer filtering is not complete"); + + // After finishReadRecoveredState(), bufferFilteringCompleteFuture is done + // but stateConsumedFuture is not + channel.finishReadRecoveredState(); + assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); + assertThat(channel.getStateConsumedFuture()).isNotDone(); + + // Conversion should still fail because stateConsumedFuture is not done + assertThatThrownBy(() -> channel.toInputChannel()) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("recovered state is not fully consumed"); + + // Consume the EndOfInputChannelStateEvent to complete stateConsumedFuture + assertThat(channel.getNextBuffer()).isNotPresent(); + assertThat(channel.getStateConsumedFuture()).isDone(); + + // Now conversion should succeed + InputChannel converted = channel.toInputChannel(); + assertThat(converted).isNotNull(); + } + + @Test + void testBufferFilteringCompleteFutureAlwaysCompletes() throws IOException { + // finishReadRecoveredState() unconditionally completes bufferFilteringCompleteFuture + for (boolean configEnabled : new boolean[] {true, false}) { + RecoveredInputChannel channel = buildChannel(configEnabled); + assertThat(channel.getBufferFilteringCompleteFuture()).isNotDone(); + channel.finishReadRecoveredState(); + assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); + } + } + + @Test + void testStateConsumedFutureCompletesAfterConsumingAllBuffers() throws IOException { + // This test verifies that stateConsumedFuture completes after consuming + // EndOfInputChannelStateEvent regardless of the config setting + for (boolean configEnabled : new boolean[] {true, false}) { + RecoveredInputChannel channel = buildChannel(configEnabled); + + assertThat(channel.getStateConsumedFuture()).isNotDone(); + + channel.finishReadRecoveredState(); + assertThat(channel.getStateConsumedFuture()).isNotDone(); + + // Consuming the EndOfInputChannelStateEvent should complete the future. + // getNextBuffer() returns empty when it encounters the event internally. + assertThat(channel.getNextBuffer()).isNotPresent(); + assertThat(channel.getStateConsumedFuture()).isDone(); + } + } + + private RecoveredInputChannel buildChannel(boolean checkpointingDuringRecoveryEnabled) { try { + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setCheckpointingDuringRecoveryEnabled( + checkpointingDuringRecoveryEnabled) + .build(); return new RecoveredInputChannel( - new SingleInputGateBuilder().build(), + inputGate, 0, new ResultPartitionID(), new ResultSubpartitionIndexSet(0), @@ -85,4 +177,41 @@ protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffer throw new AssertionError("channel creation failed", e); } } + + private TestableRecoveredInputChannel buildTestableChannel( + boolean checkpointingDuringRecoveryEnabled) { + try { + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setCheckpointingDuringRecoveryEnabled( + checkpointingDuringRecoveryEnabled) + .build(); + return new TestableRecoveredInputChannel(inputGate); + } catch (Exception e) { + throw new AssertionError("channel creation failed", e); + } + } + + /** + * A RecoveredInputChannel that returns a TestInputChannel when converted, for testing purposes. + */ + private static class TestableRecoveredInputChannel extends RecoveredInputChannel { + TestableRecoveredInputChannel(SingleInputGate inputGate) { + super( + inputGate, + 0, + new ResultPartitionID(), + new ResultSubpartitionIndexSet(0), + 0, + 0, + new SimpleCounter(), + new SimpleCounter(), + 10); + } + + @Override + protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) { + return new TestInputChannel(inputGate, 0); + } + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java index 618f31e518f8b..584aeb7eb9089 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java @@ -57,6 +57,11 @@ public CompletableFuture getStateConsumedFuture() { return CompletableFuture.completedFuture(null); } + @Override + public CompletableFuture getBufferFilteringCompleteFuture() { + return CompletableFuture.completedFuture(null); + } + @Override public void finishReadRecoveredState() {} diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java index a35c8995d2a29..71b2c43f3306a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java @@ -80,6 +80,11 @@ public CompletableFuture getStateConsumedFuture() { return CompletableFuture.completedFuture(null); } + @Override + public CompletableFuture getBufferFilteringCompleteFuture() { + return CompletableFuture.completedFuture(null); + } + @Override public void finishReadRecoveredState() {} diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java index da3d25519345d..619873c387d08 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java @@ -263,6 +263,11 @@ public CompletableFuture getStateConsumedFuture() { return CompletableFuture.completedFuture(null); } + @Override + public CompletableFuture getBufferFilteringCompleteFuture() { + return CompletableFuture.completedFuture(null); + } + @Override public void finishReadRecoveredState() {} From 6db933f4520ab16da01f664d2dd149e59412d25d Mon Sep 17 00:00:00 2001 From: Rui Fan <1996fanrui@gmail.com> Date: Wed, 18 Feb 2026 21:26:40 +0100 Subject: [PATCH 3/3] [FLINK-38543][checkpoint] Change overall UC restore process for checkpoint during recovery --- .../CheckpointingOptionsTest.java | 2 +- .../streaming/runtime/tasks/StreamTask.java | 25 ++++---- .../consumer/SingleInputGateBuilder.java | 8 +++ .../consumer/SingleInputGateTest.java | 30 ++++++++++ .../consumer/UnionInputGateTest.java | 58 +++++++++++++++++++ 5 files changed, 112 insertions(+), 11 deletions(-) diff --git a/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java b/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java index 75c4ec07d234c..7f895ef1e516c 100644 --- a/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java +++ b/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java @@ -330,7 +330,7 @@ void testIsUnalignedCheckpointInterruptibleTimersEnabled() { } @Test - void testIsUnalignedDuringRecoveryEnabled() { + void testIsCheckpointingDuringRecoveryEnabled() { // Test when both options are disabled (default) - should return false Configuration defaultConfig = new Configuration(); assertThat(CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(defaultConfig)) diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 5ca9a5662e30e..7938a1ef278b7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -883,6 +883,9 @@ private CompletableFuture restoreStateAndGates( boolean checkpointingDuringRecoveryEnabled = CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration()); + + // Must set the flag on input gates BEFORE starting the async read task, because + // finishReadRecoveredState() checks this flag to complete bufferFilteringCompleteFuture. for (IndexedInputGate inputGate : inputGates) { inputGate.setCheckpointingDuringRecoveryEnabled(checkpointingDuringRecoveryEnabled); } @@ -899,18 +902,20 @@ private CompletableFuture restoreStateAndGates( // We wait for all input channel state to recover before we go into RUNNING state, and thus // start checkpointing. If we implement incremental checkpointing of input channel state - // we must make sure it supports CheckpointType#FULL_CHECKPOINT + // we must make sure it supports CheckpointType#FULL_CHECKPOINT. List> recoveredFutures = new ArrayList<>(inputGates.length); for (InputGate inputGate : inputGates) { - recoveredFutures.add(inputGate.getStateConsumedFuture()); - - inputGate - .getStateConsumedFuture() - .thenRun( - () -> - mainMailboxExecutor.execute( - inputGate::requestPartitions, - "Input gate request partitions")); + CompletableFuture requestPartitionsTrigger = + checkpointingDuringRecoveryEnabled + ? inputGate.getBufferFilteringCompleteFuture() + : inputGate.getStateConsumedFuture(); + + recoveredFutures.add(requestPartitionsTrigger); + + requestPartitionsTrigger.thenRun( + () -> + mainMailboxExecutor.execute( + inputGate::requestPartitions, "Input gate request partitions")); } // Return allOf future instead of thenRun future. thenRun() returns a NEW future that diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java index e4a4c289dc6e8..a4da811f8a32a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java @@ -83,6 +83,8 @@ public class SingleInputGateBuilder { private TieredStorageConsumerClient tieredStorageConsumerClient = null; + private boolean isCheckpointingDuringRecoveryEnabled = false; + public SingleInputGateBuilder setPartitionProducerStateProvider( PartitionProducerStateProvider partitionProducerStateProvider) { @@ -167,6 +169,11 @@ public SingleInputGateBuilder setTieredStorageConsumerClient( return this; } + public SingleInputGateBuilder setCheckpointingDuringRecoveryEnabled(boolean enabled) { + this.isCheckpointingDuringRecoveryEnabled = enabled; + return this; + } + public SingleInputGate build() { SingleInputGate gate = new SingleInputGate( @@ -195,6 +202,7 @@ public SingleInputGate build() { .toArray(InputChannel[]::new)); } gate.setTieredStorageService(null, tieredStorageConsumerClient, null); + gate.setCheckpointingDuringRecoveryEnabled(isCheckpointingDuringRecoveryEnabled); return gate; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java index f7f0b744fb9fd..b2cc9d7ce3c9c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -142,6 +142,36 @@ void testCheckpointsDeclinedUnlessStateConsumed() { .isInstanceOf(CheckpointException.class); } + @Test + void testBufferFilteringCompleteFutureAggregation() throws Exception { + final NettyShuffleEnvironment environment = createNettyShuffleEnvironment(); + final SingleInputGate inputGate = createInputGate(environment); + try (Closer closer = Closer.create()) { + closer.register(environment::close); + closer.register(inputGate::close); + + // Enable unaligned during recovery for this test so that + // bufferFilteringCompleteFuture is completed by finishReadRecoveredState() + inputGate.setCheckpointingDuringRecoveryEnabled(true); + inputGate.setup(); + + // Initially, the aggregated future should not be completed + assertThat(inputGate.getBufferFilteringCompleteFuture()).isNotDone(); + + // After finishing read recovered state, bufferFilteringCompleteFuture should be + // completed (only when config is enabled) + inputGate.finishReadRecoveredState(); + assertThat(inputGate.getBufferFilteringCompleteFuture()).isDone(); + + // stateConsumedFuture should not be completed until data is consumed + assertThat(inputGate.getStateConsumedFuture()).isNotDone(); + + // Consuming the EndOfInputChannelStateEvent should complete stateConsumedFuture + inputGate.pollNext(); + assertThat(inputGate.getStateConsumedFuture()).isDone(); + } + } + /** * Tests {@link InputGate#setup()} should create the respective {@link BufferPool} and assign * exclusive buffers for {@link RemoteInputChannel}s, but should not request partitions. diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java index 419246137e8c2..1ed1a42a66ea0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java @@ -18,12 +18,14 @@ package org.apache.flink.runtime.io.network.partition.consumer; +import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.io.PullingAsyncDataInput; import org.apache.flink.runtime.io.network.api.StopMode; import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils; import org.apache.flink.runtime.io.network.partition.NoOpResultSubpartitionView; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateTest.TestingResultPartitionManager; import org.junit.jupiter.api.Test; @@ -275,6 +277,62 @@ void testGetChannelWithShiftedGateIndexes() { assertThat(unionInputGate.getChannel(1)).isEqualTo(inputChannel2); } + @Test + void testBufferFilteringCompleteFutureAggregation() throws IOException { + // Create 2 SingleInputGates, each with 1 RecoveredInputChannel + SingleInputGate ig1 = + new SingleInputGateBuilder().setCheckpointingDuringRecoveryEnabled(true).build(); + RecoveredInputChannel channel1 = buildRecoveredChannel(ig1); + ig1.setInputChannels(channel1); + + SingleInputGate ig2 = + new SingleInputGateBuilder() + .setSingleInputGateIndex(1) + .setCheckpointingDuringRecoveryEnabled(true) + .build(); + RecoveredInputChannel channel2 = buildRecoveredChannel(ig2); + ig2.setInputChannels(channel2); + + UnionInputGate union = new UnionInputGate(ig1, ig2); + + // Initially, bufferFilteringCompleteFuture should not be done + assertThat(union.getBufferFilteringCompleteFuture()).isNotDone(); + assertThat(union.getStateConsumedFuture()).isNotDone(); + + // Complete buffer filtering on first gate only + channel1.finishReadRecoveredState(); + assertThat(ig1.getBufferFilteringCompleteFuture()).isDone(); + assertThat(union.getBufferFilteringCompleteFuture()).isNotDone(); + + // Complete buffer filtering on second gate + channel2.finishReadRecoveredState(); + assertThat(ig2.getBufferFilteringCompleteFuture()).isDone(); + assertThat(union.getBufferFilteringCompleteFuture()).isDone(); + + // State consumed futures should still NOT be done (state not consumed yet) + assertThat(union.getStateConsumedFuture()).isNotDone(); + } + + private static RecoveredInputChannel buildRecoveredChannel(SingleInputGate inputGate) { + return new RecoveredInputChannel( + inputGate, + 0, + new ResultPartitionID(), + new ResultSubpartitionIndexSet(0), + 0, + 0, + new SimpleCounter(), + new SimpleCounter(), + 10) { + @Override + protected InputChannel toInputChannelInternal( + java.util.ArrayDeque + remainingBuffers) { + throw new UnsupportedOperationException(); + } + }; + } + @Test void testEmptyPull() throws IOException, InterruptedException { final SingleInputGate inputGate1 = createInputGate(1);