diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/client/BsonHelper.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/client/BsonHelper.java index e9613f4a49..1250570ba1 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/client/BsonHelper.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/client/BsonHelper.java @@ -186,7 +186,7 @@ private static Bson buildQuery(final Function function, final Stri } } - private static boolean isClassNumber(final String className) { + public static boolean isClassNumber(final String className) { return className.equals("java.lang.Integer") || className.equals("java.lang.Long") || className.equals("java.lang.Double") || className.equals("org.bson.types.Decimal128"); } diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/export/MongoDBExportPartitionSupplier.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/export/MongoDBExportPartitionSupplier.java index dfbf518318..56ba8c7147 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/export/MongoDBExportPartitionSupplier.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/export/MongoDBExportPartitionSupplier.java @@ -6,13 +6,14 @@ package org.opensearch.dataprepper.plugins.mongo.export; import com.mongodb.MongoClientException; -import com.mongodb.client.FindIterable; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; -import com.mongodb.client.MongoCursor; import com.mongodb.client.MongoDatabase; import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; import org.bson.Document; +import org.bson.conversions.Bson; import org.opensearch.dataprepper.model.source.coordinator.PartitionIdentifier; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; import org.opensearch.dataprepper.plugins.mongo.client.BsonHelper; @@ -37,6 +38,9 @@ public class MongoDBExportPartitionSupplier implements Function|||| private static final String COLLECTION_SPLITTER = "\\."; + private static final Bson ID_PROJECTION = Projections.include("_id"); + private static final Bson ID_ASC = Sorts.ascending("_id"); + private static final Bson ID_DESC = Sorts.descending("_id"); private final MongoDBSourceConfig sourceConfig; private final EnhancedSourceCoordinator enhancedSourceCoordinator; @@ -50,95 +54,143 @@ public MongoDBExportPartitionSupplier(final MongoDBSourceConfig sourceConfig, this.documentDBAggregateMetrics = documentDBAggregateMetrics; } + /** + * Detects whether the collection has a uniform _id type by checking the first and last documents. + * If uniform, we can use a simple Filters.gt() instead of the complex $or query across all BSON types. + */ + boolean isUniformIdType(final MongoCollection col) { + final Document first = col.find().projection(ID_PROJECTION).sort(ID_ASC).limit(1).first(); + final Document last = col.find().projection(ID_PROJECTION).sort(ID_DESC).limit(1).first(); + if (first == null || last == null) { + return true; + } + final String firstType = first.get("_id").getClass().getName(); + final String lastType = last.get("_id").getClass().getName(); + if (BsonHelper.isClassNumber(firstType) && BsonHelper.isClassNumber(lastType)) { + return true; + } + return firstType.equals(lastType); + } + + private Bson buildNextStartFilter(final Object lastLteValue, final String lteClassName, final boolean uniformType) { + if (uniformType) { + return Filters.gt("_id", lastLteValue); + } + final String lteValueString = BsonHelper.getPartitionStringFromMongoDBId(lastLteValue, lteClassName); + return buildGtQuery(lteValueString, lteClassName, MAX_KEY); + } + + private void addPartition(final List partitions, final String collectionDbName, + final Object gteValue, final String gteClassName, + final Object lteValue, final String lteClassName) { + final String gteValueString = BsonHelper.getPartitionStringFromMongoDBId(gteValue, gteClassName); + final String lteValueString = BsonHelper.getPartitionStringFromMongoDBId(lteValue, lteClassName); + LOG.debug("Partition of {} : { gte: {} class: {}, lte: {} class {} }", + collectionDbName, gteValueString, gteClassName, lteValueString, lteClassName); + partitions.add(PartitionIdentifier.builder() + .withPartitionKey(String.format(MONGODB_PARTITION_KEY_FORMAT, + collectionDbName, gteValueString, lteValueString, gteClassName, lteClassName)) + .build()); + } + private PartitionIdentifierBatch buildPartitions(final ExportPartition exportPartition) { documentDBAggregateMetrics.getExportApiInvocations().increment(); final List collectionPartitions = new ArrayList<>(); final String collectionDbName = exportPartition.getCollection(); - List collection = List.of(collectionDbName.split(COLLECTION_SPLITTER)); + final List collection = List.of(collectionDbName.split(COLLECTION_SPLITTER)); if (collection.size() < 2) { documentDBAggregateMetrics.getExport4xxErrors().increment(); throw new IllegalArgumentException("Invalid Collection Name. Must be in db.collection format"); } - final Optional exportProgressStateOptional = exportPartition - .getProgressState(); - final Object lastEndDocId = exportProgressStateOptional.map( - ExportProgressState::getLastEndDocId).orElse(null); + + final Optional exportProgressStateOptional = exportPartition.getProgressState(); + final Object lastEndDocId = exportProgressStateOptional.map(ExportProgressState::getLastEndDocId).orElse(null); boolean isLastBatch = false; Object endDocId = lastEndDocId; + try (MongoClient mongoClient = MongoDBConnection.getMongoClient(sourceConfig)) { final MongoDatabase db = mongoClient.getDatabase(collection.get(0)); - final MongoCollection col = db.getCollection(collectionDbName.substring(collection.get(0).length()+1)); + final MongoCollection col = db.getCollection( + collectionDbName.substring(collection.get(0).length() + 1)); final int partitionSize = exportPartition.getPartitionSize(); - FindIterable startIterable; + + final boolean uniformType = isUniformIdType(col); + LOG.info("Collection {} has {} _id type. Using {} partition query strategy.", + collectionDbName, uniformType ? "uniform" : "mixed", uniformType ? "simple $gt" : "$or-based"); + + Bson startFilter; if (lastEndDocId != null) { - startIterable = col.find(Filters.gt("_id", lastEndDocId)) - .projection(new Document("_id", 1)) - .sort(new Document("_id", 1)) - .limit(1); + startFilter = Filters.gt("_id", lastEndDocId); } else { - startIterable = col.find() - .projection(new Document("_id", 1)) - .sort(new Document("_id", 1)) - .limit(1); + startFilter = new Document(); } + while (!Thread.currentThread().isInterrupted()) { - try (final MongoCursor startCursor = startIterable.iterator()) { - if (!startCursor.hasNext()) { - LOG.info("No records to process or has reached end of the export partition."); - isLastBatch = true; - break; - } - final Document startDoc = startCursor.next(); - final Object gteValue = startDoc.get("_id"); - final String gteClassName = gteValue.getClass().getName(); - - // Get end doc - Document endDoc = startIterable.skip(partitionSize - 1).limit(1).first(); - if (endDoc == null) { - // this means we have reached the end of the doc - endDoc = col.find() - .projection(new Document("_id", 1)) - .sort(new Document("_id", -1)) - .limit(1) - .first(); - isLastBatch = true; - } + final Document startDoc = col.find(startFilter) + .projection(ID_PROJECTION) + .sort(ID_ASC) + .limit(1) + .first(); + + if (startDoc == null) { + LOG.info("No records to process or has reached end of the export partition."); + isLastBatch = true; + break; + } - final Object lteValue = endDoc.get("_id"); - final String lteClassName = lteValue.getClass().getName(); - endDocId = lteValue; - final String gteValueString = BsonHelper.getPartitionStringFromMongoDBId(gteValue, gteClassName); - final String lteValueString = BsonHelper.getPartitionStringFromMongoDBId(lteValue, lteClassName); - LOG.debug("Partition of {} : { gte: {} class: {}, lte: {} class {} }", collectionDbName, gteValueString, gteClassName, lteValueString, lteClassName); - collectionPartitions.add( - PartitionIdentifier - .builder() - .withPartitionKey(String.format(MONGODB_PARTITION_KEY_FORMAT, collectionDbName, gteValueString, lteValueString, gteClassName, lteClassName)) - .build()); - documentDBAggregateMetrics.getExportPartitionQueryCount().increment(); - - if (isLastBatch) { + final Object gteValue = startDoc.get("_id"); + final String gteClassName = gteValue.getClass().getName(); + + final Document endDoc = col.find(Filters.gte("_id", gteValue)) + .projection(ID_PROJECTION) + .sort(ID_ASC) + .skip(partitionSize - 1) + .limit(1) + .first(); + + final Object lteValue; + final String lteClassName; + + if (endDoc == null) { + final Document lastDoc = col.find() + .projection(ID_PROJECTION) + .sort(ID_DESC) + .limit(1) + .first(); + if (lastDoc == null) { + isLastBatch = true; break; } + lteValue = lastDoc.get("_id"); + lteClassName = lteValue.getClass().getName(); + isLastBatch = true; + } else { + lteValue = endDoc.get("_id"); + lteClassName = lteValue.getClass().getName(); + } - // extend the ownership of the partition - enhancedSourceCoordinator.saveProgressStateForPartition(exportPartition, null); + endDocId = lteValue; + addPartition(collectionPartitions, collectionDbName, gteValue, gteClassName, lteValue, lteClassName); + documentDBAggregateMetrics.getExportPartitionQueryCount().increment(); - startIterable = col.find(buildGtQuery(lteValueString, lteClassName, MAX_KEY)) - .projection(new Document("_id", 1)) - .sort(new Document("_id", 1)) - .limit(1); + if (isLastBatch) { + break; } + + // extend the ownership of the partition + enhancedSourceCoordinator.saveProgressStateForPartition(exportPartition, null); + + startFilter = buildNextStartFilter(lteValue, lteClassName, uniformType); } } catch (final IllegalArgumentException | MongoClientException e) { // IllegalArgumentException is thrown when database or collection name is not valid // MongoClientException is thrown for exceptions indicating a failure condition with the MongoClient documentDBAggregateMetrics.getExport4xxErrors().increment(); - LOG.error("Client side exception while build partitions.", e); + LOG.error("Client side exception while building partitions.", e); throw new RuntimeException(e); } catch (final Exception e) { documentDBAggregateMetrics.getExport5xxErrors().increment(); - LOG.error("Server side exception while build partitions.", e); + LOG.error("Server side exception while building partitions.", e); throw new RuntimeException(e); } diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/MongoDBExportPartitionSupplierIsUniformIdTypeTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/MongoDBExportPartitionSupplierIsUniformIdTypeTest.java new file mode 100644 index 0000000000..0063eae13d --- /dev/null +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/MongoDBExportPartitionSupplierIsUniformIdTypeTest.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.dataprepper.plugins.mongo.export; + +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoCollection; +import org.bson.Document; +import org.bson.types.Decimal128; +import org.bson.types.ObjectId; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; +import org.opensearch.dataprepper.plugins.mongo.configuration.MongoDBSourceConfig; +import org.opensearch.dataprepper.plugins.mongo.utils.DocumentDBSourceAggregateMetrics; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class MongoDBExportPartitionSupplierIsUniformIdTypeTest { + + @Mock + private MongoDBSourceConfig sourceConfig; + @Mock + private EnhancedSourceCoordinator sourceCoordinator; + @Mock + private DocumentDBSourceAggregateMetrics aggregateMetrics; + @Mock + private MongoCollection collection; + @Mock + private FindIterable findIterable; + + private MongoDBExportPartitionSupplier supplier; + + @BeforeEach + void setUp() { + supplier = new MongoDBExportPartitionSupplier(sourceConfig, sourceCoordinator, aggregateMetrics); + when(collection.find()).thenReturn(findIterable); + when(findIterable.projection(any())).thenReturn(findIterable); + when(findIterable.sort(any())).thenReturn(findIterable); + when(findIterable.limit(1)).thenReturn(findIterable); + } + + @Test + void isUniformIdType_emptyCollection_returnsTrue() { + when(findIterable.first()).thenReturn(null); + assertThat(supplier.isUniformIdType(collection), is(true)); + } + + @Test + void isUniformIdType_uniformObjectId_returnsTrue() { + when(findIterable.first()) + .thenReturn(new Document("_id", new ObjectId())) + .thenReturn(new Document("_id", new ObjectId())); + assertThat(supplier.isUniformIdType(collection), is(true)); + } + + @Test + void isUniformIdType_uniformString_returnsTrue() { + when(findIterable.first()) + .thenReturn(new Document("_id", "abc")) + .thenReturn(new Document("_id", "xyz")); + assertThat(supplier.isUniformIdType(collection), is(true)); + } + + @Test + void isUniformIdType_mixedTypes_returnsFalse() { + when(findIterable.first()) + .thenReturn(new Document("_id", 1)) + .thenReturn(new Document("_id", new ObjectId())); + assertThat(supplier.isUniformIdType(collection), is(false)); + } + + @Test + void isUniformIdType_integerAndLong_returnsTrue() { + when(findIterable.first()) + .thenReturn(new Document("_id", 42)) + .thenReturn(new Document("_id", 999999999999L)); + assertThat(supplier.isUniformIdType(collection), is(true)); + } + + @Test + void isUniformIdType_doubleAndDecimal128_returnsTrue() { + when(findIterable.first()) + .thenReturn(new Document("_id", 3.14)) + .thenReturn(new Document("_id", Decimal128.parse("99.99"))); + assertThat(supplier.isUniformIdType(collection), is(true)); + } + + @Test + void isUniformIdType_stringAndObjectId_returnsFalse() { + when(findIterable.first()) + .thenReturn(new Document("_id", "abc")) + .thenReturn(new Document("_id", new ObjectId())); + assertThat(supplier.isUniformIdType(collection), is(false)); + } +} diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/MongoDBExportPartitionSupplierTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/MongoDBExportPartitionSupplierTest.java index 0329cf7b72..fac1888744 100644 --- a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/MongoDBExportPartitionSupplierTest.java +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/MongoDBExportPartitionSupplierTest.java @@ -9,7 +9,6 @@ import com.mongodb.client.FindIterable; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; -import com.mongodb.client.MongoCursor; import com.mongodb.client.MongoDatabase; import io.micrometer.core.instrument.Counter; import org.bson.Document; @@ -92,32 +91,49 @@ public void setup() { @Test public void test_buildPartitionsCollection() { try (MockedStatic mongoDBConnectionMockedStatic = mockStatic(MongoDBConnection.class)) { - // Given a collection with 5000 items which should be split to two partitions: 0-3999 and 4000-4999 + // Given a collection with 5000 items split into two partitions: 0-3999 and 4000-4999 MongoClient mongoClient = mock(MongoClient.class); MongoDatabase mongoDatabase = mock(MongoDatabase.class); MongoCollection col = mock(MongoCollection.class); - FindIterable findIterable = mock(FindIterable.class); - MongoCursor cursor = mock(MongoCursor.class); + + // Each col.find() / col.find(Bson) call returns a fresh FindIterable mock + FindIterable uniformCheckFirst = mock(FindIterable.class); + FindIterable uniformCheckLast = mock(FindIterable.class); + FindIterable startIterable1 = mock(FindIterable.class); + FindIterable endIterable1 = mock(FindIterable.class); + FindIterable startIterable2 = mock(FindIterable.class); + FindIterable endIterable2 = mock(FindIterable.class); + FindIterable lastDocIterable = mock(FindIterable.class); + mongoDBConnectionMockedStatic.when(() -> MongoDBConnection.getMongoClient(any(MongoDBSourceConfig.class))) .thenReturn(mongoClient); when(mongoClient.getDatabase(anyString())).thenReturn(mongoDatabase); when(mongoDatabase.getCollection(anyString())).thenReturn(col); - when(col.find()).thenReturn(findIterable); - when(col.find(any(Bson.class))).thenReturn(findIterable); - when(findIterable.projection(any())).thenReturn(findIterable); - when(findIterable.sort(any())).thenReturn(findIterable); - when(findIterable.skip(anyInt())).thenReturn(findIterable); - when(findIterable.limit(anyInt())).thenReturn(findIterable); - when(findIterable.iterator()).thenReturn(cursor); - when(cursor.hasNext()).thenReturn(true, true, false); - // mock startDoc and endDoc returns, 0-3999, and 4000-4999 - when(cursor.next()) - .thenReturn(new Document("_id", "0")) - .thenReturn(new Document("_id", "4000")); - when(findIterable.first()) - .thenReturn(new Document("_id", "3999")) - .thenReturn(null) - .thenReturn(new Document("_id", "4999")); + + // isUniformIdType: col.find() called twice (first asc, last desc) + // then col.find() for last doc when endDoc is null + when(col.find()).thenReturn(uniformCheckFirst, uniformCheckLast, lastDocIterable); + setupFindIterable(uniformCheckFirst, new Document("_id", "0")); + setupFindIterable(uniformCheckLast, new Document("_id", "4999")); + + // buildPartitions loop: + // 1st iteration: col.find(empty doc) for start, col.find(gte) for end + // 2nd iteration: col.find(gt) for start, col.find(gte) for end -> null + // Then isLastBatch=true breaks the loop + when(col.find(any(Bson.class))).thenReturn( + startIterable1, // 1st start doc + endIterable1, // 1st end doc (gte + skip) + startIterable2, // 2nd start doc (gt filter) + endIterable2 // 2nd end doc (gte + skip) -> null + ); + + setupFindIterable(startIterable1, new Document("_id", "0")); + setupFindIterableWithSkip(endIterable1, new Document("_id", "3999")); + setupFindIterable(startIterable2, new Document("_id", "4000")); + setupFindIterableWithSkip(endIterable2, null); + // last doc query (col.find() desc) when endDoc is null + setupFindIterable(lastDocIterable, new Document("_id", "4999")); + // When Apply Partition create logics final PartitionIdentifierBatch partitionIdentifierBatch = testSupplier.apply(exportPartition); assertThat(partitionIdentifierBatch.isLastBatch(), is(true)); @@ -160,4 +176,19 @@ public void test_buildPartitions_dbException() { verify(export5xxErrors, never()).increment(); } } -} \ No newline at end of file + + private void setupFindIterable(FindIterable iterable, Document result) { + when(iterable.projection(any())).thenReturn(iterable); + when(iterable.sort(any())).thenReturn(iterable); + when(iterable.limit(anyInt())).thenReturn(iterable); + when(iterable.first()).thenReturn(result); + } + + private void setupFindIterableWithSkip(FindIterable iterable, Document result) { + when(iterable.projection(any())).thenReturn(iterable); + when(iterable.sort(any())).thenReturn(iterable); + when(iterable.skip(anyInt())).thenReturn(iterable); + when(iterable.limit(anyInt())).thenReturn(iterable); + when(iterable.first()).thenReturn(result); + } +}