Skip to content

Commit 69eb534

Browse files
opensearch-trigger-bot[bot]KirtanKakadiyakirtanhk
authored
Fix S3DBService and LocalDBService file overwrite handling during downloads (#5911) (#5934)
Fix S3DBService/LocalDBService file overwrite handling during downloads (cherry picked from commit b87890a) Signed-off-by: kirtanhk <kirtanhk@amazon.com> Co-authored-by: Kirtan Kakadiya <35823164+KirtanKakadiya@users.noreply.github.com> Co-authored-by: kirtanhk <kirtanhk@amazon.com>
1 parent 779825a commit 69eb534

6 files changed

Lines changed: 134 additions & 17 deletions

File tree

data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/exception/DownloadFailedException.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,9 @@ public class DownloadFailedException extends EngineFailureException {
1212
public DownloadFailedException(final String exceptionMsg) {
1313
super(exceptionMsg);
1414
}
15+
16+
public DownloadFailedException(final String exceptionMsg, Throwable cause){
17+
super(exceptionMsg, cause);
18+
}
19+
1520
}

data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/exception/EngineFailureException.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,8 @@ public class EngineFailureException extends RuntimeException {
99
public EngineFailureException(final String exceptionMsg) {
1010
super(exceptionMsg);
1111
}
12+
13+
public EngineFailureException(final String exceptionMsg, Throwable cause){
14+
super(exceptionMsg, cause);
15+
}
1216
}

data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/LocalDBDownloadService.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55

66
package org.opensearch.dataprepper.plugins.geoip.extension.databasedownload;
77

8-
import com.google.common.io.Files;
98
import org.opensearch.dataprepper.plugins.geoip.extension.MaxMindDatabaseConfig;
109

1110
import java.io.File;
11+
import java.nio.file.Files;
12+
import java.nio.file.StandardCopyOption;
1213
import java.util.Set;
1314

1415
/**
@@ -37,8 +38,8 @@ public LocalDBDownloadService(final String destinationDirectory, final MaxMindDa
3738
public void initiateDownload() throws Exception {
3839
final Set<String> strings = maxMindDatabaseConfig.getDatabasePaths().keySet();
3940
for (final String key: strings) {
40-
Files.copy(new File(maxMindDatabaseConfig.getDatabasePaths().get(key)),
41-
new File(destinationDirectory + File.separator + key + MAXMIND_DATABASE_EXTENSION));
41+
Files.copy(new File(maxMindDatabaseConfig.getDatabasePaths().get(key)).toPath(),
42+
new File(destinationDirectory + File.separator + key + MAXMIND_DATABASE_EXTENSION).toPath(), StandardCopyOption.REPLACE_EXISTING);
4243
}
4344
}
4445
}

data-prepper-plugins/geoip-processor/src/main/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/S3DBService.java

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
import org.opensearch.dataprepper.plugins.geoip.exception.DownloadFailedException;
99
import org.opensearch.dataprepper.plugins.geoip.extension.AwsAuthenticationOptionsConfig;
1010
import org.opensearch.dataprepper.plugins.geoip.extension.MaxMindDatabaseConfig;
11+
import org.slf4j.Logger;
12+
import org.slf4j.LoggerFactory;
13+
14+
import software.amazon.awssdk.core.sync.ResponseTransformer;
1115
import software.amazon.awssdk.services.s3.S3Client;
1216

1317
import java.io.File;
@@ -19,6 +23,7 @@
1923
* Implementation class for Download through S3
2024
*/
2125
public class S3DBService implements DBSource {
26+
private static final Logger LOG = LoggerFactory.getLogger(S3DBService.class);
2227
private final AwsAuthenticationOptionsConfig awsAuthenticationOptionsConfig;
2328
private final String destinationDirectory;
2429
private final MaxMindDatabaseConfig maxMindDatabaseConfig;
@@ -52,7 +57,7 @@ public void initiateDownload() {
5257
final String bucketName = uri.getHost();
5358
buildRequestAndDownloadFile(bucketName, key, database);
5459
} catch (URISyntaxException ex) {
55-
throw new DownloadFailedException("Failed to download database from S3." + ex.getMessage());
60+
throw new DownloadFailedException("Failed to download database from S3." + ex.getMessage(), ex);
5661
}
5762
}
5863
}
@@ -64,15 +69,16 @@ public void initiateDownload() {
6469
* @param key Name of S3 object key
6570
* @param fileName Name of the file to save
6671
*/
67-
private void buildRequestAndDownloadFile(final String bucketName, final String key, final String fileName) {
72+
private void buildRequestAndDownloadFile(String bucketName, String key, String fileName) {
73+
File destination = new File(this.destinationDirectory + File.separator + fileName + ".mmdb");
74+
final S3Client s3Client = this.createS3Client();
6875
try {
69-
final S3Client s3Client = createS3Client();
70-
71-
final File destination = new File(destinationDirectory + File.separator + fileName + MAXMIND_DATABASE_EXTENSION);
72-
73-
s3Client.getObject(b -> b.bucket(bucketName).key(key), destination.toPath());
76+
s3Client.getObject((b) -> {
77+
b.bucket(bucketName).key(key);
78+
}, ResponseTransformer.toFile(destination));
7479
} catch (Exception ex) {
75-
throw new DownloadFailedException("Failed to download database from S3." + ex.getMessage());
80+
LOG.error("Failed to download database '{}' from S3: {}", fileName, ex.getMessage());
81+
throw new DownloadFailedException("Failed to download database from S3: " + ex.getMessage(), ex);
7682
}
7783
}
7884

data-prepper-plugins/geoip-processor/src/test/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/LocalDBDownloadServiceTest.java

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import java.io.File;
1616
import java.io.FileWriter;
1717
import java.io.IOException;
18+
import java.nio.file.Files;
1819
import java.util.Map;
1920

2021
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
22+
import static org.junit.jupiter.api.Assertions.assertEquals;
2123
import static org.junit.jupiter.api.Assertions.assertTrue;
2224
import static org.mockito.Mockito.when;
2325

@@ -40,6 +42,30 @@ void initiateDownloadTest() throws IOException {
4042
assertTrue(new File(destinationDirectory + File.separator + "filename.mmdb").exists());
4143
}
4244

45+
@Test
46+
void testOverwriteFunctionality() throws Exception {
47+
downloadThroughLocalPath = createObjectUnderTest();
48+
49+
String initialContent = "Initial database content\nVersion: 1.0";
50+
createFileWithContent(sourceDirectory + File.separator + "SampleFile.mmdb", initialContent);
51+
52+
downloadThroughLocalPath.initiateDownload();
53+
54+
File destinationFile = new File(destinationDirectory + File.separator + "filename.mmdb");
55+
assertTrue(destinationFile.exists());
56+
57+
String copiedContent = Files.readString(destinationFile.toPath());
58+
assertEquals(initialContent, copiedContent);
59+
60+
String updatedContent = "Updated database content\nVersion: 2.0";
61+
createFileWithContent(sourceDirectory + File.separator + "SampleFile.mmdb", updatedContent);
62+
63+
downloadThroughLocalPath.initiateDownload();
64+
65+
String finalContent = Files.readString(destinationFile.toPath());
66+
assertEquals(updatedContent, finalContent);
67+
}
68+
4369
private LocalDBDownloadService createObjectUnderTest() {
4470
when(maxMindDatabaseConfig.getDatabasePaths()).thenReturn(Map.of("filename", sourceDirectory + File.separator + "SampleFile.mmdb"));
4571
createFolder(destinationDirectory);
@@ -58,8 +84,12 @@ private void generateSampleFiles() throws IOException {
5884
String content = "This is sample file";
5985

6086
createFolder(sourceDirectory);
61-
new File(sourceDirectory + File.separator + fileName);
62-
try (FileWriter writer = new FileWriter(sourceDirectory + File.separator + fileName)) {
87+
createFileWithContent(sourceDirectory + File.separator + fileName, content);
88+
}
89+
90+
private void createFileWithContent(String filePath, String content) throws IOException {
91+
createFolder(sourceDirectory);
92+
try (FileWriter writer = new FileWriter(filePath)) {
6393
writer.write(content);
6494
}
6595
}
@@ -81,4 +111,4 @@ public void deleteDirectory(final File file) {
81111
file.delete();
82112
}
83113
}
84-
}
114+
}

data-prepper-plugins/geoip-processor/src/test/java/org/opensearch/dataprepper/plugins/geoip/extension/databasedownload/S3DBServiceTest.java

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,64 @@
88
import org.junit.jupiter.api.BeforeEach;
99
import org.junit.jupiter.api.Test;
1010
import org.junit.jupiter.api.extension.ExtendWith;
11+
import org.junit.jupiter.api.io.TempDir;
12+
import org.mockito.ArgumentCaptor;
1113
import org.mockito.Mock;
14+
import org.mockito.MockedStatic;
1215
import org.mockito.junit.jupiter.MockitoExtension;
1316
import org.opensearch.dataprepper.plugins.geoip.exception.DownloadFailedException;
1417
import org.opensearch.dataprepper.plugins.geoip.extension.AwsAuthenticationOptionsConfig;
1518
import org.opensearch.dataprepper.plugins.geoip.extension.MaxMindDatabaseConfig;
19+
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
20+
import software.amazon.awssdk.core.sync.ResponseTransformer;
21+
import software.amazon.awssdk.regions.Region;
22+
import software.amazon.awssdk.services.s3.S3Client;
23+
import software.amazon.awssdk.services.s3.S3ClientBuilder;
24+
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
25+
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
1626

27+
import java.io.File;
28+
import java.io.IOException;
29+
import java.nio.charset.StandardCharsets;
30+
import java.nio.file.Files;
31+
import java.nio.file.Path;
1732
import java.util.Map;
33+
import java.util.function.Consumer;
1834

35+
import static org.junit.jupiter.api.Assertions.assertEquals;
1936
import static org.junit.jupiter.api.Assertions.assertThrows;
37+
import static org.mockito.ArgumentMatchers.any;
38+
import static org.mockito.Mockito.mockStatic;
39+
import static org.mockito.Mockito.times;
40+
import static org.mockito.Mockito.verify;
2041
import static org.mockito.Mockito.when;
2142

2243
@ExtendWith(MockitoExtension.class)
2344
class S3DBServiceTest {
2445

25-
private static final String S3_URI = "s3://mybucket10012023/GeoLite2/";
46+
private static final String S3_URI = "s3://mybucket10012023/GeoLite2/test-database.mmdb";
2647
private static final String DATABASE_DIR = "blue_database";
48+
private static final String DATABASE_NAME = "test-database";
49+
2750
@Mock
2851
private MaxMindDatabaseConfig maxMindDatabaseConfig;
2952
@Mock
3053
private AwsAuthenticationOptionsConfig awsAuthenticationOptionsConfig;
54+
@Mock
55+
private S3Client s3Client;
56+
@Mock
57+
private S3ClientBuilder s3ClientBuilder;
58+
@Mock
59+
private AwsCredentialsProvider credentialsProvider;
60+
61+
@TempDir
62+
Path tempDir;
3163

3264
@BeforeEach
3365
void setUp() {
34-
when(maxMindDatabaseConfig.getDatabasePaths()).thenReturn(Map.of("database-name", S3_URI));
66+
when(maxMindDatabaseConfig.getDatabasePaths()).thenReturn(Map.of(DATABASE_NAME, S3_URI));
67+
when(awsAuthenticationOptionsConfig.getAwsRegion()).thenReturn(Region.US_EAST_1);
68+
when(awsAuthenticationOptionsConfig.authenticateAwsConfiguration()).thenReturn(credentialsProvider);
3569
}
3670

3771
@Test
@@ -40,7 +74,44 @@ void initiateDownloadTest_DownloadFailedException() {
4074
assertThrows(DownloadFailedException.class, () -> downloadThroughS3.initiateDownload());
4175
}
4276

77+
@Test
78+
void testOverwriteFunctionality() throws IOException {
79+
String testDatabaseDir = tempDir.toString();
80+
81+
String initialContent = "Initial database content\nVersion: 1.0";
82+
String updatedContent = "Updated database content\nVersion: 2.0";
83+
84+
File destinationFile = new File(testDatabaseDir + File.separator + DATABASE_NAME + ".mmdb");
85+
Files.createDirectories(destinationFile.getParentFile().toPath());
86+
Files.write(destinationFile.toPath(), initialContent.getBytes(StandardCharsets.UTF_8));
87+
88+
String readContent = Files.readString(destinationFile.toPath());
89+
assertEquals(initialContent, readContent);
90+
91+
try (MockedStatic<S3Client> s3ClientMockedStatic = mockStatic(S3Client.class)) {
92+
when(s3ClientBuilder.region(any(Region.class))).thenReturn(s3ClientBuilder);
93+
when(s3ClientBuilder.credentialsProvider(any(AwsCredentialsProvider.class))).thenReturn(s3ClientBuilder);
94+
when(s3ClientBuilder.build()).thenReturn(s3Client);
95+
s3ClientMockedStatic.when(S3Client::builder).thenReturn(s3ClientBuilder);
96+
97+
ArgumentCaptor<Consumer<GetObjectRequest.Builder>> requestCaptor = ArgumentCaptor.forClass(Consumer.class);
98+
ArgumentCaptor<ResponseTransformer<GetObjectResponse, ?>> transformerCaptor = ArgumentCaptor.forClass(ResponseTransformer.class);
99+
100+
when(s3Client.getObject(requestCaptor.capture(), transformerCaptor.capture())).thenAnswer(invocation -> {
101+
Files.write(destinationFile.toPath(), updatedContent.getBytes(StandardCharsets.UTF_8));
102+
return null;
103+
});
104+
105+
S3DBService s3DBService = createObjectUnderTest();
106+
s3DBService.initiateDownload();
107+
}
108+
String finalContent = Files.readString(destinationFile.toPath());
109+
assertEquals(updatedContent, finalContent);
110+
verify(s3Client, times(1)).getObject(any(Consumer.class), any(ResponseTransformer.class));
111+
112+
}
113+
43114
private S3DBService createObjectUnderTest() {
44115
return new S3DBService(awsAuthenticationOptionsConfig, DATABASE_DIR, maxMindDatabaseConfig);
45116
}
46-
}
117+
}

0 commit comments

Comments
 (0)