diff --git a/blobid/__init__.py b/blobid/__init__.py index 80e88e8..f13abce 100644 --- a/blobid/__init__.py +++ b/blobid/__init__.py @@ -1,3 +1,4 @@ from .solver import get_labels +from . import labeling -__all__ = ['get_labels',] +__all__ = ['get_labels', 'labeling'] diff --git a/blobid/utils/domain.py b/blobid/_domain.py similarity index 100% rename from blobid/utils/domain.py rename to blobid/_domain.py diff --git a/blobid/utils/numba_support.py b/blobid/_numba_support.py similarity index 74% rename from blobid/utils/numba_support.py rename to blobid/_numba_support.py index 764e471..e3d32d3 100644 --- a/blobid/utils/numba_support.py +++ b/blobid/_numba_support.py @@ -1,10 +1,10 @@ """Setup Numba if available""" -_numba_availible = True +numba_availible = True try: from numba import njit except ImportError: - _numba_availible = False + numba_availible = False # create a dummy njit decorator def njit(func): return func diff --git a/blobid/labeling/__init__.py b/blobid/labeling/__init__.py new file mode 100644 index 0000000..e17d2eb --- /dev/null +++ b/blobid/labeling/__init__.py @@ -0,0 +1,3 @@ +from .labeling import apply_ccl + +__all__ = ['apply_ccl'] diff --git a/blobid/utils/ccl.py b/blobid/labeling/ccl.py similarity index 74% rename from blobid/utils/ccl.py rename to blobid/labeling/ccl.py index 14caff2..c094ba8 100644 --- a/blobid/utils/ccl.py +++ b/blobid/labeling/ccl.py @@ -1,11 +1,10 @@ """ Tools for running connected component labeling """ -from typing import Tuple import numpy as np -from blobid.utils.numba_support import njit -from blobid.utils.labeling import LabelDatabase, _DatabaseStorage, _merge +from .._numba_support import njit +from .database import LabelDatabase, _DatabaseStorage, _merge def get_temporary_labels( @@ -39,28 +38,6 @@ def get_temporary_labels( return (labels, label_database) -def stitch_boundaries( - labels: np.ndarray, - sets: LabelDatabase, - periodic: Tuple[bool, bool, bool] - ) -> LabelDatabase: - """Stitch together periodic boundaries and remove padding""" - - if periodic[0]: - for a, b in zip(labels[0, :, :].flat, labels[-2, :, :].flat): - sets.merge(a, b) - - if periodic[1]: - for a, b in zip(labels[:, 0, :].flat, labels[:, -2, :].flat): - sets.merge(a, b) - - if periodic[2]: - for a, b in zip(labels[:, :, 0].flat, labels[:, :, -2].flat): - sets.merge(a, b) - - return sets - - @njit def _tail_pass( labels: np.ndarray, diff --git a/blobid/utils/labeling.py b/blobid/labeling/database.py similarity index 98% rename from blobid/utils/labeling.py rename to blobid/labeling/database.py index 03de62c..27b7a52 100644 --- a/blobid/utils/labeling.py +++ b/blobid/labeling/database.py @@ -5,7 +5,7 @@ import numpy as np -from blobid.utils.numba_support import njit +from .._numba_support import njit _DatabaseStorage = namedtuple('_DatabaseStorage', 'r, n, t') _END_OF_SET = 0 diff --git a/blobid/labeling/labeling.py b/blobid/labeling/labeling.py new file mode 100644 index 0000000..18f462f --- /dev/null +++ b/blobid/labeling/labeling.py @@ -0,0 +1,76 @@ +from typing import Tuple + +import numpy as np + +from .database import LabelDatabase +from .ccl import get_temporary_labels + + +def apply_ccl( + is_object: np.ndarray, + is_connected: np.ndarray, + periodic: Tuple[bool, bool, bool], + label_type +) -> np.ndarray: + r""" + Given connectedness, calculate unique labels for each connected region of object cells. + + Parameters + ---------- + is_object : ndarray[ni, nj, nk] + `is_object[i, j, k]` is true if cell `[i,j,k]` is an object cell + is_connected : ndarray[ni, nj, nk, 3] + `is_connected[i, j, k, d]` is true if cell `[i,j,k]` us connected to the neighbor in the negative d direction. + For example, `is_connected[i, j, k, 0]` is true if cell `[i,j,k]` is connected to cell `[i-1,j,k]` + periodic: (bool, bool, bool) + If `periodic[d]`, then cells on each edge in direction `d` are considered the same cell. + For example, cell `[0,j,k]` is the same as cell `[-1,j,k]` if `periodic[0]` is true. + label_type : dtype + Determines the integer data-type of `labels` + + Returns + ------- + labels : ndarray[ni, nj, nk] + An array of type `label_type` with the same shape as `is_object`. + + """ + # checks + assert is_object.ndim == 3 + assert is_connected.ndim == 4 + assert np.all(is_object.shape == is_connected.shape[:-1]) + assert not np.any(is_connected[0, :, :, 0]) + assert not np.any(is_connected[:, 0, :, 1]) + assert not np.any(is_connected[:, :, 0, 2]) + + # do initial labeling + (labels, label_database) = get_temporary_labels(is_object, is_connected, label_type) + + # stitch together periodic boundaries + label_database = _stitch_boundaries(labels, label_database, periodic) + + # do final labeling with sequential labels + labels = label_database.get_sequential_lookup_table()[labels] + + return labels + + +def _stitch_boundaries( + labels: np.ndarray, + sets: LabelDatabase, + periodic: Tuple[bool, bool, bool] + ) -> LabelDatabase: + """Stitch together periodic boundaries and remove padding""" + + if periodic[0]: + for a, b in zip(labels[0, :, :].flat, labels[-2, :, :].flat): + sets.merge(a, b) + + if periodic[1]: + for a, b in zip(labels[:, 0, :].flat, labels[:, -2, :].flat): + sets.merge(a, b) + + if periodic[2]: + for a, b in zip(labels[:, :, 0].flat, labels[:, :, -2].flat): + sets.merge(a, b) + + return sets diff --git a/blobid/reconstruction/__init__.py b/blobid/reconstruction/__init__.py new file mode 100644 index 0000000..bd984a5 --- /dev/null +++ b/blobid/reconstruction/__init__.py @@ -0,0 +1,3 @@ +from .reconstruction import normals + +__all__ = ['normals'] diff --git a/blobid/utils/reconstruction.py b/blobid/reconstruction/reconstruction.py similarity index 96% rename from blobid/utils/reconstruction.py rename to blobid/reconstruction/reconstruction.py index fb99302..46269a8 100644 --- a/blobid/utils/reconstruction.py +++ b/blobid/reconstruction/reconstruction.py @@ -4,7 +4,7 @@ import warnings import numpy as np -from blobid.utils.numba_support import njit, _numba_availible +from .._numba_support import njit, numba_availible NORMALS_TYPE = np.int8 @@ -38,7 +38,7 @@ def normals_CD(f: np.ndarray) -> np.ndarray: def normals_WY(f: np.ndarray) -> np.ndarray: # Numba only supports float32 and float64 - if _numba_availible and (f.dtype not in (np.float32, np.float64)): + if numba_availible and (f.dtype not in (np.float32, np.float64)): warnings.warn(f"void_fraction type {f.dtype.name} not supported, using float32") f = f.astype(np.float32) diff --git a/blobid/solver.py b/blobid/solver.py index f9ed2a9..662e436 100644 --- a/blobid/solver.py +++ b/blobid/solver.py @@ -5,10 +5,10 @@ import numpy as np -from blobid.utils.domain import VOFDomain +from ._domain import VOFDomain -from blobid.utils.ccl import get_temporary_labels, stitch_boundaries -from blobid.utils.reconstruction import normals +from . import labeling +from . import reconstruction def get_labels( @@ -155,21 +155,24 @@ def get_labels( periodic_padding=1, extra_padding=1 if (use_normals or (cutoff_method == 'neighbors')) else 0) - # calculate object cells, removing the extra padding + # calculate object cells is_object = _calc_object_cells(domain, cutoff, cutoff_method) # calculate connectivity - is_connected = _calc_connections(is_object, - norm=normals(domain.vof(padding=1), normals_method) if use_normals else None) - - # do initial labeling - (labels, label_database) = get_temporary_labels(is_object, is_connected, label_type) - - # stitch together periodic boundaries - label_database = stitch_boundaries(labels, label_database, domain.periodic) - - # do final labeling with sequential labels - labels = label_database.get_sequential_lookup_table()[labels] + if use_normals: + normals = reconstruction.normals(domain.vof(padding=1), normals_method) + else: + normals = None + + is_connected = _calc_connections(is_object, norm=normals) + + # do the labeling + labels = labeling.apply_ccl( + is_object=is_object, + is_connected=is_connected, + periodic=domain.periodic, + label_type=label_type + ) # reshape to original dimensions (removes padding in periodic directions) return domain.convert_to_original_shape(labels) diff --git a/tests/test_labeling.py b/tests/labeling/test_database.py similarity index 98% rename from tests/test_labeling.py rename to tests/labeling/test_database.py index 7a1c2c5..31f4dcb 100644 --- a/tests/test_labeling.py +++ b/tests/labeling/test_database.py @@ -1,7 +1,7 @@ import pytest as pt import numpy as np -from blobid.utils.labeling import LabelDatabase +from blobid.labeling.database import LabelDatabase def test_label_sets(): diff --git a/tests/test_reconstruction.py b/tests/reconstruction/test_reconstruction.py similarity index 91% rename from tests/test_reconstruction.py rename to tests/reconstruction/test_reconstruction.py index abddef1..d1f03fc 100644 --- a/tests/test_reconstruction.py +++ b/tests/reconstruction/test_reconstruction.py @@ -3,8 +3,8 @@ import pytest import numpy as np -from blobid.utils.numba_support import _numba_availible -from blobid.utils.reconstruction import normals, NORMALS_TYPE, _get_dominant_direction +from blobid._numba_support import numba_availible +from blobid.reconstruction.reconstruction import normals, NORMALS_TYPE, _get_dominant_direction def test_reconstruction_CD(): @@ -71,7 +71,7 @@ def test_vof_type(): f = np.random.rand(3, 3, 3) # float16 should give a warning - if _numba_availible: + if numba_availible: with pytest.warns(): normals(f.astype(np.float16), 'WY') diff --git a/tests/test_boundaries.py b/tests/test_boundaries.py index 2e1b099..3f4c47a 100644 --- a/tests/test_boundaries.py +++ b/tests/test_boundaries.py @@ -1,10 +1,10 @@ import pytest import numpy as np -from blobid.utils.ccl import stitch_boundaries -from blobid.utils.domain import _pad_array -from blobid.utils.labeling import LabelDatabase -from blobid.utils.reconstruction import normals +from blobid._domain import _pad_array +from blobid.labeling.database import LabelDatabase +from blobid.labeling.labeling import _stitch_boundaries +from blobid.reconstruction import normals @pytest.fixture @@ -122,7 +122,7 @@ def check_periodicity(labels, sets, periodicity): for a, b in zip(labels[:, :, -1].flat, labels[:, :, -2].flat): sets.merge(a, b) - sets = stitch_boundaries(labels, sets, per) + sets = _stitch_boundaries(labels, sets, per) # make sure the size is right assert labels.shape[0] == 5 + 2 * per[0]