Skip to content

Commit 4d43486

Browse files
authored
Fix a bug of SearchResultsWrapper.getRowRecords() that returns wrong data for output fields (#1444)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent 2ca46bb commit 4d43486

5 files changed

Lines changed: 155 additions & 15 deletions

File tree

examples/src/main/java/io/milvus/v1/CommonUtils.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@ public static List<List<Float>> generateFixFloatVectors(int dimension, int count
7575
return vectors;
7676
}
7777

78+
public static void compareFloatVectors(List<Float> vec1, List<Float> vec2) {
79+
if (vec1.size() != vec2.size()) {
80+
throw new RuntimeException(String.format("Vector dimension mismatch: %d vs %d", vec1.size(), vec2.size()));
81+
}
82+
for (int i = 0; i < vec1.size(); i++) {
83+
if (Math.abs(vec1.get(i) - vec2.get(i)) > 0.001f) {
84+
throw new RuntimeException(String.format("Vector value mismatch: %f vs %f at No.%d value",
85+
vec1.get(i), vec2.get(i), i));
86+
}
87+
}
88+
}
89+
7890
/////////////////////////////////////////////////////////////////////////////////////////////////////
7991
public static ByteBuffer generateBinaryVector(int dimension) {
8092
Random ran = new Random();
@@ -281,5 +293,4 @@ public static List<SortedMap<Long, Float>> generateSparseVectors(int count) {
281293
}
282294
return vectors;
283295
}
284-
285296
}

examples/src/main/java/io/milvus/v1/JsonFieldExample.java

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,30 @@
2626
import io.milvus.common.clientenum.ConsistencyLevelEnum;
2727
import io.milvus.grpc.DataType;
2828
import io.milvus.grpc.QueryResults;
29+
import io.milvus.grpc.SearchResults;
2930
import io.milvus.param.*;
3031
import io.milvus.param.collection.*;
3132
import io.milvus.param.dml.InsertParam;
3233
import io.milvus.param.dml.QueryParam;
34+
import io.milvus.param.dml.SearchParam;
3335
import io.milvus.param.index.CreateIndexParam;
3436
import io.milvus.response.QueryResultsWrapper;
37+
import io.milvus.response.SearchResultsWrapper;
3538

39+
import java.util.ArrayList;
3640
import java.util.Arrays;
3741
import java.util.Collections;
3842
import java.util.List;
3943

4044
public class JsonFieldExample {
4145
private static final String COLLECTION_NAME = "java_sdk_example_json_v1";
42-
private static final String ID_FIELD = "id";
46+
private static final String ID_FIELD = "key";
4347
private static final String VECTOR_FIELD = "vector";
4448
private static final String JSON_FIELD = "metadata";
4549
private static final Integer VECTOR_DIM = 128;
4650

4751
private static void queryWithExpr(MilvusClient client, String expr) {
52+
System.out.printf("%n=============================Query with expr: '%s'================================%n", expr);
4853
R<QueryResults> queryRet = client.query(QueryParam.newBuilder()
4954
.withCollectionName(COLLECTION_NAME)
5055
.withExpr(expr)
@@ -56,7 +61,6 @@ private static void queryWithExpr(MilvusClient client, String expr) {
5661
for (QueryResultsWrapper.RowRecord record : records) {
5762
System.out.println(record);
5863
}
59-
System.out.println("=============================================================");
6064
}
6165

6266
public static void main(String[] args) {
@@ -123,22 +127,28 @@ public static void main(String[] args) {
123127
System.out.println("Collection created");
124128

125129
// insert rows
130+
List<List<Float>> vectors = new ArrayList<>();
131+
List<JsonObject> metadatas = new ArrayList<>();
126132
Gson gson = new Gson();
127133
for (int i = 0; i < 100; i++) {
128134
JsonObject row = new JsonObject();
129135
row.addProperty(ID_FIELD, i);
130-
row.add(VECTOR_FIELD, gson.toJsonTree(CommonUtils.generateFloatVector(VECTOR_DIM)));
136+
List<Float> vector = CommonUtils.generateFloatVector(VECTOR_DIM);
137+
row.add(VECTOR_FIELD, gson.toJsonTree(vector));
138+
vectors.add(vector);
131139

132140
// Note: for JSON field, always construct a real JsonObject
133141
// don't use row.addProperty(JSON_FIELD, strContent) since the value is treated as a string, not a JsonObject
134142
JsonObject metadata = new JsonObject();
135-
metadata.addProperty("path", String.format("\\root/abc/path%d", i));
143+
metadata.addProperty("path", String.format("\\root/abc/path_%d", i));
136144
metadata.addProperty("size", i);
137145
if (i%7 == 0) {
138146
metadata.addProperty("special", true);
139147
}
148+
140149
metadata.add("flags", gson.toJsonTree(Arrays.asList(i, i + 1, i + 2)));
141150
row.add(JSON_FIELD, metadata);
151+
metadatas.add(metadata);
142152
// System.out.println(metadata);
143153

144154
// dynamic fields
@@ -165,6 +175,65 @@ public static void main(String[] args) {
165175
long rowCount = (long)queryWrapper.getFieldWrapper("count(*)").getFieldData().get(0);
166176
System.out.printf("%d rows persisted\n", rowCount);
167177

178+
// search and output JSON field
179+
List<List<Float>> searchVectors = new ArrayList<>();
180+
List<JsonObject> expectedMetadatas = new ArrayList<>();
181+
for (int i = 0; i < 10; i++) {
182+
List<Float> targetVector = vectors.get(i);
183+
searchVectors.add(targetVector);
184+
expectedMetadatas.add(metadatas.get(i));
185+
}
186+
R<SearchResults> searchRet = client.search(SearchParam.newBuilder()
187+
.withCollectionName(COLLECTION_NAME)
188+
.withLimit(3L)
189+
.withFloatVectors(searchVectors)
190+
.withVectorFieldName(VECTOR_FIELD)
191+
.addOutField(ID_FIELD)
192+
.addOutField(VECTOR_FIELD)
193+
.addOutField(JSON_FIELD)
194+
.build());
195+
CommonUtils.handleResponseStatus(searchRet);
196+
197+
SearchResultsWrapper resultsWrapper = new SearchResultsWrapper(searchRet.getData().getResults());
198+
System.out.println("\n=============================Search result with IDScore================================");
199+
for (int i = 0; i < 10; i++) {
200+
List<SearchResultsWrapper.IDScore> scores = resultsWrapper.getIDScore(i);
201+
System.out.printf("\nThe result of No.%d target vector:\n", i);
202+
for (SearchResultsWrapper.IDScore score : scores) {
203+
System.out.println(score);
204+
}
205+
long pk = scores.get(0).getLongID();
206+
if (pk != i) {
207+
throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d", pk, i));
208+
}
209+
JsonObject metadata = (JsonObject) scores.get(0).get(JSON_FIELD);
210+
if (!metadata.equals(expectedMetadatas.get(i))) {
211+
throw new RuntimeException(String.format("The top1 metadata %s is not equal to target metadata %s",
212+
metadata, expectedMetadatas.get(i)));
213+
}
214+
List<Float> vector = (List<Float>) scores.get(0).get(VECTOR_FIELD);
215+
CommonUtils.compareFloatVectors(vector, searchVectors.get(i));
216+
}
217+
System.out.println("\n=============================Search result with RowRecord================================");
218+
for (int i = 0; i < 10; i++) {
219+
List<QueryResultsWrapper.RowRecord> records = resultsWrapper.getRowRecords(i);
220+
System.out.printf("\nThe result of No.%d target vector:\n", i);
221+
for (QueryResultsWrapper.RowRecord record : records) {
222+
System.out.println(record);
223+
}
224+
long pk = (long)records.get(0).get(ID_FIELD);
225+
if (pk != i) {
226+
throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d", pk, i));
227+
}
228+
JsonObject metadata = (JsonObject) records.get(0).get(JSON_FIELD);
229+
if (!metadata.equals(expectedMetadatas.get(i))) {
230+
throw new RuntimeException(String.format("The top1 metadata %s is not equal to target metadata %s",
231+
metadata, expectedMetadatas.get(i)));
232+
}
233+
List<Float> vector = (List<Float>) records.get(0).get(VECTOR_FIELD);
234+
CommonUtils.compareFloatVectors(vector, searchVectors.get(i));
235+
}
236+
168237
// query by filtering JSON
169238
queryWithExpr(client, "exists metadata[\"special\"]");
170239
queryWithExpr(client, "metadata[\"size\"] < 5");

examples/src/main/java/io/milvus/v2/JsonFieldExample.java

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,23 @@
3232
import io.milvus.v2.service.collection.request.DropCollectionReq;
3333
import io.milvus.v2.service.vector.request.InsertReq;
3434
import io.milvus.v2.service.vector.request.QueryReq;
35+
import io.milvus.v2.service.vector.request.SearchReq;
36+
import io.milvus.v2.service.vector.request.data.BaseVector;
37+
import io.milvus.v2.service.vector.request.data.FloatVec;
3538
import io.milvus.v2.service.vector.response.QueryResp;
39+
import io.milvus.v2.service.vector.response.SearchResp;
3640

3741
import java.util.*;
3842

3943
public class JsonFieldExample {
4044
private static final String COLLECTION_NAME = "java_sdk_example_json_v2";
41-
private static final String ID_FIELD = "id";
45+
private static final String ID_FIELD = "key";
4246
private static final String VECTOR_FIELD = "vector";
4347
private static final String JSON_FIELD = "metadata";
4448
private static final Integer VECTOR_DIM = 128;
4549

4650
private static void queryWithExpr(MilvusClientV2 client, String expr) {
51+
System.out.printf("%n=============================Query with expr: '%s'================================%n", expr);
4752
QueryResp queryRet = client.query(QueryReq.builder()
4853
.collectionName(COLLECTION_NAME)
4954
.filter(expr)
@@ -54,7 +59,6 @@ private static void queryWithExpr(MilvusClientV2 client, String expr) {
5459
for (QueryResp.QueryResult record : records) {
5560
System.out.println(record.getEntity());
5661
}
57-
System.out.println("=============================================================");
5862
}
5963

6064
public static void main(String[] args) {
@@ -104,22 +108,27 @@ public static void main(String[] args) {
104108
System.out.println("Collection created");
105109

106110
// Insert rows
111+
List<List<Float>> vectors = new ArrayList<>();
112+
List<JsonObject> metadatas = new ArrayList<>();
107113
Gson gson = new Gson();
108114
for (int i = 0; i < 100; i++) {
109115
JsonObject row = new JsonObject();
110116
row.addProperty(ID_FIELD, i);
111-
row.add(VECTOR_FIELD, gson.toJsonTree(CommonUtils.generateFloatVector(VECTOR_DIM)));
117+
List<Float> vector = CommonUtils.generateFloatVector(VECTOR_DIM);
118+
row.add(VECTOR_FIELD, gson.toJsonTree(vector));
119+
vectors.add(vector);
112120

113121
// Note: for JSON field, always construct a real JsonObject
114122
// don't use row.addProperty(JSON_FIELD, strContent) since the value is treated as a string, not a JsonObject
115123
JsonObject metadata = new JsonObject();
116-
metadata.addProperty("path", String.format("\\root/abc/path%d", i));
124+
metadata.addProperty("path", String.format("\\root/abc/path_%d", i));
117125
metadata.addProperty("size", i);
118126
if (i%7 == 0) {
119127
metadata.addProperty("special", true);
120128
}
121129
metadata.add("flags", gson.toJsonTree(Arrays.asList(i, i + 1, i + 2)));
122130
row.add(JSON_FIELD, metadata);
131+
metadatas.add(metadata);
123132
// System.out.println(metadata);
124133

125134
// dynamic fields
@@ -144,6 +153,44 @@ public static void main(String[] args) {
144153
.build());
145154
System.out.printf("%d rows persisted\n", (long)countR.getQueryResults().get(0).getEntity().get("count(*)"));
146155

156+
// Search and output JSON field
157+
List<BaseVector> searchVectors = new ArrayList<>();
158+
List<JsonObject> expectedMetadatas = new ArrayList<>();
159+
for (int i = 0; i < 10; i++) {
160+
List<Float> targetVector = vectors.get(i);
161+
searchVectors.add(new FloatVec(targetVector));
162+
expectedMetadatas.add(metadatas.get(i));
163+
}
164+
SearchResp searchRet = client.search(SearchReq.builder()
165+
.collectionName(COLLECTION_NAME)
166+
.data(searchVectors)
167+
.limit(3L)
168+
.annsField(VECTOR_FIELD)
169+
.outputFields(Arrays.asList(ID_FIELD, VECTOR_FIELD, JSON_FIELD))
170+
.build());
171+
172+
System.out.println("\n=============================Search result================================");
173+
List<List<SearchResp.SearchResult>> searchResults = searchRet.getSearchResults();
174+
for (int i = 0; i < 10; i++) {
175+
List<SearchResp.SearchResult> results = searchResults.get(i);
176+
System.out.printf("\nThe result of No.%d target vector:\n", i);
177+
for (SearchResp.SearchResult result : results) {
178+
System.out.println(result);
179+
}
180+
181+
long pk = (long)results.get(0).getId();
182+
if (pk != i) {
183+
throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d", pk, i));
184+
}
185+
JsonObject metadata = (JsonObject) results.get(0).getEntity().get(JSON_FIELD);
186+
if (!metadata.equals(expectedMetadatas.get(i))) {
187+
throw new RuntimeException(String.format("The top1 metadata %s is not equal to target metadata %s",
188+
metadata, expectedMetadatas.get(i)));
189+
}
190+
List<Float> vector = (List<Float>) results.get(0).getEntity().get(VECTOR_FIELD);
191+
CommonUtils.compareFloatVectors(vector, (List<Float>)searchVectors.get(i).getData());
192+
}
193+
147194
// Query by filtering JSON
148195
queryWithExpr(client, "exists metadata[\"special\"]");
149196
queryWithExpr(client, "metadata[\"size\"] < 5");

sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,11 @@
4040
public class SearchResultsWrapper extends RowRecordWrapper {
4141
private final SearchResultData results;
4242

43+
private String primaryKey = "id";
44+
4345
public SearchResultsWrapper(@NonNull SearchResultData results) {
4446
this.results = results;
47+
this.primaryKey = results.getPrimaryFieldName();
4548
}
4649

4750
/**
@@ -86,13 +89,13 @@ public List<QueryResultsWrapper.RowRecord> getRowRecords(int indexOfTarget) {
8689
IDScore score = idScore.get(i);
8790
QueryResultsWrapper.RowRecord record = new QueryResultsWrapper.RowRecord();
8891
if (score.getStrID().isEmpty()) {
89-
record.put("id", score.getLongID());
92+
record.put(primaryKey, score.getLongID());
9093
} else {
91-
record.put("id", score.getStrID());
94+
record.put(primaryKey, score.getStrID());
9295
}
9396

9497
record.put("score", score.getScore()); // use score instead
95-
buildRowRecord(record, i);
98+
buildRowRecord(record, indexOfTarget*topK + (long)i);
9699
records.add(record);
97100
}
98101
return records;
@@ -162,15 +165,14 @@ public List<IDScore> getIDScore(int indexOfTarget) throws ParamException, Illega
162165

163166
// set id and score
164167
IDs ids = results.getIds();
165-
String pkName = results.getPrimaryFieldName();
166168
if (ids.hasIntId()) {
167169
LongArray longIDs = ids.getIntId();
168170
if (offset + k > longIDs.getDataCount()) {
169171
throw new IllegalResponseException("Result ids count is wrong");
170172
}
171173

172174
for (int n = 0; n < k; ++n) {
173-
idScores.add(new IDScore(pkName, "", longIDs.getData((int)offset + n), results.getScores((int)offset + n)));
175+
idScores.add(new IDScore(primaryKey, "", longIDs.getData((int)offset + n), results.getScores((int)offset + n)));
174176
}
175177
} else if (ids.hasStrId()) {
176178
StringArray strIDs = ids.getStrId();
@@ -179,7 +181,7 @@ public List<IDScore> getIDScore(int indexOfTarget) throws ParamException, Illega
179181
}
180182

181183
for (int n = 0; n < k; ++n) {
182-
idScores.add(new IDScore(pkName, strIDs.getData((int)offset + n), 0, results.getScores((int)offset + n)));
184+
idScores.add(new IDScore(primaryKey, strIDs.getData((int)offset + n), 0, results.getScores((int)offset + n)));
183185
}
184186
} else {
185187
// in v2.3.3, return an empty list instead of throwing exception

sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,17 @@ void testFloatVectors() {
678678
for (int k = 0; k < outputVec.size(); k++) {
679679
Assertions.assertEquals(targetVectors.get(i).get(k), outputVec.get(k));
680680
}
681+
682+
// verify the old way
683+
List<QueryResultsWrapper.RowRecord> records = results.getRowRecords(i);
684+
obj = records.get(0).get(DataType.FloatVector.name());
685+
outputVec = (List<Float>)obj;
686+
Assertions.assertEquals(targetVectors.get(i).size(), outputVec.size());
687+
for (int k = 0; k < outputVec.size(); k++) {
688+
Assertions.assertEquals(targetVectors.get(i).get(k), outputVec.get(k));
689+
}
690+
double d = (double)records.get(0).get(DataType.Double.name());
691+
Assertions.assertEquals(d, compareWeights.get(i));
681692
}
682693

683694
List<?> fieldData = results.getFieldData(DataType.Double.name(), 0);

0 commit comments

Comments
 (0)