diff --git a/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/failures/DlqObject.java b/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/failures/DlqObject.java index 7cc3200bd4..d401c44e1c 100644 --- a/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/failures/DlqObject.java +++ b/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/failures/DlqObject.java @@ -15,6 +15,8 @@ import java.time.ZoneId; import java.time.format.DateTimeFormatter; import java.util.Objects; +import java.util.List; +import java.util.ArrayList; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; @@ -43,8 +45,11 @@ public class DlqObject { @JsonIgnore private final EventHandle eventHandle; + @JsonIgnore + private final List eventHandles; + private DlqObject(final String pluginId, final String pluginName, final String pipelineName, - final String timestamp, final Object failedData, final EventHandle eventHandle) { + final String timestamp, final Object failedData, final List eventHandles) { checkNotNull(pluginId, "pluginId cannot be null"); checkArgument(!pluginId.isEmpty(), "pluginId cannot be an empty string"); @@ -58,7 +63,8 @@ private DlqObject(final String pluginId, final String pluginName, final String p this.pluginName = pluginName; this.pipelineName = pipelineName; this.failedData = failedData; - this.eventHandle = eventHandle; + this.eventHandles = eventHandles; + this.eventHandle = null; this.timestamp = StringUtils.isEmpty(timestamp) ? FORMATTER.format(Instant.now()) : timestamp; } @@ -83,12 +89,18 @@ public String getTimestamp() { return timestamp; } - public EventHandle getEventHandle() { - return eventHandle; + public List getEventHandles() { + return eventHandles; } public void releaseEventHandle(boolean result) { - if (eventHandle != null) { + if (eventHandles != null && eventHandles.size() == 1) { + eventHandles.get(0).release(result); + } + } + + public void releaseEventHandles(boolean result) { + for (final EventHandle eventHandle: eventHandles) { eventHandle.release(result); } } @@ -102,7 +114,7 @@ public boolean equals(final Object o) { && Objects.equals(pluginId, that.pluginId) && Objects.equals(pluginName, that.pluginName) && Objects.equals(pipelineName, that.pipelineName) - && Objects.equals(eventHandle, that.eventHandle) + && Objects.equals(eventHandles, that.eventHandles) && Objects.equals(timestamp, that.getTimestamp()); } @@ -122,9 +134,9 @@ public String toString() { '}'; } - public static DlqObject createDlqObject(PluginSetting pluginSetting, EventHandle eventHandle, Object failedData) { + public static DlqObject createDlqObject(PluginSetting pluginSetting, List eventHandles, Object failedData) { return DlqObject.builder() - .withEventHandle(eventHandle) + .withEventHandles(eventHandles) .withFailedData(failedData) .withPluginName(pluginSetting.getName()) .withPipelineName(pluginSetting.getPipelineName()) @@ -142,7 +154,7 @@ public static class Builder { private String pluginName; private String pipelineName; private Object failedData; - private EventHandle eventHandle; + private List eventHandles; private String timestamp; @@ -171,8 +183,14 @@ public Builder withTimestamp(final String timestamp) { return this; } + public Builder withEventHandles(final List eventHandles) { + this.eventHandles = eventHandles; + return this; + } + public Builder withEventHandle(final EventHandle eventHandle) { - this.eventHandle = eventHandle; + this.eventHandles = new ArrayList<>(); + this.eventHandles.add(eventHandle); return this; } @@ -182,7 +200,7 @@ public Builder withTimestamp(final Instant instant) { } public DlqObject build() { - return new DlqObject(this.pluginId, this.pluginName, this.pipelineName, this.timestamp, this.failedData, this.eventHandle); + return new DlqObject(this.pluginId, this.pluginName, this.pipelineName, this.timestamp, this.failedData, this.eventHandles); } } diff --git a/data-prepper-api/src/test/java/org/opensearch/dataprepper/model/failures/DlqObjectTest.java b/data-prepper-api/src/test/java/org/opensearch/dataprepper/model/failures/DlqObjectTest.java index dbeabaf5e1..094d5ff492 100644 --- a/data-prepper-api/src/test/java/org/opensearch/dataprepper/model/failures/DlqObjectTest.java +++ b/data-prepper-api/src/test/java/org/opensearch/dataprepper/model/failures/DlqObjectTest.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; +import java.util.List; import static java.util.UUID.randomUUID; import static org.hamcrest.CoreMatchers.allOf; import static org.hamcrest.CoreMatchers.containsString; @@ -70,6 +71,22 @@ public void test_build_with_timestamp() { assertThat(testObject, is(notNullValue())); } + @Test + public void test_build_with_timestamp_with_event_handles() { + + final DlqObject testObject = DlqObject.builder() + .withPluginId(pluginId) + .withPluginName(pluginName) + .withPipelineName(pipelineName) + .withFailedData(failedData) + .withEventHandles(List.of(eventHandle)) + .withTimestamp(randomUUID().toString()) + .build(); + + assertThat(testObject, is(notNullValue())); + } + + @Test public void test_build_without_timestamp() { @@ -133,9 +150,9 @@ public void test_createDlqObject() { when(pluginSetting.getPipelineName()).thenReturn(testPipelineName); eventHandle = mock(EventHandle.class); Map data = new HashMap<>(); - DlqObject dlqObject = DlqObject.createDlqObject(pluginSetting, eventHandle, data); + DlqObject dlqObject = DlqObject.createDlqObject(pluginSetting, List.of(eventHandle), data); assertThat(dlqObject, is(notNullValue())); - assertThat(dlqObject.getEventHandle(), is(eventHandle)); + assertThat(dlqObject.getEventHandles(), is(List.of(eventHandle))); assertThat(dlqObject.getFailedData(), is(data)); assertThat(dlqObject.getPluginName(), is(testName)); assertThat(dlqObject.getPipelineName(), is(testPipelineName)); @@ -191,13 +208,23 @@ public void test_get_failedData() { @Test public void test_get_release_eventHandle() { doAnswer(a -> { return null; }).when(eventHandle).release(any(Boolean.class)); - final Object actualEventHandle = testObject.getEventHandle(); - assertThat(actualEventHandle, is(notNullValue())); - assertThat(actualEventHandle, is(eventHandle)); + final List actualEventHandles = testObject.getEventHandles(); + assertThat(actualEventHandles, is(notNullValue())); + assertThat(actualEventHandles, is(List.of(eventHandle))); testObject.releaseEventHandle(true); verify(eventHandle).release(any(Boolean.class)); } + @Test + public void test_get_release_eventHandles() { + doAnswer(a -> { return null; }).when(eventHandle).release(any(Boolean.class)); + final List actualEventHandles = testObject.getEventHandles(); + assertThat(actualEventHandles, is(notNullValue())); + assertThat(actualEventHandles, is(List.of(eventHandle))); + testObject.releaseEventHandles(true); + verify(eventHandle).release(any(Boolean.class)); + } + @Test public void test_get_timestamp() { final String string = testObject.getTimestamp(); diff --git a/data-prepper-plugins/aws-plugin-api/build.gradle b/data-prepper-plugins/aws-plugin-api/build.gradle index 1042ffadc9..1b2b723d07 100644 --- a/data-prepper-plugins/aws-plugin-api/build.gradle +++ b/data-prepper-plugins/aws-plugin-api/build.gradle @@ -3,6 +3,8 @@ dependencies { implementation 'software.amazon.awssdk:auth' implementation 'software.amazon.awssdk:apache-client' implementation 'org.apache.httpcomponents.client5:httpclient5:5.3.1' + implementation 'com.fasterxml.jackson.core:jackson-annotations' + testImplementation libs.commons.lang3 testImplementation 'org.hibernate.validator:hibernate-validator:8.0.2.Final' } diff --git a/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsConfig.java b/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsConfig.java new file mode 100644 index 0000000000..952a690121 --- /dev/null +++ b/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsConfig.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.aws.api; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.constraints.Size; +import software.amazon.awssdk.regions.Region; + +import java.util.Map; + +/** + * AwsConfig is based on the S3-Sink AwsAuthenticationOptions + * where the configuration allows the sink to fetch Aws credentials + * and resources. + */ +public class AwsConfig { + @JsonProperty("region") + @Size(min = 1, message = "Region cannot be empty string") + private String awsRegion; + + @JsonProperty("sts_role_arn") + @Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters") + private String awsStsRoleArn; + + @JsonProperty("sts_header_overrides") + @Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override") + private Map awsStsHeaderOverrides; + + @JsonProperty("sts_external_id") + @Size(min = 2, max = 1224, message = "awsStsExternalId length should be between 2 and 1224 characters") + private String awsStsExternalId; + + public Region getAwsRegion() { + return awsRegion != null ? Region.of(awsRegion) : null; + } + + public String getAwsStsRoleArn() { + return awsStsRoleArn; + } + + public String getAwsStsExternalId() { + return awsStsExternalId; + } + + public Map getAwsStsHeaderOverrides() { + return awsStsHeaderOverrides; + } +} diff --git a/data-prepper-plugins/aws-plugin-api/src/test/java/org/opensearch/dataprepper/aws/api/AwsConfigTest.java b/data-prepper-plugins/aws-plugin-api/src/test/java/org/opensearch/dataprepper/aws/api/AwsConfigTest.java new file mode 100644 index 0000000000..0193c84a0d --- /dev/null +++ b/data-prepper-plugins/aws-plugin-api/src/test/java/org/opensearch/dataprepper/aws/api/AwsConfigTest.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.aws.api; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Field; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import org.apache.commons.lang3.RandomStringUtils; +import software.amazon.awssdk.regions.Region; + +import java.util.Map; + +public class AwsConfigTest { + + private AwsConfig awsConfig; + + @BeforeEach + void setUp() { + awsConfig = new AwsConfig(); + } + + @Test + void TestConfigOptions_notNull() throws NoSuchFieldException, IllegalAccessException { + + final String testStsRoleArn = RandomStringUtils.randomAlphabetic(10); + reflectivelySetField(awsConfig, "awsStsRoleArn", testStsRoleArn); + assertThat(awsConfig.getAwsStsRoleArn(), equalTo(testStsRoleArn)); + final String testStsExternalId = RandomStringUtils.randomAlphabetic(10); + reflectivelySetField(awsConfig, "awsStsExternalId", testStsExternalId); + assertThat(awsConfig.getAwsStsExternalId(), equalTo(testStsExternalId)); + + final Map testStsHeaderOverrides = Map.of(RandomStringUtils.randomAlphabetic(5), RandomStringUtils.randomAlphabetic(10)); + reflectivelySetField(awsConfig, "awsStsHeaderOverrides", testStsHeaderOverrides); + assertThat(awsConfig.getAwsStsHeaderOverrides(), equalTo(testStsHeaderOverrides)); + + final String testRegion = RandomStringUtils.randomAlphabetic(8); + reflectivelySetField(awsConfig, "awsRegion", testRegion); + assertThat(awsConfig.getAwsRegion(), equalTo(Region.of(testRegion))); + } + + private void reflectivelySetField(final AwsConfig awsConfig, final String fieldName, final Object value) throws NoSuchFieldException, IllegalAccessException { + final Field field = AwsConfig.class.getDeclaredField(fieldName); + try { + field.setAccessible(true); + field.set(awsConfig, value); + } finally { + field.setAccessible(false); + } + } +} + diff --git a/data-prepper-plugins/cloudwatch-logs/src/main/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/utils/CloudWatchLogsSinkUtils.java b/data-prepper-plugins/cloudwatch-logs/src/main/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/utils/CloudWatchLogsSinkUtils.java index 35f97b96d9..26bfcc84dd 100644 --- a/data-prepper-plugins/cloudwatch-logs/src/main/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/utils/CloudWatchLogsSinkUtils.java +++ b/data-prepper-plugins/cloudwatch-logs/src/main/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/utils/CloudWatchLogsSinkUtils.java @@ -19,7 +19,7 @@ public class CloudWatchLogsSinkUtils { public static DlqObject createDlqObject(final int status, final EventHandle eventHandle, final String message, final String failureMessage, final DlqPushHandler dlqPushHandler) { if (dlqPushHandler != null) { CloudWatchLogsSinkDlqData cloudWatchLogsSinkDlqData = CloudWatchLogsSinkDlqData.createDlqData(status, message, failureMessage); - return DlqObject.createDlqObject(dlqPushHandler.getPluginSetting(), eventHandle, cloudWatchLogsSinkDlqData); + return DlqObject.createDlqObject(dlqPushHandler.getPluginSetting(), List.of(eventHandle), cloudWatchLogsSinkDlqData); } else { eventHandle.release(false); } @@ -35,7 +35,7 @@ public static void handleDlqObjects(List dlqObjects, final DlqPushHan result = dlqPushHandler.perform(dlqObjects); } for (final DlqObject dlqObject : dlqObjects) { - dlqObject.getEventHandle().release(result); + dlqObject.releaseEventHandles(result); } } } diff --git a/data-prepper-plugins/common/src/main/java/org/opensearch/dataprepper/plugins/accumulator/InMemoryBuffer.java b/data-prepper-plugins/common/src/main/java/org/opensearch/dataprepper/plugins/accumulator/InMemoryBuffer.java index 252dac88a9..900f200b80 100644 --- a/data-prepper-plugins/common/src/main/java/org/opensearch/dataprepper/plugins/accumulator/InMemoryBuffer.java +++ b/data-prepper-plugins/common/src/main/java/org/opensearch/dataprepper/plugins/accumulator/InMemoryBuffer.java @@ -6,6 +6,7 @@ package org.opensearch.dataprepper.plugins.accumulator; import org.apache.commons.lang3.time.StopWatch; + import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStream; @@ -16,12 +17,12 @@ */ public class InMemoryBuffer implements Buffer { - private static final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + private final ByteArrayOutputStream byteArrayOutputStream; private int eventCount; private final StopWatch watch; InMemoryBuffer() { - byteArrayOutputStream.reset(); + byteArrayOutputStream = new ByteArrayOutputStream(); eventCount = 0; watch = new StopWatch(); watch.start(); @@ -59,4 +60,4 @@ public OutputStream getOutputStream() { public void setEventCount(int eventCount) { this.eventCount = eventCount; } -} \ No newline at end of file +} diff --git a/data-prepper-plugins/sqs-sink/build.gradle b/data-prepper-plugins/sqs-sink/build.gradle new file mode 100644 index 0000000000..ef3241d8e6 --- /dev/null +++ b/data-prepper-plugins/sqs-sink/build.gradle @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +sourceSets { + integrationTest { + java { + compileClasspath += main.output + test.output + runtimeClasspath += main.output + test.output + srcDir file('src/integrationTest/java') + } + } +} + +configurations { + integrationTestImplementation.extendsFrom testImplementation + integrationTestRuntime.extendsFrom testRuntime +} + +dependencies { + implementation project(':data-prepper-plugins:aws-plugin-api') + implementation project(':data-prepper-api') + implementation project(':data-prepper-plugins:common') + implementation project(':data-prepper-plugins:failures-common') + implementation project(':data-prepper-plugins:sqs-common') + implementation project(':data-prepper-plugins:parse-json-processor:') + implementation 'io.micrometer:micrometer-core' + implementation 'com.fasterxml.jackson.core:jackson-core' + implementation 'com.fasterxml.jackson.core:jackson-databind' + implementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml' + implementation 'software.amazon.awssdk:sqs' + implementation 'software.amazon.awssdk:arns' + implementation 'software.amazon.awssdk:s3' + implementation 'software.amazon.awssdk:sts' + implementation libs.commons.lang3 + implementation libs.armeria.core + implementation 'org.projectlombok:lombok:1.18.26' + implementation 'org.hibernate.validator:hibernate-validator:8.0.0.Final' + testImplementation project(':data-prepper-test-common') + annotationProcessor 'org.projectlombok:lombok:1.18.24' +} + +jacocoTestCoverageVerification { + dependsOn jacocoTestReport + violationRules { + rule { + limit { + minimum = 0.90 + } + } + } +} + +task integrationTest(type: Test) { + group = 'verification' + testClassesDirs = sourceSets.integrationTest.output.classesDirs + + useJUnitPlatform() + + classpath = sourceSets.integrationTest.runtimeClasspath + systemProperty 'tests.sqs.queue_url', System.getProperty('tests.sqs.queue_url') + systemProperty 'tests.s3.bucket', System.getProperty('tests.s3.bucket') + systemProperty 'tests.aws.region', System.getProperty('tests.aws.region') + systemProperty 'tests.aws.role', System.getProperty('tests.aws.role') + filter { + includeTestsMatching '*IT' + } +} + diff --git a/data-prepper-plugins/sqs-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkIT.java b/data-prepper-plugins/sqs-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkIT.java new file mode 100644 index 0000000000..da7d2bddba --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkIT.java @@ -0,0 +1,707 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.aws.api.AwsConfig; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import io.micrometer.core.instrument.Counter; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.expression.ExpressionEvaluator; +import org.opensearch.dataprepper.plugins.source.sqs.common.SqsClientFactory; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.S3Object; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; + +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import org.opensearch.dataprepper.plugins.dlq.DlqProvider; +import org.opensearch.dataprepper.plugins.dlq.s3.S3DlqProvider; +import org.opensearch.dataprepper.plugins.dlq.s3.S3DlqWriterConfig; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.Message; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse; +import software.amazon.awssdk.services.sqs.model.PurgeQueueRequest; +import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodec; +import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodecConfig; +import org.opensearch.dataprepper.model.sink.OutputCodecContext; +import org.opensearch.dataprepper.model.codec.OutputCodec; + +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.EventHandle; +import org.opensearch.dataprepper.model.log.JacksonLog; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.sink.SinkContext; +import software.amazon.awssdk.regions.Region; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.awaitility.Awaitility.await; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.times; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Random; +import java.util.Map; +import java.util.List; +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicInteger; + +@ExtendWith(MockitoExtension.class) +public class SqsSinkIT { + static final int NUM_RECORDS = 10; + static final int MAX_SIZE = 256*1024; + static final String DLQ_PREFIX = "sqsSinkIT/"; + @Mock + private PluginFactory pluginFactory; + @Mock + private PluginSetting pluginSetting; + + @Mock + private EventHandle eventHandle; + + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + + @Mock + private AwsConfig awsConfig; + + @Mock + private SqsThresholdConfig thresholdConfig; + + @Mock + private SqsSinkConfig sqsSinkConfig; + + @Mock + private SinkContext sinkContext; + + @Mock + private PluginModel codec; + + @Mock + private Counter eventsSuccessCounter; + @Mock + private Counter requestsSuccessCounter; + @Mock + private Counter eventsFailedCounter; + @Mock + private Counter requestsFailedCounter; + @Mock + private Counter dlqSuccessCounter; + + private JsonOutputCodec jsonCodec; + private String bucket; + private String awsRegion; + private String awsRole; + private String queueUrl; + private String groupId; + private SqsSink sink; + private AtomicInteger count; + private SqsClient sqsClient; + private ObjectMapper objectMapper; + private ExpressionEvaluator expressionEvaluator; + private List messages; + private AtomicInteger eventsSuccessCount; + private AtomicInteger requestsSuccessCount; + private AtomicInteger eventsFailedCount; + private AtomicInteger requestsFailedCount; + private AtomicInteger dlqSuccessCount; + private AwsCredentialsProvider awsCredentialsProvider; + private S3Client s3Client; + private int numLargeMessages; + private Random random; + + + @BeforeEach + void setUp() { + random = new Random(); + numLargeMessages = 0; + awsCredentialsProvider = DefaultCredentialsProvider.create(); + pluginMetrics = mock(PluginMetrics.class); + eventsSuccessCount = new AtomicInteger(0); + requestsSuccessCount = new AtomicInteger(0); + eventsFailedCount = new AtomicInteger(0); + requestsFailedCount = new AtomicInteger(0); + dlqSuccessCount = new AtomicInteger(0); + eventsSuccessCounter = mock(Counter.class); + eventsFailedCounter = mock(Counter.class); + requestsSuccessCounter = mock(Counter.class); + requestsFailedCounter = mock(Counter.class); + dlqSuccessCounter = mock(Counter.class); + lenient().doAnswer((a)-> { + int v = (int)(double)(a.getArgument(0)); + eventsSuccessCount.addAndGet(v); + return null; + }).when(eventsSuccessCounter).increment(any(Double.class)); + lenient().doAnswer((a)-> { + int v = (int)(double)(a.getArgument(0)); + eventsFailedCount.addAndGet(v); + return null; + }).when(eventsFailedCounter).increment(any(Double.class)); + lenient().doAnswer((a)-> { + requestsSuccessCount.addAndGet(1); + return null; + }).when(requestsSuccessCounter).increment(); + lenient().doAnswer((a)-> { + int v = (int)(double)(a.getArgument(0)); + requestsSuccessCount.addAndGet(v); + return null; + }).when(requestsSuccessCounter).increment(any(Double.class)); + lenient().doAnswer((a)-> { + int v = (int)(double)(a.getArgument(0)); + requestsFailedCount.addAndGet(v); + return null; + }).when(requestsFailedCounter).increment(any(Double.class)); + lenient().doAnswer((a)-> { + int v = (int)(double)(a.getArgument(0)); + dlqSuccessCount.addAndGet(v); + return null; + }).when(dlqSuccessCounter).increment(any(Double.class)); + lenient().doAnswer(a -> { + String s = (String)(a.getArgument(0)); + if (s.equals(SqsSinkMetrics.SQS_SINK_REQUESTS_SUCCEEDED)) { + return requestsSuccessCounter; + } + if (s.equals(SqsSinkMetrics.SQS_SINK_EVENTS_SUCCEEDED)) { + return eventsSuccessCounter; + } + if (s.equals(SqsSinkMetrics.SQS_SINK_REQUESTS_FAILED)) { + return requestsFailedCounter; + } + if (s.equals(SqsSinkMetrics.SQS_SINK_EVENTS_FAILED)) { + return eventsFailedCounter; + } + if (s.contains("NumDlqSuccess")) { + return dlqSuccessCounter; + } + return null; + }).when(pluginMetrics).counter(anyString()); + messages = new ArrayList<>(); + pluginFactory = mock(PluginFactory.class); + jsonCodec = new JsonOutputCodec(new JsonOutputCodecConfig()); + when(pluginFactory.loadPlugin(eq(OutputCodec.class), any())).thenReturn(jsonCodec); + expressionEvaluator = mock(ExpressionEvaluator.class); + codec = mock(PluginModel.class); + when(codec.getPluginName()).thenReturn("json"); + pluginSetting = mock(PluginSetting.class); + when(pluginSetting.getName()).thenReturn("sqs"); + when(pluginSetting.getPipelineName()).thenReturn("test-pipeline"); + when(codec.getPluginSettings()).thenReturn(new HashMap()); + groupId = "testGroupId"; + count = new AtomicInteger(0); + objectMapper = new ObjectMapper(); + sinkContext = mock(SinkContext.class); + eventHandle = mock(EventHandle.class); + when(sinkContext.getExcludeKeys()).thenReturn(null); + when(sinkContext.getIncludeKeys()).thenReturn(null); + when(sinkContext.getTagsTargetKey()).thenReturn(null); + awsRegion = System.getProperty("tests.aws.region"); + awsRole = System.getProperty("tests.aws.role"); + bucket = System.getProperty("tests.s3.bucket"); + awsConfig = mock(AwsConfig.class); + when(awsConfig.getAwsRegion()).thenReturn(Region.of(awsRegion)); + when(awsConfig.getAwsStsRoleArn()).thenReturn(awsRole); + when(awsConfig.getAwsStsExternalId()).thenReturn(null); + when(awsConfig.getAwsStsHeaderOverrides()).thenReturn(null); + when(awsCredentialsSupplier.getProvider(any())).thenAnswer(options -> DefaultCredentialsProvider.create()); + sqsClient = SqsClientFactory.createSqsClient(Region.of(awsRegion), DefaultCredentialsProvider.create()); + queueUrl = System.getProperty("tests.sqs.queue_url"); + sqsSinkConfig = mock(SqsSinkConfig.class); + when(sqsSinkConfig.getQueueUrl()).thenReturn(queueUrl); + when(sqsSinkConfig.getGroupId()).thenReturn(groupId); + when(sqsSinkConfig.getCodec()).thenReturn(codec); + when(sqsSinkConfig.getAwsConfig()).thenReturn(awsConfig); + when(sqsSinkConfig.getDlq()).thenReturn(null); + + thresholdConfig = mock(SqsThresholdConfig.class); + when(sqsSinkConfig.getMaxRetries()).thenReturn(3); + when(thresholdConfig.getMaxEventsPerMessage()).thenReturn(1); + when(thresholdConfig.getMaxMessageSizeBytes()).thenReturn(250*1024L); + when(sqsSinkConfig.getThresholdConfig()).thenReturn(thresholdConfig); + try { + purgeMessages(); + } catch (Exception e){} + } + + private void purgeMessages() { + sqsClient.purgeQueue(PurgeQueueRequest.builder() + .queueUrl(queueUrl) + .build()); + } + + @AfterEach + void tearDown() { + List entries = new ArrayList<>(); + int i = 0; + for (final Message message : messages) { + entries.add(DeleteMessageBatchRequestEntry.builder() + .id(message.messageId()) + .receiptHandle(message.receiptHandle()) + .build()); + + if (++i == 10) { + DeleteMessageBatchResponse response = + sqsClient.deleteMessageBatch(DeleteMessageBatchRequest.builder().queueUrl(queueUrl).entries(entries).build()); + i = 0; + entries.clear(); + } + } + if (i > 0) { + DeleteMessageBatchResponse response = + sqsClient.deleteMessageBatch(DeleteMessageBatchRequest.builder().queueUrl(queueUrl).entries(entries).build()); + } + deleteObjectsWithPrefix(bucket, DLQ_PREFIX); + } + + private List listObjectsWithPrefix(String bucketName, String prefix) { + List objectNames = new ArrayList<>(); + ListObjectsRequest request = ListObjectsRequest.builder() + .bucket(bucketName) + .prefix(prefix).build(); + + ListObjectsResponse result = s3Client.listObjects(request); + for (final S3Object s3Object : result.contents()) { + objectNames.add(s3Object.key()); + } + return objectNames; + } + + private void deleteObjectsWithPrefix(String bucketName, String prefix) { + if (s3Client != null) { + List objectNames = listObjectsWithPrefix(bucketName, prefix); + for (final String objectName : objectNames) { + final DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder() + .bucket(bucket) + .key(objectName).build(); + s3Client.deleteObject(deleteObjectRequest); + } + } + } + + private SqsSink createObjectUnderTest() { + return new SqsSink(pluginSetting, pluginMetrics, pluginFactory, sqsSinkConfig, sinkContext, expressionEvaluator, awsCredentialsSupplier); + } + + private List getMessages(final String queueUrl) { + ReceiveMessageRequest request = ReceiveMessageRequest.builder() + .queueUrl(queueUrl) + .maxNumberOfMessages(10) + .waitTimeSeconds(3) + .attributeNamesWithStrings("All") + .messageAttributeNames("All") + .build(); + return sqsClient.receiveMessage(request).messages(); + } + + @ParameterizedTest + @ValueSource(ints = {10, 30, 50, 70}) + void TestSinkOperationWithBatchSize(int numRecords) throws Exception { + when(thresholdConfig.getMaxEventsPerMessage()).thenReturn(1); + sink = createObjectUnderTest(); + Collection> records = getRecordList(numRecords, false); + sink.doOutput(records); + + await().atMost(Duration.ofSeconds(60)) + .untilAsserted(() -> { + final Map expectedMap = new HashMap<>(); + for (int i = 0; i < numRecords; i++) { + expectedMap.put("Person"+i, Integer.toString(i)); + } + List msgs = getMessages(queueUrl); + messages.addAll(msgs); + assertThat(messages.size(), equalTo(numRecords)); + for (int i = 0; i < messages.size(); i++) { + String body = messages.get(i).body(); + Map events = objectMapper.readValue(body, Map.class); + List objs = (List)events.get("events"); + assertNotNull(objs); + assertThat(objs.size(), equalTo(1)); + Map event = (Map)objs.get(0); + String name = (String)event.get("name"); + assertTrue(expectedMap.containsKey(name)); + assertThat(expectedMap.get(name), equalTo(event.get("age"))); + expectedMap.remove(name); + } + }); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo(numRecords/10)); + assertThat(dlqSuccessCount.get(), equalTo(0)); + verify(eventHandle, times(numRecords)).release(true); + } + + @ParameterizedTest + @ValueSource(ints = {10, 30, 50, 70}) + void TestSinkOperationWithBatchSizeWithSinkContext(int numRecords) throws Exception { + when(thresholdConfig.getMaxEventsPerMessage()).thenReturn(1); + when(sinkContext.getTagsTargetKey()).thenReturn("sqsSinkTags"); + when(sinkContext.getExcludeKeys()).thenReturn(List.of("age")); + sink = createObjectUnderTest(); + Collection> records = getRecordList(numRecords, false); + sink.doOutput(records); + + await().atMost(Duration.ofSeconds(60)) + .untilAsserted(() -> { + final Map expectedMap = new HashMap<>(); + for (int i = 0; i < numRecords; i++) { + expectedMap.put("Person"+i, Integer.toString(i)); + } + List msgs = getMessages(queueUrl); + messages.addAll(msgs); + assertThat(messages.size(), equalTo(numRecords)); + for (int i = 0; i < messages.size(); i++) { + String body = messages.get(i).body(); + Map events = objectMapper.readValue(body, Map.class); + List objs = (List)events.get("events"); + assertNotNull(objs); + assertThat(objs.size(), equalTo(1)); + Map event = (Map)objs.get(0); + String name = (String)event.get("name"); + assertTrue(expectedMap.containsKey(name)); + assertThat(event.get("age"), equalTo(null)); + assertThat(event.get("sqsSinkTags"), equalTo(List.of())); + expectedMap.remove(name); + } + }); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo(numRecords/10)); + assertThat(dlqSuccessCount.get(), equalTo(0)); + verify(eventHandle, times(numRecords)).release(true); + } + + @ParameterizedTest + @ValueSource(ints = {10, 25, 40, 75}) + void TestSinkOperationWithFlushIntervalOneRequest(int numRecords) throws Exception { + when(thresholdConfig.getMaxEventsPerMessage()).thenReturn(50); + when(thresholdConfig.getFlushInterval()).thenReturn(3L); + + sink = createObjectUnderTest(); + Collection> records = getRecordList(numRecords, false); + sink.doOutput(records); + + await().atMost(Duration.ofSeconds(60)) + .untilAsserted(() -> { + sink.doOutput(Collections.emptyList()); + final Map expectedMap = new HashMap<>(); + for (int i = 0; i < numRecords; i++) { + expectedMap.put("Person"+i, Integer.toString(i)); + } + List msgs = getMessages(queueUrl); + messages.addAll(msgs); + assertThat(messages.size(), equalTo(1 + numRecords/50)); + int remainingRecords = numRecords; + for (int i = 0; i < messages.size(); i++) { + String body = messages.get(i).body(); + Map events = objectMapper.readValue(body, Map.class); + List objs = (List)events.get("events"); + assertNotNull(objs); + remainingRecords -= objs.size(); + for (int j = 0; j < objs.size(); j++) { + Map event = (Map)objs.get(j); + String name = (String)event.get("name"); + assertTrue(expectedMap.containsKey(name)); + assertThat(expectedMap.get(name), equalTo(event.get("age"))); + expectedMap.remove(name); + } + } + }); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo(1)); + assertThat(dlqSuccessCount.get(), equalTo(0)); + verify(eventHandle, times(numRecords)).release(true); + } + + @ParameterizedTest + @ValueSource(ints = {10, 25, 40, 75}) + void TestSinkOperationWithFlushIntervalMultipleRequests(int numRecords) throws Exception { + when(thresholdConfig.getMaxEventsPerMessage()).thenReturn(5); + when(thresholdConfig.getFlushInterval()).thenReturn(5L); + + sink = createObjectUnderTest(); + Collection> records = getRecordList(numRecords, false); + sink.doOutput(records); + + await().atMost(Duration.ofSeconds(60)) + .untilAsserted(() -> { + sink.doOutput(Collections.emptyList()); + final Map expectedMap = new HashMap<>(); + for (int i = 0; i < numRecords; i++) { + expectedMap.put("Person"+i, Integer.toString(i)); + } + List msgs = getMessages(queueUrl); + messages.addAll(msgs); + assertThat((double)messages.size(), equalTo(Math.ceil(numRecords/5))); + int remainingRecords = numRecords; + for (int i = 0; i < messages.size(); i++) { + String body = messages.get(i).body(); + Map events = objectMapper.readValue(body, Map.class); + List objs = (List)events.get("events"); + assertNotNull(objs); + assertThat(objs.size(), equalTo(5)); + remainingRecords -= objs.size(); + for (int j = 0; j < objs.size(); j++) { + Map event = (Map)objs.get(j); + String name = (String)event.get("name"); + assertTrue(expectedMap.containsKey(name)); + assertThat(expectedMap.get(name), equalTo(event.get("age"))); + expectedMap.remove(name); + } + } + }); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat((double)requestsSuccessCount.get(), equalTo(1.0+numRecords/50)); + assertThat(dlqSuccessCount.get(), equalTo(0)); + verify(eventHandle, times(numRecords)).release(true); + } + + @Test + void TestWithLargeSingleMessagesSentToDLQ() throws Exception { + s3Client = S3Client.builder() + .credentialsProvider(awsCredentialsProvider) + .region(Region.of(awsRegion)) + .build(); + PluginModel dlqConfig = mock(PluginModel.class); + when(dlqConfig.getPluginSettings()).thenReturn(new HashMap()); + when(dlqConfig.getPluginName()).thenReturn("s3"); + + S3DlqWriterConfig s3DlqWriterConfig = mock(S3DlqWriterConfig.class); + when(s3DlqWriterConfig.getBucket()).thenReturn(bucket); + when(s3DlqWriterConfig.getKeyPathPrefix()).thenReturn(DLQ_PREFIX); + when(s3DlqWriterConfig.getS3Client()).thenReturn(s3Client); + S3DlqProvider s3DlqProvider = new S3DlqProvider(s3DlqWriterConfig); + when(pluginFactory.loadPlugin(eq(DlqProvider.class), any())).thenReturn(s3DlqProvider); + + when(thresholdConfig.getMaxEventsPerMessage()).thenReturn(1); + when(sqsSinkConfig.getDlq()).thenReturn(dlqConfig); + + sink = createObjectUnderTest(); + Collection> records = getRecordList(NUM_RECORDS, false); + Record largeRecord = getLargeRecord(256*1024); + records.add(largeRecord); + + sink.doOutput(records); + await().atMost(Duration.ofSeconds(30)) + .untilAsserted(() -> { + final Map expectedMap = new HashMap<>(); + for (int i = 0; i < NUM_RECORDS; i++) { + expectedMap.put("Person"+i, Integer.toString(i)); + } + List msgs = getMessages(queueUrl); + messages.addAll(msgs); + assertThat(messages.size(), equalTo(NUM_RECORDS)); + for (int i = 0; i < messages.size(); i++) { + String body = messages.get(i).body(); + Map events = objectMapper.readValue(body, Map.class); + List objs = (List)events.get("events"); + assertNotNull(objs); + assertThat(objs.size(), equalTo(1)); + Map event = (Map)objs.get(0); + String name = (String)event.get("name"); + assertTrue(expectedMap.containsKey(name)); + assertThat(expectedMap.get(name), equalTo(event.get("age"))); + expectedMap.remove(name); + } + }); + assertThat(eventsSuccessCount.get(), equalTo(NUM_RECORDS)); + assertThat(requestsSuccessCount.get(), equalTo(1)); + assertThat(dlqSuccessCount.get(), equalTo(1)); + verify(eventHandle, times(NUM_RECORDS+1)).release(true); + } + + private Record getLargeRecord(int size) { + final Event event = JacksonLog.builder() + .withData(Map.of("key", RandomStringUtils.randomAlphabetic(size))) + .withEventHandle(eventHandle) + .build(); + return new Record<>(event); + } + + + @Test + void TestSinkOperationWithQueuesAsExpression() throws Exception { + s3Client = S3Client.builder() + .credentialsProvider(awsCredentialsProvider) + .region(Region.of(awsRegion)) + .build(); + PluginModel dlqConfig = mock(PluginModel.class); + when(dlqConfig.getPluginSettings()).thenReturn(new HashMap()); + when(dlqConfig.getPluginName()).thenReturn("s3"); + + S3DlqWriterConfig s3DlqWriterConfig = mock(S3DlqWriterConfig.class); + when(s3DlqWriterConfig.getBucket()).thenReturn(bucket); + when(s3DlqWriterConfig.getKeyPathPrefix()).thenReturn(DLQ_PREFIX); + when(s3DlqWriterConfig.getS3Client()).thenReturn(s3Client); + S3DlqProvider s3DlqProvider = new S3DlqProvider(s3DlqWriterConfig); + when(pluginFactory.loadPlugin(eq(DlqProvider.class), any())).thenReturn(s3DlqProvider); + + when(thresholdConfig.getMaxEventsPerMessage()).thenReturn(1); + when(sqsSinkConfig.getDlq()).thenReturn(dlqConfig); + when(expressionEvaluator.isValidFormatExpression(anyString())).thenReturn(true); + + when(sqsSinkConfig.getQueueUrl()).thenReturn(queueUrl+"${/id}"); + + sink = createObjectUnderTest(); + Collection> records = getRecordList(2*NUM_RECORDS, false); + sink.doOutput(records); + + await().atMost(Duration.ofSeconds(60)) + .untilAsserted(() -> { + final Map expectedMap = new HashMap<>(); + for (int i = 0; i < NUM_RECORDS; i++) { + expectedMap.put("Person"+i, Integer.toString(i)); + } + List msgs = getMessages(queueUrl); + messages.addAll(msgs); + assertThat(messages.size(), equalTo(NUM_RECORDS)); + for (int i = 0; i < messages.size(); i++) { + String body = messages.get(i).body(); + Map events = objectMapper.readValue(body, Map.class); + List objs = (List)events.get("events"); + assertNotNull(objs); + assertThat(objs.size(), equalTo(1)); + Map event = (Map)objs.get(0); + String name = (String)event.get("name"); + assertTrue(expectedMap.containsKey(name)); + assertThat(expectedMap.get(name), equalTo(event.get("age"))); + expectedMap.remove(name); + } + }); + assertThat(eventsSuccessCount.get(), equalTo(NUM_RECORDS)); + assertThat(requestsSuccessCount.get(), equalTo(1)); + assertThat(dlqSuccessCount.get(), equalTo(NUM_RECORDS)); + verify(eventHandle, times(2*NUM_RECORDS)).release(true); + } + + @RepeatedTest(value = 5) + void TestWithManyRecordsWithRandomSizes() throws Exception { + final int numRecords = 100 + random.nextInt(100); + s3Client = S3Client.builder() + .credentialsProvider(awsCredentialsProvider) + .region(Region.of(awsRegion)) + .build(); + PluginModel dlqConfig = mock(PluginModel.class); + when(dlqConfig.getPluginSettings()).thenReturn(new HashMap()); + when(dlqConfig.getPluginName()).thenReturn("s3"); + when(thresholdConfig.getFlushInterval()).thenReturn(3L); + + S3DlqWriterConfig s3DlqWriterConfig = mock(S3DlqWriterConfig.class); + when(s3DlqWriterConfig.getBucket()).thenReturn(bucket); + when(s3DlqWriterConfig.getKeyPathPrefix()).thenReturn(DLQ_PREFIX); + when(s3DlqWriterConfig.getS3Client()).thenReturn(s3Client); + S3DlqProvider s3DlqProvider = new S3DlqProvider(s3DlqWriterConfig); + when(pluginFactory.loadPlugin(eq(DlqProvider.class), any())).thenReturn(s3DlqProvider); + + when(thresholdConfig.getMaxEventsPerMessage()).thenReturn(20+random.nextInt(30)); + when(sqsSinkConfig.getDlq()).thenReturn(dlqConfig); + + sink = createObjectUnderTest(); + Collection> records = getRecordList(numRecords, true); + + sink.doOutput(records); + await().atMost(Duration.ofSeconds(30)) + .untilAsserted(() -> { + sink.doOutput(Collections.emptyList()); + final Map expectedMap = new HashMap<>(); + for (int i = 0; i < numRecords; i++) { + expectedMap.put("Person"+i, Integer.toString(i)); + } + List msgs = getMessages(queueUrl); + messages.addAll(msgs); + int recordsReceived = 0; + for (int i = 0; i < messages.size(); i++) { + String body = messages.get(i).body(); + Map events = objectMapper.readValue(body, Map.class); + List objs = (List)events.get("events"); + assertNotNull(objs); + recordsReceived += objs.size(); + for (int j = 0; j < objs.size(); j++) { + Map event = (Map)objs.get(j); + String name = (String)event.get("name"); + assertTrue(expectedMap.containsKey(name)); + expectedMap.remove(name); + } + } + + assertThat(recordsReceived, equalTo(numRecords)); + }); + assertThat(eventsSuccessCount.get(), equalTo(numRecords - numLargeMessages)); + assertThat(dlqSuccessCount.get(), equalTo(numLargeMessages)); + verify(eventHandle, times(numRecords)).release(true); + } + + + private Collection> getRecordList(int numberOfRecords, boolean randomSize) throws Exception { + final Collection> recordList = new ArrayList<>(); + List records = generateRecords(numberOfRecords, randomSize); + for (int i = 0; i < numberOfRecords; i++) { + final Event event = JacksonLog.builder() + .withData(records.get(i)) + .withEventHandle(eventHandle) + .build(); + if (randomSize) { + long size = jsonCodec.getEstimatedSize(event, new OutputCodecContext()); + if (size > 256*1024) { + numLargeMessages++; + } + } + recordList.add(new Record<>(event)); + } + return recordList; + } + + private List generateRecords(int numberOfRecords, boolean randomSize) { + List recordList = new ArrayList<>(); + + for (int rows = 0; rows < numberOfRecords; rows++) { + HashMap eventData = new HashMap<>(); + eventData.put("name", "Person" + rows); + if (!randomSize) { + eventData.put("age", Integer.toString(rows)); + } else { + int size = random.nextInt(MAX_SIZE); + String ageValue = RandomStringUtils.randomAlphabetic(size); + eventData.put("age", ageValue); + } + eventData.put("id", (rows < NUM_RECORDS) ? "": "10"); + recordList.add(eventData); + + } + return recordList; + } +} diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSink.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSink.java new file mode 100644 index 0000000000..a067358b14 --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSink.java @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.annotations.Experimental; + +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.aws.api.AwsConfig; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; +import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; +import org.opensearch.dataprepper.expression.ExpressionEvaluator; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.sink.AbstractSink; +import org.opensearch.dataprepper.model.sink.Sink; +import org.opensearch.dataprepper.model.sink.SinkContext; +import org.opensearch.dataprepper.plugins.source.sqs.common.SqsClientFactory; +import org.opensearch.dataprepper.model.codec.OutputCodec; +import org.opensearch.dataprepper.plugins.dlq.DlqPushHandler; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.regions.Region; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.Map; +import java.util.Collection; + +@Experimental +@DataPrepperPlugin(name = "sqs", pluginType = Sink.class, pluginConfigurationType = SqsSinkConfig.class) +public class SqsSink extends AbstractSink> { + + private static final Logger LOG = LoggerFactory.getLogger(SqsSink.class); + private static final Duration RETRY_FLUSH_BACKOFF = Duration.ofSeconds(5); + private final SqsSinkConfig sqsSinkConfig; + private volatile boolean sinkInitialized; + private final SqsSinkService sqsSinkService; + + @DataPrepperPluginConstructor + public SqsSink(final PluginSetting pluginSetting, + final PluginMetrics pluginMetrics, + final PluginFactory pluginFactory, + final SqsSinkConfig sqsSinkConfig, + final SinkContext sinkContext, + final ExpressionEvaluator expressionEvaluator, + final AwsCredentialsSupplier awsCredentialsSupplier) { + super(pluginSetting); + this.sqsSinkConfig = sqsSinkConfig; + sinkInitialized = false; + final PluginModel codecConfiguration = sqsSinkConfig.getCodec(); + final PluginSetting codecPluginSettings; + if (codecConfiguration != null) { + String codecPluginName = codecConfiguration.getPluginName(); + if (!codecPluginName.equals("json") && !codecPluginName.equals("ndjson")) { + throw new RuntimeException(String.format("Codec {} not supported.", codecPluginName)); + } + codecPluginSettings = new PluginSetting(codecConfiguration.getPluginName(), + codecConfiguration.getPluginSettings()); + } else { + codecPluginSettings = new PluginSetting("ndjson", Map.of()); + } + + final OutputCodec outputCodec = pluginFactory.loadPlugin(OutputCodec.class, codecPluginSettings); + AwsConfig awsConfig = sqsSinkConfig.getAwsConfig(); + if (awsConfig == null && awsCredentialsSupplier == null) { + throw new RuntimeException("Missing awsConfig and awsCredentialsSupplier"); + } + final AwsCredentialsProvider awsCredentialsProvider = awsConfig != null ? awsCredentialsSupplier.getProvider(convertToCredentialOptions(awsConfig)) : awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder().build()); + Region region = awsConfig != null ? awsConfig.getAwsRegion() : awsCredentialsSupplier.getDefaultRegion().get(); + final SqsClient sqsClient = SqsClientFactory.createSqsClient(region, awsCredentialsProvider); + + DlqPushHandler dlqPushHandler = null; + if (sqsSinkConfig.getDlq() != null) { + StsClient stsClient = StsClient.builder() + .region(region) + .credentialsProvider(awsCredentialsProvider) + .build(); + String role = stsClient.getCallerIdentity().arn(); + dlqPushHandler = new DlqPushHandler(pluginFactory, pluginSetting, pluginMetrics, sqsSinkConfig.getDlq(), region.toString(), role, "sqsSink"); + } + sqsSinkService = new SqsSinkService(sqsSinkConfig, sqsClient, expressionEvaluator, outputCodec, sinkContext, dlqPushHandler, pluginMetrics); + } + + private static AwsCredentialsOptions convertToCredentialOptions(final AwsConfig awsConfig) { + return AwsCredentialsOptions.builder() + .withRegion(awsConfig.getAwsRegion()) + .withStsRoleArn(awsConfig.getAwsStsRoleArn()) + .withStsExternalId(awsConfig.getAwsStsExternalId()) + .withStsHeaderOverrides(awsConfig.getAwsStsHeaderOverrides()) + .build(); + } + + @Override + public boolean isReady() { + return sinkInitialized; + } + + @Override + public void doInitialize() { + sinkInitialized = true; + } + + /** + * @param records Records to be output + */ + @Override + public void doOutput(final Collection> records) { + sqsSinkService.output(records); + } +} + diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatch.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatch.java new file mode 100644 index 0000000000..a8e7a5a335 --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatch.java @@ -0,0 +1,233 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import org.opensearch.dataprepper.model.event.Event; +import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry; +import org.opensearch.dataprepper.model.codec.OutputCodec; +import org.opensearch.dataprepper.model.sink.OutputCodecContext; +import org.opensearch.dataprepper.plugins.accumulator.BufferFactory; + +import software.amazon.awssdk.services.sqs.model.SqsException; +import software.amazon.awssdk.services.sqs.model.RequestThrottledException; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry; +import software.amazon.awssdk.services.sqs.SqsClient; + + +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.function.BiConsumer; + +public class SqsSinkBatch { + public static final int MAX_MESSAGES_PER_BATCH = 10; + public static final int MAX_BATCH_SIZE_BYTES = 256*1024; + private static final String SQS_FIFO_SUFFIX = ".fifo"; + private long lastFlushedTime; + private Map entries; + private boolean flushReady; + private boolean fifoQueue; + private String queueUrl; + private final long maxMessageSize; + private final int maxEvents; + private final OutputCodecContext codecContext; + private final OutputCodec codec; + private final SqsClient sqsClient; + private final BufferFactory bufferFactory; + private final SqsSinkMetrics sinkMetrics; + private SqsSinkBatchEntry currentBatchEntry; + + public SqsSinkBatch(final BufferFactory bufferFactory, + final SqsClient sqsClient, + final SqsSinkMetrics sinkMetrics, + final String queueUrl, + final OutputCodec codec, + final OutputCodecContext codecContext, + final long maxMessageSize, + final int maxEvents) { + this.maxMessageSize = maxMessageSize; + this.bufferFactory = bufferFactory; + this.maxEvents = maxEvents; + this.codec = codec; + this.sinkMetrics = sinkMetrics; + this.codecContext = codecContext; + this.queueUrl = queueUrl; + this.sqsClient = sqsClient; + lastFlushedTime = Instant.now().getEpochSecond(); + flushReady = false; + fifoQueue = queueUrl.endsWith(SQS_FIFO_SUFFIX); + entries = new HashMap<>(); + currentBatchEntry = null; + } + + public String getQueueUrl() { + return queueUrl; + } + + private boolean isFull() { + return entries.size() == MAX_MESSAGES_PER_BATCH && (currentBatchEntry.getEventCount() == maxEvents || currentBatchEntry.getSize() == maxMessageSize); + } + + public boolean willExceedLimits(long estimatedSize) { + if (getCurrentBatchSize() + estimatedSize > MAX_BATCH_SIZE_BYTES) { + return true; + } + if (currentBatchEntry != null) { + if (currentBatchEntry.getEventCount() < maxEvents && + currentBatchEntry.getSize() + estimatedSize <= maxMessageSize) { + return false; + } + } + if (entries.size() == MAX_MESSAGES_PER_BATCH) { + return true; + } + return false; + } + + public boolean addEntry(final Event event, String groupId, String deDupId, final long estimatedSize) throws Exception { + if (currentBatchEntry != null) { + if (currentBatchEntry.getEventCount() < maxEvents && + currentBatchEntry.getSize() + estimatedSize < maxMessageSize) { + currentBatchEntry.addEvent(event); + return isFull(); + } else { + currentBatchEntry.complete(); + } + } + if (entries.size() == MAX_MESSAGES_PER_BATCH) { + throw new RuntimeException("Exceeds max messages per batch"); + } + if (groupId == null) { + groupId = UUID.randomUUID().toString(); + } + if (deDupId == null) { + deDupId = UUID.randomUUID().toString(); + } + currentBatchEntry = new SqsSinkBatchEntry(bufferFactory.getBuffer(), groupId, deDupId, codec, codecContext); + + currentBatchEntry.addEvent(event); + final String id = UUID.randomUUID().toString(); + entries.put(id, currentBatchEntry); + return isFull(); + } + + public long getLastFlushedTime() { + return lastFlushedTime; + } + + public long getCurrentBatchSize() { + long sum = 0; + for (Map.Entry entry : entries.entrySet()) { + sum += entry.getValue().getSize(); + } + return sum; + } + + public int getEventCount() { + return entries.values().stream().mapToInt(SqsSinkBatchEntry::getEventCount).sum(); + } + + public void setFlushReady() throws Exception { + for (Map.Entry entry: entries.entrySet()) { + entry.getValue().complete(); + } + flushReady = true; + } + + public boolean isReady() { + return flushReady; + } + + private SendMessageBatchRequestEntry getRequestEntry(final String id, final SqsSinkBatchEntry entry) { + SendMessageBatchRequestEntry.Builder builder = SendMessageBatchRequestEntry.builder() + .id(id) + .messageBody(entry.getBody()); + if (fifoQueue) { + builder = builder + .messageGroupId(entry.getGroupId()) + .messageDeduplicationId(entry.getDedupId()); + } + return builder.build(); + } + + private boolean isRetryableException(SqsException e) { + return (e instanceof RequestThrottledException); + } + + public boolean flushOnce(final BiConsumer addToDLQList) { + if (!isReady()) { + return true; + } + SendMessageBatchResponse flushResponse; + List requestEntries = new ArrayList<>(); + for (Map.Entry groupEntry: entries.entrySet()) { + final String id = groupEntry.getKey(); + final SqsSinkBatchEntry entry = groupEntry.getValue(); + requestEntries.add(getRequestEntry(id, entry)); + } + SendMessageBatchRequest batchRequest = + SendMessageBatchRequest.builder() + .queueUrl(queueUrl) + .entries(requestEntries) + .build(); + try { + flushResponse = sqsClient.sendMessageBatch(batchRequest); + } catch (SqsException e) { + sinkMetrics.incrementRequestsFailedCounter(1); + sinkMetrics.incrementEventsFailedCounter(entries.size()); + if (!isRetryableException(e)) { + for (Map.Entry entry: entries.entrySet()) { + addToDLQList.accept(entry.getValue(), e.getMessage()); + } + entries.clear(); + flushResponse = null; + return true; + } + return false; + } + sinkMetrics.incrementRequestsSuccessCounter(1); + + boolean flushResult = false; + if (!flushResponse.hasFailed()) { + for (SendMessageBatchRequestEntry entry: requestEntries) { + SqsSinkBatchEntry batchEntry = entries.get(entry.id()); + batchEntry.releaseEventHandles(true); + sinkMetrics.incrementEventsSuccessCounter(batchEntry.getEventCount()); + } + entries.clear(); + } else { + Map newEntries = new HashMap<>(); + sinkMetrics.incrementEventsFailedCounter(flushResponse.failed().size()); + for (BatchResultErrorEntry errorEntry : flushResponse.failed()) { + SqsSinkBatchEntry batchEntry = entries.get(errorEntry.id()); + if (!errorEntry.senderFault()) { + newEntries.put(errorEntry.id(), batchEntry); + } else { + addToDLQList.accept(batchEntry, errorEntry.message()); + } + entries.remove(errorEntry.id()); + } + sinkMetrics.incrementEventsSuccessCounter(entries.size()); + for (Map.Entry entry: entries.entrySet()) { + entry.getValue().releaseEventHandles(true); + } + entries.clear(); + entries = newEntries; + } + lastFlushedTime = Instant.now().getEpochSecond(); + return entries.size() == 0; + } + + public Map getEntries() { + return entries; + } +} + diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchEntry.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchEntry.java new file mode 100644 index 0000000000..8fdee5db33 --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchEntry.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import org.opensearch.dataprepper.model.event.EventHandle; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.codec.OutputCodec; +import org.opensearch.dataprepper.model.sink.OutputCodecContext; +import org.opensearch.dataprepper.plugins.accumulator.Buffer; + +import java.util.ArrayList; +import java.util.List; + +public class SqsSinkBatchEntry { + private final List eventHandles; + private final String groupId; + private final String deDupId; + private final Buffer buffer; + private final OutputCodec codec; + private final OutputCodecContext codecContext; + private OutputCodec.Writer writer; + private int eventCount; + private int size; + private boolean completed; + + public SqsSinkBatchEntry(final Buffer buffer, final String groupId, final String deDupId, final OutputCodec codec, final OutputCodecContext codecContext) { + this.eventHandles = new ArrayList<>(); + this.buffer = buffer; + completed = false; + this.groupId = groupId; + this.deDupId = deDupId; + this.codec = codec; + this.codecContext = codecContext; + + this.eventCount = 0; + size = 0; + } + + public String getBody() { + return buffer.getOutputStream().toString(); + } + + public void releaseEventHandles(boolean result) { + for (EventHandle eventHandle: eventHandles) { + eventHandle.release(result); + } + } + + public void addEvent(final Event event) throws Exception { + if (completed) { + throw new RuntimeException("Batch is completed"); + } + if (eventCount == 0) { + writer = codec.createWriter(buffer.getOutputStream(), null, codecContext); + } + writer.writeEvent(event); + eventHandles.add(event.getEventHandle()); + eventCount++; + } + + public long getSize() { + return buffer.getSize(); + } + + public void complete() throws Exception { + if (completed) { + return; + } + writer.complete(); + completed = true; + } + + + public String getGroupId() { + return groupId; + } + + public String getDedupId() { + return deDupId; + } + + public List getEventHandles() { + return eventHandles; + } + + public int getEventCount() { + return eventCount; + } + +} + diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkConfig.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkConfig.java new file mode 100644 index 0000000000..7c013c37e1 --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkConfig.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import org.opensearch.dataprepper.aws.api.AwsConfig; +import org.opensearch.dataprepper.model.configuration.PluginModel; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.constraints.AssertTrue; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotNull; +import lombok.Getter; + +@Getter +public class SqsSinkConfig { + public static int DEFAULT_MAX_RETRIES = 3; + + @JsonProperty("aws") + @NotNull + @Valid + private AwsConfig awsConfig; + + @JsonProperty("queue_url") + @NotNull + private String queueUrl; + + @JsonProperty("codec") + private PluginModel codec; + + @JsonProperty("threshold") + private SqsThresholdConfig thresholdConfig = new SqsThresholdConfig(); + + @JsonProperty("max_retries") + private int maxRetries = DEFAULT_MAX_RETRIES; + + @JsonProperty("group_id") + private String groupId; + + @JsonProperty("deduplication_id") + private String deDuplicationId; + + @JsonProperty("dlq") + private PluginModel dlq; + + @AssertTrue(message = "FIFO queues wth dynamic group id or dynamic deduplication id and more than one events per message is not valid OR standard queues do not support groupId or deduplication configuration") + boolean isValidConfig() { + String deDupId = getDeDuplicationId(); + String groupId = getGroupId(); + String queueUrl = getQueueUrl(); + boolean isDynamicDeDupId = deDupId != null && deDupId.contains("${"); + boolean isDynamicGroupId = groupId != null && groupId.contains("${"); + boolean isDynamicQueueUrl = queueUrl != null && queueUrl.contains("${"); + if (isDynamicQueueUrl) { + return true; + } + if (getQueueUrl().endsWith(".fifo")) { + if ((isDynamicGroupId || isDynamicDeDupId) && thresholdConfig.getMaxEventsPerMessage() > 1) { + return false; + } else{ + return true; + } + } else { + return (groupId == null && deDupId == null); + } + } + + @AssertTrue(message = "ndjson codec (default codec) doesn't support max events per message greater than 1") + boolean isValidCodecConfig() { + if ((codec == null || codec.getPluginName().equals("ndjson")) && thresholdConfig.getMaxEventsPerMessage() > 1) + return false; + return true; + } +} + diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkDlqData.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkDlqData.java new file mode 100644 index 0000000000..6cd8f0bc66 --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkDlqData.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import java.util.Objects; + +public class SqsSinkDlqData { + private final String message; + private final Object data; + + private SqsSinkDlqData(final String message, final Object data) { + Objects.requireNonNull(message); + this.message = message; + Objects.requireNonNull(data); + this.data = data; + } + + public String getMessage() { + return message; + } + + public Object getData() { + return data; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final SqsSinkDlqData that = (SqsSinkDlqData) o; + return Objects.equals(message, that.message) && + Objects.equals(data, that.data); + } + + @Override + public int hashCode() { + return Objects.hash(message, data); + } + + @Override + public String toString() { + return "SqsSinkDlqData{" + + ", message='" + message + '\'' + + ", data=" + data + + '}'; + } + + public static SqsSinkDlqData createDlqData(final Object data, final String failureMessage) { + return SqsSinkDlqData.builder() + .withData(data) + .withMessage(failureMessage) + .build(); + } + + public static SqsSinkDlqData.Builder builder() { + return new SqsSinkDlqData.Builder(); + } + + public static class Builder { + + private String message; + private Object data; + + public Builder withMessage(final String message) { + this.message = message; + return this; + } + + public Builder withData(final Object data) { + this.data = data; + return this; + } + + public SqsSinkDlqData build() { + return new SqsSinkDlqData(message, data); + } + } +} diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkExecutor.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkExecutor.java new file mode 100644 index 0000000000..ced091c79b --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkExecutor.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import com.linecorp.armeria.client.retry.Backoff; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.Collection; + +import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; + +public abstract class SqsSinkExecutor { + private static final Logger LOG = LoggerFactory.getLogger(SqsSinkExecutor.class); + private static final long INITIAL_DELAY_MS = 10; + private static final long MAXIMUM_DELAY_MS = Duration.ofMinutes(10).toMillis(); + + + public void execute(Collection> records) { + if (records.isEmpty()) { + lock(); + try { + if (exceedsFlushTimeInterval()) { + flushBuffer(); + } + } finally { + unlock(); + } + return; + } + lock(); + try { + for (Record record : records) { + final Event event = record.getData(); + try { + long estimatedSize = getEstimatedSize(event); + if (exceedsMaxEventSizeThreshold(estimatedSize)) { + throw new RuntimeException("Event size exceeds max allowed event size"); + } + if (willExceedMaxBatchSize(event, estimatedSize)) { + flushBuffer(); + } + boolean reachedMaxEventsLimit = addToBuffer(event, estimatedSize); + if (reachedMaxEventsLimit) { + flushBuffer(); + } + } catch (Exception ex) { + addEventToDLQList(event, ex); + } + } + pushDLQList(); + } finally { + unlock(); + } + } + + public void flushBuffer() { + int retryCount = 1; + Object failedStatus = null; + int maxRetries = getMaxRetries(); + final Backoff backoff = Backoff.exponential(INITIAL_DELAY_MS, MAXIMUM_DELAY_MS).withMaxAttempts(maxRetries); + while (retryCount <= maxRetries) { + failedStatus = doFlushOnce(failedStatus); + if (failedStatus != null) { + final long delayMillis = backoff.nextDelayMillis(retryCount); + if (delayMillis < 0) { + break; + } + try { + Thread.sleep(delayMillis); + } catch (final InterruptedException e){ + LOG.error(NOISY, "Thread is interrupted while attempting to SQS with retry.", e); + } + } + retryCount++; + } + if (failedStatus != null) { + pushFailedObjectsToDlq(failedStatus); + } + } + + public abstract void pushFailedObjectsToDlq(Object failedStatus); + public abstract void pushDLQList(); + public abstract void addEventToDLQList(final Event event, Throwable ex); + public abstract Object doFlushOnce(Object failedStatus); + public abstract int getMaxRetries(); + public abstract boolean addToBuffer(final Event event, final long estimatedSize) throws Exception; + public abstract boolean exceedsFlushTimeInterval(); + public abstract boolean willExceedMaxBatchSize(final Event event, final long estimatedSize) throws Exception; + public abstract boolean exceedsMaxEventSizeThreshold(final long estimatedSize); + public abstract long getEstimatedSize(final Event event) throws Exception; + + public abstract void lock(); + public abstract void unlock(); + +} + diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkMetrics.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkMetrics.java new file mode 100644 index 0000000000..7cf5a9a752 --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkMetrics.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import io.micrometer.core.instrument.Counter; +import org.opensearch.dataprepper.metrics.PluginMetrics; + +public class SqsSinkMetrics { + public static final String SQS_SINK_REQUESTS_SUCCEEDED = "sqsSinkRequestsSucceeded"; + public static final String SQS_SINK_EVENTS_SUCCEEDED = "sqsSinkEventsSucceeded"; + public static final String SQS_SINK_EVENTS_FAILED = "sqsSinkEventsFailed"; + public static final String SQS_SINK_REQUESTS_FAILED = "sqsSinkRequestsFailed"; + private final Counter sqsSinkRequestsSucceeded; + private final Counter sqsSinkEventsSucceeded; + private final Counter sqsSinkRequestsFailed; + private final Counter sqsSinkEventsFailed; + + public SqsSinkMetrics(final PluginMetrics pluginMetrics) { + this.sqsSinkRequestsSucceeded = pluginMetrics.counter(SQS_SINK_REQUESTS_SUCCEEDED); + this.sqsSinkEventsSucceeded = pluginMetrics.counter(SQS_SINK_EVENTS_SUCCEEDED); + this.sqsSinkRequestsFailed = pluginMetrics.counter(SQS_SINK_REQUESTS_FAILED); + this.sqsSinkEventsFailed = pluginMetrics.counter(SQS_SINK_EVENTS_FAILED); + } + + public void incrementEventsSuccessCounter(int value) { + sqsSinkEventsSucceeded.increment(value); + } + + public void incrementRequestsSuccessCounter(int value) { + sqsSinkRequestsSucceeded.increment(value); + } + + public void incrementEventsFailedCounter(int value) { + sqsSinkEventsFailed.increment(value); + } + + public void incrementRequestsFailedCounter(int value) { + sqsSinkRequestsFailed.increment(value); + } +} diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkService.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkService.java new file mode 100644 index 0000000000..c3d027a585 --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkService.java @@ -0,0 +1,357 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.expression.ExpressionEvaluator; +import org.opensearch.dataprepper.model.codec.OutputCodec; +import org.opensearch.dataprepper.model.sink.OutputCodecContext; +import org.opensearch.dataprepper.plugins.accumulator.BufferFactory; +import org.opensearch.dataprepper.plugins.accumulator.InMemoryBufferFactory; +import org.opensearch.dataprepper.model.sink.SinkContext; +import software.amazon.awssdk.services.sqs.SqsClient; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.EventHandle; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.dlq.DlqPushHandler; +import org.opensearch.dataprepper.model.failures.DlqObject; + + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.locks.ReentrantLock; + +import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; + +public class SqsSinkService extends SqsSinkExecutor { + private static final Logger LOG = LoggerFactory.getLogger(SqsSinkService.class); + public static final int MAX_BYTES_IN_BATCH = 256*1024; + public static final int MAX_EVENT_SIZE = 256*1024; + + private final Map batchUrlMap; + private final String queueUrl; + private final String groupId; + private final String deDupId; + private final SqsClient sqsClient; + private final boolean isDynamicGroupId; + private final boolean isDynamicDeDupId; + private final boolean isDynamicQueueUrl; + private final ExpressionEvaluator expressionEvaluator; + private final ReentrantLock reentrantLock; + private final SqsThresholdConfig thresholdConfig; + private final SqsSinkConfig sqsSinkConfig; + private final SinkContext sinkContext; + private final OutputCodec codec; + private final BufferFactory inMemoryBufferFactory; + private final SqsSinkMetrics sinkMetrics; + private final DlqPushHandler dlqPushHandler; + private final List dlqObjects; + + public SqsSinkService(final SqsSinkConfig sqsSinkConfig, + final SqsClient sqsClient, + final ExpressionEvaluator expressionEvaluator, + final OutputCodec codec, + final SinkContext sinkContext, + final DlqPushHandler dlqPushHandler, + final PluginMetrics pluginMetrics) { + batchUrlMap = new HashMap<>(); + dlqObjects = new ArrayList<>(); + inMemoryBufferFactory =new InMemoryBufferFactory(); + this.sqsClient = sqsClient; + this.dlqPushHandler = dlqPushHandler; + this.sinkContext = sinkContext; + this.expressionEvaluator = expressionEvaluator; + this.thresholdConfig = sqsSinkConfig.getThresholdConfig(); + this.codec = codec; + this.sqsSinkConfig = sqsSinkConfig; + reentrantLock = new ReentrantLock(); + this.sinkMetrics = new SqsSinkMetrics(pluginMetrics); + + queueUrl = sqsSinkConfig.getQueueUrl(); + isDynamicQueueUrl = queueUrl.contains("${"); + if (isDynamicQueueUrl) { + if (!expressionEvaluator.isValidFormatExpression(queueUrl)) { + throw new IllegalArgumentException("Invalid queue url expression"); + } + } + + groupId = sqsSinkConfig.getGroupId(); + isDynamicGroupId = groupId != null && groupId.contains("${"); + if (isDynamicGroupId) { + if (!expressionEvaluator.isValidFormatExpression(groupId)) { + throw new IllegalArgumentException("Invalid groupId expression"); + } + } + + deDupId = sqsSinkConfig.getDeDuplicationId(); + isDynamicDeDupId = deDupId != null && deDupId.contains("${"); + if (isDynamicDeDupId) { + if (!expressionEvaluator.isValidFormatExpression(deDupId)) { + throw new IllegalArgumentException("Invalid deduplicationId expression"); + } + } + + } + + @Override + public boolean exceedsMaxEventSizeThreshold(final long estimatedSize) { + return estimatedSize > MAX_EVENT_SIZE; + } + + @Override + public void pushDLQList() { + if (dlqObjects.size() == 0) { + return; + } + boolean result = false; + if (dlqPushHandler != null) { + result = dlqPushHandler.perform(dlqObjects); + } + for (final DlqObject dlqObject : dlqObjects) { + dlqObject.releaseEventHandles(result); + } + dlqObjects.clear(); + } + + @Override + public void pushFailedObjectsToDlq(Object object) { + List failedBatches = (List) object; + for (SqsSinkBatch failedBatch: failedBatches) { + for (Map.Entry entry: failedBatch.getEntries().entrySet()) { + addBatchEntryToDLQ(entry.getValue(), "Failed to write to sink after maxRetries"); + } + batchUrlMap.remove(failedBatch.getQueueUrl()); + } + } + + @Override + public long getEstimatedSize(final Event event) throws Exception { + return codec.getEstimatedSize(event, new OutputCodecContext()); + } + + @Override + public boolean willExceedMaxBatchSize(final Event event, final long estimatedSize) throws Exception { + String qUrl = getQueueUrl(event, false); + if (qUrl == null) + return false; + SqsSinkBatch batch = batchUrlMap.get(qUrl); + if (batch == null) + return false; + boolean result = batch.willExceedLimits(estimatedSize); + if (result) { + setFlushReady(qUrl, batch); + } + return result; + } + + private boolean doFlushBatch(SqsSinkBatch batch) { + boolean flushSuccess = batch.flushOnce( + (batchEntry, exceptionMessage ) -> { + addBatchEntryToDLQ(batchEntry, exceptionMessage); + }); + + // Sending to DLQ is also considered success (because no + // retry needed) + return flushSuccess; + } + + @Override + public Object doFlushOnce(Object previousFailedBatches) { + List failedBatches = new ArrayList<>(); + List successQueueUrls = new ArrayList<>(); + if (previousFailedBatches != null) { + List pFailedBatches = (List) previousFailedBatches; + for (SqsSinkBatch failedBatch: pFailedBatches) { + if (!doFlushBatch(failedBatch)) { + failedBatches.add(failedBatch); + } else { + successQueueUrls.add(failedBatch.getQueueUrl()); + } + } + } else { + Iterator> iterator = batchUrlMap.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry qUrlEntry = iterator.next(); + SqsSinkBatch batch = qUrlEntry.getValue(); + if (batch.isReady()) { + if (!doFlushBatch(batch)) { + failedBatches.add(batch); + } else { + successQueueUrls.add(batch.getQueueUrl()); + } + } + } + } + for (final String qUrl : successQueueUrls) { + batchUrlMap.remove(qUrl); + } + if (failedBatches.size() == 0) + return null; + return failedBatches; + } + + private String getQueueUrl(final Event event, boolean logError) { + String qUrl = queueUrl; + if (isDynamicQueueUrl) { + try { + Object obj = event.formatString(queueUrl, expressionEvaluator); + if (obj instanceof String) { + qUrl = (String) obj; + } else { + throw new RuntimeException("Evaluated queue url is not a string"); + } + } catch (Exception e) { + qUrl = null; + if (logError) { + LOG.error(NOISY, "Invalid queueURL expression {} ", e.getMessage()); + addEventToDLQList(event, e); + } + } + } + return qUrl; + } + + private String getGroupId(final Event event) { + String gId = groupId; + if (isDynamicGroupId) { + try { + Object obj = event.formatString(groupId, expressionEvaluator); + if (obj instanceof String) { + gId = (String) obj; + } else { + throw new RuntimeException("Evaluated group id is not a string"); + } + } catch (Exception e) { + LOG.error(NOISY, "Invalid groupId expression {}, using random groupId ", e.getMessage()); + } + } + return gId; + } + + private String getDeDupId(final Event event) { + String ddId = deDupId; + if (isDynamicDeDupId) { + try { + Object obj = event.formatString(deDupId, expressionEvaluator); + if (obj instanceof String) { + ddId = (String) obj; + } else { + throw new RuntimeException("Evaluated deduplicate id is not a string"); + } + } catch (Exception e) { + LOG.error(NOISY, "Invalid deDupId expression {}, using random deDupId ", e.getMessage()); + } + } + return ddId; + } + + + @Override + public int getMaxRetries() { + return sqsSinkConfig.getMaxRetries(); + } + + Map getBatchUrlMap() { + return batchUrlMap; + } + + @Override + public boolean addToBuffer(final Event event, final long estimatedSize) throws Exception { + String qUrl = getQueueUrl(event, true); + if (qUrl == null) { + return false; + } + SqsSinkBatch batch = batchUrlMap.get(qUrl); + if (batch == null) { + final OutputCodecContext codecContext = OutputCodecContext.fromSinkContext(sinkContext); + batch = new SqsSinkBatch(inMemoryBufferFactory, sqsClient, sinkMetrics, qUrl, codec, codecContext, thresholdConfig.getMaxMessageSizeBytes(), thresholdConfig.getMaxEventsPerMessage()); + + batchUrlMap.put(qUrl, batch); + } + String gId = getGroupId(event); + String ddId = getDeDupId(event); + boolean isFull = batch.addEntry(event, gId, ddId, estimatedSize); + if (isFull) { + setFlushReady(qUrl, batch); + } + return isFull; + } + + private boolean setFlushReady(final String queueUrl, final SqsSinkBatch batch) { + try { + batch.setFlushReady(); + return true; + } catch (Exception e) { + for (Map.Entry entry: batch.getEntries().entrySet()) { + addBatchEntryToDLQ(entry.getValue(), "Failed to setFlushReady for the batch"); + } + batchUrlMap.remove(queueUrl); + return false; + } + } + + @Override + public boolean exceedsFlushTimeInterval() { + long now = Instant.now().getEpochSecond(); + boolean result = false; + + Iterator> iterator = batchUrlMap.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry qUrlEntry = iterator.next(); + String qUrl = qUrlEntry.getKey(); + SqsSinkBatch batch = qUrlEntry.getValue(); + if (now - batch.getLastFlushedTime() > thresholdConfig.getFlushInterval()) { + result = result || setFlushReady(qUrl, batch); + } + } + return result; + } + + private void addBatchEntryToDLQ(final SqsSinkBatchEntry batchEntry, final String errorMessage) { + addMessageToDLQ(batchEntry.getBody(), batchEntry.getEventHandles(), errorMessage); + } + + private void addMessageToDLQ(final String message, final List eventHandles, final String errorMessage) { + if (dlqPushHandler != null) { + SqsSinkDlqData sqsSinkDlqData = SqsSinkDlqData.createDlqData(message, errorMessage); + DlqObject dlqObject = DlqObject.createDlqObject(dlqPushHandler.getPluginSetting(), eventHandles, sqsSinkDlqData); + dlqObjects.add(dlqObject); + } else { + for (final EventHandle handle: eventHandles) { + handle.release(false); + } + } + } + + @Override + public void addEventToDLQList(final Event event, Throwable ex) { + List eventHandles = new ArrayList<>(); + eventHandles.add(event.getEventHandle()); + addMessageToDLQ(event.toJsonString(), eventHandles, ex.getMessage()); + } + + @Override + public void lock() { + reentrantLock.lock(); + } + + @Override + public void unlock() { + reentrantLock.unlock(); + } + + void output(Collection> records) { + execute(records); + } +} diff --git a/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsThresholdConfig.java b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsThresholdConfig.java new file mode 100644 index 0000000000..04c9d80a75 --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsThresholdConfig.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.constraints.Size; +import org.hibernate.validator.constraints.time.DurationMax; +import org.hibernate.validator.constraints.time.DurationMin; +import org.opensearch.dataprepper.model.types.ByteCount; + +import java.time.Duration; + +public class SqsThresholdConfig { + public static final int DEFAULT_MESSAGES_PER_EVENT = 25; + public static final ByteCount DEFAULT_MAX_MESSAGE_SIZE = ByteCount.parse("256kb"); + public static final long DEFAULT_FLUSH_INTERVAL_TIME = 30; + + @JsonProperty("max_events_per_message") + @Size(min = 1, max = 1000, message = "batch_size amount should be between 1 to 1000") + private int maxEventsPerMessage = DEFAULT_MESSAGES_PER_EVENT; + + @JsonProperty("max_message_size") + private ByteCount maxMessageSize = DEFAULT_MAX_MESSAGE_SIZE; + + @JsonProperty("flush_interval") + @DurationMin(seconds = 60) + @DurationMax(seconds = 3600) + private Duration flushInterval = Duration.ofSeconds(DEFAULT_FLUSH_INTERVAL_TIME); + + public long getMaxMessageSizeBytes() { + return maxMessageSize.getBytes(); + } + + public int getMaxEventsPerMessage() { + return maxEventsPerMessage; + } + + public long getFlushInterval() { + return flushInterval.getSeconds(); + } + +} + diff --git a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchEntryTest.java b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchEntryTest.java new file mode 100644 index 0000000000..1bd083cec1 --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchEntryTest.java @@ -0,0 +1,152 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.opensearch.dataprepper.model.codec.OutputCodec; +import org.opensearch.dataprepper.model.sink.OutputCodecContext; +import org.opensearch.dataprepper.plugins.accumulator.Buffer; +import org.opensearch.dataprepper.plugins.accumulator.InMemoryBufferFactory; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.log.JacksonLog; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodec; +import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodecConfig; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +public class SqsSinkBatchEntryTest { + private Buffer buffer; + + private OutputCodec outputCodec; + + private OutputCodecContext outputCodecContext; + + private String groupId; + private String deDupId; + private ObjectMapper objectMapper; + + private SqsSinkBatchEntry createObjectUnderTest() { + return new SqsSinkBatchEntry(buffer, groupId, deDupId, outputCodec, outputCodecContext); + } + + @BeforeEach + void setup() { + objectMapper = new ObjectMapper(); + InMemoryBufferFactory inMemoryBufferFactory = new InMemoryBufferFactory(); + buffer = inMemoryBufferFactory.getBuffer(); + outputCodec = new JsonOutputCodec(new JsonOutputCodecConfig()); + outputCodecContext = new OutputCodecContext(); + groupId = UUID.randomUUID().toString(); + deDupId = UUID.randomUUID().toString(); + } + + @Test + void TestBasic() { + SqsSinkBatchEntry sqsSinkBatchEntry = createObjectUnderTest(); + assertThat(sqsSinkBatchEntry.getBody().length(), equalTo(0)); + assertThat(sqsSinkBatchEntry.getSize(), equalTo(0L)); + assertThat(sqsSinkBatchEntry.getEventCount(), equalTo(0)); + assertThat(sqsSinkBatchEntry.getGroupId(), equalTo(groupId)); + assertThat(sqsSinkBatchEntry.getDedupId(), equalTo(deDupId)); + assertTrue(sqsSinkBatchEntry.getEventHandles().isEmpty()); + } + + @Test + void TestAddingOneEvent() throws Exception { + SqsSinkBatchEntry sqsSinkBatchEntry = createObjectUnderTest(); + List> records = getRecordList(1); + Event event = records.get(0).getData(); + sqsSinkBatchEntry.addEvent(event); + sqsSinkBatchEntry.complete(); + String expectedBody = "{\"events\":["+event.toJsonString()+"]}"; + assertThat(sqsSinkBatchEntry.getEventCount(), equalTo(1)); + assertThat(sqsSinkBatchEntry.getBody(), equalTo(expectedBody)); + assertThat(sqsSinkBatchEntry.getSize(), equalTo((long)expectedBody.length())); + assertThat(sqsSinkBatchEntry.getGroupId(), equalTo(groupId)); + assertThat(sqsSinkBatchEntry.getDedupId(), equalTo(deDupId)); + assertThat(sqsSinkBatchEntry.getEventHandles().size(), equalTo(1)); + } + + + @ParameterizedTest + @ValueSource(ints = {10, 25, 57, 73}) + void TestAddingMultipleEvents(int numRecords) throws Exception { + SqsSinkBatchEntry sqsSinkBatchEntry = createObjectUnderTest(); + List> records = getRecordList(numRecords); + long expectedSize = "{\"events\":[]}".length(); + + for (Record record: records) { + Event event = record.getData(); + sqsSinkBatchEntry.addEvent(event); + expectedSize += event.toJsonString().length(); + } + // account for commas + expectedSize += (records.size() - 1); + + sqsSinkBatchEntry.complete(); + final Map expectedMap = new HashMap<>(); + for (int i = 0; i < numRecords; i++) { + expectedMap.put("Person"+i, Integer.toString(i)); + } + Map body = objectMapper.readValue(sqsSinkBatchEntry.getBody(), Map.class); + List> events = (List>) body.get("events"); + assertThat(events.size(), equalTo(numRecords)); + + for (int i = 0; i < numRecords; i++) { + Map eventMap = (Map) events.get(i); + String name = (String)eventMap.get("name"); + assertTrue(expectedMap.containsKey(name)); + assertThat(expectedMap.get(name), equalTo(eventMap.get("age"))); + expectedMap.remove(name); + } + assertThat(expectedMap.size(), equalTo(0)); + assertThat(sqsSinkBatchEntry.getEventCount(), equalTo(numRecords)); + assertThat(sqsSinkBatchEntry.getGroupId(), equalTo(groupId)); + assertThat(sqsSinkBatchEntry.getDedupId(), equalTo(deDupId)); + assertThat(sqsSinkBatchEntry.getSize(), equalTo(expectedSize)); + assertThat(sqsSinkBatchEntry.getEventHandles().size(), equalTo(numRecords)); + } + + private List> getRecordList(int numberOfRecords) { + final List> recordList = new ArrayList<>(); + List records = generateRecords(numberOfRecords); + for (int i = 0; i < numberOfRecords; i++) { + final Event event = JacksonLog.builder().withData(records.get(i)).build(); + recordList.add(new Record<>(event)); + } + return recordList; + } + + private static List generateRecords(int numberOfRecords) { + + List recordList = new ArrayList<>(); + + for (int rows = 0; rows < numberOfRecords; rows++) { + + HashMap eventData = new HashMap<>(); + + eventData.put("name", "Person" + rows); + eventData.put("age", Integer.toString(rows)); + recordList.add(eventData); + + } + return recordList; + } +} diff --git a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchTest.java b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchTest.java new file mode 100644 index 0000000000..19db2d5fab --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkBatchTest.java @@ -0,0 +1,293 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.opensearch.dataprepper.model.codec.OutputCodec; +import org.opensearch.dataprepper.model.sink.OutputCodecContext; +import org.opensearch.dataprepper.plugins.accumulator.Buffer; +import org.opensearch.dataprepper.plugins.accumulator.InMemoryBufferFactory; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.times; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.mockito.ArgumentMatchers.any; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse; + +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.SqsException; +import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry; +import software.amazon.awssdk.services.sqs.model.RequestThrottledException; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.EventHandle; +import org.opensearch.dataprepper.model.log.JacksonLog; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodec; +import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodecConfig; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.hamcrest.Matchers.greaterThan; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; + +public class SqsSinkBatchTest { + @Mock + private EventHandle eventHandle; + @Mock + private SendMessageBatchResponse flushResponse; + @Mock + private SqsClient sqsClient; + @Mock + private SqsException sqsException; + + private AtomicInteger eventsSuccessCount; + private AtomicInteger requestsSuccessCount; + private AtomicInteger eventsFailedCount; + private AtomicInteger requestsFailedCount; + private AtomicInteger dlqSuccessCount; + private Buffer buffer; + + private OutputCodec outputCodec; + private OutputCodecContext outputCodecContext; + + private String groupId; + private String deDupId; + private ObjectMapper objectMapper; + private SqsSinkMetrics sinkMetrics; + private InMemoryBufferFactory bufferFactory; + private long maxMessageSize; + private int maxEvents; + private SqsSinkBatch batch; + private String queueUrl; + + private SqsSinkBatch createObjectUnderTest() { + return new SqsSinkBatch(bufferFactory, sqsClient, sinkMetrics, queueUrl, outputCodec, outputCodecContext, maxMessageSize, maxEvents); + } + + @BeforeEach + void setup() { + eventsSuccessCount = new AtomicInteger(0); + requestsSuccessCount = new AtomicInteger(0); + eventsFailedCount = new AtomicInteger(0); + requestsFailedCount = new AtomicInteger(0); + dlqSuccessCount = new AtomicInteger(0); + objectMapper = new ObjectMapper(); + queueUrl = UUID.randomUUID().toString(); + bufferFactory = new InMemoryBufferFactory(); + sqsClient = mock(SqsClient.class); + sqsException = mock(SqsException.class); + sinkMetrics = mock(SqsSinkMetrics.class); + eventHandle = mock(EventHandle.class); + lenient().doAnswer((a)-> { + int v = (int)(a.getArgument(0)); + eventsSuccessCount.addAndGet(v); + return null; + }).when(sinkMetrics).incrementEventsSuccessCounter(any(Integer.class)); + lenient().doAnswer((a)-> { + int v = (int)(a.getArgument(0)); + eventsFailedCount.addAndGet(v); + return null; + }).when(sinkMetrics).incrementEventsFailedCounter(any(Integer.class)); + lenient().doAnswer((a)-> { + int v = (int)(a.getArgument(0)); + requestsSuccessCount.addAndGet(v); + return null; + }).when(sinkMetrics).incrementRequestsSuccessCounter(any(Integer.class)); + lenient().doAnswer((a)-> { + int v = (int)(a.getArgument(0)); + requestsFailedCount.addAndGet(v); + return null; + }).when(sinkMetrics).incrementRequestsFailedCounter(any(Integer.class)); + flushResponse = mock(SendMessageBatchResponse.class); + outputCodec = new JsonOutputCodec(new JsonOutputCodecConfig()); + outputCodecContext = new OutputCodecContext(); + maxMessageSize = 256 * 1024; + } + + @Test + void TestBasic() { + batch = createObjectUnderTest(); + assertThat(batch.getQueueUrl(), equalTo(queueUrl)); + assertThat(batch.getCurrentBatchSize(), equalTo(0L)); + assertThat(batch.getEventCount(), equalTo(0)); + assertThat(batch.getEntries().size(), equalTo(0)); + } + + @Test + void TestOneBatch_WithOneEventPerMessage_WithSuccessfulSendMessage() throws Exception { + maxEvents = 1; + batch = createObjectUnderTest(); + groupId = UUID.randomUUID().toString(); + String dedupId = UUID.randomUUID().toString(); + + when(flushResponse.hasFailed()).thenReturn(false); + when(sqsClient.sendMessageBatch(any(SendMessageBatchRequest.class))).thenReturn(flushResponse); + final int numRecords = SqsSinkBatch.MAX_MESSAGES_PER_BATCH; + List> records = getRecordList(numRecords); + long minSize = 0; + for (int i = 0; i <= numRecords; i++) { + + Event event = records.get(i%numRecords).getData(); + minSize += event.toJsonString().length(); + long eSize = outputCodec.getEstimatedSize(event, new OutputCodecContext()); + // Make sure trying to add more than max possible records results in an exception + if (i == numRecords) { + assertThrows(RuntimeException.class, () -> batch.addEntry(event, groupId, dedupId, eSize)); + } else { + boolean result = batch.addEntry(event, groupId, dedupId, eSize); + if (i < numRecords-1) { + assertFalse(result); + } else { + assertTrue(result); + } + } + } + assertThat(batch.getEntries().size(), equalTo(numRecords)); + batch.setFlushReady(); + assertTrue(batch.willExceedLimits(1L)); + assertThat(batch.getCurrentBatchSize(), greaterThan(minSize)); + assertThat(batch.getEventCount(), equalTo(SqsSinkBatch.MAX_MESSAGES_PER_BATCH)); + boolean flushResult = batch.flushOnce(null); + assertTrue(flushResult); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo(1)); + verify(eventHandle, times(numRecords)).release(true); + } + + @Test + void TestOneBatch_WithOneEventPerMessage_WithFlushFailure() throws Exception { + maxEvents = 1; + batch = createObjectUnderTest(); + groupId = UUID.randomUUID().toString(); + String dedupId = UUID.randomUUID().toString(); + + SendMessageBatchResponse flushResponse = mock(SendMessageBatchResponse.class); + when(flushResponse.hasFailed()).thenReturn(true); + BatchResultErrorEntry errorEntry = mock(BatchResultErrorEntry.class); + when(errorEntry.id()).thenReturn(UUID.randomUUID().toString()); + when(errorEntry.message()).thenReturn(UUID.randomUUID().toString()); + when(errorEntry.senderFault()).thenReturn(true); + when(flushResponse.failed()).thenReturn(List.of(errorEntry)); + when(sqsClient.sendMessageBatch(any(SendMessageBatchRequest.class))).thenReturn(flushResponse); + final int numRecords = SqsSinkBatch.MAX_MESSAGES_PER_BATCH; + List> records = getRecordList(numRecords); + long minSize = 0; + for (int i = 0; i < numRecords; i++) { + + Event event = records.get(i).getData(); + minSize += event.toJsonString().length(); + long eSize = outputCodec.getEstimatedSize(event, new OutputCodecContext()); + batch.addEntry(event, groupId, dedupId, eSize); + } + assertThat(batch.getEntries().size(), equalTo(numRecords)); + batch.setFlushReady(); + assertTrue(batch.willExceedLimits(1L)); + assertThat(batch.getCurrentBatchSize(), greaterThan(minSize)); + assertThat(batch.getEventCount(), equalTo(SqsSinkBatch.MAX_MESSAGES_PER_BATCH)); + boolean flushResult = batch.flushOnce((e, m) -> {}); + // all entries sent to DLQ + assertTrue(flushResult); + } + + + @Test + void TestOneBatch_WithOneEventPerMessage_WithFlushExceptionFailure() throws Exception { + maxEvents = 1; + batch = createObjectUnderTest(); + groupId = UUID.randomUUID().toString(); + String dedupId = UUID.randomUUID().toString(); + + when(sqsClient.sendMessageBatch(any(SendMessageBatchRequest.class))).thenThrow(RequestThrottledException.builder().build()); + final int numRecords = SqsSinkBatch.MAX_MESSAGES_PER_BATCH; + List> records = getRecordList(numRecords); + long minSize = 0; + for (int i = 0; i < numRecords; i++) { + + Event event = records.get(i).getData(); + minSize += event.toJsonString().length(); + long eSize = outputCodec.getEstimatedSize(event, new OutputCodecContext()); + batch.addEntry(event, groupId, dedupId, eSize); + } + assertThat(batch.getEntries().size(), equalTo(numRecords)); + batch.setFlushReady(); + assertTrue(batch.willExceedLimits(1L)); + assertThat(batch.getCurrentBatchSize(), greaterThan(minSize)); + assertThat(batch.getEventCount(), equalTo(SqsSinkBatch.MAX_MESSAGES_PER_BATCH)); + boolean flushResult = batch.flushOnce((e, m) -> {}); + assertFalse(flushResult); + } + + @Test + void TestOneBatch_WithOneEventPerMessage_WithFlushException() throws Exception { + maxEvents = 1; + batch = createObjectUnderTest(); + groupId = UUID.randomUUID().toString(); + String dedupId = UUID.randomUUID().toString(); + + when(sqsClient.sendMessageBatch(any(SendMessageBatchRequest.class))).thenThrow(sqsException); + final int numRecords = SqsSinkBatch.MAX_MESSAGES_PER_BATCH; + List> records = getRecordList(numRecords); + long minSize = 0; + for (int i = 0; i < numRecords; i++) { + + Event event = records.get(i).getData(); + minSize += event.toJsonString().length(); + long eSize = outputCodec.getEstimatedSize(event, new OutputCodecContext()); + batch.addEntry(event, groupId, dedupId, eSize); + } + assertThat(batch.getEntries().size(), equalTo(numRecords)); + batch.setFlushReady(); + assertTrue(batch.willExceedLimits(1L)); + assertThat(batch.getCurrentBatchSize(), greaterThan(minSize)); + assertThat(batch.getEventCount(), equalTo(SqsSinkBatch.MAX_MESSAGES_PER_BATCH)); + boolean flushResult = batch.flushOnce((e, m) -> {}); + assertTrue(flushResult); + assertThat(eventsFailedCount.get(), equalTo(numRecords)); + assertThat(requestsFailedCount.get(), equalTo(1)); + verify(eventHandle, times(0)).release(true); + } + + private List> getRecordList(int numberOfRecords) { + final List> recordList = new ArrayList<>(); + List records = generateRecords(numberOfRecords); + for (int i = 0; i < numberOfRecords; i++) { + final Event event = JacksonLog.builder().withData(records.get(i)).withEventHandle(eventHandle).build(); + recordList.add(new Record<>(event)); + } + return recordList; + } + + private static List generateRecords(int numberOfRecords) { + + List recordList = new ArrayList<>(); + + for (int rows = 0; rows < numberOfRecords; rows++) { + + HashMap eventData = new HashMap<>(); + + eventData.put("name", "Person" + rows); + eventData.put("age", Integer.toString(rows)); + recordList.add(eventData); + + } + return recordList; + } +} + + diff --git a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkConfigTest.java b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkConfigTest.java new file mode 100644 index 0000000000..cf5e0a6546 --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkConfigTest.java @@ -0,0 +1,144 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.aws.api.AwsConfig; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Field; + +import org.apache.commons.lang3.RandomStringUtils; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class SqsSinkConfigTest { + private SqsSinkConfig sqsSinkConfig; + SqsThresholdConfig sqsThresholdConfig; + + @BeforeEach + void setUp() { + sqsThresholdConfig = mock(SqsThresholdConfig.class); + sqsSinkConfig = new SqsSinkConfig(); + } + + @Test + void TestDefaultConfig() { + assertThat(sqsSinkConfig.getMaxRetries(), equalTo(SqsSinkConfig.DEFAULT_MAX_RETRIES)); + assertThat(sqsSinkConfig.getQueueUrl(), equalTo(null)); + assertThat(sqsSinkConfig.getGroupId(), equalTo(null)); + assertThat(sqsSinkConfig.getDeDuplicationId(), equalTo(null)); + assertThat(sqsSinkConfig.getDlq(), equalTo(null)); + assertThat(sqsSinkConfig.getCodec(), equalTo(null)); + assertThat(sqsSinkConfig.getAwsConfig(), equalTo(null)); + assertThat(sqsSinkConfig.getThresholdConfig().getMaxEventsPerMessage(), equalTo(SqsThresholdConfig.DEFAULT_MESSAGES_PER_EVENT)); + assertThat(sqsSinkConfig.getThresholdConfig().getMaxMessageSizeBytes(), equalTo(SqsThresholdConfig.DEFAULT_MAX_MESSAGE_SIZE.getBytes())); + assertThat(sqsSinkConfig.getThresholdConfig().getFlushInterval(), equalTo(SqsThresholdConfig.DEFAULT_FLUSH_INTERVAL_TIME)); + } + + @Test + private void TestCustomConfig() throws Exception { + AwsConfig awsConfig = mock(AwsConfig.class); + reflectivelySetField(sqsSinkConfig, "awsConfig", awsConfig); + assertThat(sqsSinkConfig.getAwsConfig(), equalTo(awsConfig)); + final int TEST_MAX_RETRIES = 10; + reflectivelySetField(sqsSinkConfig, "maxRetries", TEST_MAX_RETRIES); + assertThat(sqsSinkConfig.getMaxRetries(), equalTo(TEST_MAX_RETRIES)); + final String testQUrl = RandomStringUtils.randomAlphabetic(10); + reflectivelySetField(sqsSinkConfig, "queueUrl", testQUrl); + assertThat(sqsSinkConfig.getQueueUrl(), equalTo(testQUrl)); + final String testGroupId = RandomStringUtils.randomAlphabetic(10); + reflectivelySetField(sqsSinkConfig, "groupId", testGroupId); + assertThat(sqsSinkConfig.getGroupId(), equalTo(testGroupId)); + final String testDeDupId = RandomStringUtils.randomAlphabetic(10); + reflectivelySetField(sqsSinkConfig, "deDuplicationId", testDeDupId); + assertThat(sqsSinkConfig.getDeDuplicationId(), equalTo(testDeDupId)); + reflectivelySetField(sqsSinkConfig, "thresholdConfig", sqsThresholdConfig); + assertThat(sqsSinkConfig.getThresholdConfig(), equalTo(sqsThresholdConfig)); + PluginModel codec = mock(PluginModel.class); + reflectivelySetField(sqsSinkConfig, "codec", codec); + assertThat(sqsSinkConfig.getCodec(), equalTo(codec)); + } + + @Test + void TestValidDynamicQUrlConfigs() throws Exception { + + final String testQUrl = RandomStringUtils.randomAlphabetic(10)+"${"+RandomStringUtils.randomAlphabetic(5)+"}"; + reflectivelySetField(sqsSinkConfig, "queueUrl", testQUrl); + assertTrue(sqsSinkConfig.isValidConfig()); + } + + @Test + void TestValidConfigs() throws Exception { + final String testQUrl = RandomStringUtils.randomAlphabetic(10); + reflectivelySetField(sqsSinkConfig, "queueUrl", testQUrl); + assertTrue(sqsSinkConfig.isValidConfig()); + String testGroupId = RandomStringUtils.randomAlphabetic(10); + reflectivelySetField(sqsSinkConfig, "groupId", testGroupId); + assertFalse(sqsSinkConfig.isValidConfig()); + reflectivelySetField(sqsSinkConfig, "groupId", null); + String testDeDupId = RandomStringUtils.randomAlphabetic(10); + reflectivelySetField(sqsSinkConfig, "deDuplicationId", testDeDupId); + assertFalse(sqsSinkConfig.isValidConfig()); + } + + @Test + void TestValidFiFoQConfigs() throws Exception { + final String testQUrl = RandomStringUtils.randomAlphabetic(10)+".fifo"; + reflectivelySetField(sqsSinkConfig, "queueUrl", testQUrl); + assertTrue(sqsSinkConfig.isValidConfig()); + String testGroupId = RandomStringUtils.randomAlphabetic(10); + reflectivelySetField(sqsSinkConfig, "groupId", testGroupId); + assertTrue(sqsSinkConfig.isValidConfig()); + + testGroupId = RandomStringUtils.randomAlphabetic(10)+"${abcd}"; + reflectivelySetField(sqsSinkConfig, "groupId", testGroupId); + reflectivelySetField(sqsSinkConfig, "thresholdConfig", sqsThresholdConfig); + when(sqsThresholdConfig.getMaxEventsPerMessage()).thenReturn(1); + assertTrue(sqsSinkConfig.isValidConfig()); + when(sqsThresholdConfig.getMaxEventsPerMessage()).thenReturn(2); + assertFalse(sqsSinkConfig.isValidConfig()); + testGroupId = RandomStringUtils.randomAlphabetic(10); + String testDeDupId = RandomStringUtils.randomAlphabetic(10)+"${abcd}"; + reflectivelySetField(sqsSinkConfig, "groupId", testGroupId); + reflectivelySetField(sqsSinkConfig, "deDuplicationId", testDeDupId); + assertFalse(sqsSinkConfig.isValidConfig()); + + } + + @Test + void TestValidCodecConfig() throws Exception { + reflectivelySetField(sqsSinkConfig, "codec", null); + reflectivelySetField(sqsSinkConfig, "thresholdConfig", sqsThresholdConfig); + when(sqsThresholdConfig.getMaxEventsPerMessage()).thenReturn(2); + assertFalse(sqsSinkConfig.isValidCodecConfig()); + when(sqsThresholdConfig.getMaxEventsPerMessage()).thenReturn(1); + assertTrue(sqsSinkConfig.isValidCodecConfig()); + PluginModel codec = mock(PluginModel.class); + when(codec.getPluginName()).thenReturn("ndjson"); + reflectivelySetField(sqsSinkConfig, "codec", codec); + when(sqsThresholdConfig.getMaxEventsPerMessage()).thenReturn(2); + assertFalse(sqsSinkConfig.isValidCodecConfig()); + } + + private void reflectivelySetField(final SqsSinkConfig sqsSinkConfig, final String fieldName, final Object value) throws NoSuchFieldException, IllegalAccessException { + final Field field = SqsSinkConfig.class.getDeclaredField(fieldName); + try { + field.setAccessible(true); + field.set(sqsSinkConfig, value); + } finally { + field.setAccessible(false); + } + } + + +} + diff --git a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkDlqDataTest.java b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkDlqDataTest.java new file mode 100644 index 0000000000..35503b01ba --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkDlqDataTest.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.Matchers.equalTo; + +import java.util.UUID; + +public class SqsSinkDlqDataTest { + private SqsSinkDlqData sqsSinkDlqData; + + @BeforeEach + void setUp() { + } + + SqsSinkDlqData createObjectUnderTest(String message, Object data) { + return SqsSinkDlqData.createDlqData(data, message); + } + + @Test + void TestBasic() { + final String message = UUID.randomUUID().toString(); + final String data = UUID.randomUUID().toString(); + sqsSinkDlqData = createObjectUnderTest(message, data); + assertThat(sqsSinkDlqData.getMessage(), equalTo(message)); + assertThat(sqsSinkDlqData.getData(), equalTo(data)); + assertThat(sqsSinkDlqData.hashCode(), notNullValue()); + assertTrue(sqsSinkDlqData.toString().contains("SqsSinkDlqData{")); + } + + @Test + void TestEquals() { + final String message = UUID.randomUUID().toString(); + final String data = UUID.randomUUID().toString(); + sqsSinkDlqData = createObjectUnderTest(message, data); + SqsSinkDlqData sqsSinkDlqData2 = createObjectUnderTest(message, data); + assertTrue(sqsSinkDlqData.equals(sqsSinkDlqData2)); + } + +} + + diff --git a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkServiceTest.java b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkServiceTest.java new file mode 100644 index 0000000000..f42c4a2d05 --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkServiceTest.java @@ -0,0 +1,468 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import org.opensearch.dataprepper.model.codec.OutputCodec; +import org.opensearch.dataprepper.model.sink.OutputCodecContext; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.plugins.dlq.DlqPushHandler; +import org.opensearch.dataprepper.expression.ExpressionEvaluator; + +import io.micrometer.core.instrument.Counter; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.times; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.not; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.awaitility.Awaitility.await; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse; +import org.opensearch.dataprepper.model.sink.SinkContext; + +import org.apache.commons.lang3.RandomStringUtils; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.RequestThrottledException; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.EventHandle; +import org.opensearch.dataprepper.model.log.JacksonLog; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodec; +import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodecConfig; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; + +public class SqsSinkServiceTest { + @Mock + private SqsClient sqsClient; + @Mock + private SqsSinkConfig sqsSinkConfig; + @Mock + private ExpressionEvaluator expressionEvaluator; + @Mock + private DlqPushHandler dlqPushHandler; + @Mock + private PluginMetrics pluginMetrics; + @Mock + private PluginFactory pluginFactory; + @Mock + private SendMessageBatchResponse flushResponse; + @Mock + private EventHandle eventHandle; + @Mock + private SqsThresholdConfig thresholdConfig; + + @Mock + private SinkContext sinkContext; + + @Mock + private Counter eventsSuccessCounter; + @Mock + private Counter requestsSuccessCounter; + @Mock + private Counter eventsFailedCounter; + @Mock + private Counter requestsFailedCounter; + @Mock + private Counter dlqSuccessCounter; + private AtomicInteger eventsSuccessCount; + private AtomicInteger requestsSuccessCount; + private AtomicInteger eventsFailedCount; + private AtomicInteger requestsFailedCount; + private AtomicInteger dlqSuccessCount; + + private OutputCodec outputCodec; + private OutputCodecContext outputCodecContext; + private String queueUrl; + + private SqsSinkService createObjectUnderTest() { + return new SqsSinkService(sqsSinkConfig, sqsClient, expressionEvaluator, outputCodec, sinkContext, dlqPushHandler, pluginMetrics); + } + + @BeforeEach + void setup() { + sinkContext = mock(SinkContext.class); + pluginFactory = mock(PluginFactory.class); + when(sinkContext.getExcludeKeys()).thenReturn(null); + when(sinkContext.getIncludeKeys()).thenReturn(null); + when(sinkContext.getTagsTargetKey()).thenReturn(null); + eventsSuccessCount = new AtomicInteger(0); + requestsSuccessCount = new AtomicInteger(0); + eventsFailedCount = new AtomicInteger(0); + requestsFailedCount = new AtomicInteger(0); + dlqSuccessCount = new AtomicInteger(0); + outputCodec = new JsonOutputCodec(new JsonOutputCodecConfig()); + when(pluginFactory.loadPlugin(eq(OutputCodec.class), any())).thenReturn(outputCodec); + eventHandle = mock(EventHandle.class); + outputCodecContext = new OutputCodecContext(); + queueUrl = UUID.randomUUID().toString(); + sqsSinkConfig = mock(SqsSinkConfig.class); + thresholdConfig = mock(SqsThresholdConfig.class); + when (sqsSinkConfig.getQueueUrl()).thenReturn(queueUrl); + when (thresholdConfig.getMaxMessageSizeBytes()).thenReturn(256*1024L); + when (thresholdConfig.getMaxEventsPerMessage()).thenReturn(1); + when (sqsSinkConfig.getThresholdConfig()).thenReturn(thresholdConfig); + when (sqsSinkConfig.getMaxRetries()).thenReturn(3); + sqsClient = mock(SqsClient.class); + flushResponse = mock(SendMessageBatchResponse.class); + when(flushResponse.hasFailed()).thenReturn(false); + when(sqsClient.sendMessageBatch(any(SendMessageBatchRequest.class))).thenReturn(flushResponse); + expressionEvaluator = mock(ExpressionEvaluator.class); + when(expressionEvaluator.isValidFormatExpression(anyString())).thenReturn(true); + dlqPushHandler = mock(DlqPushHandler.class); + when(dlqPushHandler.perform(any(List.class))).thenReturn(true); + PluginSetting pluginSetting = mock(PluginSetting.class); + when(pluginSetting.getName()).thenReturn("name"); + when(pluginSetting.getPipelineName()).thenReturn("pipeline"); + when(dlqPushHandler.getPluginSetting()).thenReturn(pluginSetting); + pluginMetrics = mock(PluginMetrics.class); + eventsSuccessCounter = mock(Counter.class); + eventsFailedCounter = mock(Counter.class); + requestsSuccessCounter = mock(Counter.class); + requestsFailedCounter = mock(Counter.class); + dlqSuccessCounter = mock(Counter.class); + lenient().doAnswer((a)-> { + int v = (int)(double)(a.getArgument(0)); + eventsSuccessCount.addAndGet(v); + return null; + }).when(eventsSuccessCounter).increment(any(Double.class)); + lenient().doAnswer((a)-> { + int v = (int)(double)(a.getArgument(0)); + eventsFailedCount.addAndGet(v); + return null; + }).when(eventsFailedCounter).increment(any(Double.class)); + lenient().doAnswer((a)-> { + requestsSuccessCount.addAndGet(1); + return null; + }).when(requestsSuccessCounter).increment(); + lenient().doAnswer((a)-> { + int v = (int)(double)(a.getArgument(0)); + requestsSuccessCount.addAndGet(v); + return null; + }).when(requestsSuccessCounter).increment(any(Double.class)); + lenient().doAnswer((a)-> { + int v = (int)(double)(a.getArgument(0)); + requestsFailedCount.addAndGet(v); + return null; + }).when(requestsFailedCounter).increment(any(Double.class)); + lenient().doAnswer((a)-> { + int v = (int)(double)(a.getArgument(0)); + dlqSuccessCount.addAndGet(v); + return null; + }).when(dlqSuccessCounter).increment(any(Double.class)); + lenient().doAnswer(a -> { + String s = (String)(a.getArgument(0)); + if (s.equals(SqsSinkMetrics.SQS_SINK_REQUESTS_SUCCEEDED)) { + return requestsSuccessCounter; + } + if (s.equals(SqsSinkMetrics.SQS_SINK_EVENTS_SUCCEEDED)) { + return eventsSuccessCounter; + } + if (s.equals(SqsSinkMetrics.SQS_SINK_REQUESTS_FAILED)) { + return requestsFailedCounter; + } + if (s.equals(SqsSinkMetrics.SQS_SINK_EVENTS_FAILED)) { + return eventsFailedCounter; + } + if (s.contains("NumDlqSuccess")) { + return dlqSuccessCounter; + } + return null; + }).when(pluginMetrics).counter(anyString()); + } + + @Test + void TestBasic() { + SqsSinkService sqsSinkService = createObjectUnderTest(); + assertTrue(sqsSinkService.exceedsMaxEventSizeThreshold(256*1024+1)); + assertFalse(sqsSinkService.exceedsMaxEventSizeThreshold(256*1024-1)); + assertFalse(sqsSinkService.exceedsMaxEventSizeThreshold(256*1024)); + } + + @ParameterizedTest + @ValueSource(ints = {9, 29, 49, 69}) + void TestExecuteWithOneBatch_FlushTimeout(int numRecords) throws Exception { + when (thresholdConfig.getFlushInterval()).thenReturn(2L); + SqsSinkService sqsSinkService = createObjectUnderTest(); + List> records = getRecordList(numRecords); + sqsSinkService.execute(records); + await().atMost(Duration.ofSeconds(30)) + .untilAsserted(() -> { + sqsSinkService.execute(Collections.emptyList()); + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo((numRecords+1)/10)); + verify(eventHandle, times(numRecords)).release(true); + }); + } + + @ParameterizedTest + @ValueSource(ints = {10, 30, 50, 70}) + void TestExecuteOneBatch_WithLargeRecords(int numRecords) throws Exception { + SqsSinkService sqsSinkService = createObjectUnderTest(); + List> records = getLargeRecordList(numRecords); + sqsSinkService.execute(records); + await().atMost(Duration.ofSeconds(30)) + .untilAsserted(() -> { + sqsSinkService.execute(Collections.emptyList()); + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo(numRecords)); + verify(eventHandle, times(numRecords)).release(true); + }); + } + + @Test + void TestLargeRecordToDLQ() { + SqsSinkService sqsSinkService = createObjectUnderTest(); + Record record = getLargeRecord(300*1024); + sqsSinkService.execute(List.of(record)); + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(0)); + assertThat(requestsSuccessCount.get(), equalTo(0)); + } + + @ParameterizedTest + @ValueSource(ints = {20, 40, 60, 80}) + void TestExecuteWithOneBatch_SuccessfulFlush_DynamicQUrl(int numRecords) throws Exception { + when (sqsSinkConfig.getQueueUrl()).thenReturn(queueUrl+"${/id}"); + SqsSinkService sqsSinkService = createObjectUnderTest(); + List> records = getRecordList(numRecords); + sqsSinkService.execute(records); + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo(numRecords/10)); + verify(eventHandle, times(numRecords)).release(true); + } + + @ParameterizedTest + @ValueSource(ints = {10, 30, 50, 70}) + void TestExecuteWithOneBatch_SuccessfulFlush(int numRecords) throws Exception { + SqsSinkService sqsSinkService = createObjectUnderTest(); + List> records = getRecordList(numRecords); + sqsSinkService.execute(records); + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo(numRecords/10)); + verify(eventHandle, times(numRecords)).release(true); + } + + @Test + void TestSendingToDLQAfterMultipleRetries() { + final int numRecords = 10; + RequestThrottledException requestThrottledException = mock(RequestThrottledException.class); + when(sqsClient.sendMessageBatch(any(SendMessageBatchRequest.class))).thenThrow(requestThrottledException); + SqsSinkService sqsSinkService = createObjectUnderTest(); + List> records = getRecordList(numRecords); + sqsSinkService.execute(records); + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(0)); + assertThat(requestsSuccessCount.get(), equalTo(0)); + verify(eventHandle, times(numRecords)).release(true); + } + + @ParameterizedTest + @ValueSource(ints = {10, 30, 50, 70}) + void TestExecuteWithOneBatch_MultipleRetries(int numRecords) throws Exception { + RequestThrottledException requestThrottledException = mock(RequestThrottledException.class); + when(sqsClient.sendMessageBatch(any(SendMessageBatchRequest.class))).thenThrow(requestThrottledException).thenReturn(flushResponse); + SqsSinkService sqsSinkService = createObjectUnderTest(); + List> records = getRecordList(numRecords); + sqsSinkService.execute(records); + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo(numRecords/10)); + verify(eventHandle, times(numRecords)).release(true); + } + + @Test + void TestFiFoQWithInvalidDeDupIdExpression() { + when(expressionEvaluator.isValidFormatExpression(anyString())).thenReturn(false); + when (sqsSinkConfig.getQueueUrl()).thenReturn(queueUrl+".fifo"); + when (sqsSinkConfig.getDeDuplicationId()).thenReturn(UUID.randomUUID().toString()+"${/id - }"); + assertThrows(IllegalArgumentException.class, ()-> createObjectUnderTest()); + } + + @Test + void TestFiFoQWithInvalidGroupIdExpression() { + when(expressionEvaluator.isValidFormatExpression(anyString())).thenReturn(false); + when (sqsSinkConfig.getQueueUrl()).thenReturn(queueUrl+".fifo"); + when (sqsSinkConfig.getGroupId()).thenReturn(UUID.randomUUID().toString()+"${/id - }"); + assertThrows(IllegalArgumentException.class, ()-> createObjectUnderTest()); + } + + @ParameterizedTest + @ValueSource(ints = {10, 30, 50, 70}) + void TestWithOneBatch_SuccessfulFlushFiFoQDynamic(int numRecords) throws Exception { + when (sqsSinkConfig.getQueueUrl()).thenReturn(queueUrl+".fifo"); + when (sqsSinkConfig.getGroupId()).thenReturn(UUID.randomUUID().toString()+"${/id}"); + when (sqsSinkConfig.getDeDuplicationId()).thenReturn(UUID.randomUUID().toString()+"${/id}"); + SqsSinkService sqsSinkService = createObjectUnderTest(); + List> records = getRecordList(numRecords); + boolean isFull = false; + for (int i = 0; i < numRecords; i++) { + assertFalse(isFull); + Event event = records.get(i).getData(); + long eSize = outputCodec.getEstimatedSize(event, new OutputCodecContext()); + isFull = sqsSinkService.addToBuffer(event, eSize); + if (isFull) { + Object flushResult = sqsSinkService.doFlushOnce(null); + assertThat(flushResult, equalTo(null)); + isFull = false; + } + } + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo(numRecords/10)); + verify(eventHandle, times(numRecords)).release(true); + } + + + + @ParameterizedTest + @ValueSource(ints = {10, 30, 50, 70}) + void TestWithOneBatch_SuccessfulFlushFiFoQ(int numRecords) throws Exception { + when (sqsSinkConfig.getQueueUrl()).thenReturn(queueUrl+".fifo"); + when (sqsSinkConfig.getGroupId()).thenReturn(UUID.randomUUID().toString()); + when (sqsSinkConfig.getDeDuplicationId()).thenReturn(UUID.randomUUID().toString()); + SqsSinkService sqsSinkService = createObjectUnderTest(); + List> records = getRecordList(numRecords); + boolean isFull = false; + for (int i = 0; i < numRecords; i++) { + assertFalse(isFull); + Event event = records.get(i).getData(); + long eSize = outputCodec.getEstimatedSize(event, new OutputCodecContext()); + isFull = sqsSinkService.addToBuffer(event, eSize); + if (isFull) { + Object flushResult = sqsSinkService.doFlushOnce(null); + assertThat(flushResult, equalTo(null)); + isFull = false; + } + } + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo(numRecords/10)); + verify(eventHandle, times(numRecords)).release(true); + } + + + @ParameterizedTest + @ValueSource(ints = {10, 30, 50, 70}) + void TestWithOneBatch_SuccessfulFlush(int numRecords) throws Exception { + SqsSinkService sqsSinkService = createObjectUnderTest(); + List> records = getRecordList(numRecords); + boolean isFull = false; + for (int i = 0; i < numRecords; i++) { + assertFalse(isFull); + Event event = records.get(i).getData(); + long eSize = outputCodec.getEstimatedSize(event, new OutputCodecContext()); + isFull = sqsSinkService.addToBuffer(event, eSize); + if (isFull) { + Object flushResult = sqsSinkService.doFlushOnce(null); + assertThat(flushResult, equalTo(null)); + isFull = false; + } + } + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo(numRecords/10)); + verify(eventHandle, times(numRecords)).release(true); + } + + @Test + void TestWithOneBatch_RetryFlushes() throws Exception { + RequestThrottledException requestThrottledException = mock(RequestThrottledException.class); + SqsSinkService sqsSinkService = createObjectUnderTest(); + int numRecords = SqsSinkBatch.MAX_MESSAGES_PER_BATCH; + List> records = getRecordList(numRecords); + boolean isFull = false; + for (int i = 0; i < numRecords; i++) { + when(sqsClient.sendMessageBatch(any(SendMessageBatchRequest.class))).thenThrow(requestThrottledException); + assertFalse(isFull); + Event event = records.get(i).getData(); + long eSize = outputCodec.getEstimatedSize(event, new OutputCodecContext()); + isFull = sqsSinkService.addToBuffer(event, eSize); + if (isFull) { + Object flushResult = sqsSinkService.doFlushOnce(null); + assertThat(flushResult, not(equalTo(null))); + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(1)); + when(sqsClient.sendMessageBatch(any(SendMessageBatchRequest.class))).thenReturn(flushResponse); + flushResult = sqsSinkService.doFlushOnce(null); + assertThat(flushResult, equalTo(null)); + } + } + assertThat(sqsSinkService.getBatchUrlMap().size(), equalTo(0)); + assertThat(eventsSuccessCount.get(), equalTo(numRecords)); + assertThat(requestsSuccessCount.get(), equalTo(numRecords/10)); + verify(eventHandle, times(numRecords)).release(true); + } + + private List> getLargeRecordList(int numberOfRecords) { + final List> recordList = new ArrayList<>(); + for (int i = 0; i < numberOfRecords; i++) { + recordList.add(getLargeRecord(245*1024)); + } + return recordList; + } + + private Record getLargeRecord(int size) { + final Event event = JacksonLog.builder() + .withData(Map.of("key", RandomStringUtils.randomAlphabetic(size))) + .withEventHandle(eventHandle) + .build(); + return new Record<>(event); + } + + private List> getRecordList(int numberOfRecords) { + final List> recordList = new ArrayList<>(); + List records = generateRecords(numberOfRecords); + for (int i = 0; i < numberOfRecords; i++) { + final Event event = JacksonLog.builder().withData(records.get(i)).withEventHandle(eventHandle).build(); + recordList.add(new Record<>(event)); + } + return recordList; + } + + private static List generateRecords(int numberOfRecords) { + + List recordList = new ArrayList<>(); + + for (int rows = 0; rows < numberOfRecords; rows++) { + + HashMap eventData = new HashMap<>(); + + eventData.put("name", "Person" + rows); + eventData.put("age", Integer.toString(rows)); + eventData.put("id", Integer.toString(rows%2)); + recordList.add(eventData); + + } + return recordList; + } + +} diff --git a/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkTest.java b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkTest.java new file mode 100644 index 0000000000..b08be2ec0f --- /dev/null +++ b/data-prepper-plugins/sqs-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/sqs/SqsSinkTest.java @@ -0,0 +1,212 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.sqs; + +import org.opensearch.dataprepper.model.codec.OutputCodec; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.sink.SinkContext; +import org.opensearch.dataprepper.expression.ExpressionEvaluator; +import org.mockito.MockedStatic; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; + +import org.opensearch.dataprepper.plugins.source.sqs.common.SqsClientFactory; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sqs.SqsClient; +import org.opensearch.dataprepper.aws.api.AwsConfig; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.JacksonEvent; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodec; +import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodecConfig; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Map; +import java.util.UUID; + +public class SqsSinkTest { + private static final String TEST_CODEC_PLUGIN_NAME = "json"; + private static final String TEST_PLUGIN_NAME = "testPluginName"; + private static final String TEST_PIPELINE_NAME = "testPipelineName"; + @Mock + private SqsSinkConfig sqsSinkConfig; + @Mock + private SinkContext sinkContext; + @Mock + private ExpressionEvaluator expressionEvaluator; + @Mock + private PluginMetrics pluginMetrics; + @Mock + private PluginSetting pluginSetting; + @Mock + private PluginFactory pluginFactory; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + @Mock + private AwsCredentialsProvider awsCredentialsProvider; + @Mock + private AwsConfig awsConfig; + + private SqsClient sqsClient; + private PluginModel codecConfig; + private String queueUrl; + + @BeforeEach + void setup() { + pluginSetting = mock(PluginSetting.class); + pluginMetrics = mock(PluginMetrics.class); + pluginFactory = mock(PluginFactory.class); + sqsSinkConfig = mock(SqsSinkConfig.class); + sinkContext = mock(SinkContext.class); + when(sinkContext.getExcludeKeys()).thenReturn(null); + when(sinkContext.getIncludeKeys()).thenReturn(null); + when(sinkContext.getTagsTargetKey()).thenReturn(null); + sqsClient = mock(SqsClient.class); + expressionEvaluator = mock(ExpressionEvaluator.class); + awsCredentialsSupplier = mock(AwsCredentialsSupplier.class); + awsCredentialsProvider = mock(AwsCredentialsProvider.class); + when(sqsSinkConfig.getDlq()).thenReturn(null); + codecConfig = mock(PluginModel.class); + when(codecConfig.getPluginName()).thenReturn(TEST_CODEC_PLUGIN_NAME); + when(codecConfig.getPluginSettings()).thenReturn(Map.of()); + when(sqsSinkConfig.getCodec()).thenReturn(codecConfig); + queueUrl = UUID.randomUUID().toString(); + when(sqsSinkConfig.getQueueUrl()).thenReturn(queueUrl); + when(pluginFactory.loadPlugin(eq(OutputCodec.class), any())).thenReturn(new JsonOutputCodec(new JsonOutputCodecConfig())); + awsConfig = mock(AwsConfig.class); + when(awsConfig.getAwsRegion()).thenReturn(Region.of("us-west-2")); + when(sqsSinkConfig.getAwsConfig()).thenReturn(awsConfig); + when(pluginSetting.getName()).thenReturn(TEST_PLUGIN_NAME); + when(pluginSetting.getPipelineName()).thenReturn(TEST_PIPELINE_NAME); + + } + SqsSink createObjectUnderTest() { + return new SqsSink(pluginSetting, pluginMetrics, pluginFactory, sqsSinkConfig, sinkContext, expressionEvaluator, awsCredentialsSupplier); + } + + @Test + void TestBasic() { + try(MockedStatic mockedStatic = mockStatic(SqsClientFactory.class)) { + mockedStatic.when(() -> SqsClientFactory.createSqsClient(any(Region.class), + any(AwsCredentialsProvider.class))) + .thenReturn(sqsClient); + + SqsSink sqsSink = createObjectUnderTest(); + sqsSink.doInitialize(); + assertTrue(sqsSink.isReady()); + } + } + + @Test + void TestWithInvalidCodec() { + when(codecConfig.getPluginName()).thenReturn("badCodec"); + awsCredentialsSupplier = null; + when(sqsSinkConfig.getAwsConfig()).thenReturn(null); + try(MockedStatic mockedStatic = mockStatic(SqsClientFactory.class)) { + mockedStatic.when(() -> SqsClientFactory.createSqsClient(any(Region.class), + any(AwsCredentialsProvider.class))) + .thenReturn(sqsClient); + + assertThrows(RuntimeException.class, ()-> createObjectUnderTest()); + } + } + + @Test + void TestWithNullAwsConfig() { + awsCredentialsSupplier = null; + when(sqsSinkConfig.getAwsConfig()).thenReturn(null); + try(MockedStatic mockedStatic = mockStatic(SqsClientFactory.class)) { + mockedStatic.when(() -> SqsClientFactory.createSqsClient(any(Region.class), + any(AwsCredentialsProvider.class))) + .thenReturn(sqsClient); + + assertThrows(RuntimeException.class, ()-> createObjectUnderTest()); + } + } + + @Test + void TestForDefaultCodec() { + when(sqsSinkConfig.getCodec()).thenReturn(codecConfig); + try(MockedStatic mockedStatic = mockStatic(SqsClientFactory.class)) { + mockedStatic.when(() -> SqsClientFactory.createSqsClient(any(Region.class), + any(AwsCredentialsProvider.class))) + .thenReturn(sqsClient); + + SqsSink sqsSink = createObjectUnderTest(); + sqsSink.doInitialize(); + assertTrue(sqsSink.isReady()); + } + } + + @Test + void TestSinkOutputWithEvents() { + try(MockedStatic mockedStatic = mockStatic(SqsClientFactory.class)) { + mockedStatic.when(() -> SqsClientFactory.createSqsClient(any(Region.class), + any(AwsCredentialsProvider.class))) + .thenReturn(sqsClient); + + SqsSink sqsSink = createObjectUnderTest(); + sqsSink.doInitialize(); + Collection> spyEvents = getMockedRecords(); + + sqsSink.doOutput(spyEvents); + + for (Record spyEvent : spyEvents) { + verify(spyEvent, atLeast(1)).getData(); + } + } + } + + @Test + void TestOutputWithEmptyEvents() { + try(MockedStatic mockedStatic = mockStatic(SqsClientFactory.class)) { + mockedStatic.when(() -> SqsClientFactory.createSqsClient(any(Region.class), + any(AwsCredentialsProvider.class))) + .thenReturn(sqsClient); + + SqsSink sqsSink = createObjectUnderTest(); + sqsSink.doInitialize(); + Collection> spyEvents = spy(ArrayList.class); + + assertTrue(spyEvents.isEmpty()); + + sqsSink.doOutput(spyEvents); + verify(spyEvents, times(2)).isEmpty(); + } + } + + Collection> getMockedRecords() { + Collection> testCollection = new ArrayList<>(); + Record mockedEvent = new Record<>(JacksonEvent.fromMessage("")); + Record spyEvent = spy(mockedEvent); + testCollection.add(spyEvent); + return testCollection; + } + + +} + diff --git a/settings.gradle b/settings.gradle index 0fdfc4c88e..744852c78e 100644 --- a/settings.gradle +++ b/settings.gradle @@ -170,6 +170,7 @@ include 'data-prepper-plugins:buffer-common' include 'data-prepper-plugins:sqs-source' include 'data-prepper-plugins:sqs-common' include 'data-prepper-plugins:cloudwatch-logs' +include 'data-prepper-plugins:sqs-sink' //include 'data-prepper-plugins:http-sink' //include 'data-prepper-plugins:sns-sink' //include 'data-prepper-plugins:prometheus-sink'