diff --git a/examples/src/main/java/io/milvus/v1/BinaryVectorExample.java b/examples/src/main/java/io/milvus/v1/BinaryVectorExample.java index b561149f1..7a4c6f4a7 100644 --- a/examples/src/main/java/io/milvus/v1/BinaryVectorExample.java +++ b/examples/src/main/java/io/milvus/v1/BinaryVectorExample.java @@ -38,10 +38,10 @@ public class BinaryVectorExample { private static final String COLLECTION_NAME = "java_sdk_example_binary_vector_v1"; - private static final String ID_FIELD = "id"; + private static final String ID_FIELD = "pk"; private static final String VECTOR_FIELD = "vector"; - private static final Integer VECTOR_DIM = 512; + private static final Integer VECTOR_DIM = 128; public static void main(String[] args) { @@ -152,6 +152,8 @@ public static void main(String[] args) { Random ran = new Random(); int k = ran.nextInt(rowCount); ByteBuffer targetVector = vectors.get(k); + System.out.printf("\nANN search for vector ID=%d:\n", k); + CommonUtils.printBinaryVector(targetVector); R searchRet = milvusClient.search(SearchParam.newBuilder() .withCollectionName(COLLECTION_NAME) .withMetricType(MetricType.HAMMING) @@ -169,13 +171,9 @@ public static void main(String[] args) { List scores = resultsWrapper.getIDScore(0); System.out.printf("The result of No.%d target vector:\n", i); for (SearchResultsWrapper.IDScore score : scores) { - System.out.printf("ID: %d, Score: %f, Vector: ", score.getLongID(), score.getScore()); + System.out.println(score); ByteBuffer vector = (ByteBuffer)score.get(VECTOR_FIELD); - vector.rewind(); - while (vector.hasRemaining()) { - System.out.print(Integer.toBinaryString(vector.get())); - } - System.out.println(); + CommonUtils.printBinaryVector(vector); } if (scores.get(0).getLongID() != k) { throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d", @@ -188,7 +186,7 @@ public static void main(String[] args) { int n = 99; R queryR = milvusClient.query(QueryParam.newBuilder() .withCollectionName(COLLECTION_NAME) - .withExpr(String.format("id == %d", n)) + .withExpr(String.format("%s == %d", ID_FIELD, n)) .addOutField(VECTOR_FIELD) .build()); CommonUtils.handleResponseStatus(queryR); diff --git a/examples/src/main/java/io/milvus/v1/CommonUtils.java b/examples/src/main/java/io/milvus/v1/CommonUtils.java index a73efe9aa..d12ab2f06 100644 --- a/examples/src/main/java/io/milvus/v1/CommonUtils.java +++ b/examples/src/main/java/io/milvus/v1/CommonUtils.java @@ -96,6 +96,15 @@ public static List generateBinaryVectors(int dimension, int count) { return vectors; } + public static void printBinaryVector(ByteBuffer vector) { + vector.rewind(); + while (vector.hasRemaining()) { + String byteStr = String.format("%8s", Integer.toBinaryString(vector.get())).replace(' ', '0'); + System.out.print(byteStr); + } + System.out.println(); + } + ///////////////////////////////////////////////////////////////////////////////////////////////////// public static TBfloat16 genTensorflowBF16Vector(int dimension) { Random ran = new Random(); @@ -135,7 +144,7 @@ public static List encodeTensorBF16Vectors(List vectors) return buffers; } - public static TBfloat16 decodeTensorBF16Vector(ByteBuffer buf) { + public static TBfloat16 decodeBF16VectorToTensor(ByteBuffer buf) { if (buf.limit()%2 != 0) { return null; } @@ -144,6 +153,15 @@ public static TBfloat16 decodeTensorBF16Vector(ByteBuffer buf) { return Tensor.of(TBfloat16.class, Shape.of(dim), bf); } + public static List decodeBF16VectorToFloat(ByteBuffer buf) { + List vector = new ArrayList<>(); + TBfloat16 tf = decodeBF16VectorToTensor(buf); + for (long i = 0; i < tf.size(); i++) { + vector.add(tf.getFloat(i)); + } + return vector; + } + public static TFloat16 genTensorflowFP16Vector(int dimension) { Random ran = new Random(); @@ -183,7 +201,7 @@ public static List encodeTensorFP16Vectors(List vectors) { return buffers; } - public static TFloat16 decodeTensorFP16Vector(ByteBuffer buf) { + public static TFloat16 decodeFP16VectorToTensor(ByteBuffer buf) { if (buf.limit()%2 != 0) { return null; } @@ -192,6 +210,15 @@ public static TFloat16 decodeTensorFP16Vector(ByteBuffer buf) { return Tensor.of(TFloat16.class, Shape.of(dim), bf); } + public static List decodeFP16VectorToFloat(ByteBuffer buf) { + List vector = new ArrayList<>(); + TFloat16 tf = decodeFP16VectorToTensor(buf); + for (long i = 0; i < tf.size(); i++) { + vector.add(tf.getFloat(i)); + } + return vector; + } + ///////////////////////////////////////////////////////////////////////////////////////////////////// public static ByteBuffer encodeFloat16Vector(List originVector, boolean bfloat16) { if (bfloat16) { diff --git a/examples/src/main/java/io/milvus/v1/Float16VectorExample.java b/examples/src/main/java/io/milvus/v1/Float16VectorExample.java index 957f2d979..596561753 100644 --- a/examples/src/main/java/io/milvus/v1/Float16VectorExample.java +++ b/examples/src/main/java/io/milvus/v1/Float16VectorExample.java @@ -201,9 +201,6 @@ private static void testFloat16(boolean bfloat16) { SearchResultsWrapper resultsWrapper = new SearchResultsWrapper(searchRet.getData().getResults()); List scores = resultsWrapper.getIDScore(0); System.out.printf("The result of No.%d target vector:\n", i); - for (SearchResultsWrapper.IDScore score : scores) { - System.out.println(score); - } SearchResultsWrapper.IDScore firstScore = scores.get(0); if (firstScore.getLongID() != k) { @@ -223,6 +220,9 @@ private static void testFloat16(boolean bfloat16) { throw new RuntimeException(String.format("The output vector is not equal to original vector: ID %d", k)); } } + System.out.println("\nTarget vector: " + originVector); + System.out.println("Top0 result: " + firstScore); + System.out.println("Top0 result vector: " + outputVector); } System.out.println("Search result is correct"); @@ -316,19 +316,13 @@ private static void testTensorflowFloat16(boolean bfloat16) { throw new RuntimeException("The query result is incorrect"); } - List vector = new ArrayList<>(); + List outVector; if (bfloat16) { - TBfloat16 tf = CommonUtils.decodeTensorBF16Vector(outputBuf); - for (long i = 0; i < tf.size(); i++) { - vector.add(tf.getFloat(i)); - } + outVector = CommonUtils.decodeBF16VectorToFloat(outputBuf); } else { - TFloat16 tf = CommonUtils.decodeTensorFP16Vector(outputBuf); - for (long i = 0; i < tf.size(); i++) { - vector.add(tf.getFloat(i)); - } + outVector = CommonUtils.decodeFP16VectorToFloat(outputBuf); } - System.out.println(vector); + System.out.println("Output vector: " + outVector); System.out.println("Query result is correct"); // drop the collection if you don't need the collection anymore diff --git a/examples/src/main/java/io/milvus/v1/GeneralExample.java b/examples/src/main/java/io/milvus/v1/GeneralExample.java index b279fc87c..52b959838 100644 --- a/examples/src/main/java/io/milvus/v1/GeneralExample.java +++ b/examples/src/main/java/io/milvus/v1/GeneralExample.java @@ -349,9 +349,9 @@ private R searchFace(String expr) { for (int i = 0; i < vectors.size(); ++i) { System.out.println("Search result of No." + i); List scores = wrapper.getIDScore(i); - System.out.println(scores); - System.out.println("Output field data for No." + i); - System.out.println(wrapper.getFieldData(AGE_FIELD, i)); + for (SearchResultsWrapper.IDScore score : scores) { + System.out.println(score); + } } return response; diff --git a/examples/src/main/java/io/milvus/v1/SparseVectorExample.java b/examples/src/main/java/io/milvus/v1/SparseVectorExample.java index f89924968..58277678b 100644 --- a/examples/src/main/java/io/milvus/v1/SparseVectorExample.java +++ b/examples/src/main/java/io/milvus/v1/SparseVectorExample.java @@ -149,6 +149,7 @@ public static void main(String[] args) { Random ran = new Random(); int k = ran.nextInt(rowCount); SortedMap targetVector = vectors.get(k); + System.out.println("\nTarget vector: " + targetVector); R searchRet = milvusClient.search(SearchParam.newBuilder() .withCollectionName(COLLECTION_NAME) .withMetricType(MetricType.IP) diff --git a/examples/src/main/java/io/milvus/v2/BinaryVectorExample.java b/examples/src/main/java/io/milvus/v2/BinaryVectorExample.java index c962e2c75..8cb1a5e02 100644 --- a/examples/src/main/java/io/milvus/v2/BinaryVectorExample.java +++ b/examples/src/main/java/io/milvus/v2/BinaryVectorExample.java @@ -42,10 +42,10 @@ public class BinaryVectorExample { private static final String COLLECTION_NAME = "java_sdk_example_binary_vector_v2"; - private static final String ID_FIELD = "id"; + private static final String ID_FIELD = "pk"; private static final String VECTOR_FIELD = "vector"; - private static final Integer VECTOR_DIM = 512; + private static final Integer VECTOR_DIM = 128; public static void main(String[] args) { @@ -126,6 +126,8 @@ public static void main(String[] args) { Random ran = new Random(); int k = ran.nextInt(rowCount); ByteBuffer targetVector = vectors.get(k); + System.out.printf("\nANN search for vector ID=%d:\n", k); + CommonUtils.printBinaryVector(targetVector); Map params = new HashMap<>(); params.put("nprobe",16); SearchResp searchResp = client.search(SearchReq.builder() @@ -141,16 +143,11 @@ public static void main(String[] args) { // Here we only input one vector to search, get the result of No.0 vector to check List> searchResults = searchResp.getSearchResults(); List results = searchResults.get(0); - System.out.printf("The result of No.%d target vector:\n", i); + System.out.printf("The result of No.%d target vector, ID=%d:\n", i, k); for (SearchResp.SearchResult result : results) { - System.out.println(result.getEntity()); - System.out.printf("ID: %d, Score: %f, Vector: ", result.getId(), result.getScore()); + System.out.println(result); ByteBuffer vector = (ByteBuffer) result.getEntity().get(VECTOR_FIELD); - vector.rewind(); - while (vector.hasRemaining()) { - System.out.print(Integer.toBinaryString(vector.get())); - } - System.out.println(); + CommonUtils.printBinaryVector(vector); } SearchResp.SearchResult firstResult = results.get(0); @@ -165,7 +162,7 @@ public static void main(String[] args) { int n = 99; QueryResp queryResp = client.query(QueryReq.builder() .collectionName(COLLECTION_NAME) - .filter(String.format("id == %d", n)) + .filter(String.format("%s == %d", ID_FIELD, n)) .outputFields(Collections.singletonList(VECTOR_FIELD)) .build()); diff --git a/examples/src/main/java/io/milvus/v2/Float16VectorExample.java b/examples/src/main/java/io/milvus/v2/Float16VectorExample.java index 0ae38293b..cf522f4c7 100644 --- a/examples/src/main/java/io/milvus/v2/Float16VectorExample.java +++ b/examples/src/main/java/io/milvus/v2/Float16VectorExample.java @@ -162,9 +162,12 @@ private static void searchVectors(List taargetIDs, List target } Map entity = topResult.getEntity(); ByteBuffer vectorBuf = (ByteBuffer) entity.get(vectorFieldName); - if (!vectorBuf.equals(targetVectors.get(i).getData())) { + ByteBuffer targetVectorBuf = (ByteBuffer)targetVectors.get(i).getData(); + if (!vectorBuf.equals(targetVectorBuf)) { throw new RuntimeException("The top1 output vector is incorrect"); } + List decodedTargetVector = CommonUtils.decodeFloat16Vector(targetVectorBuf, + BF16_VECTOR_FIELD.equals(vectorFieldName)); // The method for converting float16 vector to float32 vector can be found in // CommonUtils. List decodedFpVector = CommonUtils.decodeFloat16Vector(vectorBuf, @@ -172,7 +175,9 @@ private static void searchVectors(List taargetIDs, List target if (decodedFpVector.size() != VECTOR_DIM) { throw new RuntimeException("The decoded vector dimension is incorrect"); } - System.out.println(results.get(0)); + System.out.println("\nTarget vector: " + decodedTargetVector); + System.out.println("Top0 result: " + topResult); + System.out.println("Top0 result vector: " + decodedFpVector); } System.out.println("Search result of " + vectorFieldName + " is correct"); } diff --git a/examples/src/main/java/io/milvus/v2/FullTextSearchExample.java b/examples/src/main/java/io/milvus/v2/FullTextSearchExample.java index 6cc622dd8..ca1a8f96c 100644 --- a/examples/src/main/java/io/milvus/v2/FullTextSearchExample.java +++ b/examples/src/main/java/io/milvus/v2/FullTextSearchExample.java @@ -38,7 +38,7 @@ private static void searchByText(MilvusClientV2 client, String text) { List> searchResults = searchResp.getSearchResults(); for (List results : searchResults) { for (SearchResp.SearchResult result : results) { - System.out.printf("ID: %d, Score: %f, %s\n", (long)result.getId(), result.getScore(), result.getEntity().toString()); + System.out.println(result); } } System.out.println("============================================================="); diff --git a/examples/src/main/java/io/milvus/v2/GeneralExample.java b/examples/src/main/java/io/milvus/v2/GeneralExample.java index b82d3f108..16ee02ab0 100644 --- a/examples/src/main/java/io/milvus/v2/GeneralExample.java +++ b/examples/src/main/java/io/milvus/v2/GeneralExample.java @@ -205,7 +205,7 @@ private static void searchFace(String filter) { for (List results : searchResults) { System.out.println("Search result of No." + i++); for (SearchResp.SearchResult result : results) { - System.out.printf("ID: %s, Score: %f, %s\n", result.getId(), result.getScore(), result.getEntity().toString()); + System.out.println(result); } } } diff --git a/examples/src/main/java/io/milvus/v2/HybridSearchExample.java b/examples/src/main/java/io/milvus/v2/HybridSearchExample.java index 7fc62f5e1..73a37977c 100644 --- a/examples/src/main/java/io/milvus/v2/HybridSearchExample.java +++ b/examples/src/main/java/io/milvus/v2/HybridSearchExample.java @@ -223,7 +223,7 @@ private void hybridSearch() { System.out.printf("============= Search result of No.%d vector =============\n", i); List results = searchResults.get(i); for (SearchResp.SearchResult result : results) { - System.out.printf("{id: %d, score: %f}%n", result.getId(), result.getScore()); + System.out.println(result); } } } diff --git a/examples/src/main/java/io/milvus/v2/Int8VectorExample.java b/examples/src/main/java/io/milvus/v2/Int8VectorExample.java index 3df579d8a..f77dc8b19 100644 --- a/examples/src/main/java/io/milvus/v2/Int8VectorExample.java +++ b/examples/src/main/java/io/milvus/v2/Int8VectorExample.java @@ -136,10 +136,9 @@ public static void main(String[] args) { List results = searchResults.get(0); System.out.printf("\nThe result of No.%d vector %s:\n", k, Arrays.toString(targetVector.array())); for (SearchResp.SearchResult result : results) { - System.out.printf("ID: %d, Score: %f, Vector: ", (long)result.getId(), result.getScore()); + System.out.println(result); ByteBuffer vector = (ByteBuffer) result.getEntity().get(VECTOR_FIELD); - System.out.print(Arrays.toString(vector.array())); - System.out.println(); + System.out.println(Arrays.toString(vector.array())); } SearchResp.SearchResult firstResult = results.get(0); diff --git a/examples/src/main/java/io/milvus/v2/SparseVectorExample.java b/examples/src/main/java/io/milvus/v2/SparseVectorExample.java index f408c85dc..de129963f 100644 --- a/examples/src/main/java/io/milvus/v2/SparseVectorExample.java +++ b/examples/src/main/java/io/milvus/v2/SparseVectorExample.java @@ -119,6 +119,7 @@ public static void main(String[] args) { Random ran = new Random(); int k = ran.nextInt(rowCount); SortedMap targetVector = vectors.get(k); + System.out.println("\nTarget vector: " + targetVector); Map params = new HashMap<>(); params.put("drop_ratio_search",0.2); SearchResp searchResp = client.search(SearchReq.builder() @@ -136,7 +137,7 @@ public static void main(String[] args) { List results = searchResults.get(0); System.out.printf("The result of No.%d target vector:\n", i); for (SearchResp.SearchResult result : results) { - System.out.println(result.getEntity()); + System.out.println(result); } SearchResp.SearchResult firstResult = results.get(0); diff --git a/examples/src/main/java/io/milvus/v2/TextMatchExample.java b/examples/src/main/java/io/milvus/v2/TextMatchExample.java index 15cdbebb9..963f701de 100644 --- a/examples/src/main/java/io/milvus/v2/TextMatchExample.java +++ b/examples/src/main/java/io/milvus/v2/TextMatchExample.java @@ -36,7 +36,7 @@ private static void queryWithFilter(MilvusClientV2 client, String filter) { System.out.println("\nQuery with filter: " + filter); List records = queryRet.getQueryResults(); for (QueryResp.QueryResult record : records) { - System.out.println(record.getEntity()); + System.out.println(record); } System.out.printf("%d items matched%n", records.size()); System.out.println("============================================================="); diff --git a/sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java b/sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java index 8e9bc51a7..2242b13a5 100644 --- a/sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java +++ b/sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java @@ -27,6 +27,7 @@ import io.milvus.response.basic.RowRecordWrapper; import lombok.Getter; import lombok.NonNull; +import org.apache.commons.lang3.StringUtils; import java.util.ArrayList; import java.util.HashMap; @@ -161,6 +162,7 @@ public List getIDScore(int indexOfTarget) throws ParamException, Illega // set id and score IDs ids = results.getIds(); + String pkName = results.getPrimaryFieldName(); if (ids.hasIntId()) { LongArray longIDs = ids.getIntId(); if (offset + k > longIDs.getDataCount()) { @@ -168,7 +170,7 @@ public List getIDScore(int indexOfTarget) throws ParamException, Illega } for (int n = 0; n < k; ++n) { - idScores.add(new IDScore("", longIDs.getData((int)offset + n), results.getScores((int)offset + n))); + idScores.add(new IDScore(pkName, "", longIDs.getData((int)offset + n), results.getScores((int)offset + n))); } } else if (ids.hasStrId()) { StringArray strIDs = ids.getStrId(); @@ -177,7 +179,7 @@ public List getIDScore(int indexOfTarget) throws ParamException, Illega } for (int n = 0; n < k; ++n) { - idScores.add(new IDScore(strIDs.getData((int)offset + n), 0, results.getScores((int)offset + n))); + idScores.add(new IDScore(pkName, strIDs.getData((int)offset + n), 0, results.getScores((int)offset + n))); } } else { // in v2.3.3, return an empty list instead of throwing exception @@ -272,12 +274,14 @@ private Position getOffsetByIndex(int indexOfTarget) { */ @Getter public static final class IDScore { + private final String primaryKey; private final String strID; private final long longID; private final float score; Map fieldValues = new HashMap<>(); - public IDScore(String strID, long longID, float score) { + public IDScore(String primaryKey, String strID, long longID, float score) { + this.primaryKey = primaryKey; this.strID = strID; this.longID = longID; this.score = score; @@ -333,16 +337,12 @@ public boolean contains(String keyName) { @Override public String toString() { - List pairs = new ArrayList<>(); - fieldValues.forEach((keyName, fieldValue) -> { - pairs.add(keyName + ":" + fieldValue); - }); - - if (strID.isEmpty()) { - return "(ID: " + getLongID() + " Score: " + getScore() + " OutputFields: " + pairs + ")"; - } else { - return "(ID: '" + getStrID() + "' Score: " + getScore()+ " OutputFields: " + pairs + ")"; + Object id = strID; + if (StringUtils.isEmpty(strID)) { + id = longID; } + + return "{" + getPrimaryKey() + ": " + id + ", Score: " + getScore() + ", OutputFields: " + fieldValues + "}"; } } } diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/response/SearchResp.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/response/SearchResp.java index c0ff3601a..193e8ce9a 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/response/SearchResp.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/response/SearchResp.java @@ -45,5 +45,12 @@ public static class SearchResult { private Map entity = new HashMap<>(); private Float score; private Object id; + @Builder.Default + private String primaryKey = "id"; + + @Override + public String toString() { + return "{" + getPrimaryKey() + ": " + getId() + ", Score: " + getScore() + ", OutputFields: " + entity + "}"; + } } } diff --git a/sdk-core/src/main/java/io/milvus/v2/utils/ConvertUtils.java b/sdk-core/src/main/java/io/milvus/v2/utils/ConvertUtils.java index 9c7e19b92..8bac91fed 100644 --- a/sdk-core/src/main/java/io/milvus/v2/utils/ConvertUtils.java +++ b/sdk-core/src/main/java/io/milvus/v2/utils/ConvertUtils.java @@ -75,6 +75,7 @@ public List> getEntities(SearchResults response) { searchResults.add(searchResultsWrapper.getIDScore(i).stream().map(idScore -> SearchResp.SearchResult.builder() .entity(idScore.getFieldValues()) .score(idScore.getScore()) + .primaryKey(idScore.getPrimaryKey()) .id(idScore.getStrID().isEmpty() ? idScore.getLongID() : idScore.getStrID()) .build()).collect(Collectors.toList())); }