Skip to content

Commit 92e67be

Browse files
mengweiericclaude
andcommitted
WIP: add unit tests for vectorSearch table function
- VectorSearchTableFunctionResolverTest: resolution, wrong arg count - VectorSearchTableFunctionImplementationTest: valueOf, type, toString, applyArguments (bracketed/unbracketed vector, complex options), missing k option, missing argument - OpenSearchStorageEngineTest: getFunctions returns vector search resolver Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1a1be1a commit 92e67be

3 files changed

Lines changed: 227 additions & 0 deletions

File tree

opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,21 @@
66
package org.opensearch.sql.opensearch.storage;
77

88
import static org.junit.jupiter.api.Assertions.assertAll;
9+
import static org.junit.jupiter.api.Assertions.assertEquals;
10+
import static org.junit.jupiter.api.Assertions.assertFalse;
911
import static org.junit.jupiter.api.Assertions.assertNotNull;
1012
import static org.junit.jupiter.api.Assertions.assertTrue;
1113
import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME;
1214
import static org.opensearch.sql.utils.SystemIndexUtils.TABLE_INFO;
1315

16+
import java.util.Collection;
1417
import org.junit.jupiter.api.Test;
1518
import org.junit.jupiter.api.extension.ExtendWith;
1619
import org.mockito.Mock;
1720
import org.mockito.junit.jupiter.MockitoExtension;
1821
import org.opensearch.sql.DataSourceSchemaName;
1922
import org.opensearch.sql.common.setting.Settings;
23+
import org.opensearch.sql.expression.function.FunctionResolver;
2024
import org.opensearch.sql.opensearch.client.OpenSearchClient;
2125
import org.opensearch.sql.opensearch.storage.system.OpenSearchSystemIndex;
2226
import org.opensearch.sql.storage.Table;
@@ -43,4 +47,13 @@ public void getSystemTable() {
4347
engine.getTable(new DataSourceSchemaName(DEFAULT_DATASOURCE_NAME, "default"), TABLE_INFO);
4448
assertAll(() -> assertNotNull(table), () -> assertTrue(table instanceof OpenSearchSystemIndex));
4549
}
50+
51+
@Test
52+
public void getFunctionsReturnsVectorSearchResolver() {
53+
OpenSearchStorageEngine engine = new OpenSearchStorageEngine(client, settings);
54+
Collection<FunctionResolver> functions = engine.getFunctions();
55+
assertFalse(functions.isEmpty());
56+
assertEquals(1, functions.size());
57+
assertTrue(functions.iterator().next() instanceof VectorSearchTableFunctionResolver);
58+
}
4659
}
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.opensearch.storage;
7+
8+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
9+
import static org.junit.jupiter.api.Assertions.assertEquals;
10+
import static org.junit.jupiter.api.Assertions.assertThrows;
11+
import static org.junit.jupiter.api.Assertions.assertTrue;
12+
13+
import java.util.List;
14+
import org.junit.jupiter.api.Test;
15+
import org.junit.jupiter.api.extension.ExtendWith;
16+
import org.mockito.Mock;
17+
import org.mockito.junit.jupiter.MockitoExtension;
18+
import org.opensearch.sql.common.setting.Settings;
19+
import org.opensearch.sql.data.type.ExprCoreType;
20+
import org.opensearch.sql.exception.ExpressionEvaluationException;
21+
import org.opensearch.sql.expression.DSL;
22+
import org.opensearch.sql.expression.Expression;
23+
import org.opensearch.sql.expression.function.FunctionName;
24+
import org.opensearch.sql.opensearch.client.OpenSearchClient;
25+
import org.opensearch.sql.storage.Table;
26+
27+
@ExtendWith(MockitoExtension.class)
28+
class VectorSearchTableFunctionImplementationTest {
29+
30+
@Mock private OpenSearchClient client;
31+
32+
@Mock private Settings settings;
33+
34+
@Test
35+
void testValueOfThrows() {
36+
VectorSearchTableFunctionImplementation impl = createImpl();
37+
UnsupportedOperationException ex =
38+
assertThrows(UnsupportedOperationException.class, () -> impl.valueOf());
39+
assertTrue(ex.getMessage().contains("only supported in FROM clause"));
40+
}
41+
42+
@Test
43+
void testType() {
44+
VectorSearchTableFunctionImplementation impl = createImpl();
45+
assertEquals(ExprCoreType.STRUCT, impl.type());
46+
}
47+
48+
@Test
49+
void testToString() {
50+
VectorSearchTableFunctionImplementation impl = createImpl();
51+
String str = impl.toString();
52+
assertTrue(str.contains("vectorsearch"));
53+
assertTrue(str.contains("table="));
54+
assertTrue(str.contains("my-index"));
55+
}
56+
57+
@Test
58+
void testApplyArguments() {
59+
VectorSearchTableFunctionImplementation impl = createImpl();
60+
Table table = impl.applyArguments();
61+
assertTrue(table instanceof VectorSearchIndex);
62+
}
63+
64+
@Test
65+
void testApplyArgumentsWithBracketedVector() {
66+
VectorSearchTableFunctionImplementation impl =
67+
createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5");
68+
Table table = impl.applyArguments();
69+
assertTrue(table instanceof VectorSearchIndex);
70+
}
71+
72+
@Test
73+
void testApplyArgumentsWithUnbracketedVector() {
74+
VectorSearchTableFunctionImplementation impl =
75+
createImplWithArgs("my-index", "embedding", "1.0, 2.0, 3.0", "k=5");
76+
Table table = impl.applyArguments();
77+
assertTrue(table instanceof VectorSearchIndex);
78+
}
79+
80+
@Test
81+
void testApplyArgumentsWithComplexOptions() {
82+
VectorSearchTableFunctionImplementation impl =
83+
createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=10,method.ef_search=100");
84+
Table table = impl.applyArguments();
85+
assertTrue(table instanceof VectorSearchIndex);
86+
}
87+
88+
@Test
89+
void testMissingKOptionThrows() {
90+
VectorSearchTableFunctionImplementation impl =
91+
createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "method.ef_search=100");
92+
ExpressionEvaluationException ex =
93+
assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments());
94+
assertEquals("Missing required option: k", ex.getMessage());
95+
}
96+
97+
@Test
98+
void testMissingArgumentThrows() {
99+
FunctionName functionName = FunctionName.of("vectorsearch");
100+
List<Expression> args =
101+
List.of(
102+
DSL.namedArgument("table", DSL.literal("my-index")),
103+
DSL.namedArgument("field", DSL.literal("embedding")),
104+
DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]")));
105+
VectorSearchTableFunctionImplementation impl =
106+
new VectorSearchTableFunctionImplementation(functionName, args, client, settings);
107+
ExpressionEvaluationException ex =
108+
assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments());
109+
assertEquals("Missing required argument: option", ex.getMessage());
110+
}
111+
112+
private VectorSearchTableFunctionImplementation createImpl() {
113+
return createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5");
114+
}
115+
116+
private VectorSearchTableFunctionImplementation createImplWithArgs(
117+
String table, String field, String vector, String option) {
118+
FunctionName functionName = FunctionName.of("vectorsearch");
119+
List<Expression> args =
120+
List.of(
121+
DSL.namedArgument("table", DSL.literal(table)),
122+
DSL.namedArgument("field", DSL.literal(field)),
123+
DSL.namedArgument("vector", DSL.literal(vector)),
124+
DSL.namedArgument("option", DSL.literal(option)));
125+
return new VectorSearchTableFunctionImplementation(functionName, args, client, settings);
126+
}
127+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.opensearch.storage;
7+
8+
import static org.junit.jupiter.api.Assertions.assertEquals;
9+
import static org.junit.jupiter.api.Assertions.assertThrows;
10+
import static org.junit.jupiter.api.Assertions.assertTrue;
11+
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
12+
13+
import java.util.List;
14+
import java.util.stream.Collectors;
15+
import org.apache.commons.lang3.tuple.Pair;
16+
import org.junit.jupiter.api.Test;
17+
import org.junit.jupiter.api.extension.ExtendWith;
18+
import org.mockito.Mock;
19+
import org.mockito.junit.jupiter.MockitoExtension;
20+
import org.opensearch.sql.common.setting.Settings;
21+
import org.opensearch.sql.expression.DSL;
22+
import org.opensearch.sql.expression.Expression;
23+
import org.opensearch.sql.expression.function.FunctionBuilder;
24+
import org.opensearch.sql.expression.function.FunctionName;
25+
import org.opensearch.sql.expression.function.FunctionProperties;
26+
import org.opensearch.sql.expression.function.FunctionSignature;
27+
import org.opensearch.sql.expression.function.TableFunctionImplementation;
28+
import org.opensearch.sql.opensearch.client.OpenSearchClient;
29+
30+
@ExtendWith(MockitoExtension.class)
31+
class VectorSearchTableFunctionResolverTest {
32+
33+
@Mock private OpenSearchClient client;
34+
35+
@Mock private Settings settings;
36+
37+
@Mock private FunctionProperties functionProperties;
38+
39+
@Test
40+
void testResolve() {
41+
VectorSearchTableFunctionResolver resolver =
42+
new VectorSearchTableFunctionResolver(client, settings);
43+
FunctionName functionName = FunctionName.of("vectorsearch");
44+
List<Expression> expressions =
45+
List.of(
46+
DSL.namedArgument("table", DSL.literal("my-index")),
47+
DSL.namedArgument("field", DSL.literal("embedding")),
48+
DSL.namedArgument("vector", DSL.literal("[1.0, 2.0, 3.0]")),
49+
DSL.namedArgument("option", DSL.literal("k=5")));
50+
FunctionSignature functionSignature =
51+
new FunctionSignature(
52+
functionName, expressions.stream().map(Expression::type).collect(Collectors.toList()));
53+
54+
Pair<FunctionSignature, FunctionBuilder> resolution = resolver.resolve(functionSignature);
55+
56+
assertEquals(functionName, resolution.getKey().getFunctionName());
57+
assertEquals(functionName, resolver.getFunctionName());
58+
assertEquals(List.of(STRING, STRING, STRING, STRING), resolution.getKey().getParamTypeList());
59+
60+
TableFunctionImplementation impl =
61+
(TableFunctionImplementation) resolution.getValue().apply(functionProperties, expressions);
62+
assertTrue(impl instanceof VectorSearchTableFunctionImplementation);
63+
}
64+
65+
@Test
66+
void testWrongArgumentCount() {
67+
VectorSearchTableFunctionResolver resolver =
68+
new VectorSearchTableFunctionResolver(client, settings);
69+
FunctionName functionName = FunctionName.of("vectorsearch");
70+
List<Expression> expressions =
71+
List.of(
72+
DSL.namedArgument("table", DSL.literal("my-index")),
73+
DSL.namedArgument("field", DSL.literal("embedding")));
74+
FunctionSignature functionSignature =
75+
new FunctionSignature(
76+
functionName, expressions.stream().map(Expression::type).collect(Collectors.toList()));
77+
78+
Pair<FunctionSignature, FunctionBuilder> resolution = resolver.resolve(functionSignature);
79+
FunctionBuilder builder = resolution.getValue();
80+
81+
IllegalArgumentException ex =
82+
assertThrows(
83+
IllegalArgumentException.class,
84+
() -> builder.apply(functionProperties, expressions));
85+
assertTrue(ex.getMessage().contains("requires 4 arguments"));
86+
}
87+
}

0 commit comments

Comments
 (0)