Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions sdk-core/src/main/java/io/milvus/client/MilvusServiceClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
import io.milvus.param.partition.*;
import io.milvus.param.resourcegroup.*;
import io.milvus.param.role.*;
import io.milvus.v2.utils.ClientUtils;
import io.grpc.ProxiedSocketAddress;
import io.grpc.ProxyDetector;
import lombok.NonNull;
import org.apache.commons.lang3.StringUtils;

Expand All @@ -58,6 +61,9 @@
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.time.LocalDateTime;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import io.grpc.HttpConnectProxiedSocketAddress;

public class MilvusServiceClient extends AbstractMilvusGrpcClient {

Expand Down Expand Up @@ -102,7 +108,6 @@ public void start(ClientCall.Listener<RespT> responseListener, Metadata headers)
SslContext sslContext = GrpcSslContexts.forClient()
.trustManager(new File(connectParam.getServerPemPath()))
.build();

NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
.overrideAuthority(connectParam.getServerName())
.sslContext(sslContext)
Expand All @@ -112,6 +117,10 @@ public void start(ClientCall.Listener<RespT> responseListener, Metadata headers)
.keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
.idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
.intercept(clientInterceptors);
// Add proxy configuration if proxy address is set
if (StringUtils.isNotEmpty(connectParam.getProxyAddress())) {
ClientUtils.configureProxy(builder, connectParam.getProxyAddress());
}
if(connectParam.isSecure()){
builder.useTransportSecurity();
}
Expand All @@ -124,7 +133,6 @@ public void start(ClientCall.Listener<RespT> responseListener, Metadata headers)
.trustManager(new File(connectParam.getCaPemPath()))
.keyManager(new File(connectParam.getClientPemPath()), new File(connectParam.getClientKeyPath()))
.build();

NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
.sslContext(sslContext)
.maxInboundMessageSize(Integer.MAX_VALUE)
Expand All @@ -133,6 +141,11 @@ public void start(ClientCall.Listener<RespT> responseListener, Metadata headers)
.keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
.idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
.intercept(clientInterceptors);

// Add proxy configuration if proxy address is set
if (StringUtils.isNotEmpty(connectParam.getProxyAddress())) {
ClientUtils.configureProxy(builder, connectParam.getProxyAddress());
}
if(connectParam.isSecure()){
builder.useTransportSecurity();
}
Expand All @@ -150,6 +163,9 @@ public void start(ClientCall.Listener<RespT> responseListener, Metadata headers)
.keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
.idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
.intercept(clientInterceptors);
if (StringUtils.isNotEmpty(connectParam.getProxyAddress())) {
ClientUtils.configureProxy(builder, connectParam.getProxyAddress());
}
if(connectParam.isSecure()){
builder.useTransportSecurity();
}
Expand Down
17 changes: 16 additions & 1 deletion sdk-core/src/main/java/io/milvus/param/ConnectParam.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public class ConnectParam {
private final String serverName;
private final String userName;
private final ThreadLocal<String> clientRequestId;
private final String proxyAddress;

protected ConnectParam(@NonNull Builder builder) {
this.host = builder.host;
Expand All @@ -81,6 +82,7 @@ protected ConnectParam(@NonNull Builder builder) {
this.serverName = builder.serverName;
this.userName = builder.userName;
this.clientRequestId = builder.clientRequestId;
this.proxyAddress = builder.proxyAddress;
}

public static Builder newBuilder() {
Expand Down Expand Up @@ -120,6 +122,8 @@ public static class Builder {

//used to set client_request_id in the grpc header uniquely for every request
private ThreadLocal<String> clientRequestId;

private String proxyAddress;

protected Builder() {
}
Expand Down Expand Up @@ -359,6 +363,17 @@ public Builder withClientRequestId(@NonNull ThreadLocal<String> clientRequestId)
this.clientRequestId = clientRequestId;
return this;
}

/**
* Sets the proxy address for connections through a proxy server.
*
* @param proxyAddress proxy server address in format "host:port"
* @return <code>Builder</code>
*/
public Builder withProxyAddress(String proxyAddress) {
this.proxyAddress = proxyAddress;
return this;
}

/**
* Verifies parameters and creates a new {@link ConnectParam} instance.
Expand Down Expand Up @@ -418,4 +433,4 @@ protected void verify() throws ParamException {
}
}
}
}
}
5 changes: 5 additions & 0 deletions sdk-core/src/main/java/io/milvus/v2/client/ConnectConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public class ConnectConfig {
private String caPemPath;
private String serverPemPath;
private String serverName;
private String proxyAddress;
@Builder.Default
private Boolean secure = false;
@Builder.Default
Expand Down Expand Up @@ -97,4 +98,8 @@ public Boolean isSecure() {
}
return secure;
}

public String getProxyAddress(){
return proxyAddress;
}
}
49 changes: 48 additions & 1 deletion sdk-core/src/main/java/io/milvus/v2/utils/ClientUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
import io.milvus.client.MilvusServiceClient;
import io.milvus.grpc.*;
import io.milvus.v2.client.ConnectConfig;
import io.grpc.HttpConnectProxiedSocketAddress;
import io.grpc.ProxiedSocketAddress;
import io.grpc.ProxyDetector;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
Expand All @@ -46,6 +49,8 @@
import java.time.LocalDateTime;
import java.util.Base64;
import java.util.concurrent.TimeUnit;
import java.net.InetSocketAddress;
import java.net.SocketAddress;

public class ClientUtils {
Logger logger = LoggerFactory.getLogger(ClientUtils.class);
Expand Down Expand Up @@ -73,6 +78,11 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){
.keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
.idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
.intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));

if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) {
configureProxy(builder, connectConfig.getProxyAddress());
}

if(connectConfig.isSecure()) {
builder.useTransportSecurity();
}
Expand All @@ -95,14 +105,19 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){
.keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
.idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
.intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));

if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) {
configureProxy(builder, connectConfig.getProxyAddress());
}

if(connectConfig.isSecure()){
builder.useTransportSecurity();
}
channel = builder.build();
} else if (StringUtils.isNotEmpty(connectConfig.getClientPemPath())
&& StringUtils.isNotEmpty(connectConfig.getClientKeyPath())
&& StringUtils.isNotEmpty(connectConfig.getCaPemPath())) {
// tow-way tls
// two-way tls
SslContext sslContext = GrpcSslContexts.forClient()
.trustManager(new File(connectConfig.getCaPemPath()))
.keyManager(new File(connectConfig.getClientPemPath()), new File(connectConfig.getClientKeyPath()))
Expand All @@ -116,6 +131,11 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){
.keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
.idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
.intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));

if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) {
configureProxy(builder, connectConfig.getProxyAddress());
}

if (connectConfig.getSecure()) {
builder.useTransportSecurity();
}
Expand All @@ -133,6 +153,9 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){
.keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
.idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
.intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) {
configureProxy(builder, connectConfig.getProxyAddress());
}
if(connectConfig.isSecure()){
builder.useTransportSecurity();
}
Expand All @@ -145,6 +168,30 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){
return channel;
}

/**
* Configures the proxy settings for a NettyChannelBuilder if proxy address is specified
*
* @param builder NettyChannelBuilder to configure
* @param connectConfig Connection configuration containing proxy settings
*/
public static void configureProxy(ManagedChannelBuilder builder, String proxyAddress) {
String[] hostPort = proxyAddress.split(":");
if (hostPort.length == 2) {
String proxyHost = hostPort[0];
int proxyPort = Integer.parseInt(hostPort[1]);

builder.proxyDetector(new ProxyDetector() {
@Override
public ProxiedSocketAddress proxyFor(SocketAddress targetServerAddress) {
return HttpConnectProxiedSocketAddress.newBuilder()
.setProxyAddress(new InetSocketAddress(proxyHost, proxyPort))
.setTargetAddress((InetSocketAddress) targetServerAddress)
.build();
}
});
}
}

private static JdkSslContext convertJavaSslContextToNetty(ConnectConfig connectConfig) {
ApplicationProtocolConfig applicationProtocolConfig = new ApplicationProtocolConfig(ApplicationProtocolConfig.Protocol.NONE,
ApplicationProtocolConfig.SelectorFailureBehavior.FATAL_ALERT, ApplicationProtocolConfig.SelectedListenerFailureBehavior.FATAL_ALERT);
Expand Down
Loading