Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
import org.opensearch.dataprepper.model.processor.AbstractProcessor;
import org.opensearch.dataprepper.model.processor.Processor;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.ml_inference.processor.common.MLBatchJobCreator;
import org.opensearch.dataprepper.plugins.ml_inference.processor.common.BatchActionExecutor;
import org.opensearch.dataprepper.plugins.ml_inference.processor.common.MLActionExecutor;
import org.opensearch.dataprepper.plugins.ml_inference.processor.common.MLBatchJobCreatorFactory;
import org.opensearch.dataprepper.plugins.ml_inference.processor.common.ModelSyncInferenceExecutor;
import org.opensearch.dataprepper.plugins.ml_inference.processor.common.PredictActionExecutor;
import org.opensearch.dataprepper.plugins.ml_inference.processor.configuration.ActionType;
import org.opensearch.dataprepper.plugins.ml_inference.processor.configuration.ServiceName;
import org.opensearch.dataprepper.plugins.ml_inference.processor.dlq.DlqPushHandler;
import org.opensearch.dataprepper.plugins.ml_inference.processor.exception.MLBatchJobException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -40,7 +43,7 @@ public class MLProcessor extends AbstractProcessor<Record<Event>, Record<Event>>
public static final String NUMBER_OF_ML_PROCESSOR_FAILED = "BatchJobRequestsFailed";

private final String whenCondition;
private final MLBatchJobCreator mlBatchJobCreator;
private final MLActionExecutor actionExecutor;
private final Counter numberOfMLProcessorSuccessCounter;
private final Counter numberOfMLProcessorFailedCounter;
private final ExpressionEvaluator expressionEvaluator;
Expand All @@ -52,87 +55,89 @@ public class MLProcessor extends AbstractProcessor<Record<Event>, Record<Event>>
public MLProcessor(final MLProcessorConfig mlProcessorConfig, final PluginMetrics pluginMetrics, final PluginFactory pluginFactory, final PluginSetting pluginSetting, final AwsCredentialsSupplier awsCredentialsSupplier, final ExpressionEvaluator expressionEvaluator) {
super(pluginMetrics);
this.whenCondition = mlProcessorConfig.getWhenCondition();
ServiceName serviceName = mlProcessorConfig.getServiceName();
this.numberOfMLProcessorSuccessCounter = pluginMetrics.counter(
NUMBER_OF_ML_PROCESSOR_SUCCESS);
this.numberOfMLProcessorFailedCounter = pluginMetrics.counter(
NUMBER_OF_ML_PROCESSOR_FAILED);
this.numberOfMLProcessorSuccessCounter = pluginMetrics.counter(NUMBER_OF_ML_PROCESSOR_SUCCESS);
this.numberOfMLProcessorFailedCounter = pluginMetrics.counter(NUMBER_OF_ML_PROCESSOR_FAILED);
this.expressionEvaluator = expressionEvaluator;
this.pluginSetting = pluginSetting;

if (mlProcessorConfig.getDlqPluginSetting() != null) {
this.dlqPushHandler = new DlqPushHandler(pluginFactory, pluginSetting, mlProcessorConfig.getDlq(), mlProcessorConfig.getAwsAuthenticationOptions());
}

// Use factory to get the appropriate job creator
mlBatchJobCreator = MLBatchJobCreatorFactory.getJobCreator(serviceName, mlProcessorConfig, awsCredentialsSupplier, pluginMetrics, dlqPushHandler);
if (ActionType.PREDICT.equals(mlProcessorConfig.getActionType())) {
this.actionExecutor = new PredictActionExecutor(new ModelSyncInferenceExecutor(mlProcessorConfig, awsCredentialsSupplier, pluginMetrics));
} else {
final ServiceName serviceName = mlProcessorConfig.getServiceName();
this.actionExecutor = new BatchActionExecutor(MLBatchJobCreatorFactory.getJobCreator(serviceName, mlProcessorConfig, awsCredentialsSupplier, pluginMetrics, dlqPushHandler));
}
}

@Override
public Collection<Record<Event>> doExecute(Collection<Record<Event>> records) {
List<Record<Event>> resultRecords = new ArrayList<>();
// check and process any existing batch
mlBatchJobCreator.checkAndProcessBatch();
// Add processed records to results
mlBatchJobCreator.addProcessedBatchRecordsToResults(resultRecords);
// reads from input - S3 input
if (records.size() == 0)
public Collection<Record<Event>> doExecute(final Collection<Record<Event>> records) {
final List<Record<Event>> resultRecords = new ArrayList<>();

actionExecutor.prepareExecution(resultRecords);

if (records.isEmpty()) {
return resultRecords;
}

// Process new records
List<Record<Event>> recordsToMlCommons = records.stream()
.filter(record -> {
try {
boolean meetCondition = whenCondition == null || expressionEvaluator.evaluateConditional(whenCondition, record.getData());
if (!meetCondition) {
resultRecords.add(record);
}
return meetCondition; // Include in recordsToMlCommons if true
} catch (ExpressionParsingException e) {
LOG.warn("Expression parsing failed for record: {}. Error: {}", record, e.getMessage());
resultRecords.add(record);
return false; // Skip the record on parsing failure
} catch (ClassCastException e) {
LOG.warn("Unexpected return type when evaluating condition for record: {}. Error: {}", record, e.getMessage());
resultRecords.add(record);
return false; // Skip the record on type mismatch
} catch (Exception e) {
LOG.error("Failed to evaluate conditional expression for record: {}", record, e);
resultRecords.add(record);
return false; // Skip the record if evaluation fails
}
})
.collect(Collectors.toList());

if (recordsToMlCommons.isEmpty()) {
final List<Record<Event>> filteredRecords = filterByCondition(records, resultRecords);
if (filteredRecords.isEmpty()) {
return resultRecords;
}

try {
mlBatchJobCreator.createMLBatchJob(recordsToMlCommons, resultRecords);
actionExecutor.execute(filteredRecords, resultRecords);
numberOfMLProcessorSuccessCounter.increment();
} catch (MLBatchJobException e) {
LOG.error(NOISY, "ML Batch job creation failed: {}", e.getMessage());
numberOfMLProcessorFailedCounter.increment();
} catch (Exception e) {
LOG.error(NOISY, "Unexpected Error occurred while creating the batch job: {}", e.getMessage(), e);
} catch (final Exception e) {
LOG.error(NOISY, "Unexpected error during ML processing: {}", e.getMessage(), e);
numberOfMLProcessorFailedCounter.increment();
}

return resultRecords;
}

private List<Record<Event>> filterByCondition(final Collection<Record<Event>> records,
final List<Record<Event>> resultRecords) {
return records.stream()
.filter(record -> {
try {
final boolean meetCondition = whenCondition == null
|| expressionEvaluator.evaluateConditional(whenCondition, record.getData());
if (!meetCondition) {
resultRecords.add(record);
}
return meetCondition;
} catch (ExpressionParsingException e) {
LOG.warn("Expression parsing failed for record: {}. Error: {}", record, e.getMessage());
resultRecords.add(record);
return false;
} catch (ClassCastException e) {
LOG.warn("Unexpected return type when evaluating condition for record: {}. Error: {}", record, e.getMessage());
resultRecords.add(record);
return false;
} catch (Exception e) {
LOG.error("Failed to evaluate conditional expression for record: {}", record, e);
resultRecords.add(record);
return false;
}
})
.collect(Collectors.toList());
}

@Override
public void prepareForShutdown() {
mlBatchJobCreator.prepareForShutdown();
actionExecutor.prepareForShutdown();
}

@Override
public boolean isReadyForShutdown() {
return mlBatchJobCreator.isReadyForShutdown();
return actionExecutor.isReadyForShutdown();
}

@Override
public void shutdown() {
mlBatchJobCreator.shutdown();
actionExecutor.shutdown();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,17 @@ public class MLProcessorConfig {
@DurationMax(seconds = 300)
private Duration retryInterval = Duration.ofSeconds(DEFAULT_RETRY_INTERVAL_SECONDS);

@JsonPropertyDescription("Maps document fields to model input fields for the predict action type. "
+ "Each element is a map of <model_input_field>: <document_field>.")
@JsonProperty("input_map")
private List<Map<String, String>> inputMap;

@JsonPropertyDescription("Maps model output fields to new document fields for the predict action type. "
+ "Each element is a map of <new_document_field>: <model_output_field>.")
@JsonProperty("output_map")
private List<Map<String, String>> outputMap;


@JsonProperty("dlq")
private PluginModel dlq;

Expand All @@ -126,6 +137,11 @@ public String getWhenCondition() {

public List<String> getTagsOnFailure() { return tagsOnFailure; }

public List<Map<String, String>> getInputMap() { return inputMap; }

public List<Map<String, String>> getOutputMap() { return outputMap; }


public PluginModel getDlq() {
return dlq;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
*/

package org.opensearch.dataprepper.plugins.ml_inference.processor.common;

import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.ml_inference.processor.exception.MLBatchJobException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collection;
import java.util.List;

import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY;

public class BatchActionExecutor implements MLActionExecutor {

private static final Logger LOG = LoggerFactory.getLogger(BatchActionExecutor.class);

private final MLBatchJobCreator mlBatchJobCreator;

public BatchActionExecutor(final MLBatchJobCreator mlBatchJobCreator) {
this.mlBatchJobCreator = mlBatchJobCreator;
}

@Override
public void prepareExecution(final List<Record<Event>> resultRecords) {
mlBatchJobCreator.checkAndProcessBatch();
mlBatchJobCreator.addProcessedBatchRecordsToResults(resultRecords);
}

@Override
public Collection<Record<Event>> execute(final List<Record<Event>> filteredRecords,
final List<Record<Event>> resultRecords) {
try {
mlBatchJobCreator.createMLBatchJob(filteredRecords, resultRecords);
} catch (final MLBatchJobException e) {
LOG.error(NOISY, "ML Batch job creation failed: {}", e.getMessage());
throw e;
}
return resultRecords;
}

@Override
public void prepareForShutdown() {
mlBatchJobCreator.prepareForShutdown();
}

@Override
public boolean isReadyForShutdown() {
return mlBatchJobCreator.isReadyForShutdown();
}

@Override
public void shutdown() {
mlBatchJobCreator.shutdown();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
*/

package org.opensearch.dataprepper.plugins.ml_inference.processor.common;

import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;

import java.util.Collection;
import java.util.List;

public interface MLActionExecutor {
default void prepareExecution(List<Record<Event>> resultRecords) {}

Collection<Record<Event>> execute(List<Record<Event>> filteredRecords, List<Record<Event>> resultRecords);

default void prepareForShutdown() {}

default boolean isReadyForShutdown() { return true; }

default void shutdown() {}
}
Loading