Skip to content

Commit 56e412f

Browse files
Monitoring: Accurately capture predict failures (#4525) (#4555)
* fix: correctly capture ml failure stats on predict failures * refactor: remove redundant braces * refactor: dont track any 400 errors and add positive test case * fix: apply spotless --------- (cherry picked from commit a28d344) Signed-off-by: Pavan Yekbote <pybot@amazon.com> Co-authored-by: Pavan Yekbote <pybot@amazon.com>
1 parent 527ab99 commit 56e412f

2 files changed

Lines changed: 42 additions & 2 deletions

File tree

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ private void runPredict(
461461
mlModelManager.trackPredictDuration(modelId, startTime);
462462
internalListener.onResponse(output);
463463
}
464-
}, e -> handlePredictFailure(mlTask, internalListener, e, false, modelId, actionName));
464+
}, e -> handlePredictFailure(mlTask, internalListener, e, shouldTrackRemoteFailure(e), modelId, actionName));
465465
predictor.asyncPredict(mlInput, trackPredictDurationListener);
466466
} else {
467467
MLOutput output = mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput));
@@ -478,7 +478,7 @@ private void runPredict(
478478
return;
479479
} catch (Exception e) {
480480
log.error("Failed to predict model " + modelId, e);
481-
handlePredictFailure(mlTask, internalListener, e, false, modelId, actionName);
481+
handlePredictFailure(mlTask, internalListener, e, shouldTrackRemoteFailure(e), modelId, actionName);
482482
return;
483483
}
484484
} else if (FunctionName.needDeployFirst(algorithm)) {
@@ -604,4 +604,22 @@ public void validateOutputSchema(String modelId, ModelTensorOutput output) {
604604
}
605605
}
606606
}
607+
608+
boolean shouldTrackRemoteFailure(Exception e) {
609+
// Don't track failures for user configuration issues
610+
if (e instanceof IllegalArgumentException) {
611+
return false;
612+
}
613+
614+
// Don't track any 4xx client errors (user/configuration issues)
615+
if (e instanceof OpenSearchStatusException) {
616+
RestStatus status = ((OpenSearchStatusException) e).status();
617+
if (status.getStatus() >= 400 && status.getStatus() < 500) {
618+
return false;
619+
}
620+
}
621+
622+
// Track failures for infrastructure/service issues
623+
return true;
624+
}
607625
}

plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import org.opensearch.core.action.ActionListener;
5252
import org.opensearch.core.common.bytes.BytesReference;
5353
import org.opensearch.core.common.transport.TransportAddress;
54+
import org.opensearch.core.rest.RestStatus;
5455
import org.opensearch.core.xcontent.ToXContent;
5556
import org.opensearch.core.xcontent.XContentBuilder;
5657
import org.opensearch.index.get.GetResult;
@@ -721,4 +722,25 @@ private void setupMocks(boolean runOnLocalNode, boolean failedToParseQueryInput,
721722
}).when(client).get(any(), any());
722723
}
723724
}
725+
726+
public void testShouldTrackRemoteFailure() {
727+
// Test IllegalArgumentException - should not track
728+
assertFalse(taskRunner.shouldTrackRemoteFailure(new IllegalArgumentException("Invalid argument")));
729+
730+
// Test OpenSearchStatusException with 4xx status codes - should not track
731+
assertFalse(taskRunner.shouldTrackRemoteFailure(new OpenSearchStatusException("Bad request", RestStatus.BAD_REQUEST)));
732+
assertFalse(taskRunner.shouldTrackRemoteFailure(new OpenSearchStatusException("Unauthorized", RestStatus.UNAUTHORIZED)));
733+
assertFalse(taskRunner.shouldTrackRemoteFailure(new OpenSearchStatusException("Forbidden", RestStatus.FORBIDDEN)));
734+
assertFalse(taskRunner.shouldTrackRemoteFailure(new OpenSearchStatusException("Not found", RestStatus.NOT_FOUND)));
735+
736+
// Test OpenSearchStatusException with 5xx status codes - should track
737+
assertTrue(taskRunner.shouldTrackRemoteFailure(new OpenSearchStatusException("Server error", RestStatus.INTERNAL_SERVER_ERROR)));
738+
assertTrue(
739+
taskRunner.shouldTrackRemoteFailure(new OpenSearchStatusException("Service unavailable", RestStatus.SERVICE_UNAVAILABLE))
740+
);
741+
742+
// Test other exceptions - should track
743+
assertTrue(taskRunner.shouldTrackRemoteFailure(new RuntimeException("Runtime error")));
744+
assertTrue(taskRunner.shouldTrackRemoteFailure(new IOException("IO error")));
745+
}
724746
}

0 commit comments

Comments
 (0)