diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fca6f8f..1930765 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,9 +2,9 @@ name: SpinStep CI on: push: - branches: [main, dev, 'features/cuda'] + branches: [main, dev, 'features/cuda', 'feature/*'] pull_request: - branches: [main, dev] + branches: [main, dev, 'feature/*'] permissions: contents: read diff --git a/CHANGELOG.md b/CHANGELOG.md index a5b1689..f4719ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,55 @@ and uses [Semantic Versioning](https://semver.org/). --- +## [0.5.0a0] – 2026-04-10 + +### Added +- **Control subsystem** (`spinstep.control`): + - `OrientationState` — observer-centered spherical state (quaternion + distance) + - `ControlCommand` — angular + radial velocity commands + - `ProportionalOrientationController` — P-controller with rate limiting + - `PIDOrientationController` — PID controller with anti-windup + - `OrientationTrajectory`, `TrajectoryInterpolator`, `TrajectoryController` — + waypoint-based trajectory tracking with SLERP interpolation + - `integrate_state()`, `compute_orientation_error()` — state utilities +- **Math library** (`spinstep.math`): + - `core` — `quaternion_multiply`, `quaternion_conjugate`, `quaternion_normalize`, + `quaternion_inverse` + - `interpolation` — `slerp` (spherical linear), `squad` (spherical cubic) + - `geometry` — `quaternion_distance`, `rotate_quaternion`, + `forward_vector_from_quaternion`, `direction_to_quaternion`, + `angle_between_directions`, `is_within_angle_threshold` + - `conversions` — `quaternion_from_euler`, `rotation_matrix_to_quaternion`, + `quaternion_from_rotvec`, `quaternion_to_rotvec` + - `analysis` — `batch_quaternion_angle`, `angular_velocity_from_quaternions`, + `get_relative_spin`, `get_unique_relative_spins`, `NodeProtocol` + - `constraints` — `clamp_rotation_angle` +- `NodeProtocol` — `typing.Protocol` for any object with `.orientation` attribute, + replacing `object` type hints in `get_relative_spin()` and + `get_unique_relative_spins()` +- `Node.add_child()` method for ergonomic tree building +- Comprehensive API stability tests (`test_api.py`) — 97 parametrized tests + covering all subpackage exports and backward compatibility +- `Typing :: Typed` classifier in `pyproject.toml` +- `Changelog` and `Documentation` project URLs +- `[[tool.mypy.overrides]]` for `scipy.*`, `sklearn.*`, `healpy.*`, `cupy.*` +- CI workflow now triggers on `feature/*` branches + +### Changed +- **Traversal classes moved** to `spinstep.traversal` subpackage + (`node.py`, `continuous.py`, `discrete.py`, `discrete_iterator.py`). + Top-level imports via `from spinstep import Node` remain backward-compatible. +- **`spinstep.utils` is now a backward-compatible re-export layer**. + All 13 quaternion functions now delegate to `spinstep.math` — the single + source of truth. Direct `from spinstep.utils import …` continues to work. +- Version bumped to `0.5.0a0` + +### Deprecated +- Direct imports from `spinstep.utils.quaternion_utils` and + `spinstep.utils.quaternion_math`. Use `spinstep.math` instead. + +--- + ## [0.3.0a0] – 2026-03-26 ### Added diff --git a/pyproject.toml b/pyproject.toml index 782bf98..b972696 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta" [project] name = "spinstep" -version = "0.3.1a1" -description = "Quaternion-based tree traversal for orientation-aware structures." +version = "0.5.0a0" +description = "Quaternion-based orientation control and traversal in observer-centered spherical coordinates." authors = [{ name = "Eraldo Marques", email = "eraldo.bernardo@gmail.com" }] readme = "README.md" license = "MIT" @@ -20,6 +20,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Mathematics", + "Typing :: Typed", ] dependencies = [ "numpy>=1.22", @@ -43,11 +44,16 @@ dev = [ [project.urls] Repository = "https://github.com/VoxleOne/SpinStep" +Changelog = "https://github.com/VoxleOne/SpinStep/blob/main/CHANGELOG.md" +Documentation = "https://github.com/VoxleOne/SpinStep/tree/main/docs" [tool.setuptools.packages.find] include = ["spinstep*"] exclude = ["benchmark*", "demos*", "examples*", "tests*", "docs*"] +[tool.setuptools.package-data] +spinstep = ["py.typed"] + [tool.pytest.ini_options] testpaths = ["tests"] @@ -57,5 +63,14 @@ line-length = 88 [tool.mypy] strict = true +[[tool.mypy.overrides]] +module = [ + "scipy.*", + "sklearn.*", + "healpy.*", + "cupy.*", +] +ignore_missing_imports = true + [tool.ruff] line-length = 88 diff --git a/spinstep/__init__.py b/spinstep/__init__.py index 7b2d1e6..2936b73 100644 --- a/spinstep/__init__.py +++ b/spinstep/__init__.py @@ -2,32 +2,78 @@ # Author: Eraldo B. Marques — Created: 2025-05-14 # See LICENSE.txt for full terms. This header must be retained in redistributions. -"""SpinStep: A quaternion-driven traversal framework. +"""SpinStep: Quaternion-based orientation control and traversal. -Provides quaternion-based tree traversal for orientation-aware structures, -supporting both continuous and discrete rotation stepping. +SpinStep uses an observer-centered spherical model where every guided +vehicle (node) is located by a quaternion (direction from the observer) +and a radial distance (layer). -Example usage:: +Control layer (primary API):: - from spinstep import Node, QuaternionDepthIterator + from spinstep import ( + OrientationState, ControlCommand, + ProportionalOrientationController, PIDOrientationController, + OrientationTrajectory, TrajectoryController, + integrate_state, compute_orientation_error, + slerp, + ) - root = Node("root", [0, 0, 0, 1]) - child = Node("child", [0, 0, 0.1, 0.995]) - root.children.append(child) +Traversal layer (original tree-walking API):: - step = [0, 0, 0.05, 0.9987] # small rotation about Z - for node in QuaternionDepthIterator(root, step): - print(node.name) + from spinstep.traversal import ( + Node, QuaternionDepthIterator, + DiscreteOrientationSet, DiscreteQuaternionIterator, + ) + +Math layer:: + + from spinstep.math import quaternion_multiply, quaternion_distance, slerp """ -__version__ = "0.3.0a0" +__version__ = "0.5.0a0" + +# --- control layer (primary API) --- +from .control.state import ( + ControlCommand, + OrientationState, + compute_orientation_error, + integrate_state, +) +from .control.controllers import ( + OrientationController, + PIDOrientationController, + ProportionalOrientationController, +) +from .control.trajectory import ( + OrientationTrajectory, + TrajectoryController, + TrajectoryInterpolator, +) + +# --- key math utilities at top level --- +from .math.interpolation import slerp -from .node import Node -from .traversal import QuaternionDepthIterator -from .discrete import DiscreteOrientationSet -from .discrete_iterator import DiscreteQuaternionIterator +# --- backward-compatible traversal re-exports --- +from .traversal.node import Node +from .traversal.continuous import QuaternionDepthIterator +from .traversal.discrete import DiscreteOrientationSet +from .traversal.discrete_iterator import DiscreteQuaternionIterator __all__ = [ + # control + "OrientationState", + "ControlCommand", + "integrate_state", + "compute_orientation_error", + "OrientationController", + "ProportionalOrientationController", + "PIDOrientationController", + "OrientationTrajectory", + "TrajectoryInterpolator", + "TrajectoryController", + # math + "slerp", + # traversal (backward compat) "Node", "QuaternionDepthIterator", "DiscreteOrientationSet", diff --git a/spinstep/control/__init__.py b/spinstep/control/__init__.py new file mode 100644 index 0000000..57ee8a7 --- /dev/null +++ b/spinstep/control/__init__.py @@ -0,0 +1,51 @@ +# control/__init__.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Observer-centered orientation control: state, controllers, and trajectories. + +The SpinStep control model places the observer at the origin. Every +guided vehicle (node) is located by a quaternion (direction from observer) +and a radial distance (layer). Controllers produce commands with both +angular and radial velocity components. + +Sub-modules: + +- :mod:`~.state` — :class:`OrientationState`, :class:`ControlCommand`, + integration, error computation +- :mod:`~.controllers` — proportional and PID orientation controllers +- :mod:`~.trajectory` — waypoint trajectories and trajectory tracking +""" + +__all__ = [ + # state + "OrientationState", + "ControlCommand", + "integrate_state", + "compute_orientation_error", + # controllers + "OrientationController", + "ProportionalOrientationController", + "PIDOrientationController", + # trajectory + "OrientationTrajectory", + "TrajectoryInterpolator", + "TrajectoryController", +] + +from .state import ( + ControlCommand, + OrientationState, + compute_orientation_error, + integrate_state, +) +from .controllers import ( + OrientationController, + PIDOrientationController, + ProportionalOrientationController, +) +from .trajectory import ( + OrientationTrajectory, + TrajectoryController, + TrajectoryInterpolator, +) diff --git a/spinstep/control/controllers.py b/spinstep/control/controllers.py new file mode 100644 index 0000000..b1327dd --- /dev/null +++ b/spinstep/control/controllers.py @@ -0,0 +1,378 @@ +# control/controllers.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Orientation controllers: proportional and PID with rate limiting. + +All controllers operate in the observer-centered spherical model and +produce a :class:`~.state.ControlCommand` containing both angular +velocity and radial velocity components. +""" + +from __future__ import annotations + +__all__ = [ + "OrientationController", + "ProportionalOrientationController", + "PIDOrientationController", +] + +from abc import ABC, abstractmethod +from typing import Optional + +import numpy as np +from numpy.typing import ArrayLike + +from .state import ControlCommand, compute_orientation_error + + +class OrientationController(ABC): + """Abstract base class for orientation controllers. + + All controllers share the :meth:`update` interface, which takes current + and target poses (quaternion + distance) plus a time step, and returns + a :class:`ControlCommand` with angular and radial velocity components. + + Args: + max_angular_velocity: Maximum angular velocity magnitude + (rad/s). ``None`` means unlimited. + max_angular_acceleration: Maximum angular acceleration magnitude + (rad/s²). ``None`` means unlimited. + max_radial_velocity: Maximum radial speed (units/s). + ``None`` means unlimited. + max_radial_acceleration: Maximum radial acceleration (units/s²). + ``None`` means unlimited. + """ + + def __init__( + self, + max_angular_velocity: Optional[float] = None, + max_angular_acceleration: Optional[float] = None, + max_radial_velocity: Optional[float] = None, + max_radial_acceleration: Optional[float] = None, + ) -> None: + self.max_angular_velocity = max_angular_velocity + self.max_angular_acceleration = max_angular_acceleration + self.max_radial_velocity = max_radial_velocity + self.max_radial_acceleration = max_radial_acceleration + self._prev_angular_cmd: Optional[np.ndarray] = None + self._prev_radial_cmd: Optional[float] = None + + @abstractmethod + def compute_raw_command( + self, + current_q: ArrayLike, + target_q: ArrayLike, + dt: float, + current_distance: float = 0.0, + target_distance: float = 0.0, + ) -> ControlCommand: + """Compute the raw (unclamped) control command. + + Subclasses must implement this. + + Args: + current_q: Current orientation ``[x, y, z, w]``. + target_q: Target orientation ``[x, y, z, w]``. + dt: Time step in seconds. + current_distance: Current radial distance from observer. + target_distance: Target radial distance from observer. + + Returns: + Raw :class:`ControlCommand`. + """ + ... + + def update( + self, + current_q: ArrayLike, + target_q: ArrayLike, + dt: float, + current_distance: float = 0.0, + target_distance: float = 0.0, + ) -> ControlCommand: + """Compute a rate-limited control command. + + Calls :meth:`compute_raw_command` and applies velocity and + acceleration limits to both angular and radial components. + + Args: + current_q: Current orientation ``[x, y, z, w]``. + target_q: Target orientation ``[x, y, z, w]``. + dt: Time step in seconds. Must be positive. + current_distance: Current radial distance from observer. + target_distance: Target radial distance from observer. + + Returns: + Rate-limited :class:`ControlCommand`. + + Raises: + ValueError: If *dt* is not positive. + """ + if dt <= 0: + raise ValueError(f"dt must be positive, got {dt}") + + raw = self.compute_raw_command( + current_q, target_q, dt, current_distance, target_distance + ) + angular = raw.angular_velocity.copy() + radial = raw.radial_velocity + + # --- angular velocity limit --- + if self.max_angular_velocity is not None: + speed = np.linalg.norm(angular) + if speed > self.max_angular_velocity: + angular = angular * (self.max_angular_velocity / speed) + + # --- angular acceleration limit --- + if ( + self.max_angular_acceleration is not None + and self._prev_angular_cmd is not None + ): + delta = angular - self._prev_angular_cmd + delta_norm = np.linalg.norm(delta) + if delta_norm / dt > self.max_angular_acceleration: + max_delta = ( + delta / delta_norm * self.max_angular_acceleration * dt + ) + angular = self._prev_angular_cmd + max_delta + + # --- radial velocity limit --- + if self.max_radial_velocity is not None: + if abs(radial) > self.max_radial_velocity: + radial = np.sign(radial) * self.max_radial_velocity + + # --- radial acceleration limit --- + if ( + self.max_radial_acceleration is not None + and self._prev_radial_cmd is not None + ): + delta_r = radial - self._prev_radial_cmd + if abs(delta_r) / dt > self.max_radial_acceleration: + max_delta_r = np.sign(delta_r) * self.max_radial_acceleration * dt + radial = self._prev_radial_cmd + max_delta_r + + self._prev_angular_cmd = angular.copy() + self._prev_radial_cmd = float(radial) + return ControlCommand(angular_velocity=angular, radial_velocity=radial) + + def reset(self) -> None: + """Reset the controller's internal state.""" + self._prev_angular_cmd = None + self._prev_radial_cmd = None + + +class ProportionalOrientationController(OrientationController): + """Proportional (P) orientation controller. + + Computes angular velocity as ``kp × angular_error`` and radial + velocity as ``kp_radial × radial_error``. + + Args: + kp: Proportional gain for angular error. Defaults to ``1.0``. + kp_radial: Proportional gain for radial error. Defaults to ``1.0``. + max_angular_velocity: Maximum angular velocity (rad/s). + max_angular_acceleration: Maximum angular acceleration (rad/s²). + max_radial_velocity: Maximum radial speed (units/s). + max_radial_acceleration: Maximum radial acceleration (units/s²). + + Example:: + + from spinstep.control import ProportionalOrientationController + + ctrl = ProportionalOrientationController(kp=2.0, kp_radial=1.5) + cmd = ctrl.update( + [0, 0, 0, 1], [0, 0, 0.383, 0.924], dt=0.01, + current_distance=3.0, target_distance=5.0, + ) + print(cmd.angular_velocity, cmd.radial_velocity) + """ + + def __init__( + self, + kp: float = 1.0, + kp_radial: float = 1.0, + max_angular_velocity: Optional[float] = None, + max_angular_acceleration: Optional[float] = None, + max_radial_velocity: Optional[float] = None, + max_radial_acceleration: Optional[float] = None, + ) -> None: + super().__init__( + max_angular_velocity, + max_angular_acceleration, + max_radial_velocity, + max_radial_acceleration, + ) + self.kp = kp + self.kp_radial = kp_radial + + def compute_raw_command( + self, + current_q: ArrayLike, + target_q: ArrayLike, + dt: float, + current_distance: float = 0.0, + target_distance: float = 0.0, + ) -> ControlCommand: + """Compute ``kp × error`` for angular and radial components. + + Args: + current_q: Current orientation ``[x, y, z, w]``. + target_q: Target orientation ``[x, y, z, w]``. + dt: Time step in seconds (unused by P controller). + current_distance: Current radial distance. + target_distance: Target radial distance. + + Returns: + :class:`ControlCommand`. + """ + ang_error, rad_error = compute_orientation_error( + current_q, target_q, current_distance, target_distance + ) + return ControlCommand( + angular_velocity=self.kp * ang_error, + radial_velocity=self.kp_radial * rad_error, + ) + + def reset(self) -> None: + """Reset the controller's internal state.""" + super().reset() + + +class PIDOrientationController(OrientationController): + """PID orientation controller with anti-windup. + + Computes ``kp × e + ki × ∫e·dt + kd × de/dt`` for both the angular + and radial error channels independently. Integral windup is prevented + by clamping integrated error magnitudes. + + Args: + kp: Proportional gain (angular). + ki: Integral gain (angular). + kd: Derivative gain (angular). + kp_radial: Proportional gain (radial). + ki_radial: Integral gain (radial). + kd_radial: Derivative gain (radial). + max_integral: Maximum angular integral magnitude. + max_integral_radial: Maximum radial integral magnitude. + max_angular_velocity: Maximum angular velocity (rad/s). + max_angular_acceleration: Maximum angular acceleration (rad/s²). + max_radial_velocity: Maximum radial speed (units/s). + max_radial_acceleration: Maximum radial acceleration (units/s²). + + Example:: + + from spinstep.control import PIDOrientationController + + ctrl = PIDOrientationController(kp=2.0, ki=0.1, kd=0.5, + kp_radial=1.0, ki_radial=0.05) + cmd = ctrl.update( + [0, 0, 0, 1], [0, 0, 0.383, 0.924], dt=0.01, + current_distance=3.0, target_distance=5.0, + ) + """ + + def __init__( + self, + kp: float = 1.0, + ki: float = 0.0, + kd: float = 0.0, + kp_radial: float = 1.0, + ki_radial: float = 0.0, + kd_radial: float = 0.0, + max_integral: float = 10.0, + max_integral_radial: float = 10.0, + max_angular_velocity: Optional[float] = None, + max_angular_acceleration: Optional[float] = None, + max_radial_velocity: Optional[float] = None, + max_radial_acceleration: Optional[float] = None, + ) -> None: + super().__init__( + max_angular_velocity, + max_angular_acceleration, + max_radial_velocity, + max_radial_acceleration, + ) + self.kp = kp + self.ki = ki + self.kd = kd + self.kp_radial = kp_radial + self.ki_radial = ki_radial + self.kd_radial = kd_radial + self.max_integral = max_integral + self.max_integral_radial = max_integral_radial + + self._ang_integral: np.ndarray = np.zeros(3) + self._rad_integral: float = 0.0 + self._prev_ang_error: Optional[np.ndarray] = None + self._prev_rad_error: Optional[float] = None + + def compute_raw_command( + self, + current_q: ArrayLike, + target_q: ArrayLike, + dt: float, + current_distance: float = 0.0, + target_distance: float = 0.0, + ) -> ControlCommand: + """Compute the PID command for angular and radial channels. + + Args: + current_q: Current orientation ``[x, y, z, w]``. + target_q: Target orientation ``[x, y, z, w]``. + dt: Time step in seconds. + current_distance: Current radial distance. + target_distance: Target radial distance. + + Returns: + :class:`ControlCommand`. + """ + ang_error, rad_error = compute_orientation_error( + current_q, target_q, current_distance, target_distance + ) + + # --- angular PID --- + p_ang = self.kp * ang_error + + self._ang_integral += ang_error * dt + mag = np.linalg.norm(self._ang_integral) + if mag > self.max_integral: + self._ang_integral *= self.max_integral / mag + i_ang = self.ki * self._ang_integral + + if self._prev_ang_error is not None: + d_ang = self.kd * (ang_error - self._prev_ang_error) / dt + else: + d_ang = np.zeros(3) + self._prev_ang_error = ang_error.copy() + + angular_cmd = p_ang + i_ang + d_ang + + # --- radial PID --- + p_rad = self.kp_radial * rad_error + + self._rad_integral += rad_error * dt + if abs(self._rad_integral) > self.max_integral_radial: + self._rad_integral = np.sign(self._rad_integral) * self.max_integral_radial + i_rad = self.ki_radial * self._rad_integral + + if self._prev_rad_error is not None: + d_rad = self.kd_radial * (rad_error - self._prev_rad_error) / dt + else: + d_rad = 0.0 + self._prev_rad_error = rad_error + + radial_cmd = p_rad + i_rad + d_rad + + return ControlCommand( + angular_velocity=angular_cmd, + radial_velocity=radial_cmd, + ) + + def reset(self) -> None: + """Reset the controller's internal state (integral, derivative, etc.).""" + super().reset() + self._ang_integral = np.zeros(3) + self._rad_integral = 0.0 + self._prev_ang_error = None + self._prev_rad_error = None diff --git a/spinstep/control/state.py b/spinstep/control/state.py new file mode 100644 index 0000000..4e06b41 --- /dev/null +++ b/spinstep/control/state.py @@ -0,0 +1,261 @@ +# control/state.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Observer-centered spherical state model. + +In the SpinStep control model the observer sits at the origin. +Every guided vehicle (node) is located by: + +- **orientation** — a unit quaternion giving the direction from the observer +- **distance** — the radial distance (layer) from the observer + +Velocities follow the same decomposition: angular velocity (rad/s) for +the tangential component and radial velocity (units/s) for the range +component. +""" + +from __future__ import annotations + +__all__ = [ + "OrientationState", + "ControlCommand", + "integrate_state", + "compute_orientation_error", +] + +from dataclasses import dataclass, field + +import numpy as np +from numpy.typing import ArrayLike +from scipy.spatial.transform import Rotation as R + +from ..math.core import quaternion_multiply, quaternion_normalize + + +@dataclass +class OrientationState: + """Observer-centered state: direction, distance, velocities, timestamp. + + The state describes a guided vehicle (node) in the observer's spherical + frame. The quaternion gives the direction from the observer; the + distance gives the radial layer. + + All quaternions use ``[x, y, z, w]`` convention. + + Args: + quaternion: Unit quaternion ``[x, y, z, w]`` — direction from observer. + distance: Radial distance from observer. Defaults to ``0.0``. + angular_velocity: Angular velocity ``[ωx, ωy, ωz]`` in rad/s. + radial_velocity: Radial velocity (units/s). Positive = moving away + from observer. + timestamp: Time in seconds. + + Attributes: + quaternion: Normalised quaternion ``(4,)``. + distance: Radial distance (≥ 0). + angular_velocity: Angular velocity ``(3,)``. + radial_velocity: Radial velocity scalar. + timestamp: Timestamp in seconds. + + Example:: + + from spinstep.control import OrientationState + + # A vehicle at distance 5.0 looking along +Z + state = OrientationState([0, 0, 0, 1], distance=5.0) + """ + + quaternion: np.ndarray = field( + default_factory=lambda: np.array([0.0, 0.0, 0.0, 1.0]) + ) + distance: float = 0.0 + angular_velocity: np.ndarray = field(default_factory=lambda: np.zeros(3)) + radial_velocity: float = 0.0 + timestamp: float = 0.0 + + def __init__( + self, + quaternion: ArrayLike = (0.0, 0.0, 0.0, 1.0), + distance: float = 0.0, + angular_velocity: ArrayLike = (0.0, 0.0, 0.0), + radial_velocity: float = 0.0, + timestamp: float = 0.0, + ) -> None: + q = np.asarray(quaternion, dtype=float) + if q.shape != (4,): + raise ValueError( + f"quaternion must have shape (4,), got {q.shape}" + ) + norm = np.linalg.norm(q) + if norm < 1e-8: + raise ValueError("quaternion must be non-zero") + self.quaternion = q / norm + + if distance < 0: + raise ValueError(f"distance must be non-negative, got {distance}") + self.distance = float(distance) + + omega = np.asarray(angular_velocity, dtype=float) + if omega.shape != (3,): + raise ValueError( + f"angular_velocity must have shape (3,), got {omega.shape}" + ) + self.angular_velocity = omega + self.radial_velocity = float(radial_velocity) + self.timestamp = float(timestamp) + + def __repr__(self) -> str: + return ( + f"OrientationState(" + f"q={self.quaternion.tolist()}, " + f"d={self.distance}, " + f"ω={self.angular_velocity.tolist()}, " + f"ṙ={self.radial_velocity}, " + f"t={self.timestamp})" + ) + + +@dataclass +class ControlCommand: + """Command output from a controller: angular + radial velocity. + + Separates the tangential (angular) and radial components of the + velocity command so they can be applied independently to actuators. + + Args: + angular_velocity: Desired angular velocity ``[ωx, ωy, ωz]`` in + rad/s. + radial_velocity: Desired radial velocity in units/s. Positive + means moving away from the observer. + + Example:: + + from spinstep.control.state import ControlCommand + + cmd = ControlCommand(angular_velocity=[0, 0, 1.0], radial_velocity=0.5) + """ + + angular_velocity: np.ndarray = field(default_factory=lambda: np.zeros(3)) + radial_velocity: float = 0.0 + + def __init__( + self, + angular_velocity: ArrayLike = (0.0, 0.0, 0.0), + radial_velocity: float = 0.0, + ) -> None: + omega = np.asarray(angular_velocity, dtype=float) + if omega.shape != (3,): + raise ValueError( + f"angular_velocity must have shape (3,), got {omega.shape}" + ) + self.angular_velocity = omega + self.radial_velocity = float(radial_velocity) + + def __repr__(self) -> str: + return ( + f"ControlCommand(" + f"ω={self.angular_velocity.tolist()}, " + f"ṙ={self.radial_velocity})" + ) + + +def integrate_state(state: OrientationState, dt: float) -> OrientationState: + """Integrate the full spherical state forward by *dt* seconds. + + Orientation is integrated via the exponential map: + ``q(t+dt) = q(t) * exp(ω · dt / 2)``. + Distance is integrated linearly: + ``d(t+dt) = max(0, d(t) + ṙ · dt)``. + + Args: + state: Current state. + dt: Time step in seconds. Must be positive. + + Returns: + New :class:`OrientationState` with updated quaternion, distance, + and timestamp. Velocities are carried forward unchanged. + + Raises: + ValueError: If *dt* is not positive. + + Example:: + + from spinstep.control import OrientationState, integrate_state + + state = OrientationState([0, 0, 0, 1], distance=5.0, + angular_velocity=[0, 0, 1.0], + radial_velocity=0.5) + new = integrate_state(state, dt=0.01) + # new.distance ≈ 5.005 + """ + if dt <= 0: + raise ValueError(f"dt must be positive, got {dt}") + + # --- angular integration --- + omega = state.angular_velocity + angle = np.linalg.norm(omega) + + if angle < 1e-10: + new_q = state.quaternion.copy() + else: + half_angle = angle * dt / 2.0 + axis = omega / angle + delta_q = np.array([ + *(axis * np.sin(half_angle)), + np.cos(half_angle), + ]) + new_q = quaternion_multiply(state.quaternion, delta_q) + new_q = quaternion_normalize(new_q) + + # --- radial integration --- + new_distance = max(0.0, state.distance + state.radial_velocity * dt) + + return OrientationState( + quaternion=new_q, + distance=new_distance, + angular_velocity=state.angular_velocity.copy(), + radial_velocity=state.radial_velocity, + timestamp=state.timestamp + dt, + ) + + +def compute_orientation_error( + current_q: ArrayLike, + target_q: ArrayLike, + current_distance: float = 0.0, + target_distance: float = 0.0, +) -> tuple[np.ndarray, float]: + """Compute the full spherical error: angular + radial. + + The angular error is expressed in the body frame of *current_q* as a + rotation vector (axis × angle, in radians). + + The radial error is ``target_distance − current_distance`` (positive + means the target is farther from the observer). + + Args: + current_q: Current orientation quaternion ``[x, y, z, w]``. + target_q: Target orientation quaternion ``[x, y, z, w]``. + current_distance: Current radial distance. + target_distance: Target radial distance. + + Returns: + A tuple ``(angular_error, radial_error)`` where *angular_error* + is a rotation vector ``(3,)`` and *radial_error* is a float. + + Example:: + + from spinstep.control import compute_orientation_error + + ang_err, rad_err = compute_orientation_error( + [0, 0, 0, 1], [0, 0, 0.383, 0.924], + current_distance=3.0, target_distance=5.0, + ) + """ + r_current = R.from_quat(current_q) + r_target = R.from_quat(target_q) + r_error = r_current.inv() * r_target + angular_error: np.ndarray = r_error.as_rotvec() + radial_error = float(target_distance - current_distance) + return angular_error, radial_error diff --git a/spinstep/control/trajectory.py b/spinstep/control/trajectory.py new file mode 100644 index 0000000..7d009ee --- /dev/null +++ b/spinstep/control/trajectory.py @@ -0,0 +1,272 @@ +# control/trajectory.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Orientation trajectories: waypoints, interpolation, and tracking. + +Trajectories in SpinStep are sequences of ``(quaternion, distance, time)`` +waypoints in the observer-centered spherical frame. The quaternion gives +the direction from the observer and the distance gives the radial layer. +""" + +from __future__ import annotations + +__all__ = [ + "OrientationTrajectory", + "TrajectoryInterpolator", + "TrajectoryController", +] + +from typing import List, Sequence, Tuple, Union + +import numpy as np +from numpy.typing import ArrayLike + +from ..math.interpolation import slerp +from .controllers import OrientationController +from .state import ControlCommand + + +class OrientationTrajectory: + """A sequence of spherical waypoints with timestamps. + + Each waypoint is ``(quaternion, distance, time)`` or, for backward + compatibility, ``(quaternion, time)`` (distance defaults to ``0.0``). + + Waypoints must be in ascending time order. + + Args: + waypoints: Sequence of waypoint tuples. Accepted forms: + - ``(quaternion, distance, time)`` + - ``(quaternion, time)`` — distance defaults to ``0.0`` + + Raises: + ValueError: If fewer than two waypoints are provided, times are + not strictly increasing, or quaternions are invalid. + + Attributes: + quaternions: Array of shape ``(N, 4)`` — waypoint quaternions. + distances: Array of shape ``(N,)`` — waypoint distances. + times: Array of shape ``(N,)`` — waypoint times in seconds. + + Example:: + + from spinstep.control import OrientationTrajectory + + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0.383, 0.924], 7.5, 1.0), + ([0, 0, 0.707, 0.707], 10.0, 2.0), + ]) + """ + + quaternions: np.ndarray + distances: np.ndarray + times: np.ndarray + + def __init__( + self, + waypoints: Sequence[Union[Tuple[ArrayLike, float, float], Tuple[ArrayLike, float]]], + ) -> None: + if len(waypoints) < 2: + raise ValueError( + f"At least 2 waypoints are required, got {len(waypoints)}" + ) + + quats: List[np.ndarray] = [] + dists: List[float] = [] + times: List[float] = [] + + for wp in waypoints: + if len(wp) == 3: + q_raw, dist, t = wp # type: ignore[misc] + elif len(wp) == 2: + q_raw, t = wp # type: ignore[misc] + dist = 0.0 + else: + raise ValueError( + f"Each waypoint must be (quaternion, distance, time) " + f"or (quaternion, time), got tuple of length {len(wp)}" + ) + + arr = np.asarray(q_raw, dtype=float) + if arr.shape != (4,): + raise ValueError( + f"Each waypoint quaternion must have shape (4,), got {arr.shape}" + ) + norm = np.linalg.norm(arr) + if norm < 1e-8: + raise ValueError("Waypoint quaternion must be non-zero") + quats.append(arr / norm) + dists.append(float(dist)) + times.append(float(t)) + + for i in range(1, len(times)): + if times[i] <= times[i - 1]: + raise ValueError( + f"Waypoint times must be strictly increasing: " + f"t[{i-1}]={times[i-1]}, t[{i}]={times[i]}" + ) + + self.quaternions = np.array(quats) + self.distances = np.array(dists) + self.times = np.array(times) + + @property + def duration(self) -> float: + """Total duration of the trajectory in seconds.""" + return float(self.times[-1] - self.times[0]) + + @property + def start_time(self) -> float: + """Start time of the trajectory.""" + return float(self.times[0]) + + @property + def end_time(self) -> float: + """End time of the trajectory.""" + return float(self.times[-1]) + + def __len__(self) -> int: + return len(self.times) + + def __repr__(self) -> str: + return ( + f"OrientationTrajectory({len(self)} waypoints, " + f"t=[{self.start_time}, {self.end_time}])" + ) + + +class TrajectoryInterpolator: + """SLERP + linear interpolator for an :class:`OrientationTrajectory`. + + Orientation is interpolated via SLERP; distance is linearly + interpolated between adjacent waypoints. + + Args: + trajectory: The trajectory to interpolate. + + Example:: + + from spinstep.control import OrientationTrajectory, TrajectoryInterpolator + + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0.383, 0.924], 10.0, 1.0), + ]) + interp = TrajectoryInterpolator(traj) + q, d = interp.evaluate(0.5) + # q ≈ slerp midpoint, d ≈ 7.5 + """ + + def __init__(self, trajectory: OrientationTrajectory) -> None: + self.trajectory = trajectory + + def evaluate(self, t: float) -> Tuple[np.ndarray, float]: + """Return the interpolated quaternion and distance at time *t*. + + Times before the first waypoint return the first pose; times + after the last return the last pose. + + Args: + t: Query time in seconds. + + Returns: + A tuple ``(quaternion, distance)``. + """ + traj = self.trajectory + + if t <= traj.times[0]: + return traj.quaternions[0].copy(), float(traj.distances[0]) + if t >= traj.times[-1]: + return traj.quaternions[-1].copy(), float(traj.distances[-1]) + + # Find the segment + idx = int(np.searchsorted(traj.times, t, side="right") - 1) + idx = min(idx, len(traj.times) - 2) + + t0 = traj.times[idx] + t1 = traj.times[idx + 1] + alpha = (t - t0) / (t1 - t0) + + q = slerp(traj.quaternions[idx], traj.quaternions[idx + 1], alpha) + d = float(traj.distances[idx] + alpha * (traj.distances[idx + 1] - traj.distances[idx])) + return q, d + + @property + def duration(self) -> float: + """Total duration of the underlying trajectory.""" + return self.trajectory.duration + + +class TrajectoryController: + """Controller that tracks a spherical trajectory over time. + + Wraps a base :class:`OrientationController` and a + :class:`TrajectoryInterpolator`. At each step the controller queries + the interpolator for the desired pose and computes the + :class:`~.state.ControlCommand` to drive the vehicle towards it. + + Args: + controller: Base controller instance. + trajectory: The trajectory to follow. + + Attributes: + interpolator: Internal :class:`TrajectoryInterpolator`. + controller: The wrapped base controller. + is_complete: Whether the trajectory end time has been reached. + + Example:: + + from spinstep.control import ( + OrientationTrajectory, + ProportionalOrientationController, + TrajectoryController, + ) + + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0.383, 0.924], 10.0, 1.0), + ]) + ctrl = ProportionalOrientationController(kp=2.0, kp_radial=1.0) + traj_ctrl = TrajectoryController(ctrl, traj) + cmd = traj_ctrl.update([0, 0, 0, 1], current_distance=5.0, t=0.5, dt=0.01) + """ + + def __init__( + self, + controller: OrientationController, + trajectory: OrientationTrajectory, + ) -> None: + self.controller = controller + self.interpolator = TrajectoryInterpolator(trajectory) + self.is_complete: bool = False + + def update( + self, + current_q: ArrayLike, + t: float, + dt: float, + current_distance: float = 0.0, + ) -> ControlCommand: + """Compute the control command to track the trajectory at time *t*. + + Args: + current_q: Current orientation ``[x, y, z, w]``. + t: Current time in seconds. + dt: Time step in seconds. + current_distance: Current radial distance from observer. + + Returns: + :class:`~.state.ControlCommand` with angular and radial components. + """ + target_q, target_distance = self.interpolator.evaluate(t) + self.is_complete = t >= self.interpolator.trajectory.end_time + return self.controller.update( + current_q, target_q, dt, current_distance, target_distance + ) + + def reset(self) -> None: + """Reset the controller state.""" + self.controller.reset() + self.is_complete = False diff --git a/spinstep/math/__init__.py b/spinstep/math/__init__.py new file mode 100644 index 0000000..88661c5 --- /dev/null +++ b/spinstep/math/__init__.py @@ -0,0 +1,77 @@ +# math/__init__.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Quaternion mathematics library. + +Sub-modules: + +- :mod:`~.core` — multiply, conjugate, normalize, inverse +- :mod:`~.interpolation` — slerp, squad +- :mod:`~.geometry` — distance, angle, direction conversions +- :mod:`~.conversions` — Euler ↔ quaternion, matrix ↔ quaternion +- :mod:`~.analysis` — batch distances, angular velocity, relative spins +- :mod:`~.constraints` — rotation angle clamping +""" + +__all__ = [ + # core + "quaternion_multiply", + "quaternion_conjugate", + "quaternion_normalize", + "quaternion_inverse", + # interpolation + "slerp", + "squad", + # geometry + "quaternion_distance", + "is_within_angle_threshold", + "forward_vector_from_quaternion", + "direction_to_quaternion", + "angle_between_directions", + "rotate_quaternion", + # conversions + "quaternion_from_euler", + "rotation_matrix_to_quaternion", + "quaternion_from_rotvec", + "quaternion_to_rotvec", + # analysis + "batch_quaternion_angle", + "angular_velocity_from_quaternions", + "get_relative_spin", + "get_unique_relative_spins", + # constraints + "clamp_rotation_angle", + # protocols + "NodeProtocol", +] + +from .core import ( + quaternion_conjugate, + quaternion_inverse, + quaternion_multiply, + quaternion_normalize, +) +from .interpolation import slerp, squad +from .geometry import ( + angle_between_directions, + direction_to_quaternion, + forward_vector_from_quaternion, + is_within_angle_threshold, + quaternion_distance, + rotate_quaternion, +) +from .conversions import ( + quaternion_from_euler, + quaternion_from_rotvec, + quaternion_to_rotvec, + rotation_matrix_to_quaternion, +) +from .analysis import ( + NodeProtocol, + angular_velocity_from_quaternions, + batch_quaternion_angle, + get_relative_spin, + get_unique_relative_spins, +) +from .constraints import clamp_rotation_angle diff --git a/spinstep/math/analysis.py b/spinstep/math/analysis.py new file mode 100644 index 0000000..05e7ea6 --- /dev/null +++ b/spinstep/math/analysis.py @@ -0,0 +1,152 @@ +# math/analysis.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Quaternion analysis: batch distances, angular velocity, relative spins.""" + +from __future__ import annotations + +__all__ = [ + "batch_quaternion_angle", + "angular_velocity_from_quaternions", + "get_relative_spin", + "get_unique_relative_spins", +] + +from types import ModuleType +from typing import Any, List, Protocol, Sequence, runtime_checkable + +import numpy as np +from numpy.typing import ArrayLike + +from .core import quaternion_conjugate, quaternion_multiply + + +@runtime_checkable +class NodeProtocol(Protocol): + """Structural type for objects accepted by :func:`get_relative_spin`. + + Any object with an ``orientation`` attribute holding a quaternion + ``[x, y, z, w]`` satisfies this protocol. + """ + + orientation: np.ndarray + + +def batch_quaternion_angle(qs1: Any, qs2: Any, xp: ModuleType) -> Any: + """Compute pairwise angular distances between two sets of quaternions. + + Args: + qs1: Array of shape ``(N, 4)`` — first set of quaternions. + qs2: Array of shape ``(M, 4)`` — second set of quaternions. + xp: Array module (:mod:`numpy` or :mod:`cupy`). + + Returns: + ``(N, M)`` array of angular distances in radians. + """ + dots = xp.abs(xp.dot(qs1, qs2.T)) + dots = xp.clip(dots, -1.0, 1.0) + angles = 2 * xp.arccos(dots) + return angles + + +def angular_velocity_from_quaternions( + q1: ArrayLike, q2: ArrayLike, dt: float +) -> np.ndarray: + """Estimate angular velocity from two quaternions separated by *dt* seconds. + + Computes the rotation from *q1* to *q2*, converts to a rotation vector + (axis × angle), and divides by *dt* to obtain angular velocity in rad/s. + + Args: + q1: Start quaternion ``[x, y, z, w]``. + q2: End quaternion ``[x, y, z, w]``. + dt: Time step in seconds. Must be positive. + + Returns: + Angular velocity vector ``(3,)`` in radians per second. + + Raises: + ValueError: If *dt* is not positive. + """ + if dt <= 0: + raise ValueError(f"dt must be positive, got {dt}") + from scipy.spatial.transform import Rotation as R + + r1 = R.from_quat(q1) + r2 = R.from_quat(q2) + delta = r1.inv() * r2 + rotvec = delta.as_rotvec() + return rotvec / dt + + +def get_relative_spin(nf: NodeProtocol, nt: NodeProtocol) -> np.ndarray: + """Return the relative quaternion rotation from node *nf* to node *nt*. + + Both nodes must have an ``.orientation`` attribute storing a quaternion + ``[x, y, z, w]``. + + Args: + nf: Source node with ``.orientation`` attribute. + nt: Target node with ``.orientation`` attribute. + + Returns: + Unit quaternion representing the relative rotation. + """ + qfc = quaternion_conjugate(nf.orientation) + qr = quaternion_multiply(qfc, nt.orientation) + n = np.linalg.norm(qr) + return qr / n if n > 1e-8 else np.array([0.0, 0.0, 0.0, 1.0]) + + +def get_unique_relative_spins( + nodes: Sequence[NodeProtocol], + nside: int, + nest: bool, + threshold: float = 1e-3, +) -> List[np.ndarray]: + """Compute unique relative rotations between HEALPix neighbours. + + Requires the ``healpy`` package. + + Args: + nodes: Sequence of node objects with ``.orientation`` attributes. + nside: HEALPix *nside* parameter. + nest: Whether to use the NESTED pixel ordering. + threshold: Angular threshold (radians) for considering two + rotations identical. + + Returns: + List of unique unit quaternions representing relative rotations. + + Raises: + ImportError: If ``healpy`` is not installed. + """ + try: + import healpy as hp + except ImportError: + raise ImportError( + "healpy is required for get_unique_relative_spins(). " + "Install it with: pip install healpy" + ) + spins: List[np.ndarray] = [] + NPIX = hp.nside2npix(nside) + for i in range(NPIX): + nf = nodes[i] + nidx = hp.get_all_neighbours(nside, i, nest=nest) + for idx in nidx: + if idx != -1: + q = get_relative_spin(nf, nodes[idx]) + if q[3] < 0: + q = -q # Canonical form (w >= 0) + is_uniq = True + for s_q in spins: + dot = np.abs(np.dot(q, s_q)) + dot = np.clip(dot, -1, 1) + angle = 2 * np.arccos(dot) + if angle < threshold: + is_uniq = False + break + if is_uniq: + spins.append(q) + return spins diff --git a/spinstep/math/constraints.py b/spinstep/math/constraints.py new file mode 100644 index 0000000..4147f97 --- /dev/null +++ b/spinstep/math/constraints.py @@ -0,0 +1,52 @@ +# math/constraints.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Quaternion constraints: clamping, limiting rotation magnitude.""" + +from __future__ import annotations + +__all__ = [ + "clamp_rotation_angle", +] + +import numpy as np +from numpy.typing import ArrayLike +from scipy.spatial.transform import Rotation as R + + +def clamp_rotation_angle(q: ArrayLike, max_angle: float) -> np.ndarray: + """Clamp a rotation quaternion so its angle does not exceed *max_angle*. + + If the rotation represented by *q* has an angle larger than *max_angle*, + the quaternion is scaled to represent exactly *max_angle* around the + same axis. + + Args: + q: Rotation quaternion ``[x, y, z, w]``. + max_angle: Maximum allowed rotation angle in radians. Must be + non-negative. + + Returns: + Clamped unit quaternion ``[x, y, z, w]``. + + Raises: + ValueError: If *max_angle* is negative. + """ + if max_angle < 0: + raise ValueError(f"max_angle must be non-negative, got {max_angle}") + + rot = R.from_quat(q) + angle = rot.magnitude() + + if angle <= max_angle: + return np.asarray(q, dtype=float) + + # Scale the rotation vector to max_angle + rotvec = rot.as_rotvec() + if angle < 1e-10: + return np.array([0.0, 0.0, 0.0, 1.0]) + + axis = rotvec / angle + clamped_rotvec = axis * max_angle + return R.from_rotvec(clamped_rotvec).as_quat() diff --git a/spinstep/math/conversions.py b/spinstep/math/conversions.py new file mode 100644 index 0000000..f1ea394 --- /dev/null +++ b/spinstep/math/conversions.py @@ -0,0 +1,102 @@ +# math/conversions.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Quaternion conversions: Euler, rotation matrix, rotation vector.""" + +from __future__ import annotations + +__all__ = [ + "quaternion_from_euler", + "rotation_matrix_to_quaternion", + "quaternion_from_rotvec", + "quaternion_to_rotvec", +] + +from typing import Sequence + +import numpy as np +from numpy.typing import ArrayLike +from scipy.spatial.transform import Rotation as R + + +def quaternion_from_euler( + angles: Sequence[float], + order: str = "zyx", + degrees: bool = True, +) -> np.ndarray: + """Convert Euler angles to a quaternion ``[x, y, z, w]``. + + Args: + angles: Euler angles as a 3-element sequence. + order: Rotation order string (e.g. ``"zyx"``). + degrees: If ``True``, angles are in degrees; otherwise radians. + + Returns: + Unit quaternion ``[x, y, z, w]``. + """ + return R.from_euler(order, angles, degrees=degrees).as_quat() + + +def rotation_matrix_to_quaternion(m: ArrayLike) -> np.ndarray: + """Convert a 3×3 rotation matrix to a unit quaternion ``[x, y, z, w]``. + + Args: + m: A 3×3 rotation matrix. + + Returns: + Unit quaternion ``[x, y, z, w]``. + """ + mat = np.asarray(m, dtype=float) + t = np.trace(mat) + if t > 0: + s = np.sqrt(t + 1) * 2 + qw = 0.25 * s + qx = (mat[2, 1] - mat[1, 2]) / s + qy = (mat[0, 2] - mat[2, 0]) / s + qz = (mat[1, 0] - mat[0, 1]) / s + elif (mat[0, 0] > mat[1, 1]) and (mat[0, 0] > mat[2, 2]): + s = np.sqrt(1 + mat[0, 0] - mat[1, 1] - mat[2, 2]) * 2 + qx = 0.25 * s + qw = (mat[2, 1] - mat[1, 2]) / s + qy = (mat[0, 1] + mat[1, 0]) / s + qz = (mat[0, 2] + mat[2, 0]) / s + elif mat[1, 1] > mat[2, 2]: + s = np.sqrt(1 + mat[1, 1] - mat[0, 0] - mat[2, 2]) * 2 + qy = 0.25 * s + qw = (mat[0, 2] - mat[2, 0]) / s + qx = (mat[0, 1] + mat[1, 0]) / s + qz = (mat[1, 2] + mat[2, 1]) / s + else: + s = np.sqrt(1 + mat[2, 2] - mat[0, 0] - mat[1, 1]) * 2 + qz = 0.25 * s + qw = (mat[1, 0] - mat[0, 1]) / s + qx = (mat[0, 2] + mat[2, 0]) / s + qy = (mat[1, 2] + mat[2, 1]) / s + q = np.array([qx, qy, qz, qw]) + n = np.linalg.norm(q) + return q / n if n > 1e-8 else np.array([0.0, 0.0, 0.0, 1.0]) + + +def quaternion_from_rotvec(rotvec: ArrayLike) -> np.ndarray: + """Convert a rotation vector (axis × angle) to a quaternion. + + Args: + rotvec: Rotation vector of shape ``(3,)``. + + Returns: + Unit quaternion ``[x, y, z, w]``. + """ + return R.from_rotvec(np.asarray(rotvec, dtype=float)).as_quat() + + +def quaternion_to_rotvec(q: ArrayLike) -> np.ndarray: + """Convert a quaternion to a rotation vector (axis × angle). + + Args: + q: Quaternion ``[x, y, z, w]``. + + Returns: + Rotation vector of shape ``(3,)``. + """ + return R.from_quat(np.asarray(q, dtype=float)).as_rotvec() diff --git a/spinstep/math/core.py b/spinstep/math/core.py new file mode 100644 index 0000000..8dd0fa1 --- /dev/null +++ b/spinstep/math/core.py @@ -0,0 +1,90 @@ +# math/core.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Core quaternion operations: multiply, conjugate, normalize, inverse.""" + +from __future__ import annotations + +__all__ = [ + "quaternion_multiply", + "quaternion_conjugate", + "quaternion_normalize", + "quaternion_inverse", +] + +import numpy as np +from numpy.typing import ArrayLike + + +def quaternion_multiply(q1: ArrayLike, q2: ArrayLike) -> np.ndarray: + """Hamilton product of two quaternions in ``[x, y, z, w]`` order. + + Args: + q1: First quaternion ``[x, y, z, w]``. + q2: Second quaternion ``[x, y, z, w]``. + + Returns: + Product quaternion ``[x, y, z, w]`` as a NumPy array of shape ``(4,)``. + """ + a1 = np.asarray(q1, dtype=float) + a2 = np.asarray(q2, dtype=float) + x1, y1, z1, w1 = a1 + x2, y2, z2, w2 = a2 + return np.array([ + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + ]) + + +def quaternion_conjugate(q: ArrayLike) -> np.ndarray: + """Return the conjugate of quaternion *q* ``[x, y, z, w]``. + + The conjugate negates the vector part while keeping the scalar part. + + Args: + q: Quaternion ``[x, y, z, w]``. + + Returns: + Conjugate quaternion ``[-x, -y, -z, w]``. + """ + a = np.asarray(q, dtype=float) + return np.array([-a[0], -a[1], -a[2], a[3]]) + + +def quaternion_normalize(q: ArrayLike) -> np.ndarray: + """Normalize a quaternion to unit length. + + Args: + q: Quaternion ``[x, y, z, w]``. + + Returns: + Unit quaternion. Returns ``[0, 0, 0, 1]`` if the input has + near-zero norm. + """ + a = np.asarray(q, dtype=float) + n = np.linalg.norm(a) + if n < 1e-8: + return np.array([0.0, 0.0, 0.0, 1.0]) + return a / n + + +def quaternion_inverse(q: ArrayLike) -> np.ndarray: + """Return the inverse of a unit quaternion. + + For unit quaternions the inverse equals the conjugate. + + Args: + q: Unit quaternion ``[x, y, z, w]``. + + Returns: + Inverse quaternion ``[x, y, z, w]``. + """ + a = np.asarray(q, dtype=float) + norm_sq = np.dot(a, a) + if norm_sq < 1e-16: + return np.array([0.0, 0.0, 0.0, 1.0]) + conj = np.array([-a[0], -a[1], -a[2], a[3]]) + return conj / norm_sq diff --git a/spinstep/math/geometry.py b/spinstep/math/geometry.py new file mode 100644 index 0000000..ce2cb71 --- /dev/null +++ b/spinstep/math/geometry.py @@ -0,0 +1,125 @@ +# math/geometry.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Quaternion geometry: distance, angle, direction conversions.""" + +from __future__ import annotations + +__all__ = [ + "quaternion_distance", + "is_within_angle_threshold", + "forward_vector_from_quaternion", + "direction_to_quaternion", + "angle_between_directions", + "rotate_quaternion", +] + +import numpy as np +from numpy.typing import ArrayLike +from scipy.spatial.transform import Rotation as R + + +def quaternion_distance(q1: ArrayLike, q2: ArrayLike) -> float: + """Return the angular distance (radians) between two quaternions. + + Args: + q1: First quaternion ``[x, y, z, w]``. + q2: Second quaternion ``[x, y, z, w]``. + + Returns: + Angular distance in radians. + """ + r1 = R.from_quat(q1) + r2 = R.from_quat(q2) + return float((r1.inv() * r2).magnitude()) + + +def is_within_angle_threshold( + q_current: ArrayLike, + q_target: ArrayLike, + threshold_rad: float, +) -> bool: + """Check whether two quaternions are within *threshold_rad* of each other. + + Args: + q_current: Current quaternion ``[x, y, z, w]``. + q_target: Target quaternion ``[x, y, z, w]``. + threshold_rad: Maximum angular distance in radians. + + Returns: + ``True`` if the angular distance is less than *threshold_rad*. + """ + return quaternion_distance(q_current, q_target) < threshold_rad + + +def rotate_quaternion(q: ArrayLike, rotation_step: ArrayLike) -> np.ndarray: + """Apply *rotation_step* to quaternion *q* and return the result. + + Args: + q: Base quaternion ``[x, y, z, w]``. + rotation_step: Rotation to apply, as quaternion ``[x, y, z, w]``. + + Returns: + Composed quaternion ``[x, y, z, w]``. + """ + r1 = R.from_quat(q) + step = R.from_quat(rotation_step) + return (r1 * step).as_quat() + + +def forward_vector_from_quaternion(q: ArrayLike) -> np.ndarray: + """Extract the forward (look) direction from a quaternion. + + The forward direction is defined as ``[0, 0, -1]`` rotated by the + quaternion, following the convention where negative-Z is "forward". + + Args: + q: Quaternion ``[x, y, z, w]``. + + Returns: + Unit direction vector ``(3,)`` pointing forward. + """ + return R.from_quat(q).apply([0, 0, -1]) + + +def direction_to_quaternion(direction: ArrayLike) -> np.ndarray: + """Convert a 3D direction vector to an orientation quaternion. + + The returned quaternion represents the rotation that aligns the + default forward axis ``[0, 0, -1]`` with the given *direction*. + + Args: + direction: Target direction vector (does not need to be normalised). + + Returns: + Unit quaternion ``[x, y, z, w]``. + """ + d = np.asarray(direction, dtype=float) + norm = np.linalg.norm(d) + if norm < 1e-8: + return np.array([0.0, 0.0, 0.0, 1.0]) + d = d / norm + rot, _ = R.align_vectors([d], [[0, 0, -1]]) + return rot.as_quat() + + +def angle_between_directions(d1: ArrayLike, d2: ArrayLike) -> float: + """Compute the angular distance (radians) between two direction vectors. + + Args: + d1: First direction vector. + d2: Second direction vector. + + Returns: + Angle in radians in the range ``[0, π]``. + """ + v1 = np.asarray(d1, dtype=float) + v2 = np.asarray(d2, dtype=float) + n1 = np.linalg.norm(v1) + n2 = np.linalg.norm(v2) + if n1 < 1e-8 or n2 < 1e-8: + return 0.0 + cos_angle = np.dot(v1 / n1, v2 / n2) + cos_angle = np.clip(cos_angle, -1.0, 1.0) + return float(np.arccos(cos_angle)) diff --git a/spinstep/math/interpolation.py b/spinstep/math/interpolation.py new file mode 100644 index 0000000..22fba9d --- /dev/null +++ b/spinstep/math/interpolation.py @@ -0,0 +1,133 @@ +# math/interpolation.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Quaternion interpolation: SLERP and SQUAD.""" + +from __future__ import annotations + +__all__ = [ + "slerp", + "squad", +] + +import numpy as np +from numpy.typing import ArrayLike + + +def slerp(q0: ArrayLike, q1: ArrayLike, t: float) -> np.ndarray: + """Spherical linear interpolation between two quaternions. + + Interpolates along the shortest arc on the unit quaternion hypersphere. + + Args: + q0: Start quaternion ``[x, y, z, w]``. + q1: End quaternion ``[x, y, z, w]``. + t: Interpolation parameter in ``[0, 1]``. + + Returns: + Interpolated unit quaternion ``[x, y, z, w]``. + """ + a = np.asarray(q0, dtype=float) + b = np.asarray(q1, dtype=float) + + # Normalize inputs + a = a / np.linalg.norm(a) + b = b / np.linalg.norm(b) + + # Ensure shortest path + dot = np.dot(a, b) + if dot < 0.0: + b = -b + dot = -dot + + dot = np.clip(dot, -1.0, 1.0) + + # If quaternions are very close, use linear interpolation + if dot > 0.9995: + result = a + t * (b - a) + return result / np.linalg.norm(result) + + theta_0 = np.arccos(dot) + theta = theta_0 * t + sin_theta = np.sin(theta) + sin_theta_0 = np.sin(theta_0) + + s0 = np.cos(theta) - dot * sin_theta / sin_theta_0 + s1 = sin_theta / sin_theta_0 + + result = s0 * a + s1 * b + return result / np.linalg.norm(result) + + +def squad( + q0: ArrayLike, + q1: ArrayLike, + q2: ArrayLike, + q3: ArrayLike, + t: float, +) -> np.ndarray: + """Spherical cubic interpolation (SQUAD) between quaternion waypoints. + + Produces a smooth C¹-continuous curve through a sequence of orientations. + *q0* and *q3* are the neighboring control points; the interpolation is + between *q1* (at *t* = 0) and *q2* (at *t* = 1). + + Args: + q0: Control quaternion before *q1*. + q1: Start quaternion for this segment. + q2: End quaternion for this segment. + q3: Control quaternion after *q2*. + t: Interpolation parameter in ``[0, 1]``. + + Returns: + Interpolated unit quaternion ``[x, y, z, w]``. + """ + a1 = np.asarray(q1, dtype=float) + a2 = np.asarray(q2, dtype=float) + + s1 = _squad_intermediate(np.asarray(q0, dtype=float), a1, a2) + s2 = _squad_intermediate(a1, a2, np.asarray(q3, dtype=float)) + + slerp_q1_q2 = slerp(a1, a2, t) + slerp_s1_s2 = slerp(s1, s2, t) + return slerp(slerp_q1_q2, slerp_s1_s2, 2.0 * t * (1.0 - t)) + + +def _squad_intermediate( + q_prev: np.ndarray, q_curr: np.ndarray, q_next: np.ndarray +) -> np.ndarray: + """Compute the SQUAD intermediate control point for *q_curr*.""" + from .core import quaternion_conjugate, quaternion_multiply + + q_curr = q_curr / np.linalg.norm(q_curr) + inv_curr = quaternion_conjugate(q_curr) + + log_prev = _quat_log(quaternion_multiply(inv_curr, q_prev / np.linalg.norm(q_prev))) + log_next = _quat_log(quaternion_multiply(inv_curr, q_next / np.linalg.norm(q_next))) + + avg = -(log_prev + log_next) / 4.0 + result = quaternion_multiply(q_curr, _quat_exp(avg)) + return result / np.linalg.norm(result) + + +def _quat_log(q: np.ndarray) -> np.ndarray: + """Quaternion logarithm (returns pure-quaternion vector part).""" + q = q / np.linalg.norm(q) + vec = q[:3] + w = q[3] + vec_norm = np.linalg.norm(vec) + if vec_norm < 1e-10: + return np.array([0.0, 0.0, 0.0, 0.0]) + theta = np.arctan2(vec_norm, w) + return np.array([*(vec / vec_norm * theta), 0.0]) + + +def _quat_exp(v: np.ndarray) -> np.ndarray: + """Quaternion exponential (from pure-quaternion vector).""" + vec = v[:3] + theta = np.linalg.norm(vec) + if theta < 1e-10: + return np.array([0.0, 0.0, 0.0, 1.0]) + axis = vec / theta + return np.array([*(axis * np.sin(theta)), np.cos(theta)]) diff --git a/spinstep/traversal/__init__.py b/spinstep/traversal/__init__.py new file mode 100644 index 0000000..d480ffa --- /dev/null +++ b/spinstep/traversal/__init__.py @@ -0,0 +1,25 @@ +# traversal/__init__.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Tree traversal using quaternion orientation. + +This sub-package contains the traversal classes: + +- :class:`Node` — tree node with quaternion orientation +- :class:`QuaternionDepthIterator` — continuous rotation-step depth-first traversal +- :class:`DiscreteOrientationSet` — queryable set of discrete orientations +- :class:`DiscreteQuaternionIterator` — discrete rotation-step depth-first traversal +""" + +__all__ = [ + "Node", + "QuaternionDepthIterator", + "DiscreteOrientationSet", + "DiscreteQuaternionIterator", +] + +from .node import Node +from .continuous import QuaternionDepthIterator +from .discrete import DiscreteOrientationSet +from .discrete_iterator import DiscreteQuaternionIterator diff --git a/spinstep/traversal.py b/spinstep/traversal/continuous.py similarity index 100% rename from spinstep/traversal.py rename to spinstep/traversal/continuous.py diff --git a/spinstep/discrete.py b/spinstep/traversal/discrete.py similarity index 99% rename from spinstep/discrete.py rename to spinstep/traversal/discrete.py index b4fae53..be04773 100644 --- a/spinstep/discrete.py +++ b/spinstep/traversal/discrete.py @@ -14,7 +14,7 @@ from numpy.typing import ArrayLike from scipy.spatial.transform import Rotation as R -from spinstep.utils.array_backend import get_array_module +from ..utils.array_backend import get_array_module class DiscreteOrientationSet: diff --git a/spinstep/discrete_iterator.py b/spinstep/traversal/discrete_iterator.py similarity index 100% rename from spinstep/discrete_iterator.py rename to spinstep/traversal/discrete_iterator.py diff --git a/spinstep/node.py b/spinstep/traversal/node.py similarity index 84% rename from spinstep/node.py rename to spinstep/traversal/node.py index d7d9e20..b7d1c78 100644 --- a/spinstep/node.py +++ b/spinstep/traversal/node.py @@ -40,8 +40,7 @@ class Node: from spinstep import Node root = Node("root", [0, 0, 0, 1]) - child = Node("child", [0.2588, 0, 0, 0.9659]) - root.children.append(child) + child = root.add_child(Node("child", [0.2588, 0, 0, 0.9659])) """ name: str @@ -64,5 +63,17 @@ def __init__( self.name = name self.children = list(children) if children else [] + def add_child(self, child: "Node") -> "Node": + """Append *child* to this node's children and return it. + + Args: + child: The child node to add. + + Returns: + The same *child* node, for convenience. + """ + self.children.append(child) + return child + def __repr__(self) -> str: return f"Node({self.name!r}, orientation={self.orientation.tolist()})" diff --git a/spinstep/utils/__init__.py b/spinstep/utils/__init__.py index 052def5..9ff0f8e 100644 --- a/spinstep/utils/__init__.py +++ b/spinstep/utils/__init__.py @@ -2,14 +2,18 @@ # Author: Eraldo B. Marques — Created: 2025-05-14 # See LICENSE.txt for full terms. This header must be retained in redistributions. -"""Utilities for quaternion math and array backend selection. +"""Utilities: array backend selection and backward-compatible quaternion re-exports. -This sub-package provides: +The ``get_array_module`` function is the primary utility provided here. +All quaternion math functions have moved to :mod:`spinstep.math` and are +re-exported here only for backward compatibility. -- :func:`~.array_backend.get_array_module` — NumPy / CuPy backend selection. -- :func:`~.quaternion_math.batch_quaternion_angle` — batch angular distances. -- Quaternion helpers in :mod:`~.quaternion_utils` (conversion, distance, - multiplication, etc.). +.. deprecated:: + For quaternion operations, import from :mod:`spinstep.math` instead. + +Example (preferred):: + + from spinstep.math import quaternion_multiply, quaternion_distance """ __all__ = [ diff --git a/spinstep/utils/quaternion_math.py b/spinstep/utils/quaternion_math.py index fbba4bf..347e42d 100644 --- a/spinstep/utils/quaternion_math.py +++ b/spinstep/utils/quaternion_math.py @@ -2,34 +2,17 @@ # Author: Eraldo B. Marques — Created: 2025-05-14 # See LICENSE.txt for full terms. This header must be retained in redistributions. -"""Batch quaternion angular distance computation.""" +"""Backward-compatible re-export of batch quaternion angle computation. -from __future__ import annotations - -from types import ModuleType -from typing import Any - -__all__ = ["batch_quaternion_angle"] +The implementation has moved to :mod:`spinstep.math.analysis`. +.. deprecated:: + Import from :mod:`spinstep.math` instead. +""" -def batch_quaternion_angle(qs1: Any, qs2: Any, xp: ModuleType) -> Any: - """Compute pairwise angular distances between two sets of quaternions. +from __future__ import annotations - Parameters - ---------- - qs1: - Array of shape ``(N, 4)`` — first set of quaternions. - qs2: - Array of shape ``(M, 4)`` — second set of quaternions. - xp: - Array module (:mod:`numpy` or :mod:`cupy`). +__all__ = ["batch_quaternion_angle"] - Returns - ------- - array - ``(N, M)`` array of angular distances in radians. - """ - dots = xp.abs(xp.dot(qs1, qs2.T)) - dots = xp.clip(dots, -1.0, 1.0) - angles = 2 * xp.arccos(dots) - return angles +# Re-export from canonical location +from spinstep.math.analysis import batch_quaternion_angle diff --git a/spinstep/utils/quaternion_utils.py b/spinstep/utils/quaternion_utils.py index 1d01aaa..d0bf24e 100644 --- a/spinstep/utils/quaternion_utils.py +++ b/spinstep/utils/quaternion_utils.py @@ -2,7 +2,15 @@ # Author: Eraldo Marques — Created: 2025-05-14 # See LICENSE.txt for full terms. This header must be retained in redistributions. -"""Quaternion conversion, distance, and manipulation utilities.""" +"""Backward-compatible re-exports of quaternion utilities. + +All quaternion functions have moved to :mod:`spinstep.math`. This module +re-exports them so that existing ``from spinstep.utils.quaternion_utils import …`` +statements continue to work. + +.. deprecated:: + Import from :mod:`spinstep.math` instead. +""" from __future__ import annotations @@ -21,209 +29,21 @@ "angle_between_directions", ] -from typing import List, Sequence - -import numpy as np -from numpy.typing import ArrayLike -from scipy.spatial.transform import Rotation as R - - -def quaternion_from_euler( - angles: Sequence[float], - order: str = "zyx", - degrees: bool = True, -) -> np.ndarray: - """Convert Euler angles to a quaternion ``[x, y, z, w]``.""" - return R.from_euler(order, angles, degrees=degrees).as_quat() - - -def quaternion_distance(q1: ArrayLike, q2: ArrayLike) -> float: - """Return the angular distance (radians) between two quaternions.""" - r1 = R.from_quat(q1) - r2 = R.from_quat(q2) - return float((r1.inv() * r2).magnitude()) - - -def rotate_quaternion(q: ArrayLike, rotation_step: ArrayLike) -> np.ndarray: - """Apply *rotation_step* to quaternion *q* and return the result.""" - r1 = R.from_quat(q) - step = R.from_quat(rotation_step) - return (r1 * step).as_quat() - - -def is_within_angle_threshold( - q_current: ArrayLike, - q_target: ArrayLike, - threshold_rad: float, -) -> bool: - """Check whether two quaternions are within *threshold_rad* of each other.""" - return quaternion_distance(q_current, q_target) < threshold_rad - - -def quaternion_conjugate(q: ArrayLike) -> np.ndarray: - """Return the conjugate of quaternion *q* ``[x, y, z, w]``.""" - return np.array([-q[0], -q[1], -q[2], q[3]]) - - -def quaternion_multiply(q1: ArrayLike, q2: ArrayLike) -> np.ndarray: - """Hamilton product of two quaternions in ``[x, y, z, w]`` order.""" - x1, y1, z1, w1 = q1 - x2, y2, z2, w2 = q2 - return np.array([ - w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, - w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, - w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, - w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, - ]) - - -def rotation_matrix_to_quaternion(m: ArrayLike) -> np.ndarray: - """Convert a 3×3 rotation matrix to a unit quaternion ``[x, y, z, w]``.""" - t = np.trace(m) - if t > 0: - s = np.sqrt(t + 1) * 2 - qw = 0.25 * s - qx = (m[2, 1] - m[1, 2]) / s - qy = (m[0, 2] - m[2, 0]) / s - qz = (m[1, 0] - m[0, 1]) / s - elif (m[0, 0] > m[1, 1]) and (m[0, 0] > m[2, 2]): - s = np.sqrt(1 + m[0, 0] - m[1, 1] - m[2, 2]) * 2 - qx = 0.25 * s - qw = (m[2, 1] - m[1, 2]) / s - qy = (m[0, 1] + m[1, 0]) / s - qz = (m[0, 2] + m[2, 0]) / s - elif m[1, 1] > m[2, 2]: - s = np.sqrt(1 + m[1, 1] - m[0, 0] - m[2, 2]) * 2 - qy = 0.25 * s - qw = (m[0, 2] - m[2, 0]) / s - qx = (m[0, 1] + m[1, 0]) / s - qz = (m[1, 2] + m[2, 1]) / s - else: - s = np.sqrt(1 + m[2, 2] - m[0, 0] - m[1, 1]) * 2 - qz = 0.25 * s - qw = (m[1, 0] - m[0, 1]) / s - qx = (m[0, 2] + m[2, 0]) / s - qy = (m[1, 2] + m[2, 1]) / s - q = np.array([qx, qy, qz, qw]) - n = np.linalg.norm(q) - return q / n if n > 1e-8 else np.array([0.0, 0.0, 0.0, 1.0]) - - -def forward_vector_from_quaternion(q: ArrayLike) -> np.ndarray: - """Extract the forward (look) direction from a quaternion. - - The forward direction is defined as ``[0, 0, -1]`` rotated by the - quaternion, following the convention where negative-Z is "forward". - - Args: - q: Quaternion ``[x, y, z, w]``. - - Returns: - Unit direction vector ``(3,)`` pointing forward. - """ - return R.from_quat(q).apply([0, 0, -1]) - - -def direction_to_quaternion(direction: ArrayLike) -> np.ndarray: - """Convert a 3D direction vector to an orientation quaternion. - - The returned quaternion represents the rotation that aligns the - default forward axis ``[0, 0, -1]`` with the given *direction*. - - Args: - direction: Target direction vector (does not need to be normalised). - - Returns: - Unit quaternion ``[x, y, z, w]``. - """ - d = np.asarray(direction, dtype=float) - norm = np.linalg.norm(d) - if norm < 1e-8: - return np.array([0.0, 0.0, 0.0, 1.0]) - d = d / norm - rot, _ = R.align_vectors([d], [[0, 0, -1]]) - return rot.as_quat() - - -def angle_between_directions(d1: ArrayLike, d2: ArrayLike) -> float: - """Compute the angular distance (radians) between two direction vectors. - - Args: - d1: First direction vector. - d2: Second direction vector. - - Returns: - Angle in radians in the range ``[0, π]``. - """ - v1 = np.asarray(d1, dtype=float) - v2 = np.asarray(d2, dtype=float) - n1 = np.linalg.norm(v1) - n2 = np.linalg.norm(v2) - if n1 < 1e-8 or n2 < 1e-8: - return 0.0 - cos_angle = np.dot(v1 / n1, v2 / n2) - cos_angle = np.clip(cos_angle, -1.0, 1.0) - return float(np.arccos(cos_angle)) - - -def get_relative_spin(nf: object, nt: object) -> np.ndarray: - """Return the relative quaternion rotation from node *nf* to node *nt*. - - Both nodes must have an ``.orientation`` attribute storing a quaternion - ``[x, y, z, w]``. - """ - qfc = quaternion_conjugate(nf.orientation) # type: ignore[union-attr] - qr = quaternion_multiply(qfc, nt.orientation) # type: ignore[union-attr] - n = np.linalg.norm(qr) - return qr / n if n > 1e-8 else np.array([0.0, 0.0, 0.0, 1.0]) - - -def get_unique_relative_spins( - nodes: Sequence[object], - nside: int, - nest: bool, - threshold: float = 1e-3, -) -> List[np.ndarray]: - """Compute unique relative rotations between HEALPix neighbours. - - Requires the ``healpy`` package. - - Parameters - ---------- - nodes: - Sequence of node objects with ``.orientation`` attributes. - nside: - HEALPix *nside* parameter. - nest: - Whether to use the NESTED pixel ordering. - threshold: - Angular threshold (radians) for considering two rotations identical. - """ - try: - import healpy as hp - except ImportError: - raise ImportError( - "healpy is required for get_unique_relative_spins(). " - "Install it with: pip install healpy" - ) - spins: List[np.ndarray] = [] - NPIX = hp.nside2npix(nside) - for i in range(NPIX): - nf = nodes[i] - nidx = hp.get_all_neighbours(nside, i, nest=nest) - for idx in nidx: - if idx != -1: - q = get_relative_spin(nf, nodes[idx]) - if q[3] < 0: - q = -q # Canonical form (w >= 0) - is_uniq = True - for s_q in spins: - dot = np.abs(np.dot(q, s_q)) - dot = np.clip(dot, -1, 1) - angle = 2 * np.arccos(dot) - if angle < threshold: - is_uniq = False - break - if is_uniq: - spins.append(q) - return spins +# Re-export from canonical locations in spinstep.math +from spinstep.math.core import quaternion_conjugate, quaternion_multiply +from spinstep.math.geometry import ( + angle_between_directions, + direction_to_quaternion, + forward_vector_from_quaternion, + is_within_angle_threshold, + quaternion_distance, + rotate_quaternion, +) +from spinstep.math.conversions import ( + quaternion_from_euler, + rotation_matrix_to_quaternion, +) +from spinstep.math.analysis import ( + get_relative_spin, + get_unique_relative_spins, +) diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..940ff16 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,299 @@ +# test_api.py — SpinStep Test Suite — MIT License +# Author: SpinStep Contributors +# See LICENSE.txt for full terms. + +"""Tests for public API stability, exports, and packaging markers.""" + +import importlib +import re + +import pytest + + +class TestPackageMetadata: + """Verify package-level metadata and markers.""" + + def test_version_format(self) -> None: + """Version string follows PEP 440.""" + import spinstep + + assert hasattr(spinstep, "__version__") + # PEP 440: N.N.NaN, N.N.N, etc. + assert re.match( + r"^\d+\.\d+\.\d+(a\d+|b\d+|rc\d+)?$", spinstep.__version__ + ), f"Invalid version format: {spinstep.__version__}" + + def test_py_typed_marker(self) -> None: + """PEP 561 py.typed marker exists.""" + import pathlib + + import spinstep + + pkg_dir = pathlib.Path(spinstep.__file__).parent + assert (pkg_dir / "py.typed").exists(), "Missing py.typed marker" + + +class TestTopLevelExports: + """Verify spinstep.__all__ contains expected symbols.""" + + EXPECTED_EXPORTS = [ + # control + "OrientationState", + "ControlCommand", + "integrate_state", + "compute_orientation_error", + "OrientationController", + "ProportionalOrientationController", + "PIDOrientationController", + "OrientationTrajectory", + "TrajectoryInterpolator", + "TrajectoryController", + # math + "slerp", + # traversal + "Node", + "QuaternionDepthIterator", + "DiscreteOrientationSet", + "DiscreteQuaternionIterator", + ] + + def test_all_defined(self) -> None: + import spinstep + + assert hasattr(spinstep, "__all__") + assert isinstance(spinstep.__all__, list) + + @pytest.mark.parametrize("name", EXPECTED_EXPORTS) + def test_export_importable(self, name: str) -> None: + """Every name in __all__ is importable from the top-level package.""" + import spinstep + + assert name in spinstep.__all__, f"{name} missing from __all__" + assert hasattr(spinstep, name), f"{name} not accessible on spinstep" + + +class TestMathSubpackageExports: + """Verify spinstep.math.__all__ contains expected symbols.""" + + EXPECTED_EXPORTS = [ + "quaternion_multiply", + "quaternion_conjugate", + "quaternion_normalize", + "quaternion_inverse", + "slerp", + "squad", + "quaternion_distance", + "is_within_angle_threshold", + "forward_vector_from_quaternion", + "direction_to_quaternion", + "angle_between_directions", + "rotate_quaternion", + "quaternion_from_euler", + "rotation_matrix_to_quaternion", + "quaternion_from_rotvec", + "quaternion_to_rotvec", + "batch_quaternion_angle", + "angular_velocity_from_quaternions", + "get_relative_spin", + "get_unique_relative_spins", + "clamp_rotation_angle", + "NodeProtocol", + ] + + def test_all_defined(self) -> None: + from spinstep import math + + assert hasattr(math, "__all__") + + @pytest.mark.parametrize("name", EXPECTED_EXPORTS) + def test_export_importable(self, name: str) -> None: + from spinstep import math + + assert name in math.__all__, f"{name} missing from math.__all__" + assert hasattr(math, name), f"{name} not accessible on spinstep.math" + + +class TestControlSubpackageExports: + """Verify spinstep.control.__all__ contains expected symbols.""" + + EXPECTED_EXPORTS = [ + "OrientationState", + "ControlCommand", + "integrate_state", + "compute_orientation_error", + "OrientationController", + "ProportionalOrientationController", + "PIDOrientationController", + "OrientationTrajectory", + "TrajectoryInterpolator", + "TrajectoryController", + ] + + def test_all_defined(self) -> None: + from spinstep import control + + assert hasattr(control, "__all__") + + @pytest.mark.parametrize("name", EXPECTED_EXPORTS) + def test_export_importable(self, name: str) -> None: + from spinstep import control + + assert name in control.__all__, f"{name} missing from control.__all__" + assert hasattr(control, name), f"{name} not accessible on spinstep.control" + + +class TestTraversalSubpackageExports: + """Verify spinstep.traversal.__all__ contains expected symbols.""" + + EXPECTED_EXPORTS = [ + "Node", + "QuaternionDepthIterator", + "DiscreteOrientationSet", + "DiscreteQuaternionIterator", + ] + + def test_all_defined(self) -> None: + from spinstep import traversal + + assert hasattr(traversal, "__all__") + + @pytest.mark.parametrize("name", EXPECTED_EXPORTS) + def test_export_importable(self, name: str) -> None: + from spinstep import traversal + + assert name in traversal.__all__, f"{name} missing from traversal.__all__" + assert hasattr( + traversal, name + ), f"{name} not accessible on spinstep.traversal" + + +class TestUtilsBackwardCompat: + """Verify that utils/ re-exports still work for backward compatibility.""" + + COMPAT_NAMES = [ + "get_array_module", + "batch_quaternion_angle", + "quaternion_from_euler", + "quaternion_distance", + "rotate_quaternion", + "is_within_angle_threshold", + "quaternion_conjugate", + "quaternion_multiply", + "rotation_matrix_to_quaternion", + "get_relative_spin", + "get_unique_relative_spins", + "forward_vector_from_quaternion", + "direction_to_quaternion", + "angle_between_directions", + ] + + @pytest.mark.parametrize("name", COMPAT_NAMES) + def test_utils_reexport(self, name: str) -> None: + """Functions remain importable from spinstep.utils.""" + from spinstep import utils + + assert hasattr(utils, name), f"{name} not accessible on spinstep.utils" + + def test_utils_quaternion_functions_are_math_functions(self) -> None: + """Verify that utils re-exports point to the same objects as math.""" + from spinstep import math as sp_math + from spinstep import utils as sp_utils + + shared = [ + "quaternion_multiply", + "quaternion_conjugate", + "quaternion_distance", + "quaternion_from_euler", + "rotation_matrix_to_quaternion", + "rotate_quaternion", + "is_within_angle_threshold", + "forward_vector_from_quaternion", + "direction_to_quaternion", + "angle_between_directions", + "batch_quaternion_angle", + "get_relative_spin", + "get_unique_relative_spins", + ] + for name in shared: + assert getattr(sp_utils, name) is getattr(sp_math, name), ( + f"utils.{name} is not the same object as math.{name}" + ) + + +class TestSubpackagesImportable: + """Verify all subpackages can be imported.""" + + @pytest.mark.parametrize( + "module", + [ + "spinstep", + "spinstep.math", + "spinstep.math.core", + "spinstep.math.geometry", + "spinstep.math.conversions", + "spinstep.math.interpolation", + "spinstep.math.analysis", + "spinstep.math.constraints", + "spinstep.control", + "spinstep.control.state", + "spinstep.control.controllers", + "spinstep.control.trajectory", + "spinstep.traversal", + "spinstep.traversal.node", + "spinstep.traversal.continuous", + "spinstep.traversal.discrete", + "spinstep.traversal.discrete_iterator", + "spinstep.utils", + "spinstep.utils.array_backend", + "spinstep.utils.quaternion_math", + "spinstep.utils.quaternion_utils", + ], + ) + def test_importable(self, module: str) -> None: + """Every listed module can be imported without error.""" + importlib.import_module(module) + + +class TestNodeProtocol: + """Verify NodeProtocol structural typing works correctly.""" + + def test_node_satisfies_protocol(self) -> None: + """spinstep.traversal.Node satisfies NodeProtocol.""" + from spinstep.math.analysis import NodeProtocol + from spinstep.traversal.node import Node + + node = Node("test", [0, 0, 0, 1]) + assert isinstance(node, NodeProtocol) + + def test_custom_class_satisfies_protocol(self) -> None: + """Any class with .orientation satisfies NodeProtocol.""" + import numpy as np + + from spinstep.math.analysis import NodeProtocol + + class MyNode: + def __init__(self) -> None: + self.orientation = np.array([0.0, 0.0, 0.0, 1.0]) + + assert isinstance(MyNode(), NodeProtocol) + + +class TestNodeAddChild: + """Verify Node.add_child() convenience method.""" + + def test_add_child_appends(self) -> None: + from spinstep.traversal.node import Node + + root = Node("root", [0, 0, 0, 1]) + child = Node("child", [1, 0, 0, 0]) + result = root.add_child(child) + assert child in root.children + assert result is child + + def test_add_child_returns_child(self) -> None: + from spinstep.traversal.node import Node + + root = Node("root", [0, 0, 0, 1]) + child = root.add_child(Node("child", [0, 1, 0, 0])) + assert child.name == "child" + assert len(root.children) == 1 diff --git a/tests/test_control.py b/tests/test_control.py new file mode 100644 index 0000000..123e866 --- /dev/null +++ b/tests/test_control.py @@ -0,0 +1,515 @@ +# test_control.py — SpinStep Test Suite — MIT License +# Tests for spinstep.control subpackage (state, controllers, trajectory) + +import numpy as np +import pytest +from scipy.spatial.transform import Rotation as R + +from spinstep.control.state import ( + ControlCommand, + OrientationState, + compute_orientation_error, + integrate_state, +) +from spinstep.control.controllers import ( + PIDOrientationController, + ProportionalOrientationController, +) +from spinstep.control.trajectory import ( + OrientationTrajectory, + TrajectoryController, + TrajectoryInterpolator, +) + + +# ===== OrientationState ===== + + +class TestOrientationState: + def test_defaults(self): + state = OrientationState() + assert np.allclose(state.quaternion, [0, 0, 0, 1]) + assert state.distance == 0.0 + assert np.allclose(state.angular_velocity, [0, 0, 0]) + assert state.radial_velocity == 0.0 + assert state.timestamp == 0.0 + + def test_with_distance(self): + state = OrientationState([0, 0, 0, 1], distance=5.0) + assert state.distance == 5.0 + + def test_normalizes_quaternion(self): + state = OrientationState([0, 0, 0, 2]) + assert np.allclose(state.quaternion, [0, 0, 0, 1]) + + def test_full_state(self): + state = OrientationState( + [0, 0, 0, 1], + distance=10.0, + angular_velocity=[0, 0, 1.0], + radial_velocity=0.5, + timestamp=1.0, + ) + assert state.distance == 10.0 + assert state.radial_velocity == 0.5 + assert state.timestamp == 1.0 + + def test_invalid_quaternion_shape(self): + with pytest.raises(ValueError, match="shape"): + OrientationState([1, 0, 0]) + + def test_zero_quaternion(self): + with pytest.raises(ValueError, match="non-zero"): + OrientationState([0, 0, 0, 0]) + + def test_negative_distance(self): + with pytest.raises(ValueError, match="non-negative"): + OrientationState([0, 0, 0, 1], distance=-1.0) + + def test_repr(self): + state = OrientationState([0, 0, 0, 1], distance=5.0) + r = repr(state) + assert "OrientationState" in r + assert "d=5.0" in r + + +# ===== ControlCommand ===== + + +class TestControlCommand: + def test_defaults(self): + cmd = ControlCommand() + assert np.allclose(cmd.angular_velocity, [0, 0, 0]) + assert cmd.radial_velocity == 0.0 + + def test_custom(self): + cmd = ControlCommand([1, 0, 0], radial_velocity=2.5) + assert np.allclose(cmd.angular_velocity, [1, 0, 0]) + assert cmd.radial_velocity == 2.5 + + def test_repr(self): + cmd = ControlCommand([0, 0, 1.0], radial_velocity=0.5) + r = repr(cmd) + assert "ControlCommand" in r + + +# ===== integrate_state ===== + + +class TestIntegrateState: + def test_stationary(self): + state = OrientationState([0, 0, 0, 1], distance=5.0) + new = integrate_state(state, dt=0.1) + assert np.allclose(new.quaternion, [0, 0, 0, 1]) + assert new.distance == 5.0 + assert new.timestamp == pytest.approx(0.1) + + def test_radial_integration(self): + state = OrientationState( + [0, 0, 0, 1], distance=5.0, radial_velocity=2.0 + ) + new = integrate_state(state, dt=1.0) + assert new.distance == pytest.approx(7.0) + + def test_radial_clamp_to_zero(self): + """Distance cannot go negative.""" + state = OrientationState( + [0, 0, 0, 1], distance=1.0, radial_velocity=-5.0 + ) + new = integrate_state(state, dt=1.0) + assert new.distance == 0.0 + + def test_angular_integration(self): + """Rotating about Z axis.""" + omega = np.pi # rad/s → 180° per second + state = OrientationState( + [0, 0, 0, 1], angular_velocity=[0, 0, omega] + ) + new = integrate_state(state, dt=0.5) + # After 0.5s at π rad/s → 90° rotation about Z + angle = R.from_quat(new.quaternion).magnitude() + assert angle == pytest.approx(np.pi / 2, abs=0.01) + + def test_combined_integration(self): + state = OrientationState( + [0, 0, 0, 1], + distance=5.0, + angular_velocity=[0, 0, 1.0], + radial_velocity=0.5, + timestamp=1.0, + ) + new = integrate_state(state, dt=0.1) + assert new.timestamp == pytest.approx(1.1) + assert new.distance == pytest.approx(5.05) + assert not np.allclose(new.quaternion, [0, 0, 0, 1]) + + def test_invalid_dt(self): + state = OrientationState() + with pytest.raises(ValueError): + integrate_state(state, dt=0) + with pytest.raises(ValueError): + integrate_state(state, dt=-0.1) + + +# ===== compute_orientation_error ===== + + +class TestComputeOrientationError: + def test_no_error(self): + ang, rad = compute_orientation_error( + [0, 0, 0, 1], [0, 0, 0, 1], 5.0, 5.0 + ) + assert np.allclose(ang, [0, 0, 0], atol=1e-6) + assert rad == pytest.approx(0.0) + + def test_angular_error_only(self): + q_target = R.from_euler("z", 45, degrees=True).as_quat() + ang, rad = compute_orientation_error( + [0, 0, 0, 1], q_target, 5.0, 5.0 + ) + assert np.linalg.norm(ang) == pytest.approx(np.deg2rad(45), abs=0.01) + assert rad == pytest.approx(0.0) + + def test_radial_error_only(self): + ang, rad = compute_orientation_error( + [0, 0, 0, 1], [0, 0, 0, 1], 3.0, 7.0 + ) + assert np.allclose(ang, [0, 0, 0], atol=1e-6) + assert rad == pytest.approx(4.0) + + def test_combined_error(self): + q_target = R.from_euler("z", 90, degrees=True).as_quat() + ang, rad = compute_orientation_error( + [0, 0, 0, 1], q_target, 2.0, 8.0 + ) + assert np.linalg.norm(ang) == pytest.approx(np.pi / 2, abs=0.01) + assert rad == pytest.approx(6.0) + + def test_backward_compat_no_distance(self): + """Works without distance arguments (defaults to 0).""" + ang, rad = compute_orientation_error([0, 0, 0, 1], [0, 0, 0, 1]) + assert np.allclose(ang, [0, 0, 0], atol=1e-6) + assert rad == pytest.approx(0.0) + + +# ===== ProportionalOrientationController ===== + + +class TestProportionalController: + def test_zero_error(self): + ctrl = ProportionalOrientationController(kp=2.0) + cmd = ctrl.update([0, 0, 0, 1], [0, 0, 0, 1], dt=0.01) + assert np.allclose(cmd.angular_velocity, [0, 0, 0], atol=1e-6) + assert cmd.radial_velocity == pytest.approx(0.0, abs=1e-6) + + def test_angular_error(self): + ctrl = ProportionalOrientationController(kp=1.0) + q_target = R.from_euler("z", 45, degrees=True).as_quat() + cmd = ctrl.update([0, 0, 0, 1], q_target, dt=0.01) + assert np.linalg.norm(cmd.angular_velocity) > 0 + + def test_radial_error(self): + ctrl = ProportionalOrientationController(kp=1.0, kp_radial=2.0) + cmd = ctrl.update( + [0, 0, 0, 1], [0, 0, 0, 1], dt=0.01, + current_distance=3.0, target_distance=5.0, + ) + assert cmd.radial_velocity == pytest.approx(4.0) # 2.0 × 2.0 + + def test_velocity_limit(self): + ctrl = ProportionalOrientationController( + kp=100.0, max_angular_velocity=1.0 + ) + q_target = R.from_euler("z", 90, degrees=True).as_quat() + cmd = ctrl.update([0, 0, 0, 1], q_target, dt=0.01) + assert np.linalg.norm(cmd.angular_velocity) <= 1.0 + 1e-6 + + def test_radial_velocity_limit(self): + ctrl = ProportionalOrientationController( + kp_radial=100.0, max_radial_velocity=2.0 + ) + cmd = ctrl.update( + [0, 0, 0, 1], [0, 0, 0, 1], dt=0.01, + current_distance=0.0, target_distance=10.0, + ) + assert abs(cmd.radial_velocity) <= 2.0 + 1e-6 + + def test_invalid_dt(self): + ctrl = ProportionalOrientationController() + with pytest.raises(ValueError): + ctrl.update([0, 0, 0, 1], [0, 0, 0, 1], dt=0) + + def test_reset(self): + ctrl = ProportionalOrientationController(kp=1.0) + ctrl.update([0, 0, 0, 1], [0, 0, 0, 1], dt=0.01) + ctrl.reset() + assert ctrl._prev_angular_cmd is None + + +# ===== PIDOrientationController ===== + + +class TestPIDController: + def test_zero_error(self): + ctrl = PIDOrientationController(kp=1.0, ki=0.1, kd=0.5) + cmd = ctrl.update([0, 0, 0, 1], [0, 0, 0, 1], dt=0.01) + assert np.allclose(cmd.angular_velocity, [0, 0, 0], atol=1e-6) + + def test_integral_accumulates(self): + ctrl = PIDOrientationController(kp=0.0, ki=1.0, kd=0.0) + q_target = R.from_euler("z", 10, degrees=True).as_quat() + # First step — integral starts accumulating + ctrl.update([0, 0, 0, 1], q_target, dt=0.1) + # Second step — integral grows + cmd = ctrl.update([0, 0, 0, 1], q_target, dt=0.1) + assert np.linalg.norm(cmd.angular_velocity) > 0 + + def test_derivative_term(self): + ctrl = PIDOrientationController(kp=0.0, ki=0.0, kd=1.0) + q1 = R.from_euler("z", 10, degrees=True).as_quat() + q2 = R.from_euler("z", 20, degrees=True).as_quat() + # First call sets _prev_error + ctrl.update([0, 0, 0, 1], q1, dt=0.01) + # Second call with different target should produce d_term + cmd = ctrl.update([0, 0, 0, 1], q2, dt=0.01) + assert np.linalg.norm(cmd.angular_velocity) > 0 + + def test_radial_pid(self): + ctrl = PIDOrientationController( + kp=0.0, ki=0.0, kd=0.0, + kp_radial=2.0, ki_radial=0.5, kd_radial=0.0, + ) + cmd = ctrl.update( + [0, 0, 0, 1], [0, 0, 0, 1], dt=0.1, + current_distance=3.0, target_distance=5.0, + ) + assert cmd.radial_velocity == pytest.approx(2.0 * 2.0 + 0.5 * 2.0 * 0.1) + + def test_reset_clears_state(self): + ctrl = PIDOrientationController(kp=1.0, ki=1.0, kd=1.0) + ctrl.update([0, 0, 0, 1], [0, 0, 0, 1], dt=0.01) + ctrl.reset() + assert np.allclose(ctrl._ang_integral, [0, 0, 0]) + assert ctrl._rad_integral == 0.0 + assert ctrl._prev_ang_error is None + assert ctrl._prev_rad_error is None + + def test_anti_windup(self): + """Integral should be clamped.""" + ctrl = PIDOrientationController(kp=0.0, ki=1.0, max_integral=0.01) + q_target = R.from_euler("z", 90, degrees=True).as_quat() + for _ in range(100): + ctrl.update([0, 0, 0, 1], q_target, dt=1.0) + assert np.linalg.norm(ctrl._ang_integral) <= 0.01 + 1e-6 + + +# ===== OrientationTrajectory ===== + + +class TestOrientationTrajectory: + def test_basic_3_tuple(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0.383, 0.924], 10.0, 1.0), + ]) + assert len(traj) == 2 + assert traj.duration == pytest.approx(1.0) + assert np.allclose(traj.distances, [5.0, 10.0]) + + def test_basic_2_tuple_backward_compat(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 0.0), + ([0, 0, 0.383, 0.924], 1.0), + ]) + assert len(traj) == 2 + assert np.allclose(traj.distances, [0.0, 0.0]) + + def test_mixed_tuple_lengths(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0.383, 0.924], 1.0), # 2-tuple, distance=0 + ]) + assert traj.distances[0] == 5.0 + assert traj.distances[1] == 0.0 + + def test_too_few_waypoints(self): + with pytest.raises(ValueError, match="At least 2"): + OrientationTrajectory([([0, 0, 0, 1], 0.0)]) + + def test_non_increasing_times(self): + with pytest.raises(ValueError, match="strictly increasing"): + OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 1.0), + ([0, 0, 0, 1], 5.0, 0.5), + ]) + + def test_properties(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 1.0), + ([0, 0, 0, 1], 10.0, 3.0), + ]) + assert traj.start_time == 1.0 + assert traj.end_time == 3.0 + assert traj.duration == 2.0 + + def test_repr(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0, 1], 10.0, 1.0), + ]) + assert "OrientationTrajectory" in repr(traj) + + +# ===== TrajectoryInterpolator ===== + + +class TestTrajectoryInterpolator: + def test_at_waypoints(self): + q0 = [0, 0, 0, 1] + q1 = R.from_euler("z", 90, degrees=True).as_quat().tolist() + traj = OrientationTrajectory([ + (q0, 5.0, 0.0), + (q1, 10.0, 1.0), + ]) + interp = TrajectoryInterpolator(traj) + q_start, d_start = interp.evaluate(0.0) + q_end, d_end = interp.evaluate(1.0) + assert np.allclose(q_start, q0, atol=1e-6) + assert d_start == pytest.approx(5.0) + assert d_end == pytest.approx(10.0) + + def test_midpoint_distance(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 4.0, 0.0), + ([0, 0, 0, 1], 8.0, 1.0), + ]) + interp = TrajectoryInterpolator(traj) + _, d_mid = interp.evaluate(0.5) + assert d_mid == pytest.approx(6.0) + + def test_before_start(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 1.0), + ([0, 0, 0, 1], 10.0, 2.0), + ]) + interp = TrajectoryInterpolator(traj) + q, d = interp.evaluate(0.0) + assert d == pytest.approx(5.0) + + def test_after_end(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0, 1], 10.0, 1.0), + ]) + interp = TrajectoryInterpolator(traj) + q, d = interp.evaluate(99.0) + assert d == pytest.approx(10.0) + + +# ===== TrajectoryController ===== + + +class TestTrajectoryController: + def test_basic_tracking(self): + q0 = [0, 0, 0, 1] + q1 = R.from_euler("z", 45, degrees=True).as_quat().tolist() + traj = OrientationTrajectory([ + (q0, 5.0, 0.0), + (q1, 10.0, 1.0), + ]) + ctrl = ProportionalOrientationController(kp=1.0, kp_radial=1.0) + tc = TrajectoryController(ctrl, traj) + + cmd = tc.update(q0, t=0.5, dt=0.01, current_distance=7.0) + # At t=0.5 target distance ≈ 7.5, so radial_velocity > 0 + assert cmd.radial_velocity > 0 + assert not tc.is_complete + + def test_complete_flag(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0, 1], 10.0, 1.0), + ]) + ctrl = ProportionalOrientationController() + tc = TrajectoryController(ctrl, traj) + tc.update([0, 0, 0, 1], t=1.5, dt=0.01, current_distance=10.0) + assert tc.is_complete + + def test_reset(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0, 1], 10.0, 1.0), + ]) + ctrl = ProportionalOrientationController() + tc = TrajectoryController(ctrl, traj) + tc.update([0, 0, 0, 1], t=2.0, dt=0.01) + assert tc.is_complete + tc.reset() + assert not tc.is_complete + + +# ===== Integration: full control loop ===== + + +class TestControlLoop: + def test_converges_to_target(self): + """P controller should drive orientation toward target over many steps.""" + ctrl = ProportionalOrientationController(kp=5.0, kp_radial=3.0) + current_q = np.array([0, 0, 0, 1.0]) + target_q = R.from_euler("z", 45, degrees=True).as_quat() + current_dist = 5.0 + target_dist = 10.0 + dt = 0.01 + + for _ in range(500): + cmd = ctrl.update( + current_q, target_q, dt, + current_distance=current_dist, target_distance=target_dist, + ) + state = OrientationState( + current_q, distance=current_dist, + angular_velocity=cmd.angular_velocity, + radial_velocity=cmd.radial_velocity, + ) + new_state = integrate_state(state, dt) + current_q = new_state.quaternion + current_dist = new_state.distance + + # Should be close to target + ang_err, rad_err = compute_orientation_error( + current_q, target_q, current_dist, target_dist + ) + assert np.linalg.norm(ang_err) < 0.01 + assert abs(rad_err) < 0.1 + + def test_trajectory_tracking_loop(self): + """TrajectoryController should follow a trajectory.""" + q0 = [0, 0, 0, 1] + q1 = R.from_euler("z", 90, degrees=True).as_quat().tolist() + traj = OrientationTrajectory([ + (q0, 5.0, 0.0), + (q1, 10.0, 2.0), + ]) + ctrl = ProportionalOrientationController(kp=5.0, kp_radial=3.0) + tc = TrajectoryController(ctrl, traj) + + current_q = np.array(q0, dtype=float) + current_dist = 5.0 + dt = 0.01 + t = 0.0 + + for _ in range(200): + cmd = tc.update(current_q, t=t, dt=dt, current_distance=current_dist) + state = OrientationState( + current_q, distance=current_dist, + angular_velocity=cmd.angular_velocity, + radial_velocity=cmd.radial_velocity, + ) + new_state = integrate_state(state, dt) + current_q = new_state.quaternion + current_dist = new_state.distance + t += dt + + # At t=2.0 should be near q1 / distance=10 + # At t=2.0 we're only at t=2.0 so check trajectory makes progress + assert current_dist > 5.0 # moved outward diff --git a/tests/test_discrete_traversal.py b/tests/test_discrete_traversal.py index f55df45..9cd9872 100644 --- a/tests/test_discrete_traversal.py +++ b/tests/test_discrete_traversal.py @@ -11,11 +11,10 @@ cuda_available = False import numpy as np -from scipy.spatial.transform import Rotation as R # Import the modules under test -from spinstep.discrete import DiscreteOrientationSet -from spinstep.discrete_iterator import DiscreteQuaternionIterator +from spinstep.traversal.discrete import DiscreteOrientationSet +from spinstep.traversal.discrete_iterator import DiscreteQuaternionIterator # Simple node class for testing class Node: @@ -106,7 +105,7 @@ def test_cuda_support(self): try: cuda_set = DiscreteOrientationSet([[0, 0, 0, 1]], use_cuda=True) # If this succeeds, basic CUDA import worked - assert cuda_set.use_cuda == True + assert cuda_set.use_cuda is True # Test as_numpy() method for GPU->CPU transfer cpu_array = cuda_set.as_numpy() @@ -261,8 +260,8 @@ def test_full_pipeline(): # Create a simple tree root = Node("root", [0, 0, 0, 1]) - node_a = root.add_child(Node("A", [1, 0, 0, 0])) - node_b = root.add_child(Node("B", [0, 1, 0, 0])) + root.add_child(Node("A", [1, 0, 0, 0])) + root.add_child(Node("B", [0, 1, 0, 0])) # Create iterator and traverse iterator = DiscreteQuaternionIterator( diff --git a/tests/test_math.py b/tests/test_math.py new file mode 100644 index 0000000..7a611dc --- /dev/null +++ b/tests/test_math.py @@ -0,0 +1,191 @@ +# test_math.py — SpinStep Test Suite — MIT License +# Tests for spinstep.math subpackage + +import numpy as np +import pytest +from scipy.spatial.transform import Rotation as R + +from spinstep.math.core import ( + quaternion_inverse, + quaternion_multiply, + quaternion_normalize, +) +from spinstep.math.interpolation import slerp, squad +from spinstep.math.geometry import quaternion_distance +from spinstep.math.conversions import ( + quaternion_from_rotvec, + quaternion_to_rotvec, +) +from spinstep.math.analysis import angular_velocity_from_quaternions +from spinstep.math.constraints import clamp_rotation_angle + + +# ===== core ===== + + +class TestQuaternionNormalize: + def test_unit_quaternion(self): + q = quaternion_normalize([0.5, 0.5, 0.5, 0.5]) + assert np.linalg.norm(q) == pytest.approx(1.0, abs=1e-10) + + def test_zero_quaternion(self): + q = quaternion_normalize([0, 0, 0, 0]) + assert np.allclose(q, [0, 0, 0, 1]) + + def test_non_unit(self): + q = quaternion_normalize([0, 0, 0, 2]) + assert np.allclose(q, [0, 0, 0, 1]) + + +class TestQuaternionInverse: + def test_unit_quaternion_inverse(self): + q = np.array([0.5, 0.5, 0.5, 0.5]) + inv = quaternion_inverse(q) + product = quaternion_multiply(q, inv) + # Should be close to identity [0, 0, 0, 1] + assert np.allclose(product, [0, 0, 0, 1], atol=1e-10) or np.allclose( + product, [0, 0, 0, -1], atol=1e-10 + ) + + def test_zero_quaternion(self): + inv = quaternion_inverse([0, 0, 0, 0]) + assert np.allclose(inv, [0, 0, 0, 1]) + + +# ===== interpolation ===== + + +class TestSlerp: + def test_endpoints(self): + q0 = np.array([0, 0, 0, 1.0]) + q1 = np.array([0, 0, np.sin(np.pi / 4), np.cos(np.pi / 4)]) + assert np.allclose(slerp(q0, q1, 0.0), q0, atol=1e-6) + assert np.allclose(slerp(q0, q1, 1.0), q1 / np.linalg.norm(q1), atol=1e-6) + + def test_midpoint_unit(self): + q0 = np.array([0, 0, 0, 1.0]) + q1 = np.array([0, 0, np.sin(np.pi / 4), np.cos(np.pi / 4)]) + mid = slerp(q0, q1, 0.5) + assert np.linalg.norm(mid) == pytest.approx(1.0, abs=1e-10) + + def test_same_quaternion(self): + q = np.array([0, 0, 0, 1.0]) + result = slerp(q, q, 0.5) + assert np.allclose(result, q, atol=1e-6) + + def test_shortest_path(self): + """SLERP takes the shortest arc even when quaternions are antipodal representatives.""" + q0 = np.array([0, 0, 0, 1.0]) + q1 = np.array([0, 0, 0, -1.0]) # same rotation, opposite sign + mid = slerp(q0, q1, 0.5) + assert np.linalg.norm(mid) == pytest.approx(1.0, abs=1e-6) + + def test_interpolation_angle(self): + """Midpoint of 0° and 90° should be ~45°.""" + q0 = np.array([0, 0, 0, 1.0]) + q1 = R.from_euler("z", 90, degrees=True).as_quat() + mid = slerp(q0, q1, 0.5) + angle = R.from_quat(mid).magnitude() + assert angle == pytest.approx(np.deg2rad(45), abs=0.01) + + +class TestSquad: + def test_returns_unit_quaternion(self): + q0 = R.from_euler("z", 0, degrees=True).as_quat() + q1 = R.from_euler("z", 30, degrees=True).as_quat() + q2 = R.from_euler("z", 60, degrees=True).as_quat() + q3 = R.from_euler("z", 90, degrees=True).as_quat() + result = squad(q0, q1, q2, q3, 0.5) + assert np.linalg.norm(result) == pytest.approx(1.0, abs=1e-6) + + def test_endpoints(self): + q0 = R.from_euler("z", 0, degrees=True).as_quat() + q1 = R.from_euler("z", 30, degrees=True).as_quat() + q2 = R.from_euler("z", 60, degrees=True).as_quat() + q3 = R.from_euler("z", 90, degrees=True).as_quat() + start = squad(q0, q1, q2, q3, 0.0) + end = squad(q0, q1, q2, q3, 1.0) + # At t=0 should be close to q1, at t=1 close to q2 + angle_start = quaternion_distance(start, q1) + angle_end = quaternion_distance(end, q2) + assert angle_start < 0.1 + assert angle_end < 0.1 + + +# ===== analysis ===== + + +class TestAngularVelocityFromQuaternions: + def test_no_rotation(self): + q = [0, 0, 0, 1] + omega = angular_velocity_from_quaternions(q, q, dt=0.1) + assert np.allclose(omega, [0, 0, 0], atol=1e-6) + + def test_known_rotation(self): + """90° about Z in 1 second → ω ≈ [0, 0, π/2].""" + q1 = [0, 0, 0, 1] + q2 = R.from_euler("z", 90, degrees=True).as_quat() + omega = angular_velocity_from_quaternions(q1, q2, dt=1.0) + assert np.allclose(omega, [0, 0, np.pi / 2], atol=1e-6) + + def test_invalid_dt(self): + with pytest.raises(ValueError): + angular_velocity_from_quaternions([0, 0, 0, 1], [0, 0, 0, 1], dt=0) + + +# ===== constraints ===== + + +class TestClampRotationAngle: + def test_within_limit(self): + """Small rotation should not be changed.""" + q = R.from_euler("z", 10, degrees=True).as_quat() + clamped = clamp_rotation_angle(q, max_angle=np.pi) + assert np.allclose(clamped, q, atol=1e-6) + + def test_clamped(self): + """Large rotation should be clamped to max_angle.""" + q = R.from_euler("z", 90, degrees=True).as_quat() + max_angle = np.deg2rad(45) + clamped = clamp_rotation_angle(q, max_angle=max_angle) + angle = R.from_quat(clamped).magnitude() + assert angle == pytest.approx(max_angle, abs=1e-6) + + def test_preserves_axis(self): + """Clamping preserves the rotation axis.""" + q = R.from_euler("z", 90, degrees=True).as_quat() + max_angle = np.deg2rad(45) + clamped = clamp_rotation_angle(q, max_angle=max_angle) + original_axis = R.from_quat(q).as_rotvec() + clamped_axis = R.from_quat(clamped).as_rotvec() + # Axes should be parallel + original_dir = original_axis / np.linalg.norm(original_axis) + clamped_dir = clamped_axis / np.linalg.norm(clamped_axis) + assert np.allclose(original_dir, clamped_dir, atol=1e-6) + + def test_negative_max_angle_raises(self): + with pytest.raises(ValueError): + clamp_rotation_angle([0, 0, 0, 1], max_angle=-1.0) + + +# ===== conversions ===== + + +class TestQuaternionFromRotvec: + def test_identity(self): + q = quaternion_from_rotvec([0, 0, 0]) + assert np.allclose(q, [0, 0, 0, 1], atol=1e-6) + + def test_90_about_z(self): + rotvec = [0, 0, np.pi / 2] + q = quaternion_from_rotvec(rotvec) + expected = R.from_rotvec(rotvec).as_quat() + assert np.allclose(q, expected, atol=1e-6) + + +class TestQuaternionToRotvec: + def test_roundtrip(self): + rotvec = [0.3, -0.5, 0.7] + q = quaternion_from_rotvec(rotvec) + result = quaternion_to_rotvec(q) + assert np.allclose(result, rotvec, atol=1e-6) diff --git a/tests/test_spinstep.py b/tests/test_spinstep.py index ee5b828..f2e61b5 100644 --- a/tests/test_spinstep.py +++ b/tests/test_spinstep.py @@ -6,15 +6,15 @@ import numpy as np try: - import cupy as cp + import cupy as cp # noqa: F401 HAS_CUPY = True except ImportError: HAS_CUPY = False -from spinstep.discrete import DiscreteOrientationSet +from spinstep.traversal.discrete import DiscreteOrientationSet # If you have continuous traversal classes, import them here -# from spinstep.continuous import QuaternionDepthIterator -from spinstep.node import Node +# from spinstep.traversal.continuous import QuaternionDepthIterator +from spinstep.traversal.node import Node @pytest.fixture def simple_tree(): diff --git a/tests/test_traversal.py b/tests/test_traversal.py index db8053e..b8b42a5 100644 --- a/tests/test_traversal.py +++ b/tests/test_traversal.py @@ -4,10 +4,9 @@ import pytest import numpy as np -from scipy.spatial.transform import Rotation as R -from spinstep.node import Node -from spinstep.traversal import QuaternionDepthIterator +from spinstep.traversal.node import Node +from spinstep.traversal.continuous import QuaternionDepthIterator class TestQuaternionDepthIterator: diff --git a/tests/test_utils.py b/tests/test_utils.py index 6e6ce90..21effca 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,7 +6,7 @@ import pytest from scipy.spatial.transform import Rotation as R -from spinstep.node import Node +from spinstep.traversal.node import Node from spinstep.utils.array_backend import get_array_module from spinstep.utils.quaternion_math import batch_quaternion_angle from spinstep.utils.quaternion_utils import (