diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/AwsConfig.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/AwsConfig.java index d42d894684..b27140a7ad 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/AwsConfig.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/AwsConfig.java @@ -10,6 +10,8 @@ import jakarta.validation.Valid; import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import java.util.Map; + public class AwsConfig implements AwsCredentialsConfig { public static class AwsMskConfig { @@ -28,6 +30,7 @@ public String getArn() { public MskBrokerConnectionType getBrokerConnectionType() { return brokerConnectionType; } + } @JsonProperty("msk") @@ -43,6 +46,10 @@ public MskBrokerConnectionType getBrokerConnectionType() { @JsonProperty("sts_role_arn") private String stsRoleArn; + @JsonProperty("sts_header_overrides") + @Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override") + private Map awsStsHeaderOverrides; + @JsonProperty("role_session_name") private String stsRoleSessionName; @@ -64,11 +71,16 @@ public String getStsRoleSessionName() { return stsRoleSessionName; } + public Map getAwsStsHeaderOverrides() { + return awsStsHeaderOverrides; + } + @Override public AwsCredentialsOptions toCredentialsOptions() { return AwsCredentialsOptions.builder() .withRegion(region) .withStsRoleArn(stsRoleArn) + .withStsHeaderOverrides(awsStsHeaderOverrides) .build(); } } diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/util/KafkaSecurityConfigurer.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/util/KafkaSecurityConfigurer.java index fc07147d4d..a465df22f0 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/util/KafkaSecurityConfigurer.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/util/KafkaSecurityConfigurer.java @@ -255,16 +255,21 @@ private static void configureMSKCredentialsProvider(final AuthConfig authConfig, .region(Region.of(awsConfig.getRegion())) .credentialsProvider(mskCredentialsProvider) .build(); + AssumeRoleRequest.Builder assumeRequestBuilder = AssumeRoleRequest + .builder() + .roleArn(awsConfig.getStsRoleArn()) + .roleSessionName(sessionName); + Map headers = awsConfig.getAwsStsHeaderOverrides(); + if (Objects.nonNull(headers)) { + assumeRequestBuilder.overrideConfiguration(configuration -> { + headers.forEach(configuration::putHeader); + }); + } mskCredentialsProvider = StsAssumeRoleCredentialsProvider .builder() .stsClient(stsClient) - .refreshRequest( - AssumeRoleRequest - .builder() - .roleArn(awsConfig.getStsRoleArn()) - .roleSessionName(sessionName) - .build() - ).build(); + .refreshRequest(assumeRequestBuilder.build()) + .build(); } } diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/configuration/AwsConfigTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/configuration/AwsConfigTest.java index b82694f4f6..db4f6b1c9d 100644 --- a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/configuration/AwsConfigTest.java +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/configuration/AwsConfigTest.java @@ -9,6 +9,7 @@ import org.junit.jupiter.api.Test; import java.lang.reflect.Field; +import java.util.Map; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; @@ -36,6 +37,53 @@ void TestConfigOptions_notNull() throws NoSuchFieldException, IllegalAccessExcep final String testRegion = RandomStringUtils.randomAlphabetic(8); reflectivelySetField(awsConfig, "region", testRegion); assertThat(awsConfig.getRegion(), equalTo(testRegion)); + + final Map testStsHeaderOverrides = Map.of("header1", "value1", "header2", "value2"); + reflectivelySetField(awsConfig, "awsStsHeaderOverrides", testStsHeaderOverrides); + assertThat(awsConfig.getAwsStsHeaderOverrides(), equalTo(testStsHeaderOverrides)); + } + + @Test + void testStsHeaderOverridesValidation_hasMaxSizeConstraint() throws NoSuchFieldException { + // Verify that the sts_header_overrides field has the @Size(max = 5) validation annotation + final Field field = AwsConfig.class.getDeclaredField("awsStsHeaderOverrides"); + final jakarta.validation.constraints.Size sizeAnnotation = field.getAnnotation(jakarta.validation.constraints.Size.class); + + assertThat("sts_header_overrides field should have @Size annotation", sizeAnnotation != null, equalTo(true)); + assertThat("sts_header_overrides should have max size of 5", sizeAnnotation.max(), equalTo(5)); + assertThat("sts_header_overrides validation message should be correct", + sizeAnnotation.message(), equalTo("sts_header_overrides supports a maximum of 5 headers to override")); + } + + @Test + void testToCredentialsOptions_withoutStsHeaderOverrides() throws NoSuchFieldException, IllegalAccessException { + final String testRegion = "us-east-1"; + final String testStsRoleArn = "arn:aws:iam::123456789012:role/test-role"; + + reflectivelySetField(awsConfig, "region", testRegion); + reflectivelySetField(awsConfig, "stsRoleArn", testStsRoleArn); + + final org.opensearch.dataprepper.aws.api.AwsCredentialsOptions result = awsConfig.toCredentialsOptions(); + + assertThat(result.getRegion().toString(), equalTo(testRegion)); + assertThat(result.getStsRoleArn(), equalTo(testStsRoleArn)); + } + + @Test + void testToCredentialsOptions_withStsHeaderOverrides() throws NoSuchFieldException, IllegalAccessException { + final String testRegion = "us-east-1"; + final String testStsRoleArn = "arn:aws:iam::123456789012:role/test-role"; + + reflectivelySetField(awsConfig, "region", testRegion); + reflectivelySetField(awsConfig, "stsRoleArn", testStsRoleArn); + final Map testStsHeaderOverrides = Map.of("header1", "value1", "header2", "value2"); + reflectivelySetField(awsConfig, "awsStsHeaderOverrides", testStsHeaderOverrides); + + final org.opensearch.dataprepper.aws.api.AwsCredentialsOptions result = awsConfig.toCredentialsOptions(); + + assertThat(result.getRegion().toString(), equalTo(testRegion)); + assertThat(result.getStsRoleArn(), equalTo(testStsRoleArn)); + assertThat(result.getStsHeaderOverrides(), equalTo(testStsHeaderOverrides)); } private void reflectivelySetField(final AwsConfig awsConfig, final String fieldName, final Object value) throws NoSuchFieldException, IllegalAccessException { diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/util/KafkaSecurityConfigurerTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/util/KafkaSecurityConfigurerTest.java index f65b48315b..9b84ee7a76 100644 --- a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/util/KafkaSecurityConfigurerTest.java +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/util/KafkaSecurityConfigurerTest.java @@ -25,12 +25,14 @@ import org.slf4j.LoggerFactory; import org.yaml.snakeyaml.Yaml; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.kafka.KafkaClient; import software.amazon.awssdk.services.kafka.KafkaClientBuilder; import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersRequest; import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersResponse; import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; +import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; import java.io.FileReader; import java.io.IOException; @@ -43,11 +45,13 @@ import static org.apache.kafka.common.config.SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItem; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.hasKey; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; @@ -347,6 +351,61 @@ void testSetDynamicSaslClientCallbackHandlerWithNullAuthConfig() { verifyNoInteractions(pluginConfigObservable); } + @Test + void testSetAuthPropertiesWithStsHeaderOverrides() throws IOException { + final Properties props = new Properties(); + final KafkaSourceConfig kafkaSourceConfig = createKafkaSinkConfig("kafka-pipeline-bootstrap-servers-sasl-iam-role.yaml"); + + try (MockedStatic mockedProvider = mockStatic(StsAssumeRoleCredentialsProvider.class)) { + final StsAssumeRoleCredentialsProvider.Builder mockBuilder = mock(StsAssumeRoleCredentialsProvider.Builder.class); + when(mockBuilder.stsClient(any())).thenReturn(mockBuilder); + when(mockBuilder.refreshRequest(any(AssumeRoleRequest.class))).thenReturn(mockBuilder); + when(mockBuilder.build()).thenReturn(stsAssumeRoleCredentialsProvider); + mockedProvider.when(StsAssumeRoleCredentialsProvider::builder).thenReturn(mockBuilder); + + KafkaSecurityConfigurer.setAuthProperties(props, kafkaSourceConfig, LOG); + + verify(mockBuilder).refreshRequest(any(AssumeRoleRequest.class)); + } + } + + @Test + void testSetAuthPropertiesWithStsHeaderOverridesConfigured() throws IOException { + final Properties props = new Properties(); + final KafkaSourceConfig kafkaSourceConfig = createKafkaSinkConfig("kafka-pipeline-bootstrap-servers-sasl-iam-role-with-headers.yaml"); + + try (MockedStatic mockedProvider = mockStatic(StsAssumeRoleCredentialsProvider.class)) { + final StsAssumeRoleCredentialsProvider.Builder stsCredentialsProviderBuilder = mock(StsAssumeRoleCredentialsProvider.Builder.class); + when(stsCredentialsProviderBuilder.stsClient(any())).thenReturn(stsCredentialsProviderBuilder); + when(stsCredentialsProviderBuilder.refreshRequest(any(AssumeRoleRequest.class))).thenReturn(stsCredentialsProviderBuilder); + when(stsCredentialsProviderBuilder.build()).thenReturn(stsAssumeRoleCredentialsProvider); + mockedProvider.when(StsAssumeRoleCredentialsProvider::builder).thenReturn(stsCredentialsProviderBuilder); + + KafkaSecurityConfigurer.setAuthProperties(props, kafkaSourceConfig, LOG); + + final ArgumentCaptor assumeRoleRequestArgumentCaptor = ArgumentCaptor.forClass(AssumeRoleRequest.class); + verify(stsCredentialsProviderBuilder).refreshRequest(assumeRoleRequestArgumentCaptor.capture()); + final AssumeRoleRequest actualAssumeRoleRequest = assumeRoleRequestArgumentCaptor.getValue(); + assertThat(actualAssumeRoleRequest.overrideConfiguration(), notNullValue()); + assertThat(actualAssumeRoleRequest.overrideConfiguration().isPresent(), equalTo(true)); + final AwsRequestOverrideConfiguration overrideConfiguration = actualAssumeRoleRequest.overrideConfiguration().get(); + assertThat(overrideConfiguration.headers(), notNullValue()); + assertThat(overrideConfiguration.headers().size(), equalTo(2)); + final String headerName1 = "X-Custom-Header"; + final String headerValue1 = "custom-value"; + final String headerName2 = "X-Another-Header"; + final String headerValue2 = "another-value"; + assertThat(overrideConfiguration.headers(), hasKey(headerName1)); + assertThat(overrideConfiguration.headers(), hasKey(headerName2)); + assertThat(overrideConfiguration.headers().get(headerName1), notNullValue()); + assertThat(overrideConfiguration.headers().get(headerName1).size(), equalTo(1)); + assertThat(overrideConfiguration.headers().get(headerName1), hasItem(headerValue1)); + assertThat(overrideConfiguration.headers().get(headerName2), notNullValue()); + assertThat(overrideConfiguration.headers().get(headerName2).size(), equalTo(1)); + assertThat(overrideConfiguration.headers().get(headerName2), hasItem(headerValue2)); + } + } + private KafkaSourceConfig createKafkaSinkConfig(final String fileName) throws IOException { final Yaml yaml = new Yaml(); final FileReader fileReader = new FileReader(Objects.requireNonNull(getClass().getClassLoader() diff --git a/data-prepper-plugins/kafka-plugins/src/test/resources/kafka-pipeline-bootstrap-servers-sasl-iam-role-with-headers.yaml b/data-prepper-plugins/kafka-plugins/src/test/resources/kafka-pipeline-bootstrap-servers-sasl-iam-role-with-headers.yaml new file mode 100644 index 0000000000..ca7826ed2c --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/test/resources/kafka-pipeline-bootstrap-servers-sasl-iam-role-with-headers.yaml @@ -0,0 +1,21 @@ +log-pipeline : + source: + kafka: + bootstrap_servers: + - "localhost:9092" + encryption: + type: "SSL" + authentication: + sasl: + aws_msk_iam: role + aws: + region: us-east-2 + sts_role_arn: test_sasl_iam_sts_role + sts_header_overrides: + X-Custom-Header: custom-value + X-Another-Header: another-value + topics: + - name: "quickstart-events" + group_id: "groupdID1" + sink: + stdout: