diff --git a/blobid/plic/central_difference.py b/blobid/plic/central_difference.py index 832a803..5414891 100644 --- a/blobid/plic/central_difference.py +++ b/blobid/plic/central_difference.py @@ -12,8 +12,8 @@ def get_normals(f: np.ndarray, normals_dtype: npt.DTypeLike) -> np.ndarray: n = np.empty((3, f.shape[0]-2, f.shape[1]-2, f.shape[2]-2), dtype=normals_dtype) - n[0] = np.sign(f[:-2, 1:-1, 1:-1] - f[2:, 1:-1, 1:-1]) - n[1] = np.sign(f[1:-1, :-2, 1:-1] - f[1:-1, 2:, 1:-1]) - n[2] = np.sign(f[1:-1, 1:-1, :-2] - f[1:-1, 1:-1, 2:]) + n[0] = (f[:-2, 1:-1, 1:-1] - f[2:, 1:-1, 1:-1])/2.0 + n[1] = (f[1:-1, :-2, 1:-1] - f[1:-1, 2:, 1:-1])/2.0 + n[2] = (f[1:-1, 1:-1, :-2] - f[1:-1, 1:-1, 2:])/2.0 return n diff --git a/blobid/plic/nivira.py b/blobid/plic/nivira.py index fe9f648..293a13c 100644 --- a/blobid/plic/nivira.py +++ b/blobid/plic/nivira.py @@ -12,6 +12,8 @@ import numpy as np import numpy.typing as npt +from . import central_difference + from .._numba_support import njit, numba_availible @@ -25,6 +27,10 @@ def get_normals(f: np.ndarray, normals_dtype: npt.DTypeLike) -> np.ndarray: warnings.warn(f"void_fraction type {f.dtype.name} not supported, using float32") f = f.astype(np.float32) + if numba_availible and (np.dtype(normals_dtype) not in (np.float32, np.float64)): + warnings.warn(f"normals_dtype {np.dtype(normals_dtype).name} not supported, using float32") + normals_dtype = np.float32 + # estimate dominant direction d_dom = _get_dominant_direction(f) @@ -39,10 +45,7 @@ def get_normals(f: np.ndarray, normals_dtype: npt.DTypeLike) -> np.ndarray: def _get_dominant_direction(f): # estimate the normal using central difference - n_approx = np.empty((3, f.shape[0]-2, f.shape[1]-2, f.shape[2]-2), dtype=f.dtype) - n_approx[0] = f[:-2, 1:-1, 1:-1] - f[2:, 1:-1, 1:-1] - n_approx[1] = f[1:-1, :-2, 1:-1] - f[1:-1, 2:, 1:-1] - n_approx[2] = f[1:-1, 1:-1, :-2] - f[1:-1, 1:-1, 2:] + n_approx = central_difference.get_normals(f, normals_dtype=f.dtype) # cells do not have a normal if empty, full, or undefined CD invalid = (f[1:-1, 1:-1, 1:-1] == 0) | (f[1:-1, 1:-1, 1:-1] == 1) | np.all(n_approx == 0, axis=0) @@ -63,7 +66,7 @@ def _normals_WY(f, d_dom, d, dtype) -> np.ndarray: block = np.swapaxes(f[i:i+3, j:j+3, k:k+3], 0, d) # use central difference - n[i, j, k] = np.sign(block[0, 1, 1] - block[2, 1, 1]) + n[i, j, k] = (block[0, 1, 1] - block[2, 1, 1])/2.0 else: # skip the dimension we don't care about @@ -74,15 +77,15 @@ def _normals_WY(f, d_dom, d, dtype) -> np.ndarray: height = np.sum(block, axis=1 if (d_dom[i, j, k] > d) else 0) # central difference of summed columns - n_tmp = (height[0] - height[2])/2 + n_tmp = (height[0] - height[2])/2.0 if abs(n_tmp) < 0.5: - n[i, j, k] = np.sign(n_tmp) + n[i, j, k] = n_tmp else: # for steep interface, use one-sided differences if n_tmp * f[i+1, j+1, k+1] >= 0: - n[i, j, k] = np.sign(height[1] - height[2]) + n[i, j, k] = (height[1] - height[2])/1.0 else: - n[i, j, k] = np.sign(height[0] - height[1]) + n[i, j, k] = (height[0] - height[1])/1.0 return n diff --git a/blobid/plic/plic.py b/blobid/plic/plic.py index 34be0f1..f568e03 100644 --- a/blobid/plic/plic.py +++ b/blobid/plic/plic.py @@ -2,42 +2,44 @@ Calculate the sign of the interface normal based on a 3x3(x3) stencil """ import numpy as np +import numpy.typing as npt from .central_difference import get_normals as get_normals_CD from .nivira import get_normals as get_normals_WY -NORMALS_TYPE = np.int8 SUPPORTED_METHODS = ['CD', 'WY'] """Supported methods for calculating interface normals""" -def get_normals(void_fraction: np.ndarray, normals_method: str = 'CD') -> np.ndarray: +def get_normals( + void_fraction: np.ndarray, + normals_method: str = 'CD', + normals_type: npt.DTypeLike | None = None +) -> np.ndarray: r""" - Calculate the sign of the interface normals using method specified by `normals_method` + Calculate the the interface normals :math:`\vec{n}=\vec{n}(f)` using the method specified by `normals_method` Parameters ---------- void_fraction : ndarray[ni+2, nj+2, nk+2] - The void fractions for each grid cell. + The void fractions :math:`f` for each grid cell. normals_method: {'CD', 'WY'}, optional Method used to calculate interface normals. Defaults to 'CD'. + normals_type: DTypeLike + Determines the data-type of `normals`. Defaults to same type as `void_fraction` Returns ------- normals : ndarray[3, ni, nj, nk] - An integer array which returns the sign of the interface normal. - If the interface in a grid cell is undefined, :math:`\langle 0, 0, 0\rangle` will be returned. - - Notes - ----- - - The signs of the interface normal :math:`\langle n_x, n_y, n_z\rangle` are calculated using the method set by - `normals_method`: + `normals[:,i,j,k]` contains :math:`\vec{n}=\langle n_x, n_y, n_z\rangle` for cell `[i,j,k]`. + :math:`\vec{n}=\langle 0, 0, 0\rangle` indicates the interface in the cell is undefined. + Methods + ------- - If `normals_method` is `CD`, central differencing is used. For example .. math:: - \mathrm{sign}(n_x) = - \mathrm{sign}(f_{i+1} - f_{i-1}) + n_x = - \frac{f_{i+1} - f_{i-1}}{2} - If `normals_method` is `WY`, normals are calculated using the No Inversion VOF Interface Reconstruction Algorithm (NIVIRA) described by Weymouth and Yue.[1]_ `WY` is more accurate than `CD` but can be slower @@ -51,14 +53,17 @@ def get_normals(void_fraction: np.ndarray, normals_method: str = 'CD') -> np.nda [10.1016/j.jcp.2009.12.018](https://doi.org/10.1016/j.jcp.2009.12.018) """ + if normals_type is None: + normals_type = void_fraction.dtype + # checks assert void_fraction.ndim == 3 assert all(dim > 2 for dim in void_fraction.shape) match normals_method: case 'CD': - return get_normals_CD(void_fraction, NORMALS_TYPE) + return get_normals_CD(void_fraction, normals_type) case 'WY': - return get_normals_WY(void_fraction, NORMALS_TYPE) + return get_normals_WY(void_fraction, normals_type) case _: raise ValueError(f"normals_method '{normals_method}' is not supported") diff --git a/tests/plic/test_get_normals.py b/tests/plic/test_get_normals.py index 9b80407..b5829f1 100644 --- a/tests/plic/test_get_normals.py +++ b/tests/plic/test_get_normals.py @@ -32,8 +32,8 @@ def assert_undefined(norm): def test_end_to_end(fs_vof): expected_result = [ - ['CD', 'c90626929711b744621692f3561700a1'], - ['WY', '9ac131e45e11e9f9457327c01ad3fb8e'] + ['CD', '599820a35d442746d95e8cd2f06fbc61'], + ['WY', '2aeb743728db4ecefda7ce7527d836b4'] ] for method, signature in expected_result: diff --git a/tests/plic/test_nivira.py b/tests/plic/test_nivira.py index 8616ae0..d2274b8 100644 --- a/tests/plic/test_nivira.py +++ b/tests/plic/test_nivira.py @@ -5,7 +5,6 @@ import blobid.plic.nivira as nivira from blobid._numba_support import numba_availible -from blobid.plic.plic import NORMALS_TYPE def test_WY_Fig2(): @@ -19,9 +18,9 @@ def test_WY_Fig2(): assert nivira._get_dominant_direction(f)[0, 0, 0] == 0 - n = nivira.get_normals(f, NORMALS_TYPE).squeeze() - assert n[0] > 0 - assert n[1] > 0 + n = nivira.get_normals(f, f.dtype).squeeze() + assert n[0] == pytest.approx((1.0-0.15)/2.0) + assert n[1] == pytest.approx(2.05-1.25) assert n[2] == 0 @@ -31,10 +30,13 @@ def test_vof_type(): # float16 should give a warning if numba_availible: with pytest.warns(): - nivira.get_normals(f.astype(np.float16), NORMALS_TYPE) + nivira.get_normals(f.astype(np.float16), np.float32) + + with pytest.warns(): + nivira.get_normals(f.astype(np.float32), np.float16) # float32 and float64 should work without warning with warnings.catch_warnings(): warnings.simplefilter("error") - nivira.get_normals(f.astype(np.float32), NORMALS_TYPE) - nivira.get_normals(f.astype(np.float64), NORMALS_TYPE) + nivira.get_normals(f.astype(np.float32), np.float32) + nivira.get_normals(f.astype(np.float64), np.float32)