diff --git a/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/MLProcessor.java b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/MLProcessor.java index d93aff9036..4c096f01e4 100644 --- a/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/MLProcessor.java +++ b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/MLProcessor.java @@ -18,11 +18,14 @@ import org.opensearch.dataprepper.model.processor.AbstractProcessor; import org.opensearch.dataprepper.model.processor.Processor; import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.plugins.ml_inference.processor.common.MLBatchJobCreator; +import org.opensearch.dataprepper.plugins.ml_inference.processor.common.BatchActionExecutor; +import org.opensearch.dataprepper.plugins.ml_inference.processor.common.MLActionExecutor; import org.opensearch.dataprepper.plugins.ml_inference.processor.common.MLBatchJobCreatorFactory; +import org.opensearch.dataprepper.plugins.ml_inference.processor.common.ModelSyncInferenceExecutor; +import org.opensearch.dataprepper.plugins.ml_inference.processor.common.PredictActionExecutor; +import org.opensearch.dataprepper.plugins.ml_inference.processor.configuration.ActionType; import org.opensearch.dataprepper.plugins.ml_inference.processor.configuration.ServiceName; import org.opensearch.dataprepper.plugins.ml_inference.processor.dlq.DlqPushHandler; -import org.opensearch.dataprepper.plugins.ml_inference.processor.exception.MLBatchJobException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,7 +43,7 @@ public class MLProcessor extends AbstractProcessor, Record> public static final String NUMBER_OF_ML_PROCESSOR_FAILED = "BatchJobRequestsFailed"; private final String whenCondition; - private final MLBatchJobCreator mlBatchJobCreator; + private final MLActionExecutor actionExecutor; private final Counter numberOfMLProcessorSuccessCounter; private final Counter numberOfMLProcessorFailedCounter; private final ExpressionEvaluator expressionEvaluator; @@ -52,11 +55,8 @@ public class MLProcessor extends AbstractProcessor, Record> public MLProcessor(final MLProcessorConfig mlProcessorConfig, final PluginMetrics pluginMetrics, final PluginFactory pluginFactory, final PluginSetting pluginSetting, final AwsCredentialsSupplier awsCredentialsSupplier, final ExpressionEvaluator expressionEvaluator) { super(pluginMetrics); this.whenCondition = mlProcessorConfig.getWhenCondition(); - ServiceName serviceName = mlProcessorConfig.getServiceName(); - this.numberOfMLProcessorSuccessCounter = pluginMetrics.counter( - NUMBER_OF_ML_PROCESSOR_SUCCESS); - this.numberOfMLProcessorFailedCounter = pluginMetrics.counter( - NUMBER_OF_ML_PROCESSOR_FAILED); + this.numberOfMLProcessorSuccessCounter = pluginMetrics.counter(NUMBER_OF_ML_PROCESSOR_SUCCESS); + this.numberOfMLProcessorFailedCounter = pluginMetrics.counter(NUMBER_OF_ML_PROCESSOR_FAILED); this.expressionEvaluator = expressionEvaluator; this.pluginSetting = pluginSetting; @@ -64,75 +64,80 @@ public MLProcessor(final MLProcessorConfig mlProcessorConfig, final PluginMetric this.dlqPushHandler = new DlqPushHandler(pluginFactory, pluginSetting, mlProcessorConfig.getDlq(), mlProcessorConfig.getAwsAuthenticationOptions()); } - // Use factory to get the appropriate job creator - mlBatchJobCreator = MLBatchJobCreatorFactory.getJobCreator(serviceName, mlProcessorConfig, awsCredentialsSupplier, pluginMetrics, dlqPushHandler); + if (ActionType.PREDICT.equals(mlProcessorConfig.getActionType())) { + this.actionExecutor = new PredictActionExecutor(new ModelSyncInferenceExecutor(mlProcessorConfig, awsCredentialsSupplier, pluginMetrics)); + } else { + final ServiceName serviceName = mlProcessorConfig.getServiceName(); + this.actionExecutor = new BatchActionExecutor(MLBatchJobCreatorFactory.getJobCreator(serviceName, mlProcessorConfig, awsCredentialsSupplier, pluginMetrics, dlqPushHandler)); + } } @Override - public Collection> doExecute(Collection> records) { - List> resultRecords = new ArrayList<>(); - // check and process any existing batch - mlBatchJobCreator.checkAndProcessBatch(); - // Add processed records to results - mlBatchJobCreator.addProcessedBatchRecordsToResults(resultRecords); - // reads from input - S3 input - if (records.size() == 0) + public Collection> doExecute(final Collection> records) { + final List> resultRecords = new ArrayList<>(); + + actionExecutor.prepareExecution(resultRecords); + + if (records.isEmpty()) { return resultRecords; + } - // Process new records - List> recordsToMlCommons = records.stream() - .filter(record -> { - try { - boolean meetCondition = whenCondition == null || expressionEvaluator.evaluateConditional(whenCondition, record.getData()); - if (!meetCondition) { - resultRecords.add(record); - } - return meetCondition; // Include in recordsToMlCommons if true - } catch (ExpressionParsingException e) { - LOG.warn("Expression parsing failed for record: {}. Error: {}", record, e.getMessage()); - resultRecords.add(record); - return false; // Skip the record on parsing failure - } catch (ClassCastException e) { - LOG.warn("Unexpected return type when evaluating condition for record: {}. Error: {}", record, e.getMessage()); - resultRecords.add(record); - return false; // Skip the record on type mismatch - } catch (Exception e) { - LOG.error("Failed to evaluate conditional expression for record: {}", record, e); - resultRecords.add(record); - return false; // Skip the record if evaluation fails - } - }) - .collect(Collectors.toList()); - - if (recordsToMlCommons.isEmpty()) { + final List> filteredRecords = filterByCondition(records, resultRecords); + if (filteredRecords.isEmpty()) { return resultRecords; } try { - mlBatchJobCreator.createMLBatchJob(recordsToMlCommons, resultRecords); + actionExecutor.execute(filteredRecords, resultRecords); numberOfMLProcessorSuccessCounter.increment(); - } catch (MLBatchJobException e) { - LOG.error(NOISY, "ML Batch job creation failed: {}", e.getMessage()); - numberOfMLProcessorFailedCounter.increment(); - } catch (Exception e) { - LOG.error(NOISY, "Unexpected Error occurred while creating the batch job: {}", e.getMessage(), e); + } catch (final Exception e) { + LOG.error(NOISY, "Unexpected error during ML processing: {}", e.getMessage(), e); numberOfMLProcessorFailedCounter.increment(); } + return resultRecords; } + private List> filterByCondition(final Collection> records, + final List> resultRecords) { + return records.stream() + .filter(record -> { + try { + final boolean meetCondition = whenCondition == null + || expressionEvaluator.evaluateConditional(whenCondition, record.getData()); + if (!meetCondition) { + resultRecords.add(record); + } + return meetCondition; + } catch (ExpressionParsingException e) { + LOG.warn("Expression parsing failed for record: {}. Error: {}", record, e.getMessage()); + resultRecords.add(record); + return false; + } catch (ClassCastException e) { + LOG.warn("Unexpected return type when evaluating condition for record: {}. Error: {}", record, e.getMessage()); + resultRecords.add(record); + return false; + } catch (Exception e) { + LOG.error("Failed to evaluate conditional expression for record: {}", record, e); + resultRecords.add(record); + return false; + } + }) + .collect(Collectors.toList()); + } + @Override public void prepareForShutdown() { - mlBatchJobCreator.prepareForShutdown(); + actionExecutor.prepareForShutdown(); } @Override public boolean isReadyForShutdown() { - return mlBatchJobCreator.isReadyForShutdown(); + return actionExecutor.isReadyForShutdown(); } @Override public void shutdown() { - mlBatchJobCreator.shutdown(); + actionExecutor.shutdown(); } -} \ No newline at end of file +} diff --git a/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/MLProcessorConfig.java b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/MLProcessorConfig.java index 811135eb12..51418b11f4 100644 --- a/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/MLProcessorConfig.java +++ b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/MLProcessorConfig.java @@ -107,6 +107,17 @@ public class MLProcessorConfig { @DurationMax(seconds = 300) private Duration retryInterval = Duration.ofSeconds(DEFAULT_RETRY_INTERVAL_SECONDS); + @JsonPropertyDescription("Maps document fields to model input fields for the predict action type. " + + "Each element is a map of : .") + @JsonProperty("input_map") + private List> inputMap; + + @JsonPropertyDescription("Maps model output fields to new document fields for the predict action type. " + + "Each element is a map of : .") + @JsonProperty("output_map") + private List> outputMap; + + @JsonProperty("dlq") private PluginModel dlq; @@ -126,6 +137,11 @@ public String getWhenCondition() { public List getTagsOnFailure() { return tagsOnFailure; } + public List> getInputMap() { return inputMap; } + + public List> getOutputMap() { return outputMap; } + + public PluginModel getDlq() { return dlq; } diff --git a/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/BatchActionExecutor.java b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/BatchActionExecutor.java new file mode 100644 index 0000000000..304380f281 --- /dev/null +++ b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/BatchActionExecutor.java @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.ml_inference.processor.common; + +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.ml_inference.processor.exception.MLBatchJobException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.List; + +import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; + +public class BatchActionExecutor implements MLActionExecutor { + + private static final Logger LOG = LoggerFactory.getLogger(BatchActionExecutor.class); + + private final MLBatchJobCreator mlBatchJobCreator; + + public BatchActionExecutor(final MLBatchJobCreator mlBatchJobCreator) { + this.mlBatchJobCreator = mlBatchJobCreator; + } + + @Override + public void prepareExecution(final List> resultRecords) { + mlBatchJobCreator.checkAndProcessBatch(); + mlBatchJobCreator.addProcessedBatchRecordsToResults(resultRecords); + } + + @Override + public Collection> execute(final List> filteredRecords, + final List> resultRecords) { + try { + mlBatchJobCreator.createMLBatchJob(filteredRecords, resultRecords); + } catch (final MLBatchJobException e) { + LOG.error(NOISY, "ML Batch job creation failed: {}", e.getMessage()); + throw e; + } + return resultRecords; + } + + @Override + public void prepareForShutdown() { + mlBatchJobCreator.prepareForShutdown(); + } + + @Override + public boolean isReadyForShutdown() { + return mlBatchJobCreator.isReadyForShutdown(); + } + + @Override + public void shutdown() { + mlBatchJobCreator.shutdown(); + } +} diff --git a/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/MLActionExecutor.java b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/MLActionExecutor.java new file mode 100644 index 0000000000..f279408096 --- /dev/null +++ b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/MLActionExecutor.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.ml_inference.processor.common; + +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; + +import java.util.Collection; +import java.util.List; + +public interface MLActionExecutor { + default void prepareExecution(List> resultRecords) {} + + Collection> execute(List> filteredRecords, List> resultRecords); + + default void prepareForShutdown() {} + + default boolean isReadyForShutdown() { return true; } + + default void shutdown() {} +} diff --git a/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/ModelSyncInferenceExecutor.java b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/ModelSyncInferenceExecutor.java new file mode 100644 index 0000000000..3e55842569 --- /dev/null +++ b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/ModelSyncInferenceExecutor.java @@ -0,0 +1,208 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.ml_inference.processor.common; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.micrometer.core.instrument.Counter; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.ml_inference.processor.MLProcessorConfig; +import org.opensearch.dataprepper.plugins.ml_inference.processor.connector.AbstractConnector; +import org.opensearch.dataprepper.plugins.ml_inference.processor.connector.BuiltInConnectors; +import org.opensearch.dataprepper.plugins.ml_inference.processor.connector.Connector; +import org.opensearch.dataprepper.plugins.ml_inference.processor.connector.ConnectorActionType; +import org.opensearch.dataprepper.plugins.ml_inference.processor.connector.ConnectorExecutorFactory; +import org.opensearch.dataprepper.plugins.ml_inference.processor.connector.RemoteConnectorExecutor; +import org.opensearch.dataprepper.plugins.ml_inference.processor.exception.MLBatchJobException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.HttpURLConnection; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Handles synchronous PREDICT invocations. + * + *

For each record: + *

    + *
  1. Reads model input fields from the event using {@code input_map}.
  2. + *
  3. Invokes the remote model via the built-in connector's PREDICT action.
  4. + *
  5. Extracts model output fields from the response using {@code output_map} and writes + * them back into the event.
  6. + *
+ */ +public class ModelSyncInferenceExecutor { + + public static final String NUMBER_OF_SYNC_INFERENCE_RECORDS_SUCCESS = "syncInferenceRecordsSucceeded"; + public static final String NUMBER_OF_SYNC_INFERENCE_RECORDS_FAILED = "syncInferenceRecordsFailed"; + + private static final Logger LOG = LoggerFactory.getLogger(ModelSyncInferenceExecutor.class); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private final MLProcessorConfig config; + private final RemoteConnectorExecutor connectorExecutor; + private final List> inputMap; + private final List> outputMap; + private final List tagsOnFailure; + private final Counter numberOfSyncInferenceRecordsSuccessCounter; + private final Counter numberOfSyncInferenceRecordsFailedCounter; + + public ModelSyncInferenceExecutor(final MLProcessorConfig config, + final AwsCredentialsSupplier awsCredentialsSupplier, + final PluginMetrics pluginMetrics) { + this.config = config; + this.inputMap = config.getInputMap() != null ? config.getInputMap() : Collections.emptyList(); + this.outputMap = config.getOutputMap() != null ? config.getOutputMap() : Collections.emptyList(); + this.tagsOnFailure = config.getTagsOnFailure() != null ? config.getTagsOnFailure() : Collections.emptyList(); + this.connectorExecutor = buildConnectorExecutor(config, awsCredentialsSupplier); + this.numberOfSyncInferenceRecordsSuccessCounter = pluginMetrics.counter(NUMBER_OF_SYNC_INFERENCE_RECORDS_SUCCESS); + this.numberOfSyncInferenceRecordsFailedCounter = pluginMetrics.counter(NUMBER_OF_SYNC_INFERENCE_RECORDS_FAILED); + } + + /** + * Processes all records synchronously. For each record, every entry in {@code input_map} + * triggers one PREDICT invocation; the corresponding {@code output_map} entry determines + * where the result is written back into the event. + */ + public Collection> execute(final Collection> records) { + final List> resultRecords = new ArrayList<>(); + for (final Record record : records) { + try { + processRecord(record); + numberOfSyncInferenceRecordsSuccessCounter.increment(); + resultRecords.add(record); + } catch (final Exception e) { + LOG.error("Failed to run PREDICT for record: {}", e.getMessage(), e); + numberOfSyncInferenceRecordsFailedCounter.increment(); + addFailureTags(record); + resultRecords.add(record); + } + } + return resultRecords; + } + + private void processRecord(final Record record) { + final Event event = record.getData(); + + for (int i = 0; i < inputMap.size(); i++) { + final Map inputEntry = inputMap.get(i); + final Map outputEntry = i < outputMap.size() ? outputMap.get(i) : Collections.emptyMap(); + + final Map runtimeParameters = buildRuntimeParameters(event, inputEntry); + LOG.debug("Invoking PREDICT with parameters: {}", runtimeParameters.keySet()); + + final String responseBody = connectorExecutor.executeActionAndGetResponse( + ConnectorActionType.PREDICT, runtimeParameters); + + writeOutputsToEvent(event, responseBody, outputEntry); + } + } + + /** + * Reads each model input field value from the event using the document field specified + * in the input map entry, and adds the region from config. + */ + private Map buildRuntimeParameters(final Event event, + final Map inputEntry) { + final Map parameters = new HashMap<>(); + parameters.put("region", config.getAwsAuthenticationOptions().getAwsRegion().id()); + + for (final Map.Entry mapping : inputEntry.entrySet()) { + final String modelInputField = mapping.getKey(); + final String documentField = mapping.getValue(); + final String fieldValue = event.get(documentField, String.class); + if (fieldValue == null) { + throw new MLBatchJobException(HttpURLConnection.HTTP_BAD_REQUEST, + "input_map field '" + documentField + "' not found in event"); + } + parameters.put(modelInputField, fieldValue); + } + return parameters; + } + + /** + * Extracts values from the JSON response body using the model output JSON paths defined + * in the output map and writes them into the event under the corresponding document field names. + */ + private void writeOutputsToEvent(final Event event, + final String responseBody, + final Map outputEntry) { + if (outputEntry.isEmpty()) { + return; + } + try { + final JsonNode responseNode = OBJECT_MAPPER.readTree(responseBody); + for (final Map.Entry mapping : outputEntry.entrySet()) { + final String documentField = mapping.getKey(); + final String modelOutputPath = mapping.getValue(); + final Object value = extractFromResponse(responseNode, modelOutputPath); + event.put(documentField, value); + LOG.debug("Wrote output field '{}' from model path '{}'", documentField, modelOutputPath); + } + } catch (final MLBatchJobException e) { + throw e; + } catch (final Exception e) { + throw new MLBatchJobException(HttpURLConnection.HTTP_INTERNAL_ERROR, + "Failed to parse model response: " + e.getMessage()); + } + } + + private Object extractFromResponse(final JsonNode responseNode, final String modelOutputPath) { + try { + final String pointer = modelOutputPath.startsWith("/") ? modelOutputPath : "/" + modelOutputPath; + final JsonNode result = responseNode.at(pointer); + if (result.isMissingNode()) { + throw new IllegalArgumentException("Path '" + modelOutputPath + "' not found in response"); + } + return OBJECT_MAPPER.convertValue(result, Object.class); + } catch (final Exception e) { + throw new MLBatchJobException(HttpURLConnection.HTTP_INTERNAL_ERROR, + "output_map path '" + modelOutputPath + "' not found in model response: " + e.getMessage()); + } + } + + private void addFailureTags(final Record record) { + if (tagsOnFailure.isEmpty()) { + return; + } + final Event event = record.getData(); + if (event.getMetadata() != null) { + event.getMetadata().addTags(tagsOnFailure); + } + } + + private static RemoteConnectorExecutor buildConnectorExecutor(final MLProcessorConfig config, + final AwsCredentialsSupplier supplier) { + return BuiltInConnectors.findConnectorJson(config.getModelId()) + .map(json -> { + try { + final Connector connector = AbstractConnector.fromJson(json); + final RemoteConnectorExecutor executor = ConnectorExecutorFactory.create(connector, config, supplier); + LOG.info("ModelSyncInferenceExecutor using built-in connector for model: {}", config.getModelId()); + return executor; + } catch (final Exception e) { + throw new RuntimeException( + "Failed to initialize connector for model: " + config.getModelId(), e); + } + }) + .orElseThrow(() -> new IllegalArgumentException( + "No built-in connector found for model_id '" + config.getModelId() + + "'. The predict action_type requires a supported model_id.")); + } +} diff --git a/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/PredictActionExecutor.java b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/PredictActionExecutor.java new file mode 100644 index 0000000000..cdff4be2bf --- /dev/null +++ b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/PredictActionExecutor.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.ml_inference.processor.common; + +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; + +import java.util.Collection; +import java.util.List; + +public class PredictActionExecutor implements MLActionExecutor { + + private final ModelSyncInferenceExecutor modelSyncInferenceExecutor; + + public PredictActionExecutor(final ModelSyncInferenceExecutor modelSyncInferenceExecutor) { + this.modelSyncInferenceExecutor = modelSyncInferenceExecutor; + } + + @Override + public Collection> execute(final List> filteredRecords, + final List> resultRecords) { + resultRecords.addAll(modelSyncInferenceExecutor.execute(filteredRecords)); + return resultRecords; + } +} diff --git a/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/connector/AbstractConnectorExecutor.java b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/connector/AbstractConnectorExecutor.java index d62b8c1db7..194312d72a 100644 --- a/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/connector/AbstractConnectorExecutor.java +++ b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/connector/AbstractConnectorExecutor.java @@ -26,15 +26,20 @@ public abstract class AbstractConnectorExecutor implements RemoteConnectorExecut private static final Logger LOG = LoggerFactory.getLogger(AbstractConnectorExecutor.class); - /** - * {@inheritDoc} - * - *

Merges runtime parameters with connector defaults, resolves URL and payload via - * template substitution, then calls {@link #sendRequest}. - */ @Override public void executeAction(final ConnectorActionType actionType, final Map runtimeParameters) { + resolveAndSend(actionType, runtimeParameters); + } + + @Override + public String executeActionAndGetResponse(final ConnectorActionType actionType, + final Map runtimeParameters) { + return resolveAndSend(actionType, runtimeParameters); + } + + private String resolveAndSend(final ConnectorActionType actionType, + final Map runtimeParameters) { final Connector connector = getConnector(); final String actionName = actionType.name(); @@ -50,19 +55,20 @@ public void executeAction(final ConnectorActionType actionType, final String payload = connector.createPayload(actionName, merged); LOG.debug("Sending {} request to: {}", action.getMethod(), url); - sendRequest(action, url, payload, merged); + return sendRequest(action, url, payload, merged); } /** - * Performs the actual HTTP request after URL and payload have been resolved. + * Performs the actual HTTP request and returns the raw response body. * * @param action the matching connector action (carries method, headers, etc.) * @param url the fully-resolved request URL * @param payload the fully-resolved request body * @param merged the merged parameter map (connector defaults + runtime overrides) + * @return the raw HTTP response body string */ - protected abstract void sendRequest(ConnectorAction action, - String url, - String payload, - Map merged); + protected abstract String sendRequest(ConnectorAction action, + String url, + String payload, + Map merged); } diff --git a/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/connector/AwsConnectorExecutor.java b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/connector/AwsConnectorExecutor.java index 9131efd522..523d1c091c 100644 --- a/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/connector/AwsConnectorExecutor.java +++ b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/connector/AwsConnectorExecutor.java @@ -95,10 +95,10 @@ public AwsConnector getConnector() { * {@code connector.parameters.service_name}, then executes it synchronously. */ @Override - protected void sendRequest(final ConnectorAction action, - final String url, - final String payload, - final Map merged) { + protected String sendRequest(final ConnectorAction action, + final String url, + final String payload, + final Map merged) { final SdkHttpMethod method = SdkHttpMethod.fromValue(action.getMethod()); final SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder() @@ -125,7 +125,7 @@ protected void sendRequest(final ConnectorAction action, .contentStreamProvider(signedRequest.contentStreamProvider().orElse(null)) .build(); - executeHttpRequest(executeRequest, action.getActionType()); + return executeHttpRequest(executeRequest, action.getActionType()); } private SdkHttpFullRequest signRequest(final SdkHttpFullRequest request, @@ -148,7 +148,7 @@ private SdkHttpFullRequest signRequest(final SdkHttpFullRequest request, } } - private void executeHttpRequest(final HttpExecuteRequest executeRequest, final String action) { + private String executeHttpRequest(final HttpExecuteRequest executeRequest, final String action) { final HttpExecuteResponse response; try { response = httpClientExecutor.execute(executeRequest); @@ -161,10 +161,10 @@ private void executeHttpRequest(final HttpExecuteRequest executeRequest, final S throw new MLBatchJobException(HttpURLConnection.HTTP_INTERNAL_ERROR, "Unexpected error executing " + action + " request: " + e.getMessage()); } - handleHttpResponse(response, action); + return handleHttpResponse(response, action); } - private void handleHttpResponse(final HttpExecuteResponse response, final String action) { + private String handleHttpResponse(final HttpExecuteResponse response, final String action) { final int statusCode = response.httpResponse().statusCode(); final String responseBody = response.responseBody().map(this::readStream).orElse("No response"); @@ -186,7 +186,8 @@ private void handleHttpResponse(final HttpExecuteResponse response, final String "Unexpected status code " + statusCode + " on " + action); } - LOG.info("{} request succeeded: {}", action, responseBody); + LOG.info("{} request succeeded", action); + return responseBody; } private String readStream(final AbortableInputStream stream) { diff --git a/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/connector/RemoteConnectorExecutor.java b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/connector/RemoteConnectorExecutor.java index 890c881ef6..43d5008e16 100644 --- a/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/connector/RemoteConnectorExecutor.java +++ b/data-prepper-plugins/ml-inference-processor/src/main/java/org/opensearch/dataprepper/plugins/ml_inference/processor/connector/RemoteConnectorExecutor.java @@ -33,4 +33,14 @@ public interface RemoteConnectorExecutor { * @param runtimeParameters per-request parameter overrides merged with connector defaults */ void executeAction(ConnectorActionType actionType, Map runtimeParameters); + + /** + * Executes the named action and returns the raw response body string. + * Used by synchronous actions (e.g. PREDICT) where the caller needs the model output. + * + * @param actionType the action to execute (e.g. PREDICT) + * @param runtimeParameters per-request parameter overrides merged with connector defaults + * @return the raw HTTP response body + */ + String executeActionAndGetResponse(ConnectorActionType actionType, Map runtimeParameters); } diff --git a/data-prepper-plugins/ml-inference-processor/src/test/java/org/opensearch/dataprepper/plugins/ml_inference/processor/MLProcessorTest.java b/data-prepper-plugins/ml-inference-processor/src/test/java/org/opensearch/dataprepper/plugins/ml_inference/processor/MLProcessorTest.java index da378776e0..042ecf2a05 100644 --- a/data-prepper-plugins/ml-inference-processor/src/test/java/org/opensearch/dataprepper/plugins/ml_inference/processor/MLProcessorTest.java +++ b/data-prepper-plugins/ml-inference-processor/src/test/java/org/opensearch/dataprepper/plugins/ml_inference/processor/MLProcessorTest.java @@ -6,9 +6,10 @@ package org.opensearch.dataprepper.plugins.ml_inference.processor; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; -import org.junit.jupiter.api.Test; import org.mockito.MockitoAnnotations; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; @@ -18,26 +19,31 @@ import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.plugin.PluginFactory; import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.plugins.ml_inference.processor.common.MLBatchJobCreator; +import org.opensearch.dataprepper.plugins.ml_inference.processor.common.BatchActionExecutor; +import org.opensearch.dataprepper.plugins.ml_inference.processor.common.MLActionExecutor; +import org.opensearch.dataprepper.plugins.ml_inference.processor.common.ModelSyncInferenceExecutor; +import org.opensearch.dataprepper.plugins.ml_inference.processor.common.PredictActionExecutor; import io.micrometer.core.instrument.Counter; +import org.opensearch.dataprepper.plugins.ml_inference.processor.configuration.ActionType; import org.opensearch.dataprepper.plugins.ml_inference.processor.configuration.AwsAuthenticationOptions; import org.opensearch.dataprepper.plugins.ml_inference.processor.configuration.ServiceName; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.Region; + import java.lang.reflect.Field; -import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; - import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -47,130 +53,233 @@ @ExtendWith(MockitoExtension.class) public class MLProcessorTest { - @Mock - private MLProcessorConfig mlProcessorConfig; - - @Mock - private PluginMetrics pluginMetrics; - - @Mock - private AwsCredentialsSupplier awsCredentialsSupplier; - - @Mock - private ExpressionEvaluator expressionEvaluator; - - @Mock - private MLBatchJobCreator mlBatchJobCreator; - - @Mock - private Counter successCounter; - - @Mock - private Counter failureCounter; - @Mock - private MLProcessor mlProcessor; - - @Mock - private AwsAuthenticationOptions awsAuthenticationOptions; - - @Mock - private AwsCredentialsProvider awsCredentialsProvider; - - @Mock - private PluginFactory pluginFactory; - - @Mock - private PluginSetting pluginSetting; - - @BeforeEach - void setUp() throws NoSuchFieldException, IllegalAccessException { + @Mock private MLProcessorConfig mlProcessorConfig; + @Mock private PluginMetrics pluginMetrics; + @Mock private AwsCredentialsSupplier awsCredentialsSupplier; + @Mock private ExpressionEvaluator expressionEvaluator; + @Mock private Counter successCounter; + @Mock private Counter failureCounter; + @Mock private AwsAuthenticationOptions awsAuthenticationOptions; + @Mock private AwsCredentialsProvider awsCredentialsProvider; + @Mock private PluginFactory pluginFactory; + @Mock private PluginSetting pluginSetting; + + private void setupCommonMocks() { MockitoAnnotations.openMocks(this); - when(mlProcessorConfig.getWhenCondition()).thenReturn("condition"); - lenient().when(expressionEvaluator.evaluateConditional(eq("condition"), any())).thenReturn(true); - lenient().when(mlProcessorConfig.getServiceName()).thenReturn(ServiceName.SAGEMAKER); lenient().when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_WEST_2); lenient().when(awsCredentialsSupplier.getProvider(any())).thenReturn(awsCredentialsProvider); lenient().when(mlProcessorConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); lenient().when(mlProcessorConfig.getDlqPluginSetting()).thenReturn(null); lenient().when(pluginMetrics.counter(NUMBER_OF_ML_PROCESSOR_SUCCESS)).thenReturn(successCounter); lenient().when(pluginMetrics.counter(NUMBER_OF_ML_PROCESSOR_FAILED)).thenReturn(failureCounter); - - mlProcessor = new MLProcessor(mlProcessorConfig, pluginMetrics, pluginFactory, pluginSetting, awsCredentialsSupplier, expressionEvaluator); - // Inject the mocked mlBatchJobCreator using reflection - Field field = MLProcessor.class.getDeclaredField("mlBatchJobCreator"); - field.setAccessible(true); - field.set(mlProcessor, mlBatchJobCreator); + lenient().when(pluginMetrics.counter(ModelSyncInferenceExecutor.NUMBER_OF_SYNC_INFERENCE_RECORDS_SUCCESS)).thenReturn(mock(io.micrometer.core.instrument.Counter.class)); + lenient().when(pluginMetrics.counter(ModelSyncInferenceExecutor.NUMBER_OF_SYNC_INFERENCE_RECORDS_FAILED)).thenReturn(mock(io.micrometer.core.instrument.Counter.class)); } - @Test - void testDoExecute_WithValidRecords() throws Exception { - Event event = mock(Event.class); - Record record = new Record<>(event); - List> records = Collections.singletonList(record); - - Collection> result = mlProcessor.doExecute(records); - - verify(mlBatchJobCreator, times(1)).addProcessedBatchRecordsToResults(new ArrayList<>()); - verify(mlBatchJobCreator, times(1)).createMLBatchJob(records, new ArrayList<>()); - verify(successCounter, times(1)).increment(); - } - - @Test - void testDoExecute_WithNoRecords() { - Collection> result = mlProcessor.doExecute(Collections.emptyList()); - - verifyNoInteractions(successCounter, failureCounter); - assertTrue(result.isEmpty()); + private MLProcessor buildProcessor(final ActionType actionType) throws Exception { + setupCommonMocks(); + when(mlProcessorConfig.getWhenCondition()).thenReturn("condition"); + lenient().when(expressionEvaluator.evaluateConditional(eq("condition"), any())).thenReturn(true); + lenient().when(mlProcessorConfig.getActionType()).thenReturn(actionType); + + if (ActionType.PREDICT.equals(actionType)) { + lenient().when(mlProcessorConfig.getModelId()).thenReturn("amazon.titan-embed-text-v2:0"); + lenient().when(mlProcessorConfig.getTagsOnFailure()).thenReturn(Collections.emptyList()); + } else { + lenient().when(mlProcessorConfig.getServiceName()).thenReturn(ServiceName.SAGEMAKER); + } + return new MLProcessor(mlProcessorConfig, pluginMetrics, pluginFactory, pluginSetting, awsCredentialsSupplier, expressionEvaluator); } - @Test - void testDoExecute_WithConditionNotMet() { - // Mock event and record - Event event = mock(Event.class); - Record record = new Record<>(event); - List> records = Collections.singletonList(record); - - // Mock the expression evaluator - when(expressionEvaluator.evaluateConditional(eq("condition"), any())).thenReturn(false); - - // Ensure no interaction with mlBatchJobCreator - Collection> result = mlProcessor.doExecute(records); - - // Verify no interactions with mlBatchJobCreator, successCounter, or failureCounter - verify(mlBatchJobCreator, times(1)).addProcessedBatchRecordsToResults(records); - verify(mlBatchJobCreator, times(1)).checkAndProcessBatch(); - verifyNoInteractions(successCounter, failureCounter); - - // Assert that the input records are returned as output - assertEquals(records, result); + private void injectExecutor(final MLProcessor processor, final MLActionExecutor executor) throws Exception { + final Field field = MLProcessor.class.getDeclaredField("actionExecutor"); + field.setAccessible(true); + field.set(processor, executor); } - @Test - void testDoExecute_WithException() throws Exception { - Event event = mock(Event.class); - Record record = new Record<>(event); - List> records = Collections.singletonList(record); - - doThrow(new RuntimeException("Test Exception")).when(mlBatchJobCreator).createMLBatchJob(records, new ArrayList<>()); - - Collection> result = mlProcessor.doExecute(records); - - verify(failureCounter, times(1)).increment(); + @Nested + class BatchPredictMode { + + private MLProcessor mlProcessor; + private BatchActionExecutor batchActionExecutor; + + @BeforeEach + void setUp() throws Exception { + batchActionExecutor = mock(BatchActionExecutor.class); + mlProcessor = buildProcessor(ActionType.BATCH_PREDICT); + injectExecutor(mlProcessor, batchActionExecutor); + } + + @Test + void testDoExecute_WithValidRecords() { + final Event event = mock(Event.class); + final Record record = new Record<>(event); + final List> records = Collections.singletonList(record); + + mlProcessor.doExecute(records); + + verify(batchActionExecutor, times(1)).prepareExecution(any()); + verify(batchActionExecutor, times(1)).execute(eq(records), any()); + verify(successCounter, times(1)).increment(); + } + + @Test + void testDoExecute_WithNoRecords() { + final Collection> result = mlProcessor.doExecute(Collections.emptyList()); + + verify(batchActionExecutor, times(1)).prepareExecution(any()); + verify(batchActionExecutor, never()).execute(any(), any()); + verifyNoInteractions(successCounter, failureCounter); + assertTrue(result.isEmpty()); + } + + @Test + void testDoExecute_WithConditionNotMet() { + final Event event = mock(Event.class); + final Record record = new Record<>(event); + final List> records = Collections.singletonList(record); + when(expressionEvaluator.evaluateConditional(eq("condition"), any())).thenReturn(false); + + final Collection> result = mlProcessor.doExecute(records); + + verify(batchActionExecutor, times(1)).prepareExecution(any()); + verify(batchActionExecutor, never()).execute(any(), any()); + verifyNoInteractions(successCounter, failureCounter); + assertEquals(records, result); + } + + @Test + void testDoExecute_WithException() { + final Event event = mock(Event.class); + final Record record = new Record<>(event); + final List> records = Collections.singletonList(record); + doThrow(new RuntimeException("Test Exception")).when(batchActionExecutor).execute(any(), any()); + + mlProcessor.doExecute(records); + + verify(failureCounter, times(1)).increment(); + } + + @Test + void testShutdownMethods() { + when(batchActionExecutor.isReadyForShutdown()).thenReturn(true); + + assertTrue(mlProcessor.isReadyForShutdown()); + mlProcessor.prepareForShutdown(); + mlProcessor.shutdown(); + + verify(batchActionExecutor).isReadyForShutdown(); + verify(batchActionExecutor).prepareForShutdown(); + verify(batchActionExecutor).shutdown(); + } } - @Test - void testShutdownMethods() { - when(mlBatchJobCreator.isReadyForShutdown()).thenReturn(true); - - assertTrue(mlProcessor.isReadyForShutdown()); - mlProcessor.prepareForShutdown(); - mlProcessor.shutdown(); - - // Verify that these methods were called on the batch job creator - verify(mlBatchJobCreator).isReadyForShutdown(); - verify(mlBatchJobCreator).prepareForShutdown(); - verify(mlBatchJobCreator).shutdown(); - + @Nested + class PredictMode { + + private MLProcessor mlProcessor; + private PredictActionExecutor predictActionExecutor; + + @BeforeEach + void setUp() throws Exception { + predictActionExecutor = mock(PredictActionExecutor.class); + mlProcessor = buildProcessor(ActionType.PREDICT); + injectExecutor(mlProcessor, predictActionExecutor); + } + + @Test + @SuppressWarnings("unchecked") + void testDoExecute_WithValidRecords_delegatesToExecutor() { + final Event event = mock(Event.class); + final Record record = new Record<>(event); + final List> records = Collections.singletonList(record); + doAnswer(invocation -> { + final List> result = invocation.getArgument(1); + result.addAll(records); + return result; + }).when(predictActionExecutor).execute(eq(records), any()); + + final Collection> result = mlProcessor.doExecute(records); + + verify(predictActionExecutor, times(1)).execute(eq(records), any()); + verify(successCounter, times(1)).increment(); + assertEquals(records, new java.util.ArrayList<>(result)); + } + + @Test + void testDoExecute_WithNoRecords_returnsEmpty() { + final Collection> result = mlProcessor.doExecute(Collections.emptyList()); + + verify(predictActionExecutor, never()).execute(any(), any()); + verifyNoInteractions(successCounter, failureCounter); + assertTrue(result.isEmpty()); + } + + @Test + void testDoExecute_WithConditionNotMet_skipsExecutor() { + final Event event = mock(Event.class); + final Record record = new Record<>(event); + final List> records = Collections.singletonList(record); + when(expressionEvaluator.evaluateConditional(eq("condition"), any())).thenReturn(false); + + final Collection> result = mlProcessor.doExecute(records); + + verify(predictActionExecutor, never()).execute(any(), any()); + verifyNoInteractions(successCounter, failureCounter); + assertEquals(records, result); + } + + @Test + void testDoExecute_WithConditionMet_passesFilteredRecordsToExecutor() { + final Event matchedEvent = mock(Event.class); + final Event skippedEvent = mock(Event.class); + final Record matchedRecord = new Record<>(matchedEvent); + final Record skippedRecord = new Record<>(skippedEvent); + final List> records = List.of(matchedRecord, skippedRecord); + + when(expressionEvaluator.evaluateConditional(eq("condition"), eq(matchedEvent))).thenReturn(true); + when(expressionEvaluator.evaluateConditional(eq("condition"), eq(skippedEvent))).thenReturn(false); + + final List> filteredRecords = Collections.singletonList(matchedRecord); + doAnswer(invocation -> { + final List> result = invocation.getArgument(1); + result.addAll(filteredRecords); + return result; + }).when(predictActionExecutor).execute(eq(filteredRecords), any()); + + final Collection> result = mlProcessor.doExecute(records); + + verify(predictActionExecutor, times(1)).execute(eq(filteredRecords), any()); + assertTrue(result.contains(matchedRecord)); + assertTrue(result.contains(skippedRecord)); + assertEquals(2, result.size()); + } + + @Test + void testDoExecute_ExecutorThrows_incrementsFailureCounter() { + final Event event = mock(Event.class); + final Record record = new Record<>(event); + final List> records = Collections.singletonList(record); + doThrow(new RuntimeException("predict failed")).when(predictActionExecutor).execute(any(), any()); + + mlProcessor.doExecute(records); + + verify(failureCounter, times(1)).increment(); + verifyNoInteractions(successCounter); + } + + @Test + void testShutdownMethods_areNoOps() { + when(predictActionExecutor.isReadyForShutdown()).thenReturn(true); + + assertTrue(mlProcessor.isReadyForShutdown()); + mlProcessor.prepareForShutdown(); + mlProcessor.shutdown(); + + verify(predictActionExecutor).isReadyForShutdown(); + verify(predictActionExecutor).prepareForShutdown(); + verify(predictActionExecutor).shutdown(); + } } } diff --git a/data-prepper-plugins/ml-inference-processor/src/test/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/ModelSyncInferenceExecutorTest.java b/data-prepper-plugins/ml-inference-processor/src/test/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/ModelSyncInferenceExecutorTest.java new file mode 100644 index 0000000000..18079f510c --- /dev/null +++ b/data-prepper-plugins/ml-inference-processor/src/test/java/org/opensearch/dataprepper/plugins/ml_inference/processor/common/ModelSyncInferenceExecutorTest.java @@ -0,0 +1,261 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.ml_inference.processor.common; + +import io.micrometer.core.instrument.Counter; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.ml_inference.processor.MLProcessorConfig; +import org.opensearch.dataprepper.plugins.ml_inference.processor.configuration.AwsAuthenticationOptions; +import org.opensearch.dataprepper.plugins.ml_inference.processor.connector.ConnectorActionType; +import org.opensearch.dataprepper.plugins.ml_inference.processor.connector.RemoteConnectorExecutor; +import org.opensearch.dataprepper.plugins.ml_inference.processor.exception.MLBatchJobException; +import software.amazon.awssdk.regions.Region; + +import java.lang.reflect.Field; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.dataprepper.plugins.ml_inference.processor.common.ModelSyncInferenceExecutor.NUMBER_OF_SYNC_INFERENCE_RECORDS_FAILED; +import static org.opensearch.dataprepper.plugins.ml_inference.processor.common.ModelSyncInferenceExecutor.NUMBER_OF_SYNC_INFERENCE_RECORDS_SUCCESS; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +class ModelSyncInferenceExecutorTest { + + @Mock private MLProcessorConfig config; + @Mock private AwsCredentialsSupplier awsCredentialsSupplier; + @Mock private AwsAuthenticationOptions awsAuthOptions; + @Mock private RemoteConnectorExecutor connectorExecutor; + @Mock private Event event; + @Mock private PluginMetrics pluginMetrics; + @Mock private Counter successCounter; + @Mock private Counter failureCounter; + + private ModelSyncInferenceExecutor predictProcessor; + + @BeforeEach + void setUp() throws Exception { + when(config.getModelId()).thenReturn("amazon.titan-embed-text-v2:0"); + when(config.getAwsAuthenticationOptions()).thenReturn(awsAuthOptions); + when(awsAuthOptions.getAwsRegion()).thenReturn(Region.US_EAST_1); + when(config.getTagsOnFailure()).thenReturn(Collections.emptyList()); + when(config.getInputMap()).thenReturn(List.of(Map.of("inputText", "passage_text"))); + when(config.getOutputMap()).thenReturn(List.of(Map.of("passage_embedding", "embedding"))); + when(pluginMetrics.counter(NUMBER_OF_SYNC_INFERENCE_RECORDS_SUCCESS)).thenReturn(successCounter); + when(pluginMetrics.counter(NUMBER_OF_SYNC_INFERENCE_RECORDS_FAILED)).thenReturn(failureCounter); + + predictProcessor = new ModelSyncInferenceExecutor(config, awsCredentialsSupplier, pluginMetrics); + injectConnectorExecutor(predictProcessor, connectorExecutor); + } + + private ModelSyncInferenceExecutor buildAndInject() throws Exception { + final ModelSyncInferenceExecutor executor = new ModelSyncInferenceExecutor(config, awsCredentialsSupplier, pluginMetrics); + injectConnectorExecutor(executor, connectorExecutor); + return executor; + } + + @Test + void execute_success_writesEmbeddingToEvent() { + when(event.get("passage_text", String.class)).thenReturn("hello world"); + when(connectorExecutor.executeActionAndGetResponse(eq(ConnectorActionType.PREDICT), any())) + .thenReturn("{\"embedding\":[0.1,0.2,0.3]}"); + + final Collection> results = predictProcessor.execute(List.of(new Record<>(event))); + + assertThat(results, hasSize(1)); + verify(event).put(eq("passage_embedding"), any()); + verify(successCounter, times(1)).increment(); + verify(failureCounter, never()).increment(); + } + + @Test + void execute_multipleInputMappings_invokesModelForEachMapping() throws Exception { + when(config.getInputMap()).thenReturn(List.of( + Map.of("inputText", "field_a"), + Map.of("inputText", "field_b") + )); + when(config.getOutputMap()).thenReturn(List.of( + Map.of("embedding_a", "/embedding"), + Map.of("embedding_b", "/embedding") + )); + when(event.get("field_a", String.class)).thenReturn("text a"); + when(event.get("field_b", String.class)).thenReturn("text b"); + when(connectorExecutor.executeActionAndGetResponse(eq(ConnectorActionType.PREDICT), any())) + .thenReturn("{\"embedding\":[0.1,0.2]}"); + + predictProcessor = buildAndInject(); + + final Collection> results = predictProcessor.execute(List.of(new Record<>(event))); + + assertThat(results, hasSize(1)); + verify(event).put(eq("embedding_a"), any()); + verify(event).put(eq("embedding_b"), any()); + verify(successCounter, times(1)).increment(); + } + + @Test + void execute_missingInputField_incrementsFailureCounterAndReturnsRecord() throws Exception { + when(config.getTagsOnFailure()).thenReturn(List.of("_ml_inference_failure")); + when(event.get("passage_text", String.class)).thenReturn(null); + when(event.getMetadata()).thenReturn(null); + + predictProcessor = buildAndInject(); + + final Collection> results = predictProcessor.execute(List.of(new Record<>(event))); + + assertThat(results, hasSize(1)); + verify(connectorExecutor, never()).executeActionAndGetResponse(any(), any()); + verify(failureCounter, times(1)).increment(); + verify(successCounter, never()).increment(); + } + + @Test + void execute_modelReturnsNestedOutputPath_extractsCorrectly() throws Exception { + when(config.getOutputMap()).thenReturn(List.of(Map.of("passage_embedding", "modelOutput/embedding"))); + when(event.get("passage_text", String.class)).thenReturn("hello"); + when(connectorExecutor.executeActionAndGetResponse(eq(ConnectorActionType.PREDICT), any())) + .thenReturn("{\"modelOutput\":{\"embedding\":[0.5,0.6]}}"); + + predictProcessor = buildAndInject(); + + final Collection> results = predictProcessor.execute(List.of(new Record<>(event))); + + assertThat(results, hasSize(1)); + verify(event).put(eq("passage_embedding"), any()); + verify(successCounter, times(1)).increment(); + } + + @Test + void execute_remoteServiceThrows_incrementsFailureCounterAndReturnsRecord() throws Exception { + when(config.getTagsOnFailure()).thenReturn(List.of("_ml_failure")); + when(event.get("passage_text", String.class)).thenReturn("hello"); + when(event.getMetadata()).thenReturn(null); + when(connectorExecutor.executeActionAndGetResponse(any(), any())) + .thenThrow(new MLBatchJobException(500, "server error")); + + predictProcessor = buildAndInject(); + + final Collection> results = predictProcessor.execute(List.of(new Record<>(event))); + + assertThat(results, hasSize(1)); + verify(failureCounter, times(1)).increment(); + verify(successCounter, never()).increment(); + } + + @Test + void execute_emptyRecords_returnsEmpty() { + final Collection> results = predictProcessor.execute(Collections.emptyList()); + + assertThat(results, hasSize(0)); + verify(successCounter, never()).increment(); + verify(failureCounter, never()).increment(); + } + + @Test + void execute_outputPathNotFound_incrementsFailureCounterAndReturnsRecord() { + when(event.get("passage_text", String.class)).thenReturn("hello"); + when(event.getMetadata()).thenReturn(null); + when(connectorExecutor.executeActionAndGetResponse(any(), any())) + .thenReturn("{\"other_field\":\"value\"}"); + + final Collection> results = predictProcessor.execute(List.of(new Record<>(event))); + + assertThat(results, hasSize(1)); + verify(failureCounter, times(1)).increment(); + verify(successCounter, never()).increment(); + } + + @Test + void execute_noOutputMap_doesNotWriteToEvent() throws Exception { + when(config.getOutputMap()).thenReturn(Collections.emptyList()); + when(event.get("passage_text", String.class)).thenReturn("hello"); + when(connectorExecutor.executeActionAndGetResponse(any(), any())) + .thenReturn("{\"embedding\":[0.1]}"); + + predictProcessor = buildAndInject(); + + final Collection> results = predictProcessor.execute(List.of(new Record<>(event))); + + assertThat(results, hasSize(1)); + verify(event, never()).put(any(String.class), any()); + verify(successCounter, times(1)).increment(); + } + + @Test + void execute_multipleRecords_countsEachIndependently() throws Exception { + final Event successEvent = event; + final Event failureEvent = org.mockito.Mockito.mock(Event.class); + when(successEvent.get("passage_text", String.class)).thenReturn("hello"); + when(failureEvent.get("passage_text", String.class)).thenReturn(null); + when(failureEvent.getMetadata()).thenReturn(null); + when(connectorExecutor.executeActionAndGetResponse(any(), any())) + .thenReturn("{\"embedding\":[0.1]}"); + + final Collection> results = predictProcessor.execute( + List.of(new Record<>(successEvent), new Record<>(failureEvent))); + + assertThat(results, hasSize(2)); + verify(successCounter, times(1)).increment(); + verify(failureCounter, times(1)).increment(); + } + + @Test + void execute_passesRegionInParameters() { + when(event.get("passage_text", String.class)).thenReturn("hello"); + when(connectorExecutor.executeActionAndGetResponse(eq(ConnectorActionType.PREDICT), any())) + .thenReturn("{\"embedding\":[0.1]}"); + + predictProcessor.execute(List.of(new Record<>(event))); + + verify(connectorExecutor).executeActionAndGetResponse( + eq(ConnectorActionType.PREDICT), + argThatContainsRegion("us-east-1") + ); + } + + // --- helpers --- + + private Map argThatContainsRegion(final String region) { + return org.mockito.ArgumentMatchers.argThat( + params -> params != null && region.equals(params.get("region"))); + } + + private void injectConnectorExecutor(final ModelSyncInferenceExecutor processor, + final RemoteConnectorExecutor executor) { + try { + final Field field = ModelSyncInferenceExecutor.class.getDeclaredField("connectorExecutor"); + field.setAccessible(true); + field.set(processor, executor); + } catch (final Exception e) { + throw new RuntimeException(e); + } + } +}