diff --git a/blobid/solver.py b/blobid/solver.py index 2715987..f9ed2a9 100644 --- a/blobid/solver.py +++ b/blobid/solver.py @@ -5,9 +5,10 @@ import numpy as np -from blobid.utils.ccl import get_temporary_labels +from blobid.utils.domain import VOFDomain + +from blobid.utils.ccl import get_temporary_labels, stitch_boundaries from blobid.utils.reconstruction import normals -from blobid.utils.boundaries import pad_array, stitch_boundaries def get_labels( @@ -145,58 +146,43 @@ def get_labels( """ # Defaults - if periodic is None: - periodic = [False] * void_fraction.ndim if label_type is None: label_type = np.uint32 - # Checks - assert len(periodic) == void_fraction.ndim - - # copy void fraction and convert to 3D - vof = np.atleast_3d(void_fraction.copy()) - - # Update periodicity with how numpy adds dimensions - match void_fraction.ndim: - case 1: periodic_3d = (False, periodic[0], False) - case 2: periodic_3d = (periodic[0], periodic[1], False) - case 3: periodic_3d = (periodic[0], periodic[1], periodic[2]) - case _: - raise ValueError("Unexpected void_fraction.ndim: " + str(void_fraction.ndim)) - - # add padding for boundary conditions - extra_padding = use_normals or (cutoff_method == 'neighbors') - vof = pad_array(vof, periodic_3d, extra=extra_padding) + # Setup to domain + domain = VOFDomain(void_fraction=void_fraction, + periodic=periodic if periodic is not None else [False] * void_fraction.ndim, + periodic_padding=1, + extra_padding=1 if (use_normals or (cutoff_method == 'neighbors')) else 0) # calculate object cells, removing the extra padding - is_object = _calc_object_cells(vof, cutoff, cutoff_method) - if extra_padding: - is_object = is_object[1:-1, 1:-1, 1:-1] + is_object = _calc_object_cells(domain, cutoff, cutoff_method) # calculate connectivity - is_connected = _calc_connections(is_object, norm=(normals(vof, normals_method) if use_normals else None)) + is_connected = _calc_connections(is_object, + norm=normals(domain.vof(padding=1), normals_method) if use_normals else None) # do initial labeling (labels, label_database) = get_temporary_labels(is_object, is_connected, label_type) - # stitch together periodic boundaries (removes padding from labels) - (labels, label_database) = stitch_boundaries(labels, label_database, periodic_3d) + # stitch together periodic boundaries + label_database = stitch_boundaries(labels, label_database, domain.periodic) # do final labeling with sequential labels labels = label_database.get_sequential_lookup_table()[labels] - # reshape to original dimensions - return labels.reshape(void_fraction.shape) + # reshape to original dimensions (removes padding in periodic directions) + return domain.convert_to_original_shape(labels) -def _calc_object_cells(vof: np.ndarray, cutoff: float, cutoff_method: str) -> np.ndarray: +def _calc_object_cells(domain: VOFDomain, cutoff: float, cutoff_method: str) -> np.ndarray: """returns an array that is true if cell is an object cell""" match cutoff_method: case 'local': - return vof > cutoff + return domain.vof() > cutoff case 'neighbors': - large = np.pad(vof > cutoff, 1, constant_values=False) + large = domain.vof(padding=1) > cutoff large_neighbor = np.logical_or.reduce([ large[1:-1, 1:-1, 1:-1], # center @@ -208,7 +194,7 @@ def _calc_object_cells(vof: np.ndarray, cutoff: float, cutoff_method: str) -> np large[1:-1, 1:-1, :-2], # k-1 ]) - return np.logical_and(large_neighbor, vof > 0) + return np.logical_and(large_neighbor, domain.vof() > 0) case _: raise ValueError(f"cutoff_method '{cutoff_method}' is not supported") diff --git a/blobid/utils/boundaries.py b/blobid/utils/boundaries.py deleted file mode 100644 index 2d5de83..0000000 --- a/blobid/utils/boundaries.py +++ /dev/null @@ -1,50 +0,0 @@ -""" -Routines for handling periodic and symmetric boundary conditions -""" - -from typing import Tuple -import numpy as np - -from blobid.utils.labeling import LabelDatabase - - -def pad_array( - arr: np.ndarray, - periodic: Tuple[bool, bool, bool], - extra: int = 0 - ) -> np.ndarray: - """Add appropriate padding to the array""" - - # padding in periodic directions - out = np.pad(arr, [(1+extra, 1+extra) if p else (0, 0) for p in periodic], 'wrap') - - # padding in non-periodic directions - if extra != 0: - out = np.pad(out, [(0, 0) if p else (int(extra), int(extra)) for p in periodic], 'symmetric') - - return out - - -def stitch_boundaries( - labels: np.ndarray, - sets, - periodic: Tuple[bool, bool, bool] - ) -> tuple[np.ndarray, LabelDatabase]: - """Stitch together periodic boundaries and remove padding""" - - if periodic[0]: - for a, b in zip(labels[0, :, :].flat, labels[-2, :, :].flat): - sets.merge(a, b) - labels = labels[1:-1, :, :] - - if periodic[1]: - for a, b in zip(labels[:, 0, :].flat, labels[:, -2, :].flat): - sets.merge(a, b) - labels = labels[:, 1:-1, :] - - if periodic[2]: - for a, b in zip(labels[:, :, 0].flat, labels[:, :, -2].flat): - sets.merge(a, b) - labels = labels[:, :, 1:-1] - - return (labels, sets) diff --git a/blobid/utils/ccl.py b/blobid/utils/ccl.py index 311c5ef..14caff2 100644 --- a/blobid/utils/ccl.py +++ b/blobid/utils/ccl.py @@ -1,6 +1,7 @@ """ Tools for running connected component labeling """ +from typing import Tuple import numpy as np from blobid.utils.numba_support import njit @@ -38,6 +39,28 @@ def get_temporary_labels( return (labels, label_database) +def stitch_boundaries( + labels: np.ndarray, + sets: LabelDatabase, + periodic: Tuple[bool, bool, bool] + ) -> LabelDatabase: + """Stitch together periodic boundaries and remove padding""" + + if periodic[0]: + for a, b in zip(labels[0, :, :].flat, labels[-2, :, :].flat): + sets.merge(a, b) + + if periodic[1]: + for a, b in zip(labels[:, 0, :].flat, labels[:, -2, :].flat): + sets.merge(a, b) + + if periodic[2]: + for a, b in zip(labels[:, :, 0].flat, labels[:, :, -2].flat): + sets.merge(a, b) + + return sets + + @njit def _tail_pass( labels: np.ndarray, diff --git a/blobid/utils/domain.py b/blobid/utils/domain.py new file mode 100644 index 0000000..abba78c --- /dev/null +++ b/blobid/utils/domain.py @@ -0,0 +1,85 @@ +""" +Holds information about the input domain +""" +from typing import List, Tuple + +import numpy as np + + +def _pad_array( + arr: np.ndarray, + periodic: Tuple[bool, bool, bool], + periodic_padding: int, + extra_padding: int + ) -> np.ndarray: + """Add appropriate padding to the array""" + + # padding in periodic directions + width = periodic_padding + extra_padding + if width != 0: + arr = np.pad(arr, [(width, width) if p else (0, 0) for p in periodic], 'wrap') + + # padding in non-periodic directions + width = extra_padding + if width != 0: + arr = np.pad(arr, [(0, 0) if p else (width, width) for p in periodic], 'symmetric') + + return arr + + +class VOFDomain: + """Holds the vof field and information about periodicity""" + + def __init__(self, + void_fraction: np.ndarray, + periodic: List[bool], + periodic_padding: int, + extra_padding: int + ): + # Checks + assert len(periodic) == void_fraction.ndim + + # Convert original VOF felid into a 3D field + self.original_shape = void_fraction.shape + match void_fraction.ndim: + case 1: self.periodic = (False, periodic[0], False) + case 2: self.periodic = (periodic[0], periodic[1], False) + case 3: self.periodic = (periodic[0], periodic[1], periodic[2]) + case _: + raise ValueError("Unexpected void_fraction.ndim: " + str(void_fraction.ndim)) + + self._vof_storage = np.atleast_3d(void_fraction.copy()) + + # Add padding + self.periodic_padding = periodic_padding + self.extra_padding = extra_padding + self._vof_storage = _pad_array(self._vof_storage, self.periodic, + periodic_padding=self.periodic_padding, + extra_padding=self.extra_padding) + + def vof(self, padding: int = 0) -> np.ndarray: + """ + Padding is `periodic_padding + padding` in periodic directions and `padding` in non-periodic directions + """ + skip = self.extra_padding - padding + assert skip >= 0, f"requested pad {padding} larger than domain's extra_padding {self.extra_padding}" + + if skip == 0: + return self._vof_storage + else: + return self._vof_storage[skip:-(skip), skip:-(skip), skip:-(skip)] + + def convert_to_original_shape(self, arr: np.ndarray) -> np.ndarray: + """ + Convert back to input shape of void_fraction + """ + # remove padding in periodic directions + if self.periodic_padding != 0: + if self.periodic[0]: + arr = arr[self.periodic_padding:-self.periodic_padding, :, :] + if self.periodic[1]: + arr = arr[:, self.periodic_padding:-self.periodic_padding, :] + if self.periodic[2]: + arr = arr[:, :, self.periodic_padding:-self.periodic_padding] + + return arr.reshape(self.original_shape) diff --git a/tests/test_boundaries.py b/tests/test_boundaries.py index 5db35fb..2e1b099 100644 --- a/tests/test_boundaries.py +++ b/tests/test_boundaries.py @@ -1,7 +1,8 @@ import pytest import numpy as np -from blobid.utils.boundaries import pad_array, stitch_boundaries +from blobid.utils.ccl import stitch_boundaries +from blobid.utils.domain import _pad_array from blobid.utils.labeling import LabelDatabase from blobid.utils.reconstruction import normals @@ -13,56 +14,57 @@ def vof_3d() -> np.ndarray: def check_periodicity(arr, periodicity, depth): - if periodicity[0]: + if periodicity[0] and depth != 0: assert np.all(arr[:depth, :, :] == arr[-2*depth:-depth, :, :]) assert np.all(arr[depth:2*depth, :, :] == arr[-depth:, :, :]) - if periodicity[1]: + if periodicity[1] and depth != 0: assert np.all(arr[:, :depth, :] == arr[:, -2*depth:-depth, :]) assert np.all(arr[:, depth:2*depth, :] == arr[:, -depth:, :]) - if periodicity[2]: + if periodicity[2] and depth != 0: assert np.all(arr[:, :, :depth] == arr[:, :, -2*depth:-depth]) assert np.all(arr[:, :, depth:2*depth] == arr[:, :, -depth:]) def check_symmetry(arr, symmetry, depth): - if symmetry[0]: + if symmetry[0] and depth != 0: assert np.all(arr[:depth, :, :] == np.flip(arr[depth:2*depth, :, :], axis=0)) assert np.all(arr[-depth:, :, :] == np.flip(arr[-2*depth:-depth, :, :], axis=0)) - if symmetry[1]: + if symmetry[1] and depth != 0: assert np.all(arr[:, :depth, :] == np.flip(arr[:, depth:2*depth, :], axis=1)) assert np.all(arr[:, -depth:, :] == np.flip(arr[:, -2*depth:-depth, :], axis=1)) - if symmetry[2]: + if symmetry[2] and depth != 0: assert np.all(arr[:, :, :depth] == np.flip(arr[:, :, depth:2*depth], axis=2)) assert np.all(arr[:, :, -depth:] == np.flip(arr[:, :, -2*depth:-depth], axis=2)) def test_pad_array(vof_3d): - for extra in range(3): - for n in range(8): - per = [bool(n % 2), bool((n//2) % 2), bool((n//4) % 2)] + for periodic_size in range(3): + for extra in range(3): + for n in range(8): + per = [bool(n % 2), bool((n//2) % 2), bool((n//4) % 2)] - arr = pad_array(vof_3d.copy(), per, extra) + arr = _pad_array(vof_3d.copy(), per, periodic_size, extra) - # make sure dimensions are right - for d in range(3): - assert arr.shape[d] == vof_3d.shape[d]+2*per[d]+2*extra + # make sure dimensions are right + for d in range(3): + assert arr.shape[d] == vof_3d.shape[d]+2*per[d]*periodic_size+2*extra - # make sure nothing else has changed - unchanged_range = arr[ - (per[0]+extra):arr.shape[0]-(per[0]+extra), - (per[1]+extra):arr.shape[1]-(per[1]+extra), - (per[2]+extra):arr.shape[2]-(per[2]+extra) - ] - assert np.all(unchanged_range == vof_3d) + # make sure nothing else has changed + unchanged_range = arr[ + (per[0]*periodic_size+extra):arr.shape[0]-(per[0]*periodic_size+extra), + (per[1]*periodic_size+extra):arr.shape[1]-(per[1]*periodic_size+extra), + (per[2]*periodic_size+extra):arr.shape[2]-(per[2]*periodic_size+extra) + ] + assert np.all(unchanged_range == vof_3d) - # check edges - check_periodicity(arr, per, extra+1) - if extra != 0: - check_symmetry(arr, [not p for p in per], extra) + # check edges + check_periodicity(arr, per, extra+periodic_size) + if extra != 0: + check_symmetry(arr, [not p for p in per], extra) def test_stitch_boundaries(): # noqa: C901 @@ -76,21 +78,21 @@ def create_label_field(n): def check_periodicity(labels, sets, periodicity): if periodicity[0]: - for a, b in zip(labels[0, :, :].flat, labels[-1, :, :].flat): + for a, b in zip(labels[1, :, :].flat, labels[-2, :, :].flat): assert sets.root(a) == sets.root(b) else: for a, b in zip(labels[0, :, :].flat, labels[-1, :, :].flat): assert sets.root(a) != sets.root(b) if periodicity[1]: - for a, b in zip(labels[:, 0, :].flat, labels[:, -1, :].flat): + for a, b in zip(labels[:, 1, :].flat, labels[:, -2, :].flat): assert sets.root(a) == sets.root(b) else: for a, b in zip(labels[:, 0, :].flat, labels[:, -1, :].flat): assert sets.root(a) != sets.root(b) if periodicity[2]: - for a, b in zip(labels[:, :, 0].flat, labels[:, :, -1].flat): + for a, b in zip(labels[:, :, 1].flat, labels[:, :, -2].flat): assert sets.root(a) == sets.root(b) else: for a, b in zip(labels[:, :, 0].flat, labels[:, :, -1].flat): @@ -120,12 +122,12 @@ def check_periodicity(labels, sets, periodicity): for a, b in zip(labels[:, :, -1].flat, labels[:, :, -2].flat): sets.merge(a, b) - (labels, sets) = stitch_boundaries(labels, sets, per) + sets = stitch_boundaries(labels, sets, per) # make sure the size is right - assert labels.shape[0] == 5 - assert labels.shape[1] == 3 - assert labels.shape[2] == 4 + assert labels.shape[0] == 5 + 2 * per[0] + assert labels.shape[1] == 3 + 2 * per[1] + assert labels.shape[2] == 4 + 2 * per[2] check_periodicity(labels, sets, per) @@ -133,7 +135,7 @@ def check_periodicity(labels, sets, periodicity): def test_normal_calculation_at_boundaries(vof_3d): for n in range(8): per = [bool(n % 2), bool((n//2) % 2), bool((n//4) % 2)] - f = pad_array(vof_3d.copy(), per, 1) + f = _pad_array(vof_3d.copy(), per, 1, 1) # normals should be the same on periodic sides n = normals(f, normals_method='CD')