Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.common.errors.RebalanceInProgressException;
import org.apache.kafka.common.header.Header;
import org.apache.kafka.common.header.Headers;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.errors.AuthenticationException;
import org.apache.kafka.common.errors.RebalanceInProgressException;
import org.apache.kafka.common.errors.RecordDeserializationException;
import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.apache.kafka.common.header.Header;
import org.apache.kafka.common.header.Headers;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.model.buffer.Buffer;
Expand Down Expand Up @@ -70,8 +69,8 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener

private static final Logger LOG = LoggerFactory.getLogger(KafkaCustomConsumer.class);
private static final Long COMMIT_OFFSET_INTERVAL_MS = 300000L;
private static final int DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE = 1;
private static final int RETRY_ON_EXCEPTION_SLEEP_MS = 1000;
private static final int BUFFER_WRITE_TIMEOUT = 2000;
static final String DEFAULT_KEY = "message";

private volatile long lastCommitTime;
Expand All @@ -81,7 +80,6 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener
private final TopicConsumerConfig topicConfig;
private MessageFormat schema;
private boolean paused;
private final BufferAccumulator<Record<Event>> bufferAccumulator;
private final Buffer<Record<Event>> buffer;
private static final ObjectMapper objectMapper = new ObjectMapper();
private final JsonFactory jsonFactory = new JsonFactory();
Expand Down Expand Up @@ -123,7 +121,7 @@ public KafkaCustomConsumer(final KafkaConsumer consumer,
this.paused = false;
this.byteDecoder = byteDecoder;
this.topicMetrics = topicMetrics;
this.maxRetriesOnException = topicConfig.getMaxPollInterval().toMillis() / (2 * RETRY_ON_EXCEPTION_SLEEP_MS);
this.maxRetriesOnException = topicConfig.getMaxPollInterval().toMillis() / (2 * (RETRY_ON_EXCEPTION_SLEEP_MS + BUFFER_WRITE_TIMEOUT));
this.pauseConsumePredicate = pauseConsumePredicate;
this.topicMetrics.register(consumer);
this.offsetsToCommit = new HashMap<>();
Expand All @@ -137,8 +135,6 @@ public KafkaCustomConsumer(final KafkaConsumer consumer,
this.partitionCommitTrackerMap = new HashMap<>();
this.partitionsToReset = Collections.synchronizedSet(new HashSet<>());
this.schema = MessageFormat.getByMessageFormatByName(schemaType);
Duration bufferTimeout = Duration.ofSeconds(1);
this.bufferAccumulator = BufferAccumulator.create(buffer, DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE, bufferTimeout);
this.lastCommitTime = System.currentTimeMillis();
this.numberOfAcksPending = new AtomicInteger(0);
this.errLogRateLimiter = new LogRateLimiter(2, System.currentTimeMillis());
Expand Down Expand Up @@ -492,23 +488,19 @@ private <T> Record<Event> getRecord(ConsumerRecord<String, T> consumerRecord, in
return new Record<Event>(event);
}

private void processRecord(final AcknowledgementSet acknowledgementSet, final Record<Event> record) {
private void processRecords(final AcknowledgementSet acknowledgementSet, final List<Record<Event>> eventRecords) {
// Always add record to acknowledgementSet before adding to
// buffer because another thread may take and process
// buffer contents before the event record is added
// to acknowledgement set
if (acknowledgementSet != null) {
acknowledgementSet.add(record.getData());
eventRecords.forEach(record -> acknowledgementSet.add(record.getData()));
}
long numRetries = 0;
while (true) {
LOG.debug("In while loop for processing records, paused = {}", paused);
try {
if (numRetries == 0) {
bufferAccumulator.add(record);
} else {
bufferAccumulator.flush();
}
buffer.writeAll(eventRecords, BUFFER_WRITE_TIMEOUT);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the "topicMetrics.getNumberOfBufferSizeOverflows().increment();" needs to be modified as well to keep track of how times we got this timeout exception. Maybe check for "SizeOverflowException" can be removed as well?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well we are still tracking SizeOverflowException only in the code below where we check the exception type. If you look at the buffer code you can see that we already have a metric for TimeoutException here (

). Under the hood the BufferAccumulator was just calling buffer.writeAll() so it's the same code path as far as metrics go

break;
} catch (Exception e) {
if (!paused && numRetries++ > maxRetriesOnException) {
Expand Down Expand Up @@ -559,6 +551,7 @@ private <T> void iterateRecordPartitions(ConsumerRecords<String, T> records, fin
}

List<ConsumerRecord<String, T>> partitionRecords = records.records(topicPartition);
final List<Record<Event>> eventRecords = new ArrayList<>();
for (ConsumerRecord<String, T> consumerRecord : partitionRecords) {
if (schema == MessageFormat.BYTES) {
InputStream byteInputStream = new ByteArrayInputStream((byte[])consumerRecord.value());
Expand All @@ -567,24 +560,24 @@ private <T> void iterateRecordPartitions(ConsumerRecords<String, T> records, fin
if(byteDecoder != null) {
final long receivedTimeStamp = getRecordTimeStamp(consumerRecord, Instant.now().toEpochMilli());

byteDecoder.parse(decompressedInputStream, Instant.ofEpochMilli(receivedTimeStamp), (record) -> {
processRecord(acknowledgementSet, record);
});
byteDecoder.parse(decompressedInputStream, Instant.ofEpochMilli(receivedTimeStamp), eventRecords::add);
} else {
JsonNode jsonNode = objectMapper.readValue(decompressedInputStream, JsonNode.class);

Event event = JacksonLog.builder().withData(jsonNode).build();
Record<Event> record = new Record<>(event);
processRecord(acknowledgementSet, record);
eventRecords.add(record);
}
} else {
Record<Event> record = getRecord(consumerRecord, topicPartition.partition());
if (record != null) {
processRecord(acknowledgementSet, record);
eventRecords.add(record);
}
}
}

processRecords(acknowledgementSet, eventRecords);

long lastOffset = partitionRecords.get(partitionRecords.size() - 1).offset();
long firstOffset = partitionRecords.get(0).offset();
Range<Long> offsetRange = Range.between(firstOffset, lastOffset);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
Expand Down Expand Up @@ -52,7 +55,9 @@
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Stream;

import static org.awaitility.Awaitility.await;
import static org.hamcrest.CoreMatchers.equalTo;
Expand All @@ -62,8 +67,8 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -230,14 +235,15 @@ public void testGetRecordTimeStamp() {
assertThat(consumer.getRecordTimeStamp(consumerRecord3, nowMs), equalTo(nowMs));
}

@Test
public void testBufferOverflowPauseResume() throws InterruptedException, Exception {
@ParameterizedTest
@MethodSource("provideExceptionsFromBufferWrite")
public void testBufferOverflowPauseResume(final Exception bufferException) throws InterruptedException, Exception {
when(topicConfig.getMaxPollInterval()).thenReturn(Duration.ofMillis(4000));
String topic = topicConfig.getName();
consumerRecords = createPlainTextRecords(topic, 0L);
doAnswer((i)-> {
if (!paused && !resumed)
throw new SizeOverflowException("size overflow");
throw bufferException;
buffer.writeAll(i.getArgument(0), i.getArgument(1));
return null;
}).when(mockBuffer).writeAll(any(), anyInt());
Expand Down Expand Up @@ -690,6 +696,12 @@ private ConsumerRecords createJsonRecords(String topic) throws Exception {
records.put(new TopicPartition(topic, testJsonPartition), Arrays.asList(record1, record2));
return new ConsumerRecords(records);
}

private static Stream<Arguments> provideExceptionsFromBufferWrite() {
return Stream.of(
Arguments.of(new SizeOverflowException("size overflow")),
Arguments.of(new TimeoutException()));
}
}


Loading