Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions grpc-gcp/src/main/java/com/google/cloud/grpc/GcpClientCall.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import java.util.ArrayDeque;
import java.util.Collections;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -222,17 +223,24 @@ public void onMessage(RespT message) {
*/
public static class SimpleGcpClientCall<ReqT, RespT> extends ForwardingClientCall<ReqT, RespT> {

private final GcpManagedChannel delegateChannel;
private final GcpManagedChannel.ChannelRef channelRef;
private final ClientCall<ReqT, RespT> delegateCall;
@Nullable private final String affinityKey;
private final boolean unbindOnComplete;
private long startNanos = 0;

private final AtomicBoolean decremented = new AtomicBoolean(false);

protected SimpleGcpClientCall(
GcpManagedChannel delegateChannel,
GcpManagedChannel.ChannelRef channelRef,
MethodDescriptor<ReqT, RespT> methodDescriptor,
CallOptions callOptions) {
this.delegateChannel = delegateChannel;
this.channelRef = channelRef;
this.affinityKey = callOptions.getOption(GcpManagedChannel.AFFINITY_KEY);
this.unbindOnComplete = callOptions.getOption(GcpManagedChannel.UNBIND_AFFINITY_KEY);
// Set the actual channel ID in callOptions so downstream interceptors can access it.
CallOptions callOptionsWithChannelId =
callOptions.withOption(GcpManagedChannel.CHANNEL_ID_KEY, channelRef.getId());
Expand All @@ -257,6 +265,12 @@ public void onClose(Status status, Metadata trailers) {
if (!decremented.getAndSet(true)) {
channelRef.activeStreamsCountDecr(startNanos, status, false);
}
// Unbind the affinity key when the caller explicitly requests it
// (e.g., on terminal RPCs like Commit or Rollback) to prevent
// unbounded growth of the affinity map.
if (unbindOnComplete && affinityKey != null) {
delegateChannel.unbind(Collections.singletonList(affinityKey));
}
super.onClose(status, trailers);
}

Expand All @@ -276,6 +290,10 @@ public void cancel(String message, Throwable cause) {
if (!decremented.getAndSet(true)) {
channelRef.activeStreamsCountDecr(startNanos, Status.CANCELLED, true);
}
// Always unbind on cancel — the transaction is being abandoned.
if (affinityKey != null) {
delegateChannel.unbind(Collections.singletonList(affinityKey));
}
delegateCall.cancel(message, cause);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ public class GcpManagedChannel extends ManagedChannel {
public static final Context.Key<String> AFFINITY_CTX_KEY = Context.key("AffinityKey");
public static final CallOptions.Key<String> AFFINITY_KEY = CallOptions.Key.create("AffinityKey");

/** When set to true, the affinity key will be unbound after the call completes. */
public static final CallOptions.Key<Boolean> UNBIND_AFFINITY_KEY =
CallOptions.Key.createWithDefault("UnbindAffinityKey", false);

/**
* CallOptions key that will be set by grpc-gcp with the actual channel ID used for the call. This
* can be read by downstream interceptors to get the real channel ID after channel selection.
Expand Down Expand Up @@ -1848,7 +1852,7 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
logger.finest(log("Channel affinity is disabled via context or call options."));
}
return new GcpClientCall.SimpleGcpClientCall<>(
getChannelRef(null), methodDescriptor, callOptions);
this, getChannelRef(null), methodDescriptor, callOptions);
}

AffinityConfig affinity = methodToAffinity.get(methodDescriptor.getFullMethodName());
Expand All @@ -1858,7 +1862,7 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
}

return new GcpClientCall.SimpleGcpClientCall<>(
getChannelRef(key), methodDescriptor, callOptions);
this, getChannelRef(key), methodDescriptor, callOptions);
}

@Nullable
Expand Down
165 changes: 165 additions & 0 deletions grpc-gcp/src/test/java/com/google/cloud/grpc/GcpClientCallTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.grpc;

import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import java.io.InputStream;
import java.util.Collections;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;

@RunWith(MockitoJUnitRunner.class)
public final class GcpClientCallTest {

private static final class FakeMarshaller<T> implements MethodDescriptor.Marshaller<T> {
@Override
public InputStream stream(T value) {
return null;
}

@Override
public T parse(InputStream stream) {
return null;
}
}

private static final MethodDescriptor<String, String> METHOD_DESCRIPTOR =
MethodDescriptor.<String, String>newBuilder()
.setType(MethodDescriptor.MethodType.UNARY)
.setFullMethodName("test/method")
.setRequestMarshaller(new FakeMarshaller<>())
.setResponseMarshaller(new FakeMarshaller<>())
.build();

@Mock private ManagedChannel delegateChannel;
@Mock private ClientCall<String, String> delegateCall;

private GcpManagedChannel gcpChannel;
private GcpManagedChannel.ChannelRef channelRef;

@Before
public void setUp() {
ManagedChannelBuilder<?> builder = ManagedChannelBuilder.forAddress("localhost", 443);
gcpChannel = (GcpManagedChannel) GcpManagedChannelBuilder.forDelegateBuilder(builder).build();

when(delegateChannel.getState(anyBoolean())).thenReturn(ConnectivityState.IDLE);
when(delegateChannel.newCall(eq(METHOD_DESCRIPTOR), any(CallOptions.class)))
.thenReturn(delegateCall);

channelRef = gcpChannel.new ChannelRef(delegateChannel);
}

@After
public void tearDown() {
gcpChannel.shutdownNow();
}

@SuppressWarnings("unchecked")
@Test
public void simpleCallUnbindsAffinityKeyOnCloseWhenRequested() {
String affinityKey = "txn-1";
gcpChannel.bind(channelRef, Collections.singletonList(affinityKey));

GcpClientCall.SimpleGcpClientCall<String, String> call =
new GcpClientCall.SimpleGcpClientCall<>(
gcpChannel,
channelRef,
METHOD_DESCRIPTOR,
CallOptions.DEFAULT
.withOption(GcpManagedChannel.AFFINITY_KEY, affinityKey)
.withOption(GcpManagedChannel.UNBIND_AFFINITY_KEY, true));

call.start(new ClientCall.Listener<String>() {}, new Metadata());

ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor =
(ArgumentCaptor<ClientCall.Listener<String>>)
(ArgumentCaptor<?>) ArgumentCaptor.forClass(ClientCall.Listener.class);
verify(delegateCall).start(listenerCaptor.capture(), any(Metadata.class));

assertThat(gcpChannel.affinityKeyToChannelRef).containsKey(affinityKey);

listenerCaptor.getValue().onClose(Status.OK, new Metadata());

assertThat(gcpChannel.affinityKeyToChannelRef).doesNotContainKey(affinityKey);
assertThat(channelRef.getAffinityCount()).isEqualTo(0);
}

@SuppressWarnings("unchecked")
@Test
public void simpleCallKeepsAffinityKeyOnCloseWhenUnbindNotRequested() {
String affinityKey = "txn-2";
gcpChannel.bind(channelRef, Collections.singletonList(affinityKey));

GcpClientCall.SimpleGcpClientCall<String, String> call =
new GcpClientCall.SimpleGcpClientCall<>(
gcpChannel,
channelRef,
METHOD_DESCRIPTOR,
CallOptions.DEFAULT.withOption(GcpManagedChannel.AFFINITY_KEY, affinityKey));

call.start(new ClientCall.Listener<String>() {}, new Metadata());

ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor =
(ArgumentCaptor<ClientCall.Listener<String>>)
(ArgumentCaptor<?>) ArgumentCaptor.forClass(ClientCall.Listener.class);
verify(delegateCall).start(listenerCaptor.capture(), any(Metadata.class));

listenerCaptor.getValue().onClose(Status.OK, new Metadata());

assertThat(gcpChannel.affinityKeyToChannelRef).containsEntry(affinityKey, channelRef);
assertThat(channelRef.getAffinityCount()).isEqualTo(1);
}

@Test
public void simpleCallUnbindsAffinityKeyOnCancel() {
String affinityKey = "txn-3";
gcpChannel.bind(channelRef, Collections.singletonList(affinityKey));

GcpClientCall.SimpleGcpClientCall<String, String> call =
new GcpClientCall.SimpleGcpClientCall<>(
gcpChannel,
channelRef,
METHOD_DESCRIPTOR,
CallOptions.DEFAULT.withOption(GcpManagedChannel.AFFINITY_KEY, affinityKey));

call.start(new ClientCall.Listener<String>() {}, new Metadata());
call.cancel("cancelled", null);

assertThat(gcpChannel.affinityKeyToChannelRef).doesNotContainKey(affinityKey);
assertThat(channelRef.getAffinityCount()).isEqualTo(0);
verify(delegateCall).cancel("cancelled", null);
}
}