Skip to content

Commit c1060a7

Browse files
yhmoJeri-josedivyaruhil
authored
Adding in Proxy setting for connection to milvus. (#1350) (#1358)
* Adding in Proxy setting for connection to milvus. * proxy-setting configuration into reusable method --------- Signed-off-by: jeri-jose <jerijose111@gmail.com> Co-authored-by: Jeri Jose <72429659+Jeri-jose@users.noreply.github.com> Co-authored-by: divyaruhil <divyaruhil999@gmail.com>
1 parent 28ca9ba commit c1060a7

4 files changed

Lines changed: 87 additions & 4 deletions

File tree

sdk-core/src/main/java/io/milvus/client/MilvusServiceClient.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
import io.milvus.param.partition.*;
4747
import io.milvus.param.resourcegroup.*;
4848
import io.milvus.param.role.*;
49+
import io.milvus.v2.utils.ClientUtils;
50+
import io.grpc.ProxiedSocketAddress;
51+
import io.grpc.ProxyDetector;
4952
import lombok.NonNull;
5053
import org.apache.commons.lang3.StringUtils;
5154

@@ -58,6 +61,9 @@
5861
import java.net.InetAddress;
5962
import java.net.UnknownHostException;
6063
import java.time.LocalDateTime;
64+
import java.net.InetSocketAddress;
65+
import java.net.SocketAddress;
66+
import io.grpc.HttpConnectProxiedSocketAddress;
6167

6268
public class MilvusServiceClient extends AbstractMilvusGrpcClient {
6369

@@ -102,7 +108,6 @@ public void start(ClientCall.Listener<RespT> responseListener, Metadata headers)
102108
SslContext sslContext = GrpcSslContexts.forClient()
103109
.trustManager(new File(connectParam.getServerPemPath()))
104110
.build();
105-
106111
NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
107112
.overrideAuthority(connectParam.getServerName())
108113
.sslContext(sslContext)
@@ -112,6 +117,10 @@ public void start(ClientCall.Listener<RespT> responseListener, Metadata headers)
112117
.keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
113118
.idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
114119
.intercept(clientInterceptors);
120+
// Add proxy configuration if proxy address is set
121+
if (StringUtils.isNotEmpty(connectParam.getProxyAddress())) {
122+
ClientUtils.configureProxy(builder, connectParam.getProxyAddress());
123+
}
115124
if(connectParam.isSecure()){
116125
builder.useTransportSecurity();
117126
}
@@ -124,7 +133,6 @@ public void start(ClientCall.Listener<RespT> responseListener, Metadata headers)
124133
.trustManager(new File(connectParam.getCaPemPath()))
125134
.keyManager(new File(connectParam.getClientPemPath()), new File(connectParam.getClientKeyPath()))
126135
.build();
127-
128136
NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
129137
.sslContext(sslContext)
130138
.maxInboundMessageSize(Integer.MAX_VALUE)
@@ -133,6 +141,11 @@ public void start(ClientCall.Listener<RespT> responseListener, Metadata headers)
133141
.keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
134142
.idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
135143
.intercept(clientInterceptors);
144+
145+
// Add proxy configuration if proxy address is set
146+
if (StringUtils.isNotEmpty(connectParam.getProxyAddress())) {
147+
ClientUtils.configureProxy(builder, connectParam.getProxyAddress());
148+
}
136149
if(connectParam.isSecure()){
137150
builder.useTransportSecurity();
138151
}
@@ -150,6 +163,9 @@ public void start(ClientCall.Listener<RespT> responseListener, Metadata headers)
150163
.keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
151164
.idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
152165
.intercept(clientInterceptors);
166+
if (StringUtils.isNotEmpty(connectParam.getProxyAddress())) {
167+
ClientUtils.configureProxy(builder, connectParam.getProxyAddress());
168+
}
153169
if(connectParam.isSecure()){
154170
builder.useTransportSecurity();
155171
}

sdk-core/src/main/java/io/milvus/param/ConnectParam.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ public class ConnectParam {
5959
private final String serverName;
6060
private final String userName;
6161
private final ThreadLocal<String> clientRequestId;
62+
private final String proxyAddress;
6263

6364
protected ConnectParam(@NonNull Builder builder) {
6465
this.host = builder.host;
@@ -81,6 +82,7 @@ protected ConnectParam(@NonNull Builder builder) {
8182
this.serverName = builder.serverName;
8283
this.userName = builder.userName;
8384
this.clientRequestId = builder.clientRequestId;
85+
this.proxyAddress = builder.proxyAddress;
8486
}
8587

8688
public static Builder newBuilder() {
@@ -120,6 +122,8 @@ public static class Builder {
120122

121123
//used to set client_request_id in the grpc header uniquely for every request
122124
private ThreadLocal<String> clientRequestId;
125+
126+
private String proxyAddress;
123127

124128
protected Builder() {
125129
}
@@ -359,6 +363,17 @@ public Builder withClientRequestId(@NonNull ThreadLocal<String> clientRequestId)
359363
this.clientRequestId = clientRequestId;
360364
return this;
361365
}
366+
367+
/**
368+
* Sets the proxy address for connections through a proxy server.
369+
*
370+
* @param proxyAddress proxy server address in format "host:port"
371+
* @return <code>Builder</code>
372+
*/
373+
public Builder withProxyAddress(String proxyAddress) {
374+
this.proxyAddress = proxyAddress;
375+
return this;
376+
}
362377

363378
/**
364379
* Verifies parameters and creates a new {@link ConnectParam} instance.
@@ -418,4 +433,4 @@ protected void verify() throws ParamException {
418433
}
419434
}
420435
}
421-
}
436+
}

sdk-core/src/main/java/io/milvus/v2/client/ConnectConfig.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ public class ConnectConfig {
5656
private String caPemPath;
5757
private String serverPemPath;
5858
private String serverName;
59+
private String proxyAddress;
5960
@Builder.Default
6061
private Boolean secure = false;
6162
@Builder.Default
@@ -97,4 +98,8 @@ public Boolean isSecure() {
9798
}
9899
return secure;
99100
}
101+
102+
public String getProxyAddress(){
103+
return proxyAddress;
104+
}
100105
}

sdk-core/src/main/java/io/milvus/v2/utils/ClientUtils.java

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
import io.milvus.client.MilvusServiceClient;
3434
import io.milvus.grpc.*;
3535
import io.milvus.v2.client.ConnectConfig;
36+
import io.grpc.HttpConnectProxiedSocketAddress;
37+
import io.grpc.ProxiedSocketAddress;
38+
import io.grpc.ProxyDetector;
3639
import org.apache.commons.lang3.StringUtils;
3740
import org.jetbrains.annotations.NotNull;
3841
import org.slf4j.Logger;
@@ -46,6 +49,8 @@
4649
import java.time.LocalDateTime;
4750
import java.util.Base64;
4851
import java.util.concurrent.TimeUnit;
52+
import java.net.InetSocketAddress;
53+
import java.net.SocketAddress;
4954

5055
public class ClientUtils {
5156
Logger logger = LoggerFactory.getLogger(ClientUtils.class);
@@ -73,6 +78,11 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){
7378
.keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
7479
.idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
7580
.intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
81+
82+
if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) {
83+
configureProxy(builder, connectConfig.getProxyAddress());
84+
}
85+
7686
if(connectConfig.isSecure()) {
7787
builder.useTransportSecurity();
7888
}
@@ -95,14 +105,19 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){
95105
.keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
96106
.idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
97107
.intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
108+
109+
if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) {
110+
configureProxy(builder, connectConfig.getProxyAddress());
111+
}
112+
98113
if(connectConfig.isSecure()){
99114
builder.useTransportSecurity();
100115
}
101116
channel = builder.build();
102117
} else if (StringUtils.isNotEmpty(connectConfig.getClientPemPath())
103118
&& StringUtils.isNotEmpty(connectConfig.getClientKeyPath())
104119
&& StringUtils.isNotEmpty(connectConfig.getCaPemPath())) {
105-
// tow-way tls
120+
// two-way tls
106121
SslContext sslContext = GrpcSslContexts.forClient()
107122
.trustManager(new File(connectConfig.getCaPemPath()))
108123
.keyManager(new File(connectConfig.getClientPemPath()), new File(connectConfig.getClientKeyPath()))
@@ -116,6 +131,11 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){
116131
.keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
117132
.idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
118133
.intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
134+
135+
if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) {
136+
configureProxy(builder, connectConfig.getProxyAddress());
137+
}
138+
119139
if (connectConfig.getSecure()) {
120140
builder.useTransportSecurity();
121141
}
@@ -133,6 +153,9 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){
133153
.keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
134154
.idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
135155
.intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
156+
if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) {
157+
configureProxy(builder, connectConfig.getProxyAddress());
158+
}
136159
if(connectConfig.isSecure()){
137160
builder.useTransportSecurity();
138161
}
@@ -145,6 +168,30 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){
145168
return channel;
146169
}
147170

171+
/**
172+
* Configures the proxy settings for a NettyChannelBuilder if proxy address is specified
173+
*
174+
* @param builder NettyChannelBuilder to configure
175+
* @param connectConfig Connection configuration containing proxy settings
176+
*/
177+
public static void configureProxy(ManagedChannelBuilder builder, String proxyAddress) {
178+
String[] hostPort = proxyAddress.split(":");
179+
if (hostPort.length == 2) {
180+
String proxyHost = hostPort[0];
181+
int proxyPort = Integer.parseInt(hostPort[1]);
182+
183+
builder.proxyDetector(new ProxyDetector() {
184+
@Override
185+
public ProxiedSocketAddress proxyFor(SocketAddress targetServerAddress) {
186+
return HttpConnectProxiedSocketAddress.newBuilder()
187+
.setProxyAddress(new InetSocketAddress(proxyHost, proxyPort))
188+
.setTargetAddress((InetSocketAddress) targetServerAddress)
189+
.build();
190+
}
191+
});
192+
}
193+
}
194+
148195
private static JdkSslContext convertJavaSslContextToNetty(ConnectConfig connectConfig) {
149196
ApplicationProtocolConfig applicationProtocolConfig = new ApplicationProtocolConfig(ApplicationProtocolConfig.Protocol.NONE,
150197
ApplicationProtocolConfig.SelectorFailureBehavior.FATAL_ALERT, ApplicationProtocolConfig.SelectedListenerFailureBehavior.FATAL_ALERT);

0 commit comments

Comments
 (0)