Skip to content

Commit 13049d3

Browse files
author
Jonah Calvo
authored
Ensure shards are completed when last getRecords call has no records and no shardIterator (#5958)
Signed-off-by: Jonah Calvo <caljonah@amazon.com>
1 parent 15b32c7 commit 13049d3

2 files changed

Lines changed: 100 additions & 1 deletion

File tree

  • data-prepper-plugins/dynamodb-source/src

data-prepper-plugins/dynamodb-source/src/main/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumer.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,14 @@ public void run() {
252252
String sequenceNumber = "";
253253
int interval;
254254
List<software.amazon.awssdk.services.dynamodb.model.Record> records;
255-
255+
boolean createdFinalAcknowledgmentSetForShard = false;
256256
while (!shouldStop) {
257257
if (shardIterator == null) {
258258
// End of Shard
259+
if (shardAcknowledgementManager != null && !createdFinalAcknowledgmentSetForShard) {
260+
final AcknowledgementSet finalAcknowledgmentSet = shardAcknowledgementManager.createAcknowledgmentSet(streamPartition, sequenceNumber, true);
261+
finalAcknowledgmentSet.complete();
262+
}
259263
LOG.debug("Reached end of shard");
260264
break;
261265
}
@@ -287,6 +291,9 @@ public void run() {
287291
AcknowledgementSet acknowledgementSet = null;
288292
if (shardAcknowledgementManager != null) {
289293
acknowledgementSet = shardAcknowledgementManager.createAcknowledgmentSet(streamPartition, sequenceNumber, shardIterator == null);
294+
if (shardIterator == null) {
295+
createdFinalAcknowledgmentSetForShard = true;
296+
}
290297
}
291298

292299
records = response.records().stream()

data-prepper-plugins/dynamodb-source/src/test/java/org/opensearch/dataprepper/plugins/source/dynamodb/stream/ShardConsumerTest.java

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import static org.junit.jupiter.api.Assertions.assertThrows;
5050
import static org.mockito.ArgumentMatchers.any;
5151
import static org.mockito.ArgumentMatchers.anyString;
52+
import static org.mockito.ArgumentMatchers.eq;
5253
import static org.mockito.BDDMockito.given;
5354
import static org.mockito.Mockito.lenient;
5455
import static org.mockito.Mockito.mock;
@@ -347,6 +348,9 @@ void test_run_shardConsumer_catches_4xx_exception_and_increments_metric() {
347348

348349
@Test
349350
void test_run_shardConsumer_calls_startUpdatingOwnershipForShard() throws Exception {
351+
final AcknowledgementSet finalAcknowledgementSet = mock(AcknowledgementSet.class);
352+
when(shardAcknowledgementManager.createAcknowledgmentSet(any(StreamPartition.class), any(String.class), any(Boolean.class)))
353+
.thenReturn(finalAcknowledgementSet);
350354
try (final MockedStatic<BufferAccumulator> bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) {
351355
bufferAccumulatorMockedStatic.when(() -> BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT)).thenReturn(bufferAccumulator);
352356
ShardConsumer shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig)
@@ -364,6 +368,94 @@ void test_run_shardConsumer_calls_startUpdatingOwnershipForShard() throws Except
364368
verify(shardAcknowledgementManager).startUpdatingOwnershipForShard(streamPartition);
365369
}
366370

371+
@Test
372+
void test_shard_has_records_null_iterator() throws Exception {
373+
final AcknowledgementSet finalAcknowledgementSet = mock(AcknowledgementSet.class);
374+
when(shardAcknowledgementManager.createAcknowledgmentSet(any(StreamPartition.class), any(String.class), any(Boolean.class)))
375+
.thenReturn(finalAcknowledgementSet);
376+
377+
// Set up response with null nextShardIterator to trigger end of shard
378+
GetRecordsResponse response = GetRecordsResponse.builder()
379+
.records(buildRecords(1))
380+
.nextShardIterator(null)
381+
.build();
382+
when(dynamoDbStreamsClient.getRecords(any(GetRecordsRequest.class))).thenReturn(response);
383+
384+
try (MockedStatic<ShardConsumer> shardConsumerMockedStatic = mockStatic(ShardConsumer.class, invocation -> {
385+
if (invocation.getMethod().getName().equals("stopAll")) {
386+
return null;
387+
} else if (invocation.getMethod().getName().equals("shouldStop")) {
388+
return false;
389+
}
390+
return invocation.callRealMethod();
391+
})) {
392+
ShardConsumer shardConsumer;
393+
try (final MockedStatic<BufferAccumulator> bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) {
394+
bufferAccumulatorMockedStatic.when(() -> BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT)).thenReturn(bufferAccumulator);
395+
shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig)
396+
.shardIterator(shardIterator)
397+
.shardAcknowledgementManager(shardAcknowledgementManager)
398+
.streamPartition(streamPartition)
399+
.tableInfo(tableInfo)
400+
.startTime(null)
401+
.waitForExport(false)
402+
.build();
403+
}
404+
405+
shardConsumer.run();
406+
407+
// Verify acknowledgment set created for records with shardIterator == null (true)
408+
verify(shardAcknowledgementManager).createAcknowledgmentSet(eq(streamPartition), any(String.class), eq(true));
409+
// Verify final acknowledgment set created and completed when shardIterator is null
410+
verify(finalAcknowledgementSet).complete();
411+
412+
}
413+
}
414+
415+
416+
@Test
417+
void test_shard_has_no_records_null_iterator() throws Exception {
418+
final AcknowledgementSet finalAcknowledgementSet = mock(AcknowledgementSet.class);
419+
when(shardAcknowledgementManager.createAcknowledgmentSet(any(StreamPartition.class), any(String.class), any(Boolean.class)))
420+
.thenReturn(finalAcknowledgementSet);
421+
422+
// Set up response with null nextShardIterator to trigger end of shard
423+
GetRecordsResponse response = GetRecordsResponse.builder()
424+
.records(List.of())
425+
.nextShardIterator(null)
426+
.build();
427+
when(dynamoDbStreamsClient.getRecords(any(GetRecordsRequest.class))).thenReturn(response);
428+
429+
try (MockedStatic<ShardConsumer> shardConsumerMockedStatic = mockStatic(ShardConsumer.class, invocation -> {
430+
if (invocation.getMethod().getName().equals("stopAll")) {
431+
return null;
432+
} else if (invocation.getMethod().getName().equals("shouldStop")) {
433+
return false;
434+
}
435+
return invocation.callRealMethod();
436+
})) {
437+
ShardConsumer shardConsumer;
438+
try (final MockedStatic<BufferAccumulator> bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) {
439+
bufferAccumulatorMockedStatic.when(() -> BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT)).thenReturn(bufferAccumulator);
440+
shardConsumer = ShardConsumer.builder(dynamoDbStreamsClient, pluginMetrics, aggregateMetrics, buffer, streamConfig)
441+
.shardIterator(shardIterator)
442+
.shardAcknowledgementManager(shardAcknowledgementManager)
443+
.streamPartition(streamPartition)
444+
.tableInfo(tableInfo)
445+
.startTime(null)
446+
.waitForExport(false)
447+
.build();
448+
}
449+
450+
shardConsumer.run();
451+
452+
// Verify acknowledgment set created for records with shardIterator == null (true)
453+
verify(shardAcknowledgementManager).createAcknowledgmentSet(eq(streamPartition), any(String.class), eq(true));
454+
// Verify final acknowledgment set created and completed when shardIterator is null
455+
verify(finalAcknowledgementSet).complete();
456+
457+
}
458+
}
367459
private List<Record> buildRecords(int count) {
368460
List<Record> records = new ArrayList<>();
369461
for (int i = 0; i < count; i++) {

0 commit comments

Comments
 (0)