Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions spinstep/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,17 @@
"rotation_matrix_to_quaternion",
"get_relative_spin",
"get_unique_relative_spins",
"forward_vector_from_quaternion",
"direction_to_quaternion",
"angle_between_directions",
]

from .array_backend import get_array_module
from .quaternion_math import batch_quaternion_angle
from .quaternion_utils import (
angle_between_directions,
direction_to_quaternion,
forward_vector_from_quaternion,
get_relative_spin,
get_unique_relative_spins,
is_within_angle_threshold,
Expand Down
60 changes: 60 additions & 0 deletions spinstep/utils/quaternion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
"rotation_matrix_to_quaternion",
"get_relative_spin",
"get_unique_relative_spins",
"forward_vector_from_quaternion",
"direction_to_quaternion",
"angle_between_directions",
]

from typing import List, Sequence
Expand Down Expand Up @@ -106,6 +109,63 @@ def rotation_matrix_to_quaternion(m: ArrayLike) -> np.ndarray:
return q / n if n > 1e-8 else np.array([0.0, 0.0, 0.0, 1.0])


def forward_vector_from_quaternion(q: ArrayLike) -> np.ndarray:
"""Extract the forward (look) direction from a quaternion.

The forward direction is defined as ``[0, 0, -1]`` rotated by the
quaternion, following the convention where negative-Z is "forward".

Args:
q: Quaternion ``[x, y, z, w]``.

Returns:
Unit direction vector ``(3,)`` pointing forward.
"""
return R.from_quat(q).apply([0, 0, -1])


def direction_to_quaternion(direction: ArrayLike) -> np.ndarray:
"""Convert a 3D direction vector to an orientation quaternion.

The returned quaternion represents the rotation that aligns the
default forward axis ``[0, 0, -1]`` with the given *direction*.

Args:
direction: Target direction vector (does not need to be normalised).

Returns:
Unit quaternion ``[x, y, z, w]``.
"""
d = np.asarray(direction, dtype=float)
norm = np.linalg.norm(d)
if norm < 1e-8:
return np.array([0.0, 0.0, 0.0, 1.0])
d = d / norm
rot, _ = R.align_vectors([d], [[0, 0, -1]])
return rot.as_quat()


def angle_between_directions(d1: ArrayLike, d2: ArrayLike) -> float:
"""Compute the angular distance (radians) between two direction vectors.

Args:
d1: First direction vector.
d2: Second direction vector.

Returns:
Angle in radians in the range ``[0, π]``.
"""
v1 = np.asarray(d1, dtype=float)
v2 = np.asarray(d2, dtype=float)
n1 = np.linalg.norm(v1)
n2 = np.linalg.norm(v2)
if n1 < 1e-8 or n2 < 1e-8:
return 0.0
cos_angle = np.dot(v1 / n1, v2 / n2)
cos_angle = np.clip(cos_angle, -1.0, 1.0)
return float(np.arccos(cos_angle))


def get_relative_spin(nf: object, nt: object) -> np.ndarray:
"""Return the relative quaternion rotation from node *nf* to node *nt*.

Expand Down
89 changes: 89 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
rotate_quaternion,
rotation_matrix_to_quaternion,
get_relative_spin,
forward_vector_from_quaternion,
direction_to_quaternion,
angle_between_directions,
)


Expand Down Expand Up @@ -236,3 +239,89 @@ def test_returns_list_of_unit_quaternions(self):
for q in spins:
assert np.linalg.norm(q) == pytest.approx(1.0, abs=1e-6)
assert q[3] >= 0 # canonical form (w >= 0)


# ===== forward_vector_from_quaternion tests =====


class TestForwardVectorFromQuaternion:
def test_identity_forward(self):
"""Identity quaternion forward is [0, 0, -1]."""
fwd = forward_vector_from_quaternion([0, 0, 0, 1])
assert np.allclose(fwd, [0, 0, -1], atol=1e-6)

def test_180_yaw(self):
"""180° yaw flips forward to [0, 0, 1]."""
q = R.from_euler("y", 180, degrees=True).as_quat()
fwd = forward_vector_from_quaternion(q)
assert np.allclose(fwd, [0, 0, 1], atol=1e-6)

def test_unit_length(self):
"""Forward vector is always unit length."""
q = R.random().as_quat()
fwd = forward_vector_from_quaternion(q)
assert np.linalg.norm(fwd) == pytest.approx(1.0, abs=1e-6)


# ===== direction_to_quaternion tests =====


class TestDirectionToQuaternion:
def test_forward_direction(self):
"""Direction [0, 0, -1] gives identity-like quaternion."""
q = direction_to_quaternion([0, 0, -1])
fwd = R.from_quat(q).apply([0, 0, -1])
assert np.allclose(fwd, [0, 0, -1], atol=1e-6)

def test_roundtrip(self):
"""direction_to_quaternion → forward_vector_from_quaternion roundtrip."""
direction = np.array([1.0, 2.0, -3.0])
direction = direction / np.linalg.norm(direction)
q = direction_to_quaternion(direction)
fwd = forward_vector_from_quaternion(q)
assert np.allclose(fwd, direction, atol=1e-6)

def test_unit_quaternion(self):
"""Returned quaternion is a unit quaternion."""
q = direction_to_quaternion([1, 0, 0])
assert np.linalg.norm(q) == pytest.approx(1.0, abs=1e-6)

def test_zero_vector(self):
"""Zero vector returns identity quaternion."""
q = direction_to_quaternion([0, 0, 0])
assert np.allclose(q, [0, 0, 0, 1], atol=1e-6)


# ===== angle_between_directions tests =====


class TestAngleBetweenDirections:
def test_same_direction(self):
"""Angle between identical directions is zero."""
d = [1, 0, 0]
assert angle_between_directions(d, d) == pytest.approx(0.0, abs=1e-7)

def test_opposite_directions(self):
"""Angle between opposite directions is π."""
assert angle_between_directions(
[0, 0, 1], [0, 0, -1]
) == pytest.approx(np.pi, abs=1e-6)

def test_perpendicular_directions(self):
"""Angle between perpendicular directions is π/2."""
assert angle_between_directions(
[1, 0, 0], [0, 1, 0]
) == pytest.approx(np.pi / 2, abs=1e-6)

def test_unnormalized_inputs(self):
"""Works with non-unit direction vectors."""
assert angle_between_directions(
[3, 0, 0], [0, 4, 0]
) == pytest.approx(np.pi / 2, abs=1e-6)

def test_zero_vector(self):
"""Zero vector returns 0.0."""
assert angle_between_directions([0, 0, 0], [1, 0, 0]) == pytest.approx(
0.0, abs=1e-7
)

Loading