Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ void testIsUnalignedCheckpointInterruptibleTimersEnabled() {
}

@Test
void testIsUnalignedDuringRecoveryEnabled() {
void testIsCheckpointingDuringRecoveryEnabled() {
// Test when both options are disabled (default) - should return false
Configuration defaultConfig = new Configuration();
assertThat(CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(defaultConfig))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,5 +192,12 @@ public String toString() {

public abstract CompletableFuture<Void> getStateConsumedFuture();

/**
* Returns a future that completes when buffer filtering is complete for all channels. This
* future completes before {@link #getStateConsumedFuture()}, enabling earlier RUNNING state
* transition when unaligned checkpoint during recovery is enabled.
*/
public abstract CompletableFuture<Void> getBufferFilteringCompleteFuture();
Comment thread
pnowojski marked this conversation as resolved.

public abstract void finishReadRecoveredState() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ public abstract class RecoveredInputChannel extends InputChannel implements Chan
private final CompletableFuture<?> stateConsumedFuture = new CompletableFuture<>();
protected final BufferManager bufferManager;

/**
* Future that completes when recovered buffers have been filtered for this channel. This
* completes before stateConsumedFuture, enabling earlier RUNNING state transition when
* unaligned checkpoint during recovery is enabled.
*/
private final CompletableFuture<Void> bufferFilteringCompleteFuture = new CompletableFuture<>();

@GuardedBy("receivedBuffers")
private boolean isReleased;

Expand Down Expand Up @@ -110,7 +117,11 @@ public void setChannelStateWriter(ChannelStateWriter channelStateWriter) {

public final InputChannel toInputChannel() throws IOException {
Preconditions.checkState(
stateConsumedFuture.isDone(), "recovered state is not fully consumed");
bufferFilteringCompleteFuture.isDone(), "buffer filtering is not complete");
if (!inputGate.isCheckpointingDuringRecoveryEnabled()) {
Preconditions.checkState(
stateConsumedFuture.isDone(), "recovered state is not fully consumed");
}

// Extract remaining buffers before conversion.
// These buffers have been filtered but not yet consumed by the Task.
Expand Down Expand Up @@ -140,6 +151,14 @@ public void checkpointStopped(long checkpointId) {
protected abstract InputChannel toInputChannelInternal(ArrayDeque<Buffer> remainingBuffers)
throws IOException;

/**
* Returns the future that completes when buffer filtering is complete. This future completes
* before stateConsumedFuture, at the point when finishReadRecoveredState() is called.
*/
CompletableFuture<Void> getBufferFilteringCompleteFuture() {
return bufferFilteringCompleteFuture;
}

CompletableFuture<?> getStateConsumedFuture() {
return stateConsumedFuture;
}
Expand Down Expand Up @@ -176,8 +195,22 @@ public void onRecoveredStateBuffer(Buffer buffer) {
}

public void finishReadRecoveredState() throws IOException {
onRecoveredStateBuffer(
EventSerializer.toBuffer(EndOfInputChannelStateEvent.INSTANCE, false));
// Adding the event and completing the future must be atomic under receivedBuffers lock.
// Without this, either ordering has a race:
// - event first: task thread consumes EndOfInputChannelStateEvent, which completes
// stateConsumedFuture. When checkpointing during recovery is disabled,
// stateConsumedFuture triggers requestPartitions -> toInputChannel(), which
// fails because bufferFilteringCompleteFuture is not yet done.
// - future first: toInputChannel() extracts buffers before the event is added,
// losing the EndOfInputChannelStateEvent.
// Both toInputChannel() and getNextRecoveredStateBuffer() synchronize on
// receivedBuffers, so holding the same lock here guarantees
// bufferFilteringCompleteFuture is always done before stateConsumedFuture.
synchronized (receivedBuffers) {
onRecoveredStateBuffer(
EventSerializer.toBuffer(EndOfInputChannelStateEvent.INSTANCE, false));
bufferFilteringCompleteFuture.complete(null);
}
bufferManager.releaseFloatingBuffers();
LOG.debug("{}/{} finished recovering input.", inputGate.getOwningTaskName(), channelInfo);
}
Expand All @@ -196,6 +229,8 @@ private BufferAndAvailability getNextRecoveredStateBuffer() throws IOException {
if (next == null) {
return null;
} else if (isEndOfInputChannelStateEvent(next)) {
Preconditions.checkState(
bufferFilteringCompleteFuture.isDone(), "buffer filtering is not complete");
stateConsumedFuture.complete(null);
return null;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,21 @@ public boolean isCheckpointingDuringRecoveryEnabled() {
return checkpointingDuringRecoveryEnabled;
}

@Override
public CompletableFuture<Void> getBufferFilteringCompleteFuture() {
synchronized (requestLock) {
List<CompletableFuture<?>> futures = new ArrayList<>(numberOfInputChannels);
for (InputChannel inputChannel : inputChannels()) {
if (inputChannel instanceof RecoveredInputChannel) {
futures.add(
((RecoveredInputChannel) inputChannel)
.getBufferFilteringCompleteFuture());
}
}
return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
}
}

@Override
public void requestPartitions() {
synchronized (requestLock) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,15 @@ public CompletableFuture<Void> getStateConsumedFuture() {
.toArray(new CompletableFuture[] {}));
}

@Override
public CompletableFuture<Void> getBufferFilteringCompleteFuture() {
return CompletableFuture.allOf(
inputGatesByGateIndex.values().stream()
.map(InputGate::getBufferFilteringCompleteFuture)
.collect(Collectors.toList())
.toArray(new CompletableFuture[] {}));
}

@Override
public void requestPartitions() throws IOException {
for (InputGate inputGate : inputGatesByGateIndex.values()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ public CompletableFuture<Void> getStateConsumedFuture() {
return inputGate.getStateConsumedFuture();
}

@Override
public CompletableFuture<Void> getBufferFilteringCompleteFuture() {
return inputGate.getBufferFilteringCompleteFuture();
}

@Override
public void requestPartitions() throws IOException {
inputGate.requestPartitions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,9 @@ private CompletableFuture<Void> restoreStateAndGates(

boolean checkpointingDuringRecoveryEnabled =
CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration());

// Must set the flag on input gates BEFORE starting the async read task, because
// finishReadRecoveredState() checks this flag to complete bufferFilteringCompleteFuture.
for (IndexedInputGate inputGate : inputGates) {
inputGate.setCheckpointingDuringRecoveryEnabled(checkpointingDuringRecoveryEnabled);
}
Expand All @@ -899,22 +902,37 @@ private CompletableFuture<Void> restoreStateAndGates(

// We wait for all input channel state to recover before we go into RUNNING state, and thus
// start checkpointing. If we implement incremental checkpointing of input channel state
// we must make sure it supports CheckpointType#FULL_CHECKPOINT
// we must make sure it supports CheckpointType#FULL_CHECKPOINT.
List<CompletableFuture<?>> recoveredFutures = new ArrayList<>(inputGates.length);
for (InputGate inputGate : inputGates) {
recoveredFutures.add(inputGate.getStateConsumedFuture());

inputGate
.getStateConsumedFuture()
.thenRun(
() ->
mainMailboxExecutor.execute(
inputGate::requestPartitions,
"Input gate request partitions"));
CompletableFuture<?> requestPartitionsTrigger =
checkpointingDuringRecoveryEnabled
? inputGate.getBufferFilteringCompleteFuture()
: inputGate.getStateConsumedFuture();

recoveredFutures.add(requestPartitionsTrigger);

requestPartitionsTrigger.thenRun(
() ->
mainMailboxExecutor.execute(
inputGate::requestPartitions, "Input gate request partitions"));
}

return CompletableFuture.allOf(recoveredFutures.toArray(new CompletableFuture[0]))
.thenRun(mailboxProcessor::suspend);
// Return allOf future instead of thenRun future. thenRun() returns a NEW future that
// completes only after the callback finishes. CompletableFuture executes thenRun callbacks
// synchronously on the thread that calls complete(). When recoveredFutures contains
// bufferFilteringCompleteFuture (checkpointingDuringRecovery enabled), complete() is called
// on channelIOExecutor (in finishReadRecoveredState), so thenRun(suspend) also runs on
// channelIOExecutor. suspend() sends a poison mail, and the mailbox thread can pick it up
// and exit runMailboxLoop() before the thenRun future completes — causing
// checkState(isDone) to fail. With stateConsumedFuture (the default), complete() runs on
// the mailbox thread itself, so thenRun(suspend) blocks the loop from processing the poison
// mail until the future completes — no race. Returning allOf future avoids the issue
// entirely.
CompletableFuture<Void> allRecoveredFuture =
CompletableFuture.allOf(recoveredFutures.toArray(new CompletableFuture[0]));
allRecoveredFuture.thenRun(mailboxProcessor::suspend);
return allRecoveredFuture;
}

private void ensureNotCanceled() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,38 @@

import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.util.ArrayDeque;

import static org.apache.flink.runtime.checkpoint.CheckpointOptions.unaligned;
import static org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/** Tests for {@link RecoveredInputChannel}. */
class RecoveredInputChannelTest {

@Test
void testConversionOnlyPossibleAfterConsumed() {
assertThatThrownBy(() -> buildChannel().toInputChannel())
.isInstanceOf(IllegalStateException.class);
void testConversionOnlyPossibleAfterBufferFilteringComplete() {
// toInputChannel() always checks bufferFilteringCompleteFuture regardless of config
for (boolean configEnabled : new boolean[] {true, false}) {
assertThatThrownBy(() -> buildChannel(configEnabled).toInputChannel())
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("buffer filtering is not complete");
}
}

@Test
void testRequestPartitionsImpossible() {
assertThatThrownBy(() -> buildChannel().requestSubpartitions())
assertThatThrownBy(() -> buildChannel(false).requestSubpartitions())
.isInstanceOf(UnsupportedOperationException.class);
}

@Test
void testCheckpointStartImpossible() {
assertThatThrownBy(
() ->
buildChannel()
buildChannel(false)
.checkpointStarted(
new CheckpointBarrier(
0L,
Expand All @@ -64,10 +70,96 @@ void testCheckpointStartImpossible() {
.isInstanceOf(CheckpointException.class);
}

private RecoveredInputChannel buildChannel() {
@Test
void testToInputChannelAllowedWhenBufferFilteringCompleteAndConfigEnabled() throws IOException {
// When config is enabled, conversion is allowed when bufferFilteringCompleteFuture is done
TestableRecoveredInputChannel channel = buildTestableChannel(true);

// Initially, conversion should fail
assertThatThrownBy(() -> channel.toInputChannel())
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("buffer filtering is not complete");

// After finishReadRecoveredState(), bufferFilteringCompleteFuture should be done
channel.finishReadRecoveredState();
assertThat(channel.getBufferFilteringCompleteFuture()).isDone();
assertThat(channel.getStateConsumedFuture()).isNotDone();

// Conversion should now succeed (no exception)
InputChannel converted = channel.toInputChannel();
assertThat(converted).isNotNull();
}

@Test
void testToInputChannelAllowedWhenStateConsumedAndConfigDisabled() throws IOException {
// When config is disabled, conversion requires both bufferFilteringCompleteFuture
// and stateConsumedFuture to be done
TestableRecoveredInputChannel channel = buildTestableChannel(false);

// Initially, conversion should fail (buffer filtering not complete)
assertThatThrownBy(() -> channel.toInputChannel())
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("buffer filtering is not complete");

// After finishReadRecoveredState(), bufferFilteringCompleteFuture is done
// but stateConsumedFuture is not
channel.finishReadRecoveredState();
assertThat(channel.getBufferFilteringCompleteFuture()).isDone();
assertThat(channel.getStateConsumedFuture()).isNotDone();

// Conversion should still fail because stateConsumedFuture is not done
assertThatThrownBy(() -> channel.toInputChannel())
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("recovered state is not fully consumed");

// Consume the EndOfInputChannelStateEvent to complete stateConsumedFuture
assertThat(channel.getNextBuffer()).isNotPresent();
assertThat(channel.getStateConsumedFuture()).isDone();

// Now conversion should succeed
InputChannel converted = channel.toInputChannel();
assertThat(converted).isNotNull();
}

@Test
void testBufferFilteringCompleteFutureAlwaysCompletes() throws IOException {
// finishReadRecoveredState() unconditionally completes bufferFilteringCompleteFuture
for (boolean configEnabled : new boolean[] {true, false}) {
RecoveredInputChannel channel = buildChannel(configEnabled);
assertThat(channel.getBufferFilteringCompleteFuture()).isNotDone();
channel.finishReadRecoveredState();
assertThat(channel.getBufferFilteringCompleteFuture()).isDone();
}
}

@Test
void testStateConsumedFutureCompletesAfterConsumingAllBuffers() throws IOException {
// This test verifies that stateConsumedFuture completes after consuming
// EndOfInputChannelStateEvent regardless of the config setting
for (boolean configEnabled : new boolean[] {true, false}) {
RecoveredInputChannel channel = buildChannel(configEnabled);

assertThat(channel.getStateConsumedFuture()).isNotDone();

channel.finishReadRecoveredState();
assertThat(channel.getStateConsumedFuture()).isNotDone();

// Consuming the EndOfInputChannelStateEvent should complete the future.
// getNextBuffer() returns empty when it encounters the event internally.
assertThat(channel.getNextBuffer()).isNotPresent();
assertThat(channel.getStateConsumedFuture()).isDone();
}
}

private RecoveredInputChannel buildChannel(boolean checkpointingDuringRecoveryEnabled) {
try {
SingleInputGate inputGate =
new SingleInputGateBuilder()
.setCheckpointingDuringRecoveryEnabled(
checkpointingDuringRecoveryEnabled)
.build();
return new RecoveredInputChannel(
new SingleInputGateBuilder().build(),
inputGate,
0,
new ResultPartitionID(),
new ResultSubpartitionIndexSet(0),
Expand All @@ -85,4 +177,41 @@ protected InputChannel toInputChannelInternal(ArrayDeque<Buffer> remainingBuffer
throw new AssertionError("channel creation failed", e);
}
}

private TestableRecoveredInputChannel buildTestableChannel(
boolean checkpointingDuringRecoveryEnabled) {
try {
SingleInputGate inputGate =
new SingleInputGateBuilder()
.setCheckpointingDuringRecoveryEnabled(
checkpointingDuringRecoveryEnabled)
.build();
return new TestableRecoveredInputChannel(inputGate);
} catch (Exception e) {
throw new AssertionError("channel creation failed", e);
}
}

/**
* A RecoveredInputChannel that returns a TestInputChannel when converted, for testing purposes.
*/
private static class TestableRecoveredInputChannel extends RecoveredInputChannel {
TestableRecoveredInputChannel(SingleInputGate inputGate) {
super(
inputGate,
0,
new ResultPartitionID(),
new ResultSubpartitionIndexSet(0),
0,
0,
new SimpleCounter(),
new SimpleCounter(),
10);
}

@Override
protected InputChannel toInputChannelInternal(ArrayDeque<Buffer> remainingBuffers) {
return new TestInputChannel(inputGate, 0);
}
}
}
Loading