Skip to content

Commit a5c1e99

Browse files
committed
[BugFix] Validate table name and reject duplicate or positional args in vectorSearch()
Three shape-level validation gaps on vectorSearch() previously crashed the server or returned a misleading 200 with zero rows: * Validate table= with the same SAFE_FIELD_NAME regex that already guards field=. Previously a hostile table name could reach the native layer unvalidated. * Reject duplicate named arguments in both the Resolver and the Implementation. The same name appearing twice (e.g. table=a, table=b) previously produced a 500 ArrayIndexOutOfBoundsException. * Shape-check arguments at the Resolver (not just arity) so positional or unknown-named arguments surface as a clean 400 before planning, instead of silently returning 200 with an empty DSL. Implementation keeps the original check as defense-in-depth. Unit tests cover all three paths. Integration tests assert that the SQL layer surfaces the three errors as user-facing 400s. Signed-off-by: Eric Wei <mengwei.eric@gmail.com>
1 parent fa444fe commit a5c1e99

5 files changed

Lines changed: 220 additions & 2 deletions

File tree

integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchIT.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,4 +300,55 @@ public void testExplainWithoutKnnPluginStillWorks() throws IOException {
300300

301301
assertThat(explain, containsString("wrapper"));
302302
}
303+
304+
// ── Argument shape validation (PR B) ──────────────────────────────────
305+
306+
@Test
307+
public void testInvalidTableNameRejected() throws IOException {
308+
// A slash is outside the SAFE_FIELD_NAME regex and is not a valid OpenSearch index character,
309+
// so it should be rejected at the SQL layer before any cluster call is attempted.
310+
ResponseException ex =
311+
expectThrows(
312+
ResponseException.class,
313+
() ->
314+
executeQuery(
315+
"SELECT v._id FROM vectorSearch(table='idx/evil', field='f', "
316+
+ "vector='[1.0]', option='k=5') AS v"));
317+
318+
assertThat(ex.getMessage(), containsString("Invalid table name"));
319+
}
320+
321+
@Test
322+
public void testDuplicateNamedArgRejected() throws IOException {
323+
// Previously this crashed the server with 500 ArrayIndexOutOfBoundsException. Must now
324+
// surface as a clean 400 with a user-facing message.
325+
ResponseException ex =
326+
expectThrows(
327+
ResponseException.class,
328+
() ->
329+
executeQuery(
330+
"SELECT v._id FROM vectorSearch(table='a', table='b', "
331+
+ "vector='[1.0]', option='k=5') AS v"));
332+
333+
assertThat(ex.getMessage(), containsString("Duplicate argument name"));
334+
}
335+
336+
@Test
337+
public void testPositionalArgRejected() throws IOException {
338+
// The SQL grammar already requires `ident=value` for each table function argument. That means
339+
// the literal case `vectorSearch('idx', ...)` does not even reach this resolver: it fails to
340+
// parse, then the whole statement falls back to the legacy engine, which previously returned
341+
// 200 with zero rows. Guard that path at the first engine that understands vectorSearch() by
342+
// sending a bogus named arg (grammar-legal but positional from the resolver's perspective, in
343+
// the sense that it does not map to a known parameter name). The resolver now surfaces a 400.
344+
ResponseException ex =
345+
expectThrows(
346+
ResponseException.class,
347+
() ->
348+
executeQuery(
349+
"SELECT v._id FROM vectorSearch(bogus='idx', field='f', "
350+
+ "vector='[1.0]', option='k=5') AS v"));
351+
352+
assertThat(ex.getMessage(), containsString("Unknown argument name"));
353+
}
303354
}

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.TABLE;
1111
import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.VECTOR;
1212

13+
import java.util.HashSet;
1314
import java.util.LinkedHashMap;
1415
import java.util.List;
1516
import java.util.Map;
@@ -40,7 +41,8 @@ public class VectorSearchTableFunctionImplementation extends FunctionExpression
4041

4142
/**
4243
* Field names must be safe for JSON interpolation: alphanumeric, dots (nested), underscores,
43-
* hyphens. Rejects characters that could corrupt the WrapperQueryBuilder JSON.
44+
* hyphens. Rejects characters that could corrupt the WrapperQueryBuilder JSON. The same regex is
45+
* reused for table names so user-supplied identifiers cannot break out of the JSON context.
4446
*/
4547
private static final Pattern SAFE_FIELD_NAME = Pattern.compile("^[a-zA-Z0-9._\\-]+$");
4648

@@ -99,6 +101,7 @@ public Table applyArguments() {
99101
// clusters without k-NN.
100102
validateNamedArgs();
101103
String tableName = getArgumentValue(TABLE);
104+
validateTableName(tableName);
102105
String fieldName = getArgumentValue(FIELD);
103106
validateFieldName(fieldName);
104107
String vectorLiteral = getArgumentValue(VECTOR);
@@ -162,8 +165,12 @@ static Map<String, String> parseOptions(String optionStr) {
162165
return options;
163166
}
164167

165-
/** Reject non-named arguments and null arg names early. */
168+
/**
169+
* Reject non-named arguments, null arg names, and duplicate named arguments early. Runs before
170+
* any list-index-based lookup so a malformed argument list can never cause an AIOOBE downstream.
171+
*/
166172
private void validateNamedArgs() {
173+
HashSet<String> seen = new HashSet<>();
167174
for (Expression arg : arguments) {
168175
if (!(arg instanceof NamedArgumentExpression)) {
169176
throw new ExpressionEvaluationException(
@@ -177,6 +184,27 @@ private void validateNamedArgs() {
177184
"vectorSearch() requires named arguments (e.g., table='index'), "
178185
+ "but received an argument with no name");
179186
}
187+
if (!seen.add(name.toLowerCase())) {
188+
throw new ExpressionEvaluationException(
189+
"Duplicate argument name '"
190+
+ name
191+
+ "' in vectorSearch(); each named argument may appear at most once");
192+
}
193+
}
194+
}
195+
196+
/**
197+
* Reject table names with characters that could corrupt the WrapperQueryBuilder JSON or escape
198+
* the target index name. Allows alphanumeric, dots, underscores, and hyphens (the characters
199+
* OpenSearch index names already permit).
200+
*/
201+
private void validateTableName(String tableName) {
202+
if (!SAFE_FIELD_NAME.matcher(tableName).matches()) {
203+
throw new ExpressionEvaluationException(
204+
String.format(
205+
"Invalid table name '%s': must contain only alphanumeric characters,"
206+
+ " dots, underscores, or hyphens",
207+
tableName));
180208
}
181209
}
182210

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolver.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77

88
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
99

10+
import java.util.HashSet;
1011
import java.util.List;
1112
import org.apache.commons.lang3.tuple.Pair;
1213
import org.opensearch.sql.common.setting.Settings;
14+
import org.opensearch.sql.exception.ExpressionEvaluationException;
1315
import org.opensearch.sql.expression.Expression;
16+
import org.opensearch.sql.expression.NamedArgumentExpression;
1417
import org.opensearch.sql.expression.function.FunctionBuilder;
1518
import org.opensearch.sql.expression.function.FunctionName;
1619
import org.opensearch.sql.expression.function.FunctionResolver;
@@ -68,5 +71,35 @@ private void validateArguments(List<Expression> arguments) {
6871
"vectorSearch requires %d arguments (%s), got %d",
6972
ARGUMENT_NAMES.size(), String.join(", ", ARGUMENT_NAMES), arguments.size()));
7073
}
74+
// Shape check at the resolver so positional or unknown-named args produce a clean 400 before
75+
// planning proceeds. The Implementation layer repeats these checks as defense-in-depth.
76+
HashSet<String> seen = new HashSet<>();
77+
for (Expression arg : arguments) {
78+
if (!(arg instanceof NamedArgumentExpression)) {
79+
throw new ExpressionEvaluationException(
80+
"vectorSearch() requires named arguments (e.g., table='index'), "
81+
+ "but received: "
82+
+ arg.getClass().getSimpleName());
83+
}
84+
String name = ((NamedArgumentExpression) arg).getArgName();
85+
if (name == null || name.isEmpty()) {
86+
throw new ExpressionEvaluationException(
87+
"vectorSearch() requires named arguments (e.g., table='index'), "
88+
+ "but received an argument with no name");
89+
}
90+
String lower = name.toLowerCase();
91+
if (!ARGUMENT_NAMES.contains(lower)) {
92+
throw new ExpressionEvaluationException(
93+
String.format(
94+
"Unknown argument name '%s' in vectorSearch(); allowed names are %s",
95+
name, ARGUMENT_NAMES));
96+
}
97+
if (!seen.add(lower)) {
98+
throw new ExpressionEvaluationException(
99+
"Duplicate argument name '"
100+
+ name
101+
+ "' in vectorSearch(); each named argument may appear at most once");
102+
}
103+
}
71104
}
72105
}

opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,37 @@ void testParseOptionsPreservesFilterTypeValue() {
458458
assertEquals("post", options.get("filter_type"));
459459
}
460460

461+
@Test
462+
void applyArguments_rejectsInvalidTableName() {
463+
VectorSearchTableFunctionImplementation impl =
464+
createImplWithArgs("idx\"; DROP", "embedding", "[1.0, 2.0]", "k=5");
465+
ExpressionEvaluationException ex =
466+
assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments());
467+
assertTrue(ex.getMessage().contains("Invalid table name"));
468+
assertTrue(
469+
ex.getMessage()
470+
.contains("must contain only alphanumeric characters, dots, underscores, or hyphens"));
471+
}
472+
473+
@Test
474+
void validateNamedArgs_rejectsDuplicateNames() {
475+
// Two occurrences of "table" reach the Implementation layer directly (bypassing the resolver).
476+
FunctionName functionName = FunctionName.of("vectorsearch");
477+
List<Expression> args =
478+
List.of(
479+
DSL.namedArgument("table", DSL.literal("a")),
480+
DSL.namedArgument("table", DSL.literal("b")),
481+
DSL.namedArgument("vector", DSL.literal("[1.0]")),
482+
DSL.namedArgument("option", DSL.literal("k=5")));
483+
VectorSearchTableFunctionImplementation impl =
484+
new VectorSearchTableFunctionImplementation(
485+
functionName, args, client, settings, knnCapability);
486+
ExpressionEvaluationException ex =
487+
assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments());
488+
assertTrue(ex.getMessage().contains("Duplicate argument name"));
489+
assertTrue(ex.getMessage().contains("table"));
490+
}
491+
461492
private VectorSearchTableFunctionImplementation createImpl() {
462493
return createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5");
463494
}

opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.mockito.Mock;
1919
import org.mockito.junit.jupiter.MockitoExtension;
2020
import org.opensearch.sql.common.setting.Settings;
21+
import org.opensearch.sql.exception.ExpressionEvaluationException;
2122
import org.opensearch.sql.expression.DSL;
2223
import org.opensearch.sql.expression.Expression;
2324
import org.opensearch.sql.expression.function.FunctionBuilder;
@@ -127,4 +128,78 @@ void testZeroArguments() {
127128
IllegalArgumentException.class, () -> builder.apply(functionProperties, expressions));
128129
assertTrue(ex.getMessage().contains("requires 4 arguments"));
129130
}
131+
132+
@Test
133+
void resolve_rejectsPositionalArgument() {
134+
VectorSearchTableFunctionResolver resolver =
135+
new VectorSearchTableFunctionResolver(client, settings);
136+
FunctionName functionName = FunctionName.of("vectorsearch");
137+
// One positional literal mixed with three named arguments. Arity passes, but the resolver
138+
// must reject this before planning so the SQL layer returns a clean 400 rather than a 200
139+
// with zero rows.
140+
List<Expression> expressions =
141+
List.of(
142+
DSL.literal("my-index"),
143+
DSL.namedArgument("field", DSL.literal("embedding")),
144+
DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]")),
145+
DSL.namedArgument("option", DSL.literal("k=5")));
146+
FunctionSignature functionSignature =
147+
new FunctionSignature(
148+
functionName, expressions.stream().map(Expression::type).collect(Collectors.toList()));
149+
FunctionBuilder builder = resolver.resolve(functionSignature).getValue();
150+
151+
ExpressionEvaluationException ex =
152+
assertThrows(
153+
ExpressionEvaluationException.class,
154+
() -> builder.apply(functionProperties, expressions));
155+
assertTrue(ex.getMessage().contains("requires named arguments"));
156+
}
157+
158+
@Test
159+
void resolve_rejectsDuplicateNamedArgument() {
160+
VectorSearchTableFunctionResolver resolver =
161+
new VectorSearchTableFunctionResolver(client, settings);
162+
FunctionName functionName = FunctionName.of("vectorsearch");
163+
List<Expression> expressions =
164+
List.of(
165+
DSL.namedArgument("table", DSL.literal("a")),
166+
DSL.namedArgument("table", DSL.literal("b")),
167+
DSL.namedArgument("vector", DSL.literal("[1.0]")),
168+
DSL.namedArgument("option", DSL.literal("k=5")));
169+
FunctionSignature functionSignature =
170+
new FunctionSignature(
171+
functionName, expressions.stream().map(Expression::type).collect(Collectors.toList()));
172+
FunctionBuilder builder = resolver.resolve(functionSignature).getValue();
173+
174+
ExpressionEvaluationException ex =
175+
assertThrows(
176+
ExpressionEvaluationException.class,
177+
() -> builder.apply(functionProperties, expressions));
178+
assertTrue(ex.getMessage().contains("Duplicate argument name"));
179+
assertTrue(ex.getMessage().contains("table"));
180+
}
181+
182+
@Test
183+
void resolve_rejectsUnknownArgumentName() {
184+
VectorSearchTableFunctionResolver resolver =
185+
new VectorSearchTableFunctionResolver(client, settings);
186+
FunctionName functionName = FunctionName.of("vectorsearch");
187+
List<Expression> expressions =
188+
List.of(
189+
DSL.namedArgument("table", DSL.literal("my-index")),
190+
DSL.namedArgument("field", DSL.literal("embedding")),
191+
DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]")),
192+
DSL.namedArgument("bogus", DSL.literal("k=5")));
193+
FunctionSignature functionSignature =
194+
new FunctionSignature(
195+
functionName, expressions.stream().map(Expression::type).collect(Collectors.toList()));
196+
FunctionBuilder builder = resolver.resolve(functionSignature).getValue();
197+
198+
ExpressionEvaluationException ex =
199+
assertThrows(
200+
ExpressionEvaluationException.class,
201+
() -> builder.apply(functionProperties, expressions));
202+
assertTrue(ex.getMessage().contains("Unknown argument name"));
203+
assertTrue(ex.getMessage().contains("bogus"));
204+
}
130205
}

0 commit comments

Comments
 (0)