Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 35 additions & 248 deletions google/cloud/bigtable/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# limitations under the License.

"""User friendly container for Google Cloud Bigtable MutationBatcher."""
import threading
import queue
import concurrent.futures
import atexit


from google.api_core.exceptions import from_grpc_status
from dataclasses import dataclass
from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
from google.cloud.bigtable.data.mutations import RowMutationEntry


FLUSH_COUNT = 100 # after this many elements, send out the batch
Expand All @@ -41,131 +39,6 @@ def __init__(self, message, exc):
super().__init__(self.message)


class _MutationsBatchQueue(object):
"""Private Threadsafe Queue to hold rows for batching."""

def __init__(self, max_mutation_bytes=MAX_MUTATION_SIZE, flush_count=FLUSH_COUNT):
"""Specify the queue constraints"""
self._queue = queue.Queue()
self.total_mutation_count = 0
self.total_size = 0
self.max_mutation_bytes = max_mutation_bytes
self.flush_count = flush_count

def get(self):
"""
Retrieve an item from the queue. Recalculate queue size.

If the queue is empty, return None.
"""
try:
row = self._queue.get_nowait()
mutation_size = row.get_mutations_size()
self.total_mutation_count -= len(row._get_mutations())
self.total_size -= mutation_size
return row
except queue.Empty:
return None

def put(self, item):
"""Insert an item to the queue. Recalculate queue size."""

mutation_count = len(item._get_mutations())

self._queue.put(item)

self.total_size += item.get_mutations_size()
self.total_mutation_count += mutation_count

def full(self):
"""Check if the queue is full."""
if (
self.total_mutation_count >= self.flush_count
or self.total_size >= self.max_mutation_bytes
):
return True
return False


@dataclass
class _BatchInfo:
"""Keeping track of size of a batch"""

mutations_count: int = 0
rows_count: int = 0
mutations_size: int = 0


class _FlowControl(object):
def __init__(
self,
max_mutations=MAX_OUTSTANDING_ELEMENTS,
max_mutation_bytes=MAX_OUTSTANDING_BYTES,
):
"""Control the inflight requests. Keep track of the mutations, row bytes and row counts.
As requests to backend are being made, adjust the number of mutations being processed.

If threshold is reached, block the flow.
Reopen the flow as requests are finished.
"""
self.max_mutations = max_mutations
self.max_mutation_bytes = max_mutation_bytes
self.inflight_mutations = 0
self.inflight_size = 0
self.event = threading.Event()
self.event.set()
self._lock = threading.Lock()

def is_blocked(self):
"""Returns True if:

- inflight mutations >= max_mutations, or
- inflight bytes size >= max_mutation_bytes, or
"""

return (
self.inflight_mutations >= self.max_mutations
or self.inflight_size >= self.max_mutation_bytes
)

def control_flow(self, batch_info):
"""
Calculate the resources used by this batch
"""

with self._lock:
self.inflight_mutations += batch_info.mutations_count
self.inflight_size += batch_info.mutations_size
self.set_flow_control_status()

def wait(self):
"""
Wait until flow control pushback has been released.
It awakens as soon as `event` is set.
"""
self.event.wait()

def set_flow_control_status(self):
"""Check the inflight mutations and size.

If values exceed the allowed threshold, block the event.
"""
if self.is_blocked():
self.event.clear() # sleep
else:
self.event.set() # awaken the threads

def release(self, batch_info):
"""
Release the resources.
Decrement the row size to allow enqueued mutations to be run.
"""
with self._lock:
self.inflight_mutations -= batch_info.mutations_count
self.inflight_size -= batch_info.mutations_size
self.set_flow_control_status()


class MutationsBatcher(object):
"""A MutationsBatcher is used in batch cases where the number of mutations
is large or unknown. It will store :class:`DirectRow` in memory until one of the
Expand Down Expand Up @@ -214,29 +87,41 @@ def __init__(
flush_interval=1,
batch_completed_callback=None,
):
self._rows = _MutationsBatchQueue(
max_mutation_bytes=max_row_bytes, flush_count=flush_count
)
self.table = table
self._executor = concurrent.futures.ThreadPoolExecutor()
atexit.register(self.close)
self._timer = threading.Timer(flush_interval, self.flush)
self._timer.start()
self.flow_control = _FlowControl(
max_mutations=MAX_OUTSTANDING_ELEMENTS,
max_mutation_bytes=MAX_OUTSTANDING_BYTES,
)
self.futures_mapping = {}
self.exceptions = queue.Queue()
self._flush_count = flush_count
self._max_row_bytes = max_row_bytes
self._flush_interval = flush_interval
self._user_batch_completed_callback = batch_completed_callback
self._init_batcher()
atexit.register(self.close)
self._exceptions = queue.Queue()

@property
def flush_count(self):
return self._rows.flush_count
return self._flush_count

@property
def max_row_bytes(self):
return self._rows.max_mutation_bytes
return self._max_row_bytes

def _init_batcher(self):
self._batcher = self.table._table_impl.mutations_batcher(
flush_interval=self._flush_interval,
flush_limit_mutation_count=self._flush_count,
flush_limit_bytes=self._max_row_bytes,
)
self._batcher._user_batch_completed_callback = (
self._user_batch_completed_callback
)
Comment on lines +113 to +115
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing the private attribute _user_batch_completed_callback of the underlying _batcher can introduce fragility. If the internal implementation of the data client's batcher changes this private attribute, it could break this shim. Consider if there's a public API or a more robust way to pass this callback to the underlying batcher, or add a comment explaining this design choice and its implications for future maintenance.


def _close_batcher(self):
try:
self._batcher.close()
except MutationsExceptionGroup as exc_group:
for error in exc_group.exceptions:
# Return the cause of the FailedMutationEntryError to the user,
# as this might be more what they're expecting.
self._exceptions.put(error.__cause__)

def __enter__(self):
"""Starting the MutationsBatcher as a context manager"""
Expand All @@ -260,10 +145,7 @@ def mutate(self, row):
* :exc:`~.table._BigtableRetryableError` if any row returned a transient error.
* :exc:`RuntimeError` if the number of responses doesn't match the number of rows that were retried
"""
self._rows.put(row)

if self._rows.full():
self._flush_async()
self._batcher.append(RowMutationEntry(row.row_key, row._get_mutations()))

def mutate_rows(self, rows):
"""Add multiple rows to the batch. If the current batch meets one of the size
Expand Down Expand Up @@ -298,102 +180,8 @@ def flush(self):
:raises:
* :exc:`.batcherMutationsBatchError` if there's any error in the mutations.
"""
rows_to_flush = []
row = self._rows.get()
while row is not None:
rows_to_flush.append(row)
row = self._rows.get()
response = self._flush_rows(rows_to_flush)
return response

def _flush_async(self):
"""Sends the current batch to Cloud Bigtable asynchronously.

:raises:
* :exc:`.batcherMutationsBatchError` if there's any error in the mutations.
"""
next_row = self._rows.get()
while next_row is not None:
# start a new batch
rows_to_flush = [next_row]
batch_info = _BatchInfo(
mutations_count=len(next_row._get_mutations()),
rows_count=1,
mutations_size=next_row.get_mutations_size(),
)
# fill up batch with rows
next_row = self._rows.get()
while next_row is not None and self._row_fits_in_batch(
next_row, batch_info
):
rows_to_flush.append(next_row)
batch_info.mutations_count += len(next_row._get_mutations())
batch_info.rows_count += 1
batch_info.mutations_size += next_row.get_mutations_size()
next_row = self._rows.get()
# send batch over network
# wait for resources to become available
self.flow_control.wait()
# once unblocked, submit the batch
# event flag will be set by control_flow to block subsequent thread, but not blocking this one
self.flow_control.control_flow(batch_info)
future = self._executor.submit(self._flush_rows, rows_to_flush)
# schedule release of resources from flow control
self.futures_mapping[future] = batch_info
future.add_done_callback(self._batch_completed_callback)

def _batch_completed_callback(self, future):
"""Callback for when the mutation has finished to clean up the current batch
and release items from the flow controller.
Raise exceptions if there's any.
Release the resources locked by the flow control and allow enqueued tasks to be run.
"""
processed_rows = self.futures_mapping[future]
self.flow_control.release(processed_rows)
del self.futures_mapping[future]

def _row_fits_in_batch(self, row, batch_info):
"""Checks if a row can fit in the current batch.

:type row: class
:param row: :class:`~google.cloud.bigtable.row.DirectRow`.

:type batch_info: :class:`_BatchInfo`
:param batch_info: Information about the current batch.

:rtype: bool
:returns: True if the row can fit in the current batch.
"""
new_rows_count = batch_info.rows_count + 1
new_mutations_count = batch_info.mutations_count + len(row._get_mutations())
new_mutations_size = batch_info.mutations_size + row.get_mutations_size()
return (
new_rows_count <= self.flush_count
and new_mutations_size <= self.max_row_bytes
and new_mutations_count <= self.flow_control.max_mutations
and new_mutations_size <= self.flow_control.max_mutation_bytes
)

def _flush_rows(self, rows_to_flush):
"""Mutate the specified rows.

:raises:
* :exc:`.batcherMutationsBatchError` if there's any error in the mutations.
"""
responses = []
if len(rows_to_flush) > 0:
response = self.table.mutate_rows(rows_to_flush)

if self._user_batch_completed_callback:
self._user_batch_completed_callback(response)

for result in response:
if result.code != 0:
exc = from_grpc_status(result.code, result.message)
self.exceptions.put(exc)
responses.append(result)

return responses
self._close_batcher()
self._init_batcher()

def __exit__(self, exc_type, exc_value, exc_traceback):
"""Clean up resources. Flush and shutdown the ThreadPoolExecutor."""
Expand All @@ -406,9 +194,8 @@ def close(self):
:raises:
* :exc:`.batcherMutationsBatchError` if there's any error in the mutations.
"""
self.flush()
self._executor.shutdown(wait=True)
self._close_batcher()
atexit.unregister(self.close)
if self.exceptions.qsize() > 0:
exc = list(self.exceptions.queue)
if self._exceptions.qsize() > 0:
exc = list(self._exceptions.queue)
raise MutationsBatchError("Errors in batch mutations.", exc=exc)
2 changes: 0 additions & 2 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from __future__ import annotations

from typing import (
Callable,
cast,
Any,
AsyncIterable,
Expand Down Expand Up @@ -116,7 +115,6 @@
if TYPE_CHECKING:
from google.cloud.bigtable.data._helpers import RowKeySamples
from google.cloud.bigtable.data._helpers import ShardedQuery
from google.rpc import status_pb2

if CrossSync.is_async:
from google.cloud.bigtable.data._async.mutations_batcher import (
Expand Down
4 changes: 3 additions & 1 deletion google/cloud/bigtable/data/_async/mutations_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ def __init__(
self._newest_exceptions: deque[Exception] = deque(
maxlen=self._exception_list_limit
)
self._user_batch_completed_callback = None
self._user_batch_completed_callback: Optional[
Callable[[list[status_pb2.Status]], None]
] = None
# clean up on program exit
atexit.register(self._on_exit)

Expand Down
21 changes: 12 additions & 9 deletions google/cloud/bigtable/data/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
from __future__ import annotations

from typing import Callable, Sequence, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import cast, Callable, Sequence, List, Optional, Tuple, TYPE_CHECKING, Union
import time
import enum
from collections import namedtuple
Expand Down Expand Up @@ -272,14 +272,17 @@ def _get_status(exc: Optional[Exception]) -> status_pb2.Status:
Returns:
status_pb2.Status: A Status proto object.
"""
if (
isinstance(exc, core_exceptions.GoogleAPICallError)
and exc.grpc_status_code is not None
):
return status_pb2.Status( # type: ignore[unreachable]
code=exc.grpc_status_code.value[0],
message=exc.message,
details=exc.details,
if isinstance(exc, core_exceptions.GoogleAPICallError):
status_code = cast(Optional["grpc.StatusCode"], exc.grpc_status_code)
if status_code is not None:
return status_pb2.Status(
code=status_code.value[0],
message=exc.message,
details=exc.details,
)
return status_pb2.Status(
code=code_pb2.Code.UNKNOWN,
message="An unknown error has occurred",
)

return status_pb2.Status(
Expand Down
Loading
Loading