Skip to content

Commit 95eecfd

Browse files
committed
[FLINK-38543][network] Buffer migration from RecoveredInputChannel to physical channels
1 parent 4fa25ef commit 95eecfd

13 files changed

Lines changed: 203 additions & 26 deletions

File tree

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ public LocalInputChannel(
9898
int maxBackoff,
9999
Counter numBytesIn,
100100
Counter numBuffersIn,
101-
ChannelStateWriter stateWriter) {
101+
ChannelStateWriter stateWriter,
102+
@Nullable ArrayDeque<Buffer> initialRecoveredBuffers) {
102103

103104
super(
104105
inputGate,
@@ -113,6 +114,31 @@ public LocalInputChannel(
113114
this.partitionManager = checkNotNull(partitionManager);
114115
this.taskEventPublisher = checkNotNull(taskEventPublisher);
115116
this.channelStatePersister = new ChannelStatePersister(stateWriter, getChannelInfo());
117+
118+
// Migrate recovered buffers from RecoveredInputChannel if provided.
119+
// These buffers have been filtered but not yet consumed by the Task.
120+
if (initialRecoveredBuffers != null && !initialRecoveredBuffers.isEmpty()) {
121+
final int expectedCount = initialRecoveredBuffers.size();
122+
// Sequence number starts at Integer.MIN_VALUE, consistent with RecoveredInputChannel.
123+
int seqNum = Integer.MIN_VALUE;
124+
while (!initialRecoveredBuffers.isEmpty()) {
125+
Buffer buffer = initialRecoveredBuffers.poll();
126+
// Determine next data type based on the next buffer in the queue
127+
Buffer.DataType nextDataType =
128+
initialRecoveredBuffers.isEmpty()
129+
? Buffer.DataType.NONE
130+
: initialRecoveredBuffers.peek().getDataType();
131+
// buffersInBacklog is set to 0 as these are recovered buffers
132+
BufferAndBacklog bufferAndBacklog =
133+
new BufferAndBacklog(buffer, 0, nextDataType, seqNum++);
134+
toBeConsumedBuffers.add(bufferAndBacklog);
135+
}
136+
checkState(
137+
toBeConsumedBuffers.size() == expectedCount,
138+
"Buffer migration failed: expected %s buffers but got %s",
139+
expectedCount,
140+
toBeConsumedBuffers.size());
141+
}
116142
}
117143

118144
// ------------------------------------------------------------------------

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@
1919
package org.apache.flink.runtime.io.network.partition.consumer;
2020

2121
import org.apache.flink.runtime.io.network.TaskEventPublisher;
22+
import org.apache.flink.runtime.io.network.buffer.Buffer;
2223
import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics;
2324
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
2425
import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
2526
import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
2627

28+
import java.util.ArrayDeque;
29+
2730
import static org.apache.flink.util.Preconditions.checkNotNull;
2831

2932
/**
@@ -61,7 +64,7 @@ public class LocalRecoveredInputChannel extends RecoveredInputChannel {
6164
}
6265

6366
@Override
64-
protected InputChannel toInputChannelInternal() {
67+
protected InputChannel toInputChannelInternal(ArrayDeque<Buffer> remainingBuffers) {
6568
return new LocalInputChannel(
6669
inputGate,
6770
getChannelIndex(),
@@ -73,6 +76,7 @@ protected InputChannel toInputChannelInternal() {
7376
maxBackoff,
7477
numBytesIn,
7578
numBuffersIn,
76-
channelStateWriter);
79+
channelStateWriter,
80+
remainingBuffers);
7781
}
7882
}

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,16 @@ public void setChannelStateWriter(ChannelStateWriter channelStateWriter) {
111111
public final InputChannel toInputChannel() throws IOException {
112112
Preconditions.checkState(
113113
stateConsumedFuture.isDone(), "recovered state is not fully consumed");
114-
final InputChannel inputChannel = toInputChannelInternal();
114+
115+
// Extract remaining buffers before conversion.
116+
// These buffers have been filtered but not yet consumed by the Task.
117+
final ArrayDeque<Buffer> remainingBuffers;
118+
synchronized (receivedBuffers) {
119+
remainingBuffers = new ArrayDeque<>(receivedBuffers);
120+
receivedBuffers.clear();
121+
}
122+
123+
final InputChannel inputChannel = toInputChannelInternal(remainingBuffers);
115124
inputChannel.checkpointStopped(lastStoppedCheckpointId);
116125
return inputChannel;
117126
}
@@ -121,7 +130,15 @@ public void checkpointStopped(long checkpointId) {
121130
this.lastStoppedCheckpointId = checkpointId;
122131
}
123132

124-
protected abstract InputChannel toInputChannelInternal() throws IOException;
133+
/**
134+
* Creates the physical InputChannel from this recovered channel.
135+
*
136+
* @param remainingBuffers buffers that have been filtered but not yet consumed by the Task.
137+
* These buffers will be migrated to the new physical channel.
138+
* @return the physical InputChannel (LocalInputChannel or RemoteInputChannel)
139+
*/
140+
protected abstract InputChannel toInputChannelInternal(ArrayDeque<Buffer> remainingBuffers)
141+
throws IOException;
125142

126143
CompletableFuture<?> getStateConsumedFuture() {
127144
return stateConsumedFuture;

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ public RemoteInputChannel(
138138
int networkBuffersPerChannel,
139139
Counter numBytesIn,
140140
Counter numBuffersIn,
141-
ChannelStateWriter stateWriter) {
141+
ChannelStateWriter stateWriter,
142+
@Nullable ArrayDeque<Buffer> initialRecoveredBuffers) {
142143

143144
super(
144145
inputGate,
@@ -157,6 +158,29 @@ public RemoteInputChannel(
157158
this.connectionManager = checkNotNull(connectionManager);
158159
this.bufferManager = new BufferManager(inputGate.getMemorySegmentProvider(), this, 0);
159160
this.channelStatePersister = new ChannelStatePersister(stateWriter, getChannelInfo());
161+
162+
// Migrate recovered buffers from RecoveredInputChannel if provided.
163+
// These buffers have been filtered but not yet consumed by the Task.
164+
if (initialRecoveredBuffers != null && !initialRecoveredBuffers.isEmpty()) {
165+
final int expectedCount = initialRecoveredBuffers.size();
166+
// Sequence number starts at Integer.MIN_VALUE, consistent with RecoveredInputChannel.
167+
int seqNum = Integer.MIN_VALUE;
168+
for (Buffer buffer : initialRecoveredBuffers) {
169+
// subpartitionId is set to 0 for recovered buffers. This is correct because:
170+
// 1) For single-subpartition channels, the only valid subpartition is 0.
171+
// 2) For multi-subpartition channels (consumedSubpartitionIndexSet.size() > 1),
172+
// RecoveryMetadata events embedded in the recovered buffer sequence track
173+
// the actual subpartition context for proper routing.
174+
SequenceBuffer sequenceBuffer = new SequenceBuffer(buffer, seqNum++, 0);
175+
receivedBuffers.add(sequenceBuffer);
176+
totalQueueSizeInBytes += buffer.getSize();
177+
}
178+
checkState(
179+
receivedBuffers.size() == expectedCount,
180+
"Buffer migration failed: expected %s buffers but got %s",
181+
expectedCount,
182+
receivedBuffers.size());
183+
}
160184
}
161185

162186
@VisibleForTesting
@@ -239,9 +263,9 @@ protected boolean increaseBackoff() {
239263

240264
@Override
241265
protected int peekNextBufferSubpartitionIdInternal() throws IOException {
242-
checkPartitionRequestQueueInitialized();
243-
244266
synchronized (receivedBuffers) {
267+
checkReadability();
268+
245269
final SequenceBuffer next = receivedBuffers.peek();
246270

247271
if (next != null) {
@@ -254,12 +278,12 @@ protected int peekNextBufferSubpartitionIdInternal() throws IOException {
254278

255279
@Override
256280
public Optional<BufferAndAvailability> getNextBuffer() throws IOException {
257-
checkPartitionRequestQueueInitialized();
258-
259281
final SequenceBuffer next;
260282
final DataType nextDataType;
261283

262284
synchronized (receivedBuffers) {
285+
checkReadability();
286+
263287
next = receivedBuffers.poll();
264288

265289
if (next != null) {
@@ -879,6 +903,20 @@ public void onError(Throwable cause) {
879903
setError(cause);
880904
}
881905

906+
/**
907+
* When receivedBuffers contains migrated buffers from RecoveredInputChannel, they can be read
908+
* before requestSubpartitions(). In that case only check for errors. Once migrated buffers are
909+
* drained, require full client initialization check.
910+
*/
911+
private void checkReadability() throws IOException {
912+
assert Thread.holdsLock(receivedBuffers);
913+
if (receivedBuffers.isEmpty()) {
914+
checkPartitionRequestQueueInitialized();
915+
} else {
916+
checkError();
917+
}
918+
}
919+
882920
private void checkPartitionRequestQueueInitialized() throws IOException {
883921
checkError();
884922
checkState(

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020

2121
import org.apache.flink.runtime.io.network.ConnectionID;
2222
import org.apache.flink.runtime.io.network.ConnectionManager;
23+
import org.apache.flink.runtime.io.network.buffer.Buffer;
2324
import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics;
2425
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
2526
import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
2627

2728
import java.io.IOException;
29+
import java.util.ArrayDeque;
2830

2931
import static org.apache.flink.util.Preconditions.checkNotNull;
3032

@@ -66,7 +68,8 @@ public class RemoteRecoveredInputChannel extends RecoveredInputChannel {
6668
}
6769

6870
@Override
69-
protected InputChannel toInputChannelInternal() throws IOException {
71+
protected InputChannel toInputChannelInternal(ArrayDeque<Buffer> remainingBuffers)
72+
throws IOException {
7073
RemoteInputChannel remoteInputChannel =
7174
new RemoteInputChannel(
7275
inputGate,
@@ -81,7 +84,8 @@ protected InputChannel toInputChannelInternal() throws IOException {
8184
networkBuffersPerChannel,
8285
numBytesIn,
8386
numBuffersIn,
84-
channelStateWriter);
87+
channelStateWriter,
88+
remainingBuffers);
8589
remoteInputChannel.setup();
8690
return remoteInputChannel;
8791
}

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,19 @@ public void requestPartitions() {
375375
}
376376
}
377377

378+
/**
379+
* Converts all {@link RecoveredInputChannel}s to their real channel types ({@link
380+
* LocalInputChannel} or {@link RemoteInputChannel}).
381+
*
382+
* <p><b>Lock ordering note:</b> This method acquires {@code inputChannelsWithData} and then may
383+
* indirectly acquire {@code receivedBuffers} (via {@code toInputChannel()} and {@code
384+
* releaseAllResources()}). This is the reverse order of {@link
385+
* RecoveredInputChannel#onRecoveredStateBuffer}, which acquires {@code receivedBuffers} first
386+
* and then {@code inputChannelsWithData} (via {@code notifyChannelNonEmpty()}). This is safe
387+
* because {@code convertRecoveredInputChannels()} is only called from {@link
388+
* #requestPartitions()}, which happens after all state recovery is complete (buffer filtering
389+
* future is done), so {@code onRecoveredStateBuffer()} is no longer being called concurrently.
390+
*/
378391
@VisibleForTesting
379392
public void convertRecoveredInputChannels() {
380393
LOG.debug("Converting recovered input channels ({} channels)", getNumberOfInputChannels());
@@ -384,19 +397,37 @@ public void convertRecoveredInputChannels() {
384397
new HashSet<>(inputChannelsForCurrentPartition.keySet());
385398
for (InputChannelInfo inputChannelInfo : oldInputChannelInfos) {
386399
InputChannel inputChannel = inputChannelsForCurrentPartition.get(inputChannelInfo);
387-
if (inputChannel instanceof RecoveredInputChannel) {
388-
try {
400+
if (!(inputChannel instanceof RecoveredInputChannel)) {
401+
continue;
402+
}
403+
try {
404+
synchronized (inputChannelsWithData) {
405+
// Remove old channel from queue if present
406+
if (inputChannelsWithData.contains(inputChannel)) {
407+
inputChannelsWithData.getAndRemove(ch -> ch == inputChannel);
408+
}
409+
enqueuedInputChannelsWithData.clear(inputChannel.getChannelIndex());
410+
411+
// Convert the channel
389412
InputChannel realInputChannel =
390413
((RecoveredInputChannel) inputChannel).toInputChannel();
391414
inputChannel.releaseAllResources();
415+
416+
// Update data structures
392417
inputChannelsForCurrentPartition.remove(inputChannelInfo);
393418
inputChannelsForCurrentPartition.put(
394419
realInputChannel.getChannelInfo(), realInputChannel);
395420
channels[inputChannel.getChannelIndex()] = realInputChannel;
396-
} catch (Throwable t) {
397-
inputChannel.setError(t);
398-
return;
421+
422+
// If the new channel has buffered data, enqueue it
423+
if (realInputChannel.getBuffersInUseCount() > 0) {
424+
inputChannelsWithData.add(realInputChannel);
425+
enqueuedInputChannelsWithData.set(realInputChannel.getChannelIndex());
426+
}
399427
}
428+
} catch (Throwable t) {
429+
inputChannel.setError(t);
430+
return;
400431
}
401432
}
402433
}

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ public RemoteInputChannel toRemoteInputChannel(
183183
networkBuffersPerChannel,
184184
metrics.getNumBytesInRemoteCounter(),
185185
metrics.getNumBuffersInRemoteCounter(),
186-
channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter);
186+
channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter,
187+
null);
187188
}
188189

189190
public LocalInputChannel toLocalInputChannel(ResultPartitionID resultPartitionID) {
@@ -198,7 +199,8 @@ public LocalInputChannel toLocalInputChannel(ResultPartitionID resultPartitionID
198199
maxBackoff,
199200
metrics.getNumBytesInLocalCounter(),
200201
metrics.getNumBuffersInLocalCounter(),
201-
channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter);
202+
channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter,
203+
null);
202204
}
203205

204206
@Override

flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,8 @@ private static class TestRemoteInputChannelForError extends RemoteInputChannel {
951951
2,
952952
new SimpleCounter(),
953953
new SimpleCounter(),
954-
ChannelStateWriter.NO_OP);
954+
ChannelStateWriter.NO_OP,
955+
null);
955956
this.expectedMessage = expectedMessage;
956957
}
957958

flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ private static class TestRemoteInputChannelForPartitionNotFound extends RemoteIn
248248
2,
249249
new SimpleCounter(),
250250
new SimpleCounter(),
251-
ChannelStateWriter.NO_OP);
251+
ChannelStateWriter.NO_OP,
252+
null);
252253
this.latch = latch;
253254
}
254255

flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ public LocalInputChannel buildLocalChannel(SingleInputGate inputGate) {
164164
maxBackoff,
165165
metrics.getNumBytesInLocalCounter(),
166166
metrics.getNumBuffersInLocalCounter(),
167-
stateWriter);
167+
stateWriter,
168+
null);
168169
}
169170

170171
public RemoteInputChannel buildRemoteChannel(SingleInputGate inputGate) {
@@ -181,7 +182,8 @@ public RemoteInputChannel buildRemoteChannel(SingleInputGate inputGate) {
181182
networkBuffersPerChannel,
182183
metrics.getNumBytesInRemoteCounter(),
183184
metrics.getNumBuffersInRemoteCounter(),
184-
stateWriter);
185+
stateWriter,
186+
null);
185187
}
186188

187189
public LocalRecoveredInputChannel buildLocalRecoveredChannel(SingleInputGate inputGate) {

0 commit comments

Comments
 (0)