diff --git a/wool/src/wool/runtime/context/factory.py b/wool/src/wool/runtime/context/factory.py index 9725afe9..9b8bb1d5 100644 --- a/wool/src/wool/runtime/context/factory.py +++ b/wool/src/wool/runtime/context/factory.py @@ -60,7 +60,22 @@ def context_is_armed(context: contextvars.Context) -> bool: # the context for every task it creates and registers it here — not only # explicitly-passed ones — so re-passing a live task's own context to a # second create_task is caught. -_task_contexts: dict[int, asyncio.Future[Any]] = {} +# +# Values are held *weakly*: an entry drops the moment its task is +# garbage-collected, even if the `_release` done-callback never fires +# (e.g., the worker loop was torn down before the callback ran, so the +# `add_done_callback` was stranded). This keeps the registry bounded to +# in-flight dispatches regardless of the worker event-loop lifecycle: +# the weak values are the primary reclaim bound, and `_release` is a +# best-effort eager cleanup rather than the sole reclaim path (see its +# docstring for the backstop role it also plays). A live weak entry +# implies its task is alive, and a live task pins its context, so the id +# cannot be reused while the entry resolves. `_PENDING` is strongly +# referenced at its definition, so a reserved slot's weak value never +# evaporates from under the reservation logic. +_task_contexts: weakref.WeakValueDictionary[int, asyncio.Future[Any]] = ( + weakref.WeakValueDictionary() +) class _PendingSentinel: diff --git a/wool/tests/runtime/context/test_factory.py b/wool/tests/runtime/context/test_factory.py index 9ec99c2b..0465bfc4 100644 --- a/wool/tests/runtime/context/test_factory.py +++ b/wool/tests/runtime/context/test_factory.py @@ -3,6 +3,7 @@ import gc import logging import uuid +import weakref import pytest import pytest_asyncio @@ -534,19 +535,26 @@ async def child() -> None: ] +@pytest.mark.parametrize("collect_between", [False, True], ids=["immediate", "after_gc"]) @pytest.mark.asyncio -async def test_install_task_factory_should_raise_when_context_shared_across_live_tasks(): +async def test_install_task_factory_should_raise_when_context_shared_across_live_tasks( + collect_between: bool, +): """Test the factory rejects one contextvars.Context shared by two tasks. Given: An armed chain with Wool's task factory installed and a - live task created with an explicit contextvars.Context. + live task created with an explicit contextvars.Context, + optionally after a garbage collection while that task is in + flight. When: A second task is created with that same context object while the first is still running. Then: - It should raise wool.ChainContention — two tasks - cannot interleave on one context's chain context. + It should raise wool.ChainContention — two tasks cannot + interleave on one context's chain context, and a collection + must not evict the live task's registry entry because the live + task pins its own registry value. """ # Arrange var = ContextVar(_unique("shared_ctx")) @@ -565,6 +573,11 @@ async def body() -> None: # Act & assert try: + if collect_between: + # Let the first task start, then force a collection: the + # weak registry must keep the live task's entry. + await asyncio.sleep(0) + gc.collect() with pytest.raises(ChainContention): loop.create_task(body(), context=shared) finally: @@ -1009,6 +1022,134 @@ async def user_coro() -> None: # pragma: no cover — never awaited loop.set_task_factory(None) +def test_install_task_factory_should_evict_task_when_release_stranded_and_gc(): + """Test a registered task is reclaimed when its release callback is stranded. + + Given: + Wool's task factory installed on a dedicated event loop and an + armed task registered in the chain-contention registry whose + per-task release callback never fires — the loop is torn down + before it can run, the worker-loop teardown scenario that + strands the callback. + When: + The last strong reference to the task is dropped and a garbage + collection is forced. + Then: + It should be reclaimed and its registry entry evicted — the + registry holds tasks weakly, so worker bookkeeping stays bounded + even when the release callback never runs, rather than leaking a + done-but-pinned task per dispatch. + """ + # Arrange + loop = asyncio.new_event_loop() + install_task_factory(loop) + var = ContextVar(_unique("stranded_release")) + tracker: dict[str, object] = {} + + async def arm_and_register() -> None: + # Arm the chain, then register a child under the armed context. + # The child is never stepped: run_until_complete returns the + # moment this coroutine finishes, so the child's first step — + # and therefore its release callback — never runs, exactly as + # when a worker loop is torn down mid-dispatch. + var.set("x") + armed_context = contextvars.copy_context() + + async def child() -> None: # pragma: no cover — never stepped + return None + + async def probe() -> None: # pragma: no cover — rejected before it runs + return None + + task = loop.create_task(child(), context=armed_context) + tracker["ref"] = weakref.ref(task) + # Precondition: the chain-contention registry is private with no + # public surface, so prove it actually tracks the task before + # asserting eviction. While the task is live, re-passing its + # armed context is rejected as chain contention — the public + # proof that the entry exists. A strong-dict registry would + # additionally pin the task alive past collection. + try: + loop.create_task(probe(), context=armed_context) + except ChainContention: + tracker["registered"] = True + + try: + loop.run_until_complete(arm_and_register()) + finally: + # Tear the loop down before the stranded release callback can + # run. This drops the loop's scheduled first-step handle — the + # only strong reference to the task besides the weak registry. + loop.set_exception_handler(lambda _loop, _context: None) + loop.close() + + # Act — collect now that the loop and every strong reference to the + # task are gone. + gc.collect() + + # Assert + assert tracker.get("registered") is True + assert tracker["ref"]() is None # type: ignore[operator] + + +@pytest.mark.asyncio +async def test_install_task_factory_should_keep_pending_reservation_across_gc(): + """Test the pending-slot reservation survives a garbage collection. + + Given: + Wool's task factory composed over an inner factory that, while + it runs, forces a garbage collection and then re-enters task + creation with the same armed context whose slot Wool has just + reserved with its pending sentinel. + When: + A task is created with that armed context, driving the inner + factory through the reservation window. + Then: + It should raise wool.ChainContention — the pending reservation + is intact across the collection because the sentinel is a + strongly-held module singleton — and the original task should + still run to completion once the slot is populated. + """ + # Arrange + loop = asyncio.get_running_loop() + var = ContextVar(_unique("pending_reservation")) + var.set("x") + armed_context = contextvars.copy_context() + observed: dict[str, object] = {} + + async def probe() -> None: # pragma: no cover — rejected before it runs + return None + + def reservation_probe( + inner_loop: asyncio.AbstractEventLoop, + coro, + **kwargs, + ) -> asyncio.Task: + # Wool reserved the armed context's slot with its pending + # sentinel before delegating here. Force a collection, then try + # to create a second task on the same armed context: the pending + # reservation must still register as a live owner. + gc.collect() + try: + inner_loop.create_task(probe(), context=armed_context) + except ChainContention: + observed["contention"] = True + return asyncio.Task(coro, loop=inner_loop, **kwargs) + + loop.set_task_factory(reservation_probe) + install_task_factory() + + async def body() -> int: + return 7 + + # Act + result = await loop.create_task(body(), context=armed_context) + + # Assert + assert observed.get("contention") is True + assert result == 7 + + @pytest.mark.asyncio async def test_to_thread_should_return_result_when_positional_args(): """Test wool.to_thread runs the callable and returns its result.