Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion blobid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .solver import get_labels
from . import labeling

__all__ = ['get_labels',]
__all__ = ['get_labels', 'labeling']
File renamed without changes.
4 changes: 2 additions & 2 deletions blobid/utils/numba_support.py → blobid/_numba_support.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Setup Numba if available"""

_numba_availible = True
numba_availible = True
try:
from numba import njit
except ImportError:
_numba_availible = False
numba_availible = False

# create a dummy njit decorator
def njit(func): return func
3 changes: 3 additions & 0 deletions blobid/labeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .labeling import apply_ccl

__all__ = ['apply_ccl']
27 changes: 2 additions & 25 deletions blobid/utils/ccl.py → blobid/labeling/ccl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""
Tools for running connected component labeling
"""
from typing import Tuple

import numpy as np
from blobid.utils.numba_support import njit
from blobid.utils.labeling import LabelDatabase, _DatabaseStorage, _merge
from .._numba_support import njit
from .database import LabelDatabase, _DatabaseStorage, _merge


def get_temporary_labels(
Expand Down Expand Up @@ -39,28 +38,6 @@ def get_temporary_labels(
return (labels, label_database)


def stitch_boundaries(
labels: np.ndarray,
sets: LabelDatabase,
periodic: Tuple[bool, bool, bool]
) -> LabelDatabase:
"""Stitch together periodic boundaries and remove padding"""

if periodic[0]:
for a, b in zip(labels[0, :, :].flat, labels[-2, :, :].flat):
sets.merge(a, b)

if periodic[1]:
for a, b in zip(labels[:, 0, :].flat, labels[:, -2, :].flat):
sets.merge(a, b)

if periodic[2]:
for a, b in zip(labels[:, :, 0].flat, labels[:, :, -2].flat):
sets.merge(a, b)

return sets


@njit
def _tail_pass(
labels: np.ndarray,
Expand Down
2 changes: 1 addition & 1 deletion blobid/utils/labeling.py → blobid/labeling/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from blobid.utils.numba_support import njit
from .._numba_support import njit

_DatabaseStorage = namedtuple('_DatabaseStorage', 'r, n, t')
_END_OF_SET = 0
Expand Down
76 changes: 76 additions & 0 deletions blobid/labeling/labeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Tuple

import numpy as np

from .database import LabelDatabase
from .ccl import get_temporary_labels


def apply_ccl(
is_object: np.ndarray,
is_connected: np.ndarray,
periodic: Tuple[bool, bool, bool],
label_type
) -> np.ndarray:
r"""
Given connectedness, calculate unique labels for each connected region of object cells.

Parameters
----------
is_object : ndarray[ni, nj, nk]
`is_object[i, j, k]` is true if cell `[i,j,k]` is an object cell
is_connected : ndarray[ni, nj, nk, 3]
`is_connected[i, j, k, d]` is true if cell `[i,j,k]` us connected to the neighbor in the negative d direction.
For example, `is_connected[i, j, k, 0]` is true if cell `[i,j,k]` is connected to cell `[i-1,j,k]`
periodic: (bool, bool, bool)
If `periodic[d]`, then cells on each edge in direction `d` are considered the same cell.
For example, cell `[0,j,k]` is the same as cell `[-1,j,k]` if `periodic[0]` is true.
label_type : dtype
Determines the integer data-type of `labels`

Returns
-------
labels : ndarray[ni, nj, nk]
An array of type `label_type` with the same shape as `is_object`.

"""
# checks
assert is_object.ndim == 3
assert is_connected.ndim == 4
assert np.all(is_object.shape == is_connected.shape[:-1])
assert not np.any(is_connected[0, :, :, 0])
assert not np.any(is_connected[:, 0, :, 1])
assert not np.any(is_connected[:, :, 0, 2])

# do initial labeling
(labels, label_database) = get_temporary_labels(is_object, is_connected, label_type)

# stitch together periodic boundaries
label_database = _stitch_boundaries(labels, label_database, periodic)

# do final labeling with sequential labels
labels = label_database.get_sequential_lookup_table()[labels]

return labels


def _stitch_boundaries(
labels: np.ndarray,
sets: LabelDatabase,
periodic: Tuple[bool, bool, bool]
) -> LabelDatabase:
"""Stitch together periodic boundaries and remove padding"""

if periodic[0]:
for a, b in zip(labels[0, :, :].flat, labels[-2, :, :].flat):
sets.merge(a, b)

if periodic[1]:
for a, b in zip(labels[:, 0, :].flat, labels[:, -2, :].flat):
sets.merge(a, b)

if periodic[2]:
for a, b in zip(labels[:, :, 0].flat, labels[:, :, -2].flat):
sets.merge(a, b)

return sets
3 changes: 3 additions & 0 deletions blobid/reconstruction/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .reconstruction import normals

__all__ = ['normals']
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
import numpy as np

from blobid.utils.numba_support import njit, _numba_availible
from .._numba_support import njit, numba_availible

NORMALS_TYPE = np.int8

Expand Down Expand Up @@ -38,7 +38,7 @@ def normals_CD(f: np.ndarray) -> np.ndarray:

def normals_WY(f: np.ndarray) -> np.ndarray:
# Numba only supports float32 and float64
if _numba_availible and (f.dtype not in (np.float32, np.float64)):
if numba_availible and (f.dtype not in (np.float32, np.float64)):
warnings.warn(f"void_fraction type {f.dtype.name} not supported, using float32")
f = f.astype(np.float32)

Expand Down
33 changes: 18 additions & 15 deletions blobid/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import numpy as np

from blobid.utils.domain import VOFDomain
from ._domain import VOFDomain

from blobid.utils.ccl import get_temporary_labels, stitch_boundaries
from blobid.utils.reconstruction import normals
from . import labeling
from . import reconstruction


def get_labels(
Expand Down Expand Up @@ -155,21 +155,24 @@ def get_labels(
periodic_padding=1,
extra_padding=1 if (use_normals or (cutoff_method == 'neighbors')) else 0)

# calculate object cells, removing the extra padding
# calculate object cells
is_object = _calc_object_cells(domain, cutoff, cutoff_method)

# calculate connectivity
is_connected = _calc_connections(is_object,
norm=normals(domain.vof(padding=1), normals_method) if use_normals else None)

# do initial labeling
(labels, label_database) = get_temporary_labels(is_object, is_connected, label_type)

# stitch together periodic boundaries
label_database = stitch_boundaries(labels, label_database, domain.periodic)

# do final labeling with sequential labels
labels = label_database.get_sequential_lookup_table()[labels]
if use_normals:
normals = reconstruction.normals(domain.vof(padding=1), normals_method)
else:
normals = None

is_connected = _calc_connections(is_object, norm=normals)

# do the labeling
labels = labeling.apply_ccl(
is_object=is_object,
is_connected=is_connected,
periodic=domain.periodic,
label_type=label_type
)

# reshape to original dimensions (removes padding in periodic directions)
return domain.convert_to_original_shape(labels)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_labeling.py → tests/labeling/test_database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest as pt
import numpy as np

from blobid.utils.labeling import LabelDatabase
from blobid.labeling.database import LabelDatabase


def test_label_sets():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import pytest
import numpy as np

from blobid.utils.numba_support import _numba_availible
from blobid.utils.reconstruction import normals, NORMALS_TYPE, _get_dominant_direction
from blobid._numba_support import numba_availible
from blobid.reconstruction.reconstruction import normals, NORMALS_TYPE, _get_dominant_direction


def test_reconstruction_CD():
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_vof_type():
f = np.random.rand(3, 3, 3)

# float16 should give a warning
if _numba_availible:
if numba_availible:
with pytest.warns():
normals(f.astype(np.float16), 'WY')

Expand Down
10 changes: 5 additions & 5 deletions tests/test_boundaries.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
import numpy as np

from blobid.utils.ccl import stitch_boundaries
from blobid.utils.domain import _pad_array
from blobid.utils.labeling import LabelDatabase
from blobid.utils.reconstruction import normals
from blobid._domain import _pad_array
from blobid.labeling.database import LabelDatabase
from blobid.labeling.labeling import _stitch_boundaries
from blobid.reconstruction import normals


@pytest.fixture
Expand Down Expand Up @@ -122,7 +122,7 @@ def check_periodicity(labels, sets, periodicity):
for a, b in zip(labels[:, :, -1].flat, labels[:, :, -2].flat):
sets.merge(a, b)

sets = stitch_boundaries(labels, sets, per)
sets = _stitch_boundaries(labels, sets, per)

# make sure the size is right
assert labels.shape[0] == 5 + 2 * per[0]
Expand Down