1919 InternalMessage ,
2020 TopicWriterStopped ,
2121 TopicWriterError ,
22+ TopicWriterBufferFullError ,
23+ internal_message_size_bytes ,
2224 messages_to_proto_requests ,
2325 PublicWriteResult ,
2426 PublicWriteResultTypes ,
@@ -277,6 +279,9 @@ class WriterAsyncIOReconnector:
277279 else :
278280 _stop_reason : asyncio .Future
279281 _init_info : Optional [PublicWriterInitInfo ]
282+ _buffer_bytes : int
283+ _buffer_messages : int
284+ _buffer_updated : asyncio .Event
280285
281286 def __init__ (
282287 self , driver : SupportedDriverType , settings : WriterSettings , tx : Optional ["BaseQueryTxContext" ] = None
@@ -317,6 +322,12 @@ def __init__(
317322 self ._messages = deque ()
318323 self ._messages_future = deque ()
319324 self ._new_messages = asyncio .Queue ()
325+ self ._backpressure_enabled = (
326+ settings .max_buffer_size_bytes is not None or settings .max_buffer_messages is not None
327+ )
328+ self ._buffer_bytes = 0
329+ self ._buffer_messages = 0
330+ self ._buffer_updated = asyncio .Event ()
320331 self ._stop_reason = self ._loop .create_future ()
321332 connection_task = asyncio .create_task (self ._connection_loop ())
322333 connection_task .set_name ("connection_loop" )
@@ -371,7 +382,6 @@ async def wait_stop(self) -> BaseException:
371382 return stop_reason
372383
373384 async def write_with_ack_future (self , messages : List [PublicMessage ]) -> List [asyncio .Future ]:
374- # todo check internal buffer limit
375385 self ._check_stop ()
376386
377387 if self ._settings .auto_seqno :
@@ -380,6 +390,9 @@ async def write_with_ack_future(self, messages: List[PublicMessage]) -> List[asy
380390 internal_messages = self ._prepare_internal_messages (messages )
381391 messages_future = [self ._loop .create_future () for _ in internal_messages ]
382392
393+ if self ._backpressure_enabled :
394+ await self ._acquire_buffer_space (internal_messages )
395+
383396 self ._messages_future .extend (messages_future )
384397
385398 if self ._codec is not None and self ._codec == PublicCodec .RAW :
@@ -389,6 +402,46 @@ async def write_with_ack_future(self, messages: List[PublicMessage]) -> List[asy
389402
390403 return messages_future
391404
405+ async def _acquire_buffer_space (self , internal_messages : List [InternalMessage ]) -> None :
406+ """Wait until the buffer is below its limit, then admit the batch (soft-limit semantics).
407+
408+ Blocking starts only when the buffer is already at or above the limit at call time.
409+ Once unblocked, the entire batch is admitted regardless of its size, so callers that
410+ batch messages never get a permanent deadlock.
411+ """
412+ max_buf = self ._settings .max_buffer_size_bytes
413+ max_msgs = self ._settings .max_buffer_messages
414+ timeout_sec = self ._settings .buffer_wait_timeout_sec
415+ deadline = self ._loop .time () + timeout_sec if timeout_sec is not None else None
416+
417+ while True :
418+ self ._buffer_updated .clear ()
419+ if (max_buf is None or self ._buffer_bytes < max_buf ) and (
420+ max_msgs is None or self ._buffer_messages < max_msgs
421+ ):
422+ break
423+ self ._check_stop ()
424+ if deadline is not None :
425+ assert timeout_sec is not None
426+ remaining = deadline - self ._loop .time ()
427+ if remaining <= 0 :
428+ raise TopicWriterBufferFullError (
429+ "Topic writer buffer full: no free space within %.1f s"
430+ " (buffer_bytes=%d, max_bytes=%s, buffer_msgs=%d, max_msgs=%s)"
431+ % (timeout_sec , self ._buffer_bytes , max_buf , self ._buffer_messages , max_msgs )
432+ )
433+ try :
434+ await asyncio .wait_for (self ._buffer_updated .wait (), timeout = min (0.5 , remaining ))
435+ except asyncio .TimeoutError :
436+ pass
437+ else :
438+ await self ._buffer_updated .wait ()
439+
440+ self ._check_stop ()
441+ new_bytes = sum (internal_message_size_bytes (m ) for m in internal_messages )
442+ self ._buffer_bytes += new_bytes
443+ self ._buffer_messages += len (internal_messages )
444+
392445 def _add_messages_to_send_queue (self , internal_messages : List [InternalMessage ]):
393446 self ._messages .extend (internal_messages )
394447 for m in internal_messages :
@@ -648,6 +701,10 @@ def _handle_receive_ack(self, ack):
648701 "internal error - receive unexpected ack. Expected seqno: %s, received seqno: %s"
649702 % (current_message .seq_no , ack .seq_no )
650703 )
704+ if self ._backpressure_enabled :
705+ self ._buffer_bytes = max (0 , self ._buffer_bytes - internal_message_size_bytes (current_message ))
706+ self ._buffer_messages = max (0 , self ._buffer_messages - 1 )
707+ self ._buffer_updated .set ()
651708 write_ack_msg = StreamWriteMessage .WriteResponse .WriteAck
652709 status = ack .message_write_status
653710 if isinstance (status , write_ack_msg .StatusSkipped ):
@@ -716,7 +773,9 @@ def _stop(self, reason: BaseException):
716773
717774 for f in self ._messages_future :
718775 f .set_exception (reason )
776+ f .exception () # mark as retrieved so asyncio does not log "Future exception was never retrieved"
719777
778+ self ._buffer_updated .set () # wake any tasks blocked in _acquire_buffer_space
720779 self ._state_changed .set ()
721780 logger .info ("Stop topic writer %s: %s" % (self ._id , reason ))
722781
0 commit comments