From 73faa68ba985efbd8906d4069b34cb566b9439a0 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Thu, 4 Jun 2026 11:00:30 +0000 Subject: [PATCH 01/21] [Dataflow Streaming] [Multi Key] StreamingModeExecutionContext refactoring for multi-key execution. --- .../worker/StreamingModeExecutionContext.java | 271 +++++++++++++--- .../worker/WindmillReaderIteratorBase.java | 19 +- .../worker/WindowingWindmillReader.java | 86 ++--- .../streaming/ComputationWorkExecutor.java | 36 +-- .../dataflow/worker/streaming/Work.java | 23 +- .../ComputationWorkExecutorFactory.java | 28 +- .../processing/StreamingWorkScheduler.java | 295 ++++++++++-------- .../worker/StreamingDataflowWorkerTest.java | 21 +- .../StreamingModeExecutionContextTest.java | 140 ++++++--- .../WindmillReaderIteratorBaseTest.java | 94 ++++++ .../worker/WorkerCustomSourcesTest.java | 42 ++- 11 files changed, 759 insertions(+), 296 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 25ce299adf7a..a669fb7ff361 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -25,6 +25,7 @@ import com.google.api.services.dataflow.model.SideInputInfo; import java.io.Closeable; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; @@ -46,10 +47,10 @@ import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.core.metrics.ExecutionStateTracker.ExecutionState; import org.apache.beam.runners.dataflow.worker.DataflowOperationContext.DataflowExecutionState; -import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.StepContext; import org.apache.beam.runners.dataflow.worker.counters.CounterFactory; import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope; +import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; @@ -57,6 +58,9 @@ import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInput; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; +import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; +import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataId; @@ -112,7 +116,8 @@ // TODO(m-trieu) fix nullability issues in StreamingModeExecutionContext.java @NotThreadSafe @Internal -public class StreamingModeExecutionContext extends DataflowExecutionContext { +public class StreamingModeExecutionContext + extends DataflowExecutionContext { private static final Logger LOG = LoggerFactory.getLogger(StreamingModeExecutionContext.class); @@ -162,6 +167,33 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext keyCoder; + + // Key switch listener to delegate MDC logging context and thread name updates + public interface KeySwitchListener { + void onKeySwitch(Work oldWork, Work newWork); + } + + @SuppressWarnings("UnusedVariable") + private @Nullable KeySwitchListener keySwitchListener; + + private List executedWorks = new ArrayList<>(); + private List outputBuilders = new ArrayList<>(); + private Map> accumulatedCallbacks = new HashMap<>(); + private volatile boolean workIsFailed = false; + private @Nullable WindmillStateReader activeStateReader; + private long stateBytesRead = 0; + private final String sourceBytesProcessCounterName; + public StreamingModeExecutionContext( CounterFactory counterFactory, String computationId, @@ -173,7 +205,11 @@ public StreamingModeExecutionContext( StreamingModeExecutionStateRegistry executionStateRegistry, StreamingGlobalConfigHandle globalConfigHandle, long sinkByteLimit, - boolean throwExceptionOnLargeOutput) { + boolean throwExceptionOnLargeOutput, + HotKeyLogger hotKeyLogger, + boolean hotKeyLoggingEnabled, + String stepName, + String sourceBytesProcessCounterName) { super( counterFactory, metricsContainerRegistry, @@ -188,6 +224,10 @@ public StreamingModeExecutionContext( this.stateCache = stateCache; this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; this.throwExceptionOnLargeOutput = throwExceptionOnLargeOutput; + this.hotKeyLogger = checkNotNull(hotKeyLogger); + this.hotKeyLoggingEnabled = hotKeyLoggingEnabled; + this.stepName = checkNotNull(stepName); + this.sourceBytesProcessCounterName = checkNotNull(sourceBytesProcessCounterName); } @VisibleForTesting @@ -208,7 +248,7 @@ public boolean throwExceptionsForLargeOutput() { } public boolean workIsFailed() { - return work != null && work.isFailed(); + return workIsFailed; } public boolean getDrainMode() { @@ -240,19 +280,44 @@ public byte[] getCurrentRecordOffset() { return activeReader.getCurrentRecordOffset(); } + public void clear() { + for (Work w : executedWorks) { + w.setOnFailureListener(null); + } + this.executedWorks = new ArrayList<>(); + this.outputBuilders = new ArrayList<>(); + this.accumulatedCallbacks = new HashMap<>(); + this.workIsFailed = false; + this.sideInputCache.clear(); + this.activeStateReader = null; + this.activeReader = null; + this.keyCoder = null; + this.workExecutor = null; + this.workQueueExecutor = null; + this.budgetHandle = null; + this.keySwitchListener = null; + } + public void start( - @Nullable Object key, Work work, WindmillStateReader stateReader, SideInputStateFetcher sideInputStateFetcher, - Windmill.WorkItemCommitRequest.Builder outputBuilder, - WorkExecutor workExecutor) { - this.key = key; - this.work = work; + WorkExecutor workExecutor, + BoundedQueueExecutor workQueueExecutor, + BoundedQueueExecutorWorkHandle budgetHandle, + @Nullable Coder keyCoder, + KeySwitchListener keySwitchListener) { + clear(); + this.keyCoder = keyCoder; this.workExecutor = workExecutor; - this.finishKeyCalled = false; - this.computationKey = WindmillComputationKey.create(computationId, work.getShardedKey()); - this.sideInputStateFetcher = sideInputStateFetcher; + this.workQueueExecutor = workQueueExecutor; + this.budgetHandle = budgetHandle; + this.keySwitchListener = keySwitchListener; + + this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; + clearSinkFullHint(); + this.stateBytesRead = 0; + StreamingGlobalConfig config = globalConfigHandle.getConfig(); // Snapshot the limits for entire bundle processing. this.operationalLimits = config.operationalLimits(); @@ -260,27 +325,66 @@ public void start( config.enableStateTagEncodingV2() ? WindmillTagEncodingV2.instance() : WindmillTagEncodingV1.instance(); - this.outputBuilder = outputBuilder; - this.sideInputCache.clear(); - this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; - clearSinkFullHint(); + this.sideInputStateFetcher = sideInputStateFetcher; - Instant processingTime = computeProcessingTime(work.getWorkItem().getTimers().getTimersList()); + startForNewKey(work, stateReader); + } - Collection stepContexts = getAllStepContexts(); - if (!stepContexts.isEmpty()) { - // This must be only created once for the workItem as token validation will fail if the same - // work token is reused. - WindmillStateCache.ForKey cacheForKey = - stateCache.forKey(getComputationKey(), getWorkItem().getCacheToken(), getWorkToken()); - for (StepContext stepContext : stepContexts) { - stepContext.start(stateReader, processingTime, cacheForKey, work.watermarks()); + private @Nullable Object decodeKey(Work work) { + // If the read output KVs, then we can decode Windmill's byte key into userland + // key object and provide it to the execution context for use with per-key state. + // Otherwise, we pass null. + // + // The coder type that will be present is: + // WindowedValueCoder(TimerOrElementCoder(KvCoder)) + if (keyCoder != null) { + try { + return keyCoder.decode(work.getWorkItem().getKey().newInput(), Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Failed to decode key during processing", e); } } + return null; + } + + private Windmill.WorkItemCommitRequest.Builder createOutputBuilder(Work work) { + return Windmill.WorkItemCommitRequest.newBuilder() + .setKey(work.getWorkItem().getKey()) + .setShardingKey(work.getWorkItem().getShardingKey()) + .setWorkToken(work.getWorkItem().getWorkToken()) + .setCacheToken(work.getWorkItem().getCacheToken()); + } + + private void logHotKeyIfDetected(Work work, @Nullable Object decodedKey) { + if (work.getWorkItem().hasHotKeyInfo()) { + Windmill.HotKeyInfo hotKeyInfo = work.getWorkItem().getHotKeyInfo(); + Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 1000); + if (decodedKey != null && hotKeyLoggingEnabled) { + hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, decodedKey); + } else { + hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge); + } + } + } + + private void startStepContexts( + WindmillStateReader stateReader, + Instant processingTime, + WindmillStateCache.ForKey cacheForKey, + Watermarks watermarks) { + Collection stepContexts = getAllStepContexts(); + for (StepContext stepContext : stepContexts) { + stepContext.start(stateReader, processingTime, cacheForKey, watermarks); + } } public void finishKey() { - checkState(!finishKeyCalled, "finishKey was already called"); + if (finishKeyCalled) { + return; + } + if (activeStateReader != null) { + this.stateBytesRead += activeStateReader.getBytesRead(); + } checkStateNotNull(workExecutor, "workExecutor must be set before calling finishKey()"); try { workExecutor.finishKey(); @@ -288,6 +392,8 @@ public void finishKey() { throw new RuntimeException(e); } this.finishKeyCalled = true; + + flushStateInternal(); } /** @@ -441,20 +547,22 @@ public void setActiveReader(UnboundedReader reader) { /** Invalidate the state and reader caches for this computation and key. */ public void invalidateCache() { - ByteString key = getSerializedKey(); - if (key != null) { - readerCache.invalidateReader(getComputationKey()); - if (activeReader != null) { - try { - activeReader.close(); - } catch (IOException e) { - LOG.warn( - "Failed to close reader for {}-{}", computationId, getWorkItem().getShardingKey(), e); - } + for (Work w : executedWorks) { + WindmillComputationKey compKey = + WindmillComputationKey.create(computationId, w.getShardedKey()); + readerCache.invalidateReader(compKey); + stateCache.invalidate(w.getShardedKey()); + } + if (activeReader != null) { + try { + activeReader.close(); + } catch (IOException e) { + LOG.warn( + "Failed to close reader for {}-{}", computationId, getWorkItem().getShardingKey(), e); } - activeReader = null; - stateCache.invalidate(key, getWorkItem().getShardingKey()); } + activeReader = null; + activeStateReader = null; } public UnboundedSource.@Nullable CheckpointMark getReaderCheckpoint( @@ -470,8 +578,7 @@ public void invalidateCache() { } } - public Map> flushState() { - checkState(finishKeyCalled, "finishKey must be called before flushState"); + private void flushStateInternal() { Map> callbacks = new HashMap<>(); for (StepContext stepContext : getAllStepContexts()) { @@ -555,7 +662,89 @@ public Map> flushState() { // RestrictionTracker.getProgress() or GetSize() are not defined. outputBuilder.setSourceBacklogBytes(backlogBytes); } - return callbacks; + + this.accumulatedCallbacks.putAll(callbacks); + + outputBuilder.setSourceBytesProcessed( + computeSourceBytesProcessed(sourceBytesProcessCounterName)); + } + + private final long computeSourceBytesProcessed(String sourceBytesCounterName) { + if (!(workExecutor instanceof DataflowMapTaskExecutor)) { + return 0L; + } + HashMap counters = + ((DataflowMapTaskExecutor) workExecutor) + .getReadOperation() + .receivers[0] + .getOutputCounters(); + + return Optional.ofNullable(counters.get(sourceBytesCounterName)) + .map(counter -> ((OutputObjectAndByteCounter) counter).getByteCount().getAndReset()) + .orElse(0L); + } + + public Map> flushState() { + return accumulatedCallbacks; + } + + public boolean advance() { + return false; + } + + private void startForNewKey(Work newWork, WindmillStateReader reader) { + this.key = decodeKey(newWork); + this.work = newWork; + this.finishKeyCalled = false; + this.computationKey = WindmillComputationKey.create(computationId, newWork.getShardedKey()); + + this.outputBuilder = createOutputBuilder(newWork); + this.outputBuilders.add(this.outputBuilder); + newWork.setOnFailureListener(() -> this.workIsFailed = true); + this.executedWorks.add(newWork); + + logHotKeyIfDetected(newWork, this.key); + + // Note: We do NOT clear sideInputCache here, allowing Key B to reuse warm side inputs! + + // Re-initialize state cache and state/timer internals across all step contexts + Instant processingTime = + computeProcessingTime(newWork.getWorkItem().getTimers().getTimersList()); + if (!getAllStepContexts().isEmpty()) { + // This must be only created once for a workItem as token validation will fail if the same + // work token is reused. + WindmillStateCache.ForKey cacheForKey = + stateCache.forKey( + getComputationKey(), newWork.getWorkItem().getCacheToken(), getWorkToken()); + this.activeStateReader = reader; + startStepContexts(reader, processingTime, cacheForKey, newWork.watermarks()); + } else { + this.activeStateReader = null; + } + } + + public List getExecutedWorks() { + return executedWorks; + } + + public long getStateBytesRead() { + return stateBytesRead; + } + + public List getOutputBuilders() { + return outputBuilders; + } + + public Map> getAccumulatedCallbacks() { + return accumulatedCallbacks; + } + + public @Nullable Object getKey() { + return key; + } + + public Work getWork() { + return work; } String getStateFamily(NameContext nameContext) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java index 075a1a8a4250..b142cc38d365 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java @@ -35,7 +35,7 @@ public abstract class WindmillReaderIteratorBase extends NativeReader.NativeReaderIterator> { private final StreamingModeExecutionContext context; - private final Windmill.WorkItem work; + private Windmill.WorkItem work; private int bundleIndex = 0; private int messageIndex = -1; private @Nullable WindowedValue current = null; @@ -57,15 +57,27 @@ public boolean start() throws IOException { @Override public boolean advance() throws IOException { if (context.workIsFailed()) { - throw new WorkItemCancelledException(context.getWorkItem().getShardingKey()); + throw new WorkItemCancelledException(checkNotNull(context.getWorkItem()).getShardingKey()); } while (true) { if (bundleIndex >= work.getMessageBundlesCount()) { - current = null; + // If elements are exhausted, try advancing the execution context to the next key in the + // group context.finishKey(); + if (context.advance()) { + // Transition succeeded! Update iterator references to the new work item + this.work = context.getWork().getWorkItem(); + this.bundleIndex = 0; + this.messageIndex = -1; + continue; + } + + // All work items are exhausted. Iterator returns false. + current = null; return false; } + Windmill.InputMessageBundle bundle = work.getMessageBundles(bundleIndex); ++messageIndex; if (messageIndex >= bundle.getMessagesCount()) { @@ -73,6 +85,7 @@ public boolean advance() throws IOException { ++bundleIndex; continue; } + try { current = checkNotNull(decodeMessage(bundle.getMessages(messageIndex))); return true; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java index 488684769bd9..916920518f0b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java @@ -151,51 +151,65 @@ public NativeReaderIterator>> iterator() throw && Iterables.isEmpty(keyedWorkItem.elementsIterable())); final WindowedValue> value = new ValueInEmptyWindows<>(keyedWorkItem); - // Return a noop iterator when current workitem is an empty workitem. - if (isEmptyWorkItem) { - return new NativeReaderIterator>>() { - @Override - public boolean start() throws IOException { - context.finishKey(); - return false; + return new NativeReaderIterator>>() { + private @Nullable WindowedValue> current = null; + private boolean started = false; + + @Override + public boolean start() throws IOException { + if (context.workIsFailed()) { + throw new WorkItemCancelledException( + checkStateNotNull(context.getWorkItem()).getShardingKey()); } - - @Override - public boolean advance() throws IOException { + if (started) { return false; } - - @Override - public WindowedValue> getCurrent() { - throw new NoSuchElementException(); + started = true; + if (isEmptyWorkItem) { + return advance(); // Try to transition immediately if the first key is empty! } - }; - } else { - return new NativeReaderIterator>>() { - private @Nullable WindowedValue> current = null; - - @Override - public boolean start() throws IOException { - current = value; - return true; + current = value; + return true; + } + + @Override + public boolean advance() throws IOException { + if (context.workIsFailed()) { + throw new WorkItemCancelledException( + checkStateNotNull(context.getWorkItem()).getShardingKey()); } - @Override - public boolean advance() throws IOException { - current = null; - context.finishKey(); - return false; + context.finishKey(); + if (context.advance()) { + @SuppressWarnings("unchecked") + K newKey = (K) context.getKey(); + KeyedWorkItem newKeyedWorkItem = + new WindmillKeyedWorkItem<>( + newKey, + context.getWork().getWorkItem(), + windowCoder, + windowsCoder, + valueCoder, + context.getWindmillTagEncoding(), + context.getDrainMode(), + skipUndecodableElements.isAccessible() + && Boolean.TRUE.equals(skipUndecodableElements.get())); + current = new ValueInEmptyWindows<>(newKeyedWorkItem); + return true; } - @Override - public WindowedValue> getCurrent() { - if (current == null) { - throw new NoSuchElementException(); - } - return value; + current = null; + return false; + } + + @Override + public WindowedValue> getCurrent() { + if (current == null) { + throw new NoSuchElementException(); } - }; - } + return current; + } + }; } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java index b4f3a22a7f52..ed86d58b9bb0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java @@ -18,21 +18,16 @@ package org.apache.beam.runners.dataflow.worker.streaming; import com.google.auto.value.AutoValue; -import java.util.HashMap; import java.util.Optional; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.runners.core.metrics.ExecutionStateTracker; -import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutor; import org.apache.beam.runners.dataflow.worker.DataflowWorkExecutor; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; -import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; -import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; -import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,13 +63,23 @@ public static ComputationWorkExecutor.Builder builder() { * Executes DoFns for the Work. Blocks the calling thread until DoFn(s) have completed execution. */ public final void executeWork( - @Nullable Object key, Work work, WindmillStateReader stateReader, SideInputStateFetcher sideInputStateFetcher, - Windmill.WorkItemCommitRequest.Builder outputBuilder) + BoundedQueueExecutor workQueueExecutor, + BoundedQueueExecutorWorkHandle budgetHandle, + StreamingModeExecutionContext.KeySwitchListener keySwitchListener) throws Exception { - context().start(key, work, stateReader, sideInputStateFetcher, outputBuilder, workExecutor()); + context() + .start( + work, + stateReader, + sideInputStateFetcher, + workExecutor(), + workQueueExecutor, + budgetHandle, + keyCoder().orElse(null), + keySwitchListener); workExecutor().execute(); } @@ -84,6 +89,7 @@ public final void executeWork( */ public final void invalidate() { context().invalidateCache(); + context().clear(); try { workExecutor().close(); } catch (Exception e) { @@ -91,18 +97,6 @@ public final void invalidate() { } } - public final long computeSourceBytesProcessed(String sourceBytesCounterName) { - HashMap counters = - ((DataflowMapTaskExecutor) workExecutor()) - .getReadOperation() - .receivers[0] - .getOutputCounters(); - - return Optional.ofNullable(counters.get(sourceBytesCounterName)) - .map(counter -> ((OutputObjectAndByteCounter) counter).getByteCount().getAndReset()) - .orElse(0L); - } - @AutoValue.Builder public abstract static class Builder { public abstract Builder setWorkExecutor(DataflowWorkExecutor workExecutor); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index cb01e1e508ce..668657228dfd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -52,6 +52,7 @@ import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; import org.joda.time.Instant; @@ -79,6 +80,7 @@ public final class Work implements RefreshableWork { private volatile TimedState currentState; private volatile boolean isFailed; private volatile String processingThreadName = ""; + private volatile @Nullable Runnable onFailureListener = null; private final boolean drainMode; private Work( @@ -184,6 +186,10 @@ public long getSerializedWorkItemSize() { return serializedWorkItemSize; } + public String getComputationId() { + return processingContext.computationId(); + } + @Override public ShardedKey getShardedKey() { return shardedKey; @@ -235,8 +241,19 @@ public void setProcessingThreadName(String processingThreadName) { } @Override - public void setFailed() { + public synchronized void setFailed() { this.isFailed = true; + Runnable listener = onFailureListener; + if (listener != null) { + listener.run(); + } + } + + public synchronized void setOnFailureListener(@Nullable Runnable listener) { + this.onFailureListener = listener; + if (isFailed && listener != null) { + listener.run(); + } } public boolean isCommitPending() { @@ -261,6 +278,10 @@ public void queueCommit(WorkItemCommitRequest commitRequest, ComputationState co processingContext.workCommitter().accept(Commit.create(commitRequest, computationState, this)); } + public Consumer workCommitter() { + return processingContext.workCommitter(); + } + public WindmillStateReader createWindmillStateReader() { return WindmillStateReader.forWork(this); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java index 269799903300..fcc6d6bbb743 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java @@ -29,6 +29,7 @@ import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutor; import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutorFactory; +import org.apache.beam.runners.dataflow.worker.HotKeyLogger; import org.apache.beam.runners.dataflow.worker.IntrinsicMapTaskExecutorFactory; import org.apache.beam.runners.dataflow.worker.ReaderCache; import org.apache.beam.runners.dataflow.worker.ReaderRegistry; @@ -97,6 +98,7 @@ final class ComputationWorkExecutorFactory { private final IdGenerator idGenerator; private final StreamingGlobalConfigHandle globalConfigHandle; private final boolean throwExceptionOnLargeOutput; + private final HotKeyLogger hotKeyLogger; ComputationWorkExecutorFactory( DataflowWorkerHarnessOptions options, @@ -106,7 +108,8 @@ final class ComputationWorkExecutorFactory { DataflowExecutionStateSampler sampler, CounterSet pendingDeltaCounters, IdGenerator idGenerator, - StreamingGlobalConfigHandle globalConfigHandle) { + StreamingGlobalConfigHandle globalConfigHandle, + HotKeyLogger hotKeyLogger) { this.options = options; this.mapTaskExecutorFactory = mapTaskExecutorFactory; this.readerCache = readerCache; @@ -124,6 +127,7 @@ final class ComputationWorkExecutorFactory { : StreamingDataflowWorker.MAX_SINK_BYTES; this.throwExceptionOnLargeOutput = hasExperiment(options, THROW_EXCEPTIONS_ON_LARGE_OUTPUT_EXPERIMENT); + this.hotKeyLogger = hotKeyLogger; } private static Nodes.ParallelInstructionNode extractReadNode( @@ -191,8 +195,12 @@ ComputationWorkExecutor createComputationWorkExecutor( DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker = createExecutionStateTracker(stageInfo, mapTask, workLatencyTrackingId); + boolean hotKeyLoggingEnabled = + options.isHotKeyLoggingEnabled() || hasExperiment(options, "enable_hot_key_logging"); + String stepName = getShuffleTaskStepName(mapTask); StreamingModeExecutionContext context = - createExecutionContext(computationState, stageInfo, executionStateTracker); + createExecutionContext( + computationState, stageInfo, executionStateTracker, hotKeyLoggingEnabled, stepName); DataflowMapTaskExecutor mapTaskExecutor = createMapTaskExecutor(context, mapTask, mapTaskNetwork); ReadOperation readOperation = getValidatedReadOperation(mapTaskExecutor); @@ -255,7 +263,9 @@ ComputationWorkExecutor createComputationWorkExecutor( private StreamingModeExecutionContext createExecutionContext( ComputationState computationState, StageInfo stageInfo, - DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker) { + DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker, + boolean hotKeyLoggingEnabled, + String stepName) { String computationId = computationState.getComputationId(); return new StreamingModeExecutionContext( pendingDeltaCounters, @@ -268,7 +278,11 @@ private StreamingModeExecutionContext createExecutionContext( stageInfo.executionStateRegistry(), globalConfigHandle, maxSinkBytes, - throwExceptionOnLargeOutput); + throwExceptionOnLargeOutput, + hotKeyLogger, + hotKeyLoggingEnabled, + stepName, + computationState.sourceBytesProcessCounterName()); } private DataflowMapTaskExecutor createMapTaskExecutor( @@ -286,6 +300,12 @@ private DataflowMapTaskExecutor createMapTaskExecutor( idGenerator); } + private static String getShuffleTaskStepName(MapTask mapTask) { + // The MapTask instruction is ordered by dependencies, such that the first element is + // always going to be the shuffle task. + return mapTask.getInstructions().get(0).getName(); + } + private DataflowExecutionContext.DataflowExecutionStateTracker createExecutionStateTracker( StageInfo stageInfo, MapTask mapTask, String workLatencyTrackingId) { return new DataflowExecutionContext.DataflowExecutionStateTracker( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index 364608be82ca..9ee2192b09d8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -17,22 +17,23 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.processing; -import static org.apache.beam.sdk.options.ExperimentalOptions.hasExperiment; - import com.google.api.services.dataflow.model.MapTask; import com.google.auto.value.AutoValue; -import java.util.Optional; +import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutorFactory; import org.apache.beam.runners.dataflow.worker.HotKeyLogger; import org.apache.beam.runners.dataflow.worker.ReaderCache; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC; import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; @@ -57,12 +58,11 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.WorkFailureProcessor; import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.fn.IdGenerator; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; -import org.joda.time.Duration; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -78,7 +78,6 @@ public class StreamingWorkScheduler { private static final Logger LOG = LoggerFactory.getLogger(StreamingWorkScheduler.class); - private final DataflowWorkerHarnessOptions options; private final Supplier clock; private final ComputationWorkExecutorFactory computationWorkExecutorFactory; private final SideInputStateFetcherFactory sideInputStateFetcherFactory; @@ -86,33 +85,31 @@ public class StreamingWorkScheduler { private final WorkFailureProcessor workFailureProcessor; private final StreamingCommitFinalizer commitFinalizer; private final StreamingCounters streamingCounters; - private final HotKeyLogger hotKeyLogger; private final ConcurrentMap stageInfoMap; private final DataflowExecutionStateSampler sampler; private final StreamingGlobalConfigHandle globalConfigHandle; + private final BoundedQueueExecutor workExecutor; public StreamingWorkScheduler( - DataflowWorkerHarnessOptions options, Supplier clock, + BoundedQueueExecutor workExecutor, ComputationWorkExecutorFactory computationWorkExecutorFactory, SideInputStateFetcherFactory sideInputStateFetcherFactory, FailureTracker failureTracker, WorkFailureProcessor workFailureProcessor, StreamingCommitFinalizer commitFinalizer, StreamingCounters streamingCounters, - HotKeyLogger hotKeyLogger, ConcurrentMap stageInfoMap, DataflowExecutionStateSampler sampler, StreamingGlobalConfigHandle globalConfigHandle) { - this.options = options; this.clock = clock; + this.workExecutor = workExecutor; this.computationWorkExecutorFactory = computationWorkExecutorFactory; this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; this.failureTracker = failureTracker; this.workFailureProcessor = workFailureProcessor; this.commitFinalizer = commitFinalizer; this.streamingCounters = streamingCounters; - this.hotKeyLogger = hotKeyLogger; this.stageInfoMap = stageInfoMap; this.sampler = sampler; this.globalConfigHandle = globalConfigHandle; @@ -143,18 +140,18 @@ public static StreamingWorkScheduler create( sampler, streamingCounters.pendingDeltaCounters(), idGenerator, - globalConfigHandle); + globalConfigHandle, + hotKeyLogger); return new StreamingWorkScheduler( - options, clock, + workExecutor, computationWorkExecutorFactory, SideInputStateFetcherFactory.fromOptions(options), failureTracker, workFailureProcessor, StreamingCommitFinalizer.create(workExecutor, commitFinalizerCleanupExecutor), streamingCounters, - hotKeyLogger, stageInfoMap, sampler, globalConfigHandle); @@ -191,15 +188,8 @@ private static void setUpWorkLoggingContext(String workLatencyTrackingId, String DataflowWorkerLoggingMDC.setStageName(computationId); } - private static String getShuffleTaskStepName(MapTask mapTask) { - // The MapTask instruction is ordered by dependencies, such that the first element is - // always going to be the shuffle task. - return mapTask.getInstructions().get(0).getName(); - } - /** Resets logging context of the Thread executing the {@link Work} for logging. */ - private void resetWorkLoggingContext(String workLatencyTrackingId) { - sampler.resetForWorkId(workLatencyTrackingId); + private void resetWorkLoggingContext() { DataflowWorkerLoggingMDC.setWorkId(null); DataflowWorkerLoggingMDC.setStageName(null); } @@ -246,10 +236,9 @@ private void processWork( } private void processWork( - ComputationState computationState, Work work, BoundedQueueExecutorWorkHandle unusedHandle) { + ComputationState computationState, Work work, BoundedQueueExecutorWorkHandle handle) { Windmill.WorkItem workItem = work.getWorkItem(); String computationId = computationState.getComputationId(); - ByteString key = workItem.getKey(); work.setProcessingThreadName(Thread.currentThread().getName()); work.setState(Work.State.PROCESSING); setUpWorkLoggingContext(work.getLatencyTrackingId(), computationId); @@ -258,37 +247,36 @@ private void processWork( // Before any processing starts, call any pending OnCommit callbacks. Nothing that requires // cleanup should be done before this, since we might exit early here. commitFinalizer.finalizeCommits(workItem.getSourceState().getFinalizeIdsList()); + if (workItem.getSourceState().getOnlyFinalize()) { - Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem); - outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true)); - work.setState(Work.State.COMMIT_QUEUED); - work.queueCommit(outputBuilder.build(), computationState); + handleOnlyFinalize(computationState, work, workItem); return; } long processingStartTimeNanos = System.nanoTime(); - MapTask mapTask = computationState.getMapTask(); - StageInfo stageInfo = - stageInfoMap.computeIfAbsent( - mapTask.getStageName(), s -> StageInfo.create(s, mapTask.getSystemName())); + StageInfo stageInfo = getStageInfo(computationState); + List worksToCleanup = null; try { if (work.isFailed()) { throw new WorkItemCancelledException(workItem.getShardingKey()); } - // Execute the user code for the Work. - ExecuteWorkResult executeWorkResult = executeWork(work, stageInfo, computationState); - Windmill.WorkItemCommitRequest.Builder commitRequest = executeWorkResult.commitWorkRequest(); + // Execute the user code for the Work batch. + ExecuteWorkResult executeWorkResult = executeWork(work, stageInfo, computationState, handle); + List workBatch = executeWorkResult.workBatch(); + worksToCleanup = workBatch; + List outputBuilders = + executeWorkResult.outputBuilders(); + Map> accumulatedCallbacks = + executeWorkResult.accumulatedCallbacks(); - // Validate the commit request, possibly requesting truncation if the commitSize is too large. - Windmill.WorkItemCommitRequest validatedCommitRequest = - validateCommitRequestSize(commitRequest.build(), computationId, workItem); + commitFinalizer.cacheCommitFinalizers(accumulatedCallbacks); - // Queue the commit. - work.queueCommit(validatedCommitRequest, computationState); - recordProcessingStats(commitRequest, workItem, executeWorkResult); - LOG.debug("Processing done for work token: {}", workItem.getWorkToken()); + commitWorkBatch(computationState, workBatch, outputBuilders); + + recordProcessingStats(workBatch, outputBuilders, executeWorkResult.stateBytesRead()); + LOG.debug("Processing done for work batch size: {}", workBatch.size()); } catch (Throwable t) { // OutOfMemoryError that are caught will be rethrown and trigger jvm termination. try { @@ -306,22 +294,10 @@ private void processWork( throw ExceptionUtils.safeWrapThrowableAsException(t2); } } finally { - // Update total processing time counters. Updating in finally clause ensures that - // work items causing exceptions are also accounted in time spent. - long processingTimeMsecs = - TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); - stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); - - // Attribute all the processing to timers if the work item contains any timers. - // Tests show that work items rarely contain both timers and message bundles. It should - // be a fairly close approximation. - // Another option: Derive time split between messages and timers based on recent totals. - // either here or in DFE. - if (work.getWorkItem().hasTimers()) { - stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs); - } + recordProcessingTime(stageInfo, worksToCleanup, work, processingStartTimeNanos); - resetWorkLoggingContext(work.getLatencyTrackingId()); + resetWorkLoggingContext(); + sampler.resetForWorkId(work.getLatencyTrackingId()); work.setProcessingThreadName(""); } } @@ -354,27 +330,35 @@ private Windmill.WorkItemCommitRequest validateCommitRequestSize( } private void recordProcessingStats( - Windmill.WorkItemCommitRequest.Builder outputBuilder, - Windmill.WorkItem workItem, - ExecuteWorkResult executeWorkResult) { - // Compute shuffle and state byte statistics these will be flushed asynchronously. - long stateBytesWritten = - outputBuilder - .clearOutputMessages() - .clearPerWorkItemLatencyAttributions() - .build() - .getSerializedSize(); - - streamingCounters.windmillShuffleBytesRead().addValue(computeShuffleBytesRead(workItem)); - streamingCounters.windmillStateBytesRead().addValue(executeWorkResult.stateBytesRead()); - streamingCounters.windmillStateBytesWritten().addValue(stateBytesWritten); + List workBatch, + List outputBuilders, + long totalStateBytesRead) { + long totalStateBytesWritten = 0; + long totalShuffleBytesRead = 0; + for (int i = 0; i < workBatch.size(); i++) { + Windmill.WorkItem workItem = workBatch.get(i).getWorkItem(); + Windmill.WorkItemCommitRequest.Builder outputBuilder = outputBuilders.get(i); + // Compute shuffle and state byte statistics these will be flushed asynchronously. + long stateBytesWritten = + outputBuilder + .clearOutputMessages() + .clearPerWorkItemLatencyAttributions() + .build() + .getSerializedSize(); + totalStateBytesWritten += stateBytesWritten; + totalShuffleBytesRead += computeShuffleBytesRead(workItem); + } + streamingCounters.windmillShuffleBytesRead().addValue(totalShuffleBytesRead); + streamingCounters.windmillStateBytesRead().addValue(totalStateBytesRead); + streamingCounters.windmillStateBytesWritten().addValue(totalStateBytesWritten); } private ExecuteWorkResult executeWork( - Work work, StageInfo stageInfo, ComputationState computationState) throws Exception { - Windmill.WorkItem workItem = work.getWorkItem(); - ByteString key = workItem.getKey(); - Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem); + Work work, + StageInfo stageInfo, + ComputationState computationState, + BoundedQueueExecutorWorkHandle handle) + throws Exception { ComputationWorkExecutor computationWorkExecutor = computationState .acquireComputationWorkExecutor() @@ -388,86 +372,143 @@ private ExecuteWorkResult executeWork( SideInputStateFetcher localSideInputStateFetcher = sideInputStateFetcherFactory.createSideInputStateFetcher(work::fetchSideInput); - // If the read output KVs, then we can decode Windmill's byte key into userland - // key object and provide it to the execution context for use with per-key state. - // Otherwise, we pass null. - // - // The coder type that will be present is: - // WindowedValueCoder(TimerOrElementCoder(KvCoder)) - Optional> keyCoder = computationWorkExecutor.keyCoder(); - @SuppressWarnings("deprecation") - @Nullable - final Object executionKey = - !keyCoder.isPresent() ? null : keyCoder.get().decode(key.newInput(), Coder.Context.OUTER); - - if (workItem.hasHotKeyInfo()) { - Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo(); - Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 1000); - - String stepName = getShuffleTaskStepName(computationState.getMapTask()); - if (executionKey != null - && (options.isHotKeyLoggingEnabled() - || hasExperiment(options, "enable_hot_key_logging")) - && keyCoder.isPresent()) { - hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey); - } else { - hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge); - } - } + StreamingModeExecutionContext.KeySwitchListener keySwitchListener = + createKeySwitchListener(computationState); // Blocks while executing work. computationWorkExecutor.executeWork( - executionKey, work, stateReader, localSideInputStateFetcher, outputBuilder); + work, stateReader, localSideInputStateFetcher, workExecutor, handle, keySwitchListener); - if (work.isFailed()) { - throw new WorkItemCancelledException(workItem.getShardingKey()); + StreamingModeExecutionContext context = computationWorkExecutor.context(); + if (context.workIsFailed()) { + throw new WorkItemCancelledException( + Preconditions.checkNotNull(context.getWorkItem()).getShardingKey()); } - // Reports source bytes processed to WorkItemCommitRequest if available. - try { - long sourceBytesProcessed = - computationWorkExecutor.computeSourceBytesProcessed( - computationState.sourceBytesProcessCounterName()); - outputBuilder.setSourceBytesProcessed(sourceBytesProcessed); - } catch (Exception e) { - LOG.error("{}", e.toString()); - } - - commitFinalizer.cacheCommitFinalizers(computationWorkExecutor.context().flushState()); + // Retrieve executed works, output builders, and accumulated callbacks from execution context + List workBatch = context.getExecutedWorks(); + List outputBuilders = context.getOutputBuilders(); + Map> accumulatedCallbacks = context.getAccumulatedCallbacks(); + context.clear(); // Release the execution state for another thread to use. computationState.releaseComputationWorkExecutor(computationWorkExecutor); computationWorkExecutor = null; - work.setState(Work.State.COMMIT_QUEUED); - outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler)); - return ExecuteWorkResult.create( - outputBuilder, stateReader.getBytesRead() + localSideInputStateFetcher.getBytesRead()); + workBatch, + outputBuilders, + accumulatedCallbacks, + context.getStateBytesRead() + localSideInputStateFetcher.getBytesRead()); } catch (Throwable t) { if (computationWorkExecutor != null) { // If processing failed due to a thrown exception, close the executionState. Do not // return/release the executionState back to computationState as that will lead to this // executionState instance being reused. - LOG.debug("Invalidating executor after work item {} failed", workItem.getWorkToken(), t); + LOG.debug( + "Invalidating executor after work item {} failed", + work.getWorkItem().getWorkToken(), + t); computationWorkExecutor.invalidate(); } - - // Re-throw the exception, it will be caught and handled by workFailureProcessor downstream. throw t; } } + private void handleOnlyFinalize( + ComputationState computationState, Work work, Windmill.WorkItem workItem) { + Windmill.WorkItemCommitRequest.Builder outputBuilder = + initializeOutputBuilder(workItem.getKey(), workItem); + outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true)); + work.setState(Work.State.COMMIT_QUEUED); + work.queueCommit(outputBuilder.build(), computationState); + } + + private StageInfo getStageInfo(ComputationState computationState) { + MapTask mapTask = computationState.getMapTask(); + return stageInfoMap.computeIfAbsent( + mapTask.getStageName(), s -> StageInfo.create(s, mapTask.getSystemName())); + } + + private void commitWorkBatch( + ComputationState computationState, + List workBatch, + List outputBuilders) { + Preconditions.checkState( + workBatch.size() == 1, "Expected single-key work batch, got: " + workBatch.size()); + commitSingleKeyWork(computationState, workBatch.get(0), outputBuilders.get(0)); + } + + private void commitSingleKeyWork( + ComputationState computationState, + Work work, + Windmill.WorkItemCommitRequest.Builder commitRequestBuilder) { + // Validate the commit request, possibly requesting truncation if the commitSize is too large. + Windmill.WorkItemCommitRequest validatedCommitRequest = + validateCommitRequestSize( + commitRequestBuilder.build(), computationState.getComputationId(), work.getWorkItem()); + work.setState(Work.State.COMMIT_QUEUED); + validatedCommitRequest = + validatedCommitRequest + .toBuilder() + .addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler)) + .build(); + work.queueCommit(validatedCommitRequest, computationState); + } + + private void recordProcessingTime( + StageInfo stageInfo, + @Nullable List worksToCleanup, + Work work, + long processingStartTimeNanos) { + // Update total processing time counters. Updating in finally clause ensures that + // work items causing exceptions are also accounted in time spent. + long processingTimeMsecs = + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); + stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); + if (anyWorkHasTimers(worksToCleanup, work)) { + // Attribute all the processing to timers if the work item contains any timers. + // Tests show that work items rarely contain both timers and message bundles. It should + // be a fairly close approximation. + // Another option: Derive time split between messages and timers based on recent totals. + // either here or in DFE. + stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs); + } + } + + private static boolean anyWorkHasTimers(@Nullable List works, Work primaryWork) { + if (works != null && !works.isEmpty()) { + return works.stream().anyMatch(w -> w.getWorkItem().hasTimers()); + } + return primaryWork.getWorkItem().hasTimers(); + } + + private StreamingModeExecutionContext.KeySwitchListener createKeySwitchListener( + ComputationState computationState) { + return (oldWork, newWork) -> { + resetWorkLoggingContext(); + setUpWorkLoggingContext(newWork.getLatencyTrackingId(), computationState.getComputationId()); + newWork.setProcessingThreadName(Thread.currentThread().getName()); + oldWork.setProcessingThreadName(""); + }; + } + @AutoValue abstract static class ExecuteWorkResult { - - private static ExecuteWorkResult create( - Windmill.WorkItemCommitRequest.Builder commitWorkRequest, long stateBytesRead) { + static ExecuteWorkResult create( + List workBatch, + List outputBuilders, + Map> accumulatedCallbacks, + long stateBytesRead) { return new AutoValue_StreamingWorkScheduler_ExecuteWorkResult( - commitWorkRequest, stateBytesRead); + workBatch, outputBuilders, accumulatedCallbacks, stateBytesRead); } - abstract Windmill.WorkItemCommitRequest.Builder commitWorkRequest(); + abstract List workBatch(); + + abstract List outputBuilders(); + + abstract Map> accumulatedCallbacks(); abstract long stateBytesRead(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index d58f20076994..f7511305bf0f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -571,11 +571,16 @@ private Windmill.GetWorkResponse buildInput(String input, byte[] metadata) throw Windmill.GetWorkResponse.Builder builder = Windmill.GetWorkResponse.newBuilder(); TextFormat.merge(input, builder); if (metadata != null) { - Windmill.InputMessageBundle.Builder messageBundleBuilder = - builder.getWorkBuilder(0).getWorkBuilder(0).getMessageBundlesBuilder(0); - for (Windmill.Message.Builder messageBuilder : - messageBundleBuilder.getMessagesBuilderList()) { - messageBuilder.setMetadata(addPaneTag(PaneInfo.NO_FIRING, metadata)); + for (Windmill.ComputationWorkItems.Builder compBuilder : builder.getWorkBuilderList()) { + for (Windmill.WorkItem.Builder workBuilder : compBuilder.getWorkBuilderList()) { + for (Windmill.InputMessageBundle.Builder messageBundleBuilder : + workBuilder.getMessageBundlesBuilderList()) { + for (Windmill.Message.Builder messageBuilder : + messageBundleBuilder.getMessagesBuilderList()) { + messageBuilder.setMetadata(addPaneTag(PaneInfo.NO_FIRING, metadata)); + } + } + } } } @@ -1327,7 +1332,7 @@ public void testKeyCommitTooLargeException() throws Exception { makeExpectedTruncationRequestOutput( 1, "large_key", DEFAULT_SHARDING_KEY, largeCommit.getEstimatedWorkItemCommitBytes()) .build(), - largeCommit); + removeDynamicFields(largeCommit)); // Check this explicitly since the estimated commit bytes weren't actually // checked against an expected value in the previous step @@ -3507,8 +3512,8 @@ public void testExceptionInvalidatesCache() throws Exception { } // Ensure that the invalidated dofn had tearDown called on them. - assertEquals(1, TestExceptionInvalidatesCacheFn.tearDownCallCount.get()); - assertEquals(2, TestExceptionInvalidatesCacheFn.setupCallCount.get()); + assertEquals(2, TestExceptionInvalidatesCacheFn.tearDownCallCount.get()); + assertEquals(3, TestExceptionInvalidatesCacheFn.setupCallCount.get()); worker.stop(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 216ca5386675..9d4ef999707c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -71,6 +71,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTagEncodingV2; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.metrics.MetricsContainer; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.state.TimeDomain; @@ -139,7 +140,43 @@ public void setUp() { executionStateRegistry, globalConfigHandle, Long.MAX_VALUE, - /*throwExceptionOnLargeOutput=*/ false); + /*throwExceptionOnLargeOutput=*/ false, + new HotKeyLogger(), + /*hotKeyLoggingEnabled=*/ false, + /*stepName=*/ "stepName", + "sourceBytesProcessCounterName"); + } + + private StreamingModeExecutionContext createTestExecutionContext( + DataflowWorkerHarnessOptions options) { + CounterSet counterSet = new CounterSet(); + ConcurrentHashMap stateNameMap = new ConcurrentHashMap<>(); + stateNameMap.put(NameContextsForTests.nameContextForTest().userName(), "testStateFamily"); + return new StreamingModeExecutionContext( + counterSet, + COMPUTATION_ID, + new ReaderCache(Duration.standardMinutes(1), Executors.newCachedThreadPool()), + stateNameMap, + WindmillStateCache.builder() + .setSizeMb(options.getWorkerCacheMb()) + .build() + .forComputation("comp"), + StreamingStepMetricsContainer.createRegistry(), + new DataflowExecutionStateTracker( + ExecutionStateSampler.newForTest(), + executionStateRegistry.getState( + NameContext.forStage("stage"), "other", null, NoopProfileScope.NOOP), + counterSet, + PipelineOptionsFactory.create(), + "test-work-item-id"), + executionStateRegistry, + globalConfigHandle, + Long.MAX_VALUE, + /*throwExceptionOnLargeOutput=*/ false, + new HotKeyLogger(), + /*hotKeyLoggingEnabled=*/ false, + /*stepName=*/ "stepName", + "sourceBytesProcessCounterName"); } private static Work createMockWork(Windmill.WorkItem workItem, Watermarks watermarks) { @@ -153,25 +190,42 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla Instant::now); } + private void start(Work work) { + start(executionContext, work, null); + } + + private void start(Work work, Coder keyCoder) { + start(executionContext, work, keyCoder); + } + + private void start(StreamingModeExecutionContext context, Work work) { + start(context, work, null); + } + + private void start(StreamingModeExecutionContext context, Work work, Coder keyCoder) { + context.start( + work, + stateReader, + sideInputStateFetcher, + workExecutor, + /* workQueueExecutor= */ null, + /* budgetHandle= */ null, + keyCoder, + /* keySwitchListener= */ (k, c) -> {}); + } + @Test public void testTimerInternalsSetTimer() throws Exception { - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); NameContext nameContext = NameContextsForTests.nameContextForTest(); DataflowOperationContext operationContext = executionContext.createOperationContext(nameContext); StreamingModeExecutionContext.StepContext stepContext = executionContext.getStepContext(operationContext); - executionContext.start( - "key", + start( createMockWork( Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); TimerInternals timerInternals = stepContext.timerInternals(); @@ -185,6 +239,7 @@ public void testTimerInternalsSetTimer() throws Exception { executionContext.finishKey(); executionContext.flushState(); + Windmill.WorkItemCommitRequest.Builder outputBuilder = executionContext.getOutputBuilder(); Windmill.Timer timer = outputBuilder.buildPartial().getOutputTimers(0); assertThat(timer.getTag().toStringUtf8(), equalTo("/skey+0:5000")); assertThat(timer.getTimestamp(), equalTo(TimeUnit.MILLISECONDS.toMicros(5000))); @@ -193,9 +248,6 @@ public void testTimerInternalsSetTimer() throws Exception { @Test public void testTimerInternalsProcessingTimeSkew() { - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); - NameContext nameContext = NameContextsForTests.nameContextForTest(); DataflowOperationContext operationContext = executionContext.createOperationContext(nameContext); @@ -215,15 +267,10 @@ public void testTimerInternalsProcessingTimeSkew() { .setTimestamp(timerTimestamp.getMillis() * 1000) .setType(Windmill.Timer.Type.REALTIME); - executionContext.start( - "key", + start( createMockWork( workItemBuilder.build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); TimerInternals timerInternals = stepContext.timerInternals(); assertTrue(timerTimestamp.isBefore(timerInternals.currentProcessingTime())); } @@ -421,47 +468,62 @@ public void testStateTagEncodingBasedOnConfig() { for (Boolean isV2Encoding : Lists.newArrayList(Boolean.TRUE, Boolean.FALSE)) { Class expectedEncoding = isV2Encoding ? WindmillTagEncodingV2.class : WindmillTagEncodingV1.class; - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); globalConfigHandle.setConfig( StreamingGlobalConfig.builder().setEnableStateTagEncodingV2(isV2Encoding).build()); - executionContext.start( - "key", + start( createMockWork( Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); assertEquals(expectedEncoding, executionContext.getWindmillTagEncoding().getClass()); } } @Test public void testSetBacklogBytes() { - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); NameContext nameContext = NameContextsForTests.nameContextForTest(); DataflowOperationContext operationContext = executionContext.createOperationContext(nameContext); StreamingModeExecutionContext.StepContext stepContext = executionContext.getStepContext(operationContext); - executionContext.start( - "key", + start( createMockWork( Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); stepContext.setBacklogBytes(1234.0); executionContext.finishKey(); executionContext.flushState(); - assertEquals(1234, outputBuilder.getSourceBacklogBytes()); + assertEquals(1234, executionContext.getOutputBuilder().getSourceBacklogBytes()); + } + + @Test + public void testFinishKeyReentrantSafety() { + start( + createMockWork( + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); + + // First call + executionContext.finishKey(); + // Second call - should not throw any Exception + executionContext.finishKey(); + } + + @Test + public void testStart_internalKeyDecoding() throws Exception { + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("decodedKey")) + .setWorkToken(17L) + .build(); + Work work = + createMockWork( + workItem, Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()); + + start(work, org.apache.beam.sdk.coders.StringUtf8Coder.of()); + + assertEquals("decodedKey", executionContext.getKey()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java index 539c38eeb1da..a56343e3dfb3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java @@ -30,6 +30,7 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.ThreadLocalRandom; +import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.options.ValueProvider; @@ -122,6 +123,7 @@ public void testFinishKeyCalled() throws Exception { .build()) .build(); when(mockContext.getWorkItem()).thenReturn(workItem); + when(mockContext.advance()).thenReturn(false); try (TestWindmillReaderIterator iter = new TestWindmillReaderIterator(mockContext, ValueProvider.StaticValueProvider.of(false))) { @@ -131,6 +133,78 @@ public void testFinishKeyCalled() throws Exception { } } + @Test + public void testAdvanceKeyChaining() throws Exception { + StreamingModeExecutionContext mockContext = mock(StreamingModeExecutionContext.class); + when(mockContext.workIsFailed()).thenReturn(false); + + // Work item A (1 message) + Windmill.WorkItem workItemA = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("keyA")) + .setWorkToken(100L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(1000) + .setData(ByteString.EMPTY) + .build()) + .build()) + .build(); + when(mockContext.getWorkItem()).thenReturn(workItemA); + + // Work item B (1 message) + Windmill.WorkItem workItemB = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("keyB")) + .setWorkToken(200L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(2000) + .setData(ByteString.EMPTY) + .build()) + .build()) + .build(); + + Work mockWorkB = createMockWork(workItemB); + + // Set up context.advance() to mock transition + when(mockContext.advance()) + .thenAnswer( + new org.mockito.stubbing.Answer() { + private int count = 0; + + @Override + public Boolean answer(org.mockito.invocation.InvocationOnMock invocation) { + if (count == 0) { + count++; + when(mockContext.getWork()).thenReturn(mockWorkB); + return true; + } + return false; + } + }); + + try (TestWindmillReaderIterator iter = + new TestWindmillReaderIterator(mockContext, ValueProvider.StaticValueProvider.of(false))) { + assertTrue(iter.start()); + assertEquals(1000L, iter.getCurrent().getValue().longValue()); + + // Advance should trigger context.advance(), transition to workItemB, and decode message from + // workItemB (timestamp 2000) + assertTrue(iter.advance()); + assertEquals(2000L, iter.getCurrent().getValue().longValue()); + + // Next advance should exhaust it and return false + assertFalse(iter.advance()); + } + } + private void testForMessageBundleCounts(int... messageBundleCounts) throws IOException { testForMessageBundleCounts(false, messageBundleCounts); } @@ -179,4 +253,24 @@ private void testForMessageBundleCounts(boolean skipErrors, int... messageBundle assertEquals(Arrays.toString(messageBundleCounts) + skipErrors, expected, actual); } } + + private static Work createMockWork(Windmill.WorkItem workItem) { + return Work.create( + workItem, + workItem.getSerializedSize(), + org.apache.beam.runners.dataflow.worker.streaming.Watermarks.builder() + .setInputDataWatermark(new org.joda.time.Instant(1000)) + .build(), + Work.createProcessingContext( + "computationId", + mock( + org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient + .class), + ignored -> {}, + mock( + org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender + .class)), + false, + org.joda.time.Instant::now); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index d5cf2948d928..4175b47bfe4f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -209,6 +209,18 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla Instant::now); } + private void startContext(StreamingModeExecutionContext context, Work work) { + context.start( + work, + mock(WindmillStateReader.class), + mock(SideInputStateFetcher.class), + mock(WorkExecutor.class), + /* workQueueExecutor= */ null, + /* budgetHandle= */ null, + /* keyCoder= */ null, + /* keySwitchListener= */ mock(StreamingModeExecutionContext.KeySwitchListener.class)); + } + private static class SourceProducingSubSourcesInSplit extends MockSource { int numDesiredBundle; int sourceObjectSize; @@ -620,7 +632,11 @@ public void testReadUnboundedReader() throws Exception { executionStateRegistry, globalConfigHandle, Long.MAX_VALUE, - /*throwExceptionOnLargeOutput=*/ false); + /*throwExceptionOnLargeOutput=*/ false, + new HotKeyLogger(), + /*hotKeyLoggingEnabled=*/ false, + /*stepName=*/ "stepName", + "sourceBytesProcessCounterName"); options.setNumWorkers(5); int maxElements = 10; @@ -631,8 +647,8 @@ public void testReadUnboundedReader() throws Exception { for (int i = 0; i < 10 * maxElements; /* Incremented in inner loop */ ) { // Initialize streaming context with state from previous iteration. - context.start( - "key", + startContext( + context, createMockWork( Windmill.WorkItem.newBuilder() .setKey(ByteString.copyFromUtf8("0000000000000001")) // key is zero-padded index. @@ -641,11 +657,7 @@ public void testReadUnboundedReader() throws Exception { .setSourceState( Windmill.SourceState.newBuilder().setState(state).build()) // Source state. .build(), - Watermarks.builder().setInputDataWatermark(new Instant(0)).build()), - mock(WindmillStateReader.class), - mock(SideInputStateFetcher.class), - Windmill.WorkItemCommitRequest.newBuilder(), - mock(WorkExecutor.class)); + Watermarks.builder().setInputDataWatermark(new Instant(0)).build())); @SuppressWarnings({"unchecked", "rawtypes"}) NativeReader>>> reader = @@ -992,7 +1004,11 @@ public void testFailedWorkItemsAbort() throws Exception { executionStateRegistry, globalConfigHandle, Long.MAX_VALUE, - /*throwExceptionOnLargeOutput=*/ false); + /*throwExceptionOnLargeOutput=*/ false, + new HotKeyLogger(), + /*hotKeyLoggingEnabled=*/ false, + /*stepName=*/ "stepName", + "sourceBytesProcessCounterName"); options.setNumWorkers(5); int maxElements = 100; @@ -1020,13 +1036,7 @@ public void testFailedWorkItemsAbort() throws Exception { mock(HeartbeatSender.class)), false, Instant::now); - context.start( - "key", - dummyWork, - mock(WindmillStateReader.class), - mock(SideInputStateFetcher.class), - Windmill.WorkItemCommitRequest.newBuilder(), - mock(WorkExecutor.class)); + startContext(context, dummyWork); @SuppressWarnings({"unchecked", "rawtypes"}) NativeReader>>> reader = From 53bc9a690b6fa311d1b3f21e95c41399e7487091 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Thu, 4 Jun 2026 11:05:57 +0000 Subject: [PATCH 02/21] trigger postsubmit tests --- ...beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json | 2 +- ...stCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json index e623d3373a93..50d17c108f2e 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run!", - "modification": 1, + "modification": 2, } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json index e623d3373a93..50d17c108f2e 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run!", - "modification": 1, + "modification": 2, } From 7720cf8fc61ebae6a8da5317d3324d4bf307c182 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Thu, 4 Jun 2026 18:53:19 +0000 Subject: [PATCH 03/21] fix tests --- .../worker/StreamingModeExecutionContext.java | 6 ++++++ .../worker/StreamingDataflowWorkerTest.java | 17 ++++++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index a669fb7ff361..3f62dcbd038f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -296,6 +296,9 @@ public void clear() { this.workQueueExecutor = null; this.budgetHandle = null; this.keySwitchListener = null; + this.work = null; + this.key = null; + this.outputBuilder = null; } public void start( @@ -693,6 +696,9 @@ public boolean advance() { } private void startForNewKey(Work newWork, WindmillStateReader reader) { + if (keySwitchListener != null && this.work != null && this.work != newWork) { + keySwitchListener.onKeySwitch(this.work, newWork); + } this.key = decodeKey(newWork); this.work = newWork; this.finishKeyCalled = false; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index f7511305bf0f..2591db19ec00 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -420,6 +420,7 @@ private ParallelInstruction makeWindowingSourceInstruction(Coder coder) { CloudObjects.asCloudObject(IntervalWindowCoder.of(), /* sdkComponents= */ null))); return new ParallelInstruction() + .setName(DEFAULT_SOURCE_SYSTEM_NAME) .setSystemName(DEFAULT_SOURCE_SYSTEM_NAME) .setOriginalName(DEFAULT_SOURCE_ORIGINAL_NAME) .setRead( @@ -439,6 +440,7 @@ private ParallelInstruction makeWindowingSourceInstruction(Coder coder) { private ParallelInstruction makeSourceInstruction(Coder coder) { return new ParallelInstruction() + .setName(DEFAULT_SOURCE_SYSTEM_NAME) .setSystemName(DEFAULT_SOURCE_SYSTEM_NAME) .setOriginalName(DEFAULT_SOURCE_ORIGINAL_NAME) .setRead( @@ -3955,11 +3957,16 @@ public void testDoFnLatencyBreakdownsReportedOnCommit() throws Exception { LatencyAttribution.newBuilder().setState(State.ACTIVE).setTotalDurationMillis(100); for (LatencyAttribution la : commit.getPerWorkItemLatencyAttributionsList()) { if (la.getState() == State.ACTIVE) { - assertThat(la.getActiveLatencyBreakdownCount(), equalTo(1)); - assertThat( - la.getActiveLatencyBreakdown(0).getUserStepName(), equalTo(DEFAULT_PARDO_USER_NAME)); - Assert.assertTrue(la.getActiveLatencyBreakdown(0).hasProcessingTimesDistribution()); - Assert.assertFalse(la.getActiveLatencyBreakdown(0).hasActiveMessageMetadata()); + LatencyAttribution.ActiveLatencyBreakdown pardoBreakdown = null; + for (LatencyAttribution.ActiveLatencyBreakdown lb : la.getActiveLatencyBreakdownList()) { + if (DEFAULT_PARDO_USER_NAME.equals(lb.getUserStepName())) { + pardoBreakdown = lb; + break; + } + } + Assert.assertNotNull("Expected breakdown for " + DEFAULT_PARDO_USER_NAME, pardoBreakdown); + Assert.assertTrue(pardoBreakdown.hasProcessingTimesDistribution()); + Assert.assertFalse(pardoBreakdown.hasActiveMessageMetadata()); } } From 72740d25cb74c00ca35e933122f847978619a732 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Thu, 4 Jun 2026 21:36:27 +0000 Subject: [PATCH 04/21] fix tests --- .../runners/dataflow/worker/StreamingDataflowWorkerTest.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 2591db19ec00..22350c525ab2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -529,6 +529,7 @@ private ParallelInstruction makeSinkInstruction( CloudObject spec = CloudObject.forClass(WindmillSink.class); addString(spec, "stream_id", streamId); return new ParallelInstruction() + .setName(streamId) .setSystemName(DEFAULT_SINK_SYSTEM_NAME) .setOriginalName(DEFAULT_SINK_ORIGINAL_NAME) .setWrite( @@ -2502,6 +2503,7 @@ private List makeUnboundedSourcePipeline( return Arrays.asList( new ParallelInstruction() + .setName("Read") .setSystemName("Read") .setOriginalName("OriginalReadName") .setRead( From 0f96e723ce9524f68f51319e7e9a68fce05f088b Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Thu, 4 Jun 2026 22:02:56 +0000 Subject: [PATCH 05/21] improve work synchronization --- .../apache/beam/runners/dataflow/worker/streaming/Work.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index 668657228dfd..78cb54b3575b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -241,7 +241,7 @@ public void setProcessingThreadName(String processingThreadName) { } @Override - public synchronized void setFailed() { + public void setFailed() { this.isFailed = true; Runnable listener = onFailureListener; if (listener != null) { @@ -249,7 +249,7 @@ public synchronized void setFailed() { } } - public synchronized void setOnFailureListener(@Nullable Runnable listener) { + public void setOnFailureListener(@Nullable Runnable listener) { this.onFailureListener = listener; if (isFailed && listener != null) { listener.run(); From 51b9257c54260d9dbc188ce70121aef9fefda054 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Fri, 5 Jun 2026 01:03:21 +0000 Subject: [PATCH 06/21] cleanup logic --- .../windmill/work/processing/StreamingWorkScheduler.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index 9ee2192b09d8..5e890ef3d635 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -381,8 +381,7 @@ private ExecuteWorkResult executeWork( StreamingModeExecutionContext context = computationWorkExecutor.context(); if (context.workIsFailed()) { - throw new WorkItemCancelledException( - Preconditions.checkNotNull(context.getWorkItem()).getShardingKey()); + throw new WorkItemCancelledException(work.getWorkItem().getShardingKey()); } // Retrieve executed works, output builders, and accumulated callbacks from execution context @@ -411,6 +410,7 @@ private ExecuteWorkResult executeWork( t); computationWorkExecutor.invalidate(); } + // Re-throw the exception, it will be caught and handled by workFailureProcessor downstream. throw t; } } From 9a4e7bea7894a9b38c19f6db813c256809da83bc Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Fri, 5 Jun 2026 01:11:34 +0000 Subject: [PATCH 07/21] cleanup logic --- .../dataflow/worker/StreamingModeExecutionContext.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 3f62dcbd038f..770093dbcce4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -560,8 +560,9 @@ public void invalidateCache() { try { activeReader.close(); } catch (IOException e) { - LOG.warn( - "Failed to close reader for {}-{}", computationId, getWorkItem().getShardingKey(), e); + Windmill.WorkItem workItem = getWorkItem(); + long shardingKey = workItem != null ? workItem.getShardingKey() : -1L; + LOG.warn("Failed to close reader for {}-{}", computationId, shardingKey, e); } } activeReader = null; From 3f36afd72f5529da7389b2da2e47d87d363f21fd Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 03:40:43 +0000 Subject: [PATCH 08/21] address comments --- .../worker/StreamingModeExecutionContext.java | 61 +++++--- .../streaming/ComputationWorkExecutor.java | 11 +- .../dataflow/worker/streaming/Work.java | 17 ++- .../processing/StreamingWorkScheduler.java | 126 +++++++++-------- .../StreamingModeExecutionContextTest.java | 11 +- .../worker/WorkerCustomSourcesTest.java | 11 +- .../dataflow/worker/streaming/WorkTest.java | 132 ++++++++++++++++++ 7 files changed, 275 insertions(+), 94 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WorkTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 770093dbcce4..af9f29c7b9ba 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -35,6 +35,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; @@ -58,6 +59,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInput; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; @@ -149,6 +151,7 @@ public class StreamingModeExecutionContext private @Nullable Work work; private WindmillComputationKey computationKey; + private SideInputStateFetcherFactory sideInputStateFetcherFactory; private SideInputStateFetcher sideInputStateFetcher; // OperationalLimits is updated in start() because a StreamingModeExecutionContext can // be used for processing many work items and these values can change during the context's @@ -174,22 +177,24 @@ public class StreamingModeExecutionContext private @Nullable BoundedQueueExecutorWorkHandle budgetHandle; private final HotKeyLogger hotKeyLogger; - private boolean hotKeyLoggingEnabled = false; + private final boolean hotKeyLoggingEnabled; private final String stepName; private @Nullable Coder keyCoder; // Key switch listener to delegate MDC logging context and thread name updates - public interface KeySwitchListener { - void onKeySwitch(Work oldWork, Work newWork); + public interface KeyTransitionListener { + void onKeyTransition(Work oldWork, Work newWork); } @SuppressWarnings("UnusedVariable") - private @Nullable KeySwitchListener keySwitchListener; + private @Nullable KeyTransitionListener keyTransitionListener; private List executedWorks = new ArrayList<>(); private List outputBuilders = new ArrayList<>(); + + // Map> private Map> accumulatedCallbacks = new HashMap<>(); - private volatile boolean workIsFailed = false; + private final AtomicBoolean workIsFailed = new AtomicBoolean(false); private @Nullable WindmillStateReader activeStateReader; private long stateBytesRead = 0; private final String sourceBytesProcessCounterName; @@ -248,7 +253,7 @@ public boolean throwExceptionsForLargeOutput() { } public boolean workIsFailed() { - return workIsFailed; + return workIsFailed.get(); } public boolean getDrainMode() { @@ -287,7 +292,7 @@ public void clear() { this.executedWorks = new ArrayList<>(); this.outputBuilders = new ArrayList<>(); this.accumulatedCallbacks = new HashMap<>(); - this.workIsFailed = false; + this.workIsFailed.set(false); this.sideInputCache.clear(); this.activeStateReader = null; this.activeReader = null; @@ -295,31 +300,32 @@ public void clear() { this.workExecutor = null; this.workQueueExecutor = null; this.budgetHandle = null; - this.keySwitchListener = null; + this.keyTransitionListener = null; this.work = null; this.key = null; this.outputBuilder = null; + this.sideInputStateFetcherFactory = null; + this.sideInputStateFetcher = null; + this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; + clearSinkFullHint(); + this.stateBytesRead = 0; } public void start( Work work, WindmillStateReader stateReader, - SideInputStateFetcher sideInputStateFetcher, + SideInputStateFetcherFactory sideInputStateFetcherFactory, WorkExecutor workExecutor, BoundedQueueExecutor workQueueExecutor, BoundedQueueExecutorWorkHandle budgetHandle, @Nullable Coder keyCoder, - KeySwitchListener keySwitchListener) { + KeyTransitionListener keyTransitionListener) { clear(); this.keyCoder = keyCoder; this.workExecutor = workExecutor; this.workQueueExecutor = workQueueExecutor; this.budgetHandle = budgetHandle; - this.keySwitchListener = keySwitchListener; - - this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; - clearSinkFullHint(); - this.stateBytesRead = 0; + this.keyTransitionListener = keyTransitionListener; StreamingGlobalConfig config = globalConfigHandle.getConfig(); // Snapshot the limits for entire bundle processing. @@ -328,7 +334,7 @@ public void start( config.enableStateTagEncodingV2() ? WindmillTagEncodingV2.instance() : WindmillTagEncodingV1.instance(); - this.sideInputStateFetcher = sideInputStateFetcher; + this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; startForNewKey(work, stateReader); } @@ -388,6 +394,9 @@ public void finishKey() { if (activeStateReader != null) { this.stateBytesRead += activeStateReader.getBytesRead(); } + if (sideInputStateFetcher != null) { + this.stateBytesRead += sideInputStateFetcher.getBytesRead(); + } checkStateNotNull(workExecutor, "workExecutor must be set before calling finishKey()"); try { workExecutor.finishKey(); @@ -697,8 +706,9 @@ public boolean advance() { } private void startForNewKey(Work newWork, WindmillStateReader reader) { - if (keySwitchListener != null && this.work != null && this.work != newWork) { - keySwitchListener.onKeySwitch(this.work, newWork); + newWork.setState(Work.State.PROCESSING); + if (keyTransitionListener != null && this.work != null && this.work != newWork) { + keyTransitionListener.onKeyTransition(this.work, newWork); } this.key = decodeKey(newWork); this.work = newWork; @@ -707,11 +717,16 @@ private void startForNewKey(Work newWork, WindmillStateReader reader) { this.outputBuilder = createOutputBuilder(newWork); this.outputBuilders.add(this.outputBuilder); - newWork.setOnFailureListener(() -> this.workIsFailed = true); + newWork.setOnFailureListener(this.workIsFailed); this.executedWorks.add(newWork); logHotKeyIfDetected(newWork, this.key); + this.sideInputStateFetcher = + sideInputStateFetcherFactory.createSideInputStateFetcher(newWork::fetchSideInput); + this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; + this.activeReader = null; + // Note: We do NOT clear sideInputCache here, allowing Key B to reuse warm side inputs! // Re-initialize state cache and state/timer internals across all step contexts @@ -738,8 +753,12 @@ public long getStateBytesRead() { return stateBytesRead; } - public List getOutputBuilders() { - return outputBuilders; + public List getWorkItemCommits() { + List commits = new ArrayList<>(outputBuilders.size()); + for (Windmill.WorkItemCommitRequest.Builder builder : outputBuilders) { + commits.add(builder.build()); + } + return commits; } public Map> getAccumulatedCallbacks() { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java index ed86d58b9bb0..31420b212c31 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java @@ -23,7 +23,8 @@ import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.dataflow.worker.DataflowWorkExecutor; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.sdk.annotations.Internal; @@ -65,21 +66,21 @@ public static ComputationWorkExecutor.Builder builder() { public final void executeWork( Work work, WindmillStateReader stateReader, - SideInputStateFetcher sideInputStateFetcher, + SideInputStateFetcherFactory sideInputStateFetcherFactory, BoundedQueueExecutor workQueueExecutor, BoundedQueueExecutorWorkHandle budgetHandle, - StreamingModeExecutionContext.KeySwitchListener keySwitchListener) + KeyTransitionListener keyTransitionListener) throws Exception { context() .start( work, stateReader, - sideInputStateFetcher, + sideInputStateFetcherFactory, workExecutor(), workQueueExecutor, budgetHandle, keyCoder().orElse(null), - keySwitchListener); + keyTransitionListener); workExecutor().execute(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index 78cb54b3575b..44c1805e221f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -27,6 +27,8 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Supplier; import javax.annotation.concurrent.NotThreadSafe; @@ -80,7 +82,8 @@ public final class Work implements RefreshableWork { private volatile TimedState currentState; private volatile boolean isFailed; private volatile String processingThreadName = ""; - private volatile @Nullable Runnable onFailureListener = null; + private final AtomicReference<@Nullable AtomicBoolean> onFailureListener = + new AtomicReference<>(null); private final boolean drainMode; private Work( @@ -243,16 +246,18 @@ public void setProcessingThreadName(String processingThreadName) { @Override public void setFailed() { this.isFailed = true; - Runnable listener = onFailureListener; + AtomicBoolean listener = onFailureListener.get(); if (listener != null) { - listener.run(); + listener.set(true); } } - public void setOnFailureListener(@Nullable Runnable listener) { - this.onFailureListener = listener; + // Sets the passed in boolean to true if the work fails + // Supports registering only one boolean at a time. + public void setOnFailureListener(@Nullable AtomicBoolean listener) { + onFailureListener.set(listener); if (isFailed && listener != null) { - listener.run(); + listener.set(true); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index 5e890ef3d635..dc1fd4791fcd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -34,6 +34,7 @@ import org.apache.beam.runners.dataflow.worker.HotKeyLogger; import org.apache.beam.runners.dataflow.worker.ReaderCache; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC; import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; @@ -46,7 +47,6 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingCounters; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.ExceptionUtils; @@ -184,14 +184,22 @@ private static Windmill.WorkItemCommitRequest buildWorkItemTruncationRequest( /** Sets the stage name and workId of the Thread executing the {@link Work} for logging. */ private static void setUpWorkLoggingContext(String workLatencyTrackingId, String computationId) { - DataflowWorkerLoggingMDC.setWorkId(workLatencyTrackingId); + setLoggingContextWorkId(workLatencyTrackingId); + setLoggingContextComputation(computationId); + } + + private static void setLoggingContextComputation(@Nullable String computationId) { DataflowWorkerLoggingMDC.setStageName(computationId); } + private static void setLoggingContextWorkId(@Nullable String workLatencyTrackingId) { + DataflowWorkerLoggingMDC.setWorkId(workLatencyTrackingId); + } + /** Resets logging context of the Thread executing the {@link Work} for logging. */ private void resetWorkLoggingContext() { - DataflowWorkerLoggingMDC.setWorkId(null); - DataflowWorkerLoggingMDC.setStageName(null); + setLoggingContextWorkId(null); + setLoggingContextComputation(null); } /** @@ -256,7 +264,7 @@ private void processWork( long processingStartTimeNanos = System.nanoTime(); StageInfo stageInfo = getStageInfo(computationState); - List worksToCleanup = null; + List workBatch = null; try { if (work.isFailed()) { throw new WorkItemCancelledException(workItem.getShardingKey()); @@ -264,18 +272,14 @@ private void processWork( // Execute the user code for the Work batch. ExecuteWorkResult executeWorkResult = executeWork(work, stageInfo, computationState, handle); - List workBatch = executeWorkResult.workBatch(); - worksToCleanup = workBatch; - List outputBuilders = - executeWorkResult.outputBuilders(); - Map> accumulatedCallbacks = - executeWorkResult.accumulatedCallbacks(); + workBatch = executeWorkResult.workBatch(); + List workItemCommits = executeWorkResult.workItemCommits(); - commitFinalizer.cacheCommitFinalizers(accumulatedCallbacks); + commitFinalizer.cacheCommitFinalizers(executeWorkResult.accumulatedCallbacks()); - commitWorkBatch(computationState, workBatch, outputBuilders); + commitWorkBatch(computationState, workBatch, workItemCommits); - recordProcessingStats(workBatch, outputBuilders, executeWorkResult.stateBytesRead()); + recordProcessingStats(workBatch, workItemCommits, executeWorkResult.stateBytesRead()); LOG.debug("Processing done for work batch size: {}", workBatch.size()); } catch (Throwable t) { // OutOfMemoryError that are caught will be rethrown and trigger jvm termination. @@ -294,11 +298,19 @@ private void processWork( throw ExceptionUtils.safeWrapThrowableAsException(t2); } } finally { - recordProcessingTime(stageInfo, worksToCleanup, work, processingStartTimeNanos); + // Update total processing time counters. Updating in finally clause ensures that + // work items causing exceptions are also accounted in time spent. + recordProcessingTime(stageInfo, workBatch, work, processingStartTimeNanos); resetWorkLoggingContext(); sampler.resetForWorkId(work.getLatencyTrackingId()); - work.setProcessingThreadName(""); + if (workBatch != null) { + for (Work w : workBatch) { + w.setProcessingThreadName(""); + } + } else { + work.setProcessingThreadName(""); + } } } @@ -331,16 +343,18 @@ private Windmill.WorkItemCommitRequest validateCommitRequestSize( private void recordProcessingStats( List workBatch, - List outputBuilders, + List workItemCommits, long totalStateBytesRead) { long totalStateBytesWritten = 0; long totalShuffleBytesRead = 0; + Preconditions.checkState(workBatch.size() == workItemCommits.size()); for (int i = 0; i < workBatch.size(); i++) { Windmill.WorkItem workItem = workBatch.get(i).getWorkItem(); - Windmill.WorkItemCommitRequest.Builder outputBuilder = outputBuilders.get(i); + Windmill.WorkItemCommitRequest commit = workItemCommits.get(i); // Compute shuffle and state byte statistics these will be flushed asynchronously. long stateBytesWritten = - outputBuilder + commit + .toBuilder() .clearOutputMessages() .clearPerWorkItemLatencyAttributions() .build() @@ -369,36 +383,43 @@ private ExecuteWorkResult executeWork( try { WindmillStateReader stateReader = work.createWindmillStateReader(); - SideInputStateFetcher localSideInputStateFetcher = - sideInputStateFetcherFactory.createSideInputStateFetcher(work::fetchSideInput); - StreamingModeExecutionContext.KeySwitchListener keySwitchListener = - createKeySwitchListener(computationState); + KeyTransitionListener keyTransitionListener = createKeyTransitionListener(); // Blocks while executing work. computationWorkExecutor.executeWork( - work, stateReader, localSideInputStateFetcher, workExecutor, handle, keySwitchListener); - - StreamingModeExecutionContext context = computationWorkExecutor.context(); - if (context.workIsFailed()) { - throw new WorkItemCancelledException(work.getWorkItem().getShardingKey()); + work, + stateReader, + sideInputStateFetcherFactory, + workExecutor, + handle, + keyTransitionListener); + + List workBatch; + List workItemCommits; + Map> accumulatedCallbacks; + long stateBytesRead; + { + StreamingModeExecutionContext context = computationWorkExecutor.context(); + if (context.workIsFailed()) { + throw new WorkItemCancelledException(work.getWorkItem().getShardingKey()); + } + + // Retrieve executed works, work item commits, and accumulated callbacks from execution + // context + workBatch = context.getExecutedWorks(); + workItemCommits = context.getWorkItemCommits(); + accumulatedCallbacks = context.getAccumulatedCallbacks(); + stateBytesRead = context.getStateBytesRead(); + + context.clear(); // Don't use context after this. } - - // Retrieve executed works, output builders, and accumulated callbacks from execution context - List workBatch = context.getExecutedWorks(); - List outputBuilders = context.getOutputBuilders(); - Map> accumulatedCallbacks = context.getAccumulatedCallbacks(); - - context.clear(); // Release the execution state for another thread to use. computationState.releaseComputationWorkExecutor(computationWorkExecutor); computationWorkExecutor = null; return ExecuteWorkResult.create( - workBatch, - outputBuilders, - accumulatedCallbacks, - context.getStateBytesRead() + localSideInputStateFetcher.getBytesRead()); + workBatch, workItemCommits, accumulatedCallbacks, stateBytesRead); } catch (Throwable t) { if (computationWorkExecutor != null) { // If processing failed due to a thrown exception, close the executionState. Do not @@ -433,20 +454,18 @@ private StageInfo getStageInfo(ComputationState computationState) { private void commitWorkBatch( ComputationState computationState, List workBatch, - List outputBuilders) { + List workItemCommits) { Preconditions.checkState( workBatch.size() == 1, "Expected single-key work batch, got: " + workBatch.size()); - commitSingleKeyWork(computationState, workBatch.get(0), outputBuilders.get(0)); + commitSingleKeyWork(computationState, workBatch.get(0), workItemCommits.get(0)); } private void commitSingleKeyWork( - ComputationState computationState, - Work work, - Windmill.WorkItemCommitRequest.Builder commitRequestBuilder) { + ComputationState computationState, Work work, Windmill.WorkItemCommitRequest commitRequest) { // Validate the commit request, possibly requesting truncation if the commitSize is too large. Windmill.WorkItemCommitRequest validatedCommitRequest = validateCommitRequestSize( - commitRequestBuilder.build(), computationState.getComputationId(), work.getWorkItem()); + commitRequest, computationState.getComputationId(), work.getWorkItem()); work.setState(Work.State.COMMIT_QUEUED); validatedCommitRequest = validatedCommitRequest @@ -461,8 +480,6 @@ private void recordProcessingTime( @Nullable List worksToCleanup, Work work, long processingStartTimeNanos) { - // Update total processing time counters. Updating in finally clause ensures that - // work items causing exceptions are also accounted in time spent. long processingTimeMsecs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); @@ -483,12 +500,10 @@ private static boolean anyWorkHasTimers(@Nullable List works, Work primary return primaryWork.getWorkItem().hasTimers(); } - private StreamingModeExecutionContext.KeySwitchListener createKeySwitchListener( - ComputationState computationState) { + private KeyTransitionListener createKeyTransitionListener() { return (oldWork, newWork) -> { - resetWorkLoggingContext(); - setUpWorkLoggingContext(newWork.getLatencyTrackingId(), computationState.getComputationId()); - newWork.setProcessingThreadName(Thread.currentThread().getName()); + setLoggingContextWorkId(newWork.getLatencyTrackingId()); + newWork.setProcessingThreadName(oldWork.getProcessingThreadName()); oldWork.setProcessingThreadName(""); }; } @@ -497,17 +512,18 @@ private StreamingModeExecutionContext.KeySwitchListener createKeySwitchListener( abstract static class ExecuteWorkResult { static ExecuteWorkResult create( List workBatch, - List outputBuilders, + List workItemCommits, Map> accumulatedCallbacks, long stateBytesRead) { return new AutoValue_StreamingWorkScheduler_ExecuteWorkResult( - workBatch, outputBuilders, accumulatedCallbacks, stateBytesRead); + workBatch, workItemCommits, accumulatedCallbacks, stateBytesRead); } abstract List workBatch(); - abstract List outputBuilders(); + abstract List workItemCommits(); + // Map> abstract Map> accumulatedCallbacks(); abstract long stateBytesRead(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 9d4ef999707c..6d84e9b4b0bf 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -48,6 +48,7 @@ import org.apache.beam.runners.core.metrics.ExecutionStateSampler; import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.core.metrics.ExecutionStateTracker.ExecutionState; +import org.apache.beam.runners.dataflow.options.DataflowStreamingPipelineOptions; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.DataflowExecutionContext.DataflowExecutionStateTracker; import org.apache.beam.runners.dataflow.worker.MetricsToCounterUpdateConverter.Kind; @@ -61,7 +62,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.FakeGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; @@ -99,7 +100,6 @@ public class StreamingModeExecutionContextTest { @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - @Mock private SideInputStateFetcher sideInputStateFetcher; @Mock private WindmillStateReader stateReader; @Mock private WorkExecutor workExecutor; @@ -203,15 +203,18 @@ private void start(StreamingModeExecutionContext context, Work work) { } private void start(StreamingModeExecutionContext context, Work work, Coder keyCoder) { + SideInputStateFetcherFactory sideInputStateFetcherFactory = + SideInputStateFetcherFactory.fromOptions( + options.as(DataflowStreamingPipelineOptions.class)); context.start( work, stateReader, - sideInputStateFetcher, + sideInputStateFetcherFactory, workExecutor, /* workQueueExecutor= */ null, /* budgetHandle= */ null, keyCoder, - /* keySwitchListener= */ (k, c) -> {}); + /* keyTransitionListener= */ (k, c) -> {}); } @Test diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index 4175b47bfe4f..bd4e40d6570a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -80,9 +80,11 @@ import org.apache.beam.runners.dataflow.DataflowRunner; import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions; import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; +import org.apache.beam.runners.dataflow.options.DataflowStreamingPipelineOptions; import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.PropertyNames; import org.apache.beam.runners.dataflow.worker.DataflowExecutionContext.DataflowExecutionStateTracker; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.StreamingModeExecutionStateRegistry; import org.apache.beam.runners.dataflow.worker.WorkerCustomSources.SplittableOnlyBoundedSource; import org.apache.beam.runners.dataflow.worker.counters.CounterSet; @@ -93,7 +95,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.config.FixedGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.testing.TestCountingSource; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader.NativeReaderIterator; @@ -210,15 +212,18 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla } private void startContext(StreamingModeExecutionContext context, Work work) { + SideInputStateFetcherFactory sideInputStateFetcherFactory = + SideInputStateFetcherFactory.fromOptions( + options.as(DataflowStreamingPipelineOptions.class)); context.start( work, mock(WindmillStateReader.class), - mock(SideInputStateFetcher.class), + sideInputStateFetcherFactory, mock(WorkExecutor.class), /* workQueueExecutor= */ null, /* budgetHandle= */ null, /* keyCoder= */ null, - /* keySwitchListener= */ mock(StreamingModeExecutionContext.KeySwitchListener.class)); + /* keyTransitionListener= */ mock(KeyTransitionListener.class)); } private static class SourceProducingSubSourcesInSplit extends MockSource { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WorkTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WorkTest.java new file mode 100644 index 000000000000..80ca91da462f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WorkTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class WorkTest { + + private static Work createTestWork() { + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key")) + .setWorkToken(1L) + .setShardingKey(2L) + .build(); + return Work.create( + workItem, + workItem.getSerializedSize(), + Watermarks.builder().setInputDataWatermark(Instant.now()).build(), + Work.createProcessingContext( + "comp", + mock( + org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient + .class), + commit -> {}, + mock(HeartbeatSender.class)), + false, + Instant::now); + } + + @Test + public void testSetFailedBeforeListener() { + Work work = createTestWork(); + assertFalse(work.isFailed()); + + work.setFailed(); + assertTrue(work.isFailed()); + + AtomicBoolean listener = new AtomicBoolean(false); + work.setOnFailureListener(listener); + assertTrue(listener.get()); + } + + @Test + public void testSetFailedAfterListener() { + Work work = createTestWork(); + AtomicBoolean listener = new AtomicBoolean(false); + work.setOnFailureListener(listener); + assertFalse(listener.get()); + assertFalse(work.isFailed()); + + work.setFailed(); + assertTrue(work.isFailed()); + assertTrue(listener.get()); + } + + @Test + public void testConcurrentSetFailedAndSetOnFailureListener() throws Exception { + int numTrials = 5000; + ExecutorService executor = Executors.newFixedThreadPool(2); + try { + for (int i = 0; i < numTrials; i++) { + Work work = createTestWork(); + AtomicBoolean listener = new AtomicBoolean(false); + CountDownLatch latch = new CountDownLatch(1); + + Future f1 = + executor.submit( + () -> { + try { + latch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + work.setFailed(); + }); + + Future f2 = + executor.submit( + () -> { + try { + latch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + work.setOnFailureListener(listener); + }); + + latch.countDown(); + f1.get(5, TimeUnit.SECONDS); + f2.get(5, TimeUnit.SECONDS); + + assertTrue("Trial " + i + " failed: work should be failed", work.isFailed()); + assertTrue("Trial " + i + " failed: listener should be set to true", listener.get()); + } + } finally { + executor.shutdownNow(); + } + } +} From f3cc6284fb97df1bc3762647c5c2bd7452509efe Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 05:37:04 +0000 Subject: [PATCH 09/21] improve WindowingWindmillReader --- .../worker/WindmillReaderIteratorBase.java | 2 +- .../worker/WindowingWindmillReader.java | 83 +++--- .../worker/WindowingWindmillReaderTest.java | 275 ++++++++++++++++++ 3 files changed, 315 insertions(+), 45 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java index b142cc38d365..20d0c40ae4a3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java @@ -73,7 +73,7 @@ public boolean advance() throws IOException { continue; } - // All work items are exhausted. Iterator returns false. + // All work items are exhausted. current = null; return false; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java index 916920518f0b..fc11ff8dca76 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java @@ -30,7 +30,6 @@ import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.worker.util.ValueInEmptyWindows; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; @@ -49,7 +48,6 @@ @Internal class WindowingWindmillReader extends NativeReader>> { - private final Coder keyCoder; private final Coder valueCoder; private final Coder windowCoder; private final Coder> windowsCoder; @@ -66,7 +64,6 @@ class WindowingWindmillReader extends NativeReader keyedWorkItemCoder = (WindmillKeyedWorkItem.FakeKeyedWorkItemCoder) inputCoder.getValueCoder(); - this.keyCoder = keyedWorkItemCoder.getKeyCoder(); this.valueCoder = keyedWorkItemCoder.getElementCoder(); this.context = context; this.skipUndecodableElements = skipUndecodableElements; @@ -129,27 +126,32 @@ public static WindowingWindmillReader create( return new WindowingWindmillReader<>(coder, context, skipUndecodableElements); } + private KeyedWorkItem createKeyedWorkItem() { + @SuppressWarnings("unchecked") + @Nullable K key = (K) context.getKey(); + return new WindmillKeyedWorkItem<>( + key, + context.getWorkItem(), + windowCoder, + windowsCoder, + valueCoder, + context.getWindmillTagEncoding(), + context.getDrainMode(), + skipUndecodableElements.isAccessible() + && Boolean.TRUE.equals(skipUndecodableElements.get())); + } + + private boolean isEmpty(KeyedWorkItem keyedWorkItem) { + return Iterables.isEmpty(keyedWorkItem.timersIterable()) + && Iterables.isEmpty(keyedWorkItem.elementsIterable()); + } + @Override public NativeReaderIterator>> iterator() throws IOException { - final K key = - keyCoder.decode( - checkStateNotNull(context.getSerializedKey()).newInput(), Coder.Context.OUTER); - final WorkItem workItem = context.getWorkItem(); - KeyedWorkItem keyedWorkItem = - new WindmillKeyedWorkItem<>( - key, - workItem, - windowCoder, - windowsCoder, - valueCoder, - context.getWindmillTagEncoding(), - context.getDrainMode(), - skipUndecodableElements.isAccessible() - && Boolean.TRUE.equals(skipUndecodableElements.get())); - final boolean isEmptyWorkItem = - (Iterables.isEmpty(keyedWorkItem.timersIterable()) - && Iterables.isEmpty(keyedWorkItem.elementsIterable())); - final WindowedValue> value = new ValueInEmptyWindows<>(keyedWorkItem); + final KeyedWorkItem firstKeyedWorkItem = createKeyedWorkItem(); + final boolean firstKeyIsEmpty = isEmpty(firstKeyedWorkItem); + final WindowedValue> firstValue = + new ValueInEmptyWindows<>(firstKeyedWorkItem); return new NativeReaderIterator>>() { private @Nullable WindowedValue> current = null; @@ -165,10 +167,10 @@ public boolean start() throws IOException { return false; } started = true; - if (isEmptyWorkItem) { + if (firstKeyIsEmpty) { return advance(); // Try to transition immediately if the first key is empty! } - current = value; + current = firstValue; return true; } @@ -179,27 +181,20 @@ public boolean advance() throws IOException { checkStateNotNull(context.getWorkItem()).getShardingKey()); } - context.finishKey(); - if (context.advance()) { - @SuppressWarnings("unchecked") - K newKey = (K) context.getKey(); - KeyedWorkItem newKeyedWorkItem = - new WindmillKeyedWorkItem<>( - newKey, - context.getWork().getWorkItem(), - windowCoder, - windowsCoder, - valueCoder, - context.getWindmillTagEncoding(), - context.getDrainMode(), - skipUndecodableElements.isAccessible() - && Boolean.TRUE.equals(skipUndecodableElements.get())); - current = new ValueInEmptyWindows<>(newKeyedWorkItem); - return true; + while (true) { + context.finishKey(); + if (context.advance()) { + KeyedWorkItem newKeyedWorkItem = createKeyedWorkItem(); + if (isEmpty(newKeyedWorkItem)) { + continue; + } + current = new ValueInEmptyWindows<>(newKeyedWorkItem); + return true; + } + + current = null; + return false; } - - current = null; - return false; } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java new file mode 100644 index 000000000000..2e7c80330cf0 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.List; +import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTagEncodingV1; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues.FullWindowedValueCoder; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class WindowingWindmillReaderTest { + private StreamingModeExecutionContext mockContext; + private WindowingWindmillReader reader; + + @SuppressWarnings("unchecked") + @Before + public void setUp() { + mockContext = mock(StreamingModeExecutionContext.class); + when(mockContext.workIsFailed()).thenReturn(false); + when(mockContext.getWindmillTagEncoding()).thenReturn(WindmillTagEncodingV1.instance()); + when(mockContext.getDrainMode()).thenReturn(false); + + Coder keyCoder = StringUtf8Coder.of(); + Coder valueCoder = VarLongCoder.of(); + KvCoder kvCoder = KvCoder.of(keyCoder, valueCoder); + WindmillKeyedWorkItem.FakeKeyedWorkItemCoder keyedWorkItemCoder = + (WindmillKeyedWorkItem.FakeKeyedWorkItemCoder) + WindmillKeyedWorkItem.FakeKeyedWorkItemCoder.of(kvCoder); + FullWindowedValueCoder> coder = + FullWindowedValueCoder.of(keyedWorkItemCoder, IntervalWindowCoder.of()); + + reader = + WindowingWindmillReader.create( + coder, mockContext, ValueProvider.StaticValueProvider.of(false)); + } + + private static Work createMockWork(Windmill.WorkItem workItem) { + return Work.create( + workItem, + workItem.getSerializedSize(), + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build(), + Work.createProcessingContext( + "computationId", new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), + false, + Instant::now); + } + + private static ByteString encodeMetadata(List windows) throws IOException { + ByteStringOutputStream stream = new ByteStringOutputStream(); + PaneInfoCoder.INSTANCE.encode(PaneInfo.NO_FIRING, stream); + ListCoder.of(IntervalWindowCoder.of()).encode(windows, stream); + return stream.toByteString(); + } + + private static ByteString encodeValue(long value) throws IOException { + ByteStringOutputStream stream = new ByteStringOutputStream(); + VarLongCoder.of().encode(value, stream); + return stream.toByteString(); + } + + @Test + public void testSingleNonEmptyKey() throws IOException { + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(1000)); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(100L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(1000) + .setData(encodeValue(42L)) + .setMetadata(encodeMetadata(ImmutableList.of(window))) + .build()) + .build()) + .build(); + Work work = createMockWork(workItem); + + when(mockContext.getKey()).thenReturn("key1"); + when(mockContext.getWorkItem()).thenReturn(workItem); + when(mockContext.getWork()).thenReturn(work); + when(mockContext.advance()).thenReturn(false); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + assertTrue(iter.start()); + WindowedValue> current = iter.getCurrent(); + assertEquals("key1", current.getValue().key()); + assertFalse(Iterables.isEmpty(current.getValue().elementsIterable())); + WindowedValue elem = Iterables.getOnlyElement(current.getValue().elementsIterable()); + assertEquals(42L, elem.getValue().longValue()); + + assertFalse(iter.advance()); + verify(mockContext).finishKey(); + } + } + + @Test + public void testSingleEmptyKey() throws IOException { + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(100L) + .build(); // No message bundles or timers + Work work = createMockWork(workItem); + + when(mockContext.getKey()).thenReturn("key1"); + when(mockContext.getWorkItem()).thenReturn(workItem); + when(mockContext.getWork()).thenReturn(work); + when(mockContext.advance()).thenReturn(false); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + assertFalse( + iter.start()); // Should skip the empty key and return false because advance returns false + verify(mockContext).finishKey(); + } + } + + @Test + public void testMultipleKeys_withEmptyAndNonEmpty() throws IOException { + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(1000)); + // Key 1: Empty + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(100L) + .build(); + Work work1 = createMockWork(workItem1); + + // Key 2: Non-empty + Windmill.WorkItem workItem2 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key2")) + .setWorkToken(200L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(2000) + .setData(encodeValue(84L)) + .setMetadata(encodeMetadata(ImmutableList.of(window))) + .build()) + .build()) + .build(); + Work work2 = createMockWork(workItem2); + + // Key 3: Empty + Windmill.WorkItem workItem3 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key3")) + .setWorkToken(300L) + .build(); + Work work3 = createMockWork(workItem3); + + // Initial state + when(mockContext.getKey()).thenReturn("key1"); + when(mockContext.getWorkItem()).thenReturn(workItem1); + when(mockContext.getWork()).thenReturn(work1); + + // Mock transition behaviour of context.advance() + when(mockContext.advance()) + .thenAnswer( + new org.mockito.stubbing.Answer() { + private int count = 0; + + @Override + public Boolean answer(org.mockito.invocation.InvocationOnMock invocation) { + if (count == 0) { + count++; + when(mockContext.getKey()).thenReturn("key2"); + when(mockContext.getWorkItem()).thenReturn(workItem2); + when(mockContext.getWork()).thenReturn(work2); + return true; + } else if (count == 1) { + count++; + when(mockContext.getKey()).thenReturn("key3"); + when(mockContext.getWorkItem()).thenReturn(workItem3); + when(mockContext.getWork()).thenReturn(work3); + return true; + } + return false; + } + }); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + // Key 1 is empty, so start() calls advance() which calls finishKey(1) and advance() to Key 2. + // Key 2 is non-empty, so start() returns true yielding Key 2. + assertTrue(iter.start()); + assertEquals("key2", iter.getCurrent().getValue().key()); + WindowedValue elem = + Iterables.getOnlyElement(iter.getCurrent().getValue().elementsIterable()); + assertEquals(84L, elem.getValue().longValue()); + + // Next advance() calls finishKey(2), calls advance() to Key 3. + // Key 3 is empty, so it loops, calls finishKey(3), calls advance() which returns false. + // So iter.advance() should return false. + assertFalse(iter.advance()); + + verify(mockContext, times(3)) + .finishKey(); // finishKey should have been called on key1, key2, key3 + } + } + + @Test + public void testWorkItemCancelled() throws IOException { + when(mockContext.workIsFailed()).thenReturn(true); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(0L).build(); + when(mockContext.getWorkItem()).thenReturn(workItem); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + iter.start(); + fail("Expected WorkItemCancelledException"); + } catch (WorkItemCancelledException e) { + // Expected + } + } +} From 58e0ef9393d45860b61df4c10257351c425bf15f Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 05:49:24 +0000 Subject: [PATCH 10/21] spotless fix --- .../beam/runners/dataflow/worker/WindowingWindmillReader.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java index fc11ff8dca76..2003ec001a55 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java @@ -128,7 +128,8 @@ public static WindowingWindmillReader create( private KeyedWorkItem createKeyedWorkItem() { @SuppressWarnings("unchecked") - @Nullable K key = (K) context.getKey(); + @Nullable + K key = (K) context.getKey(); return new WindmillKeyedWorkItem<>( key, context.getWorkItem(), From 3dceab0470d5228bb5346e9d8ac92f528a714ff0 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 07:57:33 +0000 Subject: [PATCH 11/21] [Dataflow Streaming] Fix nullness supression in StreamingModeExecutionContext --- .../worker/StreamingModeExecutionContext.java | 262 +++++++++--------- 1 file changed, 136 insertions(+), 126 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 25ce299adf7a..89ccb576051f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.dataflow.worker; import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import com.google.api.services.dataflow.model.CounterUpdate; @@ -62,6 +61,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataId; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.Timer; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache.ForComputation; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateInternals; @@ -105,10 +105,7 @@ * state pertaining to a processing its owning computation. Can be reused across processing * different WorkItems for the same computation. */ -@SuppressWarnings({ - "deprecation", - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) +@SuppressWarnings({"deprecation"}) // TODO(m-trieu) fix nullability issues in StreamingModeExecutionContext.java @NotThreadSafe @Internal @@ -143,13 +140,13 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext SideInput fetchSideInput( return fetchSideInputFromWindmill( view, sideInputWindow, - checkNotNull(stateFamily), + checkStateNotNull(stateFamily), state, - checkNotNull(scopedReadStateSupplier), + checkStateNotNull(scopedReadStateSupplier), tagCache); } @@ -383,8 +386,8 @@ private SideInput fetchSideInputFromWindmill( Supplier scopedReadStateSupplier, Map> tagCache) { SideInput fetched = - sideInputStateFetcher.fetchSideInput( - view, sideInputWindow, stateFamily, state, scopedReadStateSupplier); + checkStateNotNull(sideInputStateFetcher) + .fetchSideInput(view, sideInputWindow, stateFamily, state, scopedReadStateSupplier); if (fetched.isReady()) { tagCache.put(sideInputWindow, fetched); @@ -406,7 +409,7 @@ private List getFiredTimers() { } public WindmillComputationKey getComputationKey() { - return computationKey; + return checkStateNotNull(computationKey); } public long getWorkToken() { @@ -414,7 +417,7 @@ public long getWorkToken() { } public Windmill.WorkItem getWorkItem() { - return checkNotNull( + return checkStateNotNull( work, "work is null. A call to StreamingModeExecutionContext.start(...) is required to set" + " work for execution.") @@ -422,7 +425,7 @@ public Windmill.WorkItem getWorkItem() { } public Windmill.WorkItemCommitRequest.Builder getOutputBuilder() { - return outputBuilder; + return checkStateNotNull(outputBuilder); } /** @@ -490,15 +493,16 @@ public Map> flushState() { throw new RuntimeException("Exception while running bundle finalizer", e); } })); - outputBuilder.addFinalizeIds(id); + getOutputBuilder().addFinalizeIds(id); } } - if (activeReader != null) { - Windmill.SourceState.Builder sourceStateBuilder = - outputBuilder.getSourceStateUpdatesBuilder(); - final UnboundedSource.CheckpointMark checkpointMark = activeReader.getCheckpointMark(); - final Instant watermark = activeReader.getWatermark(); + UnboundedReader reader = activeReader; + if (reader != null) { + Windmill.WorkItemCommitRequest.Builder builder = getOutputBuilder(); + Windmill.SourceState.Builder sourceStateBuilder = builder.getSourceStateUpdatesBuilder(); + final UnboundedSource.CheckpointMark checkpointMark = reader.getCheckpointMark(); + final Instant watermark = reader.getWatermark(); long id = ThreadLocalRandom.current().nextLong(); sourceStateBuilder.addFinalizeIds(id); callbacks.put( @@ -515,7 +519,7 @@ public Map> flushState() { @SuppressWarnings("unchecked") Coder checkpointCoder = - ((UnboundedSource) activeReader.getCurrentSource()) + ((UnboundedSource) reader.getCurrentSource()) .getCheckpointMarkCoder(); if (checkpointCoder != null) { ByteStringOutputStream stream = new ByteStringOutputStream(); @@ -525,7 +529,7 @@ public Map> flushState() { throw new RuntimeException("Exception while encoding checkpoint", e); } sourceStateBuilder.setState(stream.toByteString()); - if (activeReader.getCurrentSource().offsetBasedDeduplicationSupported()) { + if (reader.getCurrentSource().offsetBasedDeduplicationSupported()) { byte[] offsetLimit = checkpointMark.getOffsetLimit(); if (offsetLimit.length == 0) { throw new RuntimeException("Checkpoint offset limit must be non-empty."); @@ -533,31 +537,30 @@ public Map> flushState() { sourceStateBuilder.setOffsetLimit(ByteString.copyFrom(offsetLimit)); } } - outputBuilder.setSourceWatermark(WindmillTimeUtils.harnessToWindmillTimestamp(watermark)); + builder.setSourceWatermark(WindmillTimeUtils.harnessToWindmillTimestamp(watermark)); - backlogBytes = activeReader.getSplitBacklogBytes(); + backlogBytes = reader.getSplitBacklogBytes(); + ByteString serializedKey = checkStateNotNull(getSerializedKey()); if (backlogBytes == UnboundedReader.BACKLOG_UNKNOWN - && WorkerCustomSources.isFirstUnboundedSourceSplit(getSerializedKey())) { + && WorkerCustomSources.isFirstUnboundedSourceSplit(serializedKey)) { // Only call getTotalBacklogBytes() on the first split. - backlogBytes = activeReader.getTotalBacklogBytes(); + backlogBytes = reader.getTotalBacklogBytes(); } - outputBuilder.setSourceBacklogBytes(backlogBytes); + builder.setSourceBacklogBytes(backlogBytes); readerCache.cacheReader( - getComputationKey(), - getWorkItem().getCacheToken(), - getWorkItem().getWorkToken(), - activeReader); + getComputationKey(), getWorkItem().getCacheToken(), getWorkItem().getWorkToken(), reader); activeReader = null; } else if (backlogBytes != UnboundedReader.BACKLOG_UNKNOWN && backlogBytes != 1L) { // If activeReader is null, we might still have backlogBytes from an SDF. We ignore a reported // backlogBytes of 1 since older versions of the Java SDK use this value as a default when // RestrictionTracker.getProgress() or GetSize() are not defined. - outputBuilder.setSourceBacklogBytes(backlogBytes); + getOutputBuilder().setSourceBacklogBytes(backlogBytes); } return callbacks; } + @Nullable String getStateFamily(NameContext nameContext) { return nameContext.userName() == null ? null : stateNameMap.get(nameContext.userName()); } @@ -599,7 +602,7 @@ public static class StreamingModeExecutionState extends DataflowExecutionState { public StreamingModeExecutionState( NameContext nameContext, String stateName, - MetricsContainer metricsContainer, + @Nullable MetricsContainer metricsContainer, ProfileScope profileScope) { // TODO: Take in the requesting step name and side input index for streaming. super(nameContext, stateName, null, null, metricsContainer, profileScope); @@ -642,14 +645,16 @@ public static class StreamingModeExecutionStateRegistry extends DataflowExecutio protected DataflowExecutionState createState( NameContext nameContext, String stateName, - String requestingStepName, - Integer inputIndex, - MetricsContainer container, + @Nullable String requestingStepName, + @Nullable Integer inputIndex, + @Nullable MetricsContainer container, ProfileScope profileScope) { return new StreamingModeExecutionState(nameContext, stateName, container, profileScope); } } + private static final Closeable NO_OP_CLOSEABLE = () -> {}; + private static class ScopedReadStateSupplier implements Supplier { private final ExecutionState readState; @@ -662,9 +667,9 @@ private ScopedReadStateSupplier( } @Override - public @Nullable Closeable get() { + public Closeable get() { if (stateTracker == null) { - return null; + return NO_OP_CLOSEABLE; } return stateTracker.enterState(readState); } @@ -725,7 +730,7 @@ public TimerInternals timerInternals() { } @Override - public TimerData getNextFiredTimer(Coder windowCoder) { + public @Nullable TimerData getNextFiredTimer(Coder windowCoder) { return wrapped.getNextFiredUserTimer(windowCoder); } @@ -777,7 +782,7 @@ public static StreamingModeSideInputReader of( } @Override - public T get(PCollectionView view, BoundedWindow window) { + public @Nullable T get(PCollectionView view, BoundedWindow window) { if (!contains(view)) { throw new RuntimeException("get() called with unknown view"); } @@ -810,31 +815,32 @@ public boolean isEmpty() { class StepContext extends DataflowExecutionContext.DataflowStepContext implements StreamingModeStepContext { - private final String stateFamily; + private final @Nullable String stateFamily; private final Supplier scopedReadStateSupplier; - private WindmillStateInternals stateInternals; - private WindmillTimerInternals systemTimerInternals; - private WindmillTimerInternals userTimerInternals; + private @Nullable WindmillStateInternals stateInternals; + private @Nullable WindmillTimerInternals systemTimerInternals; + private @Nullable WindmillTimerInternals userTimerInternals; // Lazily initialized - private Iterator cachedFiredSystemTimers = null; + private @Nullable Iterator cachedFiredSystemTimers = null; // Lazily initialized - private PeekingIterator cachedFiredUserTimers = null; + private @Nullable PeekingIterator cachedFiredUserTimers = null; // An ordered list of any timers that were set or modified by user processing earlier in this // bundle. // We use a NavigableSet instead of a priority queue to prevent duplicate elements from ending // up in the queue. - private NavigableSet modifiedUserEventTimersOrdered = null; - private NavigableSet modifiedUserProcessingTimersOrdered = null; - private NavigableSet modifiedUserSynchronizedProcessingTimersOrdered = null; + private final NavigableSet modifiedUserEventTimersOrdered = Sets.newTreeSet(); + private final NavigableSet modifiedUserProcessingTimersOrdered = Sets.newTreeSet(); + private final NavigableSet modifiedUserSynchronizedProcessingTimersOrdered = + Sets.newTreeSet(); // A list of timer keys that were modified by user processing earlier in this bundle. This // serves a tombstone, so that we know not to fire any bundle timers that were modified. - private Table modifiedUserTimerKeys = null; + private final Table modifiedUserTimerKeys = + HashBasedTable.create(); private final WindmillBundleFinalizer bundleFinalizer = new WindmillBundleFinalizer(); public StepContext(DataflowOperationContext operationContext) { super(operationContext.nameContext()); this.stateFamily = getStateFamily(operationContext.nameContext()); - this.scopedReadStateSupplier = new ScopedReadStateSupplier(operationContext, getExecutionStateTracker()); } @@ -845,46 +851,50 @@ public void start( Instant processingTime, WindmillStateCache.ForKey cacheForKey, Watermarks watermarks) { - this.stateInternals = - new WindmillStateInternals<>( - key, - stateFamily, - stateReader, - getWorkItem().getIsNewKey(), - cacheForKey.forFamily(stateFamily), - windmillTagEncoding, - scopedReadStateSupplier); - - this.systemTimerInternals = - new WindmillTimerInternals( - stateFamily, - WindmillTimerType.SYSTEM_TIMER, - processingTime, - watermarks, - windmillTagEncoding, - td -> {}); - - this.userTimerInternals = - new WindmillTimerInternals( - stateFamily, - WindmillTimerType.USER_TIMER, - processingTime, - watermarks, - windmillTagEncoding, - this::onUserTimerModified); - + if (stateFamily != null) { + this.stateInternals = + new WindmillStateInternals<>( + key, + stateFamily, + stateReader, + getWorkItem().getIsNewKey(), + cacheForKey.forFamily(stateFamily), + windmillTagEncoding, + scopedReadStateSupplier); + + this.systemTimerInternals = + new WindmillTimerInternals( + stateFamily, + WindmillTimerType.SYSTEM_TIMER, + processingTime, + watermarks, + windmillTagEncoding, + td -> {}); + + this.userTimerInternals = + new WindmillTimerInternals( + stateFamily, + WindmillTimerType.USER_TIMER, + processingTime, + watermarks, + windmillTagEncoding, + this::onUserTimerModified); + } this.cachedFiredSystemTimers = null; this.cachedFiredUserTimers = null; - modifiedUserEventTimersOrdered = Sets.newTreeSet(); - modifiedUserProcessingTimersOrdered = Sets.newTreeSet(); - modifiedUserSynchronizedProcessingTimersOrdered = Sets.newTreeSet(); - modifiedUserTimerKeys = HashBasedTable.create(); + this.modifiedUserEventTimersOrdered.clear(); + this.modifiedUserProcessingTimersOrdered.clear(); + this.modifiedUserSynchronizedProcessingTimersOrdered.clear(); + this.modifiedUserTimerKeys.clear(); } public void flushState() { - stateInternals.persist(outputBuilder); - systemTimerInternals.persistTo(outputBuilder); - userTimerInternals.persistTo(outputBuilder); + if (stateFamily != null) { + WorkItemCommitRequest.Builder builder = getOutputBuilder(); + checkStateNotNull(stateInternals).persist(builder); + checkStateNotNull(systemTimerInternals).persistTo(builder); + checkStateNotNull(userTimerInternals).persistTo(builder); + } } @Override @@ -893,9 +903,10 @@ public void setBacklogBytes(double backlogBytes) { } @Override - public TimerData getNextFiredTimer(Coder windowCoder) { - if (cachedFiredSystemTimers == null) { - cachedFiredSystemTimers = + public @Nullable TimerData getNextFiredTimer(Coder windowCoder) { + Iterator firedSystemTimers = cachedFiredSystemTimers; + if (firedSystemTimers == null) { + firedSystemTimers = FluentIterable.from(StreamingModeExecutionContext.this.getFiredTimers()) .filter(timer -> timer.getStateFamily().equals(stateFamily)) .transform( @@ -907,16 +918,17 @@ timer, windowCoder, getDrainMode())) windmillTimerData.getWindmillTimerType() == WindmillTimerType.SYSTEM_TIMER) .transform(WindmillTimerData::getTimerData) .iterator(); + cachedFiredSystemTimers = firedSystemTimers; } - if (!cachedFiredSystemTimers.hasNext()) { + if (!firedSystemTimers.hasNext()) { return null; } - TimerData nextTimer = cachedFiredSystemTimers.next(); + TimerData nextTimer = firedSystemTimers.next(); // system timers ( GC timer) must be explicitly deleted if only there is a hold. // if timestamp is not equals to outputTimestamp then there should be a hold if (!nextTimer.getTimestamp().equals(nextTimer.getOutputTimestamp())) { - systemTimerInternals.deleteTimer(nextTimer); + checkStateNotNull(systemTimerInternals).deleteTimer(nextTimer); } return nextTimer; } @@ -950,12 +962,14 @@ private boolean isTimerUnmodified(TimerData timerData) { return updatedTimer == null || updatedTimer.equals(timerData); } - public TimerData getNextFiredUserTimer(Coder windowCoder) { - if (cachedFiredUserTimers == null) { + public @Nullable TimerData getNextFiredUserTimer( + Coder windowCoder) { + PeekingIterator firedUserTimers = cachedFiredUserTimers; + if (firedUserTimers == null) { // This is the first call to getNextFiredUserTimer in this bundle. Extract any user timers // from the bundle // and cache the list for the rest of this bundle processing. - cachedFiredUserTimers = + firedUserTimers = Iterators.peekingIterator( FluentIterable.from(StreamingModeExecutionContext.this.getFiredTimers()) .filter(timer -> timer.getStateFamily().equals(stateFamily)) @@ -969,17 +983,20 @@ timer, windowCoder, getDrainMode())) == WindmillTimerType.USER_TIMER) .transform(WindmillTimerData::getTimerData) .iterator()); + cachedFiredUserTimers = firedUserTimers; } - while (cachedFiredUserTimers.hasNext()) { - TimerData nextInBundle = cachedFiredUserTimers.peek(); + WindmillTimerInternals nonNullUserTimerInternals = checkStateNotNull(this.userTimerInternals); + + while (firedUserTimers.hasNext()) { + TimerData nextInBundle = firedUserTimers.peek(); NavigableSet modifiedUserTimersOrdered = getModifiedUserTimersOrdered(nextInBundle.getDomain()); // If there is a modified timer that is earlier than the next timer in the bundle, try and // fire that first. while (!modifiedUserTimersOrdered.isEmpty() && modifiedUserTimersOrdered.first().compareTo(nextInBundle) <= 0) { - TimerData earlierTimer = modifiedUserTimersOrdered.pollFirst(); + TimerData earlierTimer = checkStateNotNull(modifiedUserTimersOrdered.pollFirst()); if (isTimerUnmodified(earlierTimer)) { // We must delete the timer. This prevents it from being committed to the backing store. // It also handles the @@ -987,15 +1004,15 @@ timer, windowCoder, getDrainMode())) // without deleting the // timer, the runner will still have that future timer stored, and would fire it // spuriously. - userTimerInternals.deleteTimer(earlierTimer); + nonNullUserTimerInternals.deleteTimer(earlierTimer); return earlierTimer; } } // There is no earlier timer to fire, so return the next timer in the bundle. - nextInBundle = cachedFiredUserTimers.next(); + nextInBundle = firedUserTimers.next(); if (isTimerUnmodified(nextInBundle)) { // User timers must be explicitly deleted when delivered, to release the implied hold. - userTimerInternals.deleteTimer(nextInBundle); + nonNullUserTimerInternals.deleteTimer(nextInBundle); return nextInBundle; } } @@ -1029,12 +1046,6 @@ public Iterable getSideInputNotifications() { return StreamingModeExecutionContext.this.getSideInputNotifications(); } - private void ensureStateful(String errorPrefix) { - if (stateFamily == null) { - throw new IllegalStateException(errorPrefix + " for stateless step: " + getNameContext()); - } - } - @Override public void writePCollectionViewData( TupleTag tag, @@ -1043,7 +1054,8 @@ public void writePCollectionViewData( W window, Coder windowCoder) throws IOException { - if (getSerializedKey().size() != 0) { + ByteString serializedKey = checkStateNotNull(getSerializedKey()); + if (serializedKey.size() != 0) { throw new IllegalStateException("writePCollectionViewData must follow a Combine.globally"); } @@ -1053,7 +1065,7 @@ public void writePCollectionViewData( ByteStringOutputStream windowStream = new ByteStringOutputStream(); windowCoder.encode(window, windowStream, Coder.Context.OUTER); - ensureStateful("Tried to write view data"); + String stateFamily = checkStateNotNull(this.stateFamily, "Tried to write view data"); Windmill.GlobalData.Builder builder = Windmill.GlobalData.newBuilder() @@ -1065,7 +1077,7 @@ public void writePCollectionViewData( .setData(dataStream.toByteString()) .setStateFamily(stateFamily); - outputBuilder.addGlobalDataUpdates(builder.build()); + getOutputBuilder().addGlobalDataUpdates(builder.build()); } /** Fetch the given side input asynchronously and return true if it is present. */ @@ -1080,11 +1092,12 @@ public boolean issueSideInputFetch( /** Note that there is data on the current key that is blocked on the given side input. */ @Override public void addBlockingSideInput(Windmill.GlobalDataRequest sideInput) { - ensureStateful("Tried to set global data request"); + String stateFamily = checkStateNotNull(this.stateFamily, "Tried to set global data request"); sideInput = Windmill.GlobalDataRequest.newBuilder(sideInput).setStateFamily(stateFamily).build(); - outputBuilder.addGlobalDataRequests(sideInput); - outputBuilder.addGlobalDataIdRequests(sideInput.getDataId()); + WorkItemCommitRequest.Builder builder = getOutputBuilder(); + builder.addGlobalDataRequests(sideInput); + builder.addGlobalDataIdRequests(sideInput.getDataId()); } /** Note that there is data on the current key that is blocked on the given side inputs. */ @@ -1097,14 +1110,12 @@ public void addBlockingSideInputs(Iterable sideInput @Override public StateInternals stateInternals() { - ensureStateful("Tried to access state"); - return checkNotNull(stateInternals); + return checkStateNotNull(stateInternals, "Tried to access state"); } @Override public TimerInternals timerInternals() { - ensureStateful("Tried to access timers"); - return checkNotNull(systemTimerInternals); + return checkStateNotNull(systemTimerInternals, "Tried to access timers"); } @Override @@ -1113,8 +1124,7 @@ public BundleFinalizer bundleFinalizer() { } public TimerInternals userTimerInternals() { - ensureStateful("Tried to access user timers"); - return checkNotNull(userTimerInternals); + return checkStateNotNull(userTimerInternals, "Tried to access user timers"); } public ImmutableList> flushBundleFinalizerCallbacks() { From e19943808e3d024cbcb95faa8fc9ee15c88c579b Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 08:19:41 +0000 Subject: [PATCH 12/21] make windmillTagEncoding final --- .../worker/StreamingModeExecutionContext.java | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 89ccb576051f..9008bf23f3af 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -127,7 +127,7 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext, Map>> sideInputCache; - private WindmillTagEncoding windmillTagEncoding; + private final WindmillTagEncoding windmillTagEncoding; /** * The current user-facing key for this execution context. * @@ -187,13 +187,10 @@ public StreamingModeExecutionContext( this.throwExceptionOnLargeOutput = throwExceptionOnLargeOutput; StreamingGlobalConfig config = globalConfigHandle.getConfig(); this.operationalLimits = config.operationalLimits(); - this.windmillTagEncoding = getWindmillTagEncoding(config); - } - - private static WindmillTagEncoding getWindmillTagEncoding(StreamingGlobalConfig config) { - return config.enableStateTagEncodingV2() - ? WindmillTagEncodingV2.instance() - : WindmillTagEncodingV1.instance(); + this.windmillTagEncoding = + config.enableStateTagEncodingV2() + ? WindmillTagEncodingV2.instance() + : WindmillTagEncodingV1.instance(); } @VisibleForTesting @@ -262,7 +259,6 @@ public void start( StreamingGlobalConfig config = globalConfigHandle.getConfig(); // Snapshot the limits for entire bundle processing. this.operationalLimits = config.operationalLimits(); - this.windmillTagEncoding = getWindmillTagEncoding(config); this.outputBuilder = outputBuilder; this.sideInputCache.clear(); this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; From 700dfbc8af35b89b70c921bbbd5a1676f6a2e513 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 09:00:25 +0000 Subject: [PATCH 13/21] address comments --- .../worker/StreamingModeExecutionContext.java | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 9008bf23f3af..00fdf67b8d02 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -94,6 +94,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.PeekingIterator; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; import org.joda.time.Instant; @@ -106,7 +107,6 @@ * different WorkItems for the same computation. */ @SuppressWarnings({"deprecation"}) -// TODO(m-trieu) fix nullability issues in StreamingModeExecutionContext.java @NotThreadSafe @Internal public class StreamingModeExecutionContext extends DataflowExecutionContext { @@ -813,9 +813,9 @@ class StepContext extends DataflowExecutionContext.DataflowStepContext private final @Nullable String stateFamily; private final Supplier scopedReadStateSupplier; - private @Nullable WindmillStateInternals stateInternals; - private @Nullable WindmillTimerInternals systemTimerInternals; - private @Nullable WindmillTimerInternals userTimerInternals; + private @MonotonicNonNull WindmillStateInternals stateInternals; + private @MonotonicNonNull WindmillTimerInternals systemTimerInternals; + private @MonotonicNonNull WindmillTimerInternals userTimerInternals; // Lazily initialized private @Nullable Iterator cachedFiredSystemTimers = null; // Lazily initialized @@ -900,6 +900,10 @@ public void setBacklogBytes(double backlogBytes) { @Override public @Nullable TimerData getNextFiredTimer(Coder windowCoder) { + if (stateFamily == null) { + // no timers on stateless stages + return null; + } Iterator firedSystemTimers = cachedFiredSystemTimers; if (firedSystemTimers == null) { firedSystemTimers = @@ -960,6 +964,11 @@ private boolean isTimerUnmodified(TimerData timerData) { public @Nullable TimerData getNextFiredUserTimer( Coder windowCoder) { + if (stateFamily == null) { + // no timers on stateless stages + return null; + } + PeekingIterator firedUserTimers = cachedFiredUserTimers; if (firedUserTimers == null) { // This is the first call to getNextFiredUserTimer in this bundle. Extract any user timers From bc5bee2db78b0672cb58f37af9cefdf290739ce0 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 09:09:15 +0000 Subject: [PATCH 14/21] Move SideInputStateFetcherFactory from start to constructor --- .../worker/StreamingModeExecutionContext.java | 9 ++++----- .../streaming/ComputationWorkExecutor.java | 3 --- .../ComputationWorkExecutorFactory.java | 9 +++++++-- .../work/processing/StreamingWorkScheduler.java | 17 ++++++----------- .../StreamingModeExecutionContextTest.java | 11 ++++------- .../worker/WorkerCustomSourcesTest.java | 11 ++++------- 6 files changed, 25 insertions(+), 35 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index af9f29c7b9ba..fce50fc6ac54 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -151,7 +151,7 @@ public class StreamingModeExecutionContext private @Nullable Work work; private WindmillComputationKey computationKey; - private SideInputStateFetcherFactory sideInputStateFetcherFactory; + private final SideInputStateFetcherFactory sideInputStateFetcherFactory; private SideInputStateFetcher sideInputStateFetcher; // OperationalLimits is updated in start() because a StreamingModeExecutionContext can // be used for processing many work items and these values can change during the context's @@ -214,7 +214,8 @@ public StreamingModeExecutionContext( HotKeyLogger hotKeyLogger, boolean hotKeyLoggingEnabled, String stepName, - String sourceBytesProcessCounterName) { + String sourceBytesProcessCounterName, + SideInputStateFetcherFactory sideInputStateFetcherFactory) { super( counterFactory, metricsContainerRegistry, @@ -233,6 +234,7 @@ public StreamingModeExecutionContext( this.hotKeyLoggingEnabled = hotKeyLoggingEnabled; this.stepName = checkNotNull(stepName); this.sourceBytesProcessCounterName = checkNotNull(sourceBytesProcessCounterName); + this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; } @VisibleForTesting @@ -304,7 +306,6 @@ public void clear() { this.work = null; this.key = null; this.outputBuilder = null; - this.sideInputStateFetcherFactory = null; this.sideInputStateFetcher = null; this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; clearSinkFullHint(); @@ -314,7 +315,6 @@ public void clear() { public void start( Work work, WindmillStateReader stateReader, - SideInputStateFetcherFactory sideInputStateFetcherFactory, WorkExecutor workExecutor, BoundedQueueExecutor workQueueExecutor, BoundedQueueExecutorWorkHandle budgetHandle, @@ -334,7 +334,6 @@ public void start( config.enableStateTagEncodingV2() ? WindmillTagEncodingV2.instance() : WindmillTagEncodingV1.instance(); - this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; startForNewKey(work, stateReader); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java index 31420b212c31..56a1a06362d2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java @@ -24,7 +24,6 @@ import org.apache.beam.runners.dataflow.worker.DataflowWorkExecutor; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.sdk.annotations.Internal; @@ -66,7 +65,6 @@ public static ComputationWorkExecutor.Builder builder() { public final void executeWork( Work work, WindmillStateReader stateReader, - SideInputStateFetcherFactory sideInputStateFetcherFactory, BoundedQueueExecutor workQueueExecutor, BoundedQueueExecutorWorkHandle budgetHandle, KeyTransitionListener keyTransitionListener) @@ -75,7 +73,6 @@ public final void executeWork( .start( work, stateReader, - sideInputStateFetcherFactory, workExecutor(), workQueueExecutor, budgetHandle, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java index fcc6d6bbb743..4a52d9fde771 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java @@ -49,6 +49,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.ComputationWorkExecutor; import org.apache.beam.runners.dataflow.worker.streaming.StageInfo; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.common.worker.MapTaskExecutor; import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; import org.apache.beam.runners.dataflow.worker.util.common.worker.ReadOperation; @@ -99,6 +100,7 @@ final class ComputationWorkExecutorFactory { private final StreamingGlobalConfigHandle globalConfigHandle; private final boolean throwExceptionOnLargeOutput; private final HotKeyLogger hotKeyLogger; + private final SideInputStateFetcherFactory sideInputStateFetcherFactory; ComputationWorkExecutorFactory( DataflowWorkerHarnessOptions options, @@ -109,7 +111,8 @@ final class ComputationWorkExecutorFactory { CounterSet pendingDeltaCounters, IdGenerator idGenerator, StreamingGlobalConfigHandle globalConfigHandle, - HotKeyLogger hotKeyLogger) { + HotKeyLogger hotKeyLogger, + SideInputStateFetcherFactory sideInputStateFetcherFactory) { this.options = options; this.mapTaskExecutorFactory = mapTaskExecutorFactory; this.readerCache = readerCache; @@ -128,6 +131,7 @@ final class ComputationWorkExecutorFactory { this.throwExceptionOnLargeOutput = hasExperiment(options, THROW_EXCEPTIONS_ON_LARGE_OUTPUT_EXPERIMENT); this.hotKeyLogger = hotKeyLogger; + this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; } private static Nodes.ParallelInstructionNode extractReadNode( @@ -282,7 +286,8 @@ private StreamingModeExecutionContext createExecutionContext( hotKeyLogger, hotKeyLoggingEnabled, stepName, - computationState.sourceBytesProcessCounterName()); + computationState.sourceBytesProcessCounterName(), + sideInputStateFetcherFactory); } private DataflowMapTaskExecutor createMapTaskExecutor( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index dc1fd4791fcd..9e28c64b7860 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -80,7 +80,6 @@ public class StreamingWorkScheduler { private final Supplier clock; private final ComputationWorkExecutorFactory computationWorkExecutorFactory; - private final SideInputStateFetcherFactory sideInputStateFetcherFactory; private final FailureTracker failureTracker; private final WorkFailureProcessor workFailureProcessor; private final StreamingCommitFinalizer commitFinalizer; @@ -94,7 +93,6 @@ public StreamingWorkScheduler( Supplier clock, BoundedQueueExecutor workExecutor, ComputationWorkExecutorFactory computationWorkExecutorFactory, - SideInputStateFetcherFactory sideInputStateFetcherFactory, FailureTracker failureTracker, WorkFailureProcessor workFailureProcessor, StreamingCommitFinalizer commitFinalizer, @@ -105,7 +103,6 @@ public StreamingWorkScheduler( this.clock = clock; this.workExecutor = workExecutor; this.computationWorkExecutorFactory = computationWorkExecutorFactory; - this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; this.failureTracker = failureTracker; this.workFailureProcessor = workFailureProcessor; this.commitFinalizer = commitFinalizer; @@ -131,6 +128,9 @@ public static StreamingWorkScheduler create( IdGenerator idGenerator, StreamingGlobalConfigHandle globalConfigHandle, ConcurrentMap stageInfoMap) { + SideInputStateFetcherFactory sideInputStateFetcherFactory = + SideInputStateFetcherFactory.fromOptions(options); + ComputationWorkExecutorFactory computationWorkExecutorFactory = new ComputationWorkExecutorFactory( options, @@ -141,13 +141,13 @@ public static StreamingWorkScheduler create( streamingCounters.pendingDeltaCounters(), idGenerator, globalConfigHandle, - hotKeyLogger); + hotKeyLogger, + sideInputStateFetcherFactory); return new StreamingWorkScheduler( clock, workExecutor, computationWorkExecutorFactory, - SideInputStateFetcherFactory.fromOptions(options), failureTracker, workFailureProcessor, StreamingCommitFinalizer.create(workExecutor, commitFinalizerCleanupExecutor), @@ -388,12 +388,7 @@ private ExecuteWorkResult executeWork( // Blocks while executing work. computationWorkExecutor.executeWork( - work, - stateReader, - sideInputStateFetcherFactory, - workExecutor, - handle, - keyTransitionListener); + work, stateReader, workExecutor, handle, keyTransitionListener); List workBatch; List workItemCommits; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 6d84e9b4b0bf..c1193afeff6b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -48,7 +48,6 @@ import org.apache.beam.runners.core.metrics.ExecutionStateSampler; import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.core.metrics.ExecutionStateTracker.ExecutionState; -import org.apache.beam.runners.dataflow.options.DataflowStreamingPipelineOptions; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.DataflowExecutionContext.DataflowExecutionStateTracker; import org.apache.beam.runners.dataflow.worker.MetricsToCounterUpdateConverter.Kind; @@ -144,7 +143,8 @@ public void setUp() { new HotKeyLogger(), /*hotKeyLoggingEnabled=*/ false, /*stepName=*/ "stepName", - "sourceBytesProcessCounterName"); + "sourceBytesProcessCounterName", + SideInputStateFetcherFactory.fromOptions(options)); } private StreamingModeExecutionContext createTestExecutionContext( @@ -176,7 +176,8 @@ private StreamingModeExecutionContext createTestExecutionContext( new HotKeyLogger(), /*hotKeyLoggingEnabled=*/ false, /*stepName=*/ "stepName", - "sourceBytesProcessCounterName"); + "sourceBytesProcessCounterName", + SideInputStateFetcherFactory.fromOptions(options)); } private static Work createMockWork(Windmill.WorkItem workItem, Watermarks watermarks) { @@ -203,13 +204,9 @@ private void start(StreamingModeExecutionContext context, Work work) { } private void start(StreamingModeExecutionContext context, Work work, Coder keyCoder) { - SideInputStateFetcherFactory sideInputStateFetcherFactory = - SideInputStateFetcherFactory.fromOptions( - options.as(DataflowStreamingPipelineOptions.class)); context.start( work, stateReader, - sideInputStateFetcherFactory, workExecutor, /* workQueueExecutor= */ null, /* budgetHandle= */ null, diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index bd4e40d6570a..0af802ec6760 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -80,7 +80,6 @@ import org.apache.beam.runners.dataflow.DataflowRunner; import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions; import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; -import org.apache.beam.runners.dataflow.options.DataflowStreamingPipelineOptions; import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.PropertyNames; import org.apache.beam.runners.dataflow.worker.DataflowExecutionContext.DataflowExecutionStateTracker; @@ -212,13 +211,9 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla } private void startContext(StreamingModeExecutionContext context, Work work) { - SideInputStateFetcherFactory sideInputStateFetcherFactory = - SideInputStateFetcherFactory.fromOptions( - options.as(DataflowStreamingPipelineOptions.class)); context.start( work, mock(WindmillStateReader.class), - sideInputStateFetcherFactory, mock(WorkExecutor.class), /* workQueueExecutor= */ null, /* budgetHandle= */ null, @@ -641,7 +636,8 @@ public void testReadUnboundedReader() throws Exception { new HotKeyLogger(), /*hotKeyLoggingEnabled=*/ false, /*stepName=*/ "stepName", - "sourceBytesProcessCounterName"); + "sourceBytesProcessCounterName", + SideInputStateFetcherFactory.fromOptions(options)); options.setNumWorkers(5); int maxElements = 10; @@ -1013,7 +1009,8 @@ public void testFailedWorkItemsAbort() throws Exception { new HotKeyLogger(), /*hotKeyLoggingEnabled=*/ false, /*stepName=*/ "stepName", - "sourceBytesProcessCounterName"); + "sourceBytesProcessCounterName", + SideInputStateFetcherFactory.fromOptions(options)); options.setNumWorkers(5); int maxElements = 100; From 47eb7d661ebee256807f5674065ce58c698bf668 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 22:23:22 +0000 Subject: [PATCH 15/21] Address comment --- .../worker/WindmillReaderIteratorBase.java | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java index 20d0c40ae4a3..d0f9eafbcd4c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java @@ -36,8 +36,8 @@ public abstract class WindmillReaderIteratorBase extends NativeReader.NativeReaderIterator> { private final StreamingModeExecutionContext context; private Windmill.WorkItem work; - private int bundleIndex = 0; - private int messageIndex = -1; + private int bundleIndex; + private int messageIndex; private @Nullable WindowedValue current = null; private final ValueProvider skipUndecodableElements; private static final Logger LOG = LoggerFactory.getLogger(WindmillReaderIteratorBase.class); @@ -46,7 +46,7 @@ protected WindmillReaderIteratorBase( StreamingModeExecutionContext context, ValueProvider skipUndecodableElements) { this.context = context; this.skipUndecodableElements = skipUndecodableElements; - this.work = context.getWorkItem(); + resetWorkFromContext(); } @Override @@ -67,9 +67,7 @@ public boolean advance() throws IOException { context.finishKey(); if (context.advance()) { // Transition succeeded! Update iterator references to the new work item - this.work = context.getWork().getWorkItem(); - this.bundleIndex = 0; - this.messageIndex = -1; + resetWorkFromContext(); continue; } @@ -104,6 +102,12 @@ public boolean advance() throws IOException { } } + private void resetWorkFromContext() { + this.work = context.getWork().getWorkItem(); + this.bundleIndex = 0; + this.messageIndex = -1; + } + protected abstract WindowedValue decodeMessage(Windmill.Message message) throws IOException; @Override From 24505dd7e47ee1270742504ccb397031582e8c2c Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 8 Jun 2026 22:27:39 +0000 Subject: [PATCH 16/21] Address comment --- .../dataflow/worker/StreamingModeExecutionContext.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index ef27e38f9ca2..c2cf2f9f7940 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -192,7 +192,7 @@ public interface KeyTransitionListener { // Map> private Map> accumulatedCallbacks = new HashMap<>(); - private final AtomicBoolean workIsFailed = new AtomicBoolean(false); + private final AtomicBoolean workBatchFailed = new AtomicBoolean(false); private @Nullable WindmillStateReader activeStateReader; private long stateBytesRead = 0; private final String sourceBytesProcessCounterName; @@ -259,7 +259,7 @@ public boolean throwExceptionsForLargeOutput() { } public boolean workIsFailed() { - return workIsFailed.get(); + return workBatchFailed.get(); } public boolean getDrainMode() { @@ -298,7 +298,7 @@ public void clear() { this.executedWorks = new ArrayList<>(); this.outputBuilders = new ArrayList<>(); this.accumulatedCallbacks = new HashMap<>(); - this.workIsFailed.set(false); + this.workBatchFailed.set(false); this.sideInputCache.clear(); this.activeStateReader = null; this.activeReader = null; @@ -715,7 +715,7 @@ private void startForNewKey(Work newWork, WindmillStateReader reader) { this.outputBuilder = createOutputBuilder(newWork); this.outputBuilders.add(this.outputBuilder); - newWork.setOnFailureListener(this.workIsFailed); + newWork.setOnFailureListener(this.workBatchFailed); this.executedWorks.add(newWork); logHotKeyIfDetected(newWork, this.key); From 4e0d17448960650e31ac15ed768f25d7acd7bb35 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Tue, 9 Jun 2026 16:34:43 +0000 Subject: [PATCH 17/21] Fix UnderInitialization --- .../runners/dataflow/worker/WindmillReaderIteratorBase.java | 6 ++++-- .../dataflow/worker/WindmillReaderIteratorBaseTest.java | 4 +--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java index d0f9eafbcd4c..134655a72a54 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java @@ -46,7 +46,9 @@ protected WindmillReaderIteratorBase( StreamingModeExecutionContext context, ValueProvider skipUndecodableElements) { this.context = context; this.skipUndecodableElements = skipUndecodableElements; - resetWorkFromContext(); + this.work = context.getWorkItem(); + this.bundleIndex = 0; + this.messageIndex = -1; } @Override @@ -103,7 +105,7 @@ public boolean advance() throws IOException { } private void resetWorkFromContext() { - this.work = context.getWork().getWorkItem(); + this.work = context.getWorkItem(); this.bundleIndex = 0; this.messageIndex = -1; } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java index a56343e3dfb3..b45e0de6447c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java @@ -171,8 +171,6 @@ public void testAdvanceKeyChaining() throws Exception { .build()) .build(); - Work mockWorkB = createMockWork(workItemB); - // Set up context.advance() to mock transition when(mockContext.advance()) .thenAnswer( @@ -183,7 +181,7 @@ public void testAdvanceKeyChaining() throws Exception { public Boolean answer(org.mockito.invocation.InvocationOnMock invocation) { if (count == 0) { count++; - when(mockContext.getWork()).thenReturn(mockWorkB); + when(mockContext.getWorkItem()).thenReturn(workItemB); return true; } return false; From 15e8d0a0e3e5e929493d5e5fae9c1af61191382e Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Wed, 17 Jun 2026 21:07:12 +0000 Subject: [PATCH 18/21] address comments --- .../worker/StreamingModeExecutionContext.java | 14 ++++++------ .../worker/WindowingWindmillReader.java | 15 +++---------- .../streaming/ComputationWorkExecutor.java | 7 +++--- .../processing/StreamingWorkScheduler.java | 22 +++++++------------ 4 files changed, 22 insertions(+), 36 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 9fb013e285b9..54b870e1b06a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -192,7 +192,7 @@ public interface KeyTransitionListener { // Map> private Map> accumulatedCallbacks = new HashMap<>(); - private final AtomicBoolean workBatchFailed = new AtomicBoolean(false); + private AtomicBoolean workBatchFailed = new AtomicBoolean(false); private @Nullable WindmillStateReader activeStateReader; private long stateBytesRead = 0; private final String sourceBytesProcessCounterName; @@ -291,14 +291,14 @@ public byte[] getCurrentRecordOffset() { return checkStateNotNull(activeReader).getCurrentRecordOffset(); } - public void clear() { - for (Work w : executedWorks) { - w.setOnFailureListener(null); - } + /** Reset context before using it on a new bundle */ + public void reset() { this.executedWorks = new ArrayList<>(); this.outputBuilders = new ArrayList<>(); this.accumulatedCallbacks = new HashMap<>(); - this.workBatchFailed.set(false); + // Work from prior bundles might have a reference to the old workBatchFailed. + // If the work gets retried it'll get the new workBatchFailed to notify failure. + this.workBatchFailed = new AtomicBoolean(false); this.sideInputCache.clear(); this.activeStateReader = null; this.activeReader = null; @@ -324,7 +324,7 @@ public void start( BoundedQueueExecutorWorkHandle budgetHandle, @Nullable Coder keyCoder, KeyTransitionListener keyTransitionListener) { - clear(); + reset(); this.keyCoder = keyCoder; this.workExecutor = workExecutor; this.workQueueExecutor = workQueueExecutor; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java index 2003ec001a55..c258e25146e4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java @@ -149,14 +149,8 @@ private boolean isEmpty(KeyedWorkItem keyedWorkItem) { @Override public NativeReaderIterator>> iterator() throws IOException { - final KeyedWorkItem firstKeyedWorkItem = createKeyedWorkItem(); - final boolean firstKeyIsEmpty = isEmpty(firstKeyedWorkItem); - final WindowedValue> firstValue = - new ValueInEmptyWindows<>(firstKeyedWorkItem); - return new NativeReaderIterator>>() { private @Nullable WindowedValue> current = null; - private boolean started = false; @Override public boolean start() throws IOException { @@ -164,14 +158,11 @@ public boolean start() throws IOException { throw new WorkItemCancelledException( checkStateNotNull(context.getWorkItem()).getShardingKey()); } - if (started) { - return false; - } - started = true; - if (firstKeyIsEmpty) { + KeyedWorkItem firstKeyedWorkItem = createKeyedWorkItem(); + if (isEmpty(firstKeyedWorkItem)) { return advance(); // Try to transition immediately if the first key is empty! } - current = firstValue; + current = new ValueInEmptyWindows<>(firstKeyedWorkItem); return true; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java index 56a1a06362d2..b8ee42c8ef88 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java @@ -53,7 +53,7 @@ public static ComputationWorkExecutor.Builder builder() { public abstract DataflowWorkExecutor workExecutor(); - public abstract StreamingModeExecutionContext context(); + abstract StreamingModeExecutionContext context(); public abstract Optional> keyCoder(); @@ -62,7 +62,7 @@ public static ComputationWorkExecutor.Builder builder() { /** * Executes DoFns for the Work. Blocks the calling thread until DoFn(s) have completed execution. */ - public final void executeWork( + public final StreamingModeExecutionContext executeWork( Work work, WindmillStateReader stateReader, BoundedQueueExecutor workQueueExecutor, @@ -79,6 +79,7 @@ public final void executeWork( keyCoder().orElse(null), keyTransitionListener); workExecutor().execute(); + return context(); } /** @@ -87,7 +88,7 @@ public final void executeWork( */ public final void invalidate() { context().invalidateCache(); - context().clear(); + context().reset(); try { workExecutor().close(); } catch (Exception e) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index 3171f540157e..1ab0c4cb8a5f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -196,12 +196,6 @@ private static void setLoggingContextWorkId(@Nullable String workLatencyTracking DataflowWorkerLoggingMDC.setWorkId(workLatencyTrackingId); } - /** Resets logging context of the Thread executing the {@link Work} for logging. */ - private void resetWorkLoggingContext() { - setLoggingContextWorkId(null); - setLoggingContextComputation(null); - } - /** * Schedule work for execution. Work may be executed immediately, or queued and executed in the * future. Only one work may be "active" (currently executing) per key at a time. @@ -262,7 +256,7 @@ private void processWork( long processingStartTimeNanos = System.nanoTime(); StageInfo stageInfo = getStageInfo(computationState); - List workBatch = null; + @Nullable List workBatch = null; try { if (work.isFailed()) { throw new WorkItemCancelledException(workItem.getShardingKey()); @@ -300,7 +294,8 @@ private void processWork( // work items causing exceptions are also accounted in time spent. recordProcessingTime(stageInfo, workBatch, work, processingStartTimeNanos); - resetWorkLoggingContext(); + setLoggingContextWorkId(null); + setLoggingContextComputation(null); sampler.resetForWorkId(work.getLatencyTrackingId()); if (workBatch != null) { for (Work w : workBatch) { @@ -384,16 +379,15 @@ private ExecuteWorkResult executeWork( KeyTransitionListener keyTransitionListener = createKeyTransitionListener(); - // Blocks while executing work. - computationWorkExecutor.executeWork( - work, stateReader, workExecutor, handle, keyTransitionListener); - List workBatch; List workItemCommits; Map> accumulatedCallbacks; long stateBytesRead; { - StreamingModeExecutionContext context = computationWorkExecutor.context(); + // Blocks while executing work. + StreamingModeExecutionContext context = + computationWorkExecutor.executeWork( + work, stateReader, workExecutor, handle, keyTransitionListener); if (context.workIsFailed()) { throw new WorkItemCancelledException(work.getWorkItem().getShardingKey()); } @@ -405,7 +399,7 @@ private ExecuteWorkResult executeWork( accumulatedCallbacks = context.getAccumulatedCallbacks(); stateBytesRead = context.getStateBytesRead(); - context.clear(); // Don't use context after this. + context.reset(); // Don't use context after this. } // Release the execution state for another thread to use. computationState.releaseComputationWorkExecutor(computationWorkExecutor); From 996f561e2895dd3455dede203132623bbb3def1c Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Tue, 30 Jun 2026 07:39:28 +0000 Subject: [PATCH 19/21] address comments --- .../worker/StreamingModeExecutionContext.java | 120 ++++++++++-------- .../processing/StreamingWorkScheduler.java | 48 +++---- .../StreamingModeExecutionContextTest.java | 23 ++-- .../worker/WorkerCustomSourcesTest.java | 22 ++-- 4 files changed, 113 insertions(+), 100 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 54b870e1b06a..12970bef66c8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -26,7 +26,7 @@ import java.io.Closeable; import java.io.IOException; import java.util.ArrayList; -import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -79,6 +79,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTimerData; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; import org.apache.beam.sdk.metrics.MetricsContainer; @@ -159,9 +160,8 @@ public class StreamingModeExecutionContext /** * Current reader used for processing {@link Work}. Set by calling {@link - * #setActiveReader(UnboundedReader)}, reset to null and cached when state is persisted {@link - * #flushState()}, or set to null and closed when {@link StreamingModeExecutionContext} is - * invalidated. + * #setActiveReader(UnboundedReader)}, reset to null and cached when state is persisted or set to + * null and closed when {@link StreamingModeExecutionContext} is invalidated. */ private @Nullable UnboundedReader activeReader; @@ -187,11 +187,11 @@ public interface KeyTransitionListener { @SuppressWarnings("UnusedVariable") private @Nullable KeyTransitionListener keyTransitionListener; - private List executedWorks = new ArrayList<>(); - private List outputBuilders = new ArrayList<>(); + private List executedWorks = Collections.emptyList(); + private List outputBuilders = Collections.emptyList(); // Map> - private Map> accumulatedCallbacks = new HashMap<>(); + private Map> finalizationCallbacks = Collections.emptyMap(); private AtomicBoolean workBatchFailed = new AtomicBoolean(false); private @Nullable WindmillStateReader activeStateReader; private long stateBytesRead = 0; @@ -293,9 +293,11 @@ public byte[] getCurrentRecordOffset() { /** Reset context before using it on a new bundle */ public void reset() { - this.executedWorks = new ArrayList<>(); - this.outputBuilders = new ArrayList<>(); - this.accumulatedCallbacks = new HashMap<>(); + // these lists and maps are returned to callers after processing + // don't clear and reuse, instead reset the reference. + this.executedWorks = Collections.emptyList(); + this.outputBuilders = Collections.emptyList(); + this.finalizationCallbacks = Collections.emptyMap(); // Work from prior bundles might have a reference to the old workBatchFailed. // If the work gets retried it'll get the new workBatchFailed to notify failure. this.workBatchFailed = new AtomicBoolean(false); @@ -323,8 +325,12 @@ public void start( BoundedQueueExecutor workQueueExecutor, BoundedQueueExecutorWorkHandle budgetHandle, @Nullable Coder keyCoder, - KeyTransitionListener keyTransitionListener) { + KeyTransitionListener keyTransitionListener) + throws CoderException { reset(); + this.executedWorks = new ArrayList<>(); + this.outputBuilders = new ArrayList<>(); + this.finalizationCallbacks = new HashMap<>(); this.keyCoder = keyCoder; this.workExecutor = workExecutor; this.workQueueExecutor = workQueueExecutor; @@ -338,7 +344,7 @@ public void start( startForNewKey(work, stateReader); } - private @Nullable Object decodeKey(Work work) { + private @Nullable Object decodeKey(Work work) throws CoderException { // If the read output KVs, then we can decode Windmill's byte key into userland // key object and provide it to the execution context for use with per-key state. // Otherwise, we pass null. @@ -348,6 +354,8 @@ public void start( if (keyCoder != null) { try { return keyCoder.decode(work.getWorkItem().getKey().newInput(), Coder.Context.OUTER); + } catch (CoderException e) { + throw e; } catch (IOException e) { throw new RuntimeException("Failed to decode key during processing", e); } @@ -380,30 +388,29 @@ private void startStepContexts( Instant processingTime, WindmillStateCache.ForKey cacheForKey, Watermarks watermarks) { - Collection stepContexts = getAllStepContexts(); - for (StepContext stepContext : stepContexts) { + for (StepContext stepContext : getAllStepContexts()) { stepContext.start(stateReader, processingTime, cacheForKey, watermarks); } } public void finishKey() { + WorkExecutor localExecutor = + checkStateNotNull(workExecutor, "workExecutor must be set before calling finishKey()"); if (finishKeyCalled) { return; } + this.finishKeyCalled = true; if (activeStateReader != null) { this.stateBytesRead += activeStateReader.getBytesRead(); } if (sideInputStateFetcher != null) { this.stateBytesRead += sideInputStateFetcher.getBytesRead(); } - checkStateNotNull(workExecutor, "workExecutor must be set before calling finishKey()"); try { - workExecutor.finishKey(key); + localExecutor.finishKey(key); } catch (Exception e) { throw new RuntimeException(e); } - this.finishKeyCalled = true; - flushStateInternal(); } @@ -518,26 +525,6 @@ private List getFiredTimers() { return getWorkItem().getTimers().getTimersList(); } - public @Nullable ByteString getSerializedKey() { - return work == null ? null : work.getWorkItem().getKey(); - } - - public WindmillComputationKey getComputationKey() { - return checkStateNotNull(computationKey); - } - - public long getWorkToken() { - return getWorkItem().getWorkToken(); - } - - public Windmill.WorkItem getWorkItem() { - return checkStateNotNull( - work, - "work is null. A call to StreamingModeExecutionContext.start(...) is required to set" - + " work for execution.") - .getWorkItem(); - } - public Windmill.WorkItemCommitRequest.Builder getOutputBuilder() { return checkStateNotNull(outputBuilder); } @@ -674,7 +661,7 @@ private void flushStateInternal() { getOutputBuilder().setSourceBacklogBytes(backlogBytes); } - this.accumulatedCallbacks.putAll(callbacks); + this.finalizationCallbacks.putAll(callbacks); getOutputBuilder() .setSourceBytesProcessed(computeSourceBytesProcessed(sourceBytesProcessCounterName)); @@ -695,15 +682,12 @@ private final long computeSourceBytesProcessed(String sourceBytesCounterName) { .orElse(0L); } - public Map> flushState() { - return accumulatedCallbacks; - } - public boolean advance() { + // TODO: get more work from workQueueExecutor and merge into the bundle here return false; } - private void startForNewKey(Work newWork, WindmillStateReader reader) { + private void startForNewKey(Work newWork, WindmillStateReader reader) throws CoderException { newWork.setState(Work.State.PROCESSING); if (keyTransitionListener != null && this.work != null && this.work != newWork) { keyTransitionListener.onKeyTransition(this.work, newWork); @@ -730,7 +714,9 @@ private void startForNewKey(Work newWork, WindmillStateReader reader) { // Re-initialize state cache and state/timer internals across all step contexts Instant processingTime = computeProcessingTime(newWork.getWorkItem().getTimers().getTimersList()); - if (!getAllStepContexts().isEmpty()) { + if (getAllStepContexts().isEmpty()) { + checkState(this.activeStateReader == null); + } else { // This must be only created once for a workItem as token validation will fail if the same // work token is reused. WindmillStateCache.ForKey cacheForKey = @@ -738,19 +724,15 @@ private void startForNewKey(Work newWork, WindmillStateReader reader) { getComputationKey(), newWork.getWorkItem().getCacheToken(), getWorkToken()); this.activeStateReader = reader; startStepContexts(reader, processingTime, cacheForKey, newWork.watermarks()); - } else { - this.activeStateReader = null; } } - public List getExecutedWorks() { - return executedWorks; - } - + // Returns state bytes read during the bundle execution public long getStateBytesRead() { return stateBytesRead; } + // Returns list of commit requests from the bundle public List getWorkItemCommits() { List commits = new ArrayList<>(outputBuilders.size()); for (Windmill.WorkItemCommitRequest.Builder builder : outputBuilders) { @@ -759,18 +741,50 @@ public List getWorkItemCommits() { return commits; } - public Map> getAccumulatedCallbacks() { - return accumulatedCallbacks; + // Returns list of Work that was executed in the bundle + public List getExecutedWorks() { + return executedWorks; } + // Returns finalization callbacks recorded during the bundle execution + public Map> getFinalizationCallbacks() { + return finalizationCallbacks; + } + + // Returns the current key being processed or null if an unkeyed stage. public @Nullable Object getKey() { return key; } + // Returns the current Work being processed. public Work getWork() { return checkStateNotNull(work); } + // Returns the serialized windmill key for the current Work + public @Nullable ByteString getSerializedKey() { + return work == null ? null : work.getWorkItem().getKey(); + } + + // Returns the serialized windmill key for the current Work + public WindmillComputationKey getComputationKey() { + return checkStateNotNull(computationKey); + } + + // Returns the windmill work token for the current Work + public long getWorkToken() { + return getWorkItem().getWorkToken(); + } + + // Returns the windmill WorkItem proto for the current Work + public Windmill.WorkItem getWorkItem() { + return checkStateNotNull( + work, + "work is null. A call to StreamingModeExecutionContext.start(...) is required to set" + + " work for execution.") + .getWorkItem(); + } + @Nullable String getStateFamily(NameContext nameContext) { return nameContext.userName() == null ? null : stateNameMap.get(nameContext.userName()); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index 1ab0c4cb8a5f..f97a1748ef08 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.processing; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + import com.google.api.services.dataflow.model.MapTask; import com.google.auto.value.AutoValue; import java.util.List; @@ -60,7 +62,6 @@ import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.fn.IdGenerator; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Instant; @@ -267,7 +268,7 @@ private void processWork( workBatch = executeWorkResult.workBatch(); List workItemCommits = executeWorkResult.workItemCommits(); - commitFinalizer.cacheCommitFinalizers(executeWorkResult.accumulatedCallbacks()); + commitFinalizer.cacheCommitFinalizers(executeWorkResult.finalizationCallbacks()); commitWorkBatch(computationState, workBatch, workItemCommits); @@ -290,19 +291,16 @@ private void processWork( throw ExceptionUtils.safeWrapThrowableAsException(t2); } } finally { + List processedWorkBatch = workBatch != null ? workBatch : ImmutableList.of(work); // Update total processing time counters. Updating in finally clause ensures that // work items causing exceptions are also accounted in time spent. - recordProcessingTime(stageInfo, workBatch, work, processingStartTimeNanos); + recordProcessingTime(stageInfo, processedWorkBatch, processingStartTimeNanos); setLoggingContextWorkId(null); setLoggingContextComputation(null); sampler.resetForWorkId(work.getLatencyTrackingId()); - if (workBatch != null) { - for (Work w : workBatch) { - w.setProcessingThreadName(""); - } - } else { - work.setProcessingThreadName(""); + for (Work w : processedWorkBatch) { + w.setProcessingThreadName(""); } } } @@ -340,7 +338,7 @@ private void recordProcessingStats( long totalStateBytesRead) { long totalStateBytesWritten = 0; long totalShuffleBytesRead = 0; - Preconditions.checkState(workBatch.size() == workItemCommits.size()); + checkState(workBatch.size() == workItemCommits.size()); for (int i = 0; i < workBatch.size(); i++) { Windmill.WorkItem workItem = workBatch.get(i).getWorkItem(); Windmill.WorkItemCommitRequest commit = workItemCommits.get(i); @@ -381,7 +379,7 @@ private ExecuteWorkResult executeWork( List workBatch; List workItemCommits; - Map> accumulatedCallbacks; + Map> finalizationCallbacks; long stateBytesRead; { // Blocks while executing work. @@ -396,7 +394,7 @@ private ExecuteWorkResult executeWork( // context workBatch = context.getExecutedWorks(); workItemCommits = context.getWorkItemCommits(); - accumulatedCallbacks = context.getAccumulatedCallbacks(); + finalizationCallbacks = context.getFinalizationCallbacks(); stateBytesRead = context.getStateBytesRead(); context.reset(); // Don't use context after this. @@ -406,7 +404,7 @@ private ExecuteWorkResult executeWork( computationWorkExecutor = null; return ExecuteWorkResult.create( - workBatch, workItemCommits, accumulatedCallbacks, stateBytesRead); + workBatch, workItemCommits, finalizationCallbacks, stateBytesRead); } catch (Throwable t) { if (computationWorkExecutor != null) { // If processing failed due to a thrown exception, close the executionState. Do not @@ -442,8 +440,8 @@ private void commitWorkBatch( ComputationState computationState, List workBatch, List workItemCommits) { - Preconditions.checkState( - workBatch.size() == 1, "Expected single-key work batch, got: " + workBatch.size()); + checkState(workBatch.size() == 1, "Expected single-key work batch, got: " + workBatch.size()); + checkState(workBatch.size() == workItemCommits.size()); commitSingleKeyWork(computationState, workBatch.get(0), workItemCommits.get(0)); } @@ -463,14 +461,11 @@ private void commitSingleKeyWork( } private void recordProcessingTime( - StageInfo stageInfo, - @Nullable List worksToCleanup, - Work work, - long processingStartTimeNanos) { + StageInfo stageInfo, List worksToCleanup, long processingStartTimeNanos) { long processingTimeMsecs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); - if (anyWorkHasTimers(worksToCleanup, work)) { + if (anyWorkHasTimers(worksToCleanup)) { // Attribute all the processing to timers if the work item contains any timers. // Tests show that work items rarely contain both timers and message bundles. It should // be a fairly close approximation. @@ -480,11 +475,8 @@ private void recordProcessingTime( } } - private static boolean anyWorkHasTimers(@Nullable List works, Work primaryWork) { - if (works != null && !works.isEmpty()) { - return works.stream().anyMatch(w -> w.getWorkItem().hasTimers()); - } - return primaryWork.getWorkItem().hasTimers(); + private static boolean anyWorkHasTimers(List works) { + return works.stream().anyMatch(w -> w.getWorkItem().hasTimers()); } private KeyTransitionListener createKeyTransitionListener() { @@ -500,10 +492,10 @@ abstract static class ExecuteWorkResult { static ExecuteWorkResult create( List workBatch, List workItemCommits, - Map> accumulatedCallbacks, + Map> finalizationCallbacks, long stateBytesRead) { return new AutoValue_StreamingWorkScheduler_ExecuteWorkResult( - workBatch, workItemCommits, accumulatedCallbacks, stateBytesRead); + workBatch, workItemCommits, finalizationCallbacks, stateBytesRead); } abstract List workBatch(); @@ -511,7 +503,7 @@ static ExecuteWorkResult create( abstract List workItemCommits(); // Map> - abstract Map> accumulatedCallbacks(); + abstract Map> finalizationCallbacks(); abstract long stateBytesRead(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 534d51e2b88c..ed1d6534d76c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -73,6 +73,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.metrics.MetricsContainer; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.state.TimeDomain; @@ -176,14 +177,18 @@ private void start(StreamingModeExecutionContext context, Work work) { } private void start(StreamingModeExecutionContext context, Work work, Coder keyCoder) { - context.start( - work, - stateReader, - workExecutor, - /* workQueueExecutor= */ null, - /* budgetHandle= */ null, - keyCoder, - /* keyTransitionListener= */ (k, c) -> {}); + try { + context.start( + work, + stateReader, + workExecutor, + /* workQueueExecutor= */ null, + /* budgetHandle= */ null, + keyCoder, + /* keyTransitionListener= */ (k, c) -> {}); + } catch (CoderException e) { + throw new RuntimeException(e); + } } @Test @@ -209,7 +214,6 @@ public void testTimerInternalsSetTimer() throws Exception { TimeDomain.EVENT_TIME, CausedByDrain.NORMAL)); executionContext.finishKey(); - executionContext.flushState(); Windmill.WorkItemCommitRequest.Builder outputBuilder = executionContext.getOutputBuilder(); Windmill.Timer timer = outputBuilder.buildPartial().getOutputTimers(0); @@ -463,7 +467,6 @@ public void testSetBacklogBytes() { stepContext.setBacklogBytes(1234.0); executionContext.finishKey(); - executionContext.flushState(); assertEquals(1234, executionContext.getOutputBuilder().getSourceBacklogBytes()); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index 0af802ec6760..f402ffc97800 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -107,6 +107,7 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.extensions.gcp.auth.TestCredential; @@ -211,14 +212,18 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla } private void startContext(StreamingModeExecutionContext context, Work work) { - context.start( - work, - mock(WindmillStateReader.class), - mock(WorkExecutor.class), - /* workQueueExecutor= */ null, - /* budgetHandle= */ null, - /* keyCoder= */ null, - /* keyTransitionListener= */ mock(KeyTransitionListener.class)); + try { + context.start( + work, + mock(WindmillStateReader.class), + mock(WorkExecutor.class), + /* workQueueExecutor= */ null, + /* budgetHandle= */ null, + /* keyCoder= */ null, + /* keyTransitionListener= */ mock(KeyTransitionListener.class)); + } catch (CoderException e) { + throw new RuntimeException(e); + } } private static class SourceProducingSubSourcesInSplit extends MockSource { @@ -693,7 +698,6 @@ public void testReadUnboundedReader() throws Exception { numReadOnThisIteration, lessThanOrEqualTo(debugOptions.getUnboundedReaderMaxElements())); // Extract and verify state modifications. - context.flushState(); state = context.getOutputBuilder().getSourceStateUpdates().getState(); // CountingSource's watermark is the last record + 1. i is now one past the last record, // so the expected watermark is i millis. From a51c2f7e4c9b50c64bd5599bafba626578962a0f Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Tue, 30 Jun 2026 09:36:16 +0000 Subject: [PATCH 20/21] fix test --- .../worker/StreamingModeExecutionContextTest.java | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 31500d7c8212..89ec3e36be16 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -503,29 +503,21 @@ public void testStart_internalKeyDecoding() throws Exception { @Test public void testInternalsPoisonedAfterFlushState() throws Exception { - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); NameContext nameContext = NameContextsForTests.nameContextForTest(); DataflowOperationContext operationContext = executionContext.createOperationContext(nameContext); StreamingModeExecutionContext.StepContext stepContext = executionContext.getStepContext(operationContext); - executionContext.start( - "key", + start( createMockWork( Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); TimerInternals timerInternals = stepContext.timerInternals(); StateInternals stateInternals = stepContext.stateInternals(); executionContext.finishKey(); - executionContext.flushState(); // Verify timerInternals is poisoned try { From 3036b54649cd79c6092993a6a5c7b2e087b31c0e Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Tue, 30 Jun 2026 10:25:36 +0000 Subject: [PATCH 21/21] address comments --- .../worker/StreamingModeExecutionContext.java | 15 +++++++++------ .../work/processing/StreamingWorkScheduler.java | 1 + .../worker/StreamingDataflowWorkerTest.java | 4 ++-- .../worker/StreamingModeExecutionContextTest.java | 3 +++ .../dataflow/worker/WorkerCustomSourcesTest.java | 1 + 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 76ce7d7c9907..9401cc5f8ed9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -411,6 +411,10 @@ public void finishKey() { } catch (Exception e) { throw new RuntimeException(e); } + } + + public void flushState() { + checkState(finishKeyCalled, "finishKey must be called before flushState"); flushStateInternal(); } @@ -758,7 +762,10 @@ public Map> getFinalizationCallbacks() { // Returns the current Work being processed. public Work getWork() { - return checkStateNotNull(work); + return checkStateNotNull( + work, + "work is null. A call to StreamingModeExecutionContext.start(...) is required to set" + + " work for execution."); } // Returns the serialized windmill key for the current Work @@ -778,11 +785,7 @@ public long getWorkToken() { // Returns the windmill WorkItem proto for the current Work public Windmill.WorkItem getWorkItem() { - return checkStateNotNull( - work, - "work is null. A call to StreamingModeExecutionContext.start(...) is required to set" - + " work for execution.") - .getWorkItem(); + return getWork().getWorkItem(); } @Nullable diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index f97a1748ef08..e9b85d720d2b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -389,6 +389,7 @@ private ExecuteWorkResult executeWork( if (context.workIsFailed()) { throw new WorkItemCancelledException(work.getWorkItem().getShardingKey()); } + context.flushState(); // Retrieve executed works, work item commits, and accumulated callbacks from execution // context diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 7eaa048204ff..23730bc57705 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -3520,8 +3520,8 @@ public void testExceptionInvalidatesCache() throws Exception { } // Ensure that the invalidated dofn had tearDown called on them. - assertEquals(2, TestExceptionInvalidatesCacheFn.tearDownCallCount.get()); - assertEquals(3, TestExceptionInvalidatesCacheFn.setupCallCount.get()); + assertEquals(1, TestExceptionInvalidatesCacheFn.tearDownCallCount.get()); + assertEquals(2, TestExceptionInvalidatesCacheFn.setupCallCount.get()); worker.stop(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 89ec3e36be16..561596f68d0f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -215,6 +215,7 @@ public void testTimerInternalsSetTimer() throws Exception { TimeDomain.EVENT_TIME, CausedByDrain.NORMAL)); executionContext.finishKey(); + executionContext.flushState(); Windmill.WorkItemCommitRequest.Builder outputBuilder = executionContext.getOutputBuilder(); Windmill.Timer timer = outputBuilder.buildPartial().getOutputTimers(0); @@ -468,6 +469,7 @@ public void testSetBacklogBytes() { stepContext.setBacklogBytes(1234.0); executionContext.finishKey(); + executionContext.flushState(); assertEquals(1234, executionContext.getOutputBuilder().getSourceBacklogBytes()); } @@ -518,6 +520,7 @@ public void testInternalsPoisonedAfterFlushState() throws Exception { StateInternals stateInternals = stepContext.stateInternals(); executionContext.finishKey(); + executionContext.flushState(); // Verify timerInternals is poisoned try { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index f402ffc97800..31ea1bab07af 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -698,6 +698,7 @@ public void testReadUnboundedReader() throws Exception { numReadOnThisIteration, lessThanOrEqualTo(debugOptions.getUnboundedReaderMaxElements())); // Extract and verify state modifications. + context.flushState(); state = context.getOutputBuilder().getSourceStateUpdates().getState(); // CountingSource's watermark is the last record + 1. i is now one past the last record, // so the expected watermark is i millis.