diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumer.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumer.java index 7324abed11..f51c5cec2a 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumer.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumer.java @@ -17,6 +17,7 @@ import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.clients.consumer.KafkaConsumer; import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.errors.RebalanceInProgressException; import org.apache.kafka.common.header.Header; import org.apache.kafka.common.header.Headers; import org.apache.kafka.common.TopicPartition; @@ -354,11 +355,15 @@ private void commitOffsets(boolean forceCommit) { offsetsToCommit.forEach(((partition, offset) -> updateCommitCountMetric(partition, offset))); try { consumer.commitSync(offsetsToCommit); + lastCommitTime = currentTimeMillis; + } catch (final RebalanceInProgressException ex) { + LOG.error("Failed to commit offsets in topic {} due to rebalance in progress", topicName, ex); + return; } catch (Exception e) { LOG.error("Failed to commit offsets in topic {}", topicName, e); } + offsetsToCommit.clear(); - lastCommitTime = currentTimeMillis; } } diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerTest.java index 1bbc60ecdb..702fd26849 100644 --- a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerTest.java +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerTest.java @@ -16,6 +16,7 @@ import org.apache.kafka.clients.consumer.KafkaConsumer; import org.apache.kafka.clients.consumer.OffsetAndMetadata; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.RebalanceInProgressException; import org.apache.kafka.common.errors.RecordDeserializationException; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; @@ -60,6 +61,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -589,6 +592,88 @@ public void testAwsGlueErrorWithAcknowledgements() throws Exception { }); } + @Test + public void testCommitOffsets_RebalanceInProgressException_DoesNotClearOffsets() throws Exception { + String topic = topicConfig.getName(); + TopicPartition topicPartition = new TopicPartition(topic, testPartition); + + when(topicConfig.getSerdeFormat()).thenReturn(MessageFormat.PLAINTEXT); + when(topicConfig.getAutoCommit()).thenReturn(false); + when(topicConfig.getCommitInterval()).thenReturn(Duration.ofMillis(0)); + + consumer = createObjectUnderTest("plaintext", false); + consumer.onPartitionsAssigned(List.of(topicPartition)); + + consumerRecords = createPlainTextRecords(topic, 100L); + when(kafkaConsumer.poll(any(Duration.class))).thenReturn(consumerRecords); + + doThrow(new RebalanceInProgressException("Rebalance in progress")) + .when(kafkaConsumer).commitSync(anyMap()); + + consumer.consumeRecords(); + + Map offsetsBeforeCommit = new HashMap<>(consumer.getOffsetsToCommit()); + Assertions.assertFalse(offsetsBeforeCommit.isEmpty(), "Offsets should be populated after consuming records"); + Assertions.assertEquals(102L, offsetsBeforeCommit.get(topicPartition).offset()); + + Thread testThread = new Thread(() -> { + try { + java.lang.reflect.Method method = consumer.getClass().getDeclaredMethod("commitOffsets", boolean.class); + method.setAccessible(true); + method.invoke(consumer, true); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + testThread.start(); + testThread.join(5000); + + Map offsetsAfterFailedCommit = consumer.getOffsetsToCommit(); + Assertions.assertFalse(offsetsAfterFailedCommit.isEmpty(), + "Offsets should NOT be cleared after RebalanceInProgressException"); + Assertions.assertEquals(offsetsBeforeCommit.get(topicPartition).offset(), + offsetsAfterFailedCommit.get(topicPartition).offset(), + "Offset value should remain unchanged for retry after rebalance completes"); + } + + @Test + public void testCommitOffsets_OtherException_ClearsOffsets() throws Exception { + String topic = topicConfig.getName(); + TopicPartition topicPartition = new TopicPartition(topic, testPartition); + + when(topicConfig.getAutoCommit()).thenReturn(false); + when(topicConfig.getCommitInterval()).thenReturn(Duration.ofMillis(0)); + + consumer = createObjectUnderTest("plaintext", false); + consumer.onPartitionsAssigned(List.of(topicPartition)); + + consumerRecords = createPlainTextRecords(topic, 100L); + when(kafkaConsumer.poll(any(Duration.class))).thenReturn(consumerRecords); + + doThrow(new RuntimeException("Generic commit failure")) + .when(kafkaConsumer).commitSync(anyMap()); + + consumer.consumeRecords(); + + Assertions.assertFalse(consumer.getOffsetsToCommit().isEmpty(), + "Offsets should be populated after consuming records"); + + Thread testThread = new Thread(() -> { + try { + java.lang.reflect.Method method = consumer.getClass().getDeclaredMethod("commitOffsets", boolean.class); + method.setAccessible(true); + method.invoke(consumer, true); + } catch (Exception e) { + } + }); + testThread.start(); + testThread.join(5000); + + Map offsetsAfterFailedCommit = consumer.getOffsetsToCommit(); + Assertions.assertTrue(offsetsAfterFailedCommit.isEmpty(), + "Offsets should be cleared after non-rebalance exception"); + } + private ConsumerRecords createPlainTextRecords(String topic, final long startOffset) { Map> records = new HashMap<>(); ConsumerRecord record1 = new ConsumerRecord<>(topic, testPartition, startOffset, testKey1, testValue1);