diff --git a/examples/src/main/java/io/milvus/v1/BulkWriterExample.java b/examples/src/main/java/io/milvus/v1/BulkWriterExample.java index 5307de37b..5b618cea1 100644 --- a/examples/src/main/java/io/milvus/v1/BulkWriterExample.java +++ b/examples/src/main/java/io/milvus/v1/BulkWriterExample.java @@ -448,7 +448,7 @@ private static RemoteBulkWriter buildRemoteBulkWriter(CollectionSchemaParam coll private static StorageConnectParam buildStorageConnectParam() { StorageConnectParam connectParam; - if (StorageConsts.cloudStorage == CloudStorage.AZURE) { + if (CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName())) { String connectionStr = "DefaultEndpointsProtocol=https;AccountName=" + StorageConsts.AZURE_ACCOUNT_NAME + ";AccountKey=" + StorageConsts.AZURE_ACCOUNT_KEY + ";EndpointSuffix=core.windows.net"; connectParam = AzureConnectParam.newBuilder() @@ -541,11 +541,11 @@ private void callBulkInsert(CollectionSchemaParam collectionSchema, List> batchFiles, String collectionName, String partitionName) throws InterruptedException { - String objectUrl = StorageConsts.cloudStorage == CloudStorage.AZURE + String objectUrl = CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName()) ? StorageConsts.cloudStorage.getAzureObjectUrl(StorageConsts.AZURE_ACCOUNT_NAME, StorageConsts.AZURE_CONTAINER_NAME, ImportUtils.getCommonPrefix(batchFiles)) : StorageConsts.cloudStorage.getS3ObjectUrl(StorageConsts.STORAGE_BUCKET, ImportUtils.getCommonPrefix(batchFiles), StorageConsts.STORAGE_REGION); - String accessKey = StorageConsts.cloudStorage == CloudStorage.AZURE ? StorageConsts.AZURE_ACCOUNT_NAME : StorageConsts.STORAGE_ACCESS_KEY; - String secretKey = StorageConsts.cloudStorage == CloudStorage.AZURE ? StorageConsts.AZURE_ACCOUNT_KEY : StorageConsts.STORAGE_SECRET_KEY; + String accessKey = CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName()) ? StorageConsts.AZURE_ACCOUNT_NAME : StorageConsts.STORAGE_ACCESS_KEY; + String secretKey = CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName()) ? StorageConsts.AZURE_ACCOUNT_KEY : StorageConsts.STORAGE_SECRET_KEY; System.out.println("\n===================== call cloudImport ===================="); List objectUrls = Lists.newArrayList(objectUrl); diff --git a/examples/src/main/java/io/milvus/v2/StageFileManagerExample.java b/examples/src/main/java/io/milvus/v2/StageFileManagerExample.java index 44c89bda6..c554075fd 100644 --- a/examples/src/main/java/io/milvus/v2/StageFileManagerExample.java +++ b/examples/src/main/java/io/milvus/v2/StageFileManagerExample.java @@ -21,6 +21,7 @@ import com.google.gson.Gson; import io.milvus.bulkwriter.StageFileManager; import io.milvus.bulkwriter.StageFileManagerParam; +import io.milvus.bulkwriter.common.clientenum.ConnectType; import io.milvus.bulkwriter.model.UploadFilesResult; import io.milvus.bulkwriter.request.stage.UploadFilesRequest; @@ -35,6 +36,7 @@ public class StageFileManagerExample { .withCloudEndpoint("https://api.cloud.zilliz.com") .withApiKey("_api_key_for_cluster_org_") .withStageName("_stage_name_for_project_") + .withConnectType(ConnectType.AUTO) .build(); stageFileManager = new StageFileManager(stageFileManagerParam); } diff --git a/examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterRemoteExample.java b/examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterRemoteExample.java index b041e7e0d..c21d1bd37 100644 --- a/examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterRemoteExample.java +++ b/examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterRemoteExample.java @@ -392,7 +392,7 @@ private static RemoteBulkWriter buildRemoteBulkWriter(CreateCollectionReq.Collec private static StorageConnectParam buildStorageConnectParam() { StorageConnectParam connectParam; - if (StorageConsts.cloudStorage == CloudStorage.AZURE) { + if (CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName())) { String connectionStr = "DefaultEndpointsProtocol=https;AccountName=" + StorageConsts.AZURE_ACCOUNT_NAME + ";AccountKey=" + StorageConsts.AZURE_ACCOUNT_KEY + ";EndpointSuffix=core.windows.net"; connectParam = AzureConnectParam.newBuilder() diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageBulkWriter.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageBulkWriter.java index 8bce60da1..3917bdceb 100644 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageBulkWriter.java +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageBulkWriter.java @@ -21,6 +21,7 @@ import com.google.common.collect.Lists; import com.google.gson.JsonObject; +import io.milvus.bulkwriter.common.clientenum.ConnectType; import io.milvus.bulkwriter.model.UploadFilesResult; import io.milvus.bulkwriter.request.stage.UploadFilesRequest; import io.milvus.common.utils.ExceptionUtils; @@ -63,7 +64,7 @@ public StageBulkWriter(StageBulkWriterParam bulkWriterParam) throws IOException private StageFileManager initStageFileManagerParams(StageBulkWriterParam bulkWriterParam) throws IOException { StageFileManagerParam stageFileManagerParam = StageFileManagerParam.newBuilder() .withCloudEndpoint(bulkWriterParam.getCloudEndpoint()).withApiKey(bulkWriterParam.getApiKey()) - .withStageName(bulkWriterParam.getStageName()) + .withStageName(bulkWriterParam.getStageName()).withConnectType(ConnectType.AUTO) .build(); return new StageFileManager(stageFileManagerParam); } diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageFileManager.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageFileManager.java index 016fb16f1..4311e7fce 100644 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageFileManager.java +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageFileManager.java @@ -20,10 +20,12 @@ package io.milvus.bulkwriter; import com.google.gson.Gson; +import io.milvus.bulkwriter.common.clientenum.ConnectType; import io.milvus.bulkwriter.common.utils.FileUtils; import io.milvus.bulkwriter.model.UploadFilesResult; import io.milvus.bulkwriter.request.stage.ApplyStageRequest; import io.milvus.bulkwriter.request.stage.UploadFilesRequest; +import io.milvus.bulkwriter.resolver.EndpointResolver; import io.milvus.bulkwriter.response.ApplyStageResponse; import io.milvus.bulkwriter.restful.DataStageUtils; import io.milvus.bulkwriter.storage.StorageClient; @@ -54,6 +56,7 @@ public class StageFileManager { private final String cloudEndpoint; private final String apiKey; private final String stageName; + private final ConnectType connectType; private final ExecutorService executor; private StorageClient storageClient; @@ -63,7 +66,8 @@ public StageFileManager(StageFileManagerParam stageWriterParam) { this.cloudEndpoint = stageWriterParam.getCloudEndpoint(); this.apiKey = stageWriterParam.getApiKey(); this.stageName = stageWriterParam.getStageName(); - this.executor = Executors.newFixedThreadPool(20); + this.connectType = stageWriterParam.getConnectType(); + this.executor = Executors.newFixedThreadPool(10); } /** @@ -138,7 +142,7 @@ public CompletableFuture uploadFilesAsync(UploadFilesRequest public void shutdownGracefully() { executor.shutdown(); try { - if (!executor.awaitTermination(10, TimeUnit.SECONDS)) { + if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { logger.warn("Executor didn't terminate in time, forcing shutdown..."); executor.shutdownNow(); } @@ -168,9 +172,11 @@ private void refreshStageAndClient(String path) { applyStageResponse = new Gson().fromJson(result, ApplyStageResponse.class); logger.info("stage info refreshed"); + String endpoint = EndpointResolver.resolveEndpoint(applyStageResponse.getEndpoint(), applyStageResponse.getCloud(), + applyStageResponse.getRegion(), connectType); storageClient = MinioStorageClient.getStorageClient( applyStageResponse.getCloud(), - applyStageResponse.getEndpoint(), + endpoint, applyStageResponse.getCredentials().getTmpAK(), applyStageResponse.getCredentials().getTmpSK(), applyStageResponse.getCredentials().getSessionToken(), @@ -235,6 +241,8 @@ private T withRetry(String actionName, Callable callable, String stagePat while (attempt < maxRetries) { try { return callable.call(); + } catch (RuntimeException e) { + throw e; } catch (Exception e) { attempt++; refreshStageAndClient(stagePath); diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageFileManagerParam.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageFileManagerParam.java index 927e5d8da..8cb0b623e 100644 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageFileManagerParam.java +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageFileManagerParam.java @@ -19,6 +19,7 @@ package io.milvus.bulkwriter; +import io.milvus.bulkwriter.common.clientenum.ConnectType; import io.milvus.exception.ParamException; import io.milvus.param.ParamUtils; import lombok.Getter; @@ -35,11 +36,13 @@ public class StageFileManagerParam { private final String cloudEndpoint; private final String apiKey; private final String stageName; + private final ConnectType connectType; private StageFileManagerParam(@NonNull Builder builder) { this.cloudEndpoint = builder.cloudEndpoint; this.apiKey = builder.apiKey; this.stageName = builder.stageName; + this.connectType = builder.connectType; } public static Builder newBuilder() { @@ -56,6 +59,8 @@ public static final class Builder { private String stageName; + private ConnectType connectType = ConnectType.AUTO; + private Builder() { } @@ -79,6 +84,17 @@ public Builder withStageName(@NotNull String stageName) { return this; } + /** + * Current value is mainly for Aliyun OSS buckets, default is Auto. + * In the default case, if the OSS bucket is reachable via the internal endpoint, the internal endpoint will be used; + * otherwise, the public endpoint will be used. + * You can also force the use of either the internal or public endpoint. + */ + public Builder withConnectType(@NotNull ConnectType connectType) { + this.connectType = connectType; + return this; + } + /** * Verifies parameters and creates a new {@link StageFileManagerParam} instance. * diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/CloudStorage.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/CloudStorage.java index 6bd3a9118..2aba4216a 100644 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/CloudStorage.java +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/CloudStorage.java @@ -22,14 +22,25 @@ import io.milvus.exception.ParamException; import lombok.Getter; import org.apache.commons.lang3.StringUtils; +import org.apache.hadoop.util.Lists; + +import java.util.List; public enum CloudStorage { MINIO("minio","%s", "minioAddress"), AWS("aws","s3.amazonaws.com", null), GCP("gcp" ,"storage.googleapis.com", null), + + AZ("az" ,"%s.blob.core.windows.net", "accountName"), AZURE("azure" ,"%s.blob.core.windows.net", "accountName"), + ALI("ali","oss-%s.aliyuncs.com", "region"), - TC("tc","cos.%s.myqcloud.com", "region") + ALIYUN("aliyun","oss-%s.aliyuncs.com", "region"), + ALIBABA("alibaba","oss-%s.aliyuncs.com", "region"), + ALICLOU("alicloud","oss-%s.aliyuncs.com", "region"), + + TC("tc","cos.%s.myqcloud.com", "region"), + TENCENT("tencent","cos.%s.myqcloud.com", "region") ; @Getter @@ -45,6 +56,27 @@ public enum CloudStorage { this.replace = replace; } + public static boolean isAliCloud(String cloudName) { + List aliCloudStorages = Lists.newArrayList( + CloudStorage.ALI, CloudStorage.ALIYUN, CloudStorage.ALIBABA, CloudStorage.ALICLOU + ); + return aliCloudStorages.stream().anyMatch(e -> e.getCloudName().equalsIgnoreCase(cloudName)); + } + + public static boolean isTcCloud(String cloudName) { + List tcCloudStorages = Lists.newArrayList( + CloudStorage.TC, CloudStorage.TENCENT + ); + return tcCloudStorages.stream().anyMatch(e -> e.getCloudName().equalsIgnoreCase(cloudName)); + } + + public static boolean isAzCloud(String cloudName) { + List azCloudStorages = Lists.newArrayList( + CloudStorage.AZ, CloudStorage.AZURE + ); + return azCloudStorages.stream().anyMatch(e -> e.getCloudName().equalsIgnoreCase(cloudName)); + } + public static CloudStorage getCloudStorage(String cloudName) { for (CloudStorage cloudStorage : values()) { if (cloudStorage.getCloudName().equals(cloudName)) { @@ -71,8 +103,12 @@ public String getS3ObjectUrl(String bucketName, String commonPrefix, String regi case GCP: return String.format("https://storage.cloud.google.com/%s/%s", bucketName, commonPrefix); case TC: + case TENCENT: return String.format("https://%s.cos.%s.myqcloud.com/%s", bucketName, region, commonPrefix); case ALI: + case ALICLOU: + case ALIBABA: + case ALIYUN: return String.format("https://%s.oss-%s.aliyuncs.com/%s", bucketName, region, commonPrefix); default: throw new ParamException("no support others remote storage address"); @@ -80,7 +116,7 @@ public String getS3ObjectUrl(String bucketName, String commonPrefix, String regi } public String getAzureObjectUrl(String accountName, String containerName, String commonPrefix) { - if (this == CloudStorage.AZURE) { + if (CloudStorage.isAzCloud(this.getCloudName())) { return String.format("https://%s.blob.core.windows.net/%s/%s", accountName, containerName, commonPrefix); } throw new ParamException("no support others remote storage address"); diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/ConnectType.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/ConnectType.java new file mode 100644 index 000000000..993bda290 --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/ConnectType.java @@ -0,0 +1,7 @@ +package io.milvus.bulkwriter.common.clientenum; + +public enum ConnectType { + AUTO, + INTERNAL, + PUBLIC +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/StorageUtils.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/StorageUtils.java deleted file mode 100644 index 893bf1da1..000000000 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/StorageUtils.java +++ /dev/null @@ -1,22 +0,0 @@ -package io.milvus.bulkwriter.common.utils; - -import io.milvus.bulkwriter.common.clientenum.CloudStorage; -import io.milvus.exception.ParamException; - -public class StorageUtils { - public static String getObjectUrl(String cloudName, String bucketName, String objectPath, String region) { - CloudStorage cloudStorage = CloudStorage.getCloudStorage(cloudName); - switch (cloudStorage) { - case AWS: - return String.format("https://s3.%s.amazonaws.com/%s/%s", region, bucketName, objectPath); - case GCP: - return String.format("https://storage.cloud.google.com/%s/%s", bucketName, objectPath); - case TC: - return String.format("https://%s.cos.%s.myqcloud.com/%s", bucketName, region, objectPath); - case ALI: - return String.format("https://%s.oss-%s.aliyuncs.com/%s", bucketName, region, objectPath); - default: - throw new ParamException("no support others remote storage address"); - } - } -} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/resolver/EndpointResolver.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/resolver/EndpointResolver.java new file mode 100644 index 000000000..d1a83bd58 --- /dev/null +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/resolver/EndpointResolver.java @@ -0,0 +1,71 @@ +package io.milvus.bulkwriter.resolver; + +import io.milvus.bulkwriter.common.clientenum.CloudStorage; +import io.milvus.bulkwriter.common.clientenum.ConnectType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.HttpURLConnection; +import java.net.URL; +import java.util.concurrent.TimeUnit; + +public class EndpointResolver { + private static final Logger logger = LoggerFactory.getLogger(EndpointResolver.class); + + public static String resolveEndpoint(String defaultEndpoint, String cloud, String region, ConnectType connectType) { + logger.info("Start resolving endpoint, cloud:{}, region:{}, connectType:{}", cloud, region, connectType); + if (CloudStorage.isAliCloud(cloud)) { + defaultEndpoint = resolveOssEndpoint(region, connectType); + } + logger.info("Resolved endpoint: {}, reachable check passed", defaultEndpoint); + return defaultEndpoint; + } + + private static String resolveOssEndpoint(String region, ConnectType connectType) { + String internalEndpoint = String.format("oss-%s-internal.aliyuncs.com", region); + String publicEndpoint = String.format("oss-%s.aliyuncs.com", region); + + switch (connectType) { + case INTERNAL: + logger.info("Forced INTERNAL endpoint selected: {}", internalEndpoint); + checkEndpointReachable(internalEndpoint, true); + return internalEndpoint; + case PUBLIC: + logger.info("Forced PUBLIC endpoint selected: {}", publicEndpoint); + checkEndpointReachable(publicEndpoint, true); + return publicEndpoint; + case AUTO: + default: + if (checkEndpointReachable(internalEndpoint, false)) { + logger.info("AUTO mode: internal endpoint reachable, using {}", internalEndpoint); + return internalEndpoint; + } else { + logger.warn("AUTO mode: internal endpoint not reachable, fallback to public endpoint {}", publicEndpoint); + checkEndpointReachable(publicEndpoint, true); + return publicEndpoint; + } + } + } + + private static boolean checkEndpointReachable(String endpoint, boolean printError) { + try { + String httpEndpoint = String.format("https://%s", endpoint); + URL url = new URL(httpEndpoint); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setConnectTimeout((int) TimeUnit.SECONDS.toMillis(5)); + conn.setReadTimeout((int) TimeUnit.SECONDS.toMillis(5)); + conn.setRequestMethod("HEAD"); + int code = conn.getResponseCode(); + logger.debug("Checked endpoint {}, response code={}", endpoint, code); + return code >= 200 && code < 400; + } catch (Exception e) { + if (printError) { + logger.error("Endpoint {} not reachable, throwing exception", endpoint, e); + throw new RuntimeException(e.getMessage()); + } else { + logger.warn("Endpoint {} not reachable, will fallback if needed", endpoint); + return false; + } + } + } +} diff --git a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/client/MinioStorageClient.java b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/client/MinioStorageClient.java index a5938986e..2f55c2030 100644 --- a/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/client/MinioStorageClient.java +++ b/sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/client/MinioStorageClient.java @@ -84,7 +84,7 @@ public static MinioStorageClient getStorageClient(String cloudName, } MinioAsyncClient minioClient = minioClientBuilder.build(); - if (CloudStorage.TC.getCloudName().equals(cloudName)) { + if (CloudStorage.isTcCloud(cloudName)) { minioClient.enableVirtualStyleEndpoint(); }