Skip to content

Commit 7cb72ca

Browse files
authored
fix: mitigate STS assume role throttling in Kafka buffer (#6634)
Prevent excessive STS AssumeRole calls when customers delete their IAM role or misconfigure the trust policy. Previously, one pipeline could generate 12,000 STS calls in 4 minutes due to unbounded retries of non-retryable AccessDeniedException errors. Changes: - KafkaSecurityConfigurer: Fail fast on STS 403 (AccessDenied) in getBootStrapServersForMsk() instead of retrying 360 times - KafkaSecurityConfigurer: Replace fixed 10s retry sleep with exponential backoff (10s to 10min max) for retryable STS and Kafka errors - KafkaCustomConsumer: Replace fixed 10s retry with exponential backoff using Kafka's ExponentialBackoff (10s to 10min max) for AuthenticationException errors - KafkaCustomConsumer: Use Duration constants for backoff readability - KafkaCustomConsumer: Reset backoff counter on successful poll to handle transient errors gracefully Add exponential backoff to outer run() exception handler Signed-off-by: Dinu John <86094133+dinujoh@users.noreply.github.com> * chore: remove unused imports from KafkaSecurityConfigurerTest Signed-off-by: Dinu John <86094133+dinujoh@users.noreply.github.com> * refactor: address review comments - use Kafka ExponentialBackoff, Duration constants, remove silent shutdown Signed-off-by: Dinu John <86094133+dinujoh@users.noreply.github.com> --------- Signed-off-by: Dinu John <86094133+dinujoh@users.noreply.github.com>
1 parent 3c95779 commit 7cb72ca

4 files changed

Lines changed: 209 additions & 8 deletions

File tree

data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumer.java

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.kafka.common.TopicPartition;
2626
import org.apache.kafka.common.errors.AuthenticationException;
2727
import org.apache.kafka.common.errors.RebalanceInProgressException;
28+
import org.apache.kafka.common.utils.ExponentialBackoff;
2829
import org.apache.kafka.common.errors.RecordDeserializationException;
2930
import org.apache.kafka.common.header.Header;
3031
import org.apache.kafka.common.header.Headers;
@@ -75,6 +76,8 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener
7576
private static final Logger LOG = LoggerFactory.getLogger(KafkaCustomConsumer.class);
7677
private static final Long COMMIT_OFFSET_INTERVAL_MS = 300000L;
7778
private static final int RETRY_ON_EXCEPTION_SLEEP_MS = 1000;
79+
static final Duration INITIAL_BACKOFF = Duration.ofSeconds(10);
80+
static final Duration MAX_BACKOFF = Duration.ofMinutes(10);
7881
private static final int BUFFER_WRITE_TIMEOUT = 2000;
7982
static final String DEFAULT_KEY = "message";
8083

@@ -104,6 +107,8 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener
104107
private final LogRateLimiter errLogRateLimiter;
105108
private final ByteDecoder byteDecoder;
106109
private final long maxRetriesOnException;
110+
private final ExponentialBackoff exponentialBackoff;
111+
private long authFailureAttempts;
107112
private final Map<Integer, Long> partitionToLastReceivedTimestampMillis;
108113
private final CompressionOption compressionConfig;
109114
private final boolean invokeCallbackOnExpiry;
@@ -146,6 +151,8 @@ public KafkaCustomConsumer(final KafkaConsumer consumer,
146151
this.lastCommitTime = System.currentTimeMillis();
147152
this.numberOfAcksPending = new AtomicInteger(0);
148153
this.errLogRateLimiter = new LogRateLimiter(2, System.currentTimeMillis());
154+
this.exponentialBackoff = new ExponentialBackoff(INITIAL_BACKOFF.toMillis(), 2, MAX_BACKOFF.toMillis(), 0);
155+
this.authFailureAttempts = 0;
149156
this.compressionConfig = (compressionConfig == null) ? CompressionOption.NONE : compressionConfig;
150157
}
151158

@@ -227,6 +234,7 @@ <T> ConsumerRecords<String, T> doPoll() throws Exception {
227234
<T> void consumeRecords() throws Exception {
228235
try {
229236
ConsumerRecords<String, T> records = doPoll();
237+
resetAuthBackoff();
230238
LOG.debug("Consumed records with count {}", records.count());
231239
if (Objects.nonNull(records) && !records.isEmpty() && records.count() > 0) {
232240
Map<TopicPartition, CommitOffsetRange> offsets = new HashMap<>();
@@ -246,9 +254,13 @@ <T> void consumeRecords() throws Exception {
246254
}
247255
}
248256
} catch (AuthenticationException e) {
249-
LOG.warn("Authentication error while doing poll(). Will retry after 10 seconds", e);
257+
authFailureAttempts++;
258+
long backoffMs = exponentialBackoff.backoff(authFailureAttempts - 1);
250259
topicMetrics.getNumberOfPollAuthErrors().increment();
251-
Thread.sleep(10000);
260+
LOG.warn("Authentication error while doing poll() for topic {} (failure count: {}). " +
261+
"Will retry after {} ms. Verify that the IAM role exists and trust policy is correct.",
262+
topicName, authFailureAttempts, backoffMs, e);
263+
sleepMillis(backoffMs);
252264
} catch (RecordDeserializationException e) {
253265
LOG.warn("Deserialization error - topic {} partition {} offset {}. Error message: {}",
254266
e.topicPartition().topic(), e.topicPartition().partition(), e.offset(), e.getMessage());
@@ -272,6 +284,15 @@ <T> void consumeRecords() throws Exception {
272284
}
273285
}
274286

287+
private void resetAuthBackoff() {
288+
authFailureAttempts = 0;
289+
}
290+
291+
@VisibleForTesting
292+
void sleepMillis(long millis) throws InterruptedException {
293+
Thread.sleep(millis);
294+
}
295+
275296
private void addAcknowledgedOffsets(final TopicPartition topicPartition, final Range<Long> offsetRange) {
276297
final int partitionId = topicPartition.partition();
277298
final TopicPartitionCommitTracker commitTracker = partitionCommitTrackerMap.get(partitionId);
@@ -394,12 +415,14 @@ public void run() {
394415
});
395416

396417
boolean retryingAfterException = false;
418+
long outerRetryAttempts = 0;
397419
while (!shutdownInProgress.get()) {
398420
LOG.debug("Still running Kafka consumer in start of loop");
399421
try {
400422
if (retryingAfterException) {
401-
LOG.debug("Pause consuming from Kafka topic due a previous exception.");
402-
Thread.sleep(10000);
423+
long backoffMs = exponentialBackoff.backoff(outerRetryAttempts - 1);
424+
LOG.debug("Pause consuming from Kafka topic due a previous exception. Backoff: {} ms", backoffMs);
425+
Thread.sleep(backoffMs);
403426
} else if (pauseConsumePredicate.pauseConsuming()) {
404427
LOG.debug("Pause and skip consuming from Kafka topic due to an external condition: {}", pauseConsumePredicate);
405428
paused = true;
@@ -421,9 +444,12 @@ public void run() {
421444
topicMetrics.update(consumer);
422445
LOG.debug("Updated consumer metrics");
423446
retryingAfterException = false;
447+
outerRetryAttempts = 0;
424448
} catch (Exception exp) {
425-
LOG.error("Error while reading the records from the topic {}. Retry after 10 seconds", topicName, exp);
449+
long backoffMs = exponentialBackoff.backoff(outerRetryAttempts);
450+
LOG.error("Error while reading the records from the topic {}. Retry after {} ms", topicName, backoffMs, exp);
426451
retryingAfterException = true;
452+
outerRetryAttempts++;
427453
}
428454
}
429455
LOG.info("Shutting down, number of acks pending = {}", numberOfAcksPending.get());

data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/util/KafkaSecurityConfigurer.java

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
import org.opensearch.dataprepper.plugins.kafka.configuration.ScramAuthConfig;
2424
import org.apache.kafka.clients.consumer.ConsumerConfig;
2525

26+
import org.apache.kafka.common.utils.ExponentialBackoff;
27+
import java.time.Duration;
28+
2629
import software.amazon.awssdk.services.kafka.KafkaClient;
2730
import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersRequest;
2831
import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersResponse;
@@ -287,17 +290,33 @@ public static String getBootStrapServersForMsk(final AwsConfig awsConfig,
287290
.clusterArn(awsMskConfig.getArn())
288291
.build();
289292

293+
final ExponentialBackoff backoff = new ExponentialBackoff(
294+
Duration.ofSeconds(10).toMillis(), 2, Duration.ofMinutes(10).toMillis(), 0);
290295
int numRetries = 0;
291296
boolean retryable;
292297
GetBootstrapBrokersResponse result = null;
293298
do {
294299
retryable = false;
295300
try {
296301
result = kafkaClient.getBootstrapBrokers(request);
297-
} catch (KafkaException | StsException e) {
298-
log.info("Failed to get bootstrap server information from MSK. Will try every 10 seconds for {} seconds", 10*MAX_KAFKA_CLIENT_RETRIES, e);
302+
} catch (StsException e) {
303+
if (e.statusCode() == 403) {
304+
throw new RuntimeException("Access denied when calling STS to get bootstrap server information from MSK. " +
305+
"Verify that the role exists and the trust policy is correctly configured.", e);
306+
}
307+
long backoffMs = backoff.backoff(numRetries);
308+
log.info("Failed to get bootstrap server information from MSK due to STS error. Retrying after {} ms (attempt {}/{})",
309+
backoffMs, numRetries + 1, MAX_KAFKA_CLIENT_RETRIES, e);
310+
try {
311+
Thread.sleep(backoffMs);
312+
} catch (InterruptedException exp) {}
313+
retryable = true;
314+
} catch (KafkaException e) {
315+
long backoffMs = backoff.backoff(numRetries);
316+
log.info("Failed to get bootstrap server information from MSK due to Kafka error. Retrying after {} ms (attempt {}/{})",
317+
backoffMs, numRetries + 1, MAX_KAFKA_CLIENT_RETRIES, e);
299318
try {
300-
Thread.sleep(10000);
319+
Thread.sleep(backoffMs);
301320
} catch (InterruptedException exp) {}
302321
retryable = true;
303322
} catch (Exception e) {

data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerTest.java

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
2222
import org.apache.kafka.common.TopicPartition;
2323
import org.apache.kafka.common.errors.RebalanceInProgressException;
24+
import org.apache.kafka.common.errors.AuthenticationException;
2425
import org.apache.kafka.common.errors.RecordDeserializationException;
2526
import org.junit.jupiter.api.Assertions;
2627
import org.junit.jupiter.api.BeforeEach;
@@ -32,6 +33,7 @@
3233
import org.mockito.Mock;
3334
import org.mockito.junit.jupiter.MockitoExtension;
3435
import org.mockito.junit.jupiter.MockitoSettings;
36+
import org.mockito.ArgumentCaptor;
3537
import org.mockito.quality.Strictness;
3638
import org.opensearch.dataprepper.core.acknowledgements.DefaultAcknowledgementSetManager;
3739
import org.opensearch.dataprepper.model.CheckpointState;
@@ -65,15 +67,20 @@
6567

6668
import static org.awaitility.Awaitility.await;
6769
import static org.hamcrest.CoreMatchers.equalTo;
70+
import static org.hamcrest.CoreMatchers.is;
6871
import static org.hamcrest.MatcherAssert.assertThat;
6972
import static org.hamcrest.Matchers.hasEntry;
7073
import static org.junit.jupiter.api.Assertions.assertTrue;
7174
import static org.mockito.ArgumentMatchers.any;
7275
import static org.mockito.ArgumentMatchers.anyInt;
76+
import static org.mockito.ArgumentMatchers.anyLong;
7377
import static org.mockito.ArgumentMatchers.anyMap;
7478
import static org.mockito.Mockito.doAnswer;
7579
import static org.mockito.Mockito.doThrow;
80+
import static org.mockito.Mockito.times;
7681
import static org.mockito.Mockito.mock;
82+
import static org.mockito.Mockito.spy;
83+
import static org.mockito.Mockito.doNothing;
7784
import static org.mockito.Mockito.verify;
7885
import static org.mockito.Mockito.when;
7986

@@ -159,6 +166,7 @@ public void setUp() throws JsonProcessingException {
159166
when(topicMetrics.getNumberOfRecordsCommitted()).thenReturn(counter);
160167
when(topicMetrics.getNumberOfDeserializationErrors()).thenReturn(counter);
161168
when(topicMetrics.getNumberOfInvalidTimeStamps()).thenReturn(counter);
169+
when(topicMetrics.getNumberOfPollAuthErrors()).thenReturn(counter);
162170
when(topicConfig.getThreadWaitingTime()).thenReturn(Duration.ofSeconds(1));
163171
when(topicConfig.getSerdeFormat()).thenReturn(MessageFormat.PLAINTEXT);
164172
when(topicConfig.getAutoCommit()).thenReturn(false);
@@ -694,6 +702,87 @@ public void testCommitOffsets_OtherException_ClearsOffsets() throws Exception {
694702
"Offsets should be cleared after non-rebalance exception");
695703
}
696704

705+
@Test
706+
public void testConsumeRecords_AuthenticationException_IncrementsCounterAndUsesBackoff() throws Exception {
707+
String topic = topicConfig.getName();
708+
Counter authErrorCounter = mock(Counter.class);
709+
when(topicMetrics.getNumberOfPollAuthErrors()).thenReturn(authErrorCounter);
710+
when(topicConfig.getMaxPollInterval()).thenReturn(Duration.ofMillis(60000));
711+
712+
AuthenticationException authException = new AuthenticationException("Auth failed");
713+
when(kafkaConsumer.poll(any(Duration.class))).thenThrow(authException);
714+
715+
consumer = spy(createObjectUnderTest("plaintext", false));
716+
doNothing().when(consumer).sleepMillis(anyLong());
717+
consumer.onPartitionsAssigned(List.of(new TopicPartition(topic, testPartition)));
718+
719+
consumer.consumeRecords();
720+
721+
verify(authErrorCounter, times(1)).increment();
722+
verify(consumer).sleepMillis(KafkaCustomConsumer.INITIAL_BACKOFF.toMillis());
723+
}
724+
725+
@Test
726+
public void testConsumeRecords_MultipleAuthFailures_UsesExponentialBackoff() throws Exception {
727+
String topic = topicConfig.getName();
728+
Counter authErrorCounter = mock(Counter.class);
729+
when(topicMetrics.getNumberOfPollAuthErrors()).thenReturn(authErrorCounter);
730+
when(topicConfig.getMaxPollInterval()).thenReturn(Duration.ofMillis(60000));
731+
732+
AuthenticationException authException = new AuthenticationException("Auth failed");
733+
when(kafkaConsumer.poll(any(Duration.class))).thenThrow(authException);
734+
735+
consumer = spy(createObjectUnderTest("plaintext", false));
736+
doNothing().when(consumer).sleepMillis(anyLong());
737+
consumer.onPartitionsAssigned(List.of(new TopicPartition(topic, testPartition)));
738+
739+
consumer.consumeRecords();
740+
consumer.consumeRecords();
741+
consumer.consumeRecords();
742+
743+
Assertions.assertFalse(shutdownInProgress.get(),
744+
"Consumer should not shut down on auth failures");
745+
ArgumentCaptor<Long> sleepCaptor = ArgumentCaptor.forClass(Long.class);
746+
verify(consumer, times(3)).sleepMillis(sleepCaptor.capture());
747+
List<Long> sleepValues = sleepCaptor.getAllValues();
748+
assertThat(sleepValues.get(0), is(KafkaCustomConsumer.INITIAL_BACKOFF.toMillis()));
749+
assertThat(sleepValues.get(1), is(KafkaCustomConsumer.INITIAL_BACKOFF.toMillis() * 2));
750+
assertThat(sleepValues.get(2), is(KafkaCustomConsumer.INITIAL_BACKOFF.toMillis() * 4));
751+
}
752+
753+
@Test
754+
public void testConsumeRecords_SuccessfulPollResetsAuthFailureCounter() throws Exception {
755+
String topic = topicConfig.getName();
756+
Counter authErrorCounter = mock(Counter.class);
757+
when(topicMetrics.getNumberOfPollAuthErrors()).thenReturn(authErrorCounter);
758+
when(topicConfig.getMaxPollInterval()).thenReturn(Duration.ofMillis(60000));
759+
760+
AuthenticationException authException = new AuthenticationException("Auth failed");
761+
consumerRecords = createPlainTextRecords(topic, 0L);
762+
763+
when(kafkaConsumer.poll(any(Duration.class)))
764+
.thenThrow(authException)
765+
.thenThrow(authException)
766+
.thenReturn(consumerRecords);
767+
768+
consumer = spy(createObjectUnderTest("plaintext", false));
769+
doNothing().when(consumer).sleepMillis(anyLong());
770+
consumer.onPartitionsAssigned(List.of(new TopicPartition(topic, testPartition)));
771+
772+
consumer.consumeRecords(); // auth failure 1
773+
consumer.consumeRecords(); // auth failure 2
774+
consumer.consumeRecords(); // success - resets backoff
775+
776+
// Next auth failure should use initial backoff again
777+
when(kafkaConsumer.poll(any(Duration.class))).thenThrow(authException);
778+
consumer.consumeRecords(); // auth failure after reset
779+
780+
ArgumentCaptor<Long> sleepCaptor = ArgumentCaptor.forClass(Long.class);
781+
verify(consumer, times(3)).sleepMillis(sleepCaptor.capture());
782+
List<Long> sleepValues = sleepCaptor.getAllValues();
783+
assertThat(sleepValues.get(2), is(KafkaCustomConsumer.INITIAL_BACKOFF.toMillis()));
784+
}
785+
697786
private ConsumerRecords createPlainTextRecords(String topic, final long startOffset) {
698787
Map<TopicPartition, List<ConsumerRecord>> records = new HashMap<>();
699788
ConsumerRecord<String, String> record1 = new ConsumerRecord<>(topic, testPartition, startOffset, testKey1, testValue1);

data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/util/KafkaSecurityConfigurerTest.java

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersResponse;
3434
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
3535
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
36+
import software.amazon.awssdk.services.sts.model.StsException;
3637

3738
import java.io.FileReader;
3839
import java.io.IOException;
@@ -55,6 +56,7 @@
5556
import static org.mockito.ArgumentMatchers.any;
5657
import static org.mockito.Mockito.mock;
5758
import static org.mockito.Mockito.mockStatic;
59+
import static org.mockito.Mockito.times;
5860
import static org.mockito.Mockito.verify;
5961
import static org.mockito.Mockito.verifyNoInteractions;
6062
import static org.mockito.Mockito.when;
@@ -218,6 +220,71 @@ public void testSetAuthPropertiesBootstrapServersOverrideByMSK() throws IOExcept
218220
is("software.amazon.msk.auth.iam.IAMClientCallbackHandler"));
219221
}
220222

223+
@Test
224+
public void testGetBootStrapServersForMsk_StsException403_ThrowsImmediately() throws IOException {
225+
final Properties props = new Properties();
226+
final KafkaSourceConfig kafkaSourceConfig = createKafkaSinkConfig("kafka-pipeline-bootstrap-servers-override-by-msk.yaml");
227+
final KafkaClientBuilder kafkaClientBuilder = mock(KafkaClientBuilder.class);
228+
final KafkaClient kafkaClient = mock(KafkaClient.class);
229+
when(kafkaClientBuilder.credentialsProvider(any())).thenReturn(kafkaClientBuilder);
230+
when(kafkaClientBuilder.region(any(Region.class))).thenReturn(kafkaClientBuilder);
231+
when(kafkaClientBuilder.build()).thenReturn(kafkaClient);
232+
233+
final StsException stsException = (StsException) StsException.builder()
234+
.statusCode(403)
235+
.message("Access Denied")
236+
.build();
237+
238+
when(kafkaClient.getBootstrapBrokers(any(GetBootstrapBrokersRequest.class)))
239+
.thenThrow(stsException);
240+
241+
try (MockedStatic<KafkaClient> mockedKafkaClient = mockStatic(KafkaClient.class)) {
242+
mockedKafkaClient.when(KafkaClient::builder).thenReturn(kafkaClientBuilder);
243+
244+
RuntimeException thrown = org.junit.jupiter.api.Assertions.assertThrows(
245+
RuntimeException.class,
246+
() -> KafkaSecurityConfigurer.setAuthProperties(props, kafkaSourceConfig, LOG)
247+
);
248+
249+
assertThat(thrown.getMessage(), is("Access denied when calling STS to get bootstrap server information from MSK. " +
250+
"Verify that the role exists and the trust policy is correctly configured."));
251+
252+
verify(kafkaClient, times(1)).getBootstrapBrokers(any(GetBootstrapBrokersRequest.class));
253+
}
254+
}
255+
256+
@Test
257+
public void testGetBootStrapServersForMsk_StsExceptionNon403_Retries() throws IOException {
258+
final String testMSKEndpoint = "test-endpoint:9098";
259+
final Properties props = new Properties();
260+
final KafkaSourceConfig kafkaSourceConfig = createKafkaSinkConfig("kafka-pipeline-bootstrap-servers-override-by-msk.yaml");
261+
final KafkaClientBuilder kafkaClientBuilder = mock(KafkaClientBuilder.class);
262+
final KafkaClient kafkaClient = mock(KafkaClient.class);
263+
when(kafkaClientBuilder.credentialsProvider(any())).thenReturn(kafkaClientBuilder);
264+
when(kafkaClientBuilder.region(any(Region.class))).thenReturn(kafkaClientBuilder);
265+
when(kafkaClientBuilder.build()).thenReturn(kafkaClient);
266+
267+
final StsException stsException = (StsException) StsException.builder()
268+
.statusCode(500)
269+
.message("Internal Server Error")
270+
.build();
271+
272+
final GetBootstrapBrokersResponse response = mock(GetBootstrapBrokersResponse.class);
273+
when(response.bootstrapBrokerStringSaslIam()).thenReturn(testMSKEndpoint);
274+
275+
when(kafkaClient.getBootstrapBrokers(any(GetBootstrapBrokersRequest.class)))
276+
.thenThrow(stsException)
277+
.thenReturn(response);
278+
279+
try (MockedStatic<KafkaClient> mockedKafkaClient = mockStatic(KafkaClient.class)) {
280+
mockedKafkaClient.when(KafkaClient::builder).thenReturn(kafkaClientBuilder);
281+
KafkaSecurityConfigurer.setAuthProperties(props, kafkaSourceConfig, LOG);
282+
}
283+
284+
assertThat(props.getProperty("bootstrap.servers"), is(testMSKEndpoint));
285+
verify(kafkaClient, times(2)).getBootstrapBrokers(any(GetBootstrapBrokersRequest.class));
286+
}
287+
221288
@Test
222289
public void testSetAuthPropertiesMskWithSaslPlain() throws IOException {
223290
final String testMSKEndpoint = UUID.randomUUID().toString();

0 commit comments

Comments
 (0)