Skip to content

Commit 97158ab

Browse files
authored
Adds sts_header_overrides to the AWS plugin extension configuration. Resolves #6078. (#6080)
Signed-off-by: David Venable <dlv@amazon.com>
1 parent 118c303 commit 97158ab

4 files changed

Lines changed: 121 additions & 6 deletions

File tree

data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfiguration.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
import jakarta.validation.constraints.Size;
1010
import software.amazon.awssdk.regions.Region;
1111

12-
public class AwsStsConfiguration {
12+
import java.util.Map;
1313

14+
public class AwsStsConfiguration {
1415
@JsonProperty("region")
1516
@Size(min = 1, message = "Region cannot be empty string")
1617
private String awsRegion;
@@ -19,11 +20,19 @@ public class AwsStsConfiguration {
1920
@Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters")
2021
private String awsStsRoleArn;
2122

23+
@JsonProperty("sts_header_overrides")
24+
@Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override")
25+
private Map<String, String> awsStsHeaderOverrides;
26+
2227
public Region getAwsRegion() {
2328
return awsRegion != null ? Region.of(awsRegion) : null;
2429
}
2530

2631
public String getAwsStsRoleArn() {
2732
return awsStsRoleArn;
2833
}
34+
35+
public Map<String, String> getStsHeaderOverrides() {
36+
return awsStsHeaderOverrides;
37+
}
2938
}

data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactory.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,16 @@ AwsCredentialsProvider providerFromOptions(final AwsCredentialsOptions credentia
6666
}
6767

6868
private AwsCredentialsProvider createStsCredentials(final AwsCredentialsOptions credentialsOptions) {
69-
70-
final String stsRoleArn = credentialsOptions.getStsRoleArn() == null ? defaultStsConfiguration.getAwsStsRoleArn() : credentialsOptions.getStsRoleArn();
69+
final boolean useDefaultStsRoleArn;
70+
final String stsRoleArn;
71+
if(credentialsOptions.getStsRoleArn() != null) {
72+
stsRoleArn = credentialsOptions.getStsRoleArn();
73+
useDefaultStsRoleArn = false;
74+
}
75+
else {
76+
stsRoleArn = defaultStsConfiguration.getAwsStsRoleArn();
77+
useDefaultStsRoleArn = true;
78+
}
7179

7280
validateStsRoleArn(stsRoleArn);
7381

@@ -85,7 +93,12 @@ private AwsCredentialsProvider createStsCredentials(final AwsCredentialsOptions
8593
assumeRoleRequestBuilder = assumeRoleRequestBuilder.externalId(credentialsOptions.getStsExternalId());
8694
}
8795

88-
final Map<String, String> awsStsHeaderOverrides = credentialsOptions.getStsHeaderOverrides();
96+
final Map<String, String> awsStsHeaderOverrides;
97+
if(useDefaultStsRoleArn) {
98+
awsStsHeaderOverrides = defaultStsConfiguration.getStsHeaderOverrides();
99+
} else {
100+
awsStsHeaderOverrides = credentialsOptions.getStsHeaderOverrides();
101+
}
89102

90103
if(awsStsHeaderOverrides != null && !awsStsHeaderOverrides.isEmpty()) {
91104
assumeRoleRequestBuilder = assumeRoleRequestBuilder

data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsStsConfigurationTest.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77

88
import com.fasterxml.jackson.core.JsonProcessingException;
99
import com.fasterxml.jackson.databind.ObjectMapper;
10+
import org.junit.jupiter.api.Test;
1011
import org.junit.jupiter.params.ParameterizedTest;
1112
import org.junit.jupiter.params.provider.MethodSource;
1213
import software.amazon.awssdk.regions.Region;
1314

1415
import java.util.List;
16+
import java.util.Map;
1517

1618
import static org.hamcrest.MatcherAssert.assertThat;
1719
import static org.hamcrest.Matchers.equalTo;
@@ -24,7 +26,6 @@ public class AwsStsConfigurationTest {
2426
@ParameterizedTest
2527
@MethodSource("getRegions")
2628
void testStsConfiguration(final Region region) throws JsonProcessingException {
27-
2829
final String defaultConfigurationAsString = "{\"region\": \"" + region.toString() + "\", \"sts_role_arn\": \"arn:aws:iam::123456789012:role/test-role\"}";
2930

3031
final AwsStsConfiguration objectUnderTest = OBJECT_MAPPER.readValue(defaultConfigurationAsString, AwsStsConfiguration.class);
@@ -34,6 +35,16 @@ void testStsConfiguration(final Region region) throws JsonProcessingException {
3435
assertThat(objectUnderTest.getAwsRegion(), equalTo(region));
3536
}
3637

38+
@Test
39+
void getStsHeaderOverrides() throws JsonProcessingException {
40+
final String jsonConfiguration = "{\"sts_role_arn\": \"arn:aws:iam::123456789012:role/test-role\", \"sts_header_overrides\": {\"abc\": \"123\", \"def\": \"456\"}}";
41+
42+
final AwsStsConfiguration objectUnderTest = OBJECT_MAPPER.readValue(jsonConfiguration, AwsStsConfiguration.class);
43+
assertThat(objectUnderTest, notNullValue());
44+
assertThat(objectUnderTest.getAwsStsRoleArn(), equalTo("arn:aws:iam::123456789012:role/test-role"));
45+
assertThat(objectUnderTest.getStsHeaderOverrides(), equalTo(Map.of("abc", "123", "def", "456")));
46+
}
47+
3748
private static List<Region> getRegions() {
3849
return Region.regions();
3950
}

data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/CredentialsProviderFactoryTest.java

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ void providerFromOptions_should_return_StsCredentialsProvider_with_sts_role_arn(
194194
}
195195

196196
@Test
197-
void providerFromOptions_should_return_s3Client_with_sts_role_arn_when_no_region() {
197+
void providerFromOptions_should_return_stsClient_with_sts_role_arn_when_no_region() {
198198
when(awsCredentialsOptions.getRegion()).thenReturn(null);
199199
when(awsCredentialsOptions.getStsRoleArn()).thenReturn(testStsRole);
200200

@@ -210,6 +210,88 @@ void providerFromOptions_should_return_s3Client_with_sts_role_arn_when_no_region
210210
verify(stsClientBuilder, never()).region(any(Region.class));
211211
}
212212

213+
@Test
214+
void providerFromOptions_should_override_STS_Headers_when_default_HeaderOverrides_when_set_and_using_default_STS_role_ARN() {
215+
final String headerName1 = UUID.randomUUID().toString();
216+
final String headerValue1 = UUID.randomUUID().toString();
217+
final String headerName2 = UUID.randomUUID().toString();
218+
final String headerValue2 = UUID.randomUUID().toString();
219+
final Map<String, String> overrideHeaders = Map.of(headerName1, headerValue1, headerName2, headerValue2);
220+
221+
final String defaultStsRole = createStsRole();
222+
when(defaultStsConfiguration.getAwsStsRoleArn()).thenReturn(defaultStsRole);
223+
when(defaultStsConfiguration.getStsHeaderOverrides()).thenReturn(overrideHeaders);
224+
225+
when(awsCredentialsOptions.getRegion()).thenReturn(Region.US_EAST_1);
226+
227+
when(stsClientBuilder.region(Region.US_EAST_1)).thenReturn(stsClientBuilder);
228+
229+
final CredentialsProviderFactory objectUnderTest = createObjectUnderTest();
230+
final AwsCredentialsProvider actualCredentialsProvider;
231+
try (final MockedStatic<StsClient> stsClientMockedStatic = mockStatic(StsClient.class);
232+
final MockedStatic<StsAssumeRoleCredentialsProvider> credentialsProviderMockedStatic = mockStatic(StsAssumeRoleCredentialsProvider.class)) {
233+
stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder);
234+
credentialsProviderMockedStatic.when(StsAssumeRoleCredentialsProvider::builder).thenReturn(stsCredentialsProviderBuilder);
235+
actualCredentialsProvider = objectUnderTest.providerFromOptions(awsCredentialsOptions);
236+
}
237+
238+
assertThat(actualCredentialsProvider, instanceOf(StsAssumeRoleCredentialsProvider.class));
239+
240+
final ArgumentCaptor<AssumeRoleRequest> assumeRoleRequestArgumentCaptor = ArgumentCaptor.forClass(AssumeRoleRequest.class);
241+
verify(stsCredentialsProviderBuilder).refreshRequest(assumeRoleRequestArgumentCaptor.capture());
242+
243+
final AssumeRoleRequest actualAssumeRoleRequest = assumeRoleRequestArgumentCaptor.getValue();
244+
assertThat(actualAssumeRoleRequest.roleArn(), equalTo(defaultStsRole));
245+
assertThat(actualAssumeRoleRequest.roleSessionName(), startsWith("Data-Prepper-"));
246+
assertThat(actualAssumeRoleRequest.roleSessionName().length(), lessThanOrEqualTo(MAXIMUM_ROLE_SESSION_LENGTH));
247+
assertThat(actualAssumeRoleRequest.overrideConfiguration(), notNullValue());
248+
assertThat(actualAssumeRoleRequest.overrideConfiguration().isPresent(), equalTo(true));
249+
final AwsRequestOverrideConfiguration overrideConfiguration = actualAssumeRoleRequest.overrideConfiguration().get();
250+
assertThat(overrideConfiguration.headers(), notNullValue());
251+
assertThat(overrideConfiguration.headers().size(), equalTo(2));
252+
assertThat(overrideConfiguration.headers(), hasKey(headerName1));
253+
assertThat(overrideConfiguration.headers(), hasKey(headerName2));
254+
assertThat(overrideConfiguration.headers().get(headerName1), notNullValue());
255+
assertThat(overrideConfiguration.headers().get(headerName1).size(), equalTo(1));
256+
assertThat(overrideConfiguration.headers().get(headerName1), hasItem(headerValue1));
257+
assertThat(overrideConfiguration.headers().get(headerName2), notNullValue());
258+
assertThat(overrideConfiguration.headers().get(headerName2).size(), equalTo(1));
259+
assertThat(overrideConfiguration.headers().get(headerName2), hasItem(headerValue2));
260+
261+
verify(awsCredentialsOptions, never()).getStsHeaderOverrides();
262+
}
263+
264+
@Test
265+
void providerFromOptions_should_not_override_STS_Headers_when_HeaderOverrides_are_empty_and_using_default_STS_role_ARN() {
266+
final String defaultStsRole = createStsRole();
267+
when(defaultStsConfiguration.getAwsStsRoleArn()).thenReturn(defaultStsRole);
268+
when(awsCredentialsOptions.getRegion()).thenReturn(Region.US_EAST_1);
269+
270+
when(stsClientBuilder.region(Region.US_EAST_1)).thenReturn(stsClientBuilder);
271+
272+
final CredentialsProviderFactory objectUnderTest = createObjectUnderTest();
273+
final AwsCredentialsProvider actualCredentialsProvider;
274+
275+
try (final MockedStatic<StsClient> stsClientMockedStatic = mockStatic(StsClient.class);
276+
final MockedStatic<StsAssumeRoleCredentialsProvider> credentialsProviderMockedStatic = mockStatic(StsAssumeRoleCredentialsProvider.class)) {
277+
stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder);
278+
credentialsProviderMockedStatic.when(StsAssumeRoleCredentialsProvider::builder).thenReturn(stsCredentialsProviderBuilder);
279+
actualCredentialsProvider = objectUnderTest.providerFromOptions(awsCredentialsOptions);
280+
}
281+
282+
assertThat(actualCredentialsProvider, instanceOf(StsAssumeRoleCredentialsProvider.class));
283+
284+
final ArgumentCaptor<AssumeRoleRequest> assumeRoleRequestArgumentCaptor = ArgumentCaptor.forClass(AssumeRoleRequest.class);
285+
verify(stsCredentialsProviderBuilder).refreshRequest(assumeRoleRequestArgumentCaptor.capture());
286+
287+
final AssumeRoleRequest actualAssumeRoleRequest = assumeRoleRequestArgumentCaptor.getValue();
288+
assertThat(actualAssumeRoleRequest.roleArn(), equalTo(defaultStsRole));
289+
assertThat(actualAssumeRoleRequest.roleSessionName(), startsWith("Data-Prepper-"));
290+
assertThat(actualAssumeRoleRequest.roleSessionName().length(), lessThanOrEqualTo(MAXIMUM_ROLE_SESSION_LENGTH));
291+
assertThat(actualAssumeRoleRequest.overrideConfiguration(), notNullValue());
292+
assertThat(actualAssumeRoleRequest.overrideConfiguration().isPresent(), equalTo(false));
293+
}
294+
213295
@Test
214296
void providerFromOptions_should_override_STS_Headers_when_HeaderOverrides_when_set() {
215297
final String headerName1 = UUID.randomUUID().toString();

0 commit comments

Comments
 (0)