From 7e7959b211353081255155da785bc13b9f9ee204 Mon Sep 17 00:00:00 2001 From: Kirtan Kakadiya <35823164+KirtanKakadiya@users.noreply.github.com> Date: Thu, 31 Jul 2025 12:40:19 -0700 Subject: [PATCH] Fix S3DBService and LocalDBService file overwrite handling during downloads (#5911) Fix S3DBService/LocalDBService file overwrite handling during downloads Signed-off-by: kirtanhk Co-authored-by: kirtanhk (cherry picked from commit b87890ad33fd50fac700e21758825d651c8f33d9) --- .../exception/DownloadFailedException.java | 5 ++ .../exception/EngineFailureException.java | 4 + .../LocalDBDownloadService.java | 7 +- .../databasedownload/S3DBService.java | 22 ++++-- .../LocalDBDownloadServiceTest.java | 36 ++++++++- .../databasedownload/S3DBServiceTest.java | 77 ++++++++++++++++++- 6 files changed, 134 insertions(+), 17 deletions(-) diff --git a/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/exception/DownloadFailedException.java b/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/exception/DownloadFailedException.java index 7aa10b9901..0ea661ddd8 100644 --- a/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/exception/DownloadFailedException.java +++ b/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/exception/DownloadFailedException.java @@ -12,4 +12,9 @@ public class DownloadFailedException extends EngineFailureException { public DownloadFailedException(final String exceptionMsg) { super(exceptionMsg); } + + public DownloadFailedException(final String exceptionMsg, Throwable cause){ + super(exceptionMsg, cause); + } + } diff --git a/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/exception/EngineFailureException.java b/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/exception/EngineFailureException.java index 5eaf01fd00..503985505b 100644 --- a/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/exception/EngineFailureException.java +++ b/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/exception/EngineFailureException.java @@ -9,4 +9,8 @@ public class EngineFailureException extends RuntimeException { public EngineFailureException(final String exceptionMsg) { super(exceptionMsg); } + + public EngineFailureException(final String exceptionMsg, Throwable cause){ + super(exceptionMsg, cause); + } } diff --git a/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/LocalDBDownloadService.java b/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/LocalDBDownloadService.java index 4b8ea96676..70f3a793a1 100644 --- a/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/LocalDBDownloadService.java +++ b/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/LocalDBDownloadService.java @@ -5,10 +5,11 @@ package org.opensearch.dataprepper.plugins.geoip.extension.databasedownload; -import com.google.common.io.Files; import org.opensearch.dataprepper.plugins.geoip.extension.MaxMindDatabaseConfig; import java.io.File; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; import java.util.Set; /** @@ -37,8 +38,8 @@ public LocalDBDownloadService(final String destinationDirectory, final MaxMindDa public void initiateDownload() throws Exception { final Set strings = maxMindDatabaseConfig.getDatabasePaths().keySet(); for (final String key: strings) { - Files.copy(new File(maxMindDatabaseConfig.getDatabasePaths().get(key)), - new File(destinationDirectory + File.separator + key + MAXMIND_DATABASE_EXTENSION)); + Files.copy(new File(maxMindDatabaseConfig.getDatabasePaths().get(key)).toPath(), + new File(destinationDirectory + File.separator + key + MAXMIND_DATABASE_EXTENSION).toPath(), StandardCopyOption.REPLACE_EXISTING); } } } diff --git a/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/S3DBService.java b/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/S3DBService.java index 085f1d37bb..310961a8e9 100644 --- a/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/S3DBService.java +++ b/data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/S3DBService.java @@ -8,6 +8,10 @@ import org.opensearch.dataprepper.plugins.geoip.exception.DownloadFailedException; import org.opensearch.dataprepper.plugins.geoip.extension.AwsAuthenticationOptionsConfig; import org.opensearch.dataprepper.plugins.geoip.extension.MaxMindDatabaseConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import software.amazon.awssdk.core.sync.ResponseTransformer; import software.amazon.awssdk.services.s3.S3Client; import java.io.File; @@ -19,6 +23,7 @@ * Implementation class for Download through S3 */ public class S3DBService implements DBSource { + private static final Logger LOG = LoggerFactory.getLogger(S3DBService.class); private final AwsAuthenticationOptionsConfig awsAuthenticationOptionsConfig; private final String destinationDirectory; private final MaxMindDatabaseConfig maxMindDatabaseConfig; @@ -52,7 +57,7 @@ public void initiateDownload() { final String bucketName = uri.getHost(); buildRequestAndDownloadFile(bucketName, key, database); } catch (URISyntaxException ex) { - throw new DownloadFailedException("Failed to download database from S3." + ex.getMessage()); + throw new DownloadFailedException("Failed to download database from S3." + ex.getMessage(), ex); } } } @@ -64,15 +69,16 @@ public void initiateDownload() { * @param key Name of S3 object key * @param fileName Name of the file to save */ - private void buildRequestAndDownloadFile(final String bucketName, final String key, final String fileName) { + private void buildRequestAndDownloadFile(String bucketName, String key, String fileName) { + File destination = new File(this.destinationDirectory + File.separator + fileName + ".mmdb"); + final S3Client s3Client = this.createS3Client(); try { - final S3Client s3Client = createS3Client(); - - final File destination = new File(destinationDirectory + File.separator + fileName + MAXMIND_DATABASE_EXTENSION); - - s3Client.getObject(b -> b.bucket(bucketName).key(key), destination.toPath()); + s3Client.getObject((b) -> { + b.bucket(bucketName).key(key); + }, ResponseTransformer.toFile(destination)); } catch (Exception ex) { - throw new DownloadFailedException("Failed to download database from S3." + ex.getMessage()); + LOG.error("Failed to download database '{}' from S3: {}", fileName, ex.getMessage()); + throw new DownloadFailedException("Failed to download database from S3: " + ex.getMessage(), ex); } } diff --git a/data-prepper-plugins/geoip-processor/src/test/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/LocalDBDownloadServiceTest.java b/data-prepper-plugins/geoip-processor/src/test/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/LocalDBDownloadServiceTest.java index 9fdbdcb0ee..b361d1bbb6 100644 --- a/data-prepper-plugins/geoip-processor/src/test/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/LocalDBDownloadServiceTest.java +++ b/data-prepper-plugins/geoip-processor/src/test/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/LocalDBDownloadServiceTest.java @@ -15,9 +15,11 @@ import java.io.File; import java.io.FileWriter; import java.io.IOException; +import java.nio.file.Files; import java.util.Map; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.when; @@ -40,6 +42,30 @@ void initiateDownloadTest() throws IOException { assertTrue(new File(destinationDirectory + File.separator + "filename.mmdb").exists()); } + @Test + void testOverwriteFunctionality() throws Exception { + downloadThroughLocalPath = createObjectUnderTest(); + + String initialContent = "Initial database content\nVersion: 1.0"; + createFileWithContent(sourceDirectory + File.separator + "SampleFile.mmdb", initialContent); + + downloadThroughLocalPath.initiateDownload(); + + File destinationFile = new File(destinationDirectory + File.separator + "filename.mmdb"); + assertTrue(destinationFile.exists()); + + String copiedContent = Files.readString(destinationFile.toPath()); + assertEquals(initialContent, copiedContent); + + String updatedContent = "Updated database content\nVersion: 2.0"; + createFileWithContent(sourceDirectory + File.separator + "SampleFile.mmdb", updatedContent); + + downloadThroughLocalPath.initiateDownload(); + + String finalContent = Files.readString(destinationFile.toPath()); + assertEquals(updatedContent, finalContent); + } + private LocalDBDownloadService createObjectUnderTest() { when(maxMindDatabaseConfig.getDatabasePaths()).thenReturn(Map.of("filename", sourceDirectory + File.separator + "SampleFile.mmdb")); createFolder(destinationDirectory); @@ -58,8 +84,12 @@ private void generateSampleFiles() throws IOException { String content = "This is sample file"; createFolder(sourceDirectory); - new File(sourceDirectory + File.separator + fileName); - try (FileWriter writer = new FileWriter(sourceDirectory + File.separator + fileName)) { + createFileWithContent(sourceDirectory + File.separator + fileName, content); + } + + private void createFileWithContent(String filePath, String content) throws IOException { + createFolder(sourceDirectory); + try (FileWriter writer = new FileWriter(filePath)) { writer.write(content); } } @@ -81,4 +111,4 @@ public void deleteDirectory(final File file) { file.delete(); } } -} \ No newline at end of file +} diff --git a/data-prepper-plugins/geoip-processor/src/test/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/S3DBServiceTest.java b/data-prepper-plugins/geoip-processor/src/test/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/S3DBServiceTest.java index c6dacfc738..1eb5fb2629 100644 --- a/data-prepper-plugins/geoip-processor/src/test/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/S3DBServiceTest.java +++ b/data-prepper-plugins/geoip-processor/src/test/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/S3DBServiceTest.java @@ -8,30 +8,64 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.plugins.geoip.exception.DownloadFailedException; import org.opensearch.dataprepper.plugins.geoip.extension.AwsAuthenticationOptionsConfig; import org.opensearch.dataprepper.plugins.geoip.extension.MaxMindDatabaseConfig; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3ClientBuilder; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.Map; +import java.util.function.Consumer; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class S3DBServiceTest { - private static final String S3_URI = "s3://mybucket10012023/GeoLite2/"; + private static final String S3_URI = "s3://mybucket10012023/GeoLite2/test-database.mmdb"; private static final String DATABASE_DIR = "blue_database"; + private static final String DATABASE_NAME = "test-database"; + @Mock private MaxMindDatabaseConfig maxMindDatabaseConfig; @Mock private AwsAuthenticationOptionsConfig awsAuthenticationOptionsConfig; + @Mock + private S3Client s3Client; + @Mock + private S3ClientBuilder s3ClientBuilder; + @Mock + private AwsCredentialsProvider credentialsProvider; + + @TempDir + Path tempDir; @BeforeEach void setUp() { - when(maxMindDatabaseConfig.getDatabasePaths()).thenReturn(Map.of("database-name", S3_URI)); + when(maxMindDatabaseConfig.getDatabasePaths()).thenReturn(Map.of(DATABASE_NAME, S3_URI)); + when(awsAuthenticationOptionsConfig.getAwsRegion()).thenReturn(Region.US_EAST_1); + when(awsAuthenticationOptionsConfig.authenticateAwsConfiguration()).thenReturn(credentialsProvider); } @Test @@ -40,7 +74,44 @@ void initiateDownloadTest_DownloadFailedException() { assertThrows(DownloadFailedException.class, () -> downloadThroughS3.initiateDownload()); } + @Test + void testOverwriteFunctionality() throws IOException { + String testDatabaseDir = tempDir.toString(); + + String initialContent = "Initial database content\nVersion: 1.0"; + String updatedContent = "Updated database content\nVersion: 2.0"; + + File destinationFile = new File(testDatabaseDir + File.separator + DATABASE_NAME + ".mmdb"); + Files.createDirectories(destinationFile.getParentFile().toPath()); + Files.write(destinationFile.toPath(), initialContent.getBytes(StandardCharsets.UTF_8)); + + String readContent = Files.readString(destinationFile.toPath()); + assertEquals(initialContent, readContent); + + try (MockedStatic s3ClientMockedStatic = mockStatic(S3Client.class)) { + when(s3ClientBuilder.region(any(Region.class))).thenReturn(s3ClientBuilder); + when(s3ClientBuilder.credentialsProvider(any(AwsCredentialsProvider.class))).thenReturn(s3ClientBuilder); + when(s3ClientBuilder.build()).thenReturn(s3Client); + s3ClientMockedStatic.when(S3Client::builder).thenReturn(s3ClientBuilder); + + ArgumentCaptor> requestCaptor = ArgumentCaptor.forClass(Consumer.class); + ArgumentCaptor> transformerCaptor = ArgumentCaptor.forClass(ResponseTransformer.class); + + when(s3Client.getObject(requestCaptor.capture(), transformerCaptor.capture())).thenAnswer(invocation -> { + Files.write(destinationFile.toPath(), updatedContent.getBytes(StandardCharsets.UTF_8)); + return null; + }); + + S3DBService s3DBService = createObjectUnderTest(); + s3DBService.initiateDownload(); + } + String finalContent = Files.readString(destinationFile.toPath()); + assertEquals(updatedContent, finalContent); + verify(s3Client, times(1)).getObject(any(Consumer.class), any(ResponseTransformer.class)); + + } + private S3DBService createObjectUnderTest() { return new S3DBService(awsAuthenticationOptionsConfig, DATABASE_DIR, maxMindDatabaseConfig); } -} \ No newline at end of file +}