diff --git a/data-prepper-plugins/sqs-sink/build.gradle b/data-prepper-plugins/sqs-sink/build.gradle index ef3241d8e6..436f89f42d 100644 --- a/data-prepper-plugins/sqs-sink/build.gradle +++ b/data-prepper-plugins/sqs-sink/build.gradle @@ -46,7 +46,7 @@ jacocoTestCoverageVerification { violationRules { rule { limit { - minimum = 0.90 + minimum = 0.99 } } } diff --git a/data-prepper-plugins/sqs-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkIT.java b/data-prepper-plugins/sqs-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkIT.java index da7d2bddba..a30afa3257 100644 --- a/data-prepper-plugins/sqs-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkIT.java +++ b/data-prepper-plugins/sqs-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkIT.java @@ -12,6 +12,7 @@ import org.opensearch.dataprepper.aws.api.AwsConfig; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.DistributionSummary; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.extension.ExtendWith; @@ -65,6 +66,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.times; import static org.hamcrest.CoreMatchers.equalTo; @@ -125,6 +127,8 @@ public class SqsSinkIT { private Counter requestsFailedCounter; @Mock private Counter dlqSuccessCounter; + @Mock + private DistributionSummary summary; private JsonOutputCodec jsonCodec; private String bucket; @@ -165,6 +169,8 @@ void setUp() { requestsSuccessCounter = mock(Counter.class); requestsFailedCounter = mock(Counter.class); dlqSuccessCounter = mock(Counter.class); + summary = mock(DistributionSummary.class); + doNothing().when(summary).record(any(Double.class)); lenient().doAnswer((a)-> { int v = (int)(double)(a.getArgument(0)); eventsSuccessCount.addAndGet(v); @@ -213,6 +219,7 @@ void setUp() { } return null; }).when(pluginMetrics).counter(anyString()); + when(pluginMetrics.summary(anyString())).thenReturn(summary); messages = new ArrayList<>(); pluginFactory = mock(PluginFactory.class); jsonCodec = new JsonOutputCodec(new JsonOutputCodecConfig()); diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSink.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSink.java index a067358b14..e624c22e76 100644 --- a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSink.java +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSink.java @@ -71,13 +71,9 @@ public SqsSink(final PluginSetting pluginSetting, codecPluginSettings = new PluginSetting("ndjson", Map.of()); } - final OutputCodec outputCodec = pluginFactory.loadPlugin(OutputCodec.class, codecPluginSettings); AwsConfig awsConfig = sqsSinkConfig.getAwsConfig(); - if (awsConfig == null && awsCredentialsSupplier == null) { - throw new RuntimeException("Missing awsConfig and awsCredentialsSupplier"); - } - final AwsCredentialsProvider awsCredentialsProvider = awsConfig != null ? awsCredentialsSupplier.getProvider(convertToCredentialOptions(awsConfig)) : awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder().build()); - Region region = awsConfig != null ? awsConfig.getAwsRegion() : awsCredentialsSupplier.getDefaultRegion().get(); + final AwsCredentialsProvider awsCredentialsProvider = (awsConfig != null) ? awsCredentialsSupplier.getProvider(convertToCredentialOptions(awsConfig)) : awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder().build()); + Region region = (awsConfig != null) ? awsConfig.getAwsRegion() : awsCredentialsSupplier.getDefaultRegion().get(); final SqsClient sqsClient = SqsClientFactory.createSqsClient(region, awsCredentialsProvider); DlqPushHandler dlqPushHandler = null; @@ -89,6 +85,7 @@ public SqsSink(final PluginSetting pluginSetting, String role = stsClient.getCallerIdentity().arn(); dlqPushHandler = new DlqPushHandler(pluginFactory, pluginSetting, pluginMetrics, sqsSinkConfig.getDlq(), region.toString(), role, "sqsSink"); } + final OutputCodec outputCodec = pluginFactory.loadPlugin(OutputCodec.class, codecPluginSettings); sqsSinkService = new SqsSinkService(sqsSinkConfig, sqsClient, expressionEvaluator, outputCodec, sinkContext, dlqPushHandler, pluginMetrics); } diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatch.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatch.java index a8e7a5a335..195d3fff1f 100644 --- a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatch.java +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatch.java @@ -20,8 +20,11 @@ import java.time.Instant; +import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.UUID; @@ -43,7 +46,9 @@ public class SqsSinkBatch { private final SqsClient sqsClient; private final BufferFactory bufferFactory; private final SqsSinkMetrics sinkMetrics; + private String currentId; private SqsSinkBatchEntry currentBatchEntry; + private final BiConsumer addToDLQList; public SqsSinkBatch(final BufferFactory bufferFactory, final SqsClient sqsClient, @@ -51,11 +56,12 @@ public SqsSinkBatch(final BufferFactory bufferFactory, final String queueUrl, final OutputCodec codec, final OutputCodecContext codecContext, - final long maxMessageSize, - final int maxEvents) { - this.maxMessageSize = maxMessageSize; + final SqsThresholdConfig thresholdConfig, + final BiConsumer addToDLQList) { this.bufferFactory = bufferFactory; - this.maxEvents = maxEvents; + this.maxMessageSize = thresholdConfig.getMaxMessageSizeBytes(); + this.maxEvents = thresholdConfig.getMaxEventsPerMessage(); + this.addToDLQList = addToDLQList; this.codec = codec; this.sinkMetrics = sinkMetrics; this.codecContext = codecContext; @@ -66,6 +72,7 @@ public SqsSinkBatch(final BufferFactory bufferFactory, fifoQueue = queueUrl.endsWith(SQS_FIFO_SUFFIX); entries = new HashMap<>(); currentBatchEntry = null; + currentId = null; } public String getQueueUrl() { @@ -73,7 +80,7 @@ public String getQueueUrl() { } private boolean isFull() { - return entries.size() == MAX_MESSAGES_PER_BATCH && (currentBatchEntry.getEventCount() == maxEvents || currentBatchEntry.getSize() == maxMessageSize); + return entries.size() == MAX_MESSAGES_PER_BATCH && currentBatchEntry.getEventCount() == maxEvents; } public boolean willExceedLimits(long estimatedSize) { @@ -99,7 +106,12 @@ public boolean addEntry(final Event event, String groupId, String deDupId, final currentBatchEntry.addEvent(event); return isFull(); } else { - currentBatchEntry.complete(); + try { + currentBatchEntry.complete(); + } catch (IOException ex) { + addToDLQList.accept(currentBatchEntry, ex.getMessage()); + entries.remove(currentId); + } } } if (entries.size() == MAX_MESSAGES_PER_BATCH) { @@ -115,6 +127,7 @@ public boolean addEntry(final Event event, String groupId, String deDupId, final currentBatchEntry.addEvent(event); final String id = UUID.randomUUID().toString(); + currentId = id; entries.put(id, currentBatchEntry); return isFull(); } @@ -126,7 +139,16 @@ public long getLastFlushedTime() { public long getCurrentBatchSize() { long sum = 0; for (Map.Entry entry : entries.entrySet()) { - sum += entry.getValue().getSize(); + SqsSinkBatchEntry batchEntry = entry.getValue(); + sum += batchEntry.getSize(); + if (fifoQueue) { + if (batchEntry.getGroupId() != null) { + sum += batchEntry.getGroupId().length(); + } + if (batchEntry.getDedupId() != null) { + sum += batchEntry.getDedupId().length(); + } + } } return sum; } @@ -135,9 +157,16 @@ public int getEventCount() { return entries.values().stream().mapToInt(SqsSinkBatchEntry::getEventCount).sum(); } - public void setFlushReady() throws Exception { - for (Map.Entry entry: entries.entrySet()) { - entry.getValue().complete(); + public void setFlushReady() { + Iterator> iterator = entries.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + try { + entry.getValue().complete(); + } catch (IOException ex) { + addToDLQList.accept(entry.getValue(), ex.getMessage()); + iterator.remove(); + } } flushReady = true; } @@ -162,16 +191,31 @@ private boolean isRetryableException(SqsException e) { return (e instanceof RequestThrottledException); } - public boolean flushOnce(final BiConsumer addToDLQList) { - if (!isReady()) { + private long getEntrySize(SqsSinkBatchEntry entry) { + long result = entry.getBody().getBytes(StandardCharsets.UTF_8).length; + if (fifoQueue) { + if (entry.getGroupId() != null) { + result += entry.getGroupId().getBytes(StandardCharsets.UTF_8).length; + } + if (entry.getDedupId() != null) { + result += entry.getDedupId().getBytes(StandardCharsets.UTF_8).length; + } + } + return result; + } + + public boolean flushOnce() { + if (!isReady() || entries.size() == 0) { return true; } SendMessageBatchResponse flushResponse; List requestEntries = new ArrayList<>(); + long requestSize = 0; for (Map.Entry groupEntry: entries.entrySet()) { final String id = groupEntry.getKey(); final SqsSinkBatchEntry entry = groupEntry.getValue(); requestEntries.add(getRequestEntry(id, entry)); + requestSize += getEntrySize(entry); } SendMessageBatchRequest batchRequest = SendMessageBatchRequest.builder() @@ -180,6 +224,8 @@ public boolean flushOnce(final BiConsumer addToDLQLis .build(); try { flushResponse = sqsClient.sendMessageBatch(batchRequest); + sinkMetrics.recordRequestSize((double)requestSize); + } catch (SqsException e) { sinkMetrics.incrementRequestsFailedCounter(1); sinkMetrics.incrementEventsFailedCounter(entries.size()); diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchEntry.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchEntry.java index 8fdee5db33..1e65d0f3a0 100644 --- a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchEntry.java +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchEntry.java @@ -11,6 +11,7 @@ import org.opensearch.dataprepper.model.sink.OutputCodecContext; import org.opensearch.dataprepper.plugins.accumulator.Buffer; +import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -65,12 +66,12 @@ public long getSize() { return buffer.getSize(); } - public void complete() throws Exception { + public void complete() throws IOException { if (completed) { return; } - writer.complete(); completed = true; + writer.complete(); } diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkExecutor.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkExecutor.java index ced091c79b..de6dd85fb9 100644 --- a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkExecutor.java +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkExecutor.java @@ -15,14 +15,11 @@ import java.time.Duration; import java.util.Collection; -import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; - public abstract class SqsSinkExecutor { private static final Logger LOG = LoggerFactory.getLogger(SqsSinkExecutor.class); private static final long INITIAL_DELAY_MS = 10; private static final long MAXIMUM_DELAY_MS = Duration.ofMinutes(10).toMillis(); - public void execute(Collection> records) { if (records.isEmpty()) { lock(); @@ -33,6 +30,7 @@ public void execute(Collection> records) { } finally { unlock(); } + pushDLQList(); return; } lock(); @@ -66,23 +64,25 @@ public void flushBuffer() { Object failedStatus = null; int maxRetries = getMaxRetries(); final Backoff backoff = Backoff.exponential(INITIAL_DELAY_MS, MAXIMUM_DELAY_MS).withMaxAttempts(maxRetries); + long startTime = System.nanoTime(); while (retryCount <= maxRetries) { failedStatus = doFlushOnce(failedStatus); - if (failedStatus != null) { - final long delayMillis = backoff.nextDelayMillis(retryCount); - if (delayMillis < 0) { - break; - } - try { - Thread.sleep(delayMillis); - } catch (final InterruptedException e){ - LOG.error(NOISY, "Thread is interrupted while attempting to SQS with retry.", e); - } + if (failedStatus == null) { + break; } + final long delayMillis = backoff.nextDelayMillis(retryCount); + if (delayMillis < 0) { + break; + } + try { + Thread.sleep(delayMillis); + } catch (final InterruptedException e){} retryCount++; } if (failedStatus != null) { pushFailedObjectsToDlq(failedStatus); + } else { + recordLatency((double)System.nanoTime() - startTime); } } @@ -96,6 +96,7 @@ public void flushBuffer() { public abstract boolean willExceedMaxBatchSize(final Event event, final long estimatedSize) throws Exception; public abstract boolean exceedsMaxEventSizeThreshold(final long estimatedSize); public abstract long getEstimatedSize(final Event event) throws Exception; + public abstract void recordLatency(double latencyMillis); public abstract void lock(); public abstract void unlock(); diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkMetrics.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkMetrics.java index 7cf5a9a752..2e0d4cbd2d 100644 --- a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkMetrics.java +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkMetrics.java @@ -6,6 +6,7 @@ package org.opensearch.dataprepper.plugins.sink.sqs; import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.DistributionSummary; import org.opensearch.dataprepper.metrics.PluginMetrics; public class SqsSinkMetrics { @@ -13,16 +14,22 @@ public class SqsSinkMetrics { public static final String SQS_SINK_EVENTS_SUCCEEDED = "sqsSinkEventsSucceeded"; public static final String SQS_SINK_EVENTS_FAILED = "sqsSinkEventsFailed"; public static final String SQS_SINK_REQUESTS_FAILED = "sqsSinkRequestsFailed"; + public static final String SQS_SINK_REQUEST_LATENCY = "sqsSinkRequestLatency"; + public static final String SQS_SINK_REQUEST_SIZE = "sqsSinkRequestSize"; private final Counter sqsSinkRequestsSucceeded; private final Counter sqsSinkEventsSucceeded; private final Counter sqsSinkRequestsFailed; private final Counter sqsSinkEventsFailed; + private final DistributionSummary sqsSinkRequestLatency; + private final DistributionSummary sqsSinkRequestSize; public SqsSinkMetrics(final PluginMetrics pluginMetrics) { this.sqsSinkRequestsSucceeded = pluginMetrics.counter(SQS_SINK_REQUESTS_SUCCEEDED); this.sqsSinkEventsSucceeded = pluginMetrics.counter(SQS_SINK_EVENTS_SUCCEEDED); this.sqsSinkRequestsFailed = pluginMetrics.counter(SQS_SINK_REQUESTS_FAILED); this.sqsSinkEventsFailed = pluginMetrics.counter(SQS_SINK_EVENTS_FAILED); + this.sqsSinkRequestLatency = pluginMetrics.summary(SQS_SINK_REQUEST_LATENCY); + this.sqsSinkRequestSize = pluginMetrics.summary(SQS_SINK_REQUEST_SIZE); } public void incrementEventsSuccessCounter(int value) { @@ -40,4 +47,12 @@ public void incrementEventsFailedCounter(int value) { public void incrementRequestsFailedCounter(int value) { sqsSinkRequestsFailed.increment(value); } + + public void recordRequestLatency(double value) { + sqsSinkRequestLatency.record(value); + } + + public void recordRequestSize(double value) { + sqsSinkRequestSize.record(value); + } } diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkService.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkService.java index c3d027a585..79f5e21805 100644 --- a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkService.java +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkService.java @@ -111,13 +111,12 @@ public boolean exceedsMaxEventSizeThreshold(final long estimatedSize) { @Override public void pushDLQList() { + // If DLQ push handler is null, dlqObjects list + // would be empty if (dlqObjects.size() == 0) { return; } - boolean result = false; - if (dlqPushHandler != null) { - result = dlqPushHandler.perform(dlqObjects); - } + boolean result = dlqPushHandler.perform(dlqObjects); for (final DlqObject dlqObject : dlqObjects) { dlqObject.releaseEventHandles(result); } @@ -150,22 +149,11 @@ public boolean willExceedMaxBatchSize(final Event event, final long estimatedSiz return false; boolean result = batch.willExceedLimits(estimatedSize); if (result) { - setFlushReady(qUrl, batch); + batch.setFlushReady(); } return result; } - private boolean doFlushBatch(SqsSinkBatch batch) { - boolean flushSuccess = batch.flushOnce( - (batchEntry, exceptionMessage ) -> { - addBatchEntryToDLQ(batchEntry, exceptionMessage); - }); - - // Sending to DLQ is also considered success (because no - // retry needed) - return flushSuccess; - } - @Override public Object doFlushOnce(Object previousFailedBatches) { List failedBatches = new ArrayList<>(); @@ -173,7 +161,7 @@ public Object doFlushOnce(Object previousFailedBatches) { if (previousFailedBatches != null) { List pFailedBatches = (List) previousFailedBatches; for (SqsSinkBatch failedBatch: pFailedBatches) { - if (!doFlushBatch(failedBatch)) { + if (!failedBatch.flushOnce()) { failedBatches.add(failedBatch); } else { successQueueUrls.add(failedBatch.getQueueUrl()); @@ -185,7 +173,7 @@ public Object doFlushOnce(Object previousFailedBatches) { Map.Entry qUrlEntry = iterator.next(); SqsSinkBatch batch = qUrlEntry.getValue(); if (batch.isReady()) { - if (!doFlushBatch(batch)) { + if (!batch.flushOnce()) { failedBatches.add(batch); } else { successQueueUrls.add(batch.getQueueUrl()); @@ -205,12 +193,7 @@ private String getQueueUrl(final Event event, boolean logError) { String qUrl = queueUrl; if (isDynamicQueueUrl) { try { - Object obj = event.formatString(queueUrl, expressionEvaluator); - if (obj instanceof String) { - qUrl = (String) obj; - } else { - throw new RuntimeException("Evaluated queue url is not a string"); - } + qUrl = event.formatString(queueUrl, expressionEvaluator); } catch (Exception e) { qUrl = null; if (logError) { @@ -226,12 +209,7 @@ private String getGroupId(final Event event) { String gId = groupId; if (isDynamicGroupId) { try { - Object obj = event.formatString(groupId, expressionEvaluator); - if (obj instanceof String) { - gId = (String) obj; - } else { - throw new RuntimeException("Evaluated group id is not a string"); - } + gId = event.formatString(groupId, expressionEvaluator); } catch (Exception e) { LOG.error(NOISY, "Invalid groupId expression {}, using random groupId ", e.getMessage()); } @@ -243,12 +221,7 @@ private String getDeDupId(final Event event) { String ddId = deDupId; if (isDynamicDeDupId) { try { - Object obj = event.formatString(deDupId, expressionEvaluator); - if (obj instanceof String) { - ddId = (String) obj; - } else { - throw new RuntimeException("Evaluated deduplicate id is not a string"); - } + ddId = event.formatString(deDupId, expressionEvaluator); } catch (Exception e) { LOG.error(NOISY, "Invalid deDupId expression {}, using random deDupId ", e.getMessage()); } @@ -275,7 +248,9 @@ public boolean addToBuffer(final Event event, final long estimatedSize) throws E SqsSinkBatch batch = batchUrlMap.get(qUrl); if (batch == null) { final OutputCodecContext codecContext = OutputCodecContext.fromSinkContext(sinkContext); - batch = new SqsSinkBatch(inMemoryBufferFactory, sqsClient, sinkMetrics, qUrl, codec, codecContext, thresholdConfig.getMaxMessageSizeBytes(), thresholdConfig.getMaxEventsPerMessage()); + batch = new SqsSinkBatch(inMemoryBufferFactory, sqsClient, sinkMetrics, qUrl, codec, codecContext, thresholdConfig, (batchEntry, exceptionMessage ) -> { + addBatchEntryToDLQ(batchEntry, exceptionMessage); + }); batchUrlMap.put(qUrl, batch); } @@ -283,36 +258,23 @@ public boolean addToBuffer(final Event event, final long estimatedSize) throws E String ddId = getDeDupId(event); boolean isFull = batch.addEntry(event, gId, ddId, estimatedSize); if (isFull) { - setFlushReady(qUrl, batch); + batch.setFlushReady(); } return isFull; } - private boolean setFlushReady(final String queueUrl, final SqsSinkBatch batch) { - try { - batch.setFlushReady(); - return true; - } catch (Exception e) { - for (Map.Entry entry: batch.getEntries().entrySet()) { - addBatchEntryToDLQ(entry.getValue(), "Failed to setFlushReady for the batch"); - } - batchUrlMap.remove(queueUrl); - return false; - } - } @Override public boolean exceedsFlushTimeInterval() { long now = Instant.now().getEpochSecond(); boolean result = false; - Iterator> iterator = batchUrlMap.entrySet().iterator(); - while (iterator.hasNext()) { - Map.Entry qUrlEntry = iterator.next(); + for (Map.Entry qUrlEntry: batchUrlMap.entrySet()) { String qUrl = qUrlEntry.getKey(); SqsSinkBatch batch = qUrlEntry.getValue(); if (now - batch.getLastFlushedTime() > thresholdConfig.getFlushInterval()) { - result = result || setFlushReady(qUrl, batch); + batch.setFlushReady(); + result = true; } } return result; @@ -334,6 +296,11 @@ private void addMessageToDLQ(final String message, final List event } } + @Override + public void recordLatency(double latencyMillis) { + sinkMetrics.recordRequestLatency((double)latencyMillis); + } + @Override public void addEventToDLQList(final Event event, Throwable ex) { List eventHandles = new ArrayList<>(); diff --git a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchEntryTest.java b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchEntryTest.java index 1bd083cec1..90f0441f5a 100644 --- a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchEntryTest.java +++ b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchEntryTest.java @@ -16,6 +16,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.hamcrest.CoreMatchers.equalTo; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -123,6 +124,18 @@ void TestAddingMultipleEvents(int numRecords) throws Exception { assertThat(sqsSinkBatchEntry.getSize(), equalTo(expectedSize)); assertThat(sqsSinkBatchEntry.getEventHandles().size(), equalTo(numRecords)); } + + @Test + public void TestAddingToCompletedEntry() throws Exception { + SqsSinkBatchEntry sqsSinkBatchEntry = createObjectUnderTest(); + List> records = getRecordList(1); + Event event = records.get(0).getData(); + sqsSinkBatchEntry.addEvent(event); + sqsSinkBatchEntry.complete(); + List> newRecords = getRecordList(1); + Event newEvent = records.get(0).getData(); + assertThrows(RuntimeException.class, () -> sqsSinkBatchEntry.addEvent(newEvent)); + } private List> getRecordList(int numberOfRecords) { final List> recordList = new ArrayList<>(); diff --git a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchTest.java b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchTest.java index 19db2d5fab..e20cfa9d95 100644 --- a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchTest.java +++ b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchTest.java @@ -13,9 +13,12 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.times; @@ -23,6 +26,7 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.mockito.ArgumentMatchers.any; import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry; import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse; import software.amazon.awssdk.services.sqs.SqsClient; @@ -39,10 +43,12 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.hamcrest.Matchers.greaterThan; +import org.apache.commons.lang3.RandomStringUtils; import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; @@ -55,6 +61,8 @@ public class SqsSinkBatchTest { private SqsClient sqsClient; @Mock private SqsException sqsException; + @Mock + private SqsThresholdConfig sqsThresholdConfig; private AtomicInteger eventsSuccessCount; private AtomicInteger requestsSuccessCount; @@ -77,11 +85,15 @@ public class SqsSinkBatchTest { private String queueUrl; private SqsSinkBatch createObjectUnderTest() { - return new SqsSinkBatch(bufferFactory, sqsClient, sinkMetrics, queueUrl, outputCodec, outputCodecContext, maxMessageSize, maxEvents); + return new SqsSinkBatch(bufferFactory, sqsClient, sinkMetrics, queueUrl, outputCodec, outputCodecContext, sqsThresholdConfig, (a, b) -> {}); } @BeforeEach void setup() { + sqsThresholdConfig = mock(SqsThresholdConfig.class); + maxMessageSize = 256 * 1024; + when(sqsThresholdConfig.getMaxMessageSizeBytes()).thenReturn(maxMessageSize); + when(sqsThresholdConfig.getMaxEventsPerMessage()).thenReturn(1); eventsSuccessCount = new AtomicInteger(0); requestsSuccessCount = new AtomicInteger(0); eventsFailedCount = new AtomicInteger(0); @@ -117,7 +129,6 @@ void setup() { flushResponse = mock(SendMessageBatchResponse.class); outputCodec = new JsonOutputCodec(new JsonOutputCodecConfig()); outputCodecContext = new OutputCodecContext(); - maxMessageSize = 256 * 1024; } @Test @@ -129,9 +140,10 @@ void TestBasic() { assertThat(batch.getEntries().size(), equalTo(0)); } - @Test - void TestOneBatch_WithOneEventPerMessage_WithSuccessfulSendMessage() throws Exception { - maxEvents = 1; + @ParameterizedTest + @ValueSource(strings= {"", ".fifo"}) + void TestOneBatch_WithOneEventPerMessage_WithSuccessfulSendMessage(String queueUrlSuffix) throws Exception { + queueUrl += queueUrlSuffix; batch = createObjectUnderTest(); groupId = UUID.randomUUID().toString(); String dedupId = UUID.randomUUID().toString(); @@ -163,16 +175,150 @@ void TestOneBatch_WithOneEventPerMessage_WithSuccessfulSendMessage() throws Exce assertTrue(batch.willExceedLimits(1L)); assertThat(batch.getCurrentBatchSize(), greaterThan(minSize)); assertThat(batch.getEventCount(), equalTo(SqsSinkBatch.MAX_MESSAGES_PER_BATCH)); - boolean flushResult = batch.flushOnce(null); + boolean flushResult = batch.flushOnce(); + assertTrue(flushResult); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo(1)); + verify(eventHandle, times(numRecords)).release(true); + } + + @Test + void TestOneBatch_WithMultipleEventsPerMessage_WithSuccessfulSendMessage() throws Exception { + when(sqsThresholdConfig.getMaxEventsPerMessage()).thenReturn(2); + batch = createObjectUnderTest(); + groupId = UUID.randomUUID().toString(); + String dedupId = UUID.randomUUID().toString(); + + when(flushResponse.hasFailed()).thenReturn(false); + when(sqsClient.sendMessageBatch(any(SendMessageBatchRequest.class))).thenReturn(flushResponse); + final int numRecords = 2*SqsSinkBatch.MAX_MESSAGES_PER_BATCH; + List> records = getRecordList(numRecords); + long minSize = 0; + for (int i = 0; i < numRecords; i++) { + + Event event = records.get(i%numRecords).getData(); + minSize += event.toJsonString().length(); + long eSize = outputCodec.getEstimatedSize(event, new OutputCodecContext()); + boolean result = batch.addEntry(event, groupId, dedupId, eSize); + if (i < numRecords-1) { + assertFalse(result); + } else { + assertTrue(result); + } + } + assertThat(batch.getEntries().size(), equalTo(numRecords/2)); + batch.setFlushReady(); + assertTrue(batch.willExceedLimits(1L)); + assertThat(batch.getCurrentBatchSize(), greaterThan(minSize)); + assertThat(batch.getEventCount(), equalTo(2*SqsSinkBatch.MAX_MESSAGES_PER_BATCH)); + boolean flushResult = batch.flushOnce(); assertTrue(flushResult); assertThat(eventsSuccessCount.get(), equalTo(numRecords)); assertThat(requestsSuccessCount.get(), equalTo(1)); verify(eventHandle, times(numRecords)).release(true); } + @Test + void TestOneBatch_WithMultipleEventsPerMessageWithMessageSize_WithSuccessfulSendMessage() throws Exception { + when(sqsThresholdConfig.getMaxMessageSizeBytes()).thenReturn(95L); + when(sqsThresholdConfig.getMaxEventsPerMessage()).thenReturn(3); + batch = createObjectUnderTest(); + assertTrue(batch.flushOnce()); + batch.setFlushReady(); + assertTrue(batch.flushOnce()); + groupId = UUID.randomUUID().toString(); + String dedupId = UUID.randomUUID().toString(); + + when(flushResponse.hasFailed()).thenReturn(false); + when(sqsClient.sendMessageBatch(any(SendMessageBatchRequest.class))).thenReturn(flushResponse); + final int numRecords = 2*SqsSinkBatch.MAX_MESSAGES_PER_BATCH; + List> records = getRecordList(numRecords); + long minSize = 0; + for (int i = 0; i < numRecords; i++) { + + Event event = records.get(i%numRecords).getData(); + minSize += event.toJsonString().length(); + long eSize = outputCodec.getEstimatedSize(event, new OutputCodecContext()); + assertFalse(batch.willExceedLimits(45L)); + boolean result = batch.addEntry(event, groupId, dedupId, eSize); + assertFalse(result); + } + assertThat(batch.getEntries().size(), equalTo(numRecords/2)); + batch.setFlushReady(); + assertTrue(batch.willExceedLimits(45L)); + assertThat(batch.getCurrentBatchSize(), greaterThan(minSize)); + assertThat(batch.getEventCount(), equalTo(2*SqsSinkBatch.MAX_MESSAGES_PER_BATCH)); + boolean flushResult = batch.flushOnce(); + assertTrue(flushResult); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo(1)); + verify(eventHandle, times(numRecords)).release(true); + } + + @Test + void TestBatchFull() throws Exception { + batch = createObjectUnderTest(); + final int numRecords = SqsSinkBatch.MAX_MESSAGES_PER_BATCH; + List> records = getRecordList(numRecords); + for (int i = 0; i < records.size(); i++) { + Event event = records.get(i).getData(); + long eSize = outputCodec.getEstimatedSize(event, new OutputCodecContext()); + boolean isFull = batch.addEntry(event, null, null, eSize); + if (i == records.size()-1) { + assertTrue(isFull); + } else { + assertFalse(isFull); + } + } + } + + private Record getLargeRecord(int size) { + final Event event = JacksonLog.builder() + .withData(Map.of("key", RandomStringUtils.randomAlphabetic(size))) + .withEventHandle(eventHandle) + .build(); + return new Record<>(event); + } + + @Test + void TestOneBatch_WithOneEventPerMessage_WithFlushFailure_NotSenderFault() throws Exception { + batch = createObjectUnderTest(); + groupId = UUID.randomUUID().toString(); + String dedupId = UUID.randomUUID().toString(); + + SendMessageBatchResponse flushResponse = mock(SendMessageBatchResponse.class); + when(flushResponse.hasFailed()).thenReturn(true); + BatchResultErrorEntry errorEntry = mock(BatchResultErrorEntry.class); + when(errorEntry.id()).thenReturn(UUID.randomUUID().toString()); + when(errorEntry.message()).thenReturn(UUID.randomUUID().toString()); + when(errorEntry.senderFault()).thenReturn(false); + List errorEntries = new ArrayList<>(); + doAnswer((a) -> { + SendMessageBatchRequest batchRequest = (SendMessageBatchRequest)a.getArgument(0); + List entries = batchRequest.entries(); + for (final SendMessageBatchRequestEntry entry: entries) { + errorEntries.add(BatchResultErrorEntry.builder().id(entry.id()).senderFault(false).message(entry.messageBody()).build()); + } + return flushResponse; + }).when(sqsClient).sendMessageBatch(any(SendMessageBatchRequest.class)); + when(flushResponse.failed()).thenReturn(errorEntries); + final int numRecords = SqsSinkBatch.MAX_MESSAGES_PER_BATCH; + List> records = getRecordList(numRecords); + for (int i = 0; i < numRecords; i++) { + + Event event = records.get(i).getData(); + long eSize = outputCodec.getEstimatedSize(event, new OutputCodecContext()); + batch.addEntry(event, groupId, dedupId, eSize); + } + assertThat(batch.getEntries().size(), equalTo(numRecords)); + batch.setFlushReady(); + assertThat(batch.getEventCount(), equalTo(SqsSinkBatch.MAX_MESSAGES_PER_BATCH)); + boolean flushResult = batch.flushOnce(); + assertFalse(flushResult); + } + @Test void TestOneBatch_WithOneEventPerMessage_WithFlushFailure() throws Exception { - maxEvents = 1; batch = createObjectUnderTest(); groupId = UUID.randomUUID().toString(); String dedupId = UUID.randomUUID().toString(); @@ -200,7 +346,7 @@ void TestOneBatch_WithOneEventPerMessage_WithFlushFailure() throws Exception { assertTrue(batch.willExceedLimits(1L)); assertThat(batch.getCurrentBatchSize(), greaterThan(minSize)); assertThat(batch.getEventCount(), equalTo(SqsSinkBatch.MAX_MESSAGES_PER_BATCH)); - boolean flushResult = batch.flushOnce((e, m) -> {}); + boolean flushResult = batch.flushOnce(); // all entries sent to DLQ assertTrue(flushResult); } @@ -208,7 +354,6 @@ void TestOneBatch_WithOneEventPerMessage_WithFlushFailure() throws Exception { @Test void TestOneBatch_WithOneEventPerMessage_WithFlushExceptionFailure() throws Exception { - maxEvents = 1; batch = createObjectUnderTest(); groupId = UUID.randomUUID().toString(); String dedupId = UUID.randomUUID().toString(); @@ -229,13 +374,12 @@ void TestOneBatch_WithOneEventPerMessage_WithFlushExceptionFailure() throws Exce assertTrue(batch.willExceedLimits(1L)); assertThat(batch.getCurrentBatchSize(), greaterThan(minSize)); assertThat(batch.getEventCount(), equalTo(SqsSinkBatch.MAX_MESSAGES_PER_BATCH)); - boolean flushResult = batch.flushOnce((e, m) -> {}); + boolean flushResult = batch.flushOnce(); assertFalse(flushResult); } @Test void TestOneBatch_WithOneEventPerMessage_WithFlushException() throws Exception { - maxEvents = 1; batch = createObjectUnderTest(); groupId = UUID.randomUUID().toString(); String dedupId = UUID.randomUUID().toString(); @@ -256,7 +400,7 @@ void TestOneBatch_WithOneEventPerMessage_WithFlushException() throws Exception { assertTrue(batch.willExceedLimits(1L)); assertThat(batch.getCurrentBatchSize(), greaterThan(minSize)); assertThat(batch.getEventCount(), equalTo(SqsSinkBatch.MAX_MESSAGES_PER_BATCH)); - boolean flushResult = batch.flushOnce((e, m) -> {}); + boolean flushResult = batch.flushOnce(); assertTrue(flushResult); assertThat(eventsFailedCount.get(), equalTo(numRecords)); assertThat(requestsFailedCount.get(), equalTo(1)); diff --git a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkDlqDataTest.java b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkDlqDataTest.java index 35503b01ba..6b73b2d8ac 100644 --- a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkDlqDataTest.java +++ b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkDlqDataTest.java @@ -9,6 +9,7 @@ import org.junit.jupiter.api.Test; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.Matchers.equalTo; @@ -43,6 +44,12 @@ void TestEquals() { sqsSinkDlqData = createObjectUnderTest(message, data); SqsSinkDlqData sqsSinkDlqData2 = createObjectUnderTest(message, data); assertTrue(sqsSinkDlqData.equals(sqsSinkDlqData2)); + assertTrue(sqsSinkDlqData.equals(sqsSinkDlqData)); + SqsSinkDlqData sqsSinkDlqData3 = createObjectUnderTest(message, data+data); + assertFalse(sqsSinkDlqData.equals(sqsSinkDlqData3)); + Integer testInteger = 5; + assertFalse(sqsSinkDlqData.equals(testInteger)); + assertFalse(sqsSinkDlqData.equals(null)); } } diff --git a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkServiceTest.java b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkServiceTest.java index f42c4a2d05..b051ebc50e 100644 --- a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkServiceTest.java +++ b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkServiceTest.java @@ -13,7 +13,9 @@ import org.opensearch.dataprepper.plugins.dlq.DlqPushHandler; import org.opensearch.dataprepper.expression.ExpressionEvaluator; +import java.io.IOException; import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.DistributionSummary; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -22,6 +24,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.times; import static org.hamcrest.MatcherAssert.assertThat; @@ -39,6 +42,7 @@ import org.apache.commons.lang3.RandomStringUtils; import software.amazon.awssdk.services.sqs.SqsClient; import software.amazon.awssdk.services.sqs.model.RequestThrottledException; +import software.amazon.awssdk.services.sqs.model.UnsupportedOperationException; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.event.EventHandle; import org.opensearch.dataprepper.model.log.JacksonLog; @@ -90,6 +94,9 @@ public class SqsSinkServiceTest { private Counter requestsFailedCounter; @Mock private Counter dlqSuccessCounter; + @Mock + private DistributionSummary summary; + private AtomicInteger eventsSuccessCount; private AtomicInteger requestsSuccessCount; private AtomicInteger eventsFailedCount; @@ -146,6 +153,8 @@ void setup() { requestsSuccessCounter = mock(Counter.class); requestsFailedCounter = mock(Counter.class); dlqSuccessCounter = mock(Counter.class); + summary = mock(DistributionSummary.class); + doNothing().when(summary).record(any(Double.class)); lenient().doAnswer((a)-> { int v = (int)(double)(a.getArgument(0)); eventsSuccessCount.addAndGet(v); @@ -194,6 +203,11 @@ void setup() { } return null; }).when(pluginMetrics).counter(anyString()); + + lenient().doAnswer(a -> { + return summary; + }).when(pluginMetrics).summary(anyString()); + } @Test @@ -204,6 +218,22 @@ void TestBasic() { assertFalse(sqsSinkService.exceedsMaxEventSizeThreshold(256*1024)); } + @Test + void TestWithInvalidQueueUrlMissingFieldInEvent() { + int numRecords = 10; + when (sqsSinkConfig.getQueueUrl()).thenReturn(queueUrl+"${/abcd}"); + when (thresholdConfig.getFlushInterval()).thenReturn(2L); + SqsSinkService sqsSinkService = createObjectUnderTest(); + List> records = getRecordList(numRecords); + sqsSinkService.execute(records); + await().atMost(Duration.ofSeconds(30)) + .untilAsserted(() -> { + sqsSinkService.execute(Collections.emptyList()); + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(0)); + }); + } + @ParameterizedTest @ValueSource(ints = {9, 29, 49, 69}) void TestExecuteWithOneBatch_FlushTimeout(int numRecords) throws Exception { @@ -237,6 +267,42 @@ void TestExecuteOneBatch_WithLargeRecords(int numRecords) throws Exception { }); } + @ParameterizedTest + @ValueSource(ints = {18, 36, 54, 72}) + void TestExecuteMultipleBatches(int numRecords) throws Exception { + OutputCodec mOutputCodec = mock(OutputCodec.class); + OutputCodec.Writer mWriter = mock(OutputCodec.Writer.class); + lenient().doAnswer((a)-> { + throw new IOException("IO Exception"); + }).when(mWriter).complete(); + when(mOutputCodec.createWriter(any(), eq(null), any(OutputCodecContext.class))).thenReturn(mWriter); + outputCodec = mOutputCodec; + when (thresholdConfig.getFlushInterval()).thenReturn(2L); + when (sqsSinkConfig.getQueueUrl()).thenReturn(queueUrl+"${/id}"); + SqsSinkService sqsSinkService = createObjectUnderTest(); + List> records = getRecordList(numRecords); + sqsSinkService.execute(records); + await().atMost(Duration.ofSeconds(30)) + .untilAsserted(() -> { + sqsSinkService.execute(Collections.emptyList()); + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(0)); + assertThat(requestsSuccessCount.get(), equalTo(0)); + verify(eventHandle, times(numRecords)).release(true); + }); + } + + @Test + void TestLargeRecordToNoDLQ() { + dlqPushHandler = null; + SqsSinkService sqsSinkService = createObjectUnderTest(); + Record record = getLargeRecord(300*1024); + sqsSinkService.execute(List.of(record)); + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(0)); + assertThat(requestsSuccessCount.get(), equalTo(0)); + } + @Test void TestLargeRecordToDLQ() { SqsSinkService sqsSinkService = createObjectUnderTest(); @@ -271,7 +337,22 @@ void TestExecuteWithOneBatch_SuccessfulFlush(int numRecords) throws Exception { assertThat(requestsSuccessCount.get(), equalTo(numRecords/10)); verify(eventHandle, times(numRecords)).release(true); } - + + @Test + void TestSendingToDLQAfterNonRetryableException() { + final int numRecords = 10; + UnsupportedOperationException unsupportedOperationException = mock(UnsupportedOperationException.class); + when(unsupportedOperationException.getMessage()).thenReturn("Unsupported operation"); + when(sqsClient.sendMessageBatch(any(SendMessageBatchRequest.class))).thenThrow(unsupportedOperationException); + SqsSinkService sqsSinkService = createObjectUnderTest(); + List> records = getRecordList(numRecords); + sqsSinkService.execute(records); + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(0)); + assertThat(requestsSuccessCount.get(), equalTo(0)); + verify(eventHandle, times(numRecords)).release(true); + } + @Test void TestSendingToDLQAfterMultipleRetries() { final int numRecords = 10; @@ -300,6 +381,23 @@ void TestExecuteWithOneBatch_MultipleRetries(int numRecords) throws Exception { verify(eventHandle, times(numRecords)).release(true); } + @Test + void TestFiFoQWithEventsWithInvalidExpression() throws Exception { + int numRecords = 10; + when(expressionEvaluator.isValidFormatExpression(anyString())).thenReturn(true); + when (sqsSinkConfig.getQueueUrl()).thenReturn(queueUrl+".fifo"); + when (sqsSinkConfig.getDeDuplicationId()).thenReturn(UUID.randomUUID().toString()+"${/ident}"); + when (sqsSinkConfig.getGroupId()).thenReturn(UUID.randomUUID().toString()+"${/ident}"); + SqsSinkService sqsSinkService = createObjectUnderTest(); + List> records = getRecordList(numRecords); + for (int i = 0; i < numRecords-1; i++) { + Event event = records.get(i).getData(); + long eSize = outputCodec.getEstimatedSize(event, new OutputCodecContext()); + boolean isFull = sqsSinkService.addToBuffer(event, eSize); + assertFalse(isFull); + } + } + @Test void TestFiFoQWithInvalidDeDupIdExpression() { when(expressionEvaluator.isValidFormatExpression(anyString())).thenReturn(false); @@ -308,6 +406,13 @@ void TestFiFoQWithInvalidDeDupIdExpression() { assertThrows(IllegalArgumentException.class, ()-> createObjectUnderTest()); } + @Test + void TestFiFoQWithInvalidQueueUrlExpression() { + when(expressionEvaluator.isValidFormatExpression(anyString())).thenReturn(false); + when (sqsSinkConfig.getQueueUrl()).thenReturn(queueUrl+"${id - }"+".fifo"); + assertThrows(IllegalArgumentException.class, ()-> createObjectUnderTest()); + } + @Test void TestFiFoQWithInvalidGroupIdExpression() { when(expressionEvaluator.isValidFormatExpression(anyString())).thenReturn(false); @@ -412,7 +517,7 @@ void TestWithOneBatch_RetryFlushes() throws Exception { assertThat(flushResult, not(equalTo(null))); assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(1)); when(sqsClient.sendMessageBatch(any(SendMessageBatchRequest.class))).thenReturn(flushResponse); - flushResult = sqsSinkService.doFlushOnce(null); + flushResult = sqsSinkService.doFlushOnce(null); assertThat(flushResult, equalTo(null)); } } @@ -454,11 +559,12 @@ private static List generateRecords(int numberOfRecords) { for (int rows = 0; rows < numberOfRecords; rows++) { - HashMap eventData = new HashMap<>(); + HashMap eventData = new HashMap<>(); eventData.put("name", "Person" + rows); eventData.put("age", Integer.toString(rows)); eventData.put("id", Integer.toString(rows%2)); + eventData.put("idx", rows); recordList.add(eventData); } diff --git a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkTest.java b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkTest.java index b08be2ec0f..5b5c716908 100644 --- a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkTest.java +++ b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkTest.java @@ -28,11 +28,17 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.mockStatic; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import org.opensearch.dataprepper.plugins.dlq.DlqProvider; import org.opensearch.dataprepper.plugins.source.sqs.common.SqsClientFactory; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.StsClientBuilder; +import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; +import software.amazon.awssdk.services.sts.model.GetCallerIdentityResponse; import org.opensearch.dataprepper.aws.api.AwsConfig; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.event.JacksonEvent; @@ -46,6 +52,8 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Map; +import java.util.HashMap; +import java.util.Optional; import java.util.UUID; public class SqsSinkTest { @@ -87,8 +95,10 @@ void setup() { when(sinkContext.getTagsTargetKey()).thenReturn(null); sqsClient = mock(SqsClient.class); expressionEvaluator = mock(ExpressionEvaluator.class); - awsCredentialsSupplier = mock(AwsCredentialsSupplier.class); awsCredentialsProvider = mock(AwsCredentialsProvider.class); + awsCredentialsSupplier = mock(AwsCredentialsSupplier.class); + when(awsCredentialsSupplier.getProvider(any())).thenReturn(awsCredentialsProvider); + when(awsCredentialsSupplier.getDefaultRegion()).thenReturn(Optional.of(Region.of("us-west-2"))); when(sqsSinkConfig.getDlq()).thenReturn(null); codecConfig = mock(PluginModel.class); when(codecConfig.getPluginName()).thenReturn(TEST_CODEC_PLUGIN_NAME); @@ -122,23 +132,81 @@ void TestBasic() { } @Test - void TestWithInvalidCodec() { - when(codecConfig.getPluginName()).thenReturn("badCodec"); - awsCredentialsSupplier = null; + void TestBasicWithNullAwsConfig() { when(sqsSinkConfig.getAwsConfig()).thenReturn(null); try(MockedStatic mockedStatic = mockStatic(SqsClientFactory.class)) { mockedStatic.when(() -> SqsClientFactory.createSqsClient(any(Region.class), any(AwsCredentialsProvider.class))) .thenReturn(sqsClient); - assertThrows(RuntimeException.class, ()-> createObjectUnderTest()); + SqsSink sqsSink = createObjectUnderTest(); + sqsSink.doInitialize(); + assertTrue(sqsSink.isReady()); } } @Test - void TestWithNullAwsConfig() { - awsCredentialsSupplier = null; - when(sqsSinkConfig.getAwsConfig()).thenReturn(null); + void TestBasicNdJsonCodec() { + when(codecConfig.getPluginName()).thenReturn("ndjson"); + try(MockedStatic mockedStatic = mockStatic(SqsClientFactory.class)) { + mockedStatic.when(() -> SqsClientFactory.createSqsClient(any(Region.class), + any(AwsCredentialsProvider.class))) + .thenReturn(sqsClient); + + SqsSink sqsSink = createObjectUnderTest(); + sqsSink.doInitialize(); + assertTrue(sqsSink.isReady()); + } + } + + @Test + void TestWithDLQConfig() { + awsConfig = mock(AwsConfig.class); + when(awsConfig.getAwsRegion()).thenReturn(Region.of("us-west-2")); + when(sqsSinkConfig.getAwsConfig()).thenReturn(awsConfig); + PluginModel dlqConfig = mock(PluginModel.class); + when(dlqConfig.getPluginSettings()).thenReturn(new HashMap()); + when(dlqConfig.getPluginName()).thenReturn("testDlqPlugin"); + DlqProvider dlqProvider = mock(DlqProvider.class); + + StsClient stsClient = mock(StsClient.class); + StsClientBuilder stsClientBuilder = mock(StsClientBuilder.class); + when(stsClientBuilder.build()).thenReturn(stsClient); + when(stsClientBuilder.region(any())).thenReturn(stsClientBuilder); + when(stsClientBuilder.credentialsProvider(any())).thenReturn(stsClientBuilder); + + GetCallerIdentityResponse identityResponse = mock(GetCallerIdentityResponse.class); + when(identityResponse.arn()).thenReturn("arn"); + when(stsClient.getCallerIdentity()).thenReturn(identityResponse); + + when(pluginFactory.loadPlugin(eq(DlqProvider.class), any())).thenReturn(dlqProvider); + + when(sqsSinkConfig.getDlq()).thenReturn(dlqConfig); + try (final MockedStatic stsClientMockedStatic = mockStatic(StsClient.class); + final MockedStatic assumeRoleRequestMockedStatic = mockStatic(AssumeRoleRequest.class)) { + stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder); + final AssumeRoleRequest.Builder assumeRoleRequestBuilder = mock(AssumeRoleRequest.Builder.class); + when(assumeRoleRequestBuilder.roleSessionName(anyString())) + .thenReturn(assumeRoleRequestBuilder); + when(assumeRoleRequestBuilder.roleArn(anyString())) + .thenReturn(assumeRoleRequestBuilder); + assumeRoleRequestMockedStatic.when(AssumeRoleRequest::builder).thenReturn(assumeRoleRequestBuilder); + + try(MockedStatic mockedStatic = mockStatic(SqsClientFactory.class)) { + mockedStatic.when(() -> SqsClientFactory.createSqsClient(any(Region.class), + any(AwsCredentialsProvider.class))) + .thenReturn(sqsClient); + + SqsSink sqsSink = createObjectUnderTest(); + sqsSink.doInitialize(); + assertTrue(sqsSink.isReady()); + } + } + } + + @Test + void TestWithInvalidCodec() { + when(codecConfig.getPluginName()).thenReturn("badCodec"); try(MockedStatic mockedStatic = mockStatic(SqsClientFactory.class)) { mockedStatic.when(() -> SqsClientFactory.createSqsClient(any(Region.class), any(AwsCredentialsProvider.class))) @@ -162,6 +230,20 @@ void TestForDefaultCodec() { } } + @Test + void TestForNullCodec() { + when(sqsSinkConfig.getCodec()).thenReturn(null); + try(MockedStatic mockedStatic = mockStatic(SqsClientFactory.class)) { + mockedStatic.when(() -> SqsClientFactory.createSqsClient(any(Region.class), + any(AwsCredentialsProvider.class))) + .thenReturn(sqsClient); + + SqsSink sqsSink = createObjectUnderTest(); + sqsSink.doInitialize(); + assertTrue(sqsSink.isReady()); + } + } + @Test void TestSinkOutputWithEvents() { try(MockedStatic mockedStatic = mockStatic(SqsClientFactory.class)) {