Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions wool/src/wool/runtime/worker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,18 @@ def _on_done(t: asyncio.Task):
# nudge to escape an otherwise-indefinite await on
# a frame that will never arrive.
response_queue.close()
# Drop the session's strong reference to the completed
# worker task. ``_run`` closes over ``self``, so
# ``session -> _worker_task -> coro -> session`` forms a
# cycle reclaimable only by the cyclic GC; clearing the
# reference here breaks it so refcounting reclaims the
# session, task, and per-fork contexts promptly instead
# of letting them accumulate between GC passes. Safe
# because ``_worker_task`` is already ``None`` before
# scheduling and on a scheduling failure, so every reader
# must already tolerate ``None``; clearing it here adds no
# state a correct reader isn't already guarding against.
self._worker_task = None

task.add_done_callback(_on_done)

Expand Down
177 changes: 177 additions & 0 deletions wool/tests/runtime/worker/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from __future__ import annotations

import asyncio
import gc
import threading
import weakref
from uuid import uuid4

import pytest
Expand Down Expand Up @@ -909,6 +911,181 @@ async def test___aiter___should_exit_cleanly_when_request_stream_ends(
# otherwise.
assert results == []

@pytest.mark.asyncio
async def test___aiter___should_reclaim_worker_task_on_natural_completion(
self, worker_loop, mock_worker_proxy_cache
):
"""Test a completed worker driver task is reclaimed by
refcounting alone when a coroutine routine runs to natural
completion.

Given:
A handler driving a coroutine routine, with the session
strongly retained for the whole test (modelling a completed
session a separate holder keeps alive) and automatic cyclic
GC disabled
When:
The dispatch runs to natural (non-cancelled) completion and
the last external reference to the worker driver task is
dropped
Then:
It should let refcounting reclaim the worker driver task — a
weakref to it clears without a forced ``gc.collect()`` and
without automatic collection — proving the completed task's
done-callback severed the ``session -> worker task ->
coroutine -> session`` cycle so the retained session no
longer pins the task (and its per-fork contexts) through the
reference cycle until the next GC pass.
"""
# Arrange
task = _make_task(_coro_returning_default)
stream = _stream(_request_for(task))
handler = DispatchSession(stream, worker_loop)
await handler.__aenter__()
try:
iterator = aiter(handler)

# Capture the worker driver task as a public observable while
# it is suspended on the request queue, before the dispatch
# drives it — ``asyncio.all_tasks`` exposes the scheduled
# worker task and the driver runs ``_run``, so its coroutine
# qualname distinguishes it from any in-flight per-step task.
driver = None
for _ in range(500):
drivers = [
t
for t in asyncio.all_tasks(loop=worker_loop)
if "_run" in t.get_coro().__qualname__
]
if drivers:
driver = drivers[0]
break
await asyncio.sleep(0.01)
assert driver is not None, "worker driver task was never scheduled"
worker_task_ref = weakref.ref(driver)
del driver, drivers

# Act — with automatic GC disabled, only refcounting can
# reclaim the finished task. Drive the coroutine to natural
# completion, release the produced responses, and drain so the
# driver returns normally (its done-callback drops the
# session's reference to it).
gc.disable()
try:
results = [r async for r in iterator]
assert len(results) == 1
assert results[0].payload == "coroutine_value"
del results
await asyncio.wait_for(handler.drain(), timeout=2.0)

# Assert — the session is still strongly held; if it kept
# pointing at the task (the pre-fix cycle) the weakref
# would survive with GC off. The fix drops that reference
# on natural completion, so refcounting clears the weakref.
for _ in range(200):
if worker_task_ref() is None:
break
await asyncio.sleep(0.01)
assert worker_task_ref() is None, (
"a retained session must not pin its naturally "
"completed worker driver task — refcounting should "
"reclaim it without a forced gc.collect() or an "
"automatic collection"
)
finally:
gc.enable()
finally:
await handler.__aexit__(None, None, None)

@pytest.mark.asyncio
async def test___aiter___should_reclaim_worker_task_without_forced_gc_when_retained(
self, worker_loop, mock_worker_proxy_cache
):
"""Test a completed worker driver task is reclaimed by
refcounting alone even while the session is retained.

Given:
A handler whose worker driver task is live on the worker
loop, with the session strongly retained for the whole test
(modelling a completed session a separate holder keeps
alive) and automatic cyclic GC disabled
When:
The dispatch completes and the last external reference to
the worker driver task is dropped
Then:
It should let refcounting reclaim the worker driver task
immediately — a weakref to it clears without a forced
``gc.collect()`` and without automatic collection — proving
the retained session no longer pins the task (and its
per-fork contexts) through the reference cycle until the
next GC pass.
"""
# Arrange — a never-returning coroutine keeps the worker driver
# task live on the worker loop long enough to capture a public
# handle to it via ``asyncio.all_tasks``.
task = _make_task(_slow_coro)
stream = _stream(_request_for(task))
handler = DispatchSession(stream, worker_loop)
await handler.__aenter__()
try:
iterator = aiter(handler)
pull = asyncio.ensure_future(anext(iterator))

# Capture the worker driver task as a public observable —
# ``asyncio.all_tasks`` exposes the scheduled worker task and
# the driver runs ``_run``, so its coroutine qualname
# distinguishes it from any in-flight per-step task.
driver = None
for _ in range(500):
drivers = [
t
for t in asyncio.all_tasks(loop=worker_loop)
if "_run" in t.get_coro().__qualname__
]
if drivers:
driver = drivers[0]
break
await asyncio.sleep(0.01)
assert driver is not None, "worker driver task was never scheduled"
worker_task_ref = weakref.ref(driver)
del driver, drivers

# Act — with automatic GC disabled, only refcounting can
# reclaim the finished task. Complete the dispatch (cancel
# drives the worker driver task to completion, whose
# done-callback drops the session's reference to it) and
# release the last external handle to the iterator's pull.
gc.disable()
try:
await handler.cancel()
with pytest.raises(asyncio.CancelledError):
await pull
del pull
await asyncio.wait_for(handler.drain(), timeout=2.0)

# Assert — the session is still strongly held; if it kept
# pointing at the task (the pre-fix cycle) the weakref
# would survive with GC off. The fix drops that
# reference, so refcounting clears the weakref.
for _ in range(200):
if worker_task_ref() is None:
break
await asyncio.sleep(0.01)
assert worker_task_ref() is None, (
"a retained session must not pin its completed worker "
"driver task — refcounting should reclaim it without a "
"forced gc.collect() or an automatic collection"
)
finally:
gc.enable()

# The session was strongly referenced throughout, so the
# reclamation above is attributable to the severed cycle, not
# to the session itself being collected.
assert handler is not None
finally:
await handler.__aexit__(None, None, None)

# -- __aexit__ --------------------------------------------------------

@pytest.mark.asyncio
Expand Down
Loading