Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,13 @@ public interface ChannelEndpoint {
* @return the managed channel for this server
*/
ManagedChannel getChannel();

/**
* Returns the latency tracker for this endpoint, or null if not supported.
*
* @return the latency tracker or null
*/
default LatencyTracker getLatencyTracker() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import com.google.api.core.BetaApi;
import com.google.api.core.InternalApi;
import com.google.common.base.Preconditions;
import com.google.spanner.v1.PartialResultSet;
import io.grpc.MethodDescriptor;
import java.time.Duration;
import java.util.concurrent.TimeUnit;
import javax.annotation.concurrent.GuardedBy;
Expand Down Expand Up @@ -67,8 +69,7 @@ public double getScore() {
}
}

@Override
public void update(Duration latency) {
void update(Duration latency) {
long latencyMicros;
try {
latencyMicros = TimeUnit.MICROSECONDS.convert(latency.toNanos(), TimeUnit.NANOSECONDS);
Expand All @@ -92,4 +93,21 @@ public void recordError(Duration penalty) {
// Treat the error as a sample with high latency (penalty)
update(penalty);
}

@Override
public boolean isEligible(MethodDescriptor<?, ?> methodDescriptor) {
String methodName = methodDescriptor.getFullMethodName();
return KeyAwareChannel.STREAMING_READ_METHOD.equals(methodName)
|| KeyAwareChannel.STREAMING_SQL_METHOD.equals(methodName);
}

@Override
public void maybeUpdate(Object message, Duration latency) {
if (message instanceof PartialResultSet) {
PartialResultSet response = (PartialResultSet) message;
if (response.getLast()) {
update(latency);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ public GrpcChannelEndpointCache(InstantiatingGrpcChannelProvider channelProvider
throws IOException {
this.baseProvider = channelProvider;
String defaultEndpoint = channelProvider.getEndpoint();
this.defaultEndpoint = new GrpcChannelEndpoint(defaultEndpoint, channelProvider);
this.defaultEndpoint =
new GrpcChannelEndpoint(defaultEndpoint, channelProvider, new EwmaLatencyTracker());
this.defaultAuthority = this.defaultEndpoint.getChannel().authority();
this.servers.put(defaultEndpoint, this.defaultEndpoint);
}
Expand All @@ -92,7 +93,8 @@ public ChannelEndpoint get(String address) {
// This is thread-safe as withEndpoint() returns a new provider instance.
InstantiatingGrpcChannelProvider newProvider =
createProviderWithAuthorityOverride(addr);
GrpcChannelEndpoint endpoint = new GrpcChannelEndpoint(addr, newProvider);
GrpcChannelEndpoint endpoint =
new GrpcChannelEndpoint(addr, newProvider, new EwmaLatencyTracker());
logger.log(Level.FINE, "Location-aware endpoint created for address: {0}", addr);
return endpoint;
} catch (IOException e) {
Expand Down Expand Up @@ -178,10 +180,10 @@ private void shutdownChannel(GrpcChannelEndpoint server, boolean awaitTerminatio
}
}

/** gRPC implementation of {@link ChannelEndpoint}. */
static class GrpcChannelEndpoint implements ChannelEndpoint {
private final String address;
private final ManagedChannel channel;
private final LatencyTracker latencyTracker;

/**
* Creates a server from a channel provider.
Expand All @@ -190,7 +192,8 @@ static class GrpcChannelEndpoint implements ChannelEndpoint {
* @param provider the channel provider (must be a gRPC provider)
* @throws IOException if the channel cannot be created
*/
GrpcChannelEndpoint(String address, InstantiatingGrpcChannelProvider provider)
GrpcChannelEndpoint(
String address, InstantiatingGrpcChannelProvider provider, LatencyTracker latencyTracker)
throws IOException {
this.address = address;
// Build a raw ManagedChannel directly instead of going through getTransportChannel(),
Expand All @@ -203,6 +206,7 @@ static class GrpcChannelEndpoint implements ChannelEndpoint {
provider.withHeaders(java.util.Collections.emptyMap());
}
this.channel = readyProvider.createDecoratedChannelBuilder().build();
this.latencyTracker = latencyTracker;
}

/**
Expand All @@ -212,9 +216,10 @@ static class GrpcChannelEndpoint implements ChannelEndpoint {
* @param channel the managed channel
*/
@VisibleForTesting
GrpcChannelEndpoint(String address, ManagedChannel channel) {
GrpcChannelEndpoint(String address, ManagedChannel channel, LatencyTracker latencyTracker) {
this.address = address;
this.channel = channel;
this.latencyTracker = latencyTracker;
}

@Override
Expand Down Expand Up @@ -267,5 +272,10 @@ public boolean isTransientFailure() {
public ManagedChannel getChannel() {
return channel;
}

@Override
public LatencyTracker getLatencyTracker() {
return latencyTracker;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import java.io.IOException;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.SoftReference;
import java.time.Duration;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
Expand All @@ -72,9 +73,8 @@ final class KeyAwareChannel extends ManagedChannel {
private static final long MAX_TRACKED_READ_ONLY_TRANSACTIONS = 100_000L;
private static final long MAX_TRACKED_EXCLUDED_LOGICAL_REQUESTS = 100_000L;
private static final long EXCLUDED_LOGICAL_REQUEST_TTL_MINUTES = 10L;
private static final String STREAMING_READ_METHOD = "google.spanner.v1.Spanner/StreamingRead";
private static final String STREAMING_SQL_METHOD =
"google.spanner.v1.Spanner/ExecuteStreamingSql";
static final String STREAMING_READ_METHOD = "google.spanner.v1.Spanner/StreamingRead";
static final String STREAMING_SQL_METHOD = "google.spanner.v1.Spanner/ExecuteStreamingSql";
private static final String UNARY_SQL_METHOD = "google.spanner.v1.Spanner/ExecuteSql";
private static final String BEGIN_TRANSACTION_METHOD =
"google.spanner.v1.Spanner/BeginTransaction";
Expand Down Expand Up @@ -462,6 +462,7 @@ static final class KeyAwareClientCall<RequestT, ResponseT>
private boolean isReadOnlyBegin;
private boolean readOnlyIsStrong;
private final Object lock = new Object();
volatile long startTimeNanos;

KeyAwareClientCall(
KeyAwareChannel parentChannel,
Expand Down Expand Up @@ -610,6 +611,7 @@ public void sendMessage(RequestT message) {
}
delegate.start(responseListener, headers);
drainPendingRequests();
startTimeNanos = System.nanoTime();
delegate.sendMessage(message);
if (pendingHalfClose) {
delegate.halfClose();
Expand Down Expand Up @@ -810,6 +812,7 @@ private RoutingDecision(@Nullable ChannelFinder finder, @Nullable ChannelEndpoin
static final class KeyAwareClientCallListener<ResponseT>
extends SimpleForwardingClientCallListener<ResponseT> {
private final KeyAwareClientCall<?, ResponseT> call;
private boolean firstMessageReceived = false;

KeyAwareClientCallListener(
ClientCall.Listener<ResponseT> responseListener, KeyAwareClientCall<?, ResponseT> call) {
Expand All @@ -819,6 +822,18 @@ static final class KeyAwareClientCallListener<ResponseT>

@Override
public void onMessage(ResponseT message) {
if (!firstMessageReceived) {
firstMessageReceived = true;
// call.selectedEndpoint will in real usage never be null when we reach this
// point.
if (call.selectedEndpoint != null) {
LatencyTracker tracker = call.selectedEndpoint.getLatencyTracker();
if (tracker != null && tracker.isEligible(call.methodDescriptor)) {
Duration latency = Duration.ofNanos(System.nanoTime() - call.startTimeNanos);
tracker.maybeUpdate(message, latency);
}
}
}
ByteString transactionId = null;
if (message instanceof PartialResultSet) {
PartialResultSet response = (PartialResultSet) message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.api.core.BetaApi;
import com.google.api.core.InternalApi;
import io.grpc.MethodDescriptor;
import java.time.Duration;

/**
Expand All @@ -38,16 +39,25 @@ public interface LatencyTracker {
double getScore();

/**
* Updates the latency score with a new observation.
* Potentially updates the latency score based on the response message.
*
* @param latency the observed latency.
* @param message the response message.
* @param latency the measured latency.
*/
void update(Duration latency);
void maybeUpdate(Object message, Duration latency);

/**
* Records an error and applies a latency penalty.
*
* @param penalty the penalty to apply.
*/
void recordError(Duration penalty);

/**
* Returns whether a call with the given method descriptor is eligible for latency measurement.
*
* @param methodDescriptor the method descriptor of the call.
* @return true if eligible, false otherwise.
*/
boolean isEligible(MethodDescriptor<?, ?> methodDescriptor);
}
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,93 @@ public void resultSetCacheUpdateRoutesSubsequentRequest() throws Exception {
assertThat(harness.endpointCache.callCountForAddress("routed:1234")).isEqualTo(1);
}

@Test
public void callTracksLatencyOnMessage() throws Exception {
TestHarness harness = createHarness();
ExecuteSqlRequest request = ExecuteSqlRequest.newBuilder().setSession(SESSION).build();

ClientCall<ExecuteSqlRequest, PartialResultSet> call =
harness.channel.newCall(SpannerGrpc.getExecuteStreamingSqlMethod(), CallOptions.DEFAULT);
CapturingListener<PartialResultSet> listener = new CapturingListener<>();
call.start(listener, new Metadata());
call.sendMessage(request);

@SuppressWarnings("unchecked")
RecordingClientCall<ExecuteSqlRequest, PartialResultSet> delegate =
(RecordingClientCall<ExecuteSqlRequest, PartialResultSet>)
harness.defaultManagedChannel.latestCall();

FakeEndpoint defaultEndpoint = harness.endpointCache.defaultEndpoint;
LatencyTracker tracker = defaultEndpoint.getLatencyTracker();

double initialScore = tracker.getScore();

// Emit a message with last=true to trigger onMessage and latency update.
delegate.emitOnMessage(PartialResultSet.newBuilder().setLast(true).build());

// Verify that the score has been updated (it should not be equal to the initial score).
double newScore = tracker.getScore();
assertThat(newScore).isNotEqualTo(initialScore);
}

@Test
public void callDoesNotTrackLatencyForNonEligibleRpc() throws Exception {
TestHarness harness = createHarness();
ExecuteSqlRequest request = ExecuteSqlRequest.newBuilder().setSession(SESSION).build();

ClientCall<ExecuteSqlRequest, ResultSet> call =
harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), CallOptions.DEFAULT);
CapturingListener<ResultSet> listener = new CapturingListener<>();
call.start(listener, new Metadata());
call.sendMessage(request);

@SuppressWarnings("unchecked")
RecordingClientCall<ExecuteSqlRequest, ResultSet> delegate =
(RecordingClientCall<ExecuteSqlRequest, ResultSet>)
harness.defaultManagedChannel.latestCall();

FakeEndpoint defaultEndpoint = harness.endpointCache.defaultEndpoint;
LatencyTracker tracker = defaultEndpoint.getLatencyTracker();

double initialScore = tracker.getScore();

// Emit a message.
delegate.emitOnMessage(ResultSet.newBuilder().build());

// Verify that the score has not been updated.
double newScore = tracker.getScore();
assertThat(newScore).isEqualTo(initialScore);
}

@Test
public void callDoesNotTrackLatencyForNonLastPartialResultSet() throws Exception {
TestHarness harness = createHarness();
ExecuteSqlRequest request = ExecuteSqlRequest.newBuilder().setSession(SESSION).build();

ClientCall<ExecuteSqlRequest, PartialResultSet> call =
harness.channel.newCall(SpannerGrpc.getExecuteStreamingSqlMethod(), CallOptions.DEFAULT);
CapturingListener<PartialResultSet> listener = new CapturingListener<>();
call.start(listener, new Metadata());
call.sendMessage(request);

@SuppressWarnings("unchecked")
RecordingClientCall<ExecuteSqlRequest, PartialResultSet> delegate =
(RecordingClientCall<ExecuteSqlRequest, PartialResultSet>)
harness.defaultManagedChannel.latestCall();

FakeEndpoint defaultEndpoint = harness.endpointCache.defaultEndpoint;
LatencyTracker tracker = defaultEndpoint.getLatencyTracker();

double initialScore = tracker.getScore();

// Emit a message with last=false.
delegate.emitOnMessage(PartialResultSet.newBuilder().setLast(false).build());

// Verify that the score has not been updated.
double newScore = tracker.getScore();
assertThat(newScore).isEqualTo(initialScore);
}

@Test
public void beginTransactionWithMutationKeyAddsRoutingHint() throws Exception {
TestHarness harness = createHarness();
Expand Down Expand Up @@ -1350,12 +1437,18 @@ int callCountForAddress(String address) {
private static final class FakeEndpoint implements ChannelEndpoint {
private final String address;
private final FakeManagedChannel channel;
private final LatencyTracker latencyTracker = new EwmaLatencyTracker();

private FakeEndpoint(String address) {
this.address = address;
this.channel = new FakeManagedChannel(address);
}

@Override
public LatencyTracker getLatencyTracker() {
return latencyTracker;
}

@Override
public String getAddress() {
return address;
Expand Down
Loading