Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ public enum MappingType {
ScaledFloat("scaled_float", ExprCoreType.DOUBLE),
Double("double", ExprCoreType.DOUBLE),
Boolean("boolean", ExprCoreType.BOOLEAN),
Alias("alias", ExprCoreType.UNKNOWN);
Alias("alias", ExprCoreType.UNKNOWN),
KnnVector("knn_vector", ExprCoreType.ARRAY);
// TODO: ranges, geo shape, point, shape

private final String name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

import static org.opensearch.sql.utils.SystemIndexUtils.isSystemIndex;

import java.util.Collection;
import java.util.List;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.DataSourceSchemaName;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.expression.function.FunctionResolver;
import org.opensearch.sql.opensearch.client.OpenSearchClient;
import org.opensearch.sql.opensearch.storage.system.OpenSearchSystemIndex;
import org.opensearch.sql.storage.StorageEngine;
Expand All @@ -25,6 +28,11 @@ public class OpenSearchStorageEngine implements StorageEngine {

@Getter private final Settings settings;

@Override
public Collection<FunctionResolver> getFunctions() {
return List.of(new VectorSearchTableFunctionResolver(client, settings));
}

@Override
public Table getTable(DataSourceSchemaName dataSourceSchemaName, String name) {
if (isSystemIndex(name)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.opensearch.storage;

import java.util.Map;
import java.util.function.Function;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.WrapperQueryBuilder;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.opensearch.client.OpenSearchClient;
import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder;
import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan;
import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanBuilder;
import org.opensearch.sql.opensearch.storage.scan.VectorSearchQueryBuilder;
import org.opensearch.sql.storage.read.TableScanBuilder;

/**
* Vector-search-aware OpenSearch index. Seeds the scan with a knn query and enables score tracking.
*/
public class VectorSearchIndex extends OpenSearchIndex {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of separate VectorSearchQueryBuilder and VectorSearchIndex, just wonder can we abstract vector search translation logic by LuceneQuery interface?

@mengweieric mengweieric Apr 8, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I split it this way because the two classes currently own different seams: VectorSearchIndex seeds the scan/request state (knn query + track_scores + size=k default) while VectorSearchQueryBuilder handles score-preserving filter pushdown. I agree a shared abstraction could make sense if we see a second consumer or want a cleaner bridge for the unified SQL path, but I'd prefer not to turn PR-2 into a broader refactor unless there's an existing LuceneQuery contract that fits both responsibilities cleanly.


private final String field;
private final float[] vector;
private final Map<String, String> options;

public VectorSearchIndex(
OpenSearchClient client,
Settings settings,
String indexName,
String field,
float[] vector,
Map<String, String> options) {
super(client, settings, indexName);
this.field = field;
this.vector = vector;
this.options = options;
}

@Override
public TableScanBuilder createScanBuilder() {
final TimeValue cursorKeepAlive =
getSettings().getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE);
var requestBuilder = createRequestBuilder();

// Use VectorSearchQueryBuilder to keep knn in must (scoring) context.
// WHERE filters will be placed in filter (non-scoring) context.
var queryBuilder = new VectorSearchQueryBuilder(requestBuilder, buildKnnQuery());
requestBuilder.pushDownTrackedScore(true);

// Top-k mode: default size to k so queries without LIMIT return k results
// instead of falling into the generic large-scan path.
// LIMIT pushdown will further reduce this if present.
if (options.containsKey("k")) {
requestBuilder.pushDownLimitToRequestTotal(Integer.parseInt(options.get("k")), 0);
}

Function<OpenSearchRequestBuilder, OpenSearchIndexScan> createScanOperator =
rb ->
new OpenSearchIndexScan(
getClient(),
rb.getMaxResponseSize(),
rb.build(getIndexName(), cursorKeepAlive, getClient(), getFieldTypes().isEmpty()));
return new OpenSearchIndexScanBuilder(queryBuilder, createScanOperator);
}

private QueryBuilder buildKnnQuery() {
StringBuilder vectorJson = new StringBuilder("[");
for (int i = 0; i < vector.length; i++) {
if (i > 0) vectorJson.append(",");
vectorJson.append(vector[i]);
}
vectorJson.append("]");

StringBuilder optionsJson = new StringBuilder();
for (Map.Entry<String, String> entry : options.entrySet()) {
optionsJson.append(",");
String value = entry.getValue();
// All P0 option values are canonicalized to numeric strings by validateOptions().
// The quoted fallback is retained for forward compatibility with future non-numeric options.
if (isNumeric(value)) {
optionsJson.append(String.format("\"%s\":%s", entry.getKey(), value));
} else {
optionsJson.append(String.format("\"%s\":\"%s\"", entry.getKey(), value));
}
}

String knnQueryJson =
String.format(
"{\"knn\":{\"%s\":{\"vector\":%s%s}}}",
field, vectorJson.toString(), optionsJson.toString());
return new WrapperQueryBuilder(knnQueryJson);
}

private static boolean isNumeric(String str) {
try {
Double.parseDouble(str);
return true;
} catch (NumberFormatException e) {
return false;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.opensearch.storage;

import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.FIELD;
import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.OPTION;
import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.TABLE;
import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.VECTOR;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.ExpressionEvaluationException;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.NamedArgumentExpression;
import org.opensearch.sql.expression.env.Environment;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.TableFunctionImplementation;
import org.opensearch.sql.opensearch.client.OpenSearchClient;
import org.opensearch.sql.storage.Table;

public class VectorSearchTableFunctionImplementation extends FunctionExpression
implements TableFunctionImplementation {

/** P0 allowed option keys. Rejects unknown/future keys to prevent unvalidated DSL injection. */
static final Set<String> ALLOWED_OPTION_KEYS = Set.of("k", "max_distance", "min_score");

/**
* Field names must be safe for JSON interpolation: alphanumeric, dots (nested), underscores,
* hyphens. Rejects characters that could corrupt the WrapperQueryBuilder JSON.
*/
private static final Pattern SAFE_FIELD_NAME = Pattern.compile("^[a-zA-Z0-9._\\-]+$");

private final FunctionName functionName;
private final List<Expression> arguments;
private final OpenSearchClient client;
private final Settings settings;

public VectorSearchTableFunctionImplementation(
FunctionName functionName,
List<Expression> arguments,
OpenSearchClient client,
Settings settings) {
super(functionName, arguments);
this.functionName = functionName;
this.arguments = arguments;
this.client = client;
this.settings = settings;
}

@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
throw new UnsupportedOperationException(
String.format("vectorSearch function [%s] is only supported in FROM clause", functionName));
}

@Override
public ExprType type() {
return ExprCoreType.STRUCT;
}

@Override
public String toString() {
List<String> args =
arguments.stream()
.map(
arg -> {
if (arg instanceof NamedArgumentExpression) {
NamedArgumentExpression named = (NamedArgumentExpression) arg;
return String.format("%s=%s", named.getArgName(), named.getValue().toString());
}
return arg.toString();
})
.collect(Collectors.toList());
return String.format("%s(%s)", functionName, String.join(", ", args));
}

@Override
public Table applyArguments() {
validateNamedArgs();
String tableName = getArgumentValue(TABLE);
String fieldName = getArgumentValue(FIELD);
validateFieldName(fieldName);
String vectorLiteral = getArgumentValue(VECTOR);
String optionStr = getArgumentValue(OPTION);

float[] vector = parseVector(vectorLiteral);
Map<String, String> options = parseOptions(optionStr);
validateOptions(options);

return new VectorSearchIndex(client, settings, tableName, fieldName, vector, options);
}

private float[] parseVector(String vectorLiteral) {
String cleaned = vectorLiteral.replaceAll("[\\[\\]]", "").trim();
if (cleaned.isEmpty()) {
throw new ExpressionEvaluationException("Vector literal must not be empty");
}
String[] parts = cleaned.split(",");
float[] vector = new float[parts.length];
for (int i = 0; i < parts.length; i++) {
try {
vector[i] = Float.parseFloat(parts[i].trim());
} catch (NumberFormatException e) {
throw new ExpressionEvaluationException(
String.format("Invalid vector component '%s': must be a number", parts[i].trim()));
}
if (!Float.isFinite(vector[i])) {
throw new ExpressionEvaluationException(
String.format(
"Invalid vector component '%s': must be a finite number", parts[i].trim()));
}
}
return vector;
}

static Map<String, String> parseOptions(String optionStr) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea to simplify the parse/validate logic? Probably double check if some is already guarded by grammar or if it can be done by existing visit method.

@mengweieric mengweieric Apr 8, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the grammar/visitor angle. We can lean on parser/resolver for the outer function shape and arity, but option still reaches this implementation as a single STRING argument, so the inner k=...,max_distance=... content is not visible to grammar/visit methods. The remaining checks like malformed segments, duplicate keys, unknown keys, and numeric parsing are semantic validation at this layer.

I can simplify the code shape, though: for example by extracting the option parsing/validation into a small helper/value object like VectorSearchOptions.parse(optionStr), so this class reads more like orchestration.

Map<String, String> options = new LinkedHashMap<>();
for (String pair : optionStr.split(",")) {
String trimmed = pair.trim();
if (trimmed.isEmpty()) {
continue;
}
String[] kv = trimmed.split("=", 2);
if (kv.length != 2 || kv[0].trim().isEmpty() || kv[1].trim().isEmpty()) {
throw new ExpressionEvaluationException(
String.format("Malformed option segment '%s': expected key=value", trimmed));
}
String key = kv[0].trim();
if (options.containsKey(key)) {
throw new ExpressionEvaluationException(String.format("Duplicate option key '%s'", key));
}
options.put(key, kv[1].trim());
}
return options;
}

/** Reject non-named arguments and null arg names early. */
private void validateNamedArgs() {
for (Expression arg : arguments) {
if (!(arg instanceof NamedArgumentExpression)) {
throw new ExpressionEvaluationException(
"vectorSearch() requires named arguments (e.g., table='index'), "
+ "but received: "
+ arg.getClass().getSimpleName());
}
String name = ((NamedArgumentExpression) arg).getArgName();
if (name == null || name.isEmpty()) {
throw new ExpressionEvaluationException(
"vectorSearch() requires named arguments (e.g., table='index'), "
+ "but received an argument with no name");
}
}
}

/**
* Reject field names with characters that could corrupt the WrapperQueryBuilder JSON. Allows
* alphanumeric, dots (nested fields), underscores, and hyphens.
*/
private void validateFieldName(String fieldName) {
if (!SAFE_FIELD_NAME.matcher(fieldName).matches()) {
throw new ExpressionEvaluationException(
String.format(
"Invalid field name '%s': must contain only alphanumeric characters,"
+ " dots, underscores, or hyphens",
fieldName));
}
}

/**
* Validates and canonicalizes option values. All P0 option values must be numeric. Parsing them
* here prevents non-numeric strings from reaching the raw JSON construction in buildKnnQuery().
*/
private void validateOptions(Map<String, String> options) {
// Reject unknown option keys — only P0 keys are allowed
for (String key : options.keySet()) {
if (!ALLOWED_OPTION_KEYS.contains(key)) {
throw new ExpressionEvaluationException(
String.format("Unknown option key '%s'. Supported keys: %s", key, ALLOWED_OPTION_KEYS));
}
}
boolean hasK = options.containsKey("k");
boolean hasMaxDistance = options.containsKey("max_distance");
boolean hasMinScore = options.containsKey("min_score");
if (!hasK && !hasMaxDistance && !hasMinScore) {
Comment thread
penghuo marked this conversation as resolved.
throw new ExpressionEvaluationException(
"Missing required option: one of k, max_distance, or min_score");
}
// Parse and canonicalize numeric values — closes JSON injection via option values
if (hasK) {
parseIntOption(options, "k");
}
if (hasMaxDistance) {
parseDoubleOption(options, "max_distance");
}
if (hasMinScore) {
parseDoubleOption(options, "min_score");
}
}

private void parseIntOption(Map<String, String> options, String key) {
try {
int value = Integer.parseInt(options.get(key));
options.put(key, Integer.toString(value));
} catch (NumberFormatException e) {
throw new ExpressionEvaluationException(
String.format("Option '%s' must be an integer, got '%s'", key, options.get(key)));
}
}

private void parseDoubleOption(Map<String, String> options, String key) {
try {
double value = Double.parseDouble(options.get(key));
if (!Double.isFinite(value)) {
throw new ExpressionEvaluationException(
String.format("Option '%s' must be a finite number, got '%s'", key, options.get(key)));
}
options.put(key, Double.toString(value));
} catch (NumberFormatException e) {
throw new ExpressionEvaluationException(
String.format("Option '%s' must be a number, got '%s'", key, options.get(key)));
}
}

private String getArgumentValue(String name) {
return arguments.stream()
.filter(arg -> ((NamedArgumentExpression) arg).getArgName().equalsIgnoreCase(name))
.map(arg -> ((NamedArgumentExpression) arg).getValue().valueOf().stringValue())
.findFirst()
.orElseThrow(
() ->
new ExpressionEvaluationException(
String.format("Missing required argument: %s", name)));
}
}
Loading
Loading