Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/src/main/java/io/milvus/v1/BulkWriterExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ private static RemoteBulkWriter buildRemoteBulkWriter(CollectionSchemaParam coll

private static StorageConnectParam buildStorageConnectParam() {
StorageConnectParam connectParam;
if (StorageConsts.cloudStorage == CloudStorage.AZURE) {
if (CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName())) {
String connectionStr = "DefaultEndpointsProtocol=https;AccountName=" + StorageConsts.AZURE_ACCOUNT_NAME +
";AccountKey=" + StorageConsts.AZURE_ACCOUNT_KEY + ";EndpointSuffix=core.windows.net";
connectParam = AzureConnectParam.newBuilder()
Expand Down Expand Up @@ -541,11 +541,11 @@ private void callBulkInsert(CollectionSchemaParam collectionSchema, List<List<St
}

private void callCloudImport(List<List<String>> batchFiles, String collectionName, String partitionName) throws InterruptedException {
String objectUrl = StorageConsts.cloudStorage == CloudStorage.AZURE
String objectUrl = CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName())
? StorageConsts.cloudStorage.getAzureObjectUrl(StorageConsts.AZURE_ACCOUNT_NAME, StorageConsts.AZURE_CONTAINER_NAME, ImportUtils.getCommonPrefix(batchFiles))
: StorageConsts.cloudStorage.getS3ObjectUrl(StorageConsts.STORAGE_BUCKET, ImportUtils.getCommonPrefix(batchFiles), StorageConsts.STORAGE_REGION);
String accessKey = StorageConsts.cloudStorage == CloudStorage.AZURE ? StorageConsts.AZURE_ACCOUNT_NAME : StorageConsts.STORAGE_ACCESS_KEY;
String secretKey = StorageConsts.cloudStorage == CloudStorage.AZURE ? StorageConsts.AZURE_ACCOUNT_KEY : StorageConsts.STORAGE_SECRET_KEY;
String accessKey = CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName()) ? StorageConsts.AZURE_ACCOUNT_NAME : StorageConsts.STORAGE_ACCESS_KEY;
String secretKey = CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName()) ? StorageConsts.AZURE_ACCOUNT_KEY : StorageConsts.STORAGE_SECRET_KEY;

System.out.println("\n===================== call cloudImport ====================");
List<String> objectUrls = Lists.newArrayList(objectUrl);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.google.gson.Gson;
import io.milvus.bulkwriter.StageFileManager;
import io.milvus.bulkwriter.StageFileManagerParam;
import io.milvus.bulkwriter.common.clientenum.ConnectType;
import io.milvus.bulkwriter.model.UploadFilesResult;
import io.milvus.bulkwriter.request.stage.UploadFilesRequest;

Expand All @@ -35,6 +36,7 @@ public class StageFileManagerExample {
.withCloudEndpoint("https://api.cloud.zilliz.com")
.withApiKey("_api_key_for_cluster_org_")
.withStageName("_stage_name_for_project_")
.withConnectType(ConnectType.AUTO)
.build();
stageFileManager = new StageFileManager(stageFileManagerParam);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ private static RemoteBulkWriter buildRemoteBulkWriter(CreateCollectionReq.Collec

private static StorageConnectParam buildStorageConnectParam() {
StorageConnectParam connectParam;
if (StorageConsts.cloudStorage == CloudStorage.AZURE) {
if (CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName())) {
String connectionStr = "DefaultEndpointsProtocol=https;AccountName=" + StorageConsts.AZURE_ACCOUNT_NAME +
";AccountKey=" + StorageConsts.AZURE_ACCOUNT_KEY + ";EndpointSuffix=core.windows.net";
connectParam = AzureConnectParam.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.google.common.collect.Lists;
import com.google.gson.JsonObject;
import io.milvus.bulkwriter.common.clientenum.ConnectType;
import io.milvus.bulkwriter.model.UploadFilesResult;
import io.milvus.bulkwriter.request.stage.UploadFilesRequest;
import io.milvus.common.utils.ExceptionUtils;
Expand Down Expand Up @@ -63,7 +64,7 @@ public StageBulkWriter(StageBulkWriterParam bulkWriterParam) throws IOException
private StageFileManager initStageFileManagerParams(StageBulkWriterParam bulkWriterParam) throws IOException {
StageFileManagerParam stageFileManagerParam = StageFileManagerParam.newBuilder()
.withCloudEndpoint(bulkWriterParam.getCloudEndpoint()).withApiKey(bulkWriterParam.getApiKey())
.withStageName(bulkWriterParam.getStageName())
.withStageName(bulkWriterParam.getStageName()).withConnectType(ConnectType.AUTO)
.build();
return new StageFileManager(stageFileManagerParam);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
package io.milvus.bulkwriter;

import com.google.gson.Gson;
import io.milvus.bulkwriter.common.clientenum.ConnectType;
import io.milvus.bulkwriter.common.utils.FileUtils;
import io.milvus.bulkwriter.model.UploadFilesResult;
import io.milvus.bulkwriter.request.stage.ApplyStageRequest;
import io.milvus.bulkwriter.request.stage.UploadFilesRequest;
import io.milvus.bulkwriter.resolver.EndpointResolver;
import io.milvus.bulkwriter.response.ApplyStageResponse;
import io.milvus.bulkwriter.restful.DataStageUtils;
import io.milvus.bulkwriter.storage.StorageClient;
Expand Down Expand Up @@ -54,6 +56,7 @@ public class StageFileManager {
private final String cloudEndpoint;
private final String apiKey;
private final String stageName;
private final ConnectType connectType;
private final ExecutorService executor;

private StorageClient storageClient;
Expand All @@ -63,7 +66,8 @@ public StageFileManager(StageFileManagerParam stageWriterParam) {
this.cloudEndpoint = stageWriterParam.getCloudEndpoint();
this.apiKey = stageWriterParam.getApiKey();
this.stageName = stageWriterParam.getStageName();
this.executor = Executors.newFixedThreadPool(20);
this.connectType = stageWriterParam.getConnectType();
this.executor = Executors.newFixedThreadPool(10);
}

/**
Expand Down Expand Up @@ -138,7 +142,7 @@ public CompletableFuture<UploadFilesResult> uploadFilesAsync(UploadFilesRequest
public void shutdownGracefully() {
executor.shutdown();
try {
if (!executor.awaitTermination(10, TimeUnit.SECONDS)) {
if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
logger.warn("Executor didn't terminate in time, forcing shutdown...");
executor.shutdownNow();
}
Expand Down Expand Up @@ -168,9 +172,11 @@ private void refreshStageAndClient(String path) {
applyStageResponse = new Gson().fromJson(result, ApplyStageResponse.class);
logger.info("stage info refreshed");

String endpoint = EndpointResolver.resolveEndpoint(applyStageResponse.getEndpoint(), applyStageResponse.getCloud(),
applyStageResponse.getRegion(), connectType);
storageClient = MinioStorageClient.getStorageClient(
applyStageResponse.getCloud(),
applyStageResponse.getEndpoint(),
endpoint,
applyStageResponse.getCredentials().getTmpAK(),
applyStageResponse.getCredentials().getTmpSK(),
applyStageResponse.getCredentials().getSessionToken(),
Expand Down Expand Up @@ -235,6 +241,8 @@ private <T> T withRetry(String actionName, Callable<T> callable, String stagePat
while (attempt < maxRetries) {
try {
return callable.call();
} catch (RuntimeException e) {
throw e;
} catch (Exception e) {
attempt++;
refreshStageAndClient(stagePath);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package io.milvus.bulkwriter;

import io.milvus.bulkwriter.common.clientenum.ConnectType;
import io.milvus.exception.ParamException;
import io.milvus.param.ParamUtils;
import lombok.Getter;
Expand All @@ -35,11 +36,13 @@ public class StageFileManagerParam {
private final String cloudEndpoint;
private final String apiKey;
private final String stageName;
private final ConnectType connectType;

private StageFileManagerParam(@NonNull Builder builder) {
this.cloudEndpoint = builder.cloudEndpoint;
this.apiKey = builder.apiKey;
this.stageName = builder.stageName;
this.connectType = builder.connectType;
}

public static Builder newBuilder() {
Expand All @@ -56,6 +59,8 @@ public static final class Builder {

private String stageName;

private ConnectType connectType = ConnectType.AUTO;

private Builder() {
}

Expand All @@ -79,6 +84,17 @@ public Builder withStageName(@NotNull String stageName) {
return this;
}

/**
* Current value is mainly for Aliyun OSS buckets, default is Auto.
* In the default case, if the OSS bucket is reachable via the internal endpoint, the internal endpoint will be used;
* otherwise, the public endpoint will be used.
* You can also force the use of either the internal or public endpoint.
*/
public Builder withConnectType(@NotNull ConnectType connectType) {
this.connectType = connectType;
return this;
}

/**
* Verifies parameters and creates a new {@link StageFileManagerParam} instance.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,25 @@
import io.milvus.exception.ParamException;
import lombok.Getter;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.util.Lists;

import java.util.List;

public enum CloudStorage {
MINIO("minio","%s", "minioAddress"),
AWS("aws","s3.amazonaws.com", null),
GCP("gcp" ,"storage.googleapis.com", null),

AZ("az" ,"%s.blob.core.windows.net", "accountName"),
AZURE("azure" ,"%s.blob.core.windows.net", "accountName"),

ALI("ali","oss-%s.aliyuncs.com", "region"),
TC("tc","cos.%s.myqcloud.com", "region")
ALIYUN("aliyun","oss-%s.aliyuncs.com", "region"),
ALIBABA("alibaba","oss-%s.aliyuncs.com", "region"),
ALICLOU("alicloud","oss-%s.aliyuncs.com", "region"),

TC("tc","cos.%s.myqcloud.com", "region"),
TENCENT("tencent","cos.%s.myqcloud.com", "region")
;

@Getter
Expand All @@ -45,6 +56,27 @@ public enum CloudStorage {
this.replace = replace;
}

public static boolean isAliCloud(String cloudName) {
List<CloudStorage> aliCloudStorages = Lists.newArrayList(
CloudStorage.ALI, CloudStorage.ALIYUN, CloudStorage.ALIBABA, CloudStorage.ALICLOU
);
return aliCloudStorages.stream().anyMatch(e -> e.getCloudName().equalsIgnoreCase(cloudName));
}

public static boolean isTcCloud(String cloudName) {
List<CloudStorage> tcCloudStorages = Lists.newArrayList(
CloudStorage.TC, CloudStorage.TENCENT
);
return tcCloudStorages.stream().anyMatch(e -> e.getCloudName().equalsIgnoreCase(cloudName));
}

public static boolean isAzCloud(String cloudName) {
List<CloudStorage> azCloudStorages = Lists.newArrayList(
CloudStorage.AZ, CloudStorage.AZURE
);
return azCloudStorages.stream().anyMatch(e -> e.getCloudName().equalsIgnoreCase(cloudName));
}

public static CloudStorage getCloudStorage(String cloudName) {
for (CloudStorage cloudStorage : values()) {
if (cloudStorage.getCloudName().equals(cloudName)) {
Expand All @@ -71,16 +103,20 @@ public String getS3ObjectUrl(String bucketName, String commonPrefix, String regi
case GCP:
return String.format("https://storage.cloud.google.com/%s/%s", bucketName, commonPrefix);
case TC:
case TENCENT:
return String.format("https://%s.cos.%s.myqcloud.com/%s", bucketName, region, commonPrefix);
case ALI:
case ALICLOU:
case ALIBABA:
case ALIYUN:
return String.format("https://%s.oss-%s.aliyuncs.com/%s", bucketName, region, commonPrefix);
default:
throw new ParamException("no support others remote storage address");
}
}

public String getAzureObjectUrl(String accountName, String containerName, String commonPrefix) {
if (this == CloudStorage.AZURE) {
if (CloudStorage.isAzCloud(this.getCloudName())) {
return String.format("https://%s.blob.core.windows.net/%s/%s", accountName, containerName, commonPrefix);
}
throw new ParamException("no support others remote storage address");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package io.milvus.bulkwriter.common.clientenum;

public enum ConnectType {
AUTO,
INTERNAL,
PUBLIC
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package io.milvus.bulkwriter.resolver;

import io.milvus.bulkwriter.common.clientenum.CloudStorage;
import io.milvus.bulkwriter.common.clientenum.ConnectType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.HttpURLConnection;
import java.net.URL;
import java.util.concurrent.TimeUnit;

public class EndpointResolver {
private static final Logger logger = LoggerFactory.getLogger(EndpointResolver.class);

public static String resolveEndpoint(String defaultEndpoint, String cloud, String region, ConnectType connectType) {
logger.info("Start resolving endpoint, cloud:{}, region:{}, connectType:{}", cloud, region, connectType);
if (CloudStorage.isAliCloud(cloud)) {
defaultEndpoint = resolveOssEndpoint(region, connectType);
}
logger.info("Resolved endpoint: {}, reachable check passed", defaultEndpoint);
return defaultEndpoint;
}

private static String resolveOssEndpoint(String region, ConnectType connectType) {
String internalEndpoint = String.format("oss-%s-internal.aliyuncs.com", region);
String publicEndpoint = String.format("oss-%s.aliyuncs.com", region);

switch (connectType) {
case INTERNAL:
logger.info("Forced INTERNAL endpoint selected: {}", internalEndpoint);
checkEndpointReachable(internalEndpoint, true);
return internalEndpoint;
case PUBLIC:
logger.info("Forced PUBLIC endpoint selected: {}", publicEndpoint);
checkEndpointReachable(publicEndpoint, true);
return publicEndpoint;
case AUTO:
default:
if (checkEndpointReachable(internalEndpoint, false)) {
logger.info("AUTO mode: internal endpoint reachable, using {}", internalEndpoint);
return internalEndpoint;
} else {
logger.warn("AUTO mode: internal endpoint not reachable, fallback to public endpoint {}", publicEndpoint);
checkEndpointReachable(publicEndpoint, true);
return publicEndpoint;
}
}
}

private static boolean checkEndpointReachable(String endpoint, boolean printError) {
try {
String httpEndpoint = String.format("https://%s", endpoint);
URL url = new URL(httpEndpoint);
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
conn.setConnectTimeout((int) TimeUnit.SECONDS.toMillis(5));
conn.setReadTimeout((int) TimeUnit.SECONDS.toMillis(5));
conn.setRequestMethod("HEAD");
int code = conn.getResponseCode();
logger.debug("Checked endpoint {}, response code={}", endpoint, code);
return code >= 200 && code < 400;
} catch (Exception e) {
if (printError) {
logger.error("Endpoint {} not reachable, throwing exception", endpoint, e);
throw new RuntimeException(e.getMessage());
} else {
logger.warn("Endpoint {} not reachable, will fallback if needed", endpoint);
return false;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public static MinioStorageClient getStorageClient(String cloudName,
}

MinioAsyncClient minioClient = minioClientBuilder.build();
if (CloudStorage.TC.getCloudName().equals(cloudName)) {
if (CloudStorage.isTcCloud(cloudName)) {
minioClient.enableVirtualStyleEndpoint();
}

Expand Down
Loading