diff --git a/data-prepper-plugins/dynamodb-source/build.gradle b/data-prepper-plugins/dynamodb-source/build.gradle index 3b3046434a..b5bd6fd9cc 100644 --- a/data-prepper-plugins/dynamodb-source/build.gradle +++ b/data-prepper-plugins/dynamodb-source/build.gradle @@ -23,7 +23,9 @@ dependencies { implementation project(path: ':data-prepper-plugins:aws-plugin-api') implementation project(path: ':data-prepper-plugins:buffer-common') + implementation project(path: ':data-prepper-plugins:common') testImplementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml' + testImplementation project(':data-prepper-test:test-common') } \ No newline at end of file diff --git a/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/DynamoDBSourceConfig.java b/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/DynamoDBSourceConfig.java index 9d6d3fd358..1f8ace0b54 100644 --- a/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/DynamoDBSourceConfig.java +++ b/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/DynamoDBSourceConfig.java @@ -37,7 +37,7 @@ public class DynamoDBSourceConfig { private boolean acknowledgments = false; @JsonProperty("shard_acknowledgment_timeout") - private Duration shardAcknowledgmentTimeout = Duration.ofMinutes(10); + private Duration shardAcknowledgmentTimeout = Duration.ofMinutes(30); @JsonProperty("s3_data_file_acknowledgment_timeout") private Duration dataFileAcknowledgmentTimeout = Duration.ofMinutes(15); diff --git a/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/model/ShardCheckpointStatus.java b/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/model/ShardCheckpointStatus.java new file mode 100644 index 0000000000..b9489fbcef --- /dev/null +++ b/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/model/ShardCheckpointStatus.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.dynamodb.model; + +import java.time.Duration; +import java.time.Instant; + +public class ShardCheckpointStatus { + private final String sequenceNumber; + + private final boolean isFinalAcknowledgmentForPartition; + private AcknowledgmentStatus acknowledgeStatus; + private final long createTimestamp; + private Long acknowledgedTimestamp; + + public enum AcknowledgmentStatus { + POSITIVE_ACK, + NEGATIVE_ACK, + NO_ACK + } + + public ShardCheckpointStatus(final String sequenceNumber, final long createTimestamp, final boolean isFinalAcknowledgmentForPartition) { + this.sequenceNumber = sequenceNumber; + this.acknowledgeStatus = AcknowledgmentStatus.NO_ACK; + this.createTimestamp = createTimestamp; + this.isFinalAcknowledgmentForPartition = isFinalAcknowledgmentForPartition; + } + + public void setAcknowledgedTimestamp(final Long acknowledgedTimestamp) { + this.acknowledgedTimestamp = acknowledgedTimestamp; + } + + public void setAcknowledged(final AcknowledgmentStatus acknowledgmentStatus) { + this.acknowledgeStatus = acknowledgmentStatus; + } + + public String getSequenceNumber() { + return sequenceNumber; + } + + public boolean isPositiveAcknowledgement() { + return this.acknowledgeStatus == AcknowledgmentStatus.POSITIVE_ACK; + } + + public boolean isNegativeAcknowledgement() { + return this.acknowledgeStatus == AcknowledgmentStatus.NEGATIVE_ACK; + } + + public boolean isFinalAcknowledgmentForPartition() { + return isFinalAcknowledgmentForPartition; + } + + public boolean isExpired(final Duration expiredDuration) { + return Duration.between(Instant.ofEpochMilli(createTimestamp), Instant.now()).compareTo(expiredDuration) > 0; + } + +} \ No newline at end of file diff --git a/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardAcknowledgementManager.java b/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardAcknowledgementManager.java new file mode 100644 index 0000000000..f4afc5243e --- /dev/null +++ b/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardAcknowledgementManager.java @@ -0,0 +1,259 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.dynamodb.stream; + +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; +import org.opensearch.dataprepper.common.concurrent.BackgroundThreadFactory; +import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourcePartition; +import org.opensearch.dataprepper.plugins.source.dynamodb.DynamoDBSourceConfig; +import org.opensearch.dataprepper.plugins.source.dynamodb.coordination.partition.StreamPartition; +import org.opensearch.dataprepper.plugins.source.dynamodb.coordination.state.StreamProgressState; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import org.opensearch.dataprepper.plugins.source.dynamodb.model.ShardCheckpointStatus; + +public class ShardAcknowledgementManager { + private static final Logger LOG = LoggerFactory.getLogger(ShardAcknowledgementManager.class); + + private static final String NULL_SEQUENCE_NUMBER = "null"; + + private static final long WAIT_FOR_ACKNOWLEDGMENTS_TIMEOUT = 10L; + + static final Duration CHECKPOINT_INTERVAL = Duration.ofMinutes(2); + + private final DynamoDBSourceConfig dynamoDBSourceConfig; + private final Map> checkpoints = new ConcurrentHashMap<>(); + private final ConcurrentHashMap> ackStatuses = new ConcurrentHashMap<>(); + + private final AcknowledgementSetManager acknowledgementSetManager; + + private final EnhancedSourceCoordinator sourceCoordinator; + + private final ExecutorService executorService; + private final List partitionsToRemove; + private final List partitionsToGiveUp; + private boolean shutdownTriggered; + + private Instant lastCheckpointTime; + + public ShardAcknowledgementManager(final AcknowledgementSetManager acknowledgementSetManager, + final EnhancedSourceCoordinator sourceCoordinator, + final DynamoDBSourceConfig dynamoDBSourceConfig, + final Consumer stopWorkerConsumer + ) { + this.acknowledgementSetManager = acknowledgementSetManager; + this.sourceCoordinator = sourceCoordinator; + this.dynamoDBSourceConfig = dynamoDBSourceConfig; + this.executorService = Executors.newSingleThreadExecutor(BackgroundThreadFactory.defaultExecutorThreadFactory("dynamodb-shard-ack-monitor")); + this.partitionsToRemove = Collections.synchronizedList(new ArrayList<>()); + this.partitionsToGiveUp = Collections.synchronizedList(new ArrayList<>()); + this.lastCheckpointTime = Instant.now(); + + executorService.submit(() -> monitorAcknowledgments(stopWorkerConsumer)); + } + + void monitorAcknowledgments(final Consumer stopWorkerConsumer) { + while (!Thread.currentThread().isInterrupted()) { + boolean exit = runMonitorAcknowledgmentLoop(stopWorkerConsumer); + if (exit) { + break; + } + } + + LOG.info("Exiting acknowledgment manager"); + } + + boolean runMonitorAcknowledgmentLoop(final Consumer stopWorkerConsumer) { + removePartitions(); + if (shutdownTriggered) { + LOG.info("Shutdown was triggered giving up partitions and exiting cleanly"); + for (final StreamPartition streamPartition : checkpoints.keySet()) { + sourceCoordinator.giveUpPartition(streamPartition); + } + return true; + } + + for (final StreamPartition streamPartition : checkpoints.keySet()) { + try { + final StreamProgressState streamProgressState = streamPartition.getProgressState().orElseThrow(); + final ConcurrentLinkedQueue checkpointStatuses = checkpoints.get(streamPartition); + ShardCheckpointStatus latestCheckpointForShard = null; + boolean gaveUpPartition = false; + while (!checkpointStatuses.isEmpty()) { + updateOwnershipForAllShardPartitions(); + + if (checkpointStatuses.peek().isPositiveAcknowledgement()) { + latestCheckpointForShard = checkpointStatuses.poll(); + } else if (checkpointStatuses.peek().isNegativeAcknowledgement() + || checkpointStatuses.peek().isExpired(dynamoDBSourceConfig.getShardAcknowledgmentTimeout())) { + handleFailure(streamPartition, streamProgressState, latestCheckpointForShard); + gaveUpPartition = true; + + if (checkpointStatuses.peek().isNegativeAcknowledgement()) { + LOG.warn("Received negative acknowledgment for partition {} with sequence number {}, giving up partition", + streamPartition.getPartitionKey(), checkpointStatuses.peek().getSequenceNumber()); + } else { + LOG.warn("Acknowledgment timed out for partition {} with sequence number {}, giving up partition", + streamPartition.getPartitionKey(), checkpointStatuses.peek().getSequenceNumber()); + } + + stopWorkerConsumer.accept(streamPartition); + break; + } else { + break; + } + } + + if (!gaveUpPartition) { + updateOwnershipForAllShardPartitions(); + } + + if (gaveUpPartition || latestCheckpointForShard == null) { + continue; + } + + if (latestCheckpointForShard.isFinalAcknowledgmentForPartition()) { + handleCompletedShard(streamPartition); + } else { + streamProgressState.setSequenceNumber(Objects.equals(latestCheckpointForShard.getSequenceNumber(), NULL_SEQUENCE_NUMBER) ? null : latestCheckpointForShard.getSequenceNumber()); + sourceCoordinator.saveProgressStateForPartition(streamPartition, dynamoDBSourceConfig.getShardAcknowledgmentTimeout()); + LOG.debug("Checkpointed shard {} with latest sequence number acknowledged {}", streamPartition.getShardId(), latestCheckpointForShard.getSequenceNumber()); + } + if (partitionsToGiveUp.contains(streamPartition)) { + partitionsToRemove.add(streamPartition); + sourceCoordinator.giveUpPartition(streamPartition); + } + + } catch (final Exception e) { + LOG.error("Received exception while monitoring acknowledgments for stream partition {}", streamPartition.getPartitionKey(), e); + } + } + + return false; + } + + public AcknowledgementSet createAcknowledgmentSet( + final StreamPartition streamPartition, + final String sequenceNumber, + final boolean isFinalSetForPartition) { + final String sequenceNumberNoNull = sequenceNumber == null ? NULL_SEQUENCE_NUMBER : sequenceNumber; + final ShardCheckpointStatus shardCheckpointStatus = new ShardCheckpointStatus(sequenceNumber, Instant.now().toEpochMilli(), isFinalSetForPartition); + checkpoints.computeIfAbsent(streamPartition, segment -> new ConcurrentLinkedQueue<>()).add(shardCheckpointStatus); + ackStatuses.computeIfAbsent(streamPartition, segment -> new ConcurrentHashMap<>()); + ackStatuses.get(streamPartition).put(sequenceNumberNoNull, shardCheckpointStatus); + + return acknowledgementSetManager.create((result) -> { + if (ackStatuses.containsKey(streamPartition) && ackStatuses.get(streamPartition).containsKey(sequenceNumberNoNull)) { + final ShardCheckpointStatus ackCheckpointStatus = ackStatuses.get(streamPartition).get(sequenceNumberNoNull); + + ackCheckpointStatus.setAcknowledgedTimestamp(Instant.now().toEpochMilli()); + + if (result) { + LOG.debug("Received acknowledgment of completion from sink for partition {} with sequence number {}", + streamPartition.getPartitionKey(), sequenceNumberNoNull); + ackCheckpointStatus.setAcknowledged(ShardCheckpointStatus.AcknowledgmentStatus.POSITIVE_ACK); + } else { + LOG.warn("Negative acknowledgment received for partition {} with sequence number {}", + streamPartition.getPartitionKey(), sequenceNumberNoNull); + ackCheckpointStatus.setAcknowledged(ShardCheckpointStatus.AcknowledgmentStatus.NEGATIVE_ACK); + } + } + }, dynamoDBSourceConfig.getShardAcknowledgmentTimeout()); + } + + void updateOwnershipForAllShardPartitions() { + if (Duration.between(lastCheckpointTime, Instant.now()).compareTo(CHECKPOINT_INTERVAL) > 0) { + for (final StreamPartition streamPartition : checkpoints.keySet()) { + if (!partitionsToRemove.contains(streamPartition)) { + sourceCoordinator.saveProgressStateForPartition(streamPartition, dynamoDBSourceConfig.getShardAcknowledgmentTimeout()); + } + } + + lastCheckpointTime = Instant.now(); + } + } + + private void handleFailure(final StreamPartition streamPartition, + final StreamProgressState streamProgressState, + final ShardCheckpointStatus latestCheckpointForShard) { + if (latestCheckpointForShard != null) { + streamProgressState.setSequenceNumber(latestCheckpointForShard.getSequenceNumber()); + sourceCoordinator.saveProgressStateForPartition(streamPartition, dynamoDBSourceConfig.getShardAcknowledgmentTimeout()); + } + partitionsToRemove.add(streamPartition); + sourceCoordinator.giveUpPartition(streamPartition); + partitionsToGiveUp.remove(streamPartition); + } + + private void handleCompletedShard(final StreamPartition streamPartition) { + sourceCoordinator.completePartition(streamPartition); + partitionsToRemove.add(streamPartition); + partitionsToGiveUp.remove(streamPartition); + LOG.info("Received all acknowledgments for partition {}, marking partition as completed", streamPartition.getPartitionKey()); + } + + public void shutdown() { + shutdownTriggered = true; + executorService.shutdown(); + try { + if (!executorService.awaitTermination(WAIT_FOR_ACKNOWLEDGMENTS_TIMEOUT, TimeUnit.MINUTES)) { + executorService.shutdownNow(); + } + } catch (InterruptedException e) { + executorService.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + + private void removePartitions() { + partitionsToRemove.forEach(streamPartition -> { + checkpoints.remove(streamPartition); + ackStatuses.remove(streamPartition); + }); + + partitionsToRemove.clear(); + } + + public void giveUpPartition(final StreamPartition streamPartition) { + if (!partitionsToGiveUp.contains(streamPartition)) { + LOG.debug("Adding partition {} to give up list", streamPartition.getPartitionKey()); + partitionsToGiveUp.add(streamPartition); + } + } + + public boolean isExportDone(StreamPartition streamPartition) { + Optional globalPartition = sourceCoordinator.getPartition(streamPartition.getStreamArn()); + return globalPartition.isPresent(); + } + + public void startUpdatingOwnershipForShard(final StreamPartition streamPartition) { + checkpoints.computeIfAbsent(streamPartition, segment -> new ConcurrentLinkedQueue<>()); + } +} diff --git a/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumer.java b/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumer.java index f3297671d8..08f8576802 100644 --- a/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumer.java +++ b/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumer.java @@ -13,6 +13,7 @@ import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.plugins.source.dynamodb.configuration.StreamConfig; +import org.opensearch.dataprepper.plugins.source.dynamodb.coordination.partition.StreamPartition; import org.opensearch.dataprepper.plugins.source.dynamodb.converter.StreamRecordConverter; import org.opensearch.dataprepper.plugins.source.dynamodb.model.TableInfo; import org.opensearch.dataprepper.plugins.source.dynamodb.utils.DynamoDBSourceAggregateMetrics; @@ -35,10 +36,6 @@ public class ShardConsumer implements Runnable { private static final Logger LOG = LoggerFactory.getLogger(ShardConsumer.class); - private static final Duration ACKNOWLEDGMENT_EXPIRY_INCREASE_TIME = Duration.ofMinutes(10); - - private static final Duration ACKNOWLEDGMENT_PROGRESS_CHECK_INTERVAL = Duration.ofMinutes(3); - /** * A flag to interrupt the process */ @@ -96,6 +93,10 @@ public class ShardConsumer implements Runnable { private final StreamCheckpointer checkpointer; + private final ShardAcknowledgementManager shardAcknowledgementManager; + + private final StreamPartition streamPartition; + private String shardIterator; private final String lastShardIterator; @@ -104,10 +105,6 @@ public class ShardConsumer implements Runnable { private boolean waitForExport; - private final AcknowledgementSet acknowledgementSet; - - private final Duration shardAcknowledgmentTimeout; - private final String shardId; private final DynamoDBSourceAggregateMetrics dynamoDBSourceAggregateMetrics; @@ -120,6 +117,8 @@ private ShardConsumer(Builder builder) { this.shardProgress = builder.pluginMetrics.counter(SHARD_PROGRESS); this.dynamoDbStreamsClient = builder.dynamoDbStreamsClient; this.checkpointer = builder.checkpointer; + this.shardAcknowledgementManager = builder.shardAcknowledgementManager; + this.streamPartition = builder.streamPartition; this.shardIterator = builder.shardIterator; this.lastShardIterator = builder.lastShardIterator; // Introduce an overlap @@ -127,8 +126,6 @@ private ShardConsumer(Builder builder) { this.waitForExport = builder.waitForExport; final BufferAccumulator> bufferAccumulator = BufferAccumulator.create(builder.buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT); recordConverter = new StreamRecordConverter(bufferAccumulator, builder.tableInfo, builder.pluginMetrics, builder.streamConfig); - this.acknowledgementSet = builder.acknowledgementSet; - this.shardAcknowledgmentTimeout = builder.dataFileAcknowledgmentTimeout; this.shardId = builder.shardId; this.recordsWrittenToBuffer = 0; this.dynamoDBSourceAggregateMetrics = builder.dynamoDBSourceAggregateMetrics; @@ -157,6 +154,10 @@ static class Builder { private StreamCheckpointer checkpointer; + private ShardAcknowledgementManager shardAcknowledgementManager; + + private StreamPartition streamPartition; + private String shardIterator; private String lastShardIterator; @@ -167,9 +168,6 @@ static class Builder { private String shardId; - private AcknowledgementSet acknowledgementSet; - private Duration dataFileAcknowledgmentTimeout; - private StreamConfig streamConfig; public Builder(final DynamoDbStreamsClient dynamoDbStreamsClient, @@ -199,6 +197,16 @@ public Builder checkpointer(StreamCheckpointer checkpointer) { return this; } + public Builder shardAcknowledgementManager(ShardAcknowledgementManager shardAcknowledgementManager) { + this.shardAcknowledgementManager = shardAcknowledgementManager; + return this; + } + + public Builder streamPartition(StreamPartition streamPartition) { + this.streamPartition = streamPartition; + return this; + } + public Builder shardIterator(String shardIterator) { this.shardIterator = shardIterator; return this; @@ -219,16 +227,6 @@ public Builder waitForExport(boolean waitForExport) { return this; } - public Builder acknowledgmentSet(AcknowledgementSet acknowledgementSet) { - this.acknowledgementSet = acknowledgementSet; - return this; - } - - public Builder acknowledgmentSetTimeout(Duration dataFileAcknowledgmentTimeout) { - this.dataFileAcknowledgmentTimeout = dataFileAcknowledgmentTimeout; - return this; - } - public ShardConsumer build() { return new ShardConsumer(this); } @@ -242,100 +240,91 @@ public void run() { // Check should skip processing or not. if (shouldSkip()) { shardProgress.increment(); - if (acknowledgementSet != null) { - checkpointer.updateShardForAcknowledgmentWait(shardAcknowledgmentTimeout); - acknowledgementSet.complete(); + if (shardAcknowledgementManager != null) { + checkpointer.completePartition(); } return; } - - if (acknowledgementSet != null) { - addProgressCheck(acknowledgementSet); + if (shardAcknowledgementManager != null) { + shardAcknowledgementManager.startUpdatingOwnershipForShard(streamPartition); } - long lastCheckpointTime = System.currentTimeMillis(); String sequenceNumber = ""; int interval; List records; - try { - while (!shouldStop) { - if (shardIterator == null) { - // End of Shard - LOG.debug("Reached end of shard"); + while (!shouldStop) { + if (shardIterator == null) { + // End of Shard + LOG.debug("Reached end of shard"); + break; + } + + if (System.currentTimeMillis() - lastCheckpointTime > DEFAULT_CHECKPOINT_INTERVAL_MILLS) { + LOG.debug("{} records written to buffer for shard {}", recordsWrittenToBuffer, shardId); + if (shardAcknowledgementManager == null) { checkpointer.checkpoint(sequenceNumber); - break; } + lastCheckpointTime = System.currentTimeMillis(); + } - if (System.currentTimeMillis() - lastCheckpointTime > DEFAULT_CHECKPOINT_INTERVAL_MILLS) { - LOG.debug("{} records written to buffer for shard {}", recordsWrittenToBuffer, shardId); - if (acknowledgementSet != null) { - checkpointer.updateShardForAcknowledgmentWait(shardAcknowledgmentTimeout); - } else { - checkpointer.checkpoint(sequenceNumber); - } - lastCheckpointTime = System.currentTimeMillis(); - } + GetRecordsResponse response = callGetRecords(shardIterator); + shardIterator = response.nextShardIterator(); + if (!response.records().isEmpty()) { + // Always use the last sequence number for checkpoint + sequenceNumber = response.records().get(response.records().size() - 1).dynamodb().sequenceNumber(); + Instant lastEventTime = response.records().get(response.records().size() - 1).dynamodb().approximateCreationDateTime(); - GetRecordsResponse response = callGetRecords(shardIterator); - shardIterator = response.nextShardIterator(); - if (!response.records().isEmpty()) { - // Always use the last sequence number for checkpoint - sequenceNumber = response.records().get(response.records().size() - 1).dynamodb().sequenceNumber(); - Instant lastEventTime = response.records().get(response.records().size() - 1).dynamodb().approximateCreationDateTime(); + if (lastEventTime.isBefore(startTime)) { + LOG.debug("Get {} events before start time, ignore...", response.records().size()); + continue; + } + if (waitForExport) { + waitForExport(); + waitForExport = false; + } - if (lastEventTime.isBefore(startTime)) { - LOG.debug("Get {} events before start time, ignore...", response.records().size()); - continue; - } - if (waitForExport) { - checkpointer.checkpoint(sequenceNumber); - waitForExport(); - waitForExport = false; - } - records = response.records().stream() - .filter(record -> record.dynamodb().approximateCreationDateTime().isAfter(startTime)) - .collect(Collectors.toList()); - recordConverter.writeToBuffer(acknowledgementSet, records); - shardProgress.increment(); - recordsWrittenToBuffer += records.size(); - long delay = System.currentTimeMillis() - lastEventTime.toEpochMilli(); - interval = delay > GET_RECORD_DELAY_THRESHOLD_MILLS ? MINIMUM_GET_RECORD_INTERVAL_MILLS : GET_RECORD_INTERVAL_MILLS; - - } else { - interval = GET_RECORD_INTERVAL_MILLS; - shardProgress.increment(); + AcknowledgementSet acknowledgementSet = null; + if (shardAcknowledgementManager != null) { + acknowledgementSet = shardAcknowledgementManager.createAcknowledgmentSet(streamPartition, sequenceNumber, shardIterator == null); } - try { - // Idle between get records call. - Thread.sleep(interval); - } catch (InterruptedException e) { - throw new RuntimeException(e); + records = response.records().stream() + .filter(record -> record.dynamodb().approximateCreationDateTime().isAfter(startTime)) + .collect(Collectors.toList()); + + recordConverter.writeToBuffer(acknowledgementSet, records); + if (acknowledgementSet != null) { + acknowledgementSet.complete(); } - } - // interrupted - if (shouldStop) { - // Do last checkpoint and then quit - LOG.warn("Processing for shard {} was interrupted by a shutdown signal, giving up shard", shardId); - checkpointer.checkpoint(sequenceNumber); - throw new RuntimeException("Consuming shard was interrupted from shutdown"); - } + shardProgress.increment(); + recordsWrittenToBuffer += records.size(); + long delay = System.currentTimeMillis() - lastEventTime.toEpochMilli(); + interval = delay > GET_RECORD_DELAY_THRESHOLD_MILLS ? MINIMUM_GET_RECORD_INTERVAL_MILLS : GET_RECORD_INTERVAL_MILLS; - if (acknowledgementSet != null) { - checkpointer.updateShardForAcknowledgmentWait(shardAcknowledgmentTimeout); - acknowledgementSet.complete(); + } else { + interval = GET_RECORD_INTERVAL_MILLS; + shardProgress.increment(); } - if (waitForExport) { - waitForExport(); - } - } catch (final Exception exc) { - if (acknowledgementSet != null) { - acknowledgementSet.cancel(); + try { + // Idle between get records call. + Thread.sleep(interval); + } catch (InterruptedException e) { + throw new RuntimeException(e); } - throw exc; + } + + // interrupted + if (shouldStop) { + // Do last checkpoint and then quit + LOG.warn("Processing for shard {} was interrupted by a shutdown signal, giving up shard", shardId); + throw new RuntimeException("Consuming shard was interrupted from shutdown"); + } + + if (waitForExport) { + waitForExport(); } } @@ -377,7 +366,9 @@ private void waitForExport() { numberOfWaits++; if (numberOfWaits % DEFAULT_WAIT_COUNT_TO_CHECKPOINT == 0) { // To extend the timeout of lease - checkpointer.checkpoint(null); + if (shardAcknowledgementManager == null) { + checkpointer.checkpoint(null); + } } } catch (InterruptedException e) { LOG.error("Wait for export is interrupted ({})", e.getMessage()); @@ -423,10 +414,4 @@ public static void stopAll() { shouldStop = true; } - private void addProgressCheck(final AcknowledgementSet acknowledgementSet) { - acknowledgementSet.addProgressCheck( - (ignored) -> { - acknowledgementSet.increaseExpiry(ACKNOWLEDGMENT_EXPIRY_INCREASE_TIME); - }, ACKNOWLEDGMENT_PROGRESS_CHECK_INTERVAL); - } } \ No newline at end of file diff --git a/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumerFactory.java b/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumerFactory.java index 0a72133032..eb101a5638 100644 --- a/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumerFactory.java +++ b/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumerFactory.java @@ -6,7 +6,6 @@ package org.opensearch.dataprepper.plugins.source.dynamodb.stream; import org.opensearch.dataprepper.metrics.PluginMetrics; -import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; @@ -64,8 +63,8 @@ public ShardConsumerFactory(final EnhancedSourceCoordinator enhancedSourceCoordi } public Runnable createConsumer(final StreamPartition streamPartition, - final AcknowledgementSet acknowledgementSet, - final Duration shardAcknowledgmentTimeout) { + final Duration shardAcknowledgmentTimeout, + final ShardAcknowledgementManager shardAcknowledgementManager) { LOG.info("Starting to consume shard " + streamPartition.getShardId()); @@ -76,8 +75,7 @@ public Runnable createConsumer(final StreamPartition streamPartition, Instant startTime = null; boolean waitForExport = false; if (progressState.isPresent()) { - // We can't checkpoint with acks yet - sequenceNumber = acknowledgementSet == null ? null : progressState.get().getSequenceNumber(); + sequenceNumber = shardAcknowledgementManager == null ? null : progressState.get().getSequenceNumber(); waitForExport = progressState.get().shouldWaitForExport(); if (progressState.get().getStartTime() != 0) { startTime = Instant.ofEpochMilli(progressState.get().getStartTime()); @@ -104,13 +102,13 @@ public Runnable createConsumer(final StreamPartition streamPartition, ShardConsumer shardConsumer = ShardConsumer.builder(streamsClient, pluginMetrics, dynamoDBSourceAggregateMetrics, buffer, streamConfig) .tableInfo(tableInfo) .checkpointer(checkpointer) + .shardAcknowledgementManager(shardAcknowledgementManager) + .streamPartition(streamPartition) .shardIterator(shardIterator) .shardId(streamPartition.getShardId()) .lastShardIterator(lastShardIterator) .startTime(startTime) .waitForExport(waitForExport) - .acknowledgmentSet(acknowledgementSet) - .acknowledgmentSetTimeout(shardAcknowledgmentTimeout) .build(); return shardConsumer; } diff --git a/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/StreamCheckpointer.java b/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/StreamCheckpointer.java index e7ba18688d..62d5559295 100644 --- a/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/StreamCheckpointer.java +++ b/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/StreamCheckpointer.java @@ -62,7 +62,7 @@ public boolean isExportDone() { return globalPartition.isPresent(); } - public void updateShardForAcknowledgmentWait(final Duration acknowledgmentSetTimeout) { - coordinator.saveProgressStateForPartition(streamPartition, acknowledgmentSetTimeout); + public void completePartition() { + coordinator.completePartition(streamPartition); } } diff --git a/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/StreamScheduler.java b/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/StreamScheduler.java index bd58e9e0d5..976965e432 100644 --- a/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/StreamScheduler.java +++ b/data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/StreamScheduler.java @@ -6,7 +6,6 @@ package org.opensearch.dataprepper.plugins.source.dynamodb.stream; import org.opensearch.dataprepper.metrics.PluginMetrics; -import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourcePartition; @@ -24,8 +23,6 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiConsumer; -import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; - /** * A scheduler to manage all the stream related work in one place */ @@ -50,14 +47,14 @@ public class StreamScheduler implements Runnable { private final EnhancedSourceCoordinator coordinator; private final ShardConsumerFactory consumerFactory; private final ExecutorService executor; - private final PluginMetrics pluginMetrics; private final AtomicLong activeChangeEventConsumers; private final AtomicLong shardsInProcessing; - private final AcknowledgementSetManager acknowledgementSetManager; private final DynamoDBSourceConfig dynamoDBSourceConfig; private final BackoffCalculator backoffCalculator; private int noAvailableShardsCount = 0; + private final ShardAcknowledgementManager shardAcknowledgementManager; + public StreamScheduler(final EnhancedSourceCoordinator coordinator, final ShardConsumerFactory consumerFactory, @@ -67,10 +64,10 @@ public StreamScheduler(final EnhancedSourceCoordinator coordinator, final BackoffCalculator backoffCalculator) { this.coordinator = coordinator; this.consumerFactory = consumerFactory; - this.pluginMetrics = pluginMetrics; - this.acknowledgementSetManager = acknowledgementSetManager; this.dynamoDBSourceConfig = dynamoDBSourceConfig; this.backoffCalculator = backoffCalculator; + this.shardAcknowledgementManager = dynamoDBSourceConfig.isAcknowledgmentsEnabled() ? + new ShardAcknowledgementManager(acknowledgementSetManager, coordinator, dynamoDBSourceConfig, coordinator::giveUpPartition) : null; executor = Executors.newFixedThreadPool(MAX_JOB_COUNT); activeChangeEventConsumers = pluginMetrics.gauge(ACTIVE_CHANGE_EVENT_CONSUMERS, new AtomicLong()); @@ -78,42 +75,12 @@ public StreamScheduler(final EnhancedSourceCoordinator coordinator, } private void processStreamPartition(StreamPartition streamPartition) { - final boolean acknowledgmentsEnabled = dynamoDBSourceConfig.isAcknowledgmentsEnabled(); - AcknowledgementSet acknowledgementSet = null; - - if (acknowledgmentsEnabled) { - acknowledgementSet = acknowledgementSetManager.create((result) -> { - if (result) { - LOG.info("Received acknowledgment of completion from sink for shard {}", streamPartition.getShardId()); - completeConsumer(streamPartition).accept(null, null); - } else { - LOG.warn("Negative acknowledgment received for shard {}, it will be retried", streamPartition.getShardId()); - coordinator.giveUpPartition(streamPartition); - } - }, dynamoDBSourceConfig.getShardAcknowledgmentTimeout()); - } - Runnable shardConsumer = consumerFactory.createConsumer(streamPartition, acknowledgementSet, dynamoDBSourceConfig.getShardAcknowledgmentTimeout()); + Runnable shardConsumer = consumerFactory.createConsumer(streamPartition, dynamoDBSourceConfig.getShardAcknowledgmentTimeout(), shardAcknowledgementManager); if (shardConsumer != null) { CompletableFuture runConsumer = CompletableFuture.runAsync(shardConsumer, executor); - - if (acknowledgmentsEnabled) { - runConsumer.whenComplete((v, ex) -> { - numOfWorkers.decrementAndGet(); - if (ex != null) { - LOG.error(NOISY, "Received exception while processing shard {}, giving up this shard for reprocessing: {}", - streamPartition.getShardId(), ex); - coordinator.giveUpPartition(streamPartition); - } - if (numOfWorkers.get() == 0) { - activeChangeEventConsumers.decrementAndGet(); - } - shardsInProcessing.decrementAndGet(); - }); - } else { - runConsumer.whenComplete(completeConsumer(streamPartition)); - } + runConsumer.whenComplete(completeConsumer(streamPartition)); numOfWorkers.incrementAndGet(); if (numOfWorkers.get() % 10 == 0) { SHARD_COUNT_LOGGER.info("Actively processing {} shards", numOfWorkers.get()); @@ -164,6 +131,11 @@ public void run() { // Should Stop LOG.warn("Stream Scheduler is interrupted, looks like shutdown has triggered"); + // Shutdown acknowledgment manager if it exists + if (shardAcknowledgementManager != null) { + shardAcknowledgementManager.shutdown(); + } + // Cannot call executor.shutdownNow() here // Otherwise the final checkpoint will fail due to SDK interruption. ShardConsumer.stopAll(); @@ -172,23 +144,24 @@ public void run() { private BiConsumer completeConsumer(StreamPartition streamPartition) { return (v, ex) -> { - if (!dynamoDBSourceConfig.isAcknowledgmentsEnabled()) { - numOfWorkers.decrementAndGet(); - if (numOfWorkers.get() == 0) { - activeChangeEventConsumers.decrementAndGet(); - } - shardsInProcessing.decrementAndGet(); + numOfWorkers.decrementAndGet(); + if (numOfWorkers.get() == 0) { + activeChangeEventConsumers.decrementAndGet(); } + shardsInProcessing.decrementAndGet(); if (ex == null) { LOG.info("Shard consumer for {} is completed", streamPartition.getShardId()); - coordinator.completePartition(streamPartition); + if (!dynamoDBSourceConfig.isAcknowledgmentsEnabled()) { + coordinator.completePartition(streamPartition); + } } else { - // Do nothing - // The consumer must have already done one last checkpointing. LOG.error("Received an exception while processing shard {}, giving up shard: {}", streamPartition.getShardId(), ex); - coordinator.giveUpPartition(streamPartition); + if (dynamoDBSourceConfig.isAcknowledgmentsEnabled()) { + shardAcknowledgementManager.giveUpPartition(streamPartition); + } else { + coordinator.giveUpPartition(streamPartition); + } } }; } - } \ No newline at end of file diff --git a/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardAcknowledgementManagerTest.java b/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardAcknowledgementManagerTest.java new file mode 100644 index 0000000000..9b13fe807a --- /dev/null +++ b/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardAcknowledgementManagerTest.java @@ -0,0 +1,118 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.dynamodb.stream; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; +import org.opensearch.dataprepper.plugins.source.dynamodb.DynamoDBSourceConfig; +import org.opensearch.dataprepper.plugins.source.dynamodb.coordination.partition.StreamPartition; +import org.opensearch.dataprepper.plugins.source.dynamodb.coordination.state.StreamProgressState; + +import java.time.Duration; +import java.time.Instant; +import java.util.Optional; +import java.util.function.Consumer; + +import static org.opensearch.dataprepper.test.helper.ReflectivelySetField.setField; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.verify; + +@ExtendWith(MockitoExtension.class) +class ShardAcknowledgementManagerTest { + + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + + @Mock + private EnhancedSourceCoordinator sourceCoordinator; + + @Mock + private DynamoDBSourceConfig dynamoDBSourceConfig; + + @Mock + private StreamPartition streamPartition; + + @Mock + private StreamProgressState streamProgressState; + + @Mock + private AcknowledgementSet acknowledgementSet; + + @Mock + private Consumer stopWorkerConsumer; + + private ShardAcknowledgementManager shardAcknowledgementManager; + + @BeforeEach + void setUp() { + shardAcknowledgementManager = new ShardAcknowledgementManager( + acknowledgementSetManager, sourceCoordinator, dynamoDBSourceConfig, stopWorkerConsumer); + } + + @Test + void testCreateAcknowledgmentSet() { + when(dynamoDBSourceConfig.getShardAcknowledgmentTimeout()).thenReturn(Duration.ofMinutes(15)); + when(acknowledgementSetManager.create(any(Consumer.class), any(Duration.class))) + .thenReturn(acknowledgementSet); + + AcknowledgementSet result = shardAcknowledgementManager.createAcknowledgmentSet( + streamPartition, "seq123", false); + + assertNotNull(result); + verify(acknowledgementSetManager).create(any(Consumer.class), eq(Duration.ofMinutes(15))); + } + + @Test + void testIsExportDone() { + when(streamPartition.getStreamArn()).thenReturn("stream-arn"); + when(sourceCoordinator.getPartition("stream-arn")).thenReturn(Optional.of(streamPartition)); + + boolean result = shardAcknowledgementManager.isExportDone(streamPartition); + + assertTrue(result); + } + + @Test + void testShutdown() { + assertDoesNotThrow(() -> shardAcknowledgementManager.shutdown()); + } + + @Test + void testUpdateOwnershipForAllShardPartitions() throws Exception { + when(dynamoDBSourceConfig.getShardAcknowledgmentTimeout()).thenReturn(Duration.ofMinutes(15)); + when(acknowledgementSetManager.create(any(Consumer.class), any(Duration.class))).thenReturn(acknowledgementSet); + + // Create acknowledgment set to add partition to checkpoints + shardAcknowledgementManager.createAcknowledgmentSet(streamPartition, "seq123", false); + + // Set lastCheckpointTime to past to trigger checkpoint interval + setField(ShardAcknowledgementManager.class, shardAcknowledgementManager, + "lastCheckpointTime", Instant.now().minus(Duration.ofMinutes(5))); + + // Call updateOwnershipForAllShardPartitions directly + shardAcknowledgementManager.updateOwnershipForAllShardPartitions(); + + // Verify that saveProgressStateForPartition is called + verify(sourceCoordinator).saveProgressStateForPartition(eq(streamPartition), any(Duration.class)); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumerFactoryTest.java b/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumerFactoryTest.java index 6d503f7a19..d9627ad329 100644 --- a/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumerFactoryTest.java +++ b/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumerFactoryTest.java @@ -28,6 +28,7 @@ import software.amazon.awssdk.services.dynamodb.model.InternalServerErrorException; import software.amazon.awssdk.services.dynamodb.streams.DynamoDbStreamsClient; +import java.time.Duration; import java.time.Instant; import java.util.Optional; import java.util.UUID; @@ -116,7 +117,8 @@ public void test_create_shardConsumer_correctly() { streamPartition = new StreamPartition(streamArn, shardId, Optional.of(state)); ShardConsumerFactory consumerFactory = new ShardConsumerFactory(coordinator, dynamoDbStreamsClient, pluginMetrics, dynamoDBSourceAggregateMetrics, buffer, streamConfig); - Runnable consumer = consumerFactory.createConsumer(streamPartition, null, null); + ShardAcknowledgementManager shardAcknowledgementManager = mock(ShardAcknowledgementManager.class); + Runnable consumer = consumerFactory.createConsumer(streamPartition, Duration.ofMinutes(1), shardAcknowledgementManager); assertThat(consumer, notNullValue()); verify(dynamoDbStreamsClient).getShardIterator(any(GetShardIteratorRequest.class)); @@ -133,7 +135,8 @@ public void test_create_shardConsumer_for_closedShards() { streamPartition = new StreamPartition(streamArn, shardId, Optional.of(state)); ShardConsumerFactory consumerFactory = new ShardConsumerFactory(coordinator, dynamoDbStreamsClient, pluginMetrics, dynamoDBSourceAggregateMetrics, buffer, streamConfig); - Runnable consumer = consumerFactory.createConsumer(streamPartition, null, null); + ShardAcknowledgementManager shardAcknowledgementManager = mock(ShardAcknowledgementManager.class); + Runnable consumer = consumerFactory.createConsumer(streamPartition, Duration.ofMinutes(1), shardAcknowledgementManager); assertThat(consumer, notNullValue()); // Should get iterators twice verify(dynamoDbStreamsClient, times(2)).getShardIterator(any(GetShardIteratorRequest.class)); @@ -154,7 +157,8 @@ void stream5xxErrors_is_incremented_when_get_shard_iterator_throws_internal_exce when(dynamoDBSourceAggregateMetrics.getStream5xxErrors()).thenReturn(stream5xxErrors); ShardConsumerFactory consumerFactory = new ShardConsumerFactory(coordinator, dynamoDbStreamsClient, pluginMetrics, dynamoDBSourceAggregateMetrics, buffer, streamConfig); - Runnable consumer = consumerFactory.createConsumer(streamPartition, null, null); + ShardAcknowledgementManager shardAcknowledgementManager = mock(ShardAcknowledgementManager.class); + Runnable consumer = consumerFactory.createConsumer(streamPartition, Duration.ofMinutes(1), shardAcknowledgementManager); assertThat(consumer, nullValue()); verify(stream5xxErrors).increment(); verify(streamApiInvocations).increment(); @@ -172,7 +176,8 @@ void stream4xxErrors_is_incremented_when_get_shard_iterator_throws_dynamodb_exce when(dynamoDBSourceAggregateMetrics.getStream4xxErrors()).thenReturn(stream4xxErrors); ShardConsumerFactory consumerFactory = new ShardConsumerFactory(coordinator, dynamoDbStreamsClient, pluginMetrics, dynamoDBSourceAggregateMetrics, buffer, streamConfig); - Runnable consumer = consumerFactory.createConsumer(streamPartition, null, null); + ShardAcknowledgementManager shardAcknowledgementManager = mock(ShardAcknowledgementManager.class); + Runnable consumer = consumerFactory.createConsumer(streamPartition, Duration.ofMinutes(1), shardAcknowledgementManager); assertThat(consumer, nullValue()); verify(stream4xxErrors).increment(); verify(streamApiInvocations).increment(); diff --git a/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumerTest.java b/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumerTest.java index 8d7fd97e0d..034405898e 100644 --- a/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumerTest.java +++ b/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumerTest.java @@ -10,14 +10,12 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.buffer.common.BufferAccumulator; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; -import org.opensearch.dataprepper.model.acknowledgements.ProgressCheck; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; @@ -47,12 +45,10 @@ import java.util.Optional; import java.util.Random; import java.util.UUID; -import java.util.function.Consumer; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; @@ -63,7 +59,6 @@ import static org.opensearch.dataprepper.plugins.source.dynamodb.stream.ShardConsumer.BUFFER_TIMEOUT; import static org.opensearch.dataprepper.plugins.source.dynamodb.stream.ShardConsumer.DEFAULT_BUFFER_BATCH_SIZE; import static org.opensearch.dataprepper.plugins.source.dynamodb.stream.ShardConsumer.SHARD_PROGRESS; -import static org.opensearch.dataprepper.plugins.source.dynamodb.stream.StreamCheckpointer.CHECKPOINT_OWNERSHIP_TIMEOUT_INCREASE; @ExtendWith(MockitoExtension.class) class ShardConsumerTest { @@ -110,14 +105,13 @@ class ShardConsumerTest { @Mock private StreamConfig streamConfig; - - private StreamCheckpointer checkpointer; + @Mock + private ShardAcknowledgementManager shardAcknowledgementManager; private StreamPartition streamPartition; private TableInfo tableInfo; - private final String tableName = UUID.randomUUID().toString(); private final String tableArn = "arn:aws:dynamodb:us-west-2:123456789012:table/" + tableName; @@ -134,7 +128,6 @@ class ShardConsumerTest { private final int total = random.nextInt(10) + 1; - @BeforeEach void setup() throws Exception { @@ -143,7 +136,6 @@ void setup() throws Exception { state.setStartTime(Instant.now().toEpochMilli()); streamPartition = new StreamPartition(streamArn, shardId, Optional.of(state)); - // Mock Global Table Info lenient().when(coordinator.getPartition(tableArn)).thenReturn(Optional.of(tableInfoGlobalState)); TableMetadata metadata = TableMetadata.builder() @@ -157,14 +149,12 @@ void setup() throws Exception { lenient().when(coordinator.createPartition(any(EnhancedSourcePartition.class))).thenReturn(true); lenient().doNothing().when(coordinator).completePartition(any(EnhancedSourcePartition.class)); - lenient().doNothing().when(coordinator).saveProgressStateForPartition(any(EnhancedSourcePartition.class), eq(null)); + lenient().doNothing().when(coordinator).saveProgressStateForPartition(any(EnhancedSourcePartition.class), any(Duration.class)); lenient().doNothing().when(coordinator).giveUpPartition(any(EnhancedSourcePartition.class)); lenient().doNothing().when(bufferAccumulator).add(any(org.opensearch.dataprepper.model.record.Record.class)); lenient().doNothing().when(bufferAccumulator).flush(); - checkpointer = new StreamCheckpointer(coordinator, streamPartition); - List records = buildRecords(total); GetRecordsResponse response = GetRecordsResponse.builder() .records(records) @@ -177,176 +167,203 @@ void setup() throws Exception { given(pluginMetrics.counter("changeEventsProcessingErrors")).willReturn(testCounter); given(pluginMetrics.summary(anyString())).willReturn(testSummary); - when(aggregateMetrics.getStreamApiInvocations()).thenReturn(streamApiInvocations); - } + lenient().when(aggregateMetrics.getStreamApiInvocations()).thenReturn(streamApiInvocations); + lenient().when(shardAcknowledgementManager.isExportDone(any(StreamPartition.class))).thenReturn(true); + } + @Test void test_run_shardConsumer_correctly() throws Exception { - ShardConsumer shardConsumer; - try ( - final MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class) - ) { - bufferAccumulatorMockedStatic.when(() -> BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT)).thenReturn(bufferAccumulator); - shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig) - .shardIterator(shardIterator) - .checkpointer(checkpointer) - .tableInfo(tableInfo) - .startTime(null) - .waitForExport(false) - .build(); + // Disable the static shouldStop flag to prevent early exit + try (MockedStatic shardConsumerMockedStatic = mockStatic(ShardConsumer.class, invocation -> { + if (invocation.getMethod().getName().equals("stopAll")) { + return null; + } else if (invocation.getMethod().getName().equals("shouldStop")) { + return false; + } + return invocation.callRealMethod(); + })) { + ShardConsumer shardConsumer; + try (final MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) { + bufferAccumulatorMockedStatic.when(() -> BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT)).thenReturn(bufferAccumulator); + shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig) + .shardIterator(shardIterator) + .shardAcknowledgementManager(shardAcknowledgementManager) + .streamPartition(streamPartition) + .tableInfo(tableInfo) + .startTime(null) + .waitForExport(false) + .build(); + } + + shardConsumer.run(); + + verify(dynamoDbStreamsClient).getRecords(any(GetRecordsRequest.class)); + verify(bufferAccumulator, times(total)).add(any(org.opensearch.dataprepper.model.record.Record.class)); + verify(bufferAccumulator).flush(); + verify(streamApiInvocations).increment(); + verify(shardProgress).increment(); } - - shardConsumer.run(); - - // Should call GetRecords - verify(dynamoDbStreamsClient).getRecords(any(GetRecordsRequest.class)); - - // Should write to buffer - verify(bufferAccumulator, times(total)).add(any(org.opensearch.dataprepper.model.record.Record.class)); - verify(bufferAccumulator).flush(); - // Should complete the consumer as reach to end of shard - verify(coordinator).saveProgressStateForPartition(any(StreamPartition.class), eq(CHECKPOINT_OWNERSHIP_TIMEOUT_INCREASE)); - - verify(streamApiInvocations).increment(); - verify(shardProgress).increment(); } @Test void test_run_shardConsumer_with_acknowledgments_correctly() throws Exception { final AcknowledgementSet acknowledgementSet = mock(AcknowledgementSet.class); - final Duration acknowledgmentTimeout = Duration.ofSeconds(30); - - ShardConsumer shardConsumer; - try ( - final MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class) - ) { - bufferAccumulatorMockedStatic.when(() -> BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT)).thenReturn(bufferAccumulator); - shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig) - .shardIterator(shardIterator) - .checkpointer(checkpointer) - .tableInfo(tableInfo) - .startTime(null) - .acknowledgmentSetTimeout(acknowledgmentTimeout) - .acknowledgmentSet(acknowledgementSet) - .waitForExport(false) - .build(); + + // Mock the shardAcknowledgementManager to return our mock acknowledgementSet + lenient().when(shardAcknowledgementManager.createAcknowledgmentSet(any(StreamPartition.class), any(String.class), any(Boolean.class))) + .thenReturn(acknowledgementSet); + + // Disable the static shouldStop flag to prevent early exit + try (MockedStatic shardConsumerMockedStatic = mockStatic(ShardConsumer.class, invocation -> { + if (invocation.getMethod().getName().equals("stopAll")) { + return null; + } else if (invocation.getMethod().getName().equals("shouldStop")) { + return false; + } + return invocation.callRealMethod(); + })) { + ShardConsumer shardConsumer; + try (final MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) { + bufferAccumulatorMockedStatic.when(() -> BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT)).thenReturn(bufferAccumulator); + shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig) + .shardIterator(shardIterator) + .shardAcknowledgementManager(shardAcknowledgementManager) + .streamPartition(streamPartition) + .tableInfo(tableInfo) + .startTime(null) + .waitForExport(false) + .build(); + } + + shardConsumer.run(); + + verify(dynamoDbStreamsClient).getRecords(any(GetRecordsRequest.class)); + verify(bufferAccumulator, times(total)).add(any(org.opensearch.dataprepper.model.record.Record.class)); + verify(bufferAccumulator).flush(); + verify(streamApiInvocations).increment(); + verify(shardProgress).increment(); } - - shardConsumer.run(); - - final ArgumentCaptor progressCheckConsumerArgumentCaptor = ArgumentCaptor.forClass(Consumer.class); - verify(acknowledgementSet).addProgressCheck(progressCheckConsumerArgumentCaptor.capture(), any(Duration.class)); - - final Consumer progressCheckConsumer = progressCheckConsumerArgumentCaptor.getValue(); - progressCheckConsumer.accept(mock(ProgressCheck.class)); - - verify(acknowledgementSet).increaseExpiry(any(Duration.class)); - - // Should call GetRecords - verify(dynamoDbStreamsClient).getRecords(any(GetRecordsRequest.class)); - - // Should write to buffer - verify(bufferAccumulator, times(total)).add(any(org.opensearch.dataprepper.model.record.Record.class)); - verify(bufferAccumulator).flush(); - - // Should complete the consumer as reach to end of shard - verify(coordinator).saveProgressStateForPartition(any(StreamPartition.class), eq(CHECKPOINT_OWNERSHIP_TIMEOUT_INCREASE)); - - verify(acknowledgementSet).complete(); - - verify(streamApiInvocations).increment(); - verify(shardProgress).increment(); } @Test void test_run_shardConsumer_with_acknowledgments_and_error_cancels_acknowledgment_set() throws Exception { - final AcknowledgementSet acknowledgementSet = mock(AcknowledgementSet.class); - final Duration acknowledgmentTimeout = Duration.ofSeconds(30); - - when(aggregateMetrics.getStream5xxErrors()).thenReturn(stream5xxErrors); when(dynamoDbStreamsClient.getRecords(any(GetRecordsRequest.class))).thenThrow(InternalServerErrorException.class); + when(aggregateMetrics.getStream5xxErrors()).thenReturn(stream5xxErrors); - ShardConsumer shardConsumer; - try ( - final MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class) - ) { - bufferAccumulatorMockedStatic.when(() -> BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT)).thenReturn(bufferAccumulator); - shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig) - .shardIterator(shardIterator) - .checkpointer(checkpointer) - .tableInfo(tableInfo) - .startTime(null) - .acknowledgmentSetTimeout(acknowledgmentTimeout) - .acknowledgmentSet(acknowledgementSet) - .waitForExport(false) - .build(); + // Disable the static shouldStop flag to prevent early exit + try (MockedStatic shardConsumerMockedStatic = mockStatic(ShardConsumer.class, invocation -> { + if (invocation.getMethod().getName().equals("stopAll")) { + return null; + } + return invocation.callRealMethod(); + })) { + ShardConsumer shardConsumer; + try (final MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) { + bufferAccumulatorMockedStatic.when(() -> BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT)).thenReturn(bufferAccumulator); + shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig) + .shardIterator(shardIterator) + .shardAcknowledgementManager(shardAcknowledgementManager) + .streamPartition(streamPartition) + .tableInfo(tableInfo) + .startTime(null) + .waitForExport(false) + .build(); + } + + assertThrows(RuntimeException.class, shardConsumer::run); + + verify(stream5xxErrors).increment(); + verify(streamApiInvocations).increment(); } - - assertThrows(RuntimeException.class, shardConsumer::run); - - final ArgumentCaptor progressCheckConsumerArgumentCaptor = ArgumentCaptor.forClass(Consumer.class); - verify(acknowledgementSet).addProgressCheck(progressCheckConsumerArgumentCaptor.capture(), any(Duration.class)); - - final Consumer progressCheckConsumer = progressCheckConsumerArgumentCaptor.getValue(); - progressCheckConsumer.accept(mock(ProgressCheck.class)); - - verify(acknowledgementSet).increaseExpiry(any(Duration.class)); - - verify(acknowledgementSet).cancel(); } @Test void test_run_shardConsumer_catches_5xx_exception_and_increments_metric() { - ShardConsumer shardConsumer; + // First set up the mocks for the exception case + when(dynamoDbStreamsClient.getRecords(any(GetRecordsRequest.class))).thenThrow(InternalServerErrorException.class); when(aggregateMetrics.getStream5xxErrors()).thenReturn(stream5xxErrors); - try ( - final MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) { - bufferAccumulatorMockedStatic.when(() -> BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT)).thenReturn(bufferAccumulator); - shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig) - .shardIterator(shardIterator) - .checkpointer(checkpointer) - .tableInfo(tableInfo) - .startTime(null) - .waitForExport(false) - .build(); + + // Disable the static shouldStop flag to prevent early exit + try (MockedStatic shardConsumerMockedStatic = mockStatic(ShardConsumer.class, invocation -> { + if (invocation.getMethod().getName().equals("stopAll")) { + return null; + } + return invocation.callRealMethod(); + })) { + ShardConsumer shardConsumer; + try (final MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) { + bufferAccumulatorMockedStatic.when(() -> BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT)).thenReturn(bufferAccumulator); + shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig) + .shardIterator(shardIterator) + .shardAcknowledgementManager(shardAcknowledgementManager) + .streamPartition(streamPartition) + .tableInfo(tableInfo) + .startTime(null) + .waitForExport(false) + .build(); + } + + assertThrows(RuntimeException.class, shardConsumer::run); + + verify(stream5xxErrors).increment(); + verify(streamApiInvocations).increment(); } - - when(dynamoDbStreamsClient.getRecords(any(GetRecordsRequest.class))).thenThrow(InternalServerErrorException.class); - - assertThrows(RuntimeException.class, shardConsumer::run); - - verify(stream5xxErrors).increment(); - verify(streamApiInvocations).increment(); } @Test void test_run_shardConsumer_catches_4xx_exception_and_increments_metric() { - ShardConsumer shardConsumer; + // First set up the mocks for the exception case + when(dynamoDbStreamsClient.getRecords(any(GetRecordsRequest.class))).thenThrow(DynamoDbException.class); when(aggregateMetrics.getStream4xxErrors()).thenReturn(stream4xxErrors); - try ( - final MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) { + + // Disable the static shouldStop flag to prevent early exit + try (MockedStatic shardConsumerMockedStatic = mockStatic(ShardConsumer.class, invocation -> { + if (invocation.getMethod().getName().equals("stopAll")) { + return null; + } + return invocation.callRealMethod(); + })) { + ShardConsumer shardConsumer; + try (final MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) { + bufferAccumulatorMockedStatic.when(() -> BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT)).thenReturn(bufferAccumulator); + shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig) + .shardIterator(shardIterator) + .shardAcknowledgementManager(shardAcknowledgementManager) + .streamPartition(streamPartition) + .tableInfo(tableInfo) + .startTime(null) + .waitForExport(false) + .build(); + } + + assertThrows(RuntimeException.class, shardConsumer::run); + + verify(stream4xxErrors).increment(); + verify(streamApiInvocations).increment(); + } + } + + @Test + void test_run_shardConsumer_calls_startUpdatingOwnershipForShard() throws Exception { + try (final MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) { bufferAccumulatorMockedStatic.when(() -> BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT)).thenReturn(bufferAccumulator); - shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig) - .shardIterator(shardIterator) - .checkpointer(checkpointer) + ShardConsumer shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig) + .shardIterator(null) + .shardAcknowledgementManager(shardAcknowledgementManager) + .streamPartition(streamPartition) .tableInfo(tableInfo) .startTime(null) .waitForExport(false) .build(); - } - when(dynamoDbStreamsClient.getRecords(any(GetRecordsRequest.class))).thenThrow(DynamoDbException.class); - - assertThrows(RuntimeException.class, shardConsumer::run); - - verify(stream4xxErrors).increment(); - verify(streamApiInvocations).increment(); + shardConsumer.run(); + } + // Verify that startUpdatingOwnershipForShard is called + verify(shardAcknowledgementManager).startUpdatingOwnershipForShard(streamPartition); } - /** - * Helper function to generate some data. - */ private List buildRecords(int count) { List records = new ArrayList<>(); for (int i = 0; i < count; i++) { diff --git a/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/StreamSchedulerTest.java b/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/StreamSchedulerTest.java index 5928442dbb..d1ff6d382d 100644 --- a/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/StreamSchedulerTest.java +++ b/data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/StreamSchedulerTest.java @@ -126,7 +126,13 @@ public void test_normal_run() throws InterruptedException { when(backoffCalculator.calculateBackoffToAcquireNextShard(eq(1), any(AtomicInteger.class))) .thenReturn(10000L); - when(consumerFactory.createConsumer(any(StreamPartition.class), eq(null), any(Duration.class))).thenReturn(() -> LOG.info("Hello")); + // Set up the mock for getShardAcknowledgmentTimeout + Duration timeout = Duration.ofMinutes(1); + when(dynamoDBSourceConfig.getShardAcknowledgmentTimeout()).thenReturn(timeout); + when(dynamoDBSourceConfig.isAcknowledgmentsEnabled()).thenReturn(false); + + // Set up the mock for createConsumer with the specific timeout + when(consumerFactory.createConsumer(any(StreamPartition.class), eq(timeout), eq(null))).thenReturn(() -> LOG.info("Hello")); when(coordinator.acquireAvailablePartition(StreamPartition.PARTITION_TYPE)).thenReturn(Optional.of(streamPartition)).thenReturn(Optional.empty()); scheduler = new StreamScheduler(coordinator, consumerFactory, pluginMetrics, acknowledgementSetManager, dynamoDBSourceConfig, backoffCalculator); @@ -140,8 +146,8 @@ public void test_normal_run() throws InterruptedException { // Should acquire the stream partition verify(coordinator).acquireAvailablePartition(StreamPartition.PARTITION_TYPE); - // Should start a new consumer - verify(consumerFactory).createConsumer(any(StreamPartition.class), eq(null), any(Duration.class)); + // Should start a new consumer with the specific timeout + verify(consumerFactory).createConsumer(any(StreamPartition.class), eq(timeout), eq(null)); // Should mask the stream partition as completed. verify(coordinator).completePartition(any(StreamPartition.class)); @@ -173,7 +179,7 @@ public void test_normal_run_with_acknowledgments() throws InterruptedException { return acknowledgementSet; }).when(acknowledgementSetManager).create(any(Consumer.class), eq(shardAcknowledgmentTimeout)); - when(consumerFactory.createConsumer(any(StreamPartition.class), eq(acknowledgementSet), eq(shardAcknowledgmentTimeout))).thenReturn(() -> LOG.info("Hello")); + when(consumerFactory.createConsumer(any(StreamPartition.class), eq(shardAcknowledgmentTimeout), any(ShardAcknowledgementManager.class))).thenReturn(() -> LOG.info("Hello")); scheduler = new StreamScheduler(coordinator, consumerFactory, pluginMetrics, acknowledgementSetManager, dynamoDBSourceConfig, backoffCalculator); @@ -188,10 +194,7 @@ public void test_normal_run_with_acknowledgments() throws InterruptedException { // Should acquire the stream partition verify(coordinator).acquireAvailablePartition(StreamPartition.PARTITION_TYPE); // Should start a new consumer - verify(consumerFactory).createConsumer(any(StreamPartition.class), any(AcknowledgementSet.class), any(Duration.class)); - - // Should mask the stream partition as completed. - verify(coordinator).completePartition(any(StreamPartition.class)); + verify(consumerFactory).createConsumer(any(StreamPartition.class), any(Duration.class), any(ShardAcknowledgementManager.class)); verify(activeShardsInProcessing).incrementAndGet(); verify(activeShardsInProcessing).decrementAndGet();