Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import jakarta.validation.Valid;
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;

import java.util.Map;

public class AwsConfig implements AwsCredentialsConfig {

public static class AwsMskConfig {
Expand All @@ -28,6 +30,7 @@ public String getArn() {
public MskBrokerConnectionType getBrokerConnectionType() {
return brokerConnectionType;
}

}

@JsonProperty("msk")
Expand All @@ -43,6 +46,10 @@ public MskBrokerConnectionType getBrokerConnectionType() {
@JsonProperty("sts_role_arn")
private String stsRoleArn;

@JsonProperty("sts_header_overrides")
@Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override")
private Map<String, String> awsStsHeaderOverrides;

@JsonProperty("role_session_name")
private String stsRoleSessionName;

Expand All @@ -64,11 +71,16 @@ public String getStsRoleSessionName() {
return stsRoleSessionName;
}

public Map<String, String> getAwsStsHeaderOverrides() {
return awsStsHeaderOverrides;
}

@Override
public AwsCredentialsOptions toCredentialsOptions() {
return AwsCredentialsOptions.builder()
.withRegion(region)
.withStsRoleArn(stsRoleArn)
.withStsHeaderOverrides(awsStsHeaderOverrides)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -255,16 +255,21 @@ private static void configureMSKCredentialsProvider(final AuthConfig authConfig,
.region(Region.of(awsConfig.getRegion()))
.credentialsProvider(mskCredentialsProvider)
.build();
AssumeRoleRequest.Builder assumeRequestBuilder = AssumeRoleRequest
.builder()
.roleArn(awsConfig.getStsRoleArn())
.roleSessionName(sessionName);
Map<String, String> headers = awsConfig.getAwsStsHeaderOverrides();
if (Objects.nonNull(headers)) {
assumeRequestBuilder.overrideConfiguration(configuration -> {
headers.forEach(configuration::putHeader);
});
}
mskCredentialsProvider = StsAssumeRoleCredentialsProvider
.builder()
.stsClient(stsClient)
.refreshRequest(
AssumeRoleRequest
.builder()
.roleArn(awsConfig.getStsRoleArn())
.roleSessionName(sessionName)
.build()
).build();
.refreshRequest(assumeRequestBuilder.build())
.build();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.junit.jupiter.api.Test;

import java.lang.reflect.Field;
import java.util.Map;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
Expand Down Expand Up @@ -36,6 +37,53 @@ void TestConfigOptions_notNull() throws NoSuchFieldException, IllegalAccessExcep
final String testRegion = RandomStringUtils.randomAlphabetic(8);
reflectivelySetField(awsConfig, "region", testRegion);
assertThat(awsConfig.getRegion(), equalTo(testRegion));

final Map<String, String> testStsHeaderOverrides = Map.of("header1", "value1", "header2", "value2");
reflectivelySetField(awsConfig, "awsStsHeaderOverrides", testStsHeaderOverrides);
assertThat(awsConfig.getAwsStsHeaderOverrides(), equalTo(testStsHeaderOverrides));
}

@Test
void testStsHeaderOverridesValidation_hasMaxSizeConstraint() throws NoSuchFieldException {
// Verify that the sts_header_overrides field has the @Size(max = 5) validation annotation
final Field field = AwsConfig.class.getDeclaredField("awsStsHeaderOverrides");
final jakarta.validation.constraints.Size sizeAnnotation = field.getAnnotation(jakarta.validation.constraints.Size.class);

assertThat("sts_header_overrides field should have @Size annotation", sizeAnnotation != null, equalTo(true));
assertThat("sts_header_overrides should have max size of 5", sizeAnnotation.max(), equalTo(5));
assertThat("sts_header_overrides validation message should be correct",
sizeAnnotation.message(), equalTo("sts_header_overrides supports a maximum of 5 headers to override"));
}

@Test
void testToCredentialsOptions_withoutStsHeaderOverrides() throws NoSuchFieldException, IllegalAccessException {
final String testRegion = "us-east-1";
final String testStsRoleArn = "arn:aws:iam::123456789012:role/test-role";

reflectivelySetField(awsConfig, "region", testRegion);
reflectivelySetField(awsConfig, "stsRoleArn", testStsRoleArn);

final org.opensearch.dataprepper.aws.api.AwsCredentialsOptions result = awsConfig.toCredentialsOptions();

assertThat(result.getRegion().toString(), equalTo(testRegion));
assertThat(result.getStsRoleArn(), equalTo(testStsRoleArn));
}

@Test
void testToCredentialsOptions_withStsHeaderOverrides() throws NoSuchFieldException, IllegalAccessException {
final String testRegion = "us-east-1";
final String testStsRoleArn = "arn:aws:iam::123456789012:role/test-role";

reflectivelySetField(awsConfig, "region", testRegion);
reflectivelySetField(awsConfig, "stsRoleArn", testStsRoleArn);
final Map<String, String> testStsHeaderOverrides = Map.of("header1", "value1", "header2", "value2");
reflectivelySetField(awsConfig, "awsStsHeaderOverrides", testStsHeaderOverrides);

final org.opensearch.dataprepper.aws.api.AwsCredentialsOptions result = awsConfig.toCredentialsOptions();

assertThat(result.getRegion().toString(), equalTo(testRegion));
assertThat(result.getStsRoleArn(), equalTo(testStsRoleArn));
assertThat(result.getStsHeaderOverrides(), equalTo(testStsHeaderOverrides));
}

private void reflectivelySetField(final AwsConfig awsConfig, final String fieldName, final Object value) throws NoSuchFieldException, IllegalAccessException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
import org.slf4j.LoggerFactory;
import org.yaml.snakeyaml.Yaml;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.kafka.KafkaClient;
import software.amazon.awssdk.services.kafka.KafkaClientBuilder;
import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersRequest;
import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersResponse;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;

import java.io.FileReader;
import java.io.IOException;
Expand All @@ -43,11 +45,13 @@

import static org.apache.kafka.common.config.SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.hasItem;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.hasKey;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
Expand Down Expand Up @@ -347,6 +351,61 @@ void testSetDynamicSaslClientCallbackHandlerWithNullAuthConfig() {
verifyNoInteractions(pluginConfigObservable);
}

@Test
void testSetAuthPropertiesWithStsHeaderOverrides() throws IOException {
final Properties props = new Properties();
final KafkaSourceConfig kafkaSourceConfig = createKafkaSinkConfig("kafka-pipeline-bootstrap-servers-sasl-iam-role.yaml");

try (MockedStatic<StsAssumeRoleCredentialsProvider> mockedProvider = mockStatic(StsAssumeRoleCredentialsProvider.class)) {
final StsAssumeRoleCredentialsProvider.Builder mockBuilder = mock(StsAssumeRoleCredentialsProvider.Builder.class);
when(mockBuilder.stsClient(any())).thenReturn(mockBuilder);
when(mockBuilder.refreshRequest(any(AssumeRoleRequest.class))).thenReturn(mockBuilder);
when(mockBuilder.build()).thenReturn(stsAssumeRoleCredentialsProvider);
mockedProvider.when(StsAssumeRoleCredentialsProvider::builder).thenReturn(mockBuilder);

KafkaSecurityConfigurer.setAuthProperties(props, kafkaSourceConfig, LOG);

verify(mockBuilder).refreshRequest(any(AssumeRoleRequest.class));
}
}

@Test
void testSetAuthPropertiesWithStsHeaderOverridesConfigured() throws IOException {
final Properties props = new Properties();
final KafkaSourceConfig kafkaSourceConfig = createKafkaSinkConfig("kafka-pipeline-bootstrap-servers-sasl-iam-role-with-headers.yaml");

try (MockedStatic<StsAssumeRoleCredentialsProvider> mockedProvider = mockStatic(StsAssumeRoleCredentialsProvider.class)) {
final StsAssumeRoleCredentialsProvider.Builder stsCredentialsProviderBuilder = mock(StsAssumeRoleCredentialsProvider.Builder.class);
when(stsCredentialsProviderBuilder.stsClient(any())).thenReturn(stsCredentialsProviderBuilder);
when(stsCredentialsProviderBuilder.refreshRequest(any(AssumeRoleRequest.class))).thenReturn(stsCredentialsProviderBuilder);
when(stsCredentialsProviderBuilder.build()).thenReturn(stsAssumeRoleCredentialsProvider);
mockedProvider.when(StsAssumeRoleCredentialsProvider::builder).thenReturn(stsCredentialsProviderBuilder);

KafkaSecurityConfigurer.setAuthProperties(props, kafkaSourceConfig, LOG);

final ArgumentCaptor<AssumeRoleRequest> assumeRoleRequestArgumentCaptor = ArgumentCaptor.forClass(AssumeRoleRequest.class);
verify(stsCredentialsProviderBuilder).refreshRequest(assumeRoleRequestArgumentCaptor.capture());
final AssumeRoleRequest actualAssumeRoleRequest = assumeRoleRequestArgumentCaptor.getValue();
assertThat(actualAssumeRoleRequest.overrideConfiguration(), notNullValue());
assertThat(actualAssumeRoleRequest.overrideConfiguration().isPresent(), equalTo(true));
final AwsRequestOverrideConfiguration overrideConfiguration = actualAssumeRoleRequest.overrideConfiguration().get();
assertThat(overrideConfiguration.headers(), notNullValue());
assertThat(overrideConfiguration.headers().size(), equalTo(2));
final String headerName1 = "X-Custom-Header";
final String headerValue1 = "custom-value";
final String headerName2 = "X-Another-Header";
final String headerValue2 = "another-value";
assertThat(overrideConfiguration.headers(), hasKey(headerName1));
assertThat(overrideConfiguration.headers(), hasKey(headerName2));
assertThat(overrideConfiguration.headers().get(headerName1), notNullValue());
assertThat(overrideConfiguration.headers().get(headerName1).size(), equalTo(1));
assertThat(overrideConfiguration.headers().get(headerName1), hasItem(headerValue1));
assertThat(overrideConfiguration.headers().get(headerName2), notNullValue());
assertThat(overrideConfiguration.headers().get(headerName2).size(), equalTo(1));
assertThat(overrideConfiguration.headers().get(headerName2), hasItem(headerValue2));
}
}

private KafkaSourceConfig createKafkaSinkConfig(final String fileName) throws IOException {
final Yaml yaml = new Yaml();
final FileReader fileReader = new FileReader(Objects.requireNonNull(getClass().getClassLoader()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
log-pipeline :
source:
kafka:
bootstrap_servers:
- "localhost:9092"
encryption:
type: "SSL"
authentication:
sasl:
aws_msk_iam: role
aws:
region: us-east-2
sts_role_arn: test_sasl_iam_sts_role
sts_header_overrides:
X-Custom-Header: custom-value

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this being validated anywhere.

X-Another-Header: another-value
topics:
- name: "quickstart-events"
group_id: "groupdID1"
sink:
stdout:
Loading