Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,9 @@ public class DownloadFailedException extends EngineFailureException {
public DownloadFailedException(final String exceptionMsg) {
super(exceptionMsg);
}

public DownloadFailedException(final String exceptionMsg, Throwable cause){
super(exceptionMsg, cause);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ public class EngineFailureException extends RuntimeException {
public EngineFailureException(final String exceptionMsg) {
super(exceptionMsg);
}

public EngineFailureException(final String exceptionMsg, Throwable cause){
super(exceptionMsg, cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

package org.opensearch.dataprepper.plugins.geoip.extension.databasedownload;

import com.google.common.io.Files;
import org.opensearch.dataprepper.plugins.geoip.extension.MaxMindDatabaseConfig;

import java.io.File;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
import java.util.Set;

/**
Expand Down Expand Up @@ -37,8 +38,8 @@ public LocalDBDownloadService(final String destinationDirectory, final MaxMindDa
public void initiateDownload() throws Exception {
final Set<String> strings = maxMindDatabaseConfig.getDatabasePaths().keySet();
for (final String key: strings) {
Files.copy(new File(maxMindDatabaseConfig.getDatabasePaths().get(key)),
new File(destinationDirectory + File.separator + key + MAXMIND_DATABASE_EXTENSION));
Files.copy(new File(maxMindDatabaseConfig.getDatabasePaths().get(key)).toPath(),
new File(destinationDirectory + File.separator + key + MAXMIND_DATABASE_EXTENSION).toPath(), StandardCopyOption.REPLACE_EXISTING);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import org.opensearch.dataprepper.plugins.geoip.exception.DownloadFailedException;
import org.opensearch.dataprepper.plugins.geoip.extension.AwsAuthenticationOptionsConfig;
import org.opensearch.dataprepper.plugins.geoip.extension.MaxMindDatabaseConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import software.amazon.awssdk.core.sync.ResponseTransformer;
import software.amazon.awssdk.services.s3.S3Client;

import java.io.File;
Expand All @@ -19,6 +23,7 @@
* Implementation class for Download through S3
*/
public class S3DBService implements DBSource {
private static final Logger LOG = LoggerFactory.getLogger(S3DBService.class);
private final AwsAuthenticationOptionsConfig awsAuthenticationOptionsConfig;
private final String destinationDirectory;
private final MaxMindDatabaseConfig maxMindDatabaseConfig;
Expand Down Expand Up @@ -52,7 +57,7 @@ public void initiateDownload() {
final String bucketName = uri.getHost();
buildRequestAndDownloadFile(bucketName, key, database);
} catch (URISyntaxException ex) {
throw new DownloadFailedException("Failed to download database from S3." + ex.getMessage());
throw new DownloadFailedException("Failed to download database from S3." + ex.getMessage(), ex);
}
}
}
Expand All @@ -64,15 +69,16 @@ public void initiateDownload() {
* @param key Name of S3 object key
* @param fileName Name of the file to save
*/
private void buildRequestAndDownloadFile(final String bucketName, final String key, final String fileName) {
private void buildRequestAndDownloadFile(String bucketName, String key, String fileName) {
File destination = new File(this.destinationDirectory + File.separator + fileName + ".mmdb");
final S3Client s3Client = this.createS3Client();
try {
final S3Client s3Client = createS3Client();

final File destination = new File(destinationDirectory + File.separator + fileName + MAXMIND_DATABASE_EXTENSION);

s3Client.getObject(b -> b.bucket(bucketName).key(key), destination.toPath());
s3Client.getObject((b) -> {
b.bucket(bucketName).key(key);
}, ResponseTransformer.toFile(destination));
} catch (Exception ex) {
throw new DownloadFailedException("Failed to download database from S3." + ex.getMessage());
LOG.error("Failed to download database '{}' from S3: {}", fileName, ex.getMessage());
throw new DownloadFailedException("Failed to download database from S3: " + ex.getMessage(), ex);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.when;

Expand All @@ -40,6 +42,30 @@ void initiateDownloadTest() throws IOException {
assertTrue(new File(destinationDirectory + File.separator + "filename.mmdb").exists());
}

@Test
void testOverwriteFunctionality() throws Exception {
downloadThroughLocalPath = createObjectUnderTest();

String initialContent = "Initial database content\nVersion: 1.0";
createFileWithContent(sourceDirectory + File.separator + "SampleFile.mmdb", initialContent);

downloadThroughLocalPath.initiateDownload();

File destinationFile = new File(destinationDirectory + File.separator + "filename.mmdb");
assertTrue(destinationFile.exists());

String copiedContent = Files.readString(destinationFile.toPath());
assertEquals(initialContent, copiedContent);

String updatedContent = "Updated database content\nVersion: 2.0";
createFileWithContent(sourceDirectory + File.separator + "SampleFile.mmdb", updatedContent);

downloadThroughLocalPath.initiateDownload();

String finalContent = Files.readString(destinationFile.toPath());
assertEquals(updatedContent, finalContent);
}

private LocalDBDownloadService createObjectUnderTest() {
when(maxMindDatabaseConfig.getDatabasePaths()).thenReturn(Map.of("filename", sourceDirectory + File.separator + "SampleFile.mmdb"));
createFolder(destinationDirectory);
Expand All @@ -58,8 +84,12 @@ private void generateSampleFiles() throws IOException {
String content = "This is sample file";

createFolder(sourceDirectory);
new File(sourceDirectory + File.separator + fileName);
try (FileWriter writer = new FileWriter(sourceDirectory + File.separator + fileName)) {
createFileWithContent(sourceDirectory + File.separator + fileName, content);
}

private void createFileWithContent(String filePath, String content) throws IOException {
createFolder(sourceDirectory);
try (FileWriter writer = new FileWriter(filePath)) {
writer.write(content);
}
}
Expand All @@ -81,4 +111,4 @@ public void deleteDirectory(final File file) {
file.delete();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,64 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.io.TempDir;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.plugins.geoip.exception.DownloadFailedException;
import org.opensearch.dataprepper.plugins.geoip.extension.AwsAuthenticationOptionsConfig;
import org.opensearch.dataprepper.plugins.geoip.extension.MaxMindDatabaseConfig;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.sync.ResponseTransformer;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;

import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Map;
import java.util.function.Consumer;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class S3DBServiceTest {

private static final String S3_URI = "s3://mybucket10012023/GeoLite2/";
private static final String S3_URI = "s3://mybucket10012023/GeoLite2/test-database.mmdb";
private static final String DATABASE_DIR = "blue_database";
private static final String DATABASE_NAME = "test-database";

@Mock
private MaxMindDatabaseConfig maxMindDatabaseConfig;
@Mock
private AwsAuthenticationOptionsConfig awsAuthenticationOptionsConfig;
@Mock
private S3Client s3Client;
@Mock
private S3ClientBuilder s3ClientBuilder;
@Mock
private AwsCredentialsProvider credentialsProvider;

@TempDir
Path tempDir;

@BeforeEach
void setUp() {
when(maxMindDatabaseConfig.getDatabasePaths()).thenReturn(Map.of("database-name", S3_URI));
when(maxMindDatabaseConfig.getDatabasePaths()).thenReturn(Map.of(DATABASE_NAME, S3_URI));
when(awsAuthenticationOptionsConfig.getAwsRegion()).thenReturn(Region.US_EAST_1);
when(awsAuthenticationOptionsConfig.authenticateAwsConfiguration()).thenReturn(credentialsProvider);
}

@Test
Expand All @@ -40,7 +74,44 @@ void initiateDownloadTest_DownloadFailedException() {
assertThrows(DownloadFailedException.class, () -> downloadThroughS3.initiateDownload());
}

@Test
void testOverwriteFunctionality() throws IOException {
String testDatabaseDir = tempDir.toString();

String initialContent = "Initial database content\nVersion: 1.0";
String updatedContent = "Updated database content\nVersion: 2.0";

File destinationFile = new File(testDatabaseDir + File.separator + DATABASE_NAME + ".mmdb");
Files.createDirectories(destinationFile.getParentFile().toPath());
Files.write(destinationFile.toPath(), initialContent.getBytes(StandardCharsets.UTF_8));

String readContent = Files.readString(destinationFile.toPath());
assertEquals(initialContent, readContent);

try (MockedStatic<S3Client> s3ClientMockedStatic = mockStatic(S3Client.class)) {
when(s3ClientBuilder.region(any(Region.class))).thenReturn(s3ClientBuilder);
when(s3ClientBuilder.credentialsProvider(any(AwsCredentialsProvider.class))).thenReturn(s3ClientBuilder);
when(s3ClientBuilder.build()).thenReturn(s3Client);
s3ClientMockedStatic.when(S3Client::builder).thenReturn(s3ClientBuilder);

ArgumentCaptor<Consumer<GetObjectRequest.Builder>> requestCaptor = ArgumentCaptor.forClass(Consumer.class);
ArgumentCaptor<ResponseTransformer<GetObjectResponse, ?>> transformerCaptor = ArgumentCaptor.forClass(ResponseTransformer.class);

when(s3Client.getObject(requestCaptor.capture(), transformerCaptor.capture())).thenAnswer(invocation -> {
Files.write(destinationFile.toPath(), updatedContent.getBytes(StandardCharsets.UTF_8));
return null;
});

S3DBService s3DBService = createObjectUnderTest();
s3DBService.initiateDownload();
}
String finalContent = Files.readString(destinationFile.toPath());
assertEquals(updatedContent, finalContent);
verify(s3Client, times(1)).getObject(any(Consumer.class), any(ResponseTransformer.class));

}

private S3DBService createObjectUnderTest() {
return new S3DBService(awsAuthenticationOptionsConfig, DATABASE_DIR, maxMindDatabaseConfig);
}
}
}
Loading