diff --git a/docker-compose.yml b/docker-compose.yml index 37361e309..bc57efe22 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,7 +3,7 @@ version: '3.5' services: standalone: container_name: milvus-javasdk-standalone-1 - image: milvusdb/milvus:v2.6.7 + image: milvusdb/milvus:v2.6.9 command: [ "milvus", "run", "standalone" ] environment: - COMMON_STORAGETYPE=local @@ -24,7 +24,7 @@ services: standaloneslave: container_name: milvus-javasdk-standalone-2 - image: milvusdb/milvus:v2.6.7 + image: milvusdb/milvus:v2.6.9 command: [ "milvus", "run", "standalone" ] environment: - COMMON_STORAGETYPE=local diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java index 0e68f8139..fc36013fa 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java @@ -40,6 +40,7 @@ public class SearchReq { private String filter; private List outputFields; private List data; + private List ids; private long offset; private long limit; private int roundDecimal; @@ -78,6 +79,7 @@ private SearchReq(SearchReqBuilder builder) { this.filter = builder.filter; this.outputFields = builder.outputFields; this.data = builder.data; + this.ids = builder.ids; this.offset = builder.offset; this.limit = builder.limit; this.roundDecimal = builder.roundDecimal; @@ -171,6 +173,10 @@ public void setData(List data) { this.data = data; } + public List getIds() { + return ids; + } + public long getOffset() { return offset; } @@ -299,7 +305,7 @@ public String toString() { ", topK=" + topK + ", filter='" + filter + '\'' + ", outputFields=" + outputFields + - ", data=" + data + + (ids == null || ids.isEmpty() ? ", data=" + data : ", ids=" + ids) + ", offset=" + offset + ", limit=" + limit + ", roundDecimal=" + roundDecimal + @@ -332,6 +338,7 @@ public static class SearchReqBuilder { private String filter; private List outputFields = new ArrayList<>(); // default value private List data = new ArrayList<>(); // default value + private List ids = new ArrayList<>(); private long offset; private long limit = 0L; // default value private int roundDecimal = -1; // default value @@ -399,6 +406,11 @@ public SearchReqBuilder data(List data) { return this; } + public SearchReqBuilder ids(List ids) { + this.ids = ids; + return this; + } + public SearchReqBuilder offset(long offset) { this.offset = offset; return this; diff --git a/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java b/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java index 7778f9700..babf79973 100644 --- a/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java +++ b/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java @@ -172,6 +172,65 @@ private static ByteString convertPlaceholder(List data, PlaceholderType } } + private static void convertSearchTarget(SearchReq request, SearchRequest.Builder builder) { + // prepare target, the input could be: + // 1. vectors or string list for doc-in-doc-out + // 2. ids list for search by primary keys + List vectors = request.getData(); + List ids = request.getIds(); + boolean vectorsIsEmpty = CollectionUtils.isEmpty(vectors); + boolean idsIsEmpty = CollectionUtils.isEmpty(ids); + if (vectorsIsEmpty && idsIsEmpty) { + throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Require either ids or vectors, but both are empty"); + } + if (!vectorsIsEmpty && !idsIsEmpty) { + throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Require either ids or vectors, but both are provided"); + } + + if (!vectorsIsEmpty) { + // the elements must be all-vector or all-string + PlaceholderType plType = vectors.get(0).getPlaceholderType(); + List data = new ArrayList<>(); + for (BaseVector vector : vectors) { + if (vector.getPlaceholderType() != plType) { + throw new MilvusClientException(ErrorCode.INVALID_PARAMS, + "Different types of target vectors in a search request is not allowed."); + } + data.add(vector.getData()); + } + + ByteString byteStr = convertPlaceholder(data, plType); + builder.setPlaceholderGroup(byteStr); + builder.setNq(vectors.size()); + } else { + Object val = ids.get(0); + if (val instanceof String) { + StringArray.Builder strBuilder = StringArray.newBuilder(); + for (Object obj : ids) { + if (!(obj instanceof String)) { + throw new MilvusClientException(ErrorCode.INVALID_PARAMS, + "All IDs must be of type String if the first ID is a String."); + } + strBuilder.addData((String) obj); + } + builder.setIds(IDs.newBuilder().setStrId(strBuilder.build()).build()); + } else if (val instanceof Long) { + LongArray.Builder longBuilder = LongArray.newBuilder(); + for (Object obj : ids) { + if (!(obj instanceof Long)) { + throw new MilvusClientException(ErrorCode.INVALID_PARAMS, + "All IDs must be of type Long if the first ID is a Long."); + } + longBuilder.addData((Long) obj); + } + builder.setIds(IDs.newBuilder().setIntId(longBuilder.build()).build()); + } else { + throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "ID type must be String or Long."); + } + builder.setNq(ids.size()); + } + } + public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) { String dbName = request.getDatabaseName(); String collectionName = request.getCollectionName(); @@ -185,26 +244,8 @@ public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) { builder.setDbName(dbName); } - // prepare target, the input could be vectors or string list for doc-in-doc-out - List vectors = request.getData(); - if (vectors == null || vectors.isEmpty()) { - throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Target data list of search request is empty."); - } - - // the elements must be all-vector or all-string - PlaceholderType plType = vectors.get(0).getPlaceholderType(); - List data = new ArrayList<>(); - for (BaseVector vector : vectors) { - if (vector.getPlaceholderType() != plType) { - throw new MilvusClientException(ErrorCode.INVALID_PARAMS, - "Different types of target vectors in a search request is not allowed."); - } - data.add(vector.getData()); - } - - ByteString byteStr = convertPlaceholder(data, plType); - builder.setPlaceholderGroup(byteStr); - builder.setNq(vectors.size()); + // target vectors or ids + convertSearchTarget(request, builder); // search parameters // tries to fit the compatibility between v2.5.1 and older versions diff --git a/sdk-core/src/test/java/io/milvus/TestUtils.java b/sdk-core/src/test/java/io/milvus/TestUtils.java index 5e7344022..849913145 100644 --- a/sdk-core/src/test/java/io/milvus/TestUtils.java +++ b/sdk-core/src/test/java/io/milvus/TestUtils.java @@ -11,7 +11,7 @@ public class TestUtils { private int dimension = 256; private static final Random RANDOM = new Random(); - public static final String MilvusDockerImageID = "milvusdb/milvus:v2.6.7"; + public static final String MilvusDockerImageID = "milvusdb/milvus:v2.6.9"; public TestUtils(int dimension) { this.dimension = dimension; diff --git a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java index 38d4242a5..199f433e0 100644 --- a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java +++ b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java @@ -827,6 +827,27 @@ void testFloat16Vectors() { // System.out.println("Output bfloat16 vector: " + outputVector); } + // search by ids + { + List ids = Arrays.asList(5L, 88L, 100L); + SearchResp searchResp = client.search(SearchReq.builder() + .collectionName(randomCollectionName) + .annsField(bfloat16Field) + .ids(ids) + .limit(topk) + .consistencyLevel(ConsistencyLevel.STRONG) + .outputFields(Collections.singletonList(bfloat16Field)) + .build()); + List> searchResults = searchResp.getSearchResults(); + Assertions.assertEquals(3, searchResults.size()); + for (int i = 0; i < searchResults.size(); i++) { + List results = searchResults.get(i); + Assertions.assertEquals(topk, results.size()); + SearchResp.SearchResult firstResult = results.get(0); + Assertions.assertEquals(ids.get(i), firstResult.getId()); + } + } + // get row count long rowCount = getRowCount("", randomCollectionName); Assertions.assertEquals(count, rowCount);