diff --git a/wool/src/wool/runtime/worker/session.py b/wool/src/wool/runtime/worker/session.py index de603fea..76486b16 100644 --- a/wool/src/wool/runtime/worker/session.py +++ b/wool/src/wool/runtime/worker/session.py @@ -680,6 +680,18 @@ def _on_done(t: asyncio.Task): # nudge to escape an otherwise-indefinite await on # a frame that will never arrive. response_queue.close() + # Drop the session's strong reference to the completed + # worker task. ``_run`` closes over ``self``, so + # ``session -> _worker_task -> coro -> session`` forms a + # cycle reclaimable only by the cyclic GC; clearing the + # reference here breaks it so refcounting reclaims the + # session, task, and per-fork contexts promptly instead + # of letting them accumulate between GC passes. Safe + # because ``_worker_task`` is already ``None`` before + # scheduling and on a scheduling failure, so every reader + # must already tolerate ``None``; clearing it here adds no + # state a correct reader isn't already guarding against. + self._worker_task = None task.add_done_callback(_on_done) diff --git a/wool/tests/runtime/worker/test_session.py b/wool/tests/runtime/worker/test_session.py index 444c1dc3..6d42fcde 100644 --- a/wool/tests/runtime/worker/test_session.py +++ b/wool/tests/runtime/worker/test_session.py @@ -11,7 +11,9 @@ from __future__ import annotations import asyncio +import gc import threading +import weakref from uuid import uuid4 import pytest @@ -909,6 +911,181 @@ async def test___aiter___should_exit_cleanly_when_request_stream_ends( # otherwise. assert results == [] + @pytest.mark.asyncio + async def test___aiter___should_reclaim_worker_task_on_natural_completion( + self, worker_loop, mock_worker_proxy_cache + ): + """Test a completed worker driver task is reclaimed by + refcounting alone when a coroutine routine runs to natural + completion. + + Given: + A handler driving a coroutine routine, with the session + strongly retained for the whole test (modelling a completed + session a separate holder keeps alive) and automatic cyclic + GC disabled + When: + The dispatch runs to natural (non-cancelled) completion and + the last external reference to the worker driver task is + dropped + Then: + It should let refcounting reclaim the worker driver task — a + weakref to it clears without a forced ``gc.collect()`` and + without automatic collection — proving the completed task's + done-callback severed the ``session -> worker task -> + coroutine -> session`` cycle so the retained session no + longer pins the task (and its per-fork contexts) through the + reference cycle until the next GC pass. + """ + # Arrange + task = _make_task(_coro_returning_default) + stream = _stream(_request_for(task)) + handler = DispatchSession(stream, worker_loop) + await handler.__aenter__() + try: + iterator = aiter(handler) + + # Capture the worker driver task as a public observable while + # it is suspended on the request queue, before the dispatch + # drives it — ``asyncio.all_tasks`` exposes the scheduled + # worker task and the driver runs ``_run``, so its coroutine + # qualname distinguishes it from any in-flight per-step task. + driver = None + for _ in range(500): + drivers = [ + t + for t in asyncio.all_tasks(loop=worker_loop) + if "_run" in t.get_coro().__qualname__ + ] + if drivers: + driver = drivers[0] + break + await asyncio.sleep(0.01) + assert driver is not None, "worker driver task was never scheduled" + worker_task_ref = weakref.ref(driver) + del driver, drivers + + # Act — with automatic GC disabled, only refcounting can + # reclaim the finished task. Drive the coroutine to natural + # completion, release the produced responses, and drain so the + # driver returns normally (its done-callback drops the + # session's reference to it). + gc.disable() + try: + results = [r async for r in iterator] + assert len(results) == 1 + assert results[0].payload == "coroutine_value" + del results + await asyncio.wait_for(handler.drain(), timeout=2.0) + + # Assert — the session is still strongly held; if it kept + # pointing at the task (the pre-fix cycle) the weakref + # would survive with GC off. The fix drops that reference + # on natural completion, so refcounting clears the weakref. + for _ in range(200): + if worker_task_ref() is None: + break + await asyncio.sleep(0.01) + assert worker_task_ref() is None, ( + "a retained session must not pin its naturally " + "completed worker driver task — refcounting should " + "reclaim it without a forced gc.collect() or an " + "automatic collection" + ) + finally: + gc.enable() + finally: + await handler.__aexit__(None, None, None) + + @pytest.mark.asyncio + async def test___aiter___should_reclaim_worker_task_without_forced_gc_when_retained( + self, worker_loop, mock_worker_proxy_cache + ): + """Test a completed worker driver task is reclaimed by + refcounting alone even while the session is retained. + + Given: + A handler whose worker driver task is live on the worker + loop, with the session strongly retained for the whole test + (modelling a completed session a separate holder keeps + alive) and automatic cyclic GC disabled + When: + The dispatch completes and the last external reference to + the worker driver task is dropped + Then: + It should let refcounting reclaim the worker driver task + immediately — a weakref to it clears without a forced + ``gc.collect()`` and without automatic collection — proving + the retained session no longer pins the task (and its + per-fork contexts) through the reference cycle until the + next GC pass. + """ + # Arrange — a never-returning coroutine keeps the worker driver + # task live on the worker loop long enough to capture a public + # handle to it via ``asyncio.all_tasks``. + task = _make_task(_slow_coro) + stream = _stream(_request_for(task)) + handler = DispatchSession(stream, worker_loop) + await handler.__aenter__() + try: + iterator = aiter(handler) + pull = asyncio.ensure_future(anext(iterator)) + + # Capture the worker driver task as a public observable — + # ``asyncio.all_tasks`` exposes the scheduled worker task and + # the driver runs ``_run``, so its coroutine qualname + # distinguishes it from any in-flight per-step task. + driver = None + for _ in range(500): + drivers = [ + t + for t in asyncio.all_tasks(loop=worker_loop) + if "_run" in t.get_coro().__qualname__ + ] + if drivers: + driver = drivers[0] + break + await asyncio.sleep(0.01) + assert driver is not None, "worker driver task was never scheduled" + worker_task_ref = weakref.ref(driver) + del driver, drivers + + # Act — with automatic GC disabled, only refcounting can + # reclaim the finished task. Complete the dispatch (cancel + # drives the worker driver task to completion, whose + # done-callback drops the session's reference to it) and + # release the last external handle to the iterator's pull. + gc.disable() + try: + await handler.cancel() + with pytest.raises(asyncio.CancelledError): + await pull + del pull + await asyncio.wait_for(handler.drain(), timeout=2.0) + + # Assert — the session is still strongly held; if it kept + # pointing at the task (the pre-fix cycle) the weakref + # would survive with GC off. The fix drops that + # reference, so refcounting clears the weakref. + for _ in range(200): + if worker_task_ref() is None: + break + await asyncio.sleep(0.01) + assert worker_task_ref() is None, ( + "a retained session must not pin its completed worker " + "driver task — refcounting should reclaim it without a " + "forced gc.collect() or an automatic collection" + ) + finally: + gc.enable() + + # The session was strongly referenced throughout, so the + # reclamation above is attributable to the severed cycle, not + # to the session itself being collected. + assert handler is not None + finally: + await handler.__aexit__(None, None, None) + # -- __aexit__ -------------------------------------------------------- @pytest.mark.asyncio