diff --git a/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java b/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java index 75c4ec07d234c..7f895ef1e516c 100644 --- a/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java +++ b/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java @@ -330,7 +330,7 @@ void testIsUnalignedCheckpointInterruptibleTimersEnabled() { } @Test - void testIsUnalignedDuringRecoveryEnabled() { + void testIsCheckpointingDuringRecoveryEnabled() { // Test when both options are disabled (default) - should return false Configuration defaultConfig = new Configuration(); assertThat(CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(defaultConfig)) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java index 11d22a8df4ae6..dd744bae330ff 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java @@ -192,5 +192,12 @@ public String toString() { public abstract CompletableFuture getStateConsumedFuture(); + /** + * Returns a future that completes when buffer filtering is complete for all channels. This + * future completes before {@link #getStateConsumedFuture()}, enabling earlier RUNNING state + * transition when unaligned checkpoint during recovery is enabled. + */ + public abstract CompletableFuture getBufferFilteringCompleteFuture(); + public abstract void finishReadRecoveredState() throws IOException; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java index d2a7a07137df5..d9b7885815bd1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java @@ -62,6 +62,13 @@ public abstract class RecoveredInputChannel extends InputChannel implements Chan private final CompletableFuture stateConsumedFuture = new CompletableFuture<>(); protected final BufferManager bufferManager; + /** + * Future that completes when recovered buffers have been filtered for this channel. This + * completes before stateConsumedFuture, enabling earlier RUNNING state transition when + * unaligned checkpoint during recovery is enabled. + */ + private final CompletableFuture bufferFilteringCompleteFuture = new CompletableFuture<>(); + @GuardedBy("receivedBuffers") private boolean isReleased; @@ -110,7 +117,11 @@ public void setChannelStateWriter(ChannelStateWriter channelStateWriter) { public final InputChannel toInputChannel() throws IOException { Preconditions.checkState( - stateConsumedFuture.isDone(), "recovered state is not fully consumed"); + bufferFilteringCompleteFuture.isDone(), "buffer filtering is not complete"); + if (!inputGate.isCheckpointingDuringRecoveryEnabled()) { + Preconditions.checkState( + stateConsumedFuture.isDone(), "recovered state is not fully consumed"); + } // Extract remaining buffers before conversion. // These buffers have been filtered but not yet consumed by the Task. @@ -140,6 +151,14 @@ public void checkpointStopped(long checkpointId) { protected abstract InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) throws IOException; + /** + * Returns the future that completes when buffer filtering is complete. This future completes + * before stateConsumedFuture, at the point when finishReadRecoveredState() is called. + */ + CompletableFuture getBufferFilteringCompleteFuture() { + return bufferFilteringCompleteFuture; + } + CompletableFuture getStateConsumedFuture() { return stateConsumedFuture; } @@ -176,8 +195,22 @@ public void onRecoveredStateBuffer(Buffer buffer) { } public void finishReadRecoveredState() throws IOException { - onRecoveredStateBuffer( - EventSerializer.toBuffer(EndOfInputChannelStateEvent.INSTANCE, false)); + // Adding the event and completing the future must be atomic under receivedBuffers lock. + // Without this, either ordering has a race: + // - event first: task thread consumes EndOfInputChannelStateEvent, which completes + // stateConsumedFuture. When checkpointing during recovery is disabled, + // stateConsumedFuture triggers requestPartitions -> toInputChannel(), which + // fails because bufferFilteringCompleteFuture is not yet done. + // - future first: toInputChannel() extracts buffers before the event is added, + // losing the EndOfInputChannelStateEvent. + // Both toInputChannel() and getNextRecoveredStateBuffer() synchronize on + // receivedBuffers, so holding the same lock here guarantees + // bufferFilteringCompleteFuture is always done before stateConsumedFuture. + synchronized (receivedBuffers) { + onRecoveredStateBuffer( + EventSerializer.toBuffer(EndOfInputChannelStateEvent.INSTANCE, false)); + bufferFilteringCompleteFuture.complete(null); + } bufferManager.releaseFloatingBuffers(); LOG.debug("{}/{} finished recovering input.", inputGate.getOwningTaskName(), channelInfo); } @@ -196,6 +229,8 @@ private BufferAndAvailability getNextRecoveredStateBuffer() throws IOException { if (next == null) { return null; } else if (isEndOfInputChannelStateEvent(next)) { + Preconditions.checkState( + bufferFilteringCompleteFuture.isDone(), "buffer filtering is not complete"); stateConsumedFuture.complete(null); return null; } else { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index 2847e36fcc2b8..438efa2f58bd5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -341,6 +341,21 @@ public boolean isCheckpointingDuringRecoveryEnabled() { return checkpointingDuringRecoveryEnabled; } + @Override + public CompletableFuture getBufferFilteringCompleteFuture() { + synchronized (requestLock) { + List> futures = new ArrayList<>(numberOfInputChannels); + for (InputChannel inputChannel : inputChannels()) { + if (inputChannel instanceof RecoveredInputChannel) { + futures.add( + ((RecoveredInputChannel) inputChannel) + .getBufferFilteringCompleteFuture()); + } + } + return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])); + } + } + @Override public void requestPartitions() { synchronized (requestLock) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java index 6c7c765938ba7..dda71c63be38f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java @@ -350,6 +350,15 @@ public CompletableFuture getStateConsumedFuture() { .toArray(new CompletableFuture[] {})); } + @Override + public CompletableFuture getBufferFilteringCompleteFuture() { + return CompletableFuture.allOf( + inputGatesByGateIndex.values().stream() + .map(InputGate::getBufferFilteringCompleteFuture) + .collect(Collectors.toList()) + .toArray(new CompletableFuture[] {})); + } + @Override public void requestPartitions() throws IOException { for (InputGate inputGate : inputGatesByGateIndex.values()) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java index 31775fb21deda..bff412f53b330 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java @@ -120,6 +120,11 @@ public CompletableFuture getStateConsumedFuture() { return inputGate.getStateConsumedFuture(); } + @Override + public CompletableFuture getBufferFilteringCompleteFuture() { + return inputGate.getBufferFilteringCompleteFuture(); + } + @Override public void requestPartitions() throws IOException { inputGate.requestPartitions(); diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 9ec03137842dd..7938a1ef278b7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -883,6 +883,9 @@ private CompletableFuture restoreStateAndGates( boolean checkpointingDuringRecoveryEnabled = CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration()); + + // Must set the flag on input gates BEFORE starting the async read task, because + // finishReadRecoveredState() checks this flag to complete bufferFilteringCompleteFuture. for (IndexedInputGate inputGate : inputGates) { inputGate.setCheckpointingDuringRecoveryEnabled(checkpointingDuringRecoveryEnabled); } @@ -899,22 +902,37 @@ private CompletableFuture restoreStateAndGates( // We wait for all input channel state to recover before we go into RUNNING state, and thus // start checkpointing. If we implement incremental checkpointing of input channel state - // we must make sure it supports CheckpointType#FULL_CHECKPOINT + // we must make sure it supports CheckpointType#FULL_CHECKPOINT. List> recoveredFutures = new ArrayList<>(inputGates.length); for (InputGate inputGate : inputGates) { - recoveredFutures.add(inputGate.getStateConsumedFuture()); - - inputGate - .getStateConsumedFuture() - .thenRun( - () -> - mainMailboxExecutor.execute( - inputGate::requestPartitions, - "Input gate request partitions")); + CompletableFuture requestPartitionsTrigger = + checkpointingDuringRecoveryEnabled + ? inputGate.getBufferFilteringCompleteFuture() + : inputGate.getStateConsumedFuture(); + + recoveredFutures.add(requestPartitionsTrigger); + + requestPartitionsTrigger.thenRun( + () -> + mainMailboxExecutor.execute( + inputGate::requestPartitions, "Input gate request partitions")); } - return CompletableFuture.allOf(recoveredFutures.toArray(new CompletableFuture[0])) - .thenRun(mailboxProcessor::suspend); + // Return allOf future instead of thenRun future. thenRun() returns a NEW future that + // completes only after the callback finishes. CompletableFuture executes thenRun callbacks + // synchronously on the thread that calls complete(). When recoveredFutures contains + // bufferFilteringCompleteFuture (checkpointingDuringRecovery enabled), complete() is called + // on channelIOExecutor (in finishReadRecoveredState), so thenRun(suspend) also runs on + // channelIOExecutor. suspend() sends a poison mail, and the mailbox thread can pick it up + // and exit runMailboxLoop() before the thenRun future completes — causing + // checkState(isDone) to fail. With stateConsumedFuture (the default), complete() runs on + // the mailbox thread itself, so thenRun(suspend) blocks the loop from processing the poison + // mail until the future completes — no race. Returning allOf future avoids the issue + // entirely. + CompletableFuture allRecoveredFuture = + CompletableFuture.allOf(recoveredFutures.toArray(new CompletableFuture[0])); + allRecoveredFuture.thenRun(mailboxProcessor::suspend); + return allRecoveredFuture; } private void ensureNotCanceled() { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java index 5985a81e8ca86..f40fd09702ede 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java @@ -28,24 +28,30 @@ import org.junit.jupiter.api.Test; +import java.io.IOException; import java.util.ArrayDeque; import static org.apache.flink.runtime.checkpoint.CheckpointOptions.unaligned; import static org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; /** Tests for {@link RecoveredInputChannel}. */ class RecoveredInputChannelTest { @Test - void testConversionOnlyPossibleAfterConsumed() { - assertThatThrownBy(() -> buildChannel().toInputChannel()) - .isInstanceOf(IllegalStateException.class); + void testConversionOnlyPossibleAfterBufferFilteringComplete() { + // toInputChannel() always checks bufferFilteringCompleteFuture regardless of config + for (boolean configEnabled : new boolean[] {true, false}) { + assertThatThrownBy(() -> buildChannel(configEnabled).toInputChannel()) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("buffer filtering is not complete"); + } } @Test void testRequestPartitionsImpossible() { - assertThatThrownBy(() -> buildChannel().requestSubpartitions()) + assertThatThrownBy(() -> buildChannel(false).requestSubpartitions()) .isInstanceOf(UnsupportedOperationException.class); } @@ -53,7 +59,7 @@ void testRequestPartitionsImpossible() { void testCheckpointStartImpossible() { assertThatThrownBy( () -> - buildChannel() + buildChannel(false) .checkpointStarted( new CheckpointBarrier( 0L, @@ -64,10 +70,96 @@ void testCheckpointStartImpossible() { .isInstanceOf(CheckpointException.class); } - private RecoveredInputChannel buildChannel() { + @Test + void testToInputChannelAllowedWhenBufferFilteringCompleteAndConfigEnabled() throws IOException { + // When config is enabled, conversion is allowed when bufferFilteringCompleteFuture is done + TestableRecoveredInputChannel channel = buildTestableChannel(true); + + // Initially, conversion should fail + assertThatThrownBy(() -> channel.toInputChannel()) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("buffer filtering is not complete"); + + // After finishReadRecoveredState(), bufferFilteringCompleteFuture should be done + channel.finishReadRecoveredState(); + assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); + assertThat(channel.getStateConsumedFuture()).isNotDone(); + + // Conversion should now succeed (no exception) + InputChannel converted = channel.toInputChannel(); + assertThat(converted).isNotNull(); + } + + @Test + void testToInputChannelAllowedWhenStateConsumedAndConfigDisabled() throws IOException { + // When config is disabled, conversion requires both bufferFilteringCompleteFuture + // and stateConsumedFuture to be done + TestableRecoveredInputChannel channel = buildTestableChannel(false); + + // Initially, conversion should fail (buffer filtering not complete) + assertThatThrownBy(() -> channel.toInputChannel()) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("buffer filtering is not complete"); + + // After finishReadRecoveredState(), bufferFilteringCompleteFuture is done + // but stateConsumedFuture is not + channel.finishReadRecoveredState(); + assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); + assertThat(channel.getStateConsumedFuture()).isNotDone(); + + // Conversion should still fail because stateConsumedFuture is not done + assertThatThrownBy(() -> channel.toInputChannel()) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("recovered state is not fully consumed"); + + // Consume the EndOfInputChannelStateEvent to complete stateConsumedFuture + assertThat(channel.getNextBuffer()).isNotPresent(); + assertThat(channel.getStateConsumedFuture()).isDone(); + + // Now conversion should succeed + InputChannel converted = channel.toInputChannel(); + assertThat(converted).isNotNull(); + } + + @Test + void testBufferFilteringCompleteFutureAlwaysCompletes() throws IOException { + // finishReadRecoveredState() unconditionally completes bufferFilteringCompleteFuture + for (boolean configEnabled : new boolean[] {true, false}) { + RecoveredInputChannel channel = buildChannel(configEnabled); + assertThat(channel.getBufferFilteringCompleteFuture()).isNotDone(); + channel.finishReadRecoveredState(); + assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); + } + } + + @Test + void testStateConsumedFutureCompletesAfterConsumingAllBuffers() throws IOException { + // This test verifies that stateConsumedFuture completes after consuming + // EndOfInputChannelStateEvent regardless of the config setting + for (boolean configEnabled : new boolean[] {true, false}) { + RecoveredInputChannel channel = buildChannel(configEnabled); + + assertThat(channel.getStateConsumedFuture()).isNotDone(); + + channel.finishReadRecoveredState(); + assertThat(channel.getStateConsumedFuture()).isNotDone(); + + // Consuming the EndOfInputChannelStateEvent should complete the future. + // getNextBuffer() returns empty when it encounters the event internally. + assertThat(channel.getNextBuffer()).isNotPresent(); + assertThat(channel.getStateConsumedFuture()).isDone(); + } + } + + private RecoveredInputChannel buildChannel(boolean checkpointingDuringRecoveryEnabled) { try { + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setCheckpointingDuringRecoveryEnabled( + checkpointingDuringRecoveryEnabled) + .build(); return new RecoveredInputChannel( - new SingleInputGateBuilder().build(), + inputGate, 0, new ResultPartitionID(), new ResultSubpartitionIndexSet(0), @@ -85,4 +177,41 @@ protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffer throw new AssertionError("channel creation failed", e); } } + + private TestableRecoveredInputChannel buildTestableChannel( + boolean checkpointingDuringRecoveryEnabled) { + try { + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setCheckpointingDuringRecoveryEnabled( + checkpointingDuringRecoveryEnabled) + .build(); + return new TestableRecoveredInputChannel(inputGate); + } catch (Exception e) { + throw new AssertionError("channel creation failed", e); + } + } + + /** + * A RecoveredInputChannel that returns a TestInputChannel when converted, for testing purposes. + */ + private static class TestableRecoveredInputChannel extends RecoveredInputChannel { + TestableRecoveredInputChannel(SingleInputGate inputGate) { + super( + inputGate, + 0, + new ResultPartitionID(), + new ResultSubpartitionIndexSet(0), + 0, + 0, + new SimpleCounter(), + new SimpleCounter(), + 10); + } + + @Override + protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) { + return new TestInputChannel(inputGate, 0); + } + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java index e4a4c289dc6e8..a4da811f8a32a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java @@ -83,6 +83,8 @@ public class SingleInputGateBuilder { private TieredStorageConsumerClient tieredStorageConsumerClient = null; + private boolean isCheckpointingDuringRecoveryEnabled = false; + public SingleInputGateBuilder setPartitionProducerStateProvider( PartitionProducerStateProvider partitionProducerStateProvider) { @@ -167,6 +169,11 @@ public SingleInputGateBuilder setTieredStorageConsumerClient( return this; } + public SingleInputGateBuilder setCheckpointingDuringRecoveryEnabled(boolean enabled) { + this.isCheckpointingDuringRecoveryEnabled = enabled; + return this; + } + public SingleInputGate build() { SingleInputGate gate = new SingleInputGate( @@ -195,6 +202,7 @@ public SingleInputGate build() { .toArray(InputChannel[]::new)); } gate.setTieredStorageService(null, tieredStorageConsumerClient, null); + gate.setCheckpointingDuringRecoveryEnabled(isCheckpointingDuringRecoveryEnabled); return gate; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java index f7f0b744fb9fd..b2cc9d7ce3c9c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -142,6 +142,36 @@ void testCheckpointsDeclinedUnlessStateConsumed() { .isInstanceOf(CheckpointException.class); } + @Test + void testBufferFilteringCompleteFutureAggregation() throws Exception { + final NettyShuffleEnvironment environment = createNettyShuffleEnvironment(); + final SingleInputGate inputGate = createInputGate(environment); + try (Closer closer = Closer.create()) { + closer.register(environment::close); + closer.register(inputGate::close); + + // Enable unaligned during recovery for this test so that + // bufferFilteringCompleteFuture is completed by finishReadRecoveredState() + inputGate.setCheckpointingDuringRecoveryEnabled(true); + inputGate.setup(); + + // Initially, the aggregated future should not be completed + assertThat(inputGate.getBufferFilteringCompleteFuture()).isNotDone(); + + // After finishing read recovered state, bufferFilteringCompleteFuture should be + // completed (only when config is enabled) + inputGate.finishReadRecoveredState(); + assertThat(inputGate.getBufferFilteringCompleteFuture()).isDone(); + + // stateConsumedFuture should not be completed until data is consumed + assertThat(inputGate.getStateConsumedFuture()).isNotDone(); + + // Consuming the EndOfInputChannelStateEvent should complete stateConsumedFuture + inputGate.pollNext(); + assertThat(inputGate.getStateConsumedFuture()).isDone(); + } + } + /** * Tests {@link InputGate#setup()} should create the respective {@link BufferPool} and assign * exclusive buffers for {@link RemoteInputChannel}s, but should not request partitions. diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java index 419246137e8c2..1ed1a42a66ea0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java @@ -18,12 +18,14 @@ package org.apache.flink.runtime.io.network.partition.consumer; +import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.io.PullingAsyncDataInput; import org.apache.flink.runtime.io.network.api.StopMode; import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils; import org.apache.flink.runtime.io.network.partition.NoOpResultSubpartitionView; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateTest.TestingResultPartitionManager; import org.junit.jupiter.api.Test; @@ -275,6 +277,62 @@ void testGetChannelWithShiftedGateIndexes() { assertThat(unionInputGate.getChannel(1)).isEqualTo(inputChannel2); } + @Test + void testBufferFilteringCompleteFutureAggregation() throws IOException { + // Create 2 SingleInputGates, each with 1 RecoveredInputChannel + SingleInputGate ig1 = + new SingleInputGateBuilder().setCheckpointingDuringRecoveryEnabled(true).build(); + RecoveredInputChannel channel1 = buildRecoveredChannel(ig1); + ig1.setInputChannels(channel1); + + SingleInputGate ig2 = + new SingleInputGateBuilder() + .setSingleInputGateIndex(1) + .setCheckpointingDuringRecoveryEnabled(true) + .build(); + RecoveredInputChannel channel2 = buildRecoveredChannel(ig2); + ig2.setInputChannels(channel2); + + UnionInputGate union = new UnionInputGate(ig1, ig2); + + // Initially, bufferFilteringCompleteFuture should not be done + assertThat(union.getBufferFilteringCompleteFuture()).isNotDone(); + assertThat(union.getStateConsumedFuture()).isNotDone(); + + // Complete buffer filtering on first gate only + channel1.finishReadRecoveredState(); + assertThat(ig1.getBufferFilteringCompleteFuture()).isDone(); + assertThat(union.getBufferFilteringCompleteFuture()).isNotDone(); + + // Complete buffer filtering on second gate + channel2.finishReadRecoveredState(); + assertThat(ig2.getBufferFilteringCompleteFuture()).isDone(); + assertThat(union.getBufferFilteringCompleteFuture()).isDone(); + + // State consumed futures should still NOT be done (state not consumed yet) + assertThat(union.getStateConsumedFuture()).isNotDone(); + } + + private static RecoveredInputChannel buildRecoveredChannel(SingleInputGate inputGate) { + return new RecoveredInputChannel( + inputGate, + 0, + new ResultPartitionID(), + new ResultSubpartitionIndexSet(0), + 0, + 0, + new SimpleCounter(), + new SimpleCounter(), + 10) { + @Override + protected InputChannel toInputChannelInternal( + java.util.ArrayDeque + remainingBuffers) { + throw new UnsupportedOperationException(); + } + }; + } + @Test void testEmptyPull() throws IOException, InterruptedException { final SingleInputGate inputGate1 = createInputGate(1); diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java index 618f31e518f8b..584aeb7eb9089 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java @@ -57,6 +57,11 @@ public CompletableFuture getStateConsumedFuture() { return CompletableFuture.completedFuture(null); } + @Override + public CompletableFuture getBufferFilteringCompleteFuture() { + return CompletableFuture.completedFuture(null); + } + @Override public void finishReadRecoveredState() {} diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java index a35c8995d2a29..71b2c43f3306a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java @@ -80,6 +80,11 @@ public CompletableFuture getStateConsumedFuture() { return CompletableFuture.completedFuture(null); } + @Override + public CompletableFuture getBufferFilteringCompleteFuture() { + return CompletableFuture.completedFuture(null); + } + @Override public void finishReadRecoveredState() {} diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java index da3d25519345d..619873c387d08 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java @@ -263,6 +263,11 @@ public CompletableFuture getStateConsumedFuture() { return CompletableFuture.completedFuture(null); } + @Override + public CompletableFuture getBufferFilteringCompleteFuture() { + return CompletableFuture.completedFuture(null); + } + @Override public void finishReadRecoveredState() {}