Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 4621e30

Browse files
authored
feat: Added a batch completed callback to the data client mutations batcher (#1308)
**Changes made:** - Refactored logic from `Table.mutate_rows` from producing a list of `Status` protos from a `MutationsExceptionGroup` - Added private keyword argument for a batch completion callback in the MutationsBatcher. - Added unit tests/system tests.
1 parent 9ee4032 commit 4621e30

10 files changed

Lines changed: 420 additions & 46 deletions

File tree

google/cloud/bigtable/data/_async/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from __future__ import annotations
1717

1818
from typing import (
19+
Callable,
1920
cast,
2021
Any,
2122
AsyncIterable,
@@ -115,6 +116,7 @@
115116
if TYPE_CHECKING:
116117
from google.cloud.bigtable.data._helpers import RowKeySamples
117118
from google.cloud.bigtable.data._helpers import ShardedQuery
119+
from google.rpc import status_pb2
118120

119121
if CrossSync.is_async:
120122
from google.cloud.bigtable.data._async.mutations_batcher import (

google/cloud/bigtable/data/_async/mutations_batcher.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#
1515
from __future__ import annotations
1616

17-
from typing import Sequence, TYPE_CHECKING, cast
17+
from typing import Callable, Optional, Sequence, TYPE_CHECKING, cast
1818
import atexit
1919
import warnings
2020
from collections import deque
@@ -24,6 +24,10 @@
2424
from google.cloud.bigtable.data.exceptions import FailedMutationEntryError
2525
from google.cloud.bigtable.data._helpers import _get_retryable_errors
2626
from google.cloud.bigtable.data._helpers import _get_timeouts
27+
from google.cloud.bigtable.data._helpers import (
28+
_get_statuses_from_mutations_exception_group,
29+
)
30+
2731
from google.cloud.bigtable.data._helpers import TABLE_DEFAULT
2832

2933
from google.cloud.bigtable.data.mutations import (
@@ -33,6 +37,9 @@
3337

3438
from google.cloud.bigtable.data._cross_sync import CrossSync
3539

40+
from google.rpc import code_pb2
41+
from google.rpc import status_pb2
42+
3643
if TYPE_CHECKING:
3744
from google.cloud.bigtable.data.mutations import RowMutationEntry
3845

@@ -269,6 +276,7 @@ def __init__(
269276
self._newest_exceptions: deque[Exception] = deque(
270277
maxlen=self._exception_list_limit
271278
)
279+
self._user_batch_completed_callback = None
272280
# clean up on program exit
273281
atexit.register(self._on_exit)
274282

@@ -380,6 +388,7 @@ async def _execute_mutate_rows(
380388
list of FailedMutationEntryError objects for mutations that failed.
381389
FailedMutationEntryError objects will not contain index information
382390
"""
391+
statuses = [status_pb2.Status(code=code_pb2.Code.UNKNOWN)] * len(batch)
383392
try:
384393
operation = CrossSync._MutateRowsOperation(
385394
self._target.client._gapic_client,
@@ -391,13 +400,21 @@ async def _execute_mutate_rows(
391400
)
392401
await operation.start()
393402
except MutationsExceptionGroup as e:
403+
statuses = _get_statuses_from_mutations_exception_group(e, len(batch))
404+
394405
# strip index information from exceptions, since it is not useful in a batch context
395406
for subexc in e.exceptions:
396407
subexc.index = None
397408
return list(e.exceptions)
409+
else:
410+
statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(batch)
398411
finally:
399412
# mark batch as complete in flow control
400413
await self._flow_control.remove_from_flow(batch)
414+
415+
# Call batch done callback with list of statuses.
416+
if self._user_batch_completed_callback:
417+
self._user_batch_completed_callback(statuses)
401418
return []
402419

403420
def _add_exceptions(self, excs: list[Exception]):

google/cloud/bigtable/data/_helpers.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""
1717
from __future__ import annotations
1818

19-
from typing import Callable, Sequence, List, Tuple, TYPE_CHECKING, Union
19+
from typing import Callable, Sequence, List, Optional, Tuple, TYPE_CHECKING, Union
2020
import time
2121
import enum
2222
from collections import namedtuple
@@ -26,6 +26,10 @@
2626
from google.api_core import retry as retries
2727
from google.api_core.retry import RetryFailureReason
2828
from google.cloud.bigtable.data.exceptions import RetryExceptionGroup
29+
from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
30+
from google.rpc import code_pb2
31+
from google.rpc import status_pb2
32+
2933

3034
if TYPE_CHECKING:
3135
import grpc
@@ -224,6 +228,66 @@ def _align_timeouts(operation: float, attempt: float | None) -> tuple[float, flo
224228
return operation, final_attempt
225229

226230

231+
def _get_statuses_from_mutations_exception_group(
232+
exc_group: MutationsExceptionGroup, batch_size: int
233+
) -> list[status_pb2.Status]:
234+
"""
235+
Helper function that populates a list of Status objects with exception information from
236+
the exception group.
237+
238+
Args:
239+
exc_group: The exception group from a mutate rows operation
240+
batch_size: How many RowMutationGroups were provided to the batch
241+
Returns:
242+
list[status_pb2.Status]: A list of Status proto objects
243+
"""
244+
# We exception handle as follows:
245+
#
246+
# 1. Each exception in the error group is a FailedMutationEntryError, and its
247+
# cause is either a singular exception or a RetryExceptionGroup consisting of
248+
# multiple exceptions.
249+
#
250+
# 2. In the case of a singular exception, if the error does not have a gRPC status
251+
# code, we return a status code of UNKNOWN.
252+
#
253+
# 3. In the case of a RetryExceptionGroup, we use terminal exception in the exception
254+
# group and process that.
255+
statuses = [status_pb2.Status(code=code_pb2.OK)] * batch_size
256+
for error in exc_group.exceptions:
257+
if isinstance(error.index, int) and 0 <= error.index < len(statuses):
258+
cause = error.__cause__
259+
if isinstance(cause, RetryExceptionGroup):
260+
statuses[error.index] = _get_status(cause.exceptions[-1])
261+
else:
262+
statuses[error.index] = _get_status(cause)
263+
return statuses
264+
265+
266+
def _get_status(exc: Optional[Exception]) -> status_pb2.Status:
267+
"""
268+
Helper function that returns a Status object corresponding to the given exception.
269+
270+
Args:
271+
exc: An exception to be converted into a Status.
272+
Returns:
273+
status_pb2.Status: A Status proto object.
274+
"""
275+
if (
276+
isinstance(exc, core_exceptions.GoogleAPICallError)
277+
and exc.grpc_status_code is not None
278+
):
279+
return status_pb2.Status( # type: ignore[unreachable]
280+
code=exc.grpc_status_code.value[0],
281+
message=exc.message,
282+
details=exc.details,
283+
)
284+
285+
return status_pb2.Status(
286+
code=code_pb2.Code.UNKNOWN,
287+
message=str(exc) if exc else "An unknown error has occurred",
288+
)
289+
290+
227291
def _validate_timeouts(
228292
operation_timeout: float, attempt_timeout: float | None, allow_none: bool = False
229293
):

google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,15 @@
2525
from google.cloud.bigtable.data.exceptions import FailedMutationEntryError
2626
from google.cloud.bigtable.data._helpers import _get_retryable_errors
2727
from google.cloud.bigtable.data._helpers import _get_timeouts
28+
from google.cloud.bigtable.data._helpers import (
29+
_get_statuses_from_mutations_exception_group,
30+
)
2831
from google.cloud.bigtable.data._helpers import TABLE_DEFAULT
2932
from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT
3033
from google.cloud.bigtable.data.mutations import Mutation
3134
from google.cloud.bigtable.data._cross_sync import CrossSync
35+
from google.rpc import code_pb2
36+
from google.rpc import status_pb2
3237

3338
if TYPE_CHECKING:
3439
from google.cloud.bigtable.data.mutations import RowMutationEntry
@@ -233,6 +238,7 @@ def __init__(
233238
self._newest_exceptions: deque[Exception] = deque(
234239
maxlen=self._exception_list_limit
235240
)
241+
self._user_batch_completed_callback = None
236242
atexit.register(self._on_exit)
237243

238244
def _timer_routine(self, interval: float | None) -> None:
@@ -324,6 +330,7 @@ def _execute_mutate_rows(
324330
list[FailedMutationEntryError]:
325331
list of FailedMutationEntryError objects for mutations that failed.
326332
FailedMutationEntryError objects will not contain index information"""
333+
statuses = [status_pb2.Status(code=code_pb2.Code.UNKNOWN)] * len(batch)
327334
try:
328335
operation = CrossSync._Sync_Impl._MutateRowsOperation(
329336
self._target.client._gapic_client,
@@ -335,11 +342,16 @@ def _execute_mutate_rows(
335342
)
336343
operation.start()
337344
except MutationsExceptionGroup as e:
345+
statuses = _get_statuses_from_mutations_exception_group(e, len(batch))
338346
for subexc in e.exceptions:
339347
subexc.index = None
340348
return list(e.exceptions)
349+
else:
350+
statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(batch)
341351
finally:
342352
self._flow_control.remove_from_flow(batch)
353+
if self._user_batch_completed_callback:
354+
self._user_batch_completed_callback(statuses)
343355
return []
344356

345357
def _add_exceptions(self, excs: list[Exception]):

google/cloud/bigtable/table.py

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Set
1818
import warnings
1919

20-
from google.api_core.exceptions import GoogleAPICallError
2120
from google.api_core.exceptions import Aborted
2221
from google.api_core.exceptions import DeadlineExceeded
2322
from google.api_core.exceptions import NotFound
@@ -31,12 +30,12 @@
3130
from google.cloud.bigtable.column_family import _gc_rule_from_pb
3231
from google.cloud.bigtable.column_family import ColumnFamily
3332
from google.cloud.bigtable.data._helpers import TABLE_DEFAULT
34-
from google.cloud.bigtable.data.exceptions import (
35-
RetryExceptionGroup,
36-
MutationsExceptionGroup,
33+
from google.cloud.bigtable.data._helpers import (
34+
_get_statuses_from_mutations_exception_group,
3735
)
38-
from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
36+
from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
3937
from google.cloud.bigtable.data.mutations import RowMutationEntry
38+
from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
4039
from google.cloud.bigtable.batcher import MutationsBatcher
4140
from google.cloud.bigtable.batcher import FLUSH_COUNT, MAX_MUTATION_SIZE
4241
from google.cloud.bigtable.encryption_info import EncryptionInfo
@@ -767,9 +766,9 @@ def mutate_rows(self, rows, retry=DEFAULT_RETRY, timeout=DEFAULT):
767766
mutation_entries = [
768767
RowMutationEntry(row.row_key, row._get_mutations()) for row in rows
769768
]
770-
return_statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(
769+
return_statuses = [status_pb2.Status(code=code_pb2.Code.UNKNOWN)] * len(
771770
mutation_entries
772-
) # By default, return status OKs for everything
771+
)
773772

774773
try:
775774
self._table_impl.bulk_mutate_rows(
@@ -779,41 +778,15 @@ def mutate_rows(self, rows, retry=DEFAULT_RETRY, timeout=DEFAULT):
779778
retryable_errors=retryable_errors,
780779
)
781780
except MutationsExceptionGroup as mut_exc_group:
782-
# We exception handle as follows:
783-
#
784-
# 1. Each exception in the error group is a FailedMutationEntryError, and its
785-
# cause is either a singular exception or a RetryExceptionGroup consisting of
786-
# multiple exceptions.
787-
#
788-
# 2. In the case of a singular exception, if the error does not have a gRPC status
789-
# code, we return a status code of UNKNOWN.
790-
#
791-
# 3. In the case of a RetryExceptionGroup, we use terminal exception in the exception
792-
# group and process that.
793-
for error in mut_exc_group.exceptions:
794-
cause = error.__cause__
795-
if isinstance(cause, RetryExceptionGroup):
796-
return_statuses[error.index] = self._get_status(
797-
cause.exceptions[-1]
798-
)
799-
else:
800-
return_statuses[error.index] = self._get_status(cause)
801-
802-
return return_statuses
803-
804-
@staticmethod
805-
def _get_status(error):
806-
if isinstance(error, GoogleAPICallError) and error.grpc_status_code is not None:
807-
return status_pb2.Status(
808-
code=error.grpc_status_code.value[0],
809-
message=error.message,
810-
details=error.details,
781+
return_statuses = _get_statuses_from_mutations_exception_group(
782+
mut_exc_group, len(mutation_entries)
783+
)
784+
else:
785+
return_statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(
786+
mutation_entries
811787
)
812788

813-
return status_pb2.Status(
814-
code=code_pb2.Code.UNKNOWN,
815-
message=str(error),
816-
)
789+
return return_statuses
817790

818791
def sample_row_keys(self):
819792
"""Read a sample of row keys in the table.

tests/system/data/test_system_async.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,42 @@ async def test_mutations_batcher_timer_flush(self, client, target, temp_rows):
482482
# ensure cell is updated
483483
assert (await self._retrieve_cell_value(target, row_key)) == new_value
484484

485+
@pytest.mark.usefixtures("client")
486+
@pytest.mark.usefixtures("target")
487+
@CrossSync.Retry(
488+
predicate=retry.if_exception_type(ClientError), initial=1, maximum=5
489+
)
490+
@CrossSync.pytest
491+
async def test_mutations_batcher_completed_callback(
492+
self, client, target, temp_rows
493+
):
494+
"""
495+
test batcher with batch completed callback. It should be called when the batcher flushes.
496+
"""
497+
from google.cloud.bigtable.data.mutations import RowMutationEntry
498+
from google.rpc import code_pb2, status_pb2
499+
500+
import mock
501+
502+
callback = mock.Mock()
503+
504+
new_value = uuid.uuid4().hex.encode()
505+
row_key, mutation = await self._create_row_and_mutation(
506+
target, temp_rows, new_value=new_value
507+
)
508+
bulk_mutation = RowMutationEntry(row_key, [mutation])
509+
flush_interval = 0.1
510+
async with target.mutations_batcher(flush_interval=flush_interval) as batcher:
511+
batcher._user_batch_completed_callback = callback
512+
await batcher.append(bulk_mutation)
513+
await CrossSync.yield_to_event_loop()
514+
assert len(batcher._staged_entries) == 1
515+
await CrossSync.sleep(flush_interval + 0.1)
516+
assert len(batcher._staged_entries) == 0
517+
callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)])
518+
# ensure cell is updated
519+
assert (await self._retrieve_cell_value(target, row_key)) == new_value
520+
485521
@pytest.mark.usefixtures("client")
486522
@pytest.mark.usefixtures("target")
487523
@CrossSync.Retry(

tests/system/data/test_system_autogen.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,34 @@ def test_mutations_batcher_timer_flush(self, client, target, temp_rows):
385385
assert len(batcher._staged_entries) == 0
386386
assert self._retrieve_cell_value(target, row_key) == new_value
387387

388+
@pytest.mark.usefixtures("client")
389+
@pytest.mark.usefixtures("target")
390+
@CrossSync._Sync_Impl.Retry(
391+
predicate=retry.if_exception_type(ClientError), initial=1, maximum=5
392+
)
393+
def test_mutations_batcher_completed_callback(self, client, target, temp_rows):
394+
"""test batcher with batch completed callback. It should be called when the batcher flushes."""
395+
from google.cloud.bigtable.data.mutations import RowMutationEntry
396+
from google.rpc import code_pb2, status_pb2
397+
import mock
398+
399+
callback = mock.Mock()
400+
new_value = uuid.uuid4().hex.encode()
401+
(row_key, mutation) = self._create_row_and_mutation(
402+
target, temp_rows, new_value=new_value
403+
)
404+
bulk_mutation = RowMutationEntry(row_key, [mutation])
405+
flush_interval = 0.1
406+
with target.mutations_batcher(flush_interval=flush_interval) as batcher:
407+
batcher._user_batch_completed_callback = callback
408+
batcher.append(bulk_mutation)
409+
CrossSync._Sync_Impl.yield_to_event_loop()
410+
assert len(batcher._staged_entries) == 1
411+
CrossSync._Sync_Impl.sleep(flush_interval + 0.1)
412+
assert len(batcher._staged_entries) == 0
413+
callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)])
414+
assert self._retrieve_cell_value(target, row_key) == new_value
415+
388416
@pytest.mark.usefixtures("client")
389417
@pytest.mark.usefixtures("target")
390418
@CrossSync._Sync_Impl.Retry(

0 commit comments

Comments
 (0)