Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ private static Bson buildQuery(final Function<Object, Bson> function, final Stri
}
}

private static boolean isClassNumber(final String className) {
public static boolean isClassNumber(final String className) {
return className.equals("java.lang.Integer") || className.equals("java.lang.Long") || className.equals("java.lang.Double")
|| className.equals("org.bson.types.Decimal128");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
package org.opensearch.dataprepper.plugins.mongo.export;

import com.mongodb.MongoClientException;
import com.mongodb.client.FindIterable;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoCursor;
import com.mongodb.client.MongoDatabase;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.Projections;
import com.mongodb.client.model.Sorts;
import org.bson.Document;
import org.bson.conversions.Bson;
import org.opensearch.dataprepper.model.source.coordinator.PartitionIdentifier;
import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator;
import org.opensearch.dataprepper.plugins.mongo.client.BsonHelper;
Expand All @@ -37,6 +38,9 @@ public class MongoDBExportPartitionSupplier implements Function<ExportPartition,
private static final Logger LOG = LoggerFactory.getLogger(MongoDBExportPartitionSupplier.class);
private static final String MONGODB_PARTITION_KEY_FORMAT = "%s|%s|%s|%s|%s"; // partition format: <db.collection>|<gte>|<lt>|<gteClassName>|<lteClassName>
private static final String COLLECTION_SPLITTER = "\\.";
private static final Bson ID_PROJECTION = Projections.include("_id");
private static final Bson ID_ASC = Sorts.ascending("_id");
private static final Bson ID_DESC = Sorts.descending("_id");

private final MongoDBSourceConfig sourceConfig;
private final EnhancedSourceCoordinator enhancedSourceCoordinator;
Expand All @@ -50,95 +54,143 @@ public MongoDBExportPartitionSupplier(final MongoDBSourceConfig sourceConfig,
this.documentDBAggregateMetrics = documentDBAggregateMetrics;
}

/**
* Detects whether the collection has a uniform _id type by checking the first and last documents.
* If uniform, we can use a simple Filters.gt() instead of the complex $or query across all BSON types.
*/
boolean isUniformIdType(final MongoCollection<Document> col) {
final Document first = col.find().projection(ID_PROJECTION).sort(ID_ASC).limit(1).first();
final Document last = col.find().projection(ID_PROJECTION).sort(ID_DESC).limit(1).first();
if (first == null || last == null) {
return true;
}
final String firstType = first.get("_id").getClass().getName();
final String lastType = last.get("_id").getClass().getName();
if (BsonHelper.isClassNumber(firstType) && BsonHelper.isClassNumber(lastType)) {
return true;
}
return firstType.equals(lastType);
}

private Bson buildNextStartFilter(final Object lastLteValue, final String lteClassName, final boolean uniformType) {
if (uniformType) {
return Filters.gt("_id", lastLteValue);
}
final String lteValueString = BsonHelper.getPartitionStringFromMongoDBId(lastLteValue, lteClassName);
return buildGtQuery(lteValueString, lteClassName, MAX_KEY);
}

private void addPartition(final List<PartitionIdentifier> partitions, final String collectionDbName,
final Object gteValue, final String gteClassName,
final Object lteValue, final String lteClassName) {
final String gteValueString = BsonHelper.getPartitionStringFromMongoDBId(gteValue, gteClassName);
final String lteValueString = BsonHelper.getPartitionStringFromMongoDBId(lteValue, lteClassName);
LOG.debug("Partition of {} : { gte: {} class: {}, lte: {} class {} }",
collectionDbName, gteValueString, gteClassName, lteValueString, lteClassName);
partitions.add(PartitionIdentifier.builder()
.withPartitionKey(String.format(MONGODB_PARTITION_KEY_FORMAT,
collectionDbName, gteValueString, lteValueString, gteClassName, lteClassName))
.build());
}

private PartitionIdentifierBatch buildPartitions(final ExportPartition exportPartition) {
documentDBAggregateMetrics.getExportApiInvocations().increment();
final List<PartitionIdentifier> collectionPartitions = new ArrayList<>();
final String collectionDbName = exportPartition.getCollection();
List<String> collection = List.of(collectionDbName.split(COLLECTION_SPLITTER));
final List<String> collection = List.of(collectionDbName.split(COLLECTION_SPLITTER));
if (collection.size() < 2) {
documentDBAggregateMetrics.getExport4xxErrors().increment();
throw new IllegalArgumentException("Invalid Collection Name. Must be in db.collection format");
}
final Optional<ExportProgressState> exportProgressStateOptional = exportPartition
.getProgressState();
final Object lastEndDocId = exportProgressStateOptional.map(
ExportProgressState::getLastEndDocId).orElse(null);

final Optional<ExportProgressState> exportProgressStateOptional = exportPartition.getProgressState();
final Object lastEndDocId = exportProgressStateOptional.map(ExportProgressState::getLastEndDocId).orElse(null);
boolean isLastBatch = false;
Object endDocId = lastEndDocId;

try (MongoClient mongoClient = MongoDBConnection.getMongoClient(sourceConfig)) {
final MongoDatabase db = mongoClient.getDatabase(collection.get(0));
final MongoCollection<Document> col = db.getCollection(collectionDbName.substring(collection.get(0).length()+1));
final MongoCollection<Document> col = db.getCollection(
collectionDbName.substring(collection.get(0).length() + 1));
final int partitionSize = exportPartition.getPartitionSize();
FindIterable<Document> startIterable;

final boolean uniformType = isUniformIdType(col);
LOG.info("Collection {} has {} _id type. Using {} partition query strategy.",
collectionDbName, uniformType ? "uniform" : "mixed", uniformType ? "simple $gt" : "$or-based");

Bson startFilter;
if (lastEndDocId != null) {
startIterable = col.find(Filters.gt("_id", lastEndDocId))
.projection(new Document("_id", 1))
.sort(new Document("_id", 1))
.limit(1);
startFilter = Filters.gt("_id", lastEndDocId);
} else {
startIterable = col.find()
.projection(new Document("_id", 1))
.sort(new Document("_id", 1))
.limit(1);
startFilter = new Document();
}

while (!Thread.currentThread().isInterrupted()) {
try (final MongoCursor<Document> startCursor = startIterable.iterator()) {
if (!startCursor.hasNext()) {
LOG.info("No records to process or has reached end of the export partition.");
isLastBatch = true;
break;
}
final Document startDoc = startCursor.next();
final Object gteValue = startDoc.get("_id");
final String gteClassName = gteValue.getClass().getName();

// Get end doc
Document endDoc = startIterable.skip(partitionSize - 1).limit(1).first();
if (endDoc == null) {
// this means we have reached the end of the doc
endDoc = col.find()
.projection(new Document("_id", 1))
.sort(new Document("_id", -1))
.limit(1)
.first();
isLastBatch = true;
}
final Document startDoc = col.find(startFilter)
.projection(ID_PROJECTION)
.sort(ID_ASC)
.limit(1)
.first();

if (startDoc == null) {
LOG.info("No records to process or has reached end of the export partition.");
isLastBatch = true;
break;
}

final Object lteValue = endDoc.get("_id");
final String lteClassName = lteValue.getClass().getName();
endDocId = lteValue;
final String gteValueString = BsonHelper.getPartitionStringFromMongoDBId(gteValue, gteClassName);
final String lteValueString = BsonHelper.getPartitionStringFromMongoDBId(lteValue, lteClassName);
LOG.debug("Partition of {} : { gte: {} class: {}, lte: {} class {} }", collectionDbName, gteValueString, gteClassName, lteValueString, lteClassName);
collectionPartitions.add(
PartitionIdentifier
.builder()
.withPartitionKey(String.format(MONGODB_PARTITION_KEY_FORMAT, collectionDbName, gteValueString, lteValueString, gteClassName, lteClassName))
.build());
documentDBAggregateMetrics.getExportPartitionQueryCount().increment();

if (isLastBatch) {
final Object gteValue = startDoc.get("_id");
final String gteClassName = gteValue.getClass().getName();

final Document endDoc = col.find(Filters.gte("_id", gteValue))
.projection(ID_PROJECTION)
.sort(ID_ASC)
.skip(partitionSize - 1)
.limit(1)
.first();

final Object lteValue;
final String lteClassName;

if (endDoc == null) {
final Document lastDoc = col.find()
.projection(ID_PROJECTION)
.sort(ID_DESC)
.limit(1)
.first();
if (lastDoc == null) {
isLastBatch = true;
break;
}
lteValue = lastDoc.get("_id");
lteClassName = lteValue.getClass().getName();
isLastBatch = true;
} else {
lteValue = endDoc.get("_id");
lteClassName = lteValue.getClass().getName();
}

// extend the ownership of the partition
enhancedSourceCoordinator.saveProgressStateForPartition(exportPartition, null);
endDocId = lteValue;
addPartition(collectionPartitions, collectionDbName, gteValue, gteClassName, lteValue, lteClassName);
documentDBAggregateMetrics.getExportPartitionQueryCount().increment();

startIterable = col.find(buildGtQuery(lteValueString, lteClassName, MAX_KEY))
.projection(new Document("_id", 1))
.sort(new Document("_id", 1))
.limit(1);
if (isLastBatch) {
break;
}

// extend the ownership of the partition
enhancedSourceCoordinator.saveProgressStateForPartition(exportPartition, null);

startFilter = buildNextStartFilter(lteValue, lteClassName, uniformType);
}
} catch (final IllegalArgumentException | MongoClientException e) {
// IllegalArgumentException is thrown when database or collection name is not valid
// MongoClientException is thrown for exceptions indicating a failure condition with the MongoClient
documentDBAggregateMetrics.getExport4xxErrors().increment();
LOG.error("Client side exception while build partitions.", e);
LOG.error("Client side exception while building partitions.", e);
throw new RuntimeException(e);
} catch (final Exception e) {
documentDBAggregateMetrics.getExport5xxErrors().increment();
LOG.error("Server side exception while build partitions.", e);
LOG.error("Server side exception while building partitions.", e);
throw new RuntimeException(e);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.dataprepper.plugins.mongo.export;

import com.mongodb.client.FindIterable;
import com.mongodb.client.MongoCollection;
import org.bson.Document;
import org.bson.types.Decimal128;
import org.bson.types.ObjectId;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator;
import org.opensearch.dataprepper.plugins.mongo.configuration.MongoDBSourceConfig;
import org.opensearch.dataprepper.plugins.mongo.utils.DocumentDBSourceAggregateMetrics;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
public class MongoDBExportPartitionSupplierIsUniformIdTypeTest {

@Mock
private MongoDBSourceConfig sourceConfig;
@Mock
private EnhancedSourceCoordinator sourceCoordinator;
@Mock
private DocumentDBSourceAggregateMetrics aggregateMetrics;
@Mock
private MongoCollection<Document> collection;
@Mock
private FindIterable<Document> findIterable;

private MongoDBExportPartitionSupplier supplier;

@BeforeEach
void setUp() {
supplier = new MongoDBExportPartitionSupplier(sourceConfig, sourceCoordinator, aggregateMetrics);
when(collection.find()).thenReturn(findIterable);
when(findIterable.projection(any())).thenReturn(findIterable);
when(findIterable.sort(any())).thenReturn(findIterable);
when(findIterable.limit(1)).thenReturn(findIterable);
}

@Test
void isUniformIdType_emptyCollection_returnsTrue() {
when(findIterable.first()).thenReturn(null);
assertThat(supplier.isUniformIdType(collection), is(true));
}

@Test
void isUniformIdType_uniformObjectId_returnsTrue() {
when(findIterable.first())
.thenReturn(new Document("_id", new ObjectId()))
.thenReturn(new Document("_id", new ObjectId()));
assertThat(supplier.isUniformIdType(collection), is(true));
}

@Test
void isUniformIdType_uniformString_returnsTrue() {
when(findIterable.first())
.thenReturn(new Document("_id", "abc"))
.thenReturn(new Document("_id", "xyz"));
assertThat(supplier.isUniformIdType(collection), is(true));
}

@Test
void isUniformIdType_mixedTypes_returnsFalse() {
when(findIterable.first())
.thenReturn(new Document("_id", 1))
.thenReturn(new Document("_id", new ObjectId()));
assertThat(supplier.isUniformIdType(collection), is(false));
}

@Test
void isUniformIdType_integerAndLong_returnsTrue() {
when(findIterable.first())
.thenReturn(new Document("_id", 42))
.thenReturn(new Document("_id", 999999999999L));
assertThat(supplier.isUniformIdType(collection), is(true));
}

@Test
void isUniformIdType_doubleAndDecimal128_returnsTrue() {
when(findIterable.first())
.thenReturn(new Document("_id", 3.14))
.thenReturn(new Document("_id", Decimal128.parse("99.99")));
assertThat(supplier.isUniformIdType(collection), is(true));
}

@Test
void isUniformIdType_stringAndObjectId_returnsFalse() {
when(findIterable.first())
.thenReturn(new Document("_id", "abc"))
.thenReturn(new Document("_id", new ObjectId()));
assertThat(supplier.isUniformIdType(collection), is(false));
}
}
Loading
Loading