diff --git a/docker-compose.yml b/docker-compose.yml index 803b035a8..d882236ba 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -32,7 +32,7 @@ services: standalone: container_name: milvus-javasdk-test-standalone - image: milvusdb/milvus:v2.5.11 + image: milvusdb/milvus:master-20250610-9439eaef-amd64 command: ["milvus", "run", "standalone"] environment: ETCD_ENDPOINTS: etcd:2379 @@ -77,7 +77,7 @@ services: standaloneslave: container_name: milvus-javasdk-test-slave-standalone - image: milvusdb/milvus:v2.5.11 + image: milvusdb/milvus:master-20250610-9439eaef-amd64 command: ["milvus", "run", "standalone"] environment: ETCD_ENDPOINTS: etcdslave:2379 diff --git a/sdk-core/src/main/java/io/milvus/client/AbstractMilvusGrpcClient.java b/sdk-core/src/main/java/io/milvus/client/AbstractMilvusGrpcClient.java index 7a9bac829..865b663c5 100644 --- a/sdk-core/src/main/java/io/milvus/client/AbstractMilvusGrpcClient.java +++ b/sdk-core/src/main/java/io/milvus/client/AbstractMilvusGrpcClient.java @@ -255,6 +255,8 @@ private void waitForFlush(FlushResponse flushResponse, long waitingInterval, lon // If waiting time exceed timeout, exist the circle long tsBegin = System.currentTimeMillis(); Map collectionSegIDs = flushResponse.getCollSegIDsMap(); + Map flushTsMap = flushResponse.getCollFlushTsMap(); + String dbName = flushResponse.getDbName(); collectionSegIDs.forEach((collectionName, segmentIDs) -> { while (segmentIDs.getDataCount() > 0) { long tsNow = System.currentTimeMillis(); @@ -263,10 +265,15 @@ private void waitForFlush(FlushResponse flushResponse, long waitingInterval, lon break; } - GetFlushStateRequest getFlushStateRequest = GetFlushStateRequest.newBuilder() + GetFlushStateRequest.Builder builder = GetFlushStateRequest.newBuilder() .addAllSegmentIDs(segmentIDs.getDataList()) - .build(); - GetFlushStateResponse response = blockingStub().getFlushState(getFlushStateRequest); + .setCollectionName(collectionName) + .setFlushTs(flushTsMap.get(collectionName)); + if (StringUtils.isNotEmpty(dbName)) { + builder.setDbName(dbName); + } + + GetFlushStateResponse response = blockingStub().getFlushState(builder.build()); if (response.getFlushed()) { // if all segment of this collection has been flushed, break this circle and check next collection String msg = segmentIDs.getDataCount() + " segments of " + collectionName + " has been flushed"; diff --git a/sdk-core/src/main/java/io/milvus/v2/client/MilvusClientV2.java b/sdk-core/src/main/java/io/milvus/v2/client/MilvusClientV2.java index fa6a1198e..97a388011 100644 --- a/sdk-core/src/main/java/io/milvus/v2/client/MilvusClientV2.java +++ b/sdk-core/src/main/java/io/milvus/v2/client/MilvusClientV2.java @@ -303,6 +303,14 @@ public void alterCollection(AlterCollectionReq request) { public void alterCollectionProperties(AlterCollectionPropertiesReq request) { rpcUtils.retry(()-> collectionService.alterCollectionProperties(this.getRpcStub(), request)); } + /** + * Add a new field to collection. + * + * @param request add new field request + */ + public void addCollectionField(AddCollectionFieldReq request) { + rpcUtils.retry(()-> collectionService.addCollectionField(this.getRpcStub(), request)); + } /** * Alter a field's properties. * @@ -920,7 +928,7 @@ public void flush(FlushReq request) { if (request.getWaitFlushedTimeoutMs() > 0L) { tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getWaitFlushedTimeoutMs(), TimeUnit.MILLISECONDS); } - utilityService.waitFlush(tempBlockingStub, response.getCollectionSegmentIDs(), response.getCollectionFlushTs()); + utilityService.waitFlush(tempBlockingStub, response); } /** diff --git a/sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java b/sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java index 71fcf072b..61b0d8cc6 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java @@ -216,6 +216,24 @@ public Void alterCollectionProperties(MilvusServiceGrpc.MilvusServiceBlockingStu return null; } + public Void addCollectionField(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, AddCollectionFieldReq request) { + String title = String.format("AddCollectionFieldReq fieldName:%s", request.getFieldName()); + AddCollectionFieldRequest.Builder builder = AddCollectionFieldRequest.newBuilder() + .setCollectionName(request.getCollectionName()); + if (StringUtils.isNotEmpty(request.getDatabaseName())) { + builder.setDbName(request.getDatabaseName()); + } + + CreateCollectionReq.FieldSchema fieldSchema = SchemaUtils.convertFieldReqToFieldSchema(request); + FieldSchema grpcFieldSchema = SchemaUtils.convertToGrpcFieldSchema(fieldSchema); + builder.setSchema(grpcFieldSchema.toByteString()); + + Status response = blockingStub.addCollectionField(builder.build()); + rpcUtils.handleResponse(title, response); + + return null; + } + public Void alterCollectionField(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, AlterCollectionFieldReq request) { String title = String.format("AlterCollectionFieldReq collectionName:%s", request.getCollectionName()); AlterCollectionFieldRequest.Builder builder = AlterCollectionFieldRequest.newBuilder() diff --git a/sdk-core/src/main/java/io/milvus/v2/service/collection/request/AddCollectionFieldReq.java b/sdk-core/src/main/java/io/milvus/v2/service/collection/request/AddCollectionFieldReq.java new file mode 100644 index 000000000..130ae61e2 --- /dev/null +++ b/sdk-core/src/main/java/io/milvus/v2/service/collection/request/AddCollectionFieldReq.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.v2.service.collection.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class AddCollectionFieldReq extends AddFieldReq{ + private String collectionName; + private String databaseName; +} diff --git a/sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java b/sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java index 989ddb7b7..cb8f04181 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java @@ -20,12 +20,12 @@ package io.milvus.v2.service.collection.request; import io.milvus.common.clientenum.FunctionType; -import io.milvus.param.ParamUtils; import io.milvus.v2.common.ConsistencyLevel; import io.milvus.v2.common.DataType; import io.milvus.v2.common.IndexParam; import io.milvus.v2.exception.ErrorCode; import io.milvus.v2.exception.MilvusClientException; +import io.milvus.v2.utils.SchemaUtils; import lombok.Builder; import lombok.Data; import lombok.NonNull; @@ -136,44 +136,7 @@ public static class CollectionSchema { private List functionList = new ArrayList<>(); public CollectionSchema addField(AddFieldReq addFieldReq) { - // check the input here to pop error messages earlier - if (addFieldReq.isEnableDefaultValue() && addFieldReq.getDefaultValue() == null - && addFieldReq.getIsNullable() == Boolean.FALSE) { - String msg = String.format("Default value cannot be null for field '%s' that is defined as nullable == false.", addFieldReq.getFieldName()); - throw new MilvusClientException(ErrorCode.INVALID_PARAMS, msg); - } - - CreateCollectionReq.FieldSchema fieldSchema = FieldSchema.builder() - .name(addFieldReq.getFieldName()) - .dataType(addFieldReq.getDataType()) - .description(addFieldReq.getDescription()) - .isPrimaryKey(addFieldReq.getIsPrimaryKey()) - .isPartitionKey(addFieldReq.getIsPartitionKey()) - .isClusteringKey(addFieldReq.getIsClusteringKey()) - .autoID(addFieldReq.getAutoID()) - .isNullable(addFieldReq.getIsNullable()) - .defaultValue(addFieldReq.getDefaultValue()) - .enableAnalyzer(addFieldReq.getEnableAnalyzer()) - .enableMatch(addFieldReq.getEnableMatch()) - .analyzerParams(addFieldReq.getAnalyzerParams()) - .typeParams(addFieldReq.getTypeParams()) - .multiAnalyzerParams(addFieldReq.getMultiAnalyzerParams()) - .build(); - if (addFieldReq.getDataType().equals(DataType.Array)) { - if (addFieldReq.getElementType() == null) { - throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Element type, maxCapacity are required for array field"); - } - fieldSchema.setElementType(addFieldReq.getElementType()); - fieldSchema.setMaxCapacity(addFieldReq.getMaxCapacity()); - } else if (addFieldReq.getDataType().equals(DataType.VarChar)) { - fieldSchema.setMaxLength(addFieldReq.getMaxLength()); - } else if (ParamUtils.isDenseVectorDataType(io.milvus.grpc.DataType.valueOf(addFieldReq.getDataType().name()))) { - if (addFieldReq.getDimension() == null) { - throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Dimension is required for vector field"); - } - fieldSchema.setDimension(addFieldReq.getDimension()); - } - fieldSchemaList.add(fieldSchema); + fieldSchemaList.add(SchemaUtils.convertFieldReqToFieldSchema(addFieldReq)); return this; } diff --git a/sdk-core/src/main/java/io/milvus/v2/service/utility/UtilityService.java b/sdk-core/src/main/java/io/milvus/v2/service/utility/UtilityService.java index 341478e29..6d6599953 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/utility/UtilityService.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/utility/UtilityService.java @@ -52,21 +52,23 @@ public FlushResp flush(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, }); Map collectionFlushTs = response.getCollFlushTsMap(); return FlushResp.builder() + .databaseName(response.getDbName()) .collectionSegmentIDs(collectionSegmentIDs) .collectionFlushTs(collectionFlushTs) .build(); } // this method is internal use, not expose to user - public Void waitFlush(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, - Map> collectionSegmentIDs, - Map collectionFlushTs) { + public Void waitFlush(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, FlushResp flushResp) { + Map> collectionSegmentIDs = flushResp.getCollectionSegmentIDs(); + Map collectionFlushTs = flushResp.getCollectionFlushTs(); collectionSegmentIDs.forEach((collectionName, segmentIDs)->{ if (collectionFlushTs.containsKey(collectionName)) { Long flushTs = collectionFlushTs.get(collectionName); boolean flushed = false; while (!flushed) { GetFlushStateResponse flushResponse = blockingStub.getFlushState(GetFlushStateRequest.newBuilder() + .setDbName(flushResp.getDatabaseName()) .addAllSegmentIDs(segmentIDs) .setFlushTs(flushTs) .build()); diff --git a/sdk-core/src/main/java/io/milvus/v2/service/utility/response/FlushResp.java b/sdk-core/src/main/java/io/milvus/v2/service/utility/response/FlushResp.java index 25dd566be..56b90df41 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/utility/response/FlushResp.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/utility/response/FlushResp.java @@ -28,6 +28,8 @@ @Data @SuperBuilder public class FlushResp { + @Builder.Default + String databaseName = ""; @Builder.Default Map> collectionSegmentIDs = new HashMap<>(); @Builder.Default diff --git a/sdk-core/src/main/java/io/milvus/v2/utils/SchemaUtils.java b/sdk-core/src/main/java/io/milvus/v2/utils/SchemaUtils.java index becf8cd47..7059dd462 100644 --- a/sdk-core/src/main/java/io/milvus/v2/utils/SchemaUtils.java +++ b/sdk-core/src/main/java/io/milvus/v2/utils/SchemaUtils.java @@ -31,6 +31,7 @@ import io.milvus.param.ParamUtils; import io.milvus.v2.exception.ErrorCode; import io.milvus.v2.exception.MilvusClientException; +import io.milvus.v2.service.collection.request.AddFieldReq; import io.milvus.v2.service.collection.request.CreateCollectionReq; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; @@ -215,4 +216,46 @@ public static CreateCollectionReq.Function convertFromGrpcFunction(FunctionSchem .build(); return function; } + + public static CreateCollectionReq.FieldSchema convertFieldReqToFieldSchema(AddFieldReq addFieldReq) { + // check the input here to pop error messages earlier + if (addFieldReq.isEnableDefaultValue() && addFieldReq.getDefaultValue() == null + && addFieldReq.getIsNullable() == Boolean.FALSE) { + String msg = String.format("Default value cannot be null for field '%s' that is defined as nullable == false.", addFieldReq.getFieldName()); + throw new MilvusClientException(ErrorCode.INVALID_PARAMS, msg); + } + + CreateCollectionReq.FieldSchema fieldSchema = CreateCollectionReq.FieldSchema.builder() + .name(addFieldReq.getFieldName()) + .dataType(addFieldReq.getDataType()) + .description(addFieldReq.getDescription()) + .isPrimaryKey(addFieldReq.getIsPrimaryKey()) + .isPartitionKey(addFieldReq.getIsPartitionKey()) + .isClusteringKey(addFieldReq.getIsClusteringKey()) + .autoID(addFieldReq.getAutoID()) + .isNullable(addFieldReq.getIsNullable()) + .defaultValue(addFieldReq.getDefaultValue()) + .enableAnalyzer(addFieldReq.getEnableAnalyzer()) + .enableMatch(addFieldReq.getEnableMatch()) + .analyzerParams(addFieldReq.getAnalyzerParams()) + .typeParams(addFieldReq.getTypeParams()) + .multiAnalyzerParams(addFieldReq.getMultiAnalyzerParams()) + .build(); + if (addFieldReq.getDataType().equals(io.milvus.v2.common.DataType.Array)) { + if (addFieldReq.getElementType() == null) { + throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Element type, maxCapacity are required for array field"); + } + fieldSchema.setElementType(addFieldReq.getElementType()); + fieldSchema.setMaxCapacity(addFieldReq.getMaxCapacity()); + } else if (addFieldReq.getDataType().equals(io.milvus.v2.common.DataType.VarChar)) { + fieldSchema.setMaxLength(addFieldReq.getMaxLength()); + } else if (ParamUtils.isDenseVectorDataType(io.milvus.grpc.DataType.valueOf(addFieldReq.getDataType().name()))) { + if (addFieldReq.getDimension() == null) { + throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Dimension is required for vector field"); + } + fieldSchema.setDimension(addFieldReq.getDimension()); + } + + return fieldSchema; + } } diff --git a/sdk-core/src/test/java/io/milvus/TestUtils.java b/sdk-core/src/test/java/io/milvus/TestUtils.java index 1948eec41..4edf5d30c 100644 --- a/sdk-core/src/test/java/io/milvus/TestUtils.java +++ b/sdk-core/src/test/java/io/milvus/TestUtils.java @@ -11,7 +11,7 @@ public class TestUtils { private int dimension = 256; private static final Random RANDOM = new Random(); - public static final String MilvusDockerImageID = "milvusdb/milvus:v2.5.11"; + public static final String MilvusDockerImageID = "milvusdb/milvus:master-20250610-9439eaef-amd64"; public TestUtils(int dimension) { this.dimension = dimension; diff --git a/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java b/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java index 48809212e..2c942ad6e 100644 --- a/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java +++ b/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java @@ -75,10 +75,16 @@ class MilvusClientDockerTest { private static final TestUtils utils = new TestUtils(DIMENSION); @Container - private static final MilvusContainer milvus = new MilvusContainer(TestUtils.MilvusDockerImageID); + private static final MilvusContainer milvus = new MilvusContainer(TestUtils.MilvusDockerImageID) + .withEnv("DEPLOY_MODE", "STANDALONE"); @BeforeAll public static void setUp() { + try { + Thread.sleep(3000); // Sleep for few seconds since the master branch milvus healthz check is bug + } catch (InterruptedException ignored) { + } + ConnectParam connectParam = connectParamBuilder() .withAuthorization("root", "Milvus") .build(); @@ -2021,6 +2027,7 @@ void testDynamicField() { for (int i = 0; i < targetVectors.size(); ++i) { List scores = results.getIDScore(i); System.out.println("The result of No." + i + " target vector:"); + Assertions.assertFalse(scores.isEmpty()); SearchResultsWrapper.IDScore score = scores.get(0); System.out.println(score); Object extraMeta = score.get("dynamic"); diff --git a/sdk-core/src/test/java/io/milvus/client/MilvusServiceClientTest.java b/sdk-core/src/test/java/io/milvus/client/MilvusServiceClientTest.java index d10d497d7..d36a01d9c 100644 --- a/sdk-core/src/test/java/io/milvus/client/MilvusServiceClientTest.java +++ b/sdk-core/src/test/java/io/milvus/client/MilvusServiceClientTest.java @@ -575,6 +575,7 @@ void getCollectionStatistics() { final long segmentID = 2021L; mockServerImpl.setFlushResponse(FlushResponse.newBuilder() .putCollSegIDs(collectionName, LongArray.newBuilder().addData(segmentID).build()) + .putCollFlushTs(collectionName, 200L) .build()); mockServerImpl.setGetFlushStateResponse(GetFlushStateResponse.newBuilder() .setFlushed(false) diff --git a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java index 64ee2754c..3b23fb072 100644 --- a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java +++ b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java @@ -81,10 +81,16 @@ class MilvusClientV2DockerTest { private static final TestUtils utils = new TestUtils(DIMENSION); @Container - private static final MilvusContainer milvus = new MilvusContainer(TestUtils.MilvusDockerImageID); + private static final MilvusContainer milvus = new MilvusContainer(TestUtils.MilvusDockerImageID) + .withEnv("DEPLOY_MODE", "STANDALONE"); @BeforeAll public static void setUp() { + try { + Thread.sleep(3000); // Sleep for few seconds since the master branch milvus healthz check is bug + } catch (InterruptedException ignored) { + } + ConnectConfig config = ConnectConfig.builder() .uri(milvus.getEndpoint()) .build(); @@ -324,19 +330,26 @@ void testFloatVectors() { .collectionNames(Collections.singletonList(randomCollectionName)) .build()); - // get persistent segment info - GetPersistentSegmentInfoResp pSegInfo = client.getPersistentSegmentInfo(GetPersistentSegmentInfoReq.builder() - .collectionName(randomCollectionName) - .build()); - Assertions.assertEquals(1, pSegInfo.getSegmentInfos().size()); - GetPersistentSegmentInfoResp.PersistentSegmentInfo pInfo = pSegInfo.getSegmentInfos().get(0); - Assertions.assertTrue(pInfo.getSegmentID() > 0L); - Assertions.assertTrue(pInfo.getCollectionID() > 0L); - Assertions.assertTrue(pInfo.getPartitionID() > 0L); - Assertions.assertEquals(count, pInfo.getNumOfRows()); - Assertions.assertEquals("Flushed", pInfo.getState()); - Assertions.assertEquals("L1", pInfo.getLevel()); - Assertions.assertFalse(pInfo.getIsSorted()); + // master branch, getPersistentSegmentInfo cannot ensure the segment is returned after flush() + while(true) { + // get persistent segment info + GetPersistentSegmentInfoResp pSegInfo = client.getPersistentSegmentInfo(GetPersistentSegmentInfoReq.builder() + .collectionName(randomCollectionName) + .build()); + if (pSegInfo.getSegmentInfos().size() == 0) { + continue; + } + Assertions.assertEquals(1, pSegInfo.getSegmentInfos().size()); + GetPersistentSegmentInfoResp.PersistentSegmentInfo pInfo = pSegInfo.getSegmentInfos().get(0); + Assertions.assertTrue(pInfo.getSegmentID() > 0L); + Assertions.assertTrue(pInfo.getCollectionID() > 0L); + Assertions.assertTrue(pInfo.getPartitionID() > 0L); + Assertions.assertEquals(count, pInfo.getNumOfRows()); + Assertions.assertEquals("Flushed", pInfo.getState()); + Assertions.assertEquals("L1", pInfo.getLevel()); +// Assertions.assertFalse(pInfo.getIsSorted()); + break; + } // compact CompactResp compactResp = client.compact(CompactReq.builder() @@ -796,142 +809,142 @@ void testSparseVectors() { client.dropCollection(DropCollectionReq.builder().collectionName(randomCollectionName).build()); } -// @Test -// void testInt8Vectors() { -// String randomCollectionName = generator.generate(10); -// String vectorFieldName = "int8_vector"; -// int dimension = 8; -// CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder() -// .build(); -// collectionSchema.addField(AddFieldReq.builder() -// .fieldName("id") -// .dataType(DataType.Int64) -// .isPrimaryKey(Boolean.TRUE) -// .build()); -// collectionSchema.addField(AddFieldReq.builder() -// .fieldName(vectorFieldName) -// .dataType(DataType.Int8Vector) -// .dimension(dimension) -// .build()); -// -// client.dropCollection(DropCollectionReq.builder() -// .collectionName(randomCollectionName) -// .build()); -// CreateCollectionReq requestCreate = CreateCollectionReq.builder() -// .collectionName(randomCollectionName) -// .collectionSchema(collectionSchema) -// .build(); -// client.createCollection(requestCreate); -// -// // insert rows -// Gson gson = new Gson(); -// Random RANDOM = new Random(); -// long count = 10; -// List vectors = new ArrayList<>(); -// List data = new ArrayList<>(); -// for (int i = 0; i < count; i++) { -// JsonObject row = new JsonObject(); -// row.addProperty("id", i); -// -// ByteBuffer vector = ByteBuffer.allocate(dimension); -// for (int k = 0; k < dimension; ++k) { -// vector.put((byte) (RANDOM.nextInt(256) - 128)); -// } -// vectors.add(vector); -// row.add(vectorFieldName, gson.toJsonTree(vector.array())); -// data.add(row); -// } -// -// InsertResp insertResp = client.insert(InsertReq.builder() -// .collectionName(randomCollectionName) -// .data(data) -// .build()); -// Assertions.assertEquals(count, insertResp.getInsertCnt()); -// -// // flush -// client.flush(FlushReq.builder() -// .collectionNames(Collections.singletonList(randomCollectionName)) -// .build()); -// -// // create index -// Map extraParams = new HashMap<>(); -// extraParams.put("M", 64); -// extraParams.put("efConstruction", 200); -// IndexParam indexParam = IndexParam.builder() -// .fieldName(vectorFieldName) -// .indexType(IndexParam.IndexType.HNSW) -// .metricType(IndexParam.MetricType.COSINE) -// .extraParams(extraParams) -// .build(); -// client.createIndex(CreateIndexReq.builder() -// .collectionName(randomCollectionName) -// .indexParams(Collections.singletonList(indexParam)) -// .build()); -// -// client.loadCollection(LoadCollectionReq.builder() -// .collectionName(randomCollectionName) -// .build()); -// -// // describe collection -// DescribeCollectionResp descResp = client.describeCollection(DescribeCollectionReq.builder() -// .collectionName(randomCollectionName) -// .build()); -// Assertions.assertEquals(randomCollectionName, descResp.getCollectionName()); -// -// List fieldNames = descResp.getFieldNames(); -// Assertions.assertEquals(collectionSchema.getFieldSchemaList().size(), fieldNames.size()); -// CreateCollectionReq.CollectionSchema schema = descResp.getCollectionSchema(); -// for (String name : fieldNames) { -// CreateCollectionReq.FieldSchema f1 = collectionSchema.getField(name); -// CreateCollectionReq.FieldSchema f2 = schema.getField(name); -// Assertions.assertNotNull(f1); -// Assertions.assertNotNull(f2); -// Assertions.assertEquals(f1.getName(), f2.getName()); -// Assertions.assertEquals(f1.getDataType(), f2.getDataType()); -// Assertions.assertEquals(f1.getDimension(), f2.getDimension()); -// } -// -// // search in collection -// int topK = 3; -// List targetVectors = Arrays.asList(new Int8Vec(vectors.get(5)), new Int8Vec(vectors.get(0))); -// SearchResp searchResp = client.search(SearchReq.builder() -// .collectionName(randomCollectionName) -// .annsField(vectorFieldName) -// .data(targetVectors) -// .topK(topK) -// .outputFields(Collections.singletonList("*")) -// .consistencyLevel(ConsistencyLevel.STRONG) -// .build()); -// List> searchResults = searchResp.getSearchResults(); -// Assertions.assertEquals(targetVectors.size(), searchResults.size()); -// -// for (List results : searchResults) { -// Assertions.assertEquals(topK, results.size()); -// for (int i = 0; i < results.size(); i++) { -// SearchResp.SearchResult result = results.get(i); -// Map entity = result.getEntity(); -// long id = (long) entity.get("id"); -// ByteBuffer originVec = vectors.get((int) id); -// ByteBuffer getVec = (ByteBuffer) entity.get(vectorFieldName); -// Assertions.assertEquals(originVec, getVec); -// } -// } -// -// // query -// QueryResp queryResp = client.query(QueryReq.builder() -// .collectionName(randomCollectionName) -// .filter("id == 5") -// .build()); -// List queryResults = queryResp.getQueryResults(); -// Assertions.assertEquals(1, queryResults.size()); -// { -// QueryResp.QueryResult result = queryResults.get(0); -// Map entity = result.getEntity(); -// ByteBuffer originVec = vectors.get(5); -// ByteBuffer getVec = (ByteBuffer)entity.get(vectorFieldName); -// Assertions.assertEquals(originVec, getVec); -// } -// } + @Test + void testInt8Vectors() { + String randomCollectionName = generator.generate(10); + String vectorFieldName = "int8_vector"; + int dimension = 8; + CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder() + .build(); + collectionSchema.addField(AddFieldReq.builder() + .fieldName("id") + .dataType(DataType.Int64) + .isPrimaryKey(Boolean.TRUE) + .build()); + collectionSchema.addField(AddFieldReq.builder() + .fieldName(vectorFieldName) + .dataType(DataType.Int8Vector) + .dimension(dimension) + .build()); + + client.dropCollection(DropCollectionReq.builder() + .collectionName(randomCollectionName) + .build()); + CreateCollectionReq requestCreate = CreateCollectionReq.builder() + .collectionName(randomCollectionName) + .collectionSchema(collectionSchema) + .build(); + client.createCollection(requestCreate); + + // insert rows + Gson gson = new Gson(); + Random RANDOM = new Random(); + long count = 10; + List vectors = new ArrayList<>(); + List data = new ArrayList<>(); + for (int i = 0; i < count; i++) { + JsonObject row = new JsonObject(); + row.addProperty("id", i); + + ByteBuffer vector = ByteBuffer.allocate(dimension); + for (int k = 0; k < dimension; ++k) { + vector.put((byte) (RANDOM.nextInt(256) - 128)); + } + vectors.add(vector); + row.add(vectorFieldName, gson.toJsonTree(vector.array())); + data.add(row); + } + + InsertResp insertResp = client.insert(InsertReq.builder() + .collectionName(randomCollectionName) + .data(data) + .build()); + Assertions.assertEquals(count, insertResp.getInsertCnt()); + + // flush + client.flush(FlushReq.builder() + .collectionNames(Collections.singletonList(randomCollectionName)) + .build()); + + // create index + Map extraParams = new HashMap<>(); + extraParams.put("M", 64); + extraParams.put("efConstruction", 200); + IndexParam indexParam = IndexParam.builder() + .fieldName(vectorFieldName) + .indexType(IndexParam.IndexType.HNSW) + .metricType(IndexParam.MetricType.COSINE) + .extraParams(extraParams) + .build(); + client.createIndex(CreateIndexReq.builder() + .collectionName(randomCollectionName) + .indexParams(Collections.singletonList(indexParam)) + .build()); + + client.loadCollection(LoadCollectionReq.builder() + .collectionName(randomCollectionName) + .build()); + + // describe collection + DescribeCollectionResp descResp = client.describeCollection(DescribeCollectionReq.builder() + .collectionName(randomCollectionName) + .build()); + Assertions.assertEquals(randomCollectionName, descResp.getCollectionName()); + + List fieldNames = descResp.getFieldNames(); + Assertions.assertEquals(collectionSchema.getFieldSchemaList().size(), fieldNames.size()); + CreateCollectionReq.CollectionSchema schema = descResp.getCollectionSchema(); + for (String name : fieldNames) { + CreateCollectionReq.FieldSchema f1 = collectionSchema.getField(name); + CreateCollectionReq.FieldSchema f2 = schema.getField(name); + Assertions.assertNotNull(f1); + Assertions.assertNotNull(f2); + Assertions.assertEquals(f1.getName(), f2.getName()); + Assertions.assertEquals(f1.getDataType(), f2.getDataType()); + Assertions.assertEquals(f1.getDimension(), f2.getDimension()); + } + + // search in collection + int topK = 3; + List targetVectors = Arrays.asList(new Int8Vec(vectors.get(5)), new Int8Vec(vectors.get(0))); + SearchResp searchResp = client.search(SearchReq.builder() + .collectionName(randomCollectionName) + .annsField(vectorFieldName) + .data(targetVectors) + .topK(topK) + .outputFields(Collections.singletonList("*")) + .consistencyLevel(ConsistencyLevel.STRONG) + .build()); + List> searchResults = searchResp.getSearchResults(); + Assertions.assertEquals(targetVectors.size(), searchResults.size()); + + for (List results : searchResults) { + Assertions.assertEquals(topK, results.size()); + for (int i = 0; i < results.size(); i++) { + SearchResp.SearchResult result = results.get(i); + Map entity = result.getEntity(); + long id = (long) entity.get("id"); + ByteBuffer originVec = vectors.get((int) id); + ByteBuffer getVec = (ByteBuffer) entity.get(vectorFieldName); + Assertions.assertEquals(originVec, getVec); + } + } + + // query + QueryResp queryResp = client.query(QueryReq.builder() + .collectionName(randomCollectionName) + .filter("id == 5") + .build()); + List queryResults = queryResp.getQueryResults(); + Assertions.assertEquals(1, queryResults.size()); + { + QueryResp.QueryResult result = queryResults.get(0); + Map entity = result.getEntity(); + ByteBuffer originVec = vectors.get(5); + ByteBuffer getVec = (ByteBuffer)entity.get(vectorFieldName); + Assertions.assertEquals(originVec, getVec); + } + } @Test void testHybridSearch() { @@ -2402,6 +2415,36 @@ void testDynamicField() { SearchResp.SearchResult r = searchResults.get(0).get(0); Assertions.assertTrue(r.getEntity().containsKey("dynamic_10")); Assertions.assertEquals("this is dynamic value", r.getEntity().get("dynamic_10")); + + // add new field + client.addCollectionField(AddCollectionFieldReq.builder() + .collectionName(collectionName) + .fieldName("text") + .dataType(DataType.VarChar) + .maxLength(100) + .isNullable(true) // must be nullable + .build()); + client.addCollectionField(AddCollectionFieldReq.builder() + .collectionName(collectionName) + .fieldName("flag") + .dataType(DataType.Int32) + .defaultValue(100) + .isNullable(true) // must be nullable + .build()); + + DescribeCollectionResp descResp = client.describeCollection(DescribeCollectionReq.builder() + .collectionName(collectionName) + .build()); + Assertions.assertEquals(4, descResp.getFieldNames().size()); + List fieldNames = descResp.getFieldNames(); + Assertions.assertTrue(fieldNames.contains("text")); + Assertions.assertTrue(fieldNames.contains("flag")); + CreateCollectionReq.CollectionSchema schema = descResp.getCollectionSchema(); + + CreateCollectionReq.FieldSchema field = schema.getField("text"); + Assertions.assertEquals(DataType.VarChar, field.getDataType()); + Assertions.assertEquals(100, field.getMaxLength()); + Assertions.assertTrue(field.getIsNullable()); } @Test