Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions examples/src/main/java/io/milvus/v1/BinaryVectorExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@

public class BinaryVectorExample {
private static final String COLLECTION_NAME = "java_sdk_example_binary_vector_v1";
private static final String ID_FIELD = "id";
private static final String ID_FIELD = "pk";
private static final String VECTOR_FIELD = "vector";

private static final Integer VECTOR_DIM = 512;
private static final Integer VECTOR_DIM = 128;


public static void main(String[] args) {
Expand Down Expand Up @@ -152,6 +152,8 @@ public static void main(String[] args) {
Random ran = new Random();
int k = ran.nextInt(rowCount);
ByteBuffer targetVector = vectors.get(k);
System.out.printf("\nANN search for vector ID=%d:\n", k);
CommonUtils.printBinaryVector(targetVector);
R<SearchResults> searchRet = milvusClient.search(SearchParam.newBuilder()
.withCollectionName(COLLECTION_NAME)
.withMetricType(MetricType.HAMMING)
Expand All @@ -169,13 +171,9 @@ public static void main(String[] args) {
List<SearchResultsWrapper.IDScore> scores = resultsWrapper.getIDScore(0);
System.out.printf("The result of No.%d target vector:\n", i);
for (SearchResultsWrapper.IDScore score : scores) {
System.out.printf("ID: %d, Score: %f, Vector: ", score.getLongID(), score.getScore());
System.out.println(score);
ByteBuffer vector = (ByteBuffer)score.get(VECTOR_FIELD);
vector.rewind();
while (vector.hasRemaining()) {
System.out.print(Integer.toBinaryString(vector.get()));
}
System.out.println();
CommonUtils.printBinaryVector(vector);
}
if (scores.get(0).getLongID() != k) {
throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
Expand All @@ -188,7 +186,7 @@ public static void main(String[] args) {
int n = 99;
R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
.withCollectionName(COLLECTION_NAME)
.withExpr(String.format("id == %d", n))
.withExpr(String.format("%s == %d", ID_FIELD, n))
.addOutField(VECTOR_FIELD)
.build());
CommonUtils.handleResponseStatus(queryR);
Expand Down
31 changes: 29 additions & 2 deletions examples/src/main/java/io/milvus/v1/CommonUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ public static List<ByteBuffer> generateBinaryVectors(int dimension, int count) {
return vectors;
}

public static void printBinaryVector(ByteBuffer vector) {
vector.rewind();
while (vector.hasRemaining()) {
String byteStr = String.format("%8s", Integer.toBinaryString(vector.get())).replace(' ', '0');
System.out.print(byteStr);
}
System.out.println();
}

/////////////////////////////////////////////////////////////////////////////////////////////////////
public static TBfloat16 genTensorflowBF16Vector(int dimension) {
Random ran = new Random();
Expand Down Expand Up @@ -135,7 +144,7 @@ public static List<ByteBuffer> encodeTensorBF16Vectors(List<TBfloat16> vectors)
return buffers;
}

public static TBfloat16 decodeTensorBF16Vector(ByteBuffer buf) {
public static TBfloat16 decodeBF16VectorToTensor(ByteBuffer buf) {
if (buf.limit()%2 != 0) {
return null;
}
Expand All @@ -144,6 +153,15 @@ public static TBfloat16 decodeTensorBF16Vector(ByteBuffer buf) {
return Tensor.of(TBfloat16.class, Shape.of(dim), bf);
}

public static List<Float> decodeBF16VectorToFloat(ByteBuffer buf) {
List<Float> vector = new ArrayList<>();
TBfloat16 tf = decodeBF16VectorToTensor(buf);
for (long i = 0; i < tf.size(); i++) {
vector.add(tf.getFloat(i));
}
return vector;
}


public static TFloat16 genTensorflowFP16Vector(int dimension) {
Random ran = new Random();
Expand Down Expand Up @@ -183,7 +201,7 @@ public static List<ByteBuffer> encodeTensorFP16Vectors(List<TFloat16> vectors) {
return buffers;
}

public static TFloat16 decodeTensorFP16Vector(ByteBuffer buf) {
public static TFloat16 decodeFP16VectorToTensor(ByteBuffer buf) {
if (buf.limit()%2 != 0) {
return null;
}
Expand All @@ -192,6 +210,15 @@ public static TFloat16 decodeTensorFP16Vector(ByteBuffer buf) {
return Tensor.of(TFloat16.class, Shape.of(dim), bf);
}

public static List<Float> decodeFP16VectorToFloat(ByteBuffer buf) {
List<Float> vector = new ArrayList<>();
TFloat16 tf = decodeFP16VectorToTensor(buf);
for (long i = 0; i < tf.size(); i++) {
vector.add(tf.getFloat(i));
}
return vector;
}

/////////////////////////////////////////////////////////////////////////////////////////////////////
public static ByteBuffer encodeFloat16Vector(List<Float> originVector, boolean bfloat16) {
if (bfloat16) {
Expand Down
20 changes: 7 additions & 13 deletions examples/src/main/java/io/milvus/v1/Float16VectorExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,6 @@ private static void testFloat16(boolean bfloat16) {
SearchResultsWrapper resultsWrapper = new SearchResultsWrapper(searchRet.getData().getResults());
List<SearchResultsWrapper.IDScore> scores = resultsWrapper.getIDScore(0);
System.out.printf("The result of No.%d target vector:\n", i);
for (SearchResultsWrapper.IDScore score : scores) {
System.out.println(score);
}

SearchResultsWrapper.IDScore firstScore = scores.get(0);
if (firstScore.getLongID() != k) {
Expand All @@ -223,6 +220,9 @@ private static void testFloat16(boolean bfloat16) {
throw new RuntimeException(String.format("The output vector is not equal to original vector: ID %d", k));
}
}
System.out.println("\nTarget vector: " + originVector);
System.out.println("Top0 result: " + firstScore);
System.out.println("Top0 result vector: " + outputVector);
}
System.out.println("Search result is correct");

Expand Down Expand Up @@ -316,19 +316,13 @@ private static void testTensorflowFloat16(boolean bfloat16) {
throw new RuntimeException("The query result is incorrect");
}

List<Float> vector = new ArrayList<>();
List<Float> outVector;
if (bfloat16) {
TBfloat16 tf = CommonUtils.decodeTensorBF16Vector(outputBuf);
for (long i = 0; i < tf.size(); i++) {
vector.add(tf.getFloat(i));
}
outVector = CommonUtils.decodeBF16VectorToFloat(outputBuf);
} else {
TFloat16 tf = CommonUtils.decodeTensorFP16Vector(outputBuf);
for (long i = 0; i < tf.size(); i++) {
vector.add(tf.getFloat(i));
}
outVector = CommonUtils.decodeFP16VectorToFloat(outputBuf);
}
System.out.println(vector);
System.out.println("Output vector: " + outVector);
System.out.println("Query result is correct");

// drop the collection if you don't need the collection anymore
Expand Down
6 changes: 3 additions & 3 deletions examples/src/main/java/io/milvus/v1/GeneralExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,9 @@ private R<SearchResults> searchFace(String expr) {
for (int i = 0; i < vectors.size(); ++i) {
System.out.println("Search result of No." + i);
List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
System.out.println(scores);
System.out.println("Output field data for No." + i);
System.out.println(wrapper.getFieldData(AGE_FIELD, i));
for (SearchResultsWrapper.IDScore score : scores) {
System.out.println(score);
}
}

return response;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ public static void main(String[] args) {
Random ran = new Random();
int k = ran.nextInt(rowCount);
SortedMap<Long, Float> targetVector = vectors.get(k);
System.out.println("\nTarget vector: " + targetVector);
R<SearchResults> searchRet = milvusClient.search(SearchParam.newBuilder()
.withCollectionName(COLLECTION_NAME)
.withMetricType(MetricType.IP)
Expand Down
19 changes: 8 additions & 11 deletions examples/src/main/java/io/milvus/v2/BinaryVectorExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@

public class BinaryVectorExample {
private static final String COLLECTION_NAME = "java_sdk_example_binary_vector_v2";
private static final String ID_FIELD = "id";
private static final String ID_FIELD = "pk";
private static final String VECTOR_FIELD = "vector";

private static final Integer VECTOR_DIM = 512;
private static final Integer VECTOR_DIM = 128;


public static void main(String[] args) {
Expand Down Expand Up @@ -126,6 +126,8 @@ public static void main(String[] args) {
Random ran = new Random();
int k = ran.nextInt(rowCount);
ByteBuffer targetVector = vectors.get(k);
System.out.printf("\nANN search for vector ID=%d:\n", k);
CommonUtils.printBinaryVector(targetVector);
Map<String,Object> params = new HashMap<>();
params.put("nprobe",16);
SearchResp searchResp = client.search(SearchReq.builder()
Expand All @@ -141,16 +143,11 @@ public static void main(String[] args) {
// Here we only input one vector to search, get the result of No.0 vector to check
List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
List<SearchResp.SearchResult> results = searchResults.get(0);
System.out.printf("The result of No.%d target vector:\n", i);
System.out.printf("The result of No.%d target vector, ID=%d:\n", i, k);
for (SearchResp.SearchResult result : results) {
System.out.println(result.getEntity());
System.out.printf("ID: %d, Score: %f, Vector: ", result.getId(), result.getScore());
System.out.println(result);
ByteBuffer vector = (ByteBuffer) result.getEntity().get(VECTOR_FIELD);
vector.rewind();
while (vector.hasRemaining()) {
System.out.print(Integer.toBinaryString(vector.get()));
}
System.out.println();
CommonUtils.printBinaryVector(vector);
}

SearchResp.SearchResult firstResult = results.get(0);
Expand All @@ -165,7 +162,7 @@ public static void main(String[] args) {
int n = 99;
QueryResp queryResp = client.query(QueryReq.builder()
.collectionName(COLLECTION_NAME)
.filter(String.format("id == %d", n))
.filter(String.format("%s == %d", ID_FIELD, n))
.outputFields(Collections.singletonList(VECTOR_FIELD))
.build());

Expand Down
9 changes: 7 additions & 2 deletions examples/src/main/java/io/milvus/v2/Float16VectorExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,22 @@ private static void searchVectors(List<Long> taargetIDs, List<BaseVector> target
}
Map<String, Object> entity = topResult.getEntity();
ByteBuffer vectorBuf = (ByteBuffer) entity.get(vectorFieldName);
if (!vectorBuf.equals(targetVectors.get(i).getData())) {
ByteBuffer targetVectorBuf = (ByteBuffer)targetVectors.get(i).getData();
if (!vectorBuf.equals(targetVectorBuf)) {
throw new RuntimeException("The top1 output vector is incorrect");
}
List<Float> decodedTargetVector = CommonUtils.decodeFloat16Vector(targetVectorBuf,
BF16_VECTOR_FIELD.equals(vectorFieldName));
// The method for converting float16 vector to float32 vector can be found in
// CommonUtils.
List<Float> decodedFpVector = CommonUtils.decodeFloat16Vector(vectorBuf,
BF16_VECTOR_FIELD.equals(vectorFieldName));
if (decodedFpVector.size() != VECTOR_DIM) {
throw new RuntimeException("The decoded vector dimension is incorrect");
}
System.out.println(results.get(0));
System.out.println("\nTarget vector: " + decodedTargetVector);
System.out.println("Top0 result: " + topResult);
System.out.println("Top0 result vector: " + decodedFpVector);
}
System.out.println("Search result of " + vectorFieldName + " is correct");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private static void searchByText(MilvusClientV2 client, String text) {
List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
for (List<SearchResp.SearchResult> results : searchResults) {
for (SearchResp.SearchResult result : results) {
System.out.printf("ID: %d, Score: %f, %s\n", (long)result.getId(), result.getScore(), result.getEntity().toString());
System.out.println(result);
}
}
System.out.println("=============================================================");
Expand Down
2 changes: 1 addition & 1 deletion examples/src/main/java/io/milvus/v2/GeneralExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ private static void searchFace(String filter) {
for (List<SearchResp.SearchResult> results : searchResults) {
System.out.println("Search result of No." + i++);
for (SearchResp.SearchResult result : results) {
System.out.printf("ID: %s, Score: %f, %s\n", result.getId(), result.getScore(), result.getEntity().toString());
System.out.println(result);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ private void hybridSearch() {
System.out.printf("============= Search result of No.%d vector =============\n", i);
List<SearchResp.SearchResult> results = searchResults.get(i);
for (SearchResp.SearchResult result : results) {
System.out.printf("{id: %d, score: %f}%n", result.getId(), result.getScore());
System.out.println(result);
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions examples/src/main/java/io/milvus/v2/Int8VectorExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,9 @@ public static void main(String[] args) {
List<SearchResp.SearchResult> results = searchResults.get(0);
System.out.printf("\nThe result of No.%d vector %s:\n", k, Arrays.toString(targetVector.array()));
for (SearchResp.SearchResult result : results) {
System.out.printf("ID: %d, Score: %f, Vector: ", (long)result.getId(), result.getScore());
System.out.println(result);
ByteBuffer vector = (ByteBuffer) result.getEntity().get(VECTOR_FIELD);
System.out.print(Arrays.toString(vector.array()));
System.out.println();
System.out.println(Arrays.toString(vector.array()));
}

SearchResp.SearchResult firstResult = results.get(0);
Expand Down
3 changes: 2 additions & 1 deletion examples/src/main/java/io/milvus/v2/SparseVectorExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ public static void main(String[] args) {
Random ran = new Random();
int k = ran.nextInt(rowCount);
SortedMap<Long, Float> targetVector = vectors.get(k);
System.out.println("\nTarget vector: " + targetVector);
Map<String,Object> params = new HashMap<>();
params.put("drop_ratio_search",0.2);
SearchResp searchResp = client.search(SearchReq.builder()
Expand All @@ -136,7 +137,7 @@ public static void main(String[] args) {
List<SearchResp.SearchResult> results = searchResults.get(0);
System.out.printf("The result of No.%d target vector:\n", i);
for (SearchResp.SearchResult result : results) {
System.out.println(result.getEntity());
System.out.println(result);
}

SearchResp.SearchResult firstResult = results.get(0);
Expand Down
2 changes: 1 addition & 1 deletion examples/src/main/java/io/milvus/v2/TextMatchExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ private static void queryWithFilter(MilvusClientV2 client, String filter) {
System.out.println("\nQuery with filter: " + filter);
List<QueryResp.QueryResult> records = queryRet.getQueryResults();
for (QueryResp.QueryResult record : records) {
System.out.println(record.getEntity());
System.out.println(record);
}
System.out.printf("%d items matched%n", records.size());
System.out.println("=============================================================");
Expand Down
24 changes: 12 additions & 12 deletions sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.milvus.response.basic.RowRecordWrapper;
import lombok.Getter;
import lombok.NonNull;
import org.apache.commons.lang3.StringUtils;

import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -161,14 +162,15 @@ public List<IDScore> getIDScore(int indexOfTarget) throws ParamException, Illega

// set id and score
IDs ids = results.getIds();
String pkName = results.getPrimaryFieldName();
if (ids.hasIntId()) {
LongArray longIDs = ids.getIntId();
if (offset + k > longIDs.getDataCount()) {
throw new IllegalResponseException("Result ids count is wrong");
}

for (int n = 0; n < k; ++n) {
idScores.add(new IDScore("", longIDs.getData((int)offset + n), results.getScores((int)offset + n)));
idScores.add(new IDScore(pkName, "", longIDs.getData((int)offset + n), results.getScores((int)offset + n)));
}
} else if (ids.hasStrId()) {
StringArray strIDs = ids.getStrId();
Expand All @@ -177,7 +179,7 @@ public List<IDScore> getIDScore(int indexOfTarget) throws ParamException, Illega
}

for (int n = 0; n < k; ++n) {
idScores.add(new IDScore(strIDs.getData((int)offset + n), 0, results.getScores((int)offset + n)));
idScores.add(new IDScore(pkName, strIDs.getData((int)offset + n), 0, results.getScores((int)offset + n)));
}
} else {
// in v2.3.3, return an empty list instead of throwing exception
Expand Down Expand Up @@ -272,12 +274,14 @@ private Position getOffsetByIndex(int indexOfTarget) {
*/
@Getter
public static final class IDScore {
private final String primaryKey;
private final String strID;
private final long longID;
private final float score;
Map<String, Object> fieldValues = new HashMap<>();

public IDScore(String strID, long longID, float score) {
public IDScore(String primaryKey, String strID, long longID, float score) {
this.primaryKey = primaryKey;
this.strID = strID;
this.longID = longID;
this.score = score;
Expand Down Expand Up @@ -333,16 +337,12 @@ public boolean contains(String keyName) {

@Override
public String toString() {
List<String> pairs = new ArrayList<>();
fieldValues.forEach((keyName, fieldValue) -> {
pairs.add(keyName + ":" + fieldValue);
});

if (strID.isEmpty()) {
return "(ID: " + getLongID() + " Score: " + getScore() + " OutputFields: " + pairs + ")";
} else {
return "(ID: '" + getStrID() + "' Score: " + getScore()+ " OutputFields: " + pairs + ")";
Object id = strID;
if (StringUtils.isEmpty(strID)) {
id = longID;
}

return "{" + getPrimaryKey() + ": " + id + ", Score: " + getScore() + ", OutputFields: " + fieldValues + "}";
}
}
}
Loading
Loading