diff --git a/kinesumer.go b/kinesumer.go index 934a431..bf2b5bc 100644 --- a/kinesumer.go +++ b/kinesumer.go @@ -311,11 +311,10 @@ func (k *Kinesumer) listShards(stream string) (Shards, error) { } var shards []*Shard for _, shard := range output.Shards { - // TODO(mingrammer): handle CLOSED shards. + // Closed shards will be handled while consuming the records. if shard.SequenceNumberRange.EndingSequenceNumber == nil { shards = append(shards, &Shard{ - ID: *shard.ShardId, - Closed: shard.SequenceNumberRange.EndingSequenceNumber != nil, + ID: *shard.ShardId, }) } } @@ -330,11 +329,10 @@ func (k *Kinesumer) listShards(stream string) (Shards, error) { return nil, errors.WithStack(err) } for _, shard := range output.Shards { - // Skip CLOSED shards. + // Closed shards will be handled while consuming the records. if shard.SequenceNumberRange.EndingSequenceNumber == nil { shards = append(shards, &Shard{ - ID: *shard.ShardId, - Closed: shard.SequenceNumberRange.EndingSequenceNumber != nil, + ID: *shard.ShardId, }) } } @@ -345,7 +343,8 @@ func (k *Kinesumer) listShards(stream string) (Shards, error) { // Consume consumes messages from Kinesis. func (k *Kinesumer) Consume( - streams []string) (<-chan *Record, error) { + streams []string, +) (<-chan *Record, error) { k.streams = streams ctx := context.Background() @@ -635,15 +634,8 @@ func (k *Kinesumer) consumeLoop(stream string, shard *Shard) { default: time.Sleep(k.scanInterval) records, closed := k.consumeOnce(stream, shard) - if closed { - k.cleanupOffsets(stream, shard) - return // Close consume loop if shard is CLOSED and has no data. - } n := len(records) - if n == 0 { - continue - } for i, record := range records { r := &Record{ @@ -656,6 +648,13 @@ func (k *Kinesumer) consumeLoop(stream string, shard *Shard) { if k.autoCommit && i == n-1 { k.MarkRecord(r) } + + // Closed shard may have remaining data, + // so clean up is executed, once the records are pushed to the channel. + if closed { + k.cleanupOffsets(stream, shard) + return // Close consume loop if shard is CLOSED and has no data. + } } } } @@ -694,17 +693,24 @@ func (k *Kinesumer) consumeOnce(stream string, shard *Shard) ([]*kinesis.Record, } defer k.nextIters[stream].Store(shard.ID, output.NextShardIterator) // Update iter. - n := len(output.Records) - // We no longer care about shards that have no records left and are in the "CLOSED" state. - if n == 0 { - return nil, shard.Closed - } + shard.Closed = getShardStatus(output) + + // outer function has the for loop that takes care of the empty records case + // so not needed to check it here. + return output.Records, shard.Closed +} - return output.Records, false +// getShardStatus returns whether the shard is closed or not. +func getShardStatus(output *kinesis.GetRecordsOutput) bool { + // set the shard closed state. + // If shard has no data, NextShardIterator will be nil. + // Reference: https://docs.aws.amazon.com/cli/latest/reference/kinesis/get-records.html#output + return output.NextShardIterator == nil } func (k *Kinesumer) getNextShardIterator( - ctx context.Context, stream, shardID string) (*string, error) { + ctx context.Context, stream, shardID string, +) (*string, error) { if iter, ok := k.nextIters[stream].Load(shardID); ok { return iter.(*string), nil } @@ -729,7 +735,7 @@ func (k *Kinesumer) getNextShardIterator( } func (k *Kinesumer) commitPeriodically() { - var checkPointTicker = time.NewTicker(k.commitInterval) + checkPointTicker := time.NewTicker(k.commitInterval) for { select { diff --git a/kinesumer_test.go b/kinesumer_test.go index 4c264ad..bc1a4a9 100644 --- a/kinesumer_test.go +++ b/kinesumer_test.go @@ -682,3 +682,43 @@ func TestKinesumer_cleanupOffsetsWorksFine(t *testing.T) { }) } } + +func Test_getShardStatus(t *testing.T) { + tests := []struct { + name string // description of this test case + input struct { + output *kinesis.GetRecordsOutput + } + want bool + }{ + { + name: "when shard is closed", + input: struct { + output *kinesis.GetRecordsOutput + }{ + output: &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + }, + }, + want: true, + }, + + { + name: "when shard is open", + input: struct { + output *kinesis.GetRecordsOutput + }{ + output: &kinesis.GetRecordsOutput{ + NextShardIterator: aws.String("shardIterator"), + }, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := getShardStatus(tt.input.output) + assert.Equal(t, tt.want, got, "Should be equal") + }) + } +}