diff --git a/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsCredentialsSupplier.java b/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsCredentialsSupplier.java index 0b3adf7107..9ad9d2533c 100644 --- a/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsCredentialsSupplier.java +++ b/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsCredentialsSupplier.java @@ -9,6 +9,7 @@ import software.amazon.awssdk.regions.Region; import java.util.Optional; +import java.util.Map; /** * An interface available to plugins via the AWS Plugin Extension which supplies @@ -29,6 +30,12 @@ public interface AwsCredentialsSupplier { */ Optional getDefaultRegion(); + /** + * Gets the default STS header overrides if configured. Otherwise returns empty Optional + * @return Optional containing Map of STS header overrides + */ + Optional> getDefaultStsHeaderOverrides(); + /** * Gets the default STS role ARN if it is configured. Otherwise returns empty Optional * @return Default STS role ARN as String diff --git a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfiguration.java b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfiguration.java index ee22244fa7..3528081c36 100644 --- a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfiguration.java +++ b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfiguration.java @@ -9,6 +9,8 @@ import jakarta.validation.constraints.Size; import software.amazon.awssdk.regions.Region; +import java.util.Map; + public class AwsStsConfiguration { @JsonProperty("region") @@ -19,6 +21,10 @@ public class AwsStsConfiguration { @Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters") private String awsStsRoleArn; + @JsonProperty("sts_header_overrides") + @Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override") + private Map stsHeaderOverrides; + public Region getAwsRegion() { return awsRegion != null ? Region.of(awsRegion) : null; } @@ -26,4 +32,8 @@ public Region getAwsRegion() { public String getAwsStsRoleArn() { return awsStsRoleArn; } + + public Map getStsHeaderOverrides() { + return stsHeaderOverrides; + } } diff --git a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactory.java b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactory.java index 3084f92ae8..2a017475e2 100644 --- a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactory.java +++ b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactory.java @@ -47,6 +47,10 @@ Region getDefaultRegion() { return defaultStsConfiguration.getAwsRegion(); } + Map getDefaultStsHeaderOverrides() { + return defaultStsConfiguration.getStsHeaderOverrides(); + } + String getDefaultStsRoleArn() { return defaultStsConfiguration.getAwsStsRoleArn(); } @@ -85,7 +89,8 @@ private AwsCredentialsProvider createStsCredentials(final AwsCredentialsOptions assumeRoleRequestBuilder = assumeRoleRequestBuilder.externalId(credentialsOptions.getStsExternalId()); } - final Map awsStsHeaderOverrides = credentialsOptions.getStsHeaderOverrides(); + final Map awsStsHeaderOverrides = credentialsOptions.getStsHeaderOverrides() != null ? + credentialsOptions.getStsHeaderOverrides() : defaultStsConfiguration.getStsHeaderOverrides(); if(awsStsHeaderOverrides != null && !awsStsHeaderOverrides.isEmpty()) { assumeRoleRequestBuilder = assumeRoleRequestBuilder diff --git a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplier.java b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplier.java index 00bdb70670..76ba179afc 100644 --- a/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplier.java +++ b/data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplier.java @@ -10,6 +10,7 @@ import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.Region; +import java.util.Map; import java.util.Optional; class DefaultAwsCredentialsSupplier implements AwsCredentialsSupplier { @@ -32,6 +33,11 @@ public Optional getDefaultRegion() { return Optional.ofNullable(credentialsProviderFactory.getDefaultRegion()); } + @Override + public Optional> getDefaultStsHeaderOverrides() { + return Optional.ofNullable(credentialsProviderFactory.getDefaultStsHeaderOverrides()); + } + @Override public Optional getDefaultStsRoleArn() { return Optional.ofNullable(credentialsProviderFactory.getDefaultStsRoleArn()); diff --git a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfigurationTest.java b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfigurationTest.java index f44e4dd932..10228f0a51 100644 --- a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfigurationTest.java +++ b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfigurationTest.java @@ -7,6 +7,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import software.amazon.awssdk.regions.Region; @@ -16,6 +17,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; public class AwsStsConfigurationTest { @@ -34,6 +36,36 @@ void testStsConfiguration(final Region region) throws JsonProcessingException { assertThat(objectUnderTest.getAwsRegion(), equalTo(region)); } + @Test + void testStsConfigurationWithHeaderOverrides() throws JsonProcessingException { + final String configWithHeaderOverrides = + "{\"region\": \"us-west-2\", " + + "\"sts_role_arn\": \"arn:aws:iam::123456789012:role/test-role\", " + + "\"sts_header_overrides\": {\"header1\": \"value1\", \"header2\": \"value2\"}}"; + + final AwsStsConfiguration objectUnderTest = OBJECT_MAPPER.readValue(configWithHeaderOverrides, AwsStsConfiguration.class); + + assertThat(objectUnderTest, notNullValue()); + assertThat(objectUnderTest.getAwsStsRoleArn(), equalTo("arn:aws:iam::123456789012:role/test-role")); + assertThat(objectUnderTest.getAwsRegion(), equalTo(Region.US_WEST_2)); + assertThat(objectUnderTest.getStsHeaderOverrides(), notNullValue()); + assertThat(objectUnderTest.getStsHeaderOverrides().size(), equalTo(2)); + assertThat(objectUnderTest.getStsHeaderOverrides().get("header1"), equalTo("value1")); + assertThat(objectUnderTest.getStsHeaderOverrides().get("header2"), equalTo("value2")); + } + + @Test + void testStsConfigurationWithoutHeaderOverrides() throws JsonProcessingException { + final String configWithoutHeaderOverrides = + "{\"region\": \"us-west-2\", " + + "\"sts_role_arn\": \"arn:aws:iam::123456789012:role/test-role\"}"; + + final AwsStsConfiguration objectUnderTest = OBJECT_MAPPER.readValue(configWithoutHeaderOverrides, AwsStsConfiguration.class); + + assertThat(objectUnderTest, notNullValue()); + assertThat(objectUnderTest.getStsHeaderOverrides(), nullValue()); + } + private static List getRegions() { return Region.regions(); } diff --git a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactoryTest.java b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactoryTest.java index c53b266535..24f4de7444 100644 --- a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactoryTest.java +++ b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactoryTest.java @@ -116,6 +116,33 @@ void getDefaultRegion_returns_expected_region(final Region region) { assertThat(actualRegion, equalTo(region)); } + @Test + void getDefaultStsHeaderOverrides_returns_expected_headers() { + final Map headerOverrides = Map.of( + "header1", "value1", + "header2", "value2", + "custom-header", "custom-value" + ); + when(defaultStsConfiguration.getStsHeaderOverrides()).thenReturn(headerOverrides); + + final CredentialsProviderFactory credentialsProviderFactory = createObjectUnderTest(); + + final Map actualHeaderOverrides = credentialsProviderFactory.getDefaultStsHeaderOverrides(); + + assertThat(actualHeaderOverrides, equalTo(headerOverrides)); + } + + @Test + void getDefaultStsHeaderOverrides_returns_null_when_not_configured() { + when(defaultStsConfiguration.getStsHeaderOverrides()).thenReturn(null); + + final CredentialsProviderFactory credentialsProviderFactory = createObjectUnderTest(); + + final Map actualHeaderOverrides = credentialsProviderFactory.getDefaultStsHeaderOverrides(); + + assertThat(actualHeaderOverrides, nullValue()); + } + @Test void getDefaultStsRoleArn_returns_from_default_configuration() { final String roleArn = "arn:aws:iam::123456789012:role/test-role"; @@ -390,6 +417,47 @@ void providerFromOptions_should_create_StsClient_with_correct_backoff(final Stri assertThat(retryPolicy.throttlingBackoffStrategy(), sameInstance(backoffStrategy)); } + + @Test + void providerFromOptions_should_use_default_STS_Headers_when_credentialsOptions_HeaderOverrides_are_null() { + final String defaultHeaderName = "default-header"; + final String defaultHeaderValue = "default-value"; + final Map defaultHeaders = Map.of(defaultHeaderName, defaultHeaderValue); + + when(awsCredentialsOptions.getRegion()).thenReturn(Region.US_EAST_1); + when(awsCredentialsOptions.getStsRoleArn()).thenReturn(testStsRole); + when(awsCredentialsOptions.getStsHeaderOverrides()).thenReturn(null); + when(defaultStsConfiguration.getStsHeaderOverrides()).thenReturn(defaultHeaders); + + when(stsClientBuilder.region(Region.US_EAST_1)).thenReturn(stsClientBuilder); + + final CredentialsProviderFactory objectUnderTest = createObjectUnderTest(); + final AwsCredentialsProvider actualCredentialsProvider; + try (final MockedStatic stsClientMockedStatic = mockStatic(StsClient.class); + final MockedStatic credentialsProviderMockedStatic = mockStatic(StsAssumeRoleCredentialsProvider.class)) { + stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder); + credentialsProviderMockedStatic.when(StsAssumeRoleCredentialsProvider::builder).thenReturn(stsCredentialsProviderBuilder); + actualCredentialsProvider = objectUnderTest.providerFromOptions(awsCredentialsOptions); + } + + assertThat(actualCredentialsProvider, instanceOf(StsAssumeRoleCredentialsProvider.class)); + + final ArgumentCaptor assumeRoleRequestArgumentCaptor = ArgumentCaptor.forClass(AssumeRoleRequest.class); + verify(stsCredentialsProviderBuilder).refreshRequest(assumeRoleRequestArgumentCaptor.capture()); + + final AssumeRoleRequest actualAssumeRoleRequest = assumeRoleRequestArgumentCaptor.getValue(); + assertThat(actualAssumeRoleRequest.roleArn(), equalTo(testStsRole)); + assertThat(actualAssumeRoleRequest.roleSessionName(), startsWith("Data-Prepper-")); + assertThat(actualAssumeRoleRequest.overrideConfiguration(), notNullValue()); + assertThat(actualAssumeRoleRequest.overrideConfiguration().isPresent(), equalTo(true)); + final AwsRequestOverrideConfiguration overrideConfiguration = actualAssumeRoleRequest.overrideConfiguration().get(); + assertThat(overrideConfiguration.headers(), notNullValue()); + assertThat(overrideConfiguration.headers().size(), equalTo(1)); + assertThat(overrideConfiguration.headers(), hasKey(defaultHeaderName)); + assertThat(overrideConfiguration.headers().get(defaultHeaderName), notNullValue()); + assertThat(overrideConfiguration.headers().get(defaultHeaderName).size(), equalTo(1)); + assertThat(overrideConfiguration.headers().get(defaultHeaderName), hasItem(defaultHeaderValue)); + } } private String createStsRole() { diff --git a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplierTest.java b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplierTest.java index 6f9298cfe5..3b562ac85f 100644 --- a/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplierTest.java +++ b/data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/DefaultAwsCredentialsSupplierTest.java @@ -18,6 +18,7 @@ import software.amazon.awssdk.regions.Region; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.function.Supplier; @@ -85,7 +86,21 @@ void no_default_region_returns_empty_optional() { final AwsCredentialsSupplier objectUnderTest = createObjectUnderTest(); assertThat(objectUnderTest.getDefaultRegion(), equalTo(Optional.empty())); + } + + @Test + void getDefaultStsHeaderOverrides_returns_default_sts_header_overrides() { + final Map headerOverrides = Map.of("header1", "value1", "header2", "value2"); + when(credentialsProviderFactory.getDefaultStsHeaderOverrides()).thenReturn(headerOverrides); + + assertThat(createObjectUnderTest().getDefaultStsHeaderOverrides(), equalTo(Optional.of(headerOverrides))); + } + + @Test + void no_default_sts_header_overrides_returns_empty_optional() { + when(credentialsProviderFactory.getDefaultStsHeaderOverrides()).thenReturn(null); + assertThat(createObjectUnderTest().getDefaultStsHeaderOverrides(), equalTo(Optional.empty())); } @Test