diff --git a/sdk-core/src/main/java/io/milvus/v2/client/ConnectConfig.java b/sdk-core/src/main/java/io/milvus/v2/client/ConnectConfig.java index 3178d7157..4fd3bb985 100644 --- a/sdk-core/src/main/java/io/milvus/v2/client/ConnectConfig.java +++ b/sdk-core/src/main/java/io/milvus/v2/client/ConnectConfig.java @@ -61,6 +61,9 @@ public class ConnectConfig { private long idleTimeoutMs = TimeUnit.MILLISECONDS.convert(24, TimeUnit.HOURS); private SSLContext sslContext; + // clientRequestId maintains a map for different threads, each thread can assign a specific id. + // the specific id is passed to the server, from the access log we can know which client calls the interface + private ThreadLocal clientRequestId; public String getHost() { io.milvus.utils.URLParser urlParser = new io.milvus.utils.URLParser(this.uri); diff --git a/sdk-core/src/main/java/io/milvus/v2/utils/ClientUtils.java b/sdk-core/src/main/java/io/milvus/v2/utils/ClientUtils.java index 65c7f6fed..9aee732f6 100644 --- a/sdk-core/src/main/java/io/milvus/v2/utils/ClientUtils.java +++ b/sdk-core/src/main/java/io/milvus/v2/utils/ClientUtils.java @@ -19,9 +19,7 @@ package io.milvus.v2.utils; -import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; -import io.grpc.Metadata; +import io.grpc.*; import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.shaded.io.netty.handler.ssl.ApplicationProtocolConfig; @@ -33,11 +31,7 @@ import io.milvus.client.MilvusServiceClient; import io.milvus.grpc.*; import io.milvus.v2.client.ConnectConfig; -import io.grpc.HttpConnectProxiedSocketAddress; -import io.grpc.ProxiedSocketAddress; -import io.grpc.ProxyDetector; import org.apache.commons.lang3.StringUtils; -import org.jetbrains.annotations.NotNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,7 +41,9 @@ import java.net.UnknownHostException; import java.nio.charset.StandardCharsets; import java.time.LocalDateTime; +import java.util.ArrayList; import java.util.Base64; +import java.util.List; import java.util.concurrent.TimeUnit; import java.net.InetSocketAddress; import java.net.SocketAddress; @@ -66,6 +62,30 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){ metadata.put(Metadata.Key.of("dbname", Metadata.ASCII_STRING_MARSHALLER), connectConfig.getDbName()); } + List clientInterceptors = new ArrayList<>(); + clientInterceptors.add(MetadataUtils.newAttachHeadersInterceptor(metadata)); + //client interceptor used to fetch client_request_id from threadlocal variable and set it for every grpc request + clientInterceptors.add(new ClientInterceptor() { + @Override + public ClientCall interceptCall(MethodDescriptor method, CallOptions callOptions, Channel next) { + return new ForwardingClientCall + .SimpleForwardingClientCall(next.newCall(method, callOptions)) { + @Override + public void start(ClientCall.Listener responseListener, Metadata headers) { + String currentMs = String.valueOf(System.currentTimeMillis()); + headers.put(Metadata.Key.of("client-request-unixmsec", Metadata.ASCII_STRING_MARSHALLER), currentMs); + if(connectConfig.getClientRequestId() != null) { + String clientID = connectConfig.getClientRequestId().get(); + if (!StringUtils.isEmpty(clientID)) { + headers.put(Metadata.Key.of("client_request_id", Metadata.ASCII_STRING_MARSHALLER), clientID); + } + } + super.start(responseListener, headers); + } + }; + } + }); + try { if (connectConfig.getSslContext() != null) { // sslContext from connect config @@ -77,7 +97,7 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){ .keepAliveTimeout(connectConfig.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS) .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls()) .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) - .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata)); + .intercept(clientInterceptors); if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) { configureProxy(builder, connectConfig.getProxyAddress()); @@ -104,7 +124,7 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){ .keepAliveTimeout(connectConfig.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS) .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls()) .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) - .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata)); + .intercept(clientInterceptors); if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) { configureProxy(builder, connectConfig.getProxyAddress()); @@ -130,7 +150,7 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){ .keepAliveTimeout(connectConfig.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS) .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls()) .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) - .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata)); + .intercept(clientInterceptors); if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) { configureProxy(builder, connectConfig.getProxyAddress()); @@ -152,7 +172,7 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){ .keepAliveTimeout(connectConfig.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS) .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls()) .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) - .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata)); + .intercept(clientInterceptors); if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) { configureProxy(builder, connectConfig.getProxyAddress()); }