diff --git a/spinstep/utils/__init__.py b/spinstep/utils/__init__.py index 2f41c50..052def5 100644 --- a/spinstep/utils/__init__.py +++ b/spinstep/utils/__init__.py @@ -24,11 +24,17 @@ "rotation_matrix_to_quaternion", "get_relative_spin", "get_unique_relative_spins", + "forward_vector_from_quaternion", + "direction_to_quaternion", + "angle_between_directions", ] from .array_backend import get_array_module from .quaternion_math import batch_quaternion_angle from .quaternion_utils import ( + angle_between_directions, + direction_to_quaternion, + forward_vector_from_quaternion, get_relative_spin, get_unique_relative_spins, is_within_angle_threshold, diff --git a/spinstep/utils/quaternion_utils.py b/spinstep/utils/quaternion_utils.py index 1d711c9..1d01aaa 100644 --- a/spinstep/utils/quaternion_utils.py +++ b/spinstep/utils/quaternion_utils.py @@ -16,6 +16,9 @@ "rotation_matrix_to_quaternion", "get_relative_spin", "get_unique_relative_spins", + "forward_vector_from_quaternion", + "direction_to_quaternion", + "angle_between_directions", ] from typing import List, Sequence @@ -106,6 +109,63 @@ def rotation_matrix_to_quaternion(m: ArrayLike) -> np.ndarray: return q / n if n > 1e-8 else np.array([0.0, 0.0, 0.0, 1.0]) +def forward_vector_from_quaternion(q: ArrayLike) -> np.ndarray: + """Extract the forward (look) direction from a quaternion. + + The forward direction is defined as ``[0, 0, -1]`` rotated by the + quaternion, following the convention where negative-Z is "forward". + + Args: + q: Quaternion ``[x, y, z, w]``. + + Returns: + Unit direction vector ``(3,)`` pointing forward. + """ + return R.from_quat(q).apply([0, 0, -1]) + + +def direction_to_quaternion(direction: ArrayLike) -> np.ndarray: + """Convert a 3D direction vector to an orientation quaternion. + + The returned quaternion represents the rotation that aligns the + default forward axis ``[0, 0, -1]`` with the given *direction*. + + Args: + direction: Target direction vector (does not need to be normalised). + + Returns: + Unit quaternion ``[x, y, z, w]``. + """ + d = np.asarray(direction, dtype=float) + norm = np.linalg.norm(d) + if norm < 1e-8: + return np.array([0.0, 0.0, 0.0, 1.0]) + d = d / norm + rot, _ = R.align_vectors([d], [[0, 0, -1]]) + return rot.as_quat() + + +def angle_between_directions(d1: ArrayLike, d2: ArrayLike) -> float: + """Compute the angular distance (radians) between two direction vectors. + + Args: + d1: First direction vector. + d2: Second direction vector. + + Returns: + Angle in radians in the range ``[0, π]``. + """ + v1 = np.asarray(d1, dtype=float) + v2 = np.asarray(d2, dtype=float) + n1 = np.linalg.norm(v1) + n2 = np.linalg.norm(v2) + if n1 < 1e-8 or n2 < 1e-8: + return 0.0 + cos_angle = np.dot(v1 / n1, v2 / n2) + cos_angle = np.clip(cos_angle, -1.0, 1.0) + return float(np.arccos(cos_angle)) + + def get_relative_spin(nf: object, nt: object) -> np.ndarray: """Return the relative quaternion rotation from node *nf* to node *nt*. diff --git a/tests/test_utils.py b/tests/test_utils.py index 650885a..6e6ce90 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -18,6 +18,9 @@ rotate_quaternion, rotation_matrix_to_quaternion, get_relative_spin, + forward_vector_from_quaternion, + direction_to_quaternion, + angle_between_directions, ) @@ -236,3 +239,89 @@ def test_returns_list_of_unit_quaternions(self): for q in spins: assert np.linalg.norm(q) == pytest.approx(1.0, abs=1e-6) assert q[3] >= 0 # canonical form (w >= 0) + + +# ===== forward_vector_from_quaternion tests ===== + + +class TestForwardVectorFromQuaternion: + def test_identity_forward(self): + """Identity quaternion forward is [0, 0, -1].""" + fwd = forward_vector_from_quaternion([0, 0, 0, 1]) + assert np.allclose(fwd, [0, 0, -1], atol=1e-6) + + def test_180_yaw(self): + """180° yaw flips forward to [0, 0, 1].""" + q = R.from_euler("y", 180, degrees=True).as_quat() + fwd = forward_vector_from_quaternion(q) + assert np.allclose(fwd, [0, 0, 1], atol=1e-6) + + def test_unit_length(self): + """Forward vector is always unit length.""" + q = R.random().as_quat() + fwd = forward_vector_from_quaternion(q) + assert np.linalg.norm(fwd) == pytest.approx(1.0, abs=1e-6) + + +# ===== direction_to_quaternion tests ===== + + +class TestDirectionToQuaternion: + def test_forward_direction(self): + """Direction [0, 0, -1] gives identity-like quaternion.""" + q = direction_to_quaternion([0, 0, -1]) + fwd = R.from_quat(q).apply([0, 0, -1]) + assert np.allclose(fwd, [0, 0, -1], atol=1e-6) + + def test_roundtrip(self): + """direction_to_quaternion → forward_vector_from_quaternion roundtrip.""" + direction = np.array([1.0, 2.0, -3.0]) + direction = direction / np.linalg.norm(direction) + q = direction_to_quaternion(direction) + fwd = forward_vector_from_quaternion(q) + assert np.allclose(fwd, direction, atol=1e-6) + + def test_unit_quaternion(self): + """Returned quaternion is a unit quaternion.""" + q = direction_to_quaternion([1, 0, 0]) + assert np.linalg.norm(q) == pytest.approx(1.0, abs=1e-6) + + def test_zero_vector(self): + """Zero vector returns identity quaternion.""" + q = direction_to_quaternion([0, 0, 0]) + assert np.allclose(q, [0, 0, 0, 1], atol=1e-6) + + +# ===== angle_between_directions tests ===== + + +class TestAngleBetweenDirections: + def test_same_direction(self): + """Angle between identical directions is zero.""" + d = [1, 0, 0] + assert angle_between_directions(d, d) == pytest.approx(0.0, abs=1e-7) + + def test_opposite_directions(self): + """Angle between opposite directions is π.""" + assert angle_between_directions( + [0, 0, 1], [0, 0, -1] + ) == pytest.approx(np.pi, abs=1e-6) + + def test_perpendicular_directions(self): + """Angle between perpendicular directions is π/2.""" + assert angle_between_directions( + [1, 0, 0], [0, 1, 0] + ) == pytest.approx(np.pi / 2, abs=1e-6) + + def test_unnormalized_inputs(self): + """Works with non-unit direction vectors.""" + assert angle_between_directions( + [3, 0, 0], [0, 4, 0] + ) == pytest.approx(np.pi / 2, abs=1e-6) + + def test_zero_vector(self): + """Zero vector returns 0.0.""" + assert angle_between_directions([0, 0, 0], [1, 0, 0]) == pytest.approx( + 0.0, abs=1e-7 + ) +