Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion wool/src/wool/runtime/context/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,22 @@ def context_is_armed(context: contextvars.Context) -> bool:
# the context for every task it creates and registers it here — not only
# explicitly-passed ones — so re-passing a live task's own context to a
# second create_task is caught.
_task_contexts: dict[int, asyncio.Future[Any]] = {}
#
# Values are held *weakly*: an entry drops the moment its task is
# garbage-collected, even if the `_release` done-callback never fires
# (e.g., the worker loop was torn down before the callback ran, so the
# `add_done_callback` was stranded). This keeps the registry bounded to
# in-flight dispatches regardless of the worker event-loop lifecycle:
# the weak values are the primary reclaim bound, and `_release` is a
# best-effort eager cleanup rather than the sole reclaim path (see its
# docstring for the backstop role it also plays). A live weak entry
# implies its task is alive, and a live task pins its context, so the id
# cannot be reused while the entry resolves. `_PENDING` is strongly
# referenced at its definition, so a reserved slot's weak value never
# evaporates from under the reservation logic.
_task_contexts: weakref.WeakValueDictionary[int, asyncio.Future[Any]] = (
weakref.WeakValueDictionary()
)


class _PendingSentinel:
Expand Down
149 changes: 145 additions & 4 deletions wool/tests/runtime/context/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gc
import logging
import uuid
import weakref

import pytest
import pytest_asyncio
Expand Down Expand Up @@ -534,19 +535,26 @@ async def child() -> None:
]


@pytest.mark.parametrize("collect_between", [False, True], ids=["immediate", "after_gc"])
@pytest.mark.asyncio
async def test_install_task_factory_should_raise_when_context_shared_across_live_tasks():
async def test_install_task_factory_should_raise_when_context_shared_across_live_tasks(
collect_between: bool,
):
"""Test the factory rejects one contextvars.Context shared by two tasks.

Given:
An armed chain with Wool's task factory installed and a
live task created with an explicit contextvars.Context.
live task created with an explicit contextvars.Context,
optionally after a garbage collection while that task is in
flight.
When:
A second task is created with that same context object
while the first is still running.
Then:
It should raise wool.ChainContention — two tasks
cannot interleave on one context's chain context.
It should raise wool.ChainContention — two tasks cannot
interleave on one context's chain context, and a collection
must not evict the live task's registry entry because the live
task pins its own registry value.
"""
# Arrange
var = ContextVar(_unique("shared_ctx"))
Expand All @@ -565,6 +573,11 @@ async def body() -> None:

# Act & assert
try:
if collect_between:
# Let the first task start, then force a collection: the
# weak registry must keep the live task's entry.
await asyncio.sleep(0)
gc.collect()
with pytest.raises(ChainContention):
loop.create_task(body(), context=shared)
finally:
Expand Down Expand Up @@ -1009,6 +1022,134 @@ async def user_coro() -> None: # pragma: no cover — never awaited
loop.set_task_factory(None)


def test_install_task_factory_should_evict_task_when_release_stranded_and_gc():
"""Test a registered task is reclaimed when its release callback is stranded.

Given:
Wool's task factory installed on a dedicated event loop and an
armed task registered in the chain-contention registry whose
per-task release callback never fires — the loop is torn down
before it can run, the worker-loop teardown scenario that
strands the callback.
When:
The last strong reference to the task is dropped and a garbage
collection is forced.
Then:
It should be reclaimed and its registry entry evicted — the
registry holds tasks weakly, so worker bookkeeping stays bounded
even when the release callback never runs, rather than leaking a
done-but-pinned task per dispatch.
"""
# Arrange
loop = asyncio.new_event_loop()
install_task_factory(loop)
var = ContextVar(_unique("stranded_release"))
tracker: dict[str, object] = {}

async def arm_and_register() -> None:
# Arm the chain, then register a child under the armed context.
# The child is never stepped: run_until_complete returns the
# moment this coroutine finishes, so the child's first step —
# and therefore its release callback — never runs, exactly as
# when a worker loop is torn down mid-dispatch.
var.set("x")
armed_context = contextvars.copy_context()

async def child() -> None: # pragma: no cover — never stepped
return None

async def probe() -> None: # pragma: no cover — rejected before it runs
return None

task = loop.create_task(child(), context=armed_context)
tracker["ref"] = weakref.ref(task)
# Precondition: the chain-contention registry is private with no
# public surface, so prove it actually tracks the task before
# asserting eviction. While the task is live, re-passing its
# armed context is rejected as chain contention — the public
# proof that the entry exists. A strong-dict registry would
# additionally pin the task alive past collection.
try:
loop.create_task(probe(), context=armed_context)
except ChainContention:
tracker["registered"] = True

try:
loop.run_until_complete(arm_and_register())
finally:
# Tear the loop down before the stranded release callback can
# run. This drops the loop's scheduled first-step handle — the
# only strong reference to the task besides the weak registry.
loop.set_exception_handler(lambda _loop, _context: None)
loop.close()

# Act — collect now that the loop and every strong reference to the
# task are gone.
gc.collect()

# Assert
assert tracker.get("registered") is True
assert tracker["ref"]() is None # type: ignore[operator]


@pytest.mark.asyncio
async def test_install_task_factory_should_keep_pending_reservation_across_gc():
"""Test the pending-slot reservation survives a garbage collection.

Given:
Wool's task factory composed over an inner factory that, while
it runs, forces a garbage collection and then re-enters task
creation with the same armed context whose slot Wool has just
reserved with its pending sentinel.
When:
A task is created with that armed context, driving the inner
factory through the reservation window.
Then:
It should raise wool.ChainContention — the pending reservation
is intact across the collection because the sentinel is a
strongly-held module singleton — and the original task should
still run to completion once the slot is populated.
"""
# Arrange
loop = asyncio.get_running_loop()
var = ContextVar(_unique("pending_reservation"))
var.set("x")
armed_context = contextvars.copy_context()
observed: dict[str, object] = {}

async def probe() -> None: # pragma: no cover — rejected before it runs
return None

def reservation_probe(
inner_loop: asyncio.AbstractEventLoop,
coro,
**kwargs,
) -> asyncio.Task:
# Wool reserved the armed context's slot with its pending
# sentinel before delegating here. Force a collection, then try
# to create a second task on the same armed context: the pending
# reservation must still register as a live owner.
gc.collect()
try:
inner_loop.create_task(probe(), context=armed_context)
except ChainContention:
observed["contention"] = True
return asyncio.Task(coro, loop=inner_loop, **kwargs)

loop.set_task_factory(reservation_probe)
install_task_factory()

async def body() -> int:
return 7

# Act
result = await loop.create_task(body(), context=armed_context)

# Assert
assert observed.get("contention") is True
assert result == 7


@pytest.mark.asyncio
async def test_to_thread_should_return_result_when_positional_args():
"""Test wool.to_thread runs the callable and returns its result.
Expand Down
Loading