Skip to content

Commit a9563bc

Browse files
graytaylor0simonelbaz
authored andcommitted
Remove usage of buffer accumulator from Kafka custom consumer (opensearch-project#6357)
Signed-off-by: Taylor Gray <tylgry@amazon.com> Signed-off-by: Simon ELBAZ <elbazsimon9@gmail.com>
1 parent 70190c8 commit a9563bc

2 files changed

Lines changed: 30 additions & 25 deletions

File tree

data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumer.java

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717
import org.apache.kafka.clients.consumer.ConsumerRecords;
1818
import org.apache.kafka.clients.consumer.KafkaConsumer;
1919
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
20-
import org.apache.kafka.common.errors.RebalanceInProgressException;
21-
import org.apache.kafka.common.header.Header;
22-
import org.apache.kafka.common.header.Headers;
2320
import org.apache.kafka.common.TopicPartition;
2421
import org.apache.kafka.common.errors.AuthenticationException;
22+
import org.apache.kafka.common.errors.RebalanceInProgressException;
2523
import org.apache.kafka.common.errors.RecordDeserializationException;
26-
import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
24+
import org.apache.kafka.common.header.Header;
25+
import org.apache.kafka.common.header.Headers;
2726
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
2827
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
2928
import org.opensearch.dataprepper.model.buffer.Buffer;
@@ -70,8 +69,8 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener
7069

7170
private static final Logger LOG = LoggerFactory.getLogger(KafkaCustomConsumer.class);
7271
private static final Long COMMIT_OFFSET_INTERVAL_MS = 300000L;
73-
private static final int DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE = 1;
7472
private static final int RETRY_ON_EXCEPTION_SLEEP_MS = 1000;
73+
private static final int BUFFER_WRITE_TIMEOUT = 2000;
7574
static final String DEFAULT_KEY = "message";
7675

7776
private volatile long lastCommitTime;
@@ -81,7 +80,6 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener
8180
private final TopicConsumerConfig topicConfig;
8281
private MessageFormat schema;
8382
private boolean paused;
84-
private final BufferAccumulator<Record<Event>> bufferAccumulator;
8583
private final Buffer<Record<Event>> buffer;
8684
private static final ObjectMapper objectMapper = new ObjectMapper();
8785
private final JsonFactory jsonFactory = new JsonFactory();
@@ -123,7 +121,7 @@ public KafkaCustomConsumer(final KafkaConsumer consumer,
123121
this.paused = false;
124122
this.byteDecoder = byteDecoder;
125123
this.topicMetrics = topicMetrics;
126-
this.maxRetriesOnException = topicConfig.getMaxPollInterval().toMillis() / (2 * RETRY_ON_EXCEPTION_SLEEP_MS);
124+
this.maxRetriesOnException = topicConfig.getMaxPollInterval().toMillis() / (2 * (RETRY_ON_EXCEPTION_SLEEP_MS + BUFFER_WRITE_TIMEOUT));
127125
this.pauseConsumePredicate = pauseConsumePredicate;
128126
this.topicMetrics.register(consumer);
129127
this.offsetsToCommit = new HashMap<>();
@@ -137,8 +135,6 @@ public KafkaCustomConsumer(final KafkaConsumer consumer,
137135
this.partitionCommitTrackerMap = new HashMap<>();
138136
this.partitionsToReset = Collections.synchronizedSet(new HashSet<>());
139137
this.schema = MessageFormat.getByMessageFormatByName(schemaType);
140-
Duration bufferTimeout = Duration.ofSeconds(1);
141-
this.bufferAccumulator = BufferAccumulator.create(buffer, DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE, bufferTimeout);
142138
this.lastCommitTime = System.currentTimeMillis();
143139
this.numberOfAcksPending = new AtomicInteger(0);
144140
this.errLogRateLimiter = new LogRateLimiter(2, System.currentTimeMillis());
@@ -492,23 +488,19 @@ private <T> Record<Event> getRecord(ConsumerRecord<String, T> consumerRecord, in
492488
return new Record<Event>(event);
493489
}
494490

495-
private void processRecord(final AcknowledgementSet acknowledgementSet, final Record<Event> record) {
491+
private void processRecords(final AcknowledgementSet acknowledgementSet, final List<Record<Event>> eventRecords) {
496492
// Always add record to acknowledgementSet before adding to
497493
// buffer because another thread may take and process
498494
// buffer contents before the event record is added
499495
// to acknowledgement set
500496
if (acknowledgementSet != null) {
501-
acknowledgementSet.add(record.getData());
497+
eventRecords.forEach(record -> acknowledgementSet.add(record.getData()));
502498
}
503499
long numRetries = 0;
504500
while (true) {
505501
LOG.debug("In while loop for processing records, paused = {}", paused);
506502
try {
507-
if (numRetries == 0) {
508-
bufferAccumulator.add(record);
509-
} else {
510-
bufferAccumulator.flush();
511-
}
503+
buffer.writeAll(eventRecords, BUFFER_WRITE_TIMEOUT);
512504
break;
513505
} catch (Exception e) {
514506
if (!paused && numRetries++ > maxRetriesOnException) {
@@ -559,6 +551,7 @@ private <T> void iterateRecordPartitions(ConsumerRecords<String, T> records, fin
559551
}
560552

561553
List<ConsumerRecord<String, T>> partitionRecords = records.records(topicPartition);
554+
final List<Record<Event>> eventRecords = new ArrayList<>();
562555
for (ConsumerRecord<String, T> consumerRecord : partitionRecords) {
563556
if (schema == MessageFormat.BYTES) {
564557
InputStream byteInputStream = new ByteArrayInputStream((byte[])consumerRecord.value());
@@ -567,24 +560,24 @@ private <T> void iterateRecordPartitions(ConsumerRecords<String, T> records, fin
567560
if(byteDecoder != null) {
568561
final long receivedTimeStamp = getRecordTimeStamp(consumerRecord, Instant.now().toEpochMilli());
569562

570-
byteDecoder.parse(decompressedInputStream, Instant.ofEpochMilli(receivedTimeStamp), (record) -> {
571-
processRecord(acknowledgementSet, record);
572-
});
563+
byteDecoder.parse(decompressedInputStream, Instant.ofEpochMilli(receivedTimeStamp), eventRecords::add);
573564
} else {
574565
JsonNode jsonNode = objectMapper.readValue(decompressedInputStream, JsonNode.class);
575566

576567
Event event = JacksonLog.builder().withData(jsonNode).build();
577568
Record<Event> record = new Record<>(event);
578-
processRecord(acknowledgementSet, record);
569+
eventRecords.add(record);
579570
}
580571
} else {
581572
Record<Event> record = getRecord(consumerRecord, topicPartition.partition());
582573
if (record != null) {
583-
processRecord(acknowledgementSet, record);
574+
eventRecords.add(record);
584575
}
585576
}
586577
}
587578

579+
processRecords(acknowledgementSet, eventRecords);
580+
588581
long lastOffset = partitionRecords.get(partitionRecords.size() - 1).offset();
589582
long firstOffset = partitionRecords.get(0).offset();
590583
Range<Long> offsetRange = Range.between(firstOffset, lastOffset);

data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerTest.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
import org.junit.jupiter.api.BeforeEach;
2323
import org.junit.jupiter.api.Test;
2424
import org.junit.jupiter.api.extension.ExtendWith;
25+
import org.junit.jupiter.params.ParameterizedTest;
26+
import org.junit.jupiter.params.provider.Arguments;
27+
import org.junit.jupiter.params.provider.MethodSource;
2528
import org.mockito.Mock;
2629
import org.mockito.junit.jupiter.MockitoExtension;
2730
import org.mockito.junit.jupiter.MockitoSettings;
@@ -52,7 +55,9 @@
5255
import java.util.Map;
5356
import java.util.concurrent.Executors;
5457
import java.util.concurrent.ScheduledExecutorService;
58+
import java.util.concurrent.TimeoutException;
5559
import java.util.concurrent.atomic.AtomicBoolean;
60+
import java.util.stream.Stream;
5661

5762
import static org.awaitility.Awaitility.await;
5863
import static org.hamcrest.CoreMatchers.equalTo;
@@ -62,8 +67,8 @@
6267
import static org.mockito.ArgumentMatchers.any;
6368
import static org.mockito.ArgumentMatchers.anyInt;
6469
import static org.mockito.ArgumentMatchers.anyMap;
65-
import static org.mockito.Mockito.doThrow;
6670
import static org.mockito.Mockito.doAnswer;
71+
import static org.mockito.Mockito.doThrow;
6772
import static org.mockito.Mockito.mock;
6873
import static org.mockito.Mockito.when;
6974

@@ -230,14 +235,15 @@ public void testGetRecordTimeStamp() {
230235
assertThat(consumer.getRecordTimeStamp(consumerRecord3, nowMs), equalTo(nowMs));
231236
}
232237

233-
@Test
234-
public void testBufferOverflowPauseResume() throws InterruptedException, Exception {
238+
@ParameterizedTest
239+
@MethodSource("provideExceptionsFromBufferWrite")
240+
public void testBufferOverflowPauseResume(final Exception bufferException) throws InterruptedException, Exception {
235241
when(topicConfig.getMaxPollInterval()).thenReturn(Duration.ofMillis(4000));
236242
String topic = topicConfig.getName();
237243
consumerRecords = createPlainTextRecords(topic, 0L);
238244
doAnswer((i)-> {
239245
if (!paused && !resumed)
240-
throw new SizeOverflowException("size overflow");
246+
throw bufferException;
241247
buffer.writeAll(i.getArgument(0), i.getArgument(1));
242248
return null;
243249
}).when(mockBuffer).writeAll(any(), anyInt());
@@ -690,6 +696,12 @@ private ConsumerRecords createJsonRecords(String topic) throws Exception {
690696
records.put(new TopicPartition(topic, testJsonPartition), Arrays.asList(record1, record2));
691697
return new ConsumerRecords(records);
692698
}
699+
700+
private static Stream<Arguments> provideExceptionsFromBufferWrite() {
701+
return Stream.of(
702+
Arguments.of(new SizeOverflowException("size overflow")),
703+
Arguments.of(new TimeoutException()));
704+
}
693705
}
694706

695707

0 commit comments

Comments
 (0)