From 131d6ffd4af36db92efe55c6d18cacae2c800f17 Mon Sep 17 00:00:00 2001 From: Blake Li Date: Fri, 15 May 2026 21:49:07 +0000 Subject: [PATCH] feat(gax): implement dynamic channel refreshing on 401 retries --- .../com/google/api/gax/grpc/ChannelPool.java | 8 +++ .../google/api/gax/grpc/GrpcCallContext.java | 58 +++++++++++++------ .../api/gax/grpc/GrpcTransportChannel.java | 8 +++ .../google/api/gax/rpc/ApiCallContext.java | 8 +++ .../api/gax/rpc/ApiResultRetryAlgorithm.java | 8 +++ .../google/api/gax/rpc/AttemptCallable.java | 22 +++++++ .../api/gax/rpc/BidiStreamingCallable.java | 38 +++++++++++- .../api/gax/rpc/ClientStreamingCallable.java | 33 ++++++++++- .../rpc/ServerStreamingAttemptCallable.java | 13 +++++ .../google/api/gax/rpc/TransportChannel.java | 8 +++ 10 files changed, 183 insertions(+), 21 deletions(-) diff --git a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index d611c96ff4c8..d35dbc8d12ca 100644 --- a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -82,6 +82,7 @@ class ChannelPool extends ManagedChannel { private ScheduledFuture resizeFuture = null; private final Object entryWriteLock = new Object(); + private long lastRefreshTimeNanos = 0; @VisibleForTesting final AtomicReference> entries = new AtomicReference<>(); private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; @@ -441,6 +442,13 @@ void refresh() { // - then thread2 will shut down channel that thread1 will put back into circulation (after it // replaces the list) synchronized (entryWriteLock) { + long now = System.nanoTime(); + if (now - lastRefreshTimeNanos < TimeUnit.SECONDS.toNanos(5)) { + LOG.fine("Channel pool was refreshed recently, skipping duplicate refresh"); + return; + } + lastRefreshTimeNanos = now; + LOG.fine("Refreshing all channels"); ArrayList newEntries = new ArrayList<>(entries.get()); diff --git a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java index 7ff7c54de6f0..fb5e2edb0d07 100644 --- a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java +++ b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java @@ -97,6 +97,7 @@ public final class GrpcCallContext implements ApiCallContext { private final ApiCallContextOptions options; private final EndpointContext endpointContext; private final boolean isDirectPath; + @Nullable private final TransportChannel transportChannel; /** Returns an empty instance with a null channel and default {@link CallOptions}. */ public static GrpcCallContext createDefault() { @@ -113,7 +114,8 @@ public static GrpcCallContext createDefault() { null, null, null, - false); + false, + null); } /** Returns an instance with the given channel and {@link CallOptions}. */ @@ -131,7 +133,8 @@ public static GrpcCallContext of(Channel channel, CallOptions callOptions) { null, null, null, - false); + false, + null); } private GrpcCallContext( @@ -147,7 +150,8 @@ private GrpcCallContext( @Nullable RetrySettings retrySettings, @Nullable Set retryableCodes, @Nullable EndpointContext endpointContext, - boolean isDirectPath) { + boolean isDirectPath, + @Nullable TransportChannel transportChannel) { this.channel = channel; this.credentials = credentials; Preconditions.checkNotNull(callOptions); @@ -167,6 +171,7 @@ private GrpcCallContext( this.endpointContext = endpointContext == null ? EndpointContext.getDefaultInstance() : endpointContext; this.isDirectPath = isDirectPath; + this.transportChannel = transportChannel; } /** @@ -208,7 +213,13 @@ public GrpcCallContext withCredentials(Credentials newCredentials) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); + } + + @Override + public TransportChannel getTransportChannel() { + return transportChannel; } @Override @@ -232,7 +243,8 @@ public GrpcCallContext withTransportChannel(TransportChannel inputChannel) { retrySettings, retryableCodes, endpointContext, - transportChannel.isDirectPath()); + transportChannel.isDirectPath(), + inputChannel); } @Override @@ -251,7 +263,8 @@ public GrpcCallContext withEndpointContext(EndpointContext endpointContext) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** This method is obsolete. Use {@link #withTimeoutDuration(java.time.Duration)} instead. */ @@ -286,7 +299,8 @@ public GrpcCallContext withTimeoutDuration(@Nullable java.time.Duration timeout) retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** This method is obsolete. Use {@link #getTimeoutDuration()} instead. */ @@ -335,7 +349,8 @@ public GrpcCallContext withStreamWaitTimeoutDuration( retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** @@ -370,7 +385,8 @@ public GrpcCallContext withStreamIdleTimeoutDuration( retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @BetaApi("The surface for channel affinity is not stable yet and may change in the future.") @@ -388,7 +404,8 @@ public GrpcCallContext withChannelAffinity(@Nullable Integer affinity) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @BetaApi("The surface for extra headers is not stable yet and may change in the future.") @@ -410,7 +427,8 @@ public GrpcCallContext withExtraHeaders(Map> extraHeaders) retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @Override @@ -433,7 +451,8 @@ public GrpcCallContext withRetrySettings(RetrySettings retrySettings) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @Override @@ -456,7 +475,8 @@ public GrpcCallContext withRetryableCodes(Set retryableCodes) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @Override @@ -558,7 +578,8 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { newRetrySettings, newRetryableCodes, endpointContext, - newIsDirectPath); + newIsDirectPath, + transportChannel); } /** The {@link Channel} set on this context. */ @@ -641,7 +662,8 @@ public GrpcCallContext withChannel(Channel newChannel) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** Returns a new instance with the call options set to the given call options. */ @@ -659,7 +681,8 @@ public GrpcCallContext withCallOptions(CallOptions newCallOptions) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } public GrpcCallContext withRequestParamsDynamicHeaderOption(String requestParams) { @@ -704,7 +727,8 @@ public GrpcCallContext withOption(Key key, T value) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** {@inheritDoc} */ diff --git a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java index 2fa0908f17bc..80d471701d5a 100644 --- a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java +++ b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java @@ -66,6 +66,14 @@ public Channel getChannel() { return getManagedChannel(); } + @Override + public void refresh() { + Channel channel = getChannel(); + if (channel instanceof ChannelPool) { + ((ChannelPool) channel).refresh(); + } + } + @Override public void shutdown() { getManagedChannel().shutdown(); diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java index 09af475e4833..fc7fb5e989fe 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java @@ -63,6 +63,14 @@ public interface ApiCallContext extends RetryingContext { /** Returns a new ApiCallContext with the given channel set. */ ApiCallContext withTransportChannel(TransportChannel channel); + /** + * Returns the {@link TransportChannel} associated with this call context, or {@code null} if none + * is set. + */ + default TransportChannel getTransportChannel() { + return null; + } + /** Returns a new ApiCallContext with the given Endpoint Context. */ ApiCallContext withEndpointContext(EndpointContext endpointContext); diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java index 688fc32cd14b..7c8fad8497e9 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java @@ -38,6 +38,10 @@ class ApiResultRetryAlgorithm extends BasicResultRetryAlgorithm internalFuture = callable.futureCall(request, callContext); + + if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { + final ApiCallContext finalContext = callContext; + ApiFutures.addCallback( + internalFuture, + new com.google.api.core.ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + if (t instanceof UnauthenticatedException) { + TransportChannel transportChannel = finalContext.getTransportChannel(); + if (transportChannel != null) { + transportChannel.refresh(); + } + } + } + + @Override + public void onSuccess(ResponseT result) {} + }, + com.google.common.util.concurrent.MoreExecutors.directExecutor()); + } + externalFuture.setAttemptFuture(internalFuture); } catch (Throwable e) { externalFuture.setAttemptFuture(ApiFutures.immediateFailedFuture(e)); diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java index 38efb2da3755..59d6099b2d5b 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java @@ -236,11 +236,45 @@ public BidiStreamingCallable withDefaultCallContext( return new BidiStreamingCallable() { @Override public ClientStream internalCall( - ResponseObserver responseObserver, + final ResponseObserver responseObserver, ClientStreamReadyObserver onReady, ApiCallContext thisCallContext) { + final ApiCallContext mergedContext = defaultCallContext.merge(thisCallContext); + ResponseObserver refreshingObserver = responseObserver; + + if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { + refreshingObserver = + new ResponseObserver() { + @Override + public void onStart(StreamController controller) { + responseObserver.onStart(controller); + } + + @Override + public void onResponse(ResponseT response) { + responseObserver.onResponse(response); + } + + @Override + public void onError(Throwable t) { + if (t instanceof UnauthenticatedException) { + TransportChannel transportChannel = mergedContext.getTransportChannel(); + if (transportChannel != null) { + transportChannel.refresh(); + } + } + responseObserver.onError(t); + } + + @Override + public void onComplete() { + responseObserver.onComplete(); + } + }; + } + return BidiStreamingCallable.this.internalCall( - responseObserver, onReady, defaultCallContext.merge(thisCallContext)); + refreshingObserver, onReady, mergedContext); } }; } diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java index 13ef1c64568b..c172e93ba20b 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java @@ -73,9 +73,38 @@ public ClientStreamingCallable withDefaultCallContext( return new ClientStreamingCallable() { @Override public ApiStreamObserver clientStreamingCall( - ApiStreamObserver responseObserver, ApiCallContext thisCallContext) { + final ApiStreamObserver responseObserver, ApiCallContext thisCallContext) { + final ApiCallContext mergedContext = defaultCallContext.merge(thisCallContext); + ApiStreamObserver refreshingObserver = responseObserver; + + if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { + refreshingObserver = + new ApiStreamObserver() { + @Override + public void onNext(ResponseT response) { + responseObserver.onNext(response); + } + + @Override + public void onError(Throwable t) { + if (t instanceof UnauthenticatedException) { + TransportChannel transportChannel = mergedContext.getTransportChannel(); + if (transportChannel != null) { + transportChannel.refresh(); + } + } + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + return ClientStreamingCallable.this.clientStreamingCall( - responseObserver, defaultCallContext.merge(thisCallContext)); + refreshingObserver, mergedContext); } }; } diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java index da0c8de632da..3fe6441d762c 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java @@ -219,6 +219,7 @@ public Void call() { .getTracer() .attemptStarted(request, outerRetryingFuture.getAttemptSettings().getOverallAttemptCount()); + final ApiCallContext finalContext = attemptContext; innerCallable.call( request, new StateCheckingResponseObserver() { @@ -234,6 +235,18 @@ public void onResponseImpl(ResponseT response) { @Override public void onErrorImpl(Throwable t) { + if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { + Throwable cause = t; + if (cause instanceof com.google.api.gax.retrying.ServerStreamingAttemptException) { + cause = cause.getCause(); + } + if (cause instanceof UnauthenticatedException) { + TransportChannel transportChannel = finalContext.getTransportChannel(); + if (transportChannel != null) { + transportChannel.refresh(); + } + } + } onAttemptError(t); } diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java index d54352e9b246..65b3cce0e0a3 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java @@ -47,4 +47,12 @@ public interface TransportChannel extends BackgroundResource { * Returns an empty {@link ApiCallContext} that is compatible with this {@code TransportChannel}. */ ApiCallContext getEmptyCallContext(); + + /** + * Refreshes or recreates the underlying network connections of this transport channel. + * + *

By default, this is a no-op for transports that do not require stateful connection lifecycle + * management. + */ + default void refresh() {} }