3535
3636public class BedrockBatchJobCreator extends AbstractBatchJobCreator {
3737 private final AwsCredentialsSupplier awsCredentialsSupplier ;
38- private static final long MAX_RETRY_WINDOW_MS = 300_000 ; // 5 minutes
3938 @ Getter
40- private final ConcurrentLinkedQueue <ThrottledRecord > throttledRecords = new ConcurrentLinkedQueue <>();
39+ private final ConcurrentLinkedQueue <RetryRecord > throttledRecords = new ConcurrentLinkedQueue <>();
4140 private final Lock processingLock ;
4241
4342 private static final String BEDROCK_PAYLOAD_TEMPLATE = "{\" parameters\" : {\" inputDataConfig\" : {\" s3InputDataConfig\" : {\" s3Uri\" : \" s3://\" }}," +
@@ -55,13 +54,13 @@ public void createMLBatchJob(List<Record<Event>> inputRecords, List<Record<Event
5554 }
5655
5756 private void processRecords (List <Record <Event >> records , List <Record <Event >> resultRecords ,
58- List <ThrottledRecord > throttledRecords ) {
57+ List <RetryRecord > throttledRecords ) {
5958 List <Record <Event >> failedRecords = new ArrayList <>();
6059 List <DlqObject > dlqObjects = new ArrayList <>();
6160
6261 for (int i = 0 ; i < records .size (); i ++) {
6362 Record <Event > record = records .get (i );
64- ThrottledRecord throttledRecord = throttledRecords != null ? throttledRecords .get (i ) : null ;
63+ RetryRecord throttledRecord = throttledRecords != null ? throttledRecords .get (i ) : null ;
6564
6665 processRecord (record , resultRecords , failedRecords , dlqObjects , throttledRecord );
6766
@@ -89,7 +88,7 @@ private void processRecords(List<Record<Event>> records, List<Record<Event>> res
8988
9089 private void processRecord (Record <Event > record , List <Record <Event >> resultRecords ,
9190 List <Record <Event >> failedRecords , List <DlqObject > dlqObjects ,
92- ThrottledRecord throttledRecord ) {
91+ RetryRecord throttledRecord ) {
9392 try {
9493 String s3Uri = generateS3Uri (record );
9594 String payload = createPayloadBedrock (s3Uri , mlProcessorConfig );
@@ -122,9 +121,9 @@ private void processRecord(Record<Event> record, List<Record<Event>> resultRecor
122121 if (e instanceof MLBatchJobException ) {
123122 MLBatchJobException mlException = (MLBatchJobException ) e ;
124123 statusCode = mlException .getStatusCode ();
125- if (statusCode == 429 ) {
126- ThrottledRecord newThrottledRecord = throttledRecord != null ?
127- throttledRecord : new ThrottledRecord (record );
124+ if (shouldRetry ( statusCode , mlException . getMessage ()) ) {
125+ RetryRecord newThrottledRecord = throttledRecord != null ?
126+ throttledRecord : new RetryRecord (record );
128127 throttledRecords .offer (newThrottledRecord );
129128 LOG .info ("Request {} throttled{}, added to retry queue: {}" ,
130129 throttledRecord != null ? "still" : "" ,
@@ -158,11 +157,29 @@ public void addProcessedBatchRecordsToResults(List<Record<Event>> resultRecords)
158157
159158 try {
160159 processThrottledRecords (resultRecords );
160+ } catch (Exception e ) {
161+ LOG .error ("Error in batch processing throttled records. Error: {}" , e .getMessage ());
161162 } finally {
162163 processingLock .unlock ();
163164 }
164165 }
165166
167+ private boolean shouldRetry (int statusCode , String errorMessage ) {
168+ if (statusCode == TOO_MANY_REQUESTS ) {
169+ return true ;
170+ }
171+
172+ if (errorMessage == null ) {
173+ return false ;
174+ }
175+
176+ // Check for quota-related messages
177+ return (statusCode == HttpURLConnection .HTTP_BAD_REQUEST ) &&
178+ (errorMessage .contains ("quota for number of concurrent invoke-model jobs" ) ||
179+ errorMessage .contains ("throttling" ) ||
180+ errorMessage .contains ("request was denied due to remote server throttling" ));
181+ }
182+
166183 private void handleFailure (Record <Event > record ,
167184 List <Record <Event >> resultRecords ,
168185 List <Record <Event >> failedRecords ,
@@ -203,11 +220,11 @@ private void pushToDlq(List<DlqObject> dlqObjects) {
203220 }
204221
205222 private void processThrottledRecords (List <Record <Event >> resultRecords ) {
206- List <ThrottledRecord > expiredRecords = new ArrayList <>();
207- List <ThrottledRecord > recordsToRetry = new ArrayList <>();
223+ List <RetryRecord > expiredRecords = new ArrayList <>();
224+ List <RetryRecord > recordsToRetry = new ArrayList <>();
208225
209226 // Process throttled records
210- ThrottledRecord throttledRecord ;
227+ RetryRecord throttledRecord ;
211228 while ((throttledRecord = throttledRecords .poll ()) != null ) {
212229 if (throttledRecord .isExpired ()) {
213230 expiredRecords .add (throttledRecord );
@@ -224,34 +241,34 @@ private void processThrottledRecords(List<Record<Event>> resultRecords) {
224241 retryThrottledRecords (recordsToRetry , resultRecords );
225242 }
226243
227- private void retryThrottledRecords (List <ThrottledRecord > recordsToRetry , List <Record <Event >> resultRecords ) {
244+ private void retryThrottledRecords (List <RetryRecord > recordsToRetry , List <Record <Event >> resultRecords ) {
228245 if (recordsToRetry .isEmpty ()) {
229246 return ;
230247 }
231248
232249 LOG .info ("Retrying {} throttled records" , recordsToRetry .size ());
233250 processRecords (
234251 recordsToRetry .stream ()
235- .map (ThrottledRecord ::getRecord )
252+ .map (RetryRecord ::getRecord )
236253 .collect (Collectors .toCollection (ArrayList ::new )),
237254 resultRecords ,
238255 recordsToRetry
239256 );
240257 }
241258
242- private void handleExpiredRecords (List <ThrottledRecord > expiredRecords , List <Record <Event >> resultRecords ) {
259+ private void handleExpiredRecords (List <RetryRecord > expiredRecords , List <Record <Event >> resultRecords ) {
243260 if (expiredRecords .isEmpty ()) {
244261 return ;
245262 }
246263
247264 List <Record <Event >> failedRecords = new ArrayList <>();
248265 List <DlqObject > dlqObjects = new ArrayList <>();
249266
250- for (ThrottledRecord expiredRecord : expiredRecords ) {
267+ for (RetryRecord expiredRecord : expiredRecords ) {
251268 String errorMessage = String .format (
252269 "Record expired after %d retries over %d minutes" ,
253270 expiredRecord .getRetryCount (),
254- MAX_RETRY_WINDOW_MS / 60000
271+ maxRetryTimeWindow / 60000
255272 );
256273
257274 LOG .error (NOISY , "Record expired from throttle queue: {}" , errorMessage );
@@ -302,32 +319,4 @@ private String createPayloadBedrock(String S3Uri, MLProcessorConfig mlProcessorC
302319 throw new RuntimeException ("Failed to create payload for BedRock batch job" , e );
303320 }
304321 }
305-
306- class ThrottledRecord {
307- private final Record <Event > record ;
308- private final long createdTime ;
309- private int retryCount ;
310-
311- ThrottledRecord (Record <Event > record ) {
312- this .record = record ;
313- this .createdTime = System .currentTimeMillis ();
314- this .retryCount = 0 ;
315- }
316-
317- boolean isExpired () {
318- return System .currentTimeMillis () - createdTime > MAX_RETRY_WINDOW_MS ;
319- }
320-
321- void incrementRetryCount () {
322- retryCount ++;
323- }
324-
325- Record <Event > getRecord () {
326- return record ;
327- }
328-
329- int getRetryCount () {
330- return retryCount ;
331- }
332- }
333322}
0 commit comments