diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/MongoTasksRefresher.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/MongoTasksRefresher.java index 3fea4680e8..0f4d7b4fde 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/MongoTasksRefresher.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/MongoTasksRefresher.java @@ -44,9 +44,13 @@ public class MongoTasksRefresher implements PluginConfigObserver> buffer, final EnhancedSourceCoordinator sourceCoordinator, @@ -81,6 +85,8 @@ public void update(MongoDBSourceConfig pluginConfig) { currentExecutor.shutdownNow(); refreshJobs(pluginConfig); currentMongoDBSourceConfig = pluginConfig; + forceRefreshAttempts = 0; + lastForceRefreshTime = 0; } catch (Exception e) { executorRefreshErrorsCounter.increment(); LOG.error("Refreshing executor failed.", e); @@ -88,6 +94,31 @@ public void update(MongoDBSourceConfig pluginConfig) { } } + public void forceRefresh() { + if (forceRefreshAttempts >= MAX_FORCE_REFRESH_ATTEMPTS) { + LOG.warn("Max force refresh attempts ({}) reached. Waiting for next scheduled credential refresh.", + MAX_FORCE_REFRESH_ATTEMPTS); + return; + } + final long now = System.currentTimeMillis(); + final long backoff = BASE_BACKOFF_MS * (1L << forceRefreshAttempts); + if (now - lastForceRefreshTime < backoff) { + return; + } + lastForceRefreshTime = now; + forceRefreshAttempts++; + LOG.info("Forcing credential refresh due to authentication failure (attempt {}/{})", + forceRefreshAttempts, MAX_FORCE_REFRESH_ATTEMPTS); + try { + currentExecutor.shutdownNow(); + refreshJobs(currentMongoDBSourceConfig); + credentialsChangeCounter.increment(); + } catch (final Exception e) { + executorRefreshErrorsCounter.increment(); + LOG.error("Forced refresh failed.", e); + } + } + private void refreshJobs(MongoDBSourceConfig pluginConfig) { final List runnables = new ArrayList<>(); if (pluginConfig.getCollections().stream().anyMatch(CollectionConfig::isExport)) { @@ -98,7 +129,7 @@ private void refreshJobs(MongoDBSourceConfig pluginConfig) { } if (pluginConfig.getCollections().stream().anyMatch(CollectionConfig::isStream)) { runnables.add(new StreamScheduler( - sourceCoordinator, buffer, acknowledgementSetManager, pluginConfig, s3PathPrefix, pluginMetrics, documentDBAggregateMetrics)); + sourceCoordinator, buffer, acknowledgementSetManager, pluginConfig, s3PathPrefix, pluginMetrics, documentDBAggregateMetrics, this)); } this.currentExecutor = executorServiceFunction.apply(runnables.size()); runnables.forEach(currentExecutor::submit); diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java index 60bac24d18..303d4ae282 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java @@ -1,5 +1,6 @@ package org.opensearch.dataprepper.plugins.mongo.stream; +import com.mongodb.MongoSecurityException; import org.opensearch.dataprepper.buffer.common.BufferAccumulator; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; @@ -13,6 +14,7 @@ import org.opensearch.dataprepper.plugins.mongo.configuration.MongoDBSourceConfig; import org.opensearch.dataprepper.plugins.mongo.converter.PartitionKeyRecordConverter; import org.opensearch.dataprepper.plugins.mongo.coordination.partition.StreamPartition; +import org.opensearch.dataprepper.plugins.mongo.documentdb.MongoTasksRefresher; import org.opensearch.dataprepper.plugins.mongo.utils.DocumentDBSourceAggregateMetrics; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,13 +48,15 @@ public class StreamScheduler implements Runnable { private final String s3PathPrefix; private final PluginMetrics pluginMetrics; private final DocumentDBSourceAggregateMetrics documentDBAggregateMetrics; + private final MongoTasksRefresher mongoTasksRefresher; public StreamScheduler(final EnhancedSourceCoordinator sourceCoordinator, final Buffer> buffer, final AcknowledgementSetManager acknowledgementSetManager, final MongoDBSourceConfig sourceConfig, final String s3PathPrefix, final PluginMetrics pluginMetrics, - final DocumentDBSourceAggregateMetrics documentDBAggregateMetrics) { + final DocumentDBSourceAggregateMetrics documentDBAggregateMetrics, + final MongoTasksRefresher mongoTasksRefresher) { this.sourceCoordinator = sourceCoordinator; final BufferAccumulator> bufferAccumulator = BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT); recordBufferWriter = RecordBufferWriter.create(bufferAccumulator, pluginMetrics); @@ -62,6 +66,7 @@ public StreamScheduler(final EnhancedSourceCoordinator sourceCoordinator, this.s3PathPrefix = s3PathPrefix; this.pluginMetrics = pluginMetrics; this.documentDBAggregateMetrics = documentDBAggregateMetrics; + this.mongoTasksRefresher = mongoTasksRefresher; } @Override @@ -89,6 +94,9 @@ public void run() { } } catch (final Exception e) { LOG.error("Received an exception during stream processing from DocumentDB, backing off and retrying", e); + if (isCausedByMongoSecurityException(e) && mongoTasksRefresher != null) { + mongoTasksRefresher.forceRefresh(); + } if (streamPartition != null) { if (sourceConfig.isDisableS3ReadForLeader()) { System.clearProperty(STOP_S3_SCAN_PROCESSING_PROPERTY); @@ -131,4 +139,15 @@ private PartitionKeyRecordConverter getPartitionKeyRecordConverter(final StreamP return new PartitionKeyRecordConverter(streamPartition.getCollection(), StreamPartition.PARTITION_TYPE, s3Prefix); } + + private boolean isCausedByMongoSecurityException(final Throwable throwable) { + Throwable cause = throwable; + while (cause != null) { + if (cause instanceof MongoSecurityException) { + return true; + } + cause = cause.getCause(); + } + return false; + } } diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/documentdb/MongoTasksRefresherTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/documentdb/MongoTasksRefresherTest.java index 9ce93c8ded..dafdba81ec 100644 --- a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/documentdb/MongoTasksRefresherTest.java +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/documentdb/MongoTasksRefresherTest.java @@ -24,6 +24,7 @@ import org.opensearch.dataprepper.plugins.mongo.stream.StreamScheduler; import org.opensearch.dataprepper.plugins.mongo.utils.DocumentDBSourceAggregateMetrics; +import java.lang.reflect.Field; import java.util.List; import java.util.UUID; import java.util.concurrent.ExecutorService; @@ -266,6 +267,81 @@ void testTaskRefreshShutdown() { verify(executorService).submit(any(StreamScheduler.class)); verify(executorService).shutdownNow(); verifyNoMoreInteractions(executorServiceFunction); + } + + @Test + void testForceRefreshRestartsExecutor() { + when(pluginMetrics.counter(CREDENTIALS_CHANGED)).thenReturn(credentialsChangeCounter); + final MongoTasksRefresher objectUnderTest = createObjectUnderTest(); + objectUnderTest.initialize(sourceConfig); + final ExecutorService newExecutorService = mock(ExecutorService.class); + when(executorServiceFunction.apply(anyInt())).thenReturn(newExecutorService); + objectUnderTest.forceRefresh(); + verify(executorService).shutdownNow(); + verify(credentialsChangeCounter).increment(); + verify(executorServiceFunction, times(2)).apply(eq(3)); + } + + @Test + void testForceRefreshStopsAfterMaxAttempts() throws Exception { + when(pluginMetrics.counter(CREDENTIALS_CHANGED)).thenReturn(credentialsChangeCounter); + final MongoTasksRefresher objectUnderTest = createObjectUnderTest(); + objectUnderTest.initialize(sourceConfig); + when(executorServiceFunction.apply(anyInt())).thenReturn(mock(ExecutorService.class)); + objectUnderTest.forceRefresh(); + resetLastForceRefreshTime(objectUnderTest); + objectUnderTest.forceRefresh(); + resetLastForceRefreshTime(objectUnderTest); + objectUnderTest.forceRefresh(); + resetLastForceRefreshTime(objectUnderTest); + // 4th attempt should be ignored (max reached) + objectUnderTest.forceRefresh(); + verify(credentialsChangeCounter, times(3)).increment(); + } + + @Test + void testForceRefreshCounterResetsOnCredentialChange() throws Exception { + when(pluginMetrics.counter(CREDENTIALS_CHANGED)).thenReturn(credentialsChangeCounter); + final MongoTasksRefresher objectUnderTest = createObjectUnderTest(); + objectUnderTest.initialize(sourceConfig); + when(executorServiceFunction.apply(anyInt())).thenReturn(mock(ExecutorService.class)); + objectUnderTest.forceRefresh(); + resetLastForceRefreshTime(objectUnderTest); + objectUnderTest.forceRefresh(); + resetLastForceRefreshTime(objectUnderTest); + objectUnderTest.forceRefresh(); + // Simulate credential change via update() + when(sourceConfig.getAuthenticationConfig()).thenReturn(credentialsConfig); + when(credentialsConfig.getUsername()).thenReturn(TEST_USERNAME); + when(credentialsConfig.getPassword()).thenReturn(TEST_PASSWORD); + final MongoDBSourceConfig newSourceConfig = mock(MongoDBSourceConfig.class); + when(newSourceConfig.getCollections()).thenReturn(List.of(collectionConfig)); + final MongoDBSourceConfig.AuthenticationConfig newCredentialsConfig = mock( + MongoDBSourceConfig.AuthenticationConfig.class); + when(newSourceConfig.getAuthenticationConfig()).thenReturn(newCredentialsConfig); + when(newCredentialsConfig.getUsername()).thenReturn(TEST_USERNAME); + when(newCredentialsConfig.getPassword()).thenReturn(TEST_PASSWORD + "_changed"); + objectUnderTest.update(newSourceConfig); + // Force refresh should work again after counter reset + objectUnderTest.forceRefresh(); + // 3 from forceRefresh + 1 from update + 1 from forceRefresh after reset + verify(credentialsChangeCounter, times(5)).increment(); + } + + @Test + void testForceRefreshHandlesException() { + when(pluginMetrics.counter(CREDENTIALS_CHANGED)).thenReturn(credentialsChangeCounter); + when(pluginMetrics.counter(EXECUTOR_REFRESH_ERRORS)).thenReturn(executorRefreshErrorsCounter); + final MongoTasksRefresher objectUnderTest = createObjectUnderTest(); + objectUnderTest.initialize(sourceConfig); + doThrow(RuntimeException.class).when(executorService).shutdownNow(); + objectUnderTest.forceRefresh(); + verify(executorRefreshErrorsCounter).increment(); + } + private void resetLastForceRefreshTime(final MongoTasksRefresher refresher) throws Exception { + final Field field = MongoTasksRefresher.class.getDeclaredField("lastForceRefreshTime"); + field.setAccessible(true); + field.setLong(refresher, 0); } } \ No newline at end of file diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamSchedulerTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamSchedulerTest.java index 138cd4e7f7..d729ee1430 100644 --- a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamSchedulerTest.java +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamSchedulerTest.java @@ -1,5 +1,6 @@ package org.opensearch.dataprepper.plugins.mongo.stream; +import com.mongodb.MongoSecurityException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -17,6 +18,7 @@ import org.opensearch.dataprepper.plugins.mongo.configuration.MongoDBSourceConfig; import org.opensearch.dataprepper.plugins.mongo.converter.PartitionKeyRecordConverter; import org.opensearch.dataprepper.plugins.mongo.coordination.partition.StreamPartition; +import org.opensearch.dataprepper.plugins.mongo.documentdb.MongoTasksRefresher; import org.opensearch.dataprepper.plugins.mongo.utils.DocumentDBSourceAggregateMetrics; import java.time.Duration; @@ -71,13 +73,16 @@ public class StreamSchedulerTest { @Mock private StreamWorker streamWorker; + @Mock + private MongoTasksRefresher mongoTasksRefresher; + private StreamScheduler streamScheduler; @BeforeEach void setup() { lenient().when(sourceConfig.getCollections()).thenReturn(List.of(collectionConfig)); - streamScheduler = new StreamScheduler(sourceCoordinator, buffer, acknowledgementSetManager, sourceConfig, S3_PATH_PREFIX, pluginMetrics, documentDBSourceAggregateMetrics); + streamScheduler = new StreamScheduler(sourceCoordinator, buffer, acknowledgementSetManager, sourceConfig, S3_PATH_PREFIX, pluginMetrics, documentDBSourceAggregateMetrics, mongoTasksRefresher); } @@ -205,6 +210,66 @@ void test_stream_sourceCoordinatorThrowsException() { @Test void test_stream_withNullS3PathPrefix() { - assertThrows(IllegalArgumentException.class, () -> new StreamScheduler(sourceCoordinator, buffer, acknowledgementSetManager, sourceConfig, null, pluginMetrics, documentDBSourceAggregateMetrics)); + assertThrows(IllegalArgumentException.class, () -> new StreamScheduler(sourceCoordinator, buffer, acknowledgementSetManager, sourceConfig, null, pluginMetrics, documentDBSourceAggregateMetrics, mongoTasksRefresher)); + } + + @Test + void test_stream_mongoSecurityException_triggersForceRefresh() { + final String collection = UUID.randomUUID().toString(); + final StreamPartition streamPartition = new StreamPartition(collection, null); + given(sourceCoordinator.acquireAvailablePartition(StreamPartition.PARTITION_TYPE)).willReturn(Optional.of(streamPartition)); + given(collectionConfig.getCollection()).willReturn(collection); + final int streamBatchSize = 1000; + given(collectionConfig.getStreamBatchSize()).willReturn(streamBatchSize); + + final ExecutorService executorService = Executors.newSingleThreadExecutor(); + final Future future = executorService.submit(() -> { + try (MockedStatic streamWorkerMockedStatic = mockStatic(StreamWorker.class)) { + final MongoSecurityException securityException = new MongoSecurityException( + null, "auth failed", new RuntimeException("credential revoked")); + streamWorkerMockedStatic.when(() -> StreamWorker.create(any(RecordBufferWriter.class), any(PartitionKeyRecordConverter.class), eq(sourceConfig), + any(StreamAcknowledgementManager.class), any(DataStreamPartitionCheckpoint.class), eq(pluginMetrics), eq(DEFAULT_RECORD_FLUSH_BATCH_SIZE), + eq(DEFAULT_CHECKPOINT_INTERVAL_MILLS), eq(DEFAULT_BUFFER_WRITE_INTERVAL_MILLS), eq(streamBatchSize), any(DocumentDBSourceAggregateMetrics.class))) + .thenThrow(new RuntimeException(securityException)); + streamScheduler.run(); + } + }); + + await() + .atMost(Duration.ofSeconds(5)) + .untilAsserted(() -> verify(mongoTasksRefresher).forceRefresh()); + + future.cancel(true); + executorService.shutdownNow(); + } + + @Test + void test_stream_nonSecurityException_doesNotTriggerForceRefresh() { + final String collection = UUID.randomUUID().toString(); + final StreamPartition streamPartition = new StreamPartition(collection, null); + given(sourceCoordinator.acquireAvailablePartition(StreamPartition.PARTITION_TYPE)).willReturn(Optional.of(streamPartition)); + given(collectionConfig.getCollection()).willReturn(collection); + final int streamBatchSize = 1000; + given(collectionConfig.getStreamBatchSize()).willReturn(streamBatchSize); + + final ExecutorService executorService = Executors.newSingleThreadExecutor(); + final Future future = executorService.submit(() -> { + try (MockedStatic streamWorkerMockedStatic = mockStatic(StreamWorker.class)) { + streamWorkerMockedStatic.when(() -> StreamWorker.create(any(RecordBufferWriter.class), any(PartitionKeyRecordConverter.class), eq(sourceConfig), + any(StreamAcknowledgementManager.class), any(DataStreamPartitionCheckpoint.class), eq(pluginMetrics), eq(DEFAULT_RECORD_FLUSH_BATCH_SIZE), + eq(DEFAULT_CHECKPOINT_INTERVAL_MILLS), eq(DEFAULT_BUFFER_WRITE_INTERVAL_MILLS), eq(streamBatchSize), any(DocumentDBSourceAggregateMetrics.class))) + .thenThrow(RuntimeException.class); + streamScheduler.run(); + } + }); + + await() + .atMost(Duration.ofSeconds(5)) + .untilAsserted(() -> verify(sourceCoordinator).giveUpPartition(streamPartition)); + + verify(mongoTasksRefresher, never()).forceRefresh(); + + future.cancel(true); + executorService.shutdownNow(); } }