Skip to content

Commit 6710ae0

Browse files
authored
chore(spanner): use channel affinity (#13231)
Internal reference: go/grpc-gcp-fixes#bookmark=id.q4x3oa8l672
1 parent 502841b commit 6710ae0

19 files changed

Lines changed: 546 additions & 357 deletions

grpc-gcp-java/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,40 @@ public class GcpManagedChannel extends ManagedChannel {
107107
public static final CallOptions.Key<Integer> CHANNEL_ID_KEY =
108108
CallOptions.Key.create("GcpChannelId");
109109

110+
/** CallOptions key for sticky channel routing without affinity-key map state. */
111+
public static final CallOptions.Key<ChannelAffinityRef> CHANNEL_AFFINITY_REF_KEY =
112+
CallOptions.Key.create("GcpChannelAffinityRef");
113+
114+
/** Opaque sticky channel reference for callers that should not depend on {@link ChannelRef}. */
115+
public static final class ChannelAffinityRef {
116+
private static final int USE_DIFFERENT_CHANNEL_ON_NEXT_CALL_MASK = 1 << 31;
117+
private static final int CHANNEL_ID_MASK = ~USE_DIFFERENT_CHANNEL_ON_NEXT_CALL_MASK;
118+
private static final int NO_CHANNEL_ID = -1;
119+
120+
// Single allocation hot-path state:
121+
// * lower 31 bits: channel id + 1, or 0 when unset.
122+
// * high bit: use a different active channel on the next call.
123+
private final AtomicInteger state = new AtomicInteger();
124+
125+
/** Forces the next RPC to prefer a different active channel if one is available. */
126+
public void useDifferentChannelOnNextCall() {
127+
state.getAndUpdate(value -> value | USE_DIFFERENT_CHANNEL_ON_NEXT_CALL_MASK);
128+
}
129+
130+
private static int channelIdFromState(int state) {
131+
int encodedChannelId = state & CHANNEL_ID_MASK;
132+
return encodedChannelId == 0 ? NO_CHANNEL_ID : encodedChannelId - 1;
133+
}
134+
135+
private static boolean useDifferentChannelOnNextCallFromState(int state) {
136+
return (state & USE_DIFFERENT_CHANNEL_ON_NEXT_CALL_MASK) != 0;
137+
}
138+
139+
private static int stateFromChannelId(int channelId) {
140+
return (channelId + 1) & CHANNEL_ID_MASK;
141+
}
142+
}
143+
110144
@GuardedBy("this")
111145
private Integer bindingIndex = -1;
112146

@@ -140,6 +174,7 @@ public class GcpManagedChannel extends ManagedChannel {
140174

141175
// The channel pool.
142176
@VisibleForTesting final List<ChannelRef> channelRefs = new CopyOnWriteArrayList<>();
177+
private final Map<Integer, ChannelRef> channelIdToChannelRef = new ConcurrentHashMap<>();
143178
// A set of channels that we removed from the pool and wait for their RPCs to be completed before
144179
// we can shut them down.
145180
final Set<ChannelRef> removedChannelRefs = new HashSet<>();
@@ -352,6 +387,7 @@ private synchronized void checkScaleDown() {
352387
channelRef.getChannel().shutdown();
353388
// Remove channel from broken channels map.
354389
fallbackMap.remove(channelRef.getId());
390+
channelIdToChannelRef.remove(channelRef.getId());
355391
}
356392
}
357393

@@ -372,6 +408,7 @@ private void removeOldestChannels(int num) {
372408

373409
for (ChannelRef channelRef : channelsToRemove) {
374410
channelRef.resetAffinityCount();
411+
channelRef.deactivate();
375412
if (channelRef.getState() == ConnectivityState.READY) {
376413
decReadyChannels(false);
377414
}
@@ -1678,6 +1715,59 @@ protected ChannelRef getChannelRef(@Nullable String key) {
16781715
return mappedChannel;
16791716
}
16801717

1718+
/**
1719+
* Pick a {@link ChannelRef} using a caller-owned reference instead of grpc-gcp's affinity map.
1720+
*/
1721+
protected ChannelRef getChannelRefByAffinityRef(ChannelAffinityRef affinityRef) {
1722+
maybeDynamicUpscale();
1723+
// Retry if another thread updates the caller-owned affinity ref while we are picking a channel.
1724+
while (true) {
1725+
int state = affinityRef.state.get();
1726+
int channelId = ChannelAffinityRef.channelIdFromState(state);
1727+
boolean useDifferentChannel =
1728+
ChannelAffinityRef.useDifferentChannelOnNextCallFromState(state);
1729+
ChannelRef channelRef =
1730+
channelId == ChannelAffinityRef.NO_CHANNEL_ID
1731+
? null
1732+
: channelIdToChannelRef.get(channelId);
1733+
if (!useDifferentChannel && channelRef != null && channelRef.isActive()) {
1734+
return channelRef;
1735+
}
1736+
1737+
ChannelRef selectedChannelRef =
1738+
useDifferentChannel
1739+
? pickLeastBusyChannelDifferentFrom(channelRef)
1740+
: pickLeastBusyChannel(/* forFallback= */ false);
1741+
if (affinityRef.state.compareAndSet(
1742+
state, ChannelAffinityRef.stateFromChannelId(selectedChannelRef.getId()))) {
1743+
return selectedChannelRef;
1744+
}
1745+
}
1746+
}
1747+
1748+
private ChannelRef pickLeastBusyChannelDifferentFrom(@Nullable ChannelRef excludedChannelRef) {
1749+
ChannelRef channelRef = pickLeastBusyChannel(/* forFallback= */ false);
1750+
if (excludedChannelRef == null || channelRefs.size() <= 1) {
1751+
return channelRef;
1752+
}
1753+
if (channelRef != excludedChannelRef && channelRef.isActive()) {
1754+
return channelRef;
1755+
}
1756+
ChannelRef leastBusyChannelRef = null;
1757+
int leastBusyStreams = Integer.MAX_VALUE;
1758+
for (ChannelRef candidate : channelRefs) {
1759+
if (candidate == excludedChannelRef || !candidate.isActive()) {
1760+
continue;
1761+
}
1762+
int streams = candidate.getActiveStreamsCount();
1763+
if (leastBusyChannelRef == null || streams < leastBusyStreams) {
1764+
leastBusyChannelRef = candidate;
1765+
leastBusyStreams = streams;
1766+
}
1767+
}
1768+
return leastBusyChannelRef == null ? channelRef : leastBusyChannelRef;
1769+
}
1770+
16811771
// Create a new channel and add it to channelRefs.
16821772
// If we have a ready channel not in the pool that we wait for completing its RPCs,
16831773
// then re-use that channel instead.
@@ -1688,6 +1778,8 @@ ChannelRef createNewChannel() {
16881778
ChannelRef chRef = reusedChannelRef.get();
16891779
channelRefs.add(chRef);
16901780
removedChannelRefs.remove(chRef);
1781+
channelIdToChannelRef.put(chRef.getId(), chRef);
1782+
chRef.activate();
16911783
logger.finer(log("Channel %d reused.", chRef.getId()));
16921784
incReadyChannels(false);
16931785
maxChannels.accumulateAndGet(getNumberOfChannels(), Math::max);
@@ -1696,6 +1788,7 @@ ChannelRef createNewChannel() {
16961788

16971789
ChannelRef channelRef = new ChannelRef(delegateChannelBuilder.build());
16981790
channelRefs.add(channelRef);
1791+
channelIdToChannelRef.put(channelRef.getId(), channelRef);
16991792
logger.finer(log("Channel %d created.", channelRef.getId()));
17001793
maxChannels.accumulateAndGet(getNumberOfChannels(), Math::max);
17011794
return channelRef;
@@ -1961,6 +2054,12 @@ public String authority() {
19612054
@Override
19622055
public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
19632056
MethodDescriptor<ReqT, RespT> methodDescriptor, CallOptions callOptions) {
2057+
ChannelAffinityRef channelAffinityRef = callOptions.getOption(CHANNEL_AFFINITY_REF_KEY);
2058+
if (channelAffinityRef != null) {
2059+
return new GcpClientCall.SimpleGcpClientCall<>(
2060+
this, getChannelRefByAffinityRef(channelAffinityRef), methodDescriptor, callOptions);
2061+
}
2062+
19642063
if (callOptions.getOption(DISABLE_AFFINITY_KEY)
19652064
|| DISABLE_AFFINITY_CTX_KEY.get(Context.current())) {
19662065
if (logger.isLoggable(Level.FINEST)) {
@@ -2314,6 +2413,7 @@ protected class ChannelRef {
23142413
private final AtomicLong okCalls = new AtomicLong();
23152414
private final AtomicLong errCalls = new AtomicLong();
23162415
private final ChannelStateMonitor channelStateMonitor;
2416+
private volatile boolean active = true;
23172417

23182418
protected ChannelRef(ManagedChannel channel) {
23192419
this(channel, 0, 0);
@@ -2343,6 +2443,18 @@ protected int getId() {
23432443
return channelId;
23442444
}
23452445

2446+
protected boolean isActive() {
2447+
return active;
2448+
}
2449+
2450+
private void activate() {
2451+
active = true;
2452+
}
2453+
2454+
private void deactivate() {
2455+
active = false;
2456+
}
2457+
23462458
protected void affinityCountIncr() {
23472459
int count = affinityCount.incrementAndGet();
23482460
maxAffinity.accumulateAndGet(count, Math::max);

grpc-gcp-java/src/test/java/com/google/cloud/grpc/ChannelIdPropagationTest.java

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import static com.google.common.truth.Truth.assertThat;
2020

21+
import com.google.cloud.grpc.GcpManagedChannel.ChannelAffinityRef;
2122
import com.google.cloud.grpc.GcpManagedChannelOptions.GcpChannelPoolOptions;
2223
import io.grpc.CallOptions;
2324
import io.grpc.Channel;
@@ -28,13 +29,22 @@
2829
import io.grpc.Metadata;
2930
import io.grpc.MethodDescriptor;
3031
import java.io.InputStream;
32+
import java.util.ArrayList;
33+
import java.util.List;
3134
import java.util.concurrent.atomic.AtomicInteger;
3235
import org.junit.Test;
3336
import org.junit.runner.RunWith;
3437
import org.junit.runners.JUnit4;
3538

3639
@RunWith(JUnit4.class)
3740
public class ChannelIdPropagationTest {
41+
private static final MethodDescriptor<String, String> METHOD_DESCRIPTOR =
42+
MethodDescriptor.<String, String>newBuilder()
43+
.setType(MethodDescriptor.MethodType.UNARY)
44+
.setFullMethodName("test/method")
45+
.setRequestMarshaller(new FakeMarshaller<>())
46+
.setResponseMarshaller(new FakeMarshaller<>())
47+
.build();
3848

3949
private static class FakeMarshaller<T> implements MethodDescriptor.Marshaller<T> {
4050
@Override
@@ -85,16 +95,8 @@ public void start(Listener<RespT> responseListener, Metadata headers) {
8595
.build())
8696
.build();
8797

88-
MethodDescriptor<String, String> methodDescriptor =
89-
MethodDescriptor.<String, String>newBuilder()
90-
.setType(MethodDescriptor.MethodType.UNARY)
91-
.setFullMethodName("test/method")
92-
.setRequestMarshaller(new FakeMarshaller<>())
93-
.setResponseMarshaller(new FakeMarshaller<>())
94-
.build();
95-
9698
// Use the pool directly (interceptor is already inside)
97-
ClientCall<String, String> newCall = pool.newCall(methodDescriptor, CallOptions.DEFAULT);
99+
ClientCall<String, String> newCall = pool.newCall(METHOD_DESCRIPTOR, CallOptions.DEFAULT);
98100
Metadata headers = new Metadata();
99101

100102
// First call (should initialize channel and correct ID)
@@ -105,7 +107,7 @@ public void start(Listener<RespT> responseListener, Metadata headers) {
105107
assertThat(channelId.get()).isAnyOf(0, 1, 2);
106108

107109
// Attempt 2
108-
newCall = pool.newCall(methodDescriptor, CallOptions.DEFAULT);
110+
newCall = pool.newCall(METHOD_DESCRIPTOR, CallOptions.DEFAULT);
109111
newCall.start(
110112
new ForwardingClientCall.SimpleForwardingClientCall.Listener<String>() {}, headers);
111113

@@ -114,4 +116,82 @@ public void start(Listener<RespT> responseListener, Metadata headers) {
114116

115117
pool.shutdownNow();
116118
}
119+
120+
@Test
121+
public void testChannelAffinityRefSticksToSameChannel() {
122+
List<Integer> channelIds = new ArrayList<>();
123+
GcpManagedChannel pool = newPoolWithChannelIdInterceptor(channelIds);
124+
125+
try {
126+
ChannelAffinityRef affinityRef = new ChannelAffinityRef();
127+
CallOptions callOptions =
128+
CallOptions.DEFAULT.withOption(GcpManagedChannel.CHANNEL_AFFINITY_REF_KEY, affinityRef);
129+
130+
startCall(pool, callOptions);
131+
startCall(pool, callOptions);
132+
startCall(pool, callOptions);
133+
134+
assertThat(channelIds).hasSize(3);
135+
assertThat(channelIds.get(1)).isEqualTo(channelIds.get(0));
136+
assertThat(channelIds.get(2)).isEqualTo(channelIds.get(0));
137+
assertThat(pool.affinityKeyToChannelRef).isEmpty();
138+
} finally {
139+
pool.shutdownNow();
140+
}
141+
}
142+
143+
@Test
144+
public void testChannelAffinityRefCanMoveToDifferentChannelOnNextCall() {
145+
List<Integer> channelIds = new ArrayList<>();
146+
GcpManagedChannel pool = newPoolWithChannelIdInterceptor(channelIds);
147+
148+
try {
149+
ChannelAffinityRef affinityRef = new ChannelAffinityRef();
150+
CallOptions callOptions =
151+
CallOptions.DEFAULT.withOption(GcpManagedChannel.CHANNEL_AFFINITY_REF_KEY, affinityRef);
152+
153+
startCall(pool, callOptions);
154+
affinityRef.useDifferentChannelOnNextCall();
155+
startCall(pool, callOptions);
156+
startCall(pool, callOptions);
157+
158+
assertThat(channelIds).hasSize(3);
159+
assertThat(channelIds.get(1)).isNotEqualTo(channelIds.get(0));
160+
assertThat(channelIds.get(2)).isEqualTo(channelIds.get(1));
161+
assertThat(pool.affinityKeyToChannelRef).isEmpty();
162+
} finally {
163+
pool.shutdownNow();
164+
}
165+
}
166+
167+
private static GcpManagedChannel newPoolWithChannelIdInterceptor(List<Integer> channelIds) {
168+
ManagedChannelBuilder<?> builder = ManagedChannelBuilder.forAddress("localhost", 443);
169+
builder.intercept(
170+
new ClientInterceptor() {
171+
@Override
172+
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
173+
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
174+
Integer channelId = callOptions.getOption(GcpManagedChannel.CHANNEL_ID_KEY);
175+
if (channelId != null) {
176+
channelIds.add(channelId);
177+
}
178+
return next.newCall(method, callOptions);
179+
}
180+
});
181+
return (GcpManagedChannel)
182+
GcpManagedChannelBuilder.forDelegateBuilder(builder)
183+
.withOptions(
184+
GcpManagedChannelOptions.newBuilder()
185+
.withChannelPoolOptions(
186+
GcpChannelPoolOptions.newBuilder().setMinSize(3).setMaxSize(3).build())
187+
.build())
188+
.build();
189+
}
190+
191+
private static void startCall(GcpManagedChannel pool, CallOptions callOptions) {
192+
pool.newCall(METHOD_DESCRIPTOR, callOptions)
193+
.start(
194+
new ForwardingClientCall.SimpleForwardingClientCall.Listener<String>() {},
195+
new Metadata());
196+
}
117197
}

0 commit comments

Comments
 (0)