Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public DynamoDBService(final EnhancedSourceCoordinator coordinator,
s3Client = clientFactory.buildS3Client();

// A shard manager is responsible to retrieve the shard information from streams.
shardManager = new ShardManager(dynamoDbStreamsClient, dynamoDBSourceAggregateMetrics);
shardManager = new ShardManager(dynamoDbStreamsClient, dynamoDBSourceAggregateMetrics, pluginMetrics);
tableConfigs = sourceConfig.getTableConfigs();
executor = Executors.newFixedThreadPool(4);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.opensearch.dataprepper.plugins.source.dynamodb.leader;

import io.micrometer.core.instrument.DistributionSummary;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.plugins.source.dynamodb.utils.DynamoDBSourceAggregateMetrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -23,6 +25,7 @@
public class ShardManager {

private static final Logger LOG = LoggerFactory.getLogger(ShardManager.class);
static final String TOTAL_OPEN_SHARDS = "totalOpenShards";

/**
* Max number of shards to return in the DescribeStream API call, maximum 100.
Expand All @@ -47,14 +50,17 @@ public class ShardManager {

private final DynamoDbStreamsClient streamsClient;
private final DynamoDBSourceAggregateMetrics dynamoDBSourceAggregateMetrics;
private final DistributionSummary totalOpenShardCountDistributionSummary;


public ShardManager(final DynamoDbStreamsClient streamsClient,
final DynamoDBSourceAggregateMetrics dynamoDBSourceAggregateMetrics) {
final DynamoDBSourceAggregateMetrics dynamoDBSourceAggregateMetrics,
final PluginMetrics pluginMetrics) {
this.streamsClient = streamsClient;
this.dynamoDBSourceAggregateMetrics = dynamoDBSourceAggregateMetrics;
streamMap = new HashMap<>();
endingSequenceNumberMap = new HashMap<>();
this.totalOpenShardCountDistributionSummary = pluginMetrics.summary(TOTAL_OPEN_SHARDS);
}

/**
Expand Down Expand Up @@ -100,12 +106,16 @@ public List<Shard> runDiscovery(String streamArn) {
});

if (streamInfo.getLastEvaluatedShardId() == null) {
endingSequenceNumberMap = shards.stream()
final List<Shard> closedShards = shards.stream()
.filter(shard -> shard.sequenceNumberRange().endingSequenceNumber() != null)
.collect(Collectors.toList());
endingSequenceNumberMap = closedShards.stream()
.collect(Collectors.toMap(
shard -> shard.shardId(),
shard -> shard.sequenceNumberRange().endingSequenceNumber()
));

totalOpenShardCountDistributionSummary.record(shards.size() - closedShards.size());
}
LOG.debug("New last evaluated shard ID is " + shards.get(shards.size() - 1).shardId());
streamInfo.setLastEvaluatedShardId(shards.get(shards.size() - 1).shardId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,9 @@ private boolean shouldSkip() {
if (lastShardIterator != null && !lastShardIterator.isEmpty()) {
GetRecordsResponse response = callGetRecords(lastShardIterator);
if (response.records().isEmpty()) {
// Empty shard
LOG.info("LastShardIterator is provided, but there is no Last Event Time, skip processing");
return true;
// There is no guarantee that the shard is empty just because there is no record at the endingSequenceNumber
LOG.info("LastShardIterator is provided, but there is no Last Event Time, paginating through for documents");
return false;
}

Instant lastEventTime = response.records().get(response.records().size() - 1).dynamodb().approximateCreationDateTime();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
package org.opensearch.dataprepper.plugins.source.dynamodb.leader;

import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.plugins.source.dynamodb.utils.DynamoDBSourceAggregateMetrics;
import software.amazon.awssdk.services.dynamodb.model.DescribeStreamRequest;
import software.amazon.awssdk.services.dynamodb.model.DescribeStreamResponse;
Expand All @@ -33,6 +35,7 @@
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.dataprepper.plugins.source.dynamodb.leader.ShardManager.TOTAL_OPEN_SHARDS;


@ExtendWith(MockitoExtension.class)
Expand All @@ -52,6 +55,12 @@ class ShardManagerTest {
@Mock
private Counter streamApiInvocations;

@Mock
private PluginMetrics pluginMetrics;

@Mock
private DistributionSummary distributionSummary;

private ShardManager shardManager;


Expand Down Expand Up @@ -98,7 +107,9 @@ void setup() {
.build();

lenient().when(dynamoDbStreamsClient.describeStream(any(DescribeStreamRequest.class))).thenReturn(response);
shardManager = new ShardManager(dynamoDbStreamsClient, dynamoDBSourceAggregateMetrics);

when(pluginMetrics.summary(TOTAL_OPEN_SHARDS)).thenReturn(distributionSummary);
shardManager = new ShardManager(dynamoDbStreamsClient, dynamoDBSourceAggregateMetrics, pluginMetrics);

when(dynamoDBSourceAggregateMetrics.getStreamApiInvocations()).thenReturn(streamApiInvocations);
}
Expand All @@ -109,6 +120,8 @@ void test_getChildShardIds_should_return_child_shards() {
assertThat(childShards, notNullValue());
assertThat(childShards.size(), equalTo(6));

verify(distributionSummary).record(2);

List<String> childShardIds1 = shardManager.findChildShardIds(streamArn, "shardId-001");
assertThat(childShardIds1, notNullValue());
assertThat(childShardIds1.size(), equalTo(1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
Expand Down Expand Up @@ -46,6 +47,9 @@
import java.util.Random;
import java.util.UUID;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
Expand Down Expand Up @@ -480,6 +484,59 @@ void test_shard_has_no_records_null_iterator() throws Exception {

}
}

@Test
void test_shard_has_no_records_null_with_last_shard_iterator_paginates_through_shard() throws Exception {
final AcknowledgementSet finalAcknowledgementSet = mock(AcknowledgementSet.class);
when(shardAcknowledgementManager.createAcknowledgmentSet(any(StreamPartition.class), any(String.class), any(Boolean.class)))
.thenReturn(finalAcknowledgementSet);

final String lastShardIterator = UUID.randomUUID().toString();

// Set up response with null nextShardIterator to trigger end of shard
GetRecordsResponse response = GetRecordsResponse.builder()
.records(List.of())
.nextShardIterator(null)
.build();
final ArgumentCaptor<GetRecordsRequest> getRecordsRequest = ArgumentCaptor.forClass(GetRecordsRequest.class);
when(dynamoDbStreamsClient.getRecords(getRecordsRequest.capture())).thenReturn(response);

try (MockedStatic<ShardConsumer> shardConsumerMockedStatic = mockStatic(ShardConsumer.class, invocation -> {
if (invocation.getMethod().getName().equals("stopAll")) {
return null;
} else if (invocation.getMethod().getName().equals("shouldStop")) {
return false;
}
return invocation.callRealMethod();
})) {
ShardConsumer shardConsumer;
try (final MockedStatic<BufferAccumulator> bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) {
bufferAccumulatorMockedStatic.when(() -> BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT)).thenReturn(bufferAccumulator);
shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig)
.shardIterator(shardIterator)
.lastShardIterator(lastShardIterator)
.shardAcknowledgementManager(shardAcknowledgementManager)
.streamPartition(streamPartition)
.tableInfo(tableInfo)
.startTime(null)
.waitForExport(false)
.build();
}

shardConsumer.run();

// Verify acknowledgment set created for records with shardIterator == null (true)
verify(shardAcknowledgementManager).createAcknowledgmentSet(eq(streamPartition), eq(END_OF_SHARD), eq(true));
// Verify final acknowledgment set created and completed when shardIterator is null
verify(finalAcknowledgementSet).complete();

final List<GetRecordsRequest> requestWithLastShardIterator = getRecordsRequest.getAllValues();
assertThat(requestWithLastShardIterator, notNullValue());
assertThat(requestWithLastShardIterator.size(), equalTo(2));
assertThat(requestWithLastShardIterator.get(0).shardIterator(), equalTo(lastShardIterator));
assertThat(requestWithLastShardIterator.get(1).shardIterator(), equalTo(shardIterator));
}
}
private List<Record> buildRecords(int count) {
List<Record> records = new ArrayList<>();
for (int i = 0; i < count; i++) {
Expand Down
Loading