Skip to content

Commit e1195c1

Browse files
authored
allow retry sagemaker batch job creation for longer time window (#6082)
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 5a5cce7 commit e1195c1

6 files changed

Lines changed: 559 additions & 71 deletions

File tree

data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/MLProcessorConfig.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.opensearch.dataprepper.plugins.ml_inference.processor.configuration.AwsAuthenticationOptions;
2222
import org.opensearch.dataprepper.plugins.ml_inference.processor.configuration.ServiceName;
2323

24+
import java.time.Duration;
2425
import java.util.Collections;
2526
import java.util.List;
2627
import java.util.Map;
@@ -31,6 +32,7 @@
3132
"It supports both synchronous and asynchronous invocations based on your use case.")
3233
public class MLProcessorConfig {
3334
private static final int DEFAULT_MAX_BATCH_SIZE = 100;
35+
public static final Duration DEFAULT_RETRY_WINDOW = Duration.ofMinutes(10);
3436

3537
@JsonProperty("aws")
3638
@NotNull
@@ -82,6 +84,11 @@ public class MLProcessorConfig {
8284
@JsonProperty("max_batch_size")
8385
private int maxBatchSize = DEFAULT_MAX_BATCH_SIZE;
8486

87+
@JsonPropertyDescription("The time duration for which the ml_inference processor retains events for retry attempts."
88+
+ "Supports ISO_8601 notation Strings (\"PT20.345S\", \"PT15M\", etc.) as well as simple notation Strings for seconds (\"60s\") and milliseconds (\"1500ms\")")
89+
@JsonProperty("retry_time_window")
90+
private Duration retryTimeWindow = DEFAULT_RETRY_WINDOW;
91+
8592
@JsonProperty("dlq")
8693
private PluginModel dlq;
8794

data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/AbstractBatchJobCreator.java

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public abstract class AbstractBatchJobCreator implements MLBatchJobCreator {
3131
public static final String NUMBER_OF_RECORDS_FAILED_IN_BATCH_JOB = "recordsFailedInBatchJobCreation";
3232
public static final String NUMBER_OF_RECORDS_SUCCEEDED_IN_BATCH_JOB = "recordsSucceededInBatchJobCreation";
3333
protected static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
34+
protected static final int TOO_MANY_REQUESTS = 429;
3435
protected final MLProcessorConfig mlProcessorConfig;
3536
protected final AwsCredentialsSupplier awsCredentialsSupplier;
3637
protected final Counter numberOfBatchJobsSuccessCounter;
@@ -40,7 +41,7 @@ public abstract class AbstractBatchJobCreator implements MLBatchJobCreator {
4041
protected final List<String> tagsOnFailure;
4142
protected final MlCommonRequester mlCommonRequester;
4243
protected DlqPushHandler dlqPushHandler = null;
43-
44+
protected final long maxRetryTimeWindow;
4445
private static final Aws4Signer signer;
4546
static {
4647
signer = Aws4Signer.create();
@@ -60,6 +61,7 @@ public AbstractBatchJobCreator(MLProcessorConfig mlProcessorConfig,
6061
this.tagsOnFailure = mlProcessorConfig.getTagsOnFailure();
6162
this.mlCommonRequester = new MlCommonRequester(signer, mlProcessorConfig, awsCredentialsSupplier);
6263
this.dlqPushHandler = dlqPushHandler;
64+
this.maxRetryTimeWindow = mlProcessorConfig.getRetryTimeWindow().toMillis();
6365
}
6466

6567
// Add common logic here that both subclasses can share
@@ -119,4 +121,32 @@ protected DlqObject createDlqObjectFromEvent(final Event event,
119121
.withPluginId(dlqPushHandler.getDlqPluginSetting().getName())
120122
.build();
121123
}
124+
125+
class RetryRecord {
126+
private final Record<Event> record;
127+
private final long createdTime;
128+
private int retryCount;
129+
130+
RetryRecord(Record<Event> record) {
131+
this.record = record;
132+
this.createdTime = System.currentTimeMillis();
133+
this.retryCount = 0;
134+
}
135+
136+
boolean isExpired() {
137+
return System.currentTimeMillis() - createdTime > maxRetryTimeWindow;
138+
}
139+
140+
void incrementRetryCount() {
141+
retryCount++;
142+
}
143+
144+
Record<Event> getRecord() {
145+
return record;
146+
}
147+
148+
int getRetryCount() {
149+
return retryCount;
150+
}
151+
}
122152
}

data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/BedrockBatchJobCreator.java

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@
3535

3636
public class BedrockBatchJobCreator extends AbstractBatchJobCreator {
3737
private final AwsCredentialsSupplier awsCredentialsSupplier;
38-
private static final long MAX_RETRY_WINDOW_MS = 300_000; // 5 minutes
3938
@Getter
40-
private final ConcurrentLinkedQueue<ThrottledRecord> throttledRecords = new ConcurrentLinkedQueue<>();
39+
private final ConcurrentLinkedQueue<RetryRecord> throttledRecords = new ConcurrentLinkedQueue<>();
4140
private final Lock processingLock;
4241

4342
private static final String BEDROCK_PAYLOAD_TEMPLATE = "{\"parameters\": {\"inputDataConfig\": {\"s3InputDataConfig\": {\"s3Uri\": \"s3://\"}}," +
@@ -55,13 +54,13 @@ public void createMLBatchJob(List<Record<Event>> inputRecords, List<Record<Event
5554
}
5655

5756
private void processRecords(List<Record<Event>> records, List<Record<Event>> resultRecords,
58-
List<ThrottledRecord> throttledRecords) {
57+
List<RetryRecord> throttledRecords) {
5958
List<Record<Event>> failedRecords = new ArrayList<>();
6059
List<DlqObject> dlqObjects = new ArrayList<>();
6160

6261
for (int i = 0; i < records.size(); i++) {
6362
Record<Event> record = records.get(i);
64-
ThrottledRecord throttledRecord = throttledRecords != null ? throttledRecords.get(i) : null;
63+
RetryRecord throttledRecord = throttledRecords != null ? throttledRecords.get(i) : null;
6564

6665
processRecord(record, resultRecords, failedRecords, dlqObjects, throttledRecord);
6766

@@ -89,7 +88,7 @@ private void processRecords(List<Record<Event>> records, List<Record<Event>> res
8988

9089
private void processRecord(Record<Event> record, List<Record<Event>> resultRecords,
9190
List<Record<Event>> failedRecords, List<DlqObject> dlqObjects,
92-
ThrottledRecord throttledRecord) {
91+
RetryRecord throttledRecord) {
9392
try {
9493
String s3Uri = generateS3Uri(record);
9594
String payload = createPayloadBedrock(s3Uri, mlProcessorConfig);
@@ -122,9 +121,9 @@ private void processRecord(Record<Event> record, List<Record<Event>> resultRecor
122121
if (e instanceof MLBatchJobException) {
123122
MLBatchJobException mlException = (MLBatchJobException) e;
124123
statusCode = mlException.getStatusCode();
125-
if (statusCode == 429) {
126-
ThrottledRecord newThrottledRecord = throttledRecord != null ?
127-
throttledRecord : new ThrottledRecord(record);
124+
if (shouldRetry(statusCode, mlException.getMessage())) {
125+
RetryRecord newThrottledRecord = throttledRecord != null ?
126+
throttledRecord : new RetryRecord(record);
128127
throttledRecords.offer(newThrottledRecord);
129128
LOG.info("Request {} throttled{}, added to retry queue: {}",
130129
throttledRecord != null ? "still" : "",
@@ -158,11 +157,29 @@ public void addProcessedBatchRecordsToResults(List<Record<Event>> resultRecords)
158157

159158
try {
160159
processThrottledRecords(resultRecords);
160+
} catch (Exception e) {
161+
LOG.error("Error in batch processing throttled records. Error: {}", e.getMessage());
161162
} finally {
162163
processingLock.unlock();
163164
}
164165
}
165166

167+
private boolean shouldRetry(int statusCode, String errorMessage) {
168+
if (statusCode == TOO_MANY_REQUESTS) {
169+
return true;
170+
}
171+
172+
if (errorMessage == null) {
173+
return false;
174+
}
175+
176+
// Check for quota-related messages
177+
return (statusCode == HttpURLConnection.HTTP_BAD_REQUEST) &&
178+
(errorMessage.contains("quota for number of concurrent invoke-model jobs") ||
179+
errorMessage.contains("throttling") ||
180+
errorMessage.contains("request was denied due to remote server throttling"));
181+
}
182+
166183
private void handleFailure(Record<Event> record,
167184
List<Record<Event>> resultRecords,
168185
List<Record<Event>> failedRecords,
@@ -203,11 +220,11 @@ private void pushToDlq(List<DlqObject> dlqObjects) {
203220
}
204221

205222
private void processThrottledRecords(List<Record<Event>> resultRecords) {
206-
List<ThrottledRecord> expiredRecords = new ArrayList<>();
207-
List<ThrottledRecord> recordsToRetry = new ArrayList<>();
223+
List<RetryRecord> expiredRecords = new ArrayList<>();
224+
List<RetryRecord> recordsToRetry = new ArrayList<>();
208225

209226
// Process throttled records
210-
ThrottledRecord throttledRecord;
227+
RetryRecord throttledRecord;
211228
while ((throttledRecord = throttledRecords.poll()) != null) {
212229
if (throttledRecord.isExpired()) {
213230
expiredRecords.add(throttledRecord);
@@ -224,34 +241,34 @@ private void processThrottledRecords(List<Record<Event>> resultRecords) {
224241
retryThrottledRecords(recordsToRetry, resultRecords);
225242
}
226243

227-
private void retryThrottledRecords(List<ThrottledRecord> recordsToRetry, List<Record<Event>> resultRecords) {
244+
private void retryThrottledRecords(List<RetryRecord> recordsToRetry, List<Record<Event>> resultRecords) {
228245
if (recordsToRetry.isEmpty()) {
229246
return;
230247
}
231248

232249
LOG.info("Retrying {} throttled records", recordsToRetry.size());
233250
processRecords(
234251
recordsToRetry.stream()
235-
.map(ThrottledRecord::getRecord)
252+
.map(RetryRecord::getRecord)
236253
.collect(Collectors.toCollection(ArrayList::new)),
237254
resultRecords,
238255
recordsToRetry
239256
);
240257
}
241258

242-
private void handleExpiredRecords(List<ThrottledRecord> expiredRecords, List<Record<Event>> resultRecords) {
259+
private void handleExpiredRecords(List<RetryRecord> expiredRecords, List<Record<Event>> resultRecords) {
243260
if (expiredRecords.isEmpty()) {
244261
return;
245262
}
246263

247264
List<Record<Event>> failedRecords = new ArrayList<>();
248265
List<DlqObject> dlqObjects = new ArrayList<>();
249266

250-
for (ThrottledRecord expiredRecord : expiredRecords) {
267+
for (RetryRecord expiredRecord : expiredRecords) {
251268
String errorMessage = String.format(
252269
"Record expired after %d retries over %d minutes",
253270
expiredRecord.getRetryCount(),
254-
MAX_RETRY_WINDOW_MS / 60000
271+
maxRetryTimeWindow / 60000
255272
);
256273

257274
LOG.error(NOISY, "Record expired from throttle queue: {}", errorMessage);
@@ -302,32 +319,4 @@ private String createPayloadBedrock(String S3Uri, MLProcessorConfig mlProcessorC
302319
throw new RuntimeException("Failed to create payload for BedRock batch job", e);
303320
}
304321
}
305-
306-
class ThrottledRecord {
307-
private final Record<Event> record;
308-
private final long createdTime;
309-
private int retryCount;
310-
311-
ThrottledRecord(Record<Event> record) {
312-
this.record = record;
313-
this.createdTime = System.currentTimeMillis();
314-
this.retryCount = 0;
315-
}
316-
317-
boolean isExpired() {
318-
return System.currentTimeMillis() - createdTime > MAX_RETRY_WINDOW_MS;
319-
}
320-
321-
void incrementRetryCount() {
322-
retryCount++;
323-
}
324-
325-
Record<Event> getRecord() {
326-
return record;
327-
}
328-
329-
int getRetryCount() {
330-
return retryCount;
331-
}
332-
}
333322
}

0 commit comments

Comments
 (0)