1010import static org .opensearch .sql .util .TestUtils .performRequest ;
1111
1212import java .io .IOException ;
13+ import java .util .Arrays ;
14+ import java .util .HashSet ;
15+ import java .util .Set ;
1316import org .json .JSONArray ;
1417import org .json .JSONObject ;
1518import org .junit .Assume ;
2528 *
2629 * <p>The k-NN plugin is not provisioned by the default integ-test cluster — each test calls {@link
2730 * Assume#assumeTrue} on {@link #isKnnPluginInstalled()} so the class is silently skipped when k-NN
28- * is absent. Run locally after {@code scripts/setup -knn-local.sh} has wired k-NN into the test
29- * cluster. Provisioning k-NN in CI is a separate follow-up.
31+ * is absent. Run locally against a cluster that has opensearch -knn installed. Provisioning k-NN in
32+ * CI is a separate follow-up.
3033 */
3134public class VectorSearchExecutionIT extends SQLIntegTestCase {
3235
@@ -35,12 +38,23 @@ public class VectorSearchExecutionIT extends SQLIntegTestCase {
3538 // 6 docs in 2D — two clusters so filter/radial tests have distinguishable results.
3639 // Cluster A near [1, 1]: docs 1-3 (state=TX, ages 25/30/40).
3740 // Cluster B near [9, 9]: docs 4-6 (state=CA, ages 28/35/45).
41+ // Pin Lucene HNSW + L2 so efficient filtering is deterministic (k-NN supports efficient
42+ // filtering only on lucene+hnsw and faiss+hnsw/ivf) and the L2 → 1/(1+d) scoring used by the
43+ // radial min_score test is well-defined.
3844 private static final String MAPPING =
3945 "{"
4046 + " \" settings\" : {\" index\" : {\" knn\" : true}},"
4147 + " \" mappings\" : {"
4248 + " \" properties\" : {"
43- + " \" embedding\" : {\" type\" : \" knn_vector\" , \" dimension\" : 2},"
49+ + " \" embedding\" : {"
50+ + " \" type\" : \" knn_vector\" ,"
51+ + " \" dimension\" : 2,"
52+ + " \" method\" : {"
53+ + " \" name\" : \" hnsw\" ,"
54+ + " \" engine\" : \" lucene\" ,"
55+ + " \" space_type\" : \" l2\" "
56+ + " }"
57+ + " },"
4458 + " \" state\" : {\" type\" : \" keyword\" },"
4559 + " \" age\" : {\" type\" : \" integer\" }"
4660 + " }"
@@ -119,8 +133,9 @@ public void testTopKReturnsNearestSortedByScore() throws IOException {
119133
120134 @ Test
121135 public void testPostFilterReturnsOnlyMatchingDocs () throws IOException {
122- // Query from cluster B with WHERE state='TX' should force the scan to find TX docs
123- // (cluster A) even though the vector is closer to cluster B. Proves filter is applied.
136+ // Query from cluster B with WHERE state='TX' forces POST filtering to surface TX docs
137+ // (cluster A) even though the vector is closer to cluster B. k=10 covers all 6 docs so
138+ // post-filtering to state='TX' deterministically yields exactly {1,2,3}.
124139 JSONObject result =
125140 executeJdbcRequest (
126141 "SELECT v._id, v._score "
@@ -131,46 +146,37 @@ public void testPostFilterReturnsOnlyMatchingDocs() throws IOException {
131146 + "WHERE v.state = 'TX' "
132147 + "LIMIT 10" );
133148
134- JSONArray rows = result .getJSONArray ("datarows" );
135- assertTrue ("Expected at least one row:\n " + result , rows .length () > 0 );
136- for (int i = 0 ; i < rows .length (); i ++) {
137- String id = rows .getJSONArray (i ).getString (0 );
138- assertTrue (
139- "Row " + i + " id=" + id + " should be from TX cluster (1,2,3):\n " + result ,
140- id .equals ("1" ) || id .equals ("2" ) || id .equals ("3" ));
141- }
149+ assertRowIdsEqual (result , "1" , "2" , "3" );
142150 }
143151
144152 // ── EFFICIENT filter happy path ─────────────────────────────────────
145153
146154 @ Test
147155 public void testEfficientFilterReturnsOnlyMatchingDocs () throws IOException {
156+ // Query vector sits on cluster A (TX) but WHERE state='CA' forces EFFICIENT filtering to
157+ // navigate HNSW toward CA docs. With k=3, a POST-filter implementation would return 0 rows
158+ // (the 3 nearest candidates are all TX, which get filtered out); an efficient-filter
159+ // implementation returns exactly the 3 CA docs {4,5,6}. This asymmetry makes the test
160+ // discriminate between the two filter modes.
148161 JSONObject result =
149162 executeJdbcRequest (
150163 "SELECT v._id, v._score "
151164 + "FROM vectorSearch(table='"
152165 + TEST_INDEX
153166 + "', field='embedding', "
154- + "vector='[1.0, 1.0]', option='k=5 ,filter_type=efficient') AS v "
167+ + "vector='[1.0, 1.0]', option='k=3 ,filter_type=efficient') AS v "
155168 + "WHERE v.state = 'CA' "
156169 + "LIMIT 5" );
157170
158- JSONArray rows = result .getJSONArray ("datarows" );
159- assertTrue ("Expected at least one row:\n " + result , rows .length () > 0 );
160- for (int i = 0 ; i < rows .length (); i ++) {
161- String id = rows .getJSONArray (i ).getString (0 );
162- assertTrue (
163- "Row " + i + " id=" + id + " should be from CA cluster (4,5,6):\n " + result ,
164- id .equals ("4" ) || id .equals ("5" ) || id .equals ("6" ));
165- }
171+ assertRowIdsEqual (result , "4" , "5" , "6" );
166172 }
167173
168174 // ── Radial happy paths ──────────────────────────────────────────────
169175
170176 @ Test
171177 public void testRadialMaxDistanceReturnsOnlyNearDocs () throws IOException {
172- // max_distance=1.0 (L2) centered on [1,1] should pick up cluster A docs and exclude
173- // cluster B which is ~11 units away.
178+ // max_distance=1.0 (L2) centered on [1,1] includes all 3 cluster A docs (max L2 ≈ 0.22)
179+ // and excludes cluster B which is ~11 units away.
174180 JSONObject result =
175181 executeJdbcRequest (
176182 "SELECT v._id "
@@ -180,20 +186,13 @@ public void testRadialMaxDistanceReturnsOnlyNearDocs() throws IOException {
180186 + "vector='[1.0, 1.0]', option='max_distance=1.0') AS v "
181187 + "LIMIT 10" );
182188
183- JSONArray rows = result .getJSONArray ("datarows" );
184- assertTrue ("Expected at least one row:\n " + result , rows .length () > 0 );
185- for (int i = 0 ; i < rows .length (); i ++) {
186- String id = rows .getJSONArray (i ).getString (0 );
187- assertTrue (
188- "Row " + i + " id=" + id + " should be within max_distance of cluster A:\n " + result ,
189- id .equals ("1" ) || id .equals ("2" ) || id .equals ("3" ));
190- }
189+ assertRowIdsEqual (result , "1" , "2" , "3" );
191190 }
192191
193192 @ Test
194193 public void testRadialMinScoreReturnsOnlyHighScoreDocs () throws IOException {
195194 // For L2 space, OpenSearch score = 1/(1+distance). Centered on [1,1], cluster A docs
196- // score ~0.8 -1.0 and cluster B scores ~0.08. min_score=0.5 should exclude cluster B .
195+ // score ~0.82 -1.0 and cluster B scores ~0.08. min_score=0.5 yields exactly {1,2,3} .
197196 JSONObject result =
198197 executeJdbcRequest (
199198 "SELECT v._id, v._score "
@@ -204,16 +203,23 @@ public void testRadialMinScoreReturnsOnlyHighScoreDocs() throws IOException {
204203 + "LIMIT 10" );
205204
206205 JSONArray rows = result .getJSONArray ("datarows" );
207- assertTrue ("Expected at least one row:\n " + result , rows .length () > 0 );
208206 for (int i = 0 ; i < rows .length (); i ++) {
209- String id = rows .getJSONArray (i ).getString (0 );
210207 double score = rows .getJSONArray (i ).getDouble (1 );
211- assertTrue (
212- "Row " + i + " id=" + id + " score=" + score + " should be >= 0.5:\n " + result ,
213- score >= 0.5 );
214- assertTrue (
215- "Row " + i + " id=" + id + " should be from cluster A:\n " + result ,
216- id .equals ("1" ) || id .equals ("2" ) || id .equals ("3" ));
208+ assertTrue ("Row " + i + " score=" + score + " should be >= 0.5:\n " + result , score >= 0.5 );
209+ }
210+ assertRowIdsEqual (result , "1" , "2" , "3" );
211+ }
212+
213+ /** Asserts the result's datarows column 0 contains exactly the given ids (as a set). */
214+ private static void assertRowIdsEqual (JSONObject result , String ... expectedIds ) {
215+ JSONArray rows = result .getJSONArray ("datarows" );
216+ assertEquals (
217+ "Expected " + expectedIds .length + " rows:\n " + result , expectedIds .length , rows .length ());
218+ Set <String > expected = new HashSet <>(Arrays .asList (expectedIds ));
219+ Set <String > actual = new HashSet <>();
220+ for (int i = 0 ; i < rows .length (); i ++) {
221+ actual .add (rows .getJSONArray (i ).getString (0 ));
217222 }
223+ assertEquals ("Row id set mismatch:\n " + result , expected , actual );
218224 }
219225}
0 commit comments