Skip to content

Commit c539d3c

Browse files
mengweiericclaude
andcommitted
WIP: generalize option parsing to support k, max_distance, and min_score
Replace hardcoded k parameter with generic options map. All knn query options (k, max_distance, min_score, method.ef_search, etc.) are now passed through from the option string to the knn query JSON. Validates that at least one search mode (k, max_distance, or min_score) is specified. Adds tests for radial search modes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 92e67be commit c539d3c

3 files changed

Lines changed: 76 additions & 14 deletions

File tree

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.sql.opensearch.storage;
77

8+
import java.util.Map;
89
import java.util.function.Function;
910
import org.opensearch.common.unit.TimeValue;
1011
import org.opensearch.index.query.QueryBuilder;
@@ -21,21 +22,23 @@
2122
*/
2223
public class VectorSearchIndex extends OpenSearchIndex {
2324

25+
private static final String VECTOR_OPTION = "vector";
26+
2427
private final String field;
2528
private final float[] vector;
26-
private final int k;
29+
private final Map<String, String> options;
2730

2831
public VectorSearchIndex(
2932
OpenSearchClient client,
3033
Settings settings,
3134
String indexName,
3235
String field,
3336
float[] vector,
34-
int k) {
37+
Map<String, String> options) {
3538
super(client, settings, indexName);
3639
this.field = field;
3740
this.vector = vector;
38-
this.k = k;
41+
this.options = options;
3942
}
4043

4144
@Override
@@ -66,9 +69,31 @@ private QueryBuilder buildKnnQuery() {
6669
}
6770
vectorJson.append("]");
6871

72+
StringBuilder optionsJson = new StringBuilder();
73+
for (Map.Entry<String, String> entry : options.entrySet()) {
74+
optionsJson.append(",");
75+
String value = entry.getValue();
76+
// Numeric values go unquoted, everything else quoted
77+
if (isNumeric(value)) {
78+
optionsJson.append(String.format("\"%s\":%s", entry.getKey(), value));
79+
} else {
80+
optionsJson.append(String.format("\"%s\":\"%s\"", entry.getKey(), value));
81+
}
82+
}
83+
6984
String knnQueryJson =
7085
String.format(
71-
"{\"knn\":{\"%s\":{\"vector\":%s,\"k\":%d}}}", field, vectorJson.toString(), k);
86+
"{\"knn\":{\"%s\":{\"vector\":%s%s}}}",
87+
field, vectorJson.toString(), optionsJson.toString());
7288
return new WrapperQueryBuilder(knnQueryJson);
7389
}
90+
91+
private static boolean isNumeric(String str) {
92+
try {
93+
Double.parseDouble(str);
94+
return true;
95+
} catch (NumberFormatException e) {
96+
return false;
97+
}
98+
}
7499
}

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.TABLE;
1111
import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.VECTOR;
1212

13+
import java.util.LinkedHashMap;
1314
import java.util.List;
15+
import java.util.Map;
1416
import java.util.stream.Collectors;
1517
import org.opensearch.sql.common.setting.Settings;
1618
import org.opensearch.sql.data.model.ExprValue;
@@ -80,9 +82,10 @@ public Table applyArguments() {
8082
String optionStr = getArgumentValue(OPTION);
8183

8284
float[] vector = parseVector(vectorLiteral);
83-
int k = parseK(optionStr);
85+
Map<String, String> options = parseOptions(optionStr);
86+
validateOptions(options);
8487

85-
return new VectorSearchIndex(client, settings, tableName, fieldName, vector, k);
88+
return new VectorSearchIndex(client, settings, tableName, fieldName, vector, options);
8689
}
8790

8891
private float[] parseVector(String vectorLiteral) {
@@ -96,15 +99,25 @@ private float[] parseVector(String vectorLiteral) {
9699
return vector;
97100
}
98101

99-
private int parseK(String optionStr) {
100-
// Parse "k=10" or "k=10,method.ef_search=100"
102+
static Map<String, String> parseOptions(String optionStr) {
103+
Map<String, String> options = new LinkedHashMap<>();
101104
for (String pair : optionStr.split(",")) {
102105
String[] kv = pair.trim().split("=", 2);
103-
if (kv.length == 2 && kv[0].trim().equals("k")) {
104-
return Integer.parseInt(kv[1].trim());
106+
if (kv.length == 2) {
107+
options.put(kv[0].trim(), kv[1].trim());
105108
}
106109
}
107-
throw new ExpressionEvaluationException("Missing required option: k");
110+
return options;
111+
}
112+
113+
private void validateOptions(Map<String, String> options) {
114+
boolean hasK = options.containsKey("k");
115+
boolean hasMaxDistance = options.containsKey("max_distance");
116+
boolean hasMinScore = options.containsKey("min_score");
117+
if (!hasK && !hasMaxDistance && !hasMinScore) {
118+
throw new ExpressionEvaluationException(
119+
"Missing required option: one of k, max_distance, or min_score");
120+
}
108121
}
109122

110123
/** Extract a named argument's string value. */

opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
package org.opensearch.sql.opensearch.storage;
77

8-
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
98
import static org.junit.jupiter.api.Assertions.assertEquals;
109
import static org.junit.jupiter.api.Assertions.assertThrows;
1110
import static org.junit.jupiter.api.Assertions.assertTrue;
1211

1312
import java.util.List;
13+
import java.util.Map;
1414
import org.junit.jupiter.api.Test;
1515
import org.junit.jupiter.api.extension.ExtendWith;
1616
import org.mockito.Mock;
@@ -86,12 +86,36 @@ void testApplyArgumentsWithComplexOptions() {
8686
}
8787

8888
@Test
89-
void testMissingKOptionThrows() {
89+
void testApplyArgumentsWithMaxDistance() {
90+
VectorSearchTableFunctionImplementation impl =
91+
createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "max_distance=10.0");
92+
Table table = impl.applyArguments();
93+
assertTrue(table instanceof VectorSearchIndex);
94+
}
95+
96+
@Test
97+
void testApplyArgumentsWithMinScore() {
98+
VectorSearchTableFunctionImplementation impl =
99+
createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "min_score=0.5");
100+
Table table = impl.applyArguments();
101+
assertTrue(table instanceof VectorSearchIndex);
102+
}
103+
104+
@Test
105+
void testMissingSearchModeOptionThrows() {
90106
VectorSearchTableFunctionImplementation impl =
91107
createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "method.ef_search=100");
92108
ExpressionEvaluationException ex =
93109
assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments());
94-
assertEquals("Missing required option: k", ex.getMessage());
110+
assertTrue(ex.getMessage().contains("one of k, max_distance, or min_score"));
111+
}
112+
113+
@Test
114+
void testParseOptionsMultiple() {
115+
Map<String, String> opts =
116+
VectorSearchTableFunctionImplementation.parseOptions("k=5,method.ef_search=100");
117+
assertEquals("5", opts.get("k"));
118+
assertEquals("100", opts.get("method.ef_search"));
95119
}
96120

97121
@Test

0 commit comments

Comments
 (0)