Skip to content

Commit 9d6e760

Browse files
committed
feat: implement advisory locking in Channel base class and enhance serialization support
1 parent 9cdbcc7 commit 9d6e760

3 files changed

Lines changed: 43 additions & 66 deletions

File tree

graflow/channels/base.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,32 @@
22

33
from __future__ import annotations
44

5+
import threading
56
from abc import ABC, abstractmethod
67
from contextlib import contextmanager
7-
from typing import Any, Iterator, List, Optional, Union
8+
from typing import Any, Dict, Iterator, List, Optional, Union
89

910

1011
class Channel(ABC):
1112
"""Abstract base class for all channel implementations."""
1213

1314
def __init__(self, name: str):
1415
self.name = name
16+
self._key_locks: Dict[str, threading.RLock] = {}
17+
self._key_locks_guard = threading.Lock() # protects _key_locks dict itself
18+
19+
def __getstate__(self) -> Dict[str, Any]:
20+
"""Exclude unpicklable lock objects during serialization."""
21+
state = self.__dict__.copy()
22+
state.pop("_key_locks", None)
23+
state.pop("_key_locks_guard", None)
24+
return state
25+
26+
def __setstate__(self, state: Dict[str, Any]) -> None:
27+
"""Recreate lock objects after deserialization."""
28+
self.__dict__.update(state) # type: ignore[arg-type]
29+
self._key_locks = {}
30+
self._key_locks_guard = threading.Lock()
1531

1632
@abstractmethod
1733
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
@@ -90,9 +106,20 @@ def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, floa
90106
"""
91107
pass
92108

109+
# -- advisory locking for compound operations --
110+
111+
def _get_key_lock(self, key: str) -> threading.RLock:
112+
"""Return (or create) the ``RLock`` associated with *key*."""
113+
if key not in self._key_locks:
114+
with self._key_locks_guard:
115+
# Double-checked locking
116+
if key not in self._key_locks:
117+
self._key_locks[key] = threading.RLock()
118+
return self._key_locks[key]
119+
93120
@contextmanager
94121
def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]:
95-
"""Acquire an advisory lock scoped to *key* for compound operations.
122+
"""Acquire an advisory per-key lock for compound read-modify-write.
96123
97124
Usage::
98125
@@ -105,22 +132,22 @@ def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]:
105132
protect read-modify-write sequences that cannot be expressed with
106133
``atomic_add()``.
107134
108-
.. warning::
109-
110-
The default implementation is a **no-op** (yields immediately)
111-
and provides **no mutual exclusion**. Subclasses that need
112-
compound-operation safety **must** override this method — this
113-
includes both in-process backends under multi-threading
114-
(e.g. ``MemoryChannel``) and distributed backends where
115-
multi-client read-modify-write sequences are racy
116-
(e.g. Redis without a distributed lock).
117-
118135
Args:
119136
key: Logical key to lock on (does not need to correspond to a
120137
stored key).
121138
timeout: Maximum seconds to wait for the lock.
122139
123140
Yields:
124141
None — the lock is held for the duration of the ``with`` block.
142+
143+
Raises:
144+
TimeoutError: If the lock cannot be acquired within *timeout*.
125145
"""
126-
yield
146+
rlock = self._get_key_lock(key)
147+
acquired = rlock.acquire(timeout=timeout)
148+
if not acquired:
149+
raise TimeoutError(f"Could not acquire lock for key '{key}' within {timeout}s")
150+
try:
151+
yield
152+
finally:
153+
rlock.release()

graflow/channels/memory_channel.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22

33
from __future__ import annotations
44

5-
import threading
65
import time
7-
from contextlib import contextmanager
8-
from typing import Any, Dict, Iterator, List, Optional, Union
6+
from typing import Any, Dict, List, Optional, Union
97

108
from graflow.channels.base import Channel
119

@@ -18,21 +16,6 @@ def __init__(self, name: str, **kwargs):
1816
super().__init__(name)
1917
self.data: Dict[str, Any] = {}
2018
self.ttl_data: Dict[str, float] = {}
21-
self._key_locks: Dict[str, threading.RLock] = {}
22-
self._key_locks_guard = threading.Lock() # protects _key_locks dict itself
23-
24-
def __getstate__(self) -> Dict[str, Any]:
25-
"""Exclude unpicklable lock objects during serialization."""
26-
state = self.__dict__.copy()
27-
state.pop("_key_locks", None)
28-
state.pop("_key_locks_guard", None)
29-
return state
30-
31-
def __setstate__(self, state: Dict[str, Any]) -> None:
32-
"""Recreate lock objects after deserialization."""
33-
self.__dict__.update(state) # type: ignore[assignment]
34-
self._key_locks = {}
35-
self._key_locks_guard = threading.Lock()
3619

3720
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
3821
"""Store data in the channel."""
@@ -164,34 +147,3 @@ def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, floa
164147
new_value = current + amount
165148
self.data[key] = new_value
166149
return new_value
167-
168-
# -- advisory locking for compound operations --
169-
170-
def _get_key_lock(self, key: str) -> threading.RLock:
171-
"""Return (or create) the RLock associated with *key*."""
172-
if key not in self._key_locks:
173-
with self._key_locks_guard:
174-
# Double-checked locking
175-
if key not in self._key_locks:
176-
self._key_locks[key] = threading.RLock()
177-
return self._key_locks[key]
178-
179-
@contextmanager
180-
def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]:
181-
"""Acquire an advisory per-key lock for compound read-modify-write.
182-
183-
Args:
184-
key: Logical key to lock on.
185-
timeout: Maximum seconds to wait for the lock.
186-
187-
Raises:
188-
TimeoutError: If the lock cannot be acquired within *timeout*.
189-
"""
190-
rlock = self._get_key_lock(key)
191-
acquired = rlock.acquire(timeout=timeout)
192-
if not acquired:
193-
raise TimeoutError(f"Could not acquire lock for key '{key}' within {timeout}s")
194-
try:
195-
yield
196-
finally:
197-
rlock.release()

graflow/channels/redis_channel.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,15 +240,13 @@ def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, floa
240240

241241
def __getstate__(self):
242242
"""Support for pickle serialization."""
243-
state = self.__dict__.copy()
244-
# Remove the unpicklable Redis client
243+
state = super().__getstate__()
245244
del state["redis_client"]
246245
return state
247246

248247
def __setstate__(self, state):
249248
"""Support for pickle deserialization."""
250-
self.__dict__.update(state)
251-
# Recreate the Redis client
249+
super().__setstate__(state)
252250
assert redis is not None, "redis package is required for RedisChannel"
253251
self.redis_client = redis.Redis(
254252
host=self._host, port=self._port, db=self._db, decode_responses=True, **self._kwargs

0 commit comments

Comments
 (0)