diff --git a/diffractio/scalar_fields_XZ.py b/diffractio/scalar_fields_XZ.py index 48f48e1c..44c9a8f6 100755 --- a/diffractio/scalar_fields_XZ.py +++ b/diffractio/scalar_fields_XZ.py @@ -94,11 +94,19 @@ kernelRS, kernelRSinverse) from .scalar_masks_X import Scalar_mask_X from .scalar_sources_X import Scalar_source_X +from numba import njit copyreg.pickle(types.MethodType, _pickle_method, _unpickle_method) percentage_intensity_config = CONF_DRAWING['percentage_intensity'] +@njit(parallel=True, nogil=True) +def _rotate_numba(X, Z, angle, x0, z0): + """Numba-accelerated rotation computation""" + Xrot = x0 + (X - x0) * np.cos(angle) + (Z - z0) * np.sin(angle) + Zrot = z0 - (X - x0) * np.sin(angle) + (Z - z0) * np.cos(angle) + return Xrot, Zrot + class Scalar_field_XZ(): """Class for working with XZ scalar fields. @@ -300,11 +308,7 @@ def __rotate__(self, angle: float, position=None): # Definicion de la rotation x0, z0 = position - Xrot = x0 + (self.X - x0) * np.cos(angle) + (self.Z - - z0) * np.sin(angle) - Zrot = z0 - (self.X - x0) * np.sin(angle) + (self.Z - - z0) * np.cos(angle) - return Xrot, Zrot + return _rotate_numba(self.X, self.Z, angle, x0, z0) def size(self, verbose: bool = False):