diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java index fc82c530fc6f..6624303407a4 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java @@ -71,4 +71,13 @@ public interface ChannelEndpoint { * @return the managed channel for this server */ ManagedChannel getChannel(); + + /** + * Returns the latency tracker for this endpoint, or null if not supported. + * + * @return the latency tracker or null + */ + default LatencyTracker getLatencyTracker() { + return null; + } } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java index 0cb2331660f9..213ed500d76e 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java @@ -19,6 +19,8 @@ import com.google.api.core.BetaApi; import com.google.api.core.InternalApi; import com.google.common.base.Preconditions; +import com.google.spanner.v1.PartialResultSet; +import io.grpc.MethodDescriptor; import java.time.Duration; import java.util.concurrent.TimeUnit; import javax.annotation.concurrent.GuardedBy; @@ -67,8 +69,7 @@ public double getScore() { } } - @Override - public void update(Duration latency) { + void update(Duration latency) { long latencyMicros; try { latencyMicros = TimeUnit.MICROSECONDS.convert(latency.toNanos(), TimeUnit.NANOSECONDS); @@ -92,4 +93,21 @@ public void recordError(Duration penalty) { // Treat the error as a sample with high latency (penalty) update(penalty); } + + @Override + public boolean isEligible(MethodDescriptor methodDescriptor) { + String methodName = methodDescriptor.getFullMethodName(); + return KeyAwareChannel.STREAMING_READ_METHOD.equals(methodName) + || KeyAwareChannel.STREAMING_SQL_METHOD.equals(methodName); + } + + @Override + public void maybeUpdate(Object message, Duration latency) { + if (message instanceof PartialResultSet) { + PartialResultSet response = (PartialResultSet) message; + if (response.getLast()) { + update(latency); + } + } + } } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java index 98e7f83b094f..c3a1bf0b8518 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java @@ -67,7 +67,8 @@ public GrpcChannelEndpointCache(InstantiatingGrpcChannelProvider channelProvider throws IOException { this.baseProvider = channelProvider; String defaultEndpoint = channelProvider.getEndpoint(); - this.defaultEndpoint = new GrpcChannelEndpoint(defaultEndpoint, channelProvider); + this.defaultEndpoint = + new GrpcChannelEndpoint(defaultEndpoint, channelProvider, new EwmaLatencyTracker()); this.defaultAuthority = this.defaultEndpoint.getChannel().authority(); this.servers.put(defaultEndpoint, this.defaultEndpoint); } @@ -92,7 +93,8 @@ public ChannelEndpoint get(String address) { // This is thread-safe as withEndpoint() returns a new provider instance. InstantiatingGrpcChannelProvider newProvider = createProviderWithAuthorityOverride(addr); - GrpcChannelEndpoint endpoint = new GrpcChannelEndpoint(addr, newProvider); + GrpcChannelEndpoint endpoint = + new GrpcChannelEndpoint(addr, newProvider, new EwmaLatencyTracker()); logger.log(Level.FINE, "Location-aware endpoint created for address: {0}", addr); return endpoint; } catch (IOException e) { @@ -178,10 +180,10 @@ private void shutdownChannel(GrpcChannelEndpoint server, boolean awaitTerminatio } } - /** gRPC implementation of {@link ChannelEndpoint}. */ static class GrpcChannelEndpoint implements ChannelEndpoint { private final String address; private final ManagedChannel channel; + private final LatencyTracker latencyTracker; /** * Creates a server from a channel provider. @@ -190,7 +192,8 @@ static class GrpcChannelEndpoint implements ChannelEndpoint { * @param provider the channel provider (must be a gRPC provider) * @throws IOException if the channel cannot be created */ - GrpcChannelEndpoint(String address, InstantiatingGrpcChannelProvider provider) + GrpcChannelEndpoint( + String address, InstantiatingGrpcChannelProvider provider, LatencyTracker latencyTracker) throws IOException { this.address = address; // Build a raw ManagedChannel directly instead of going through getTransportChannel(), @@ -203,6 +206,7 @@ static class GrpcChannelEndpoint implements ChannelEndpoint { provider.withHeaders(java.util.Collections.emptyMap()); } this.channel = readyProvider.createDecoratedChannelBuilder().build(); + this.latencyTracker = latencyTracker; } /** @@ -212,9 +216,10 @@ static class GrpcChannelEndpoint implements ChannelEndpoint { * @param channel the managed channel */ @VisibleForTesting - GrpcChannelEndpoint(String address, ManagedChannel channel) { + GrpcChannelEndpoint(String address, ManagedChannel channel, LatencyTracker latencyTracker) { this.address = address; this.channel = channel; + this.latencyTracker = latencyTracker; } @Override @@ -267,5 +272,10 @@ public boolean isTransientFailure() { public ManagedChannel getChannel() { return channel; } + + @Override + public LatencyTracker getLatencyTracker() { + return latencyTracker; + } } } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java index d7b32f72bcd6..7201bf477a28 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java @@ -46,6 +46,7 @@ import java.io.IOException; import java.lang.ref.ReferenceQueue; import java.lang.ref.SoftReference; +import java.time.Duration; import java.util.HashSet; import java.util.Map; import java.util.Set; @@ -72,9 +73,8 @@ final class KeyAwareChannel extends ManagedChannel { private static final long MAX_TRACKED_READ_ONLY_TRANSACTIONS = 100_000L; private static final long MAX_TRACKED_EXCLUDED_LOGICAL_REQUESTS = 100_000L; private static final long EXCLUDED_LOGICAL_REQUEST_TTL_MINUTES = 10L; - private static final String STREAMING_READ_METHOD = "google.spanner.v1.Spanner/StreamingRead"; - private static final String STREAMING_SQL_METHOD = - "google.spanner.v1.Spanner/ExecuteStreamingSql"; + static final String STREAMING_READ_METHOD = "google.spanner.v1.Spanner/StreamingRead"; + static final String STREAMING_SQL_METHOD = "google.spanner.v1.Spanner/ExecuteStreamingSql"; private static final String UNARY_SQL_METHOD = "google.spanner.v1.Spanner/ExecuteSql"; private static final String BEGIN_TRANSACTION_METHOD = "google.spanner.v1.Spanner/BeginTransaction"; @@ -462,6 +462,7 @@ static final class KeyAwareClientCall private boolean isReadOnlyBegin; private boolean readOnlyIsStrong; private final Object lock = new Object(); + volatile long startTimeNanos; KeyAwareClientCall( KeyAwareChannel parentChannel, @@ -610,6 +611,7 @@ public void sendMessage(RequestT message) { } delegate.start(responseListener, headers); drainPendingRequests(); + startTimeNanos = System.nanoTime(); delegate.sendMessage(message); if (pendingHalfClose) { delegate.halfClose(); @@ -810,6 +812,7 @@ private RoutingDecision(@Nullable ChannelFinder finder, @Nullable ChannelEndpoin static final class KeyAwareClientCallListener extends SimpleForwardingClientCallListener { private final KeyAwareClientCall call; + private boolean firstMessageReceived = false; KeyAwareClientCallListener( ClientCall.Listener responseListener, KeyAwareClientCall call) { @@ -819,6 +822,18 @@ static final class KeyAwareClientCallListener @Override public void onMessage(ResponseT message) { + if (!firstMessageReceived) { + firstMessageReceived = true; + // call.selectedEndpoint will in real usage never be null when we reach this + // point. + if (call.selectedEndpoint != null) { + LatencyTracker tracker = call.selectedEndpoint.getLatencyTracker(); + if (tracker != null && tracker.isEligible(call.methodDescriptor)) { + Duration latency = Duration.ofNanos(System.nanoTime() - call.startTimeNanos); + tracker.maybeUpdate(message, latency); + } + } + } ByteString transactionId = null; if (message instanceof PartialResultSet) { PartialResultSet response = (PartialResultSet) message; diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/LatencyTracker.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/LatencyTracker.java index d7467853492d..c70bcc144eec 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/LatencyTracker.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/LatencyTracker.java @@ -18,6 +18,7 @@ import com.google.api.core.BetaApi; import com.google.api.core.InternalApi; +import io.grpc.MethodDescriptor; import java.time.Duration; /** @@ -38,11 +39,12 @@ public interface LatencyTracker { double getScore(); /** - * Updates the latency score with a new observation. + * Potentially updates the latency score based on the response message. * - * @param latency the observed latency. + * @param message the response message. + * @param latency the measured latency. */ - void update(Duration latency); + void maybeUpdate(Object message, Duration latency); /** * Records an error and applies a latency penalty. @@ -50,4 +52,12 @@ public interface LatencyTracker { * @param penalty the penalty to apply. */ void recordError(Duration penalty); + + /** + * Returns whether a call with the given method descriptor is eligible for latency measurement. + * + * @param methodDescriptor the method descriptor of the call. + * @return true if eligible, false otherwise. + */ + boolean isEligible(MethodDescriptor methodDescriptor); } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java index 1ad3888b4f9d..cad6b70b5673 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java @@ -286,6 +286,93 @@ public void resultSetCacheUpdateRoutesSubsequentRequest() throws Exception { assertThat(harness.endpointCache.callCountForAddress("routed:1234")).isEqualTo(1); } + @Test + public void callTracksLatencyOnMessage() throws Exception { + TestHarness harness = createHarness(); + ExecuteSqlRequest request = ExecuteSqlRequest.newBuilder().setSession(SESSION).build(); + + ClientCall call = + harness.channel.newCall(SpannerGrpc.getExecuteStreamingSqlMethod(), CallOptions.DEFAULT); + CapturingListener listener = new CapturingListener<>(); + call.start(listener, new Metadata()); + call.sendMessage(request); + + @SuppressWarnings("unchecked") + RecordingClientCall delegate = + (RecordingClientCall) + harness.defaultManagedChannel.latestCall(); + + FakeEndpoint defaultEndpoint = harness.endpointCache.defaultEndpoint; + LatencyTracker tracker = defaultEndpoint.getLatencyTracker(); + + double initialScore = tracker.getScore(); + + // Emit a message with last=true to trigger onMessage and latency update. + delegate.emitOnMessage(PartialResultSet.newBuilder().setLast(true).build()); + + // Verify that the score has been updated (it should not be equal to the initial score). + double newScore = tracker.getScore(); + assertThat(newScore).isNotEqualTo(initialScore); + } + + @Test + public void callDoesNotTrackLatencyForNonEligibleRpc() throws Exception { + TestHarness harness = createHarness(); + ExecuteSqlRequest request = ExecuteSqlRequest.newBuilder().setSession(SESSION).build(); + + ClientCall call = + harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), CallOptions.DEFAULT); + CapturingListener listener = new CapturingListener<>(); + call.start(listener, new Metadata()); + call.sendMessage(request); + + @SuppressWarnings("unchecked") + RecordingClientCall delegate = + (RecordingClientCall) + harness.defaultManagedChannel.latestCall(); + + FakeEndpoint defaultEndpoint = harness.endpointCache.defaultEndpoint; + LatencyTracker tracker = defaultEndpoint.getLatencyTracker(); + + double initialScore = tracker.getScore(); + + // Emit a message. + delegate.emitOnMessage(ResultSet.newBuilder().build()); + + // Verify that the score has not been updated. + double newScore = tracker.getScore(); + assertThat(newScore).isEqualTo(initialScore); + } + + @Test + public void callDoesNotTrackLatencyForNonLastPartialResultSet() throws Exception { + TestHarness harness = createHarness(); + ExecuteSqlRequest request = ExecuteSqlRequest.newBuilder().setSession(SESSION).build(); + + ClientCall call = + harness.channel.newCall(SpannerGrpc.getExecuteStreamingSqlMethod(), CallOptions.DEFAULT); + CapturingListener listener = new CapturingListener<>(); + call.start(listener, new Metadata()); + call.sendMessage(request); + + @SuppressWarnings("unchecked") + RecordingClientCall delegate = + (RecordingClientCall) + harness.defaultManagedChannel.latestCall(); + + FakeEndpoint defaultEndpoint = harness.endpointCache.defaultEndpoint; + LatencyTracker tracker = defaultEndpoint.getLatencyTracker(); + + double initialScore = tracker.getScore(); + + // Emit a message with last=false. + delegate.emitOnMessage(PartialResultSet.newBuilder().setLast(false).build()); + + // Verify that the score has not been updated. + double newScore = tracker.getScore(); + assertThat(newScore).isEqualTo(initialScore); + } + @Test public void beginTransactionWithMutationKeyAddsRoutingHint() throws Exception { TestHarness harness = createHarness(); @@ -1350,12 +1437,18 @@ int callCountForAddress(String address) { private static final class FakeEndpoint implements ChannelEndpoint { private final String address; private final FakeManagedChannel channel; + private final LatencyTracker latencyTracker = new EwmaLatencyTracker(); private FakeEndpoint(String address) { this.address = address; this.channel = new FakeManagedChannel(address); } + @Override + public LatencyTracker getLatencyTracker() { + return latencyTracker; + } + @Override public String getAddress() { return address;