Skip to content

Commit 5162b53

Browse files
committed
support connectType if use oss-bucket
1 parent 3694749 commit 5162b53

11 files changed

Lines changed: 153 additions & 34 deletions

File tree

examples/src/main/java/io/milvus/v1/BulkWriterExample.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ private static RemoteBulkWriter buildRemoteBulkWriter(CollectionSchemaParam coll
448448

449449
private static StorageConnectParam buildStorageConnectParam() {
450450
StorageConnectParam connectParam;
451-
if (StorageConsts.cloudStorage == CloudStorage.AZURE) {
451+
if (CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName())) {
452452
String connectionStr = "DefaultEndpointsProtocol=https;AccountName=" + StorageConsts.AZURE_ACCOUNT_NAME +
453453
";AccountKey=" + StorageConsts.AZURE_ACCOUNT_KEY + ";EndpointSuffix=core.windows.net";
454454
connectParam = AzureConnectParam.newBuilder()
@@ -541,11 +541,11 @@ private void callBulkInsert(CollectionSchemaParam collectionSchema, List<List<St
541541
}
542542

543543
private void callCloudImport(List<List<String>> batchFiles, String collectionName, String partitionName) throws InterruptedException {
544-
String objectUrl = StorageConsts.cloudStorage == CloudStorage.AZURE
544+
String objectUrl = CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName())
545545
? StorageConsts.cloudStorage.getAzureObjectUrl(StorageConsts.AZURE_ACCOUNT_NAME, StorageConsts.AZURE_CONTAINER_NAME, ImportUtils.getCommonPrefix(batchFiles))
546546
: StorageConsts.cloudStorage.getS3ObjectUrl(StorageConsts.STORAGE_BUCKET, ImportUtils.getCommonPrefix(batchFiles), StorageConsts.STORAGE_REGION);
547-
String accessKey = StorageConsts.cloudStorage == CloudStorage.AZURE ? StorageConsts.AZURE_ACCOUNT_NAME : StorageConsts.STORAGE_ACCESS_KEY;
548-
String secretKey = StorageConsts.cloudStorage == CloudStorage.AZURE ? StorageConsts.AZURE_ACCOUNT_KEY : StorageConsts.STORAGE_SECRET_KEY;
547+
String accessKey = CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName()) ? StorageConsts.AZURE_ACCOUNT_NAME : StorageConsts.STORAGE_ACCESS_KEY;
548+
String secretKey = CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName()) ? StorageConsts.AZURE_ACCOUNT_KEY : StorageConsts.STORAGE_SECRET_KEY;
549549

550550
System.out.println("\n===================== call cloudImport ====================");
551551
List<String> objectUrls = Lists.newArrayList(objectUrl);

examples/src/main/java/io/milvus/v2/StageFileManagerExample.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import com.google.gson.Gson;
2222
import io.milvus.bulkwriter.StageFileManager;
2323
import io.milvus.bulkwriter.StageFileManagerParam;
24+
import io.milvus.bulkwriter.common.clientenum.ConnectType;
2425
import io.milvus.bulkwriter.model.UploadFilesResult;
2526
import io.milvus.bulkwriter.request.stage.UploadFilesRequest;
2627

@@ -35,6 +36,7 @@ public class StageFileManagerExample {
3536
.withCloudEndpoint("https://api.cloud.zilliz.com")
3637
.withApiKey("_api_key_for_cluster_org_")
3738
.withStageName("_stage_name_for_project_")
39+
.withConnectType(ConnectType.AUTO)
3840
.build();
3941
stageFileManager = new StageFileManager(stageFileManagerParam);
4042
}

examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterRemoteExample.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ private static RemoteBulkWriter buildRemoteBulkWriter(CreateCollectionReq.Collec
392392

393393
private static StorageConnectParam buildStorageConnectParam() {
394394
StorageConnectParam connectParam;
395-
if (StorageConsts.cloudStorage == CloudStorage.AZURE) {
395+
if (CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName())) {
396396
String connectionStr = "DefaultEndpointsProtocol=https;AccountName=" + StorageConsts.AZURE_ACCOUNT_NAME +
397397
";AccountKey=" + StorageConsts.AZURE_ACCOUNT_KEY + ";EndpointSuffix=core.windows.net";
398398
connectParam = AzureConnectParam.newBuilder()

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageBulkWriter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import com.google.common.collect.Lists;
2323
import com.google.gson.JsonObject;
24+
import io.milvus.bulkwriter.common.clientenum.ConnectType;
2425
import io.milvus.bulkwriter.model.UploadFilesResult;
2526
import io.milvus.bulkwriter.request.stage.UploadFilesRequest;
2627
import io.milvus.common.utils.ExceptionUtils;
@@ -63,7 +64,7 @@ public StageBulkWriter(StageBulkWriterParam bulkWriterParam) throws IOException
6364
private StageFileManager initStageFileManagerParams(StageBulkWriterParam bulkWriterParam) throws IOException {
6465
StageFileManagerParam stageFileManagerParam = StageFileManagerParam.newBuilder()
6566
.withCloudEndpoint(bulkWriterParam.getCloudEndpoint()).withApiKey(bulkWriterParam.getApiKey())
66-
.withStageName(bulkWriterParam.getStageName())
67+
.withStageName(bulkWriterParam.getStageName()).withConnectType(ConnectType.AUTO)
6768
.build();
6869
return new StageFileManager(stageFileManagerParam);
6970
}

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageFileManager.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
package io.milvus.bulkwriter;
2121

2222
import com.google.gson.Gson;
23+
import io.milvus.bulkwriter.common.clientenum.ConnectType;
2324
import io.milvus.bulkwriter.common.utils.FileUtils;
2425
import io.milvus.bulkwriter.model.UploadFilesResult;
2526
import io.milvus.bulkwriter.request.stage.ApplyStageRequest;
2627
import io.milvus.bulkwriter.request.stage.UploadFilesRequest;
28+
import io.milvus.bulkwriter.resolver.EndpointResolver;
2729
import io.milvus.bulkwriter.response.ApplyStageResponse;
2830
import io.milvus.bulkwriter.restful.DataStageUtils;
2931
import io.milvus.bulkwriter.storage.StorageClient;
@@ -54,6 +56,7 @@ public class StageFileManager {
5456
private final String cloudEndpoint;
5557
private final String apiKey;
5658
private final String stageName;
59+
private final ConnectType connectType;
5760
private final ExecutorService executor;
5861

5962
private StorageClient storageClient;
@@ -63,7 +66,8 @@ public StageFileManager(StageFileManagerParam stageWriterParam) {
6366
this.cloudEndpoint = stageWriterParam.getCloudEndpoint();
6467
this.apiKey = stageWriterParam.getApiKey();
6568
this.stageName = stageWriterParam.getStageName();
66-
this.executor = Executors.newFixedThreadPool(20);
69+
this.connectType = stageWriterParam.getConnectType();
70+
this.executor = Executors.newFixedThreadPool(10);
6771
}
6872

6973
/**
@@ -138,7 +142,7 @@ public CompletableFuture<UploadFilesResult> uploadFilesAsync(UploadFilesRequest
138142
public void shutdownGracefully() {
139143
executor.shutdown();
140144
try {
141-
if (!executor.awaitTermination(10, TimeUnit.SECONDS)) {
145+
if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
142146
logger.warn("Executor didn't terminate in time, forcing shutdown...");
143147
executor.shutdownNow();
144148
}
@@ -168,9 +172,11 @@ private void refreshStageAndClient(String path) {
168172
applyStageResponse = new Gson().fromJson(result, ApplyStageResponse.class);
169173
logger.info("stage info refreshed");
170174

175+
String endpoint = EndpointResolver.resolveEndpoint(applyStageResponse.getEndpoint(), applyStageResponse.getCloud(),
176+
applyStageResponse.getRegion(), connectType);
171177
storageClient = MinioStorageClient.getStorageClient(
172178
applyStageResponse.getCloud(),
173-
applyStageResponse.getEndpoint(),
179+
endpoint,
174180
applyStageResponse.getCredentials().getTmpAK(),
175181
applyStageResponse.getCredentials().getTmpSK(),
176182
applyStageResponse.getCredentials().getSessionToken(),
@@ -235,6 +241,8 @@ private <T> T withRetry(String actionName, Callable<T> callable, String stagePat
235241
while (attempt < maxRetries) {
236242
try {
237243
return callable.call();
244+
} catch (RuntimeException e) {
245+
throw e;
238246
} catch (Exception e) {
239247
attempt++;
240248
refreshStageAndClient(stagePath);

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageFileManagerParam.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package io.milvus.bulkwriter;
2121

22+
import io.milvus.bulkwriter.common.clientenum.ConnectType;
2223
import io.milvus.exception.ParamException;
2324
import io.milvus.param.ParamUtils;
2425
import lombok.Getter;
@@ -35,11 +36,13 @@ public class StageFileManagerParam {
3536
private final String cloudEndpoint;
3637
private final String apiKey;
3738
private final String stageName;
39+
private final ConnectType connectType;
3840

3941
private StageFileManagerParam(@NonNull Builder builder) {
4042
this.cloudEndpoint = builder.cloudEndpoint;
4143
this.apiKey = builder.apiKey;
4244
this.stageName = builder.stageName;
45+
this.connectType = builder.connectType;
4346
}
4447

4548
public static Builder newBuilder() {
@@ -56,6 +59,8 @@ public static final class Builder {
5659

5760
private String stageName;
5861

62+
private ConnectType connectType = ConnectType.AUTO;
63+
5964
private Builder() {
6065
}
6166

@@ -79,6 +84,17 @@ public Builder withStageName(@NotNull String stageName) {
7984
return this;
8085
}
8186

87+
/**
88+
* Current value is mainly for Aliyun OSS buckets, default is Auto.
89+
* In the default case, if the OSS bucket is reachable via the internal endpoint, the internal endpoint will be used;
90+
* otherwise, the public endpoint will be used.
91+
* You can also force the use of either the internal or public endpoint.
92+
*/
93+
public Builder withConnectType(@NotNull ConnectType connectType) {
94+
this.connectType = connectType;
95+
return this;
96+
}
97+
8298
/**
8399
* Verifies parameters and creates a new {@link StageFileManagerParam} instance.
84100
*

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/CloudStorage.java

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,25 @@
2222
import io.milvus.exception.ParamException;
2323
import lombok.Getter;
2424
import org.apache.commons.lang3.StringUtils;
25+
import org.apache.hadoop.util.Lists;
26+
27+
import java.util.List;
2528

2629
public enum CloudStorage {
2730
MINIO("minio","%s", "minioAddress"),
2831
AWS("aws","s3.amazonaws.com", null),
2932
GCP("gcp" ,"storage.googleapis.com", null),
33+
34+
AZ("az" ,"%s.blob.core.windows.net", "accountName"),
3035
AZURE("azure" ,"%s.blob.core.windows.net", "accountName"),
36+
3137
ALI("ali","oss-%s.aliyuncs.com", "region"),
32-
TC("tc","cos.%s.myqcloud.com", "region")
38+
ALIYUN("aliyun","oss-%s.aliyuncs.com", "region"),
39+
ALIBABA("alibaba","oss-%s.aliyuncs.com", "region"),
40+
ALICLOU("alicloud","oss-%s.aliyuncs.com", "region"),
41+
42+
TC("tc","cos.%s.myqcloud.com", "region"),
43+
TENCENT("tencent","cos.%s.myqcloud.com", "region")
3344
;
3445

3546
@Getter
@@ -45,6 +56,27 @@ public enum CloudStorage {
4556
this.replace = replace;
4657
}
4758

59+
public static boolean isAliCloud(String cloudName) {
60+
List<CloudStorage> aliCloudStorages = Lists.newArrayList(
61+
CloudStorage.ALI, CloudStorage.ALIYUN, CloudStorage.ALIBABA, CloudStorage.ALICLOU
62+
);
63+
return aliCloudStorages.stream().anyMatch(e -> e.getCloudName().equalsIgnoreCase(cloudName));
64+
}
65+
66+
public static boolean isTcCloud(String cloudName) {
67+
List<CloudStorage> tcCloudStorages = Lists.newArrayList(
68+
CloudStorage.TC, CloudStorage.TENCENT
69+
);
70+
return tcCloudStorages.stream().anyMatch(e -> e.getCloudName().equalsIgnoreCase(cloudName));
71+
}
72+
73+
public static boolean isAzCloud(String cloudName) {
74+
List<CloudStorage> azCloudStorages = Lists.newArrayList(
75+
CloudStorage.AZ, CloudStorage.AZURE
76+
);
77+
return azCloudStorages.stream().anyMatch(e -> e.getCloudName().equalsIgnoreCase(cloudName));
78+
}
79+
4880
public static CloudStorage getCloudStorage(String cloudName) {
4981
for (CloudStorage cloudStorage : values()) {
5082
if (cloudStorage.getCloudName().equals(cloudName)) {
@@ -71,16 +103,20 @@ public String getS3ObjectUrl(String bucketName, String commonPrefix, String regi
71103
case GCP:
72104
return String.format("https://storage.cloud.google.com/%s/%s", bucketName, commonPrefix);
73105
case TC:
106+
case TENCENT:
74107
return String.format("https://%s.cos.%s.myqcloud.com/%s", bucketName, region, commonPrefix);
75108
case ALI:
109+
case ALICLOU:
110+
case ALIBABA:
111+
case ALIYUN:
76112
return String.format("https://%s.oss-%s.aliyuncs.com/%s", bucketName, region, commonPrefix);
77113
default:
78114
throw new ParamException("no support others remote storage address");
79115
}
80116
}
81117

82118
public String getAzureObjectUrl(String accountName, String containerName, String commonPrefix) {
83-
if (this == CloudStorage.AZURE) {
119+
if (CloudStorage.isAzCloud(this.getCloudName())) {
84120
return String.format("https://%s.blob.core.windows.net/%s/%s", accountName, containerName, commonPrefix);
85121
}
86122
throw new ParamException("no support others remote storage address");
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package io.milvus.bulkwriter.common.clientenum;
2+
3+
public enum ConnectType {
4+
AUTO,
5+
INTERNAL,
6+
PUBLIC
7+
}

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/StorageUtils.java

Lines changed: 0 additions & 22 deletions
This file was deleted.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package io.milvus.bulkwriter.resolver;
2+
3+
import io.milvus.bulkwriter.common.clientenum.CloudStorage;
4+
import io.milvus.bulkwriter.common.clientenum.ConnectType;
5+
import org.slf4j.Logger;
6+
import org.slf4j.LoggerFactory;
7+
8+
import java.net.HttpURLConnection;
9+
import java.net.URL;
10+
import java.util.concurrent.TimeUnit;
11+
12+
public class EndpointResolver {
13+
private static final Logger logger = LoggerFactory.getLogger(EndpointResolver.class);
14+
15+
public static String resolveEndpoint(String defaultEndpoint, String cloud, String region, ConnectType connectType) {
16+
logger.info("Start resolving endpoint, cloud:{}, region:{}, connectType:{}", cloud, region, connectType);
17+
if (CloudStorage.isAliCloud(cloud)) {
18+
defaultEndpoint = resolveOssEndpoint(region, connectType);
19+
}
20+
logger.info("Resolved endpoint: {}, reachable check passed", defaultEndpoint);
21+
return defaultEndpoint;
22+
}
23+
24+
private static String resolveOssEndpoint(String region, ConnectType connectType) {
25+
String internalEndpoint = String.format("oss-%s-internal.aliyuncs.com", region);
26+
String publicEndpoint = String.format("oss-%s.aliyuncs.com", region);
27+
28+
switch (connectType) {
29+
case INTERNAL:
30+
logger.info("Forced INTERNAL endpoint selected: {}", internalEndpoint);
31+
checkEndpointReachable(internalEndpoint, true);
32+
return internalEndpoint;
33+
case PUBLIC:
34+
logger.info("Forced PUBLIC endpoint selected: {}", publicEndpoint);
35+
checkEndpointReachable(publicEndpoint, true);
36+
return publicEndpoint;
37+
case AUTO:
38+
default:
39+
if (checkEndpointReachable(internalEndpoint, false)) {
40+
logger.info("AUTO mode: internal endpoint reachable, using {}", internalEndpoint);
41+
return internalEndpoint;
42+
} else {
43+
logger.warn("AUTO mode: internal endpoint not reachable, fallback to public endpoint {}", publicEndpoint);
44+
checkEndpointReachable(publicEndpoint, true);
45+
return publicEndpoint;
46+
}
47+
}
48+
}
49+
50+
private static boolean checkEndpointReachable(String endpoint, boolean printError) {
51+
try {
52+
String httpEndpoint = String.format("https://%s", endpoint);
53+
URL url = new URL(httpEndpoint);
54+
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
55+
conn.setConnectTimeout((int) TimeUnit.SECONDS.toMillis(5));
56+
conn.setReadTimeout((int) TimeUnit.SECONDS.toMillis(5));
57+
conn.setRequestMethod("HEAD");
58+
int code = conn.getResponseCode();
59+
logger.debug("Checked endpoint {}, response code={}", endpoint, code);
60+
return code >= 200 && code < 400;
61+
} catch (Exception e) {
62+
if (printError) {
63+
logger.error("Endpoint {} not reachable, throwing exception", endpoint, e);
64+
throw new RuntimeException(e.getMessage());
65+
} else {
66+
logger.warn("Endpoint {} not reachable, will fallback if needed", endpoint);
67+
return false;
68+
}
69+
}
70+
}
71+
}

0 commit comments

Comments
 (0)