22
33from __future__ import annotations
44
5+ import threading
56from abc import ABC , abstractmethod
67from contextlib import contextmanager
7- from typing import Any , Iterator , List , Optional , Union
8+ from typing import Any , Dict , Iterator , List , Optional , Union
89
910
1011class Channel (ABC ):
1112 """Abstract base class for all channel implementations."""
1213
1314 def __init__ (self , name : str ):
1415 self .name = name
16+ self ._key_locks : Dict [str , threading .RLock ] = {}
17+ self ._key_locks_guard = threading .Lock () # protects _key_locks dict itself
18+
19+ def __getstate__ (self ) -> Dict [str , Any ]:
20+ """Exclude unpicklable lock objects during serialization."""
21+ state = self .__dict__ .copy ()
22+ state .pop ("_key_locks" , None )
23+ state .pop ("_key_locks_guard" , None )
24+ return state
25+
26+ def __setstate__ (self , state : Dict [str , Any ]) -> None :
27+ """Recreate lock objects after deserialization."""
28+ self .__dict__ .update (state ) # type: ignore[arg-type]
29+ self ._key_locks = {}
30+ self ._key_locks_guard = threading .Lock ()
1531
1632 @abstractmethod
1733 def set (self , key : str , value : Any , ttl : Optional [int ] = None ) -> None :
@@ -90,9 +106,20 @@ def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, floa
90106 """
91107 pass
92108
109+ # -- advisory locking for compound operations --
110+
111+ def _get_key_lock (self , key : str ) -> threading .RLock :
112+ """Return (or create) the ``RLock`` associated with *key*."""
113+ if key not in self ._key_locks :
114+ with self ._key_locks_guard :
115+ # Double-checked locking
116+ if key not in self ._key_locks :
117+ self ._key_locks [key ] = threading .RLock ()
118+ return self ._key_locks [key ]
119+
93120 @contextmanager
94121 def lock (self , key : str , timeout : float = 10.0 ) -> Iterator [None ]:
95- """Acquire an advisory lock scoped to * key* for compound operations .
122+ """Acquire an advisory per- key lock for compound read-modify-write .
96123
97124 Usage::
98125
@@ -105,22 +132,22 @@ def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]:
105132 protect read-modify-write sequences that cannot be expressed with
106133 ``atomic_add()``.
107134
108- .. warning::
109-
110- The default implementation is a **no-op** (yields immediately)
111- and provides **no mutual exclusion**. Subclasses that need
112- compound-operation safety **must** override this method — this
113- includes both in-process backends under multi-threading
114- (e.g. ``MemoryChannel``) and distributed backends where
115- multi-client read-modify-write sequences are racy
116- (e.g. Redis without a distributed lock).
117-
118135 Args:
119136 key: Logical key to lock on (does not need to correspond to a
120137 stored key).
121138 timeout: Maximum seconds to wait for the lock.
122139
123140 Yields:
124141 None — the lock is held for the duration of the ``with`` block.
142+
143+ Raises:
144+ TimeoutError: If the lock cannot be acquired within *timeout*.
125145 """
126- yield
146+ rlock = self ._get_key_lock (key )
147+ acquired = rlock .acquire (timeout = timeout )
148+ if not acquired :
149+ raise TimeoutError (f"Could not acquire lock for key '{ key } ' within { timeout } s" )
150+ try :
151+ yield
152+ finally :
153+ rlock .release ()
0 commit comments