|
21 | 21 | import org.opensearch.dataprepper.plugins.ml_inference.processor.dlq.DlqPushHandler; |
22 | 22 | import org.opensearch.dataprepper.plugins.ml_inference.processor.exception.MLBatchJobException; |
23 | 23 |
|
| 24 | +import java.time.Duration; |
24 | 25 | import java.util.ArrayList; |
25 | 26 | import java.util.Arrays; |
26 | 27 | import java.util.List; |
|
36 | 37 | import static org.mockito.Mockito.times; |
37 | 38 | import static org.mockito.Mockito.verify; |
38 | 39 | import static org.mockito.Mockito.when; |
| 40 | +import static org.opensearch.dataprepper.plugins.ml_inference.processor.MLProcessorConfig.DEFAULT_RETRY_INTERVAL_SECONDS; |
39 | 41 | import static org.opensearch.dataprepper.plugins.ml_inference.processor.MLProcessorConfig.DEFAULT_RETRY_WINDOW; |
40 | 42 | import static org.opensearch.dataprepper.plugins.ml_inference.processor.common.AbstractBatchJobCreator.NUMBER_OF_FAILED_BATCH_JOBS_CREATION; |
41 | 43 | import static org.opensearch.dataprepper.plugins.ml_inference.processor.common.AbstractBatchJobCreator.NUMBER_OF_RECORDS_FAILED_IN_BATCH_JOB; |
@@ -280,4 +282,102 @@ void testCreateMLBatchJob_ThrottledThenSuccess() { |
280 | 282 | verify(bedrockBatchJobCreator, times(1)).incrementSuccessCounter(); |
281 | 283 | } |
282 | 284 | } |
| 285 | + |
| 286 | + @Test |
| 287 | + void testRetryInterval_SkipsRetryBeforeIntervalElapses() throws InterruptedException { |
| 288 | + // Mock retry interval BEFORE creating the object |
| 289 | + when(mlProcessorConfig.getRetryInterval()).thenReturn(Duration.ofSeconds(DEFAULT_RETRY_INTERVAL_SECONDS)); // 1 second for testing |
| 290 | + |
| 291 | + // Create object with mocked config |
| 292 | + bedrockBatchJobCreator = spy(new BedrockBatchJobCreator(mlProcessorConfig, awsCredentialsSupplier, pluginMetrics, dlqPushHandler)); |
| 293 | + |
| 294 | + Event event = mock(Event.class); |
| 295 | + Record<Event> record = new Record<>(event); |
| 296 | + |
| 297 | + when(event.getJsonNode()).thenReturn(OBJECT_MAPPER.createObjectNode() |
| 298 | + .put("bucket", "test-bucket") |
| 299 | + .put("key", "input.jsonl")); |
| 300 | + |
| 301 | + try (MockedStatic<RetryUtil> mockedStatic = mockStatic(RetryUtil.class)) { |
| 302 | + // First attempt - gets throttled |
| 303 | + mockedStatic.when(() -> RetryUtil.retryWithBackoffWithResult(any(Runnable.class), any())) |
| 304 | + .thenReturn(new RetryUtil.RetryResult(false, new MLBatchJobException(429, "throttled"), 1)); |
| 305 | + |
| 306 | + List<Record<Event>> resultRecords = new ArrayList<>(); |
| 307 | + |
| 308 | + // First attempt - gets throttled |
| 309 | + bedrockBatchJobCreator.createMLBatchJob(Arrays.asList(record), resultRecords); |
| 310 | + assertEquals(1, bedrockBatchJobCreator.getThrottledRecords().size()); |
| 311 | + |
| 312 | + // Try to process immediately (should skip due to retry interval) |
| 313 | + bedrockBatchJobCreator.addProcessedBatchRecordsToResults(resultRecords); |
| 314 | + |
| 315 | + // Verify record is still in queue (not processed due to retry interval) |
| 316 | + assertEquals(1, bedrockBatchJobCreator.getThrottledRecords().size()); |
| 317 | + BedrockBatchJobCreator.RetryRecord throttledRecord = bedrockBatchJobCreator.getThrottledRecords().peek(); |
| 318 | + assertNotNull(throttledRecord); |
| 319 | + assertEquals(0, throttledRecord.getRetryCount()); // Not incremented because retry was skipped |
| 320 | + assertTrue(resultRecords.isEmpty()); |
| 321 | + } |
| 322 | + } |
| 323 | + |
| 324 | + @Test |
| 325 | + void testRetryInterval_ProcessesAfterIntervalElapses() throws InterruptedException { |
| 326 | + // Mock retry interval BEFORE creating the object |
| 327 | + when(mlProcessorConfig.getRetryInterval()).thenReturn(Duration.ofSeconds(1)); // 1 second for testing |
| 328 | + |
| 329 | + // Create object with mocked config |
| 330 | + bedrockBatchJobCreator = spy(new BedrockBatchJobCreator(mlProcessorConfig, awsCredentialsSupplier, pluginMetrics, dlqPushHandler)); |
| 331 | + |
| 332 | + Event event = mock(Event.class); |
| 333 | + Record<Event> record = new Record<>(event); |
| 334 | + |
| 335 | + when(event.getJsonNode()).thenReturn(OBJECT_MAPPER.createObjectNode() |
| 336 | + .put("bucket", "test-bucket") |
| 337 | + .put("key", "input.jsonl")); |
| 338 | + |
| 339 | + try (MockedStatic<RetryUtil> mockedStatic = mockStatic(RetryUtil.class)) { |
| 340 | + // First throttled, then success |
| 341 | + mockedStatic.when(() -> RetryUtil.retryWithBackoffWithResult(any(Runnable.class), any())) |
| 342 | + .thenReturn(new RetryUtil.RetryResult(false, new MLBatchJobException(429, "throttled"), 1)) |
| 343 | + .thenReturn(new RetryUtil.RetryResult(true, null, 1)); |
| 344 | + |
| 345 | + List<Record<Event>> resultRecords = new ArrayList<>(); |
| 346 | + |
| 347 | + // First attempt - gets throttled |
| 348 | + bedrockBatchJobCreator.createMLBatchJob(Arrays.asList(record), resultRecords); |
| 349 | + assertEquals(1, bedrockBatchJobCreator.getThrottledRecords().size()); |
| 350 | + |
| 351 | + // Wait for retry interval to elapse |
| 352 | + Thread.sleep(1100); // Wait 1.1 seconds |
| 353 | + |
| 354 | + // Now retry should proceed |
| 355 | + bedrockBatchJobCreator.addProcessedBatchRecordsToResults(resultRecords); |
| 356 | + |
| 357 | + // Verify record was processed successfully |
| 358 | + assertTrue(bedrockBatchJobCreator.getThrottledRecords().isEmpty()); |
| 359 | + assertEquals(1, resultRecords.size()); |
| 360 | + verify(bedrockBatchJobCreator, times(1)).incrementSuccessCounter(); |
| 361 | + } |
| 362 | + } |
| 363 | + |
| 364 | + @Test |
| 365 | + void testRetryInterval_EmptyQueueDoesNotUpdateTimestamp() throws Exception { |
| 366 | + List<Record<Event>> resultRecords = new ArrayList<>(); |
| 367 | + |
| 368 | + // Try to process with empty queue |
| 369 | + long timestampBefore = getLastRetryTimestamp(bedrockBatchJobCreator); |
| 370 | + bedrockBatchJobCreator.addProcessedBatchRecordsToResults(resultRecords); |
| 371 | + long timestampAfter = getLastRetryTimestamp(bedrockBatchJobCreator); |
| 372 | + |
| 373 | + // Verify timestamp was not updated (queue was empty) |
| 374 | + assertEquals(timestampBefore, timestampAfter); |
| 375 | + } |
| 376 | + |
| 377 | + // Helper method to access private lastRetryTimestamp field using reflection |
| 378 | + private long getLastRetryTimestamp(BedrockBatchJobCreator creator) throws Exception { |
| 379 | + java.lang.reflect.Field field = BedrockBatchJobCreator.class.getDeclaredField("lastRetryTimestamp"); |
| 380 | + field.setAccessible(true); |
| 381 | + return (long) field.get(creator); |
| 382 | + } |
283 | 383 | } |
0 commit comments