Skip to content

Commit 131d6ff

Browse files
committed
feat(gax): implement dynamic channel refreshing on 401 retries
1 parent 39e93fe commit 131d6ff

10 files changed

Lines changed: 183 additions & 21 deletions

File tree

sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class ChannelPool extends ManagedChannel {
8282
private ScheduledFuture<?> resizeFuture = null;
8383

8484
private final Object entryWriteLock = new Object();
85+
private long lastRefreshTimeNanos = 0;
8586
@VisibleForTesting final AtomicReference<ImmutableList<Entry>> entries = new AtomicReference<>();
8687
private final AtomicInteger indexTicker = new AtomicInteger();
8788
private final String authority;
@@ -441,6 +442,13 @@ void refresh() {
441442
// - then thread2 will shut down channel that thread1 will put back into circulation (after it
442443
// replaces the list)
443444
synchronized (entryWriteLock) {
445+
long now = System.nanoTime();
446+
if (now - lastRefreshTimeNanos < TimeUnit.SECONDS.toNanos(5)) {
447+
LOG.fine("Channel pool was refreshed recently, skipping duplicate refresh");
448+
return;
449+
}
450+
lastRefreshTimeNanos = now;
451+
444452
LOG.fine("Refreshing all channels");
445453
ArrayList<Entry> newEntries = new ArrayList<>(entries.get());
446454

sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ public final class GrpcCallContext implements ApiCallContext {
9797
private final ApiCallContextOptions options;
9898
private final EndpointContext endpointContext;
9999
private final boolean isDirectPath;
100+
@Nullable private final TransportChannel transportChannel;
100101

101102
/** Returns an empty instance with a null channel and default {@link CallOptions}. */
102103
public static GrpcCallContext createDefault() {
@@ -113,7 +114,8 @@ public static GrpcCallContext createDefault() {
113114
null,
114115
null,
115116
null,
116-
false);
117+
false,
118+
null);
117119
}
118120

119121
/** Returns an instance with the given channel and {@link CallOptions}. */
@@ -131,7 +133,8 @@ public static GrpcCallContext of(Channel channel, CallOptions callOptions) {
131133
null,
132134
null,
133135
null,
134-
false);
136+
false,
137+
null);
135138
}
136139

137140
private GrpcCallContext(
@@ -147,7 +150,8 @@ private GrpcCallContext(
147150
@Nullable RetrySettings retrySettings,
148151
@Nullable Set<StatusCode.Code> retryableCodes,
149152
@Nullable EndpointContext endpointContext,
150-
boolean isDirectPath) {
153+
boolean isDirectPath,
154+
@Nullable TransportChannel transportChannel) {
151155
this.channel = channel;
152156
this.credentials = credentials;
153157
Preconditions.checkNotNull(callOptions);
@@ -167,6 +171,7 @@ private GrpcCallContext(
167171
this.endpointContext =
168172
endpointContext == null ? EndpointContext.getDefaultInstance() : endpointContext;
169173
this.isDirectPath = isDirectPath;
174+
this.transportChannel = transportChannel;
170175
}
171176

172177
/**
@@ -208,7 +213,13 @@ public GrpcCallContext withCredentials(Credentials newCredentials) {
208213
retrySettings,
209214
retryableCodes,
210215
endpointContext,
211-
isDirectPath);
216+
isDirectPath,
217+
transportChannel);
218+
}
219+
220+
@Override
221+
public TransportChannel getTransportChannel() {
222+
return transportChannel;
212223
}
213224

214225
@Override
@@ -232,7 +243,8 @@ public GrpcCallContext withTransportChannel(TransportChannel inputChannel) {
232243
retrySettings,
233244
retryableCodes,
234245
endpointContext,
235-
transportChannel.isDirectPath());
246+
transportChannel.isDirectPath(),
247+
inputChannel);
236248
}
237249

238250
@Override
@@ -251,7 +263,8 @@ public GrpcCallContext withEndpointContext(EndpointContext endpointContext) {
251263
retrySettings,
252264
retryableCodes,
253265
endpointContext,
254-
isDirectPath);
266+
isDirectPath,
267+
transportChannel);
255268
}
256269

257270
/** This method is obsolete. Use {@link #withTimeoutDuration(java.time.Duration)} instead. */
@@ -286,7 +299,8 @@ public GrpcCallContext withTimeoutDuration(@Nullable java.time.Duration timeout)
286299
retrySettings,
287300
retryableCodes,
288301
endpointContext,
289-
isDirectPath);
302+
isDirectPath,
303+
transportChannel);
290304
}
291305

292306
/** This method is obsolete. Use {@link #getTimeoutDuration()} instead. */
@@ -335,7 +349,8 @@ public GrpcCallContext withStreamWaitTimeoutDuration(
335349
retrySettings,
336350
retryableCodes,
337351
endpointContext,
338-
isDirectPath);
352+
isDirectPath,
353+
transportChannel);
339354
}
340355

341356
/**
@@ -370,7 +385,8 @@ public GrpcCallContext withStreamIdleTimeoutDuration(
370385
retrySettings,
371386
retryableCodes,
372387
endpointContext,
373-
isDirectPath);
388+
isDirectPath,
389+
transportChannel);
374390
}
375391

376392
@BetaApi("The surface for channel affinity is not stable yet and may change in the future.")
@@ -388,7 +404,8 @@ public GrpcCallContext withChannelAffinity(@Nullable Integer affinity) {
388404
retrySettings,
389405
retryableCodes,
390406
endpointContext,
391-
isDirectPath);
407+
isDirectPath,
408+
transportChannel);
392409
}
393410

394411
@BetaApi("The surface for extra headers is not stable yet and may change in the future.")
@@ -410,7 +427,8 @@ public GrpcCallContext withExtraHeaders(Map<String, List<String>> extraHeaders)
410427
retrySettings,
411428
retryableCodes,
412429
endpointContext,
413-
isDirectPath);
430+
isDirectPath,
431+
transportChannel);
414432
}
415433

416434
@Override
@@ -433,7 +451,8 @@ public GrpcCallContext withRetrySettings(RetrySettings retrySettings) {
433451
retrySettings,
434452
retryableCodes,
435453
endpointContext,
436-
isDirectPath);
454+
isDirectPath,
455+
transportChannel);
437456
}
438457

439458
@Override
@@ -456,7 +475,8 @@ public GrpcCallContext withRetryableCodes(Set<StatusCode.Code> retryableCodes) {
456475
retrySettings,
457476
retryableCodes,
458477
endpointContext,
459-
isDirectPath);
478+
isDirectPath,
479+
transportChannel);
460480
}
461481

462482
@Override
@@ -558,7 +578,8 @@ public ApiCallContext merge(ApiCallContext inputCallContext) {
558578
newRetrySettings,
559579
newRetryableCodes,
560580
endpointContext,
561-
newIsDirectPath);
581+
newIsDirectPath,
582+
transportChannel);
562583
}
563584

564585
/** The {@link Channel} set on this context. */
@@ -641,7 +662,8 @@ public GrpcCallContext withChannel(Channel newChannel) {
641662
retrySettings,
642663
retryableCodes,
643664
endpointContext,
644-
isDirectPath);
665+
isDirectPath,
666+
transportChannel);
645667
}
646668

647669
/** Returns a new instance with the call options set to the given call options. */
@@ -659,7 +681,8 @@ public GrpcCallContext withCallOptions(CallOptions newCallOptions) {
659681
retrySettings,
660682
retryableCodes,
661683
endpointContext,
662-
isDirectPath);
684+
isDirectPath,
685+
transportChannel);
663686
}
664687

665688
public GrpcCallContext withRequestParamsDynamicHeaderOption(String requestParams) {
@@ -704,7 +727,8 @@ public <T> GrpcCallContext withOption(Key<T> key, T value) {
704727
retrySettings,
705728
retryableCodes,
706729
endpointContext,
707-
isDirectPath);
730+
isDirectPath,
731+
transportChannel);
708732
}
709733

710734
/** {@inheritDoc} */

sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ public Channel getChannel() {
6666
return getManagedChannel();
6767
}
6868

69+
@Override
70+
public void refresh() {
71+
Channel channel = getChannel();
72+
if (channel instanceof ChannelPool) {
73+
((ChannelPool) channel).refresh();
74+
}
75+
}
76+
6977
@Override
7078
public void shutdown() {
7179
getManagedChannel().shutdown();

sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ public interface ApiCallContext extends RetryingContext {
6363
/** Returns a new ApiCallContext with the given channel set. */
6464
ApiCallContext withTransportChannel(TransportChannel channel);
6565

66+
/**
67+
* Returns the {@link TransportChannel} associated with this call context, or {@code null} if none
68+
* is set.
69+
*/
70+
default TransportChannel getTransportChannel() {
71+
return null;
72+
}
73+
6674
/** Returns a new ApiCallContext with the given Endpoint Context. */
6775
ApiCallContext withEndpointContext(EndpointContext endpointContext);
6876

sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ class ApiResultRetryAlgorithm<ResponseT> extends BasicResultRetryAlgorithm<Respo
3838
/** Returns true if previousThrowable is an {@link ApiException} that is retryable. */
3939
@Override
4040
public boolean shouldRetry(Throwable previousThrowable, ResponseT previousResponse) {
41+
if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))
42+
&& previousThrowable instanceof UnauthenticatedException) {
43+
return true;
44+
}
4145
return (previousThrowable instanceof ApiException)
4246
&& ((ApiException) previousThrowable).isRetryable();
4347
}
@@ -51,6 +55,10 @@ public boolean shouldRetry(Throwable previousThrowable, ResponseT previousRespon
5155
@Override
5256
public boolean shouldRetry(
5357
RetryingContext context, Throwable previousThrowable, ResponseT previousResponse) {
58+
if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))
59+
&& previousThrowable instanceof UnauthenticatedException) {
60+
return true;
61+
}
5462
if (context.getRetryableCodes() != null) {
5563
// Ignore the isRetryable() value of the throwable if the RetryingContext has a specific list
5664
// of codes that should be retried.

sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/AttemptCallable.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,28 @@ public ResponseT call() {
8484
.attemptStarted(request, externalFuture.getAttemptSettings().getOverallAttemptCount());
8585

8686
ApiFuture<ResponseT> internalFuture = callable.futureCall(request, callContext);
87+
88+
if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) {
89+
final ApiCallContext finalContext = callContext;
90+
ApiFutures.addCallback(
91+
internalFuture,
92+
new com.google.api.core.ApiFutureCallback<ResponseT>() {
93+
@Override
94+
public void onFailure(Throwable t) {
95+
if (t instanceof UnauthenticatedException) {
96+
TransportChannel transportChannel = finalContext.getTransportChannel();
97+
if (transportChannel != null) {
98+
transportChannel.refresh();
99+
}
100+
}
101+
}
102+
103+
@Override
104+
public void onSuccess(ResponseT result) {}
105+
},
106+
com.google.common.util.concurrent.MoreExecutors.directExecutor());
107+
}
108+
87109
externalFuture.setAttemptFuture(internalFuture);
88110
} catch (Throwable e) {
89111
externalFuture.setAttemptFuture(ApiFutures.<ResponseT>immediateFailedFuture(e));

sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,45 @@ public BidiStreamingCallable<RequestT, ResponseT> withDefaultCallContext(
236236
return new BidiStreamingCallable<RequestT, ResponseT>() {
237237
@Override
238238
public ClientStream<RequestT> internalCall(
239-
ResponseObserver<ResponseT> responseObserver,
239+
final ResponseObserver<ResponseT> responseObserver,
240240
ClientStreamReadyObserver<RequestT> onReady,
241241
ApiCallContext thisCallContext) {
242+
final ApiCallContext mergedContext = defaultCallContext.merge(thisCallContext);
243+
ResponseObserver<ResponseT> refreshingObserver = responseObserver;
244+
245+
if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) {
246+
refreshingObserver =
247+
new ResponseObserver<ResponseT>() {
248+
@Override
249+
public void onStart(StreamController controller) {
250+
responseObserver.onStart(controller);
251+
}
252+
253+
@Override
254+
public void onResponse(ResponseT response) {
255+
responseObserver.onResponse(response);
256+
}
257+
258+
@Override
259+
public void onError(Throwable t) {
260+
if (t instanceof UnauthenticatedException) {
261+
TransportChannel transportChannel = mergedContext.getTransportChannel();
262+
if (transportChannel != null) {
263+
transportChannel.refresh();
264+
}
265+
}
266+
responseObserver.onError(t);
267+
}
268+
269+
@Override
270+
public void onComplete() {
271+
responseObserver.onComplete();
272+
}
273+
};
274+
}
275+
242276
return BidiStreamingCallable.this.internalCall(
243-
responseObserver, onReady, defaultCallContext.merge(thisCallContext));
277+
refreshingObserver, onReady, mergedContext);
244278
}
245279
};
246280
}

sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,38 @@ public ClientStreamingCallable<RequestT, ResponseT> withDefaultCallContext(
7373
return new ClientStreamingCallable<RequestT, ResponseT>() {
7474
@Override
7575
public ApiStreamObserver<RequestT> clientStreamingCall(
76-
ApiStreamObserver<ResponseT> responseObserver, ApiCallContext thisCallContext) {
76+
final ApiStreamObserver<ResponseT> responseObserver, ApiCallContext thisCallContext) {
77+
final ApiCallContext mergedContext = defaultCallContext.merge(thisCallContext);
78+
ApiStreamObserver<ResponseT> refreshingObserver = responseObserver;
79+
80+
if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) {
81+
refreshingObserver =
82+
new ApiStreamObserver<ResponseT>() {
83+
@Override
84+
public void onNext(ResponseT response) {
85+
responseObserver.onNext(response);
86+
}
87+
88+
@Override
89+
public void onError(Throwable t) {
90+
if (t instanceof UnauthenticatedException) {
91+
TransportChannel transportChannel = mergedContext.getTransportChannel();
92+
if (transportChannel != null) {
93+
transportChannel.refresh();
94+
}
95+
}
96+
responseObserver.onError(t);
97+
}
98+
99+
@Override
100+
public void onCompleted() {
101+
responseObserver.onCompleted();
102+
}
103+
};
104+
}
105+
77106
return ClientStreamingCallable.this.clientStreamingCall(
78-
responseObserver, defaultCallContext.merge(thisCallContext));
107+
refreshingObserver, mergedContext);
79108
}
80109
};
81110
}

0 commit comments

Comments
 (0)