Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
dab0d4c
fix(rab): detach async transport session for background worker
nbayati Jun 12, 2026
fab83f1
test(auth): add unit tests for async transport clone and socket race …
nbayati Jun 12, 2026
3240f44
fix(auth): address gemini PR review comment and fix async rab refresh…
nbayati Jun 12, 2026
4312df8
fix(auth): address review comments for async RAB transport cloning an…
nbayati Jun 12, 2026
b36e583
fix(auth): preserve enterprise connection state in async transport cl…
nbayati Jun 12, 2026
d11fd64
make clone a private method
nbayati Jun 12, 2026
d203e75
refactor: extract RAB async request unwrapping into helpers and add u…
nbayati Jun 12, 2026
9eaed0b
refactor: add type annotations and suppress type checks for aiohttp s…
nbayati Jun 12, 2026
31404d7
address review comments
nbayati Jun 12, 2026
4a79bdd
fix mocking in failing unit test
nbayati Jun 12, 2026
74fcfac
add unit tests to regional access boundary utils
nbayati Jun 12, 2026
00ffcdb
fix(transport): fix resolver leak, Windows crash, and proxy errors in…
nbayati Jun 13, 2026
8153425
add unit tests to get full coverage on the new code paths
nbayati Jun 13, 2026
5e2fbc2
fix formatting issue
nbayati Jun 13, 2026
cef4bbc
fix: synchronously clone request transport to prevent ClientSession c…
nbayati Jun 15, 2026
6d0fd3a
docs: document limitations and errors in transport _clone docstrings
nbayati Jun 15, 2026
ebaa1d5
fix: support generic awaitables like Future in close_cloned_request c…
nbayati Jun 15, 2026
cae43cf
test: fix start_refresh_suppresses_request_clone_exception by asserti…
nbayati Jun 15, 2026
05ffb8a
fix: raise TransportError on closed transport calls in legacy aiohttp…
nbayati Jun 15, 2026
2d17e65
fix lint failure
nbayati Jun 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from google.auth import _helpers
from google.auth import environment_vars

if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: NO COVER
import google.auth.credentials
import google.auth.transport

Expand Down Expand Up @@ -455,6 +455,61 @@ def start_refresh(self, credentials, request, rab_manager):
self._worker.start()


def _prepare_async_lookup_callable(request):
"""Unwraps a request callable, clones the transport, and returns the new callable.

Args:
request: The original request callable (e.g. functools.partial or raw Request).

Returns:
Tuple[Callable, Any, bool]: A tuple containing the new lookup callable, the
underlying request object, and a boolean indicating if it was cloned.
"""
is_partial = isinstance(request, functools.partial)
base_callable = request.func if is_partial else request

if not hasattr(base_callable, "_clone"):
return request, base_callable, False

cloned_callable = base_callable._clone()
is_cloned = cloned_callable is not base_callable

if is_partial:
new_request = functools.partial(
cloned_callable, *request.args, **request.keywords
)
else:
new_request = cloned_callable

return new_request, cloned_callable, is_cloned


async def _close_cloned_request(lookup_request, is_cloned):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It seems like _prepare_async_lookup_callable and _close_cloned_request would be a good fit for a context manager. That would let us encapsulate these three variables, and enforce automatic closing.

Gemini put this together:

@contextlib.asynccontextmanager
async def _managed_lookup_callable(request):
    """An async context manager that prepares a cloned lookup callable 
    and guarantees its transport is closed on exit.
    """
    lookup_callable, lookup_request, is_cloned = _prepare_async_lookup_callable(request)
    try:
        yield lookup_callable
    finally:
        await _close_cloned_request(lookup_request, is_cloned)


# ... Inside your class/function where _worker is defined:

async def _worker():
    try:
        async with _managed_lookup_callable(request) as lookup_callable:
            regional_access_boundary_info = (
                await credentials._lookup_regional_access_boundary(lookup_callable)
            )
    except Exception as e:
        if _helpers.is_logging_enabled(_LOGGER):
            _LOGGER.warning(
                "Failed regional access boundary lookup: %s", 
                e, 
                exc_info=True
            )
        regional_access_boundary_info = None

But this is just a suggestion that came to mind, I think it's fine to merge as-is too.

"""Safely closes the underlying cloned request transport, if applicable.

Args:
lookup_request (Any): The request object/transport to close.
is_cloned (bool): Whether the request was actually cloned.
"""
if not is_cloned or not hasattr(lookup_request, "close"):
return

is_async = False
try:
maybe_coro = lookup_request.close()
if is_async := inspect.isawaitable(maybe_coro):
await maybe_coro
except Exception as e:
if _helpers.is_logging_enabled(_LOGGER):
adapter_type = " asynchronous " if is_async else " "
_LOGGER.warning(
"Failed to cleanly close cloned%srequest transport: %s",
adapter_type,
e,
exc_info=True,
)


class _AsyncRegionalAccessBoundaryRefreshManager(object):
"""Manages a task for background refreshing of the Regional Access Boundary in async flows."""

Expand Down Expand Up @@ -491,11 +546,28 @@ def start_refresh(self, credentials, request, rab_manager):
# A refresh is already in progress.
return

try:
(
lookup_callable,
lookup_request,
is_cloned,
) = _prepare_async_lookup_callable(request)
except Exception as e:
if _helpers.is_logging_enabled(_LOGGER):
_LOGGER.warning(
"Synchronous cloning of request for Regional Access Boundary lookup failed: %s",
e,
exc_info=True,
)
rab_manager.process_regional_access_boundary_info(None)
return

async def _worker():
try:
# credentials._lookup_regional_access_boundary should be async in the async creds class
regional_access_boundary_info = (
await credentials._lookup_regional_access_boundary(request)
await credentials._lookup_regional_access_boundary(
lookup_callable
)
)
except Exception as e:
if _helpers.is_logging_enabled(_LOGGER):
Expand All @@ -505,6 +577,8 @@ async def _worker():
exc_info=True,
)
regional_access_boundary_info = None
finally:
await _close_cloned_request(lookup_request, is_cloned)

rab_manager.process_regional_access_boundary_info(
regional_access_boundary_info
Expand All @@ -514,7 +588,15 @@ async def _worker():
try:
self._worker_task = asyncio.create_task(coro)
except Exception:
# Clean up cloned request if task creation fails
coro.close()
try:
asyncio.get_running_loop().create_task(
_close_cloned_request(lookup_request, is_cloned)
)
except RuntimeError:
pass
rab_manager.process_regional_access_boundary_info(None)
raise


Expand Down
10 changes: 10 additions & 0 deletions packages/google-auth/google/auth/aio/transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,13 @@ async def close(self) -> None:
Close the underlying session.
"""
raise NotImplementedError("close must be implemented.")

def _clone(self) -> "Request":
"""Creates a copy of this request adapter.

The base implementation returns `self` (an identical shared instance).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I still think a name other than clone should be considered. Gemini suggests _isolate() or _branch(). But this doesn't matter too much if it's internal

Transport adapters that maintain internal connection pools or stateful
sessions must override this method to return an independent, detached
adapter instance.
"""
return self
Comment thread
nbayati marked this conversation as resolved.
82 changes: 81 additions & 1 deletion packages/google-auth/google/auth/aio/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
else:
try:
from aiohttp import ClientTimeout
except (ImportError, AttributeError):
except (ImportError, AttributeError): # pragma: NO COVER
ClientTimeout = None

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -203,3 +203,83 @@ async def close(self) -> None:
if not self._closed and self._session:
await self._session.close()
self._closed = True

def _clone(self) -> "Request":
"""Creates an independent copy of this request adapter.

Clones the connection settings, trace configurations, and session defaults
(headers, cookies, basic auth, and timeouts).

Only standard `aiohttp.TCPConnector` and `aiohttp.UnixConnector` connectors
are supported. The DNS resolver is not copied to avoid closing shared resolver
resources.

Returns:
google.auth.aio.transport.aiohttp.Request: A new request adapter.

Raises:
google.auth.exceptions.TransportError: If the transport is closed, or if the
session uses an unsupported connector.
"""
if self._closed:
raise exceptions.TransportError("Cannot clone a closed transport.")

if not self._session:
new_session = aiohttp.ClientSession(
auto_decompress=False,
trust_env=True,
)
return Request(session=new_session)

session_kwargs: dict = {
"auto_decompress": False,
"trust_env": getattr(self._session, "_trust_env", True),
}

# Copy underlying connection pool settings (SSL context, IP bindings, limits).
orig_connector = getattr(self._session, "_connector", None)
if orig_connector and not orig_connector.closed:
if isinstance(orig_connector, aiohttp.TCPConnector):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _clone() implementation explicitly checks only for standard TCPConnector and UnixConnector instances. If the original session is configured with a custom, proxy, or subclassed connector (such as corporate SOCKS or tunneling proxies), the check falls through and the cloned session is created with a default, direct-connection TCPConnector.

This silently drops the proxy/custom configuration and routes traffic directly over the public internet, which will fail or violate security constraints in enterprise/isolated cloud environments. We should either explicitly support proxy preservation or raise a clear transport exception if an unsupported custom connector is detected.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a really good point.

We can't really support proxy preservation because third-party aiohttp connectors have arbitrary, unknown constructor signatures (meaning we have no way to instantiate a fresh detached copy of them dynamically), and simply shallow-copying the existing connector is unsafe due to shared socket pools. This leaves us two options: fallback to re-using the customer transport and hope that we don't encounter the bug this PR is trying to fix, or raise the exception as you suggested and accept this as a limitation of RAB.

I've decided not to fallback to re-using the customer's transport if we can't clone it, because it's not just that the RAB call would fail, but also there's another risk: if the foreground task closes the session while the background worker is actively reading from it, the forceful socket truncation mid-flight can leave complex corporate proxy connections in a hung or corrupted state, which means that the affects won't be limited to our RAB calls. So I've added the else: raise exceptions.TransportError(...) block, as raising the error here is the safest path. The exception will trigger the 15-minute cooldown and allow the user's main request to proceed safely.

I thought about disabling RAB permanently if we can't clone the transport (thinking what's the point of entering cooldown if we're going to keep trying to clone it and fail), but decided against it. I realized that because credentials objects are frequently instantiated globally and shared across multiple different clients and API surfaces, there's a chance that the next call would be executed over entirely different transports, making the RAB call possible.

# We explicitly do not copy the resolver. The connector
# owns the resolver, and closing the cloned session would
# close the shared resolver, breaking the original session.
session_kwargs["connector"] = aiohttp.TCPConnector(
ssl=getattr(orig_connector, "_ssl", None), # type: ignore
limit=getattr(orig_connector, "_limit", 100),
limit_per_host=getattr(orig_connector, "_limit_per_host", 0),
force_close=getattr(orig_connector, "_force_close", False),
local_addr=getattr(orig_connector, "_local_addr", None),
)
elif getattr(aiohttp, "UnixConnector", None) and isinstance(
orig_connector, getattr(aiohttp, "UnixConnector")
):
path = getattr(orig_connector, "_path", None)
if path:
session_kwargs["connector"] = aiohttp.UnixConnector(
path=path,
limit=getattr(orig_connector, "_limit", 100),
force_close=getattr(orig_connector, "_force_close", False),
)
else:
raise exceptions.TransportError(
f"Unsupported connector type for cloning: {type(orig_connector)}"
)

# Preserve distributed tracing configurations.
trace_configs = getattr(self._session, "_trace_configs", None)
if trace_configs:
session_kwargs["trace_configs"] = list(trace_configs)

# Copy session-level defaults (headers, cookies, auth, timeout).
for attr_name, kwarg_name in [
("_default_headers", "headers"),
("_cookie_jar", "cookie_jar"),
("_default_auth", "auth"),
("_timeout", "timeout"),
("_json_serialize", "json_serialize"),
]:
val = getattr(self._session, attr_name, None)
if val is not None:
session_kwargs[kwarg_name] = val

return Request(session=aiohttp.ClientSession(**session_kwargs)) # type: ignore
90 changes: 90 additions & 0 deletions packages/google-auth/google/auth/transport/_aiohttp_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(self, session=None):
"Client sessions with auto_decompress=True are not supported."
)
self.session = session
self._closed = False

async def __call__(
self,
Expand Down Expand Up @@ -184,6 +185,9 @@ async def __call__(
"""

try:
if getattr(self, "_closed", False):
raise exceptions.TransportError("session is closed.")

if self.session is None: # pragma: NO COVER
self.session = aiohttp.ClientSession(
auto_decompress=False
Expand All @@ -203,6 +207,92 @@ async def __call__(
new_exc = exceptions.TransportError(caught_exc)
raise new_exc from caught_exc

def _clone(self):
"""Creates an independent copy of this request adapter.

Clones the connection settings, trace configurations, and session defaults
(headers, cookies, basic auth, and timeouts).

Only standard `aiohttp.TCPConnector` and `aiohttp.UnixConnector` connectors
are supported. The DNS resolver is not copied to avoid closing shared resolver
resources.

Returns:
google.auth.transport._aiohttp_requests.Request: A new request adapter.

Raises:
google.auth.exceptions.TransportError: If the transport is closed, or if the
session uses an unsupported connector.
"""
if getattr(self, "_closed", False):
raise exceptions.TransportError("Cannot clone a closed transport.")

if not self.session:
new_session = aiohttp.ClientSession(
auto_decompress=False,
trust_env=True,
)
return Request(session=new_session)

session_kwargs: dict = {
"auto_decompress": False,
"trust_env": getattr(self.session, "_trust_env", True),
}

# Copy underlying connection pool settings (SSL context, IP bindings, limits).
orig_connector = getattr(self.session, "_connector", None)
if orig_connector and not getattr(orig_connector, "closed", True):
if isinstance(orig_connector, aiohttp.TCPConnector):
# We explicitly do not copy the resolver. The connector
# owns the resolver, and closing the cloned session would
# close the shared resolver, breaking the original session.
session_kwargs["connector"] = aiohttp.TCPConnector(
ssl=getattr(orig_connector, "_ssl", None), # type: ignore
limit=getattr(orig_connector, "_limit", 100),
limit_per_host=getattr(orig_connector, "_limit_per_host", 0),
force_close=getattr(orig_connector, "_force_close", False),
local_addr=getattr(orig_connector, "_local_addr", None),
)
elif getattr(aiohttp, "UnixConnector", None) and isinstance(
orig_connector, getattr(aiohttp, "UnixConnector")
):
path = getattr(orig_connector, "_path", None)
if path:
session_kwargs["connector"] = aiohttp.UnixConnector(
path=path,
limit=getattr(orig_connector, "_limit", 100),
force_close=getattr(orig_connector, "_force_close", False),
)
else:
raise exceptions.TransportError(
f"Unsupported connector type for cloning: {type(orig_connector)}"
)

# Preserve distributed tracing configurations.
trace_configs = getattr(self.session, "_trace_configs", None)
if trace_configs:
session_kwargs["trace_configs"] = list(trace_configs)

# Copy session-level defaults (headers, cookies, auth, timeout).
for attr_name, kwarg_name in [
("_default_headers", "headers"),
("_cookie_jar", "cookie_jar"),
("_default_auth", "auth"),
("_timeout", "timeout"),
("_json_serialize", "json_serialize"),
]:
val = getattr(self.session, attr_name, None)
if val is not None:
session_kwargs[kwarg_name] = val

return Request(session=aiohttp.ClientSession(**session_kwargs)) # type: ignore

async def close(self):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the legacy _aiohttp_requests transport, calling call on an already-closed instance propagates a raw aiohttp RuntimeError rather than a wrapped google.auth.exceptions.TransportError because call does not inspect self._closed (unlike the modern successor in aio/transport/aiohttp.py). While this legacy adapter is deprecated/internal and closed reuse is non-standard, checking self._closed in call and raising TransportError directly would align exception behaviors.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Updated Request.__call__ in the legacy _aiohttp_requests transport to check if the adapter is closed and raise a wrapped TransportError directly, matching the exception behavior of the modern adapter.

"""Cleanly release the underlying aiohttp ClientSession resources."""
if not getattr(self, "_closed", False) and self.session:
await self.session.close()
self._closed = True


class AuthorizedSession(aiohttp.ClientSession):
"""This is an async implementation of the Authorized Session class. We utilize an
Expand Down
Loading