diff --git a/examples/src/main/java/io/milvus/v1/BulkWriterExample.java b/examples/src/main/java/io/milvus/v1/BulkWriterExample.java index c2b5afb2a..bf00f6862 100644 --- a/examples/src/main/java/io/milvus/v1/BulkWriterExample.java +++ b/examples/src/main/java/io/milvus/v1/BulkWriterExample.java @@ -18,15 +18,12 @@ */ package io.milvus.v1; -import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.dataformat.csv.CsvMapper; import com.fasterxml.jackson.dataformat.csv.CsvSchema; import com.google.common.collect.Lists; import com.google.gson.Gson; import com.google.gson.JsonElement; import com.google.gson.JsonObject; -import com.google.gson.reflect.TypeToken; -import io.milvus.bulkwriter.BulkImport; import io.milvus.bulkwriter.BulkWriter; import io.milvus.bulkwriter.LocalBulkWriter; import io.milvus.bulkwriter.LocalBulkWriterParam; @@ -46,6 +43,7 @@ import io.milvus.bulkwriter.request.import_.MilvusImportRequest; import io.milvus.bulkwriter.request.list.CloudListImportJobsRequest; import io.milvus.bulkwriter.request.list.MilvusListImportJobsRequest; +import io.milvus.bulkwriter.restful.BulkImportUtils; import io.milvus.client.MilvusClient; import io.milvus.client.MilvusServiceClient; import io.milvus.common.utils.ExceptionUtils; @@ -69,13 +67,13 @@ import io.milvus.param.index.CreateIndexParam; import io.milvus.response.GetCollStatResponseWrapper; import io.milvus.response.QueryResultsWrapper; +import io.milvus.v2.bulkwriter.CsvDataObject; import org.apache.avro.generic.GenericData; import org.apache.http.util.Asserts; import java.io.File; import java.io.IOException; import java.net.URL; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Iterator; import java.util.List; @@ -491,29 +489,6 @@ private static void readCsvSampleData(String filePath, BulkWriter writer) throws } } - private static class CsvDataObject { - @JsonProperty - private String vector; - @JsonProperty - private String path; - @JsonProperty - private String label; - - public String getVector() { - return vector; - } - public String getPath() { - return path; - } - public String getLabel() { - return label; - } - public List toFloatArray() { - return GSON_INSTANCE.fromJson(vector, new TypeToken>() { - }.getType()); - } - } - private void callBulkInsert(CollectionSchemaParam collectionSchema, List> batchFiles) throws InterruptedException { createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema, true); @@ -524,7 +499,7 @@ private void callBulkInsert(CollectionSchemaParam collectionSchema, List> batchFiles, String collectionNam .clusterId(CloudImportConsts.CLUSTER_ID).collectionName(collectionName).partitionName(partitionName) .apiKey(CloudImportConsts.API_KEY) .build(); - String bulkImportResult = BulkImport.bulkImport(CloudImportConsts.CLOUD_ENDPOINT, bulkImportRequest); + String bulkImportResult = BulkImportUtils.bulkImport(CloudImportConsts.CLOUD_ENDPOINT, bulkImportRequest); JsonObject bulkImportObject = convertJsonObject(bulkImportResult); String jobId = bulkImportObject.getAsJsonObject("data").get("jobId").getAsString(); @@ -585,7 +560,7 @@ private void callCloudImport(List> batchFiles, String collectionNam System.out.println("\n===================== call cloudListImportJobs ===================="); CloudListImportJobsRequest listImportJobsRequest = CloudListImportJobsRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).currentPage(1).pageSize(10).apiKey(CloudImportConsts.API_KEY).build(); - String listImportJobsResult = BulkImport.listImportJobs(CloudImportConsts.CLOUD_ENDPOINT, listImportJobsRequest); + String listImportJobsResult = BulkImportUtils.listImportJobs(CloudImportConsts.CLOUD_ENDPOINT, listImportJobsRequest); System.out.println(listImportJobsResult); while (true) { System.out.println("Wait 5 second to check bulkInsert job state..."); @@ -593,7 +568,7 @@ private void callCloudImport(List> batchFiles, String collectionNam System.out.println("\n===================== call cloudGetProgress ===================="); CloudDescribeImportRequest request = CloudDescribeImportRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).jobId(jobId).apiKey(CloudImportConsts.API_KEY).build(); - String getImportProgressResult = BulkImport.getImportProgress(CloudImportConsts.CLOUD_ENDPOINT, request); + String getImportProgressResult = BulkImportUtils.getImportProgress(CloudImportConsts.CLOUD_ENDPOINT, request); JsonObject getImportProgressObject = convertJsonObject(getImportProgressResult); String importProgressState = getImportProgressObject.getAsJsonObject("data").get("state").getAsString(); String progress = getImportProgressObject.getAsJsonObject("data").get("progress").getAsString(); @@ -740,7 +715,7 @@ private static void exampleCloudImport() { .clusterId(CloudImportConsts.CLUSTER_ID).collectionName(CloudImportConsts.COLLECTION_NAME).partitionName(CloudImportConsts.PARTITION_NAME) .apiKey(CloudImportConsts.API_KEY) .build(); - String bulkImportResult = BulkImport.bulkImport(CloudImportConsts.CLOUD_ENDPOINT, request); + String bulkImportResult = BulkImportUtils.bulkImport(CloudImportConsts.CLOUD_ENDPOINT, request); System.out.println(bulkImportResult); System.out.println("\n===================== get import job progress ===================="); @@ -748,12 +723,12 @@ private static void exampleCloudImport() { JsonObject bulkImportObject = convertJsonObject(bulkImportResult); String jobId = bulkImportObject.getAsJsonObject("data").get("jobId").getAsString(); CloudDescribeImportRequest getImportProgressRequest = CloudDescribeImportRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).jobId(jobId).apiKey(CloudImportConsts.API_KEY).build(); - String getImportProgressResult = BulkImport.getImportProgress(CloudImportConsts.CLOUD_ENDPOINT, getImportProgressRequest); + String getImportProgressResult = BulkImportUtils.getImportProgress(CloudImportConsts.CLOUD_ENDPOINT, getImportProgressRequest); System.out.println(getImportProgressResult); System.out.println("\n===================== list import jobs ===================="); CloudListImportJobsRequest listImportJobsRequest = CloudListImportJobsRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).currentPage(1).pageSize(10).apiKey(CloudImportConsts.API_KEY).build(); - String listImportJobsResult = BulkImport.listImportJobs(CloudImportConsts.CLOUD_ENDPOINT, listImportJobsRequest); + String listImportJobsResult = BulkImportUtils.listImportJobs(CloudImportConsts.CLOUD_ENDPOINT, listImportJobsRequest); System.out.println(listImportJobsResult); } diff --git a/examples/src/main/java/io/milvus/v1/GeneralExample.java b/examples/src/main/java/io/milvus/v1/GeneralExample.java index 4f00401a7..ee24f7514 100644 --- a/examples/src/main/java/io/milvus/v1/GeneralExample.java +++ b/examples/src/main/java/io/milvus/v1/GeneralExample.java @@ -66,7 +66,7 @@ public class GeneralExample { private static final Long SEARCH_K = 5L; private static final String SEARCH_PARAM = "{\"nprobe\":10}"; - + private R createCollection(long timeoutMilliseconds) { System.out.println("========== createCollection() =========="); diff --git a/examples/src/main/java/io/milvus/v2/GeneralExample.java b/examples/src/main/java/io/milvus/v2/GeneralExample.java index 90a91620f..dc7dcb79c 100644 --- a/examples/src/main/java/io/milvus/v2/GeneralExample.java +++ b/examples/src/main/java/io/milvus/v2/GeneralExample.java @@ -231,4 +231,4 @@ public static void main(String[] args) { releaseCollection(); } -} +} \ No newline at end of file diff --git a/examples/src/main/java/io/milvus/v2/StageExample.java b/examples/src/main/java/io/milvus/v2/StageExample.java new file mode 100644 index 000000000..5b1d15593 --- /dev/null +++ b/examples/src/main/java/io/milvus/v2/StageExample.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package io.milvus.v2; + +import io.milvus.bulkwriter.StageOperation; +import io.milvus.bulkwriter.StageOperationParam; +import io.milvus.bulkwriter.model.StageUploadResult; + + +/** + * if you don't have bucket, but you want to upload data to bucket and import to milvus + * you can use this function + */ +public class StageExample { + /** + * You need to upload the local file path or folder path for import. + */ + public static final String LOCAL_DIR_OR_FILE_PATH = "/Users/zilliz/Desktop/1.parquet"; + + /** + * The value of the URL is fixed. + * For overseas regions, it is: https://api.cloud.zilliz.com + * For regions in China, it is: https://api.cloud.zilliz.com.cn + */ + public static final String CLOUD_ENDPOINT = "https://api.cloud.zilliz.com"; + public static final String API_KEY = "_api_key_for_cluster_org_"; + /** + * This is currently a private preview feature. If you need to use it, please submit a request and contact us. + * Before using this feature, you need to create a stage using the stage API. + */ + public static final String STAGE_NAME = "_stage_name_for_project_"; + public static final String PATH = "_path_for_stage"; + + public static void main(String[] args) throws Exception { + uploadFileToStage(); + } + + /** + * If you want to upload file to stage, and then use cloud interface merge the data to collection + */ + private static void uploadFileToStage() throws Exception { + StageOperationParam stageOperationParam = StageOperationParam.newBuilder() + .withCloudEndpoint(CLOUD_ENDPOINT).withApiKey(API_KEY) + .withStageName(STAGE_NAME).withPath(PATH) + .build(); + StageOperation stageOperation = new StageOperation(stageOperationParam); + StageUploadResult result = stageOperation.uploadFileToStage(LOCAL_DIR_OR_FILE_PATH); + System.out.println("\nuploadFileToStage results: " + result); + } +} diff --git a/examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterLocalExample.java b/examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterLocalExample.java new file mode 100644 index 000000000..563e86308 --- /dev/null +++ b/examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterLocalExample.java @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package io.milvus.v2.bulkwriter; + +import com.fasterxml.jackson.dataformat.csv.CsvMapper; +import com.fasterxml.jackson.dataformat.csv.CsvSchema; +import com.google.common.collect.Lists; +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import io.milvus.bulkwriter.BulkWriter; +import io.milvus.bulkwriter.LocalBulkWriter; +import io.milvus.bulkwriter.LocalBulkWriterParam; +import io.milvus.bulkwriter.common.clientenum.BulkFileType; +import io.milvus.bulkwriter.common.utils.GeneratorUtils; +import io.milvus.bulkwriter.common.utils.ParquetReaderUtils; +import io.milvus.v2.client.ConnectConfig; +import io.milvus.v2.client.MilvusClientV2; +import io.milvus.v2.common.ConsistencyLevel; +import io.milvus.v2.common.DataType; +import io.milvus.v2.service.collection.request.AddFieldReq; +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import io.milvus.v2.service.collection.request.DropCollectionReq; +import io.milvus.v2.service.collection.request.HasCollectionReq; +import org.apache.avro.generic.GenericData; +import org.apache.http.util.Asserts; + +import java.io.File; +import java.io.IOException; +import java.net.URL; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + + +public class BulkWriterLocalExample { + // milvus + public static final String HOST = "127.0.0.1"; + public static final Integer PORT = 19530; + public static final String USER_NAME = "user.name"; + public static final String PASSWORD = "password"; + + private static final Gson GSON_INSTANCE = new Gson(); + private static final String SIMPLE_COLLECTION_NAME = "java_sdk_bulkwriter_simple_v2"; + private static final Integer DIM = 512; + private static MilvusClientV2 milvusClient; + + public static void main(String[] args) throws Exception { + createConnection(); + List fileTypes = Lists.newArrayList( + BulkFileType.PARQUET, + BulkFileType.JSON, + BulkFileType.CSV + ); + + exampleSimpleCollection(fileTypes); + } + + private static void createConnection() { + System.out.println("\nCreate connection..."); + String url = String.format("http://%s:%s", HOST, PORT); + milvusClient = new MilvusClientV2(ConnectConfig.builder() + .uri(url) + .username(USER_NAME) + .password(PASSWORD) + .build()); + System.out.println("\nConnected"); + } + + private static void exampleSimpleCollection(List fileTypes) throws Exception { + CreateCollectionReq.CollectionSchema collectionSchema = buildSimpleSchema(); + createCollection(SIMPLE_COLLECTION_NAME, collectionSchema, false); + + for (BulkFileType fileType : fileTypes) { + localWriter(collectionSchema, fileType); + } + + // parallel append + parallelAppend(collectionSchema); + } + + private static void localWriter(CreateCollectionReq.CollectionSchema collectionSchema, BulkFileType fileType) throws Exception { + System.out.printf("\n===================== local writer (%s) ====================%n", fileType.name()); + LocalBulkWriterParam bulkWriterParam = LocalBulkWriterParam.newBuilder() + .withCollectionSchema(collectionSchema) + .withLocalPath("/tmp/bulk_writer") + .withFileType(fileType) + .withChunkSize(128 * 1024 * 1024) + .build(); + + try (LocalBulkWriter localBulkWriter = new LocalBulkWriter(bulkWriterParam)) { + // read data from csv + readCsvSampleData("data/train_embeddings.csv", localBulkWriter); + + // append rows + for (int i = 0; i < 100000; i++) { + JsonObject row = new JsonObject(); + row.addProperty("path", "path_" + i); + row.add("vector", GSON_INSTANCE.toJsonTree(GeneratorUtils.genFloatVector(DIM))); + row.addProperty("label", "label_" + i); + + localBulkWriter.appendRow(row); + } + + System.out.printf("%s rows appends%n", localBulkWriter.getTotalRowCount()); + + localBulkWriter.commit(false); + List> batchFiles = localBulkWriter.getBatchFiles(); + System.out.printf("Local writer done! output local files: %s%n", batchFiles); + } catch (Exception e) { + System.out.println("Local writer catch exception: " + e); + throw e; + } + } + + private static void parallelAppend(CreateCollectionReq.CollectionSchema collectionSchema) throws Exception { + System.out.print("\n===================== parallel append ===================="); + LocalBulkWriterParam bulkWriterParam = LocalBulkWriterParam.newBuilder() + .withCollectionSchema(collectionSchema) + .withLocalPath("/tmp/bulk_writer") + .withFileType(BulkFileType.PARQUET) + .withChunkSize(128 * 1024 * 1024) // 128MB + .build(); + + try (LocalBulkWriter localBulkWriter = new LocalBulkWriter(bulkWriterParam)) { + List threads = new ArrayList<>(); + int threadCount = 10; + int rowsPerThread = 1000; + for (int i = 0; i < threadCount; ++i) { + int current = i; + Thread thread = new Thread(() -> appendRow(localBulkWriter, current * rowsPerThread, (current + 1) * rowsPerThread)); + threads.add(thread); + thread.start(); + System.out.printf("Thread %s started%n", thread.getName()); + } + + for (Thread thread : threads) { + thread.join(); + System.out.printf("Thread %s finished%n", thread.getName()); + } + + System.out.println(localBulkWriter.getTotalRowCount() + " rows appends"); + localBulkWriter.commit(false); + System.out.printf("Append finished, %s rows%n", threadCount * rowsPerThread); + + long rowCount = 0; + List> batchFiles = localBulkWriter.getBatchFiles(); + for (List batch : batchFiles) { + for (String filePath : batch) { + rowCount += readParquet(filePath); + } + } + + Asserts.check(rowCount == threadCount * rowsPerThread, String.format("rowCount %s not equals expected %s", rowCount, threadCount * rowsPerThread)); + System.out.println("Data is correct"); + } catch (Exception e) { + System.out.println("parallelAppend catch exception: " + e); + throw e; + } + } + + private static long readParquet(String localFilePath) throws Exception { + final long[] rowCount = {0}; + new ParquetReaderUtils() { + @Override + public void readRecord(GenericData.Record record) { + rowCount[0]++; + String pathValue = record.get("path").toString(); + String labelValue = record.get("label").toString(); + Asserts.check(pathValue.replace("path_", "").equals(labelValue.replace("label_", "")), String.format("the suffix of %s not equals the suffix of %s", pathValue, labelValue)); + } + }.readParquet(localFilePath); + System.out.printf("The file %s contains %s rows. Verify the content...%n", localFilePath, rowCount[0]); + return rowCount[0]; + } + + private static void appendRow(LocalBulkWriter writer, int begin, int end) { + try { + for (int i = begin; i < end; ++i) { + JsonObject row = new JsonObject(); + row.addProperty("path", "path_" + i); + row.add("vector", GSON_INSTANCE.toJsonTree(GeneratorUtils.genFloatVector(DIM))); + row.addProperty("label", "label_" + i); + + writer.appendRow(row); + if (i % 100 == 0) { + System.out.printf("%s inserted %s items%n", Thread.currentThread().getName(), i - begin); + } + } + } catch (Exception e) { + System.out.println("failed to append row!"); + } + } + + private static void readCsvSampleData(String filePath, BulkWriter writer) throws IOException, InterruptedException { + ClassLoader classLoader = BulkWriterLocalExample.class.getClassLoader(); + URL resourceUrl = classLoader.getResource(filePath); + filePath = new File(resourceUrl.getFile()).getAbsolutePath(); + + CsvMapper csvMapper = new CsvMapper(); + + File csvFile = new File(filePath); + CsvSchema csvSchema = CsvSchema.builder().setUseHeader(true).build(); + Iterator iterator = csvMapper.readerFor(CsvDataObject.class).with(csvSchema).readValues(csvFile); + while (iterator.hasNext()) { + CsvDataObject dataObject = iterator.next(); + JsonObject row = new JsonObject(); + + row.add("vector", GSON_INSTANCE.toJsonTree(dataObject.toFloatArray())); + row.addProperty("label", dataObject.getLabel()); + row.addProperty("path", dataObject.getPath()); + + writer.appendRow(row); + } + } + + /** + * @param collectionSchema collection info + * @param dropIfExist if collection already exist, will drop firstly and then create again + */ + private static void createCollection(String collectionName, CreateCollectionReq.CollectionSchema collectionSchema, boolean dropIfExist) { + System.out.println("\n===================== create collection ===================="); + checkMilvusClientIfExist(); + + CreateCollectionReq requestCreate = CreateCollectionReq.builder() + .collectionName(collectionName) + .collectionSchema(collectionSchema) + .consistencyLevel(ConsistencyLevel.BOUNDED) + .build(); + + Boolean has = milvusClient.hasCollection(HasCollectionReq.builder().collectionName(collectionName).build()); + if (has) { + if (dropIfExist) { + milvusClient.dropCollection(DropCollectionReq.builder().collectionName(collectionName).build()); + milvusClient.createCollection(requestCreate); + } + } else { + milvusClient.createCollection(requestCreate); + } + + System.out.printf("Collection %s created%n", collectionName); + } + + private static CreateCollectionReq.CollectionSchema buildSimpleSchema() { + CreateCollectionReq.CollectionSchema schemaV2 = CreateCollectionReq.CollectionSchema.builder() + .build(); + schemaV2.addField(AddFieldReq.builder() + .fieldName("id") + .dataType(DataType.Int64) + .isPrimaryKey(Boolean.TRUE) + .autoID(true) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("path") + .dataType(DataType.VarChar) + .maxLength(512) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("label") + .dataType(DataType.VarChar) + .maxLength(512) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("vector") + .dataType(DataType.FloatVector) + .dimension(DIM) + .build()); + + return schemaV2; + } + + private static void checkMilvusClientIfExist() { + if (milvusClient == null) { + String msg = "milvusClient is null. Please initialize it by calling createConnection() first before use."; + throw new RuntimeException(msg); + } + } +} diff --git a/examples/src/main/java/io/milvus/v2/BulkWriterExample.java b/examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterRemoteExample.java similarity index 79% rename from examples/src/main/java/io/milvus/v2/BulkWriterExample.java rename to examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterRemoteExample.java index 5320eecfc..ff28b9a75 100644 --- a/examples/src/main/java/io/milvus/v2/BulkWriterExample.java +++ b/examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterRemoteExample.java @@ -16,9 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -package io.milvus.v2; +package io.milvus.v2.bulkwriter; -import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.dataformat.csv.CsvMapper; import com.fasterxml.jackson.dataformat.csv.CsvSchema; import com.google.common.collect.Lists; @@ -26,18 +25,12 @@ import com.google.gson.JsonElement; import com.google.gson.JsonNull; import com.google.gson.JsonObject; -import com.google.gson.reflect.TypeToken; -import io.milvus.bulkwriter.BulkImport; import io.milvus.bulkwriter.BulkWriter; -import io.milvus.bulkwriter.LocalBulkWriter; -import io.milvus.bulkwriter.LocalBulkWriterParam; import io.milvus.bulkwriter.RemoteBulkWriter; import io.milvus.bulkwriter.RemoteBulkWriterParam; import io.milvus.bulkwriter.common.clientenum.BulkFileType; import io.milvus.bulkwriter.common.clientenum.CloudStorage; import io.milvus.bulkwriter.common.utils.GeneratorUtils; -import io.milvus.bulkwriter.common.utils.ImportUtils; -import io.milvus.bulkwriter.common.utils.ParquetReaderUtils; import io.milvus.bulkwriter.connect.AzureConnectParam; import io.milvus.bulkwriter.connect.S3ConnectParam; import io.milvus.bulkwriter.connect.StorageConnectParam; @@ -47,28 +40,38 @@ import io.milvus.bulkwriter.request.import_.MilvusImportRequest; import io.milvus.bulkwriter.request.list.CloudListImportJobsRequest; import io.milvus.bulkwriter.request.list.MilvusListImportJobsRequest; +import io.milvus.bulkwriter.restful.BulkImportUtils; import io.milvus.v1.CommonUtils; import io.milvus.v2.client.ConnectConfig; import io.milvus.v2.client.MilvusClientV2; import io.milvus.v2.common.ConsistencyLevel; import io.milvus.v2.common.DataType; import io.milvus.v2.common.IndexParam; -import io.milvus.v2.service.collection.request.*; +import io.milvus.v2.service.collection.request.AddFieldReq; +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import io.milvus.v2.service.collection.request.DropCollectionReq; +import io.milvus.v2.service.collection.request.HasCollectionReq; +import io.milvus.v2.service.collection.request.LoadCollectionReq; +import io.milvus.v2.service.collection.request.RefreshLoadReq; import io.milvus.v2.service.index.request.CreateIndexReq; import io.milvus.v2.service.vector.request.QueryReq; import io.milvus.v2.service.vector.response.QueryResp; -import org.apache.avro.generic.GenericData; -import org.apache.http.util.Asserts; import java.io.File; import java.io.IOException; import java.net.URL; import java.nio.ByteBuffer; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; import java.util.concurrent.TimeUnit; -public class BulkWriterExample { +public class BulkWriterRemoteExample { // milvus public static final String HOST = "127.0.0.1"; public static final Integer PORT = 19530; @@ -177,16 +180,9 @@ private static void exampleSimpleCollection(List fileTypes) throws CreateCollectionReq.CollectionSchema collectionSchema = buildSimpleSchema(); createCollection(SIMPLE_COLLECTION_NAME, collectionSchema, false); - for (BulkFileType fileType : fileTypes) { - localWriter(collectionSchema, fileType); - } - for (BulkFileType fileType : fileTypes) { remoteWriter(collectionSchema, fileType); } - - // parallel append - parallelAppend(collectionSchema); } private static void exampleAllTypesCollectionRemote(List fileTypes) throws Exception { @@ -197,7 +193,7 @@ private static void exampleAllTypesCollectionRemote(List fileTypes for (BulkFileType fileType : fileTypes) { CreateCollectionReq.CollectionSchema collectionSchema = buildAllTypesSchema(); List> batchFiles = allTypesRemoteWriter(collectionSchema, fileType, rows); - createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema, true); + createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema, false); callBulkInsert(batchFiles); verifyImportData(collectionSchema, originalData); } @@ -207,46 +203,12 @@ private static void exampleAllTypesCollectionRemote(List fileTypes // for (BulkFileType fileType : fileTypes) { // CreateCollectionReq.CollectionSchema collectionSchema = buildAllTypesSchema(); // List> batchFiles = allTypesRemoteWriter(collectionSchema, fileType, rows); -// createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema, true); +// createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema, false); // callCloudImport(batchFiles, ALL_TYPES_COLLECTION_NAME, ""); // verifyImportData(collectionSchema, originalData); // } } - private static void localWriter(CreateCollectionReq.CollectionSchema collectionSchema, BulkFileType fileType) throws Exception { - System.out.printf("\n===================== local writer (%s) ====================%n", fileType.name()); - LocalBulkWriterParam bulkWriterParam = LocalBulkWriterParam.newBuilder() - .withCollectionSchema(collectionSchema) - .withLocalPath("/tmp/bulk_writer") - .withFileType(fileType) - .withChunkSize(128 * 1024 * 1024) - .build(); - - try (LocalBulkWriter localBulkWriter = new LocalBulkWriter(bulkWriterParam)) { - // read data from csv - readCsvSampleData("data/train_embeddings.csv", localBulkWriter); - - // append rows - for (int i = 0; i < 100000; i++) { - JsonObject row = new JsonObject(); - row.addProperty("path", "path_" + i); - row.add("vector", GSON_INSTANCE.toJsonTree(GeneratorUtils.genFloatVector(DIM))); - row.addProperty("label", "label_" + i); - - localBulkWriter.appendRow(row); - } - - System.out.printf("%s rows appends%n", localBulkWriter.getTotalRowCount()); - - localBulkWriter.commit(false); - List> batchFiles = localBulkWriter.getBatchFiles(); - System.out.printf("Local writer done! output local files: %s%n", batchFiles); - } catch (Exception e) { - System.out.println("Local writer catch exception: " + e); - throw e; - } - } - private static void remoteWriter(CreateCollectionReq.CollectionSchema collectionSchema, BulkFileType fileType) throws Exception { System.out.printf("\n===================== remote writer (%s) ====================%n", fileType.name()); @@ -276,85 +238,6 @@ private static void remoteWriter(CreateCollectionReq.CollectionSchema collection } } - private static void parallelAppend(CreateCollectionReq.CollectionSchema collectionSchema) throws Exception { - System.out.print("\n===================== parallel append ===================="); - LocalBulkWriterParam bulkWriterParam = LocalBulkWriterParam.newBuilder() - .withCollectionSchema(collectionSchema) - .withLocalPath("/tmp/bulk_writer") - .withFileType(BulkFileType.PARQUET) - .withChunkSize(128 * 1024 * 1024) // 128MB - .build(); - - try (LocalBulkWriter localBulkWriter = new LocalBulkWriter(bulkWriterParam)) { - List threads = new ArrayList<>(); - int threadCount = 10; - int rowsPerThread = 1000; - for (int i = 0; i < threadCount; ++i) { - int current = i; - Thread thread = new Thread(() -> appendRow(localBulkWriter, current * rowsPerThread, (current + 1) * rowsPerThread)); - threads.add(thread); - thread.start(); - System.out.printf("Thread %s started%n", thread.getName()); - } - - for (Thread thread : threads) { - thread.join(); - System.out.printf("Thread %s finished%n", thread.getName()); - } - - System.out.println(localBulkWriter.getTotalRowCount() + " rows appends"); - localBulkWriter.commit(false); - System.out.printf("Append finished, %s rows%n", threadCount * rowsPerThread); - - long rowCount = 0; - List> batchFiles = localBulkWriter.getBatchFiles(); - for (List batch : batchFiles) { - for (String filePath : batch) { - rowCount += readParquet(filePath); - } - } - - Asserts.check(rowCount == threadCount * rowsPerThread, String.format("rowCount %s not equals expected %s", rowCount, threadCount * rowsPerThread)); - System.out.println("Data is correct"); - } catch (Exception e) { - System.out.println("parallelAppend catch exception: " + e); - throw e; - } - } - - private static long readParquet(String localFilePath) throws Exception { - final long[] rowCount = {0}; - new ParquetReaderUtils() { - @Override - public void readRecord(GenericData.Record record) { - rowCount[0]++; - String pathValue = record.get("path").toString(); - String labelValue = record.get("label").toString(); - Asserts.check(pathValue.replace("path_", "").equals(labelValue.replace("label_", "")), String.format("the suffix of %s not equals the suffix of %s", pathValue, labelValue)); - } - }.readParquet(localFilePath); - System.out.printf("The file %s contains %s rows. Verify the content...%n", localFilePath, rowCount[0]); - return rowCount[0]; - } - - private static void appendRow(LocalBulkWriter writer, int begin, int end) { - try { - for (int i = begin; i < end; ++i) { - JsonObject row = new JsonObject(); - row.addProperty("path", "path_" + i); - row.add("vector", GSON_INSTANCE.toJsonTree(GeneratorUtils.genFloatVector(DIM))); - row.addProperty("label", "label_" + i); - - writer.appendRow(row); - if (i % 100 == 0) { - System.out.printf("%s inserted %s items%n", Thread.currentThread().getName(), i - begin); - } - } - } catch (Exception e) { - System.out.println("failed to append row!"); - } - } - private static List> genOriginalData(int count) { List> data = new ArrayList<>(); for (int i = 0; i < count; ++i) { @@ -530,7 +413,7 @@ private static StorageConnectParam buildStorageConnectParam() { } private static void readCsvSampleData(String filePath, BulkWriter writer) throws IOException, InterruptedException { - ClassLoader classLoader = BulkWriterExample.class.getClassLoader(); + ClassLoader classLoader = BulkWriterRemoteExample.class.getClassLoader(); URL resourceUrl = classLoader.getResource(filePath); filePath = new File(resourceUrl.getFile()).getAbsolutePath(); @@ -551,29 +434,6 @@ private static void readCsvSampleData(String filePath, BulkWriter writer) throws } } - private static class CsvDataObject { - @JsonProperty - private String vector; - @JsonProperty - private String path; - @JsonProperty - private String label; - - public String getVector() { - return vector; - } - public String getPath() { - return path; - } - public String getLabel() { - return label; - } - public List toFloatArray() { - return GSON_INSTANCE.fromJson(vector, new TypeToken>() { - }.getType()); - } - } - private static void callBulkInsert(List> batchFiles) throws InterruptedException { String url = String.format("http://%s:%s", HOST, PORT); System.out.println("\n===================== import files to milvus ===================="); @@ -584,7 +444,7 @@ private static void callBulkInsert(List> batchFiles) throws Interru .files(batchFiles) .options(options) .build(); - String bulkImportResult = BulkImport.bulkImport(url, milvusImportRequest); + String bulkImportResult = BulkImportUtils.bulkImport(url, milvusImportRequest); System.out.println(bulkImportResult); JsonObject bulkImportObject = convertJsonObject(bulkImportResult); @@ -593,7 +453,7 @@ private static void callBulkInsert(List> batchFiles) throws Interru System.out.println("\n===================== listBulkInsertJobs() ===================="); MilvusListImportJobsRequest listImportJobsRequest = MilvusListImportJobsRequest.builder().collectionName(ALL_TYPES_COLLECTION_NAME).build(); - String listImportJobsResult = BulkImport.listImportJobs(url, listImportJobsRequest); + String listImportJobsResult = BulkImportUtils.listImportJobs(url, listImportJobsRequest); System.out.println(listImportJobsResult); while (true) { System.out.println("Wait 5 second to check bulkInsert job state..."); @@ -603,7 +463,7 @@ private static void callBulkInsert(List> batchFiles) throws Interru MilvusDescribeImportRequest request = MilvusDescribeImportRequest.builder() .jobId(jobId) .build(); - String getImportProgressResult = BulkImport.getImportProgress(url, request); + String getImportProgressResult = BulkImportUtils.getImportProgress(url, request); System.out.println(getImportProgressResult); JsonObject getImportProgressObject = convertJsonObject(getImportProgressResult); @@ -622,53 +482,6 @@ private static void callBulkInsert(List> batchFiles) throws Interru } } - private static void callCloudImport(List> batchFiles, String collectionName, String partitionName) throws InterruptedException { - String objectUrl = StorageConsts.cloudStorage == CloudStorage.AZURE - ? StorageConsts.cloudStorage.getAzureObjectUrl(StorageConsts.AZURE_ACCOUNT_NAME, StorageConsts.AZURE_CONTAINER_NAME, ImportUtils.getCommonPrefix(batchFiles)) - : StorageConsts.cloudStorage.getS3ObjectUrl(StorageConsts.STORAGE_BUCKET, ImportUtils.getCommonPrefix(batchFiles), StorageConsts.STORAGE_REGION); - String accessKey = StorageConsts.cloudStorage == CloudStorage.AZURE ? StorageConsts.AZURE_ACCOUNT_NAME : StorageConsts.STORAGE_ACCESS_KEY; - String secretKey = StorageConsts.cloudStorage == CloudStorage.AZURE ? StorageConsts.AZURE_ACCOUNT_KEY : StorageConsts.STORAGE_SECRET_KEY; - - System.out.println("\n===================== call cloudImport ===================="); - CloudImportRequest bulkImportRequest = CloudImportRequest.builder() - .objectUrl(objectUrl).accessKey(accessKey).secretKey(secretKey) - .clusterId(CloudImportConsts.CLUSTER_ID).collectionName(collectionName).partitionName(partitionName) - .apiKey(CloudImportConsts.API_KEY) - .build(); - String bulkImportResult = BulkImport.bulkImport(CloudImportConsts.CLOUD_ENDPOINT, bulkImportRequest); - JsonObject bulkImportObject = convertJsonObject(bulkImportResult); - - String jobId = bulkImportObject.getAsJsonObject("data").get("jobId").getAsString(); - System.out.println("Create a cloudImport job, job id: " + jobId); - - System.out.println("\n===================== call cloudListImportJobs ===================="); - CloudListImportJobsRequest listImportJobsRequest = CloudListImportJobsRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).currentPage(1).pageSize(10).apiKey(CloudImportConsts.API_KEY).build(); - String listImportJobsResult = BulkImport.listImportJobs(CloudImportConsts.CLOUD_ENDPOINT, listImportJobsRequest); - System.out.println(listImportJobsResult); - while (true) { - System.out.println("Wait 5 second to check bulkInsert job state..."); - TimeUnit.SECONDS.sleep(5); - - System.out.println("\n===================== call cloudGetProgress ===================="); - CloudDescribeImportRequest request = CloudDescribeImportRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).jobId(jobId).apiKey(CloudImportConsts.API_KEY).build(); - String getImportProgressResult = BulkImport.getImportProgress(CloudImportConsts.CLOUD_ENDPOINT, request); - JsonObject getImportProgressObject = convertJsonObject(getImportProgressResult); - String importProgressState = getImportProgressObject.getAsJsonObject("data").get("state").getAsString(); - String progress = getImportProgressObject.getAsJsonObject("data").get("progress").getAsString(); - - if ("Failed".equals(importProgressState)) { - String reason = getImportProgressObject.getAsJsonObject("data").get("reason").getAsString(); - System.out.printf("The job %s failed, reason: %s%n", jobId, reason); - break; - } else if ("Completed".equals(importProgressState)) { - System.out.printf("The job %s completed%n", jobId); - break; - } else { - System.out.printf("The job %s is running, state:%s progress:%s%n", jobId, importProgressState, progress); - } - } - } - /** * @param collectionSchema collection info * @param dropIfExist if collection already exist, will drop firstly and then create again @@ -880,7 +693,7 @@ private static void exampleCloudImport() { .clusterId(CloudImportConsts.CLUSTER_ID).collectionName(CloudImportConsts.COLLECTION_NAME).partitionName(CloudImportConsts.PARTITION_NAME) .apiKey(CloudImportConsts.API_KEY) .build(); - String bulkImportResult = BulkImport.bulkImport(CloudImportConsts.CLOUD_ENDPOINT, request); + String bulkImportResult = BulkImportUtils.bulkImport(CloudImportConsts.CLOUD_ENDPOINT, request); System.out.println(bulkImportResult); System.out.println("\n===================== get import job progress ===================="); @@ -888,12 +701,12 @@ private static void exampleCloudImport() { JsonObject bulkImportObject = convertJsonObject(bulkImportResult); String jobId = bulkImportObject.getAsJsonObject("data").get("jobId").getAsString(); CloudDescribeImportRequest getImportProgressRequest = CloudDescribeImportRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).jobId(jobId).apiKey(CloudImportConsts.API_KEY).build(); - String getImportProgressResult = BulkImport.getImportProgress(CloudImportConsts.CLOUD_ENDPOINT, getImportProgressRequest); + String getImportProgressResult = BulkImportUtils.getImportProgress(CloudImportConsts.CLOUD_ENDPOINT, getImportProgressRequest); System.out.println(getImportProgressResult); System.out.println("\n===================== list import jobs ===================="); CloudListImportJobsRequest listImportJobsRequest = CloudListImportJobsRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).currentPage(1).pageSize(10).apiKey(CloudImportConsts.API_KEY).build(); - String listImportJobsResult = BulkImport.listImportJobs(CloudImportConsts.CLOUD_ENDPOINT, listImportJobsRequest); + String listImportJobsResult = BulkImportUtils.listImportJobs(CloudImportConsts.CLOUD_ENDPOINT, listImportJobsRequest); System.out.println(listImportJobsResult); } diff --git a/examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterStageExample.java b/examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterStageExample.java new file mode 100644 index 000000000..546fc249c --- /dev/null +++ b/examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterStageExample.java @@ -0,0 +1,669 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package io.milvus.v2.bulkwriter; + +import com.google.common.collect.Lists; +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonNull; +import com.google.gson.JsonObject; +import io.milvus.bulkwriter.StageBulkWriter; +import io.milvus.bulkwriter.StageBulkWriterParam; +import io.milvus.bulkwriter.common.clientenum.BulkFileType; +import io.milvus.bulkwriter.common.utils.GeneratorUtils; +import io.milvus.bulkwriter.model.StageUploadResult; +import io.milvus.bulkwriter.request.describe.CloudDescribeImportRequest; +import io.milvus.bulkwriter.request.import_.StageImportRequest; +import io.milvus.bulkwriter.request.list.CloudListImportJobsRequest; +import io.milvus.bulkwriter.restful.BulkImportUtils; +import io.milvus.v1.CommonUtils; +import io.milvus.v2.client.ConnectConfig; +import io.milvus.v2.client.MilvusClientV2; +import io.milvus.v2.common.ConsistencyLevel; +import io.milvus.v2.common.DataType; +import io.milvus.v2.common.IndexParam; +import io.milvus.v2.service.collection.request.AddFieldReq; +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import io.milvus.v2.service.collection.request.DropCollectionReq; +import io.milvus.v2.service.collection.request.HasCollectionReq; +import io.milvus.v2.service.collection.request.LoadCollectionReq; +import io.milvus.v2.service.collection.request.RefreshLoadReq; +import io.milvus.v2.service.index.request.CreateIndexReq; +import io.milvus.v2.service.vector.request.QueryReq; +import io.milvus.v2.service.vector.response.QueryResp; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + + +public class BulkWriterStageExample { + private static final Gson GSON_INSTANCE = new Gson(); + + // milvus + public static final String HOST = "127.0.0.1"; + public static final Integer PORT = 19530; + public static final String USER_NAME = "user.name"; + public static final String PASSWORD = "password"; + + /** + * The value of the URL is fixed. + * For overseas regions, it is: https://api.cloud.zilliz.com + * For regions in China, it is: https://api.cloud.zilliz.com.cn + */ + public static final String CLOUD_ENDPOINT = "https://api.cloud.zilliz.com"; + public static final String API_KEY = "_api_key_for_cluster_org_"; + + + /** + * This is currently a private preview feature. If you need to use it, please submit a request and contact us. + */ + public static final String STAGE_NAME = "_stage_name_for_project_"; + + public static final String CLUSTER_ID = "_your_cloud_cluster_id_"; + // If db_name is not specified, use "" + public static final String DB_NAME = ""; + public static final String COLLECTION_NAME = "_collection_name_on_the_db_"; + // If partition_name is not specified, use "" + public static final String PARTITION_NAME = "_partition_name_on_the_collection_"; + + private static final Integer DIM = 512; + private static final Integer ARRAY_CAPACITY = 10; + private static MilvusClientV2 milvusClient; + + public static void main(String[] args) throws Exception { + createConnection(); + exampleCollectionRemoteStage(BulkFileType.PARQUET); + } + + private static void createConnection() { + System.out.println("\nCreate connection..."); + String url = String.format("http://%s:%s", HOST, PORT); + milvusClient = new MilvusClientV2(ConnectConfig.builder() + .uri(url) + .username(USER_NAME) + .password(PASSWORD) + .build()); + System.out.println("\nConnected"); + } + + private static void exampleCollectionRemoteStage(BulkFileType fileType) throws Exception { + List> originalData = genOriginalData(5); + List rows = genImportData(originalData, true); + + // 4 types vectors + all scalar types + dynamic field enabled, use cloud import api. + // You need to apply a cloud service from Zilliz Cloud(https://zilliz.com/cloud) + CreateCollectionReq.CollectionSchema collectionSchema = buildAllTypesSchema(); + createCollection(COLLECTION_NAME, collectionSchema, false); + + StageUploadResult stageUploadResult = stageRemoteWriter(collectionSchema, fileType, rows); + callStageImport(stageUploadResult.getStageName(), stageUploadResult.getPath()); + verifyImportData(collectionSchema, originalData); + } + + private static void callStageImport(String stageName, String path) throws InterruptedException { + List importDataPath = Lists.newArrayList(path); + StageImportRequest stageImportRequest = StageImportRequest.builder() + .apiKey(API_KEY) + .stageName(stageName).dataPaths(Lists.newArrayList(Collections.singleton(importDataPath))) + .clusterId(CLUSTER_ID).dbName(DB_NAME).collectionName(COLLECTION_NAME).partitionName(PARTITION_NAME) + .build(); + String bulkImportResult = BulkImportUtils.bulkImport(CLOUD_ENDPOINT, stageImportRequest); + System.out.println(bulkImportResult); + + JsonObject bulkImportObject = convertJsonObject(bulkImportResult); + + String jobId = bulkImportObject.getAsJsonObject("data").get("jobId").getAsString(); + System.out.println("Create a cloudImport job, job id: " + jobId); + + System.out.println("\n===================== call cloudListImportJobs ===================="); + CloudListImportJobsRequest listImportJobsRequest = CloudListImportJobsRequest.builder().clusterId(CLUSTER_ID).currentPage(1).pageSize(10).apiKey(API_KEY).build(); + String listImportJobsResult = BulkImportUtils.listImportJobs(CLOUD_ENDPOINT, listImportJobsRequest); + System.out.println(listImportJobsResult); + while (true) { + System.out.println("Wait 5 second to check bulkInsert job state..."); + TimeUnit.SECONDS.sleep(5); + + System.out.println("\n===================== call cloudGetProgress ===================="); + CloudDescribeImportRequest request = CloudDescribeImportRequest.builder().clusterId(CLUSTER_ID).jobId(jobId).apiKey(API_KEY).build(); + String getImportProgressResult = BulkImportUtils.getImportProgress(CLOUD_ENDPOINT, request); + JsonObject getImportProgressObject = convertJsonObject(getImportProgressResult); + String importProgressState = getImportProgressObject.getAsJsonObject("data").get("state").getAsString(); + String progress = getImportProgressObject.getAsJsonObject("data").get("progress").getAsString(); + + if ("Failed".equals(importProgressState)) { + String reason = getImportProgressObject.getAsJsonObject("data").get("reason").getAsString(); + System.out.printf("The job %s failed, reason: %s%n", jobId, reason); + break; + } else if ("Completed".equals(importProgressState)) { + System.out.printf("The job %s completed%n", jobId); + break; + } else { + System.out.printf("The job %s is running, state:%s progress:%s%n", jobId, importProgressState, progress); + } + } + } + + private static List> genOriginalData(int count) { + List> data = new ArrayList<>(); + for (int i = 0; i < count; ++i) { + Map row = new HashMap<>(); + // scalar field + row.put("id", (long)i); + row.put("bool", i % 5 == 0); + row.put("int8", i % 128); + row.put("int16", i % 1000); + row.put("int32", i % 100000); + row.put("float", (float)i / 3); + row.put("double", (double)i / 7); + row.put("varchar", "varchar_" + i); + row.put("json", String.format("{\"dummy\": %s, \"ok\": \"name_%s\"}", i, i)); + + // vector field + row.put("float_vector", CommonUtils.generateFloatVector(DIM)); + row.put("binary_vector", CommonUtils.generateBinaryVector(DIM).array()); + row.put("float16_vector", CommonUtils.generateFloat16Vector(DIM, false).array()); + row.put("sparse_vector", CommonUtils.generateSparseVector()); + + // array field + row.put("array_bool", GeneratorUtils.generatorBoolValue(3)); + row.put("array_int8", GeneratorUtils.generatorInt8Value(4)); + row.put("array_int16", GeneratorUtils.generatorInt16Value(5)); + row.put("array_int32", GeneratorUtils.generatorInt32Value(6)); + row.put("array_int64", GeneratorUtils.generatorLongValue(7)); + row.put("array_varchar", GeneratorUtils.generatorVarcharValue(8, 10)); + row.put("array_float", GeneratorUtils.generatorFloatValue(9)); + row.put("array_double", GeneratorUtils.generatorDoubleValue(10)); + + data.add(row); + } + // a special record with null/default values + { + Map row = new HashMap<>(); + // scalar field + row.put("id", (long)data.size()); + row.put("bool", null); + row.put("int8", null); + row.put("int16", 16); + row.put("int32", null); + row.put("float", null); + row.put("double", null); + row.put("varchar", null); + row.put("json", null); + + // vector field + row.put("float_vector", CommonUtils.generateFloatVector(DIM)); + row.put("binary_vector", CommonUtils.generateBinaryVector(DIM).array()); + row.put("float16_vector", CommonUtils.generateFloat16Vector(DIM, false).array()); + row.put("sparse_vector", CommonUtils.generateSparseVector()); + + // array field + row.put("array_bool", GeneratorUtils.generatorBoolValue(10)); + row.put("array_int8", GeneratorUtils.generatorInt8Value(9)); + row.put("array_int16", null); + row.put("array_int32", GeneratorUtils.generatorInt32Value(7)); + row.put("array_int64", GeneratorUtils.generatorLongValue(6)); + row.put("array_varchar", GeneratorUtils.generatorVarcharValue(5, 10)); + row.put("array_float", GeneratorUtils.generatorFloatValue(4)); + row.put("array_double", null); + + data.add(row); + } + return data; + } + + private static List genImportData(List> originalData, boolean isEnableDynamicField) { + List data = new ArrayList<>(); + for (Map row : originalData) { + JsonObject rowObject = new JsonObject(); + + // scalar field + rowObject.addProperty("id", (Number)row.get("id")); + if (row.get("bool") != null) { // nullable value can be missed + rowObject.addProperty("bool", (Boolean) row.get("bool")); + } + rowObject.addProperty("int8", row.get("int8") == null ? null : (Number) row.get("int8")); + rowObject.addProperty("int16", row.get("int16") == null ? null : (Number) row.get("int16")); + rowObject.addProperty("int32", row.get("int32") == null ? null : (Number) row.get("int32")); + rowObject.addProperty("float", row.get("float") == null ? null : (Number) row.get("float")); + if (row.get("double") != null) { // nullable value can be missed + rowObject.addProperty("double", (Number) row.get("double")); + } + rowObject.addProperty("varchar", row.get("varchar") == null ? null : (String) row.get("varchar")); + + // Note: for JSON field, use gson.fromJson() to construct a real JsonObject + // don't use rowObject.addProperty("json", jsonContent) since the value is treated as a string, not a JsonObject + Object jsonContent = row.get("json"); + rowObject.add("json", jsonContent == null ? null : GSON_INSTANCE.fromJson((String)jsonContent, JsonElement.class)); + + // vector field + rowObject.add("float_vector", GSON_INSTANCE.toJsonTree(row.get("float_vector"))); + rowObject.add("binary_vector", GSON_INSTANCE.toJsonTree(row.get("binary_vector"))); + rowObject.add("float16_vector", GSON_INSTANCE.toJsonTree(row.get("float16_vector"))); + rowObject.add("sparse_vector", GSON_INSTANCE.toJsonTree(row.get("sparse_vector"))); + + // array field + rowObject.add("array_bool", GSON_INSTANCE.toJsonTree(row.get("array_bool"))); + rowObject.add("array_int8", GSON_INSTANCE.toJsonTree(row.get("array_int8"))); + rowObject.add("array_int16", GSON_INSTANCE.toJsonTree(row.get("array_int16"))); + rowObject.add("array_int32", GSON_INSTANCE.toJsonTree(row.get("array_int32"))); + rowObject.add("array_int64", GSON_INSTANCE.toJsonTree(row.get("array_int64"))); + rowObject.add("array_varchar", GSON_INSTANCE.toJsonTree(row.get("array_varchar"))); + rowObject.add("array_float", GSON_INSTANCE.toJsonTree(row.get("array_float"))); + rowObject.add("array_double", GSON_INSTANCE.toJsonTree(row.get("array_double"))); + + // dynamic fields + if (isEnableDynamicField) { + rowObject.addProperty("dynamic", "dynamic_" + row.get("id")); + } + + data.add(rowObject); + } + return data; + } + + private static StageUploadResult stageRemoteWriter(CreateCollectionReq.CollectionSchema collectionSchema, + BulkFileType fileType, + List data) throws Exception { + System.out.printf("\n===================== all field types (%s) ====================%n", fileType.name()); + + try (StageBulkWriter stageBulkWriter = buildStageBulkWriter(collectionSchema, fileType)) { + for (JsonObject rowObject : data) { + stageBulkWriter.appendRow(rowObject); + } + System.out.printf("%s rows appends%n", stageBulkWriter.getTotalRowCount()); + System.out.println("Generate data files..."); + stageBulkWriter.commit(false); + + StageUploadResult stageUploadResult = stageBulkWriter.getStageUploadResult(); + System.out.printf("Data files have been uploaded: %s%n", stageUploadResult); + return stageUploadResult; + } catch (Exception e) { + System.out.println("allTypesRemoteWriter catch exception: " + e); + throw e; + } + } + + private static StageBulkWriter buildStageBulkWriter(CreateCollectionReq.CollectionSchema collectionSchema, BulkFileType fileType) throws IOException { + StageBulkWriterParam bulkWriterParam = StageBulkWriterParam.newBuilder() + .withCollectionSchema(collectionSchema) + .withRemotePath("bulk_data") + .withFileType(fileType) + .withChunkSize(512 * 1024 * 1024) + .withConfig("sep", "|") // only take effect for CSV file + .withCloudEndpoint(CLOUD_ENDPOINT) + .withApiKey(API_KEY) + .withStageName(STAGE_NAME) + .build(); + return new StageBulkWriter(bulkWriterParam); + } + + /** + * @param collectionSchema collection info + * @param dropIfExist if collection already exist, will drop firstly and then create again + */ + private static void createCollection(String collectionName, CreateCollectionReq.CollectionSchema collectionSchema, boolean dropIfExist) { + System.out.println("\n===================== create collection ===================="); + checkMilvusClientIfExist(); + + CreateCollectionReq requestCreate = CreateCollectionReq.builder() + .collectionName(collectionName) + .collectionSchema(collectionSchema) + .consistencyLevel(ConsistencyLevel.BOUNDED) + .build(); + + Boolean has = milvusClient.hasCollection(HasCollectionReq.builder().collectionName(collectionName).build()); + if (has) { + if (dropIfExist) { + milvusClient.dropCollection(DropCollectionReq.builder().collectionName(collectionName).build()); + milvusClient.createCollection(requestCreate); + } + } else { + milvusClient.createCollection(requestCreate); + } + + System.out.printf("Collection %s created%n", collectionName); + } + + private static void comparePrint(CreateCollectionReq.CollectionSchema collectionSchema, + Map expectedData, Map fetchedData, + String fieldName) { + CreateCollectionReq.FieldSchema field = collectionSchema.getField(fieldName); + Object expectedValue = expectedData.get(fieldName); + if (expectedValue == null) { + if (field.getDefaultValue() != null) { + expectedValue = field.getDefaultValue(); + // for Int8/Int16 value, the default value is Short type, the returned value is Integer type + if (expectedValue instanceof Short) { + expectedValue = ((Short)expectedValue).intValue(); + } + } + } + + Object fetchedValue = fetchedData.get(fieldName); + if (fetchedValue == null || fetchedValue instanceof JsonNull) { + if (!field.getIsNullable()) { + throw new RuntimeException("Field is not nullable but fetched data is null"); + } + if (expectedValue != null) { + throw new RuntimeException("Expected value is not null but fetched data is null"); + } + return; // both fetchedValue and expectedValue are null + } + + boolean matched; + if (fetchedValue instanceof Float) { + matched = Math.abs((Float)fetchedValue - (Float)expectedValue) < 1e-4; + } else if (fetchedValue instanceof Double) { + matched = Math.abs((Double)fetchedValue - (Double)expectedValue) < 1e-8; + } else if (fetchedValue instanceof JsonElement) { + JsonElement expectedJson = GSON_INSTANCE.fromJson((String)expectedValue, JsonElement.class); + matched = fetchedValue.equals(expectedJson); + } else if (fetchedValue instanceof ByteBuffer) { + byte[] bb = ((ByteBuffer)fetchedValue).array(); + matched = Arrays.equals(bb, (byte[])expectedValue); + } else if (fetchedValue instanceof List) { + matched = fetchedValue.equals(expectedValue); + // currently, for array field, null value, the server returns an empty list + if (((List) fetchedValue).isEmpty() && expectedValue==null) { + matched = true; + } + } else { + matched = fetchedValue.equals(expectedValue); + } + + if (!matched) { + System.out.print("Fetched value:"); + System.out.println(fetchedValue); + System.out.print("Expected value:"); + System.out.println(expectedValue); + throw new RuntimeException("Fetched data is unmatched"); + } + } + + private static void verifyImportData(CreateCollectionReq.CollectionSchema collectionSchema, List> rows) { + createIndex(); + + List QUERY_IDS = Lists.newArrayList(1L, (long)rows.get(rows.size()-1).get("id")); + System.out.printf("Load collection and query items %s%n", QUERY_IDS); + loadCollection(); + + String expr = String.format("id in %s", QUERY_IDS); + System.out.println(expr); + + List results = query(expr, Lists.newArrayList("*")); + System.out.println("Verify data..."); + if (results.size() != QUERY_IDS.size()) { + throw new RuntimeException("Result count is incorrect"); + } + for (QueryResp.QueryResult result : results) { + Map fetchedEntity = result.getEntity(); + long id = (Long)fetchedEntity.get("id"); + Map originalEntity = rows.get((int)id); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "bool"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "int8"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "int16"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "int32"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "float"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "double"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "varchar"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "json"); + + comparePrint(collectionSchema, originalEntity, fetchedEntity, "array_bool"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "array_int8"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "array_int16"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "array_int32"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "array_int64"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "array_varchar"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "array_float"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "array_double"); + + comparePrint(collectionSchema, originalEntity, fetchedEntity, "float_vector"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "binary_vector"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "float16_vector"); + comparePrint(collectionSchema, originalEntity, fetchedEntity, "sparse_vector"); + + System.out.println(fetchedEntity); + } + System.out.println("Result is correct!"); + } + + private static void createIndex() { + System.out.println("Create index..."); + checkMilvusClientIfExist(); + + List indexes = new ArrayList<>(); + indexes.add(IndexParam.builder() + .fieldName("float_vector") + .indexType(IndexParam.IndexType.FLAT) + .metricType(IndexParam.MetricType.L2) + .build()); + indexes.add(IndexParam.builder() + .fieldName("binary_vector") + .indexType(IndexParam.IndexType.BIN_FLAT) + .metricType(IndexParam.MetricType.HAMMING) + .build()); + indexes.add(IndexParam.builder() + .fieldName("float16_vector") + .indexType(IndexParam.IndexType.FLAT) + .metricType(IndexParam.MetricType.IP) + .build()); + indexes.add(IndexParam.builder() + .fieldName("sparse_vector") + .indexType(IndexParam.IndexType.SPARSE_WAND) + .metricType(IndexParam.MetricType.IP) + .build()); + + milvusClient.createIndex(CreateIndexReq.builder() + .collectionName(COLLECTION_NAME) + .indexParams(indexes) + .build()); + + milvusClient.loadCollection(LoadCollectionReq.builder() + .collectionName(COLLECTION_NAME) + .build()); + } + + private static void loadCollection() { + System.out.println("Refresh load collection..."); + checkMilvusClientIfExist(); + // RefreshLoad is a new interface from v2.5.3, + // mainly used when there are new segments generated by bulkinsert request, + // force the new segments to be loaded into memory. + milvusClient.refreshLoad(RefreshLoadReq.builder() + .collectionName(COLLECTION_NAME) + .build()); + System.out.println("Collection row number: " + getCollectionRowCount()); + } + + private static List query(String expr, List outputFields) { + System.out.println("========== query() =========="); + checkMilvusClientIfExist(); + QueryReq test = QueryReq.builder() + .collectionName(COLLECTION_NAME) + .filter(expr) + .outputFields(outputFields) + .build(); + QueryResp response = milvusClient.query(test); + return response.getQueryResults(); + } + + private static Long getCollectionRowCount() { + System.out.println("========== getCollectionRowCount() =========="); + checkMilvusClientIfExist(); + + // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible + QueryResp countR = milvusClient.query(QueryReq.builder() + .collectionName(COLLECTION_NAME) + .filter("") + .outputFields(Collections.singletonList("count(*)")) + .consistencyLevel(ConsistencyLevel.STRONG) + .build()); + return (long)countR.getQueryResults().get(0).getEntity().get("count(*)"); + } + + private static CreateCollectionReq.CollectionSchema buildAllTypesSchema() { + CreateCollectionReq.CollectionSchema schemaV2 = CreateCollectionReq.CollectionSchema.builder() + .enableDynamicField(true) + .build(); + // scalar field + schemaV2.addField(AddFieldReq.builder() + .fieldName("id") + .dataType(DataType.Int64) + .isPrimaryKey(Boolean.TRUE) + .autoID(false) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("bool") + .dataType(DataType.Bool) + .isNullable(true) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("int8") + .dataType(DataType.Int8) + .defaultValue((short)88) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("int16") + .dataType(DataType.Int16) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("int32") + .dataType(DataType.Int32) + .isNullable(true) + .defaultValue(999999) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("float") + .dataType(DataType.Float) + .isNullable(true) + .defaultValue((float)3.14159) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("double") + .dataType(DataType.Double) + .isNullable(true) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("varchar") + .dataType(DataType.VarChar) + .maxLength(512) + .isNullable(true) + .defaultValue("this is default value") + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("json") + .dataType(DataType.JSON) + .isNullable(true) + .build()); + + // vector fields + schemaV2.addField(AddFieldReq.builder() + .fieldName("float_vector") + .dataType(DataType.FloatVector) + .dimension(DIM) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("binary_vector") + .dataType(DataType.BinaryVector) + .dimension(DIM) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("float16_vector") + .dataType(DataType.Float16Vector) + .dimension(DIM) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("sparse_vector") + .dataType(DataType.SparseFloatVector) + .build()); + + // array fields + schemaV2.addField(AddFieldReq.builder() + .fieldName("array_bool") + .dataType(DataType.Array) + .maxCapacity(ARRAY_CAPACITY) + .elementType(DataType.Bool) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("array_int8") + .dataType(DataType.Array) + .maxCapacity(ARRAY_CAPACITY) + .elementType(DataType.Int8) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("array_int16") + .dataType(DataType.Array) + .maxCapacity(ARRAY_CAPACITY) + .elementType(DataType.Int16) + .isNullable(true) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("array_int32") + .dataType(DataType.Array) + .maxCapacity(ARRAY_CAPACITY) + .elementType(DataType.Int32) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("array_int64") + .dataType(DataType.Array) + .maxCapacity(ARRAY_CAPACITY) + .elementType(DataType.Int64) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("array_varchar") + .dataType(DataType.Array) + .maxCapacity(ARRAY_CAPACITY) + .elementType(DataType.VarChar) + .maxLength(512) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("array_float") + .dataType(DataType.Array) + .maxCapacity(ARRAY_CAPACITY) + .elementType(DataType.Float) + .build()); + schemaV2.addField(AddFieldReq.builder() + .fieldName("array_double") + .dataType(DataType.Array) + .maxCapacity(ARRAY_CAPACITY) + .elementType(DataType.Double) + .isNullable(true) + .build()); + + return schemaV2; + } + + private static void checkMilvusClientIfExist() { + if (milvusClient == null) { + String msg = "milvusClient is null. Please initialize it by calling createConnection() first before use."; + throw new RuntimeException(msg); + } + } + + private static JsonObject convertJsonObject(String result) { + return GSON_INSTANCE.fromJson(result, JsonObject.class); + } +} diff --git a/examples/src/main/java/io/milvus/v2/bulkwriter/CsvDataObject.java b/examples/src/main/java/io/milvus/v2/bulkwriter/CsvDataObject.java new file mode 100644 index 000000000..9a2dd5fda --- /dev/null +++ b/examples/src/main/java/io/milvus/v2/bulkwriter/CsvDataObject.java @@ -0,0 +1,32 @@ +package io.milvus.v2.bulkwriter; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; + +import java.util.List; + +public class CsvDataObject { + private static final Gson GSON_INSTANCE = new Gson(); + + @JsonProperty + private String vector; + @JsonProperty + private String path; + @JsonProperty + private String label; + + public String getVector() { + return vector; + } + public String getPath() { + return path; + } + public String getLabel() { + return label; + } + public List toFloatArray() { + return GSON_INSTANCE.fromJson(vector, new TypeToken>() { + }.getType()); + } +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/RemoteBulkWriter.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/RemoteBulkWriter.java index 22b0739a8..98a418037 100644 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/RemoteBulkWriter.java +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/RemoteBulkWriter.java @@ -35,7 +35,6 @@ import org.slf4j.LoggerFactory; import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.nio.file.FileVisitOption; import java.nio.file.Files; @@ -43,6 +42,7 @@ import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.ExecutionException; public class RemoteBulkWriter extends LocalBulkWriter { private static final Logger logger = LoggerFactory.getLogger(RemoteBulkWriter.class); @@ -222,6 +222,13 @@ private boolean objectExists(String objectName) throws Exception { } String msg = String.format("Failed to stat Azure object %s, error: %s", objectName, e.getServiceMessage()); ExceptionUtils.throwUnExpectedException(msg); + } catch (ExecutionException e) { + if (e.getCause().getCause() instanceof ErrorResponseException + && "NoSuchKey".equals(((ErrorResponseException) e.getCause().getCause()).errorResponse().code())) { + return false; + } + String msg = String.format("Failed to stat MinIO/S3 object %s, error: %s", objectName, e.getCause().getMessage()); + ExceptionUtils.throwUnExpectedException(msg); } return true; } @@ -243,13 +250,12 @@ private void uploadObject(String filePath, String objectName) throws Exception { logger.info(String.format("Prepare to upload %s to %s", filePath, objectName)); File file = new File(filePath); - FileInputStream fileInputStream = new FileInputStream(file); if (connectParam instanceof S3ConnectParam) { S3ConnectParam s3ConnectParam = (S3ConnectParam) connectParam; - storageClient.putObjectStream(fileInputStream, file.length(), s3ConnectParam.getBucketName(), objectName); + storageClient.putObject(file, s3ConnectParam.getBucketName(), objectName); } else if (connectParam instanceof AzureConnectParam) { AzureConnectParam azureConnectParam = (AzureConnectParam) connectParam; - storageClient.putObjectStream(fileInputStream, file.length(), azureConnectParam.getContainerName(), objectName); + storageClient.putObject(file, azureConnectParam.getContainerName(), objectName); } else { ExceptionUtils.throwUnExpectedException("Blob storage client is not initialized"); } diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageBulkWriter.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageBulkWriter.java new file mode 100644 index 000000000..ee8d01e09 --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageBulkWriter.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.bulkwriter; + +import com.google.common.collect.Lists; +import com.google.gson.JsonObject; +import io.milvus.bulkwriter.model.StageUploadResult; +import io.milvus.common.utils.ExceptionUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.FileVisitOption; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; + +public class StageBulkWriter extends LocalBulkWriter { + private static final Logger logger = LoggerFactory.getLogger(StageBulkWriter.class); + + private String remotePath; + private List> remoteFiles; + private StageOperation stageWriter; + private StageBulkWriterParam stageBulkWriterParam; + + public StageBulkWriter(StageBulkWriterParam bulkWriterParam) throws IOException { + super(bulkWriterParam.getCollectionSchema(), + bulkWriterParam.getChunkSize(), + bulkWriterParam.getFileType(), + generatorLocalPath(), + bulkWriterParam.getConfig()); + Path path = Paths.get(bulkWriterParam.getRemotePath()); + Path remoteDirPath = path.resolve(getUUID()); + this.remotePath = remoteDirPath + "/"; + this.stageWriter = initStageWriterParams(bulkWriterParam); + this.stageBulkWriterParam = bulkWriterParam; + + this.remoteFiles = Lists.newArrayList(); + logger.info("Remote buffer writer initialized, target path: {}", remotePath); + + } + + private StageOperation initStageWriterParams(StageBulkWriterParam bulkWriterParam) throws IOException { + StageOperationParam stageWriterParam = StageOperationParam.newBuilder() + .withCloudEndpoint(bulkWriterParam.getCloudEndpoint()).withApiKey(bulkWriterParam.getApiKey()) + .withStageName(bulkWriterParam.getStageName()).withPath(remotePath) + .build(); + return new StageOperation(stageWriterParam); + } + + @Override + public void appendRow(JsonObject rowData) throws IOException, InterruptedException { + super.appendRow(rowData); + } + + @Override + public void commit(boolean async) throws InterruptedException { + super.commit(async); + } + + @Override + protected String getDataPath() { + return remotePath; + } + + @Override + public List> getBatchFiles() { + return remoteFiles; + } + + public StageUploadResult getStageUploadResult() { + return StageUploadResult.builder() + .stageName(stageBulkWriterParam.getStageName()) + .path(remotePath) + .build(); + } + + @Override + protected void exit() throws InterruptedException { + super.exit(); + // remove the temp folder "bulk_writer" + Path parentPath = Paths.get(localPath).getParent(); + if (parentPath.toFile().exists() && isEmptyDirectory(parentPath)) { + try { + Files.delete(parentPath); + logger.info("Delete empty directory: " + parentPath); + } catch (IOException e) { + e.printStackTrace(); + } + } + } + + private static boolean isEmptyDirectory(Path directory) { + try { + return !Files.walk(directory, 1, FileVisitOption.FOLLOW_LINKS) + .skip(1) // Skip the root directory itself + .findFirst() + .isPresent(); + } catch (IOException e) { + e.printStackTrace(); + } + return false; + } + + private void rmLocal(String file) { + try { + Path filePath = Paths.get(file); + filePath.toFile().delete(); + + Path parentDir = filePath.getParent(); + if (parentDir != null && !parentDir.toString().equals(localPath)) { + try { + Files.delete(parentDir); + logger.info("Delete empty directory: " + parentDir); + } catch (IOException ex) { + logger.warn("Failed to delete empty directory: " + parentDir); + } + } + } catch (Exception ex) { + logger.warn("Failed to delete local file: " + file); + } + } + + @Override + protected void callBack(List fileList) { + serialImportData(fileList); + } + + @Override + public void close() throws Exception { + logger.info("execute remaining actions to prevent loss of memory data or residual empty directories."); + exit(); + logger.info(String.format("RemoteBulkWriter done! output remote files: %s", getBatchFiles())); + } + + private void serialImportData(List fileList) { + List remoteFileList = new ArrayList<>(); + try { + for (String filePath : fileList) { + String relativeFilePath = filePath.replace(super.getDataPath(), ""); + String minioFilePath = getMinioFilePath(remotePath, relativeFilePath); + + uploadObject(filePath, minioFilePath); + remoteFileList.add(minioFilePath); + rmLocal(filePath); + } + + } catch (Exception e) { + ExceptionUtils.throwUnExpectedException(String.format("Failed to upload files, error: %s", e)); + } + + logger.info("Successfully upload files: " + fileList); + remoteFiles.add(remoteFileList); + } + + private void uploadObject(String filePath, String objectName) throws Exception { + logger.info(String.format("Prepare to upload %s to %s", filePath, objectName)); + + stageWriter.uploadFileToStage(filePath); + logger.info(String.format("Upload file %s to %s", filePath, objectName)); + } + + private static String generatorLocalPath() { + Path currentWorkingDirectory = Paths.get("").toAbsolutePath(); + Path currentScriptPath = currentWorkingDirectory.resolve("bulk_writer"); + return currentScriptPath.toString(); + } + + private static String getMinioFilePath(String remotePath, String relativeFilePath) { + remotePath = remotePath.startsWith("/") ? remotePath.substring(1) : remotePath; + Path remote = Paths.get(remotePath); + + relativeFilePath = relativeFilePath.startsWith("/") ? relativeFilePath.substring(1) : relativeFilePath; + Path relative = Paths.get(relativeFilePath); + Path joinedPath = remote.resolve(relative); + return joinedPath.toString(); + } +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageBulkWriterParam.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageBulkWriterParam.java new file mode 100644 index 000000000..8425fda56 --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageBulkWriterParam.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.bulkwriter; + +import io.milvus.bulkwriter.common.clientenum.BulkFileType; +import io.milvus.bulkwriter.common.utils.V2AdapterUtils; +import io.milvus.exception.ParamException; +import io.milvus.param.ParamUtils; +import io.milvus.param.collection.CollectionSchemaParam; +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import lombok.Getter; +import lombok.NonNull; +import lombok.ToString; +import org.jetbrains.annotations.NotNull; + +import java.util.HashMap; +import java.util.Map; + +/** + * Parameters for stageBulkWriter interface. + */ +@Getter +@ToString +public class StageBulkWriterParam { + private final CreateCollectionReq.CollectionSchema collectionSchema; + private final String remotePath; + private final long chunkSize; + private final BulkFileType fileType; + private final Map config; + + private final String cloudEndpoint; + private final String apiKey; + private final String stageName; + + private StageBulkWriterParam(@NonNull Builder builder) { + this.collectionSchema = builder.collectionSchema; + this.remotePath = builder.remotePath; + this.chunkSize = builder.chunkSize; + this.fileType = builder.fileType; + this.config = builder.config; + + this.cloudEndpoint = builder.cloudEndpoint; + this.apiKey = builder.apiKey; + this.stageName = builder.stageName; + } + + public static Builder newBuilder() { + return new Builder(); + } + + /** + * Builder for {@link StageBulkWriterParam} class. + */ + public static final class Builder { + private CreateCollectionReq.CollectionSchema collectionSchema; + private String remotePath; + private long chunkSize = 128 * 1024 * 1024; + private BulkFileType fileType = BulkFileType.PARQUET; + private Map config = new HashMap<>(); + + private String cloudEndpoint; + private String apiKey; + + private String stageName; + + private Builder() { + } + + /** + * Sets the collection info. + * + * @param collectionSchema collection info + * @return Builder + */ + public Builder withCollectionSchema(@NonNull CollectionSchemaParam collectionSchema) { + this.collectionSchema = V2AdapterUtils.convertV1Schema(collectionSchema); + return this; + } + + /** + * Sets the collection schema by V2. + * + * @param collectionSchema collection schema + * @return Builder + */ + public Builder withCollectionSchema(@NonNull CreateCollectionReq.CollectionSchema collectionSchema) { + this.collectionSchema = collectionSchema; + return this; + } + + /** + * Sets the remotePath. + * + * @param remotePath remote path + * @return Builder + */ + public Builder withRemotePath(@NonNull String remotePath) { + this.remotePath = remotePath; + return this; + } + + public Builder withChunkSize(long chunkSize) { + this.chunkSize = chunkSize; + return this; + } + + public Builder withFileType(@NonNull BulkFileType fileType) { + this.fileType = fileType; + return this; + } + + public Builder withConfig(String key, Object val) { + this.config.put(key, val); + return this; + } + + public Builder withCloudEndpoint(@NotNull String cloudEndpoint) { + this.cloudEndpoint = cloudEndpoint; + return this; + } + + public Builder withApiKey(@NotNull String apiKey) { + this.apiKey = apiKey; + return this; + } + + public Builder withStageName(@NotNull String stageName) { + this.stageName = stageName; + return this; + } + + /** + * Verifies parameters and creates a new {@link StageBulkWriterParam} instance. + * + * @return {@link StageBulkWriterParam} + */ + public StageBulkWriterParam build() throws ParamException { + ParamUtils.CheckNullEmptyString(remotePath, "localPath"); + + if (collectionSchema == null) { + throw new ParamException("collectionSchema cannot be null"); + } + + return new StageBulkWriterParam(this); + } + } + +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageOperation.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageOperation.java new file mode 100644 index 000000000..a08692e6e --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageOperation.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.bulkwriter; + +import com.google.gson.Gson; +import io.milvus.bulkwriter.common.utils.FileUtils; +import io.milvus.bulkwriter.model.StageUploadResult; +import io.milvus.bulkwriter.request.stage.ApplyStageRequest; +import io.milvus.bulkwriter.response.ApplyStageResponse; +import io.milvus.bulkwriter.restful.DataStageUtils; +import io.milvus.bulkwriter.storage.StorageClient; +import io.milvus.bulkwriter.storage.client.MinioStorageClient; +import io.milvus.exception.ParamException; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +public class StageOperation { + private static final Logger logger = LoggerFactory.getLogger(StageOperation.class); + private final String cloudEndpoint; + private final String apiKey; + private final String stageName; + private Pair, Long> localPathPair; + private final String path; + + private StorageClient storageClient; + private ApplyStageResponse applyStageResponse; + + public StageOperation(StageOperationParam stageWriterParam) throws IOException { + cloudEndpoint = stageWriterParam.getCloudEndpoint(); + apiKey = stageWriterParam.getApiKey(); + stageName = stageWriterParam.getStageName(); + path = convertDirPath(stageWriterParam.getPath()); + + refreshStageAndClient(); + } + + public StageUploadResult uploadFileToStage(String localDirOrFilePath) throws Exception { + localPathPair = FileUtils.processLocalPath(localDirOrFilePath); + initValidator(); + + logger.info("begin to upload file to stage, localDirOrFilePath:{}, fileCount:{} to stageName:{}, stagePath:{}", localDirOrFilePath, localPathPair.getKey().size(), applyStageResponse.getStageName(), path); + long startTime = System.currentTimeMillis(); + + int concurrency = 20; // 并发线程数 + ExecutorService executor = Executors.newFixedThreadPool(concurrency); + AtomicInteger currentFileCount = new AtomicInteger(0); + long totalFiles = localPathPair.getKey().size(); + AtomicLong processedBytes = new AtomicLong(0); + long totalBytes = localPathPair.getValue(); + + List> futures = new ArrayList<>(); + for (String localFilePath : localPathPair.getKey()) { + futures.add(executor.submit(() -> { + long current = currentFileCount.incrementAndGet(); + File file = new File(localFilePath); + long fileStartTime = System.currentTimeMillis(); + try { + uploadLocalFileToStage(localFilePath); + long bytes = processedBytes.addAndGet(file.length()); + long elapsed = System.currentTimeMillis() - fileStartTime; + double percent = totalBytes == 0 ? 100.0 : (bytes * 100.0 / totalBytes); + logger.info("Uploaded file {}/{}: {} ({} bytes) elapsed:{} ms, progress(total bytes): {}/{} bytes, progress(total percentage):{}%", + current, totalFiles, localFilePath, file.length(), elapsed, bytes, totalBytes, String.format("%.2f", percent)); + } catch (Exception e) { + logger.error("Upload failed for file: {}", localFilePath, e); + } + })); + } + + for (Future f : futures) { + f.get(); + } + executor.shutdown(); + + long totalElapsed = (System.currentTimeMillis() - startTime) / 1000; + logger.info("all files in {} has been uploaded to stage, stageName:{}, stagePath:{}, totalFileCount:{}, totalFileSize:{}, cost times:{} s", + localDirOrFilePath, applyStageResponse.getStageName(), path, localPathPair.getKey().size(), localPathPair.getValue(), totalElapsed); + return StageUploadResult.builder().stageName(applyStageResponse.getStageName()).path(path).build(); + } + + private void initValidator() { + if (localPathPair.getValue() > applyStageResponse.getCondition().getMaxContentLength()) { + String msg = String.format("localFileTotalSize %s exceeds the maximum contentLength limit %s defined in the condition. If you want to upload larger files, please contact us to lift the restriction", localPathPair.getValue(), applyStageResponse.getCondition().getMaxContentLength()); + logger.error(msg); + throw new ParamException(msg); + } + } + + private void refreshStageAndClient() { + logger.info("refreshing Stage info..."); + ApplyStageRequest applyStageRequest = ApplyStageRequest.builder() + .apiKey(apiKey) + .stageName(stageName) + .path(path) + .build(); + String result = DataStageUtils.applyStage(cloudEndpoint, applyStageRequest); + applyStageResponse = new Gson().fromJson(result, ApplyStageResponse.class); + logger.info("stage info refreshed"); + + storageClient = MinioStorageClient.getStorageClient( + applyStageResponse.getCloud(), + applyStageResponse.getEndpoint(), + applyStageResponse.getCredentials().getTmpAK(), + applyStageResponse.getCredentials().getTmpSK(), + applyStageResponse.getCredentials().getSessionToken(), + applyStageResponse.getRegion(), null); + logger.info("storage client refreshed"); + } + + private String convertDirPath(String inputPath) { + if (StringUtils.isEmpty(inputPath) || inputPath.endsWith("/")) { + return inputPath; + } + return inputPath + "/"; + } + + private String uploadLocalFileToStage(String localFilePath) throws Exception { + File file = new File(localFilePath); + String fileName = file.getName(); + String remoteFilePath = applyStageResponse.getUploadPath() + fileName; + putObject(file, remoteFilePath); + return remoteFilePath; + } + + private void putObject(File file, String remoteFilePath) throws Exception { + Instant instant = Instant.parse(applyStageResponse.getCredentials().getExpireTime()); + Date expireTime = Date.from(instant); + if (new Date().after(expireTime)) { + synchronized (this) { + if (new Date().after(expireTime)) { + refreshStageAndClient(); + } + } + } + uploadWithRetry(file, remoteFilePath); + } + + private void uploadWithRetry(File file, String remoteFilePath) { + final int maxRetries = 3; + int attempt = 0; + while (attempt < maxRetries) { + try { + storageClient.putObject(file, applyStageResponse.getBucketName(), remoteFilePath); + return; + } catch (Exception e) { + attempt++; + refreshStageAndClient(); + logger.warn("Attempt {} failed to upload {}", attempt, file.getAbsolutePath(), e); + if (attempt == maxRetries) { + throw new RuntimeException("Upload failed after " + maxRetries + " attempts", e); + } + try { + Thread.sleep(5000L); + } catch (InterruptedException ignored) { + } + } + } + } +} \ No newline at end of file diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageOperationParam.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageOperationParam.java new file mode 100644 index 000000000..2d390d899 --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageOperationParam.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.bulkwriter; + +import io.milvus.exception.ParamException; +import io.milvus.param.ParamUtils; +import lombok.Getter; +import lombok.NonNull; +import lombok.ToString; +import org.jetbrains.annotations.NotNull; + +/** + * Parameters for bulkWriter interface. + */ +@Getter +@ToString +public class StageOperationParam { + private final String cloudEndpoint; + private final String apiKey; + private final String stageName; + private final String path; + + private StageOperationParam(@NonNull Builder builder) { + this.cloudEndpoint = builder.cloudEndpoint; + this.apiKey = builder.apiKey; + this.stageName = builder.stageName; + this.path = builder.path; + } + + public static Builder newBuilder() { + return new Builder(); + } + + /** + * Builder for {@link StageOperationParam} class. + */ + public static final class Builder { + private String cloudEndpoint; + + private String apiKey; + + private String stageName; + + private String path; + + private Builder() { + } + + public Builder withCloudEndpoint(@NotNull String cloudEndpoint) { + this.cloudEndpoint = cloudEndpoint; + return this; + } + + public Builder withApiKey(@NotNull String apiKey) { + this.apiKey = apiKey; + return this; + } + + public Builder withStageName(@NotNull String stageName) { + this.stageName = stageName; + return this; + } + + /** + * Sets the path + * If specify the value, will use the path of the upload file + */ + public Builder withPath(String path) { + this.path = path; + return this; + } + + /** + * Verifies parameters and creates a new {@link StageOperationParam} instance. + * + * @return {@link StageOperationParam} + */ + public StageOperationParam build() throws ParamException { + ParamUtils.CheckNullEmptyString(cloudEndpoint, "cloudEndpoint"); + ParamUtils.CheckNullEmptyString(apiKey, "apiKey"); + ParamUtils.CheckNullEmptyString(stageName, "stageName"); + + return new StageOperationParam(this); + } + } + +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/BulkFileType.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/BulkFileType.java index aba2864b8..728fb1871 100644 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/BulkFileType.java +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/BulkFileType.java @@ -19,15 +19,20 @@ package io.milvus.bulkwriter.common.clientenum; +import lombok.Getter; + +@Getter public enum BulkFileType { - PARQUET(1), - JSON(2), - CSV(3), + PARQUET(1, ".parquet"), + JSON(2, ".json"), + CSV(3, ".csv"), ; private Integer code; + private String suffix; - BulkFileType(Integer code) { + BulkFileType(Integer code, String suffix) { this.code = code; + this.suffix = suffix; } } diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/CloudStorage.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/CloudStorage.java index 51a06ab7f..6bd3a9118 100644 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/CloudStorage.java +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/CloudStorage.java @@ -45,6 +45,15 @@ public enum CloudStorage { this.replace = replace; } + public static CloudStorage getCloudStorage(String cloudName) { + for (CloudStorage cloudStorage : values()) { + if (cloudStorage.getCloudName().equals(cloudName)) { + return cloudStorage; + } + } + throw new ParamException("no support others cloudName"); + } + public String getEndpoint(String... replaceParams) { if (StringUtils.isEmpty(replace)) { return endpoint; diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/FileUtils.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/FileUtils.java new file mode 100644 index 000000000..b0db69037 --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/FileUtils.java @@ -0,0 +1,53 @@ +package io.milvus.bulkwriter.common.utils; + +import com.google.common.collect.Lists; +import org.apache.commons.lang3.tuple.Pair; + +import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; + +public class FileUtils { + // Get all filePath with the inputFileSuffix in the localPath + public static Pair, Long> processLocalPath(String localPath) { + Path path = Paths.get(localPath); + if (Files.notExists(path)) { + throw new IllegalArgumentException("Path does not exist: " + localPath); + } + if (Files.isRegularFile(path)) { + return Pair.of(Lists.newArrayList(path.toString()), path.toFile().length()); + } else if (Files.isDirectory(path)) { + return FileUtils.findFilesRecursively(path.toFile()); + } + return Pair.of(new ArrayList<>(), 0L); + } + + /** + * Finds files with the given suffix in the first level subdirectories of the folder. + */ + public static Pair, Long> findFilesRecursively(File folder) { + List result = new ArrayList<>(); + long totalSize = 0L; + + File[] entries = folder.listFiles(); + if (entries == null) { + return Pair.of(result, 0L); + } + + for (File entry : entries) { + if (entry.isFile()) { + result.add(entry.getAbsolutePath()); + totalSize += entry.length(); + } else if (entry.isDirectory()) { + Pair, Long> subResult = findFilesRecursively(entry); + result.addAll(subResult.getLeft()); + totalSize += subResult.getRight(); + } + } + + return Pair.of(result, totalSize); + } +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/StorageUtils.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/StorageUtils.java new file mode 100644 index 000000000..893bf1da1 --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/StorageUtils.java @@ -0,0 +1,22 @@ +package io.milvus.bulkwriter.common.utils; + +import io.milvus.bulkwriter.common.clientenum.CloudStorage; +import io.milvus.exception.ParamException; + +public class StorageUtils { + public static String getObjectUrl(String cloudName, String bucketName, String objectPath, String region) { + CloudStorage cloudStorage = CloudStorage.getCloudStorage(cloudName); + switch (cloudStorage) { + case AWS: + return String.format("https://s3.%s.amazonaws.com/%s/%s", region, bucketName, objectPath); + case GCP: + return String.format("https://storage.cloud.google.com/%s/%s", bucketName, objectPath); + case TC: + return String.format("https://%s.cos.%s.myqcloud.com/%s", bucketName, region, objectPath); + case ALI: + return String.format("https://%s.oss-%s.aliyuncs.com/%s", bucketName, region, objectPath); + default: + throw new ParamException("no support others remote storage address"); + } + } +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/model/CompleteMultipartUploadOutputModel.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/model/CompleteMultipartUploadOutputModel.java new file mode 100644 index 000000000..654f3c8eb --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/model/CompleteMultipartUploadOutputModel.java @@ -0,0 +1,39 @@ +package io.milvus.bulkwriter.model; + +import org.simpleframework.xml.Element; +import org.simpleframework.xml.Namespace; +import org.simpleframework.xml.Root; + +@Root(name = "CompleteMultipartUploadOutput", strict = false) +@Namespace(reference = "http://s3.amazonaws.com/doc/2006-03-01/") +public class CompleteMultipartUploadOutputModel { + @Element(name = "Location") + private String location; + + @Element(name = "Bucket") + private String bucket; + + @Element(name = "Key") + private String object; + + @Element(name = "ETag") + private String etag; + + public CompleteMultipartUploadOutputModel() {} + + public String location() { + return location; + } + + public String bucket() { + return bucket; + } + + public String object() { + return object; + } + + public String etag() { + return etag; + } +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/model/StageUploadResult.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/model/StageUploadResult.java new file mode 100644 index 000000000..006a42e06 --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/model/StageUploadResult.java @@ -0,0 +1,15 @@ +package io.milvus.bulkwriter.model; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@Builder +@AllArgsConstructor +@NoArgsConstructor +public class StageUploadResult { + private String stageName; + private String path; +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/request/import_/CloudImportRequest.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/request/import_/CloudImportRequest.java index 3d976d661..7c302ff20 100644 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/request/import_/CloudImportRequest.java +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/request/import_/CloudImportRequest.java @@ -33,6 +33,7 @@ public class CloudImportRequest extends BaseImportRequest { private String objectUrl; private String accessKey; private String secretKey; + private String token; private String clusterId; private String dbName; private String collectionName; diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/request/import_/StageImportRequest.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/request/import_/StageImportRequest.java new file mode 100644 index 000000000..2584f2b6f --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/request/import_/StageImportRequest.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.bulkwriter.request.import_; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@Data +@SuperBuilder +@AllArgsConstructor +@NoArgsConstructor +public class StageImportRequest extends BaseImportRequest { + private String stageName; + private List> dataPaths; + + private String clusterId; + private String dbName; + private String collectionName; + private String partitionName; +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/request/stage/ApplyStageRequest.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/request/stage/ApplyStageRequest.java new file mode 100644 index 000000000..e036a8eca --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/request/stage/ApplyStageRequest.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.bulkwriter.request.stage; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +@AllArgsConstructor +@NoArgsConstructor +public class ApplyStageRequest extends BaseStageRequest { + private String stageName; + + private String path; +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/request/stage/BaseStageRequest.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/request/stage/BaseStageRequest.java new file mode 100644 index 000000000..a0f732eac --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/request/stage/BaseStageRequest.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.bulkwriter.request.stage; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +import java.io.Serializable; +import java.util.Map; + +@Data +@SuperBuilder(toBuilder = true) +@AllArgsConstructor +@NoArgsConstructor +public class BaseStageRequest implements Serializable { + private static final long serialVersionUID = 8192049841043084620L; + /** + * If you are calling the cloud API, this parameter needs to be filled in; otherwise, you can ignore it. + */ + private String apiKey; + + private Map options; +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/response/ApplyStageResponse.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/response/ApplyStageResponse.java new file mode 100644 index 000000000..532c84200 --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/response/ApplyStageResponse.java @@ -0,0 +1,53 @@ +package io.milvus.bulkwriter.response; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +import java.io.Serializable; + + +@Data +@SuperBuilder +@AllArgsConstructor +@NoArgsConstructor +public class ApplyStageResponse implements Serializable { + private String endpoint; + + private String cloud; + + private String region; + + private String bucketName; + + private String uploadPath; + + private Credentials credentials; + + private Condition condition; + + private String stageName; + + @AllArgsConstructor + @NoArgsConstructor + @Data + @Builder + public static class Credentials implements Serializable { + private static final long serialVersionUID = 623702599895113789L; + private String tmpAK; + private String tmpSK; + private String sessionToken; + private String expireTime; + } + + @AllArgsConstructor + @NoArgsConstructor + @Data + @Builder + public static class Condition implements Serializable { + private static final long serialVersionUID = -2613029991242322109L; + private Long maxContentLength; + } +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/BaseBulkImport.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/restful/BaseRestful.java similarity index 98% rename from sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/BaseBulkImport.java rename to sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/restful/BaseRestful.java index 7e7097549..32efd3f50 100644 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/BaseBulkImport.java +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/restful/BaseRestful.java @@ -17,7 +17,7 @@ * under the License. */ -package io.milvus.bulkwriter; +package io.milvus.bulkwriter.restful; import io.milvus.bulkwriter.response.RestfulResponse; import io.milvus.common.utils.ExceptionUtils; @@ -27,7 +27,7 @@ import java.util.HashMap; import java.util.Map; -public class BaseBulkImport { +public class BaseRestful { protected static String postRequest(String url, String apiKey, Map params, int timeout) { try { setDefaultOptionsIfCallCloud(params, apiKey); diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/BulkImport.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/restful/BulkImportUtils.java similarity index 97% rename from sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/BulkImport.java rename to sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/restful/BulkImportUtils.java index aa830671f..d7c81c1cd 100644 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/BulkImport.java +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/restful/BulkImportUtils.java @@ -17,7 +17,7 @@ * under the License. */ -package io.milvus.bulkwriter; +package io.milvus.bulkwriter.restful; import com.google.gson.reflect.TypeToken; import io.milvus.bulkwriter.request.describe.BaseDescribeImportRequest; @@ -28,7 +28,7 @@ import java.util.Map; -public class BulkImport extends BaseBulkImport { +public class BulkImportUtils extends BaseRestful { public static String bulkImport(String url, BaseImportRequest request) { String requestURL = url + "/v2/vectordb/jobs/import/create"; diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/restful/DataStageUtils.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/restful/DataStageUtils.java new file mode 100644 index 000000000..f92e8ecb6 --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/restful/DataStageUtils.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.bulkwriter.restful; + +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import io.milvus.bulkwriter.request.stage.BaseStageRequest; +import io.milvus.bulkwriter.response.RestfulResponse; +import io.milvus.common.utils.JsonUtils; + +import java.util.Map; + +public class DataStageUtils extends BaseRestful { + public static String applyStage(String url, BaseStageRequest request) { + String requestURL = url + "/v2/stages/apply"; + + Map params = JsonUtils.fromJson(JsonUtils.toJson(request), new TypeToken>() {}.getType()); + String body = postRequest(requestURL, request.getApiKey(), params, 60 * 1000); + RestfulResponse response = JsonUtils.fromJson(body, new TypeToken>(){}.getType()); + handleResponse(requestURL, response); + return new Gson().toJson(response.getData()); + } +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/StorageClient.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/StorageClient.java index 007e12b47..67e91631a 100644 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/StorageClient.java +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/StorageClient.java @@ -20,10 +20,10 @@ package io.milvus.bulkwriter.storage; -import java.io.InputStream; +import java.io.File; public interface StorageClient { Long getObjectEntity(String bucketName, String objectKey) throws Exception; boolean checkBucketExist(String bucketName) throws Exception; - void putObjectStream(InputStream inputStream, long contentLength, String bucketName, String objectKey) throws Exception; + void putObject(File file, String bucketName, String objectKey) throws Exception; } diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/client/AzureStorageClient.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/client/AzureStorageClient.java index dfd9d9f11..82e9dc93f 100644 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/client/AzureStorageClient.java +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/client/AzureStorageClient.java @@ -30,7 +30,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.InputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; public class AzureStorageClient implements StorageClient { private static final Logger logger = LoggerFactory.getLogger(AzureStorageClient.class); @@ -66,12 +68,12 @@ public Long getObjectEntity(String bucketName, String objectKey) { return blobClient.getProperties().getBlobSize(); } - public void putObjectStream(InputStream inputStream, long contentLength, String bucketName, String objectKey) { + public void putObject(File file, String bucketName, String objectKey) throws FileNotFoundException { + FileInputStream fileInputStream = new FileInputStream(file); BlobClient blobClient = blobServiceClient.getBlobContainerClient(bucketName).getBlobClient(objectKey); - blobClient.upload(inputStream, contentLength); + blobClient.upload(fileInputStream, file.length()); } - public boolean checkBucketExist(String bucketName) { BlobContainerClient blobContainerClient = blobServiceClient.getBlobContainerClient(bucketName); return blobContainerClient.exists(); diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/client/MinioStorageClient.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/client/MinioStorageClient.java index 1bf98c824..a5938986e 100644 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/client/MinioStorageClient.java +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/client/MinioStorageClient.java @@ -19,27 +19,48 @@ package io.milvus.bulkwriter.storage.client; +import com.google.common.collect.Multimap; import io.milvus.bulkwriter.common.clientenum.CloudStorage; +import io.milvus.bulkwriter.model.CompleteMultipartUploadOutputModel; import io.milvus.bulkwriter.storage.StorageClient; import io.minio.BucketExistsArgs; -import io.minio.MinioClient; +import io.minio.MinioAsyncClient; +import io.minio.ObjectWriteResponse; import io.minio.PutObjectArgs; +import io.minio.S3Base; import io.minio.StatObjectArgs; import io.minio.StatObjectResponse; +import io.minio.Xml; import io.minio.credentials.StaticProvider; +import io.minio.errors.ErrorResponseException; +import io.minio.errors.InsufficientDataException; +import io.minio.errors.InternalException; +import io.minio.errors.XmlParserException; +import io.minio.http.Method; +import io.minio.messages.CompleteMultipartUpload; +import io.minio.messages.ErrorResponse; +import io.minio.messages.Part; import okhttp3.OkHttpClient; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.InputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import static com.amazonaws.services.s3.internal.Constants.MB; -public class MinioStorageClient extends MinioClient implements StorageClient { +public class MinioStorageClient extends MinioAsyncClient implements StorageClient { private static final Logger logger = LoggerFactory.getLogger(MinioStorageClient.class); + private static final String UPLOAD_ID = "uploadId"; - protected MinioStorageClient(MinioClient client) { + + protected MinioStorageClient(MinioAsyncClient client) { super(client); } @@ -50,7 +71,7 @@ public static MinioStorageClient getStorageClient(String cloudName, String sessionToken, String region, OkHttpClient httpClient) { - MinioClient.Builder minioClientBuilder = MinioClient.builder() + MinioAsyncClient.Builder minioClientBuilder = MinioAsyncClient.builder() .endpoint(endpoint) .credentialsProvider(new StaticProvider(accessKey, secretKey, sessionToken)); @@ -62,7 +83,7 @@ public static MinioStorageClient getStorageClient(String cloudName, minioClientBuilder.httpClient(httpClient); } - MinioClient minioClient = minioClientBuilder.build(); + MinioAsyncClient minioClient = minioClientBuilder.build(); if (CloudStorage.TC.getCloudName().equals(cloudName)) { minioClient.enableVirtualStyleEndpoint(); } @@ -75,23 +96,103 @@ public Long getObjectEntity(String bucketName, String objectKey) throws Exceptio .bucket(bucketName) .object(objectKey) .build(); - StatObjectResponse statObject = statObject(statObjectArgs); + StatObjectResponse statObject = statObject(statObjectArgs).get(); return statObject.size(); } - public void putObjectStream(InputStream inputStream, long contentLength, String bucketName, String objectKey) throws Exception { + public void putObject(File file, String bucketName, String objectKey) throws Exception { + logger.info("uploading file, fileName:{}, size:{} bytes", file.getAbsolutePath(), file.length()); + FileInputStream fileInputStream = new FileInputStream(file); PutObjectArgs putObjectArgs = PutObjectArgs.builder() .bucket(bucketName) .object(objectKey) - .stream(inputStream, contentLength, 5 * MB) + .stream(fileInputStream, file.length(), 5 * MB) .build(); - putObject(putObjectArgs); + putObject(putObjectArgs).get(); } public boolean checkBucketExist(String bucketName) throws Exception { BucketExistsArgs bucketExistsArgs = BucketExistsArgs.builder() .bucket(bucketName) .build(); - return bucketExists(bucketExistsArgs); + return bucketExists(bucketExistsArgs).get(); + } + + @Override + // Considering MinIO's compatibility with S3, some adjustments have been made here. + protected CompletableFuture completeMultipartUploadAsync(String bucketName, String region, String objectName, String uploadId, Part[] parts, Multimap extraHeaders, Multimap extraQueryParams) throws InsufficientDataException, InternalException, InvalidKeyException, IOException, NoSuchAlgorithmException, XmlParserException { + Multimap queryParams = newMultimap(extraQueryParams); + queryParams.put(UPLOAD_ID, uploadId); + return getRegionAsync(bucketName, region) + .thenCompose( + location -> { + try { + return executeAsync( + Method.POST, + bucketName, + objectName, + location, + httpHeaders(extraHeaders), + queryParams, + new CompleteMultipartUpload(parts), + 0); + } catch (InsufficientDataException + | InternalException + | InvalidKeyException + | IOException + | NoSuchAlgorithmException + | XmlParserException e) { + throw new CompletionException(e); + } + }) + .thenApply( + response -> { + try { + String bodyContent = response.body().string(); + bodyContent = bodyContent.trim(); + if (!bodyContent.isEmpty()) { + try { + if (Xml.validate(ErrorResponse.class, bodyContent)) { + ErrorResponse errorResponse = Xml.unmarshal(ErrorResponse.class, bodyContent); + throw new CompletionException( + new ErrorResponseException(errorResponse, response, null)); + } + } catch (XmlParserException e) { + // As it is not message, fallback to parse CompleteMultipartUploadOutput + // XML. + } + + try { + CompleteMultipartUploadOutputModel result = + Xml.unmarshal(CompleteMultipartUploadOutputModel.class, bodyContent); + return new ObjectWriteResponse( + response.headers(), + result.bucket(), + result.location(), + result.object(), + result.etag(), + response.header("x-amz-version-id")); + } catch (XmlParserException e) { + // As this CompleteMultipartUpload REST call succeeded, just log it. + java.util.logging.Logger.getLogger(S3Base.class.getName()) + .warning( + "S3 service returned unknown XML for CompleteMultipartUpload REST API. " + + bodyContent); + } + } + + return new ObjectWriteResponse( + response.headers(), + bucketName, + region, + objectName, + null, + response.header("x-amz-version-id")); + } catch (IOException e) { + throw new CompletionException(e); + } finally { + response.close(); + } + }); } }