From c5cea5990404efc91155e5c6c92a465312693502 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 1 Apr 2026 11:32:39 +0300 Subject: [PATCH] Topic Writer Backpressure --- AGENTS.md | 4 + docs/topic.rst | 83 ++++++++++ examples/topic/writer_example.py | 17 ++ tests/topics/test_topic_writer.py | 48 ++++++ ydb/_topic_writer/topic_writer.py | 31 ++++ ydb/_topic_writer/topic_writer_asyncio.py | 61 ++++++- .../topic_writer_asyncio_test.py | 151 ++++++++++++++++++ ydb/_topic_writer/topic_writer_test.py | 104 ++++++++++++ ydb/topic.py | 14 ++ 9 files changed, 512 insertions(+), 1 deletion(-) diff --git a/AGENTS.md b/AGENTS.md index b347890e..eecfe878 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -76,6 +76,10 @@ source .venv/bin/activate && tox -e py -- tests/path/to/test_file.py -v - Update `docs/` for any user-facing changes; create new sections if needed. - Extend `examples/` when adding new features. +- **After every change to `docs/`**, rebuild the HTML output and verify there are no new errors: + ```sh + source .venv/bin/activate && sphinx-build -b html docs docs/_build/html -q + ``` ## Auto-generated Files — Do NOT Edit diff --git a/docs/topic.rst b/docs/topic.rst index 7188986d..d06650a1 100644 --- a/docs/topic.rst +++ b/docs/topic.rst @@ -259,6 +259,89 @@ For high-throughput pipelines, buffer writes and gather futures: raise f.exception() +Writer Backpressure +^^^^^^^^^^^^^^^^^^^ + +By default the writer's internal buffer is unbounded — ``write()`` always returns immediately +regardless of how many unacknowledged messages are in flight. Enable backpressure by setting +one or both limits: + +.. code-block:: python + + writer = driver.topic_client.writer( + "/local/my-topic", + max_buffer_size_bytes=50 * 1024 * 1024, # pause when 50 MB in flight + max_buffer_messages=1000, # pause when 1000 messages in flight + ) + +A message is counted as occupying the buffer from the moment it is passed to ``write()`` +until the server acknowledges it. Backpressure is active when **at least one** limit is set; +setting both means either limit can trigger a wait (OR semantics). + +The limits are **soft**: ``write()`` blocks only if the buffer is *already* at or above the +limit when the call starts. Once unblocked, the entire batch is admitted regardless of its +size. This means callers that batch multiple messages in a single ``write()`` call will never +deadlock even when the batch is larger than the limit. + +**Blocking behavior (default)** + +When the buffer is at or above the limit, ``write()`` blocks until enough messages are +acknowledged by the server. There is no timeout by default — the call waits indefinitely: + +.. code-block:: python + + # Producer pauses here if the buffer is full, then proceeds once space is freed. + writer.write("message") + +**Timeout** + +Set ``buffer_wait_timeout_sec`` to raise :class:`~ydb.TopicWriterBufferFullError` if space +does not free up in time. Use a positive value to wait up to that many seconds, or ``0`` to +fail immediately without waiting (non-blocking): + +.. code-block:: python + + writer = driver.topic_client.writer( + "/local/my-topic", + max_buffer_messages=500, + buffer_wait_timeout_sec=5.0, # raise after 5 seconds; use 0 to fail immediately + ) + + try: + writer.write("message") + except ydb.TopicWriterBufferFullError: + # handle overload — log, drop, or apply back-off + ... + +**Async client** + +The async writer behaves identically — ``await writer.write()`` suspends the coroutine +instead of blocking the thread: + +.. code-block:: python + + writer = driver.topic_client.writer( + "/local/my-topic", + max_buffer_size_bytes=4 * 1024 * 1024, + buffer_wait_timeout_sec=10.0, + ) + + try: + await writer.write("message") + except ydb.TopicWriterBufferFullError: + ... + +To apply your own timeout without raising an error, wrap the call with +``asyncio.wait_for``: + +.. code-block:: python + + try: + await asyncio.wait_for(writer.write("message"), timeout=2.0) + except asyncio.TimeoutError: + ... # timed out waiting for buffer space + + Reading Messages ---------------- diff --git a/examples/topic/writer_example.py b/examples/topic/writer_example.py index 99346a27..8bcacb7d 100644 --- a/examples/topic/writer_example.py +++ b/examples/topic/writer_example.py @@ -71,6 +71,23 @@ def send_message_without_block_if_internal_buffer_is_full(writer: ydb.TopicWrite return False +def writer_with_buffer_limit(db: ydb.Driver, topic_path: str): + """Writer with backpressure: waits for buffer space, raises TopicWriterBufferFullError on timeout.""" + writer = db.topic_client.writer( + topic_path, + producer_id="producer-id", + max_buffer_size_bytes=10 * 1024 * 1024, # 10 MB + buffer_wait_timeout_sec=30.0, + ) + try: + writer.write(ydb.TopicWriterMessage("data")) + except ydb.TopicWriterBufferFullError: + # Buffer did not free up within timeout (e.g. server slow or disconnected) + pass # handle: retry, drop, or back off + finally: + writer.close() + + def send_messages_with_manual_seqno(writer: ydb.TopicWriter): writer.write(ydb.TopicWriterMessage("mess")) # send text diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index df729e96..035c3b80 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -5,6 +5,7 @@ import pytest +import ydb import ydb.aio @@ -136,6 +137,53 @@ class TestException(Exception): raise TestException() +@pytest.mark.asyncio +class TestTopicWriterBackpressureAsyncIO: + async def test_write_and_read_with_backpressure_settings( + self, driver: ydb.aio.Driver, topic_path: str, topic_consumer: str + ): + messages = [b"msg-1", b"msg-2", b"msg-3"] + + async with driver.topic_client.writer( + topic_path, + producer_id="bp-test", + max_buffer_size_bytes=1024 * 1024, + max_buffer_messages=100, + buffer_wait_timeout_sec=10.0, + ) as writer: + for data in messages: + await writer.write(ydb.TopicWriterMessage(data=data)) + + async with driver.topic_client.reader(topic_path, consumer=topic_consumer) as reader: + for expected in messages: + msg = await asyncio.wait_for(reader.receive_message(), timeout=10) + assert msg.data == expected + reader.commit(msg) + + +class TestTopicWriterBackpressureSync: + def test_write_and_read_with_backpressure_settings( + self, driver_sync: ydb.Driver, topic_path: str, topic_consumer: str + ): + messages = [b"msg-1", b"msg-2", b"msg-3"] + + with driver_sync.topic_client.writer( + topic_path, + producer_id="bp-sync-test", + max_buffer_size_bytes=1024 * 1024, + max_buffer_messages=100, + buffer_wait_timeout_sec=10.0, + ) as writer: + for data in messages: + writer.write(ydb.TopicWriterMessage(data=data)) + + with driver_sync.topic_client.reader(topic_path, consumer=topic_consumer) as reader: + for expected in messages: + msg = reader.receive_message(timeout=10) + assert msg.data == expected + reader.commit(msg) + + class TestTopicWriterSync: def test_send_message(self, driver_sync: ydb.Driver, topic_path): writer = driver_sync.topic_client.writer(topic_path, producer_id="test") diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 4ce63a91..23e4cd5a 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -37,10 +37,26 @@ class PublicWriterSettings: encoder_executor: Optional[concurrent.futures.Executor] = None # default shared client executor pool encoders: Optional[typing.Mapping[PublicCodec, typing.Callable[[bytes], bytes]]] = None update_token_interval: Union[int, float] = 3600 + max_buffer_size_bytes: Optional[int] = None # None = no limit + max_buffer_messages: Optional[int] = None # None = no limit + # Backpressure is enabled when at least one of the limits above is set. + # None = wait indefinitely for buffer space; positive value = raise TopicWriterBufferFullError on timeout. + buffer_wait_timeout_sec: Optional[float] = None def __post_init__(self): if self.producer_id is None: self.producer_id = uuid.uuid4().hex + if self.max_buffer_size_bytes is not None and self.max_buffer_size_bytes <= 0: + raise ValueError("max_buffer_size_bytes must be a positive integer, got %d" % self.max_buffer_size_bytes) + if self.max_buffer_messages is not None and self.max_buffer_messages <= 0: + raise ValueError("max_buffer_messages must be a positive integer, got %d" % self.max_buffer_messages) + if self.buffer_wait_timeout_sec is not None and ( + self.buffer_wait_timeout_sec < 0 + or self.buffer_wait_timeout_sec != self.buffer_wait_timeout_sec # NaN check + ): + raise ValueError( + "buffer_wait_timeout_sec must be a non-negative number, got %r" % self.buffer_wait_timeout_sec + ) @dataclass @@ -218,6 +234,12 @@ def __init__(self): super(TopicWriterStopped, self).__init__("topic writer was stopped by call close") +class TopicWriterBufferFullError(TopicWriterError): + """Raised when write cannot proceed: buffer is full and timeout expired waiting for free space.""" + + pass + + def default_serializer_message_content(data: Any) -> bytes: if data is None: return bytes() @@ -299,6 +321,15 @@ def get_message_size(msg: InternalMessage): return _split_messages_by_size(messages, connection._DEFAULT_MAX_GRPC_MESSAGE_SIZE, get_message_size) +def internal_message_size_bytes(msg: InternalMessage) -> int: + """Approximate size in bytes for buffer accounting (data + metadata + overhead). + + Uses uncompressed_size so the value stays consistent before and after encoding. + """ + meta_len = sum(len(k) + len(v) for k, v in msg.metadata_items.items()) if msg.metadata_items else 0 + return msg.uncompressed_size + meta_len + 64 # 64 bytes overhead per message (seq_no, timestamps, etc.) + + def _split_messages_by_size( messages: List[InternalMessage], split_size: int, diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 1029ef29..5e55917d 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -19,6 +19,8 @@ InternalMessage, TopicWriterStopped, TopicWriterError, + TopicWriterBufferFullError, + internal_message_size_bytes, messages_to_proto_requests, PublicWriteResult, PublicWriteResultTypes, @@ -277,6 +279,9 @@ class WriterAsyncIOReconnector: else: _stop_reason: asyncio.Future _init_info: Optional[PublicWriterInitInfo] + _buffer_bytes: int + _buffer_messages: int + _buffer_updated: asyncio.Event def __init__( self, driver: SupportedDriverType, settings: WriterSettings, tx: Optional["BaseQueryTxContext"] = None @@ -317,6 +322,12 @@ def __init__( self._messages = deque() self._messages_future = deque() self._new_messages = asyncio.Queue() + self._backpressure_enabled = ( + settings.max_buffer_size_bytes is not None or settings.max_buffer_messages is not None + ) + self._buffer_bytes = 0 + self._buffer_messages = 0 + self._buffer_updated = asyncio.Event() self._stop_reason = self._loop.create_future() connection_task = asyncio.create_task(self._connection_loop()) connection_task.set_name("connection_loop") @@ -371,7 +382,6 @@ async def wait_stop(self) -> BaseException: return stop_reason async def write_with_ack_future(self, messages: List[PublicMessage]) -> List[asyncio.Future]: - # todo check internal buffer limit self._check_stop() if self._settings.auto_seqno: @@ -380,6 +390,9 @@ async def write_with_ack_future(self, messages: List[PublicMessage]) -> List[asy internal_messages = self._prepare_internal_messages(messages) messages_future = [self._loop.create_future() for _ in internal_messages] + if self._backpressure_enabled: + await self._acquire_buffer_space(internal_messages) + self._messages_future.extend(messages_future) if self._codec is not None and self._codec == PublicCodec.RAW: @@ -389,6 +402,46 @@ async def write_with_ack_future(self, messages: List[PublicMessage]) -> List[asy return messages_future + async def _acquire_buffer_space(self, internal_messages: List[InternalMessage]) -> None: + """Wait until the buffer is below its limit, then admit the batch (soft-limit semantics). + + Blocking starts only when the buffer is already at or above the limit at call time. + Once unblocked, the entire batch is admitted regardless of its size, so callers that + batch messages never get a permanent deadlock. + """ + max_buf = self._settings.max_buffer_size_bytes + max_msgs = self._settings.max_buffer_messages + timeout_sec = self._settings.buffer_wait_timeout_sec + deadline = self._loop.time() + timeout_sec if timeout_sec is not None else None + + while True: + self._buffer_updated.clear() + if (max_buf is None or self._buffer_bytes < max_buf) and ( + max_msgs is None or self._buffer_messages < max_msgs + ): + break + self._check_stop() + if deadline is not None: + assert timeout_sec is not None + remaining = deadline - self._loop.time() + if remaining <= 0: + raise TopicWriterBufferFullError( + "Topic writer buffer full: no free space within %.1f s" + " (buffer_bytes=%d, max_bytes=%s, buffer_msgs=%d, max_msgs=%s)" + % (timeout_sec, self._buffer_bytes, max_buf, self._buffer_messages, max_msgs) + ) + try: + await asyncio.wait_for(self._buffer_updated.wait(), timeout=min(0.5, remaining)) + except asyncio.TimeoutError: + pass + else: + await self._buffer_updated.wait() + + self._check_stop() + new_bytes = sum(internal_message_size_bytes(m) for m in internal_messages) + self._buffer_bytes += new_bytes + self._buffer_messages += len(internal_messages) + def _add_messages_to_send_queue(self, internal_messages: List[InternalMessage]): self._messages.extend(internal_messages) for m in internal_messages: @@ -648,6 +701,10 @@ def _handle_receive_ack(self, ack): "internal error - receive unexpected ack. Expected seqno: %s, received seqno: %s" % (current_message.seq_no, ack.seq_no) ) + if self._backpressure_enabled: + self._buffer_bytes = max(0, self._buffer_bytes - internal_message_size_bytes(current_message)) + self._buffer_messages = max(0, self._buffer_messages - 1) + self._buffer_updated.set() write_ack_msg = StreamWriteMessage.WriteResponse.WriteAck status = ack.message_write_status if isinstance(status, write_ack_msg.StatusSkipped): @@ -716,7 +773,9 @@ def _stop(self, reason: BaseException): for f in self._messages_future: f.set_exception(reason) + f.exception() # mark as retrieved so asyncio does not log "Future exception was never retrieved" + self._buffer_updated.set() # wake any tasks blocked in _acquire_buffer_space self._state_changed.set() logger.info("Stop topic writer %s: %s" % (self._id, reason)) diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index e6b34346..f92688a0 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -31,6 +31,7 @@ PublicWriterInitInfo, PublicWriteResult, TopicWriterError, + TopicWriterBufferFullError, ) from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .._topic_common.test_helpers import StreamMock, wait_for_fast @@ -551,6 +552,156 @@ async def test_write_message(self, reconnector: WriterAsyncIOReconnector, get_st await reconnector.close(flush=False) + async def test_buffer_full_timeout_raises(self, default_driver, get_stream_writer): + # Soft limit: blocking starts when buffer >= limit. + # First message is 10 bytes data + 64 overhead = 74 bytes; set limit=74 so the + # second write finds buffer already at the limit and must wait. + settings = WriterSettings( + PublicWriterSettings( + topic="/local/topic", + producer_id="test-producer", + auto_seqno=False, + auto_created_at=False, + codec=PublicCodec.RAW, + max_buffer_size_bytes=74, + buffer_wait_timeout_sec=0.1, + ) + ) + reconnector = WriterAsyncIOReconnector(default_driver, settings) + stream_writer = get_stream_writer() + + await reconnector.write_with_ack_future([PublicMessage(data=b"x" * 10, seqno=1)]) + await stream_writer.from_client.get() + + # buffer == limit (74) → second write blocks and times out + with pytest.raises(TopicWriterBufferFullError, match="buffer full"): + await reconnector.write_with_ack_future([PublicMessage(data=b"y" * 10, seqno=2)]) + + await reconnector.close(flush=False) + + async def test_buffer_freed_by_ack_allows_next_write(self, default_driver, get_stream_writer): + # limit=74 matches one message (10 data + 64 overhead); second write blocks + # until the first is acked and buffer drops to 0 < 74. + settings = WriterSettings( + PublicWriterSettings( + topic="/local/topic", + producer_id="test-producer", + auto_seqno=False, + auto_created_at=False, + codec=PublicCodec.RAW, + max_buffer_size_bytes=74, + buffer_wait_timeout_sec=5.0, + ) + ) + reconnector = WriterAsyncIOReconnector(default_driver, settings) + stream_writer = get_stream_writer() + + await reconnector.write_with_ack_future([PublicMessage(data=b"x" * 10, seqno=1)]) + await stream_writer.from_client.get() + + # Ack the first message to free buffer space + stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1)) + + # Second write must succeed once buffer is freed + await reconnector.write_with_ack_future([PublicMessage(data=b"y" * 10, seqno=2)]) + + stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=2)) + await reconnector.close(flush=True) + + async def test_concurrent_writers_only_one_proceeds_after_ack(self, default_driver, get_stream_writer): + # Soft-limit semantics: blocking starts when buffer >= limit. + # limit=74 (one message: 10 data + 64 overhead). + # msg1 fills buffer to 74 >= 74 → tasks 2 and 3 both block. + # Ack msg1 → buffer=0 < 74 → event fires, both tasks wake up. + # First task to run adds 94 bytes (30+64) → buffer=94 >= 74. + # Second task checks again and finds buffer still at limit → stays blocked. + settings = WriterSettings( + PublicWriterSettings( + topic="/local/topic", + producer_id="test-producer", + auto_seqno=False, + auto_created_at=False, + codec=PublicCodec.RAW, + max_buffer_size_bytes=74, + buffer_wait_timeout_sec=5.0, + ) + ) + reconnector = WriterAsyncIOReconnector(default_driver, settings) + stream_writer = get_stream_writer() + + await reconnector.write_with_ack_future([PublicMessage(data=b"x" * 10, seqno=1)]) + await stream_writer.from_client.get() + + task2 = asyncio.create_task(reconnector.write_with_ack_future([PublicMessage(data=b"y" * 30, seqno=2)])) + task3 = asyncio.create_task(reconnector.write_with_ack_future([PublicMessage(data=b"z" * 30, seqno=3)])) + + # Let both tasks start and reach their buffer-wait await point + await asyncio.sleep(0) + await asyncio.sleep(0) + assert not task2.done() + assert not task3.done() + + # Ack msg1: buffer drops 74 → 0 < 74; one task proceeds and fills buffer again + stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1)) + + done, pending = await asyncio.wait([task2, task3], timeout=1.0, return_when=asyncio.FIRST_COMPLETED) + assert len(done) == 1, "exactly one write should proceed after ack" + assert len(pending) == 1, "other write should still be waiting for buffer space" + assert not next(iter(pending)).done() + + pending_task = next(iter(pending)) + pending_task.cancel() + with pytest.raises(asyncio.CancelledError): + await pending_task + await reconnector.close(flush=False) + + async def test_buffer_messages_limit_raises_on_timeout(self, default_driver, get_stream_writer): + settings = WriterSettings( + PublicWriterSettings( + topic="/local/topic", + producer_id="test-producer", + auto_seqno=False, + auto_created_at=False, + codec=PublicCodec.RAW, + max_buffer_messages=1, + buffer_wait_timeout_sec=0.1, + ) + ) + reconnector = WriterAsyncIOReconnector(default_driver, settings) + get_stream_writer() + + await reconnector.write_with_ack_future([PublicMessage(data=b"x", seqno=1)]) + + with pytest.raises(TopicWriterBufferFullError, match="buffer full"): + await reconnector.write_with_ack_future([PublicMessage(data=b"y", seqno=2)]) + + await reconnector.close(flush=False) + + async def test_buffer_messages_limit_freed_by_ack(self, default_driver, get_stream_writer): + settings = WriterSettings( + PublicWriterSettings( + topic="/local/topic", + producer_id="test-producer", + auto_seqno=False, + auto_created_at=False, + codec=PublicCodec.RAW, + max_buffer_messages=1, + buffer_wait_timeout_sec=5.0, + ) + ) + reconnector = WriterAsyncIOReconnector(default_driver, settings) + stream_writer = get_stream_writer() + + await reconnector.write_with_ack_future([PublicMessage(data=b"x", seqno=1)]) + await stream_writer.from_client.get() + + stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1)) + + await reconnector.write_with_ack_future([PublicMessage(data=b"y", seqno=2)]) + + stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=2)) + await reconnector.close(flush=True) + async def test_auto_seq_no(self, default_driver, default_settings, get_stream_writer): last_seq_no = 100 with mock.patch.object(TestWriterAsyncIOReconnector, "init_last_seqno", last_seq_no): diff --git a/ydb/_topic_writer/topic_writer_test.py b/ydb/_topic_writer/topic_writer_test.py index 215b6c02..5b857ab7 100644 --- a/ydb/_topic_writer/topic_writer_test.py +++ b/ydb/_topic_writer/topic_writer_test.py @@ -1,4 +1,7 @@ +import asyncio +import threading from typing import List +from unittest import mock import pytest @@ -6,10 +9,14 @@ from .topic_writer import ( InternalMessage, PublicMessage, + PublicWriterSettings, + TopicWriterBufferFullError, _split_messages_by_size, _split_messages_for_send, messages_to_proto_requests, ) +from .topic_writer_asyncio import WriterAsyncIOReconnector +from .topic_writer_sync import WriterSync @pytest.mark.parametrize( @@ -166,3 +173,100 @@ def test_messages_order_preserved_within_request(self): assert len(requests) == 1 seq_nos = [m.seq_no for m in requests[0].value.messages] assert seq_nos == [1, 2, 3, 4] + + +@pytest.fixture +def background_loop(): + loop = asyncio.new_event_loop() + ready = threading.Event() + + def run(): + asyncio.set_event_loop(loop) + loop.call_soon(ready.set) + loop.run_forever() + + t = threading.Thread(target=run, daemon=True) + t.start() + assert ready.wait(timeout=5), "background event loop thread did not start in time" + yield loop + loop.call_soon_threadsafe(loop.stop) + t.join(timeout=5) + assert not t.is_alive(), "background event loop thread did not stop in time" + loop.close() + + +@pytest.fixture +def mock_reconnector(monkeypatch): + def factory(reconnector_instance): + monkeypatch.setattr(WriterAsyncIOReconnector, "__new__", lambda cls, *a, **kw: reconnector_instance) + return reconnector_instance + + return factory + + +class TestWriterSyncBuffer: + def _make_writer(self, background_loop, reconnector, mock_reconnector): + mock_reconnector(reconnector) + settings = PublicWriterSettings(topic="/local/topic", producer_id="test-producer") + return WriterSync(mock.Mock(), settings, eventloop=background_loop) + + def test_buffer_full_error_propagates(self, background_loop, mock_reconnector): + class ImmediateFullReconnector: + async def write_with_ack_future(self, messages): + raise TopicWriterBufferFullError("buffer full") + + async def close(self, flush): + pass + + writer = self._make_writer(background_loop, ImmediateFullReconnector(), mock_reconnector) + with pytest.raises(TopicWriterBufferFullError): + writer.write(PublicMessage(data=b"hello", seqno=1)) + writer.close(flush=False) + + def test_write_blocks_until_buffer_freed(self, background_loop, mock_reconnector): + write_started = threading.Event() + + class BlockingReconnector: + _release_event = None + + async def write_with_ack_future(self, messages): + self._release_event = asyncio.Event() + write_started.set() + await self._release_event.wait() + loop = asyncio.get_running_loop() + futures = [loop.create_future() for _ in messages] + for f in futures: + f.set_result(None) + return futures + + async def release(self): + if self._release_event: + self._release_event.set() + + async def close(self, flush): + pass + + reconnector = BlockingReconnector() + writer = self._make_writer(background_loop, reconnector, mock_reconnector) + + write_errors = [] + + def do_write(): + try: + writer.write(PublicMessage(data=b"hello", seqno=1)) + except Exception as e: + write_errors.append(e) + + write_thread = threading.Thread(target=do_write, daemon=True) + write_thread.start() + + assert write_started.wait(timeout=1.0), "write did not start" + + # Write thread is now blocked; release the mock to simulate buffer freed + asyncio.run_coroutine_threadsafe(reconnector.release(), background_loop).result(timeout=1.0) + + write_thread.join(timeout=1.0) + assert not write_thread.is_alive(), "write should have completed after buffer was freed" + assert not write_errors, f"unexpected error: {write_errors}" + + writer.close(flush=False) diff --git a/ydb/topic.py b/ydb/topic.py index 1faf4659..89caddc9 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -32,6 +32,7 @@ "TopicWriterInitInfo", "TopicWriterMessage", "TopicWriterSettings", + "TopicWriterBufferFullError", ] import concurrent.futures @@ -72,6 +73,7 @@ RetryPolicy as TopicWriterRetryPolicy, PublicWriterInitInfo as TopicWriterInitInfo, PublicWriteResult as TopicWriteResult, + TopicWriterBufferFullError, ) from ydb._topic_writer.topic_writer_asyncio import TxWriterAsyncIO as TopicTxWriterAsyncIO @@ -339,6 +341,9 @@ def writer( # custom encoder executor for call builtin and custom decoders. If None - use shared executor pool. # If max_worker in the executor is 1 - then encoders will be called from the thread without parallel. encoder_executor: Optional[concurrent.futures.Executor] = None, + max_buffer_size_bytes: Optional[int] = None, + max_buffer_messages: Optional[int] = None, + buffer_wait_timeout_sec: Optional[float] = None, ) -> TopicWriterAsyncIO: logger.debug("Create writer for topic=%s producer_id=%s", topic, producer_id) args = locals().copy() @@ -368,6 +373,9 @@ def tx_writer( # custom encoder executor for call builtin and custom decoders. If None - use shared executor pool. # If max_worker in the executor is 1 - then encoders will be called from the thread without parallel. encoder_executor: Optional[concurrent.futures.Executor] = None, + max_buffer_size_bytes: Optional[int] = None, + max_buffer_messages: Optional[int] = None, + buffer_wait_timeout_sec: Optional[float] = None, ) -> TopicTxWriterAsyncIO: logger.debug("Create tx writer for topic=%s tx=%s", topic, tx) args = locals().copy() @@ -666,6 +674,9 @@ def writer( # custom encoder executor for call builtin and custom decoders. If None - use shared executor pool. # If max_worker in the executor is 1 - then encoders will be called from the thread without parallel. encoder_executor: Optional[concurrent.futures.Executor] = None, # default shared client executor pool + max_buffer_size_bytes: Optional[int] = None, + max_buffer_messages: Optional[int] = None, + buffer_wait_timeout_sec: Optional[float] = None, ) -> TopicWriter: logger.debug("Create writer for topic=%s producer_id=%s", topic, producer_id) args = locals().copy() @@ -696,6 +707,9 @@ def tx_writer( # custom encoder executor for call builtin and custom decoders. If None - use shared executor pool. # If max_worker in the executor is 1 - then encoders will be called from the thread without parallel. encoder_executor: Optional[concurrent.futures.Executor] = None, # default shared client executor pool + max_buffer_size_bytes: Optional[int] = None, + max_buffer_messages: Optional[int] = None, + buffer_wait_timeout_sec: Optional[float] = None, ) -> TopicWriter: logger.debug("Create tx writer for topic=%s tx=%s", topic, tx) args = locals().copy()