diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/VectorService.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/VectorService.java index 7604f5855..bee6cfc53 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/VectorService.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/VectorService.java @@ -332,17 +332,8 @@ public DeleteResp delete(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStu DescribeCollectionResp respR = convertUtils.convertDescCollectionResp(descResp); request.setFilter(vectorUtils.getExprById(respR.getPrimaryFieldName(), request.getIds())); } - DeleteRequest.Builder builder = DeleteRequest.newBuilder() - .setCollectionName(collectionName) - .setPartitionName(request.getPartitionName()) - .setExpr(request.getFilter()); - if (request.getFilter() != null && !request.getFilter().isEmpty()) { - Map filterTemplateValues = request.getFilterTemplateValues(); - filterTemplateValues.forEach((key, value)->{ - builder.putExprTemplateValues(key, VectorUtils.deduceAndCreateTemplateValue(value)); - }); - } - MutationResult response = blockingStub.delete(builder.build()); + DeleteRequest rpcRequest = dataUtils.ConvertToGrpcDeleteRequest(request); + MutationResult response = blockingStub.delete(rpcRequest); // if illegal data, server fails to process delete, clean the schema cache // so that the next call of dml can update the cache diff --git a/sdk-core/src/main/java/io/milvus/v2/utils/DataUtils.java b/sdk-core/src/main/java/io/milvus/v2/utils/DataUtils.java index eaa851c4d..f21a0f9f1 100644 --- a/sdk-core/src/main/java/io/milvus/v2/utils/DataUtils.java +++ b/sdk-core/src/main/java/io/milvus/v2/utils/DataUtils.java @@ -27,6 +27,7 @@ import io.milvus.v2.exception.MilvusClientException; import io.milvus.v2.service.collection.request.CreateCollectionReq; import io.milvus.v2.service.collection.response.DescribeCollectionResp; +import io.milvus.v2.service.vector.request.DeleteReq; import io.milvus.v2.service.vector.request.InsertReq; import io.milvus.v2.service.vector.request.UpsertReq; import lombok.Builder; @@ -238,4 +239,22 @@ public static class InsertDataInfo { private final CreateCollectionReq.FieldSchema field; private final LinkedList data; } + + public DeleteRequest ConvertToGrpcDeleteRequest(DeleteReq request) { + DeleteRequest.Builder builder = DeleteRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .setPartitionName(request.getPartitionName()) + .setExpr(request.getFilter()); + if (request.getFilter() != null && !request.getFilter().isEmpty()) { + Map filterTemplateValues = request.getFilterTemplateValues(); + filterTemplateValues.forEach((key, value)->{ + builder.putExprTemplateValues(key, VectorUtils.deduceAndCreateTemplateValue(value)); + }); + } + String dbName = request.getDatabaseName(); + if (StringUtils.isNotEmpty(dbName)) { + builder.setDbName(dbName); + } + return builder.build(); + } } diff --git a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java index 3e5c3397e..f526c2a0f 100644 --- a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java +++ b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java @@ -283,8 +283,9 @@ private void verifyOutput(JsonObject row, Map entity) { Assertions.assertEquals(arrStrOri, arrStr); } - private long getRowCount(String collectionName) { + private long getRowCount(String dbName, String collectionName) { QueryResp queryResp = client.query(QueryReq.builder() + .databaseName(dbName) .collectionName(collectionName) .outputFields(Collections.singletonList("count(*)")) .consistencyLevel(ConsistencyLevel.STRONG) @@ -414,7 +415,7 @@ void testFloatVectors() { Assertions.assertEquals(1, upsertResp.getUpsertCnt()); // get row count - long rowCount = getRowCount(randomCollectionName); + long rowCount = getRowCount("", randomCollectionName); Assertions.assertEquals(count + 1, rowCount); // describe collection @@ -638,7 +639,7 @@ void testBinaryVectors() throws InterruptedException { Assertions.assertEquals(count, insertResp.getInsertCnt()); // get row count - long rowCount = getRowCount(randomCollectionName); + long rowCount = getRowCount("", randomCollectionName); Assertions.assertEquals(count, rowCount); // search in collection @@ -807,7 +808,7 @@ void testFloat16Vectors() { } // get row count - long rowCount = getRowCount(randomCollectionName); + long rowCount = getRowCount("", randomCollectionName); Assertions.assertEquals(count, rowCount); client.dropCollection(DropCollectionReq.builder().collectionName(randomCollectionName).build()); @@ -851,7 +852,7 @@ void testSparseVectors() { Assertions.assertEquals(count, insertResp.getInsertCnt()); // get row count - long rowCount = getRowCount(randomCollectionName); + long rowCount = getRowCount("", randomCollectionName); Assertions.assertEquals(count, rowCount); // search in collection @@ -1156,7 +1157,7 @@ void testHybridSearch() { Assertions.assertEquals(count, insertResp.getInsertCnt()); // get row count - long rowCount = getRowCount(randomCollectionName); + long rowCount = getRowCount("", randomCollectionName); Assertions.assertEquals(count, rowCount); // search again, there are results @@ -1176,6 +1177,12 @@ void testHybridSearch() { void testDeleteUpsert() { String randomCollectionName = generator.generate(10); + // create a new db + String testDbName = "test_delete_db"; + client.createDatabase(CreateDatabaseReq.builder() + .databaseName(testDbName) + .build()); + CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder() .build(); collectionSchema.addField(AddFieldReq.builder() @@ -1196,7 +1203,9 @@ void testDeleteUpsert() { .metricType(IndexParam.MetricType.L2) .extraParams(new HashMap(){{put("nlist", 64);}}) .build()); + // create collection in the test db CreateCollectionReq requestCreate = CreateCollectionReq.builder() + .databaseName(testDbName) .collectionName(randomCollectionName) .collectionSchema(collectionSchema) .indexParams(indexParams) @@ -1213,6 +1222,7 @@ void testDeleteUpsert() { } InsertResp insertResp = client.insert(InsertReq.builder() + .databaseName(testDbName) .collectionName(randomCollectionName) .data(data) .build()); @@ -1220,13 +1230,14 @@ void testDeleteUpsert() { // delete DeleteResp deleteResp = client.delete(DeleteReq.builder() + .databaseName(testDbName) .collectionName(randomCollectionName) .ids(Arrays.asList("pk_5", "pk_8")) .build()); Assertions.assertEquals(2, deleteResp.getDeleteCnt()); // get row count - long rowCount = getRowCount(randomCollectionName); + long rowCount = getRowCount(testDbName, randomCollectionName); Assertions.assertEquals(8L, rowCount); // upsert @@ -1240,6 +1251,7 @@ void testDeleteUpsert() { row2.add("float_vector", JsonUtils.toJsonTree(new float[]{2.0f, 2.0f, 2.0f, 2.0f})); dataUpdate.add(row2); UpsertResp upsertResp = client.upsert(UpsertReq.builder() + .databaseName(testDbName) .collectionName(randomCollectionName) .data(dataUpdate) .build()); @@ -1247,11 +1259,12 @@ void testDeleteUpsert() { Assertions.assertEquals(2, upsertResp.getPrimaryKeys().size()); // get row count - rowCount = getRowCount(randomCollectionName); + rowCount = getRowCount(testDbName, randomCollectionName); Assertions.assertEquals(9L, rowCount); // verify QueryResp queryResp = client.query(QueryReq.builder() + .databaseName(testDbName) .collectionName(randomCollectionName) .ids(Arrays.asList("pk_2", "pk_5")) .outputFields(Collections.singletonList("*")) @@ -1281,7 +1294,10 @@ void testDeleteUpsert() { Assertions.assertEquals(5.0f, f); } - client.dropCollection(DropCollectionReq.builder().collectionName(randomCollectionName).build()); + client.dropCollection(DropCollectionReq.builder() + .databaseName(testDbName) + .collectionName(randomCollectionName) + .build()); } @Test @@ -1615,7 +1631,7 @@ void testCacheCollectionSchema() throws InterruptedException { String randomCollectionName = generator.generate(10); // create a new db - String testDbName = "test_database"; + String testDbName = "test_cache_db"; client.createDatabase(CreateDatabaseReq.builder() .databaseName(testDbName) .build()); @@ -1789,7 +1805,7 @@ public void testIterator() { Assertions.assertEquals(count, insertResp.getInsertCnt()); // get row count - long rowCount = getRowCount(randomCollectionName); + long rowCount = getRowCount("", randomCollectionName); Assertions.assertEquals(count, rowCount); // search iterator @@ -2113,7 +2129,7 @@ void testDatabase() { @Test void testClientPool() { // create a temp database - String dummyDb = "dummy_db"; + String dummyDb = "test_pool_db"; client.createDatabase(CreateDatabaseReq.builder() .databaseName(dummyDb) .build()); @@ -2557,7 +2573,7 @@ void testDocInOut() { Assertions.assertEquals(3, insertResp.getInsertCnt()); // get row count - long rowCount = getRowCount(randomCollectionName); + long rowCount = getRowCount("", randomCollectionName); Assertions.assertEquals(texts.size(), rowCount); // search @@ -2901,7 +2917,7 @@ void testConsistencyLevel() throws InterruptedException { String vectorName = "vector"; int dim = 4; String defaultDbName = "default"; - String tempDbName = "db_for_level"; + String tempDbName = "test_level_db"; // create a temp database client.createDatabase(CreateDatabaseReq.builder()