From 51546796aa39c55732e0f5de41d76f3e15273f9f Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Tue, 2 Jun 2026 05:35:05 +0000 Subject: [PATCH 01/21] [Dataflow Streaming][Multikey] Support MultiKey commits in windmill clients - Add MultiKeyWorkItemCommitRequest to windmill.proto. - Support MultiKey commits in Commit model and StreamingEngineWorkCommitter. - Update GrpcCommitWorkStream to batch and stream MultiKey commit requests. --- .../windmill/client/WindmillStream.java | 5 + .../windmill/client/commits/Commit.java | 34 ++- .../client/commits/CompleteCommit.java | 15 - .../StreamingApplianceWorkCommitter.java | 8 +- .../commits/StreamingEngineWorkCommitter.java | 66 +++-- .../client/grpc/GrpcCommitWorkStream.java | 67 +++-- .../dataflow/worker/FakeWindmillServer.java | 71 ++++- .../StreamingApplianceWorkCommitterTest.java | 9 +- .../StreamingEngineWorkCommitterTest.java | 270 ++++++++++++++++-- .../client/grpc/GrpcCommitWorkStreamTest.java | 70 +++++ .../windmill/src/main/proto/windmill.proto | 28 ++ 11 files changed, 558 insertions(+), 85 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index 526b67890783..36001c151508 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -108,6 +108,11 @@ boolean commitWorkItem( Windmill.WorkItemCommitRequest request, Consumer onDone); + boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone); + /** Flushes any pending work items to the wire. */ void flush(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java index b840d22a3434..e52a9846645f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java @@ -18,11 +18,14 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.commits; import com.google.auto.value.AutoValue; +import java.util.Optional; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; /** Value class for a queued commit. */ @Internal @@ -32,20 +35,43 @@ public abstract class Commit { public static Commit create( WorkItemCommitRequest request, ComputationState computationState, Work work) { Preconditions.checkArgument(request.getSerializedSize() > 0); - return new AutoValue_Commit(request, computationState, work); + return new AutoValue_Commit( + Optional.of(request), computationState, Optional.empty(), ImmutableList.of(work)); + } + + public static Commit createMultiKey( + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest, + ComputationState computationState, + ImmutableList workBatch) { + Preconditions.checkArgument(!workBatch.isEmpty()); + return new AutoValue_Commit( + Optional.empty(), computationState, Optional.of(multiKeyRequest), workBatch); } public final String computationId() { return computationState().getComputationId(); } - public abstract WorkItemCommitRequest request(); + public abstract Optional singleKeyRequest(); public abstract ComputationState computationState(); - public abstract Work work(); + public abstract Optional multiKeyRequest(); + + public abstract ImmutableList workBatch(); + + public final boolean isFailed() { + for (Work w : workBatch()) { + if (w.isFailed()) { + return true; + } + } + return false; + } public final int getSize() { - return request().getSerializedSize(); + return multiKeyRequest() + .map(Windmill.MultiKeyWorkItemCommitRequest::getSerializedSize) + .orElseGet(() -> singleKeyRequest().get().getSerializedSize()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java index e33e853d3d76..6c0a5a98e2ab 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java @@ -37,26 +37,11 @@ @AutoValue public abstract class CompleteCommit { - public static CompleteCommit create(Commit commit, CommitStatus commitStatus) { - return new AutoValue_CompleteCommit( - commit.computationId(), - ShardedKey.create(commit.request().getKey(), commit.request().getShardingKey()), - WorkId.builder() - .setWorkToken(commit.request().getWorkToken()) - .setCacheToken(commit.request().getCacheToken()) - .build(), - commitStatus); - } - public static CompleteCommit create( String computationId, ShardedKey shardedKey, WorkId workId, CommitStatus status) { return new AutoValue_CompleteCommit(computationId, shardedKey, workId, status); } - public static CompleteCommit forFailedWork(Commit commit) { - return create(commit, CommitStatus.ABORTED); - } - public abstract String computationId(); public abstract ShardedKey shardedKey(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java index 20b95b0661d0..40e82c4ca368 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.commits; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutorService; @@ -112,7 +114,8 @@ private void commitLoop() { } while (commit != null) { ComputationState computationState = commit.computationState(); - commit.work().setState(Work.State.COMMITTING); + checkState(commit.workBatch().size() == 1); + commit.workBatch().get(0).setState(Work.State.COMMITTING); Windmill.ComputationCommitWorkRequest.Builder computationRequestBuilder = computationRequestMap.get(computationState); if (computationRequestBuilder == null) { @@ -120,7 +123,8 @@ private void commitLoop() { computationRequestBuilder.setComputationId(computationState.getComputationId()); computationRequestMap.put(computationState, computationRequestBuilder); } - computationRequestBuilder.addRequests(commit.request()); + checkState(commit.singleKeyRequest().isPresent()); + computationRequestBuilder.addRequests(commit.singleKeyRequest().get()); // Send the request if we've exceeded the bytes or there is no more // pending work. commitBytes is a long, so this cannot overflow. commitBytes += commit.getSize(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java index b68f53121b86..cb8e6d26d089 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java @@ -30,6 +30,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.sdk.annotations.Internal; @@ -100,7 +101,7 @@ public void start() { @Override public void commit(Commit commit) { - if (commit.work().isFailed()) { + if (commit.isFailed()) { failCommit(commit); } else { commitQueue.put(commit); @@ -113,8 +114,8 @@ public void commit(Commit commit) { "Trying to queue commit on shutdown, failing commit=[computationId={}, shardingKey={}," + " workId={} ].", commit.computationId(), - commit.work().getShardedKey(), - commit.work().id()); + commit.workBatch().get(0).getShardedKey(), + commit.workBatch().get(0).id()); drainCommitQueue(); } } @@ -147,8 +148,12 @@ private void drainCommitQueue() { } private void failCommit(Commit commit) { - commit.work().setFailed(); - onCommitComplete.accept(CompleteCommit.forFailedWork(commit)); + for (Work w : commit.workBatch()) { + w.setFailed(); + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), w.getShardedKey(), w.id(), CommitStatus.ABORTED)); + } } @Override @@ -173,8 +178,8 @@ private void streamingCommitLoop() { // take() blocks until a value is available in the commitQueue. Preconditions.checkNotNull(initialCommit); - if (initialCommit.work().isFailed()) { - onCommitComplete.accept(CompleteCommit.forFailedWork(initialCommit)); + if (initialCommit.isFailed()) { + failCommit(initialCommit); initialCommit = null; continue; } @@ -202,20 +207,43 @@ private void streamingCommitLoop() { /** Adds the commit to the batch if it fits, returning true if it is consumed. */ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatcher batcher) { Preconditions.checkNotNull(commit); - commit.work().setState(Work.State.COMMITTING); + for (Work w : commit.workBatch()) { + w.setState(Work.State.COMMITTING); + } activeCommitBytes.addAndGet(commit.getSize()); - boolean isCommitAccepted = - batcher.commitWorkItem( - commit.computationId(), - commit.request(), - commitStatus -> { - onCommitComplete.accept(CompleteCommit.create(commit, commitStatus)); - activeCommitBytes.addAndGet(-commit.getSize()); - }); + boolean isCommitAccepted; + if (commit.multiKeyRequest().isPresent()) { + isCommitAccepted = + batcher.commitMultiKeyWorkItem( + commit.computationId(), + commit.multiKeyRequest().get(), + commitStatus -> { + for (Work w : commit.workBatch()) { + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), w.getShardedKey(), w.id(), commitStatus)); + } + activeCommitBytes.addAndGet(-commit.getSize()); + }); + } else { + isCommitAccepted = + batcher.commitWorkItem( + commit.computationId(), + commit.singleKeyRequest().get(), + commitStatus -> { + Work w = commit.workBatch().get(0); + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), w.getShardedKey(), w.id(), commitStatus)); + activeCommitBytes.addAndGet(-commit.getSize()); + }); + } // Since the commit was not accepted, revert the changes made above. if (!isCommitAccepted) { - commit.work().setState(Work.State.COMMIT_QUEUED); + for (Work w : commit.workBatch()) { + w.setState(Work.State.COMMIT_QUEUED); + } activeCommitBytes.addAndGet(-commit.getSize()); } @@ -246,8 +274,8 @@ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatch } // Drop commits for failed work. Such commits will be dropped by Windmill anyway. - if (commit.work().isFailed()) { - onCommitComplete.accept(CompleteCommit.forFailedWork(commit)); + if (commit.isFailed()) { + failCommit(commit); continue; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index d24676652fd8..afa736d7c3ad 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -35,6 +35,7 @@ import java.util.function.Function; import javax.annotation.Nullable; import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitRequestChunk; @@ -270,7 +271,7 @@ private void flushInternal(Map requests) if (requests.size() == 1) { Map.Entry elem = requests.entrySet().iterator().next(); - if (elem.getValue().request().getSerializedSize() + if (elem.getValue().serializedCommit().size() > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { issueMultiChunkRequest(elem.getKey(), elem.getValue()); } else { @@ -289,6 +290,7 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest) .setComputationId(pendingRequest.computationId()) .setRequestId(id) .setShardingKey(pendingRequest.shardingKey()) + .setCommitType(pendingRequest.commitType()) .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); StreamingCommitWorkRequest chunk = requestBuilder.build(); synchronized (this) { @@ -318,7 +320,8 @@ private void issueBatchedRequest(Map requests) chunkBuilder .setRequestId(entry.getKey()) .setShardingKey(request.shardingKey()) - .setSerializedWorkItemCommit(request.serializedCommit()); + .setSerializedWorkItemCommit(request.serializedCommit()) + .setCommitType(request.commitType()); } StreamingCommitWorkRequest request = requestBuilder.build(); synchronized (this) { @@ -360,7 +363,8 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) .setRequestId(id) .setSerializedWorkItemCommit(chunk) .setComputationId(pendingRequest.computationId()) - .setShardingKey(pendingRequest.shardingKey()); + .setShardingKey(pendingRequest.shardingKey()) + .setCommitType(pendingRequest.commitType()); int remaining = serializedCommit.size() - end; if (remaining > 0) { chunkBuilder.setRemainingBytesForWorkItem(remaining); @@ -378,34 +382,34 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) @AutoValue abstract static class PendingRequest { - - private static PendingRequest create( - String computationId, WorkItemCommitRequest request, Consumer onDone) { - return new AutoValue_GrpcCommitWorkStream_PendingRequest(computationId, request, onDone); + static PendingRequest create( + String computationId, + long shardingKey, + ByteString serializedCommit, + StreamingCommitRequestChunk.CommitType commitType, + Consumer onDone) { + return new AutoValue_GrpcCommitWorkStream_PendingRequest( + computationId, shardingKey, serializedCommit, commitType, onDone); } abstract String computationId(); - abstract WorkItemCommitRequest request(); + abstract long shardingKey(); + + abstract ByteString serializedCommit(); + + abstract StreamingCommitRequestChunk.CommitType commitType(); abstract Consumer onDone(); private long getBytes() { - return (long) request().getSerializedSize() + computationId().length(); - } - - private ByteString serializedCommit() { - return request().toByteString(); + return (long) serializedCommit().size() + computationId().length(); } private void completeWithStatus(CommitStatus commitStatus) { onDone().accept(commitStatus); } - private long shardingKey() { - return request().getShardingKey(); - } - private void abort() { completeWithStatus(CommitStatus.ABORTED); } @@ -462,7 +466,34 @@ public boolean commitWorkItem( return false; } - PendingRequest request = PendingRequest.create(computation, commitRequest, onDone); + PendingRequest request = + PendingRequest.create( + computation, + commitRequest.getShardingKey(), + commitRequest.toByteString(), + StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_SINGLE_KEY, + onDone); + add(idGenerator.incrementAndGet(), request); + return true; + } + + @Override + public boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest commitRequest, + Consumer onDone) { + if (!canAccept(commitRequest.getSerializedSize() + computation.length())) { + return false; + } + Preconditions.checkArgument(commitRequest.getRequestsCount() > 0); + PendingRequest request = + PendingRequest.create( + computation, + // Any key in the batch for routing + commitRequest.getRequests(0).getShardingKey(), + commitRequest.toByteString(), + StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY, + onDone); add(idGenerator.incrementAndGet(), request); return true; } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index 5be8ec0a6c72..eec77ccf435b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -29,7 +29,6 @@ import java.util.ArrayList; import java.util.Collection; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -37,6 +36,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -89,6 +89,8 @@ public final class FakeWindmillServer extends WindmillServerStub { private final Map streamingCommitsToOffer; // Keys are work tokens. private final Map commitsReceived; + private final List multiKeyCommitsReceived = + new CopyOnWriteArrayList<>(); private final ArrayList statsReceived; private final LinkedBlockingQueue exceptions; private final AtomicInteger expectedExceptionCount; @@ -118,7 +120,7 @@ public FakeWindmillServer( commitsToOffer = new ResponseQueue() .returnByDefault(CommitWorkResponse.getDefaultInstance()); - streamingCommitsToOffer = new HashMap<>(); + streamingCommitsToOffer = new ConcurrentHashMap<>(); commitsReceived = new ConcurrentHashMap<>(); exceptions = new LinkedBlockingQueue<>(); expectedExceptionCount = new AtomicInteger(); @@ -400,6 +402,7 @@ public void shutdown() {} public RequestBatcher batcher() { return new RequestBatcher() { final List requests = new ArrayList<>(); + final List multiKeyRequests = new ArrayList<>(); @Override public boolean commitWorkItem( @@ -423,6 +426,18 @@ public boolean commitWorkItem( return true; } + @Override + public boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone) { + LOG.debug("commitWorkStream::commitMultiKeyWorkItem: {}", request); + if (multiKeyRequests.size() > 5) return false; + multiKeyRequests.add(new MultiKeyRequestAndDone(request, onDone)); + flush(); + return true; + } + @Override public void flush() { for (RequestAndDone elem : requests) { @@ -445,6 +460,37 @@ public void flush() { .orElse(Windmill.CommitStatus.OK)); } requests.clear(); + + for (MultiKeyRequestAndDone elem : multiKeyRequests) { + if (dropStreamingCommits) { + for (WorkItemCommitRequest workRequest : elem.request.getRequestsList()) { + droppedStreamingCommits.put(workRequest.getWorkToken(), elem.onDone); + } + continue; + } + + multiKeyCommitsReceived.add(elem.request); + for (WorkItemCommitRequest workRequest : elem.request.getRequestsList()) { + commitsReceived.put(workRequest.getWorkToken(), workRequest); + } + + // Determine status for the batch. + // Default to OK, but if any of the works in the batch has an offered status, use it. + Windmill.CommitStatus status = Windmill.CommitStatus.OK; + for (WorkItemCommitRequest workRequest : elem.request.getRequestsList()) { + Windmill.CommitStatus offeredStatus = + streamingCommitsToOffer.remove( + WorkId.builder() + .setWorkToken(workRequest.getWorkToken()) + .setCacheToken(workRequest.getCacheToken()) + .build()); + if (offeredStatus != null) { + status = offeredStatus; + } + } + elem.onDone.accept(status); + } + multiKeyRequests.clear(); } class RequestAndDone { @@ -456,6 +502,18 @@ class RequestAndDone { this.onDone = onDone; } } + + class MultiKeyRequestAndDone { + final Consumer onDone; + final Windmill.MultiKeyWorkItemCommitRequest request; + + MultiKeyRequestAndDone( + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone) { + this.request = request; + this.onDone = onDone; + } + } }; } @@ -518,6 +576,15 @@ public Map waitForAndGetCommits(int numCommits) { public void clearCommitsReceived() { commitsRequested = 0; commitsReceived.clear(); + multiKeyCommitsReceived.clear(); + } + + public List getMultiKeyCommitsReceived() { + return multiKeyCommitsReceived; + } + + public void clearMultiKeyCommitsReceived() { + multiKeyCommitsReceived.clear(); } public ConcurrentHashMap> waitForDroppedCommits( diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java index 5c3132ae471d..3da740d53361 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java @@ -129,9 +129,9 @@ public void testCommit() { for (Commit commit : commits) { Windmill.WorkItemCommitRequest request = - committed.get(commit.work().getWorkItem().getWorkToken()); + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); } assertThat(completeCommits).hasSize(commits.size()); @@ -141,12 +141,13 @@ public void testCommit() { (CompleteCommit completeCommit, Commit commit) -> completeCommit.computationId().equals(commit.computationId()) && completeCommit.status() == Windmill.CommitStatus.OK - && completeCommit.workId().equals(commit.work().id()) + && completeCommit.workId().equals(commit.workBatch().get(0).id()) && completeCommit .shardedKey() .equals( ShardedKey.create( - commit.request().getKey(), commit.request().getShardingKey())), + commit.singleKeyRequest().get().getKey(), + commit.singleKeyRequest().get().getShardingKey())), "expected to equal")) .containsExactlyElementsIn(commits); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index 01197622c24d..5e5fd9ce6420 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -53,6 +53,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.WorkId; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; @@ -62,6 +63,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.Duration; import org.joda.time.Instant; @@ -134,12 +136,10 @@ private static ComputationState createComputationState(String computationId) { null); } - private static CompleteCommit asCompleteCommit(Commit commit, Windmill.CommitStatus status) { - if (commit.work().isFailed()) { - return CompleteCommit.forFailedWork(commit); - } - - return CompleteCommit.create(commit, status); + private static CompleteCommit asCompleteCommit( + String computationId, Work work, Windmill.CommitStatus status) { + Windmill.CommitStatus finalStatus = work.isFailed() ? Windmill.CommitStatus.ABORTED : status; + return CompleteCommit.create(computationId, work.getShardedKey(), work.id(), finalStatus); } @Before @@ -186,10 +186,14 @@ public void testCommit_sendsCommitsToStreamingEngine() { waitForExpectedSetSize(completeCommits, 5); for (Commit commit : commits) { - WorkItemCommitRequest request = committed.get(commit.work().getWorkItem().getWorkToken()); + WorkItemCommitRequest request = + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); - assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); + assertThat(completeCommits) + .contains( + asCompleteCommit( + commit.computationId(), commit.workBatch().get(0), Windmill.CommitStatus.OK)); } workCommitter.stop(); @@ -224,14 +228,24 @@ public void testCommit_handlesFailedCommits() { waitForExpectedSetSize(completeCommits, 10); for (Commit commit : commits) { - if (commit.work().isFailed()) { + if (commit.isFailed()) { assertThat(completeCommits) - .contains(asCompleteCommit(commit, Windmill.CommitStatus.ABORTED)); - assertThat(committed).doesNotContainKey(commit.work().getWorkItem().getWorkToken()); + .contains( + asCompleteCommit( + commit.computationId(), + commit.workBatch().get(0), + Windmill.CommitStatus.ABORTED)); + assertThat(committed) + .doesNotContainKey(commit.workBatch().get(0).getWorkItem().getWorkToken()); } else { - assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); + assertThat(completeCommits) + .contains( + asCompleteCommit( + commit.computationId(), commit.workBatch().get(0), Windmill.CommitStatus.OK)); assertThat(committed) - .containsEntry(commit.work().getWorkItem().getWorkToken(), commit.request()); + .containsEntry( + commit.workBatch().get(0).getWorkItem().getWorkToken(), + commit.singleKeyRequest().get()); } } @@ -282,11 +296,16 @@ public void testCommit_handlesCompleteCommits_commitStatusNotOK() { waitForExpectedSetSize(completeCommits, commits.size()); for (Commit commit : commits) { - WorkItemCommitRequest request = committed.get(commit.work().getWorkItem().getWorkToken()); + WorkItemCommitRequest request = + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); assertThat(completeCommits) - .contains(asCompleteCommit(commit, expectedCommitStatus.get(commit.work().id()))); + .contains( + asCompleteCommit( + commit.computationId(), + commit.workBatch().get(0), + expectedCommitStatus.get(commit.workBatch().get(0).id()))); } workCommitter.stop(); @@ -313,6 +332,14 @@ public boolean commitWorkItem( return false; } + @Override + public boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone) { + return false; + } + @Override public void flush() {} }; @@ -370,7 +397,7 @@ public void shutdown() {} } for (Commit commit : commits) { - assertTrue(commit.work().isFailed()); + assertTrue(commit.isFailed()); } } @@ -409,10 +436,14 @@ public void testMultipleCommitSendersSingleStream() { waitForExpectedSetSize(completeCommits, commits.size()); for (Commit commit : commits) { - WorkItemCommitRequest request = committed.get(commit.work().getWorkItem().getWorkToken()); + WorkItemCommitRequest request = + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); - assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); + assertThat(completeCommits) + .contains( + asCompleteCommit( + commit.computationId(), commit.workBatch().get(0), Windmill.CommitStatus.OK)); } workCommitter.stop(); @@ -474,4 +505,201 @@ public void testStop_drainsCommitQueue_concurrentCommit() waitForExpectedSetSize(completeCommits, sentCommits.intValue()); } + + @Test + public void testCommit_multiKeyCommitFailedWork() { + Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); + workCommitter = createWorkCommitter(completeCommits::add); + + Work workA = createMockWork(101L); + Work workB = createMockWork(102L); + Work workC = createMockWork(103L); + + // Mark non-primary key B as failed + workB.setFailed(); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workA.getWorkItem().getKey()) + .setShardingKey(workA.getWorkItem().getShardingKey()) + .setWorkToken(workA.getWorkItem().getWorkToken()) + .setCacheToken(workA.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workB.getWorkItem().getKey()) + .setShardingKey(workB.getWorkItem().getShardingKey()) + .setWorkToken(workB.getWorkItem().getWorkToken()) + .setCacheToken(workB.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workC.getWorkItem().getKey()) + .setShardingKey(workC.getWorkItem().getShardingKey()) + .setWorkToken(workC.getWorkItem().getWorkToken()) + .setCacheToken(workC.getWorkItem().getCacheToken()) + .build()) + .build(); + + Commit commit = + Commit.createMultiKey( + multiKeyRequest, + createComputationState("computationId"), + ImmutableList.of(workA, workB, workC)); + + workCommitter.start(); + workCommitter.commit(commit); + + // The entire batch must be aborted immediately without making network calls + waitForExpectedSetSize(completeCommits, 3); + + // Verify all three works are aborted individually + assertThat(completeCommits) + .containsExactly( + CompleteCommit.create( + "computationId", workA.getShardedKey(), workA.id(), CommitStatus.ABORTED), + CompleteCommit.create( + "computationId", workB.getShardedKey(), workB.id(), CommitStatus.ABORTED), + CompleteCommit.create( + "computationId", workC.getShardedKey(), workC.id(), CommitStatus.ABORTED)); + + workCommitter.stop(); + } + + @Test + public void testCommit_multiKeyCommitSuccess() { + Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); + workCommitter = createWorkCommitter(completeCommits::add); + + Work workA = createMockWork(101L); + Work workB = createMockWork(102L); + Work workC = createMockWork(103L); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workA.getWorkItem().getKey()) + .setShardingKey(workA.getWorkItem().getShardingKey()) + .setWorkToken(workA.getWorkItem().getWorkToken()) + .setCacheToken(workA.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workB.getWorkItem().getKey()) + .setShardingKey(workB.getWorkItem().getShardingKey()) + .setWorkToken(workB.getWorkItem().getWorkToken()) + .setCacheToken(workB.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workC.getWorkItem().getKey()) + .setShardingKey(workC.getWorkItem().getShardingKey()) + .setWorkToken(workC.getWorkItem().getWorkToken()) + .setCacheToken(workC.getWorkItem().getCacheToken()) + .build()) + .build(); + + Commit commit = + Commit.createMultiKey( + multiKeyRequest, + createComputationState("computationId"), + ImmutableList.of(workA, workB, workC)); + + workCommitter.start(); + workCommitter.commit(commit); + + // Wait for the server to receive and process the commits + fakeWindmillServer.waitForAndGetCommits(3); + waitForExpectedSetSize(completeCommits, 3); + + // Verify that FakeWindmillServer received all 3 work requests in multiKeyCommitsReceived + List multiKeyCommits = + fakeWindmillServer.getMultiKeyCommitsReceived(); + assertThat(multiKeyCommits).hasSize(1); + assertThat(multiKeyCommits.get(0)).isEqualTo(multiKeyRequest); + + // Verify all three works are completed successfully + assertThat(completeCommits) + .containsExactly( + CompleteCommit.create( + "computationId", workA.getShardedKey(), workA.id(), CommitStatus.OK), + CompleteCommit.create( + "computationId", workB.getShardedKey(), workB.id(), CommitStatus.OK), + CompleteCommit.create( + "computationId", workC.getShardedKey(), workC.id(), CommitStatus.OK)); + + workCommitter.stop(); + } + + @Test + public void testCommit_multiKeyCommitStatusNotOK() { + Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); + workCommitter = createWorkCommitter(completeCommits::add); + + Work workA = createMockWork(101L); + Work workB = createMockWork(102L); + Work workC = createMockWork(103L); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workA.getWorkItem().getKey()) + .setShardingKey(workA.getWorkItem().getShardingKey()) + .setWorkToken(workA.getWorkItem().getWorkToken()) + .setCacheToken(workA.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workB.getWorkItem().getKey()) + .setShardingKey(workB.getWorkItem().getShardingKey()) + .setWorkToken(workB.getWorkItem().getWorkToken()) + .setCacheToken(workB.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workC.getWorkItem().getKey()) + .setShardingKey(workC.getWorkItem().getShardingKey()) + .setWorkToken(workC.getWorkItem().getWorkToken()) + .setCacheToken(workC.getWorkItem().getCacheToken()) + .build()) + .build(); + + Commit commit = + Commit.createMultiKey( + multiKeyRequest, + createComputationState("computationId"), + ImmutableList.of(workA, workB, workC)); + + // Offer NOT_FOUND status for one of the works. + fakeWindmillServer.whenCommitWorkStreamCalled().put(workB.id(), CommitStatus.NOT_FOUND); + + workCommitter.start(); + workCommitter.commit(commit); + + // Wait for the server to receive and process the commits + fakeWindmillServer.waitForAndGetCommits(3); + waitForExpectedSetSize(completeCommits, 3); + + // Verify that FakeWindmillServer received the multi-key commit + List multiKeyCommits = + fakeWindmillServer.getMultiKeyCommitsReceived(); + assertThat(multiKeyCommits).hasSize(1); + assertThat(multiKeyCommits.get(0)).isEqualTo(multiKeyRequest); + + // Verify all three works in the multi-key commit are completed with NOT_FOUND status + assertThat(completeCommits) + .containsExactly( + CompleteCommit.create( + "computationId", workA.getShardedKey(), workA.id(), CommitStatus.NOT_FOUND), + CompleteCommit.create( + "computationId", workB.getShardedKey(), workB.id(), CommitStatus.NOT_FOUND), + CompleteCommit.create( + "computationId", workC.getShardedKey(), workC.id(), CommitStatus.NOT_FOUND)); + + workCommitter.stop(); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index e9fd55fa5668..b83890c1dbdd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -1133,6 +1133,76 @@ public void testCommitWorkItem_multiplePhysicalStreams_multipleHandovers_halfClo assertTrue(commitWorkStream.awaitTermination(10, TimeUnit.SECONDS)); } + @Test + public void testCommit_multiKeyCommit() throws Exception { + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + + // 1. Construct two individual WorkItemCommitRequests + long shardingKey1 = 101L; + long workToken1 = 201L; + long cacheToken1 = 301L; + long shardingKey2 = 102L; + long workToken2 = 202L; + long cacheToken2 = 302L; + Windmill.WorkItemCommitRequest request1 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setShardingKey(shardingKey1) + .setWorkToken(workToken1) + .setCacheToken(cacheToken1) + .build(); + Windmill.WorkItemCommitRequest request2 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key2")) + .setShardingKey(shardingKey2) + .setWorkToken(workToken2) + .setCacheToken(cacheToken2) + .build(); + + // 2. Wrap them into a MultiKeyWorkItemCommitRequest + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests(request1) + .addRequests(request2) + .build(); + + // 3. Commit the multi-key work item using the request batcher + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitMultiKeyWorkItem( + COMPUTATION_ID, multiKeyRequest, commitStatusFuture::complete)); + } + + // 4. Receive and assert request properties on FakeWindmillGrpcService + Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take(); + assertThat(request.getCommitChunkCount()).isEqualTo(1); + + Windmill.StreamingCommitRequestChunk chunk = request.getCommitChunk(0); + + // Assert that the commit type is correctly identified as COMMIT_TYPE_MULTI_KEY + assertThat(chunk.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + + // Assert that the routing sharding key is mapped to the first request's sharding key + assertThat(chunk.getShardingKey()).isEqualTo(request1.getShardingKey()); + + // Assert that the serialized payload matches the input multiKeyRequest + Windmill.MultiKeyWorkItemCommitRequest parsedRequest = + Windmill.MultiKeyWorkItemCommitRequest.parseFrom(chunk.getSerializedWorkItemCommit()); + assertThat(parsedRequest).isEqualTo(multiKeyRequest); + + // 5. Respond with the generated requestId to complete the commit + long requestId = chunk.getRequestId(); + streamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(requestId).build()); + + // 6. Verify callback completed successfully with CommitStatus.OK + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + } + private FakeWindmillGrpcService.CommitStreamInfo waitForConnectionAndConsumeHeader() { try { FakeWindmillGrpcService.CommitStreamInfo info = fakeService.waitForConnectedCommitStream(); diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index 1da7ef9be8bb..9abe23f58c89 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -421,6 +421,11 @@ message WatermarkHold { optional string state_family = 4; } +message Uint128Proto { + required fixed64 high = 1; + required fixed64 low = 2; +} + // Proto describing a hot key detected on a given WorkItem. message HotKeyInfo { // The age of the hot key measured from when it was first detected. @@ -671,9 +676,24 @@ message WorkItemCommitRequest { reserved 6, 23; } +message MultiKeyWorkItemCommitRequest { + optional Uint128Proto key_group = 7; + + repeated WorkItemCommitRequest requests = 1; + + repeated OutputMessageBundle output_messages = 2; + + repeated PubSubMessageBundle pubsub_messages = 3; + + repeated int64 finalize_ids = 4 [packed = true]; + + reserved 6; +} + message ComputationCommitWorkRequest { required string computation_id = 1; repeated WorkItemCommitRequest requests = 2; + repeated MultiKeyWorkItemCommitRequest multi_key_requests = 3; } message CommitWorkRequest { @@ -899,6 +919,14 @@ message StreamingCommitRequestChunk { // before handing off to the WindmillHost for processing. optional int64 remaining_bytes_for_work_item = 4; optional bytes serialized_work_item_commit = 5; + + enum CommitType { + COMMIT_TYPE_UNSPECIFIED = 0; + COMMIT_TYPE_SINGLE_KEY = 1; + COMMIT_TYPE_MULTI_KEY = 2; + } + + optional CommitType commit_type = 7; } message StreamingCommitResponse { From 73faa68ba985efbd8906d4069b34cb566b9439a0 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Thu, 4 Jun 2026 11:00:30 +0000 Subject: [PATCH 02/21] [Dataflow Streaming] [Multi Key] StreamingModeExecutionContext refactoring for multi-key execution. --- .../worker/StreamingModeExecutionContext.java | 271 +++++++++++++--- .../worker/WindmillReaderIteratorBase.java | 19 +- .../worker/WindowingWindmillReader.java | 86 ++--- .../streaming/ComputationWorkExecutor.java | 36 +-- .../dataflow/worker/streaming/Work.java | 23 +- .../ComputationWorkExecutorFactory.java | 28 +- .../processing/StreamingWorkScheduler.java | 295 ++++++++++-------- .../worker/StreamingDataflowWorkerTest.java | 21 +- .../StreamingModeExecutionContextTest.java | 140 ++++++--- .../WindmillReaderIteratorBaseTest.java | 94 ++++++ .../worker/WorkerCustomSourcesTest.java | 42 ++- 11 files changed, 759 insertions(+), 296 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 25ce299adf7a..a669fb7ff361 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -25,6 +25,7 @@ import com.google.api.services.dataflow.model.SideInputInfo; import java.io.Closeable; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; @@ -46,10 +47,10 @@ import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.core.metrics.ExecutionStateTracker.ExecutionState; import org.apache.beam.runners.dataflow.worker.DataflowOperationContext.DataflowExecutionState; -import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.StepContext; import org.apache.beam.runners.dataflow.worker.counters.CounterFactory; import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope; +import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; @@ -57,6 +58,9 @@ import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInput; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; +import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; +import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataId; @@ -112,7 +116,8 @@ // TODO(m-trieu) fix nullability issues in StreamingModeExecutionContext.java @NotThreadSafe @Internal -public class StreamingModeExecutionContext extends DataflowExecutionContext { +public class StreamingModeExecutionContext + extends DataflowExecutionContext { private static final Logger LOG = LoggerFactory.getLogger(StreamingModeExecutionContext.class); @@ -162,6 +167,33 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext keyCoder; + + // Key switch listener to delegate MDC logging context and thread name updates + public interface KeySwitchListener { + void onKeySwitch(Work oldWork, Work newWork); + } + + @SuppressWarnings("UnusedVariable") + private @Nullable KeySwitchListener keySwitchListener; + + private List executedWorks = new ArrayList<>(); + private List outputBuilders = new ArrayList<>(); + private Map> accumulatedCallbacks = new HashMap<>(); + private volatile boolean workIsFailed = false; + private @Nullable WindmillStateReader activeStateReader; + private long stateBytesRead = 0; + private final String sourceBytesProcessCounterName; + public StreamingModeExecutionContext( CounterFactory counterFactory, String computationId, @@ -173,7 +205,11 @@ public StreamingModeExecutionContext( StreamingModeExecutionStateRegistry executionStateRegistry, StreamingGlobalConfigHandle globalConfigHandle, long sinkByteLimit, - boolean throwExceptionOnLargeOutput) { + boolean throwExceptionOnLargeOutput, + HotKeyLogger hotKeyLogger, + boolean hotKeyLoggingEnabled, + String stepName, + String sourceBytesProcessCounterName) { super( counterFactory, metricsContainerRegistry, @@ -188,6 +224,10 @@ public StreamingModeExecutionContext( this.stateCache = stateCache; this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; this.throwExceptionOnLargeOutput = throwExceptionOnLargeOutput; + this.hotKeyLogger = checkNotNull(hotKeyLogger); + this.hotKeyLoggingEnabled = hotKeyLoggingEnabled; + this.stepName = checkNotNull(stepName); + this.sourceBytesProcessCounterName = checkNotNull(sourceBytesProcessCounterName); } @VisibleForTesting @@ -208,7 +248,7 @@ public boolean throwExceptionsForLargeOutput() { } public boolean workIsFailed() { - return work != null && work.isFailed(); + return workIsFailed; } public boolean getDrainMode() { @@ -240,19 +280,44 @@ public byte[] getCurrentRecordOffset() { return activeReader.getCurrentRecordOffset(); } + public void clear() { + for (Work w : executedWorks) { + w.setOnFailureListener(null); + } + this.executedWorks = new ArrayList<>(); + this.outputBuilders = new ArrayList<>(); + this.accumulatedCallbacks = new HashMap<>(); + this.workIsFailed = false; + this.sideInputCache.clear(); + this.activeStateReader = null; + this.activeReader = null; + this.keyCoder = null; + this.workExecutor = null; + this.workQueueExecutor = null; + this.budgetHandle = null; + this.keySwitchListener = null; + } + public void start( - @Nullable Object key, Work work, WindmillStateReader stateReader, SideInputStateFetcher sideInputStateFetcher, - Windmill.WorkItemCommitRequest.Builder outputBuilder, - WorkExecutor workExecutor) { - this.key = key; - this.work = work; + WorkExecutor workExecutor, + BoundedQueueExecutor workQueueExecutor, + BoundedQueueExecutorWorkHandle budgetHandle, + @Nullable Coder keyCoder, + KeySwitchListener keySwitchListener) { + clear(); + this.keyCoder = keyCoder; this.workExecutor = workExecutor; - this.finishKeyCalled = false; - this.computationKey = WindmillComputationKey.create(computationId, work.getShardedKey()); - this.sideInputStateFetcher = sideInputStateFetcher; + this.workQueueExecutor = workQueueExecutor; + this.budgetHandle = budgetHandle; + this.keySwitchListener = keySwitchListener; + + this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; + clearSinkFullHint(); + this.stateBytesRead = 0; + StreamingGlobalConfig config = globalConfigHandle.getConfig(); // Snapshot the limits for entire bundle processing. this.operationalLimits = config.operationalLimits(); @@ -260,27 +325,66 @@ public void start( config.enableStateTagEncodingV2() ? WindmillTagEncodingV2.instance() : WindmillTagEncodingV1.instance(); - this.outputBuilder = outputBuilder; - this.sideInputCache.clear(); - this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; - clearSinkFullHint(); + this.sideInputStateFetcher = sideInputStateFetcher; - Instant processingTime = computeProcessingTime(work.getWorkItem().getTimers().getTimersList()); + startForNewKey(work, stateReader); + } - Collection stepContexts = getAllStepContexts(); - if (!stepContexts.isEmpty()) { - // This must be only created once for the workItem as token validation will fail if the same - // work token is reused. - WindmillStateCache.ForKey cacheForKey = - stateCache.forKey(getComputationKey(), getWorkItem().getCacheToken(), getWorkToken()); - for (StepContext stepContext : stepContexts) { - stepContext.start(stateReader, processingTime, cacheForKey, work.watermarks()); + private @Nullable Object decodeKey(Work work) { + // If the read output KVs, then we can decode Windmill's byte key into userland + // key object and provide it to the execution context for use with per-key state. + // Otherwise, we pass null. + // + // The coder type that will be present is: + // WindowedValueCoder(TimerOrElementCoder(KvCoder)) + if (keyCoder != null) { + try { + return keyCoder.decode(work.getWorkItem().getKey().newInput(), Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Failed to decode key during processing", e); } } + return null; + } + + private Windmill.WorkItemCommitRequest.Builder createOutputBuilder(Work work) { + return Windmill.WorkItemCommitRequest.newBuilder() + .setKey(work.getWorkItem().getKey()) + .setShardingKey(work.getWorkItem().getShardingKey()) + .setWorkToken(work.getWorkItem().getWorkToken()) + .setCacheToken(work.getWorkItem().getCacheToken()); + } + + private void logHotKeyIfDetected(Work work, @Nullable Object decodedKey) { + if (work.getWorkItem().hasHotKeyInfo()) { + Windmill.HotKeyInfo hotKeyInfo = work.getWorkItem().getHotKeyInfo(); + Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 1000); + if (decodedKey != null && hotKeyLoggingEnabled) { + hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, decodedKey); + } else { + hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge); + } + } + } + + private void startStepContexts( + WindmillStateReader stateReader, + Instant processingTime, + WindmillStateCache.ForKey cacheForKey, + Watermarks watermarks) { + Collection stepContexts = getAllStepContexts(); + for (StepContext stepContext : stepContexts) { + stepContext.start(stateReader, processingTime, cacheForKey, watermarks); + } } public void finishKey() { - checkState(!finishKeyCalled, "finishKey was already called"); + if (finishKeyCalled) { + return; + } + if (activeStateReader != null) { + this.stateBytesRead += activeStateReader.getBytesRead(); + } checkStateNotNull(workExecutor, "workExecutor must be set before calling finishKey()"); try { workExecutor.finishKey(); @@ -288,6 +392,8 @@ public void finishKey() { throw new RuntimeException(e); } this.finishKeyCalled = true; + + flushStateInternal(); } /** @@ -441,20 +547,22 @@ public void setActiveReader(UnboundedReader reader) { /** Invalidate the state and reader caches for this computation and key. */ public void invalidateCache() { - ByteString key = getSerializedKey(); - if (key != null) { - readerCache.invalidateReader(getComputationKey()); - if (activeReader != null) { - try { - activeReader.close(); - } catch (IOException e) { - LOG.warn( - "Failed to close reader for {}-{}", computationId, getWorkItem().getShardingKey(), e); - } + for (Work w : executedWorks) { + WindmillComputationKey compKey = + WindmillComputationKey.create(computationId, w.getShardedKey()); + readerCache.invalidateReader(compKey); + stateCache.invalidate(w.getShardedKey()); + } + if (activeReader != null) { + try { + activeReader.close(); + } catch (IOException e) { + LOG.warn( + "Failed to close reader for {}-{}", computationId, getWorkItem().getShardingKey(), e); } - activeReader = null; - stateCache.invalidate(key, getWorkItem().getShardingKey()); } + activeReader = null; + activeStateReader = null; } public UnboundedSource.@Nullable CheckpointMark getReaderCheckpoint( @@ -470,8 +578,7 @@ public void invalidateCache() { } } - public Map> flushState() { - checkState(finishKeyCalled, "finishKey must be called before flushState"); + private void flushStateInternal() { Map> callbacks = new HashMap<>(); for (StepContext stepContext : getAllStepContexts()) { @@ -555,7 +662,89 @@ public Map> flushState() { // RestrictionTracker.getProgress() or GetSize() are not defined. outputBuilder.setSourceBacklogBytes(backlogBytes); } - return callbacks; + + this.accumulatedCallbacks.putAll(callbacks); + + outputBuilder.setSourceBytesProcessed( + computeSourceBytesProcessed(sourceBytesProcessCounterName)); + } + + private final long computeSourceBytesProcessed(String sourceBytesCounterName) { + if (!(workExecutor instanceof DataflowMapTaskExecutor)) { + return 0L; + } + HashMap counters = + ((DataflowMapTaskExecutor) workExecutor) + .getReadOperation() + .receivers[0] + .getOutputCounters(); + + return Optional.ofNullable(counters.get(sourceBytesCounterName)) + .map(counter -> ((OutputObjectAndByteCounter) counter).getByteCount().getAndReset()) + .orElse(0L); + } + + public Map> flushState() { + return accumulatedCallbacks; + } + + public boolean advance() { + return false; + } + + private void startForNewKey(Work newWork, WindmillStateReader reader) { + this.key = decodeKey(newWork); + this.work = newWork; + this.finishKeyCalled = false; + this.computationKey = WindmillComputationKey.create(computationId, newWork.getShardedKey()); + + this.outputBuilder = createOutputBuilder(newWork); + this.outputBuilders.add(this.outputBuilder); + newWork.setOnFailureListener(() -> this.workIsFailed = true); + this.executedWorks.add(newWork); + + logHotKeyIfDetected(newWork, this.key); + + // Note: We do NOT clear sideInputCache here, allowing Key B to reuse warm side inputs! + + // Re-initialize state cache and state/timer internals across all step contexts + Instant processingTime = + computeProcessingTime(newWork.getWorkItem().getTimers().getTimersList()); + if (!getAllStepContexts().isEmpty()) { + // This must be only created once for a workItem as token validation will fail if the same + // work token is reused. + WindmillStateCache.ForKey cacheForKey = + stateCache.forKey( + getComputationKey(), newWork.getWorkItem().getCacheToken(), getWorkToken()); + this.activeStateReader = reader; + startStepContexts(reader, processingTime, cacheForKey, newWork.watermarks()); + } else { + this.activeStateReader = null; + } + } + + public List getExecutedWorks() { + return executedWorks; + } + + public long getStateBytesRead() { + return stateBytesRead; + } + + public List getOutputBuilders() { + return outputBuilders; + } + + public Map> getAccumulatedCallbacks() { + return accumulatedCallbacks; + } + + public @Nullable Object getKey() { + return key; + } + + public Work getWork() { + return work; } String getStateFamily(NameContext nameContext) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java index 075a1a8a4250..b142cc38d365 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java @@ -35,7 +35,7 @@ public abstract class WindmillReaderIteratorBase extends NativeReader.NativeReaderIterator> { private final StreamingModeExecutionContext context; - private final Windmill.WorkItem work; + private Windmill.WorkItem work; private int bundleIndex = 0; private int messageIndex = -1; private @Nullable WindowedValue current = null; @@ -57,15 +57,27 @@ public boolean start() throws IOException { @Override public boolean advance() throws IOException { if (context.workIsFailed()) { - throw new WorkItemCancelledException(context.getWorkItem().getShardingKey()); + throw new WorkItemCancelledException(checkNotNull(context.getWorkItem()).getShardingKey()); } while (true) { if (bundleIndex >= work.getMessageBundlesCount()) { - current = null; + // If elements are exhausted, try advancing the execution context to the next key in the + // group context.finishKey(); + if (context.advance()) { + // Transition succeeded! Update iterator references to the new work item + this.work = context.getWork().getWorkItem(); + this.bundleIndex = 0; + this.messageIndex = -1; + continue; + } + + // All work items are exhausted. Iterator returns false. + current = null; return false; } + Windmill.InputMessageBundle bundle = work.getMessageBundles(bundleIndex); ++messageIndex; if (messageIndex >= bundle.getMessagesCount()) { @@ -73,6 +85,7 @@ public boolean advance() throws IOException { ++bundleIndex; continue; } + try { current = checkNotNull(decodeMessage(bundle.getMessages(messageIndex))); return true; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java index 488684769bd9..916920518f0b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java @@ -151,51 +151,65 @@ public NativeReaderIterator>> iterator() throw && Iterables.isEmpty(keyedWorkItem.elementsIterable())); final WindowedValue> value = new ValueInEmptyWindows<>(keyedWorkItem); - // Return a noop iterator when current workitem is an empty workitem. - if (isEmptyWorkItem) { - return new NativeReaderIterator>>() { - @Override - public boolean start() throws IOException { - context.finishKey(); - return false; + return new NativeReaderIterator>>() { + private @Nullable WindowedValue> current = null; + private boolean started = false; + + @Override + public boolean start() throws IOException { + if (context.workIsFailed()) { + throw new WorkItemCancelledException( + checkStateNotNull(context.getWorkItem()).getShardingKey()); } - - @Override - public boolean advance() throws IOException { + if (started) { return false; } - - @Override - public WindowedValue> getCurrent() { - throw new NoSuchElementException(); + started = true; + if (isEmptyWorkItem) { + return advance(); // Try to transition immediately if the first key is empty! } - }; - } else { - return new NativeReaderIterator>>() { - private @Nullable WindowedValue> current = null; - - @Override - public boolean start() throws IOException { - current = value; - return true; + current = value; + return true; + } + + @Override + public boolean advance() throws IOException { + if (context.workIsFailed()) { + throw new WorkItemCancelledException( + checkStateNotNull(context.getWorkItem()).getShardingKey()); } - @Override - public boolean advance() throws IOException { - current = null; - context.finishKey(); - return false; + context.finishKey(); + if (context.advance()) { + @SuppressWarnings("unchecked") + K newKey = (K) context.getKey(); + KeyedWorkItem newKeyedWorkItem = + new WindmillKeyedWorkItem<>( + newKey, + context.getWork().getWorkItem(), + windowCoder, + windowsCoder, + valueCoder, + context.getWindmillTagEncoding(), + context.getDrainMode(), + skipUndecodableElements.isAccessible() + && Boolean.TRUE.equals(skipUndecodableElements.get())); + current = new ValueInEmptyWindows<>(newKeyedWorkItem); + return true; } - @Override - public WindowedValue> getCurrent() { - if (current == null) { - throw new NoSuchElementException(); - } - return value; + current = null; + return false; + } + + @Override + public WindowedValue> getCurrent() { + if (current == null) { + throw new NoSuchElementException(); } - }; - } + return current; + } + }; } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java index b4f3a22a7f52..ed86d58b9bb0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java @@ -18,21 +18,16 @@ package org.apache.beam.runners.dataflow.worker.streaming; import com.google.auto.value.AutoValue; -import java.util.HashMap; import java.util.Optional; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.runners.core.metrics.ExecutionStateTracker; -import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutor; import org.apache.beam.runners.dataflow.worker.DataflowWorkExecutor; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; -import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; -import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; -import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,13 +63,23 @@ public static ComputationWorkExecutor.Builder builder() { * Executes DoFns for the Work. Blocks the calling thread until DoFn(s) have completed execution. */ public final void executeWork( - @Nullable Object key, Work work, WindmillStateReader stateReader, SideInputStateFetcher sideInputStateFetcher, - Windmill.WorkItemCommitRequest.Builder outputBuilder) + BoundedQueueExecutor workQueueExecutor, + BoundedQueueExecutorWorkHandle budgetHandle, + StreamingModeExecutionContext.KeySwitchListener keySwitchListener) throws Exception { - context().start(key, work, stateReader, sideInputStateFetcher, outputBuilder, workExecutor()); + context() + .start( + work, + stateReader, + sideInputStateFetcher, + workExecutor(), + workQueueExecutor, + budgetHandle, + keyCoder().orElse(null), + keySwitchListener); workExecutor().execute(); } @@ -84,6 +89,7 @@ public final void executeWork( */ public final void invalidate() { context().invalidateCache(); + context().clear(); try { workExecutor().close(); } catch (Exception e) { @@ -91,18 +97,6 @@ public final void invalidate() { } } - public final long computeSourceBytesProcessed(String sourceBytesCounterName) { - HashMap counters = - ((DataflowMapTaskExecutor) workExecutor()) - .getReadOperation() - .receivers[0] - .getOutputCounters(); - - return Optional.ofNullable(counters.get(sourceBytesCounterName)) - .map(counter -> ((OutputObjectAndByteCounter) counter).getByteCount().getAndReset()) - .orElse(0L); - } - @AutoValue.Builder public abstract static class Builder { public abstract Builder setWorkExecutor(DataflowWorkExecutor workExecutor); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index cb01e1e508ce..668657228dfd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -52,6 +52,7 @@ import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; import org.joda.time.Instant; @@ -79,6 +80,7 @@ public final class Work implements RefreshableWork { private volatile TimedState currentState; private volatile boolean isFailed; private volatile String processingThreadName = ""; + private volatile @Nullable Runnable onFailureListener = null; private final boolean drainMode; private Work( @@ -184,6 +186,10 @@ public long getSerializedWorkItemSize() { return serializedWorkItemSize; } + public String getComputationId() { + return processingContext.computationId(); + } + @Override public ShardedKey getShardedKey() { return shardedKey; @@ -235,8 +241,19 @@ public void setProcessingThreadName(String processingThreadName) { } @Override - public void setFailed() { + public synchronized void setFailed() { this.isFailed = true; + Runnable listener = onFailureListener; + if (listener != null) { + listener.run(); + } + } + + public synchronized void setOnFailureListener(@Nullable Runnable listener) { + this.onFailureListener = listener; + if (isFailed && listener != null) { + listener.run(); + } } public boolean isCommitPending() { @@ -261,6 +278,10 @@ public void queueCommit(WorkItemCommitRequest commitRequest, ComputationState co processingContext.workCommitter().accept(Commit.create(commitRequest, computationState, this)); } + public Consumer workCommitter() { + return processingContext.workCommitter(); + } + public WindmillStateReader createWindmillStateReader() { return WindmillStateReader.forWork(this); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java index 269799903300..fcc6d6bbb743 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java @@ -29,6 +29,7 @@ import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutor; import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutorFactory; +import org.apache.beam.runners.dataflow.worker.HotKeyLogger; import org.apache.beam.runners.dataflow.worker.IntrinsicMapTaskExecutorFactory; import org.apache.beam.runners.dataflow.worker.ReaderCache; import org.apache.beam.runners.dataflow.worker.ReaderRegistry; @@ -97,6 +98,7 @@ final class ComputationWorkExecutorFactory { private final IdGenerator idGenerator; private final StreamingGlobalConfigHandle globalConfigHandle; private final boolean throwExceptionOnLargeOutput; + private final HotKeyLogger hotKeyLogger; ComputationWorkExecutorFactory( DataflowWorkerHarnessOptions options, @@ -106,7 +108,8 @@ final class ComputationWorkExecutorFactory { DataflowExecutionStateSampler sampler, CounterSet pendingDeltaCounters, IdGenerator idGenerator, - StreamingGlobalConfigHandle globalConfigHandle) { + StreamingGlobalConfigHandle globalConfigHandle, + HotKeyLogger hotKeyLogger) { this.options = options; this.mapTaskExecutorFactory = mapTaskExecutorFactory; this.readerCache = readerCache; @@ -124,6 +127,7 @@ final class ComputationWorkExecutorFactory { : StreamingDataflowWorker.MAX_SINK_BYTES; this.throwExceptionOnLargeOutput = hasExperiment(options, THROW_EXCEPTIONS_ON_LARGE_OUTPUT_EXPERIMENT); + this.hotKeyLogger = hotKeyLogger; } private static Nodes.ParallelInstructionNode extractReadNode( @@ -191,8 +195,12 @@ ComputationWorkExecutor createComputationWorkExecutor( DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker = createExecutionStateTracker(stageInfo, mapTask, workLatencyTrackingId); + boolean hotKeyLoggingEnabled = + options.isHotKeyLoggingEnabled() || hasExperiment(options, "enable_hot_key_logging"); + String stepName = getShuffleTaskStepName(mapTask); StreamingModeExecutionContext context = - createExecutionContext(computationState, stageInfo, executionStateTracker); + createExecutionContext( + computationState, stageInfo, executionStateTracker, hotKeyLoggingEnabled, stepName); DataflowMapTaskExecutor mapTaskExecutor = createMapTaskExecutor(context, mapTask, mapTaskNetwork); ReadOperation readOperation = getValidatedReadOperation(mapTaskExecutor); @@ -255,7 +263,9 @@ ComputationWorkExecutor createComputationWorkExecutor( private StreamingModeExecutionContext createExecutionContext( ComputationState computationState, StageInfo stageInfo, - DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker) { + DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker, + boolean hotKeyLoggingEnabled, + String stepName) { String computationId = computationState.getComputationId(); return new StreamingModeExecutionContext( pendingDeltaCounters, @@ -268,7 +278,11 @@ private StreamingModeExecutionContext createExecutionContext( stageInfo.executionStateRegistry(), globalConfigHandle, maxSinkBytes, - throwExceptionOnLargeOutput); + throwExceptionOnLargeOutput, + hotKeyLogger, + hotKeyLoggingEnabled, + stepName, + computationState.sourceBytesProcessCounterName()); } private DataflowMapTaskExecutor createMapTaskExecutor( @@ -286,6 +300,12 @@ private DataflowMapTaskExecutor createMapTaskExecutor( idGenerator); } + private static String getShuffleTaskStepName(MapTask mapTask) { + // The MapTask instruction is ordered by dependencies, such that the first element is + // always going to be the shuffle task. + return mapTask.getInstructions().get(0).getName(); + } + private DataflowExecutionContext.DataflowExecutionStateTracker createExecutionStateTracker( StageInfo stageInfo, MapTask mapTask, String workLatencyTrackingId) { return new DataflowExecutionContext.DataflowExecutionStateTracker( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index 364608be82ca..9ee2192b09d8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -17,22 +17,23 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.processing; -import static org.apache.beam.sdk.options.ExperimentalOptions.hasExperiment; - import com.google.api.services.dataflow.model.MapTask; import com.google.auto.value.AutoValue; -import java.util.Optional; +import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutorFactory; import org.apache.beam.runners.dataflow.worker.HotKeyLogger; import org.apache.beam.runners.dataflow.worker.ReaderCache; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC; import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; @@ -57,12 +58,11 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.WorkFailureProcessor; import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.fn.IdGenerator; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; -import org.joda.time.Duration; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -78,7 +78,6 @@ public class StreamingWorkScheduler { private static final Logger LOG = LoggerFactory.getLogger(StreamingWorkScheduler.class); - private final DataflowWorkerHarnessOptions options; private final Supplier clock; private final ComputationWorkExecutorFactory computationWorkExecutorFactory; private final SideInputStateFetcherFactory sideInputStateFetcherFactory; @@ -86,33 +85,31 @@ public class StreamingWorkScheduler { private final WorkFailureProcessor workFailureProcessor; private final StreamingCommitFinalizer commitFinalizer; private final StreamingCounters streamingCounters; - private final HotKeyLogger hotKeyLogger; private final ConcurrentMap stageInfoMap; private final DataflowExecutionStateSampler sampler; private final StreamingGlobalConfigHandle globalConfigHandle; + private final BoundedQueueExecutor workExecutor; public StreamingWorkScheduler( - DataflowWorkerHarnessOptions options, Supplier clock, + BoundedQueueExecutor workExecutor, ComputationWorkExecutorFactory computationWorkExecutorFactory, SideInputStateFetcherFactory sideInputStateFetcherFactory, FailureTracker failureTracker, WorkFailureProcessor workFailureProcessor, StreamingCommitFinalizer commitFinalizer, StreamingCounters streamingCounters, - HotKeyLogger hotKeyLogger, ConcurrentMap stageInfoMap, DataflowExecutionStateSampler sampler, StreamingGlobalConfigHandle globalConfigHandle) { - this.options = options; this.clock = clock; + this.workExecutor = workExecutor; this.computationWorkExecutorFactory = computationWorkExecutorFactory; this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; this.failureTracker = failureTracker; this.workFailureProcessor = workFailureProcessor; this.commitFinalizer = commitFinalizer; this.streamingCounters = streamingCounters; - this.hotKeyLogger = hotKeyLogger; this.stageInfoMap = stageInfoMap; this.sampler = sampler; this.globalConfigHandle = globalConfigHandle; @@ -143,18 +140,18 @@ public static StreamingWorkScheduler create( sampler, streamingCounters.pendingDeltaCounters(), idGenerator, - globalConfigHandle); + globalConfigHandle, + hotKeyLogger); return new StreamingWorkScheduler( - options, clock, + workExecutor, computationWorkExecutorFactory, SideInputStateFetcherFactory.fromOptions(options), failureTracker, workFailureProcessor, StreamingCommitFinalizer.create(workExecutor, commitFinalizerCleanupExecutor), streamingCounters, - hotKeyLogger, stageInfoMap, sampler, globalConfigHandle); @@ -191,15 +188,8 @@ private static void setUpWorkLoggingContext(String workLatencyTrackingId, String DataflowWorkerLoggingMDC.setStageName(computationId); } - private static String getShuffleTaskStepName(MapTask mapTask) { - // The MapTask instruction is ordered by dependencies, such that the first element is - // always going to be the shuffle task. - return mapTask.getInstructions().get(0).getName(); - } - /** Resets logging context of the Thread executing the {@link Work} for logging. */ - private void resetWorkLoggingContext(String workLatencyTrackingId) { - sampler.resetForWorkId(workLatencyTrackingId); + private void resetWorkLoggingContext() { DataflowWorkerLoggingMDC.setWorkId(null); DataflowWorkerLoggingMDC.setStageName(null); } @@ -246,10 +236,9 @@ private void processWork( } private void processWork( - ComputationState computationState, Work work, BoundedQueueExecutorWorkHandle unusedHandle) { + ComputationState computationState, Work work, BoundedQueueExecutorWorkHandle handle) { Windmill.WorkItem workItem = work.getWorkItem(); String computationId = computationState.getComputationId(); - ByteString key = workItem.getKey(); work.setProcessingThreadName(Thread.currentThread().getName()); work.setState(Work.State.PROCESSING); setUpWorkLoggingContext(work.getLatencyTrackingId(), computationId); @@ -258,37 +247,36 @@ private void processWork( // Before any processing starts, call any pending OnCommit callbacks. Nothing that requires // cleanup should be done before this, since we might exit early here. commitFinalizer.finalizeCommits(workItem.getSourceState().getFinalizeIdsList()); + if (workItem.getSourceState().getOnlyFinalize()) { - Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem); - outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true)); - work.setState(Work.State.COMMIT_QUEUED); - work.queueCommit(outputBuilder.build(), computationState); + handleOnlyFinalize(computationState, work, workItem); return; } long processingStartTimeNanos = System.nanoTime(); - MapTask mapTask = computationState.getMapTask(); - StageInfo stageInfo = - stageInfoMap.computeIfAbsent( - mapTask.getStageName(), s -> StageInfo.create(s, mapTask.getSystemName())); + StageInfo stageInfo = getStageInfo(computationState); + List worksToCleanup = null; try { if (work.isFailed()) { throw new WorkItemCancelledException(workItem.getShardingKey()); } - // Execute the user code for the Work. - ExecuteWorkResult executeWorkResult = executeWork(work, stageInfo, computationState); - Windmill.WorkItemCommitRequest.Builder commitRequest = executeWorkResult.commitWorkRequest(); + // Execute the user code for the Work batch. + ExecuteWorkResult executeWorkResult = executeWork(work, stageInfo, computationState, handle); + List workBatch = executeWorkResult.workBatch(); + worksToCleanup = workBatch; + List outputBuilders = + executeWorkResult.outputBuilders(); + Map> accumulatedCallbacks = + executeWorkResult.accumulatedCallbacks(); - // Validate the commit request, possibly requesting truncation if the commitSize is too large. - Windmill.WorkItemCommitRequest validatedCommitRequest = - validateCommitRequestSize(commitRequest.build(), computationId, workItem); + commitFinalizer.cacheCommitFinalizers(accumulatedCallbacks); - // Queue the commit. - work.queueCommit(validatedCommitRequest, computationState); - recordProcessingStats(commitRequest, workItem, executeWorkResult); - LOG.debug("Processing done for work token: {}", workItem.getWorkToken()); + commitWorkBatch(computationState, workBatch, outputBuilders); + + recordProcessingStats(workBatch, outputBuilders, executeWorkResult.stateBytesRead()); + LOG.debug("Processing done for work batch size: {}", workBatch.size()); } catch (Throwable t) { // OutOfMemoryError that are caught will be rethrown and trigger jvm termination. try { @@ -306,22 +294,10 @@ private void processWork( throw ExceptionUtils.safeWrapThrowableAsException(t2); } } finally { - // Update total processing time counters. Updating in finally clause ensures that - // work items causing exceptions are also accounted in time spent. - long processingTimeMsecs = - TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); - stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); - - // Attribute all the processing to timers if the work item contains any timers. - // Tests show that work items rarely contain both timers and message bundles. It should - // be a fairly close approximation. - // Another option: Derive time split between messages and timers based on recent totals. - // either here or in DFE. - if (work.getWorkItem().hasTimers()) { - stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs); - } + recordProcessingTime(stageInfo, worksToCleanup, work, processingStartTimeNanos); - resetWorkLoggingContext(work.getLatencyTrackingId()); + resetWorkLoggingContext(); + sampler.resetForWorkId(work.getLatencyTrackingId()); work.setProcessingThreadName(""); } } @@ -354,27 +330,35 @@ private Windmill.WorkItemCommitRequest validateCommitRequestSize( } private void recordProcessingStats( - Windmill.WorkItemCommitRequest.Builder outputBuilder, - Windmill.WorkItem workItem, - ExecuteWorkResult executeWorkResult) { - // Compute shuffle and state byte statistics these will be flushed asynchronously. - long stateBytesWritten = - outputBuilder - .clearOutputMessages() - .clearPerWorkItemLatencyAttributions() - .build() - .getSerializedSize(); - - streamingCounters.windmillShuffleBytesRead().addValue(computeShuffleBytesRead(workItem)); - streamingCounters.windmillStateBytesRead().addValue(executeWorkResult.stateBytesRead()); - streamingCounters.windmillStateBytesWritten().addValue(stateBytesWritten); + List workBatch, + List outputBuilders, + long totalStateBytesRead) { + long totalStateBytesWritten = 0; + long totalShuffleBytesRead = 0; + for (int i = 0; i < workBatch.size(); i++) { + Windmill.WorkItem workItem = workBatch.get(i).getWorkItem(); + Windmill.WorkItemCommitRequest.Builder outputBuilder = outputBuilders.get(i); + // Compute shuffle and state byte statistics these will be flushed asynchronously. + long stateBytesWritten = + outputBuilder + .clearOutputMessages() + .clearPerWorkItemLatencyAttributions() + .build() + .getSerializedSize(); + totalStateBytesWritten += stateBytesWritten; + totalShuffleBytesRead += computeShuffleBytesRead(workItem); + } + streamingCounters.windmillShuffleBytesRead().addValue(totalShuffleBytesRead); + streamingCounters.windmillStateBytesRead().addValue(totalStateBytesRead); + streamingCounters.windmillStateBytesWritten().addValue(totalStateBytesWritten); } private ExecuteWorkResult executeWork( - Work work, StageInfo stageInfo, ComputationState computationState) throws Exception { - Windmill.WorkItem workItem = work.getWorkItem(); - ByteString key = workItem.getKey(); - Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem); + Work work, + StageInfo stageInfo, + ComputationState computationState, + BoundedQueueExecutorWorkHandle handle) + throws Exception { ComputationWorkExecutor computationWorkExecutor = computationState .acquireComputationWorkExecutor() @@ -388,86 +372,143 @@ private ExecuteWorkResult executeWork( SideInputStateFetcher localSideInputStateFetcher = sideInputStateFetcherFactory.createSideInputStateFetcher(work::fetchSideInput); - // If the read output KVs, then we can decode Windmill's byte key into userland - // key object and provide it to the execution context for use with per-key state. - // Otherwise, we pass null. - // - // The coder type that will be present is: - // WindowedValueCoder(TimerOrElementCoder(KvCoder)) - Optional> keyCoder = computationWorkExecutor.keyCoder(); - @SuppressWarnings("deprecation") - @Nullable - final Object executionKey = - !keyCoder.isPresent() ? null : keyCoder.get().decode(key.newInput(), Coder.Context.OUTER); - - if (workItem.hasHotKeyInfo()) { - Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo(); - Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 1000); - - String stepName = getShuffleTaskStepName(computationState.getMapTask()); - if (executionKey != null - && (options.isHotKeyLoggingEnabled() - || hasExperiment(options, "enable_hot_key_logging")) - && keyCoder.isPresent()) { - hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey); - } else { - hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge); - } - } + StreamingModeExecutionContext.KeySwitchListener keySwitchListener = + createKeySwitchListener(computationState); // Blocks while executing work. computationWorkExecutor.executeWork( - executionKey, work, stateReader, localSideInputStateFetcher, outputBuilder); + work, stateReader, localSideInputStateFetcher, workExecutor, handle, keySwitchListener); - if (work.isFailed()) { - throw new WorkItemCancelledException(workItem.getShardingKey()); + StreamingModeExecutionContext context = computationWorkExecutor.context(); + if (context.workIsFailed()) { + throw new WorkItemCancelledException( + Preconditions.checkNotNull(context.getWorkItem()).getShardingKey()); } - // Reports source bytes processed to WorkItemCommitRequest if available. - try { - long sourceBytesProcessed = - computationWorkExecutor.computeSourceBytesProcessed( - computationState.sourceBytesProcessCounterName()); - outputBuilder.setSourceBytesProcessed(sourceBytesProcessed); - } catch (Exception e) { - LOG.error("{}", e.toString()); - } - - commitFinalizer.cacheCommitFinalizers(computationWorkExecutor.context().flushState()); + // Retrieve executed works, output builders, and accumulated callbacks from execution context + List workBatch = context.getExecutedWorks(); + List outputBuilders = context.getOutputBuilders(); + Map> accumulatedCallbacks = context.getAccumulatedCallbacks(); + context.clear(); // Release the execution state for another thread to use. computationState.releaseComputationWorkExecutor(computationWorkExecutor); computationWorkExecutor = null; - work.setState(Work.State.COMMIT_QUEUED); - outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler)); - return ExecuteWorkResult.create( - outputBuilder, stateReader.getBytesRead() + localSideInputStateFetcher.getBytesRead()); + workBatch, + outputBuilders, + accumulatedCallbacks, + context.getStateBytesRead() + localSideInputStateFetcher.getBytesRead()); } catch (Throwable t) { if (computationWorkExecutor != null) { // If processing failed due to a thrown exception, close the executionState. Do not // return/release the executionState back to computationState as that will lead to this // executionState instance being reused. - LOG.debug("Invalidating executor after work item {} failed", workItem.getWorkToken(), t); + LOG.debug( + "Invalidating executor after work item {} failed", + work.getWorkItem().getWorkToken(), + t); computationWorkExecutor.invalidate(); } - - // Re-throw the exception, it will be caught and handled by workFailureProcessor downstream. throw t; } } + private void handleOnlyFinalize( + ComputationState computationState, Work work, Windmill.WorkItem workItem) { + Windmill.WorkItemCommitRequest.Builder outputBuilder = + initializeOutputBuilder(workItem.getKey(), workItem); + outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true)); + work.setState(Work.State.COMMIT_QUEUED); + work.queueCommit(outputBuilder.build(), computationState); + } + + private StageInfo getStageInfo(ComputationState computationState) { + MapTask mapTask = computationState.getMapTask(); + return stageInfoMap.computeIfAbsent( + mapTask.getStageName(), s -> StageInfo.create(s, mapTask.getSystemName())); + } + + private void commitWorkBatch( + ComputationState computationState, + List workBatch, + List outputBuilders) { + Preconditions.checkState( + workBatch.size() == 1, "Expected single-key work batch, got: " + workBatch.size()); + commitSingleKeyWork(computationState, workBatch.get(0), outputBuilders.get(0)); + } + + private void commitSingleKeyWork( + ComputationState computationState, + Work work, + Windmill.WorkItemCommitRequest.Builder commitRequestBuilder) { + // Validate the commit request, possibly requesting truncation if the commitSize is too large. + Windmill.WorkItemCommitRequest validatedCommitRequest = + validateCommitRequestSize( + commitRequestBuilder.build(), computationState.getComputationId(), work.getWorkItem()); + work.setState(Work.State.COMMIT_QUEUED); + validatedCommitRequest = + validatedCommitRequest + .toBuilder() + .addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler)) + .build(); + work.queueCommit(validatedCommitRequest, computationState); + } + + private void recordProcessingTime( + StageInfo stageInfo, + @Nullable List worksToCleanup, + Work work, + long processingStartTimeNanos) { + // Update total processing time counters. Updating in finally clause ensures that + // work items causing exceptions are also accounted in time spent. + long processingTimeMsecs = + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); + stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); + if (anyWorkHasTimers(worksToCleanup, work)) { + // Attribute all the processing to timers if the work item contains any timers. + // Tests show that work items rarely contain both timers and message bundles. It should + // be a fairly close approximation. + // Another option: Derive time split between messages and timers based on recent totals. + // either here or in DFE. + stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs); + } + } + + private static boolean anyWorkHasTimers(@Nullable List works, Work primaryWork) { + if (works != null && !works.isEmpty()) { + return works.stream().anyMatch(w -> w.getWorkItem().hasTimers()); + } + return primaryWork.getWorkItem().hasTimers(); + } + + private StreamingModeExecutionContext.KeySwitchListener createKeySwitchListener( + ComputationState computationState) { + return (oldWork, newWork) -> { + resetWorkLoggingContext(); + setUpWorkLoggingContext(newWork.getLatencyTrackingId(), computationState.getComputationId()); + newWork.setProcessingThreadName(Thread.currentThread().getName()); + oldWork.setProcessingThreadName(""); + }; + } + @AutoValue abstract static class ExecuteWorkResult { - - private static ExecuteWorkResult create( - Windmill.WorkItemCommitRequest.Builder commitWorkRequest, long stateBytesRead) { + static ExecuteWorkResult create( + List workBatch, + List outputBuilders, + Map> accumulatedCallbacks, + long stateBytesRead) { return new AutoValue_StreamingWorkScheduler_ExecuteWorkResult( - commitWorkRequest, stateBytesRead); + workBatch, outputBuilders, accumulatedCallbacks, stateBytesRead); } - abstract Windmill.WorkItemCommitRequest.Builder commitWorkRequest(); + abstract List workBatch(); + + abstract List outputBuilders(); + + abstract Map> accumulatedCallbacks(); abstract long stateBytesRead(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index d58f20076994..f7511305bf0f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -571,11 +571,16 @@ private Windmill.GetWorkResponse buildInput(String input, byte[] metadata) throw Windmill.GetWorkResponse.Builder builder = Windmill.GetWorkResponse.newBuilder(); TextFormat.merge(input, builder); if (metadata != null) { - Windmill.InputMessageBundle.Builder messageBundleBuilder = - builder.getWorkBuilder(0).getWorkBuilder(0).getMessageBundlesBuilder(0); - for (Windmill.Message.Builder messageBuilder : - messageBundleBuilder.getMessagesBuilderList()) { - messageBuilder.setMetadata(addPaneTag(PaneInfo.NO_FIRING, metadata)); + for (Windmill.ComputationWorkItems.Builder compBuilder : builder.getWorkBuilderList()) { + for (Windmill.WorkItem.Builder workBuilder : compBuilder.getWorkBuilderList()) { + for (Windmill.InputMessageBundle.Builder messageBundleBuilder : + workBuilder.getMessageBundlesBuilderList()) { + for (Windmill.Message.Builder messageBuilder : + messageBundleBuilder.getMessagesBuilderList()) { + messageBuilder.setMetadata(addPaneTag(PaneInfo.NO_FIRING, metadata)); + } + } + } } } @@ -1327,7 +1332,7 @@ public void testKeyCommitTooLargeException() throws Exception { makeExpectedTruncationRequestOutput( 1, "large_key", DEFAULT_SHARDING_KEY, largeCommit.getEstimatedWorkItemCommitBytes()) .build(), - largeCommit); + removeDynamicFields(largeCommit)); // Check this explicitly since the estimated commit bytes weren't actually // checked against an expected value in the previous step @@ -3507,8 +3512,8 @@ public void testExceptionInvalidatesCache() throws Exception { } // Ensure that the invalidated dofn had tearDown called on them. - assertEquals(1, TestExceptionInvalidatesCacheFn.tearDownCallCount.get()); - assertEquals(2, TestExceptionInvalidatesCacheFn.setupCallCount.get()); + assertEquals(2, TestExceptionInvalidatesCacheFn.tearDownCallCount.get()); + assertEquals(3, TestExceptionInvalidatesCacheFn.setupCallCount.get()); worker.stop(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 216ca5386675..9d4ef999707c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -71,6 +71,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTagEncodingV2; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.metrics.MetricsContainer; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.state.TimeDomain; @@ -139,7 +140,43 @@ public void setUp() { executionStateRegistry, globalConfigHandle, Long.MAX_VALUE, - /*throwExceptionOnLargeOutput=*/ false); + /*throwExceptionOnLargeOutput=*/ false, + new HotKeyLogger(), + /*hotKeyLoggingEnabled=*/ false, + /*stepName=*/ "stepName", + "sourceBytesProcessCounterName"); + } + + private StreamingModeExecutionContext createTestExecutionContext( + DataflowWorkerHarnessOptions options) { + CounterSet counterSet = new CounterSet(); + ConcurrentHashMap stateNameMap = new ConcurrentHashMap<>(); + stateNameMap.put(NameContextsForTests.nameContextForTest().userName(), "testStateFamily"); + return new StreamingModeExecutionContext( + counterSet, + COMPUTATION_ID, + new ReaderCache(Duration.standardMinutes(1), Executors.newCachedThreadPool()), + stateNameMap, + WindmillStateCache.builder() + .setSizeMb(options.getWorkerCacheMb()) + .build() + .forComputation("comp"), + StreamingStepMetricsContainer.createRegistry(), + new DataflowExecutionStateTracker( + ExecutionStateSampler.newForTest(), + executionStateRegistry.getState( + NameContext.forStage("stage"), "other", null, NoopProfileScope.NOOP), + counterSet, + PipelineOptionsFactory.create(), + "test-work-item-id"), + executionStateRegistry, + globalConfigHandle, + Long.MAX_VALUE, + /*throwExceptionOnLargeOutput=*/ false, + new HotKeyLogger(), + /*hotKeyLoggingEnabled=*/ false, + /*stepName=*/ "stepName", + "sourceBytesProcessCounterName"); } private static Work createMockWork(Windmill.WorkItem workItem, Watermarks watermarks) { @@ -153,25 +190,42 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla Instant::now); } + private void start(Work work) { + start(executionContext, work, null); + } + + private void start(Work work, Coder keyCoder) { + start(executionContext, work, keyCoder); + } + + private void start(StreamingModeExecutionContext context, Work work) { + start(context, work, null); + } + + private void start(StreamingModeExecutionContext context, Work work, Coder keyCoder) { + context.start( + work, + stateReader, + sideInputStateFetcher, + workExecutor, + /* workQueueExecutor= */ null, + /* budgetHandle= */ null, + keyCoder, + /* keySwitchListener= */ (k, c) -> {}); + } + @Test public void testTimerInternalsSetTimer() throws Exception { - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); NameContext nameContext = NameContextsForTests.nameContextForTest(); DataflowOperationContext operationContext = executionContext.createOperationContext(nameContext); StreamingModeExecutionContext.StepContext stepContext = executionContext.getStepContext(operationContext); - executionContext.start( - "key", + start( createMockWork( Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); TimerInternals timerInternals = stepContext.timerInternals(); @@ -185,6 +239,7 @@ public void testTimerInternalsSetTimer() throws Exception { executionContext.finishKey(); executionContext.flushState(); + Windmill.WorkItemCommitRequest.Builder outputBuilder = executionContext.getOutputBuilder(); Windmill.Timer timer = outputBuilder.buildPartial().getOutputTimers(0); assertThat(timer.getTag().toStringUtf8(), equalTo("/skey+0:5000")); assertThat(timer.getTimestamp(), equalTo(TimeUnit.MILLISECONDS.toMicros(5000))); @@ -193,9 +248,6 @@ public void testTimerInternalsSetTimer() throws Exception { @Test public void testTimerInternalsProcessingTimeSkew() { - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); - NameContext nameContext = NameContextsForTests.nameContextForTest(); DataflowOperationContext operationContext = executionContext.createOperationContext(nameContext); @@ -215,15 +267,10 @@ public void testTimerInternalsProcessingTimeSkew() { .setTimestamp(timerTimestamp.getMillis() * 1000) .setType(Windmill.Timer.Type.REALTIME); - executionContext.start( - "key", + start( createMockWork( workItemBuilder.build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); TimerInternals timerInternals = stepContext.timerInternals(); assertTrue(timerTimestamp.isBefore(timerInternals.currentProcessingTime())); } @@ -421,47 +468,62 @@ public void testStateTagEncodingBasedOnConfig() { for (Boolean isV2Encoding : Lists.newArrayList(Boolean.TRUE, Boolean.FALSE)) { Class expectedEncoding = isV2Encoding ? WindmillTagEncodingV2.class : WindmillTagEncodingV1.class; - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); globalConfigHandle.setConfig( StreamingGlobalConfig.builder().setEnableStateTagEncodingV2(isV2Encoding).build()); - executionContext.start( - "key", + start( createMockWork( Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); assertEquals(expectedEncoding, executionContext.getWindmillTagEncoding().getClass()); } } @Test public void testSetBacklogBytes() { - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); NameContext nameContext = NameContextsForTests.nameContextForTest(); DataflowOperationContext operationContext = executionContext.createOperationContext(nameContext); StreamingModeExecutionContext.StepContext stepContext = executionContext.getStepContext(operationContext); - executionContext.start( - "key", + start( createMockWork( Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); stepContext.setBacklogBytes(1234.0); executionContext.finishKey(); executionContext.flushState(); - assertEquals(1234, outputBuilder.getSourceBacklogBytes()); + assertEquals(1234, executionContext.getOutputBuilder().getSourceBacklogBytes()); + } + + @Test + public void testFinishKeyReentrantSafety() { + start( + createMockWork( + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); + + // First call + executionContext.finishKey(); + // Second call - should not throw any Exception + executionContext.finishKey(); + } + + @Test + public void testStart_internalKeyDecoding() throws Exception { + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("decodedKey")) + .setWorkToken(17L) + .build(); + Work work = + createMockWork( + workItem, Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()); + + start(work, org.apache.beam.sdk.coders.StringUtf8Coder.of()); + + assertEquals("decodedKey", executionContext.getKey()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java index 539c38eeb1da..a56343e3dfb3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java @@ -30,6 +30,7 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.ThreadLocalRandom; +import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.options.ValueProvider; @@ -122,6 +123,7 @@ public void testFinishKeyCalled() throws Exception { .build()) .build(); when(mockContext.getWorkItem()).thenReturn(workItem); + when(mockContext.advance()).thenReturn(false); try (TestWindmillReaderIterator iter = new TestWindmillReaderIterator(mockContext, ValueProvider.StaticValueProvider.of(false))) { @@ -131,6 +133,78 @@ public void testFinishKeyCalled() throws Exception { } } + @Test + public void testAdvanceKeyChaining() throws Exception { + StreamingModeExecutionContext mockContext = mock(StreamingModeExecutionContext.class); + when(mockContext.workIsFailed()).thenReturn(false); + + // Work item A (1 message) + Windmill.WorkItem workItemA = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("keyA")) + .setWorkToken(100L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(1000) + .setData(ByteString.EMPTY) + .build()) + .build()) + .build(); + when(mockContext.getWorkItem()).thenReturn(workItemA); + + // Work item B (1 message) + Windmill.WorkItem workItemB = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("keyB")) + .setWorkToken(200L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(2000) + .setData(ByteString.EMPTY) + .build()) + .build()) + .build(); + + Work mockWorkB = createMockWork(workItemB); + + // Set up context.advance() to mock transition + when(mockContext.advance()) + .thenAnswer( + new org.mockito.stubbing.Answer() { + private int count = 0; + + @Override + public Boolean answer(org.mockito.invocation.InvocationOnMock invocation) { + if (count == 0) { + count++; + when(mockContext.getWork()).thenReturn(mockWorkB); + return true; + } + return false; + } + }); + + try (TestWindmillReaderIterator iter = + new TestWindmillReaderIterator(mockContext, ValueProvider.StaticValueProvider.of(false))) { + assertTrue(iter.start()); + assertEquals(1000L, iter.getCurrent().getValue().longValue()); + + // Advance should trigger context.advance(), transition to workItemB, and decode message from + // workItemB (timestamp 2000) + assertTrue(iter.advance()); + assertEquals(2000L, iter.getCurrent().getValue().longValue()); + + // Next advance should exhaust it and return false + assertFalse(iter.advance()); + } + } + private void testForMessageBundleCounts(int... messageBundleCounts) throws IOException { testForMessageBundleCounts(false, messageBundleCounts); } @@ -179,4 +253,24 @@ private void testForMessageBundleCounts(boolean skipErrors, int... messageBundle assertEquals(Arrays.toString(messageBundleCounts) + skipErrors, expected, actual); } } + + private static Work createMockWork(Windmill.WorkItem workItem) { + return Work.create( + workItem, + workItem.getSerializedSize(), + org.apache.beam.runners.dataflow.worker.streaming.Watermarks.builder() + .setInputDataWatermark(new org.joda.time.Instant(1000)) + .build(), + Work.createProcessingContext( + "computationId", + mock( + org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient + .class), + ignored -> {}, + mock( + org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender + .class)), + false, + org.joda.time.Instant::now); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index d5cf2948d928..4175b47bfe4f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -209,6 +209,18 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla Instant::now); } + private void startContext(StreamingModeExecutionContext context, Work work) { + context.start( + work, + mock(WindmillStateReader.class), + mock(SideInputStateFetcher.class), + mock(WorkExecutor.class), + /* workQueueExecutor= */ null, + /* budgetHandle= */ null, + /* keyCoder= */ null, + /* keySwitchListener= */ mock(StreamingModeExecutionContext.KeySwitchListener.class)); + } + private static class SourceProducingSubSourcesInSplit extends MockSource { int numDesiredBundle; int sourceObjectSize; @@ -620,7 +632,11 @@ public void testReadUnboundedReader() throws Exception { executionStateRegistry, globalConfigHandle, Long.MAX_VALUE, - /*throwExceptionOnLargeOutput=*/ false); + /*throwExceptionOnLargeOutput=*/ false, + new HotKeyLogger(), + /*hotKeyLoggingEnabled=*/ false, + /*stepName=*/ "stepName", + "sourceBytesProcessCounterName"); options.setNumWorkers(5); int maxElements = 10; @@ -631,8 +647,8 @@ public void testReadUnboundedReader() throws Exception { for (int i = 0; i < 10 * maxElements; /* Incremented in inner loop */ ) { // Initialize streaming context with state from previous iteration. - context.start( - "key", + startContext( + context, createMockWork( Windmill.WorkItem.newBuilder() .setKey(ByteString.copyFromUtf8("0000000000000001")) // key is zero-padded index. @@ -641,11 +657,7 @@ public void testReadUnboundedReader() throws Exception { .setSourceState( Windmill.SourceState.newBuilder().setState(state).build()) // Source state. .build(), - Watermarks.builder().setInputDataWatermark(new Instant(0)).build()), - mock(WindmillStateReader.class), - mock(SideInputStateFetcher.class), - Windmill.WorkItemCommitRequest.newBuilder(), - mock(WorkExecutor.class)); + Watermarks.builder().setInputDataWatermark(new Instant(0)).build())); @SuppressWarnings({"unchecked", "rawtypes"}) NativeReader>>> reader = @@ -992,7 +1004,11 @@ public void testFailedWorkItemsAbort() throws Exception { executionStateRegistry, globalConfigHandle, Long.MAX_VALUE, - /*throwExceptionOnLargeOutput=*/ false); + /*throwExceptionOnLargeOutput=*/ false, + new HotKeyLogger(), + /*hotKeyLoggingEnabled=*/ false, + /*stepName=*/ "stepName", + "sourceBytesProcessCounterName"); options.setNumWorkers(5); int maxElements = 100; @@ -1020,13 +1036,7 @@ public void testFailedWorkItemsAbort() throws Exception { mock(HeartbeatSender.class)), false, Instant::now); - context.start( - "key", - dummyWork, - mock(WindmillStateReader.class), - mock(SideInputStateFetcher.class), - Windmill.WorkItemCommitRequest.newBuilder(), - mock(WorkExecutor.class)); + startContext(context, dummyWork); @SuppressWarnings({"unchecked", "rawtypes"}) NativeReader>>> reader = From 53bc9a690b6fa311d1b3f21e95c41399e7487091 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Thu, 4 Jun 2026 11:05:57 +0000 Subject: [PATCH 03/21] trigger postsubmit tests --- ...beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json | 2 +- ...stCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json index e623d3373a93..50d17c108f2e 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run!", - "modification": 1, + "modification": 2, } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json index e623d3373a93..50d17c108f2e 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run!", - "modification": 1, + "modification": 2, } From 7720cf8fc61ebae6a8da5317d3324d4bf307c182 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Thu, 4 Jun 2026 18:53:19 +0000 Subject: [PATCH 04/21] fix tests --- .../worker/StreamingModeExecutionContext.java | 6 ++++++ .../worker/StreamingDataflowWorkerTest.java | 17 ++++++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index a669fb7ff361..3f62dcbd038f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -296,6 +296,9 @@ public void clear() { this.workQueueExecutor = null; this.budgetHandle = null; this.keySwitchListener = null; + this.work = null; + this.key = null; + this.outputBuilder = null; } public void start( @@ -693,6 +696,9 @@ public boolean advance() { } private void startForNewKey(Work newWork, WindmillStateReader reader) { + if (keySwitchListener != null && this.work != null && this.work != newWork) { + keySwitchListener.onKeySwitch(this.work, newWork); + } this.key = decodeKey(newWork); this.work = newWork; this.finishKeyCalled = false; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index f7511305bf0f..2591db19ec00 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -420,6 +420,7 @@ private ParallelInstruction makeWindowingSourceInstruction(Coder coder) { CloudObjects.asCloudObject(IntervalWindowCoder.of(), /* sdkComponents= */ null))); return new ParallelInstruction() + .setName(DEFAULT_SOURCE_SYSTEM_NAME) .setSystemName(DEFAULT_SOURCE_SYSTEM_NAME) .setOriginalName(DEFAULT_SOURCE_ORIGINAL_NAME) .setRead( @@ -439,6 +440,7 @@ private ParallelInstruction makeWindowingSourceInstruction(Coder coder) { private ParallelInstruction makeSourceInstruction(Coder coder) { return new ParallelInstruction() + .setName(DEFAULT_SOURCE_SYSTEM_NAME) .setSystemName(DEFAULT_SOURCE_SYSTEM_NAME) .setOriginalName(DEFAULT_SOURCE_ORIGINAL_NAME) .setRead( @@ -3955,11 +3957,16 @@ public void testDoFnLatencyBreakdownsReportedOnCommit() throws Exception { LatencyAttribution.newBuilder().setState(State.ACTIVE).setTotalDurationMillis(100); for (LatencyAttribution la : commit.getPerWorkItemLatencyAttributionsList()) { if (la.getState() == State.ACTIVE) { - assertThat(la.getActiveLatencyBreakdownCount(), equalTo(1)); - assertThat( - la.getActiveLatencyBreakdown(0).getUserStepName(), equalTo(DEFAULT_PARDO_USER_NAME)); - Assert.assertTrue(la.getActiveLatencyBreakdown(0).hasProcessingTimesDistribution()); - Assert.assertFalse(la.getActiveLatencyBreakdown(0).hasActiveMessageMetadata()); + LatencyAttribution.ActiveLatencyBreakdown pardoBreakdown = null; + for (LatencyAttribution.ActiveLatencyBreakdown lb : la.getActiveLatencyBreakdownList()) { + if (DEFAULT_PARDO_USER_NAME.equals(lb.getUserStepName())) { + pardoBreakdown = lb; + break; + } + } + Assert.assertNotNull("Expected breakdown for " + DEFAULT_PARDO_USER_NAME, pardoBreakdown); + Assert.assertTrue(pardoBreakdown.hasProcessingTimesDistribution()); + Assert.assertFalse(pardoBreakdown.hasActiveMessageMetadata()); } } From 72740d25cb74c00ca35e933122f847978619a732 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Thu, 4 Jun 2026 21:36:27 +0000 Subject: [PATCH 05/21] fix tests --- .../runners/dataflow/worker/StreamingDataflowWorkerTest.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 2591db19ec00..22350c525ab2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -529,6 +529,7 @@ private ParallelInstruction makeSinkInstruction( CloudObject spec = CloudObject.forClass(WindmillSink.class); addString(spec, "stream_id", streamId); return new ParallelInstruction() + .setName(streamId) .setSystemName(DEFAULT_SINK_SYSTEM_NAME) .setOriginalName(DEFAULT_SINK_ORIGINAL_NAME) .setWrite( @@ -2502,6 +2503,7 @@ private List makeUnboundedSourcePipeline( return Arrays.asList( new ParallelInstruction() + .setName("Read") .setSystemName("Read") .setOriginalName("OriginalReadName") .setRead( From 0f96e723ce9524f68f51319e7e9a68fce05f088b Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Thu, 4 Jun 2026 22:02:56 +0000 Subject: [PATCH 06/21] improve work synchronization --- .../apache/beam/runners/dataflow/worker/streaming/Work.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index 668657228dfd..78cb54b3575b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -241,7 +241,7 @@ public void setProcessingThreadName(String processingThreadName) { } @Override - public synchronized void setFailed() { + public void setFailed() { this.isFailed = true; Runnable listener = onFailureListener; if (listener != null) { @@ -249,7 +249,7 @@ public synchronized void setFailed() { } } - public synchronized void setOnFailureListener(@Nullable Runnable listener) { + public void setOnFailureListener(@Nullable Runnable listener) { this.onFailureListener = listener; if (isFailed && listener != null) { listener.run(); From 51b9257c54260d9dbc188ce70121aef9fefda054 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Fri, 5 Jun 2026 01:03:21 +0000 Subject: [PATCH 07/21] cleanup logic --- .../windmill/work/processing/StreamingWorkScheduler.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index 9ee2192b09d8..5e890ef3d635 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -381,8 +381,7 @@ private ExecuteWorkResult executeWork( StreamingModeExecutionContext context = computationWorkExecutor.context(); if (context.workIsFailed()) { - throw new WorkItemCancelledException( - Preconditions.checkNotNull(context.getWorkItem()).getShardingKey()); + throw new WorkItemCancelledException(work.getWorkItem().getShardingKey()); } // Retrieve executed works, output builders, and accumulated callbacks from execution context @@ -411,6 +410,7 @@ private ExecuteWorkResult executeWork( t); computationWorkExecutor.invalidate(); } + // Re-throw the exception, it will be caught and handled by workFailureProcessor downstream. throw t; } } From 9a4e7bea7894a9b38c19f6db813c256809da83bc Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Fri, 5 Jun 2026 01:11:34 +0000 Subject: [PATCH 08/21] cleanup logic --- .../dataflow/worker/StreamingModeExecutionContext.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 3f62dcbd038f..770093dbcce4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -560,8 +560,9 @@ public void invalidateCache() { try { activeReader.close(); } catch (IOException e) { - LOG.warn( - "Failed to close reader for {}-{}", computationId, getWorkItem().getShardingKey(), e); + Windmill.WorkItem workItem = getWorkItem(); + long shardingKey = workItem != null ? workItem.getShardingKey() : -1L; + LOG.warn("Failed to close reader for {}-{}", computationId, shardingKey, e); } } activeReader = null; From 3f36afd72f5529da7389b2da2e47d87d363f21fd Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 03:40:43 +0000 Subject: [PATCH 09/21] address comments --- .../worker/StreamingModeExecutionContext.java | 61 +++++--- .../streaming/ComputationWorkExecutor.java | 11 +- .../dataflow/worker/streaming/Work.java | 17 ++- .../processing/StreamingWorkScheduler.java | 126 +++++++++-------- .../StreamingModeExecutionContextTest.java | 11 +- .../worker/WorkerCustomSourcesTest.java | 11 +- .../dataflow/worker/streaming/WorkTest.java | 132 ++++++++++++++++++ 7 files changed, 275 insertions(+), 94 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WorkTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 770093dbcce4..af9f29c7b9ba 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -35,6 +35,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; @@ -58,6 +59,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInput; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; @@ -149,6 +151,7 @@ public class StreamingModeExecutionContext private @Nullable Work work; private WindmillComputationKey computationKey; + private SideInputStateFetcherFactory sideInputStateFetcherFactory; private SideInputStateFetcher sideInputStateFetcher; // OperationalLimits is updated in start() because a StreamingModeExecutionContext can // be used for processing many work items and these values can change during the context's @@ -174,22 +177,24 @@ public class StreamingModeExecutionContext private @Nullable BoundedQueueExecutorWorkHandle budgetHandle; private final HotKeyLogger hotKeyLogger; - private boolean hotKeyLoggingEnabled = false; + private final boolean hotKeyLoggingEnabled; private final String stepName; private @Nullable Coder keyCoder; // Key switch listener to delegate MDC logging context and thread name updates - public interface KeySwitchListener { - void onKeySwitch(Work oldWork, Work newWork); + public interface KeyTransitionListener { + void onKeyTransition(Work oldWork, Work newWork); } @SuppressWarnings("UnusedVariable") - private @Nullable KeySwitchListener keySwitchListener; + private @Nullable KeyTransitionListener keyTransitionListener; private List executedWorks = new ArrayList<>(); private List outputBuilders = new ArrayList<>(); + + // Map> private Map> accumulatedCallbacks = new HashMap<>(); - private volatile boolean workIsFailed = false; + private final AtomicBoolean workIsFailed = new AtomicBoolean(false); private @Nullable WindmillStateReader activeStateReader; private long stateBytesRead = 0; private final String sourceBytesProcessCounterName; @@ -248,7 +253,7 @@ public boolean throwExceptionsForLargeOutput() { } public boolean workIsFailed() { - return workIsFailed; + return workIsFailed.get(); } public boolean getDrainMode() { @@ -287,7 +292,7 @@ public void clear() { this.executedWorks = new ArrayList<>(); this.outputBuilders = new ArrayList<>(); this.accumulatedCallbacks = new HashMap<>(); - this.workIsFailed = false; + this.workIsFailed.set(false); this.sideInputCache.clear(); this.activeStateReader = null; this.activeReader = null; @@ -295,31 +300,32 @@ public void clear() { this.workExecutor = null; this.workQueueExecutor = null; this.budgetHandle = null; - this.keySwitchListener = null; + this.keyTransitionListener = null; this.work = null; this.key = null; this.outputBuilder = null; + this.sideInputStateFetcherFactory = null; + this.sideInputStateFetcher = null; + this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; + clearSinkFullHint(); + this.stateBytesRead = 0; } public void start( Work work, WindmillStateReader stateReader, - SideInputStateFetcher sideInputStateFetcher, + SideInputStateFetcherFactory sideInputStateFetcherFactory, WorkExecutor workExecutor, BoundedQueueExecutor workQueueExecutor, BoundedQueueExecutorWorkHandle budgetHandle, @Nullable Coder keyCoder, - KeySwitchListener keySwitchListener) { + KeyTransitionListener keyTransitionListener) { clear(); this.keyCoder = keyCoder; this.workExecutor = workExecutor; this.workQueueExecutor = workQueueExecutor; this.budgetHandle = budgetHandle; - this.keySwitchListener = keySwitchListener; - - this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; - clearSinkFullHint(); - this.stateBytesRead = 0; + this.keyTransitionListener = keyTransitionListener; StreamingGlobalConfig config = globalConfigHandle.getConfig(); // Snapshot the limits for entire bundle processing. @@ -328,7 +334,7 @@ public void start( config.enableStateTagEncodingV2() ? WindmillTagEncodingV2.instance() : WindmillTagEncodingV1.instance(); - this.sideInputStateFetcher = sideInputStateFetcher; + this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; startForNewKey(work, stateReader); } @@ -388,6 +394,9 @@ public void finishKey() { if (activeStateReader != null) { this.stateBytesRead += activeStateReader.getBytesRead(); } + if (sideInputStateFetcher != null) { + this.stateBytesRead += sideInputStateFetcher.getBytesRead(); + } checkStateNotNull(workExecutor, "workExecutor must be set before calling finishKey()"); try { workExecutor.finishKey(); @@ -697,8 +706,9 @@ public boolean advance() { } private void startForNewKey(Work newWork, WindmillStateReader reader) { - if (keySwitchListener != null && this.work != null && this.work != newWork) { - keySwitchListener.onKeySwitch(this.work, newWork); + newWork.setState(Work.State.PROCESSING); + if (keyTransitionListener != null && this.work != null && this.work != newWork) { + keyTransitionListener.onKeyTransition(this.work, newWork); } this.key = decodeKey(newWork); this.work = newWork; @@ -707,11 +717,16 @@ private void startForNewKey(Work newWork, WindmillStateReader reader) { this.outputBuilder = createOutputBuilder(newWork); this.outputBuilders.add(this.outputBuilder); - newWork.setOnFailureListener(() -> this.workIsFailed = true); + newWork.setOnFailureListener(this.workIsFailed); this.executedWorks.add(newWork); logHotKeyIfDetected(newWork, this.key); + this.sideInputStateFetcher = + sideInputStateFetcherFactory.createSideInputStateFetcher(newWork::fetchSideInput); + this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; + this.activeReader = null; + // Note: We do NOT clear sideInputCache here, allowing Key B to reuse warm side inputs! // Re-initialize state cache and state/timer internals across all step contexts @@ -738,8 +753,12 @@ public long getStateBytesRead() { return stateBytesRead; } - public List getOutputBuilders() { - return outputBuilders; + public List getWorkItemCommits() { + List commits = new ArrayList<>(outputBuilders.size()); + for (Windmill.WorkItemCommitRequest.Builder builder : outputBuilders) { + commits.add(builder.build()); + } + return commits; } public Map> getAccumulatedCallbacks() { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java index ed86d58b9bb0..31420b212c31 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java @@ -23,7 +23,8 @@ import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.dataflow.worker.DataflowWorkExecutor; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.sdk.annotations.Internal; @@ -65,21 +66,21 @@ public static ComputationWorkExecutor.Builder builder() { public final void executeWork( Work work, WindmillStateReader stateReader, - SideInputStateFetcher sideInputStateFetcher, + SideInputStateFetcherFactory sideInputStateFetcherFactory, BoundedQueueExecutor workQueueExecutor, BoundedQueueExecutorWorkHandle budgetHandle, - StreamingModeExecutionContext.KeySwitchListener keySwitchListener) + KeyTransitionListener keyTransitionListener) throws Exception { context() .start( work, stateReader, - sideInputStateFetcher, + sideInputStateFetcherFactory, workExecutor(), workQueueExecutor, budgetHandle, keyCoder().orElse(null), - keySwitchListener); + keyTransitionListener); workExecutor().execute(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index 78cb54b3575b..44c1805e221f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -27,6 +27,8 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Supplier; import javax.annotation.concurrent.NotThreadSafe; @@ -80,7 +82,8 @@ public final class Work implements RefreshableWork { private volatile TimedState currentState; private volatile boolean isFailed; private volatile String processingThreadName = ""; - private volatile @Nullable Runnable onFailureListener = null; + private final AtomicReference<@Nullable AtomicBoolean> onFailureListener = + new AtomicReference<>(null); private final boolean drainMode; private Work( @@ -243,16 +246,18 @@ public void setProcessingThreadName(String processingThreadName) { @Override public void setFailed() { this.isFailed = true; - Runnable listener = onFailureListener; + AtomicBoolean listener = onFailureListener.get(); if (listener != null) { - listener.run(); + listener.set(true); } } - public void setOnFailureListener(@Nullable Runnable listener) { - this.onFailureListener = listener; + // Sets the passed in boolean to true if the work fails + // Supports registering only one boolean at a time. + public void setOnFailureListener(@Nullable AtomicBoolean listener) { + onFailureListener.set(listener); if (isFailed && listener != null) { - listener.run(); + listener.set(true); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index 5e890ef3d635..dc1fd4791fcd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -34,6 +34,7 @@ import org.apache.beam.runners.dataflow.worker.HotKeyLogger; import org.apache.beam.runners.dataflow.worker.ReaderCache; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC; import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; @@ -46,7 +47,6 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingCounters; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.ExceptionUtils; @@ -184,14 +184,22 @@ private static Windmill.WorkItemCommitRequest buildWorkItemTruncationRequest( /** Sets the stage name and workId of the Thread executing the {@link Work} for logging. */ private static void setUpWorkLoggingContext(String workLatencyTrackingId, String computationId) { - DataflowWorkerLoggingMDC.setWorkId(workLatencyTrackingId); + setLoggingContextWorkId(workLatencyTrackingId); + setLoggingContextComputation(computationId); + } + + private static void setLoggingContextComputation(@Nullable String computationId) { DataflowWorkerLoggingMDC.setStageName(computationId); } + private static void setLoggingContextWorkId(@Nullable String workLatencyTrackingId) { + DataflowWorkerLoggingMDC.setWorkId(workLatencyTrackingId); + } + /** Resets logging context of the Thread executing the {@link Work} for logging. */ private void resetWorkLoggingContext() { - DataflowWorkerLoggingMDC.setWorkId(null); - DataflowWorkerLoggingMDC.setStageName(null); + setLoggingContextWorkId(null); + setLoggingContextComputation(null); } /** @@ -256,7 +264,7 @@ private void processWork( long processingStartTimeNanos = System.nanoTime(); StageInfo stageInfo = getStageInfo(computationState); - List worksToCleanup = null; + List workBatch = null; try { if (work.isFailed()) { throw new WorkItemCancelledException(workItem.getShardingKey()); @@ -264,18 +272,14 @@ private void processWork( // Execute the user code for the Work batch. ExecuteWorkResult executeWorkResult = executeWork(work, stageInfo, computationState, handle); - List workBatch = executeWorkResult.workBatch(); - worksToCleanup = workBatch; - List outputBuilders = - executeWorkResult.outputBuilders(); - Map> accumulatedCallbacks = - executeWorkResult.accumulatedCallbacks(); + workBatch = executeWorkResult.workBatch(); + List workItemCommits = executeWorkResult.workItemCommits(); - commitFinalizer.cacheCommitFinalizers(accumulatedCallbacks); + commitFinalizer.cacheCommitFinalizers(executeWorkResult.accumulatedCallbacks()); - commitWorkBatch(computationState, workBatch, outputBuilders); + commitWorkBatch(computationState, workBatch, workItemCommits); - recordProcessingStats(workBatch, outputBuilders, executeWorkResult.stateBytesRead()); + recordProcessingStats(workBatch, workItemCommits, executeWorkResult.stateBytesRead()); LOG.debug("Processing done for work batch size: {}", workBatch.size()); } catch (Throwable t) { // OutOfMemoryError that are caught will be rethrown and trigger jvm termination. @@ -294,11 +298,19 @@ private void processWork( throw ExceptionUtils.safeWrapThrowableAsException(t2); } } finally { - recordProcessingTime(stageInfo, worksToCleanup, work, processingStartTimeNanos); + // Update total processing time counters. Updating in finally clause ensures that + // work items causing exceptions are also accounted in time spent. + recordProcessingTime(stageInfo, workBatch, work, processingStartTimeNanos); resetWorkLoggingContext(); sampler.resetForWorkId(work.getLatencyTrackingId()); - work.setProcessingThreadName(""); + if (workBatch != null) { + for (Work w : workBatch) { + w.setProcessingThreadName(""); + } + } else { + work.setProcessingThreadName(""); + } } } @@ -331,16 +343,18 @@ private Windmill.WorkItemCommitRequest validateCommitRequestSize( private void recordProcessingStats( List workBatch, - List outputBuilders, + List workItemCommits, long totalStateBytesRead) { long totalStateBytesWritten = 0; long totalShuffleBytesRead = 0; + Preconditions.checkState(workBatch.size() == workItemCommits.size()); for (int i = 0; i < workBatch.size(); i++) { Windmill.WorkItem workItem = workBatch.get(i).getWorkItem(); - Windmill.WorkItemCommitRequest.Builder outputBuilder = outputBuilders.get(i); + Windmill.WorkItemCommitRequest commit = workItemCommits.get(i); // Compute shuffle and state byte statistics these will be flushed asynchronously. long stateBytesWritten = - outputBuilder + commit + .toBuilder() .clearOutputMessages() .clearPerWorkItemLatencyAttributions() .build() @@ -369,36 +383,43 @@ private ExecuteWorkResult executeWork( try { WindmillStateReader stateReader = work.createWindmillStateReader(); - SideInputStateFetcher localSideInputStateFetcher = - sideInputStateFetcherFactory.createSideInputStateFetcher(work::fetchSideInput); - StreamingModeExecutionContext.KeySwitchListener keySwitchListener = - createKeySwitchListener(computationState); + KeyTransitionListener keyTransitionListener = createKeyTransitionListener(); // Blocks while executing work. computationWorkExecutor.executeWork( - work, stateReader, localSideInputStateFetcher, workExecutor, handle, keySwitchListener); - - StreamingModeExecutionContext context = computationWorkExecutor.context(); - if (context.workIsFailed()) { - throw new WorkItemCancelledException(work.getWorkItem().getShardingKey()); + work, + stateReader, + sideInputStateFetcherFactory, + workExecutor, + handle, + keyTransitionListener); + + List workBatch; + List workItemCommits; + Map> accumulatedCallbacks; + long stateBytesRead; + { + StreamingModeExecutionContext context = computationWorkExecutor.context(); + if (context.workIsFailed()) { + throw new WorkItemCancelledException(work.getWorkItem().getShardingKey()); + } + + // Retrieve executed works, work item commits, and accumulated callbacks from execution + // context + workBatch = context.getExecutedWorks(); + workItemCommits = context.getWorkItemCommits(); + accumulatedCallbacks = context.getAccumulatedCallbacks(); + stateBytesRead = context.getStateBytesRead(); + + context.clear(); // Don't use context after this. } - - // Retrieve executed works, output builders, and accumulated callbacks from execution context - List workBatch = context.getExecutedWorks(); - List outputBuilders = context.getOutputBuilders(); - Map> accumulatedCallbacks = context.getAccumulatedCallbacks(); - - context.clear(); // Release the execution state for another thread to use. computationState.releaseComputationWorkExecutor(computationWorkExecutor); computationWorkExecutor = null; return ExecuteWorkResult.create( - workBatch, - outputBuilders, - accumulatedCallbacks, - context.getStateBytesRead() + localSideInputStateFetcher.getBytesRead()); + workBatch, workItemCommits, accumulatedCallbacks, stateBytesRead); } catch (Throwable t) { if (computationWorkExecutor != null) { // If processing failed due to a thrown exception, close the executionState. Do not @@ -433,20 +454,18 @@ private StageInfo getStageInfo(ComputationState computationState) { private void commitWorkBatch( ComputationState computationState, List workBatch, - List outputBuilders) { + List workItemCommits) { Preconditions.checkState( workBatch.size() == 1, "Expected single-key work batch, got: " + workBatch.size()); - commitSingleKeyWork(computationState, workBatch.get(0), outputBuilders.get(0)); + commitSingleKeyWork(computationState, workBatch.get(0), workItemCommits.get(0)); } private void commitSingleKeyWork( - ComputationState computationState, - Work work, - Windmill.WorkItemCommitRequest.Builder commitRequestBuilder) { + ComputationState computationState, Work work, Windmill.WorkItemCommitRequest commitRequest) { // Validate the commit request, possibly requesting truncation if the commitSize is too large. Windmill.WorkItemCommitRequest validatedCommitRequest = validateCommitRequestSize( - commitRequestBuilder.build(), computationState.getComputationId(), work.getWorkItem()); + commitRequest, computationState.getComputationId(), work.getWorkItem()); work.setState(Work.State.COMMIT_QUEUED); validatedCommitRequest = validatedCommitRequest @@ -461,8 +480,6 @@ private void recordProcessingTime( @Nullable List worksToCleanup, Work work, long processingStartTimeNanos) { - // Update total processing time counters. Updating in finally clause ensures that - // work items causing exceptions are also accounted in time spent. long processingTimeMsecs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); @@ -483,12 +500,10 @@ private static boolean anyWorkHasTimers(@Nullable List works, Work primary return primaryWork.getWorkItem().hasTimers(); } - private StreamingModeExecutionContext.KeySwitchListener createKeySwitchListener( - ComputationState computationState) { + private KeyTransitionListener createKeyTransitionListener() { return (oldWork, newWork) -> { - resetWorkLoggingContext(); - setUpWorkLoggingContext(newWork.getLatencyTrackingId(), computationState.getComputationId()); - newWork.setProcessingThreadName(Thread.currentThread().getName()); + setLoggingContextWorkId(newWork.getLatencyTrackingId()); + newWork.setProcessingThreadName(oldWork.getProcessingThreadName()); oldWork.setProcessingThreadName(""); }; } @@ -497,17 +512,18 @@ private StreamingModeExecutionContext.KeySwitchListener createKeySwitchListener( abstract static class ExecuteWorkResult { static ExecuteWorkResult create( List workBatch, - List outputBuilders, + List workItemCommits, Map> accumulatedCallbacks, long stateBytesRead) { return new AutoValue_StreamingWorkScheduler_ExecuteWorkResult( - workBatch, outputBuilders, accumulatedCallbacks, stateBytesRead); + workBatch, workItemCommits, accumulatedCallbacks, stateBytesRead); } abstract List workBatch(); - abstract List outputBuilders(); + abstract List workItemCommits(); + // Map> abstract Map> accumulatedCallbacks(); abstract long stateBytesRead(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 9d4ef999707c..6d84e9b4b0bf 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -48,6 +48,7 @@ import org.apache.beam.runners.core.metrics.ExecutionStateSampler; import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.core.metrics.ExecutionStateTracker.ExecutionState; +import org.apache.beam.runners.dataflow.options.DataflowStreamingPipelineOptions; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.DataflowExecutionContext.DataflowExecutionStateTracker; import org.apache.beam.runners.dataflow.worker.MetricsToCounterUpdateConverter.Kind; @@ -61,7 +62,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.FakeGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; @@ -99,7 +100,6 @@ public class StreamingModeExecutionContextTest { @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - @Mock private SideInputStateFetcher sideInputStateFetcher; @Mock private WindmillStateReader stateReader; @Mock private WorkExecutor workExecutor; @@ -203,15 +203,18 @@ private void start(StreamingModeExecutionContext context, Work work) { } private void start(StreamingModeExecutionContext context, Work work, Coder keyCoder) { + SideInputStateFetcherFactory sideInputStateFetcherFactory = + SideInputStateFetcherFactory.fromOptions( + options.as(DataflowStreamingPipelineOptions.class)); context.start( work, stateReader, - sideInputStateFetcher, + sideInputStateFetcherFactory, workExecutor, /* workQueueExecutor= */ null, /* budgetHandle= */ null, keyCoder, - /* keySwitchListener= */ (k, c) -> {}); + /* keyTransitionListener= */ (k, c) -> {}); } @Test diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index 4175b47bfe4f..bd4e40d6570a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -80,9 +80,11 @@ import org.apache.beam.runners.dataflow.DataflowRunner; import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions; import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; +import org.apache.beam.runners.dataflow.options.DataflowStreamingPipelineOptions; import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.PropertyNames; import org.apache.beam.runners.dataflow.worker.DataflowExecutionContext.DataflowExecutionStateTracker; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.StreamingModeExecutionStateRegistry; import org.apache.beam.runners.dataflow.worker.WorkerCustomSources.SplittableOnlyBoundedSource; import org.apache.beam.runners.dataflow.worker.counters.CounterSet; @@ -93,7 +95,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.config.FixedGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.testing.TestCountingSource; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader.NativeReaderIterator; @@ -210,15 +212,18 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla } private void startContext(StreamingModeExecutionContext context, Work work) { + SideInputStateFetcherFactory sideInputStateFetcherFactory = + SideInputStateFetcherFactory.fromOptions( + options.as(DataflowStreamingPipelineOptions.class)); context.start( work, mock(WindmillStateReader.class), - mock(SideInputStateFetcher.class), + sideInputStateFetcherFactory, mock(WorkExecutor.class), /* workQueueExecutor= */ null, /* budgetHandle= */ null, /* keyCoder= */ null, - /* keySwitchListener= */ mock(StreamingModeExecutionContext.KeySwitchListener.class)); + /* keyTransitionListener= */ mock(KeyTransitionListener.class)); } private static class SourceProducingSubSourcesInSplit extends MockSource { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WorkTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WorkTest.java new file mode 100644 index 000000000000..80ca91da462f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WorkTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class WorkTest { + + private static Work createTestWork() { + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key")) + .setWorkToken(1L) + .setShardingKey(2L) + .build(); + return Work.create( + workItem, + workItem.getSerializedSize(), + Watermarks.builder().setInputDataWatermark(Instant.now()).build(), + Work.createProcessingContext( + "comp", + mock( + org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient + .class), + commit -> {}, + mock(HeartbeatSender.class)), + false, + Instant::now); + } + + @Test + public void testSetFailedBeforeListener() { + Work work = createTestWork(); + assertFalse(work.isFailed()); + + work.setFailed(); + assertTrue(work.isFailed()); + + AtomicBoolean listener = new AtomicBoolean(false); + work.setOnFailureListener(listener); + assertTrue(listener.get()); + } + + @Test + public void testSetFailedAfterListener() { + Work work = createTestWork(); + AtomicBoolean listener = new AtomicBoolean(false); + work.setOnFailureListener(listener); + assertFalse(listener.get()); + assertFalse(work.isFailed()); + + work.setFailed(); + assertTrue(work.isFailed()); + assertTrue(listener.get()); + } + + @Test + public void testConcurrentSetFailedAndSetOnFailureListener() throws Exception { + int numTrials = 5000; + ExecutorService executor = Executors.newFixedThreadPool(2); + try { + for (int i = 0; i < numTrials; i++) { + Work work = createTestWork(); + AtomicBoolean listener = new AtomicBoolean(false); + CountDownLatch latch = new CountDownLatch(1); + + Future f1 = + executor.submit( + () -> { + try { + latch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + work.setFailed(); + }); + + Future f2 = + executor.submit( + () -> { + try { + latch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + work.setOnFailureListener(listener); + }); + + latch.countDown(); + f1.get(5, TimeUnit.SECONDS); + f2.get(5, TimeUnit.SECONDS); + + assertTrue("Trial " + i + " failed: work should be failed", work.isFailed()); + assertTrue("Trial " + i + " failed: listener should be set to true", listener.get()); + } + } finally { + executor.shutdownNow(); + } + } +} From f3cc6284fb97df1bc3762647c5c2bd7452509efe Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 05:37:04 +0000 Subject: [PATCH 10/21] improve WindowingWindmillReader --- .../worker/WindmillReaderIteratorBase.java | 2 +- .../worker/WindowingWindmillReader.java | 83 +++--- .../worker/WindowingWindmillReaderTest.java | 275 ++++++++++++++++++ 3 files changed, 315 insertions(+), 45 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java index b142cc38d365..20d0c40ae4a3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java @@ -73,7 +73,7 @@ public boolean advance() throws IOException { continue; } - // All work items are exhausted. Iterator returns false. + // All work items are exhausted. current = null; return false; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java index 916920518f0b..fc11ff8dca76 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java @@ -30,7 +30,6 @@ import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.worker.util.ValueInEmptyWindows; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; @@ -49,7 +48,6 @@ @Internal class WindowingWindmillReader extends NativeReader>> { - private final Coder keyCoder; private final Coder valueCoder; private final Coder windowCoder; private final Coder> windowsCoder; @@ -66,7 +64,6 @@ class WindowingWindmillReader extends NativeReader keyedWorkItemCoder = (WindmillKeyedWorkItem.FakeKeyedWorkItemCoder) inputCoder.getValueCoder(); - this.keyCoder = keyedWorkItemCoder.getKeyCoder(); this.valueCoder = keyedWorkItemCoder.getElementCoder(); this.context = context; this.skipUndecodableElements = skipUndecodableElements; @@ -129,27 +126,32 @@ public static WindowingWindmillReader create( return new WindowingWindmillReader<>(coder, context, skipUndecodableElements); } + private KeyedWorkItem createKeyedWorkItem() { + @SuppressWarnings("unchecked") + @Nullable K key = (K) context.getKey(); + return new WindmillKeyedWorkItem<>( + key, + context.getWorkItem(), + windowCoder, + windowsCoder, + valueCoder, + context.getWindmillTagEncoding(), + context.getDrainMode(), + skipUndecodableElements.isAccessible() + && Boolean.TRUE.equals(skipUndecodableElements.get())); + } + + private boolean isEmpty(KeyedWorkItem keyedWorkItem) { + return Iterables.isEmpty(keyedWorkItem.timersIterable()) + && Iterables.isEmpty(keyedWorkItem.elementsIterable()); + } + @Override public NativeReaderIterator>> iterator() throws IOException { - final K key = - keyCoder.decode( - checkStateNotNull(context.getSerializedKey()).newInput(), Coder.Context.OUTER); - final WorkItem workItem = context.getWorkItem(); - KeyedWorkItem keyedWorkItem = - new WindmillKeyedWorkItem<>( - key, - workItem, - windowCoder, - windowsCoder, - valueCoder, - context.getWindmillTagEncoding(), - context.getDrainMode(), - skipUndecodableElements.isAccessible() - && Boolean.TRUE.equals(skipUndecodableElements.get())); - final boolean isEmptyWorkItem = - (Iterables.isEmpty(keyedWorkItem.timersIterable()) - && Iterables.isEmpty(keyedWorkItem.elementsIterable())); - final WindowedValue> value = new ValueInEmptyWindows<>(keyedWorkItem); + final KeyedWorkItem firstKeyedWorkItem = createKeyedWorkItem(); + final boolean firstKeyIsEmpty = isEmpty(firstKeyedWorkItem); + final WindowedValue> firstValue = + new ValueInEmptyWindows<>(firstKeyedWorkItem); return new NativeReaderIterator>>() { private @Nullable WindowedValue> current = null; @@ -165,10 +167,10 @@ public boolean start() throws IOException { return false; } started = true; - if (isEmptyWorkItem) { + if (firstKeyIsEmpty) { return advance(); // Try to transition immediately if the first key is empty! } - current = value; + current = firstValue; return true; } @@ -179,27 +181,20 @@ public boolean advance() throws IOException { checkStateNotNull(context.getWorkItem()).getShardingKey()); } - context.finishKey(); - if (context.advance()) { - @SuppressWarnings("unchecked") - K newKey = (K) context.getKey(); - KeyedWorkItem newKeyedWorkItem = - new WindmillKeyedWorkItem<>( - newKey, - context.getWork().getWorkItem(), - windowCoder, - windowsCoder, - valueCoder, - context.getWindmillTagEncoding(), - context.getDrainMode(), - skipUndecodableElements.isAccessible() - && Boolean.TRUE.equals(skipUndecodableElements.get())); - current = new ValueInEmptyWindows<>(newKeyedWorkItem); - return true; + while (true) { + context.finishKey(); + if (context.advance()) { + KeyedWorkItem newKeyedWorkItem = createKeyedWorkItem(); + if (isEmpty(newKeyedWorkItem)) { + continue; + } + current = new ValueInEmptyWindows<>(newKeyedWorkItem); + return true; + } + + current = null; + return false; } - - current = null; - return false; } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java new file mode 100644 index 000000000000..2e7c80330cf0 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.List; +import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTagEncodingV1; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues.FullWindowedValueCoder; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class WindowingWindmillReaderTest { + private StreamingModeExecutionContext mockContext; + private WindowingWindmillReader reader; + + @SuppressWarnings("unchecked") + @Before + public void setUp() { + mockContext = mock(StreamingModeExecutionContext.class); + when(mockContext.workIsFailed()).thenReturn(false); + when(mockContext.getWindmillTagEncoding()).thenReturn(WindmillTagEncodingV1.instance()); + when(mockContext.getDrainMode()).thenReturn(false); + + Coder keyCoder = StringUtf8Coder.of(); + Coder valueCoder = VarLongCoder.of(); + KvCoder kvCoder = KvCoder.of(keyCoder, valueCoder); + WindmillKeyedWorkItem.FakeKeyedWorkItemCoder keyedWorkItemCoder = + (WindmillKeyedWorkItem.FakeKeyedWorkItemCoder) + WindmillKeyedWorkItem.FakeKeyedWorkItemCoder.of(kvCoder); + FullWindowedValueCoder> coder = + FullWindowedValueCoder.of(keyedWorkItemCoder, IntervalWindowCoder.of()); + + reader = + WindowingWindmillReader.create( + coder, mockContext, ValueProvider.StaticValueProvider.of(false)); + } + + private static Work createMockWork(Windmill.WorkItem workItem) { + return Work.create( + workItem, + workItem.getSerializedSize(), + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build(), + Work.createProcessingContext( + "computationId", new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), + false, + Instant::now); + } + + private static ByteString encodeMetadata(List windows) throws IOException { + ByteStringOutputStream stream = new ByteStringOutputStream(); + PaneInfoCoder.INSTANCE.encode(PaneInfo.NO_FIRING, stream); + ListCoder.of(IntervalWindowCoder.of()).encode(windows, stream); + return stream.toByteString(); + } + + private static ByteString encodeValue(long value) throws IOException { + ByteStringOutputStream stream = new ByteStringOutputStream(); + VarLongCoder.of().encode(value, stream); + return stream.toByteString(); + } + + @Test + public void testSingleNonEmptyKey() throws IOException { + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(1000)); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(100L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(1000) + .setData(encodeValue(42L)) + .setMetadata(encodeMetadata(ImmutableList.of(window))) + .build()) + .build()) + .build(); + Work work = createMockWork(workItem); + + when(mockContext.getKey()).thenReturn("key1"); + when(mockContext.getWorkItem()).thenReturn(workItem); + when(mockContext.getWork()).thenReturn(work); + when(mockContext.advance()).thenReturn(false); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + assertTrue(iter.start()); + WindowedValue> current = iter.getCurrent(); + assertEquals("key1", current.getValue().key()); + assertFalse(Iterables.isEmpty(current.getValue().elementsIterable())); + WindowedValue elem = Iterables.getOnlyElement(current.getValue().elementsIterable()); + assertEquals(42L, elem.getValue().longValue()); + + assertFalse(iter.advance()); + verify(mockContext).finishKey(); + } + } + + @Test + public void testSingleEmptyKey() throws IOException { + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(100L) + .build(); // No message bundles or timers + Work work = createMockWork(workItem); + + when(mockContext.getKey()).thenReturn("key1"); + when(mockContext.getWorkItem()).thenReturn(workItem); + when(mockContext.getWork()).thenReturn(work); + when(mockContext.advance()).thenReturn(false); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + assertFalse( + iter.start()); // Should skip the empty key and return false because advance returns false + verify(mockContext).finishKey(); + } + } + + @Test + public void testMultipleKeys_withEmptyAndNonEmpty() throws IOException { + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(1000)); + // Key 1: Empty + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(100L) + .build(); + Work work1 = createMockWork(workItem1); + + // Key 2: Non-empty + Windmill.WorkItem workItem2 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key2")) + .setWorkToken(200L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(2000) + .setData(encodeValue(84L)) + .setMetadata(encodeMetadata(ImmutableList.of(window))) + .build()) + .build()) + .build(); + Work work2 = createMockWork(workItem2); + + // Key 3: Empty + Windmill.WorkItem workItem3 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key3")) + .setWorkToken(300L) + .build(); + Work work3 = createMockWork(workItem3); + + // Initial state + when(mockContext.getKey()).thenReturn("key1"); + when(mockContext.getWorkItem()).thenReturn(workItem1); + when(mockContext.getWork()).thenReturn(work1); + + // Mock transition behaviour of context.advance() + when(mockContext.advance()) + .thenAnswer( + new org.mockito.stubbing.Answer() { + private int count = 0; + + @Override + public Boolean answer(org.mockito.invocation.InvocationOnMock invocation) { + if (count == 0) { + count++; + when(mockContext.getKey()).thenReturn("key2"); + when(mockContext.getWorkItem()).thenReturn(workItem2); + when(mockContext.getWork()).thenReturn(work2); + return true; + } else if (count == 1) { + count++; + when(mockContext.getKey()).thenReturn("key3"); + when(mockContext.getWorkItem()).thenReturn(workItem3); + when(mockContext.getWork()).thenReturn(work3); + return true; + } + return false; + } + }); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + // Key 1 is empty, so start() calls advance() which calls finishKey(1) and advance() to Key 2. + // Key 2 is non-empty, so start() returns true yielding Key 2. + assertTrue(iter.start()); + assertEquals("key2", iter.getCurrent().getValue().key()); + WindowedValue elem = + Iterables.getOnlyElement(iter.getCurrent().getValue().elementsIterable()); + assertEquals(84L, elem.getValue().longValue()); + + // Next advance() calls finishKey(2), calls advance() to Key 3. + // Key 3 is empty, so it loops, calls finishKey(3), calls advance() which returns false. + // So iter.advance() should return false. + assertFalse(iter.advance()); + + verify(mockContext, times(3)) + .finishKey(); // finishKey should have been called on key1, key2, key3 + } + } + + @Test + public void testWorkItemCancelled() throws IOException { + when(mockContext.workIsFailed()).thenReturn(true); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(0L).build(); + when(mockContext.getWorkItem()).thenReturn(workItem); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + iter.start(); + fail("Expected WorkItemCancelledException"); + } catch (WorkItemCancelledException e) { + // Expected + } + } +} From 58e0ef9393d45860b61df4c10257351c425bf15f Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 05:49:24 +0000 Subject: [PATCH 11/21] spotless fix --- .../beam/runners/dataflow/worker/WindowingWindmillReader.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java index fc11ff8dca76..2003ec001a55 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java @@ -128,7 +128,8 @@ public static WindowingWindmillReader create( private KeyedWorkItem createKeyedWorkItem() { @SuppressWarnings("unchecked") - @Nullable K key = (K) context.getKey(); + @Nullable + K key = (K) context.getKey(); return new WindmillKeyedWorkItem<>( key, context.getWorkItem(), From 3dceab0470d5228bb5346e9d8ac92f528a714ff0 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 07:57:33 +0000 Subject: [PATCH 12/21] [Dataflow Streaming] Fix nullness supression in StreamingModeExecutionContext --- .../worker/StreamingModeExecutionContext.java | 262 +++++++++--------- 1 file changed, 136 insertions(+), 126 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 25ce299adf7a..89ccb576051f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.dataflow.worker; import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import com.google.api.services.dataflow.model.CounterUpdate; @@ -62,6 +61,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataId; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.Timer; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache.ForComputation; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateInternals; @@ -105,10 +105,7 @@ * state pertaining to a processing its owning computation. Can be reused across processing * different WorkItems for the same computation. */ -@SuppressWarnings({ - "deprecation", - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) +@SuppressWarnings({"deprecation"}) // TODO(m-trieu) fix nullability issues in StreamingModeExecutionContext.java @NotThreadSafe @Internal @@ -143,13 +140,13 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext SideInput fetchSideInput( return fetchSideInputFromWindmill( view, sideInputWindow, - checkNotNull(stateFamily), + checkStateNotNull(stateFamily), state, - checkNotNull(scopedReadStateSupplier), + checkStateNotNull(scopedReadStateSupplier), tagCache); } @@ -383,8 +386,8 @@ private SideInput fetchSideInputFromWindmill( Supplier scopedReadStateSupplier, Map> tagCache) { SideInput fetched = - sideInputStateFetcher.fetchSideInput( - view, sideInputWindow, stateFamily, state, scopedReadStateSupplier); + checkStateNotNull(sideInputStateFetcher) + .fetchSideInput(view, sideInputWindow, stateFamily, state, scopedReadStateSupplier); if (fetched.isReady()) { tagCache.put(sideInputWindow, fetched); @@ -406,7 +409,7 @@ private List getFiredTimers() { } public WindmillComputationKey getComputationKey() { - return computationKey; + return checkStateNotNull(computationKey); } public long getWorkToken() { @@ -414,7 +417,7 @@ public long getWorkToken() { } public Windmill.WorkItem getWorkItem() { - return checkNotNull( + return checkStateNotNull( work, "work is null. A call to StreamingModeExecutionContext.start(...) is required to set" + " work for execution.") @@ -422,7 +425,7 @@ public Windmill.WorkItem getWorkItem() { } public Windmill.WorkItemCommitRequest.Builder getOutputBuilder() { - return outputBuilder; + return checkStateNotNull(outputBuilder); } /** @@ -490,15 +493,16 @@ public Map> flushState() { throw new RuntimeException("Exception while running bundle finalizer", e); } })); - outputBuilder.addFinalizeIds(id); + getOutputBuilder().addFinalizeIds(id); } } - if (activeReader != null) { - Windmill.SourceState.Builder sourceStateBuilder = - outputBuilder.getSourceStateUpdatesBuilder(); - final UnboundedSource.CheckpointMark checkpointMark = activeReader.getCheckpointMark(); - final Instant watermark = activeReader.getWatermark(); + UnboundedReader reader = activeReader; + if (reader != null) { + Windmill.WorkItemCommitRequest.Builder builder = getOutputBuilder(); + Windmill.SourceState.Builder sourceStateBuilder = builder.getSourceStateUpdatesBuilder(); + final UnboundedSource.CheckpointMark checkpointMark = reader.getCheckpointMark(); + final Instant watermark = reader.getWatermark(); long id = ThreadLocalRandom.current().nextLong(); sourceStateBuilder.addFinalizeIds(id); callbacks.put( @@ -515,7 +519,7 @@ public Map> flushState() { @SuppressWarnings("unchecked") Coder checkpointCoder = - ((UnboundedSource) activeReader.getCurrentSource()) + ((UnboundedSource) reader.getCurrentSource()) .getCheckpointMarkCoder(); if (checkpointCoder != null) { ByteStringOutputStream stream = new ByteStringOutputStream(); @@ -525,7 +529,7 @@ public Map> flushState() { throw new RuntimeException("Exception while encoding checkpoint", e); } sourceStateBuilder.setState(stream.toByteString()); - if (activeReader.getCurrentSource().offsetBasedDeduplicationSupported()) { + if (reader.getCurrentSource().offsetBasedDeduplicationSupported()) { byte[] offsetLimit = checkpointMark.getOffsetLimit(); if (offsetLimit.length == 0) { throw new RuntimeException("Checkpoint offset limit must be non-empty."); @@ -533,31 +537,30 @@ public Map> flushState() { sourceStateBuilder.setOffsetLimit(ByteString.copyFrom(offsetLimit)); } } - outputBuilder.setSourceWatermark(WindmillTimeUtils.harnessToWindmillTimestamp(watermark)); + builder.setSourceWatermark(WindmillTimeUtils.harnessToWindmillTimestamp(watermark)); - backlogBytes = activeReader.getSplitBacklogBytes(); + backlogBytes = reader.getSplitBacklogBytes(); + ByteString serializedKey = checkStateNotNull(getSerializedKey()); if (backlogBytes == UnboundedReader.BACKLOG_UNKNOWN - && WorkerCustomSources.isFirstUnboundedSourceSplit(getSerializedKey())) { + && WorkerCustomSources.isFirstUnboundedSourceSplit(serializedKey)) { // Only call getTotalBacklogBytes() on the first split. - backlogBytes = activeReader.getTotalBacklogBytes(); + backlogBytes = reader.getTotalBacklogBytes(); } - outputBuilder.setSourceBacklogBytes(backlogBytes); + builder.setSourceBacklogBytes(backlogBytes); readerCache.cacheReader( - getComputationKey(), - getWorkItem().getCacheToken(), - getWorkItem().getWorkToken(), - activeReader); + getComputationKey(), getWorkItem().getCacheToken(), getWorkItem().getWorkToken(), reader); activeReader = null; } else if (backlogBytes != UnboundedReader.BACKLOG_UNKNOWN && backlogBytes != 1L) { // If activeReader is null, we might still have backlogBytes from an SDF. We ignore a reported // backlogBytes of 1 since older versions of the Java SDK use this value as a default when // RestrictionTracker.getProgress() or GetSize() are not defined. - outputBuilder.setSourceBacklogBytes(backlogBytes); + getOutputBuilder().setSourceBacklogBytes(backlogBytes); } return callbacks; } + @Nullable String getStateFamily(NameContext nameContext) { return nameContext.userName() == null ? null : stateNameMap.get(nameContext.userName()); } @@ -599,7 +602,7 @@ public static class StreamingModeExecutionState extends DataflowExecutionState { public StreamingModeExecutionState( NameContext nameContext, String stateName, - MetricsContainer metricsContainer, + @Nullable MetricsContainer metricsContainer, ProfileScope profileScope) { // TODO: Take in the requesting step name and side input index for streaming. super(nameContext, stateName, null, null, metricsContainer, profileScope); @@ -642,14 +645,16 @@ public static class StreamingModeExecutionStateRegistry extends DataflowExecutio protected DataflowExecutionState createState( NameContext nameContext, String stateName, - String requestingStepName, - Integer inputIndex, - MetricsContainer container, + @Nullable String requestingStepName, + @Nullable Integer inputIndex, + @Nullable MetricsContainer container, ProfileScope profileScope) { return new StreamingModeExecutionState(nameContext, stateName, container, profileScope); } } + private static final Closeable NO_OP_CLOSEABLE = () -> {}; + private static class ScopedReadStateSupplier implements Supplier { private final ExecutionState readState; @@ -662,9 +667,9 @@ private ScopedReadStateSupplier( } @Override - public @Nullable Closeable get() { + public Closeable get() { if (stateTracker == null) { - return null; + return NO_OP_CLOSEABLE; } return stateTracker.enterState(readState); } @@ -725,7 +730,7 @@ public TimerInternals timerInternals() { } @Override - public TimerData getNextFiredTimer(Coder windowCoder) { + public @Nullable TimerData getNextFiredTimer(Coder windowCoder) { return wrapped.getNextFiredUserTimer(windowCoder); } @@ -777,7 +782,7 @@ public static StreamingModeSideInputReader of( } @Override - public T get(PCollectionView view, BoundedWindow window) { + public @Nullable T get(PCollectionView view, BoundedWindow window) { if (!contains(view)) { throw new RuntimeException("get() called with unknown view"); } @@ -810,31 +815,32 @@ public boolean isEmpty() { class StepContext extends DataflowExecutionContext.DataflowStepContext implements StreamingModeStepContext { - private final String stateFamily; + private final @Nullable String stateFamily; private final Supplier scopedReadStateSupplier; - private WindmillStateInternals stateInternals; - private WindmillTimerInternals systemTimerInternals; - private WindmillTimerInternals userTimerInternals; + private @Nullable WindmillStateInternals stateInternals; + private @Nullable WindmillTimerInternals systemTimerInternals; + private @Nullable WindmillTimerInternals userTimerInternals; // Lazily initialized - private Iterator cachedFiredSystemTimers = null; + private @Nullable Iterator cachedFiredSystemTimers = null; // Lazily initialized - private PeekingIterator cachedFiredUserTimers = null; + private @Nullable PeekingIterator cachedFiredUserTimers = null; // An ordered list of any timers that were set or modified by user processing earlier in this // bundle. // We use a NavigableSet instead of a priority queue to prevent duplicate elements from ending // up in the queue. - private NavigableSet modifiedUserEventTimersOrdered = null; - private NavigableSet modifiedUserProcessingTimersOrdered = null; - private NavigableSet modifiedUserSynchronizedProcessingTimersOrdered = null; + private final NavigableSet modifiedUserEventTimersOrdered = Sets.newTreeSet(); + private final NavigableSet modifiedUserProcessingTimersOrdered = Sets.newTreeSet(); + private final NavigableSet modifiedUserSynchronizedProcessingTimersOrdered = + Sets.newTreeSet(); // A list of timer keys that were modified by user processing earlier in this bundle. This // serves a tombstone, so that we know not to fire any bundle timers that were modified. - private Table modifiedUserTimerKeys = null; + private final Table modifiedUserTimerKeys = + HashBasedTable.create(); private final WindmillBundleFinalizer bundleFinalizer = new WindmillBundleFinalizer(); public StepContext(DataflowOperationContext operationContext) { super(operationContext.nameContext()); this.stateFamily = getStateFamily(operationContext.nameContext()); - this.scopedReadStateSupplier = new ScopedReadStateSupplier(operationContext, getExecutionStateTracker()); } @@ -845,46 +851,50 @@ public void start( Instant processingTime, WindmillStateCache.ForKey cacheForKey, Watermarks watermarks) { - this.stateInternals = - new WindmillStateInternals<>( - key, - stateFamily, - stateReader, - getWorkItem().getIsNewKey(), - cacheForKey.forFamily(stateFamily), - windmillTagEncoding, - scopedReadStateSupplier); - - this.systemTimerInternals = - new WindmillTimerInternals( - stateFamily, - WindmillTimerType.SYSTEM_TIMER, - processingTime, - watermarks, - windmillTagEncoding, - td -> {}); - - this.userTimerInternals = - new WindmillTimerInternals( - stateFamily, - WindmillTimerType.USER_TIMER, - processingTime, - watermarks, - windmillTagEncoding, - this::onUserTimerModified); - + if (stateFamily != null) { + this.stateInternals = + new WindmillStateInternals<>( + key, + stateFamily, + stateReader, + getWorkItem().getIsNewKey(), + cacheForKey.forFamily(stateFamily), + windmillTagEncoding, + scopedReadStateSupplier); + + this.systemTimerInternals = + new WindmillTimerInternals( + stateFamily, + WindmillTimerType.SYSTEM_TIMER, + processingTime, + watermarks, + windmillTagEncoding, + td -> {}); + + this.userTimerInternals = + new WindmillTimerInternals( + stateFamily, + WindmillTimerType.USER_TIMER, + processingTime, + watermarks, + windmillTagEncoding, + this::onUserTimerModified); + } this.cachedFiredSystemTimers = null; this.cachedFiredUserTimers = null; - modifiedUserEventTimersOrdered = Sets.newTreeSet(); - modifiedUserProcessingTimersOrdered = Sets.newTreeSet(); - modifiedUserSynchronizedProcessingTimersOrdered = Sets.newTreeSet(); - modifiedUserTimerKeys = HashBasedTable.create(); + this.modifiedUserEventTimersOrdered.clear(); + this.modifiedUserProcessingTimersOrdered.clear(); + this.modifiedUserSynchronizedProcessingTimersOrdered.clear(); + this.modifiedUserTimerKeys.clear(); } public void flushState() { - stateInternals.persist(outputBuilder); - systemTimerInternals.persistTo(outputBuilder); - userTimerInternals.persistTo(outputBuilder); + if (stateFamily != null) { + WorkItemCommitRequest.Builder builder = getOutputBuilder(); + checkStateNotNull(stateInternals).persist(builder); + checkStateNotNull(systemTimerInternals).persistTo(builder); + checkStateNotNull(userTimerInternals).persistTo(builder); + } } @Override @@ -893,9 +903,10 @@ public void setBacklogBytes(double backlogBytes) { } @Override - public TimerData getNextFiredTimer(Coder windowCoder) { - if (cachedFiredSystemTimers == null) { - cachedFiredSystemTimers = + public @Nullable TimerData getNextFiredTimer(Coder windowCoder) { + Iterator firedSystemTimers = cachedFiredSystemTimers; + if (firedSystemTimers == null) { + firedSystemTimers = FluentIterable.from(StreamingModeExecutionContext.this.getFiredTimers()) .filter(timer -> timer.getStateFamily().equals(stateFamily)) .transform( @@ -907,16 +918,17 @@ timer, windowCoder, getDrainMode())) windmillTimerData.getWindmillTimerType() == WindmillTimerType.SYSTEM_TIMER) .transform(WindmillTimerData::getTimerData) .iterator(); + cachedFiredSystemTimers = firedSystemTimers; } - if (!cachedFiredSystemTimers.hasNext()) { + if (!firedSystemTimers.hasNext()) { return null; } - TimerData nextTimer = cachedFiredSystemTimers.next(); + TimerData nextTimer = firedSystemTimers.next(); // system timers ( GC timer) must be explicitly deleted if only there is a hold. // if timestamp is not equals to outputTimestamp then there should be a hold if (!nextTimer.getTimestamp().equals(nextTimer.getOutputTimestamp())) { - systemTimerInternals.deleteTimer(nextTimer); + checkStateNotNull(systemTimerInternals).deleteTimer(nextTimer); } return nextTimer; } @@ -950,12 +962,14 @@ private boolean isTimerUnmodified(TimerData timerData) { return updatedTimer == null || updatedTimer.equals(timerData); } - public TimerData getNextFiredUserTimer(Coder windowCoder) { - if (cachedFiredUserTimers == null) { + public @Nullable TimerData getNextFiredUserTimer( + Coder windowCoder) { + PeekingIterator firedUserTimers = cachedFiredUserTimers; + if (firedUserTimers == null) { // This is the first call to getNextFiredUserTimer in this bundle. Extract any user timers // from the bundle // and cache the list for the rest of this bundle processing. - cachedFiredUserTimers = + firedUserTimers = Iterators.peekingIterator( FluentIterable.from(StreamingModeExecutionContext.this.getFiredTimers()) .filter(timer -> timer.getStateFamily().equals(stateFamily)) @@ -969,17 +983,20 @@ timer, windowCoder, getDrainMode())) == WindmillTimerType.USER_TIMER) .transform(WindmillTimerData::getTimerData) .iterator()); + cachedFiredUserTimers = firedUserTimers; } - while (cachedFiredUserTimers.hasNext()) { - TimerData nextInBundle = cachedFiredUserTimers.peek(); + WindmillTimerInternals nonNullUserTimerInternals = checkStateNotNull(this.userTimerInternals); + + while (firedUserTimers.hasNext()) { + TimerData nextInBundle = firedUserTimers.peek(); NavigableSet modifiedUserTimersOrdered = getModifiedUserTimersOrdered(nextInBundle.getDomain()); // If there is a modified timer that is earlier than the next timer in the bundle, try and // fire that first. while (!modifiedUserTimersOrdered.isEmpty() && modifiedUserTimersOrdered.first().compareTo(nextInBundle) <= 0) { - TimerData earlierTimer = modifiedUserTimersOrdered.pollFirst(); + TimerData earlierTimer = checkStateNotNull(modifiedUserTimersOrdered.pollFirst()); if (isTimerUnmodified(earlierTimer)) { // We must delete the timer. This prevents it from being committed to the backing store. // It also handles the @@ -987,15 +1004,15 @@ timer, windowCoder, getDrainMode())) // without deleting the // timer, the runner will still have that future timer stored, and would fire it // spuriously. - userTimerInternals.deleteTimer(earlierTimer); + nonNullUserTimerInternals.deleteTimer(earlierTimer); return earlierTimer; } } // There is no earlier timer to fire, so return the next timer in the bundle. - nextInBundle = cachedFiredUserTimers.next(); + nextInBundle = firedUserTimers.next(); if (isTimerUnmodified(nextInBundle)) { // User timers must be explicitly deleted when delivered, to release the implied hold. - userTimerInternals.deleteTimer(nextInBundle); + nonNullUserTimerInternals.deleteTimer(nextInBundle); return nextInBundle; } } @@ -1029,12 +1046,6 @@ public Iterable getSideInputNotifications() { return StreamingModeExecutionContext.this.getSideInputNotifications(); } - private void ensureStateful(String errorPrefix) { - if (stateFamily == null) { - throw new IllegalStateException(errorPrefix + " for stateless step: " + getNameContext()); - } - } - @Override public void writePCollectionViewData( TupleTag tag, @@ -1043,7 +1054,8 @@ public void writePCollectionViewData( W window, Coder windowCoder) throws IOException { - if (getSerializedKey().size() != 0) { + ByteString serializedKey = checkStateNotNull(getSerializedKey()); + if (serializedKey.size() != 0) { throw new IllegalStateException("writePCollectionViewData must follow a Combine.globally"); } @@ -1053,7 +1065,7 @@ public void writePCollectionViewData( ByteStringOutputStream windowStream = new ByteStringOutputStream(); windowCoder.encode(window, windowStream, Coder.Context.OUTER); - ensureStateful("Tried to write view data"); + String stateFamily = checkStateNotNull(this.stateFamily, "Tried to write view data"); Windmill.GlobalData.Builder builder = Windmill.GlobalData.newBuilder() @@ -1065,7 +1077,7 @@ public void writePCollectionViewData( .setData(dataStream.toByteString()) .setStateFamily(stateFamily); - outputBuilder.addGlobalDataUpdates(builder.build()); + getOutputBuilder().addGlobalDataUpdates(builder.build()); } /** Fetch the given side input asynchronously and return true if it is present. */ @@ -1080,11 +1092,12 @@ public boolean issueSideInputFetch( /** Note that there is data on the current key that is blocked on the given side input. */ @Override public void addBlockingSideInput(Windmill.GlobalDataRequest sideInput) { - ensureStateful("Tried to set global data request"); + String stateFamily = checkStateNotNull(this.stateFamily, "Tried to set global data request"); sideInput = Windmill.GlobalDataRequest.newBuilder(sideInput).setStateFamily(stateFamily).build(); - outputBuilder.addGlobalDataRequests(sideInput); - outputBuilder.addGlobalDataIdRequests(sideInput.getDataId()); + WorkItemCommitRequest.Builder builder = getOutputBuilder(); + builder.addGlobalDataRequests(sideInput); + builder.addGlobalDataIdRequests(sideInput.getDataId()); } /** Note that there is data on the current key that is blocked on the given side inputs. */ @@ -1097,14 +1110,12 @@ public void addBlockingSideInputs(Iterable sideInput @Override public StateInternals stateInternals() { - ensureStateful("Tried to access state"); - return checkNotNull(stateInternals); + return checkStateNotNull(stateInternals, "Tried to access state"); } @Override public TimerInternals timerInternals() { - ensureStateful("Tried to access timers"); - return checkNotNull(systemTimerInternals); + return checkStateNotNull(systemTimerInternals, "Tried to access timers"); } @Override @@ -1113,8 +1124,7 @@ public BundleFinalizer bundleFinalizer() { } public TimerInternals userTimerInternals() { - ensureStateful("Tried to access user timers"); - return checkNotNull(userTimerInternals); + return checkStateNotNull(userTimerInternals, "Tried to access user timers"); } public ImmutableList> flushBundleFinalizerCallbacks() { From e19943808e3d024cbcb95faa8fc9ee15c88c579b Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 08:19:41 +0000 Subject: [PATCH 13/21] make windmillTagEncoding final --- .../worker/StreamingModeExecutionContext.java | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 89ccb576051f..9008bf23f3af 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -127,7 +127,7 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext, Map>> sideInputCache; - private WindmillTagEncoding windmillTagEncoding; + private final WindmillTagEncoding windmillTagEncoding; /** * The current user-facing key for this execution context. * @@ -187,13 +187,10 @@ public StreamingModeExecutionContext( this.throwExceptionOnLargeOutput = throwExceptionOnLargeOutput; StreamingGlobalConfig config = globalConfigHandle.getConfig(); this.operationalLimits = config.operationalLimits(); - this.windmillTagEncoding = getWindmillTagEncoding(config); - } - - private static WindmillTagEncoding getWindmillTagEncoding(StreamingGlobalConfig config) { - return config.enableStateTagEncodingV2() - ? WindmillTagEncodingV2.instance() - : WindmillTagEncodingV1.instance(); + this.windmillTagEncoding = + config.enableStateTagEncodingV2() + ? WindmillTagEncodingV2.instance() + : WindmillTagEncodingV1.instance(); } @VisibleForTesting @@ -262,7 +259,6 @@ public void start( StreamingGlobalConfig config = globalConfigHandle.getConfig(); // Snapshot the limits for entire bundle processing. this.operationalLimits = config.operationalLimits(); - this.windmillTagEncoding = getWindmillTagEncoding(config); this.outputBuilder = outputBuilder; this.sideInputCache.clear(); this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; From 700dfbc8af35b89b70c921bbbd5a1676f6a2e513 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 09:00:25 +0000 Subject: [PATCH 14/21] address comments --- .../worker/StreamingModeExecutionContext.java | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 9008bf23f3af..00fdf67b8d02 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -94,6 +94,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.PeekingIterator; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; import org.joda.time.Instant; @@ -106,7 +107,6 @@ * different WorkItems for the same computation. */ @SuppressWarnings({"deprecation"}) -// TODO(m-trieu) fix nullability issues in StreamingModeExecutionContext.java @NotThreadSafe @Internal public class StreamingModeExecutionContext extends DataflowExecutionContext { @@ -813,9 +813,9 @@ class StepContext extends DataflowExecutionContext.DataflowStepContext private final @Nullable String stateFamily; private final Supplier scopedReadStateSupplier; - private @Nullable WindmillStateInternals stateInternals; - private @Nullable WindmillTimerInternals systemTimerInternals; - private @Nullable WindmillTimerInternals userTimerInternals; + private @MonotonicNonNull WindmillStateInternals stateInternals; + private @MonotonicNonNull WindmillTimerInternals systemTimerInternals; + private @MonotonicNonNull WindmillTimerInternals userTimerInternals; // Lazily initialized private @Nullable Iterator cachedFiredSystemTimers = null; // Lazily initialized @@ -900,6 +900,10 @@ public void setBacklogBytes(double backlogBytes) { @Override public @Nullable TimerData getNextFiredTimer(Coder windowCoder) { + if (stateFamily == null) { + // no timers on stateless stages + return null; + } Iterator firedSystemTimers = cachedFiredSystemTimers; if (firedSystemTimers == null) { firedSystemTimers = @@ -960,6 +964,11 @@ private boolean isTimerUnmodified(TimerData timerData) { public @Nullable TimerData getNextFiredUserTimer( Coder windowCoder) { + if (stateFamily == null) { + // no timers on stateless stages + return null; + } + PeekingIterator firedUserTimers = cachedFiredUserTimers; if (firedUserTimers == null) { // This is the first call to getNextFiredUserTimer in this bundle. Extract any user timers From bc5bee2db78b0672cb58f37af9cefdf290739ce0 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 09:09:15 +0000 Subject: [PATCH 15/21] Move SideInputStateFetcherFactory from start to constructor --- .../worker/StreamingModeExecutionContext.java | 9 ++++----- .../streaming/ComputationWorkExecutor.java | 3 --- .../ComputationWorkExecutorFactory.java | 9 +++++++-- .../work/processing/StreamingWorkScheduler.java | 17 ++++++----------- .../StreamingModeExecutionContextTest.java | 11 ++++------- .../worker/WorkerCustomSourcesTest.java | 11 ++++------- 6 files changed, 25 insertions(+), 35 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index af9f29c7b9ba..fce50fc6ac54 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -151,7 +151,7 @@ public class StreamingModeExecutionContext private @Nullable Work work; private WindmillComputationKey computationKey; - private SideInputStateFetcherFactory sideInputStateFetcherFactory; + private final SideInputStateFetcherFactory sideInputStateFetcherFactory; private SideInputStateFetcher sideInputStateFetcher; // OperationalLimits is updated in start() because a StreamingModeExecutionContext can // be used for processing many work items and these values can change during the context's @@ -214,7 +214,8 @@ public StreamingModeExecutionContext( HotKeyLogger hotKeyLogger, boolean hotKeyLoggingEnabled, String stepName, - String sourceBytesProcessCounterName) { + String sourceBytesProcessCounterName, + SideInputStateFetcherFactory sideInputStateFetcherFactory) { super( counterFactory, metricsContainerRegistry, @@ -233,6 +234,7 @@ public StreamingModeExecutionContext( this.hotKeyLoggingEnabled = hotKeyLoggingEnabled; this.stepName = checkNotNull(stepName); this.sourceBytesProcessCounterName = checkNotNull(sourceBytesProcessCounterName); + this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; } @VisibleForTesting @@ -304,7 +306,6 @@ public void clear() { this.work = null; this.key = null; this.outputBuilder = null; - this.sideInputStateFetcherFactory = null; this.sideInputStateFetcher = null; this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; clearSinkFullHint(); @@ -314,7 +315,6 @@ public void clear() { public void start( Work work, WindmillStateReader stateReader, - SideInputStateFetcherFactory sideInputStateFetcherFactory, WorkExecutor workExecutor, BoundedQueueExecutor workQueueExecutor, BoundedQueueExecutorWorkHandle budgetHandle, @@ -334,7 +334,6 @@ public void start( config.enableStateTagEncodingV2() ? WindmillTagEncodingV2.instance() : WindmillTagEncodingV1.instance(); - this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; startForNewKey(work, stateReader); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java index 31420b212c31..56a1a06362d2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java @@ -24,7 +24,6 @@ import org.apache.beam.runners.dataflow.worker.DataflowWorkExecutor; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.sdk.annotations.Internal; @@ -66,7 +65,6 @@ public static ComputationWorkExecutor.Builder builder() { public final void executeWork( Work work, WindmillStateReader stateReader, - SideInputStateFetcherFactory sideInputStateFetcherFactory, BoundedQueueExecutor workQueueExecutor, BoundedQueueExecutorWorkHandle budgetHandle, KeyTransitionListener keyTransitionListener) @@ -75,7 +73,6 @@ public final void executeWork( .start( work, stateReader, - sideInputStateFetcherFactory, workExecutor(), workQueueExecutor, budgetHandle, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java index fcc6d6bbb743..4a52d9fde771 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java @@ -49,6 +49,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.ComputationWorkExecutor; import org.apache.beam.runners.dataflow.worker.streaming.StageInfo; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.common.worker.MapTaskExecutor; import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; import org.apache.beam.runners.dataflow.worker.util.common.worker.ReadOperation; @@ -99,6 +100,7 @@ final class ComputationWorkExecutorFactory { private final StreamingGlobalConfigHandle globalConfigHandle; private final boolean throwExceptionOnLargeOutput; private final HotKeyLogger hotKeyLogger; + private final SideInputStateFetcherFactory sideInputStateFetcherFactory; ComputationWorkExecutorFactory( DataflowWorkerHarnessOptions options, @@ -109,7 +111,8 @@ final class ComputationWorkExecutorFactory { CounterSet pendingDeltaCounters, IdGenerator idGenerator, StreamingGlobalConfigHandle globalConfigHandle, - HotKeyLogger hotKeyLogger) { + HotKeyLogger hotKeyLogger, + SideInputStateFetcherFactory sideInputStateFetcherFactory) { this.options = options; this.mapTaskExecutorFactory = mapTaskExecutorFactory; this.readerCache = readerCache; @@ -128,6 +131,7 @@ final class ComputationWorkExecutorFactory { this.throwExceptionOnLargeOutput = hasExperiment(options, THROW_EXCEPTIONS_ON_LARGE_OUTPUT_EXPERIMENT); this.hotKeyLogger = hotKeyLogger; + this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; } private static Nodes.ParallelInstructionNode extractReadNode( @@ -282,7 +286,8 @@ private StreamingModeExecutionContext createExecutionContext( hotKeyLogger, hotKeyLoggingEnabled, stepName, - computationState.sourceBytesProcessCounterName()); + computationState.sourceBytesProcessCounterName(), + sideInputStateFetcherFactory); } private DataflowMapTaskExecutor createMapTaskExecutor( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index dc1fd4791fcd..9e28c64b7860 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -80,7 +80,6 @@ public class StreamingWorkScheduler { private final Supplier clock; private final ComputationWorkExecutorFactory computationWorkExecutorFactory; - private final SideInputStateFetcherFactory sideInputStateFetcherFactory; private final FailureTracker failureTracker; private final WorkFailureProcessor workFailureProcessor; private final StreamingCommitFinalizer commitFinalizer; @@ -94,7 +93,6 @@ public StreamingWorkScheduler( Supplier clock, BoundedQueueExecutor workExecutor, ComputationWorkExecutorFactory computationWorkExecutorFactory, - SideInputStateFetcherFactory sideInputStateFetcherFactory, FailureTracker failureTracker, WorkFailureProcessor workFailureProcessor, StreamingCommitFinalizer commitFinalizer, @@ -105,7 +103,6 @@ public StreamingWorkScheduler( this.clock = clock; this.workExecutor = workExecutor; this.computationWorkExecutorFactory = computationWorkExecutorFactory; - this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; this.failureTracker = failureTracker; this.workFailureProcessor = workFailureProcessor; this.commitFinalizer = commitFinalizer; @@ -131,6 +128,9 @@ public static StreamingWorkScheduler create( IdGenerator idGenerator, StreamingGlobalConfigHandle globalConfigHandle, ConcurrentMap stageInfoMap) { + SideInputStateFetcherFactory sideInputStateFetcherFactory = + SideInputStateFetcherFactory.fromOptions(options); + ComputationWorkExecutorFactory computationWorkExecutorFactory = new ComputationWorkExecutorFactory( options, @@ -141,13 +141,13 @@ public static StreamingWorkScheduler create( streamingCounters.pendingDeltaCounters(), idGenerator, globalConfigHandle, - hotKeyLogger); + hotKeyLogger, + sideInputStateFetcherFactory); return new StreamingWorkScheduler( clock, workExecutor, computationWorkExecutorFactory, - SideInputStateFetcherFactory.fromOptions(options), failureTracker, workFailureProcessor, StreamingCommitFinalizer.create(workExecutor, commitFinalizerCleanupExecutor), @@ -388,12 +388,7 @@ private ExecuteWorkResult executeWork( // Blocks while executing work. computationWorkExecutor.executeWork( - work, - stateReader, - sideInputStateFetcherFactory, - workExecutor, - handle, - keyTransitionListener); + work, stateReader, workExecutor, handle, keyTransitionListener); List workBatch; List workItemCommits; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 6d84e9b4b0bf..c1193afeff6b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -48,7 +48,6 @@ import org.apache.beam.runners.core.metrics.ExecutionStateSampler; import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.core.metrics.ExecutionStateTracker.ExecutionState; -import org.apache.beam.runners.dataflow.options.DataflowStreamingPipelineOptions; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.DataflowExecutionContext.DataflowExecutionStateTracker; import org.apache.beam.runners.dataflow.worker.MetricsToCounterUpdateConverter.Kind; @@ -144,7 +143,8 @@ public void setUp() { new HotKeyLogger(), /*hotKeyLoggingEnabled=*/ false, /*stepName=*/ "stepName", - "sourceBytesProcessCounterName"); + "sourceBytesProcessCounterName", + SideInputStateFetcherFactory.fromOptions(options)); } private StreamingModeExecutionContext createTestExecutionContext( @@ -176,7 +176,8 @@ private StreamingModeExecutionContext createTestExecutionContext( new HotKeyLogger(), /*hotKeyLoggingEnabled=*/ false, /*stepName=*/ "stepName", - "sourceBytesProcessCounterName"); + "sourceBytesProcessCounterName", + SideInputStateFetcherFactory.fromOptions(options)); } private static Work createMockWork(Windmill.WorkItem workItem, Watermarks watermarks) { @@ -203,13 +204,9 @@ private void start(StreamingModeExecutionContext context, Work work) { } private void start(StreamingModeExecutionContext context, Work work, Coder keyCoder) { - SideInputStateFetcherFactory sideInputStateFetcherFactory = - SideInputStateFetcherFactory.fromOptions( - options.as(DataflowStreamingPipelineOptions.class)); context.start( work, stateReader, - sideInputStateFetcherFactory, workExecutor, /* workQueueExecutor= */ null, /* budgetHandle= */ null, diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index bd4e40d6570a..0af802ec6760 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -80,7 +80,6 @@ import org.apache.beam.runners.dataflow.DataflowRunner; import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions; import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; -import org.apache.beam.runners.dataflow.options.DataflowStreamingPipelineOptions; import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.PropertyNames; import org.apache.beam.runners.dataflow.worker.DataflowExecutionContext.DataflowExecutionStateTracker; @@ -212,13 +211,9 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla } private void startContext(StreamingModeExecutionContext context, Work work) { - SideInputStateFetcherFactory sideInputStateFetcherFactory = - SideInputStateFetcherFactory.fromOptions( - options.as(DataflowStreamingPipelineOptions.class)); context.start( work, mock(WindmillStateReader.class), - sideInputStateFetcherFactory, mock(WorkExecutor.class), /* workQueueExecutor= */ null, /* budgetHandle= */ null, @@ -641,7 +636,8 @@ public void testReadUnboundedReader() throws Exception { new HotKeyLogger(), /*hotKeyLoggingEnabled=*/ false, /*stepName=*/ "stepName", - "sourceBytesProcessCounterName"); + "sourceBytesProcessCounterName", + SideInputStateFetcherFactory.fromOptions(options)); options.setNumWorkers(5); int maxElements = 10; @@ -1013,7 +1009,8 @@ public void testFailedWorkItemsAbort() throws Exception { new HotKeyLogger(), /*hotKeyLoggingEnabled=*/ false, /*stepName=*/ "stepName", - "sourceBytesProcessCounterName"); + "sourceBytesProcessCounterName", + SideInputStateFetcherFactory.fromOptions(options)); options.setNumWorkers(5); int maxElements = 100; From 47eb7d661ebee256807f5674065ce58c698bf668 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 22:23:22 +0000 Subject: [PATCH 16/21] Address comment --- .../worker/WindmillReaderIteratorBase.java | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java index 20d0c40ae4a3..d0f9eafbcd4c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java @@ -36,8 +36,8 @@ public abstract class WindmillReaderIteratorBase extends NativeReader.NativeReaderIterator> { private final StreamingModeExecutionContext context; private Windmill.WorkItem work; - private int bundleIndex = 0; - private int messageIndex = -1; + private int bundleIndex; + private int messageIndex; private @Nullable WindowedValue current = null; private final ValueProvider skipUndecodableElements; private static final Logger LOG = LoggerFactory.getLogger(WindmillReaderIteratorBase.class); @@ -46,7 +46,7 @@ protected WindmillReaderIteratorBase( StreamingModeExecutionContext context, ValueProvider skipUndecodableElements) { this.context = context; this.skipUndecodableElements = skipUndecodableElements; - this.work = context.getWorkItem(); + resetWorkFromContext(); } @Override @@ -67,9 +67,7 @@ public boolean advance() throws IOException { context.finishKey(); if (context.advance()) { // Transition succeeded! Update iterator references to the new work item - this.work = context.getWork().getWorkItem(); - this.bundleIndex = 0; - this.messageIndex = -1; + resetWorkFromContext(); continue; } @@ -104,6 +102,12 @@ public boolean advance() throws IOException { } } + private void resetWorkFromContext() { + this.work = context.getWork().getWorkItem(); + this.bundleIndex = 0; + this.messageIndex = -1; + } + protected abstract WindowedValue decodeMessage(Windmill.Message message) throws IOException; @Override From 24505dd7e47ee1270742504ccb397031582e8c2c Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 22:27:39 +0000 Subject: [PATCH 17/21] Address comment --- .../dataflow/worker/StreamingModeExecutionContext.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index ef27e38f9ca2..c2cf2f9f7940 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -192,7 +192,7 @@ public interface KeyTransitionListener { // Map> private Map> accumulatedCallbacks = new HashMap<>(); - private final AtomicBoolean workIsFailed = new AtomicBoolean(false); + private final AtomicBoolean workBatchFailed = new AtomicBoolean(false); private @Nullable WindmillStateReader activeStateReader; private long stateBytesRead = 0; private final String sourceBytesProcessCounterName; @@ -259,7 +259,7 @@ public boolean throwExceptionsForLargeOutput() { } public boolean workIsFailed() { - return workIsFailed.get(); + return workBatchFailed.get(); } public boolean getDrainMode() { @@ -298,7 +298,7 @@ public void clear() { this.executedWorks = new ArrayList<>(); this.outputBuilders = new ArrayList<>(); this.accumulatedCallbacks = new HashMap<>(); - this.workIsFailed.set(false); + this.workBatchFailed.set(false); this.sideInputCache.clear(); this.activeStateReader = null; this.activeReader = null; @@ -715,7 +715,7 @@ private void startForNewKey(Work newWork, WindmillStateReader reader) { this.outputBuilder = createOutputBuilder(newWork); this.outputBuilders.add(this.outputBuilder); - newWork.setOnFailureListener(this.workIsFailed); + newWork.setOnFailureListener(this.workBatchFailed); this.executedWorks.add(newWork); logHotKeyIfDetected(newWork, this.key); From 4e0d17448960650e31ac15ed768f25d7acd7bb35 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Tue, 9 Jun 2026 16:34:43 +0000 Subject: [PATCH 18/21] Fix UnderInitialization --- .../runners/dataflow/worker/WindmillReaderIteratorBase.java | 6 ++++-- .../dataflow/worker/WindmillReaderIteratorBaseTest.java | 4 +--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java index d0f9eafbcd4c..134655a72a54 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java @@ -46,7 +46,9 @@ protected WindmillReaderIteratorBase( StreamingModeExecutionContext context, ValueProvider skipUndecodableElements) { this.context = context; this.skipUndecodableElements = skipUndecodableElements; - resetWorkFromContext(); + this.work = context.getWorkItem(); + this.bundleIndex = 0; + this.messageIndex = -1; } @Override @@ -103,7 +105,7 @@ public boolean advance() throws IOException { } private void resetWorkFromContext() { - this.work = context.getWork().getWorkItem(); + this.work = context.getWorkItem(); this.bundleIndex = 0; this.messageIndex = -1; } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java index a56343e3dfb3..b45e0de6447c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java @@ -171,8 +171,6 @@ public void testAdvanceKeyChaining() throws Exception { .build()) .build(); - Work mockWorkB = createMockWork(workItemB); - // Set up context.advance() to mock transition when(mockContext.advance()) .thenAnswer( @@ -183,7 +181,7 @@ public void testAdvanceKeyChaining() throws Exception { public Boolean answer(org.mockito.invocation.InvocationOnMock invocation) { if (count == 0) { count++; - when(mockContext.getWork()).thenReturn(mockWorkB); + when(mockContext.getWorkItem()).thenReturn(workItemB); return true; } return false; From 93de23afa95f183283a0ced1b997d7a143e712a9 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Thu, 11 Jun 2026 11:34:18 +0000 Subject: [PATCH 19/21] Multikey commit failure handling and integration --- .../worker/StreamingDataflowWorker.java | 10 +- .../worker/StreamingModeExecutionContext.java | 75 +++- ...ption.java => WorkCancelingException.java} | 29 +- .../worker/WorkItemCancelledException.java | 25 +- .../worker/streaming/ActiveWorkState.java | 5 + .../BoundedQueueExecutorWorkHandle.java | 7 +- .../worker/streaming/ComputationState.java | 4 + .../streaming/ComputationWorkExecutor.java | 3 - .../dataflow/worker/streaming/Work.java | 76 +++- .../worker/util/BoundedQueueExecutor.java | 36 +- .../client/commits/CompleteCommit.java | 11 +- .../StreamingApplianceWorkCommitter.java | 3 +- .../commits/StreamingEngineWorkCommitter.java | 50 ++- .../client/getdata/StreamGetDataClient.java | 5 +- .../windmill/state/WindmillStateReader.java | 9 +- .../ComputationWorkExecutorFactory.java | 3 +- .../processing/StreamingWorkScheduler.java | 178 +++++--- .../failures/WorkFailureProcessor.java | 74 ++-- .../worker/KeyTokenInvalidExceptionTest.java | 39 -- .../worker/StreamingDataflowWorkerTest.java | 379 ++++++++++++++++-- .../StreamingModeExecutionContextTest.java | 196 ++++++++- .../worker/WorkerCustomSourcesTest.java | 4 +- .../worker/streaming/ActiveWorkStateTest.java | 25 ++ .../streaming/ComputationStateTest.java | 112 ++++++ .../worker/util/BoundedQueueExecutorTest.java | 78 ++-- .../worker/util/KeyGroupWorkQueueTest.java | 6 +- .../StreamingEngineWorkCommitterTest.java | 63 ++- .../client/grpc/GrpcCommitWorkStreamTest.java | 188 +++++++++ .../state/WindmillStateReaderTest.java | 10 +- .../failures/WorkFailureProcessorTest.java | 109 +++-- 30 files changed, 1461 insertions(+), 351 deletions(-) rename runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/{KeyTokenInvalidException.java => WorkCancelingException.java} (54%) delete mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidExceptionTest.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 9e82343474c6..71a2547cc602 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -257,6 +257,7 @@ private StreamingDataflowWorker( this.streamingWorkScheduler = StreamingWorkScheduler.create( options, + DataflowRunner.hasExperiment(options, UNSTABLE_ENABLE_MULTI_KEY_BUNDLE), clock, readerCache, mapTaskExecutorFactory, @@ -1198,9 +1199,14 @@ private void onCompleteCommit(CompleteCommit completeCommit) { computationStateCache .getIfPresent(completeCommit.computationId()) .ifPresent( - state -> + state -> { + if (completeCommit.retryableFailure()) { + state.reExecuteActiveWork(completeCommit.shardedKey(), completeCommit.workId()); + } else { state.completeWorkAndScheduleNextWorkForKey( - completeCommit.shardedKey(), completeCommit.workId())); + completeCommit.shardedKey(), completeCommit.workId()); + } + }); } @AutoValue diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index c2cf2f9f7940..f855946cc075 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -35,6 +35,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.concurrent.NotThreadSafe; @@ -52,6 +53,7 @@ import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope; import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; +import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; @@ -82,6 +84,8 @@ import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; import org.apache.beam.sdk.metrics.MetricsContainer; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -120,6 +124,10 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext { private static final Logger LOG = LoggerFactory.getLogger(StreamingModeExecutionContext.class); + private static final String WINDMILL_MAX_KEY_GROUP_BATCH_SIZE = + "windmill_max_key_group_batch_size"; + private static final String WINDMILL_MAX_KEY_GROUP_BATCH_TIME_MS = + "windmill_max_key_group_batch_time_ms"; private final String computationId; private final ImmutableMap stateNameMap; @@ -181,7 +189,7 @@ public class StreamingModeExecutionContext // Key switch listener to delegate MDC logging context and thread name updates public interface KeyTransitionListener { - void onKeyTransition(Work oldWork, Work newWork); + void onKeyTransition(@Nullable Work oldWork, Work newWork); } @SuppressWarnings("UnusedVariable") @@ -197,6 +205,11 @@ public interface KeyTransitionListener { private long stateBytesRead = 0; private final String sourceBytesProcessCounterName; + private final int maxKeyGroupBatchSize; + private final long maxKeyGroupBatchTimeNanos; + private int workItemsPolled = 0; + private long bundleStartTimeNanos = 0; + public StreamingModeExecutionContext( CounterFactory counterFactory, String computationId, @@ -213,6 +226,7 @@ public StreamingModeExecutionContext( boolean hotKeyLoggingEnabled, String stepName, String sourceBytesProcessCounterName, + PipelineOptions options, SideInputStateFetcherFactory sideInputStateFetcherFactory) { super( counterFactory, @@ -232,7 +246,18 @@ public StreamingModeExecutionContext( this.hotKeyLoggingEnabled = hotKeyLoggingEnabled; this.stepName = checkNotNull(stepName); this.sourceBytesProcessCounterName = checkNotNull(sourceBytesProcessCounterName); - this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; + this.sideInputStateFetcherFactory = checkNotNull(sideInputStateFetcherFactory); + + // Initialize batch limits from pipeline options + String batchSizeStr = + ExperimentalOptions.getExperimentValue(options, WINDMILL_MAX_KEY_GROUP_BATCH_SIZE); + this.maxKeyGroupBatchSize = batchSizeStr != null ? Integer.parseInt(batchSizeStr) : 100; + + String batchTimeStr = + ExperimentalOptions.getExperimentValue(options, WINDMILL_MAX_KEY_GROUP_BATCH_TIME_MS); + this.maxKeyGroupBatchTimeNanos = + TimeUnit.MILLISECONDS.toNanos(batchTimeStr != null ? Long.parseLong(batchTimeStr) : 100); + StreamingGlobalConfig config = globalConfigHandle.getConfig(); this.operationalLimits = config.operationalLimits(); this.windmillTagEncoding = @@ -318,7 +343,6 @@ public void clear() { public void start( Work work, - WindmillStateReader stateReader, WorkExecutor workExecutor, BoundedQueueExecutor workQueueExecutor, BoundedQueueExecutorWorkHandle budgetHandle, @@ -331,11 +355,14 @@ public void start( this.budgetHandle = budgetHandle; this.keyTransitionListener = keyTransitionListener; + this.workItemsPolled = 1; + this.bundleStartTimeNanos = System.nanoTime(); + StreamingGlobalConfig config = globalConfigHandle.getConfig(); // Snapshot the limits for entire bundle processing. this.operationalLimits = config.operationalLimits(); - startForNewKey(work, stateReader); + startForNewKey(work); } private @Nullable Object decodeKey(Work work) { @@ -700,14 +727,42 @@ public Map> flushState() { } public boolean advance() { + if (workIsFailed()) { + throw new WorkItemCancelledException(checkStateNotNull(work).getWorkItem().getShardingKey()); + } + + BoundedQueueExecutor executor = checkStateNotNull(workQueueExecutor); + BoundedQueueExecutorWorkHandle handle = checkStateNotNull(budgetHandle); + Work activeWork = checkStateNotNull(work); + + if (activeWork.getKeyGroup().equals(Work.KeyGroup.DEFAULT) || shouldStopBatching()) { + return false; + } + + @Nullable + ExecutableWork additionalWork = + executor.pollWork(computationId, activeWork.getKeyGroup(), handle); + if (additionalWork != null) { + Work newWork = additionalWork.work(); + ++workItemsPolled; + checkStateNotNull(keyTransitionListener).onKeyTransition(activeWork, newWork); + startForNewKey(newWork); + return true; + } + return false; } - private void startForNewKey(Work newWork, WindmillStateReader reader) { - newWork.setState(Work.State.PROCESSING); - if (keyTransitionListener != null && this.work != null && this.work != newWork) { - keyTransitionListener.onKeyTransition(this.work, newWork); + private boolean shouldStopBatching() { + if (workItemsPolled >= maxKeyGroupBatchSize) { + return true; } + long elapsedNanos = System.nanoTime() - bundleStartTimeNanos; + return elapsedNanos >= maxKeyGroupBatchTimeNanos; + } + + private void startForNewKey(Work newWork) { + newWork.setState(Work.State.PROCESSING); this.key = decodeKey(newWork); this.work = newWork; this.finishKeyCalled = false; @@ -736,8 +791,8 @@ private void startForNewKey(Work newWork, WindmillStateReader reader) { WindmillStateCache.ForKey cacheForKey = stateCache.forKey( getComputationKey(), newWork.getWorkItem().getCacheToken(), getWorkToken()); - this.activeStateReader = reader; - startStepContexts(reader, processingTime, cacheForKey, newWork.watermarks()); + this.activeStateReader = newWork.createWindmillStateReader(this::workIsFailed); + startStepContexts(this.activeStateReader, processingTime, cacheForKey, newWork.watermarks()); } else { this.activeStateReader = null; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkCancelingException.java similarity index 54% rename from runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidException.java rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkCancelingException.java index 29b16b71883f..73a307641b96 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkCancelingException.java @@ -17,21 +17,30 @@ */ package org.apache.beam.runners.dataflow.worker; -import javax.annotation.Nullable; +import org.checkerframework.checker.nullness.qual.Nullable; -/** Indicates that the key token was invalid when data was attempted to be fetched. */ -public class KeyTokenInvalidException extends RuntimeException { - public KeyTokenInvalidException(String key) { - super("Unable to fetch data due to token mismatch for key " + key); +/** + * Indicates that the work is no longer valid and should be canceled. It is thrown as a signal for + * upper layers to mark the work as failed. + */ +public class WorkCancelingException extends RuntimeException { + + public WorkCancelingException(long sharding_key) { + super("Work canceling exception for key " + sharding_key); + } + + public WorkCancelingException(Throwable cause) { + super(cause); } - /** Returns whether an exception was caused by a {@link KeyTokenInvalidException}. */ - public static boolean isKeyTokenInvalidException(@Nullable Throwable t) { - while (t != null) { - if (t instanceof KeyTokenInvalidException) { + /** Returns whether an exception was caused by a {@link WorkCancelingException}. */ + public static boolean isWorkCancelingException(Throwable t) { + @Nullable Throwable throwable = t; + while (throwable != null) { + if (throwable instanceof WorkCancelingException) { return true; } - t = t.getCause(); + throwable = throwable.getCause(); } return false; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java index a12a5075c5ee..68cbab32254c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java @@ -17,31 +17,10 @@ */ package org.apache.beam.runners.dataflow.worker; -/** Indicates that the work item was cancelled and should not be retried. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) +/** Indicates that the work item was canceled. */ public class WorkItemCancelledException extends RuntimeException { + public WorkItemCancelledException(long sharding_key) { super("Work item cancelled for key " + sharding_key); } - - public WorkItemCancelledException(String message, Throwable cause) { - super(message, cause); - } - - public WorkItemCancelledException(Throwable cause) { - super(cause); - } - - /** Returns whether an exception was caused by a {@link WorkItemCancelledException}. */ - public static boolean isWorkItemCancelledException(Throwable t) { - while (t != null) { - if (t instanceof WorkItemCancelledException) { - return true; - } - t = t.getCause(); - } - return false; - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java index e430f6c8f638..f49aa31a439a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java @@ -88,6 +88,11 @@ static ActiveWorkState create(WindmillStateCache.ForComputation computationState return new ActiveWorkState(new HashMap<>(), computationStateCache); } + synchronized Optional getActiveWork(ShardedKey shardedKey, WorkId workId) { + LinkedHashMap workQueue = activeWork.get(shardedKey.shardingKey()); + return workQueue == null ? Optional.empty() : Optional.ofNullable(workQueue.get(workId)); + } + @VisibleForTesting static ActiveWorkState forTesting( Map> activeWork, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/BoundedQueueExecutorWorkHandle.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/BoundedQueueExecutorWorkHandle.java index 1ca534966947..20661aae0a04 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/BoundedQueueExecutorWorkHandle.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/BoundedQueueExecutorWorkHandle.java @@ -17,8 +17,13 @@ */ package org.apache.beam.runners.dataflow.worker.streaming; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; + /** * A handle to use when requesting pulling more work from @BoundedQueueExecutor * via @BoundedQueueExecutor.pollWork */ -public interface BoundedQueueExecutorWorkHandle {} +public interface BoundedQueueExecutorWorkHandle { + // Returns all work that are tracked by the handle + ImmutableList getWorkBatch(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java index 3886d4fbc01b..e9f6ddc55de6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java @@ -131,6 +131,10 @@ public void completeWorkAndScheduleNextWorkForKey(ShardedKey shardedKey, WorkId .ifPresent(this::forceExecute); } + public void reExecuteActiveWork(ShardedKey shardedKey, WorkId workId) { + activeWorkState.getActiveWork(shardedKey, workId).ifPresent(this::forceExecute); + } + public void invalidateStuckCommits(Instant stuckCommitDeadline) { activeWorkState.invalidateStuckCommits( stuckCommitDeadline, this::completeWorkAndScheduleNextWorkForKey); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java index 56a1a06362d2..5208ee475f47 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java @@ -25,7 +25,6 @@ import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; -import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; import org.slf4j.Logger; @@ -64,7 +63,6 @@ public static ComputationWorkExecutor.Builder builder() { */ public final void executeWork( Work work, - WindmillStateReader stateReader, BoundedQueueExecutor workQueueExecutor, BoundedQueueExecutorWorkHandle budgetHandle, KeyTransitionListener keyTransitionListener) @@ -72,7 +70,6 @@ public final void executeWork( context() .start( work, - stateReader, workExecutor(), workQueueExecutor, budgetHandle, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index 252a16a38bc9..f9cfec7e6807 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -36,6 +36,7 @@ import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.worker.ActiveMessageMetadata; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; +import org.apache.beam.runners.dataflow.worker.WorkCancelingException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; @@ -87,6 +88,7 @@ public final class Work implements RefreshableWork { private final AtomicReference<@Nullable AtomicBoolean> onFailureListener = new AtomicReference<>(null); private final boolean drainMode; + private ImmutableList getWorkStreamLatencies; private Work( WorkItem workItem, @@ -94,7 +96,8 @@ private Work( Watermarks watermarks, ProcessingContext processingContext, boolean drainMode, - Supplier clock) { + Supplier clock, + ImmutableList getWorkStreamLatencies) { this.shardedKey = ShardedKey.create(workItem.getKey(), workItem.getShardingKey()); this.workItem = workItem; this.serializedWorkItemSize = serializedWorkItemSize; @@ -118,6 +121,7 @@ private Work( + Long.toHexString(workItem.getWorkToken()); this.currentState = TimedState.initialState(startTime); this.isFailed = false; + this.getWorkStreamLatencies = getWorkStreamLatencies; } public static Work create( @@ -128,7 +132,31 @@ public static Work create( boolean drainMode, Supplier clock) { return new Work( - workItem, serializedWorkItemSize, watermarks, processingContext, drainMode, clock); + workItem, + serializedWorkItemSize, + watermarks, + processingContext, + drainMode, + clock, + ImmutableList.of()); + } + + public static Work create( + WorkItem workItem, + long serializedWorkItemSize, + Watermarks watermarks, + ProcessingContext processingContext, + boolean drainMode, + Supplier clock, + ImmutableList getWorkStreamLatencies) { + return new Work( + workItem, + serializedWorkItemSize, + watermarks, + processingContext, + drainMode, + clock, + getWorkStreamLatencies); } public static ProcessingContext createProcessingContext( @@ -205,11 +233,31 @@ public ShardedKey getShardedKey() { } public Optional fetchKeyedState(KeyedGetDataRequest keyedGetDataRequest) { - return processingContext.fetchKeyedState(keyedGetDataRequest); + try { + Optional response = + processingContext.fetchKeyedState(keyedGetDataRequest); + if (response.isPresent() && response.get().getFailed()) { + // Work is not valid in backend anymore. + this.setFailed(); + } + return response; + } catch (RuntimeException e) { + if (WorkCancelingException.isWorkCancelingException(e)) { + this.setFailed(); + } + throw e; + } } public GlobalData fetchSideInput(GlobalDataRequest request) { - return processingContext.getDataClient().getSideInputData(request); + try { + return processingContext.getDataClient().getSideInputData(request); + } catch (RuntimeException e) { + if (WorkCancelingException.isWorkCancelingException(e)) { + this.setFailed(); + } + throw e; + } } public String backendWorkerToken() { @@ -293,8 +341,8 @@ public Consumer workCommitter() { return processingContext.workCommitter(); } - public WindmillStateReader createWindmillStateReader() { - return WindmillStateReader.forWork(this); + public WindmillStateReader createWindmillStateReader(Supplier workIsFailed) { + return WindmillStateReader.forWork(this, workIsFailed); } @Override @@ -302,11 +350,17 @@ public WorkId id() { return id; } - public void recordGetWorkStreamLatencies( - ImmutableList getWorkStreamLatencies) { - for (LatencyAttribution latency : getWorkStreamLatencies) { - totalDurationPerState.put( - latency.getState(), Duration.millis(latency.getTotalDurationMillis())); + public ImmutableList getWorkStreamLatencies() { + return getWorkStreamLatencies; + } + + public void recordGetWorkStreamLatencies() { + if (!getWorkStreamLatencies.isEmpty()) { + for (LatencyAttribution latency : getWorkStreamLatencies) { + totalDurationPerState.put( + latency.getState(), Duration.millis(latency.getTotalDurationMillis())); + } + this.getWorkStreamLatencies = ImmutableList.of(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java index 8964246c1160..9eb9a37b1b76 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java @@ -20,6 +20,8 @@ import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadFactory; @@ -30,9 +32,9 @@ import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Work; -import org.apache.beam.runners.dataflow.worker.streaming.Work.KeyGroup; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor.Guard; import org.checkerframework.checker.nullness.qual.Nullable; @@ -260,7 +262,7 @@ final class BoundedQueueExecutorWorkHandleImpl implements BoundedQueueExecutorWorkHandle, AutoCloseable { @GuardedBy("this") - private int elements; + private final List workBatch; @GuardedBy("this") private long bytes; @@ -268,16 +270,17 @@ final class BoundedQueueExecutorWorkHandleImpl @GuardedBy("this") private boolean closed = false; - private BoundedQueueExecutorWorkHandleImpl(int elements, long bytes) { - checkArgument(elements >= 0 && bytes >= 0); - this.elements = elements; + private BoundedQueueExecutorWorkHandleImpl(Work work, long bytes) { + checkArgument(bytes >= 0); + this.workBatch = new ArrayList<>(); + this.workBatch.add(checkArgumentNotNull(work)); this.bytes = bytes; } /** * Merges the budget from another handle into this handle. * - *

This transfers the budget (elements and bytes) from the {@code other} handle to this + *

This transfers the budget (workBatch and bytes) from the {@code other} handle to this * handle, and marks the {@code other} handle as closed to prevent it from releasing the budget * again if it is closed. */ @@ -287,10 +290,10 @@ public void merge(BoundedQueueExecutorWorkHandleImpl other) { Preconditions.checkState(!closed, "Cannot merge into a closed handle"); synchronized (other) { Preconditions.checkState(!other.closed, "Cannot merge a closed handle"); - this.elements += other.elements; + this.workBatch.addAll(other.workBatch); this.bytes += other.bytes; other.closed = true; - other.elements = 0; + other.workBatch.clear(); other.bytes = 0; } } @@ -300,9 +303,9 @@ public synchronized boolean isClosed() { return closed; } - @VisibleForTesting - synchronized int elements() { - return elements; + @Override + public synchronized ImmutableList getWorkBatch() { + return ImmutableList.copyOf(workBatch); } @VisibleForTesting @@ -314,7 +317,7 @@ synchronized long bytes() { public synchronized void close() { if (closed) return; closed = true; - decrementCounters(this.elements, this.bytes); + decrementCounters(this.workBatch.size(), this.bytes); } } @@ -350,7 +353,7 @@ private void executeMonitorHeld(ExecutableWork work, long workBytes) { bytesOutstanding += workBytes; monitor.leave(); BoundedQueueExecutorWorkHandleImpl handle = - new BoundedQueueExecutorWorkHandleImpl(1, workBytes); + new BoundedQueueExecutorWorkHandleImpl(work.work(), workBytes); try { executor.execute(new QueuedWork(work, handle)); } catch (Throwable t) { @@ -379,14 +382,15 @@ private void executeMonitorHeld(Runnable work) { } @VisibleForTesting - BoundedQueueExecutorWorkHandleImpl createBudgetHandle(int elements, long bytes) { - return new BoundedQueueExecutorWorkHandleImpl(elements, bytes); + BoundedQueueExecutorWorkHandleImpl createBudgetHandle(Work work, long bytes) { + return new BoundedQueueExecutorWorkHandleImpl(work, bytes); } public @Nullable ExecutableWork pollWork( String computationId, Work.KeyGroup keyGroup, BoundedQueueExecutorWorkHandle handle) { + checkArgument( + computationId != null && keyGroup != null && !keyGroup.equals(Work.KeyGroup.DEFAULT)); checkArgument(handle instanceof BoundedQueueExecutorWorkHandleImpl); - checkArgument(computationId != null && keyGroup != null && !keyGroup.equals(KeyGroup.DEFAULT)); BoundedQueueExecutorWorkHandleImpl internalHandle = (BoundedQueueExecutorWorkHandleImpl) handle; if (keyGroupWorkQueue == null) { return null; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java index 6c0a5a98e2ab..e168d92987fb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java @@ -38,8 +38,13 @@ public abstract class CompleteCommit { public static CompleteCommit create( - String computationId, ShardedKey shardedKey, WorkId workId, CommitStatus status) { - return new AutoValue_CompleteCommit(computationId, shardedKey, workId, status); + String computationId, + ShardedKey shardedKey, + WorkId workId, + CommitStatus status, + boolean retryableFailure) { + return new AutoValue_CompleteCommit( + computationId, shardedKey, workId, status, retryableFailure); } public abstract String computationId(); @@ -49,4 +54,6 @@ public static CompleteCommit create( public abstract WorkId workId(); public abstract CommitStatus status(); + + public abstract boolean retryableFailure(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java index 40e82c4ca368..58f0dbbea242 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java @@ -159,7 +159,8 @@ private void completeWork( .setCacheToken(workRequest.getCacheToken()) .setWorkToken(workRequest.getWorkToken()) .build(), - Windmill.CommitStatus.OK)); + Windmill.CommitStatus.OK, + /* retryableFailure= */ false)); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java index cb8e6d26d089..72d9e5ed8d03 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java @@ -148,11 +148,41 @@ private void drainCommitQueue() { } private void failCommit(Commit commit) { + if (!isRunning.get()) { + // Shutting down, fail everything unconditionally to prevent infinite loops + for (Work w : commit.workBatch()) { + w.setFailed(); + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), + w.getShardedKey(), + w.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ false)); + } + return; + } + + // Still running, only fail actually failed work, and request re-execution for valid ones for (Work w : commit.workBatch()) { - w.setFailed(); - onCommitComplete.accept( - CompleteCommit.create( - commit.computationId(), w.getShardedKey(), w.id(), CommitStatus.ABORTED)); + if (w.isFailed()) { + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), + w.getShardedKey(), + w.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ false)); + } else { + LOG.debug("Requesting re-execution for valid work {} from failed commit", w.id()); + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), + w.getShardedKey(), + w.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ true)); + } } } @@ -221,7 +251,11 @@ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatch for (Work w : commit.workBatch()) { onCommitComplete.accept( CompleteCommit.create( - commit.computationId(), w.getShardedKey(), w.id(), commitStatus)); + commit.computationId(), + w.getShardedKey(), + w.id(), + commitStatus, + /* retryableFailure= */ false)); } activeCommitBytes.addAndGet(-commit.getSize()); }); @@ -234,7 +268,11 @@ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatch Work w = commit.workBatch().get(0); onCommitComplete.accept( CompleteCommit.create( - commit.computationId(), w.getShardedKey(), w.id(), commitStatus)); + commit.computationId(), + w.getShardedKey(), + w.id(), + commitStatus, + /* retryableFailure= */ false)); activeCommitBytes.addAndGet(-commit.getSize()); }); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java index ab12946ad18b..d233bf091b6a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java @@ -19,6 +19,7 @@ import java.io.PrintWriter; import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.WorkCancelingException; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; @@ -62,7 +63,7 @@ public Windmill.KeyedGetDataResponse getStateData( try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { return getDataStream.requestKeyedData(computationId, request); } catch (WindmillStreamShutdownException e) { - throw new WorkItemCancelledException(request.getShardingKey()); + throw new WorkCancelingException(request.getShardingKey()); } catch (Exception e) { throw new GetDataException( "Error occurred fetching state for computation=" @@ -87,7 +88,7 @@ public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) { return sideInputGetDataStream.requestGlobalData(request); } catch (WindmillStreamShutdownException e) { - throw new WorkItemCancelledException(e); + throw new WorkCancelingException(e); } catch (Exception e) { throw new GetDataException( "Error occurred fetching side input for tag=" + request.getDataId(), e); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java index c609bed4eae0..6c5ae50858cc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java @@ -36,8 +36,8 @@ import java.util.function.Supplier; import java.util.stream.Collectors; import javax.annotation.Nullable; -import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; +import org.apache.beam.runners.dataflow.worker.WorkCancelingException; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -153,7 +153,7 @@ static WindmillStateReader forTesting( fetchStateFromWindmillFn, key, shardingKey, workToken, () -> null, () -> Boolean.FALSE); } - public static WindmillStateReader forWork(Work work) { + public static WindmillStateReader forWork(Work work, Supplier workItemIsFailed) { return new WindmillStateReader( work::fetchKeyedState, work.getWorkItem().getKey(), @@ -163,7 +163,7 @@ public static WindmillStateReader forWork(Work work) { work.setState(Work.State.READING); return () -> work.setState(Work.State.PROCESSING); }, - work::isFailed); + workItemIsFailed); } private Future stateFuture(StateTag stateTag, @Nullable Coder coder) { @@ -588,7 +588,8 @@ private KeyedGetDataRequest createRequest(Iterable> toFetch) { private void consumeResponse(KeyedGetDataResponse response, Set> toFetch) { bytesRead += response.getSerializedSize(); if (response.getFailed()) { - throw new KeyTokenInvalidException(key.toStringUtf8()); + // upper layers will fail the work on seeing this exception. + throw new WorkCancelingException(shardingKey); } if (!key.equals(response.getKey())) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java index 4a52d9fde771..86449b1c2bb1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java @@ -84,6 +84,7 @@ final class ComputationWorkExecutorFactory { private final SinkRegistry sinkRegistry; private final DataflowExecutionStateSampler sampler; private final CounterSet pendingDeltaCounters; + private final SideInputStateFetcherFactory sideInputStateFetcherFactory; /** * Function which converts map tasks to their network representation for execution. @@ -100,7 +101,6 @@ final class ComputationWorkExecutorFactory { private final StreamingGlobalConfigHandle globalConfigHandle; private final boolean throwExceptionOnLargeOutput; private final HotKeyLogger hotKeyLogger; - private final SideInputStateFetcherFactory sideInputStateFetcherFactory; ComputationWorkExecutorFactory( DataflowWorkerHarnessOptions options, @@ -287,6 +287,7 @@ private StreamingModeExecutionContext createExecutionContext( hotKeyLoggingEnabled, stepName, computationState.sourceBytesProcessCounterName(), + options, sideInputStateFetcherFactory); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index 9e28c64b7860..664999f0d864 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -19,6 +19,7 @@ import com.google.api.services.dataflow.model.MapTask; import com.google.auto.value.AutoValue; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentMap; @@ -54,7 +55,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; -import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.WorkFailureProcessor; import org.apache.beam.sdk.annotations.Internal; @@ -88,6 +88,7 @@ public class StreamingWorkScheduler { private final DataflowExecutionStateSampler sampler; private final StreamingGlobalConfigHandle globalConfigHandle; private final BoundedQueueExecutor workExecutor; + private final boolean multiKeyExperimentEnabled; public StreamingWorkScheduler( Supplier clock, @@ -99,7 +100,8 @@ public StreamingWorkScheduler( StreamingCounters streamingCounters, ConcurrentMap stageInfoMap, DataflowExecutionStateSampler sampler, - StreamingGlobalConfigHandle globalConfigHandle) { + StreamingGlobalConfigHandle globalConfigHandle, + boolean multiKeyExperimentEnabled) { this.clock = clock; this.workExecutor = workExecutor; this.computationWorkExecutorFactory = computationWorkExecutorFactory; @@ -110,10 +112,12 @@ public StreamingWorkScheduler( this.stageInfoMap = stageInfoMap; this.sampler = sampler; this.globalConfigHandle = globalConfigHandle; + this.multiKeyExperimentEnabled = multiKeyExperimentEnabled; } public static StreamingWorkScheduler create( DataflowWorkerHarnessOptions options, + boolean multiKeyExperimentEnabled, Supplier clock, ReaderCache readerCache, DataflowMapTaskExecutorFactory mapTaskExecutorFactory, @@ -154,7 +158,8 @@ public static StreamingWorkScheduler create( streamingCounters, stageInfoMap, sampler, - globalConfigHandle); + globalConfigHandle, + multiKeyExperimentEnabled); } private static long computeShuffleBytesRead(Windmill.WorkItem workItem) { @@ -182,12 +187,6 @@ private static Windmill.WorkItemCommitRequest buildWorkItemTruncationRequest( return outputBuilder.build(); } - /** Sets the stage name and workId of the Thread executing the {@link Work} for logging. */ - private static void setUpWorkLoggingContext(String workLatencyTrackingId, String computationId) { - setLoggingContextWorkId(workLatencyTrackingId); - setLoggingContextComputation(computationId); - } - private static void setLoggingContextComputation(@Nullable String computationId) { DataflowWorkerLoggingMDC.setStageName(computationId); } @@ -217,8 +216,14 @@ public void scheduleWork( computationState.activateWork( ExecutableWork.create( Work.create( - workItem, serializedWorkItemSize, watermarks, processingContext, drainMode, clock), - (work, handle) -> processWork(computationState, work, getWorkStreamLatencies, handle))); + workItem, + serializedWorkItemSize, + watermarks, + processingContext, + drainMode, + clock, + getWorkStreamLatencies), + (work, handle) -> processWork(computationState, work, handle))); } /** Adds any applied finalize ids to the commit finalizer to have their callbacks executed. */ @@ -232,25 +237,16 @@ public void queueAppliedFinalizeIds(ImmutableList appliedFinalizeIds) { * internally if processing fails due to uncaught {@link Exception}(s). * * @implNote This will block the calling thread during execution of user DoFns. - * @param handle handled to pass to BoundedQueueExecutor.pollWork, currently unused + * @param handle handled to pass to BoundedQueueExecutor.pollWork */ - private void processWork( - ComputationState computationState, - Work work, - ImmutableList getWorkStreamLatencies, - BoundedQueueExecutorWorkHandle handle) { - work.recordGetWorkStreamLatencies(getWorkStreamLatencies); - processWork(computationState, work, handle); - } - private void processWork( ComputationState computationState, Work work, BoundedQueueExecutorWorkHandle handle) { Windmill.WorkItem workItem = work.getWorkItem(); String computationId = computationState.getComputationId(); - work.setProcessingThreadName(Thread.currentThread().getName()); - work.setState(Work.State.PROCESSING); - setUpWorkLoggingContext(work.getLatencyTrackingId(), computationId); LOG.debug("Starting processing for {}:\n{}", computationId, work); + setLoggingContextComputation(computationId); + KeyTransitionListener keyTransitionListener = createKeyTransitionListener(); + keyTransitionListener.onKeyTransition(null, work); // Before any processing starts, call any pending OnCommit callbacks. Nothing that requires // cleanup should be done before this, since we might exit early here. @@ -271,7 +267,8 @@ private void processWork( } // Execute the user code for the Work batch. - ExecuteWorkResult executeWorkResult = executeWork(work, stageInfo, computationState, handle); + ExecuteWorkResult executeWorkResult = + executeWork(work, stageInfo, computationState, handle, keyTransitionListener); workBatch = executeWorkResult.workBatch(); List workItemCommits = executeWorkResult.workItemCommits(); @@ -282,25 +279,14 @@ private void processWork( recordProcessingStats(workBatch, workItemCommits, executeWorkResult.stateBytesRead()); LOG.debug("Processing done for work batch size: {}", workBatch.size()); } catch (Throwable t) { - // OutOfMemoryError that are caught will be rethrown and trigger jvm termination. - try { - workFailureProcessor.logAndProcessFailure( - computationId, - ExecutableWork.create(work, (retry, h) -> processWork(computationState, retry, h)), - t, - invalidWork -> - computationState.completeWorkAndScheduleNextWorkForKey( - invalidWork.getShardedKey(), invalidWork.id())); - } catch (OutOfMemoryError oom) { - throw oom; - } catch (Throwable t2) { - LOG.warn("Failed to process work failure safely for work {}", work.id(), t2); - throw ExceptionUtils.safeWrapThrowableAsException(t2); - } + handleProcessWorkFailure(computationState, handle.getWorkBatch(), computationId, work, t); } finally { // Update total processing time counters. Updating in finally clause ensures that // work items causing exceptions are also accounted in time spent. - recordProcessingTime(stageInfo, workBatch, work, processingStartTimeNanos); + recordProcessingTime( + stageInfo, + workBatch != null ? workBatch : ImmutableList.of(work), + processingStartTimeNanos); resetWorkLoggingContext(); sampler.resetForWorkId(work.getLatencyTrackingId()); @@ -371,7 +357,8 @@ private ExecuteWorkResult executeWork( Work work, StageInfo stageInfo, ComputationState computationState, - BoundedQueueExecutorWorkHandle handle) + BoundedQueueExecutorWorkHandle handle, + KeyTransitionListener keyTransitionListener) throws Exception { ComputationWorkExecutor computationWorkExecutor = computationState @@ -382,20 +369,16 @@ private ExecuteWorkResult executeWork( stageInfo, computationState, work.getLatencyTrackingId())); try { - WindmillStateReader stateReader = work.createWindmillStateReader(); - - KeyTransitionListener keyTransitionListener = createKeyTransitionListener(); + StreamingModeExecutionContext context = computationWorkExecutor.context(); // Blocks while executing work. - computationWorkExecutor.executeWork( - work, stateReader, workExecutor, handle, keyTransitionListener); + computationWorkExecutor.executeWork(work, workExecutor, handle, keyTransitionListener); List workBatch; List workItemCommits; Map> accumulatedCallbacks; long stateBytesRead; { - StreamingModeExecutionContext context = computationWorkExecutor.context(); if (context.workIsFailed()) { throw new WorkItemCancelledException(work.getWorkItem().getShardingKey()); } @@ -450,9 +433,50 @@ private void commitWorkBatch( ComputationState computationState, List workBatch, List workItemCommits) { - Preconditions.checkState( - workBatch.size() == 1, "Expected single-key work batch, got: " + workBatch.size()); - commitSingleKeyWork(computationState, workBatch.get(0), workItemCommits.get(0)); + if (workBatch.isEmpty()) { + return; + } + if (workBatch.size() > 1 || multiKeyExperimentEnabled) { + commitMultiKeyWorkBatch(computationState, workBatch, workItemCommits); + } else { + commitSingleKeyWork(computationState, workBatch.get(0), workItemCommits.get(0)); + } + } + + private void commitMultiKeyWorkBatch( + ComputationState computationState, + List workBatch, + List workItemCommits) { + Windmill.MultiKeyWorkItemCommitRequest.Builder multiKeyBuilder = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder(); + + Work primaryWork = workBatch.get(0); + Work.KeyGroup keyGroup = primaryWork.getKeyGroup(); + multiKeyBuilder.setKeyGroup( + Windmill.Uint128Proto.newBuilder().setHigh(keyGroup.high()).setLow(keyGroup.low()).build()); + + for (int i = 0; i < workBatch.size(); i++) { + // TODO: Add commit size validation + Windmill.WorkItemCommitRequest commit = workItemCommits.get(i); + Work w = workBatch.get(i); + multiKeyBuilder.addRequests( + commit + .toBuilder() + .addAllPerWorkItemLatencyAttributions(w.getLatencyAttributions(sampler)) + .build()); + } + + // Transition states of all completed works in the batch to COMMIT_QUEUED and submit + for (Work w : workBatch) { + w.setState(Work.State.COMMIT_QUEUED); + } + + // Package and submit the commit batch transactionally + primaryWork + .workCommitter() + .accept( + Commit.createMultiKey( + multiKeyBuilder.build(), computationState, ImmutableList.copyOf(workBatch))); } private void commitSingleKeyWork( @@ -470,15 +494,40 @@ private void commitSingleKeyWork( work.queueCommit(validatedCommitRequest, computationState); } + private void handleProcessWorkFailure( + ComputationState computationState, + List failedBatch, + String computationId, + Work primaryWork, + Throwable t) { + try { + List executableWorks = new ArrayList<>(); + for (Work w : failedBatch) { + executableWorks.add( + ExecutableWork.create(w, (retry, h) -> processWork(computationState, retry, h))); + } + + workFailureProcessor.logAndProcessFailureBatch( + computationId, + executableWorks, + t, + invalidWork -> + computationState.completeWorkAndScheduleNextWorkForKey( + invalidWork.getShardedKey(), invalidWork.id())); + } catch (OutOfMemoryError oom) { + throw oom; + } catch (Throwable t2) { + LOG.warn("Failed to process work failure safely for work {}", primaryWork.id(), t2); + throw ExceptionUtils.safeWrapThrowableAsException(t2); + } + } + private void recordProcessingTime( - StageInfo stageInfo, - @Nullable List worksToCleanup, - Work work, - long processingStartTimeNanos) { + StageInfo stageInfo, List workBatch, long processingStartTimeNanos) { long processingTimeMsecs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); - if (anyWorkHasTimers(worksToCleanup, work)) { + if (anyWorkHasTimers(workBatch)) { // Attribute all the processing to timers if the work item contains any timers. // Tests show that work items rarely contain both timers and message bundles. It should // be a fairly close approximation. @@ -488,18 +537,21 @@ private void recordProcessingTime( } } - private static boolean anyWorkHasTimers(@Nullable List works, Work primaryWork) { - if (works != null && !works.isEmpty()) { - return works.stream().anyMatch(w -> w.getWorkItem().hasTimers()); - } - return primaryWork.getWorkItem().hasTimers(); + private static boolean anyWorkHasTimers(List works) { + return works.stream().anyMatch(w -> w.getWorkItem().hasTimers()); } private KeyTransitionListener createKeyTransitionListener() { return (oldWork, newWork) -> { + newWork.recordGetWorkStreamLatencies(); + newWork.setState(Work.State.PROCESSING); setLoggingContextWorkId(newWork.getLatencyTrackingId()); - newWork.setProcessingThreadName(oldWork.getProcessingThreadName()); - oldWork.setProcessingThreadName(""); + if (oldWork != null) { + newWork.setProcessingThreadName(oldWork.getProcessingThreadName()); + oldWork.setProcessingThreadName(""); + } else { + newWork.setProcessingThreadName(Thread.currentThread().getName()); + } }; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java index 18c8e9b8d83c..15ec1e0c2cf3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java @@ -17,13 +17,12 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures; +import java.util.List; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Supplier; import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; -import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; -import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.status.LastExceptionDataProvider; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Work; @@ -99,28 +98,41 @@ private static boolean isOutOfMemoryError(@Nullable Throwable t) { return false; } - /** - * Processes failures caused by thrown exceptions that occur during execution of {@link Work}. May - * attempt to retry execution of the {@link Work} or drop it if it is invalid. - */ - public void logAndProcessFailure( + public void logAndProcessFailureBatch( String computationId, - ExecutableWork executableWork, + List executableWorks, Throwable t, Consumer onInvalidWork) throws Throwable { - switch (evaluateRetry(computationId, executableWork.work(), t)) { - case DO_NOT_RETRY: - // Consider the item invalid. It will eventually be retried by Windmill if it still needs to - // be processed. - onInvalidWork.accept(executableWork.work()); - break; - case RETRY_LOCALLY: - // Try again after some delay and at the end of the queue to avoid a tight loop. - executeWithDelay(retryLocallyDelayMs, executableWork); - break; - case RETHROW_THROWABLE: - throw t; + List worksToRetryLocally = new java.util.ArrayList<>(); + + for (ExecutableWork executableWork : executableWorks) { + switch (evaluateRetry(computationId, executableWork.work(), t)) { + case DO_NOT_RETRY: + // Consider the item invalid. It will eventually be retried by Windmill if it still needs + // to + // be processed. + onInvalidWork.accept(executableWork.work()); + break; + case RETRY_LOCALLY: + // Try again after some delay and at the end of the queue to avoid a tight loop. + worksToRetryLocally.add(executableWork); + break; + case RETHROW_THROWABLE: + throw t; + } + } + + executeWithDelay(worksToRetryLocally); + } + + private void executeWithDelay(List worksToRetryLocally) { + if (!worksToRetryLocally.isEmpty()) { + // Sleep ONCE for the entire batch delay to avoid sequential thread blocks + Uninterruptibles.sleepUninterruptibly(retryLocallyDelayMs, TimeUnit.MILLISECONDS); + for (ExecutableWork ew : worksToRetryLocally) { + workUnitExecutor.forceExecute(ew, ew.work().getSerializedWorkItemSize()); + } } } @@ -131,12 +143,6 @@ private String tryToDumpHeap() { .orElseGet(() -> "not written"); } - private void executeWithDelay(long delayMs, ExecutableWork executableWork) { - Uninterruptibles.sleepUninterruptibly(delayMs, TimeUnit.MILLISECONDS); - workUnitExecutor.forceExecute( - executableWork, executableWork.work().getSerializedWorkItemSize()); - } - private enum RetryEvaluation { DO_NOT_RETRY, RETRY_LOCALLY, @@ -144,24 +150,16 @@ private enum RetryEvaluation { } private RetryEvaluation evaluateRetry(String computationId, Work work, Throwable t) { - @Nullable final Throwable cause = t.getCause(); - Throwable parsedException = (t instanceof UserCodeException && cause != null) ? cause : t; - if (KeyTokenInvalidException.isKeyTokenInvalidException(parsedException)) { - LOG.debug( - "Execution of work for computation '{}' on sharding key '{}' failed due to token expiration. " - + "Work will not be retried locally.", - computationId, - work.getWorkItem().getShardingKey()); - return RetryEvaluation.DO_NOT_RETRY; - } - if (WorkItemCancelledException.isWorkItemCancelledException(parsedException)) { + if (work.isFailed()) { LOG.debug( "Execution of work for computation '{}' on sharding key '{}' failed. " - + "Work will not be retried locally.", + + "Work is already marked as failed, not retrying locally.", computationId, work.getWorkItem().getShardingKey()); return RetryEvaluation.DO_NOT_RETRY; } + @Nullable final Throwable cause = t.getCause(); + Throwable parsedException = (t instanceof UserCodeException && cause != null) ? cause : t; LastExceptionDataProvider.reportException(parsedException); LOG.debug("Failed work: {}", work); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidExceptionTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidExceptionTest.java deleted file mode 100644 index 1eb2871e8cd3..000000000000 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidExceptionTest.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.worker; - -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests for {@link KeyTokenInvalidException}. */ -@RunWith(JUnit4.class) -public final class KeyTokenInvalidExceptionTest { - @Test - public void testIsKeyTokenInvalidException() throws Exception { - KeyTokenInvalidException exception = new KeyTokenInvalidException("test"); - RuntimeException keyTokenCauseException = new RuntimeException("key token cause", exception); - assertTrue(KeyTokenInvalidException.isKeyTokenInvalidException(exception)); - assertTrue(KeyTokenInvalidException.isKeyTokenInvalidException(keyTokenCauseException)); - assertFalse( - KeyTokenInvalidException.isKeyTokenInvalidException(new RuntimeException("non key token"))); - } -} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 7eaa048204ff..dbb1cc45e1b8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -901,7 +901,7 @@ private ByteString addPaneTag(PaneInfo paneInfo, byte[] windowBytes) throws IOEx } private DataflowWorkerHarnessOptions createTestingPipelineOptions(String... args) { - List argsList = Lists.newArrayList(args); + List argsList = new ArrayList<>(Arrays.asList(args)); if (streamingEngine) { argsList.add("--experiments=enable_streaming_engine"); } @@ -1252,9 +1252,8 @@ public void testNumberOfWorkerHarnessThreadsIsHonored() throws Exception { } @Test - public void testKeyTokenInvalidException() throws Exception { - if (streamingEngine) { - // TODO: This test needs to be adapted to work with streamingEngine=true. + public void testMultiKeyCommit_success() throws Exception { + if (!streamingEngine) { return; } KvCoder kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()); @@ -1262,30 +1261,359 @@ public void testKeyTokenInvalidException() throws Exception { List instructions = Arrays.asList( makeSourceInstruction(kvCoder), - makeDoFnInstruction(new KeyTokenInvalidFn(), 0, kvCoder), + makeDoFnInstruction(new WorkDoFn(), 0, kvCoder), makeSinkInstruction(kvCoder, 1)); + StreamingDataflowWorker worker = + makeWorker( + defaultWorkerParams( + "--experiments=unstable_enable_multi_key_bundle,windmill_max_key_group_batch_time_ms=50000", + "--numberOfWorkerHarnessThreads=1") + .setLocalRetryTimeoutMs(100) + .setInstructions(instructions) + .build()); + worker.start(); + + String batchInputText = + "work {" + + " computation_id: \"" + + DEFAULT_COMPUTATION_ID + + "\"" + + " input_data_watermark: 0" + + " work {" + + " key: \"key1\"" + + " sharding_key: 1" + + " work_token: 1" + + " cache_token: 2" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data1\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key2\"" + + " sharding_key: 2" + + " work_token: 2" + + " cache_token: 3" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data2\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key3\"" + + " sharding_key: 3" + + " work_token: 3" + + " cache_token: 4" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data3\"" + + " }" + + " }" + + " }" + + "}"; + Windmill.GetWorkResponse batchInput = + buildInput( + batchInputText, + CoderUtils.encodeToByteArray( + CollectionCoder.of(IntervalWindow.getCoder()), + Collections.singletonList(DEFAULT_WINDOW))); + server - .whenGetWorkCalled() - .thenReturn(makeInput(0, 0, DEFAULT_KEY_STRING, DEFAULT_SHARDING_KEY)); + .whenGetDataCalled() + .answerByDefault( + request -> { + Windmill.GetDataResponse.Builder builder = Windmill.GetDataResponse.newBuilder(); + for (ComputationGetDataRequest compRequest : request.getRequestsList()) { + ComputationGetDataResponse.Builder compBuilder = + builder.addDataBuilder().setComputationId(compRequest.getComputationId()); + for (KeyedGetDataRequest keyRequest : compRequest.getRequestsList()) { + KeyedGetDataResponse.Builder keyBuilder = + compBuilder + .addDataBuilder() + .setKey(keyRequest.getKey()) + .setShardingKey(keyRequest.getShardingKey()); + keyBuilder.addAllValues(keyRequest.getValuesToFetchList()); + keyBuilder.addAllBags(keyRequest.getBagsToFetchList()); + keyBuilder.addAllWatermarkHolds(keyRequest.getWatermarkHoldsToFetchList()); + } + } + return builder.build(); + }); + + server.whenGetWorkCalled().thenReturn(batchInput); + + Map result = server.waitForAndGetCommits(3); + + assertEquals(3, result.size()); + + List multiKeyCommits = + server.getMultiKeyCommitsReceived(); + assertEquals(1, multiKeyCommits.size()); + Windmill.MultiKeyWorkItemCommitRequest multiKeyCommit = multiKeyCommits.get(0); + assertEquals(3, multiKeyCommit.getRequestsCount()); + assertEquals(1, multiKeyCommit.getRequests(0).getWorkToken()); + assertEquals(2, multiKeyCommit.getRequests(1).getWorkToken()); + assertEquals(3, multiKeyCommit.getRequests(2).getWorkToken()); + + worker.stop(); + } + + @Test + public void testMultiKeyCommit_elementFailure() throws Exception { + if (!streamingEngine) { + return; + } + KvCoder kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()); + + List instructions = + Arrays.asList( + makeSourceInstruction(kvCoder), + makeDoFnInstruction(new WorkDoFn(), 0, kvCoder), + makeSinkInstruction(kvCoder, 1)); StreamingDataflowWorker worker = - makeWorker(defaultWorkerParams().setInstructions(instructions).publishCounters().build()); + makeWorker( + defaultWorkerParams( + "--experiments=unstable_enable_multi_key_bundle,windmill_max_key_group_batch_time_ms=5000", + "--numberOfWorkerHarnessThreads=1") + .setLocalRetryTimeoutMs(100) + .setInstructions(instructions) + .build()); worker.start(); - server.waitForEmptyWorkQueue(); + String batchInputText = + "work {" + + " computation_id: \"" + + DEFAULT_COMPUTATION_ID + + "\"" + + " input_data_watermark: 0" + + " work {" + + " key: \"key1\"" + + " sharding_key: 1" + + " work_token: 1" + + " cache_token: 2" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data1\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key2\"" + + " sharding_key: 2" + + " work_token: 2" + + " cache_token: 3" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data2\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key3\"" + + " sharding_key: 3" + + " work_token: 3" + + " cache_token: 4" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data3\"" + + " }" + + " }" + + " }" + + "}"; + Windmill.GetWorkResponse batchInput = + buildInput( + batchInputText, + CoderUtils.encodeToByteArray( + CollectionCoder.of(IntervalWindow.getCoder()), + Collections.singletonList(DEFAULT_WINDOW))); server - .whenGetWorkCalled() - .thenReturn(makeInput(1, 0, DEFAULT_KEY_STRING, DEFAULT_SHARDING_KEY)); + .whenGetDataCalled() + .answerByDefault( + request -> { + Windmill.GetDataResponse.Builder builder = Windmill.GetDataResponse.newBuilder(); + for (ComputationGetDataRequest compRequest : request.getRequestsList()) { + ComputationGetDataResponse.Builder compBuilder = + builder.addDataBuilder().setComputationId(compRequest.getComputationId()); + for (KeyedGetDataRequest keyRequest : compRequest.getRequestsList()) { + KeyedGetDataResponse.Builder keyBuilder = + compBuilder + .addDataBuilder() + .setKey(keyRequest.getKey()) + .setShardingKey(keyRequest.getShardingKey()); + if (keyRequest.getWorkToken() == 2) { + keyBuilder.setFailed(true); + } else { + keyBuilder.addAllValues(keyRequest.getValuesToFetchList()); + keyBuilder.addAllBags(keyRequest.getBagsToFetchList()); + keyBuilder.addAllWatermarkHolds(keyRequest.getWatermarkHoldsToFetchList()); + } + } + } + return builder.build(); + }); + + server.whenGetWorkCalled().thenReturn(batchInput); + + Map result = server.waitForAndGetCommits(2); + + assertTrue(result.containsKey(1L)); + assertTrue(result.containsKey(3L)); + assertFalse(result.containsKey(2L)); + + List multiKeyCommits = + server.getMultiKeyCommitsReceived(); + assertEquals(1, multiKeyCommits.size()); + Windmill.MultiKeyWorkItemCommitRequest multiKeyCommit = multiKeyCommits.get(0); + assertEquals(2, multiKeyCommit.getRequestsCount()); + assertEquals(3, multiKeyCommit.getRequests(0).getWorkToken()); + assertEquals(1, multiKeyCommit.getRequests(1).getWorkToken()); + + worker.stop(); + } + + @Test + public void testCompleteCommit_retryableFailureTriggersReExecution() throws Exception { + if (!streamingEngine) { + return; + } + KvCoder kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()); + + List instructions = + Arrays.asList( + makeSourceInstruction(kvCoder), + makeDoFnInstruction(new WorkDoFn(), 0, kvCoder), + makeSinkInstruction(kvCoder, 1)); + + StreamingDataflowWorker worker = + makeWorker( + defaultWorkerParams( + "--experiments=unstable_enable_multi_key_bundle,max_key_group_batch_time_ms=5000", + "--numberOfWorkerHarnessThreads=1") + .setLocalRetryTimeoutMs(100) + .setInstructions(instructions) + .build()); + worker.start(); + + String batchInputText = + "work {" + + " computation_id: \"" + + DEFAULT_COMPUTATION_ID + + "\"" + + " input_data_watermark: 0" + + " work {" + + " key: \"key1\"" + + " sharding_key: 1" + + " work_token: 1" + + " cache_token: 2" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data1\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key2\"" + + " sharding_key: 2" + + " work_token: 2" + + " cache_token: 3" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data2\"" + + " }" + + " }" + + " }" + + "}"; + Windmill.GetWorkResponse batchInput = + buildInput( + batchInputText, + CoderUtils.encodeToByteArray( + CollectionCoder.of(IntervalWindow.getCoder()), + Collections.singletonList(DEFAULT_WINDOW))); + + server + .whenGetDataCalled() + .answerByDefault( + request -> { + Windmill.GetDataResponse.Builder builder = Windmill.GetDataResponse.newBuilder(); + for (ComputationGetDataRequest compRequest : request.getRequestsList()) { + ComputationGetDataResponse.Builder compBuilder = + builder.addDataBuilder().setComputationId(compRequest.getComputationId()); + for (KeyedGetDataRequest keyRequest : compRequest.getRequestsList()) { + KeyedGetDataResponse.Builder keyBuilder = + compBuilder + .addDataBuilder() + .setKey(keyRequest.getKey()) + .setShardingKey(keyRequest.getShardingKey()); + if (keyRequest.getWorkToken() == 2) { + keyBuilder.setFailed(true); + } else { + keyBuilder.addAllValues(keyRequest.getValuesToFetchList()); + keyBuilder.addAllBags(keyRequest.getBagsToFetchList()); + keyBuilder.addAllWatermarkHolds(keyRequest.getWatermarkHoldsToFetchList()); + } + } + } + return builder.build(); + }); + + server.whenGetWorkCalled().thenReturn(batchInput); Map result = server.waitForAndGetCommits(1); - assertEquals( - makeExpectedOutput(1, 0, DEFAULT_KEY_STRING, DEFAULT_SHARDING_KEY, DEFAULT_KEY_STRING) - .build(), - removeDynamicFields(result.get(1L))); - assertEquals(1, result.size()); + assertTrue(result.containsKey(1L)); + assertFalse(result.containsKey(2L)); + + List multiKeyCommits = + server.getMultiKeyCommitsReceived(); + assertEquals(1, multiKeyCommits.size()); + Windmill.MultiKeyWorkItemCommitRequest multiKeyCommit = multiKeyCommits.get(0); + assertEquals(1, multiKeyCommit.getRequestsCount()); + assertEquals(1, multiKeyCommit.getRequests(0).getWorkToken()); worker.stop(); } @@ -4543,18 +4871,19 @@ public void evaluate() throws Throwable { } } - static class KeyTokenInvalidFn extends DoFn, KV> { - - static boolean thrown = false; + static class WorkDoFn extends DoFn, KV> { + @StateId("state") + private final StateSpec> stateSpec = StateSpecs.value(StringUtf8Coder.of()); @ProcessElement - public void processElement(ProcessContext c) { - if (!thrown) { - thrown = true; - throw new KeyTokenInvalidException("key"); - } else { - c.output(c.element()); + public void processElement(ProcessContext c, @StateId("state") ValueState state) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); } + state.read(); + c.output(c.element()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 534d51e2b88c..ce5d68b1a526 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -24,6 +24,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; @@ -57,23 +58,26 @@ import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.NoopProfileScope; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope; +import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; +import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.FakeGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; -import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTagEncodingV1; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTagEncodingV2; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.metrics.MetricsContainer; +import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.transforms.Create; @@ -100,7 +104,7 @@ public class StreamingModeExecutionContextTest { @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - @Mock private WindmillStateReader stateReader; + @Mock private WorkExecutor workExecutor; private static final String COMPUTATION_ID = "computationId"; @@ -112,7 +116,7 @@ public class StreamingModeExecutionContextTest { private FakeGlobalConfigHandle globalConfigHandle; private StreamingModeExecutionContext createExecutionContext( - StreamingGlobalConfigHandle configHandle) { + DataflowWorkerHarnessOptions options, StreamingGlobalConfigHandle configHandle) { CounterSet counterSet = new CounterSet(); ConcurrentHashMap stateNameMap = new ConcurrentHashMap<>(); stateNameMap.put(NameContextsForTests.nameContextForTest().userName(), "testStateFamily"); @@ -141,6 +145,7 @@ private StreamingModeExecutionContext createExecutionContext( /*hotKeyLoggingEnabled=*/ false, /*stepName=*/ "stepName", "sourceBytesProcessCounterName", + options, SideInputStateFetcherFactory.fromOptions(options)); } @@ -149,7 +154,7 @@ public void setUp() { MockitoAnnotations.initMocks(this); options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); globalConfigHandle = new FakeGlobalConfigHandle(StreamingGlobalConfig.builder().build()); - executionContext = createExecutionContext(globalConfigHandle); + executionContext = createExecutionContext(options, globalConfigHandle); } private static Work createMockWork(Windmill.WorkItem workItem, Watermarks watermarks) { @@ -178,7 +183,6 @@ private void start(StreamingModeExecutionContext context, Work work) { private void start(StreamingModeExecutionContext context, Work work, Coder keyCoder) { context.start( work, - stateReader, workExecutor, /* workQueueExecutor= */ null, /* budgetHandle= */ null, @@ -443,7 +447,7 @@ public void testStateTagEncodingBasedOnConfig() { FakeGlobalConfigHandle configHandle = new FakeGlobalConfigHandle( StreamingGlobalConfig.builder().setEnableStateTagEncodingV2(isV2Encoding).build()); - StreamingModeExecutionContext context = createExecutionContext(configHandle); + StreamingModeExecutionContext context = createExecutionContext(options, configHandle); assertEquals(expectedEncoding, context.getWindmillTagEncoding().getClass()); } } @@ -496,4 +500,184 @@ public void testStart_internalKeyDecoding() throws Exception { assertEquals("decodedKey", executionContext.getKey()); } + + @Test + public void testAdvance_success() throws Exception { + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + Windmill.WorkItem workItem2 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key2")) + .setWorkToken(2L) + .setKeyGroup(keyGroup) + .build(); + Work work2 = + createMockWork( + workItem2, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + ExecutableWork executableWork2 = ExecutableWork.create(work2, (w, h) -> {}); + + org.mockito.Mockito.when( + mockExecutor.pollWork( + org.mockito.Mockito.eq(COMPUTATION_ID), + org.mockito.Mockito.eq(work1.getKeyGroup()), + org.mockito.Mockito.eq(mockHandle))) + .thenReturn(executableWork2); + + executionContext.start( + work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertTrue(executionContext.advance()); + assertEquals("key2", executionContext.getSerializedKey().toStringUtf8()); + } + + @Test + public void testAdvance_noMoreWork() throws Exception { + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + org.mockito.Mockito.when( + mockExecutor.pollWork( + org.mockito.Mockito.eq(COMPUTATION_ID), + org.mockito.Mockito.eq(work1.getKeyGroup()), + org.mockito.Mockito.eq(mockHandle))) + .thenReturn(null); + + executionContext.start( + work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(executionContext.advance()); + } + + @Test + public void testAdvance_respectsMaxBatchSize() throws Exception { + DataflowWorkerHarnessOptions optionsWithBatchSize = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + optionsWithBatchSize + .as(ExperimentalOptions.class) + .setExperiments(Arrays.asList("max_key_group_batch_size=1")); + StreamingModeExecutionContext context = + createExecutionContext(optionsWithBatchSize, globalConfigHandle); + + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + context.start(work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(context.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } + + @Test + public void testAdvance_respectsMaxBatchTime() throws Exception { + DataflowWorkerHarnessOptions optionsWithBatchTime = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + optionsWithBatchTime + .as(ExperimentalOptions.class) + .setExperiments(Arrays.asList("max_key_group_batch_time_ms=0")); + StreamingModeExecutionContext context = + createExecutionContext(optionsWithBatchTime, globalConfigHandle); + + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + context.start(work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(context.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } + + @Test + public void testAdvance_workFailed() throws Exception { + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + executionContext.start( + work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + work1.setFailed(); + + assertThrows(WorkItemCancelledException.class, () -> executionContext.advance()); + } + + @Test + public void testAdvance_defaultKeyGroup() throws Exception { + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + executionContext.start( + work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(executionContext.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index 0af802ec6760..9d43b62d7c38 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -102,7 +102,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; -import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; @@ -213,7 +212,6 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla private void startContext(StreamingModeExecutionContext context, Work work) { context.start( work, - mock(WindmillStateReader.class), mock(WorkExecutor.class), /* workQueueExecutor= */ null, /* budgetHandle= */ null, @@ -637,6 +635,7 @@ public void testReadUnboundedReader() throws Exception { /*hotKeyLoggingEnabled=*/ false, /*stepName=*/ "stepName", "sourceBytesProcessCounterName", + options, SideInputStateFetcherFactory.fromOptions(options)); options.setNumWorkers(5); @@ -1010,6 +1009,7 @@ public void testFailedWorkItemsAbort() throws Exception { /*hotKeyLoggingEnabled=*/ false, /*stepName=*/ "stepName", "sourceBytesProcessCounterName", + options, SideInputStateFetcherFactory.fromOptions(options)); options.setNumWorkers(5); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java index 0f14efdd0c0b..60d7bb71a9de 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java @@ -565,6 +565,31 @@ public void testFailWork_batchFail() { } } + @Test + public void testGetActiveWork() { + ShardedKey shardedKey = shardedKey("someKey", 1L); + ExecutableWork work = createWork(createWorkItem(1L, 1L, shardedKey)); + + // Initially empty + assertFalse(activeWorkState.getActiveWork(shardedKey, work.id()).isPresent()); + + // Activate work + activeWorkState.activateWorkForKey(work); + + // Should find it now + Optional activeWork = activeWorkState.getActiveWork(shardedKey, work.id()); + assertTrue(activeWork.isPresent()); + assertSame(work, activeWork.get()); + + // Should not find it with different workId + assertFalse(activeWorkState.getActiveWork(shardedKey, workId(2L, 1L)).isPresent()); + assertFalse(activeWorkState.getActiveWork(shardedKey, workId(1L, 2L)).isPresent()); + + // Should not find it with different shardedKey + ShardedKey otherShardedKey = shardedKey("otherKey", 2L); + assertFalse(activeWorkState.getActiveWork(otherShardedKey, work.id()).isPresent()); + } + private static ExecutableWork firstValue(Map map) { Iterator> iterator = map.entrySet().iterator(); if (iterator.hasNext()) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateTest.java new file mode 100644 index 000000000000..6a6edd2b7192 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateTest.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import com.google.api.services.dataflow.model.MapTask; +import java.util.Collections; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ComputationStateTest { + + private final BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + private final WindmillStateCache.ForComputation mockStateCache = + mock(WindmillStateCache.ForComputation.class); + private final HeartbeatSender mockHeartbeatSender = mock(HeartbeatSender.class); + + private ComputationState computationState; + + private static ShardedKey shardedKey(String str, long shardKey) { + return ShardedKey.create(ByteString.copyFromUtf8(str), shardKey); + } + + private ExecutableWork createWork(Windmill.WorkItem workItem) { + return ExecutableWork.create( + Work.create( + workItem, + workItem.getSerializedSize(), + Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), + Work.createProcessingContext( + "computationId", new FakeGetDataClient(), ignored -> {}, mockHeartbeatSender), + false, + Instant::now), + (work, handle) -> {}); + } + + private static Windmill.WorkItem createWorkItem( + long workToken, long cacheToken, ShardedKey shardedKey) { + return Windmill.WorkItem.newBuilder() + .setShardingKey(shardedKey.shardingKey()) + .setKey(shardedKey.key()) + .setWorkToken(workToken) + .setCacheToken(cacheToken) + .build(); + } + + @Before + public void setUp() { + MapTask mapTask = new MapTask(); + mapTask.setStageName("stage"); + mapTask.setSystemName("system"); + computationState = + new ComputationState( + "computationId", mapTask, mockExecutor, Collections.emptyMap(), mockStateCache); + } + + @Test + public void testReExecuteActiveWork_workNotActive() { + ShardedKey shardedKey = shardedKey("key", 1L); + WorkId workId = WorkId.builder().setWorkToken(1L).setCacheToken(1L).build(); + + computationState.reExecuteActiveWork(shardedKey, workId); + + verifyNoInteractions(mockExecutor); + } + + @Test + public void testReExecuteActiveWork_workActive() { + ShardedKey shardedKey = shardedKey("key", 1L); + Windmill.WorkItem workItem = createWorkItem(1L, 1L, shardedKey); + ExecutableWork work = createWork(workItem); + + // Activate work first. This will execute it once. + computationState.activateWork(work); + verify(mockExecutor).execute(work, work.work().getSerializedWorkItemSize()); + + // Now re-execute + computationState.reExecuteActiveWork(shardedKey, work.id()); + verify(mockExecutor).forceExecute(work, work.work().getSerializedWorkItemSize()); + + verifyNoMoreInteractions(mockExecutor); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java index a98102751fb2..245d600448fe 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java @@ -30,7 +30,10 @@ import java.util.Collection; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import java.util.function.Consumer; +import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; @@ -82,6 +85,14 @@ private static ExecutableWork createWorkWithCompId( private static ExecutableWork createWorkWithCompIdAndKeyGroup( String computationId, Work.KeyGroup keyGroup, Consumer executeWorkFn) { + return createWorkWithHandle( + computationId, keyGroup, (work, handle) -> executeWorkFn.accept(work)); + } + + private static ExecutableWork createWorkWithHandle( + String computationId, + Work.KeyGroup keyGroup, + BiConsumer executeWorkFn) { WorkItem workItem = WorkItem.newBuilder() .setKey(ByteString.EMPTY) @@ -103,9 +114,7 @@ private static ExecutableWork createWorkWithCompIdAndKeyGroup( computationId, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), false, Instant::now), - (work, handle) -> { - executeWorkFn.accept(work); - }); + executeWorkFn); } private ExecutableWork createSleepProcessWork(CountDownLatch start, CountDownLatch stop) { @@ -406,18 +415,25 @@ public void testRunnableExceptionPropagationDecrementsCounters() throws Exceptio @Test public void testHandleMerge() throws Exception { - BoundedQueueExecutorWorkHandleImpl handle1 = executor.createBudgetHandle(1, 100L); - BoundedQueueExecutorWorkHandleImpl handle2 = executor.createBudgetHandle(2, 200L); + Work work1 = createWork(ignored -> {}).work(); + Work work2 = createWork(ignored -> {}).work(); + Work work3 = createWork(ignored -> {}).work(); + BoundedQueueExecutorWorkHandleImpl handle1 = executor.createBudgetHandle(work1, 100L); + BoundedQueueExecutorWorkHandleImpl handle2 = executor.createBudgetHandle(work2, 200L); + handle2.merge(executor.createBudgetHandle(work3, 0L)); handle1.merge(handle2); // Verify that handle2 has 0 budget and is closed. - assertEquals(0, handle2.elements()); + assertEquals(0, handle2.getWorkBatch().size()); assertEquals(0, handle2.bytes()); assertTrue(handle2.isClosed()); // Verify that handle1 has the combined budget and is not closed. - assertEquals(3, handle1.elements()); + assertEquals(3, handle1.getWorkBatch().size()); + assertTrue(handle1.getWorkBatch().contains(work1)); + assertTrue(handle1.getWorkBatch().contains(work2)); + assertTrue(handle1.getWorkBatch().contains(work3)); assertEquals(300L, handle1.bytes()); assertFalse(handle1.isClosed()); } @@ -449,11 +465,13 @@ public void testPollWork() throws Exception { // 1. Create blocker task to occupy the worker thread CountDownLatch blockerStart = new CountDownLatch(1); CountDownLatch blockerStop = new CountDownLatch(1); + AtomicReference blockerHandleRef = new AtomicReference<>(); ExecutableWork blockerWork = - createWorkWithCompIdAndKeyGroup( + createWorkWithHandle( "blockerComp", DEFAULT_KEY_GROUP, - ignored -> { + (work, handle) -> { + blockerHandleRef.set(handle); blockerStart.countDown(); try { blockerStop.await(); @@ -464,6 +482,9 @@ public void testPollWork() throws Exception { testExecutor.execute(blockerWork, 0); blockerStart.await(); + BoundedQueueExecutorWorkHandleImpl stealHandle = + (BoundedQueueExecutorWorkHandleImpl) blockerHandleRef.get(); + assertNotNull(stealHandle); // 2. Create two distinct key groups Work.KeyGroup keyGroup1 = Work.KeyGroup.create(1, 1); @@ -488,22 +509,18 @@ public void testPollWork() throws Exception { assertEquals(3, testExecutor.elementsOutstanding()); // Steal work2 using pollWork with compA and keyGroup2 - try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { - ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup2, stealHandle); - assertNotNull(stolen); - assertEquals(work2, stolen); - - // Run the stolen task - stolen.run(stealHandle); - targetStart.await(); - } + ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup2, stealHandle); + assertNotNull(stolen); + assertEquals(work2, stolen); + + // Run the stolen task + stolen.run(stealHandle); + targetStart.await(); // Steal work1 using pollWork with compA and keyGroup1 - try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { - ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup1, stealHandle); - assertNotNull(stolen); - assertEquals(work1, stolen); - } + ExecutableWork stolen1 = testExecutor.pollWork("compA", keyGroup1, stealHandle); + assertNotNull(stolen1); + assertEquals(work1, stolen1); // Unblock the blocker and shut down blockerStop.countDown(); @@ -525,11 +542,13 @@ public void testPollWorkWithLinkedBlockingQueue() throws Exception { CountDownLatch blockerStart = new CountDownLatch(1); CountDownLatch blockerStop = new CountDownLatch(1); + AtomicReference blockerHandleRef = new AtomicReference<>(); ExecutableWork blockerWork = - createWorkWithCompIdAndKeyGroup( + createWorkWithHandle( "blockerComp", DEFAULT_KEY_GROUP, - ignored -> { + (work, handle) -> { + blockerHandleRef.set(handle); blockerStart.countDown(); try { blockerStop.await(); @@ -540,15 +559,16 @@ public void testPollWorkWithLinkedBlockingQueue() throws Exception { testExecutor.execute(blockerWork, 0); blockerStart.await(); + BoundedQueueExecutorWorkHandleImpl stealHandle = + (BoundedQueueExecutorWorkHandleImpl) blockerHandleRef.get(); + assertNotNull(stealHandle); Work.KeyGroup keyGroup = Work.KeyGroup.create(1, 1); ExecutableWork work = createWorkWithCompIdAndKeyGroup("compA", keyGroup, ignored -> {}); testExecutor.execute(work, 100); - try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { - ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup, stealHandle); - assertNull(stolen); - } + ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup, stealHandle); + assertNull(stolen); blockerStop.countDown(); testExecutor.shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java index 994aa2030f3f..307cbde36989 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java @@ -63,7 +63,6 @@ public static Iterable data() { } @Parameterized.Parameter public boolean fairQueue; - private BoundedQueueExecutor executor; @Before @@ -116,7 +115,7 @@ private QueuedWork createQueuedWork( false, Instant::now), (w, h) -> {}); - return new QueuedWork(work, executor.createBudgetHandle(1, workBytes)); + return new QueuedWork(work, executor.createBudgetHandle(work.work(), workBytes)); } private static class NoOpRunnable implements Runnable { @@ -312,7 +311,6 @@ public String toString() { } })); } - // Start producers for (int i = 0; i < producerThreads; i++) { futures.add( @@ -470,7 +468,6 @@ public void testPollWorkWithKeyGroup() { QueuedWork polledNotExist = queue.pollWork("compA", keyGroupNotExist); assertNull(polledNotExist); assertEquals(2, queue.size()); - // Poll with keyGroup2 first - should return workA2 QueuedWork polledA2 = queue.pollWork("compA", keyGroup2); assertNotNull(polledA2); @@ -485,7 +482,6 @@ public void testPollWorkWithKeyGroup() { assertNotNull(polledA1); assertEquals(workA1, polledA1); assertTrue(queue.isEmpty()); - polledNotExist = queue.pollWork("compA", keyGroupNotExist); assertNull(polledNotExist); assertTrue(queue.isEmpty()); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index 5e5fd9ce6420..a48159338132 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -139,7 +139,8 @@ private static ComputationState createComputationState(String computationId) { private static CompleteCommit asCompleteCommit( String computationId, Work work, Windmill.CommitStatus status) { Windmill.CommitStatus finalStatus = work.isFailed() ? Windmill.CommitStatus.ABORTED : status; - return CompleteCommit.create(computationId, work.getShardedKey(), work.id(), finalStatus); + return CompleteCommit.create( + computationId, work.getShardedKey(), work.id(), finalStatus, /* retryableFailure= */ false); } @Before @@ -394,6 +395,7 @@ public void shutdown() {} assertThat(commits.size()).isEqualTo(completeCommits.size()); for (CompleteCommit completeCommit : completeCommits) { assertThat(completeCommit.status()).isEqualTo(Windmill.CommitStatus.ABORTED); + assertThat(completeCommit.retryableFailure()).isFalse(); } for (Commit commit : commits) { @@ -559,11 +561,28 @@ public void testCommit_multiKeyCommitFailedWork() { assertThat(completeCommits) .containsExactly( CompleteCommit.create( - "computationId", workA.getShardedKey(), workA.id(), CommitStatus.ABORTED), + "computationId", + workA.getShardedKey(), + workA.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ true), CompleteCommit.create( - "computationId", workB.getShardedKey(), workB.id(), CommitStatus.ABORTED), + "computationId", + workB.getShardedKey(), + workB.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ false), CompleteCommit.create( - "computationId", workC.getShardedKey(), workC.id(), CommitStatus.ABORTED)); + "computationId", + workC.getShardedKey(), + workC.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ true)); + + // Verify that valid work was not marked failed + assertThat(workA.isFailed()).isFalse(); + assertThat(workC.isFailed()).isFalse(); + assertThat(workB.isFailed()).isTrue(); workCommitter.stop(); } @@ -625,11 +644,23 @@ public void testCommit_multiKeyCommitSuccess() { assertThat(completeCommits) .containsExactly( CompleteCommit.create( - "computationId", workA.getShardedKey(), workA.id(), CommitStatus.OK), + "computationId", + workA.getShardedKey(), + workA.id(), + CommitStatus.OK, + /* retryableFailure= */ false), CompleteCommit.create( - "computationId", workB.getShardedKey(), workB.id(), CommitStatus.OK), + "computationId", + workB.getShardedKey(), + workB.id(), + CommitStatus.OK, + /* retryableFailure= */ false), CompleteCommit.create( - "computationId", workC.getShardedKey(), workC.id(), CommitStatus.OK)); + "computationId", + workC.getShardedKey(), + workC.id(), + CommitStatus.OK, + /* retryableFailure= */ false)); workCommitter.stop(); } @@ -694,11 +725,23 @@ public void testCommit_multiKeyCommitStatusNotOK() { assertThat(completeCommits) .containsExactly( CompleteCommit.create( - "computationId", workA.getShardedKey(), workA.id(), CommitStatus.NOT_FOUND), + "computationId", + workA.getShardedKey(), + workA.id(), + CommitStatus.NOT_FOUND, + /* retryableFailure= */ false), CompleteCommit.create( - "computationId", workB.getShardedKey(), workB.id(), CommitStatus.NOT_FOUND), + "computationId", + workB.getShardedKey(), + workB.id(), + CommitStatus.NOT_FOUND, + /* retryableFailure= */ false), CompleteCommit.create( - "computationId", workC.getShardedKey(), workC.id(), CommitStatus.NOT_FOUND)); + "computationId", + workC.getShardedKey(), + workC.id(), + CommitStatus.NOT_FOUND, + /* retryableFailure= */ false)); workCommitter.stop(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index b83890c1dbdd..fc8348a68ce6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -1203,6 +1203,194 @@ public void testCommit_multiKeyCommit() throws Exception { assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); } + @Test + public void testCommit_multiKeyCommit_multichunk() throws Exception { + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + + long shardingKey1 = 101L; + long workToken1 = 201L; + long cacheToken1 = 301L; + long shardingKey2 = 102L; + long workToken2 = 202L; + long cacheToken2 = 302L; + + Windmill.WorkItemCommitRequest request1 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setShardingKey(shardingKey1) + .setWorkToken(workToken1) + .setCacheToken(cacheToken1) + .addBagUpdates(Windmill.TagBag.newBuilder().setTag(LARGE_BYTE_STRING).build()) + .build(); + + Windmill.WorkItemCommitRequest request2 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key2")) + .setShardingKey(shardingKey2) + .setWorkToken(workToken2) + .setCacheToken(cacheToken2) + .build(); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests(request1) + .addRequests(request2) + .build(); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitMultiKeyWorkItem( + COMPUTATION_ID, multiKeyRequest, commitStatusFuture::complete)); + } + + Windmill.StreamingCommitWorkRequest requestChunk1 = streamInfo.requests.take(); + assertThat(requestChunk1.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk1 = requestChunk1.getCommitChunk(0); + + assertThat(chunk1.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + assertThat(chunk1.getShardingKey()).isEqualTo(request1.getShardingKey()); + assertThat(chunk1.getRemainingBytesForWorkItem()).isGreaterThan(0); + + Windmill.StreamingCommitWorkRequest requestChunk2 = streamInfo.requests.take(); + assertThat(requestChunk2.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk2 = requestChunk2.getCommitChunk(0); + + assertThat(chunk2.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + assertThat(chunk2.getShardingKey()).isEqualTo(request1.getShardingKey()); + assertThat(chunk2.getRemainingBytesForWorkItem()).isEqualTo(0); + + ByteString reconstructedBytes = + chunk1.getSerializedWorkItemCommit().concat(chunk2.getSerializedWorkItemCommit()); + Windmill.MultiKeyWorkItemCommitRequest parsedRequest = + Windmill.MultiKeyWorkItemCommitRequest.parseFrom(reconstructedBytes); + assertThat(parsedRequest).isEqualTo(multiKeyRequest); + + long requestId = chunk1.getRequestId(); + assertThat(chunk2.getRequestId()).isEqualTo(requestId); + + streamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(requestId).build()); + + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + } + + @Test + public void testCommitMultiKeyWorkItem_retryOnNewStream() throws Exception { + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + + long shardingKey1 = 101L; + long workToken1 = 201L; + long cacheToken1 = 301L; + Windmill.WorkItemCommitRequest request1 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setShardingKey(shardingKey1) + .setWorkToken(workToken1) + .setCacheToken(cacheToken1) + .build(); + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder().addRequests(request1).build(); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitMultiKeyWorkItem( + COMPUTATION_ID, multiKeyRequest, commitStatusFuture::complete)); + } + + Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take(); + assertThat(request.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk = request.getCommitChunk(0); + assertThat(chunk.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + long requestId = chunk.getRequestId(); + + streamInfo.responseObserver.onError(new IOException("test error")); + + FakeWindmillGrpcService.CommitStreamInfo reconnectStreamInfo = + waitForConnectionAndConsumeHeader(); + Windmill.StreamingCommitWorkRequest reconnectRequest = reconnectStreamInfo.requests.take(); + assertThat(reconnectRequest.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk reconnectChunk = reconnectRequest.getCommitChunk(0); + assertThat(reconnectChunk.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + assertThat(reconnectChunk.getRequestId()).isEqualTo(requestId); + + Windmill.MultiKeyWorkItemCommitRequest parsedRequest = + Windmill.MultiKeyWorkItemCommitRequest.parseFrom( + reconnectChunk.getSerializedWorkItemCommit()); + assertThat(parsedRequest).isEqualTo(multiKeyRequest); + + reconnectStreamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(requestId).build()); + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + } + + @Test + public void testCommitWorkItem_retryOnNewStream_multichunk() throws Exception { + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + + Windmill.WorkItemCommitRequest largeRequest = + workItemCommitRequest(1) + .toBuilder() + .addBagUpdates(Windmill.TagBag.newBuilder().setTag(LARGE_BYTE_STRING).build()) + .build(); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem(COMPUTATION_ID, largeRequest, commitStatusFuture::complete)); + } + + Windmill.StreamingCommitWorkRequest requestChunk1 = streamInfo.requests.take(); + assertThat(requestChunk1.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk1 = requestChunk1.getCommitChunk(0); + long requestId = chunk1.getRequestId(); + assertThat(chunk1.getRemainingBytesForWorkItem()).isGreaterThan(0); + + Windmill.StreamingCommitWorkRequest requestChunk2 = streamInfo.requests.take(); + assertThat(requestChunk2.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk2 = requestChunk2.getCommitChunk(0); + assertThat(chunk2.getRequestId()).isEqualTo(requestId); + assertThat(chunk2.getRemainingBytesForWorkItem()).isEqualTo(0); + + streamInfo.responseObserver.onError(new IOException("test error")); + + FakeWindmillGrpcService.CommitStreamInfo reconnectStreamInfo = + waitForConnectionAndConsumeHeader(); + + Windmill.StreamingCommitWorkRequest reconnectChunk1 = reconnectStreamInfo.requests.take(); + assertThat(reconnectChunk1.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk reconChunk1 = reconnectChunk1.getCommitChunk(0); + assertThat(reconChunk1.getRequestId()).isEqualTo(requestId); + assertThat(reconChunk1.getRemainingBytesForWorkItem()).isGreaterThan(0); + + Windmill.StreamingCommitWorkRequest reconnectChunk2 = reconnectStreamInfo.requests.take(); + assertThat(reconnectChunk2.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk reconChunk2 = reconnectChunk2.getCommitChunk(0); + assertThat(reconChunk2.getRequestId()).isEqualTo(requestId); + assertThat(reconChunk2.getRemainingBytesForWorkItem()).isEqualTo(0); + + ByteString reconstructedBytes = + reconChunk1.getSerializedWorkItemCommit().concat(reconChunk2.getSerializedWorkItemCommit()); + Windmill.WorkItemCommitRequest parsedRequest = + Windmill.WorkItemCommitRequest.parseFrom(reconstructedBytes); + assertThat(parsedRequest).isEqualTo(largeRequest); + + reconnectStreamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(requestId).build()); + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + } + private FakeWindmillGrpcService.CommitStreamInfo waitForConnectionAndConsumeHeader() { try { FakeWindmillGrpcService.CommitStreamInfo info = fakeService.waitForConnectedCommitStream(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java index 1611fdac25dc..65637437a0a0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java @@ -35,9 +35,9 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.Future; -import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; import org.apache.beam.runners.dataflow.worker.WindmillStateTestUtils; import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; +import org.apache.beam.runners.dataflow.worker.WorkCancelingException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListEntry; @@ -1572,16 +1572,16 @@ public void testKeyTokenInvalid() throws Exception { try { watermarkFuture.get(); - fail("Expected KeyTokenInvalidException"); + fail("Expected WorkCancelingException"); } catch (Exception e) { - assertTrue(KeyTokenInvalidException.isKeyTokenInvalidException(e)); + assertTrue(WorkCancelingException.isWorkCancelingException(e)); } try { bagFuture.get(); - fail("Expected KeyTokenInvalidException"); + fail("Expected WorkCancelingException"); } catch (Exception e) { - assertTrue(KeyTokenInvalidException.isKeyTokenInvalidException(e)); + assertTrue(WorkCancelingException.isWorkCancelingException(e)); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java index 0610ed44c27f..ce9fe53f47d3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java @@ -21,15 +21,15 @@ import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; +import java.util.Arrays; import java.util.HashSet; +import java.util.List; import java.util.Optional; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Supplier; -import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; -import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; @@ -109,38 +109,22 @@ private static ExecutableWork createWork(Consumer processWorkFn) { } @Test - public void logAndProcessFailure_doesNotRetryKeyTokenInvalidException() throws Throwable { + public void logAndProcessFailureBatch_doesNotRetryFailedWork() throws Throwable { Set executedWork = new HashSet<>(); ExecutableWork work = createWork(executedWork::add); + work.work().setFailed(); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingEngineFailureReporter()); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new KeyTokenInvalidException("key"), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, List.of(work), new RuntimeException(), invalidWork::add); assertThat(executedWork).isEmpty(); assertThat(invalidWork).containsExactly(work.work()); } @Test - public void logAndProcessFailure_doesNotRetryWhenWorkItemCancelled() throws Throwable { - Set executedWork = new HashSet<>(); - ExecutableWork work = createWork(executedWork::add); - WorkFailureProcessor workFailureProcessor = - createWorkFailureProcessor(streamingEngineFailureReporter()); - Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, - work, - new WorkItemCancelledException(work.getWorkItem().getShardingKey()), - invalidWork::add); - - assertThat(executedWork).isEmpty(); - assertThat(invalidWork).containsExactly(work.work()); - } - - @Test - public void logAndProcessFailure_doesNotRetryOOM() { + public void logAndProcessFailureBatch_doesNotRetryOOM() { Set executedWork = new HashSet<>(); ExecutableWork work = createWork(executedWork::add); WorkFailureProcessor workFailureProcessor = @@ -149,69 +133,120 @@ public void logAndProcessFailure_doesNotRetryOOM() { assertThrows( OutOfMemoryError.class, () -> - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new OutOfMemoryError(), invalidWork::add)); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, + Arrays.asList(work), + new OutOfMemoryError(), + invalidWork::add)); assertThat(executedWork).isEmpty(); } @Test - public void logAndProcessFailure_doesNotRetryWhenFailureReporterMarksAsNonRetryable() + public void logAndProcessFailureBatch_doesNotRetryWhenFailureReporterMarksAsNonRetryable() throws Throwable { Set executedWork = new HashSet<>(); ExecutableWork work = createWork(executedWork::add); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingApplianceFailureReporter(true)); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new RuntimeException(), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, Arrays.asList(work), new RuntimeException(), invalidWork::add); assertThat(executedWork).isEmpty(); assertThat(invalidWork).containsExactly(work.work()); } @Test - public void logAndProcessFailure_doesNotRetryAfterLocalRetryTimeout() throws Throwable { + public void logAndProcessFailureBatch_doesNotRetryAfterLocalRetryTimeout() throws Throwable { Set executedWork = new HashSet<>(); ExecutableWork veryOldWork = createWork(() -> Instant.now().minus(Duration.standardDays(30)), executedWork::add); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingEngineFailureReporter()); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, veryOldWork, new RuntimeException(), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, + Arrays.asList(veryOldWork), + new RuntimeException(), + invalidWork::add); assertThat(executedWork).isEmpty(); assertThat(invalidWork).contains(veryOldWork.work()); } @Test - public void logAndProcessFailure_retriesOnUncaughtUnhandledException_streamingEngine() + public void logAndProcessFailureBatch_retriesOnUncaughtUnhandledException_streamingEngine() throws Throwable { CountDownLatch runWork = new CountDownLatch(1); ExecutableWork work = createWork(ignored -> runWork.countDown()); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingEngineFailureReporter()); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new RuntimeException(), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, Arrays.asList(work), new RuntimeException(), invalidWork::add); runWork.await(); assertThat(invalidWork).isEmpty(); } @Test - public void logAndProcessFailure_retriesOnUncaughtUnhandledException_streamingAppliance() + public void logAndProcessFailureBatch_retriesOnUncaughtUnhandledException_streamingAppliance() throws Throwable { CountDownLatch runWork = new CountDownLatch(1); ExecutableWork work = createWork(ignored -> runWork.countDown()); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingApplianceFailureReporter(false)); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new RuntimeException(), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, Arrays.asList(work), new RuntimeException(), invalidWork::add); runWork.await(); assertThat(invalidWork).isEmpty(); } + + @Test + public void logAndProcessFailureBatch_retryAll() throws Throwable { + CountDownLatch runWork1 = new CountDownLatch(1); + CountDownLatch runWork2 = new CountDownLatch(1); + ExecutableWork work1 = createWork(ignored -> runWork1.countDown()); + ExecutableWork work2 = createWork(ignored -> runWork2.countDown()); + + WorkFailureProcessor workFailureProcessor = + createWorkFailureProcessor(streamingEngineFailureReporter()); + Set invalidWork = new HashSet<>(); + + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, + Arrays.asList(work1, work2), + new RuntimeException(), + invalidWork::add); + + runWork1.await(); + runWork2.await(); + assertThat(invalidWork).isEmpty(); + } + + @Test + public void logAndProcessFailureBatch_mixRetryAndAbort() throws Throwable { + CountDownLatch runWork1 = new CountDownLatch(1); + Set executedWork2 = new HashSet<>(); + ExecutableWork work1 = createWork(ignored -> runWork1.countDown()); + ExecutableWork work2 = createWork(executedWork2::add); + work2.work().setFailed(); + + WorkFailureProcessor workFailureProcessor = + createWorkFailureProcessor(streamingEngineFailureReporter()); + Set invalidWork = new HashSet<>(); + + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, + Arrays.asList(work1, work2), + new RuntimeException(), + invalidWork::add); + + runWork1.await(); + assertThat(executedWork2).isEmpty(); + assertThat(invalidWork).containsExactly(work2.work()); + } } From 9de3d846676ae0221082f8bf872641d3746e5a69 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Thu, 11 Jun 2026 13:04:40 +0000 Subject: [PATCH 20/21] Fix tests and add sink byte limit for batching --- .../worker/DataflowExecutionContext.java | 4 + .../worker/StreamingDataflowWorker.java | 2 +- .../worker/StreamingModeExecutionContext.java | 77 +++++++++++++-- .../StreamingModeExecutionContextTest.java | 95 ++++++++++++++++++- 4 files changed, 166 insertions(+), 12 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowExecutionContext.java index 6ff05b4b4452..888e954c1c9f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowExecutionContext.java @@ -150,6 +150,10 @@ boolean isSinkFullHintSet() { // the state size might grow unbounded. } + protected final long getBytesSinked() { + return bytesSinked; + } + /** * Sets a flag to indicate that a sink has enough data written to it. This hint is read by * upstream producers to stop producing if they can. Mainly used in streaming. diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 71a2547cc602..d03167540a88 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -181,7 +181,7 @@ public final class StreamingDataflowWorker { "windmill_bounded_queue_executor_use_fair_monitor"; // Don't use. Experiment guarding multi key bundles. The feature is work in progress and // incomplete. - private static final String UNSTABLE_ENABLE_MULTI_KEY_BUNDLE = "unstable_enable_multi_key_bundle"; + public static final String UNSTABLE_ENABLE_MULTI_KEY_BUNDLE = "unstable_enable_multi_key_bundle"; private final WindmillStateCache stateCache; private AtomicReference statusPages = new AtomicReference<>(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index f855946cc075..0618a6e17b16 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -128,6 +128,8 @@ public class StreamingModeExecutionContext "windmill_max_key_group_batch_size"; private static final String WINDMILL_MAX_KEY_GROUP_BATCH_TIME_MS = "windmill_max_key_group_batch_time_ms"; + private static final String WINDMILL_MAX_KEY_GROUP_BATCH_SINK_BYTES = + "windmill_max_key_group_batch_sink_bytes"; private final String computationId; private final ImmutableMap stateNameMap; @@ -207,6 +209,8 @@ public interface KeyTransitionListener { private final int maxKeyGroupBatchSize; private final long maxKeyGroupBatchTimeNanos; + private final boolean multiKeyBundleEnabled; + private final long maxKeyGroupBatchSinkBytes; private int workItemsPolled = 0; private long bundleStartTimeNanos = 0; @@ -249,14 +253,28 @@ public StreamingModeExecutionContext( this.sideInputStateFetcherFactory = checkNotNull(sideInputStateFetcherFactory); // Initialize batch limits from pipeline options - String batchSizeStr = - ExperimentalOptions.getExperimentValue(options, WINDMILL_MAX_KEY_GROUP_BATCH_SIZE); - this.maxKeyGroupBatchSize = batchSizeStr != null ? Integer.parseInt(batchSizeStr) : 100; - - String batchTimeStr = - ExperimentalOptions.getExperimentValue(options, WINDMILL_MAX_KEY_GROUP_BATCH_TIME_MS); - this.maxKeyGroupBatchTimeNanos = - TimeUnit.MILLISECONDS.toNanos(batchTimeStr != null ? Long.parseLong(batchTimeStr) : 100); + this.maxKeyGroupBatchSize = + tryParseInt( + ExperimentalOptions.getExperimentValue(options, WINDMILL_MAX_KEY_GROUP_BATCH_SIZE), + 100, + WINDMILL_MAX_KEY_GROUP_BATCH_SIZE); + + long batchTimeMs = + tryParseLong( + ExperimentalOptions.getExperimentValue(options, WINDMILL_MAX_KEY_GROUP_BATCH_TIME_MS), + 100, + WINDMILL_MAX_KEY_GROUP_BATCH_TIME_MS); + this.maxKeyGroupBatchTimeNanos = TimeUnit.MILLISECONDS.toNanos(batchTimeMs); + + this.multiKeyBundleEnabled = + ExperimentalOptions.hasExperiment(options, StreamingDataflowWorker.UNSTABLE_ENABLE_MULTI_KEY_BUNDLE); + + this.maxKeyGroupBatchSinkBytes = + tryParseLong( + ExperimentalOptions.getExperimentValue( + options, WINDMILL_MAX_KEY_GROUP_BATCH_SINK_BYTES), + StreamingDataflowWorker.MAX_SINK_BYTES, + WINDMILL_MAX_KEY_GROUP_BATCH_SINK_BYTES); StreamingGlobalConfig config = globalConfigHandle.getConfig(); this.operationalLimits = config.operationalLimits(); @@ -266,6 +284,41 @@ public StreamingModeExecutionContext( : WindmillTagEncodingV1.instance(); } + private static int tryParseInt(@Nullable String value, int defaultValue, String experimentName) { + if (value == null) { + return defaultValue; + } + try { + return Integer.parseInt(value); + } catch (NumberFormatException e) { + LOG.warn( + "Failed to parse experiment {} value '{}' as integer, falling back to default: {}", + experimentName, + value, + defaultValue, + e); + return defaultValue; + } + } + + private static long tryParseLong( + @Nullable String value, long defaultValue, String experimentName) { + if (value == null) { + return defaultValue; + } + try { + return Long.parseLong(value); + } catch (NumberFormatException e) { + LOG.warn( + "Failed to parse experiment {} value '{}' as long, falling back to default: {}", + experimentName, + value, + defaultValue, + e); + return defaultValue; + } + } + @VisibleForTesting public final long getBacklogBytes() { return backlogBytes; @@ -727,6 +780,9 @@ public Map> flushState() { } public boolean advance() { + if (!multiKeyBundleEnabled) { + return false; + } if (workIsFailed()) { throw new WorkItemCancelledException(checkStateNotNull(work).getWorkItem().getShardingKey()); } @@ -758,7 +814,10 @@ private boolean shouldStopBatching() { return true; } long elapsedNanos = System.nanoTime() - bundleStartTimeNanos; - return elapsedNanos >= maxKeyGroupBatchTimeNanos; + if (elapsedNanos >= maxKeyGroupBatchTimeNanos) { + return true; + } + return getBytesSinked() >= maxKeyGroupBatchSinkBytes; } private void startForNewKey(Work newWork) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index ce5d68b1a526..539a17e97508 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -153,6 +153,9 @@ private StreamingModeExecutionContext createExecutionContext( public void setUp() { MockitoAnnotations.initMocks(this); options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + options + .as(ExperimentalOptions.class) + .setExperiments(Arrays.asList("unstable_enable_multi_key_bundle")); globalConfigHandle = new FakeGlobalConfigHandle(StreamingGlobalConfig.builder().build()); executionContext = createExecutionContext(options, globalConfigHandle); } @@ -579,7 +582,7 @@ public void testAdvance_respectsMaxBatchSize() throws Exception { PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); optionsWithBatchSize .as(ExperimentalOptions.class) - .setExperiments(Arrays.asList("max_key_group_batch_size=1")); + .setExperiments(Arrays.asList("windmill_max_key_group_batch_size=1")); StreamingModeExecutionContext context = createExecutionContext(optionsWithBatchSize, globalConfigHandle); @@ -610,7 +613,7 @@ public void testAdvance_respectsMaxBatchTime() throws Exception { PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); optionsWithBatchTime .as(ExperimentalOptions.class) - .setExperiments(Arrays.asList("max_key_group_batch_time_ms=0")); + .setExperiments(Arrays.asList("windmill_max_key_group_batch_time_ms=0")); StreamingModeExecutionContext context = createExecutionContext(optionsWithBatchTime, globalConfigHandle); @@ -680,4 +683,92 @@ public void testAdvance_defaultKeyGroup() throws Exception { assertFalse(executionContext.advance()); org.mockito.Mockito.verifyNoInteractions(mockExecutor); } + + @Test + public void testAdvance_experimentDisabled() throws Exception { + DataflowWorkerHarnessOptions optionsDisabled = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + StreamingModeExecutionContext context = + createExecutionContext(optionsDisabled, globalConfigHandle); + + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + context.start(work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(context.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } + + @Test + public void testAdvance_respectsMaxBatchSinkBytes() throws Exception { + DataflowWorkerHarnessOptions optionsWithSinkBytes = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + optionsWithSinkBytes + .as(ExperimentalOptions.class) + .setExperiments( + Arrays.asList( + "unstable_enable_multi_key_bundle", "windmill_max_key_group_batch_sink_bytes=100")); + StreamingModeExecutionContext context = + createExecutionContext(optionsWithSinkBytes, globalConfigHandle); + + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + context.start(work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + context.reportBytesSinked(50); + assertFalse(context.advance()); + org.mockito.Mockito.verify(mockExecutor) + .pollWork(COMPUTATION_ID, work1.getKeyGroup(), mockHandle); + + org.mockito.Mockito.reset(mockExecutor); + + context.reportBytesSinked(60); + assertFalse(context.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } + + @Test + public void testExperimentParsingWithInvalidValues() { + DataflowWorkerHarnessOptions optionsInvalid = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + optionsInvalid + .as(ExperimentalOptions.class) + .setExperiments( + Arrays.asList( + "windmill_max_key_group_batch_size=invalid_size", + "windmill_max_key_group_batch_time_ms=invalid_time", + "windmill_max_key_group_batch_sink_bytes=invalid_bytes")); + + // This should not throw NumberFormatException + StreamingModeExecutionContext context = + createExecutionContext(optionsInvalid, globalConfigHandle); + + org.junit.Assert.assertNotNull(context); + } } From 0df9f546a257d2b1084c30f51145da715b5e471a Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Thu, 11 Jun 2026 14:26:12 +0000 Subject: [PATCH 21/21] spotless --- .../runners/dataflow/worker/StreamingModeExecutionContext.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 0618a6e17b16..c40eed196a10 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -267,7 +267,8 @@ public StreamingModeExecutionContext( this.maxKeyGroupBatchTimeNanos = TimeUnit.MILLISECONDS.toNanos(batchTimeMs); this.multiKeyBundleEnabled = - ExperimentalOptions.hasExperiment(options, StreamingDataflowWorker.UNSTABLE_ENABLE_MULTI_KEY_BUNDLE); + ExperimentalOptions.hasExperiment( + options, StreamingDataflowWorker.UNSTABLE_ENABLE_MULTI_KEY_BUNDLE); this.maxKeyGroupBatchSinkBytes = tryParseLong(