diff --git a/sdk-core/src/main/java/io/milvus/client/MilvusServiceClient.java b/sdk-core/src/main/java/io/milvus/client/MilvusServiceClient.java index 0aba92ff1..7b8404125 100644 --- a/sdk-core/src/main/java/io/milvus/client/MilvusServiceClient.java +++ b/sdk-core/src/main/java/io/milvus/client/MilvusServiceClient.java @@ -46,6 +46,9 @@ import io.milvus.param.partition.*; import io.milvus.param.resourcegroup.*; import io.milvus.param.role.*; +import io.milvus.v2.utils.ClientUtils; +import io.grpc.ProxiedSocketAddress; +import io.grpc.ProxyDetector; import lombok.NonNull; import org.apache.commons.lang3.StringUtils; @@ -58,6 +61,9 @@ import java.net.InetAddress; import java.net.UnknownHostException; import java.time.LocalDateTime; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import io.grpc.HttpConnectProxiedSocketAddress; public class MilvusServiceClient extends AbstractMilvusGrpcClient { @@ -102,7 +108,6 @@ public void start(ClientCall.Listener responseListener, Metadata headers) SslContext sslContext = GrpcSslContexts.forClient() .trustManager(new File(connectParam.getServerPemPath())) .build(); - NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort()) .overrideAuthority(connectParam.getServerName()) .sslContext(sslContext) @@ -112,6 +117,10 @@ public void start(ClientCall.Listener responseListener, Metadata headers) .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls()) .idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) .intercept(clientInterceptors); + // Add proxy configuration if proxy address is set + if (StringUtils.isNotEmpty(connectParam.getProxyAddress())) { + ClientUtils.configureProxy(builder, connectParam.getProxyAddress()); + } if(connectParam.isSecure()){ builder.useTransportSecurity(); } @@ -124,7 +133,6 @@ public void start(ClientCall.Listener responseListener, Metadata headers) .trustManager(new File(connectParam.getCaPemPath())) .keyManager(new File(connectParam.getClientPemPath()), new File(connectParam.getClientKeyPath())) .build(); - NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort()) .sslContext(sslContext) .maxInboundMessageSize(Integer.MAX_VALUE) @@ -133,6 +141,11 @@ public void start(ClientCall.Listener responseListener, Metadata headers) .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls()) .idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) .intercept(clientInterceptors); + + // Add proxy configuration if proxy address is set + if (StringUtils.isNotEmpty(connectParam.getProxyAddress())) { + ClientUtils.configureProxy(builder, connectParam.getProxyAddress()); + } if(connectParam.isSecure()){ builder.useTransportSecurity(); } @@ -150,6 +163,9 @@ public void start(ClientCall.Listener responseListener, Metadata headers) .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls()) .idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) .intercept(clientInterceptors); + if (StringUtils.isNotEmpty(connectParam.getProxyAddress())) { + ClientUtils.configureProxy(builder, connectParam.getProxyAddress()); + } if(connectParam.isSecure()){ builder.useTransportSecurity(); } diff --git a/sdk-core/src/main/java/io/milvus/param/ConnectParam.java b/sdk-core/src/main/java/io/milvus/param/ConnectParam.java index 1e1c2fd5d..5eeb28067 100644 --- a/sdk-core/src/main/java/io/milvus/param/ConnectParam.java +++ b/sdk-core/src/main/java/io/milvus/param/ConnectParam.java @@ -59,6 +59,7 @@ public class ConnectParam { private final String serverName; private final String userName; private final ThreadLocal clientRequestId; + private final String proxyAddress; protected ConnectParam(@NonNull Builder builder) { this.host = builder.host; @@ -81,6 +82,7 @@ protected ConnectParam(@NonNull Builder builder) { this.serverName = builder.serverName; this.userName = builder.userName; this.clientRequestId = builder.clientRequestId; + this.proxyAddress = builder.proxyAddress; } public static Builder newBuilder() { @@ -120,6 +122,8 @@ public static class Builder { //used to set client_request_id in the grpc header uniquely for every request private ThreadLocal clientRequestId; + + private String proxyAddress; protected Builder() { } @@ -359,6 +363,17 @@ public Builder withClientRequestId(@NonNull ThreadLocal clientRequestId) this.clientRequestId = clientRequestId; return this; } + + /** + * Sets the proxy address for connections through a proxy server. + * + * @param proxyAddress proxy server address in format "host:port" + * @return Builder + */ + public Builder withProxyAddress(String proxyAddress) { + this.proxyAddress = proxyAddress; + return this; + } /** * Verifies parameters and creates a new {@link ConnectParam} instance. @@ -418,4 +433,4 @@ protected void verify() throws ParamException { } } } -} +} \ No newline at end of file diff --git a/sdk-core/src/main/java/io/milvus/v2/client/ConnectConfig.java b/sdk-core/src/main/java/io/milvus/v2/client/ConnectConfig.java index 9e401c1f9..8e1f1e971 100644 --- a/sdk-core/src/main/java/io/milvus/v2/client/ConnectConfig.java +++ b/sdk-core/src/main/java/io/milvus/v2/client/ConnectConfig.java @@ -56,6 +56,7 @@ public class ConnectConfig { private String caPemPath; private String serverPemPath; private String serverName; + private String proxyAddress; @Builder.Default private Boolean secure = false; @Builder.Default @@ -97,4 +98,8 @@ public Boolean isSecure() { } return secure; } + + public String getProxyAddress(){ + return proxyAddress; + } } diff --git a/sdk-core/src/main/java/io/milvus/v2/utils/ClientUtils.java b/sdk-core/src/main/java/io/milvus/v2/utils/ClientUtils.java index 8164c40cb..65c7f6fed 100644 --- a/sdk-core/src/main/java/io/milvus/v2/utils/ClientUtils.java +++ b/sdk-core/src/main/java/io/milvus/v2/utils/ClientUtils.java @@ -33,6 +33,9 @@ import io.milvus.client.MilvusServiceClient; import io.milvus.grpc.*; import io.milvus.v2.client.ConnectConfig; +import io.grpc.HttpConnectProxiedSocketAddress; +import io.grpc.ProxiedSocketAddress; +import io.grpc.ProxyDetector; import org.apache.commons.lang3.StringUtils; import org.jetbrains.annotations.NotNull; import org.slf4j.Logger; @@ -46,6 +49,8 @@ import java.time.LocalDateTime; import java.util.Base64; import java.util.concurrent.TimeUnit; +import java.net.InetSocketAddress; +import java.net.SocketAddress; public class ClientUtils { Logger logger = LoggerFactory.getLogger(ClientUtils.class); @@ -73,6 +78,11 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){ .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls()) .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata)); + + if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) { + configureProxy(builder, connectConfig.getProxyAddress()); + } + if(connectConfig.isSecure()) { builder.useTransportSecurity(); } @@ -95,6 +105,11 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){ .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls()) .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata)); + + if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) { + configureProxy(builder, connectConfig.getProxyAddress()); + } + if(connectConfig.isSecure()){ builder.useTransportSecurity(); } @@ -102,7 +117,7 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){ } else if (StringUtils.isNotEmpty(connectConfig.getClientPemPath()) && StringUtils.isNotEmpty(connectConfig.getClientKeyPath()) && StringUtils.isNotEmpty(connectConfig.getCaPemPath())) { - // tow-way tls + // two-way tls SslContext sslContext = GrpcSslContexts.forClient() .trustManager(new File(connectConfig.getCaPemPath())) .keyManager(new File(connectConfig.getClientPemPath()), new File(connectConfig.getClientKeyPath())) @@ -116,6 +131,11 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){ .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls()) .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata)); + + if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) { + configureProxy(builder, connectConfig.getProxyAddress()); + } + if (connectConfig.getSecure()) { builder.useTransportSecurity(); } @@ -133,6 +153,9 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){ .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls()) .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata)); + if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) { + configureProxy(builder, connectConfig.getProxyAddress()); + } if(connectConfig.isSecure()){ builder.useTransportSecurity(); } @@ -145,6 +168,30 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){ return channel; } + /** + * Configures the proxy settings for a NettyChannelBuilder if proxy address is specified + * + * @param builder NettyChannelBuilder to configure + * @param connectConfig Connection configuration containing proxy settings + */ + public static void configureProxy(ManagedChannelBuilder builder, String proxyAddress) { + String[] hostPort = proxyAddress.split(":"); + if (hostPort.length == 2) { + String proxyHost = hostPort[0]; + int proxyPort = Integer.parseInt(hostPort[1]); + + builder.proxyDetector(new ProxyDetector() { + @Override + public ProxiedSocketAddress proxyFor(SocketAddress targetServerAddress) { + return HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(new InetSocketAddress(proxyHost, proxyPort)) + .setTargetAddress((InetSocketAddress) targetServerAddress) + .build(); + } + }); + } + } + private static JdkSslContext convertJavaSslContextToNetty(ConnectConfig connectConfig) { ApplicationProtocolConfig applicationProtocolConfig = new ApplicationProtocolConfig(ApplicationProtocolConfig.Protocol.NONE, ApplicationProtocolConfig.SelectorFailureBehavior.FATAL_ALERT, ApplicationProtocolConfig.SelectedListenerFailureBehavior.FATAL_ALERT);