diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json index e623d3373a93..50d17c108f2e 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run!", - "modification": 1, + "modification": 2, } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 66df22333944..9401cc5f8ed9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -18,13 +18,15 @@ package org.apache.beam.runners.dataflow.worker; import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import com.google.api.services.dataflow.model.CounterUpdate; import com.google.api.services.dataflow.model.SideInputInfo; import java.io.Closeable; import java.io.IOException; -import java.util.Collection; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -33,6 +35,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; @@ -45,10 +48,10 @@ import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.core.metrics.ExecutionStateTracker.ExecutionState; import org.apache.beam.runners.dataflow.worker.DataflowOperationContext.DataflowExecutionState; -import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.StepContext; import org.apache.beam.runners.dataflow.worker.counters.CounterFactory; import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope; +import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; @@ -56,6 +59,10 @@ import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInput; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; +import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; +import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataId; @@ -72,6 +79,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTimerData; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; import org.apache.beam.sdk.metrics.MetricsContainer; @@ -109,7 +117,8 @@ @SuppressWarnings({"deprecation"}) @NotThreadSafe @Internal -public class StreamingModeExecutionContext extends DataflowExecutionContext { +public class StreamingModeExecutionContext + extends DataflowExecutionContext { private static final Logger LOG = LoggerFactory.getLogger(StreamingModeExecutionContext.class); @@ -141,6 +150,7 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext activeReader; private @Nullable WorkExecutor workExecutor; private boolean finishKeyCalled = false; + @SuppressWarnings("UnusedVariable") + private @Nullable BoundedQueueExecutor workQueueExecutor; + + @SuppressWarnings("UnusedVariable") + private @Nullable BoundedQueueExecutorWorkHandle budgetHandle; + + private final HotKeyLogger hotKeyLogger; + private final boolean hotKeyLoggingEnabled; + private final String stepName; + private @Nullable Coder keyCoder; + + // Key switch listener to delegate MDC logging context and thread name updates + public interface KeyTransitionListener { + void onKeyTransition(Work oldWork, Work newWork); + } + + @SuppressWarnings("UnusedVariable") + private @Nullable KeyTransitionListener keyTransitionListener; + + private List executedWorks = Collections.emptyList(); + private List outputBuilders = Collections.emptyList(); + + // Map> + private Map> finalizationCallbacks = Collections.emptyMap(); + private AtomicBoolean workBatchFailed = new AtomicBoolean(false); + private @Nullable WindmillStateReader activeStateReader; + private long stateBytesRead = 0; + private final String sourceBytesProcessCounterName; + public StreamingModeExecutionContext( CounterFactory counterFactory, String computationId, @@ -170,7 +208,12 @@ public StreamingModeExecutionContext( StreamingModeExecutionStateRegistry executionStateRegistry, StreamingGlobalConfigHandle globalConfigHandle, long sinkByteLimit, - boolean throwExceptionOnLargeOutput) { + boolean throwExceptionOnLargeOutput, + HotKeyLogger hotKeyLogger, + boolean hotKeyLoggingEnabled, + String stepName, + String sourceBytesProcessCounterName, + SideInputStateFetcherFactory sideInputStateFetcherFactory) { super( counterFactory, metricsContainerRegistry, @@ -185,6 +228,11 @@ public StreamingModeExecutionContext( this.stateCache = stateCache; this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; this.throwExceptionOnLargeOutput = throwExceptionOnLargeOutput; + this.hotKeyLogger = checkNotNull(hotKeyLogger); + this.hotKeyLoggingEnabled = hotKeyLoggingEnabled; + this.stepName = checkNotNull(stepName); + this.sourceBytesProcessCounterName = checkNotNull(sourceBytesProcessCounterName); + this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; StreamingGlobalConfig config = globalConfigHandle.getConfig(); this.operationalLimits = config.operationalLimits(); this.windmillTagEncoding = @@ -211,7 +259,7 @@ public boolean throwExceptionsForLargeOutput() { } public boolean workIsFailed() { - return work != null && work.isFailed(); + return workBatchFailed.get(); } public boolean getDrainMode() { @@ -243,50 +291,131 @@ public byte[] getCurrentRecordOffset() { return checkStateNotNull(activeReader).getCurrentRecordOffset(); } + /** Reset context before using it on a new bundle */ + public void reset() { + // these lists and maps are returned to callers after processing + // don't clear and reuse, instead reset the reference. + this.executedWorks = Collections.emptyList(); + this.outputBuilders = Collections.emptyList(); + this.finalizationCallbacks = Collections.emptyMap(); + // Work from prior bundles might have a reference to the old workBatchFailed. + // If the work gets retried it'll get the new workBatchFailed to notify failure. + this.workBatchFailed = new AtomicBoolean(false); + this.sideInputCache.clear(); + this.activeStateReader = null; + this.activeReader = null; + this.keyCoder = null; + this.workExecutor = null; + this.workQueueExecutor = null; + this.budgetHandle = null; + this.keyTransitionListener = null; + this.work = null; + this.key = null; + this.outputBuilder = null; + this.sideInputStateFetcher = null; + this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; + clearSinkFullHint(); + this.stateBytesRead = 0; + } + public void start( - @Nullable Object key, Work work, WindmillStateReader stateReader, - SideInputStateFetcher sideInputStateFetcher, - Windmill.WorkItemCommitRequest.Builder outputBuilder, - WorkExecutor workExecutor) { - this.key = key; - this.work = work; + WorkExecutor workExecutor, + BoundedQueueExecutor workQueueExecutor, + BoundedQueueExecutorWorkHandle budgetHandle, + @Nullable Coder keyCoder, + KeyTransitionListener keyTransitionListener) + throws CoderException { + reset(); + this.executedWorks = new ArrayList<>(); + this.outputBuilders = new ArrayList<>(); + this.finalizationCallbacks = new HashMap<>(); + this.keyCoder = keyCoder; this.workExecutor = workExecutor; - this.finishKeyCalled = false; - this.computationKey = WindmillComputationKey.create(computationId, work.getShardedKey()); - this.sideInputStateFetcher = sideInputStateFetcher; + this.workQueueExecutor = workQueueExecutor; + this.budgetHandle = budgetHandle; + this.keyTransitionListener = keyTransitionListener; + StreamingGlobalConfig config = globalConfigHandle.getConfig(); // Snapshot the limits for entire bundle processing. this.operationalLimits = config.operationalLimits(); - this.outputBuilder = outputBuilder; - this.sideInputCache.clear(); - this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; - clearSinkFullHint(); - Instant processingTime = computeProcessingTime(work.getWorkItem().getTimers().getTimersList()); + startForNewKey(work, stateReader); + } - Collection stepContexts = getAllStepContexts(); - if (!stepContexts.isEmpty()) { - // This must be only created once for the workItem as token validation will fail if the same - // work token is reused. - WindmillStateCache.ForKey cacheForKey = - stateCache.forKey(getComputationKey(), getWorkItem().getCacheToken(), getWorkToken()); - for (StepContext stepContext : stepContexts) { - stepContext.start(stateReader, processingTime, cacheForKey, work.watermarks()); + private @Nullable Object decodeKey(Work work) throws CoderException { + // If the read output KVs, then we can decode Windmill's byte key into userland + // key object and provide it to the execution context for use with per-key state. + // Otherwise, we pass null. + // + // The coder type that will be present is: + // WindowedValueCoder(TimerOrElementCoder(KvCoder)) + if (keyCoder != null) { + try { + return keyCoder.decode(work.getWorkItem().getKey().newInput(), Coder.Context.OUTER); + } catch (CoderException e) { + throw e; + } catch (IOException e) { + throw new RuntimeException("Failed to decode key during processing", e); } } + return null; + } + + private Windmill.WorkItemCommitRequest.Builder createOutputBuilder(Work work) { + return Windmill.WorkItemCommitRequest.newBuilder() + .setKey(work.getWorkItem().getKey()) + .setShardingKey(work.getWorkItem().getShardingKey()) + .setWorkToken(work.getWorkItem().getWorkToken()) + .setCacheToken(work.getWorkItem().getCacheToken()); + } + + private void logHotKeyIfDetected(Work work, @Nullable Object decodedKey) { + if (work.getWorkItem().hasHotKeyInfo()) { + Windmill.HotKeyInfo hotKeyInfo = work.getWorkItem().getHotKeyInfo(); + Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 1000); + if (decodedKey != null && hotKeyLoggingEnabled) { + hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, decodedKey); + } else { + hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge); + } + } + } + + private void startStepContexts( + WindmillStateReader stateReader, + Instant processingTime, + WindmillStateCache.ForKey cacheForKey, + Watermarks watermarks) { + for (StepContext stepContext : getAllStepContexts()) { + stepContext.start(stateReader, processingTime, cacheForKey, watermarks); + } } public void finishKey() { - checkState(!finishKeyCalled, "finishKey was already called"); - checkStateNotNull(workExecutor, "workExecutor must be set before calling finishKey()"); + WorkExecutor localExecutor = + checkStateNotNull(workExecutor, "workExecutor must be set before calling finishKey()"); + if (finishKeyCalled) { + return; + } + this.finishKeyCalled = true; + if (activeStateReader != null) { + this.stateBytesRead += activeStateReader.getBytesRead(); + } + if (sideInputStateFetcher != null) { + this.stateBytesRead += sideInputStateFetcher.getBytesRead(); + } try { - workExecutor.finishKey(key); + localExecutor.finishKey(key); } catch (Exception e) { throw new RuntimeException(e); } - this.finishKeyCalled = true; + } + + public void flushState() { + checkState(finishKeyCalled, "finishKey must be called before flushState"); + flushStateInternal(); } /** @@ -400,26 +529,6 @@ private List getFiredTimers() { return getWorkItem().getTimers().getTimersList(); } - public @Nullable ByteString getSerializedKey() { - return work == null ? null : work.getWorkItem().getKey(); - } - - public WindmillComputationKey getComputationKey() { - return checkStateNotNull(computationKey); - } - - public long getWorkToken() { - return getWorkItem().getWorkToken(); - } - - public Windmill.WorkItem getWorkItem() { - return checkStateNotNull( - work, - "work is null. A call to StreamingModeExecutionContext.start(...) is required to set" - + " work for execution.") - .getWorkItem(); - } - public Windmill.WorkItemCommitRequest.Builder getOutputBuilder() { return checkStateNotNull(outputBuilder); } @@ -440,20 +549,23 @@ public void setActiveReader(UnboundedReader reader) { /** Invalidate the state and reader caches for this computation and key. */ public void invalidateCache() { - ByteString key = getSerializedKey(); - if (key != null) { - readerCache.invalidateReader(getComputationKey()); - if (activeReader != null) { - try { - activeReader.close(); - } catch (IOException e) { - LOG.warn( - "Failed to close reader for {}-{}", computationId, getWorkItem().getShardingKey(), e); - } + for (Work w : executedWorks) { + WindmillComputationKey compKey = + WindmillComputationKey.create(computationId, w.getShardedKey()); + readerCache.invalidateReader(compKey); + stateCache.invalidate(w.getShardedKey()); + } + if (activeReader != null) { + try { + activeReader.close(); + } catch (IOException e) { + Windmill.WorkItem workItem = getWorkItem(); + long shardingKey = workItem != null ? workItem.getShardingKey() : -1L; + LOG.warn("Failed to close reader for {}-{}", computationId, shardingKey, e); } - activeReader = null; - stateCache.invalidate(key, getWorkItem().getShardingKey()); } + activeReader = null; + activeStateReader = null; } public UnboundedSource.@Nullable CheckpointMark getReaderCheckpoint( @@ -469,8 +581,7 @@ public void invalidateCache() { } } - public Map> flushState() { - checkState(finishKeyCalled, "finishKey must be called before flushState"); + private void flushStateInternal() { Map> callbacks = new HashMap<>(); for (StepContext stepContext : getAllStepContexts()) { @@ -553,7 +664,128 @@ public Map> flushState() { // RestrictionTracker.getProgress() or GetSize() are not defined. getOutputBuilder().setSourceBacklogBytes(backlogBytes); } - return callbacks; + + this.finalizationCallbacks.putAll(callbacks); + + getOutputBuilder() + .setSourceBytesProcessed(computeSourceBytesProcessed(sourceBytesProcessCounterName)); + } + + private final long computeSourceBytesProcessed(String sourceBytesCounterName) { + if (!(workExecutor instanceof DataflowMapTaskExecutor)) { + return 0L; + } + HashMap counters = + ((DataflowMapTaskExecutor) workExecutor) + .getReadOperation() + .receivers[0] + .getOutputCounters(); + + return Optional.ofNullable(counters.get(sourceBytesCounterName)) + .map(counter -> ((OutputObjectAndByteCounter) counter).getByteCount().getAndReset()) + .orElse(0L); + } + + public boolean advance() { + // TODO: get more work from workQueueExecutor and merge into the bundle here + return false; + } + + private void startForNewKey(Work newWork, WindmillStateReader reader) throws CoderException { + newWork.setState(Work.State.PROCESSING); + if (keyTransitionListener != null && this.work != null && this.work != newWork) { + keyTransitionListener.onKeyTransition(this.work, newWork); + } + this.key = decodeKey(newWork); + this.work = newWork; + this.finishKeyCalled = false; + this.computationKey = WindmillComputationKey.create(computationId, newWork.getShardedKey()); + + this.outputBuilder = createOutputBuilder(newWork); + this.outputBuilders.add(this.outputBuilder); + newWork.setOnFailureListener(this.workBatchFailed); + this.executedWorks.add(newWork); + + logHotKeyIfDetected(newWork, this.key); + + this.sideInputStateFetcher = + sideInputStateFetcherFactory.createSideInputStateFetcher(newWork::fetchSideInput); + this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; + this.activeReader = null; + + // Note: We do NOT clear sideInputCache here, allowing Key B to reuse warm side inputs! + + // Re-initialize state cache and state/timer internals across all step contexts + Instant processingTime = + computeProcessingTime(newWork.getWorkItem().getTimers().getTimersList()); + if (getAllStepContexts().isEmpty()) { + checkState(this.activeStateReader == null); + } else { + // This must be only created once for a workItem as token validation will fail if the same + // work token is reused. + WindmillStateCache.ForKey cacheForKey = + stateCache.forKey( + getComputationKey(), newWork.getWorkItem().getCacheToken(), getWorkToken()); + this.activeStateReader = reader; + startStepContexts(reader, processingTime, cacheForKey, newWork.watermarks()); + } + } + + // Returns state bytes read during the bundle execution + public long getStateBytesRead() { + return stateBytesRead; + } + + // Returns list of commit requests from the bundle + public List getWorkItemCommits() { + List commits = new ArrayList<>(outputBuilders.size()); + for (Windmill.WorkItemCommitRequest.Builder builder : outputBuilders) { + commits.add(builder.build()); + } + return commits; + } + + // Returns list of Work that was executed in the bundle + public List getExecutedWorks() { + return executedWorks; + } + + // Returns finalization callbacks recorded during the bundle execution + public Map> getFinalizationCallbacks() { + return finalizationCallbacks; + } + + // Returns the current key being processed or null if an unkeyed stage. + public @Nullable Object getKey() { + return key; + } + + // Returns the current Work being processed. + public Work getWork() { + return checkStateNotNull( + work, + "work is null. A call to StreamingModeExecutionContext.start(...) is required to set" + + " work for execution."); + } + + // Returns the serialized windmill key for the current Work + public @Nullable ByteString getSerializedKey() { + return work == null ? null : work.getWorkItem().getKey(); + } + + // Returns the serialized windmill key for the current Work + public WindmillComputationKey getComputationKey() { + return checkStateNotNull(computationKey); + } + + // Returns the windmill work token for the current Work + public long getWorkToken() { + return getWorkItem().getWorkToken(); + } + + // Returns the windmill WorkItem proto for the current Work + public Windmill.WorkItem getWorkItem() { + return getWork().getWorkItem(); } @Nullable diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java index 075a1a8a4250..134655a72a54 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java @@ -35,9 +35,9 @@ public abstract class WindmillReaderIteratorBase extends NativeReader.NativeReaderIterator> { private final StreamingModeExecutionContext context; - private final Windmill.WorkItem work; - private int bundleIndex = 0; - private int messageIndex = -1; + private Windmill.WorkItem work; + private int bundleIndex; + private int messageIndex; private @Nullable WindowedValue current = null; private final ValueProvider skipUndecodableElements; private static final Logger LOG = LoggerFactory.getLogger(WindmillReaderIteratorBase.class); @@ -47,6 +47,8 @@ protected WindmillReaderIteratorBase( this.context = context; this.skipUndecodableElements = skipUndecodableElements; this.work = context.getWorkItem(); + this.bundleIndex = 0; + this.messageIndex = -1; } @Override @@ -57,15 +59,25 @@ public boolean start() throws IOException { @Override public boolean advance() throws IOException { if (context.workIsFailed()) { - throw new WorkItemCancelledException(context.getWorkItem().getShardingKey()); + throw new WorkItemCancelledException(checkNotNull(context.getWorkItem()).getShardingKey()); } while (true) { if (bundleIndex >= work.getMessageBundlesCount()) { - current = null; + // If elements are exhausted, try advancing the execution context to the next key in the + // group context.finishKey(); + if (context.advance()) { + // Transition succeeded! Update iterator references to the new work item + resetWorkFromContext(); + continue; + } + + // All work items are exhausted. + current = null; return false; } + Windmill.InputMessageBundle bundle = work.getMessageBundles(bundleIndex); ++messageIndex; if (messageIndex >= bundle.getMessagesCount()) { @@ -73,6 +85,7 @@ public boolean advance() throws IOException { ++bundleIndex; continue; } + try { current = checkNotNull(decodeMessage(bundle.getMessages(messageIndex))); return true; @@ -91,6 +104,12 @@ public boolean advance() throws IOException { } } + private void resetWorkFromContext() { + this.work = context.getWorkItem(); + this.bundleIndex = 0; + this.messageIndex = -1; + } + protected abstract WindowedValue decodeMessage(Windmill.Message message) throws IOException; @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java index 488684769bd9..c258e25146e4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java @@ -30,7 +30,6 @@ import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.worker.util.ValueInEmptyWindows; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; @@ -49,7 +48,6 @@ @Internal class WindowingWindmillReader extends NativeReader>> { - private final Coder keyCoder; private final Coder valueCoder; private final Coder windowCoder; private final Coder> windowsCoder; @@ -66,7 +64,6 @@ class WindowingWindmillReader extends NativeReader keyedWorkItemCoder = (WindmillKeyedWorkItem.FakeKeyedWorkItemCoder) inputCoder.getValueCoder(); - this.keyCoder = keyedWorkItemCoder.getKeyCoder(); this.valueCoder = keyedWorkItemCoder.getElementCoder(); this.context = context; this.skipUndecodableElements = skipUndecodableElements; @@ -129,73 +126,77 @@ public static WindowingWindmillReader create( return new WindowingWindmillReader<>(coder, context, skipUndecodableElements); } + private KeyedWorkItem createKeyedWorkItem() { + @SuppressWarnings("unchecked") + @Nullable + K key = (K) context.getKey(); + return new WindmillKeyedWorkItem<>( + key, + context.getWorkItem(), + windowCoder, + windowsCoder, + valueCoder, + context.getWindmillTagEncoding(), + context.getDrainMode(), + skipUndecodableElements.isAccessible() + && Boolean.TRUE.equals(skipUndecodableElements.get())); + } + + private boolean isEmpty(KeyedWorkItem keyedWorkItem) { + return Iterables.isEmpty(keyedWorkItem.timersIterable()) + && Iterables.isEmpty(keyedWorkItem.elementsIterable()); + } + @Override public NativeReaderIterator>> iterator() throws IOException { - final K key = - keyCoder.decode( - checkStateNotNull(context.getSerializedKey()).newInput(), Coder.Context.OUTER); - final WorkItem workItem = context.getWorkItem(); - KeyedWorkItem keyedWorkItem = - new WindmillKeyedWorkItem<>( - key, - workItem, - windowCoder, - windowsCoder, - valueCoder, - context.getWindmillTagEncoding(), - context.getDrainMode(), - skipUndecodableElements.isAccessible() - && Boolean.TRUE.equals(skipUndecodableElements.get())); - final boolean isEmptyWorkItem = - (Iterables.isEmpty(keyedWorkItem.timersIterable()) - && Iterables.isEmpty(keyedWorkItem.elementsIterable())); - final WindowedValue> value = new ValueInEmptyWindows<>(keyedWorkItem); - - // Return a noop iterator when current workitem is an empty workitem. - if (isEmptyWorkItem) { - return new NativeReaderIterator>>() { - @Override - public boolean start() throws IOException { - context.finishKey(); - return false; - } - - @Override - public boolean advance() throws IOException { - return false; + return new NativeReaderIterator>>() { + private @Nullable WindowedValue> current = null; + + @Override + public boolean start() throws IOException { + if (context.workIsFailed()) { + throw new WorkItemCancelledException( + checkStateNotNull(context.getWorkItem()).getShardingKey()); } - - @Override - public WindowedValue> getCurrent() { - throw new NoSuchElementException(); + KeyedWorkItem firstKeyedWorkItem = createKeyedWorkItem(); + if (isEmpty(firstKeyedWorkItem)) { + return advance(); // Try to transition immediately if the first key is empty! } - }; - } else { - return new NativeReaderIterator>>() { - private @Nullable WindowedValue> current = null; - - @Override - public boolean start() throws IOException { - current = value; - return true; + current = new ValueInEmptyWindows<>(firstKeyedWorkItem); + return true; + } + + @Override + public boolean advance() throws IOException { + if (context.workIsFailed()) { + throw new WorkItemCancelledException( + checkStateNotNull(context.getWorkItem()).getShardingKey()); } - @Override - public boolean advance() throws IOException { - current = null; + while (true) { context.finishKey(); + if (context.advance()) { + KeyedWorkItem newKeyedWorkItem = createKeyedWorkItem(); + if (isEmpty(newKeyedWorkItem)) { + continue; + } + current = new ValueInEmptyWindows<>(newKeyedWorkItem); + return true; + } + + current = null; return false; } + } - @Override - public WindowedValue> getCurrent() { - if (current == null) { - throw new NoSuchElementException(); - } - return value; + @Override + public WindowedValue> getCurrent() { + if (current == null) { + throw new NoSuchElementException(); } - }; - } + return current; + } + }; } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java index b4f3a22a7f52..b8ee42c8ef88 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java @@ -18,21 +18,16 @@ package org.apache.beam.runners.dataflow.worker.streaming; import com.google.auto.value.AutoValue; -import java.util.HashMap; import java.util.Optional; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.runners.core.metrics.ExecutionStateTracker; -import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutor; import org.apache.beam.runners.dataflow.worker.DataflowWorkExecutor; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; -import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; -import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; -import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -58,7 +53,7 @@ public static ComputationWorkExecutor.Builder builder() { public abstract DataflowWorkExecutor workExecutor(); - public abstract StreamingModeExecutionContext context(); + abstract StreamingModeExecutionContext context(); public abstract Optional> keyCoder(); @@ -67,15 +62,24 @@ public static ComputationWorkExecutor.Builder builder() { /** * Executes DoFns for the Work. Blocks the calling thread until DoFn(s) have completed execution. */ - public final void executeWork( - @Nullable Object key, + public final StreamingModeExecutionContext executeWork( Work work, WindmillStateReader stateReader, - SideInputStateFetcher sideInputStateFetcher, - Windmill.WorkItemCommitRequest.Builder outputBuilder) + BoundedQueueExecutor workQueueExecutor, + BoundedQueueExecutorWorkHandle budgetHandle, + KeyTransitionListener keyTransitionListener) throws Exception { - context().start(key, work, stateReader, sideInputStateFetcher, outputBuilder, workExecutor()); + context() + .start( + work, + stateReader, + workExecutor(), + workQueueExecutor, + budgetHandle, + keyCoder().orElse(null), + keyTransitionListener); workExecutor().execute(); + return context(); } /** @@ -84,6 +88,7 @@ public final void executeWork( */ public final void invalidate() { context().invalidateCache(); + context().reset(); try { workExecutor().close(); } catch (Exception e) { @@ -91,18 +96,6 @@ public final void invalidate() { } } - public final long computeSourceBytesProcessed(String sourceBytesCounterName) { - HashMap counters = - ((DataflowMapTaskExecutor) workExecutor()) - .getReadOperation() - .receivers[0] - .getOutputCounters(); - - return Optional.ofNullable(counters.get(sourceBytesCounterName)) - .map(counter -> ((OutputObjectAndByteCounter) counter).getByteCount().getAndReset()) - .orElse(0L); - } - @AutoValue.Builder public abstract static class Builder { public abstract Builder setWorkExecutor(DataflowWorkExecutor workExecutor); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index 53ed30fdedbb..252a16a38bc9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -28,6 +28,8 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Supplier; import javax.annotation.concurrent.NotThreadSafe; @@ -82,6 +84,8 @@ public final class Work implements RefreshableWork { private volatile TimedState currentState; private volatile boolean isFailed; private volatile String processingThreadName = ""; + private final AtomicReference<@Nullable AtomicBoolean> onFailureListener = + new AtomicReference<>(null); private final boolean drainMode; private Work( @@ -191,6 +195,10 @@ public long getSerializedWorkItemSize() { return serializedWorkItemSize; } + public String getComputationId() { + return processingContext.computationId(); + } + @Override public ShardedKey getShardedKey() { return shardedKey; @@ -244,6 +252,19 @@ public void setProcessingThreadName(String processingThreadName) { @Override public void setFailed() { this.isFailed = true; + AtomicBoolean listener = onFailureListener.get(); + if (listener != null) { + listener.set(true); + } + } + + // Sets the passed in boolean to true if the work fails + // Supports registering only one boolean at a time. + public void setOnFailureListener(@Nullable AtomicBoolean listener) { + onFailureListener.set(listener); + if (isFailed && listener != null) { + listener.set(true); + } } public boolean isCommitPending() { @@ -268,6 +289,10 @@ public void queueCommit(WorkItemCommitRequest commitRequest, ComputationState co processingContext.workCommitter().accept(Commit.create(commitRequest, computationState, this)); } + public Consumer workCommitter() { + return processingContext.workCommitter(); + } + public WindmillStateReader createWindmillStateReader() { return WindmillStateReader.forWork(this); } @@ -390,10 +415,6 @@ private boolean isCommitPending() { abstract Instant startTime(); } - public String getComputationId() { - return processingContext.computationId(); - } - public KeyGroup getKeyGroup() { return keyGroup; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java index 269799903300..4a52d9fde771 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java @@ -29,6 +29,7 @@ import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutor; import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutorFactory; +import org.apache.beam.runners.dataflow.worker.HotKeyLogger; import org.apache.beam.runners.dataflow.worker.IntrinsicMapTaskExecutorFactory; import org.apache.beam.runners.dataflow.worker.ReaderCache; import org.apache.beam.runners.dataflow.worker.ReaderRegistry; @@ -48,6 +49,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.ComputationWorkExecutor; import org.apache.beam.runners.dataflow.worker.streaming.StageInfo; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.common.worker.MapTaskExecutor; import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; import org.apache.beam.runners.dataflow.worker.util.common.worker.ReadOperation; @@ -97,6 +99,8 @@ final class ComputationWorkExecutorFactory { private final IdGenerator idGenerator; private final StreamingGlobalConfigHandle globalConfigHandle; private final boolean throwExceptionOnLargeOutput; + private final HotKeyLogger hotKeyLogger; + private final SideInputStateFetcherFactory sideInputStateFetcherFactory; ComputationWorkExecutorFactory( DataflowWorkerHarnessOptions options, @@ -106,7 +110,9 @@ final class ComputationWorkExecutorFactory { DataflowExecutionStateSampler sampler, CounterSet pendingDeltaCounters, IdGenerator idGenerator, - StreamingGlobalConfigHandle globalConfigHandle) { + StreamingGlobalConfigHandle globalConfigHandle, + HotKeyLogger hotKeyLogger, + SideInputStateFetcherFactory sideInputStateFetcherFactory) { this.options = options; this.mapTaskExecutorFactory = mapTaskExecutorFactory; this.readerCache = readerCache; @@ -124,6 +130,8 @@ final class ComputationWorkExecutorFactory { : StreamingDataflowWorker.MAX_SINK_BYTES; this.throwExceptionOnLargeOutput = hasExperiment(options, THROW_EXCEPTIONS_ON_LARGE_OUTPUT_EXPERIMENT); + this.hotKeyLogger = hotKeyLogger; + this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; } private static Nodes.ParallelInstructionNode extractReadNode( @@ -191,8 +199,12 @@ ComputationWorkExecutor createComputationWorkExecutor( DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker = createExecutionStateTracker(stageInfo, mapTask, workLatencyTrackingId); + boolean hotKeyLoggingEnabled = + options.isHotKeyLoggingEnabled() || hasExperiment(options, "enable_hot_key_logging"); + String stepName = getShuffleTaskStepName(mapTask); StreamingModeExecutionContext context = - createExecutionContext(computationState, stageInfo, executionStateTracker); + createExecutionContext( + computationState, stageInfo, executionStateTracker, hotKeyLoggingEnabled, stepName); DataflowMapTaskExecutor mapTaskExecutor = createMapTaskExecutor(context, mapTask, mapTaskNetwork); ReadOperation readOperation = getValidatedReadOperation(mapTaskExecutor); @@ -255,7 +267,9 @@ ComputationWorkExecutor createComputationWorkExecutor( private StreamingModeExecutionContext createExecutionContext( ComputationState computationState, StageInfo stageInfo, - DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker) { + DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker, + boolean hotKeyLoggingEnabled, + String stepName) { String computationId = computationState.getComputationId(); return new StreamingModeExecutionContext( pendingDeltaCounters, @@ -268,7 +282,12 @@ private StreamingModeExecutionContext createExecutionContext( stageInfo.executionStateRegistry(), globalConfigHandle, maxSinkBytes, - throwExceptionOnLargeOutput); + throwExceptionOnLargeOutput, + hotKeyLogger, + hotKeyLoggingEnabled, + stepName, + computationState.sourceBytesProcessCounterName(), + sideInputStateFetcherFactory); } private DataflowMapTaskExecutor createMapTaskExecutor( @@ -286,6 +305,12 @@ private DataflowMapTaskExecutor createMapTaskExecutor( idGenerator); } + private static String getShuffleTaskStepName(MapTask mapTask) { + // The MapTask instruction is ordered by dependencies, such that the first element is + // always going to be the shuffle task. + return mapTask.getInstructions().get(0).getName(); + } + private DataflowExecutionContext.DataflowExecutionStateTracker createExecutionStateTracker( StageInfo stageInfo, MapTask mapTask, String workLatencyTrackingId) { return new DataflowExecutionContext.DataflowExecutionStateTracker( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index a3f23aebdf8f..e9b85d720d2b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -17,22 +17,26 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.processing; -import static org.apache.beam.sdk.options.ExperimentalOptions.hasExperiment; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import com.google.api.services.dataflow.model.MapTask; import com.google.auto.value.AutoValue; -import java.util.Optional; +import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutorFactory; import org.apache.beam.runners.dataflow.worker.HotKeyLogger; import org.apache.beam.runners.dataflow.worker.ReaderCache; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC; import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; @@ -45,7 +49,6 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingCounters; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.ExceptionUtils; @@ -57,12 +60,10 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.WorkFailureProcessor; import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.fn.IdGenerator; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; -import org.joda.time.Duration; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -78,41 +79,35 @@ public class StreamingWorkScheduler { private static final Logger LOG = LoggerFactory.getLogger(StreamingWorkScheduler.class); - private final DataflowWorkerHarnessOptions options; private final Supplier clock; private final ComputationWorkExecutorFactory computationWorkExecutorFactory; - private final SideInputStateFetcherFactory sideInputStateFetcherFactory; private final FailureTracker failureTracker; private final WorkFailureProcessor workFailureProcessor; private final StreamingCommitFinalizer commitFinalizer; private final StreamingCounters streamingCounters; - private final HotKeyLogger hotKeyLogger; private final ConcurrentMap stageInfoMap; private final DataflowExecutionStateSampler sampler; private final StreamingGlobalConfigHandle globalConfigHandle; + private final BoundedQueueExecutor workExecutor; public StreamingWorkScheduler( - DataflowWorkerHarnessOptions options, Supplier clock, + BoundedQueueExecutor workExecutor, ComputationWorkExecutorFactory computationWorkExecutorFactory, - SideInputStateFetcherFactory sideInputStateFetcherFactory, FailureTracker failureTracker, WorkFailureProcessor workFailureProcessor, StreamingCommitFinalizer commitFinalizer, StreamingCounters streamingCounters, - HotKeyLogger hotKeyLogger, ConcurrentMap stageInfoMap, DataflowExecutionStateSampler sampler, StreamingGlobalConfigHandle globalConfigHandle) { - this.options = options; this.clock = clock; + this.workExecutor = workExecutor; this.computationWorkExecutorFactory = computationWorkExecutorFactory; - this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; this.failureTracker = failureTracker; this.workFailureProcessor = workFailureProcessor; this.commitFinalizer = commitFinalizer; this.streamingCounters = streamingCounters; - this.hotKeyLogger = hotKeyLogger; this.stageInfoMap = stageInfoMap; this.sampler = sampler; this.globalConfigHandle = globalConfigHandle; @@ -134,6 +129,9 @@ public static StreamingWorkScheduler create( IdGenerator idGenerator, StreamingGlobalConfigHandle globalConfigHandle, ConcurrentMap stageInfoMap) { + SideInputStateFetcherFactory sideInputStateFetcherFactory = + SideInputStateFetcherFactory.fromOptions(options); + ComputationWorkExecutorFactory computationWorkExecutorFactory = new ComputationWorkExecutorFactory( options, @@ -143,18 +141,18 @@ public static StreamingWorkScheduler create( sampler, streamingCounters.pendingDeltaCounters(), idGenerator, - globalConfigHandle); + globalConfigHandle, + hotKeyLogger, + sideInputStateFetcherFactory); return new StreamingWorkScheduler( - options, clock, + workExecutor, computationWorkExecutorFactory, - SideInputStateFetcherFactory.fromOptions(options), failureTracker, workFailureProcessor, StreamingCommitFinalizer.create(workExecutor, commitFinalizerCleanupExecutor), streamingCounters, - hotKeyLogger, stageInfoMap, sampler, globalConfigHandle); @@ -187,21 +185,16 @@ private static Windmill.WorkItemCommitRequest buildWorkItemTruncationRequest( /** Sets the stage name and workId of the Thread executing the {@link Work} for logging. */ private static void setUpWorkLoggingContext(String workLatencyTrackingId, String computationId) { - DataflowWorkerLoggingMDC.setWorkId(workLatencyTrackingId); - DataflowWorkerLoggingMDC.setStageName(computationId); + setLoggingContextWorkId(workLatencyTrackingId); + setLoggingContextComputation(computationId); } - private static String getShuffleTaskStepName(MapTask mapTask) { - // The MapTask instruction is ordered by dependencies, such that the first element is - // always going to be the shuffle task. - return mapTask.getInstructions().get(0).getName(); + private static void setLoggingContextComputation(@Nullable String computationId) { + DataflowWorkerLoggingMDC.setStageName(computationId); } - /** Resets logging context of the Thread executing the {@link Work} for logging. */ - private void resetWorkLoggingContext(String workLatencyTrackingId) { - sampler.resetForWorkId(workLatencyTrackingId); - DataflowWorkerLoggingMDC.setWorkId(null); - DataflowWorkerLoggingMDC.setStageName(null); + private static void setLoggingContextWorkId(@Nullable String workLatencyTrackingId) { + DataflowWorkerLoggingMDC.setWorkId(workLatencyTrackingId); } /** @@ -248,46 +241,39 @@ private void processWork( } private void processWork( - ComputationState computationState, Work work, BoundedQueueExecutorWorkHandle unusedHandle) { + ComputationState computationState, Work work, BoundedQueueExecutorWorkHandle handle) { Windmill.WorkItem workItem = work.getWorkItem(); String computationId = computationState.getComputationId(); - ByteString key = workItem.getKey(); work.setProcessingThreadName(Thread.currentThread().getName()); work.setState(Work.State.PROCESSING); setUpWorkLoggingContext(work.getLatencyTrackingId(), computationId); LOG.debug("Starting processing for {}:\n{}", computationId, work); if (workItem.getSourceState().getOnlyFinalize()) { - Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem); - outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true)); - work.setState(Work.State.COMMIT_QUEUED); - work.queueCommit(outputBuilder.build(), computationState); + handleOnlyFinalize(computationState, work, workItem); return; } long processingStartTimeNanos = System.nanoTime(); - MapTask mapTask = computationState.getMapTask(); - StageInfo stageInfo = - stageInfoMap.computeIfAbsent( - mapTask.getStageName(), s -> StageInfo.create(s, mapTask.getSystemName())); + StageInfo stageInfo = getStageInfo(computationState); + @Nullable List workBatch = null; try { if (work.isFailed()) { throw new WorkItemCancelledException(workItem.getShardingKey()); } - // Execute the user code for the Work. - ExecuteWorkResult executeWorkResult = executeWork(work, stageInfo, computationState); - Windmill.WorkItemCommitRequest.Builder commitRequest = executeWorkResult.commitWorkRequest(); + // Execute the user code for the Work batch. + ExecuteWorkResult executeWorkResult = executeWork(work, stageInfo, computationState, handle); + workBatch = executeWorkResult.workBatch(); + List workItemCommits = executeWorkResult.workItemCommits(); + + commitFinalizer.cacheCommitFinalizers(executeWorkResult.finalizationCallbacks()); - // Validate the commit request, possibly requesting truncation if the commitSize is too large. - Windmill.WorkItemCommitRequest validatedCommitRequest = - validateCommitRequestSize(commitRequest.build(), computationId, workItem); + commitWorkBatch(computationState, workBatch, workItemCommits); - // Queue the commit. - work.queueCommit(validatedCommitRequest, computationState); - recordProcessingStats(commitRequest, workItem, executeWorkResult); - LOG.debug("Processing done for work token: {}", workItem.getWorkToken()); + recordProcessingStats(workBatch, workItemCommits, executeWorkResult.stateBytesRead()); + LOG.debug("Processing done for work batch size: {}", workBatch.size()); } catch (Throwable t) { // OutOfMemoryError that are caught will be rethrown and trigger jvm termination. try { @@ -305,23 +291,17 @@ private void processWork( throw ExceptionUtils.safeWrapThrowableAsException(t2); } } finally { + List processedWorkBatch = workBatch != null ? workBatch : ImmutableList.of(work); // Update total processing time counters. Updating in finally clause ensures that // work items causing exceptions are also accounted in time spent. - long processingTimeMsecs = - TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); - stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); + recordProcessingTime(stageInfo, processedWorkBatch, processingStartTimeNanos); - // Attribute all the processing to timers if the work item contains any timers. - // Tests show that work items rarely contain both timers and message bundles. It should - // be a fairly close approximation. - // Another option: Derive time split between messages and timers based on recent totals. - // either here or in DFE. - if (work.getWorkItem().hasTimers()) { - stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs); + setLoggingContextWorkId(null); + setLoggingContextComputation(null); + sampler.resetForWorkId(work.getLatencyTrackingId()); + for (Work w : processedWorkBatch) { + w.setProcessingThreadName(""); } - - resetWorkLoggingContext(work.getLatencyTrackingId()); - work.setProcessingThreadName(""); } } @@ -353,27 +333,37 @@ private Windmill.WorkItemCommitRequest validateCommitRequestSize( } private void recordProcessingStats( - Windmill.WorkItemCommitRequest.Builder outputBuilder, - Windmill.WorkItem workItem, - ExecuteWorkResult executeWorkResult) { - // Compute shuffle and state byte statistics these will be flushed asynchronously. - long stateBytesWritten = - outputBuilder - .clearOutputMessages() - .clearPerWorkItemLatencyAttributions() - .build() - .getSerializedSize(); - - streamingCounters.windmillShuffleBytesRead().addValue(computeShuffleBytesRead(workItem)); - streamingCounters.windmillStateBytesRead().addValue(executeWorkResult.stateBytesRead()); - streamingCounters.windmillStateBytesWritten().addValue(stateBytesWritten); + List workBatch, + List workItemCommits, + long totalStateBytesRead) { + long totalStateBytesWritten = 0; + long totalShuffleBytesRead = 0; + checkState(workBatch.size() == workItemCommits.size()); + for (int i = 0; i < workBatch.size(); i++) { + Windmill.WorkItem workItem = workBatch.get(i).getWorkItem(); + Windmill.WorkItemCommitRequest commit = workItemCommits.get(i); + // Compute shuffle and state byte statistics these will be flushed asynchronously. + long stateBytesWritten = + commit + .toBuilder() + .clearOutputMessages() + .clearPerWorkItemLatencyAttributions() + .build() + .getSerializedSize(); + totalStateBytesWritten += stateBytesWritten; + totalShuffleBytesRead += computeShuffleBytesRead(workItem); + } + streamingCounters.windmillShuffleBytesRead().addValue(totalShuffleBytesRead); + streamingCounters.windmillStateBytesRead().addValue(totalStateBytesRead); + streamingCounters.windmillStateBytesWritten().addValue(totalStateBytesWritten); } private ExecuteWorkResult executeWork( - Work work, StageInfo stageInfo, ComputationState computationState) throws Exception { - Windmill.WorkItem workItem = work.getWorkItem(); - ByteString key = workItem.getKey(); - Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem); + Work work, + StageInfo stageInfo, + ComputationState computationState, + BoundedQueueExecutorWorkHandle handle) + throws Exception { ComputationWorkExecutor computationWorkExecutor = computationState .acquireComputationWorkExecutor() @@ -384,89 +374,137 @@ private ExecuteWorkResult executeWork( try { WindmillStateReader stateReader = work.createWindmillStateReader(); - SideInputStateFetcher localSideInputStateFetcher = - sideInputStateFetcherFactory.createSideInputStateFetcher(work::fetchSideInput); - - // If the read output KVs, then we can decode Windmill's byte key into userland - // key object and provide it to the execution context for use with per-key state. - // Otherwise, we pass null. - // - // The coder type that will be present is: - // WindowedValueCoder(TimerOrElementCoder(KvCoder)) - Optional> keyCoder = computationWorkExecutor.keyCoder(); - @SuppressWarnings("deprecation") - @Nullable - final Object executionKey = - !keyCoder.isPresent() ? null : keyCoder.get().decode(key.newInput(), Coder.Context.OUTER); - - if (workItem.hasHotKeyInfo()) { - Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo(); - Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 1000); - - String stepName = getShuffleTaskStepName(computationState.getMapTask()); - if (executionKey != null - && (options.isHotKeyLoggingEnabled() - || hasExperiment(options, "enable_hot_key_logging")) - && keyCoder.isPresent()) { - hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey); - } else { - hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge); + + KeyTransitionListener keyTransitionListener = createKeyTransitionListener(); + + List workBatch; + List workItemCommits; + Map> finalizationCallbacks; + long stateBytesRead; + { + // Blocks while executing work. + StreamingModeExecutionContext context = + computationWorkExecutor.executeWork( + work, stateReader, workExecutor, handle, keyTransitionListener); + if (context.workIsFailed()) { + throw new WorkItemCancelledException(work.getWorkItem().getShardingKey()); } - } + context.flushState(); - // Blocks while executing work. - computationWorkExecutor.executeWork( - executionKey, work, stateReader, localSideInputStateFetcher, outputBuilder); + // Retrieve executed works, work item commits, and accumulated callbacks from execution + // context + workBatch = context.getExecutedWorks(); + workItemCommits = context.getWorkItemCommits(); + finalizationCallbacks = context.getFinalizationCallbacks(); + stateBytesRead = context.getStateBytesRead(); - if (work.isFailed()) { - throw new WorkItemCancelledException(workItem.getShardingKey()); - } - - // Reports source bytes processed to WorkItemCommitRequest if available. - try { - long sourceBytesProcessed = - computationWorkExecutor.computeSourceBytesProcessed( - computationState.sourceBytesProcessCounterName()); - outputBuilder.setSourceBytesProcessed(sourceBytesProcessed); - } catch (Exception e) { - LOG.error("{}", e.toString()); + context.reset(); // Don't use context after this. } - - commitFinalizer.cacheCommitFinalizers(computationWorkExecutor.context().flushState()); - // Release the execution state for another thread to use. computationState.releaseComputationWorkExecutor(computationWorkExecutor); computationWorkExecutor = null; - work.setState(Work.State.COMMIT_QUEUED); - outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler)); - return ExecuteWorkResult.create( - outputBuilder, stateReader.getBytesRead() + localSideInputStateFetcher.getBytesRead()); + workBatch, workItemCommits, finalizationCallbacks, stateBytesRead); } catch (Throwable t) { if (computationWorkExecutor != null) { // If processing failed due to a thrown exception, close the executionState. Do not // return/release the executionState back to computationState as that will lead to this // executionState instance being reused. - LOG.debug("Invalidating executor after work item {} failed", workItem.getWorkToken(), t); + LOG.debug( + "Invalidating executor after work item {} failed", + work.getWorkItem().getWorkToken(), + t); computationWorkExecutor.invalidate(); } - // Re-throw the exception, it will be caught and handled by workFailureProcessor downstream. throw t; } } + private void handleOnlyFinalize( + ComputationState computationState, Work work, Windmill.WorkItem workItem) { + Windmill.WorkItemCommitRequest.Builder outputBuilder = + initializeOutputBuilder(workItem.getKey(), workItem); + outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true)); + work.setState(Work.State.COMMIT_QUEUED); + work.queueCommit(outputBuilder.build(), computationState); + } + + private StageInfo getStageInfo(ComputationState computationState) { + MapTask mapTask = computationState.getMapTask(); + return stageInfoMap.computeIfAbsent( + mapTask.getStageName(), s -> StageInfo.create(s, mapTask.getSystemName())); + } + + private void commitWorkBatch( + ComputationState computationState, + List workBatch, + List workItemCommits) { + checkState(workBatch.size() == 1, "Expected single-key work batch, got: " + workBatch.size()); + checkState(workBatch.size() == workItemCommits.size()); + commitSingleKeyWork(computationState, workBatch.get(0), workItemCommits.get(0)); + } + + private void commitSingleKeyWork( + ComputationState computationState, Work work, Windmill.WorkItemCommitRequest commitRequest) { + // Validate the commit request, possibly requesting truncation if the commitSize is too large. + Windmill.WorkItemCommitRequest validatedCommitRequest = + validateCommitRequestSize( + commitRequest, computationState.getComputationId(), work.getWorkItem()); + work.setState(Work.State.COMMIT_QUEUED); + validatedCommitRequest = + validatedCommitRequest + .toBuilder() + .addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler)) + .build(); + work.queueCommit(validatedCommitRequest, computationState); + } + + private void recordProcessingTime( + StageInfo stageInfo, List worksToCleanup, long processingStartTimeNanos) { + long processingTimeMsecs = + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); + stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); + if (anyWorkHasTimers(worksToCleanup)) { + // Attribute all the processing to timers if the work item contains any timers. + // Tests show that work items rarely contain both timers and message bundles. It should + // be a fairly close approximation. + // Another option: Derive time split between messages and timers based on recent totals. + // either here or in DFE. + stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs); + } + } + + private static boolean anyWorkHasTimers(List works) { + return works.stream().anyMatch(w -> w.getWorkItem().hasTimers()); + } + + private KeyTransitionListener createKeyTransitionListener() { + return (oldWork, newWork) -> { + setLoggingContextWorkId(newWork.getLatencyTrackingId()); + newWork.setProcessingThreadName(oldWork.getProcessingThreadName()); + oldWork.setProcessingThreadName(""); + }; + } + @AutoValue abstract static class ExecuteWorkResult { - - private static ExecuteWorkResult create( - Windmill.WorkItemCommitRequest.Builder commitWorkRequest, long stateBytesRead) { + static ExecuteWorkResult create( + List workBatch, + List workItemCommits, + Map> finalizationCallbacks, + long stateBytesRead) { return new AutoValue_StreamingWorkScheduler_ExecuteWorkResult( - commitWorkRequest, stateBytesRead); + workBatch, workItemCommits, finalizationCallbacks, stateBytesRead); } - abstract Windmill.WorkItemCommitRequest.Builder commitWorkRequest(); + abstract List workBatch(); + + abstract List workItemCommits(); + + // Map> + abstract Map> finalizationCallbacks(); abstract long stateBytesRead(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 5bcdffcc2564..23730bc57705 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -420,6 +420,7 @@ private ParallelInstruction makeWindowingSourceInstruction(Coder coder) { CloudObjects.asCloudObject(IntervalWindowCoder.of(), /* sdkComponents= */ null))); return new ParallelInstruction() + .setName(DEFAULT_SOURCE_SYSTEM_NAME) .setSystemName(DEFAULT_SOURCE_SYSTEM_NAME) .setOriginalName(DEFAULT_SOURCE_ORIGINAL_NAME) .setRead( @@ -439,6 +440,7 @@ private ParallelInstruction makeWindowingSourceInstruction(Coder coder) { private ParallelInstruction makeSourceInstruction(Coder coder) { return new ParallelInstruction() + .setName(DEFAULT_SOURCE_SYSTEM_NAME) .setSystemName(DEFAULT_SOURCE_SYSTEM_NAME) .setOriginalName(DEFAULT_SOURCE_ORIGINAL_NAME) .setRead( @@ -527,6 +529,7 @@ private ParallelInstruction makeSinkInstruction( CloudObject spec = CloudObject.forClass(WindmillSink.class); addString(spec, "stream_id", streamId); return new ParallelInstruction() + .setName(streamId) .setSystemName(DEFAULT_SINK_SYSTEM_NAME) .setOriginalName(DEFAULT_SINK_ORIGINAL_NAME) .setWrite( @@ -571,11 +574,16 @@ private Windmill.GetWorkResponse buildInput(String input, byte[] metadata) throw Windmill.GetWorkResponse.Builder builder = Windmill.GetWorkResponse.newBuilder(); TextFormat.merge(input, builder); if (metadata != null) { - Windmill.InputMessageBundle.Builder messageBundleBuilder = - builder.getWorkBuilder(0).getWorkBuilder(0).getMessageBundlesBuilder(0); - for (Windmill.Message.Builder messageBuilder : - messageBundleBuilder.getMessagesBuilderList()) { - messageBuilder.setMetadata(addPaneTag(PaneInfo.NO_FIRING, metadata)); + for (Windmill.ComputationWorkItems.Builder compBuilder : builder.getWorkBuilderList()) { + for (Windmill.WorkItem.Builder workBuilder : compBuilder.getWorkBuilderList()) { + for (Windmill.InputMessageBundle.Builder messageBundleBuilder : + workBuilder.getMessageBundlesBuilderList()) { + for (Windmill.Message.Builder messageBuilder : + messageBundleBuilder.getMessagesBuilderList()) { + messageBuilder.setMetadata(addPaneTag(PaneInfo.NO_FIRING, metadata)); + } + } + } } } @@ -1327,7 +1335,7 @@ public void testKeyCommitTooLargeException() throws Exception { makeExpectedTruncationRequestOutput( 1, "large_key", DEFAULT_SHARDING_KEY, largeCommit.getEstimatedWorkItemCommitBytes()) .build(), - largeCommit); + removeDynamicFields(largeCommit)); // Check this explicitly since the estimated commit bytes weren't actually // checked against an expected value in the previous step @@ -2495,6 +2503,7 @@ private List makeUnboundedSourcePipeline( return Arrays.asList( new ParallelInstruction() + .setName("Read") .setSystemName("Read") .setOriginalName("OriginalReadName") .setRead( @@ -3954,11 +3963,16 @@ public void testDoFnLatencyBreakdownsReportedOnCommit() throws Exception { LatencyAttribution.newBuilder().setState(State.ACTIVE).setTotalDurationMillis(100); for (LatencyAttribution la : commit.getPerWorkItemLatencyAttributionsList()) { if (la.getState() == State.ACTIVE) { - assertThat(la.getActiveLatencyBreakdownCount(), equalTo(1)); - assertThat( - la.getActiveLatencyBreakdown(0).getUserStepName(), equalTo(DEFAULT_PARDO_USER_NAME)); - Assert.assertTrue(la.getActiveLatencyBreakdown(0).hasProcessingTimesDistribution()); - Assert.assertFalse(la.getActiveLatencyBreakdown(0).hasActiveMessageMetadata()); + LatencyAttribution.ActiveLatencyBreakdown pardoBreakdown = null; + for (LatencyAttribution.ActiveLatencyBreakdown lb : la.getActiveLatencyBreakdownList()) { + if (DEFAULT_PARDO_USER_NAME.equals(lb.getUserStepName())) { + pardoBreakdown = lb; + break; + } + } + Assert.assertNotNull("Expected breakdown for " + DEFAULT_PARDO_USER_NAME, pardoBreakdown); + Assert.assertTrue(pardoBreakdown.hasProcessingTimesDistribution()); + Assert.assertFalse(pardoBreakdown.hasActiveMessageMetadata()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 056185c587f3..561596f68d0f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -63,7 +63,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.config.FakeGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; @@ -73,6 +73,8 @@ import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTagEncodingV2; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.metrics.MetricsContainer; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.state.TimeDomain; @@ -100,7 +102,6 @@ public class StreamingModeExecutionContextTest { @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - @Mock private SideInputStateFetcher sideInputStateFetcher; @Mock private WindmillStateReader stateReader; @Mock private WorkExecutor workExecutor; @@ -137,7 +138,12 @@ private StreamingModeExecutionContext createExecutionContext( executionStateRegistry, configHandle, Long.MAX_VALUE, - /*throwExceptionOnLargeOutput=*/ false); + /*throwExceptionOnLargeOutput=*/ false, + new HotKeyLogger(), + /*hotKeyLoggingEnabled=*/ false, + /*stepName=*/ "stepName", + "sourceBytesProcessCounterName", + SideInputStateFetcherFactory.fromOptions(options)); } @Before @@ -159,25 +165,45 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla Instant::now); } + private void start(Work work) { + start(executionContext, work, null); + } + + private void start(Work work, Coder keyCoder) { + start(executionContext, work, keyCoder); + } + + private void start(StreamingModeExecutionContext context, Work work) { + start(context, work, null); + } + + private void start(StreamingModeExecutionContext context, Work work, Coder keyCoder) { + try { + context.start( + work, + stateReader, + workExecutor, + /* workQueueExecutor= */ null, + /* budgetHandle= */ null, + keyCoder, + /* keyTransitionListener= */ (k, c) -> {}); + } catch (CoderException e) { + throw new RuntimeException(e); + } + } + @Test public void testTimerInternalsSetTimer() throws Exception { - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); NameContext nameContext = NameContextsForTests.nameContextForTest(); DataflowOperationContext operationContext = executionContext.createOperationContext(nameContext); StreamingModeExecutionContext.StepContext stepContext = executionContext.getStepContext(operationContext); - executionContext.start( - "key", + start( createMockWork( Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); TimerInternals timerInternals = stepContext.timerInternals(); @@ -191,6 +217,7 @@ public void testTimerInternalsSetTimer() throws Exception { executionContext.finishKey(); executionContext.flushState(); + Windmill.WorkItemCommitRequest.Builder outputBuilder = executionContext.getOutputBuilder(); Windmill.Timer timer = outputBuilder.buildPartial().getOutputTimers(0); assertThat(timer.getTag().toStringUtf8(), equalTo("/skey+0:5000")); assertThat(timer.getTimestamp(), equalTo(TimeUnit.MILLISECONDS.toMicros(5000))); @@ -199,9 +226,6 @@ public void testTimerInternalsSetTimer() throws Exception { @Test public void testTimerInternalsProcessingTimeSkew() { - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); - NameContext nameContext = NameContextsForTests.nameContextForTest(); DataflowOperationContext operationContext = executionContext.createOperationContext(nameContext); @@ -221,15 +245,10 @@ public void testTimerInternalsProcessingTimeSkew() { .setTimestamp(timerTimestamp.getMillis() * 1000) .setType(Windmill.Timer.Type.REALTIME); - executionContext.start( - "key", + start( createMockWork( workItemBuilder.build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); TimerInternals timerInternals = stepContext.timerInternals(); assertTrue(timerTimestamp.isBefore(timerInternals.currentProcessingTime())); } @@ -437,50 +456,65 @@ public void testStateTagEncodingBasedOnConfig() { @Test public void testSetBacklogBytes() { - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); NameContext nameContext = NameContextsForTests.nameContextForTest(); DataflowOperationContext operationContext = executionContext.createOperationContext(nameContext); StreamingModeExecutionContext.StepContext stepContext = executionContext.getStepContext(operationContext); - executionContext.start( - "key", + start( createMockWork( Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); stepContext.setBacklogBytes(1234.0); executionContext.finishKey(); executionContext.flushState(); - assertEquals(1234, outputBuilder.getSourceBacklogBytes()); + assertEquals(1234, executionContext.getOutputBuilder().getSourceBacklogBytes()); + } + + @Test + public void testFinishKeyReentrantSafety() { + start( + createMockWork( + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); + + // First call + executionContext.finishKey(); + // Second call - should not throw any Exception + executionContext.finishKey(); + } + + @Test + public void testStart_internalKeyDecoding() throws Exception { + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("decodedKey")) + .setWorkToken(17L) + .build(); + Work work = + createMockWork( + workItem, Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()); + + start(work, org.apache.beam.sdk.coders.StringUtf8Coder.of()); + + assertEquals("decodedKey", executionContext.getKey()); } @Test public void testInternalsPoisonedAfterFlushState() throws Exception { - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); NameContext nameContext = NameContextsForTests.nameContextForTest(); DataflowOperationContext operationContext = executionContext.createOperationContext(nameContext); StreamingModeExecutionContext.StepContext stepContext = executionContext.getStepContext(operationContext); - executionContext.start( - "key", + start( createMockWork( Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); TimerInternals timerInternals = stepContext.timerInternals(); StateInternals stateInternals = stepContext.stateInternals(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java index 539c38eeb1da..b45e0de6447c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java @@ -30,6 +30,7 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.ThreadLocalRandom; +import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.options.ValueProvider; @@ -122,6 +123,7 @@ public void testFinishKeyCalled() throws Exception { .build()) .build(); when(mockContext.getWorkItem()).thenReturn(workItem); + when(mockContext.advance()).thenReturn(false); try (TestWindmillReaderIterator iter = new TestWindmillReaderIterator(mockContext, ValueProvider.StaticValueProvider.of(false))) { @@ -131,6 +133,76 @@ public void testFinishKeyCalled() throws Exception { } } + @Test + public void testAdvanceKeyChaining() throws Exception { + StreamingModeExecutionContext mockContext = mock(StreamingModeExecutionContext.class); + when(mockContext.workIsFailed()).thenReturn(false); + + // Work item A (1 message) + Windmill.WorkItem workItemA = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("keyA")) + .setWorkToken(100L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(1000) + .setData(ByteString.EMPTY) + .build()) + .build()) + .build(); + when(mockContext.getWorkItem()).thenReturn(workItemA); + + // Work item B (1 message) + Windmill.WorkItem workItemB = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("keyB")) + .setWorkToken(200L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(2000) + .setData(ByteString.EMPTY) + .build()) + .build()) + .build(); + + // Set up context.advance() to mock transition + when(mockContext.advance()) + .thenAnswer( + new org.mockito.stubbing.Answer() { + private int count = 0; + + @Override + public Boolean answer(org.mockito.invocation.InvocationOnMock invocation) { + if (count == 0) { + count++; + when(mockContext.getWorkItem()).thenReturn(workItemB); + return true; + } + return false; + } + }); + + try (TestWindmillReaderIterator iter = + new TestWindmillReaderIterator(mockContext, ValueProvider.StaticValueProvider.of(false))) { + assertTrue(iter.start()); + assertEquals(1000L, iter.getCurrent().getValue().longValue()); + + // Advance should trigger context.advance(), transition to workItemB, and decode message from + // workItemB (timestamp 2000) + assertTrue(iter.advance()); + assertEquals(2000L, iter.getCurrent().getValue().longValue()); + + // Next advance should exhaust it and return false + assertFalse(iter.advance()); + } + } + private void testForMessageBundleCounts(int... messageBundleCounts) throws IOException { testForMessageBundleCounts(false, messageBundleCounts); } @@ -179,4 +251,24 @@ private void testForMessageBundleCounts(boolean skipErrors, int... messageBundle assertEquals(Arrays.toString(messageBundleCounts) + skipErrors, expected, actual); } } + + private static Work createMockWork(Windmill.WorkItem workItem) { + return Work.create( + workItem, + workItem.getSerializedSize(), + org.apache.beam.runners.dataflow.worker.streaming.Watermarks.builder() + .setInputDataWatermark(new org.joda.time.Instant(1000)) + .build(), + Work.createProcessingContext( + "computationId", + mock( + org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient + .class), + ignored -> {}, + mock( + org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender + .class)), + false, + org.joda.time.Instant::now); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java new file mode 100644 index 000000000000..2e7c80330cf0 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.List; +import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTagEncodingV1; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues.FullWindowedValueCoder; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class WindowingWindmillReaderTest { + private StreamingModeExecutionContext mockContext; + private WindowingWindmillReader reader; + + @SuppressWarnings("unchecked") + @Before + public void setUp() { + mockContext = mock(StreamingModeExecutionContext.class); + when(mockContext.workIsFailed()).thenReturn(false); + when(mockContext.getWindmillTagEncoding()).thenReturn(WindmillTagEncodingV1.instance()); + when(mockContext.getDrainMode()).thenReturn(false); + + Coder keyCoder = StringUtf8Coder.of(); + Coder valueCoder = VarLongCoder.of(); + KvCoder kvCoder = KvCoder.of(keyCoder, valueCoder); + WindmillKeyedWorkItem.FakeKeyedWorkItemCoder keyedWorkItemCoder = + (WindmillKeyedWorkItem.FakeKeyedWorkItemCoder) + WindmillKeyedWorkItem.FakeKeyedWorkItemCoder.of(kvCoder); + FullWindowedValueCoder> coder = + FullWindowedValueCoder.of(keyedWorkItemCoder, IntervalWindowCoder.of()); + + reader = + WindowingWindmillReader.create( + coder, mockContext, ValueProvider.StaticValueProvider.of(false)); + } + + private static Work createMockWork(Windmill.WorkItem workItem) { + return Work.create( + workItem, + workItem.getSerializedSize(), + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build(), + Work.createProcessingContext( + "computationId", new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), + false, + Instant::now); + } + + private static ByteString encodeMetadata(List windows) throws IOException { + ByteStringOutputStream stream = new ByteStringOutputStream(); + PaneInfoCoder.INSTANCE.encode(PaneInfo.NO_FIRING, stream); + ListCoder.of(IntervalWindowCoder.of()).encode(windows, stream); + return stream.toByteString(); + } + + private static ByteString encodeValue(long value) throws IOException { + ByteStringOutputStream stream = new ByteStringOutputStream(); + VarLongCoder.of().encode(value, stream); + return stream.toByteString(); + } + + @Test + public void testSingleNonEmptyKey() throws IOException { + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(1000)); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(100L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(1000) + .setData(encodeValue(42L)) + .setMetadata(encodeMetadata(ImmutableList.of(window))) + .build()) + .build()) + .build(); + Work work = createMockWork(workItem); + + when(mockContext.getKey()).thenReturn("key1"); + when(mockContext.getWorkItem()).thenReturn(workItem); + when(mockContext.getWork()).thenReturn(work); + when(mockContext.advance()).thenReturn(false); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + assertTrue(iter.start()); + WindowedValue> current = iter.getCurrent(); + assertEquals("key1", current.getValue().key()); + assertFalse(Iterables.isEmpty(current.getValue().elementsIterable())); + WindowedValue elem = Iterables.getOnlyElement(current.getValue().elementsIterable()); + assertEquals(42L, elem.getValue().longValue()); + + assertFalse(iter.advance()); + verify(mockContext).finishKey(); + } + } + + @Test + public void testSingleEmptyKey() throws IOException { + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(100L) + .build(); // No message bundles or timers + Work work = createMockWork(workItem); + + when(mockContext.getKey()).thenReturn("key1"); + when(mockContext.getWorkItem()).thenReturn(workItem); + when(mockContext.getWork()).thenReturn(work); + when(mockContext.advance()).thenReturn(false); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + assertFalse( + iter.start()); // Should skip the empty key and return false because advance returns false + verify(mockContext).finishKey(); + } + } + + @Test + public void testMultipleKeys_withEmptyAndNonEmpty() throws IOException { + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(1000)); + // Key 1: Empty + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(100L) + .build(); + Work work1 = createMockWork(workItem1); + + // Key 2: Non-empty + Windmill.WorkItem workItem2 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key2")) + .setWorkToken(200L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(2000) + .setData(encodeValue(84L)) + .setMetadata(encodeMetadata(ImmutableList.of(window))) + .build()) + .build()) + .build(); + Work work2 = createMockWork(workItem2); + + // Key 3: Empty + Windmill.WorkItem workItem3 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key3")) + .setWorkToken(300L) + .build(); + Work work3 = createMockWork(workItem3); + + // Initial state + when(mockContext.getKey()).thenReturn("key1"); + when(mockContext.getWorkItem()).thenReturn(workItem1); + when(mockContext.getWork()).thenReturn(work1); + + // Mock transition behaviour of context.advance() + when(mockContext.advance()) + .thenAnswer( + new org.mockito.stubbing.Answer() { + private int count = 0; + + @Override + public Boolean answer(org.mockito.invocation.InvocationOnMock invocation) { + if (count == 0) { + count++; + when(mockContext.getKey()).thenReturn("key2"); + when(mockContext.getWorkItem()).thenReturn(workItem2); + when(mockContext.getWork()).thenReturn(work2); + return true; + } else if (count == 1) { + count++; + when(mockContext.getKey()).thenReturn("key3"); + when(mockContext.getWorkItem()).thenReturn(workItem3); + when(mockContext.getWork()).thenReturn(work3); + return true; + } + return false; + } + }); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + // Key 1 is empty, so start() calls advance() which calls finishKey(1) and advance() to Key 2. + // Key 2 is non-empty, so start() returns true yielding Key 2. + assertTrue(iter.start()); + assertEquals("key2", iter.getCurrent().getValue().key()); + WindowedValue elem = + Iterables.getOnlyElement(iter.getCurrent().getValue().elementsIterable()); + assertEquals(84L, elem.getValue().longValue()); + + // Next advance() calls finishKey(2), calls advance() to Key 3. + // Key 3 is empty, so it loops, calls finishKey(3), calls advance() which returns false. + // So iter.advance() should return false. + assertFalse(iter.advance()); + + verify(mockContext, times(3)) + .finishKey(); // finishKey should have been called on key1, key2, key3 + } + } + + @Test + public void testWorkItemCancelled() throws IOException { + when(mockContext.workIsFailed()).thenReturn(true); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(0L).build(); + when(mockContext.getWorkItem()).thenReturn(workItem); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + iter.start(); + fail("Expected WorkItemCancelledException"); + } catch (WorkItemCancelledException e) { + // Expected + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index d5cf2948d928..31ea1bab07af 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -83,6 +83,7 @@ import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.PropertyNames; import org.apache.beam.runners.dataflow.worker.DataflowExecutionContext.DataflowExecutionStateTracker; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.StreamingModeExecutionStateRegistry; import org.apache.beam.runners.dataflow.worker.WorkerCustomSources.SplittableOnlyBoundedSource; import org.apache.beam.runners.dataflow.worker.counters.CounterSet; @@ -93,7 +94,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.config.FixedGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.testing.TestCountingSource; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader.NativeReaderIterator; @@ -106,6 +107,7 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.extensions.gcp.auth.TestCredential; @@ -209,6 +211,21 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla Instant::now); } + private void startContext(StreamingModeExecutionContext context, Work work) { + try { + context.start( + work, + mock(WindmillStateReader.class), + mock(WorkExecutor.class), + /* workQueueExecutor= */ null, + /* budgetHandle= */ null, + /* keyCoder= */ null, + /* keyTransitionListener= */ mock(KeyTransitionListener.class)); + } catch (CoderException e) { + throw new RuntimeException(e); + } + } + private static class SourceProducingSubSourcesInSplit extends MockSource { int numDesiredBundle; int sourceObjectSize; @@ -620,7 +637,12 @@ public void testReadUnboundedReader() throws Exception { executionStateRegistry, globalConfigHandle, Long.MAX_VALUE, - /*throwExceptionOnLargeOutput=*/ false); + /*throwExceptionOnLargeOutput=*/ false, + new HotKeyLogger(), + /*hotKeyLoggingEnabled=*/ false, + /*stepName=*/ "stepName", + "sourceBytesProcessCounterName", + SideInputStateFetcherFactory.fromOptions(options)); options.setNumWorkers(5); int maxElements = 10; @@ -631,8 +653,8 @@ public void testReadUnboundedReader() throws Exception { for (int i = 0; i < 10 * maxElements; /* Incremented in inner loop */ ) { // Initialize streaming context with state from previous iteration. - context.start( - "key", + startContext( + context, createMockWork( Windmill.WorkItem.newBuilder() .setKey(ByteString.copyFromUtf8("0000000000000001")) // key is zero-padded index. @@ -641,11 +663,7 @@ public void testReadUnboundedReader() throws Exception { .setSourceState( Windmill.SourceState.newBuilder().setState(state).build()) // Source state. .build(), - Watermarks.builder().setInputDataWatermark(new Instant(0)).build()), - mock(WindmillStateReader.class), - mock(SideInputStateFetcher.class), - Windmill.WorkItemCommitRequest.newBuilder(), - mock(WorkExecutor.class)); + Watermarks.builder().setInputDataWatermark(new Instant(0)).build())); @SuppressWarnings({"unchecked", "rawtypes"}) NativeReader>>> reader = @@ -992,7 +1010,12 @@ public void testFailedWorkItemsAbort() throws Exception { executionStateRegistry, globalConfigHandle, Long.MAX_VALUE, - /*throwExceptionOnLargeOutput=*/ false); + /*throwExceptionOnLargeOutput=*/ false, + new HotKeyLogger(), + /*hotKeyLoggingEnabled=*/ false, + /*stepName=*/ "stepName", + "sourceBytesProcessCounterName", + SideInputStateFetcherFactory.fromOptions(options)); options.setNumWorkers(5); int maxElements = 100; @@ -1020,13 +1043,7 @@ public void testFailedWorkItemsAbort() throws Exception { mock(HeartbeatSender.class)), false, Instant::now); - context.start( - "key", - dummyWork, - mock(WindmillStateReader.class), - mock(SideInputStateFetcher.class), - Windmill.WorkItemCommitRequest.newBuilder(), - mock(WorkExecutor.class)); + startContext(context, dummyWork); @SuppressWarnings({"unchecked", "rawtypes"}) NativeReader>>> reader = diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WorkTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WorkTest.java new file mode 100644 index 000000000000..80ca91da462f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WorkTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class WorkTest { + + private static Work createTestWork() { + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key")) + .setWorkToken(1L) + .setShardingKey(2L) + .build(); + return Work.create( + workItem, + workItem.getSerializedSize(), + Watermarks.builder().setInputDataWatermark(Instant.now()).build(), + Work.createProcessingContext( + "comp", + mock( + org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient + .class), + commit -> {}, + mock(HeartbeatSender.class)), + false, + Instant::now); + } + + @Test + public void testSetFailedBeforeListener() { + Work work = createTestWork(); + assertFalse(work.isFailed()); + + work.setFailed(); + assertTrue(work.isFailed()); + + AtomicBoolean listener = new AtomicBoolean(false); + work.setOnFailureListener(listener); + assertTrue(listener.get()); + } + + @Test + public void testSetFailedAfterListener() { + Work work = createTestWork(); + AtomicBoolean listener = new AtomicBoolean(false); + work.setOnFailureListener(listener); + assertFalse(listener.get()); + assertFalse(work.isFailed()); + + work.setFailed(); + assertTrue(work.isFailed()); + assertTrue(listener.get()); + } + + @Test + public void testConcurrentSetFailedAndSetOnFailureListener() throws Exception { + int numTrials = 5000; + ExecutorService executor = Executors.newFixedThreadPool(2); + try { + for (int i = 0; i < numTrials; i++) { + Work work = createTestWork(); + AtomicBoolean listener = new AtomicBoolean(false); + CountDownLatch latch = new CountDownLatch(1); + + Future f1 = + executor.submit( + () -> { + try { + latch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + work.setFailed(); + }); + + Future f2 = + executor.submit( + () -> { + try { + latch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + work.setOnFailureListener(listener); + }); + + latch.countDown(); + f1.get(5, TimeUnit.SECONDS); + f2.get(5, TimeUnit.SECONDS); + + assertTrue("Trial " + i + " failed: work should be failed", work.isFailed()); + assertTrue("Trial " + i + " failed: listener should be set to true", listener.get()); + } + } finally { + executor.shutdownNow(); + } + } +}