diff --git a/docs/topic.rst b/docs/topic.rst index d06650a1..d6d5c97a 100644 --- a/docs/topic.rst +++ b/docs/topic.rst @@ -377,6 +377,29 @@ Reader Parameters topic="/local/my-topic", # str, TopicReaderSelector, or a list of these consumer="my-consumer", buffer_size_bytes=50 * 1024 * 1024, # client-side buffer (default: 50 MB) + buffer_release_threshold=0.5, # see below (default: 0.5) + ) + +``buffer_size_bytes`` controls how many bytes the server is allowed to send before the client +signals that it is ready for more. The server will not exceed this limit. + +``buffer_release_threshold`` (float in ``[0.0, 1.0]``) controls when the client sends a new +``ReadRequest`` to the server after consuming messages from the local buffer: + +* ``0.0`` — send a ``ReadRequest`` immediately after every batch is consumed. + Produces more round-trips when many small batches arrive. +* ``> 0.0`` — accumulate freed bytes until they reach + ``threshold × buffer_size_bytes``, then send a single ``ReadRequest`` covering the + accumulated amount. This reduces network round-trips. The default is ``0.5``. +Example — reduce round-trips for a high-throughput reader with many small messages: + +.. code-block:: python + + reader = driver.topic_client.reader( + "/local/my-topic", + consumer="my-consumer", + buffer_size_bytes=50 * 1024 * 1024, + buffer_release_threshold=0.2, # send ReadRequest after freeing 10 MiB ) To read from multiple topics at once, pass a list: diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 38ee1be6..e4545f3a 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -57,7 +57,12 @@ class PublicReaderSettings: update_token_interval: Union[int, float] = 3600 event_handler: Optional[EventHandler] = None + buffer_release_threshold: float = 0.5 + """Min fraction of buffer_size_bytes to accumulate before sending a new ReadRequest (0.0 = immediately after every batch).""" + def __post_init__(self): + if not (0.0 <= self.buffer_release_threshold <= 1.0): + raise ValueError("buffer_release_threshold must be in [0.0, 1.0], got %s" % self.buffer_release_threshold) # check possible create init message _ = self._init_message() diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 5f5ac7a4..f4e5a4f8 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -3,6 +3,7 @@ import asyncio import concurrent.futures import gzip +import math import typing from asyncio import Task from collections import defaultdict, OrderedDict @@ -438,6 +439,8 @@ class ReaderStream: _background_tasks: Set[asyncio.Task] _partition_sessions: Dict[int, datatypes.PartitionSession] _buffer_size_bytes: int # use for init request, then for debug purposes only + _min_buffer_release_bytes: int + _pending_buffer_release_bytes: int _decode_executor: Optional[concurrent.futures.Executor] _decoders: Dict[int, typing.Callable[[bytes], bytes]] # dict[codec_code] func(encoded_bytes)->decoded_bytes @@ -471,6 +474,8 @@ def __init__( self._background_tasks = set() self._partition_sessions = dict() self._buffer_size_bytes = settings.buffer_size_bytes + self._min_buffer_release_bytes = math.ceil(settings.buffer_size_bytes * settings.buffer_release_threshold) + self._pending_buffer_release_bytes = 0 self._decode_executor = settings.decoder_executor self._decoders = {Codec.CODEC_GZIP: gzip.decompress} @@ -844,14 +849,17 @@ def _buffer_consume_bytes(self, bytes_size): self._buffer_size_bytes -= bytes_size def _buffer_release_bytes(self, bytes_size): - self._buffer_size_bytes += bytes_size - self._stream.write( - StreamReadMessage.FromClient( - client_message=StreamReadMessage.ReadRequest( - bytes_size=bytes_size, + self._pending_buffer_release_bytes += bytes_size + if self._pending_buffer_release_bytes >= self._min_buffer_release_bytes: + self._buffer_size_bytes += self._pending_buffer_release_bytes + self._stream.write( + StreamReadMessage.FromClient( + client_message=StreamReadMessage.ReadRequest( + bytes_size=self._pending_buffer_release_bytes, + ) ) ) - ) + self._pending_buffer_release_bytes = 0 def _read_response_to_batches(self, message: StreamReadMessage.ReadResponse) -> typing.List[datatypes.PublicBatch]: batches: typing.List[datatypes.PublicBatch] = [] diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index cedc5e47..cb7ce408 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -1388,11 +1388,7 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi _codec=Codec.CODEC_RAW, ) - assert stream_reader._buffer_size_bytes == initial_buffer_size - - assert ( - StreamReadMessage.ReadRequest(self.default_batch_size * 2) == stream.from_client.get_nowait().client_message - ) + assert stream_reader._buffer_size_bytes == initial_buffer_size - 2 * self.default_batch_size with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() @@ -1423,7 +1419,7 @@ async def test_receive_message_nowait(self, stream, stream_reader, partition_ses mess = stream_reader.receive_message_nowait() assert mess.seqno == expected_seqno - assert stream_reader._buffer_size_bytes == initial_buffer_size + assert stream_reader._buffer_size_bytes == initial_buffer_size - 2 * self.default_batch_size async def test_update_token(self, stream): settings = PublicReaderSettings( @@ -1603,3 +1599,185 @@ async def receive(timeout=None): pass # any error is fine, we just need wait_error() to not hang await reader.close(False) + + +@pytest.mark.asyncio +class TestReaderStreamBufferReleaseThreshold: + default_reader_reconnector_id = 4 + + @pytest.fixture() + def stream(self): + return StreamMock() + + async def _get_started_reader(self, stream, threshold, buffer_size_bytes, default_executor) -> ReaderStream: + settings = PublicReaderSettings( + consumer="test-consumer", + topic="test-topic", + buffer_size_bytes=buffer_size_bytes, + buffer_release_threshold=threshold, + decoder_executor=default_executor, + ) + reader = ReaderStream(self.default_reader_reconnector_id, settings) + init_message = object() + start = asyncio.create_task(reader._start(stream, init_message)) + + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.InitResponse(session_id="test-session"), + ) + ) + + await wait_for_fast(stream.from_client.get()) # init request + initial_read_req = await wait_for_fast(stream.from_client.get()) + assert isinstance(initial_read_req.client_message, StreamReadMessage.ReadRequest) + assert initial_read_req.client_message.bytes_size == buffer_size_bytes + + await start + return reader + + def _make_partition_session(self, reader: ReaderStream, session_id: int) -> datatypes.PartitionSession: + ps = datatypes.PartitionSession( + id=session_id, + state=datatypes.PartitionSession.State.Active, + topic_path="test-topic", + partition_id=session_id, + committed_offset=0, + reader_reconnector_id=self.default_reader_reconnector_id, + reader_stream_id=reader._id, + ) + reader._partition_sessions[ps.id] = ps + return ps + + def _make_batch(self, partition_session: datatypes.PartitionSession, bytes_size: int) -> datatypes.PublicBatch: + return datatypes.PublicBatch( + messages=[stub_message(1)], + _partition_session=partition_session, + _bytes_size=bytes_size, + _codec=Codec.CODEC_RAW, + ) + + async def test_threshold_zero_sends_immediately(self, stream, default_executor): + """threshold=0.0: every release sends a ReadRequest immediately.""" + reader = await self._get_started_reader( + stream, threshold=0.0, buffer_size_bytes=1000, default_executor=default_executor + ) + ps = self._make_partition_session(reader, session_id=1) + + reader._message_batches[ps.id] = self._make_batch(ps, 200) + + batch = reader.receive_batch_nowait() + assert batch is not None + + msg = stream.from_client.get_nowait() + assert isinstance(msg.client_message, StreamReadMessage.ReadRequest) + assert msg.client_message.bytes_size == 200 + + with pytest.raises(asyncio.QueueEmpty): + stream.from_client.get_nowait() + + await reader.close(False) + + async def test_threshold_delays_release_while_messages_pending(self, stream, default_executor): + """threshold=0.5: releases below threshold are not flushed, even when queue becomes empty.""" + reader = await self._get_started_reader( + stream, threshold=0.5, buffer_size_bytes=1000, default_executor=default_executor + ) + # min_bytes_to_flush = ceil(1000 * 0.5) = 500 + ps1 = self._make_partition_session(reader, session_id=1) + ps2 = self._make_partition_session(reader, session_id=2) + + reader._message_batches[ps1.id] = self._make_batch(ps1, 200) + reader._message_batches[ps2.id] = self._make_batch(ps2, 200) + + # Read first batch: 200 bytes freed, 200 < 500 threshold → no flush + batch1 = reader.receive_batch_nowait() + assert batch1 is not None + with pytest.raises(asyncio.QueueEmpty): + stream.from_client.get_nowait() + + # Read second batch: 400 bytes freed total, 400 < 500 threshold → no flush even with empty queue + batch2 = reader.receive_batch_nowait() + assert batch2 is not None + with pytest.raises(asyncio.QueueEmpty): + stream.from_client.get_nowait() + + await reader.close(False) + + async def test_threshold_flushes_when_bytes_reach_threshold(self, stream, default_executor): + """threshold=0.3: flush when accumulated bytes reach threshold, even with items still queued.""" + reader = await self._get_started_reader( + stream, threshold=0.3, buffer_size_bytes=1000, default_executor=default_executor + ) + # min_bytes_to_flush = ceil(1000 * 0.3) = 300 + ps1 = self._make_partition_session(reader, session_id=1) + ps2 = self._make_partition_session(reader, session_id=2) + ps3 = self._make_partition_session(reader, session_id=3) + + reader._message_batches[ps1.id] = self._make_batch(ps1, 150) + reader._message_batches[ps2.id] = self._make_batch(ps2, 150) + reader._message_batches[ps3.id] = self._make_batch(ps3, 150) + + # Read first batch: 150 bytes, 150 < 300 threshold, queue not empty → no flush + batch1 = reader.receive_batch_nowait() + assert batch1 is not None + with pytest.raises(asyncio.QueueEmpty): + stream.from_client.get_nowait() + + # Read second batch: 300 bytes total, 300 >= 300 threshold → flush with 300 bytes + batch2 = reader.receive_batch_nowait() + assert batch2 is not None + + msg = stream.from_client.get_nowait() + assert isinstance(msg.client_message, StreamReadMessage.ReadRequest) + assert msg.client_message.bytes_size == 300 + + # pending is reset; read third batch: 150 bytes, 150 < 300 threshold → no flush + batch3 = reader.receive_batch_nowait() + assert batch3 is not None + + with pytest.raises(asyncio.QueueEmpty): + stream.from_client.get_nowait() + + await reader.close(False) + + async def test_threshold_validation_rejects_invalid_values(self, default_executor): + with pytest.raises(ValueError): + PublicReaderSettings( + consumer="test", + topic="test-topic", + buffer_release_threshold=1.1, + decoder_executor=default_executor, + ) + + with pytest.raises(ValueError): + PublicReaderSettings( + consumer="test", + topic="test-topic", + buffer_release_threshold=-0.1, + decoder_executor=default_executor, + ) + + async def test_threshold_one_flushes_when_bytes_match_buffer_size(self, stream, default_executor): + """threshold=1.0: flush only when accumulated bytes reach the full buffer size.""" + reader = await self._get_started_reader( + stream, threshold=1.0, buffer_size_bytes=1000, default_executor=default_executor + ) + ps1 = self._make_partition_session(reader, session_id=1) + ps2 = self._make_partition_session(reader, session_id=2) + + reader._message_batches[ps1.id] = self._make_batch(ps1, 500) + reader._message_batches[ps2.id] = self._make_batch(ps2, 500) + + # Read first batch: 500 bytes freed, 500 < 1000 threshold → no flush + reader.receive_batch_nowait() + with pytest.raises(asyncio.QueueEmpty): + stream.from_client.get_nowait() + + # Read second batch: 1000 bytes freed, 1000 >= 1000 threshold → flush + reader.receive_batch_nowait() + msg = stream.from_client.get_nowait() + assert isinstance(msg.client_message, StreamReadMessage.ReadRequest) + assert msg.client_message.bytes_size == 1000 + + await reader.close(False) diff --git a/ydb/topic.py b/ydb/topic.py index 89caddc9..98859293 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -294,6 +294,7 @@ def reader( decoder_executor: Optional[concurrent.futures.Executor] = None, auto_partitioning_support: Optional[bool] = True, # Auto partitioning feature flag. Default - True. event_handler: Optional[TopicReaderEvents.EventHandler] = None, + buffer_release_threshold: float = 0.5, ) -> TopicReaderAsyncIO: logger.debug("Create reader for topic=%s consumer=%s", topic, consumer) @@ -629,6 +630,7 @@ def reader( decoder_executor: Optional[concurrent.futures.Executor] = None, # default shared client executor pool auto_partitioning_support: Optional[bool] = True, # Auto partitioning feature flag. Default - True. event_handler: Optional[TopicReaderEvents.EventHandler] = None, + buffer_release_threshold: float = 0.5, ) -> TopicReader: logger.debug("Create reader for topic=%s consumer=%s", topic, consumer) if not decoder_executor: