diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowExecutionContext.java index 6ff05b4b4452..888e954c1c9f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowExecutionContext.java @@ -150,6 +150,10 @@ boolean isSinkFullHintSet() { // the state size might grow unbounded. } + protected final long getBytesSinked() { + return bytesSinked; + } + /** * Sets a flag to indicate that a sink has enough data written to it. This hint is read by * upstream producers to stop producing if they can. Mainly used in streaming. diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 180dda153bb6..2329c718cbc4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -181,7 +181,7 @@ public final class StreamingDataflowWorker { "windmill_bounded_queue_executor_use_fair_monitor"; // Don't use. Experiment guarding multi key bundles. The feature is work in progress and // incomplete. - private static final String UNSTABLE_ENABLE_MULTI_KEY_BUNDLE = "unstable_enable_multi_key_bundle"; + public static final String UNSTABLE_ENABLE_MULTI_KEY_BUNDLE = "unstable_enable_multi_key_bundle"; private final WindmillStateCache stateCache; private AtomicReference statusPages = new AtomicReference<>(); @@ -257,6 +257,7 @@ private StreamingDataflowWorker( this.streamingWorkScheduler = StreamingWorkScheduler.create( options, + DataflowRunner.hasExperiment(options, UNSTABLE_ENABLE_MULTI_KEY_BUNDLE), clock, readerCache, mapTaskExecutorFactory, @@ -1206,9 +1207,14 @@ private void onCompleteCommit(CompleteCommit completeCommit) { computationStateCache .getIfPresent(completeCommit.computationId()) .ifPresent( - state -> + state -> { + if (completeCommit.retryableFailure()) { + state.reExecuteActiveWork(completeCommit.shardedKey(), completeCommit.workId()); + } else { state.completeWorkAndScheduleNextWorkForKey( - completeCommit.shardedKey(), completeCommit.workId())); + completeCommit.shardedKey(), completeCommit.workId()); + } + }); } @AutoValue diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 9401cc5f8ed9..d0cd5c823601 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -35,6 +35,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.concurrent.NotThreadSafe; @@ -52,6 +53,7 @@ import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope; import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; +import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; @@ -83,6 +85,8 @@ import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; import org.apache.beam.sdk.metrics.MetricsContainer; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -121,6 +125,12 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext { private static final Logger LOG = LoggerFactory.getLogger(StreamingModeExecutionContext.class); + private static final String WINDMILL_MAX_KEY_GROUP_BATCH_SIZE = + "windmill_max_key_group_batch_size"; + private static final String WINDMILL_MAX_KEY_GROUP_BATCH_TIME_MS = + "windmill_max_key_group_batch_time_ms"; + private static final String WINDMILL_MAX_KEY_GROUP_BATCH_SINK_BYTES = + "windmill_max_key_group_batch_sink_bytes"; private final String computationId; private final ImmutableMap stateNameMap; @@ -181,7 +191,7 @@ public class StreamingModeExecutionContext // Key switch listener to delegate MDC logging context and thread name updates public interface KeyTransitionListener { - void onKeyTransition(Work oldWork, Work newWork); + void onKeyTransition(@Nullable Work oldWork, Work newWork); } @SuppressWarnings("UnusedVariable") @@ -197,6 +207,13 @@ public interface KeyTransitionListener { private long stateBytesRead = 0; private final String sourceBytesProcessCounterName; + private final int maxKeyGroupBatchSize; + private final long maxKeyGroupBatchTimeNanos; + private final boolean multiKeyBundleEnabled; + private final long maxKeyGroupBatchSinkBytes; + private int workItemsPolled = 0; + private long bundleStartTimeNanos = 0; + public StreamingModeExecutionContext( CounterFactory counterFactory, String computationId, @@ -213,6 +230,7 @@ public StreamingModeExecutionContext( boolean hotKeyLoggingEnabled, String stepName, String sourceBytesProcessCounterName, + PipelineOptions options, SideInputStateFetcherFactory sideInputStateFetcherFactory) { super( counterFactory, @@ -232,7 +250,33 @@ public StreamingModeExecutionContext( this.hotKeyLoggingEnabled = hotKeyLoggingEnabled; this.stepName = checkNotNull(stepName); this.sourceBytesProcessCounterName = checkNotNull(sourceBytesProcessCounterName); - this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; + this.sideInputStateFetcherFactory = checkNotNull(sideInputStateFetcherFactory); + + // Initialize batch limits from pipeline options + this.maxKeyGroupBatchSize = + tryParseInt( + ExperimentalOptions.getExperimentValue(options, WINDMILL_MAX_KEY_GROUP_BATCH_SIZE), + 100, + WINDMILL_MAX_KEY_GROUP_BATCH_SIZE); + + long batchTimeMs = + tryParseLong( + ExperimentalOptions.getExperimentValue(options, WINDMILL_MAX_KEY_GROUP_BATCH_TIME_MS), + 100, + WINDMILL_MAX_KEY_GROUP_BATCH_TIME_MS); + this.maxKeyGroupBatchTimeNanos = TimeUnit.MILLISECONDS.toNanos(batchTimeMs); + + this.multiKeyBundleEnabled = + ExperimentalOptions.hasExperiment( + options, StreamingDataflowWorker.UNSTABLE_ENABLE_MULTI_KEY_BUNDLE); + + this.maxKeyGroupBatchSinkBytes = + tryParseLong( + ExperimentalOptions.getExperimentValue( + options, WINDMILL_MAX_KEY_GROUP_BATCH_SINK_BYTES), + StreamingDataflowWorker.MAX_SINK_BYTES, + WINDMILL_MAX_KEY_GROUP_BATCH_SINK_BYTES); + StreamingGlobalConfig config = globalConfigHandle.getConfig(); this.operationalLimits = config.operationalLimits(); this.windmillTagEncoding = @@ -241,6 +285,41 @@ public StreamingModeExecutionContext( : WindmillTagEncodingV1.instance(); } + private static int tryParseInt(@Nullable String value, int defaultValue, String experimentName) { + if (value == null) { + return defaultValue; + } + try { + return Integer.parseInt(value); + } catch (NumberFormatException e) { + LOG.warn( + "Failed to parse experiment {} value '{}' as integer, falling back to default: {}", + experimentName, + value, + defaultValue, + e); + return defaultValue; + } + } + + private static long tryParseLong( + @Nullable String value, long defaultValue, String experimentName) { + if (value == null) { + return defaultValue; + } + try { + return Long.parseLong(value); + } catch (NumberFormatException e) { + LOG.warn( + "Failed to parse experiment {} value '{}' as long, falling back to default: {}", + experimentName, + value, + defaultValue, + e); + return defaultValue; + } + } + @VisibleForTesting public final long getBacklogBytes() { return backlogBytes; @@ -337,6 +416,9 @@ public void start( this.budgetHandle = budgetHandle; this.keyTransitionListener = keyTransitionListener; + this.workItemsPolled = 1; + this.bundleStartTimeNanos = System.nanoTime(); + StreamingGlobalConfig config = globalConfigHandle.getConfig(); // Snapshot the limits for entire bundle processing. this.operationalLimits = config.operationalLimits(); @@ -687,10 +769,47 @@ private final long computeSourceBytesProcessed(String sourceBytesCounterName) { } public boolean advance() { - // TODO: get more work from workQueueExecutor and merge into the bundle here + if (!multiKeyBundleEnabled) { + return false; + } + if (workIsFailed()) { + throw new WorkItemCancelledException(checkStateNotNull(work).getWorkItem().getShardingKey()); + } + + BoundedQueueExecutor executor = checkStateNotNull(workQueueExecutor); + BoundedQueueExecutorWorkHandle handle = checkStateNotNull(budgetHandle); + Work activeWork = checkStateNotNull(work); + + if (activeWork.getKeyGroup().equals(Work.KeyGroup.DEFAULT) || shouldStopBatching()) { + return false; + } + + @Nullable + ExecutableWork additionalWork = + executor.pollWork(computationId, activeWork.getKeyGroup(), handle); + if (additionalWork != null) { + flushStateInternal(); + Work newWork = additionalWork.work(); + ++workItemsPolled; + checkStateNotNull(keyTransitionListener).onKeyTransition(activeWork, newWork); + startForNewKey(newWork); + return true; + } + return false; } + private boolean shouldStopBatching() { + if (workItemsPolled >= maxKeyGroupBatchSize) { + return true; + } + long elapsedNanos = System.nanoTime() - bundleStartTimeNanos; + if (elapsedNanos >= maxKeyGroupBatchTimeNanos) { + return true; + } + return getBytesSinked() >= maxKeyGroupBatchSinkBytes; + } + private void startForNewKey(Work newWork, WindmillStateReader reader) throws CoderException { newWork.setState(Work.State.PROCESSING); if (keyTransitionListener != null && this.work != null && this.work != newWork) { @@ -726,8 +845,8 @@ private void startForNewKey(Work newWork, WindmillStateReader reader) throws Cod WindmillStateCache.ForKey cacheForKey = stateCache.forKey( getComputationKey(), newWork.getWorkItem().getCacheToken(), getWorkToken()); - this.activeStateReader = reader; - startStepContexts(reader, processingTime, cacheForKey, newWork.watermarks()); + this.activeStateReader = newWork.createWindmillStateReader(this::workIsFailed); + startStepContexts(this.activeStateReader, processingTime, cacheForKey, newWork.watermarks()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkCancelingException.java similarity index 54% rename from runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidException.java rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkCancelingException.java index 29b16b71883f..73a307641b96 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkCancelingException.java @@ -17,21 +17,30 @@ */ package org.apache.beam.runners.dataflow.worker; -import javax.annotation.Nullable; +import org.checkerframework.checker.nullness.qual.Nullable; -/** Indicates that the key token was invalid when data was attempted to be fetched. */ -public class KeyTokenInvalidException extends RuntimeException { - public KeyTokenInvalidException(String key) { - super("Unable to fetch data due to token mismatch for key " + key); +/** + * Indicates that the work is no longer valid and should be canceled. It is thrown as a signal for + * upper layers to mark the work as failed. + */ +public class WorkCancelingException extends RuntimeException { + + public WorkCancelingException(long sharding_key) { + super("Work canceling exception for key " + sharding_key); + } + + public WorkCancelingException(Throwable cause) { + super(cause); } - /** Returns whether an exception was caused by a {@link KeyTokenInvalidException}. */ - public static boolean isKeyTokenInvalidException(@Nullable Throwable t) { - while (t != null) { - if (t instanceof KeyTokenInvalidException) { + /** Returns whether an exception was caused by a {@link WorkCancelingException}. */ + public static boolean isWorkCancelingException(Throwable t) { + @Nullable Throwable throwable = t; + while (throwable != null) { + if (throwable instanceof WorkCancelingException) { return true; } - t = t.getCause(); + throwable = throwable.getCause(); } return false; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java index a12a5075c5ee..68cbab32254c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java @@ -17,31 +17,10 @@ */ package org.apache.beam.runners.dataflow.worker; -/** Indicates that the work item was cancelled and should not be retried. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) +/** Indicates that the work item was canceled. */ public class WorkItemCancelledException extends RuntimeException { + public WorkItemCancelledException(long sharding_key) { super("Work item cancelled for key " + sharding_key); } - - public WorkItemCancelledException(String message, Throwable cause) { - super(message, cause); - } - - public WorkItemCancelledException(Throwable cause) { - super(cause); - } - - /** Returns whether an exception was caused by a {@link WorkItemCancelledException}. */ - public static boolean isWorkItemCancelledException(Throwable t) { - while (t != null) { - if (t instanceof WorkItemCancelledException) { - return true; - } - t = t.getCause(); - } - return false; - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java index e430f6c8f638..f49aa31a439a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java @@ -88,6 +88,11 @@ static ActiveWorkState create(WindmillStateCache.ForComputation computationState return new ActiveWorkState(new HashMap<>(), computationStateCache); } + synchronized Optional getActiveWork(ShardedKey shardedKey, WorkId workId) { + LinkedHashMap workQueue = activeWork.get(shardedKey.shardingKey()); + return workQueue == null ? Optional.empty() : Optional.ofNullable(workQueue.get(workId)); + } + @VisibleForTesting static ActiveWorkState forTesting( Map> activeWork, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/BoundedQueueExecutorWorkHandle.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/BoundedQueueExecutorWorkHandle.java index 1ca534966947..20661aae0a04 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/BoundedQueueExecutorWorkHandle.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/BoundedQueueExecutorWorkHandle.java @@ -17,8 +17,13 @@ */ package org.apache.beam.runners.dataflow.worker.streaming; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; + /** * A handle to use when requesting pulling more work from @BoundedQueueExecutor * via @BoundedQueueExecutor.pollWork */ -public interface BoundedQueueExecutorWorkHandle {} +public interface BoundedQueueExecutorWorkHandle { + // Returns all work that are tracked by the handle + ImmutableList getWorkBatch(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java index 3886d4fbc01b..e9f6ddc55de6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java @@ -131,6 +131,10 @@ public void completeWorkAndScheduleNextWorkForKey(ShardedKey shardedKey, WorkId .ifPresent(this::forceExecute); } + public void reExecuteActiveWork(ShardedKey shardedKey, WorkId workId) { + activeWorkState.getActiveWork(shardedKey, workId).ifPresent(this::forceExecute); + } + public void invalidateStuckCommits(Instant stuckCommitDeadline) { activeWorkState.invalidateStuckCommits( stuckCommitDeadline, this::completeWorkAndScheduleNextWorkForKey); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index 252a16a38bc9..f9cfec7e6807 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -36,6 +36,7 @@ import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.worker.ActiveMessageMetadata; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; +import org.apache.beam.runners.dataflow.worker.WorkCancelingException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; @@ -87,6 +88,7 @@ public final class Work implements RefreshableWork { private final AtomicReference<@Nullable AtomicBoolean> onFailureListener = new AtomicReference<>(null); private final boolean drainMode; + private ImmutableList getWorkStreamLatencies; private Work( WorkItem workItem, @@ -94,7 +96,8 @@ private Work( Watermarks watermarks, ProcessingContext processingContext, boolean drainMode, - Supplier clock) { + Supplier clock, + ImmutableList getWorkStreamLatencies) { this.shardedKey = ShardedKey.create(workItem.getKey(), workItem.getShardingKey()); this.workItem = workItem; this.serializedWorkItemSize = serializedWorkItemSize; @@ -118,6 +121,7 @@ private Work( + Long.toHexString(workItem.getWorkToken()); this.currentState = TimedState.initialState(startTime); this.isFailed = false; + this.getWorkStreamLatencies = getWorkStreamLatencies; } public static Work create( @@ -128,7 +132,31 @@ public static Work create( boolean drainMode, Supplier clock) { return new Work( - workItem, serializedWorkItemSize, watermarks, processingContext, drainMode, clock); + workItem, + serializedWorkItemSize, + watermarks, + processingContext, + drainMode, + clock, + ImmutableList.of()); + } + + public static Work create( + WorkItem workItem, + long serializedWorkItemSize, + Watermarks watermarks, + ProcessingContext processingContext, + boolean drainMode, + Supplier clock, + ImmutableList getWorkStreamLatencies) { + return new Work( + workItem, + serializedWorkItemSize, + watermarks, + processingContext, + drainMode, + clock, + getWorkStreamLatencies); } public static ProcessingContext createProcessingContext( @@ -205,11 +233,31 @@ public ShardedKey getShardedKey() { } public Optional fetchKeyedState(KeyedGetDataRequest keyedGetDataRequest) { - return processingContext.fetchKeyedState(keyedGetDataRequest); + try { + Optional response = + processingContext.fetchKeyedState(keyedGetDataRequest); + if (response.isPresent() && response.get().getFailed()) { + // Work is not valid in backend anymore. + this.setFailed(); + } + return response; + } catch (RuntimeException e) { + if (WorkCancelingException.isWorkCancelingException(e)) { + this.setFailed(); + } + throw e; + } } public GlobalData fetchSideInput(GlobalDataRequest request) { - return processingContext.getDataClient().getSideInputData(request); + try { + return processingContext.getDataClient().getSideInputData(request); + } catch (RuntimeException e) { + if (WorkCancelingException.isWorkCancelingException(e)) { + this.setFailed(); + } + throw e; + } } public String backendWorkerToken() { @@ -293,8 +341,8 @@ public Consumer workCommitter() { return processingContext.workCommitter(); } - public WindmillStateReader createWindmillStateReader() { - return WindmillStateReader.forWork(this); + public WindmillStateReader createWindmillStateReader(Supplier workIsFailed) { + return WindmillStateReader.forWork(this, workIsFailed); } @Override @@ -302,11 +350,17 @@ public WorkId id() { return id; } - public void recordGetWorkStreamLatencies( - ImmutableList getWorkStreamLatencies) { - for (LatencyAttribution latency : getWorkStreamLatencies) { - totalDurationPerState.put( - latency.getState(), Duration.millis(latency.getTotalDurationMillis())); + public ImmutableList getWorkStreamLatencies() { + return getWorkStreamLatencies; + } + + public void recordGetWorkStreamLatencies() { + if (!getWorkStreamLatencies.isEmpty()) { + for (LatencyAttribution latency : getWorkStreamLatencies) { + totalDurationPerState.put( + latency.getState(), Duration.millis(latency.getTotalDurationMillis())); + } + this.getWorkStreamLatencies = ImmutableList.of(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java index 8964246c1160..9eb9a37b1b76 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java @@ -20,6 +20,8 @@ import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadFactory; @@ -30,9 +32,9 @@ import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Work; -import org.apache.beam.runners.dataflow.worker.streaming.Work.KeyGroup; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor.Guard; import org.checkerframework.checker.nullness.qual.Nullable; @@ -260,7 +262,7 @@ final class BoundedQueueExecutorWorkHandleImpl implements BoundedQueueExecutorWorkHandle, AutoCloseable { @GuardedBy("this") - private int elements; + private final List workBatch; @GuardedBy("this") private long bytes; @@ -268,16 +270,17 @@ final class BoundedQueueExecutorWorkHandleImpl @GuardedBy("this") private boolean closed = false; - private BoundedQueueExecutorWorkHandleImpl(int elements, long bytes) { - checkArgument(elements >= 0 && bytes >= 0); - this.elements = elements; + private BoundedQueueExecutorWorkHandleImpl(Work work, long bytes) { + checkArgument(bytes >= 0); + this.workBatch = new ArrayList<>(); + this.workBatch.add(checkArgumentNotNull(work)); this.bytes = bytes; } /** * Merges the budget from another handle into this handle. * - *

This transfers the budget (elements and bytes) from the {@code other} handle to this + *

This transfers the budget (workBatch and bytes) from the {@code other} handle to this * handle, and marks the {@code other} handle as closed to prevent it from releasing the budget * again if it is closed. */ @@ -287,10 +290,10 @@ public void merge(BoundedQueueExecutorWorkHandleImpl other) { Preconditions.checkState(!closed, "Cannot merge into a closed handle"); synchronized (other) { Preconditions.checkState(!other.closed, "Cannot merge a closed handle"); - this.elements += other.elements; + this.workBatch.addAll(other.workBatch); this.bytes += other.bytes; other.closed = true; - other.elements = 0; + other.workBatch.clear(); other.bytes = 0; } } @@ -300,9 +303,9 @@ public synchronized boolean isClosed() { return closed; } - @VisibleForTesting - synchronized int elements() { - return elements; + @Override + public synchronized ImmutableList getWorkBatch() { + return ImmutableList.copyOf(workBatch); } @VisibleForTesting @@ -314,7 +317,7 @@ synchronized long bytes() { public synchronized void close() { if (closed) return; closed = true; - decrementCounters(this.elements, this.bytes); + decrementCounters(this.workBatch.size(), this.bytes); } } @@ -350,7 +353,7 @@ private void executeMonitorHeld(ExecutableWork work, long workBytes) { bytesOutstanding += workBytes; monitor.leave(); BoundedQueueExecutorWorkHandleImpl handle = - new BoundedQueueExecutorWorkHandleImpl(1, workBytes); + new BoundedQueueExecutorWorkHandleImpl(work.work(), workBytes); try { executor.execute(new QueuedWork(work, handle)); } catch (Throwable t) { @@ -379,14 +382,15 @@ private void executeMonitorHeld(Runnable work) { } @VisibleForTesting - BoundedQueueExecutorWorkHandleImpl createBudgetHandle(int elements, long bytes) { - return new BoundedQueueExecutorWorkHandleImpl(elements, bytes); + BoundedQueueExecutorWorkHandleImpl createBudgetHandle(Work work, long bytes) { + return new BoundedQueueExecutorWorkHandleImpl(work, bytes); } public @Nullable ExecutableWork pollWork( String computationId, Work.KeyGroup keyGroup, BoundedQueueExecutorWorkHandle handle) { + checkArgument( + computationId != null && keyGroup != null && !keyGroup.equals(Work.KeyGroup.DEFAULT)); checkArgument(handle instanceof BoundedQueueExecutorWorkHandleImpl); - checkArgument(computationId != null && keyGroup != null && !keyGroup.equals(KeyGroup.DEFAULT)); BoundedQueueExecutorWorkHandleImpl internalHandle = (BoundedQueueExecutorWorkHandleImpl) handle; if (keyGroupWorkQueue == null) { return null; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index 526b67890783..36001c151508 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -108,6 +108,11 @@ boolean commitWorkItem( Windmill.WorkItemCommitRequest request, Consumer onDone); + boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone); + /** Flushes any pending work items to the wire. */ void flush(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java index b840d22a3434..e52a9846645f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java @@ -18,11 +18,14 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.commits; import com.google.auto.value.AutoValue; +import java.util.Optional; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; /** Value class for a queued commit. */ @Internal @@ -32,20 +35,43 @@ public abstract class Commit { public static Commit create( WorkItemCommitRequest request, ComputationState computationState, Work work) { Preconditions.checkArgument(request.getSerializedSize() > 0); - return new AutoValue_Commit(request, computationState, work); + return new AutoValue_Commit( + Optional.of(request), computationState, Optional.empty(), ImmutableList.of(work)); + } + + public static Commit createMultiKey( + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest, + ComputationState computationState, + ImmutableList workBatch) { + Preconditions.checkArgument(!workBatch.isEmpty()); + return new AutoValue_Commit( + Optional.empty(), computationState, Optional.of(multiKeyRequest), workBatch); } public final String computationId() { return computationState().getComputationId(); } - public abstract WorkItemCommitRequest request(); + public abstract Optional singleKeyRequest(); public abstract ComputationState computationState(); - public abstract Work work(); + public abstract Optional multiKeyRequest(); + + public abstract ImmutableList workBatch(); + + public final boolean isFailed() { + for (Work w : workBatch()) { + if (w.isFailed()) { + return true; + } + } + return false; + } public final int getSize() { - return request().getSerializedSize(); + return multiKeyRequest() + .map(Windmill.MultiKeyWorkItemCommitRequest::getSerializedSize) + .orElseGet(() -> singleKeyRequest().get().getSerializedSize()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java index e33e853d3d76..e168d92987fb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java @@ -37,24 +37,14 @@ @AutoValue public abstract class CompleteCommit { - public static CompleteCommit create(Commit commit, CommitStatus commitStatus) { - return new AutoValue_CompleteCommit( - commit.computationId(), - ShardedKey.create(commit.request().getKey(), commit.request().getShardingKey()), - WorkId.builder() - .setWorkToken(commit.request().getWorkToken()) - .setCacheToken(commit.request().getCacheToken()) - .build(), - commitStatus); - } - public static CompleteCommit create( - String computationId, ShardedKey shardedKey, WorkId workId, CommitStatus status) { - return new AutoValue_CompleteCommit(computationId, shardedKey, workId, status); - } - - public static CompleteCommit forFailedWork(Commit commit) { - return create(commit, CommitStatus.ABORTED); + String computationId, + ShardedKey shardedKey, + WorkId workId, + CommitStatus status, + boolean retryableFailure) { + return new AutoValue_CompleteCommit( + computationId, shardedKey, workId, status, retryableFailure); } public abstract String computationId(); @@ -64,4 +54,6 @@ public static CompleteCommit forFailedWork(Commit commit) { public abstract WorkId workId(); public abstract CommitStatus status(); + + public abstract boolean retryableFailure(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java index 20b95b0661d0..58f0dbbea242 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.commits; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutorService; @@ -112,7 +114,8 @@ private void commitLoop() { } while (commit != null) { ComputationState computationState = commit.computationState(); - commit.work().setState(Work.State.COMMITTING); + checkState(commit.workBatch().size() == 1); + commit.workBatch().get(0).setState(Work.State.COMMITTING); Windmill.ComputationCommitWorkRequest.Builder computationRequestBuilder = computationRequestMap.get(computationState); if (computationRequestBuilder == null) { @@ -120,7 +123,8 @@ private void commitLoop() { computationRequestBuilder.setComputationId(computationState.getComputationId()); computationRequestMap.put(computationState, computationRequestBuilder); } - computationRequestBuilder.addRequests(commit.request()); + checkState(commit.singleKeyRequest().isPresent()); + computationRequestBuilder.addRequests(commit.singleKeyRequest().get()); // Send the request if we've exceeded the bytes or there is no more // pending work. commitBytes is a long, so this cannot overflow. commitBytes += commit.getSize(); @@ -155,7 +159,8 @@ private void completeWork( .setCacheToken(workRequest.getCacheToken()) .setWorkToken(workRequest.getWorkToken()) .build(), - Windmill.CommitStatus.OK)); + Windmill.CommitStatus.OK, + /* retryableFailure= */ false)); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java index b68f53121b86..72d9e5ed8d03 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java @@ -30,6 +30,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.sdk.annotations.Internal; @@ -100,7 +101,7 @@ public void start() { @Override public void commit(Commit commit) { - if (commit.work().isFailed()) { + if (commit.isFailed()) { failCommit(commit); } else { commitQueue.put(commit); @@ -113,8 +114,8 @@ public void commit(Commit commit) { "Trying to queue commit on shutdown, failing commit=[computationId={}, shardingKey={}," + " workId={} ].", commit.computationId(), - commit.work().getShardedKey(), - commit.work().id()); + commit.workBatch().get(0).getShardedKey(), + commit.workBatch().get(0).id()); drainCommitQueue(); } } @@ -147,8 +148,42 @@ private void drainCommitQueue() { } private void failCommit(Commit commit) { - commit.work().setFailed(); - onCommitComplete.accept(CompleteCommit.forFailedWork(commit)); + if (!isRunning.get()) { + // Shutting down, fail everything unconditionally to prevent infinite loops + for (Work w : commit.workBatch()) { + w.setFailed(); + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), + w.getShardedKey(), + w.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ false)); + } + return; + } + + // Still running, only fail actually failed work, and request re-execution for valid ones + for (Work w : commit.workBatch()) { + if (w.isFailed()) { + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), + w.getShardedKey(), + w.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ false)); + } else { + LOG.debug("Requesting re-execution for valid work {} from failed commit", w.id()); + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), + w.getShardedKey(), + w.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ true)); + } + } } @Override @@ -173,8 +208,8 @@ private void streamingCommitLoop() { // take() blocks until a value is available in the commitQueue. Preconditions.checkNotNull(initialCommit); - if (initialCommit.work().isFailed()) { - onCommitComplete.accept(CompleteCommit.forFailedWork(initialCommit)); + if (initialCommit.isFailed()) { + failCommit(initialCommit); initialCommit = null; continue; } @@ -202,20 +237,51 @@ private void streamingCommitLoop() { /** Adds the commit to the batch if it fits, returning true if it is consumed. */ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatcher batcher) { Preconditions.checkNotNull(commit); - commit.work().setState(Work.State.COMMITTING); + for (Work w : commit.workBatch()) { + w.setState(Work.State.COMMITTING); + } activeCommitBytes.addAndGet(commit.getSize()); - boolean isCommitAccepted = - batcher.commitWorkItem( - commit.computationId(), - commit.request(), - commitStatus -> { - onCommitComplete.accept(CompleteCommit.create(commit, commitStatus)); - activeCommitBytes.addAndGet(-commit.getSize()); - }); + boolean isCommitAccepted; + if (commit.multiKeyRequest().isPresent()) { + isCommitAccepted = + batcher.commitMultiKeyWorkItem( + commit.computationId(), + commit.multiKeyRequest().get(), + commitStatus -> { + for (Work w : commit.workBatch()) { + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), + w.getShardedKey(), + w.id(), + commitStatus, + /* retryableFailure= */ false)); + } + activeCommitBytes.addAndGet(-commit.getSize()); + }); + } else { + isCommitAccepted = + batcher.commitWorkItem( + commit.computationId(), + commit.singleKeyRequest().get(), + commitStatus -> { + Work w = commit.workBatch().get(0); + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), + w.getShardedKey(), + w.id(), + commitStatus, + /* retryableFailure= */ false)); + activeCommitBytes.addAndGet(-commit.getSize()); + }); + } // Since the commit was not accepted, revert the changes made above. if (!isCommitAccepted) { - commit.work().setState(Work.State.COMMIT_QUEUED); + for (Work w : commit.workBatch()) { + w.setState(Work.State.COMMIT_QUEUED); + } activeCommitBytes.addAndGet(-commit.getSize()); } @@ -246,8 +312,8 @@ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatch } // Drop commits for failed work. Such commits will be dropped by Windmill anyway. - if (commit.work().isFailed()) { - onCommitComplete.accept(CompleteCommit.forFailedWork(commit)); + if (commit.isFailed()) { + failCommit(commit); continue; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java index ab12946ad18b..d233bf091b6a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java @@ -19,6 +19,7 @@ import java.io.PrintWriter; import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.WorkCancelingException; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; @@ -62,7 +63,7 @@ public Windmill.KeyedGetDataResponse getStateData( try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { return getDataStream.requestKeyedData(computationId, request); } catch (WindmillStreamShutdownException e) { - throw new WorkItemCancelledException(request.getShardingKey()); + throw new WorkCancelingException(request.getShardingKey()); } catch (Exception e) { throw new GetDataException( "Error occurred fetching state for computation=" @@ -87,7 +88,7 @@ public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) { return sideInputGetDataStream.requestGlobalData(request); } catch (WindmillStreamShutdownException e) { - throw new WorkItemCancelledException(e); + throw new WorkCancelingException(e); } catch (Exception e) { throw new GetDataException( "Error occurred fetching side input for tag=" + request.getDataId(), e); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 160b0cce0133..d2dad210aa23 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -34,6 +34,7 @@ import java.util.function.Function; import javax.annotation.Nullable; import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitRequestChunk; @@ -308,7 +309,7 @@ private void flushInternal(Map requests) if (requests.size() == 1) { Map.Entry elem = requests.entrySet().iterator().next(); - if (elem.getValue().getRequest().getSerializedSize() + if (elem.getValue().serializedCommit().size() > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { issueMultiChunkRequest(elem.getKey(), elem.getValue()); } else { @@ -324,9 +325,10 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest) StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); requestBuilder .addCommitChunkBuilder() - .setComputationId(pendingRequest.getComputationId()) + .setComputationId(pendingRequest.computationId()) .setRequestId(id) .setShardingKey(pendingRequest.shardingKey()) + .setCommitType(pendingRequest.commitType()) .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); StreamingCommitWorkRequest chunk = requestBuilder.build(); synchronized (this) { @@ -349,14 +351,15 @@ private void issueBatchedRequest(Map requests) for (Map.Entry entry : requests.entrySet()) { PendingRequest request = entry.getValue(); StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder(); - if (lastComputation == null || !lastComputation.equals(request.getComputationId())) { - chunkBuilder.setComputationId(request.getComputationId()); - lastComputation = request.getComputationId(); + if (lastComputation == null || !lastComputation.equals(request.computationId())) { + chunkBuilder.setComputationId(request.computationId()); + lastComputation = request.computationId(); } chunkBuilder .setRequestId(entry.getKey()) .setShardingKey(request.shardingKey()) - .setSerializedWorkItemCommit(request.serializedCommit()); + .setSerializedWorkItemCommit(request.serializedCommit()) + .setCommitType(request.commitType()); } StreamingCommitWorkRequest request = requestBuilder.build(); synchronized (this) { @@ -376,7 +379,7 @@ private void issueBatchedRequest(Map requests) private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) throws WindmillStreamShutdownException { - checkNotNull(pendingRequest.getComputationId(), "Cannot commit WorkItem w/o a computationId."); + checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId."); ByteString serializedCommit = pendingRequest.serializedCommit(); synchronized (this) { if (isShutdown) { @@ -397,8 +400,9 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) StreamingCommitRequestChunk.newBuilder() .setRequestId(id) .setSerializedWorkItemCommit(chunk) - .setComputationId(pendingRequest.getComputationId()) - .setShardingKey(pendingRequest.shardingKey()); + .setComputationId(pendingRequest.computationId()) + .setShardingKey(pendingRequest.shardingKey()) + .setCommitType(pendingRequest.commitType()); int remaining = serializedCommit.size() - end; if (remaining > 0) { chunkBuilder.setRemainingBytesForWorkItem(remaining); @@ -416,24 +420,44 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) private static class PendingRequest { private final String computationId; - private final WorkItemCommitRequest request; + private final long shardingKey; + private final ByteString serializedCommit; + private final StreamingCommitRequestChunk.CommitType commitType; private final Consumer onDone; private final long startTimeNanos; // System.nanoTime() of when request began. private PendingRequest( - String computationId, WorkItemCommitRequest request, Consumer onDone) { + String computationId, + long shardingKey, + ByteString serializedCommit, + StreamingCommitRequestChunk.CommitType commitType, + Consumer onDone) { this.computationId = computationId; - this.request = request; + this.shardingKey = shardingKey; + this.serializedCommit = serializedCommit; + this.commitType = commitType; this.onDone = onDone; this.startTimeNanos = System.nanoTime(); } - String getComputationId() { + String computationId() { return computationId; } - WorkItemCommitRequest getRequest() { - return request; + long shardingKey() { + return shardingKey; + } + + ByteString serializedCommit() { + return serializedCommit; + } + + StreamingCommitRequestChunk.CommitType commitType() { + return commitType; + } + + Consumer onDone() { + return onDone; } long getStartTimeNanos() { @@ -441,21 +465,13 @@ long getStartTimeNanos() { } private long getBytes() { - return (long) request.getSerializedSize() + computationId.length(); - } - - private ByteString serializedCommit() { - return request.toByteString(); + return (long) serializedCommit.size() + computationId.length(); } private void completeWithStatus(CommitStatus commitStatus) { onDone.accept(commitStatus); } - private long shardingKey() { - return request.getShardingKey(); - } - private void abort() { completeWithStatus(CommitStatus.ABORTED); } @@ -512,7 +528,34 @@ public boolean commitWorkItem( return false; } - PendingRequest request = new PendingRequest(computation, commitRequest, onDone); + PendingRequest request = + new PendingRequest( + computation, + commitRequest.getShardingKey(), + commitRequest.toByteString(), + StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_SINGLE_KEY, + onDone); + add(idGenerator.incrementAndGet(), request); + return true; + } + + @Override + public boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest commitRequest, + Consumer onDone) { + if (!canAccept(commitRequest.getSerializedSize() + computation.length())) { + return false; + } + Preconditions.checkArgument(commitRequest.getRequestsCount() > 0); + PendingRequest request = + new PendingRequest( + computation, + // Any key in the batch for routing + commitRequest.getRequests(0).getShardingKey(), + commitRequest.toByteString(), + StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY, + onDone); add(idGenerator.incrementAndGet(), request); return true; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java index c609bed4eae0..6c5ae50858cc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java @@ -36,8 +36,8 @@ import java.util.function.Supplier; import java.util.stream.Collectors; import javax.annotation.Nullable; -import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; +import org.apache.beam.runners.dataflow.worker.WorkCancelingException; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -153,7 +153,7 @@ static WindmillStateReader forTesting( fetchStateFromWindmillFn, key, shardingKey, workToken, () -> null, () -> Boolean.FALSE); } - public static WindmillStateReader forWork(Work work) { + public static WindmillStateReader forWork(Work work, Supplier workItemIsFailed) { return new WindmillStateReader( work::fetchKeyedState, work.getWorkItem().getKey(), @@ -163,7 +163,7 @@ public static WindmillStateReader forWork(Work work) { work.setState(Work.State.READING); return () -> work.setState(Work.State.PROCESSING); }, - work::isFailed); + workItemIsFailed); } private Future stateFuture(StateTag stateTag, @Nullable Coder coder) { @@ -588,7 +588,8 @@ private KeyedGetDataRequest createRequest(Iterable> toFetch) { private void consumeResponse(KeyedGetDataResponse response, Set> toFetch) { bytesRead += response.getSerializedSize(); if (response.getFailed()) { - throw new KeyTokenInvalidException(key.toStringUtf8()); + // upper layers will fail the work on seeing this exception. + throw new WorkCancelingException(shardingKey); } if (!key.equals(response.getKey())) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java index 4a52d9fde771..2930244b40b4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java @@ -84,6 +84,7 @@ final class ComputationWorkExecutorFactory { private final SinkRegistry sinkRegistry; private final DataflowExecutionStateSampler sampler; private final CounterSet pendingDeltaCounters; + private final SideInputStateFetcherFactory sideInputStateFetcherFactory; /** * Function which converts map tasks to their network representation for execution. @@ -287,6 +288,7 @@ private StreamingModeExecutionContext createExecutionContext( hotKeyLoggingEnabled, stepName, computationState.sourceBytesProcessCounterName(), + options, sideInputStateFetcherFactory); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index e9b85d720d2b..69cf41fd1201 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -21,6 +21,7 @@ import com.google.api.services.dataflow.model.MapTask; import com.google.auto.value.AutoValue; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentMap; @@ -56,7 +57,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; -import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.WorkFailureProcessor; import org.apache.beam.sdk.annotations.Internal; @@ -89,6 +89,7 @@ public class StreamingWorkScheduler { private final DataflowExecutionStateSampler sampler; private final StreamingGlobalConfigHandle globalConfigHandle; private final BoundedQueueExecutor workExecutor; + private final boolean multiKeyExperimentEnabled; public StreamingWorkScheduler( Supplier clock, @@ -100,7 +101,8 @@ public StreamingWorkScheduler( StreamingCounters streamingCounters, ConcurrentMap stageInfoMap, DataflowExecutionStateSampler sampler, - StreamingGlobalConfigHandle globalConfigHandle) { + StreamingGlobalConfigHandle globalConfigHandle, + boolean multiKeyExperimentEnabled) { this.clock = clock; this.workExecutor = workExecutor; this.computationWorkExecutorFactory = computationWorkExecutorFactory; @@ -111,10 +113,12 @@ public StreamingWorkScheduler( this.stageInfoMap = stageInfoMap; this.sampler = sampler; this.globalConfigHandle = globalConfigHandle; + this.multiKeyExperimentEnabled = multiKeyExperimentEnabled; } public static StreamingWorkScheduler create( DataflowWorkerHarnessOptions options, + boolean multiKeyExperimentEnabled, Supplier clock, ReaderCache readerCache, DataflowMapTaskExecutorFactory mapTaskExecutorFactory, @@ -155,7 +159,8 @@ public static StreamingWorkScheduler create( streamingCounters, stageInfoMap, sampler, - globalConfigHandle); + globalConfigHandle, + multiKeyExperimentEnabled); } private static long computeShuffleBytesRead(Windmill.WorkItem workItem) { @@ -183,12 +188,6 @@ private static Windmill.WorkItemCommitRequest buildWorkItemTruncationRequest( return outputBuilder.build(); } - /** Sets the stage name and workId of the Thread executing the {@link Work} for logging. */ - private static void setUpWorkLoggingContext(String workLatencyTrackingId, String computationId) { - setLoggingContextWorkId(workLatencyTrackingId); - setLoggingContextComputation(computationId); - } - private static void setLoggingContextComputation(@Nullable String computationId) { DataflowWorkerLoggingMDC.setStageName(computationId); } @@ -214,8 +213,14 @@ public void scheduleWork( computationState.activateWork( ExecutableWork.create( Work.create( - workItem, serializedWorkItemSize, watermarks, processingContext, drainMode, clock), - (work, handle) -> processWork(computationState, work, getWorkStreamLatencies, handle))); + workItem, + serializedWorkItemSize, + watermarks, + processingContext, + drainMode, + clock, + getWorkStreamLatencies), + (work, handle) -> processWork(computationState, work, handle))); } /** Adds any applied finalize ids to the commit finalizer to have their callbacks executed. */ @@ -229,25 +234,20 @@ public void queueAppliedFinalizeIds(ImmutableList appliedFinalizeIds) { * internally if processing fails due to uncaught {@link Exception}(s). * * @implNote This will block the calling thread during execution of user DoFns. - * @param handle handled to pass to BoundedQueueExecutor.pollWork, currently unused + * @param handle handled to pass to BoundedQueueExecutor.pollWork */ - private void processWork( - ComputationState computationState, - Work work, - ImmutableList getWorkStreamLatencies, - BoundedQueueExecutorWorkHandle handle) { - work.recordGetWorkStreamLatencies(getWorkStreamLatencies); - processWork(computationState, work, handle); - } - private void processWork( ComputationState computationState, Work work, BoundedQueueExecutorWorkHandle handle) { Windmill.WorkItem workItem = work.getWorkItem(); String computationId = computationState.getComputationId(); - work.setProcessingThreadName(Thread.currentThread().getName()); - work.setState(Work.State.PROCESSING); - setUpWorkLoggingContext(work.getLatencyTrackingId(), computationId); LOG.debug("Starting processing for {}:\n{}", computationId, work); + setLoggingContextComputation(computationId); + KeyTransitionListener keyTransitionListener = createKeyTransitionListener(); + keyTransitionListener.onKeyTransition(null, work); + + // Before any processing starts, call any pending OnCommit callbacks. Nothing that requires + // cleanup should be done before this, since we might exit early here. + commitFinalizer.finalizeCommits(workItem.getSourceState().getFinalizeIdsList()); if (workItem.getSourceState().getOnlyFinalize()) { handleOnlyFinalize(computationState, work, workItem); @@ -264,7 +264,8 @@ private void processWork( } // Execute the user code for the Work batch. - ExecuteWorkResult executeWorkResult = executeWork(work, stageInfo, computationState, handle); + ExecuteWorkResult executeWorkResult = + executeWork(work, stageInfo, computationState, handle, keyTransitionListener); workBatch = executeWorkResult.workBatch(); List workItemCommits = executeWorkResult.workItemCommits(); @@ -275,21 +276,7 @@ private void processWork( recordProcessingStats(workBatch, workItemCommits, executeWorkResult.stateBytesRead()); LOG.debug("Processing done for work batch size: {}", workBatch.size()); } catch (Throwable t) { - // OutOfMemoryError that are caught will be rethrown and trigger jvm termination. - try { - workFailureProcessor.logAndProcessFailure( - computationId, - ExecutableWork.create(work, (retry, h) -> processWork(computationState, retry, h)), - t, - invalidWork -> - computationState.completeWorkAndScheduleNextWorkForKey( - invalidWork.getShardedKey(), invalidWork.id())); - } catch (OutOfMemoryError oom) { - throw oom; - } catch (Throwable t2) { - LOG.warn("Failed to process work failure safely for work {}", work.id(), t2); - throw ExceptionUtils.safeWrapThrowableAsException(t2); - } + handleProcessWorkFailure(computationState, handle.getWorkBatch(), computationId, work, t); } finally { List processedWorkBatch = workBatch != null ? workBatch : ImmutableList.of(work); // Update total processing time counters. Updating in finally clause ensures that @@ -362,7 +349,8 @@ private ExecuteWorkResult executeWork( Work work, StageInfo stageInfo, ComputationState computationState, - BoundedQueueExecutorWorkHandle handle) + BoundedQueueExecutorWorkHandle handle, + KeyTransitionListener keyTransitionListener) throws Exception { ComputationWorkExecutor computationWorkExecutor = computationState @@ -373,19 +361,16 @@ private ExecuteWorkResult executeWork( stageInfo, computationState, work.getLatencyTrackingId())); try { - WindmillStateReader stateReader = work.createWindmillStateReader(); + StreamingModeExecutionContext context = computationWorkExecutor.context(); - KeyTransitionListener keyTransitionListener = createKeyTransitionListener(); + // Blocks while executing work. + computationWorkExecutor.executeWork(work, workExecutor, handle, keyTransitionListener); List workBatch; List workItemCommits; Map> finalizationCallbacks; long stateBytesRead; { - // Blocks while executing work. - StreamingModeExecutionContext context = - computationWorkExecutor.executeWork( - work, stateReader, workExecutor, handle, keyTransitionListener); if (context.workIsFailed()) { throw new WorkItemCancelledException(work.getWorkItem().getShardingKey()); } @@ -441,9 +426,50 @@ private void commitWorkBatch( ComputationState computationState, List workBatch, List workItemCommits) { - checkState(workBatch.size() == 1, "Expected single-key work batch, got: " + workBatch.size()); - checkState(workBatch.size() == workItemCommits.size()); - commitSingleKeyWork(computationState, workBatch.get(0), workItemCommits.get(0)); + if (workBatch.isEmpty()) { + return; + } + if (workBatch.size() > 1 || multiKeyExperimentEnabled) { + commitMultiKeyWorkBatch(computationState, workBatch, workItemCommits); + } else { + commitSingleKeyWork(computationState, workBatch.get(0), workItemCommits.get(0)); + } + } + + private void commitMultiKeyWorkBatch( + ComputationState computationState, + List workBatch, + List workItemCommits) { + Windmill.MultiKeyWorkItemCommitRequest.Builder multiKeyBuilder = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder(); + + Work primaryWork = workBatch.get(0); + Work.KeyGroup keyGroup = primaryWork.getKeyGroup(); + multiKeyBuilder.setKeyGroup( + Windmill.Uint128Proto.newBuilder().setHigh(keyGroup.high()).setLow(keyGroup.low()).build()); + + for (int i = 0; i < workBatch.size(); i++) { + // TODO: Add commit size validation + Windmill.WorkItemCommitRequest commit = workItemCommits.get(i); + Work w = workBatch.get(i); + multiKeyBuilder.addRequests( + commit + .toBuilder() + .addAllPerWorkItemLatencyAttributions(w.getLatencyAttributions(sampler)) + .build()); + } + + // Transition states of all completed works in the batch to COMMIT_QUEUED and submit + for (Work w : workBatch) { + w.setState(Work.State.COMMIT_QUEUED); + } + + // Package and submit the commit batch transactionally + primaryWork + .workCommitter() + .accept( + Commit.createMultiKey( + multiKeyBuilder.build(), computationState, ImmutableList.copyOf(workBatch))); } private void commitSingleKeyWork( @@ -461,12 +487,40 @@ private void commitSingleKeyWork( work.queueCommit(validatedCommitRequest, computationState); } + private void handleProcessWorkFailure( + ComputationState computationState, + List failedBatch, + String computationId, + Work primaryWork, + Throwable t) { + try { + List executableWorks = new ArrayList<>(); + for (Work w : failedBatch) { + executableWorks.add( + ExecutableWork.create(w, (retry, h) -> processWork(computationState, retry, h))); + } + + workFailureProcessor.logAndProcessFailureBatch( + computationId, + executableWorks, + t, + invalidWork -> + computationState.completeWorkAndScheduleNextWorkForKey( + invalidWork.getShardedKey(), invalidWork.id())); + } catch (OutOfMemoryError oom) { + throw oom; + } catch (Throwable t2) { + LOG.warn("Failed to process work failure safely for work {}", primaryWork.id(), t2); + throw ExceptionUtils.safeWrapThrowableAsException(t2); + } + } + private void recordProcessingTime( - StageInfo stageInfo, List worksToCleanup, long processingStartTimeNanos) { + StageInfo stageInfo, List workBatch, long processingStartTimeNanos) { long processingTimeMsecs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); - if (anyWorkHasTimers(worksToCleanup)) { + if (anyWorkHasTimers(workBatch)) { // Attribute all the processing to timers if the work item contains any timers. // Tests show that work items rarely contain both timers and message bundles. It should // be a fairly close approximation. @@ -482,9 +536,15 @@ private static boolean anyWorkHasTimers(List works) { private KeyTransitionListener createKeyTransitionListener() { return (oldWork, newWork) -> { + newWork.recordGetWorkStreamLatencies(); + newWork.setState(Work.State.PROCESSING); setLoggingContextWorkId(newWork.getLatencyTrackingId()); - newWork.setProcessingThreadName(oldWork.getProcessingThreadName()); - oldWork.setProcessingThreadName(""); + if (oldWork != null) { + newWork.setProcessingThreadName(oldWork.getProcessingThreadName()); + oldWork.setProcessingThreadName(""); + } else { + newWork.setProcessingThreadName(Thread.currentThread().getName()); + } }; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java index 18c8e9b8d83c..15ec1e0c2cf3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java @@ -17,13 +17,12 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures; +import java.util.List; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Supplier; import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; -import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; -import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.status.LastExceptionDataProvider; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Work; @@ -99,28 +98,41 @@ private static boolean isOutOfMemoryError(@Nullable Throwable t) { return false; } - /** - * Processes failures caused by thrown exceptions that occur during execution of {@link Work}. May - * attempt to retry execution of the {@link Work} or drop it if it is invalid. - */ - public void logAndProcessFailure( + public void logAndProcessFailureBatch( String computationId, - ExecutableWork executableWork, + List executableWorks, Throwable t, Consumer onInvalidWork) throws Throwable { - switch (evaluateRetry(computationId, executableWork.work(), t)) { - case DO_NOT_RETRY: - // Consider the item invalid. It will eventually be retried by Windmill if it still needs to - // be processed. - onInvalidWork.accept(executableWork.work()); - break; - case RETRY_LOCALLY: - // Try again after some delay and at the end of the queue to avoid a tight loop. - executeWithDelay(retryLocallyDelayMs, executableWork); - break; - case RETHROW_THROWABLE: - throw t; + List worksToRetryLocally = new java.util.ArrayList<>(); + + for (ExecutableWork executableWork : executableWorks) { + switch (evaluateRetry(computationId, executableWork.work(), t)) { + case DO_NOT_RETRY: + // Consider the item invalid. It will eventually be retried by Windmill if it still needs + // to + // be processed. + onInvalidWork.accept(executableWork.work()); + break; + case RETRY_LOCALLY: + // Try again after some delay and at the end of the queue to avoid a tight loop. + worksToRetryLocally.add(executableWork); + break; + case RETHROW_THROWABLE: + throw t; + } + } + + executeWithDelay(worksToRetryLocally); + } + + private void executeWithDelay(List worksToRetryLocally) { + if (!worksToRetryLocally.isEmpty()) { + // Sleep ONCE for the entire batch delay to avoid sequential thread blocks + Uninterruptibles.sleepUninterruptibly(retryLocallyDelayMs, TimeUnit.MILLISECONDS); + for (ExecutableWork ew : worksToRetryLocally) { + workUnitExecutor.forceExecute(ew, ew.work().getSerializedWorkItemSize()); + } } } @@ -131,12 +143,6 @@ private String tryToDumpHeap() { .orElseGet(() -> "not written"); } - private void executeWithDelay(long delayMs, ExecutableWork executableWork) { - Uninterruptibles.sleepUninterruptibly(delayMs, TimeUnit.MILLISECONDS); - workUnitExecutor.forceExecute( - executableWork, executableWork.work().getSerializedWorkItemSize()); - } - private enum RetryEvaluation { DO_NOT_RETRY, RETRY_LOCALLY, @@ -144,24 +150,16 @@ private enum RetryEvaluation { } private RetryEvaluation evaluateRetry(String computationId, Work work, Throwable t) { - @Nullable final Throwable cause = t.getCause(); - Throwable parsedException = (t instanceof UserCodeException && cause != null) ? cause : t; - if (KeyTokenInvalidException.isKeyTokenInvalidException(parsedException)) { - LOG.debug( - "Execution of work for computation '{}' on sharding key '{}' failed due to token expiration. " - + "Work will not be retried locally.", - computationId, - work.getWorkItem().getShardingKey()); - return RetryEvaluation.DO_NOT_RETRY; - } - if (WorkItemCancelledException.isWorkItemCancelledException(parsedException)) { + if (work.isFailed()) { LOG.debug( "Execution of work for computation '{}' on sharding key '{}' failed. " - + "Work will not be retried locally.", + + "Work is already marked as failed, not retrying locally.", computationId, work.getWorkItem().getShardingKey()); return RetryEvaluation.DO_NOT_RETRY; } + @Nullable final Throwable cause = t.getCause(); + Throwable parsedException = (t instanceof UserCodeException && cause != null) ? cause : t; LastExceptionDataProvider.reportException(parsedException); LOG.debug("Failed work: {}", work); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index 5be8ec0a6c72..eec77ccf435b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -29,7 +29,6 @@ import java.util.ArrayList; import java.util.Collection; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -37,6 +36,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -89,6 +89,8 @@ public final class FakeWindmillServer extends WindmillServerStub { private final Map streamingCommitsToOffer; // Keys are work tokens. private final Map commitsReceived; + private final List multiKeyCommitsReceived = + new CopyOnWriteArrayList<>(); private final ArrayList statsReceived; private final LinkedBlockingQueue exceptions; private final AtomicInteger expectedExceptionCount; @@ -118,7 +120,7 @@ public FakeWindmillServer( commitsToOffer = new ResponseQueue() .returnByDefault(CommitWorkResponse.getDefaultInstance()); - streamingCommitsToOffer = new HashMap<>(); + streamingCommitsToOffer = new ConcurrentHashMap<>(); commitsReceived = new ConcurrentHashMap<>(); exceptions = new LinkedBlockingQueue<>(); expectedExceptionCount = new AtomicInteger(); @@ -400,6 +402,7 @@ public void shutdown() {} public RequestBatcher batcher() { return new RequestBatcher() { final List requests = new ArrayList<>(); + final List multiKeyRequests = new ArrayList<>(); @Override public boolean commitWorkItem( @@ -423,6 +426,18 @@ public boolean commitWorkItem( return true; } + @Override + public boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone) { + LOG.debug("commitWorkStream::commitMultiKeyWorkItem: {}", request); + if (multiKeyRequests.size() > 5) return false; + multiKeyRequests.add(new MultiKeyRequestAndDone(request, onDone)); + flush(); + return true; + } + @Override public void flush() { for (RequestAndDone elem : requests) { @@ -445,6 +460,37 @@ public void flush() { .orElse(Windmill.CommitStatus.OK)); } requests.clear(); + + for (MultiKeyRequestAndDone elem : multiKeyRequests) { + if (dropStreamingCommits) { + for (WorkItemCommitRequest workRequest : elem.request.getRequestsList()) { + droppedStreamingCommits.put(workRequest.getWorkToken(), elem.onDone); + } + continue; + } + + multiKeyCommitsReceived.add(elem.request); + for (WorkItemCommitRequest workRequest : elem.request.getRequestsList()) { + commitsReceived.put(workRequest.getWorkToken(), workRequest); + } + + // Determine status for the batch. + // Default to OK, but if any of the works in the batch has an offered status, use it. + Windmill.CommitStatus status = Windmill.CommitStatus.OK; + for (WorkItemCommitRequest workRequest : elem.request.getRequestsList()) { + Windmill.CommitStatus offeredStatus = + streamingCommitsToOffer.remove( + WorkId.builder() + .setWorkToken(workRequest.getWorkToken()) + .setCacheToken(workRequest.getCacheToken()) + .build()); + if (offeredStatus != null) { + status = offeredStatus; + } + } + elem.onDone.accept(status); + } + multiKeyRequests.clear(); } class RequestAndDone { @@ -456,6 +502,18 @@ class RequestAndDone { this.onDone = onDone; } } + + class MultiKeyRequestAndDone { + final Consumer onDone; + final Windmill.MultiKeyWorkItemCommitRequest request; + + MultiKeyRequestAndDone( + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone) { + this.request = request; + this.onDone = onDone; + } + } }; } @@ -518,6 +576,15 @@ public Map waitForAndGetCommits(int numCommits) { public void clearCommitsReceived() { commitsRequested = 0; commitsReceived.clear(); + multiKeyCommitsReceived.clear(); + } + + public List getMultiKeyCommitsReceived() { + return multiKeyCommitsReceived; + } + + public void clearMultiKeyCommitsReceived() { + multiKeyCommitsReceived.clear(); } public ConcurrentHashMap> waitForDroppedCommits( diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidExceptionTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidExceptionTest.java deleted file mode 100644 index 1eb2871e8cd3..000000000000 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidExceptionTest.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.worker; - -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests for {@link KeyTokenInvalidException}. */ -@RunWith(JUnit4.class) -public final class KeyTokenInvalidExceptionTest { - @Test - public void testIsKeyTokenInvalidException() throws Exception { - KeyTokenInvalidException exception = new KeyTokenInvalidException("test"); - RuntimeException keyTokenCauseException = new RuntimeException("key token cause", exception); - assertTrue(KeyTokenInvalidException.isKeyTokenInvalidException(exception)); - assertTrue(KeyTokenInvalidException.isKeyTokenInvalidException(keyTokenCauseException)); - assertFalse( - KeyTokenInvalidException.isKeyTokenInvalidException(new RuntimeException("non key token"))); - } -} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 23730bc57705..dbb1cc45e1b8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -901,7 +901,7 @@ private ByteString addPaneTag(PaneInfo paneInfo, byte[] windowBytes) throws IOEx } private DataflowWorkerHarnessOptions createTestingPipelineOptions(String... args) { - List argsList = Lists.newArrayList(args); + List argsList = new ArrayList<>(Arrays.asList(args)); if (streamingEngine) { argsList.add("--experiments=enable_streaming_engine"); } @@ -1252,9 +1252,8 @@ public void testNumberOfWorkerHarnessThreadsIsHonored() throws Exception { } @Test - public void testKeyTokenInvalidException() throws Exception { - if (streamingEngine) { - // TODO: This test needs to be adapted to work with streamingEngine=true. + public void testMultiKeyCommit_success() throws Exception { + if (!streamingEngine) { return; } KvCoder kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()); @@ -1262,30 +1261,359 @@ public void testKeyTokenInvalidException() throws Exception { List instructions = Arrays.asList( makeSourceInstruction(kvCoder), - makeDoFnInstruction(new KeyTokenInvalidFn(), 0, kvCoder), + makeDoFnInstruction(new WorkDoFn(), 0, kvCoder), makeSinkInstruction(kvCoder, 1)); + StreamingDataflowWorker worker = + makeWorker( + defaultWorkerParams( + "--experiments=unstable_enable_multi_key_bundle,windmill_max_key_group_batch_time_ms=50000", + "--numberOfWorkerHarnessThreads=1") + .setLocalRetryTimeoutMs(100) + .setInstructions(instructions) + .build()); + worker.start(); + + String batchInputText = + "work {" + + " computation_id: \"" + + DEFAULT_COMPUTATION_ID + + "\"" + + " input_data_watermark: 0" + + " work {" + + " key: \"key1\"" + + " sharding_key: 1" + + " work_token: 1" + + " cache_token: 2" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data1\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key2\"" + + " sharding_key: 2" + + " work_token: 2" + + " cache_token: 3" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data2\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key3\"" + + " sharding_key: 3" + + " work_token: 3" + + " cache_token: 4" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data3\"" + + " }" + + " }" + + " }" + + "}"; + Windmill.GetWorkResponse batchInput = + buildInput( + batchInputText, + CoderUtils.encodeToByteArray( + CollectionCoder.of(IntervalWindow.getCoder()), + Collections.singletonList(DEFAULT_WINDOW))); + server - .whenGetWorkCalled() - .thenReturn(makeInput(0, 0, DEFAULT_KEY_STRING, DEFAULT_SHARDING_KEY)); + .whenGetDataCalled() + .answerByDefault( + request -> { + Windmill.GetDataResponse.Builder builder = Windmill.GetDataResponse.newBuilder(); + for (ComputationGetDataRequest compRequest : request.getRequestsList()) { + ComputationGetDataResponse.Builder compBuilder = + builder.addDataBuilder().setComputationId(compRequest.getComputationId()); + for (KeyedGetDataRequest keyRequest : compRequest.getRequestsList()) { + KeyedGetDataResponse.Builder keyBuilder = + compBuilder + .addDataBuilder() + .setKey(keyRequest.getKey()) + .setShardingKey(keyRequest.getShardingKey()); + keyBuilder.addAllValues(keyRequest.getValuesToFetchList()); + keyBuilder.addAllBags(keyRequest.getBagsToFetchList()); + keyBuilder.addAllWatermarkHolds(keyRequest.getWatermarkHoldsToFetchList()); + } + } + return builder.build(); + }); + + server.whenGetWorkCalled().thenReturn(batchInput); + + Map result = server.waitForAndGetCommits(3); + + assertEquals(3, result.size()); + + List multiKeyCommits = + server.getMultiKeyCommitsReceived(); + assertEquals(1, multiKeyCommits.size()); + Windmill.MultiKeyWorkItemCommitRequest multiKeyCommit = multiKeyCommits.get(0); + assertEquals(3, multiKeyCommit.getRequestsCount()); + assertEquals(1, multiKeyCommit.getRequests(0).getWorkToken()); + assertEquals(2, multiKeyCommit.getRequests(1).getWorkToken()); + assertEquals(3, multiKeyCommit.getRequests(2).getWorkToken()); + + worker.stop(); + } + + @Test + public void testMultiKeyCommit_elementFailure() throws Exception { + if (!streamingEngine) { + return; + } + KvCoder kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()); + + List instructions = + Arrays.asList( + makeSourceInstruction(kvCoder), + makeDoFnInstruction(new WorkDoFn(), 0, kvCoder), + makeSinkInstruction(kvCoder, 1)); StreamingDataflowWorker worker = - makeWorker(defaultWorkerParams().setInstructions(instructions).publishCounters().build()); + makeWorker( + defaultWorkerParams( + "--experiments=unstable_enable_multi_key_bundle,windmill_max_key_group_batch_time_ms=5000", + "--numberOfWorkerHarnessThreads=1") + .setLocalRetryTimeoutMs(100) + .setInstructions(instructions) + .build()); worker.start(); - server.waitForEmptyWorkQueue(); + String batchInputText = + "work {" + + " computation_id: \"" + + DEFAULT_COMPUTATION_ID + + "\"" + + " input_data_watermark: 0" + + " work {" + + " key: \"key1\"" + + " sharding_key: 1" + + " work_token: 1" + + " cache_token: 2" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data1\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key2\"" + + " sharding_key: 2" + + " work_token: 2" + + " cache_token: 3" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data2\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key3\"" + + " sharding_key: 3" + + " work_token: 3" + + " cache_token: 4" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data3\"" + + " }" + + " }" + + " }" + + "}"; + Windmill.GetWorkResponse batchInput = + buildInput( + batchInputText, + CoderUtils.encodeToByteArray( + CollectionCoder.of(IntervalWindow.getCoder()), + Collections.singletonList(DEFAULT_WINDOW))); server - .whenGetWorkCalled() - .thenReturn(makeInput(1, 0, DEFAULT_KEY_STRING, DEFAULT_SHARDING_KEY)); + .whenGetDataCalled() + .answerByDefault( + request -> { + Windmill.GetDataResponse.Builder builder = Windmill.GetDataResponse.newBuilder(); + for (ComputationGetDataRequest compRequest : request.getRequestsList()) { + ComputationGetDataResponse.Builder compBuilder = + builder.addDataBuilder().setComputationId(compRequest.getComputationId()); + for (KeyedGetDataRequest keyRequest : compRequest.getRequestsList()) { + KeyedGetDataResponse.Builder keyBuilder = + compBuilder + .addDataBuilder() + .setKey(keyRequest.getKey()) + .setShardingKey(keyRequest.getShardingKey()); + if (keyRequest.getWorkToken() == 2) { + keyBuilder.setFailed(true); + } else { + keyBuilder.addAllValues(keyRequest.getValuesToFetchList()); + keyBuilder.addAllBags(keyRequest.getBagsToFetchList()); + keyBuilder.addAllWatermarkHolds(keyRequest.getWatermarkHoldsToFetchList()); + } + } + } + return builder.build(); + }); + + server.whenGetWorkCalled().thenReturn(batchInput); + + Map result = server.waitForAndGetCommits(2); + + assertTrue(result.containsKey(1L)); + assertTrue(result.containsKey(3L)); + assertFalse(result.containsKey(2L)); + + List multiKeyCommits = + server.getMultiKeyCommitsReceived(); + assertEquals(1, multiKeyCommits.size()); + Windmill.MultiKeyWorkItemCommitRequest multiKeyCommit = multiKeyCommits.get(0); + assertEquals(2, multiKeyCommit.getRequestsCount()); + assertEquals(3, multiKeyCommit.getRequests(0).getWorkToken()); + assertEquals(1, multiKeyCommit.getRequests(1).getWorkToken()); + + worker.stop(); + } + + @Test + public void testCompleteCommit_retryableFailureTriggersReExecution() throws Exception { + if (!streamingEngine) { + return; + } + KvCoder kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()); + + List instructions = + Arrays.asList( + makeSourceInstruction(kvCoder), + makeDoFnInstruction(new WorkDoFn(), 0, kvCoder), + makeSinkInstruction(kvCoder, 1)); + + StreamingDataflowWorker worker = + makeWorker( + defaultWorkerParams( + "--experiments=unstable_enable_multi_key_bundle,max_key_group_batch_time_ms=5000", + "--numberOfWorkerHarnessThreads=1") + .setLocalRetryTimeoutMs(100) + .setInstructions(instructions) + .build()); + worker.start(); + + String batchInputText = + "work {" + + " computation_id: \"" + + DEFAULT_COMPUTATION_ID + + "\"" + + " input_data_watermark: 0" + + " work {" + + " key: \"key1\"" + + " sharding_key: 1" + + " work_token: 1" + + " cache_token: 2" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data1\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key2\"" + + " sharding_key: 2" + + " work_token: 2" + + " cache_token: 3" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data2\"" + + " }" + + " }" + + " }" + + "}"; + Windmill.GetWorkResponse batchInput = + buildInput( + batchInputText, + CoderUtils.encodeToByteArray( + CollectionCoder.of(IntervalWindow.getCoder()), + Collections.singletonList(DEFAULT_WINDOW))); + + server + .whenGetDataCalled() + .answerByDefault( + request -> { + Windmill.GetDataResponse.Builder builder = Windmill.GetDataResponse.newBuilder(); + for (ComputationGetDataRequest compRequest : request.getRequestsList()) { + ComputationGetDataResponse.Builder compBuilder = + builder.addDataBuilder().setComputationId(compRequest.getComputationId()); + for (KeyedGetDataRequest keyRequest : compRequest.getRequestsList()) { + KeyedGetDataResponse.Builder keyBuilder = + compBuilder + .addDataBuilder() + .setKey(keyRequest.getKey()) + .setShardingKey(keyRequest.getShardingKey()); + if (keyRequest.getWorkToken() == 2) { + keyBuilder.setFailed(true); + } else { + keyBuilder.addAllValues(keyRequest.getValuesToFetchList()); + keyBuilder.addAllBags(keyRequest.getBagsToFetchList()); + keyBuilder.addAllWatermarkHolds(keyRequest.getWatermarkHoldsToFetchList()); + } + } + } + return builder.build(); + }); + + server.whenGetWorkCalled().thenReturn(batchInput); Map result = server.waitForAndGetCommits(1); - assertEquals( - makeExpectedOutput(1, 0, DEFAULT_KEY_STRING, DEFAULT_SHARDING_KEY, DEFAULT_KEY_STRING) - .build(), - removeDynamicFields(result.get(1L))); - assertEquals(1, result.size()); + assertTrue(result.containsKey(1L)); + assertFalse(result.containsKey(2L)); + + List multiKeyCommits = + server.getMultiKeyCommitsReceived(); + assertEquals(1, multiKeyCommits.size()); + Windmill.MultiKeyWorkItemCommitRequest multiKeyCommit = multiKeyCommits.get(0); + assertEquals(1, multiKeyCommit.getRequestsCount()); + assertEquals(1, multiKeyCommit.getRequests(0).getWorkToken()); worker.stop(); } @@ -3520,8 +3848,8 @@ public void testExceptionInvalidatesCache() throws Exception { } // Ensure that the invalidated dofn had tearDown called on them. - assertEquals(1, TestExceptionInvalidatesCacheFn.tearDownCallCount.get()); - assertEquals(2, TestExceptionInvalidatesCacheFn.setupCallCount.get()); + assertEquals(2, TestExceptionInvalidatesCacheFn.tearDownCallCount.get()); + assertEquals(3, TestExceptionInvalidatesCacheFn.setupCallCount.get()); worker.stop(); } @@ -4543,18 +4871,19 @@ public void evaluate() throws Throwable { } } - static class KeyTokenInvalidFn extends DoFn, KV> { - - static boolean thrown = false; + static class WorkDoFn extends DoFn, KV> { + @StateId("state") + private final StateSpec> stateSpec = StateSpecs.value(StringUtf8Coder.of()); @ProcessElement - public void processElement(ProcessContext c) { - if (!thrown) { - thrown = true; - throw new KeyTokenInvalidException("key"); - } else { - c.output(c.element()); + public void processElement(ProcessContext c, @StateId("state") ValueState state) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); } + state.read(); + c.output(c.element()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 561596f68d0f..4c65667833b5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -24,6 +24,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; @@ -58,17 +59,19 @@ import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.NoopProfileScope; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope; +import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; +import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.FakeGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; -import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTagEncodingV1; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTagEncodingV2; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; @@ -76,6 +79,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.metrics.MetricsContainer; +import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.transforms.Create; @@ -102,6 +106,7 @@ public class StreamingModeExecutionContextTest { @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + @Mock private WindmillStateReader stateReader; @Mock private WorkExecutor workExecutor; @@ -114,7 +119,7 @@ public class StreamingModeExecutionContextTest { private FakeGlobalConfigHandle globalConfigHandle; private StreamingModeExecutionContext createExecutionContext( - StreamingGlobalConfigHandle configHandle) { + DataflowWorkerHarnessOptions options, StreamingGlobalConfigHandle configHandle) { CounterSet counterSet = new CounterSet(); ConcurrentHashMap stateNameMap = new ConcurrentHashMap<>(); stateNameMap.put(NameContextsForTests.nameContextForTest().userName(), "testStateFamily"); @@ -143,6 +148,7 @@ private StreamingModeExecutionContext createExecutionContext( /*hotKeyLoggingEnabled=*/ false, /*stepName=*/ "stepName", "sourceBytesProcessCounterName", + options, SideInputStateFetcherFactory.fromOptions(options)); } @@ -150,8 +156,11 @@ private StreamingModeExecutionContext createExecutionContext( public void setUp() { MockitoAnnotations.initMocks(this); options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + options + .as(ExperimentalOptions.class) + .setExperiments(Arrays.asList("unstable_enable_multi_key_bundle")); globalConfigHandle = new FakeGlobalConfigHandle(StreamingGlobalConfig.builder().build()); - executionContext = createExecutionContext(globalConfigHandle); + executionContext = createExecutionContext(options, globalConfigHandle); } private static Work createMockWork(Windmill.WorkItem workItem, Watermarks watermarks) { @@ -449,7 +458,7 @@ public void testStateTagEncodingBasedOnConfig() { FakeGlobalConfigHandle configHandle = new FakeGlobalConfigHandle( StreamingGlobalConfig.builder().setEnableStateTagEncodingV2(isV2Encoding).build()); - StreamingModeExecutionContext context = createExecutionContext(configHandle); + StreamingModeExecutionContext context = createExecutionContext(options, configHandle); assertEquals(expectedEncoding, context.getWindmillTagEncoding().getClass()); } } @@ -503,6 +512,274 @@ public void testStart_internalKeyDecoding() throws Exception { assertEquals("decodedKey", executionContext.getKey()); } + @Test + public void testAdvance_success() throws Exception { + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + Windmill.WorkItem workItem2 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key2")) + .setWorkToken(2L) + .setKeyGroup(keyGroup) + .build(); + Work work2 = + createMockWork( + workItem2, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + ExecutableWork executableWork2 = ExecutableWork.create(work2, (w, h) -> {}); + + org.mockito.Mockito.when( + mockExecutor.pollWork( + org.mockito.Mockito.eq(COMPUTATION_ID), + org.mockito.Mockito.eq(work1.getKeyGroup()), + org.mockito.Mockito.eq(mockHandle))) + .thenReturn(executableWork2); + + executionContext.start( + work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertTrue(executionContext.advance()); + assertEquals("key2", executionContext.getSerializedKey().toStringUtf8()); + } + + @Test + public void testAdvance_noMoreWork() throws Exception { + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + org.mockito.Mockito.when( + mockExecutor.pollWork( + org.mockito.Mockito.eq(COMPUTATION_ID), + org.mockito.Mockito.eq(work1.getKeyGroup()), + org.mockito.Mockito.eq(mockHandle))) + .thenReturn(null); + + executionContext.start( + work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(executionContext.advance()); + } + + @Test + public void testAdvance_respectsMaxBatchSize() throws Exception { + DataflowWorkerHarnessOptions optionsWithBatchSize = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + optionsWithBatchSize + .as(ExperimentalOptions.class) + .setExperiments(Arrays.asList("windmill_max_key_group_batch_size=1")); + StreamingModeExecutionContext context = + createExecutionContext(optionsWithBatchSize, globalConfigHandle); + + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + context.start(work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(context.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } + + @Test + public void testAdvance_respectsMaxBatchTime() throws Exception { + DataflowWorkerHarnessOptions optionsWithBatchTime = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + optionsWithBatchTime + .as(ExperimentalOptions.class) + .setExperiments(Arrays.asList("windmill_max_key_group_batch_time_ms=0")); + StreamingModeExecutionContext context = + createExecutionContext(optionsWithBatchTime, globalConfigHandle); + + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + context.start(work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(context.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } + + @Test + public void testAdvance_workFailed() throws Exception { + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + executionContext.start( + work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + work1.setFailed(); + + assertThrows(WorkItemCancelledException.class, () -> executionContext.advance()); + } + + @Test + public void testAdvance_defaultKeyGroup() throws Exception { + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + executionContext.start( + work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(executionContext.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } + + @Test + public void testAdvance_experimentDisabled() throws Exception { + DataflowWorkerHarnessOptions optionsDisabled = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + StreamingModeExecutionContext context = + createExecutionContext(optionsDisabled, globalConfigHandle); + + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + context.start(work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(context.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } + + @Test + public void testAdvance_respectsMaxBatchSinkBytes() throws Exception { + DataflowWorkerHarnessOptions optionsWithSinkBytes = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + optionsWithSinkBytes + .as(ExperimentalOptions.class) + .setExperiments( + Arrays.asList( + "unstable_enable_multi_key_bundle", "windmill_max_key_group_batch_sink_bytes=100")); + StreamingModeExecutionContext context = + createExecutionContext(optionsWithSinkBytes, globalConfigHandle); + + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + context.start(work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + context.reportBytesSinked(50); + assertFalse(context.advance()); + org.mockito.Mockito.verify(mockExecutor) + .pollWork(COMPUTATION_ID, work1.getKeyGroup(), mockHandle); + + org.mockito.Mockito.reset(mockExecutor); + + context.reportBytesSinked(60); + assertFalse(context.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } + + @Test + public void testExperimentParsingWithInvalidValues() { + DataflowWorkerHarnessOptions optionsInvalid = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + optionsInvalid + .as(ExperimentalOptions.class) + .setExperiments( + Arrays.asList( + "windmill_max_key_group_batch_size=invalid_size", + "windmill_max_key_group_batch_time_ms=invalid_time", + "windmill_max_key_group_batch_sink_bytes=invalid_bytes")); + + // This should not throw NumberFormatException + StreamingModeExecutionContext context = + createExecutionContext(optionsInvalid, globalConfigHandle); + + org.junit.Assert.assertNotNull(context); + } + @Test public void testInternalsPoisonedAfterFlushState() throws Exception { NameContext nameContext = NameContextsForTests.nameContextForTest(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index 31ea1bab07af..32fb93559e38 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -102,7 +102,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; -import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; @@ -642,6 +641,7 @@ public void testReadUnboundedReader() throws Exception { /*hotKeyLoggingEnabled=*/ false, /*stepName=*/ "stepName", "sourceBytesProcessCounterName", + options, SideInputStateFetcherFactory.fromOptions(options)); options.setNumWorkers(5); @@ -1015,6 +1015,7 @@ public void testFailedWorkItemsAbort() throws Exception { /*hotKeyLoggingEnabled=*/ false, /*stepName=*/ "stepName", "sourceBytesProcessCounterName", + options, SideInputStateFetcherFactory.fromOptions(options)); options.setNumWorkers(5); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java index 0f14efdd0c0b..60d7bb71a9de 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java @@ -565,6 +565,31 @@ public void testFailWork_batchFail() { } } + @Test + public void testGetActiveWork() { + ShardedKey shardedKey = shardedKey("someKey", 1L); + ExecutableWork work = createWork(createWorkItem(1L, 1L, shardedKey)); + + // Initially empty + assertFalse(activeWorkState.getActiveWork(shardedKey, work.id()).isPresent()); + + // Activate work + activeWorkState.activateWorkForKey(work); + + // Should find it now + Optional activeWork = activeWorkState.getActiveWork(shardedKey, work.id()); + assertTrue(activeWork.isPresent()); + assertSame(work, activeWork.get()); + + // Should not find it with different workId + assertFalse(activeWorkState.getActiveWork(shardedKey, workId(2L, 1L)).isPresent()); + assertFalse(activeWorkState.getActiveWork(shardedKey, workId(1L, 2L)).isPresent()); + + // Should not find it with different shardedKey + ShardedKey otherShardedKey = shardedKey("otherKey", 2L); + assertFalse(activeWorkState.getActiveWork(otherShardedKey, work.id()).isPresent()); + } + private static ExecutableWork firstValue(Map map) { Iterator> iterator = map.entrySet().iterator(); if (iterator.hasNext()) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateTest.java new file mode 100644 index 000000000000..6a6edd2b7192 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateTest.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import com.google.api.services.dataflow.model.MapTask; +import java.util.Collections; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ComputationStateTest { + + private final BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + private final WindmillStateCache.ForComputation mockStateCache = + mock(WindmillStateCache.ForComputation.class); + private final HeartbeatSender mockHeartbeatSender = mock(HeartbeatSender.class); + + private ComputationState computationState; + + private static ShardedKey shardedKey(String str, long shardKey) { + return ShardedKey.create(ByteString.copyFromUtf8(str), shardKey); + } + + private ExecutableWork createWork(Windmill.WorkItem workItem) { + return ExecutableWork.create( + Work.create( + workItem, + workItem.getSerializedSize(), + Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), + Work.createProcessingContext( + "computationId", new FakeGetDataClient(), ignored -> {}, mockHeartbeatSender), + false, + Instant::now), + (work, handle) -> {}); + } + + private static Windmill.WorkItem createWorkItem( + long workToken, long cacheToken, ShardedKey shardedKey) { + return Windmill.WorkItem.newBuilder() + .setShardingKey(shardedKey.shardingKey()) + .setKey(shardedKey.key()) + .setWorkToken(workToken) + .setCacheToken(cacheToken) + .build(); + } + + @Before + public void setUp() { + MapTask mapTask = new MapTask(); + mapTask.setStageName("stage"); + mapTask.setSystemName("system"); + computationState = + new ComputationState( + "computationId", mapTask, mockExecutor, Collections.emptyMap(), mockStateCache); + } + + @Test + public void testReExecuteActiveWork_workNotActive() { + ShardedKey shardedKey = shardedKey("key", 1L); + WorkId workId = WorkId.builder().setWorkToken(1L).setCacheToken(1L).build(); + + computationState.reExecuteActiveWork(shardedKey, workId); + + verifyNoInteractions(mockExecutor); + } + + @Test + public void testReExecuteActiveWork_workActive() { + ShardedKey shardedKey = shardedKey("key", 1L); + Windmill.WorkItem workItem = createWorkItem(1L, 1L, shardedKey); + ExecutableWork work = createWork(workItem); + + // Activate work first. This will execute it once. + computationState.activateWork(work); + verify(mockExecutor).execute(work, work.work().getSerializedWorkItemSize()); + + // Now re-execute + computationState.reExecuteActiveWork(shardedKey, work.id()); + verify(mockExecutor).forceExecute(work, work.work().getSerializedWorkItemSize()); + + verifyNoMoreInteractions(mockExecutor); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java index a98102751fb2..245d600448fe 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java @@ -30,7 +30,10 @@ import java.util.Collection; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import java.util.function.Consumer; +import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; @@ -82,6 +85,14 @@ private static ExecutableWork createWorkWithCompId( private static ExecutableWork createWorkWithCompIdAndKeyGroup( String computationId, Work.KeyGroup keyGroup, Consumer executeWorkFn) { + return createWorkWithHandle( + computationId, keyGroup, (work, handle) -> executeWorkFn.accept(work)); + } + + private static ExecutableWork createWorkWithHandle( + String computationId, + Work.KeyGroup keyGroup, + BiConsumer executeWorkFn) { WorkItem workItem = WorkItem.newBuilder() .setKey(ByteString.EMPTY) @@ -103,9 +114,7 @@ private static ExecutableWork createWorkWithCompIdAndKeyGroup( computationId, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), false, Instant::now), - (work, handle) -> { - executeWorkFn.accept(work); - }); + executeWorkFn); } private ExecutableWork createSleepProcessWork(CountDownLatch start, CountDownLatch stop) { @@ -406,18 +415,25 @@ public void testRunnableExceptionPropagationDecrementsCounters() throws Exceptio @Test public void testHandleMerge() throws Exception { - BoundedQueueExecutorWorkHandleImpl handle1 = executor.createBudgetHandle(1, 100L); - BoundedQueueExecutorWorkHandleImpl handle2 = executor.createBudgetHandle(2, 200L); + Work work1 = createWork(ignored -> {}).work(); + Work work2 = createWork(ignored -> {}).work(); + Work work3 = createWork(ignored -> {}).work(); + BoundedQueueExecutorWorkHandleImpl handle1 = executor.createBudgetHandle(work1, 100L); + BoundedQueueExecutorWorkHandleImpl handle2 = executor.createBudgetHandle(work2, 200L); + handle2.merge(executor.createBudgetHandle(work3, 0L)); handle1.merge(handle2); // Verify that handle2 has 0 budget and is closed. - assertEquals(0, handle2.elements()); + assertEquals(0, handle2.getWorkBatch().size()); assertEquals(0, handle2.bytes()); assertTrue(handle2.isClosed()); // Verify that handle1 has the combined budget and is not closed. - assertEquals(3, handle1.elements()); + assertEquals(3, handle1.getWorkBatch().size()); + assertTrue(handle1.getWorkBatch().contains(work1)); + assertTrue(handle1.getWorkBatch().contains(work2)); + assertTrue(handle1.getWorkBatch().contains(work3)); assertEquals(300L, handle1.bytes()); assertFalse(handle1.isClosed()); } @@ -449,11 +465,13 @@ public void testPollWork() throws Exception { // 1. Create blocker task to occupy the worker thread CountDownLatch blockerStart = new CountDownLatch(1); CountDownLatch blockerStop = new CountDownLatch(1); + AtomicReference blockerHandleRef = new AtomicReference<>(); ExecutableWork blockerWork = - createWorkWithCompIdAndKeyGroup( + createWorkWithHandle( "blockerComp", DEFAULT_KEY_GROUP, - ignored -> { + (work, handle) -> { + blockerHandleRef.set(handle); blockerStart.countDown(); try { blockerStop.await(); @@ -464,6 +482,9 @@ public void testPollWork() throws Exception { testExecutor.execute(blockerWork, 0); blockerStart.await(); + BoundedQueueExecutorWorkHandleImpl stealHandle = + (BoundedQueueExecutorWorkHandleImpl) blockerHandleRef.get(); + assertNotNull(stealHandle); // 2. Create two distinct key groups Work.KeyGroup keyGroup1 = Work.KeyGroup.create(1, 1); @@ -488,22 +509,18 @@ public void testPollWork() throws Exception { assertEquals(3, testExecutor.elementsOutstanding()); // Steal work2 using pollWork with compA and keyGroup2 - try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { - ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup2, stealHandle); - assertNotNull(stolen); - assertEquals(work2, stolen); - - // Run the stolen task - stolen.run(stealHandle); - targetStart.await(); - } + ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup2, stealHandle); + assertNotNull(stolen); + assertEquals(work2, stolen); + + // Run the stolen task + stolen.run(stealHandle); + targetStart.await(); // Steal work1 using pollWork with compA and keyGroup1 - try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { - ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup1, stealHandle); - assertNotNull(stolen); - assertEquals(work1, stolen); - } + ExecutableWork stolen1 = testExecutor.pollWork("compA", keyGroup1, stealHandle); + assertNotNull(stolen1); + assertEquals(work1, stolen1); // Unblock the blocker and shut down blockerStop.countDown(); @@ -525,11 +542,13 @@ public void testPollWorkWithLinkedBlockingQueue() throws Exception { CountDownLatch blockerStart = new CountDownLatch(1); CountDownLatch blockerStop = new CountDownLatch(1); + AtomicReference blockerHandleRef = new AtomicReference<>(); ExecutableWork blockerWork = - createWorkWithCompIdAndKeyGroup( + createWorkWithHandle( "blockerComp", DEFAULT_KEY_GROUP, - ignored -> { + (work, handle) -> { + blockerHandleRef.set(handle); blockerStart.countDown(); try { blockerStop.await(); @@ -540,15 +559,16 @@ public void testPollWorkWithLinkedBlockingQueue() throws Exception { testExecutor.execute(blockerWork, 0); blockerStart.await(); + BoundedQueueExecutorWorkHandleImpl stealHandle = + (BoundedQueueExecutorWorkHandleImpl) blockerHandleRef.get(); + assertNotNull(stealHandle); Work.KeyGroup keyGroup = Work.KeyGroup.create(1, 1); ExecutableWork work = createWorkWithCompIdAndKeyGroup("compA", keyGroup, ignored -> {}); testExecutor.execute(work, 100); - try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { - ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup, stealHandle); - assertNull(stolen); - } + ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup, stealHandle); + assertNull(stolen); blockerStop.countDown(); testExecutor.shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java index 994aa2030f3f..307cbde36989 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java @@ -63,7 +63,6 @@ public static Iterable data() { } @Parameterized.Parameter public boolean fairQueue; - private BoundedQueueExecutor executor; @Before @@ -116,7 +115,7 @@ private QueuedWork createQueuedWork( false, Instant::now), (w, h) -> {}); - return new QueuedWork(work, executor.createBudgetHandle(1, workBytes)); + return new QueuedWork(work, executor.createBudgetHandle(work.work(), workBytes)); } private static class NoOpRunnable implements Runnable { @@ -312,7 +311,6 @@ public String toString() { } })); } - // Start producers for (int i = 0; i < producerThreads; i++) { futures.add( @@ -470,7 +468,6 @@ public void testPollWorkWithKeyGroup() { QueuedWork polledNotExist = queue.pollWork("compA", keyGroupNotExist); assertNull(polledNotExist); assertEquals(2, queue.size()); - // Poll with keyGroup2 first - should return workA2 QueuedWork polledA2 = queue.pollWork("compA", keyGroup2); assertNotNull(polledA2); @@ -485,7 +482,6 @@ public void testPollWorkWithKeyGroup() { assertNotNull(polledA1); assertEquals(workA1, polledA1); assertTrue(queue.isEmpty()); - polledNotExist = queue.pollWork("compA", keyGroupNotExist); assertNull(polledNotExist); assertTrue(queue.isEmpty()); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java index 5c3132ae471d..3da740d53361 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java @@ -129,9 +129,9 @@ public void testCommit() { for (Commit commit : commits) { Windmill.WorkItemCommitRequest request = - committed.get(commit.work().getWorkItem().getWorkToken()); + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); } assertThat(completeCommits).hasSize(commits.size()); @@ -141,12 +141,13 @@ public void testCommit() { (CompleteCommit completeCommit, Commit commit) -> completeCommit.computationId().equals(commit.computationId()) && completeCommit.status() == Windmill.CommitStatus.OK - && completeCommit.workId().equals(commit.work().id()) + && completeCommit.workId().equals(commit.workBatch().get(0).id()) && completeCommit .shardedKey() .equals( ShardedKey.create( - commit.request().getKey(), commit.request().getShardingKey())), + commit.singleKeyRequest().get().getKey(), + commit.singleKeyRequest().get().getShardingKey())), "expected to equal")) .containsExactlyElementsIn(commits); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index 01197622c24d..a48159338132 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -53,6 +53,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.WorkId; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; @@ -62,6 +63,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.Duration; import org.joda.time.Instant; @@ -134,12 +136,11 @@ private static ComputationState createComputationState(String computationId) { null); } - private static CompleteCommit asCompleteCommit(Commit commit, Windmill.CommitStatus status) { - if (commit.work().isFailed()) { - return CompleteCommit.forFailedWork(commit); - } - - return CompleteCommit.create(commit, status); + private static CompleteCommit asCompleteCommit( + String computationId, Work work, Windmill.CommitStatus status) { + Windmill.CommitStatus finalStatus = work.isFailed() ? Windmill.CommitStatus.ABORTED : status; + return CompleteCommit.create( + computationId, work.getShardedKey(), work.id(), finalStatus, /* retryableFailure= */ false); } @Before @@ -186,10 +187,14 @@ public void testCommit_sendsCommitsToStreamingEngine() { waitForExpectedSetSize(completeCommits, 5); for (Commit commit : commits) { - WorkItemCommitRequest request = committed.get(commit.work().getWorkItem().getWorkToken()); + WorkItemCommitRequest request = + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); - assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); + assertThat(completeCommits) + .contains( + asCompleteCommit( + commit.computationId(), commit.workBatch().get(0), Windmill.CommitStatus.OK)); } workCommitter.stop(); @@ -224,14 +229,24 @@ public void testCommit_handlesFailedCommits() { waitForExpectedSetSize(completeCommits, 10); for (Commit commit : commits) { - if (commit.work().isFailed()) { + if (commit.isFailed()) { assertThat(completeCommits) - .contains(asCompleteCommit(commit, Windmill.CommitStatus.ABORTED)); - assertThat(committed).doesNotContainKey(commit.work().getWorkItem().getWorkToken()); + .contains( + asCompleteCommit( + commit.computationId(), + commit.workBatch().get(0), + Windmill.CommitStatus.ABORTED)); + assertThat(committed) + .doesNotContainKey(commit.workBatch().get(0).getWorkItem().getWorkToken()); } else { - assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); + assertThat(completeCommits) + .contains( + asCompleteCommit( + commit.computationId(), commit.workBatch().get(0), Windmill.CommitStatus.OK)); assertThat(committed) - .containsEntry(commit.work().getWorkItem().getWorkToken(), commit.request()); + .containsEntry( + commit.workBatch().get(0).getWorkItem().getWorkToken(), + commit.singleKeyRequest().get()); } } @@ -282,11 +297,16 @@ public void testCommit_handlesCompleteCommits_commitStatusNotOK() { waitForExpectedSetSize(completeCommits, commits.size()); for (Commit commit : commits) { - WorkItemCommitRequest request = committed.get(commit.work().getWorkItem().getWorkToken()); + WorkItemCommitRequest request = + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); assertThat(completeCommits) - .contains(asCompleteCommit(commit, expectedCommitStatus.get(commit.work().id()))); + .contains( + asCompleteCommit( + commit.computationId(), + commit.workBatch().get(0), + expectedCommitStatus.get(commit.workBatch().get(0).id()))); } workCommitter.stop(); @@ -313,6 +333,14 @@ public boolean commitWorkItem( return false; } + @Override + public boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone) { + return false; + } + @Override public void flush() {} }; @@ -367,10 +395,11 @@ public void shutdown() {} assertThat(commits.size()).isEqualTo(completeCommits.size()); for (CompleteCommit completeCommit : completeCommits) { assertThat(completeCommit.status()).isEqualTo(Windmill.CommitStatus.ABORTED); + assertThat(completeCommit.retryableFailure()).isFalse(); } for (Commit commit : commits) { - assertTrue(commit.work().isFailed()); + assertTrue(commit.isFailed()); } } @@ -409,10 +438,14 @@ public void testMultipleCommitSendersSingleStream() { waitForExpectedSetSize(completeCommits, commits.size()); for (Commit commit : commits) { - WorkItemCommitRequest request = committed.get(commit.work().getWorkItem().getWorkToken()); + WorkItemCommitRequest request = + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); - assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); + assertThat(completeCommits) + .contains( + asCompleteCommit( + commit.computationId(), commit.workBatch().get(0), Windmill.CommitStatus.OK)); } workCommitter.stop(); @@ -474,4 +507,242 @@ public void testStop_drainsCommitQueue_concurrentCommit() waitForExpectedSetSize(completeCommits, sentCommits.intValue()); } + + @Test + public void testCommit_multiKeyCommitFailedWork() { + Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); + workCommitter = createWorkCommitter(completeCommits::add); + + Work workA = createMockWork(101L); + Work workB = createMockWork(102L); + Work workC = createMockWork(103L); + + // Mark non-primary key B as failed + workB.setFailed(); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workA.getWorkItem().getKey()) + .setShardingKey(workA.getWorkItem().getShardingKey()) + .setWorkToken(workA.getWorkItem().getWorkToken()) + .setCacheToken(workA.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workB.getWorkItem().getKey()) + .setShardingKey(workB.getWorkItem().getShardingKey()) + .setWorkToken(workB.getWorkItem().getWorkToken()) + .setCacheToken(workB.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workC.getWorkItem().getKey()) + .setShardingKey(workC.getWorkItem().getShardingKey()) + .setWorkToken(workC.getWorkItem().getWorkToken()) + .setCacheToken(workC.getWorkItem().getCacheToken()) + .build()) + .build(); + + Commit commit = + Commit.createMultiKey( + multiKeyRequest, + createComputationState("computationId"), + ImmutableList.of(workA, workB, workC)); + + workCommitter.start(); + workCommitter.commit(commit); + + // The entire batch must be aborted immediately without making network calls + waitForExpectedSetSize(completeCommits, 3); + + // Verify all three works are aborted individually + assertThat(completeCommits) + .containsExactly( + CompleteCommit.create( + "computationId", + workA.getShardedKey(), + workA.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ true), + CompleteCommit.create( + "computationId", + workB.getShardedKey(), + workB.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ false), + CompleteCommit.create( + "computationId", + workC.getShardedKey(), + workC.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ true)); + + // Verify that valid work was not marked failed + assertThat(workA.isFailed()).isFalse(); + assertThat(workC.isFailed()).isFalse(); + assertThat(workB.isFailed()).isTrue(); + + workCommitter.stop(); + } + + @Test + public void testCommit_multiKeyCommitSuccess() { + Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); + workCommitter = createWorkCommitter(completeCommits::add); + + Work workA = createMockWork(101L); + Work workB = createMockWork(102L); + Work workC = createMockWork(103L); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workA.getWorkItem().getKey()) + .setShardingKey(workA.getWorkItem().getShardingKey()) + .setWorkToken(workA.getWorkItem().getWorkToken()) + .setCacheToken(workA.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workB.getWorkItem().getKey()) + .setShardingKey(workB.getWorkItem().getShardingKey()) + .setWorkToken(workB.getWorkItem().getWorkToken()) + .setCacheToken(workB.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workC.getWorkItem().getKey()) + .setShardingKey(workC.getWorkItem().getShardingKey()) + .setWorkToken(workC.getWorkItem().getWorkToken()) + .setCacheToken(workC.getWorkItem().getCacheToken()) + .build()) + .build(); + + Commit commit = + Commit.createMultiKey( + multiKeyRequest, + createComputationState("computationId"), + ImmutableList.of(workA, workB, workC)); + + workCommitter.start(); + workCommitter.commit(commit); + + // Wait for the server to receive and process the commits + fakeWindmillServer.waitForAndGetCommits(3); + waitForExpectedSetSize(completeCommits, 3); + + // Verify that FakeWindmillServer received all 3 work requests in multiKeyCommitsReceived + List multiKeyCommits = + fakeWindmillServer.getMultiKeyCommitsReceived(); + assertThat(multiKeyCommits).hasSize(1); + assertThat(multiKeyCommits.get(0)).isEqualTo(multiKeyRequest); + + // Verify all three works are completed successfully + assertThat(completeCommits) + .containsExactly( + CompleteCommit.create( + "computationId", + workA.getShardedKey(), + workA.id(), + CommitStatus.OK, + /* retryableFailure= */ false), + CompleteCommit.create( + "computationId", + workB.getShardedKey(), + workB.id(), + CommitStatus.OK, + /* retryableFailure= */ false), + CompleteCommit.create( + "computationId", + workC.getShardedKey(), + workC.id(), + CommitStatus.OK, + /* retryableFailure= */ false)); + + workCommitter.stop(); + } + + @Test + public void testCommit_multiKeyCommitStatusNotOK() { + Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); + workCommitter = createWorkCommitter(completeCommits::add); + + Work workA = createMockWork(101L); + Work workB = createMockWork(102L); + Work workC = createMockWork(103L); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workA.getWorkItem().getKey()) + .setShardingKey(workA.getWorkItem().getShardingKey()) + .setWorkToken(workA.getWorkItem().getWorkToken()) + .setCacheToken(workA.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workB.getWorkItem().getKey()) + .setShardingKey(workB.getWorkItem().getShardingKey()) + .setWorkToken(workB.getWorkItem().getWorkToken()) + .setCacheToken(workB.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workC.getWorkItem().getKey()) + .setShardingKey(workC.getWorkItem().getShardingKey()) + .setWorkToken(workC.getWorkItem().getWorkToken()) + .setCacheToken(workC.getWorkItem().getCacheToken()) + .build()) + .build(); + + Commit commit = + Commit.createMultiKey( + multiKeyRequest, + createComputationState("computationId"), + ImmutableList.of(workA, workB, workC)); + + // Offer NOT_FOUND status for one of the works. + fakeWindmillServer.whenCommitWorkStreamCalled().put(workB.id(), CommitStatus.NOT_FOUND); + + workCommitter.start(); + workCommitter.commit(commit); + + // Wait for the server to receive and process the commits + fakeWindmillServer.waitForAndGetCommits(3); + waitForExpectedSetSize(completeCommits, 3); + + // Verify that FakeWindmillServer received the multi-key commit + List multiKeyCommits = + fakeWindmillServer.getMultiKeyCommitsReceived(); + assertThat(multiKeyCommits).hasSize(1); + assertThat(multiKeyCommits.get(0)).isEqualTo(multiKeyRequest); + + // Verify all three works in the multi-key commit are completed with NOT_FOUND status + assertThat(completeCommits) + .containsExactly( + CompleteCommit.create( + "computationId", + workA.getShardedKey(), + workA.id(), + CommitStatus.NOT_FOUND, + /* retryableFailure= */ false), + CompleteCommit.create( + "computationId", + workB.getShardedKey(), + workB.id(), + CommitStatus.NOT_FOUND, + /* retryableFailure= */ false), + CompleteCommit.create( + "computationId", + workC.getShardedKey(), + workC.id(), + CommitStatus.NOT_FOUND, + /* retryableFailure= */ false)); + + workCommitter.stop(); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index 9c3d5c9c3ef3..6c5c87aa33a0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -1134,6 +1134,264 @@ public void testCommitWorkItem_multiplePhysicalStreams_multipleHandovers_halfClo assertTrue(commitWorkStream.awaitTermination(10, TimeUnit.SECONDS)); } + @Test + public void testCommit_multiKeyCommit() throws Exception { + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + + // 1. Construct two individual WorkItemCommitRequests + long shardingKey1 = 101L; + long workToken1 = 201L; + long cacheToken1 = 301L; + long shardingKey2 = 102L; + long workToken2 = 202L; + long cacheToken2 = 302L; + Windmill.WorkItemCommitRequest request1 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setShardingKey(shardingKey1) + .setWorkToken(workToken1) + .setCacheToken(cacheToken1) + .build(); + Windmill.WorkItemCommitRequest request2 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key2")) + .setShardingKey(shardingKey2) + .setWorkToken(workToken2) + .setCacheToken(cacheToken2) + .build(); + + // 2. Wrap them into a MultiKeyWorkItemCommitRequest + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests(request1) + .addRequests(request2) + .build(); + + // 3. Commit the multi-key work item using the request batcher + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitMultiKeyWorkItem( + COMPUTATION_ID, multiKeyRequest, commitStatusFuture::complete)); + } + + // 4. Receive and assert request properties on FakeWindmillGrpcService + Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take(); + assertThat(request.getCommitChunkCount()).isEqualTo(1); + + Windmill.StreamingCommitRequestChunk chunk = request.getCommitChunk(0); + + // Assert that the commit type is correctly identified as COMMIT_TYPE_MULTI_KEY + assertThat(chunk.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + + // Assert that the routing sharding key is mapped to the first request's sharding key + assertThat(chunk.getShardingKey()).isEqualTo(request1.getShardingKey()); + + // Assert that the serialized payload matches the input multiKeyRequest + Windmill.MultiKeyWorkItemCommitRequest parsedRequest = + Windmill.MultiKeyWorkItemCommitRequest.parseFrom(chunk.getSerializedWorkItemCommit()); + assertThat(parsedRequest).isEqualTo(multiKeyRequest); + + // 5. Respond with the generated requestId to complete the commit + long requestId = chunk.getRequestId(); + streamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(requestId).build()); + + // 6. Verify callback completed successfully with CommitStatus.OK + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + } + + @Test + public void testCommit_multiKeyCommit_multichunk() throws Exception { + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + + long shardingKey1 = 101L; + long workToken1 = 201L; + long cacheToken1 = 301L; + long shardingKey2 = 102L; + long workToken2 = 202L; + long cacheToken2 = 302L; + + Windmill.WorkItemCommitRequest request1 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setShardingKey(shardingKey1) + .setWorkToken(workToken1) + .setCacheToken(cacheToken1) + .addBagUpdates(Windmill.TagBag.newBuilder().setTag(LARGE_BYTE_STRING).build()) + .build(); + + Windmill.WorkItemCommitRequest request2 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key2")) + .setShardingKey(shardingKey2) + .setWorkToken(workToken2) + .setCacheToken(cacheToken2) + .build(); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests(request1) + .addRequests(request2) + .build(); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitMultiKeyWorkItem( + COMPUTATION_ID, multiKeyRequest, commitStatusFuture::complete)); + } + + Windmill.StreamingCommitWorkRequest requestChunk1 = streamInfo.requests.take(); + assertThat(requestChunk1.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk1 = requestChunk1.getCommitChunk(0); + + assertThat(chunk1.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + assertThat(chunk1.getShardingKey()).isEqualTo(request1.getShardingKey()); + assertThat(chunk1.getRemainingBytesForWorkItem()).isGreaterThan(0); + + Windmill.StreamingCommitWorkRequest requestChunk2 = streamInfo.requests.take(); + assertThat(requestChunk2.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk2 = requestChunk2.getCommitChunk(0); + + assertThat(chunk2.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + assertThat(chunk2.getShardingKey()).isEqualTo(request1.getShardingKey()); + assertThat(chunk2.getRemainingBytesForWorkItem()).isEqualTo(0); + + ByteString reconstructedBytes = + chunk1.getSerializedWorkItemCommit().concat(chunk2.getSerializedWorkItemCommit()); + Windmill.MultiKeyWorkItemCommitRequest parsedRequest = + Windmill.MultiKeyWorkItemCommitRequest.parseFrom(reconstructedBytes); + assertThat(parsedRequest).isEqualTo(multiKeyRequest); + + long requestId = chunk1.getRequestId(); + assertThat(chunk2.getRequestId()).isEqualTo(requestId); + + streamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(requestId).build()); + + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + } + + @Test + public void testCommitMultiKeyWorkItem_retryOnNewStream() throws Exception { + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + + long shardingKey1 = 101L; + long workToken1 = 201L; + long cacheToken1 = 301L; + Windmill.WorkItemCommitRequest request1 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setShardingKey(shardingKey1) + .setWorkToken(workToken1) + .setCacheToken(cacheToken1) + .build(); + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder().addRequests(request1).build(); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitMultiKeyWorkItem( + COMPUTATION_ID, multiKeyRequest, commitStatusFuture::complete)); + } + + Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take(); + assertThat(request.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk = request.getCommitChunk(0); + assertThat(chunk.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + long requestId = chunk.getRequestId(); + + streamInfo.responseObserver.onError(new IOException("test error")); + + FakeWindmillGrpcService.CommitStreamInfo reconnectStreamInfo = + waitForConnectionAndConsumeHeader(); + Windmill.StreamingCommitWorkRequest reconnectRequest = reconnectStreamInfo.requests.take(); + assertThat(reconnectRequest.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk reconnectChunk = reconnectRequest.getCommitChunk(0); + assertThat(reconnectChunk.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + assertThat(reconnectChunk.getRequestId()).isEqualTo(requestId); + + Windmill.MultiKeyWorkItemCommitRequest parsedRequest = + Windmill.MultiKeyWorkItemCommitRequest.parseFrom( + reconnectChunk.getSerializedWorkItemCommit()); + assertThat(parsedRequest).isEqualTo(multiKeyRequest); + + reconnectStreamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(requestId).build()); + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + } + + @Test + public void testCommitWorkItem_retryOnNewStream_multichunk() throws Exception { + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + + Windmill.WorkItemCommitRequest largeRequest = + workItemCommitRequest(1) + .toBuilder() + .addBagUpdates(Windmill.TagBag.newBuilder().setTag(LARGE_BYTE_STRING).build()) + .build(); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem(COMPUTATION_ID, largeRequest, commitStatusFuture::complete)); + } + + Windmill.StreamingCommitWorkRequest requestChunk1 = streamInfo.requests.take(); + assertThat(requestChunk1.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk1 = requestChunk1.getCommitChunk(0); + long requestId = chunk1.getRequestId(); + assertThat(chunk1.getRemainingBytesForWorkItem()).isGreaterThan(0); + + Windmill.StreamingCommitWorkRequest requestChunk2 = streamInfo.requests.take(); + assertThat(requestChunk2.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk2 = requestChunk2.getCommitChunk(0); + assertThat(chunk2.getRequestId()).isEqualTo(requestId); + assertThat(chunk2.getRemainingBytesForWorkItem()).isEqualTo(0); + + streamInfo.responseObserver.onError(new IOException("test error")); + + FakeWindmillGrpcService.CommitStreamInfo reconnectStreamInfo = + waitForConnectionAndConsumeHeader(); + + Windmill.StreamingCommitWorkRequest reconnectChunk1 = reconnectStreamInfo.requests.take(); + assertThat(reconnectChunk1.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk reconChunk1 = reconnectChunk1.getCommitChunk(0); + assertThat(reconChunk1.getRequestId()).isEqualTo(requestId); + assertThat(reconChunk1.getRemainingBytesForWorkItem()).isGreaterThan(0); + + Windmill.StreamingCommitWorkRequest reconnectChunk2 = reconnectStreamInfo.requests.take(); + assertThat(reconnectChunk2.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk reconChunk2 = reconnectChunk2.getCommitChunk(0); + assertThat(reconChunk2.getRequestId()).isEqualTo(requestId); + assertThat(reconChunk2.getRemainingBytesForWorkItem()).isEqualTo(0); + + ByteString reconstructedBytes = + reconChunk1.getSerializedWorkItemCommit().concat(reconChunk2.getSerializedWorkItemCommit()); + Windmill.WorkItemCommitRequest parsedRequest = + Windmill.WorkItemCommitRequest.parseFrom(reconstructedBytes); + assertThat(parsedRequest).isEqualTo(largeRequest); + + reconnectStreamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(requestId).build()); + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + } + @Test public void testCommitWorkItem_stopsRetriesAfterDuration() throws Exception { int numCommits = 1; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java index 1611fdac25dc..65637437a0a0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java @@ -35,9 +35,9 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.Future; -import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; import org.apache.beam.runners.dataflow.worker.WindmillStateTestUtils; import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; +import org.apache.beam.runners.dataflow.worker.WorkCancelingException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListEntry; @@ -1572,16 +1572,16 @@ public void testKeyTokenInvalid() throws Exception { try { watermarkFuture.get(); - fail("Expected KeyTokenInvalidException"); + fail("Expected WorkCancelingException"); } catch (Exception e) { - assertTrue(KeyTokenInvalidException.isKeyTokenInvalidException(e)); + assertTrue(WorkCancelingException.isWorkCancelingException(e)); } try { bagFuture.get(); - fail("Expected KeyTokenInvalidException"); + fail("Expected WorkCancelingException"); } catch (Exception e) { - assertTrue(KeyTokenInvalidException.isKeyTokenInvalidException(e)); + assertTrue(WorkCancelingException.isWorkCancelingException(e)); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java index 0610ed44c27f..ce9fe53f47d3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java @@ -21,15 +21,15 @@ import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; +import java.util.Arrays; import java.util.HashSet; +import java.util.List; import java.util.Optional; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Supplier; -import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; -import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; @@ -109,38 +109,22 @@ private static ExecutableWork createWork(Consumer processWorkFn) { } @Test - public void logAndProcessFailure_doesNotRetryKeyTokenInvalidException() throws Throwable { + public void logAndProcessFailureBatch_doesNotRetryFailedWork() throws Throwable { Set executedWork = new HashSet<>(); ExecutableWork work = createWork(executedWork::add); + work.work().setFailed(); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingEngineFailureReporter()); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new KeyTokenInvalidException("key"), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, List.of(work), new RuntimeException(), invalidWork::add); assertThat(executedWork).isEmpty(); assertThat(invalidWork).containsExactly(work.work()); } @Test - public void logAndProcessFailure_doesNotRetryWhenWorkItemCancelled() throws Throwable { - Set executedWork = new HashSet<>(); - ExecutableWork work = createWork(executedWork::add); - WorkFailureProcessor workFailureProcessor = - createWorkFailureProcessor(streamingEngineFailureReporter()); - Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, - work, - new WorkItemCancelledException(work.getWorkItem().getShardingKey()), - invalidWork::add); - - assertThat(executedWork).isEmpty(); - assertThat(invalidWork).containsExactly(work.work()); - } - - @Test - public void logAndProcessFailure_doesNotRetryOOM() { + public void logAndProcessFailureBatch_doesNotRetryOOM() { Set executedWork = new HashSet<>(); ExecutableWork work = createWork(executedWork::add); WorkFailureProcessor workFailureProcessor = @@ -149,69 +133,120 @@ public void logAndProcessFailure_doesNotRetryOOM() { assertThrows( OutOfMemoryError.class, () -> - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new OutOfMemoryError(), invalidWork::add)); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, + Arrays.asList(work), + new OutOfMemoryError(), + invalidWork::add)); assertThat(executedWork).isEmpty(); } @Test - public void logAndProcessFailure_doesNotRetryWhenFailureReporterMarksAsNonRetryable() + public void logAndProcessFailureBatch_doesNotRetryWhenFailureReporterMarksAsNonRetryable() throws Throwable { Set executedWork = new HashSet<>(); ExecutableWork work = createWork(executedWork::add); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingApplianceFailureReporter(true)); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new RuntimeException(), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, Arrays.asList(work), new RuntimeException(), invalidWork::add); assertThat(executedWork).isEmpty(); assertThat(invalidWork).containsExactly(work.work()); } @Test - public void logAndProcessFailure_doesNotRetryAfterLocalRetryTimeout() throws Throwable { + public void logAndProcessFailureBatch_doesNotRetryAfterLocalRetryTimeout() throws Throwable { Set executedWork = new HashSet<>(); ExecutableWork veryOldWork = createWork(() -> Instant.now().minus(Duration.standardDays(30)), executedWork::add); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingEngineFailureReporter()); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, veryOldWork, new RuntimeException(), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, + Arrays.asList(veryOldWork), + new RuntimeException(), + invalidWork::add); assertThat(executedWork).isEmpty(); assertThat(invalidWork).contains(veryOldWork.work()); } @Test - public void logAndProcessFailure_retriesOnUncaughtUnhandledException_streamingEngine() + public void logAndProcessFailureBatch_retriesOnUncaughtUnhandledException_streamingEngine() throws Throwable { CountDownLatch runWork = new CountDownLatch(1); ExecutableWork work = createWork(ignored -> runWork.countDown()); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingEngineFailureReporter()); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new RuntimeException(), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, Arrays.asList(work), new RuntimeException(), invalidWork::add); runWork.await(); assertThat(invalidWork).isEmpty(); } @Test - public void logAndProcessFailure_retriesOnUncaughtUnhandledException_streamingAppliance() + public void logAndProcessFailureBatch_retriesOnUncaughtUnhandledException_streamingAppliance() throws Throwable { CountDownLatch runWork = new CountDownLatch(1); ExecutableWork work = createWork(ignored -> runWork.countDown()); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingApplianceFailureReporter(false)); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new RuntimeException(), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, Arrays.asList(work), new RuntimeException(), invalidWork::add); runWork.await(); assertThat(invalidWork).isEmpty(); } + + @Test + public void logAndProcessFailureBatch_retryAll() throws Throwable { + CountDownLatch runWork1 = new CountDownLatch(1); + CountDownLatch runWork2 = new CountDownLatch(1); + ExecutableWork work1 = createWork(ignored -> runWork1.countDown()); + ExecutableWork work2 = createWork(ignored -> runWork2.countDown()); + + WorkFailureProcessor workFailureProcessor = + createWorkFailureProcessor(streamingEngineFailureReporter()); + Set invalidWork = new HashSet<>(); + + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, + Arrays.asList(work1, work2), + new RuntimeException(), + invalidWork::add); + + runWork1.await(); + runWork2.await(); + assertThat(invalidWork).isEmpty(); + } + + @Test + public void logAndProcessFailureBatch_mixRetryAndAbort() throws Throwable { + CountDownLatch runWork1 = new CountDownLatch(1); + Set executedWork2 = new HashSet<>(); + ExecutableWork work1 = createWork(ignored -> runWork1.countDown()); + ExecutableWork work2 = createWork(executedWork2::add); + work2.work().setFailed(); + + WorkFailureProcessor workFailureProcessor = + createWorkFailureProcessor(streamingEngineFailureReporter()); + Set invalidWork = new HashSet<>(); + + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, + Arrays.asList(work1, work2), + new RuntimeException(), + invalidWork::add); + + runWork1.await(); + assertThat(executedWork2).isEmpty(); + assertThat(invalidWork).containsExactly(work2.work()); + } } diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index aaa09c105fc3..a7a99e2ca5a1 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -678,9 +678,24 @@ message WorkItemCommitRequest { reserved 6, 23; } +message MultiKeyWorkItemCommitRequest { + optional Uint128Proto key_group = 7; + + repeated WorkItemCommitRequest requests = 1; + + repeated OutputMessageBundle output_messages = 2; + + repeated PubSubMessageBundle pubsub_messages = 3; + + repeated int64 finalize_ids = 4 [packed = true]; + + reserved 6; +} + message ComputationCommitWorkRequest { required string computation_id = 1; repeated WorkItemCommitRequest requests = 2; + repeated MultiKeyWorkItemCommitRequest multi_key_requests = 3; } message CommitWorkRequest { @@ -906,6 +921,14 @@ message StreamingCommitRequestChunk { // before handing off to the WindmillHost for processing. optional int64 remaining_bytes_for_work_item = 4; optional bytes serialized_work_item_commit = 5; + + enum CommitType { + COMMIT_TYPE_UNSPECIFIED = 0; + COMMIT_TYPE_SINGLE_KEY = 1; + COMMIT_TYPE_MULTI_KEY = 2; + } + + optional CommitType commit_type = 7; } message StreamingCommitResponse {