Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import software.amazon.awssdk.regions.Region;

import java.util.Optional;
import java.util.Map;

/**
* An interface available to plugins via the AWS Plugin Extension which supplies
Expand All @@ -29,6 +30,12 @@ public interface AwsCredentialsSupplier {
*/
Optional<Region> getDefaultRegion();

/**
* Gets the default STS header overrides if configured. Otherwise returns empty Optional
* @return Optional containing Map of STS header overrides
*/
Optional<Map<String, String>> getDefaultStsHeaderOverrides();

/**
* Gets the default STS role ARN if it is configured. Otherwise returns empty Optional
* @return Default STS role ARN as String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import jakarta.validation.constraints.Size;
import software.amazon.awssdk.regions.Region;

import java.util.Map;

public class AwsStsConfiguration {

@JsonProperty("region")
Expand All @@ -19,11 +21,19 @@ public class AwsStsConfiguration {
@Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters")
private String awsStsRoleArn;

@JsonProperty("sts_header_overrides")
@Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override")
private Map<String, String> stsHeaderOverrides;

public Region getAwsRegion() {
return awsRegion != null ? Region.of(awsRegion) : null;
}

public String getAwsStsRoleArn() {
return awsStsRoleArn;
}

public Map<String, String> getStsHeaderOverrides() {
return stsHeaderOverrides;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ Region getDefaultRegion() {
return defaultStsConfiguration.getAwsRegion();
}

Map<String, String> getDefaultStsHeaderOverrides() {
return defaultStsConfiguration.getStsHeaderOverrides();
}

String getDefaultStsRoleArn() {
return defaultStsConfiguration.getAwsStsRoleArn();
}
Expand Down Expand Up @@ -85,7 +89,8 @@ private AwsCredentialsProvider createStsCredentials(final AwsCredentialsOptions
assumeRoleRequestBuilder = assumeRoleRequestBuilder.externalId(credentialsOptions.getStsExternalId());
}

final Map<String, String> awsStsHeaderOverrides = credentialsOptions.getStsHeaderOverrides();
final Map<String, String> awsStsHeaderOverrides = credentialsOptions.getStsHeaderOverrides() != null ?
credentialsOptions.getStsHeaderOverrides() : defaultStsConfiguration.getStsHeaderOverrides();

if(awsStsHeaderOverrides != null && !awsStsHeaderOverrides.isEmpty()) {
assumeRoleRequestBuilder = assumeRoleRequestBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;

import java.util.Map;
import java.util.Optional;

class DefaultAwsCredentialsSupplier implements AwsCredentialsSupplier {
Expand All @@ -32,6 +33,11 @@ public Optional<Region> getDefaultRegion() {
return Optional.ofNullable(credentialsProviderFactory.getDefaultRegion());
}

@Override
public Optional<Map<String, String>> getDefaultStsHeaderOverrides() {
return Optional.ofNullable(credentialsProviderFactory.getDefaultStsHeaderOverrides());
}

@Override
public Optional<String> getDefaultStsRoleArn() {
return Optional.ofNullable(credentialsProviderFactory.getDefaultStsRoleArn());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import software.amazon.awssdk.regions.Region;
Expand All @@ -16,6 +17,7 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;

public class AwsStsConfigurationTest {

Expand All @@ -34,6 +36,36 @@ void testStsConfiguration(final Region region) throws JsonProcessingException {
assertThat(objectUnderTest.getAwsRegion(), equalTo(region));
}

@Test
void testStsConfigurationWithHeaderOverrides() throws JsonProcessingException {
final String configWithHeaderOverrides =
"{\"region\": \"us-west-2\", " +
"\"sts_role_arn\": \"arn:aws:iam::123456789012:role/test-role\", " +
"\"sts_header_overrides\": {\"header1\": \"value1\", \"header2\": \"value2\"}}";

final AwsStsConfiguration objectUnderTest = OBJECT_MAPPER.readValue(configWithHeaderOverrides, AwsStsConfiguration.class);

assertThat(objectUnderTest, notNullValue());
assertThat(objectUnderTest.getAwsStsRoleArn(), equalTo("arn:aws:iam::123456789012:role/test-role"));
assertThat(objectUnderTest.getAwsRegion(), equalTo(Region.US_WEST_2));
assertThat(objectUnderTest.getStsHeaderOverrides(), notNullValue());
assertThat(objectUnderTest.getStsHeaderOverrides().size(), equalTo(2));
assertThat(objectUnderTest.getStsHeaderOverrides().get("header1"), equalTo("value1"));
assertThat(objectUnderTest.getStsHeaderOverrides().get("header2"), equalTo("value2"));
}

@Test
void testStsConfigurationWithoutHeaderOverrides() throws JsonProcessingException {
final String configWithoutHeaderOverrides =
"{\"region\": \"us-west-2\", " +
"\"sts_role_arn\": \"arn:aws:iam::123456789012:role/test-role\"}";

final AwsStsConfiguration objectUnderTest = OBJECT_MAPPER.readValue(configWithoutHeaderOverrides, AwsStsConfiguration.class);

assertThat(objectUnderTest, notNullValue());
assertThat(objectUnderTest.getStsHeaderOverrides(), nullValue());
}

private static List<Region> getRegions() {
return Region.regions();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,33 @@ void getDefaultRegion_returns_expected_region(final Region region) {
assertThat(actualRegion, equalTo(region));
}

@Test
void getDefaultStsHeaderOverrides_returns_expected_headers() {
final Map<String, String> headerOverrides = Map.of(
"header1", "value1",
"header2", "value2",
"custom-header", "custom-value"
);
when(defaultStsConfiguration.getStsHeaderOverrides()).thenReturn(headerOverrides);

final CredentialsProviderFactory credentialsProviderFactory = createObjectUnderTest();

final Map<String, String> actualHeaderOverrides = credentialsProviderFactory.getDefaultStsHeaderOverrides();

assertThat(actualHeaderOverrides, equalTo(headerOverrides));
}

@Test
void getDefaultStsHeaderOverrides_returns_null_when_not_configured() {
when(defaultStsConfiguration.getStsHeaderOverrides()).thenReturn(null);

final CredentialsProviderFactory credentialsProviderFactory = createObjectUnderTest();

final Map<String, String> actualHeaderOverrides = credentialsProviderFactory.getDefaultStsHeaderOverrides();

assertThat(actualHeaderOverrides, nullValue());
}

@Test
void getDefaultStsRoleArn_returns_from_default_configuration() {
final String roleArn = "arn:aws:iam::123456789012:role/test-role";
Expand Down Expand Up @@ -390,6 +417,47 @@ void providerFromOptions_should_create_StsClient_with_correct_backoff(final Stri

assertThat(retryPolicy.throttlingBackoffStrategy(), sameInstance(backoffStrategy));
}

@Test
void providerFromOptions_should_use_default_STS_Headers_when_credentialsOptions_HeaderOverrides_are_null() {
final String defaultHeaderName = "default-header";
final String defaultHeaderValue = "default-value";
final Map<String, String> defaultHeaders = Map.of(defaultHeaderName, defaultHeaderValue);

when(awsCredentialsOptions.getRegion()).thenReturn(Region.US_EAST_1);
when(awsCredentialsOptions.getStsRoleArn()).thenReturn(testStsRole);
when(awsCredentialsOptions.getStsHeaderOverrides()).thenReturn(null);
when(defaultStsConfiguration.getStsHeaderOverrides()).thenReturn(defaultHeaders);

when(stsClientBuilder.region(Region.US_EAST_1)).thenReturn(stsClientBuilder);

final CredentialsProviderFactory objectUnderTest = createObjectUnderTest();
final AwsCredentialsProvider actualCredentialsProvider;
try (final MockedStatic<StsClient> stsClientMockedStatic = mockStatic(StsClient.class);
final MockedStatic<StsAssumeRoleCredentialsProvider> credentialsProviderMockedStatic = mockStatic(StsAssumeRoleCredentialsProvider.class)) {
stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder);
credentialsProviderMockedStatic.when(StsAssumeRoleCredentialsProvider::builder).thenReturn(stsCredentialsProviderBuilder);
actualCredentialsProvider = objectUnderTest.providerFromOptions(awsCredentialsOptions);
}

assertThat(actualCredentialsProvider, instanceOf(StsAssumeRoleCredentialsProvider.class));

final ArgumentCaptor<AssumeRoleRequest> assumeRoleRequestArgumentCaptor = ArgumentCaptor.forClass(AssumeRoleRequest.class);
verify(stsCredentialsProviderBuilder).refreshRequest(assumeRoleRequestArgumentCaptor.capture());

final AssumeRoleRequest actualAssumeRoleRequest = assumeRoleRequestArgumentCaptor.getValue();
assertThat(actualAssumeRoleRequest.roleArn(), equalTo(testStsRole));
assertThat(actualAssumeRoleRequest.roleSessionName(), startsWith("Data-Prepper-"));
assertThat(actualAssumeRoleRequest.overrideConfiguration(), notNullValue());
assertThat(actualAssumeRoleRequest.overrideConfiguration().isPresent(), equalTo(true));
final AwsRequestOverrideConfiguration overrideConfiguration = actualAssumeRoleRequest.overrideConfiguration().get();
assertThat(overrideConfiguration.headers(), notNullValue());
assertThat(overrideConfiguration.headers().size(), equalTo(1));
assertThat(overrideConfiguration.headers(), hasKey(defaultHeaderName));
assertThat(overrideConfiguration.headers().get(defaultHeaderName), notNullValue());
assertThat(overrideConfiguration.headers().get(defaultHeaderName).size(), equalTo(1));
assertThat(overrideConfiguration.headers().get(defaultHeaderName), hasItem(defaultHeaderValue));
}
}

private String createStsRole() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import software.amazon.awssdk.regions.Region;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

Expand Down Expand Up @@ -85,7 +86,21 @@ void no_default_region_returns_empty_optional() {

final AwsCredentialsSupplier objectUnderTest = createObjectUnderTest();
assertThat(objectUnderTest.getDefaultRegion(), equalTo(Optional.empty()));
}

@Test
void getDefaultStsHeaderOverrides_returns_default_sts_header_overrides() {
final Map<String, String> headerOverrides = Map.of("header1", "value1", "header2", "value2");
when(credentialsProviderFactory.getDefaultStsHeaderOverrides()).thenReturn(headerOverrides);

assertThat(createObjectUnderTest().getDefaultStsHeaderOverrides(), equalTo(Optional.of(headerOverrides)));
}

@Test
void no_default_sts_header_overrides_returns_empty_optional() {
when(credentialsProviderFactory.getDefaultStsHeaderOverrides()).thenReturn(null);

assertThat(createObjectUnderTest().getDefaultStsHeaderOverrides(), equalTo(Optional.empty()));
}

@Test
Expand Down
Loading