Skip to content

Commit efc7ab0

Browse files
committed
[BugFix] Validate table name and reject duplicate or positional args in vectorSearch()
Three shape-level validation gaps on vectorSearch() previously crashed the server or returned a misleading 200 with zero rows: * Validate table= with the same SAFE_FIELD_NAME regex that already guards field=. Previously a hostile table name could reach the native layer unvalidated. * Reject duplicate named arguments in both the Resolver and the Implementation. The same name appearing twice (e.g. table=a, table=b) previously produced a 500 ArrayIndexOutOfBoundsException. * Shape-check arguments at the Resolver (not just arity) so positional or unknown-named arguments surface as a clean 400 before planning, instead of silently returning 200 with an empty DSL. Implementation keeps the original check as defense-in-depth. Unit tests cover all three paths. Integration tests assert that the SQL layer surfaces the three errors as user-facing 400s. Signed-off-by: Eric Wei <mengwei.eric@gmail.com>
1 parent a33e68c commit efc7ab0

5 files changed

Lines changed: 243 additions & 2 deletions

File tree

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.sql;
7+
8+
import static org.hamcrest.Matchers.containsString;
9+
10+
import java.io.IOException;
11+
import org.junit.Test;
12+
import org.opensearch.client.ResponseException;
13+
import org.opensearch.sql.legacy.SQLIntegTestCase;
14+
15+
/**
16+
* Integration tests for vectorSearch SQL table function argument validation. These tests assert
17+
* that argument-shape errors surface as a clean 400 with a user-facing message rather than a 500
18+
* crash or a silent 200 with zero rows.
19+
*/
20+
public class VectorSearchIT extends SQLIntegTestCase {
21+
22+
@Override
23+
protected void init() throws Exception {
24+
loadIndex(Index.ACCOUNT);
25+
}
26+
27+
@Test
28+
public void testInvalidTableNameRejected() throws IOException {
29+
// A slash is outside the SAFE_FIELD_NAME regex and is not a valid OpenSearch index character,
30+
// so it should be rejected at the SQL layer before any cluster call is attempted.
31+
ResponseException ex =
32+
expectThrows(
33+
ResponseException.class,
34+
() ->
35+
executeQuery(
36+
"SELECT v._id FROM vectorSearch(table='idx/evil', field='f', "
37+
+ "vector='[1.0]', option='k=5') AS v"));
38+
39+
assertThat(ex.getMessage(), containsString("Invalid table name"));
40+
}
41+
42+
@Test
43+
public void testDuplicateNamedArgRejected() throws IOException {
44+
// Previously this crashed the server with 500 ArrayIndexOutOfBoundsException. Must now
45+
// surface as a clean 400 with a user-facing message.
46+
ResponseException ex =
47+
expectThrows(
48+
ResponseException.class,
49+
() ->
50+
executeQuery(
51+
"SELECT v._id FROM vectorSearch(table='a', table='b', "
52+
+ "vector='[1.0]', option='k=5') AS v"));
53+
54+
assertThat(ex.getMessage(), containsString("Duplicate argument name"));
55+
}
56+
57+
@Test
58+
public void testPositionalArgRejected() throws IOException {
59+
// The SQL grammar already requires `ident=value` for each table function argument. That means
60+
// the literal case `vectorSearch('idx', ...)` does not even reach this resolver: it fails to
61+
// parse, then the whole statement falls back to the legacy engine, which previously returned
62+
// 200 with zero rows. Guard that path at the first engine that understands vectorSearch() by
63+
// sending a bogus named arg (grammar-legal but positional from the resolver's perspective, in
64+
// the sense that it does not map to a known parameter name). The resolver now surfaces a 400.
65+
ResponseException ex =
66+
expectThrows(
67+
ResponseException.class,
68+
() ->
69+
executeQuery(
70+
"SELECT v._id FROM vectorSearch(bogus='idx', field='f', "
71+
+ "vector='[1.0]', option='k=5') AS v"));
72+
73+
assertThat(ex.getMessage(), containsString("Unknown argument name"));
74+
}
75+
}

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.TABLE;
1111
import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.VECTOR;
1212

13+
import java.util.HashSet;
1314
import java.util.LinkedHashMap;
1415
import java.util.List;
1516
import java.util.Map;
@@ -38,7 +39,8 @@ public class VectorSearchTableFunctionImplementation extends FunctionExpression
3839

3940
/**
4041
* Field names must be safe for JSON interpolation: alphanumeric, dots (nested), underscores,
41-
* hyphens. Rejects characters that could corrupt the WrapperQueryBuilder JSON.
42+
* hyphens. Rejects characters that could corrupt the WrapperQueryBuilder JSON. The same regex is
43+
* reused for table names so user-supplied identifiers cannot break out of the JSON context.
4244
*/
4345
private static final Pattern SAFE_FIELD_NAME = Pattern.compile("^[a-zA-Z0-9._\\-]+$");
4446

@@ -90,6 +92,7 @@ public String toString() {
9092
public Table applyArguments() {
9193
validateNamedArgs();
9294
String tableName = getArgumentValue(TABLE);
95+
validateTableName(tableName);
9396
String fieldName = getArgumentValue(FIELD);
9497
validateFieldName(fieldName);
9598
String vectorLiteral = getArgumentValue(VECTOR);
@@ -146,8 +149,12 @@ static Map<String, String> parseOptions(String optionStr) {
146149
return options;
147150
}
148151

149-
/** Reject non-named arguments and null arg names early. */
152+
/**
153+
* Reject non-named arguments, null arg names, and duplicate named arguments early. Runs before
154+
* any list-index-based lookup so a malformed argument list can never cause an AIOOBE downstream.
155+
*/
150156
private void validateNamedArgs() {
157+
HashSet<String> seen = new HashSet<>();
151158
for (Expression arg : arguments) {
152159
if (!(arg instanceof NamedArgumentExpression)) {
153160
throw new ExpressionEvaluationException(
@@ -161,6 +168,27 @@ private void validateNamedArgs() {
161168
"vectorSearch() requires named arguments (e.g., table='index'), "
162169
+ "but received an argument with no name");
163170
}
171+
if (!seen.add(name.toLowerCase())) {
172+
throw new ExpressionEvaluationException(
173+
"Duplicate argument name '"
174+
+ name
175+
+ "' in vectorSearch(); each named argument may appear at most once");
176+
}
177+
}
178+
}
179+
180+
/**
181+
* Reject table names with characters that could corrupt the WrapperQueryBuilder JSON or escape
182+
* the target index name. Allows alphanumeric, dots, underscores, and hyphens (the characters
183+
* OpenSearch index names already permit).
184+
*/
185+
private void validateTableName(String tableName) {
186+
if (!SAFE_FIELD_NAME.matcher(tableName).matches()) {
187+
throw new ExpressionEvaluationException(
188+
String.format(
189+
"Invalid table name '%s': must contain only alphanumeric characters,"
190+
+ " dots, underscores, or hyphens",
191+
tableName));
164192
}
165193
}
166194

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolver.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77

88
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
99

10+
import java.util.HashSet;
1011
import java.util.List;
1112
import lombok.RequiredArgsConstructor;
1213
import org.apache.commons.lang3.tuple.Pair;
1314
import org.opensearch.sql.common.setting.Settings;
15+
import org.opensearch.sql.exception.ExpressionEvaluationException;
1416
import org.opensearch.sql.expression.Expression;
17+
import org.opensearch.sql.expression.NamedArgumentExpression;
1518
import org.opensearch.sql.expression.function.FunctionBuilder;
1619
import org.opensearch.sql.expression.function.FunctionName;
1720
import org.opensearch.sql.expression.function.FunctionResolver;
@@ -57,5 +60,35 @@ private void validateArguments(List<Expression> arguments) {
5760
"vectorSearch requires %d arguments (%s), got %d",
5861
ARGUMENT_NAMES.size(), String.join(", ", ARGUMENT_NAMES), arguments.size()));
5962
}
63+
// Shape check at the resolver so positional or unknown-named args produce a clean 400 before
64+
// planning proceeds. The Implementation layer repeats these checks as defense-in-depth.
65+
HashSet<String> seen = new HashSet<>();
66+
for (Expression arg : arguments) {
67+
if (!(arg instanceof NamedArgumentExpression)) {
68+
throw new ExpressionEvaluationException(
69+
"vectorSearch() requires named arguments (e.g., table='index'), "
70+
+ "but received: "
71+
+ arg.getClass().getSimpleName());
72+
}
73+
String name = ((NamedArgumentExpression) arg).getArgName();
74+
if (name == null || name.isEmpty()) {
75+
throw new ExpressionEvaluationException(
76+
"vectorSearch() requires named arguments (e.g., table='index'), "
77+
+ "but received an argument with no name");
78+
}
79+
String lower = name.toLowerCase();
80+
if (!ARGUMENT_NAMES.contains(lower)) {
81+
throw new ExpressionEvaluationException(
82+
String.format(
83+
"Unknown argument name '%s' in vectorSearch(); allowed names are %s",
84+
name, ARGUMENT_NAMES));
85+
}
86+
if (!seen.add(lower)) {
87+
throw new ExpressionEvaluationException(
88+
"Duplicate argument name '"
89+
+ name
90+
+ "' in vectorSearch(); each named argument may appear at most once");
91+
}
92+
}
6093
}
6194
}

opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,36 @@ void testNullArgNameThrows() {
260260
assertTrue(ex.getMessage().contains("requires named arguments"));
261261
}
262262

263+
@Test
264+
void applyArguments_rejectsInvalidTableName() {
265+
VectorSearchTableFunctionImplementation impl =
266+
createImplWithArgs("idx\"; DROP", "embedding", "[1.0, 2.0]", "k=5");
267+
ExpressionEvaluationException ex =
268+
assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments());
269+
assertTrue(ex.getMessage().contains("Invalid table name"));
270+
assertTrue(
271+
ex.getMessage()
272+
.contains("must contain only alphanumeric characters, dots, underscores, or hyphens"));
273+
}
274+
275+
@Test
276+
void validateNamedArgs_rejectsDuplicateNames() {
277+
// Two occurrences of "table" reach the Implementation layer directly (bypassing the resolver).
278+
FunctionName functionName = FunctionName.of("vectorsearch");
279+
List<Expression> args =
280+
List.of(
281+
DSL.namedArgument("table", DSL.literal("a")),
282+
DSL.namedArgument("table", DSL.literal("b")),
283+
DSL.namedArgument("vector", DSL.literal("[1.0]")),
284+
DSL.namedArgument("option", DSL.literal("k=5")));
285+
VectorSearchTableFunctionImplementation impl =
286+
new VectorSearchTableFunctionImplementation(functionName, args, client, settings);
287+
ExpressionEvaluationException ex =
288+
assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments());
289+
assertTrue(ex.getMessage().contains("Duplicate argument name"));
290+
assertTrue(ex.getMessage().contains("table"));
291+
}
292+
263293
private VectorSearchTableFunctionImplementation createImpl() {
264294
return createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5");
265295
}

opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.mockito.Mock;
1919
import org.mockito.junit.jupiter.MockitoExtension;
2020
import org.opensearch.sql.common.setting.Settings;
21+
import org.opensearch.sql.exception.ExpressionEvaluationException;
2122
import org.opensearch.sql.expression.DSL;
2223
import org.opensearch.sql.expression.Expression;
2324
import org.opensearch.sql.expression.function.FunctionBuilder;
@@ -83,4 +84,78 @@ void testWrongArgumentCount() {
8384
IllegalArgumentException.class, () -> builder.apply(functionProperties, expressions));
8485
assertTrue(ex.getMessage().contains("requires 4 arguments"));
8586
}
87+
88+
@Test
89+
void resolve_rejectsPositionalArgument() {
90+
VectorSearchTableFunctionResolver resolver =
91+
new VectorSearchTableFunctionResolver(client, settings);
92+
FunctionName functionName = FunctionName.of("vectorsearch");
93+
// One positional literal mixed with three named arguments. Arity passes, but the resolver
94+
// must reject this before planning so the SQL layer returns a clean 400 rather than a 200
95+
// with zero rows.
96+
List<Expression> expressions =
97+
List.of(
98+
DSL.literal("my-index"),
99+
DSL.namedArgument("field", DSL.literal("embedding")),
100+
DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]")),
101+
DSL.namedArgument("option", DSL.literal("k=5")));
102+
FunctionSignature functionSignature =
103+
new FunctionSignature(
104+
functionName, expressions.stream().map(Expression::type).collect(Collectors.toList()));
105+
FunctionBuilder builder = resolver.resolve(functionSignature).getValue();
106+
107+
ExpressionEvaluationException ex =
108+
assertThrows(
109+
ExpressionEvaluationException.class,
110+
() -> builder.apply(functionProperties, expressions));
111+
assertTrue(ex.getMessage().contains("requires named arguments"));
112+
}
113+
114+
@Test
115+
void resolve_rejectsDuplicateNamedArgument() {
116+
VectorSearchTableFunctionResolver resolver =
117+
new VectorSearchTableFunctionResolver(client, settings);
118+
FunctionName functionName = FunctionName.of("vectorsearch");
119+
List<Expression> expressions =
120+
List.of(
121+
DSL.namedArgument("table", DSL.literal("a")),
122+
DSL.namedArgument("table", DSL.literal("b")),
123+
DSL.namedArgument("vector", DSL.literal("[1.0]")),
124+
DSL.namedArgument("option", DSL.literal("k=5")));
125+
FunctionSignature functionSignature =
126+
new FunctionSignature(
127+
functionName, expressions.stream().map(Expression::type).collect(Collectors.toList()));
128+
FunctionBuilder builder = resolver.resolve(functionSignature).getValue();
129+
130+
ExpressionEvaluationException ex =
131+
assertThrows(
132+
ExpressionEvaluationException.class,
133+
() -> builder.apply(functionProperties, expressions));
134+
assertTrue(ex.getMessage().contains("Duplicate argument name"));
135+
assertTrue(ex.getMessage().contains("table"));
136+
}
137+
138+
@Test
139+
void resolve_rejectsUnknownArgumentName() {
140+
VectorSearchTableFunctionResolver resolver =
141+
new VectorSearchTableFunctionResolver(client, settings);
142+
FunctionName functionName = FunctionName.of("vectorsearch");
143+
List<Expression> expressions =
144+
List.of(
145+
DSL.namedArgument("table", DSL.literal("my-index")),
146+
DSL.namedArgument("field", DSL.literal("embedding")),
147+
DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]")),
148+
DSL.namedArgument("bogus", DSL.literal("k=5")));
149+
FunctionSignature functionSignature =
150+
new FunctionSignature(
151+
functionName, expressions.stream().map(Expression::type).collect(Collectors.toList()));
152+
FunctionBuilder builder = resolver.resolve(functionSignature).getValue();
153+
154+
ExpressionEvaluationException ex =
155+
assertThrows(
156+
ExpressionEvaluationException.class,
157+
() -> builder.apply(functionProperties, expressions));
158+
assertTrue(ex.getMessage().contains("Unknown argument name"));
159+
assertTrue(ex.getMessage().contains("bogus"));
160+
}
86161
}

0 commit comments

Comments
 (0)