diff --git a/examples/src/main/java/io/milvus/v1/UpsertExample.java b/examples/src/main/java/io/milvus/v1/UpsertExample.java new file mode 100644 index 000000000..4eb33d234 --- /dev/null +++ b/examples/src/main/java/io/milvus/v1/UpsertExample.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.v1; + +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import io.milvus.client.MilvusClient; +import io.milvus.client.MilvusServiceClient; +import io.milvus.common.clientenum.ConsistencyLevelEnum; +import io.milvus.grpc.DataType; +import io.milvus.grpc.MutationResult; +import io.milvus.grpc.QueryResults; +import io.milvus.param.*; +import io.milvus.param.collection.*; +import io.milvus.param.dml.InsertParam; +import io.milvus.param.dml.QueryParam; +import io.milvus.param.dml.UpsertParam; +import io.milvus.param.index.CreateIndexParam; +import io.milvus.response.QueryResultsWrapper; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class UpsertExample { + private static final MilvusClient client; + + static { + ConnectParam connectParam = ConnectParam.newBuilder() + .withHost("localhost") + .withPort(19530) + .build(); + client = new MilvusServiceClient(connectParam); + } + private static final String COLLECTION_NAME = "java_sdk_example_upsert_v1"; + private static final String ID_FIELD = "pk"; + private static final String VECTOR_FIELD = "vector"; + private static final String TEXT_FIELD = "text"; + private static final Integer VECTOR_DIM = 128; + + private static void queryWithExpr(String expr) { + R queryRet = client.query(QueryParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withExpr(expr) + .withOutFields(Arrays.asList(ID_FIELD, TEXT_FIELD)) + .withConsistencyLevel(ConsistencyLevelEnum.STRONG) + .build()); + QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryRet.getData()); + System.out.println("\nQuery with expression: " + expr); + List records = queryWrapper.getRowRecords(); + for (QueryResultsWrapper.RowRecord record : records) { + System.out.println(record); + } + } + + private static List createCollection(boolean autoID) { + // Define fields + List fieldsSchema = Arrays.asList( + FieldType.newBuilder() + .withName(ID_FIELD) + .withDataType(DataType.Int64) + .withPrimaryKey(true) + .withAutoID(autoID) + .build(), + FieldType.newBuilder() + .withName(VECTOR_FIELD) + .withDataType(DataType.FloatVector) + .withDimension(VECTOR_DIM) + .build(), + FieldType.newBuilder() + .withName(TEXT_FIELD) + .withDataType(DataType.VarChar) + .withMaxLength(100) + .build() + ); + + CollectionSchemaParam collectionSchemaParam = CollectionSchemaParam.newBuilder() + .withFieldTypes(fieldsSchema) + .build(); + + // Drop the collection if exists + client.dropCollection(DropCollectionParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .build()); + + // Create the collection with 3 fields + R ret = client.createCollection(CreateCollectionParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withSchema(collectionSchemaParam) + .build()); + CommonUtils.handleResponseStatus(ret); + + // Specify an index type on the vector field. + ret = client.createIndex(CreateIndexParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withFieldName(VECTOR_FIELD) + .withIndexType(IndexType.FLAT) + .withMetricType(MetricType.L2) + .build()); + CommonUtils.handleResponseStatus(ret); + + // Call loadCollection() to enable automatically loading data into memory for searching + client.loadCollection(LoadCollectionParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .build()); + System.out.println("\nCollection created with autoID = " + autoID); + + // insert rows + Gson gson = new Gson(); + List rows = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + JsonObject row = new JsonObject(); + if (!autoID) { + row.addProperty(ID_FIELD, i); + } + List vector = CommonUtils.generateFloatVector(VECTOR_DIM); + row.add(VECTOR_FIELD, gson.toJsonTree(vector)); + row.addProperty(TEXT_FIELD, String.format("text_%d", i)); + rows.add(row); + } + R resp = client.insert(InsertParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withRows(rows) + .build()); + CommonUtils.handleResponseStatus(resp); + return resp.getData().getIDs().getIntId().getDataList(); + } + + private static void doUpsert(boolean autoID) { + // if autoID is true, the collection primary key is auto-generated by server + List ids = createCollection(autoID); + + // query before upsert + Long testID = ids.get(1); + String filter = String.format("%s == %d", ID_FIELD, testID); + queryWithExpr(filter); + + // upsert + // the server will return a new primary key, the old entity is deleted, + // and a new entity is created with the new primary key + Gson gson = new Gson(); + JsonObject row = new JsonObject(); + row.addProperty(ID_FIELD, testID); + List vector = CommonUtils.generateFloatVector(VECTOR_DIM); + row.add(VECTOR_FIELD, gson.toJsonTree(vector)); + row.addProperty(TEXT_FIELD, "this field has been updated"); + R upsertResp = client.upsert(UpsertParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withRows(Collections.singletonList(row)) + .build()); + CommonUtils.handleResponseStatus(upsertResp); + List newIds = upsertResp.getData().getIDs().getIntId().getDataList(); + Long newID = newIds.get(0); + System.out.println("\nUpsert done"); + + // query after upsert + filter = String.format("%s == %d", ID_FIELD, newID); + queryWithExpr(filter); + } + + public static void main(String[] args) { + doUpsert(true); + doUpsert(false); + + client.close(); + } +} diff --git a/examples/src/main/java/io/milvus/v2/UpsertExample.java b/examples/src/main/java/io/milvus/v2/UpsertExample.java new file mode 100644 index 000000000..d73e3c129 --- /dev/null +++ b/examples/src/main/java/io/milvus/v2/UpsertExample.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.v2; + +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import io.milvus.v1.CommonUtils; +import io.milvus.v2.client.ConnectConfig; +import io.milvus.v2.client.MilvusClientV2; +import io.milvus.v2.common.ConsistencyLevel; +import io.milvus.v2.common.DataType; +import io.milvus.v2.common.IndexParam; +import io.milvus.v2.service.collection.request.AddFieldReq; +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import io.milvus.v2.service.collection.request.DropCollectionReq; +import io.milvus.v2.service.vector.request.InsertReq; +import io.milvus.v2.service.vector.request.QueryReq; +import io.milvus.v2.service.vector.request.UpsertReq; +import io.milvus.v2.service.vector.response.InsertResp; +import io.milvus.v2.service.vector.response.QueryResp; +import io.milvus.v2.service.vector.response.UpsertResp; + +import java.util.*; + +public class UpsertExample { + private static final MilvusClientV2 client; + static { + client = new MilvusClientV2(ConnectConfig.builder() + .uri("http://localhost:19530") + .build()); + } + + private static final String COLLECTION_NAME = "java_sdk_example_upsert_v2"; + private static final String ID_FIELD = "pk"; + private static final String VECTOR_FIELD = "vector"; + private static final String TEXT_FIELD = "text"; + private static final Integer VECTOR_DIM = 128; + + private static List createCollection(boolean autoID) { + // Drop collection if exists + client.dropCollection(DropCollectionReq.builder() + .collectionName(COLLECTION_NAME) + .build()); + + // Create collection + CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder() + .build(); + collectionSchema.addField(AddFieldReq.builder() + .fieldName(ID_FIELD) + .dataType(DataType.Int64) + .isPrimaryKey(Boolean.TRUE) + .autoID(autoID) + .build()); + collectionSchema.addField(AddFieldReq.builder() + .fieldName(VECTOR_FIELD) + .dataType(DataType.FloatVector) + .dimension(VECTOR_DIM) + .build()); + collectionSchema.addField(AddFieldReq.builder() + .fieldName(TEXT_FIELD) + .dataType(DataType.VarChar) + .maxLength(100) + .build()); + + List indexes = new ArrayList<>(); + indexes.add(IndexParam.builder() + .fieldName(VECTOR_FIELD) + .indexType(IndexParam.IndexType.FLAT) + .metricType(IndexParam.MetricType.COSINE) + .build()); + + CreateCollectionReq requestCreate = CreateCollectionReq.builder() + .collectionName(COLLECTION_NAME) + .collectionSchema(collectionSchema) + .indexParams(indexes) + .consistencyLevel(ConsistencyLevel.BOUNDED) + .build(); + client.createCollection(requestCreate); + System.out.println("\nCollection created with autoID = " + autoID); + + // Insert rows + Gson gson = new Gson(); + List rows = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + JsonObject row = new JsonObject(); + if (!autoID) { + row.addProperty(ID_FIELD, i); + } + List vector = CommonUtils.generateFloatVector(VECTOR_DIM); + row.add(VECTOR_FIELD, gson.toJsonTree(vector)); + row.addProperty(TEXT_FIELD, String.format("text_%d", i)); + rows.add(row); + } + InsertResp resp = client.insert(InsertReq.builder() + .collectionName(COLLECTION_NAME) + .data(rows) + .build()); + return resp.getPrimaryKeys(); + } + + private static void queryWithExpr(String expr) { + QueryResp queryRet = client.query(QueryReq.builder() + .collectionName(COLLECTION_NAME) + .filter(expr) + .outputFields(Arrays.asList(ID_FIELD, TEXT_FIELD)) + .consistencyLevel(ConsistencyLevel.STRONG) + .build()); + System.out.println("\nQuery with expression: " + expr); + List records = queryRet.getQueryResults(); + for (QueryResp.QueryResult record : records) { + System.out.println(record.getEntity()); + } + } + + private static void doUpsert(boolean autoID) { + // if autoID is true, the collection primary key is auto-generated by server + List ids = createCollection(autoID); + + // query before upsert + Long testID = (Long)ids.get(1); + String filter = String.format("%s == %d", ID_FIELD, testID); + queryWithExpr(filter); + + // upsert + // the server will return a new primary key, the old entity is deleted, + // and a new entity is created with the new primary key + Gson gson = new Gson(); + JsonObject row = new JsonObject(); + row.addProperty(ID_FIELD, testID); + List vector = CommonUtils.generateFloatVector(VECTOR_DIM); + row.add(VECTOR_FIELD, gson.toJsonTree(vector)); + row.addProperty(TEXT_FIELD, "this field has been updated"); + UpsertResp upsertResp = client.upsert(UpsertReq.builder() + .collectionName(COLLECTION_NAME) + .data(Collections.singletonList(row)) + .build()); + List newIds = upsertResp.getPrimaryKeys(); + Long newID = (Long)newIds.get(0); + System.out.println("\nUpsert done"); + + // query after upsert + filter = String.format("%s == %d", ID_FIELD, newID); + queryWithExpr(filter); + } + + public static void main(String[] args) { + doUpsert(true); + doUpsert(false); + + client.close(); + } +} diff --git a/sdk-core/src/main/java/io/milvus/param/ParamUtils.java b/sdk-core/src/main/java/io/milvus/param/ParamUtils.java index c88f2579b..3618c0854 100644 --- a/sdk-core/src/main/java/io/milvus/param/ParamUtils.java +++ b/sdk-core/src/main/java/io/milvus/param/ParamUtils.java @@ -526,13 +526,6 @@ public InsertBuilderWrapper(@NonNull UpsertParam requestParam, DescCollResponseWrapper wrapper) { String collectionName = requestParam.getCollectionName(); - // currently, not allow to upsert for collection whose primary key is auto-generated - FieldType pk = wrapper.getPrimaryField(); - if (pk.isAutoID()) { - throw new ParamException(String.format("Upsert don't support autoID==True, collection: %s", - requestParam.getCollectionName())); - } - // generate upsert request builder MsgBase msgBase = MsgBase.newBuilder().setMsgType(MsgType.Insert).build(); upsertBuilder = UpsertRequest.newBuilder() @@ -601,7 +594,8 @@ private void checkAndSetColumnData(DescCollResponseWrapper wrapper, List ids = new ArrayList<>(); if (response.getIDs().hasIntId()) { - List ids = new ArrayList<>(response.getIDs().getIntId().getDataList()); - return InsertResp.builder() - .InsertCnt(response.getInsertCnt()) - .primaryKeys(ids) - .build(); - } else { - List ids = new ArrayList<>(response.getIDs().getStrId().getDataList()); - return InsertResp.builder() - .InsertCnt(response.getInsertCnt()) - .primaryKeys(ids) - .build(); + ids = new ArrayList<>(response.getIDs().getIntId().getDataList()); + } else if (response.getIDs().hasStrId()) { + ids = new ArrayList<>(response.getIDs().getStrId().getDataList()); } + return InsertResp.builder() + .InsertCnt(response.getInsertCnt()) + .primaryKeys(ids) + .build(); } private UpsertRequest buildUpsertRequest(UpsertReq request, DescribeCollectionResponse descResp) { @@ -207,8 +205,17 @@ public UpsertResp upsert(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStu // update the last write timestamp for SESSION consistency String key = GTsDict.CombineCollectionName(actualDbName(dbName), collectionName); GTsDict.getInstance().updateCollectionTs(key, response.getTimestamp()); + + // handle integer pk or string pk + List ids = new ArrayList<>(); + if (response.getIDs().hasIntId()) { + ids = new ArrayList<>(response.getIDs().getIntId().getDataList()); + } else if (response.getIDs().hasStrId()) { + ids = new ArrayList<>(response.getIDs().getStrId().getDataList()); + } return UpsertResp.builder() .upsertCnt(response.getUpsertCnt()) + .primaryKeys(ids) .build(); } diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/response/InsertResp.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/response/InsertResp.java index 6b191d3f7..3f1ca9788 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/response/InsertResp.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/response/InsertResp.java @@ -29,6 +29,7 @@ @Data @SuperBuilder public class InsertResp { + // TODO: the first character should be lower case, add a new member and deprecate the old member private long InsertCnt; @Builder.Default private List primaryKeys = new ArrayList<>(); diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/response/UpsertResp.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/response/UpsertResp.java index 5c168b08c..b54c7e33c 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/response/UpsertResp.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/response/UpsertResp.java @@ -19,11 +19,21 @@ package io.milvus.v2.service.vector.response; +import lombok.Builder; import lombok.Data; import lombok.experimental.SuperBuilder; +import java.util.ArrayList; +import java.util.List; + @Data @SuperBuilder public class UpsertResp { private long upsertCnt; + + // From v2.4.10, milvus allows upsert for auto-id=true, the server will return a new pk. + // the new pk is not equal to the original pk, the original entity is deleted, and a new entity + // is created with this new pk. Here we return this new pk to user. + @Builder.Default + private List primaryKeys = new ArrayList<>(); } diff --git a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java index 2ab4ee7ec..993a3f0cc 100644 --- a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java +++ b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java @@ -1164,6 +1164,7 @@ void testDeleteUpsert() { .data(dataUpdate) .build()); Assertions.assertEquals(2, upsertResp.getUpsertCnt()); + Assertions.assertEquals(2, upsertResp.getPrimaryKeys().size()); // get row count rowCount = getRowCount(randomCollectionName);