From 7d1006eb065c8f4c8f72cf338dc91cfac816fdd5 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Thu, 19 Mar 2026 01:20:19 +0530 Subject: [PATCH 1/2] fix: unbind manual affinity keys after terminal calls --- .../com/google/cloud/grpc/GcpClientCall.java | 18 ++ .../google/cloud/grpc/GcpManagedChannel.java | 7 +- .../google/cloud/grpc/GcpClientCallTest.java | 165 ++++++++++++++++++ 3 files changed, 188 insertions(+), 2 deletions(-) create mode 100644 grpc-gcp/src/test/java/com/google/cloud/grpc/GcpClientCallTest.java diff --git a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpClientCall.java b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpClientCall.java index 847c0ff..a8946f6 100644 --- a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpClientCall.java +++ b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpClientCall.java @@ -27,6 +27,7 @@ import io.grpc.MethodDescriptor; import io.grpc.Status; import java.util.ArrayDeque; +import java.util.Collections; import java.util.List; import java.util.Queue; import java.util.concurrent.atomic.AtomicBoolean; @@ -222,17 +223,24 @@ public void onMessage(RespT message) { */ public static class SimpleGcpClientCall extends ForwardingClientCall { + private final GcpManagedChannel delegateChannel; private final GcpManagedChannel.ChannelRef channelRef; private final ClientCall delegateCall; + @Nullable private final String affinityKey; + private final boolean unbindOnComplete; private long startNanos = 0; private final AtomicBoolean decremented = new AtomicBoolean(false); protected SimpleGcpClientCall( + GcpManagedChannel delegateChannel, GcpManagedChannel.ChannelRef channelRef, MethodDescriptor methodDescriptor, CallOptions callOptions) { + this.delegateChannel = delegateChannel; this.channelRef = channelRef; + this.affinityKey = callOptions.getOption(GcpManagedChannel.AFFINITY_KEY); + this.unbindOnComplete = callOptions.getOption(GcpManagedChannel.UNBIND_AFFINITY_KEY); // Set the actual channel ID in callOptions so downstream interceptors can access it. CallOptions callOptionsWithChannelId = callOptions.withOption(GcpManagedChannel.CHANNEL_ID_KEY, channelRef.getId()); @@ -257,6 +265,12 @@ public void onClose(Status status, Metadata trailers) { if (!decremented.getAndSet(true)) { channelRef.activeStreamsCountDecr(startNanos, status, false); } + // Unbind the affinity key when the caller explicitly requests it + // (e.g., on terminal RPCs like Commit or Rollback) to prevent + // unbounded growth of the affinity map. + if (unbindOnComplete && affinityKey != null) { + delegateChannel.unbind(Collections.singletonList(affinityKey)); + } super.onClose(status, trailers); } @@ -276,6 +290,10 @@ public void cancel(String message, Throwable cause) { if (!decremented.getAndSet(true)) { channelRef.activeStreamsCountDecr(startNanos, Status.CANCELLED, true); } + // Always unbind on cancel — the transaction is being abandoned. + if (affinityKey != null) { + delegateChannel.unbind(Collections.singletonList(affinityKey)); + } delegateCall.cancel(message, cause); } } diff --git a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java index f24f2b8..819890a 100644 --- a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java +++ b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java @@ -93,6 +93,9 @@ public class GcpManagedChannel extends ManagedChannel { CallOptions.Key.createWithDefault("DisableAffinity", false); public static final Context.Key AFFINITY_CTX_KEY = Context.key("AffinityKey"); public static final CallOptions.Key AFFINITY_KEY = CallOptions.Key.create("AffinityKey"); + /** When set to true, the affinity key will be unbound after the call completes. */ + public static final CallOptions.Key UNBIND_AFFINITY_KEY = + CallOptions.Key.createWithDefault("UnbindAffinityKey", false); /** * CallOptions key that will be set by grpc-gcp with the actual channel ID used for the call. This @@ -1848,7 +1851,7 @@ public ClientCall newCall( logger.finest(log("Channel affinity is disabled via context or call options.")); } return new GcpClientCall.SimpleGcpClientCall<>( - getChannelRef(null), methodDescriptor, callOptions); + this, getChannelRef(null), methodDescriptor, callOptions); } AffinityConfig affinity = methodToAffinity.get(methodDescriptor.getFullMethodName()); @@ -1858,7 +1861,7 @@ public ClientCall newCall( } return new GcpClientCall.SimpleGcpClientCall<>( - getChannelRef(key), methodDescriptor, callOptions); + this, getChannelRef(key), methodDescriptor, callOptions); } @Nullable diff --git a/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpClientCallTest.java b/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpClientCallTest.java new file mode 100644 index 0000000..19e96ac --- /dev/null +++ b/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpClientCallTest.java @@ -0,0 +1,165 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ConnectivityState; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import java.io.InputStream; +import java.util.Collections; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public final class GcpClientCallTest { + + private static final class FakeMarshaller implements MethodDescriptor.Marshaller { + @Override + public InputStream stream(T value) { + return null; + } + + @Override + public T parse(InputStream stream) { + return null; + } + } + + private static final MethodDescriptor METHOD_DESCRIPTOR = + MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName("test/method") + .setRequestMarshaller(new FakeMarshaller<>()) + .setResponseMarshaller(new FakeMarshaller<>()) + .build(); + + @Mock private ManagedChannel delegateChannel; + @Mock private ClientCall delegateCall; + + private GcpManagedChannel gcpChannel; + private GcpManagedChannel.ChannelRef channelRef; + + @Before + public void setUp() { + ManagedChannelBuilder builder = ManagedChannelBuilder.forAddress("localhost", 443); + gcpChannel = (GcpManagedChannel) GcpManagedChannelBuilder.forDelegateBuilder(builder).build(); + + when(delegateChannel.getState(anyBoolean())).thenReturn(ConnectivityState.IDLE); + when(delegateChannel.newCall(eq(METHOD_DESCRIPTOR), any(CallOptions.class))) + .thenReturn(delegateCall); + + channelRef = gcpChannel.new ChannelRef(delegateChannel); + } + + @After + public void tearDown() { + gcpChannel.shutdownNow(); + } + + @SuppressWarnings("unchecked") + @Test + public void simpleCallUnbindsAffinityKeyOnCloseWhenRequested() { + String affinityKey = "txn-1"; + gcpChannel.bind(channelRef, Collections.singletonList(affinityKey)); + + GcpClientCall.SimpleGcpClientCall call = + new GcpClientCall.SimpleGcpClientCall<>( + gcpChannel, + channelRef, + METHOD_DESCRIPTOR, + CallOptions.DEFAULT + .withOption(GcpManagedChannel.AFFINITY_KEY, affinityKey) + .withOption(GcpManagedChannel.UNBIND_AFFINITY_KEY, true)); + + call.start(new ClientCall.Listener() {}, new Metadata()); + + ArgumentCaptor> listenerCaptor = + (ArgumentCaptor>) (ArgumentCaptor) + ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(delegateCall).start(listenerCaptor.capture(), any(Metadata.class)); + + assertThat(gcpChannel.affinityKeyToChannelRef).containsKey(affinityKey); + + listenerCaptor.getValue().onClose(Status.OK, new Metadata()); + + assertThat(gcpChannel.affinityKeyToChannelRef).doesNotContainKey(affinityKey); + assertThat(channelRef.getAffinityCount()).isEqualTo(0); + } + + @SuppressWarnings("unchecked") + @Test + public void simpleCallKeepsAffinityKeyOnCloseWhenUnbindNotRequested() { + String affinityKey = "txn-2"; + gcpChannel.bind(channelRef, Collections.singletonList(affinityKey)); + + GcpClientCall.SimpleGcpClientCall call = + new GcpClientCall.SimpleGcpClientCall<>( + gcpChannel, + channelRef, + METHOD_DESCRIPTOR, + CallOptions.DEFAULT.withOption(GcpManagedChannel.AFFINITY_KEY, affinityKey)); + + call.start(new ClientCall.Listener() {}, new Metadata()); + + ArgumentCaptor> listenerCaptor = + (ArgumentCaptor>) (ArgumentCaptor) + ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(delegateCall).start(listenerCaptor.capture(), any(Metadata.class)); + + listenerCaptor.getValue().onClose(Status.OK, new Metadata()); + + assertThat(gcpChannel.affinityKeyToChannelRef).containsEntry(affinityKey, channelRef); + assertThat(channelRef.getAffinityCount()).isEqualTo(1); + } + + @Test + public void simpleCallUnbindsAffinityKeyOnCancel() { + String affinityKey = "txn-3"; + gcpChannel.bind(channelRef, Collections.singletonList(affinityKey)); + + GcpClientCall.SimpleGcpClientCall call = + new GcpClientCall.SimpleGcpClientCall<>( + gcpChannel, + channelRef, + METHOD_DESCRIPTOR, + CallOptions.DEFAULT.withOption(GcpManagedChannel.AFFINITY_KEY, affinityKey)); + + call.start(new ClientCall.Listener() {}, new Metadata()); + call.cancel("cancelled", null); + + assertThat(gcpChannel.affinityKeyToChannelRef).doesNotContainKey(affinityKey); + assertThat(channelRef.getAffinityCount()).isEqualTo(0); + verify(delegateCall).cancel("cancelled", null); + } +} From 2c900cb04fc27b2f212f9a3a521b27c77bd4a8e2 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Fri, 20 Mar 2026 13:50:14 +0530 Subject: [PATCH 2/2] run fomat --- .../java/com/google/cloud/grpc/GcpManagedChannel.java | 1 + .../java/com/google/cloud/grpc/GcpClientCallTest.java | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java index 819890a..09f1dd0 100644 --- a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java +++ b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java @@ -93,6 +93,7 @@ public class GcpManagedChannel extends ManagedChannel { CallOptions.Key.createWithDefault("DisableAffinity", false); public static final Context.Key AFFINITY_CTX_KEY = Context.key("AffinityKey"); public static final CallOptions.Key AFFINITY_KEY = CallOptions.Key.create("AffinityKey"); + /** When set to true, the affinity key will be unbound after the call completes. */ public static final CallOptions.Key UNBIND_AFFINITY_KEY = CallOptions.Key.createWithDefault("UnbindAffinityKey", false); diff --git a/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpClientCallTest.java b/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpClientCallTest.java index 19e96ac..ce113e9 100644 --- a/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpClientCallTest.java +++ b/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpClientCallTest.java @@ -105,8 +105,8 @@ public void simpleCallUnbindsAffinityKeyOnCloseWhenRequested() { call.start(new ClientCall.Listener() {}, new Metadata()); ArgumentCaptor> listenerCaptor = - (ArgumentCaptor>) (ArgumentCaptor) - ArgumentCaptor.forClass(ClientCall.Listener.class); + (ArgumentCaptor>) + (ArgumentCaptor) ArgumentCaptor.forClass(ClientCall.Listener.class); verify(delegateCall).start(listenerCaptor.capture(), any(Metadata.class)); assertThat(gcpChannel.affinityKeyToChannelRef).containsKey(affinityKey); @@ -133,8 +133,8 @@ public void simpleCallKeepsAffinityKeyOnCloseWhenUnbindNotRequested() { call.start(new ClientCall.Listener() {}, new Metadata()); ArgumentCaptor> listenerCaptor = - (ArgumentCaptor>) (ArgumentCaptor) - ArgumentCaptor.forClass(ClientCall.Listener.class); + (ArgumentCaptor>) + (ArgumentCaptor) ArgumentCaptor.forClass(ClientCall.Listener.class); verify(delegateCall).start(listenerCaptor.capture(), any(Metadata.class)); listenerCaptor.getValue().onClose(Status.OK, new Metadata());