4949import static org .junit .jupiter .api .Assertions .assertThrows ;
5050import static org .mockito .ArgumentMatchers .any ;
5151import static org .mockito .ArgumentMatchers .anyString ;
52+ import static org .mockito .ArgumentMatchers .eq ;
5253import static org .mockito .BDDMockito .given ;
5354import static org .mockito .Mockito .lenient ;
5455import static org .mockito .Mockito .mock ;
@@ -347,6 +348,9 @@ void test_run_shardConsumer_catches_4xx_exception_and_increments_metric() {
347348
348349 @ Test
349350 void test_run_shardConsumer_calls_startUpdatingOwnershipForShard () throws Exception {
351+ final AcknowledgementSet finalAcknowledgementSet = mock (AcknowledgementSet .class );
352+ when (shardAcknowledgementManager .createAcknowledgmentSet (any (StreamPartition .class ), any (String .class ), any (Boolean .class )))
353+ .thenReturn (finalAcknowledgementSet );
350354 try (final MockedStatic <BufferAccumulator > bufferAccumulatorMockedStatic = mockStatic (BufferAccumulator .class )) {
351355 bufferAccumulatorMockedStatic .when (() -> BufferAccumulator .create (buffer , DEFAULT_BUFFER_BATCH_SIZE , BUFFER_TIMEOUT )).thenReturn (bufferAccumulator );
352356 ShardConsumer shardConsumer = ShardConsumer .builder (dynamoDbStreamsClient , pluginMetrics , aggregateMetrics , buffer , streamConfig )
@@ -364,6 +368,94 @@ void test_run_shardConsumer_calls_startUpdatingOwnershipForShard() throws Except
364368 verify (shardAcknowledgementManager ).startUpdatingOwnershipForShard (streamPartition );
365369 }
366370
371+ @ Test
372+ void test_shard_has_records_null_iterator () throws Exception {
373+ final AcknowledgementSet finalAcknowledgementSet = mock (AcknowledgementSet .class );
374+ when (shardAcknowledgementManager .createAcknowledgmentSet (any (StreamPartition .class ), any (String .class ), any (Boolean .class )))
375+ .thenReturn (finalAcknowledgementSet );
376+
377+ // Set up response with null nextShardIterator to trigger end of shard
378+ GetRecordsResponse response = GetRecordsResponse .builder ()
379+ .records (buildRecords (1 ))
380+ .nextShardIterator (null )
381+ .build ();
382+ when (dynamoDbStreamsClient .getRecords (any (GetRecordsRequest .class ))).thenReturn (response );
383+
384+ try (MockedStatic <ShardConsumer > shardConsumerMockedStatic = mockStatic (ShardConsumer .class , invocation -> {
385+ if (invocation .getMethod ().getName ().equals ("stopAll" )) {
386+ return null ;
387+ } else if (invocation .getMethod ().getName ().equals ("shouldStop" )) {
388+ return false ;
389+ }
390+ return invocation .callRealMethod ();
391+ })) {
392+ ShardConsumer shardConsumer ;
393+ try (final MockedStatic <BufferAccumulator > bufferAccumulatorMockedStatic = mockStatic (BufferAccumulator .class )) {
394+ bufferAccumulatorMockedStatic .when (() -> BufferAccumulator .create (buffer , DEFAULT_BUFFER_BATCH_SIZE , BUFFER_TIMEOUT )).thenReturn (bufferAccumulator );
395+ shardConsumer = ShardConsumer .builder (dynamoDbStreamsClient , pluginMetrics , aggregateMetrics , buffer , streamConfig )
396+ .shardIterator (shardIterator )
397+ .shardAcknowledgementManager (shardAcknowledgementManager )
398+ .streamPartition (streamPartition )
399+ .tableInfo (tableInfo )
400+ .startTime (null )
401+ .waitForExport (false )
402+ .build ();
403+ }
404+
405+ shardConsumer .run ();
406+
407+ // Verify acknowledgment set created for records with shardIterator == null (true)
408+ verify (shardAcknowledgementManager ).createAcknowledgmentSet (eq (streamPartition ), any (String .class ), eq (true ));
409+ // Verify final acknowledgment set created and completed when shardIterator is null
410+ verify (finalAcknowledgementSet ).complete ();
411+
412+ }
413+ }
414+
415+
416+ @ Test
417+ void test_shard_has_no_records_null_iterator () throws Exception {
418+ final AcknowledgementSet finalAcknowledgementSet = mock (AcknowledgementSet .class );
419+ when (shardAcknowledgementManager .createAcknowledgmentSet (any (StreamPartition .class ), any (String .class ), any (Boolean .class )))
420+ .thenReturn (finalAcknowledgementSet );
421+
422+ // Set up response with null nextShardIterator to trigger end of shard
423+ GetRecordsResponse response = GetRecordsResponse .builder ()
424+ .records (List .of ())
425+ .nextShardIterator (null )
426+ .build ();
427+ when (dynamoDbStreamsClient .getRecords (any (GetRecordsRequest .class ))).thenReturn (response );
428+
429+ try (MockedStatic <ShardConsumer > shardConsumerMockedStatic = mockStatic (ShardConsumer .class , invocation -> {
430+ if (invocation .getMethod ().getName ().equals ("stopAll" )) {
431+ return null ;
432+ } else if (invocation .getMethod ().getName ().equals ("shouldStop" )) {
433+ return false ;
434+ }
435+ return invocation .callRealMethod ();
436+ })) {
437+ ShardConsumer shardConsumer ;
438+ try (final MockedStatic <BufferAccumulator > bufferAccumulatorMockedStatic = mockStatic (BufferAccumulator .class )) {
439+ bufferAccumulatorMockedStatic .when (() -> BufferAccumulator .create (buffer , DEFAULT_BUFFER_BATCH_SIZE , BUFFER_TIMEOUT )).thenReturn (bufferAccumulator );
440+ shardConsumer = ShardConsumer .builder (dynamoDbStreamsClient , pluginMetrics , aggregateMetrics , buffer , streamConfig )
441+ .shardIterator (shardIterator )
442+ .shardAcknowledgementManager (shardAcknowledgementManager )
443+ .streamPartition (streamPartition )
444+ .tableInfo (tableInfo )
445+ .startTime (null )
446+ .waitForExport (false )
447+ .build ();
448+ }
449+
450+ shardConsumer .run ();
451+
452+ // Verify acknowledgment set created for records with shardIterator == null (true)
453+ verify (shardAcknowledgementManager ).createAcknowledgmentSet (eq (streamPartition ), any (String .class ), eq (true ));
454+ // Verify final acknowledgment set created and completed when shardIterator is null
455+ verify (finalAcknowledgementSet ).complete ();
456+
457+ }
458+ }
367459 private List <Record > buildRecords (int count ) {
368460 List <Record > records = new ArrayList <>();
369461 for (int i = 0 ; i < count ; i ++) {
0 commit comments