Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions data-prepper-plugins/mongodb/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies {

implementation 'com.fasterxml.jackson.core:jackson-core'
implementation 'com.fasterxml.jackson.core:jackson-databind'
implementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml'
implementation 'software.amazon.awssdk:s3'

implementation project(path: ':data-prepper-plugins:aws-plugin-api')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.mongodb.ConnectionString;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoCredential;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import org.opensearch.dataprepper.plugins.mongo.configuration.MongoDBSourceConfig;
Expand All @@ -13,14 +14,23 @@
import java.util.Objects;

public class MongoDBConnection {
private static final String MONGO_CONNECTION_STRING_TEMPLATE = "mongodb://%s:%s@%s:%s/?replicaSet=rs0&readpreference=%s&ssl=%s&tlsAllowInvalidHostnames=%s&directConnection=%s";
private static final String IAM_AUTH_SOURCE = "$external";
private static final String IAM_AUTH_MECHANISM = "MONGODB-AWS";
private static final String MONGO_PASSWORD_CONNECTION_STRING_TEMPLATE = "mongodb://%s:%s@%s:%s/?replicaSet=rs0&readpreference=%s&ssl=%s&tlsAllowInvalidHostnames=%s&directConnection=%s";
private static final String MONGO_IAM_CONNECTION_STRING_TEMPLATE = "mongodb://%s:%s/?replicaSet=rs0&readpreference=%s&ssl=%s&tlsAllowInvalidHostnames=%s&directConnection=%s&authSource=%s&authMechanism=%s";

public static MongoClient getMongoClient(final MongoDBSourceConfig sourceConfig) {

final String connectionString = getConnectionString(sourceConfig);
final boolean usesIAMAuthentication = usesIAMAuthentication(sourceConfig);
final String connectionString = getConnectionString(sourceConfig, usesIAMAuthentication);

final MongoClientSettings.Builder settingBuilder = MongoClientSettings.builder()
.applyConnectionString(new ConnectionString(connectionString));
if (usesIAMAuthentication) {
// Create an empty credential. This triggers mongo to use the underlying IAM role.
final MongoCredential credential = MongoCredential.createAwsCredential(null, null);
settingBuilder.credential(credential);
}

if (Objects.nonNull(sourceConfig.getTrustStoreFilePath())) {
final File truststoreFilePath = new File(sourceConfig.getTrustStoreFilePath());
Expand All @@ -39,7 +49,23 @@ private static String encodeString(final String input) {
return URLEncoder.encode(input, StandardCharsets.UTF_8);
}

private static String getConnectionString(final MongoDBSourceConfig sourceConfig) {
private static String getConnectionString(final MongoDBSourceConfig sourceConfig, final boolean usesIamAuth) {
// Support for only single host
final String hostname = sourceConfig.getHost();
final int port = sourceConfig.getPort();
final String tls = sourceConfig.getTls().toString();
final String invalidHostAllowed = sourceConfig.getSslInsecureDisableVerification().toString();
final String readPreference = sourceConfig.getReadPreference();
final String directionConnection = sourceConfig.getDirectConnection().toString();

if (sourceConfig.getHost() == null || sourceConfig.getHost().isBlank()) {
throw new RuntimeException("The host should not be null or empty.");
}

if (usesIamAuth) {
return String.format(MONGO_IAM_CONNECTION_STRING_TEMPLATE, hostname, port, readPreference, tls, invalidHostAllowed, directionConnection, encodeString(IAM_AUTH_SOURCE), encodeString(IAM_AUTH_MECHANISM));
}

final String username;
try {
username = encodeString(sourceConfig.getAuthenticationConfig().getUsername());
Expand All @@ -54,18 +80,19 @@ private static String getConnectionString(final MongoDBSourceConfig sourceConfig
throw new RuntimeException("Unsupported characters in password.");
}

if (sourceConfig.getHost() == null || sourceConfig.getHost().isBlank()) {
throw new RuntimeException("The host should not be null or empty.");
}
return String.format(MONGO_PASSWORD_CONNECTION_STRING_TEMPLATE, username, password, hostname, port, readPreference, tls, invalidHostAllowed, directionConnection);
}

// Support for only single host
final String hostname = sourceConfig.getHost();
final int port = sourceConfig.getPort();
final String tls = sourceConfig.getTls().toString();
final String invalidHostAllowed = sourceConfig.getSslInsecureDisableVerification().toString();
final String readPreference = sourceConfig.getReadPreference();
final String directionConnection = sourceConfig.getDirectConnection().toString();
return String.format(MONGO_CONNECTION_STRING_TEMPLATE, username, password, hostname, port,
readPreference, tls, invalidHostAllowed, directionConnection);
private static boolean usesIAMAuthentication(final MongoDBSourceConfig sourceConfig) {
final boolean hasUsernamePassword = Objects.nonNull(sourceConfig.getAuthenticationConfig()) &&
(Objects.nonNull(sourceConfig.getAuthenticationConfig().getUsername()) ||
Objects.nonNull(sourceConfig.getAuthenticationConfig().getPassword()));

if (hasUsernamePassword) {
return false;
}

return Objects.nonNull(sourceConfig.getAwsConfig()) &&
Objects.nonNull(sourceConfig.getAwsConfig().getAwsStsRoleArn());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

public class MongoDBSourceConfig {
private static final int DEFAULT_PORT = 27017;
Expand Down Expand Up @@ -163,4 +164,18 @@ public String getPassword() {
}

}

public void validateAwsConfigWithUsernameAndPassword() {
final boolean hasUsernamePassword = Objects.nonNull(authenticationConfig) &&
(Objects.nonNull(authenticationConfig.getUsername()) || Objects.nonNull(authenticationConfig.getPassword()));
final boolean hasAwsAuth = Objects.nonNull(awsConfig) && Objects.nonNull(awsConfig.getAwsStsRoleArn());

if (hasUsernamePassword && hasAwsAuth) {
throw new IllegalArgumentException("Either username and password, or aws sts_role_arn must be specified. Both cannot be set at once.");
}

if (!hasUsernamePassword && !hasAwsAuth) {
throw new IllegalArgumentException("Either username and password, or aws sts_role_arn must be specified.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ public DocumentDBSource(final PluginMetrics pluginMetrics,
this.acknowledgementSetManager = acknowledgementSetManager;
this.pluginConfigObservable = pluginConfigObservable;
this.acknowledgementsEnabled = sourceConfig.isAcknowledgmentsEnabled();

sourceConfig.validateAwsConfigWithUsernameAndPassword();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.plugins.mongo.configuration.AwsConfig;
import org.opensearch.dataprepper.plugins.mongo.configuration.MongoDBSourceConfig;
import org.opensearch.dataprepper.plugins.truststore.TrustStoreProvider;

Expand All @@ -30,12 +31,13 @@ public class MongoDBConnectionTest {
@Mock
private MongoDBSourceConfig.AuthenticationConfig authenticationConfig;

@Mock
private AwsConfig awsConfig;

private final Random random = new Random();

void setUp() {
when(mongoDBSourceConfig.getAuthenticationConfig()).thenReturn(authenticationConfig);
when(authenticationConfig.getUsername()).thenReturn("\uD800\uD800" + UUID.randomUUID());
when(authenticationConfig.getPassword()).thenReturn("aЯ ⾀sd?q=%%l€0£.lo" + UUID.randomUUID());
when(mongoDBSourceConfig.getHost()).thenReturn(UUID.randomUUID().toString());
when(mongoDBSourceConfig.getPort()).thenReturn(getRandomInteger());
when(mongoDBSourceConfig.getTls()).thenReturn(getRandomBoolean());
Expand All @@ -44,15 +46,19 @@ void setUp() {
}

@Test
public void getMongoClient() {
public void getMongoClientWithUsernamePassword() {
setUp();
when(authenticationConfig.getUsername()).thenReturn("\uD800\uD800" + UUID.randomUUID());
when(authenticationConfig.getPassword()).thenReturn("aЯ ⾀sd?q=%%l€0£.lo" + UUID.randomUUID());
final MongoClient mongoClient = MongoDBConnection.getMongoClient(mongoDBSourceConfig);
assertThat(mongoClient, is(notNullValue()));
}

@Test
public void getMongoClientWithTLS() {
setUp();
when(authenticationConfig.getUsername()).thenReturn("\uD800\uD800" + UUID.randomUUID());
when(authenticationConfig.getPassword()).thenReturn("aЯ ⾀sd?q=%%l€0£.lo" + UUID.randomUUID());
when(mongoDBSourceConfig.getTrustStoreFilePath()).thenReturn(UUID.randomUUID().toString());
when(mongoDBSourceConfig.getTrustStorePassword()).thenReturn(UUID.randomUUID().toString());
final Path path = mock(Path.class);
Expand All @@ -68,22 +74,26 @@ public void getMongoClientWithTLS() {

@Test
public void getMongoClientNullHost() {
when(mongoDBSourceConfig.getAuthenticationConfig()).thenReturn(authenticationConfig);
when(authenticationConfig.getUsername()).thenReturn("\uD800\uD800" + UUID.randomUUID());
when(authenticationConfig.getPassword()).thenReturn("aЯ ⾀sd?q=%%l€0£.lo" + UUID.randomUUID());
when(mongoDBSourceConfig.getHost()).thenReturn(null);
assertThrows(RuntimeException.class, () -> MongoDBConnection.getMongoClient(mongoDBSourceConfig));
}

@Test
public void getMongoClientEmptyHost() {
when(mongoDBSourceConfig.getAuthenticationConfig()).thenReturn(authenticationConfig);
when(authenticationConfig.getUsername()).thenReturn("\uD800\uD800" + UUID.randomUUID());
when(authenticationConfig.getPassword()).thenReturn("aЯ ⾀sd?q=%%l€0£.lo" + UUID.randomUUID());
when(mongoDBSourceConfig.getHost()).thenReturn(" ");
assertThrows(RuntimeException.class, () -> MongoDBConnection.getMongoClient(mongoDBSourceConfig));
}

@Test
public void getMongoClientWithIAMAuth() {
setUp();
when(mongoDBSourceConfig.getAuthenticationConfig()).thenReturn(null);
when(mongoDBSourceConfig.getAwsConfig()).thenReturn(awsConfig);
when(awsConfig.getAwsStsRoleArn()).thenReturn("arn:aws:iam::123456789012:role/testRole");
final MongoClient mongoClient = MongoDBConnection.getMongoClient(mongoDBSourceConfig);
assertThat(mongoClient, is(notNullValue()));
}

private Boolean getRandomBoolean() {
return random.nextBoolean();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
*/
package org.opensearch.dataprepper.plugins.mongo.configuration;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.jupiter.api.Assertions.assertThrows;

public class MongoDBSourceConfigTest {

private final ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS));

@Test
void username_password_only() throws JsonProcessingException {
final String configYaml =
"host: \"localhost\"\n" +
"authentication:\n" +
" username: test\n" +
" password: test\n" +
"collections:\n" +
" - collection: test\n";

final MongoDBSourceConfig config = objectMapper.readValue(configYaml, MongoDBSourceConfig.class);

config.validateAwsConfigWithUsernameAndPassword();
assertThat(config.getAuthenticationConfig(), notNullValue());
assertThat(config.getAuthenticationConfig().getUsername(), equalTo("test"));
assertThat(config.getAuthenticationConfig().getPassword(), equalTo("test"));
assertThat(config.getAwsConfig(), nullValue());
}

@Test
void aws_sts_role_arn_only() throws JsonProcessingException {
final String configYaml =
"host: \"localhost\"\n" +
"aws:\n" +
" sts_role_arn: \"arn:aws:iam::123456789012:role/test-role\"\n" +
"collections:\n" +
" - collection: test\n";

final MongoDBSourceConfig config = objectMapper.readValue(configYaml, MongoDBSourceConfig.class);

config.validateAwsConfigWithUsernameAndPassword();
assertThat(config.getAwsConfig(), notNullValue());
assertThat(config.getAwsConfig().getAwsStsRoleArn(), equalTo("arn:aws:iam::123456789012:role/test-role"));
assertThat(config.getAuthenticationConfig(), nullValue());
}

@Test
void both_username_password_and_aws_is_invalid() throws JsonProcessingException {
final String configYaml =
"host: \"localhost\"\n" +
"authentication:\n" +
" username: test\n" +
" password: test\n" +
"aws:\n" +
" sts_role_arn: \"arn:aws:iam::123456789012:role/test-role\"\n" +
"collections:\n" +
" - collection: test\n";

final MongoDBSourceConfig config = objectMapper.readValue(configYaml, MongoDBSourceConfig.class);
assertThrows(IllegalArgumentException.class, config::validateAwsConfigWithUsernameAndPassword);
}

@Test
void neither_username_password_nor_aws_is_invalid() throws JsonProcessingException {
final String configYaml =
"host: \"localhost\"\n" +
"collections:\n" +
" - collection: test\n";

final MongoDBSourceConfig config = objectMapper.readValue(configYaml, MongoDBSourceConfig.class);
assertThrows(IllegalArgumentException.class, config::validateAwsConfigWithUsernameAndPassword);
}
}