diff --git a/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/annotations/ExtensionDependsOn.java b/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/annotations/ExtensionDependsOn.java new file mode 100644 index 0000000000..9026e42592 --- /dev/null +++ b/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/annotations/ExtensionDependsOn.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.model.annotations; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE}) +public @interface ExtensionDependsOn { + /** + * The list of classes that this extension depends on + * @return Array of Class objects representing the classes this extension depends on + */ + Class[] dependentClasses() default {}; +} diff --git a/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/annotations/ExtensionProvides.java b/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/annotations/ExtensionProvides.java new file mode 100644 index 0000000000..aaf155dde9 --- /dev/null +++ b/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/annotations/ExtensionProvides.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.model.annotations; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation to specify that a class provides an extension and its dependencies. + * This annotation can be used to declare what extension points a class provides + * and what other extension points it depends on. + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE}) +public @interface ExtensionProvides { + /** + * The list of classes that this extension provides + * @return Array of Class objects representing the classes this extension provides + */ + Class[] providedClasses() default {}; +} diff --git a/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/plugin/ExtensionPoints.java b/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/plugin/ExtensionPoints.java index 7f515806b7..35a5f87ce7 100644 --- a/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/plugin/ExtensionPoints.java +++ b/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/plugin/ExtensionPoints.java @@ -20,4 +20,6 @@ public interface ExtensionPoints { * @since 2.3 */ void addExtensionProvider(ExtensionProvider extensionProvider); + + T getExtensionProvider(Class type); } diff --git a/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/DataPrepperExtensionPoints.java b/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/DataPrepperExtensionPoints.java index a66ec890b8..b8222e97e1 100644 --- a/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/DataPrepperExtensionPoints.java +++ b/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/DataPrepperExtensionPoints.java @@ -50,6 +50,12 @@ public void addExtensionProvider(final ExtensionProvider extensionProvider) { providerClassesSet.add(extensionProvider.supportedClass()); } + @Override + public T getExtensionProvider(final Class type) { + sharedApplicationContext.refresh(); + return sharedApplicationContext.getBean(type); + } + private static class EmptyContext implements ExtensionProvider.Context { } diff --git a/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/ExtensionLoader.java b/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/ExtensionLoader.java index d397edf94a..0435e26761 100644 --- a/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/ExtensionLoader.java +++ b/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/ExtensionLoader.java @@ -6,6 +6,8 @@ package org.opensearch.dataprepper.plugin; import org.opensearch.dataprepper.model.annotations.DataPrepperExtensionPlugin; +import org.opensearch.dataprepper.model.annotations.ExtensionDependsOn; +import org.opensearch.dataprepper.model.annotations.ExtensionProvides; import org.opensearch.dataprepper.model.configuration.PipelinesDataFlowModel; import org.opensearch.dataprepper.model.plugin.ExtensionPlugin; import org.opensearch.dataprepper.model.plugin.InvalidPluginConfigurationException; @@ -27,10 +29,17 @@ public class ExtensionLoader { public class ExtensionPluginWithContext { ExtensionPlugin extensionPlugin; boolean configured; + Class[] dependentClasses; + Class[] providedClasses; - public ExtensionPluginWithContext(final ExtensionPlugin extensionPlugin, final boolean isConfigured) { + public ExtensionPluginWithContext(final ExtensionPlugin extensionPlugin, + final boolean isConfigured, + final Class[] dependentClasses, + final Class[] providedClasses) { this.extensionPlugin = extensionPlugin; this.configured = isConfigured; + this.dependentClasses = dependentClasses; + this.providedClasses = providedClasses; } public ExtensionPlugin getExtensionPlugin() { @@ -40,6 +49,14 @@ public ExtensionPlugin getExtensionPlugin() { public boolean isConfigured() { return configured; } + + public Class[] getDependentClasses() { + return dependentClasses; + } + + public Class[] getProvidedClasses() { + return providedClasses; + } } private Comparator extensionsLoaderComparator; @@ -73,9 +90,17 @@ public List loadExtensions() { final String pluginName = convertClassToName(extensionClass); try { final PluginArgumentsContext pluginArgumentsContext = getConstructionContext(extensionClass); + final ExtensionProvides extensionProvidesAnnotation = extensionClass.getAnnotation(ExtensionProvides.class); + final ExtensionDependsOn extensionDependsOnAnnotation = extensionClass.getAnnotation(ExtensionDependsOn.class); + + final Class[] providedClasses = extensionProvidesAnnotation != null ? + extensionProvidesAnnotation.providedClasses() : new Class[]{}; + final Class[] dependentClasses = extensionDependsOnAnnotation != null ? + extensionDependsOnAnnotation.dependentClasses() : new Class[]{}; + final Object config = pluginArgumentsContext.getArgument(0); return new ExtensionPluginWithContext(extensionPluginCreator.newPluginInstance( - extensionClass, pluginArgumentsContext, pluginName), (config != null)); + extensionClass, pluginArgumentsContext, pluginName), (config != null), dependentClasses, providedClasses); } catch (Exception e) { final PluginError pluginError = PluginError.builder() .componentType(PipelinesDataFlowModel.EXTENSION_PLUGIN_TYPE) diff --git a/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/PluginCreatorContext.java b/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/PluginCreatorContext.java index 6c82a31691..b54bdd0db5 100644 --- a/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/PluginCreatorContext.java +++ b/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/PluginCreatorContext.java @@ -4,6 +4,7 @@ import javax.inject.Named; +import java.util.Arrays; import java.util.Comparator; @Named @@ -21,6 +22,35 @@ public PluginCreator pluginCreator( @Bean(name = "extensionsLoaderComparator") public Comparator extensionsLoaderComparator() { - return Comparator.comparing(ExtensionLoader.ExtensionPluginWithContext::isConfigured).reversed(); + return (extensionOne, extensionTwo) -> { + // First, compare by configuration status (configured ones first) + int configCompare = Boolean.compare(extensionTwo.isConfigured(), extensionOne.isConfigured()); + if (configCompare != 0) { + return configCompare; + } + + // Get the provided and dependent classes for both extensions + Class[] extensionOneProvidedClasses = extensionOne.getProvidedClasses(); + Class[] extensionTwoProvidedClasses = extensionTwo.getProvidedClasses(); + Class[] extensionOneDependentClasses = extensionOne.getDependentClasses(); + Class[] extensionTwoDependentClasses = extensionTwo.getDependentClasses(); + + // If extensionOne provides any classes that extensionTwo depends on, extensionOne should go first + if (containsAnyExtensionDependencies(extensionOneProvidedClasses, extensionTwoDependentClasses)) { + return -1; + } + + // If extensionTwo provides any classes that extensionOne depends on, extensionTwo should go first + if (containsAnyExtensionDependencies(extensionTwoProvidedClasses, extensionOneDependentClasses)) { + return 1; + } + + return 0; + }; + } + + private boolean containsAnyExtensionDependencies(final Class[] provided, final Class[] dependencies) { + return Arrays.stream(dependencies).anyMatch(dep -> + Arrays.asList(provided).contains(dep)); } } diff --git a/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/DataPrepperExtensionPointsTest.java b/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/DataPrepperExtensionPointsTest.java index d8b5b220ed..e47cc400ce 100644 --- a/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/DataPrepperExtensionPointsTest.java +++ b/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/DataPrepperExtensionPointsTest.java @@ -127,6 +127,20 @@ void addExtensionProvider_should_registerBean_as_prototype() { verifyRegisterBeanAsPrototype(coreApplicationContext); } + @Test + void getExtensionProvider_refreshes_shared_context_and_returns_correct_bean() { + final Class defaultPluginFactoryClass = DefaultPluginFactory.class; + final DefaultPluginFactory defaultPluginFactory = mock(DefaultPluginFactory.class); + + when(sharedApplicationContext.getBean(defaultPluginFactoryClass)).thenReturn(defaultPluginFactory); + + final DefaultPluginFactory result = createObjectUnderTest().getExtensionProvider(defaultPluginFactoryClass); + + assertThat(result, equalTo(defaultPluginFactory)); + + verify(sharedApplicationContext).refresh(); + } + private void verifyRegisterBeanWithProvideInstance(final GenericApplicationContext applicationContext) { reset(extensionProvider); final ArgumentCaptor> supplierArgumentCaptor = diff --git a/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/ExtensionLoaderTest.java b/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/ExtensionLoaderTest.java index ba07dcd83d..970d777bd6 100644 --- a/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/ExtensionLoaderTest.java +++ b/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/ExtensionLoaderTest.java @@ -18,6 +18,8 @@ import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.model.annotations.ExtensionDependsOn; +import org.opensearch.dataprepper.model.annotations.ExtensionProvides; import org.opensearch.dataprepper.model.configuration.PipelinesDataFlowModel; import org.opensearch.dataprepper.model.plugin.ExtensionPlugin; import org.opensearch.dataprepper.model.plugin.InvalidPluginConfigurationException; @@ -45,6 +47,7 @@ import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; @@ -224,7 +227,7 @@ void loadExtensions_throws_InvalidPluginConfigurationException_when_extensionPlu } @Test - void loadExtensions_returns_multiple_extensions_for_multiple_plugin_classes() { + void loadExtensions_returns_multiple_extensions_for_multiple_plugin_classes_in_correct_order() { final Collection> pluginClasses = new HashSet<>(); final Collection expectedPlugins = new ArrayList<>(); @@ -256,6 +259,13 @@ void loadExtensions_returns_multiple_extensions_for_multiple_plugin_classes() { for (ExtensionPlugin expectedPlugin : actualPlugins) { assertThat(actualPlugins, hasItem(expectedPlugin)); } + + assertThat(actualPlugins.get(0), instanceOf(TestExtension2.class)); + assertTrue(actualPlugins.get(1) instanceof TestExtension1 || actualPlugins.get(1) instanceof TestExtension3, + "Expected result to be either TestExtension1 or TestExtension3 but was " + actualPlugins.get(1).getClass().getName()); + + assertTrue(actualPlugins.get(2) instanceof TestExtension1 || actualPlugins.get(2) instanceof TestExtension3, + "Expected result to be either TestExtension1 or TestExtension3 but was " + actualPlugins.get(2).getClass().getName()); assertThat(pluginErrorCollector.getPluginErrors().isEmpty(), is(true)); } @@ -335,10 +345,15 @@ private static Stream validExtensionConfigs() { null); } + @ExtensionDependsOn(dependentClasses = TestExtensionConfig.class) private interface TestExtension1 extends ExtensionPlugin { } + + @ExtensionProvides(providedClasses = TestExtensionConfig.class) private interface TestExtension2 extends ExtensionPlugin { } + + @ExtensionDependsOn(dependentClasses = TestExtensionConfig.class) private interface TestExtension3 extends ExtensionPlugin { } diff --git a/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/PluginCreatorContextTest.java b/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/PluginCreatorContextTest.java index 33f14d4d98..0d44a6614f 100644 --- a/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/PluginCreatorContextTest.java +++ b/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/PluginCreatorContextTest.java @@ -40,16 +40,29 @@ public void test_pluginCreator() { @Test public void test_extensionsLoaderComparator() { + final Class[] classes = {DefaultPluginFactory.class}; + ExtensionLoader.ExtensionPluginWithContext context1 = mock(ExtensionLoader.ExtensionPluginWithContext.class); ExtensionLoader.ExtensionPluginWithContext context2 = mock(ExtensionLoader.ExtensionPluginWithContext.class); Comparator extensionsLoaderComparator = pluginCreatorContext.extensionsLoaderComparator(); assertNotNull(extensionsLoaderComparator); when(context1.isConfigured()).thenReturn(true); + when(context1.getDependentClasses()).thenReturn(classes); + when(context1.getProvidedClasses()).thenReturn(new Class[]{}); + when(context2.isConfigured()).thenReturn(true); - assertThat(extensionsLoaderComparator.compare(context1, context2), equalTo(0)); + when(context2.getProvidedClasses()).thenReturn(classes); + when(context2.getDependentClasses()).thenReturn(new Class[]{}); + assertThat(extensionsLoaderComparator.compare(context1, context2), equalTo(1)); + when(context1.isConfigured()).thenReturn(false); + when(context1.getDependentClasses()).thenReturn(new Class[]{}); + when(context1.getProvidedClasses()).thenReturn(classes); + when(context2.isConfigured()).thenReturn(false); - assertThat(extensionsLoaderComparator.compare(context1, context2), equalTo(0)); + when(context2.getProvidedClasses()).thenReturn(new Class[]{}); + when(context2.getDependentClasses()).thenReturn(classes); + assertThat(extensionsLoaderComparator.compare(context1, context2), equalTo(-1)); when(context1.isConfigured()).thenReturn(false); when(context2.isConfigured()).thenReturn(true); assertThat(extensionsLoaderComparator.compare(context1, context2), greaterThan(0)); diff --git a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsPlugin.java b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsPlugin.java index 44d2d22931..3d45f34bd6 100644 --- a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsPlugin.java +++ b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsPlugin.java @@ -5,8 +5,10 @@ package org.opensearch.dataprepper.plugins.aws; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.model.annotations.DataPrepperExtensionPlugin; import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; +import org.opensearch.dataprepper.model.annotations.ExtensionProvides; import org.opensearch.dataprepper.model.plugin.ExtensionPlugin; import org.opensearch.dataprepper.model.plugin.ExtensionPoints; @@ -15,6 +17,7 @@ * Data Prepper as an extension plugin. Everything starts from here. */ @DataPrepperExtensionPlugin(modelType = AwsPluginConfig.class, rootKeyJsonPath = "/aws/configurations") +@ExtensionProvides(providedClasses = {AwsCredentialsSupplier.class}) public class AwsPlugin implements ExtensionPlugin { private final DefaultAwsCredentialsSupplier defaultAwsCredentialsSupplier; diff --git a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretManagerConfiguration.java b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretManagerConfiguration.java index 4811719b5b..5a37b1b8a3 100644 --- a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretManagerConfiguration.java +++ b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretManagerConfiguration.java @@ -9,26 +9,19 @@ import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.Size; import org.hibernate.validator.constraints.time.DurationMin; -import software.amazon.awssdk.arns.Arn; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.PutSecretValueRequest; -import software.amazon.awssdk.services.sts.StsClient; -import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; -import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; import java.time.Duration; import java.util.Map; -import java.util.Optional; -import java.util.UUID; public class AwsSecretManagerConfiguration { static final String DEFAULT_AWS_REGION = "us-east-1"; - private static final String AWS_IAM_ROLE = "role"; - private static final String AWS_IAM = "iam"; @JsonProperty("secret_id") @NotNull @@ -71,9 +64,15 @@ public boolean isDisableRefresh() { return disableRefresh; } - public SecretsManagerClient createSecretManagerClient() { + public SecretsManagerClient createSecretManagerClient(final AwsCredentialsSupplier awsCredentialsSupplier) { + final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder() + .withRegion(this.awsRegion) + .withStsRoleArn(this.awsStsRoleArn) + .withStsHeaderOverrides(this.awsStsHeaderOverrides) + .build()); + return SecretsManagerClient.builder() - .credentialsProvider(authenticateAwsConfiguration()) + .credentialsProvider(awsCredentialsProvider) .region(getAwsRegion()) .build(); } @@ -90,56 +89,4 @@ public PutSecretValueRequest putSecretValueRequest(String secretKeyValueMapAsStr .secretString(secretKeyValueMapAsString) .build(); } - - private AwsCredentialsProvider authenticateAwsConfiguration() { - - final AwsCredentialsProvider awsCredentialsProvider; - if (awsStsRoleArn != null && !awsStsRoleArn.isEmpty()) { - - validateStsRoleArn(); - - final StsClient stsClient = StsClient.builder() - .region(getAwsRegion()) - .build(); - - AssumeRoleRequest.Builder assumeRoleRequestBuilder = AssumeRoleRequest.builder() - .roleSessionName("aws-secret-" + UUID.randomUUID()) - .roleArn(awsStsRoleArn); - - if (awsStsHeaderOverrides != null && !awsStsHeaderOverrides.isEmpty()) { - assumeRoleRequestBuilder = assumeRoleRequestBuilder.overrideConfiguration( - configuration -> awsStsHeaderOverrides.forEach(configuration::putHeader)); - } - - awsCredentialsProvider = StsAssumeRoleCredentialsProvider.builder() - .stsClient(stsClient) - .refreshRequest(assumeRoleRequestBuilder.build()) - .build(); - - } else { - // use default credential provider - awsCredentialsProvider = DefaultCredentialsProvider.create(); - } - - return awsCredentialsProvider; - } - - private void validateStsRoleArn() { - final Arn arn = getArn(); - if (!AWS_IAM.equals(arn.service())) { - throw new IllegalArgumentException("sts_role_arn must be an IAM Role"); - } - final Optional resourceType = arn.resource().resourceType(); - if (resourceType.isEmpty() || !resourceType.get().equals(AWS_IAM_ROLE)) { - throw new IllegalArgumentException("sts_role_arn must be an IAM Role"); - } - } - - private Arn getArn() { - try { - return Arn.fromString(awsStsRoleArn); - } catch (final Exception e) { - throw new IllegalArgumentException(String.format("Invalid ARN format for sts_role_arn. Check the format of %s", awsStsRoleArn)); - } - } } diff --git a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretPlugin.java b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretPlugin.java index 288f6d6e2b..f57e657656 100644 --- a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretPlugin.java +++ b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretPlugin.java @@ -6,9 +6,11 @@ package org.opensearch.dataprepper.plugins.aws; import com.fasterxml.jackson.databind.ObjectMapper; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.annotations.DataPrepperExtensionPlugin; import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; +import org.opensearch.dataprepper.model.annotations.ExtensionDependsOn; import org.opensearch.dataprepper.model.plugin.ExtensionPlugin; import org.opensearch.dataprepper.model.plugin.ExtensionPoints; import org.opensearch.dataprepper.model.plugin.PluginConfigPublisher; @@ -23,6 +25,7 @@ @DataPrepperExtensionPlugin(modelType = AwsSecretPluginConfig.class, rootKeyJsonPath = "/aws/secrets", allowInPipelineConfigurations = true) +@ExtensionDependsOn(dependentClasses = {AwsCredentialsSupplier.class}) public class AwsSecretPlugin implements ExtensionPlugin { static final int PERIOD_IN_SECONDS = 60; private static final Logger LOG = LoggerFactory.getLogger(AwsSecretPlugin.class); @@ -31,25 +34,20 @@ public class AwsSecretPlugin implements ExtensionPlugin { private PluginConfigPublisher pluginConfigPublisher; private SecretsSupplier secretsSupplier; private PluginMetrics pluginMetrics; - private final PluginConfigValueTranslator pluginConfigValueTranslator; + private PluginConfigValueTranslator pluginConfigValueTranslator; + + private final AwsSecretPluginConfig awsSecretPluginConfig; @DataPrepperPluginConstructor public AwsSecretPlugin(final AwsSecretPluginConfig awsSecretPluginConfig) { - if (awsSecretPluginConfig != null) { - final SecretValueDecoder secretValueDecoder = new SecretValueDecoder(); - secretsSupplier = new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER); - this.pluginConfigPublisher = new AwsSecretsPluginConfigPublisher(); - pluginConfigValueTranslator = new AwsSecretsPluginConfigValueTranslator(secretsSupplier); - scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(); - pluginMetrics = PluginMetrics.fromNames("secrets", "aws"); - submitSecretsRefreshJobs(awsSecretPluginConfig, secretsSupplier); - } else { - pluginConfigValueTranslator = null; - } + this.awsSecretPluginConfig = awsSecretPluginConfig; } @Override public void apply(final ExtensionPoints extensionPoints) { + final AwsCredentialsSupplier awsCredentialsSupplier = extensionPoints.getExtensionProvider(AwsCredentialsSupplier.class); + initializePluginConfigValueTranslator(awsCredentialsSupplier); + extensionPoints.addExtensionProvider(new AwsSecretsPluginConfigValueTranslatorExtensionProvider(pluginConfigValueTranslator)); extensionPoints.addExtensionProvider(new AwsSecretsPluginConfigPublisherExtensionProvider( pluginConfigPublisher)); @@ -86,4 +84,18 @@ public void shutdown() { } } } + + private void initializePluginConfigValueTranslator(final AwsCredentialsSupplier awsCredentialsSupplier) { + if (awsSecretPluginConfig != null) { + final SecretValueDecoder secretValueDecoder = new SecretValueDecoder(); + secretsSupplier = new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER, awsCredentialsSupplier); + this.pluginConfigPublisher = new AwsSecretsPluginConfigPublisher(); + pluginConfigValueTranslator = new AwsSecretsPluginConfigValueTranslator(secretsSupplier); + scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(); + pluginMetrics = PluginMetrics.fromNames("secrets", "aws"); + submitSecretsRefreshJobs(awsSecretPluginConfig, secretsSupplier); + } else { + pluginConfigValueTranslator = null; + } + } } diff --git a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretsSupplier.java b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretsSupplier.java index 8161893d29..48ca2b354f 100644 --- a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretsSupplier.java +++ b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretsSupplier.java @@ -14,6 +14,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.model.plugin.FailedToUpdatePluginConfigValueException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -39,12 +40,14 @@ public class AwsSecretsSupplier implements SecretsSupplier { public AwsSecretsSupplier( final SecretValueDecoder secretValueDecoder, - final AwsSecretPluginConfig awsSecretPluginConfig, final ObjectMapper objectMapper) { + final AwsSecretPluginConfig awsSecretPluginConfig, + final ObjectMapper objectMapper, + final AwsCredentialsSupplier awsCredentialsSupplier) { this.secretValueDecoder = secretValueDecoder; this.objectMapper = objectMapper; awsSecretManagerConfigurationMap = awsSecretPluginConfig .getAwsSecretManagerConfigurationMap(); - secretsManagerClientMap = toSecretsManagerClientMap(awsSecretPluginConfig); + secretsManagerClientMap = toSecretsManagerClientMap(awsSecretPluginConfig, awsCredentialsSupplier); secretIdToValue = toSecretMap(awsSecretManagerConfigurationMap); } @@ -61,11 +64,12 @@ private ConcurrentMap toSecretMap( } private Map toSecretsManagerClientMap( - final AwsSecretPluginConfig awsSecretPluginConfig) { + final AwsSecretPluginConfig awsSecretPluginConfig, + final AwsCredentialsSupplier awsCredentialsSupplier) { return awsSecretPluginConfig.getAwsSecretManagerConfigurationMap().entrySet().stream() .collect(Collectors.toMap(Map.Entry::getKey, entry -> { final AwsSecretManagerConfiguration awsSecretManagerConfiguration = entry.getValue(); - return awsSecretManagerConfiguration.createSecretManagerClient(); + return awsSecretManagerConfiguration.createSecretManagerClient(awsCredentialsSupplier); })); } diff --git a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsSecretManagerConfigurationTest.java b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsSecretManagerConfigurationTest.java index 1cc3aa849e..c6f188c1f8 100644 --- a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsSecretManagerConfigurationTest.java +++ b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsSecretManagerConfigurationTest.java @@ -28,28 +28,26 @@ import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; -import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClientBuilder; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.PutSecretValueRequest; import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; -import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; import java.io.IOException; import java.io.InputStream; import java.time.Duration; -import java.util.List; import java.util.Set; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; @@ -85,6 +83,9 @@ class AwsSecretManagerConfigurationTest { @Mock private SecretsManagerClient secretsManagerClient; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + @Captor private ArgumentCaptor awsCredentialsProviderArgumentCaptor; @@ -169,27 +170,6 @@ void testPutSecretValueRequest_construct_put_request(String secretValueToStore) verify(putSecretValueRequestBuilder).secretId("test-secret"); } - @Test - void testCreateSecretManagerClientWithDefaultCredential() throws IOException { - final InputStream inputStream = AwsSecretPluginConfigTest.class.getResourceAsStream( - "/test-aws-secret-manager-configuration-default.yaml"); - final AwsSecretManagerConfiguration awsSecretManagerConfiguration = objectMapper.readValue( - inputStream, AwsSecretManagerConfiguration.class); - assertThat(awsSecretManagerConfiguration.getAwsSecretId(), equalTo("test-secret")); - when(secretsManagerClientBuilder.region(any(Region.class))).thenReturn(secretsManagerClientBuilder); - when(secretsManagerClientBuilder.credentialsProvider(any(AwsCredentialsProvider.class))) - .thenReturn(secretsManagerClientBuilder); - when(secretsManagerClientBuilder.build()).thenReturn(secretsManagerClient); - try (final MockedStatic secretsManagerClientMockedStatic = mockStatic( - SecretsManagerClient.class)) { - secretsManagerClientMockedStatic.when(SecretsManagerClient::builder).thenReturn(secretsManagerClientBuilder); - assertThat(awsSecretManagerConfiguration.createSecretManagerClient(), is(secretsManagerClient)); - } - verify(secretsManagerClientBuilder).credentialsProvider(awsCredentialsProviderArgumentCaptor.capture()); - final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsProviderArgumentCaptor.getValue(); - assertThat(awsCredentialsProvider, instanceOf(DefaultCredentialsProvider.class)); - } - @Test void testCreateSecretManagerClientWithStsCredential() throws IOException { final InputStream inputStream = AwsSecretPluginConfigTest.class.getResourceAsStream( @@ -198,78 +178,30 @@ void testCreateSecretManagerClientWithStsCredential() throws IOException { inputStream, AwsSecretManagerConfiguration.class); assertThat(awsSecretManagerConfiguration.getAwsSecretId(), equalTo("test-secret")); when(secretsManagerClientBuilder.region(any(Region.class))).thenReturn(secretsManagerClientBuilder); - when(secretsManagerClientBuilder.credentialsProvider(any(AwsCredentialsProvider.class))) - .thenReturn(secretsManagerClientBuilder); - when(secretsManagerClientBuilder.build()).thenReturn(secretsManagerClient); - try (final MockedStatic secretsManagerClientMockedStatic = mockStatic( - SecretsManagerClient.class)) { - secretsManagerClientMockedStatic.when(SecretsManagerClient::builder).thenReturn(secretsManagerClientBuilder); - assertThat(awsSecretManagerConfiguration.createSecretManagerClient(), is(secretsManagerClient)); - } - verify(secretsManagerClientBuilder).credentialsProvider(awsCredentialsProviderArgumentCaptor.capture()); - final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsProviderArgumentCaptor.getValue(); - assertThat(awsCredentialsProvider, instanceOf(StsAssumeRoleCredentialsProvider.class)); - } - @Test - void testCreateSecretManagerClientWithStsHeaderOverrides() throws IOException { - final InputStream inputStream = AwsSecretPluginConfigTest.class.getResourceAsStream( - "/test-aws-secret-manager-configuration-with-sts-headers.yaml"); - final AwsSecretManagerConfiguration awsSecretManagerConfiguration = objectMapper.readValue( - inputStream, AwsSecretManagerConfiguration.class); - assertThat(awsSecretManagerConfiguration.getAwsSecretId(), equalTo("test-secret")); - final StsAssumeRoleCredentialsProvider.Builder stsAssumeRoleCredentialsProviderBuilder = - mock(StsAssumeRoleCredentialsProvider.Builder.class); - final StsAssumeRoleCredentialsProvider stsAssumeRoleCredentialsProvider = - mock(StsAssumeRoleCredentialsProvider.class); - when(stsAssumeRoleCredentialsProviderBuilder.stsClient(any())) - .thenReturn(stsAssumeRoleCredentialsProviderBuilder); - when(stsAssumeRoleCredentialsProviderBuilder.refreshRequest(any(AssumeRoleRequest.class))) - .thenReturn(stsAssumeRoleCredentialsProviderBuilder); - when(stsAssumeRoleCredentialsProviderBuilder.build()).thenReturn(stsAssumeRoleCredentialsProvider); - when(secretsManagerClientBuilder.region(any(Region.class))).thenReturn(secretsManagerClientBuilder); - when(secretsManagerClientBuilder.credentialsProvider(any(AwsCredentialsProvider.class))) + final StsAssumeRoleCredentialsProvider stsAssumeRoleCredentialsProvider = mock(StsAssumeRoleCredentialsProvider.class); + final ArgumentCaptor awsCredentialsOptionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class); + when(awsCredentialsSupplier.getProvider(awsCredentialsOptionsArgumentCaptor.capture())) + .thenReturn(stsAssumeRoleCredentialsProvider); + + when(secretsManagerClientBuilder.credentialsProvider(stsAssumeRoleCredentialsProvider)) .thenReturn(secretsManagerClientBuilder); when(secretsManagerClientBuilder.build()).thenReturn(secretsManagerClient); try (final MockedStatic secretsManagerClientMockedStatic = mockStatic( - SecretsManagerClient.class); - final MockedStatic stsAssumeRoleCredentialsProviderMockedStatic = - mockStatic(StsAssumeRoleCredentialsProvider.class)) { + SecretsManagerClient.class)) { secretsManagerClientMockedStatic.when(SecretsManagerClient::builder).thenReturn(secretsManagerClientBuilder); - stsAssumeRoleCredentialsProviderMockedStatic.when(StsAssumeRoleCredentialsProvider::builder).thenReturn( - stsAssumeRoleCredentialsProviderBuilder); - assertThat(awsSecretManagerConfiguration.createSecretManagerClient(), is(secretsManagerClient)); + assertThat(awsSecretManagerConfiguration.createSecretManagerClient(awsCredentialsSupplier), is(secretsManagerClient)); } verify(secretsManagerClientBuilder).credentialsProvider(awsCredentialsProviderArgumentCaptor.capture()); final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsProviderArgumentCaptor.getValue(); assertThat(awsCredentialsProvider, instanceOf(StsAssumeRoleCredentialsProvider.class)); - final ArgumentCaptor assumeRoleRequestArgumentCaptor = - ArgumentCaptor.forClass(AssumeRoleRequest.class); - verify(stsAssumeRoleCredentialsProviderBuilder).refreshRequest(assumeRoleRequestArgumentCaptor.capture()); - final AssumeRoleRequest assumeRoleRequest = assumeRoleRequestArgumentCaptor.getValue(); - assertThat(assumeRoleRequest.overrideConfiguration().isPresent(), is(true)); - final AwsRequestOverrideConfiguration awsRequestOverrideConfiguration = assumeRoleRequest - .overrideConfiguration().get(); - assertThat(awsRequestOverrideConfiguration.headers().size(), equalTo(1)); - assertThat(awsRequestOverrideConfiguration.headers().get("test-header"), equalTo(List.of("test-value"))); - } - @ParameterizedTest - @ValueSource(strings = { - "/test-aws-secret-manager-configuration-invalid-sts-1.yaml", - "/test-aws-secret-manager-configuration-invalid-sts-2.yaml", - "/test-aws-secret-manager-configuration-invalid-sts-3.yaml" - }) - void testCreateSecretManagerClientWithInvalidStsRoleArn(final String testFileName) throws IOException { - final InputStream inputStream = AwsSecretPluginConfigTest.class.getResourceAsStream(testFileName); - final AwsSecretManagerConfiguration awsSecretManagerConfiguration = objectMapper.readValue( - inputStream, AwsSecretManagerConfiguration.class); - try (final MockedStatic secretsManagerClientMockedStatic = mockStatic( - SecretsManagerClient.class)) { - secretsManagerClientMockedStatic.when(SecretsManagerClient::builder).thenReturn(secretsManagerClientBuilder); - assertThrows(IllegalArgumentException.class, - () -> awsSecretManagerConfiguration.createSecretManagerClient()); - } + final AwsCredentialsOptions awsCredentialsOptions = awsCredentialsOptionsArgumentCaptor.getValue(); + assertThat(awsCredentialsOptions, notNullValue()); + assertThat(awsCredentialsOptions.getRegion(), equalTo(Region.US_EAST_1)); + assertThat(awsCredentialsOptions.getStsHeaderOverrides(), notNullValue()); + assertThat(awsCredentialsOptions.getStsHeaderOverrides().get("test_key"), equalTo("test_value")); + assertThat(awsCredentialsOptions.getStsRoleArn(), equalTo("arn:aws:iam::123456789012:role/test-role")); } @Test diff --git a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsSecretPluginIT.java b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsSecretPluginIT.java index a9f434d3e8..508839a5fa 100644 --- a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsSecretPluginIT.java +++ b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsSecretPluginIT.java @@ -5,6 +5,7 @@ package org.opensearch.dataprepper.plugins.aws; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; @@ -12,6 +13,7 @@ import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.model.plugin.ExtensionPoints; import org.opensearch.dataprepper.model.plugin.ExtensionProvider; import org.opensearch.dataprepper.model.plugin.PluginConfigPublisher; @@ -71,6 +73,9 @@ class AwsSecretPluginIT { @Mock private ScheduledExecutorService scheduledExecutorService; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + @Captor private ArgumentCaptor initialDelayCaptor; @@ -79,13 +84,18 @@ class AwsSecretPluginIT { private AwsSecretPlugin objectUnderTest; + @BeforeEach + void setUp() { + when(extensionPoints.getExtensionProvider(AwsCredentialsSupplier.class)).thenReturn(awsCredentialsSupplier); + } + @Test void testInitializationWithNonNullConfig() { final Duration testInterval = Duration.ofHours(2); when(awsSecretPluginConfig.getAwsSecretManagerConfigurationMap()).thenReturn( Map.of(TEST_SECRET_CONFIG_ID, awsSecretManagerConfiguration)); when(awsSecretManagerConfiguration.getRefreshInterval()).thenReturn(testInterval); - when(awsSecretManagerConfiguration.createSecretManagerClient()).thenReturn(secretsManagerClient); + when(awsSecretManagerConfiguration.createSecretManagerClient(awsCredentialsSupplier)).thenReturn(secretsManagerClient); when(awsSecretManagerConfiguration.createGetSecretValueRequest()).thenReturn(getSecretValueRequest); when(secretsManagerClient.getSecretValue(eq(getSecretValueRequest))).thenReturn(getSecretValueResponse); when(getSecretValueResponse.secretString()).thenReturn(UUID.randomUUID().toString()); @@ -119,7 +129,7 @@ void testInitializationWithDisableRefresh() { when(awsSecretPluginConfig.getAwsSecretManagerConfigurationMap()).thenReturn( Map.of(TEST_SECRET_CONFIG_ID, awsSecretManagerConfiguration)); when(awsSecretManagerConfiguration.isDisableRefresh()).thenReturn(true); - when(awsSecretManagerConfiguration.createSecretManagerClient()).thenReturn(secretsManagerClient); + when(awsSecretManagerConfiguration.createSecretManagerClient(awsCredentialsSupplier)).thenReturn(secretsManagerClient); when(awsSecretManagerConfiguration.createGetSecretValueRequest()).thenReturn(getSecretValueRequest); when(secretsManagerClient.getSecretValue(eq(getSecretValueRequest))).thenReturn(getSecretValueResponse); when(getSecretValueResponse.secretString()).thenReturn(UUID.randomUUID().toString()); @@ -171,6 +181,7 @@ void testShutdownAwaitTerminationSuccess() throws InterruptedException { executorsMockedStatic.when(Executors::newSingleThreadScheduledExecutor) .thenReturn(scheduledExecutorService); objectUnderTest = new AwsSecretPlugin(awsSecretPluginConfig); + objectUnderTest.apply(extensionPoints); } when(scheduledExecutorService.awaitTermination(anyLong(), any(TimeUnit.class))).thenReturn(true); objectUnderTest.shutdown(); @@ -189,6 +200,7 @@ void testShutdownAwaitTerminationTimeout() throws InterruptedException { executorsMockedStatic.when(Executors::newSingleThreadScheduledExecutor) .thenReturn(scheduledExecutorService); objectUnderTest = new AwsSecretPlugin(awsSecretPluginConfig); + objectUnderTest.apply(extensionPoints); } when(scheduledExecutorService.awaitTermination(anyLong(), any(TimeUnit.class))).thenReturn(false); objectUnderTest.shutdown(); @@ -207,6 +219,7 @@ void testShutdownAwaitTerminationInterrupted() throws InterruptedException { executorsMockedStatic.when(Executors::newSingleThreadScheduledExecutor) .thenReturn(scheduledExecutorService); objectUnderTest = new AwsSecretPlugin(awsSecretPluginConfig); + objectUnderTest.apply(extensionPoints); } when(scheduledExecutorService.awaitTermination(anyLong(), any(TimeUnit.class))) .thenThrow(new InterruptedException()); @@ -224,6 +237,7 @@ void testShutdownWithNullScheduledExecutorService() { executorsMockedStatic.when(Executors::newSingleThreadScheduledExecutor) .thenReturn(scheduledExecutorService); objectUnderTest = new AwsSecretPlugin(null); + objectUnderTest.apply(extensionPoints); } objectUnderTest.shutdown(); verifyNoInteractions(scheduledExecutorService); diff --git a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsSecretsSupplierTest.java b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsSecretsSupplierTest.java index 8fc8a07d33..f72b8c48d6 100644 --- a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsSecretsSupplierTest.java +++ b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsSecretsSupplierTest.java @@ -21,6 +21,7 @@ import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; @@ -75,6 +76,9 @@ class AwsSecretsSupplierTest { @Mock private SecretsManagerException secretsManagerException; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + private AwsSecretsSupplier objectUnderTest; @BeforeEach @@ -83,12 +87,16 @@ void setUp() throws JsonProcessingException { when(awsSecretPluginConfig.getAwsSecretManagerConfigurationMap()).thenReturn( Map.of(TEST_AWS_SECRET_CONFIGURATION_NAME, awsSecretManagerConfiguration) ); - when(awsSecretManagerConfiguration.createSecretManagerClient()).thenReturn(secretsManagerClient); + when(awsSecretManagerConfiguration.createSecretManagerClient(awsCredentialsSupplier)).thenReturn(secretsManagerClient); when(secretValueDecoder.decode(eq(getSecretValueResponse))).thenReturn(OBJECT_MAPPER.writeValueAsString( Map.of(TEST_KEY, TEST_VALUE) )); when(secretsManagerClient.getSecretValue(eq(getSecretValueRequest))).thenReturn(getSecretValueResponse); - objectUnderTest = new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER); + objectUnderTest = createObjectUnderTest(); + } + + private AwsSecretsSupplier createObjectUnderTest() { + return new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER, awsCredentialsSupplier); } @Test @@ -111,7 +119,7 @@ void testRetrieveValueMissingKey() { @Test void testRetrieveValueInvalidKeyValuePair() { when(secretValueDecoder.decode(eq(getSecretValueResponse))).thenReturn(TEST_VALUE); - objectUnderTest = new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER); + objectUnderTest = new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER, awsCredentialsSupplier); final Exception exception = assertThrows(IllegalArgumentException.class, () -> objectUnderTest.retrieveValue(TEST_AWS_SECRET_CONFIGURATION_NAME, TEST_KEY)); assertThat(exception.getMessage(), equalTo(String.format("The value under secretId: %s is not a valid json.", @@ -132,7 +140,7 @@ void testRetrieveValueBySecretIdOnlyNotSerializable() throws JsonProcessingExcep when(mockedObjectMapper.readValue(eq(testValue), eq(MAP_TYPE_REFERENCE))).thenReturn(Map.of("a", "b")); when(mockedObjectMapper.writeValueAsString(ArgumentMatchers.any())).thenThrow(mockedJsonProcessingException); when(secretValueDecoder.decode(eq(getSecretValueResponse))).thenReturn(testValue); - objectUnderTest = new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, mockedObjectMapper); + objectUnderTest = new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, mockedObjectMapper, awsCredentialsSupplier); final Exception exception = assertThrows(IllegalArgumentException.class, () -> objectUnderTest.retrieveValue(TEST_AWS_SECRET_CONFIGURATION_NAME)); assertThat(exception.getMessage(), equalTo(String.format("Unable to read the value under secretId: %s as string.", @@ -143,7 +151,7 @@ void testRetrieveValueBySecretIdOnlyNotSerializable() throws JsonProcessingExcep @ValueSource(strings = {TEST_VALUE, "{\"a\":\"b\"}"}) void testRetrieveValueWithoutKey(String testValue) { when(secretValueDecoder.decode(eq(getSecretValueResponse))).thenReturn(testValue); - objectUnderTest = new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER); + objectUnderTest = createObjectUnderTest(); assertThat(objectUnderTest.retrieveValue(TEST_AWS_SECRET_CONFIGURATION_NAME), equalTo(testValue)); } @@ -151,14 +159,14 @@ void testRetrieveValueWithoutKey(String testValue) { void testConstructorWithGetSecretValueFailure() { when(secretsManagerClient.getSecretValue(eq(getSecretValueRequest))).thenThrow(secretsManagerException); assertThrows(RuntimeException.class, () -> new AwsSecretsSupplier( - secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER)); + secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER, awsCredentialsSupplier)); } @Test void testRefreshSecretsWithKey() { final String testValue = "{\"key\":\"oldValue\"}"; when(secretValueDecoder.decode(eq(getSecretValueResponse))).thenReturn(testValue); - objectUnderTest = new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER); + objectUnderTest = createObjectUnderTest(); assertThat(objectUnderTest.retrieveValue(TEST_AWS_SECRET_CONFIGURATION_NAME, "key"), equalTo("oldValue")); final String newTestValue = "{\"key\":\"newValue\"}"; @@ -172,7 +180,7 @@ void testRefreshSecretsWithKey() { void testRefreshSecretsWithoutKey() { final String testValue = UUID.randomUUID().toString(); when(secretValueDecoder.decode(eq(getSecretValueResponse))).thenReturn(testValue); - objectUnderTest = new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER); + objectUnderTest = createObjectUnderTest(); assertThat(objectUnderTest.retrieveValue(TEST_AWS_SECRET_CONFIGURATION_NAME), equalTo(testValue)); final String newTestValue = testValue + "-mutated"; when(secretValueDecoder.decode(eq(getSecretValueResponse))).thenReturn(newTestValue); @@ -187,7 +195,7 @@ void testUpdateValue_successfully_updated(String valueToSet) { when(secretsManagerClient.putSecretValue(eq(putSecretValueRequest))).thenReturn(putSecretValueResponse); String newVersionId = UUID.randomUUID().toString(); when(putSecretValueResponse.versionId()).thenReturn(newVersionId); - objectUnderTest = new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER); + objectUnderTest = createObjectUnderTest(); assertThat(objectUnderTest.updateValue(TEST_AWS_SECRET_CONFIGURATION_NAME, "key", valueToSet), equalTo(newVersionId)); } @@ -195,7 +203,7 @@ void testUpdateValue_successfully_updated(String valueToSet) { @Test void testUpdateValue_null_key_throws_exception() { when(secretsManagerClient.getSecretValue(eq(getSecretValueRequest))).thenReturn(getSecretValueResponse); - objectUnderTest = new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER); + objectUnderTest = createObjectUnderTest(); assertThrows(IllegalArgumentException.class, () -> objectUnderTest.updateValue(TEST_AWS_SECRET_CONFIGURATION_NAME, "newValue")); } @@ -207,14 +215,14 @@ void testUpdateValue_null_key_doesnot_throws_exception_when_value_is_not_key_val when(awsSecretPluginConfig.getAwsSecretManagerConfigurationMap()).thenReturn( Map.of(TEST_AWS_SECRET_CONFIGURATION_NAME, awsSecretManagerConfiguration) ); - when(awsSecretManagerConfiguration.createSecretManagerClient()).thenReturn(secretsManagerClient); + when(awsSecretManagerConfiguration.createSecretManagerClient(awsCredentialsSupplier)).thenReturn(secretsManagerClient); when(secretValueDecoder.decode(eq(getSecretValueResponse))).thenReturn(TEST_VALUE); when(secretsManagerClient.getSecretValue(eq(getSecretValueRequest))).thenReturn(getSecretValueResponse); when(awsSecretManagerConfiguration.putSecretValueRequest(any())).thenReturn(putSecretValueRequest); when(secretsManagerClient.putSecretValue(eq(putSecretValueRequest))).thenReturn(putSecretValueResponse); String versionId = UUID.randomUUID().toString(); when(putSecretValueResponse.versionId()).thenReturn(versionId); - objectUnderTest = new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER); + objectUnderTest = createObjectUnderTest(); String newValue = objectUnderTest.updateValue(TEST_AWS_SECRET_CONFIGURATION_NAME, secretValueToSet); assertEquals(versionId, newValue); } @@ -226,7 +234,7 @@ void testUpdateValue_failed_to_update() { final String testValue = "{\"key\":\"oldValue\"}"; when(secretValueDecoder.decode(eq(getSecretValueResponse))).thenReturn(testValue); when(putSecretValueResponse.versionId()).thenThrow(RuntimeException.class); - objectUnderTest = new AwsSecretsSupplier(secretValueDecoder, awsSecretPluginConfig, OBJECT_MAPPER); + objectUnderTest = createObjectUnderTest(); assertThrows(RuntimeException.class, () -> objectUnderTest.updateValue(TEST_AWS_SECRET_CONFIGURATION_NAME, "key", "newValue")); } diff --git a/data-prepper-plugins/aws-plugin/src/test/resources/test-aws-secret-manager-configuration-with-sts.yaml b/data-prepper-plugins/aws-plugin/src/test/resources/test-aws-secret-manager-configuration-with-sts.yaml index 8bf929a56b..2c77a394a0 100644 --- a/data-prepper-plugins/aws-plugin/src/test/resources/test-aws-secret-manager-configuration-with-sts.yaml +++ b/data-prepper-plugins/aws-plugin/src/test/resources/test-aws-secret-manager-configuration-with-sts.yaml @@ -1,3 +1,5 @@ secret_id: test-secret region: us-east-1 -sts_role_arn: arn:aws:iam::123456789012:role/test-role \ No newline at end of file +sts_role_arn: arn:aws:iam::123456789012:role/test-role +sts_header_overrides: + test_key: "test_value" \ No newline at end of file