From 48cd71904a8c12067cba5f301bf5290a69dd401a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:44:04 +0000 Subject: [PATCH 1/5] Add initial math/, control/, and traversal/ subpackages (pre-spherical model) Agent-Logs-Url: https://github.com/VoxleOne/SpinStep/sessions/dc0f5515-1df0-43c9-8b5e-28e75d27e046 Co-authored-by: VoxleOne <119956342+VoxleOne@users.noreply.github.com> --- spinstep/control/__init__.py | 43 ++++++ spinstep/control/controllers.py | 245 ++++++++++++++++++++++++++++++++ spinstep/control/state.py | 170 ++++++++++++++++++++++ spinstep/control/trajectory.py | 240 +++++++++++++++++++++++++++++++ spinstep/math/__init__.py | 74 ++++++++++ spinstep/math/analysis.py | 141 ++++++++++++++++++ spinstep/math/constraints.py | 52 +++++++ spinstep/math/conversions.py | 102 +++++++++++++ spinstep/math/core.py | 90 ++++++++++++ spinstep/math/geometry.py | 125 ++++++++++++++++ spinstep/math/interpolation.py | 135 ++++++++++++++++++ spinstep/traversal/__init__.py | 25 ++++ 12 files changed, 1442 insertions(+) create mode 100644 spinstep/control/__init__.py create mode 100644 spinstep/control/controllers.py create mode 100644 spinstep/control/state.py create mode 100644 spinstep/control/trajectory.py create mode 100644 spinstep/math/__init__.py create mode 100644 spinstep/math/analysis.py create mode 100644 spinstep/math/constraints.py create mode 100644 spinstep/math/conversions.py create mode 100644 spinstep/math/core.py create mode 100644 spinstep/math/geometry.py create mode 100644 spinstep/math/interpolation.py create mode 100644 spinstep/traversal/__init__.py diff --git a/spinstep/control/__init__.py b/spinstep/control/__init__.py new file mode 100644 index 0000000..bdc4c54 --- /dev/null +++ b/spinstep/control/__init__.py @@ -0,0 +1,43 @@ +# control/__init__.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Orientation control: state, controllers, and trajectory tracking. + +Sub-modules: + +- :mod:`~.state` — :class:`OrientationState`, integration, error computation +- :mod:`~.controllers` — proportional and PID orientation controllers +- :mod:`~.trajectory` — waypoint trajectories and trajectory tracking +""" + +__all__ = [ + # state + "OrientationState", + "integrate_orientation", + "compute_orientation_error", + # controllers + "OrientationController", + "ProportionalOrientationController", + "PIDOrientationController", + # trajectory + "OrientationTrajectory", + "TrajectoryInterpolator", + "TrajectoryController", +] + +from .state import ( + OrientationState, + compute_orientation_error, + integrate_orientation, +) +from .controllers import ( + OrientationController, + PIDOrientationController, + ProportionalOrientationController, +) +from .trajectory import ( + OrientationTrajectory, + TrajectoryController, + TrajectoryInterpolator, +) diff --git a/spinstep/control/controllers.py b/spinstep/control/controllers.py new file mode 100644 index 0000000..501c071 --- /dev/null +++ b/spinstep/control/controllers.py @@ -0,0 +1,245 @@ +# control/controllers.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Orientation controllers: proportional and PID with rate limiting.""" + +from __future__ import annotations + +__all__ = [ + "OrientationController", + "ProportionalOrientationController", + "PIDOrientationController", +] + +from abc import ABC, abstractmethod +from typing import Optional + +import numpy as np +from numpy.typing import ArrayLike + +from .state import compute_orientation_error + + +class OrientationController(ABC): + """Abstract base class for orientation controllers. + + All controllers share the ``update`` interface, which takes the current + and target quaternions plus a time step and returns an angular velocity + command vector. + + Args: + max_angular_velocity: Maximum angular velocity magnitude + (rad/s). ``None`` means unlimited. + max_angular_acceleration: Maximum angular acceleration magnitude + (rad/s²). ``None`` means unlimited. + """ + + def __init__( + self, + max_angular_velocity: Optional[float] = None, + max_angular_acceleration: Optional[float] = None, + ) -> None: + self.max_angular_velocity = max_angular_velocity + self.max_angular_acceleration = max_angular_acceleration + self._prev_command: Optional[np.ndarray] = None + + @abstractmethod + def compute_raw_command( + self, current_q: ArrayLike, target_q: ArrayLike, dt: float + ) -> np.ndarray: + """Compute the raw (unclamped) angular velocity command. + + Subclasses must implement this. + + Args: + current_q: Current orientation ``[x, y, z, w]``. + target_q: Target orientation ``[x, y, z, w]``. + dt: Time step in seconds. + + Returns: + Raw angular velocity command ``(3,)`` in rad/s. + """ + ... + + def update( + self, current_q: ArrayLike, target_q: ArrayLike, dt: float + ) -> np.ndarray: + """Compute a rate-limited angular velocity command. + + Calls :meth:`compute_raw_command` and then applies + :attr:`max_angular_velocity` and :attr:`max_angular_acceleration` + limits. + + Args: + current_q: Current orientation ``[x, y, z, w]``. + target_q: Target orientation ``[x, y, z, w]``. + dt: Time step in seconds. Must be positive. + + Returns: + Angular velocity command ``(3,)`` in rad/s. + + Raises: + ValueError: If *dt* is not positive. + """ + if dt <= 0: + raise ValueError(f"dt must be positive, got {dt}") + + command = self.compute_raw_command(current_q, target_q, dt) + + # Apply velocity limit + if self.max_angular_velocity is not None: + speed = np.linalg.norm(command) + if speed > self.max_angular_velocity: + command = command * (self.max_angular_velocity / speed) + + # Apply acceleration limit + if self.max_angular_acceleration is not None and self._prev_command is not None: + delta = command - self._prev_command + accel = np.linalg.norm(delta) / dt + if accel > self.max_angular_acceleration: + max_delta = ( + delta / np.linalg.norm(delta) + * self.max_angular_acceleration + * dt + ) + command = self._prev_command + max_delta + + self._prev_command = command.copy() + return command + + def reset(self) -> None: + """Reset the controller's internal state.""" + self._prev_command = None + + +class ProportionalOrientationController(OrientationController): + """Proportional (P) orientation controller. + + Computes the angular velocity command as ``kp × error``, where the + error is the rotation vector from the current to the target orientation. + + Args: + kp: Proportional gain. Defaults to ``1.0``. + max_angular_velocity: Maximum angular velocity magnitude (rad/s). + max_angular_acceleration: Maximum angular acceleration (rad/s²). + + Example:: + + from spinstep.control import ProportionalOrientationController + + ctrl = ProportionalOrientationController(kp=2.0, max_angular_velocity=3.14) + cmd = ctrl.update([0, 0, 0, 1], [0, 0, 0.383, 0.924], dt=0.01) + """ + + def __init__( + self, + kp: float = 1.0, + max_angular_velocity: Optional[float] = None, + max_angular_acceleration: Optional[float] = None, + ) -> None: + super().__init__(max_angular_velocity, max_angular_acceleration) + self.kp = kp + + def compute_raw_command( + self, current_q: ArrayLike, target_q: ArrayLike, dt: float + ) -> np.ndarray: + """Compute ``kp × orientation_error``. + + Args: + current_q: Current orientation ``[x, y, z, w]``. + target_q: Target orientation ``[x, y, z, w]``. + dt: Time step in seconds (unused by P controller). + + Returns: + Angular velocity command ``(3,)`` in rad/s. + """ + error = compute_orientation_error(current_q, target_q) + return self.kp * error + + def reset(self) -> None: + """Reset the controller's internal state.""" + super().reset() + + +class PIDOrientationController(OrientationController): + """PID orientation controller with anti-windup. + + Computes the angular velocity command as + ``kp × error + ki × ∫error·dt + kd × d(error)/dt``. + Integral windup is prevented by clamping the integrated error magnitude + to *max_integral*. + + Args: + kp: Proportional gain. Defaults to ``1.0``. + ki: Integral gain. Defaults to ``0.0``. + kd: Derivative gain. Defaults to ``0.0``. + max_integral: Maximum magnitude of the integrated error vector. + Defaults to ``10.0``. + max_angular_velocity: Maximum angular velocity magnitude (rad/s). + max_angular_acceleration: Maximum angular acceleration (rad/s²). + + Example:: + + from spinstep.control import PIDOrientationController + + ctrl = PIDOrientationController(kp=2.0, ki=0.1, kd=0.5) + cmd = ctrl.update([0, 0, 0, 1], [0, 0, 0.383, 0.924], dt=0.01) + """ + + def __init__( + self, + kp: float = 1.0, + ki: float = 0.0, + kd: float = 0.0, + max_integral: float = 10.0, + max_angular_velocity: Optional[float] = None, + max_angular_acceleration: Optional[float] = None, + ) -> None: + super().__init__(max_angular_velocity, max_angular_acceleration) + self.kp = kp + self.ki = ki + self.kd = kd + self.max_integral = max_integral + self._integral: np.ndarray = np.zeros(3) + self._prev_error: Optional[np.ndarray] = None + + def compute_raw_command( + self, current_q: ArrayLike, target_q: ArrayLike, dt: float + ) -> np.ndarray: + """Compute the PID angular velocity command. + + Args: + current_q: Current orientation ``[x, y, z, w]``. + target_q: Target orientation ``[x, y, z, w]``. + dt: Time step in seconds. + + Returns: + Angular velocity command ``(3,)`` in rad/s. + """ + error = compute_orientation_error(current_q, target_q) + + # Proportional + p_term = self.kp * error + + # Integral with anti-windup + self._integral += error * dt + integral_mag = np.linalg.norm(self._integral) + if integral_mag > self.max_integral: + self._integral = self._integral * (self.max_integral / integral_mag) + i_term = self.ki * self._integral + + # Derivative + if self._prev_error is not None: + d_term = self.kd * (error - self._prev_error) / dt + else: + d_term = np.zeros(3) + self._prev_error = error.copy() + + return p_term + i_term + d_term + + def reset(self) -> None: + """Reset the controller's internal state (integral, derivative, etc.).""" + super().reset() + self._integral = np.zeros(3) + self._prev_error = None diff --git a/spinstep/control/state.py b/spinstep/control/state.py new file mode 100644 index 0000000..d6c1fdc --- /dev/null +++ b/spinstep/control/state.py @@ -0,0 +1,170 @@ +# control/state.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Orientation state model: dataclass, integration, and error computation.""" + +from __future__ import annotations + +__all__ = [ + "OrientationState", + "integrate_orientation", + "compute_orientation_error", +] + +from dataclasses import dataclass, field + +import numpy as np +from numpy.typing import ArrayLike +from scipy.spatial.transform import Rotation as R + +from ..math.core import quaternion_multiply, quaternion_normalize + + +@dataclass +class OrientationState: + """Immutable orientation state: pose, angular velocity, and timestamp. + + All quaternions use ``[x, y, z, w]`` convention. + + Args: + quaternion: Unit quaternion ``[x, y, z, w]`` representing the + current orientation. + angular_velocity: Angular velocity vector ``[ωx, ωy, ωz]`` in + radians per second. Defaults to zero. + timestamp: Time in seconds. Defaults to ``0.0``. + + Attributes: + quaternion: Normalised quaternion as a NumPy array of shape ``(4,)``. + angular_velocity: Angular velocity as a NumPy array of shape ``(3,)``. + timestamp: Timestamp in seconds. + + Example:: + + from spinstep.control import OrientationState + + state = OrientationState([0, 0, 0, 1]) + print(state.quaternion) # [0. 0. 0. 1.] + """ + + quaternion: np.ndarray = field(default_factory=lambda: np.array([0.0, 0.0, 0.0, 1.0])) + angular_velocity: np.ndarray = field(default_factory=lambda: np.zeros(3)) + timestamp: float = 0.0 + + def __init__( + self, + quaternion: ArrayLike = (0.0, 0.0, 0.0, 1.0), + angular_velocity: ArrayLike = (0.0, 0.0, 0.0), + timestamp: float = 0.0, + ) -> None: + q = np.asarray(quaternion, dtype=float) + if q.shape != (4,): + raise ValueError( + f"quaternion must have shape (4,), got {q.shape}" + ) + norm = np.linalg.norm(q) + if norm < 1e-8: + raise ValueError("quaternion must be non-zero") + self.quaternion = q / norm + + omega = np.asarray(angular_velocity, dtype=float) + if omega.shape != (3,): + raise ValueError( + f"angular_velocity must have shape (3,), got {omega.shape}" + ) + self.angular_velocity = omega + self.timestamp = float(timestamp) + + def __repr__(self) -> str: + return ( + f"OrientationState(" + f"q={self.quaternion.tolist()}, " + f"ω={self.angular_velocity.tolist()}, " + f"t={self.timestamp})" + ) + + +def integrate_orientation(state: OrientationState, dt: float) -> OrientationState: + """Integrate orientation forward by *dt* seconds using current angular velocity. + + Uses the exponential map: ``q(t+dt) = q(t) * exp(ω · dt / 2)``, + which is the standard first-order quaternion integration. + + Args: + state: Current orientation state. + dt: Time step in seconds. Must be positive. + + Returns: + New :class:`OrientationState` with updated quaternion and timestamp. + Angular velocity is carried forward unchanged. + + Raises: + ValueError: If *dt* is not positive. + + Example:: + + from spinstep.control import OrientationState, integrate_orientation + + state = OrientationState([0, 0, 0, 1], [0, 0, 1.0]) + new_state = integrate_orientation(state, dt=0.01) + """ + if dt <= 0: + raise ValueError(f"dt must be positive, got {dt}") + + omega = state.angular_velocity + angle = np.linalg.norm(omega) + + if angle < 1e-10: + # No rotation — return state with updated timestamp + return OrientationState( + quaternion=state.quaternion.copy(), + angular_velocity=state.angular_velocity.copy(), + timestamp=state.timestamp + dt, + ) + + # Compute the incremental rotation quaternion: exp(ω·dt/2) + half_angle = angle * dt / 2.0 + axis = omega / angle + delta_q = np.array([ + *(axis * np.sin(half_angle)), + np.cos(half_angle), + ]) + + new_q = quaternion_multiply(state.quaternion, delta_q) + new_q = quaternion_normalize(new_q) + + return OrientationState( + quaternion=new_q, + angular_velocity=state.angular_velocity.copy(), + timestamp=state.timestamp + dt, + ) + + +def compute_orientation_error( + current_q: ArrayLike, target_q: ArrayLike +) -> np.ndarray: + """Compute the orientation error as an axis-angle vector from current to target. + + The error is expressed in the body frame of *current_q*. Its direction is the + rotation axis and its magnitude is the rotation angle in radians. + + Args: + current_q: Current orientation quaternion ``[x, y, z, w]``. + target_q: Target orientation quaternion ``[x, y, z, w]``. + + Returns: + Error rotation vector ``(3,)`` in radians. Zero vector when + the orientations are identical. + + Example:: + + from spinstep.control import compute_orientation_error + + error = compute_orientation_error([0, 0, 0, 1], [0, 0, 0.383, 0.924]) + print(error) # approximately [0, 0, 0.785] + """ + r_current = R.from_quat(current_q) + r_target = R.from_quat(target_q) + # Error rotation in the body frame of current + r_error = r_current.inv() * r_target + return r_error.as_rotvec() diff --git a/spinstep/control/trajectory.py b/spinstep/control/trajectory.py new file mode 100644 index 0000000..9c75bf6 --- /dev/null +++ b/spinstep/control/trajectory.py @@ -0,0 +1,240 @@ +# control/trajectory.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Orientation trajectories: waypoints, interpolation, and tracking.""" + +from __future__ import annotations + +__all__ = [ + "OrientationTrajectory", + "TrajectoryInterpolator", + "TrajectoryController", +] + +from typing import List, Optional, Sequence, Tuple + +import numpy as np +from numpy.typing import ArrayLike + +from ..math.interpolation import slerp +from .controllers import OrientationController + + +class OrientationTrajectory: + """A sequence of quaternion waypoints with associated timestamps. + + Waypoints must be in ascending time order. + + Args: + waypoints: Sequence of ``(quaternion, time)`` pairs where each + quaternion is ``[x, y, z, w]`` and time is in seconds. + + Raises: + ValueError: If fewer than two waypoints are provided or times + are not strictly increasing. + + Attributes: + quaternions: Array of shape ``(N, 4)`` — waypoint quaternions. + times: Array of shape ``(N,)`` — waypoint times in seconds. + + Example:: + + from spinstep.control import OrientationTrajectory + + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 0.0), + ([0, 0, 0.383, 0.924], 1.0), + ([0, 0, 0.707, 0.707], 2.0), + ]) + """ + + quaternions: np.ndarray + times: np.ndarray + + def __init__( + self, + waypoints: Sequence[Tuple[ArrayLike, float]], + ) -> None: + if len(waypoints) < 2: + raise ValueError( + f"At least 2 waypoints are required, got {len(waypoints)}" + ) + + quats: List[np.ndarray] = [] + times: List[float] = [] + for q, t in waypoints: + arr = np.asarray(q, dtype=float) + if arr.shape != (4,): + raise ValueError( + f"Each waypoint quaternion must have shape (4,), got {arr.shape}" + ) + norm = np.linalg.norm(arr) + if norm < 1e-8: + raise ValueError("Waypoint quaternion must be non-zero") + quats.append(arr / norm) + times.append(float(t)) + + for i in range(1, len(times)): + if times[i] <= times[i - 1]: + raise ValueError( + f"Waypoint times must be strictly increasing: " + f"t[{i-1}]={times[i-1]}, t[{i}]={times[i]}" + ) + + self.quaternions = np.array(quats) + self.times = np.array(times) + + @property + def duration(self) -> float: + """Total duration of the trajectory in seconds.""" + return float(self.times[-1] - self.times[0]) + + @property + def start_time(self) -> float: + """Start time of the trajectory.""" + return float(self.times[0]) + + @property + def end_time(self) -> float: + """End time of the trajectory.""" + return float(self.times[-1]) + + def __len__(self) -> int: + return len(self.times) + + def __repr__(self) -> str: + return ( + f"OrientationTrajectory({len(self)} waypoints, " + f"t=[{self.start_time}, {self.end_time}])" + ) + + +class TrajectoryInterpolator: + """SLERP-based interpolator for an :class:`OrientationTrajectory`. + + Evaluates the orientation at any time within the trajectory's time span + using spherical linear interpolation between adjacent waypoints. + + Args: + trajectory: The trajectory to interpolate. + + Example:: + + from spinstep.control import OrientationTrajectory, TrajectoryInterpolator + + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 0.0), + ([0, 0, 0.383, 0.924], 1.0), + ]) + interp = TrajectoryInterpolator(traj) + q = interp.evaluate(0.5) + """ + + def __init__(self, trajectory: OrientationTrajectory) -> None: + self.trajectory = trajectory + + def evaluate(self, t: float) -> np.ndarray: + """Return the interpolated quaternion at time *t*. + + Times before the first waypoint return the first quaternion. + Times after the last waypoint return the last quaternion. + + Args: + t: Query time in seconds. + + Returns: + Interpolated unit quaternion ``[x, y, z, w]``. + """ + traj = self.trajectory + + if t <= traj.times[0]: + return traj.quaternions[0].copy() + if t >= traj.times[-1]: + return traj.quaternions[-1].copy() + + # Find the segment + idx = int(np.searchsorted(traj.times, t, side="right") - 1) + idx = min(idx, len(traj.times) - 2) + + t0 = traj.times[idx] + t1 = traj.times[idx + 1] + alpha = (t - t0) / (t1 - t0) + + return slerp(traj.quaternions[idx], traj.quaternions[idx + 1], alpha) + + @property + def duration(self) -> float: + """Total duration of the underlying trajectory.""" + return self.trajectory.duration + + +class TrajectoryController: + """Controller that tracks an orientation trajectory over time. + + Wraps a base :class:`OrientationController` and a + :class:`TrajectoryInterpolator`. At each time step the controller + queries the interpolator for the desired orientation and computes the + angular velocity command to drive the system towards it. + + Args: + controller: An :class:`OrientationController` instance (e.g. + :class:`ProportionalOrientationController` or + :class:`PIDOrientationController`). + trajectory: The trajectory to follow. + + Attributes: + interpolator: The :class:`TrajectoryInterpolator` used internally. + controller: The wrapped base controller. + is_complete: Whether the trajectory end time has been reached. + + Example:: + + from spinstep.control import ( + OrientationTrajectory, + ProportionalOrientationController, + TrajectoryController, + ) + + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 0.0), + ([0, 0, 0.383, 0.924], 1.0), + ]) + ctrl = ProportionalOrientationController(kp=2.0) + traj_ctrl = TrajectoryController(ctrl, traj) + cmd = traj_ctrl.update([0, 0, 0, 1], t=0.5, dt=0.01) + """ + + def __init__( + self, + controller: OrientationController, + trajectory: OrientationTrajectory, + ) -> None: + self.controller = controller + self.interpolator = TrajectoryInterpolator(trajectory) + self.is_complete: bool = False + + def update( + self, + current_q: ArrayLike, + t: float, + dt: float, + ) -> np.ndarray: + """Compute angular velocity command to track the trajectory at time *t*. + + Args: + current_q: Current orientation ``[x, y, z, w]``. + t: Current time in seconds. + dt: Time step in seconds. + + Returns: + Angular velocity command ``(3,)`` in rad/s. + """ + target_q = self.interpolator.evaluate(t) + self.is_complete = t >= self.interpolator.trajectory.end_time + return self.controller.update(current_q, target_q, dt) + + def reset(self) -> None: + """Reset the controller state.""" + self.controller.reset() + self.is_complete = False diff --git a/spinstep/math/__init__.py b/spinstep/math/__init__.py new file mode 100644 index 0000000..a4b30f2 --- /dev/null +++ b/spinstep/math/__init__.py @@ -0,0 +1,74 @@ +# math/__init__.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Quaternion mathematics library. + +Sub-modules: + +- :mod:`~.core` — multiply, conjugate, normalize, inverse +- :mod:`~.interpolation` — slerp, squad +- :mod:`~.geometry` — distance, angle, direction conversions +- :mod:`~.conversions` — Euler ↔ quaternion, matrix ↔ quaternion +- :mod:`~.analysis` — batch distances, angular velocity, relative spins +- :mod:`~.constraints` — rotation angle clamping +""" + +__all__ = [ + # core + "quaternion_multiply", + "quaternion_conjugate", + "quaternion_normalize", + "quaternion_inverse", + # interpolation + "slerp", + "squad", + # geometry + "quaternion_distance", + "is_within_angle_threshold", + "forward_vector_from_quaternion", + "direction_to_quaternion", + "angle_between_directions", + "rotate_quaternion", + # conversions + "quaternion_from_euler", + "rotation_matrix_to_quaternion", + "quaternion_from_rotvec", + "quaternion_to_rotvec", + # analysis + "batch_quaternion_angle", + "angular_velocity_from_quaternions", + "get_relative_spin", + "get_unique_relative_spins", + # constraints + "clamp_rotation_angle", +] + +from .core import ( + quaternion_conjugate, + quaternion_inverse, + quaternion_multiply, + quaternion_normalize, +) +from .interpolation import slerp, squad +from .geometry import ( + angle_between_directions, + direction_to_quaternion, + forward_vector_from_quaternion, + is_within_angle_threshold, + quaternion_distance, + rotate_quaternion, +) +from .conversions import ( + quaternion_from_euler, + quaternion_from_rotvec, + quaternion_to_rotvec, + rotation_matrix_to_quaternion, +) +from .analysis import ( + angular_velocity_from_quaternions, + batch_quaternion_angle, + get_relative_spin, + get_unique_relative_spins, +) +from .constraints import clamp_rotation_angle diff --git a/spinstep/math/analysis.py b/spinstep/math/analysis.py new file mode 100644 index 0000000..cd573db --- /dev/null +++ b/spinstep/math/analysis.py @@ -0,0 +1,141 @@ +# math/analysis.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Quaternion analysis: batch distances, angular velocity, relative spins.""" + +from __future__ import annotations + +__all__ = [ + "batch_quaternion_angle", + "angular_velocity_from_quaternions", + "get_relative_spin", + "get_unique_relative_spins", +] + +from types import ModuleType +from typing import Any, List, Sequence + +import numpy as np +from numpy.typing import ArrayLike + +from .core import quaternion_conjugate, quaternion_multiply + + +def batch_quaternion_angle(qs1: Any, qs2: Any, xp: ModuleType) -> Any: + """Compute pairwise angular distances between two sets of quaternions. + + Args: + qs1: Array of shape ``(N, 4)`` — first set of quaternions. + qs2: Array of shape ``(M, 4)`` — second set of quaternions. + xp: Array module (:mod:`numpy` or :mod:`cupy`). + + Returns: + ``(N, M)`` array of angular distances in radians. + """ + dots = xp.abs(xp.dot(qs1, qs2.T)) + dots = xp.clip(dots, -1.0, 1.0) + angles = 2 * xp.arccos(dots) + return angles + + +def angular_velocity_from_quaternions( + q1: ArrayLike, q2: ArrayLike, dt: float +) -> np.ndarray: + """Estimate angular velocity from two quaternions separated by *dt* seconds. + + Computes the rotation from *q1* to *q2*, converts to a rotation vector + (axis × angle), and divides by *dt* to obtain angular velocity in rad/s. + + Args: + q1: Start quaternion ``[x, y, z, w]``. + q2: End quaternion ``[x, y, z, w]``. + dt: Time step in seconds. Must be positive. + + Returns: + Angular velocity vector ``(3,)`` in radians per second. + + Raises: + ValueError: If *dt* is not positive. + """ + if dt <= 0: + raise ValueError(f"dt must be positive, got {dt}") + from scipy.spatial.transform import Rotation as R + + r1 = R.from_quat(q1) + r2 = R.from_quat(q2) + delta = r1.inv() * r2 + rotvec = delta.as_rotvec() + return rotvec / dt + + +def get_relative_spin(nf: object, nt: object) -> np.ndarray: + """Return the relative quaternion rotation from node *nf* to node *nt*. + + Both nodes must have an ``.orientation`` attribute storing a quaternion + ``[x, y, z, w]``. + + Args: + nf: Source node with ``.orientation`` attribute. + nt: Target node with ``.orientation`` attribute. + + Returns: + Unit quaternion representing the relative rotation. + """ + qfc = quaternion_conjugate(nf.orientation) # type: ignore[union-attr] + qr = quaternion_multiply(qfc, nt.orientation) # type: ignore[union-attr] + n = np.linalg.norm(qr) + return qr / n if n > 1e-8 else np.array([0.0, 0.0, 0.0, 1.0]) + + +def get_unique_relative_spins( + nodes: Sequence[object], + nside: int, + nest: bool, + threshold: float = 1e-3, +) -> List[np.ndarray]: + """Compute unique relative rotations between HEALPix neighbours. + + Requires the ``healpy`` package. + + Args: + nodes: Sequence of node objects with ``.orientation`` attributes. + nside: HEALPix *nside* parameter. + nest: Whether to use the NESTED pixel ordering. + threshold: Angular threshold (radians) for considering two + rotations identical. + + Returns: + List of unique unit quaternions representing relative rotations. + + Raises: + ImportError: If ``healpy`` is not installed. + """ + try: + import healpy as hp + except ImportError: + raise ImportError( + "healpy is required for get_unique_relative_spins(). " + "Install it with: pip install healpy" + ) + spins: List[np.ndarray] = [] + NPIX = hp.nside2npix(nside) + for i in range(NPIX): + nf = nodes[i] + nidx = hp.get_all_neighbours(nside, i, nest=nest) + for idx in nidx: + if idx != -1: + q = get_relative_spin(nf, nodes[idx]) + if q[3] < 0: + q = -q # Canonical form (w >= 0) + is_uniq = True + for s_q in spins: + dot = np.abs(np.dot(q, s_q)) + dot = np.clip(dot, -1, 1) + angle = 2 * np.arccos(dot) + if angle < threshold: + is_uniq = False + break + if is_uniq: + spins.append(q) + return spins diff --git a/spinstep/math/constraints.py b/spinstep/math/constraints.py new file mode 100644 index 0000000..4147f97 --- /dev/null +++ b/spinstep/math/constraints.py @@ -0,0 +1,52 @@ +# math/constraints.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Quaternion constraints: clamping, limiting rotation magnitude.""" + +from __future__ import annotations + +__all__ = [ + "clamp_rotation_angle", +] + +import numpy as np +from numpy.typing import ArrayLike +from scipy.spatial.transform import Rotation as R + + +def clamp_rotation_angle(q: ArrayLike, max_angle: float) -> np.ndarray: + """Clamp a rotation quaternion so its angle does not exceed *max_angle*. + + If the rotation represented by *q* has an angle larger than *max_angle*, + the quaternion is scaled to represent exactly *max_angle* around the + same axis. + + Args: + q: Rotation quaternion ``[x, y, z, w]``. + max_angle: Maximum allowed rotation angle in radians. Must be + non-negative. + + Returns: + Clamped unit quaternion ``[x, y, z, w]``. + + Raises: + ValueError: If *max_angle* is negative. + """ + if max_angle < 0: + raise ValueError(f"max_angle must be non-negative, got {max_angle}") + + rot = R.from_quat(q) + angle = rot.magnitude() + + if angle <= max_angle: + return np.asarray(q, dtype=float) + + # Scale the rotation vector to max_angle + rotvec = rot.as_rotvec() + if angle < 1e-10: + return np.array([0.0, 0.0, 0.0, 1.0]) + + axis = rotvec / angle + clamped_rotvec = axis * max_angle + return R.from_rotvec(clamped_rotvec).as_quat() diff --git a/spinstep/math/conversions.py b/spinstep/math/conversions.py new file mode 100644 index 0000000..f1ea394 --- /dev/null +++ b/spinstep/math/conversions.py @@ -0,0 +1,102 @@ +# math/conversions.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Quaternion conversions: Euler, rotation matrix, rotation vector.""" + +from __future__ import annotations + +__all__ = [ + "quaternion_from_euler", + "rotation_matrix_to_quaternion", + "quaternion_from_rotvec", + "quaternion_to_rotvec", +] + +from typing import Sequence + +import numpy as np +from numpy.typing import ArrayLike +from scipy.spatial.transform import Rotation as R + + +def quaternion_from_euler( + angles: Sequence[float], + order: str = "zyx", + degrees: bool = True, +) -> np.ndarray: + """Convert Euler angles to a quaternion ``[x, y, z, w]``. + + Args: + angles: Euler angles as a 3-element sequence. + order: Rotation order string (e.g. ``"zyx"``). + degrees: If ``True``, angles are in degrees; otherwise radians. + + Returns: + Unit quaternion ``[x, y, z, w]``. + """ + return R.from_euler(order, angles, degrees=degrees).as_quat() + + +def rotation_matrix_to_quaternion(m: ArrayLike) -> np.ndarray: + """Convert a 3×3 rotation matrix to a unit quaternion ``[x, y, z, w]``. + + Args: + m: A 3×3 rotation matrix. + + Returns: + Unit quaternion ``[x, y, z, w]``. + """ + mat = np.asarray(m, dtype=float) + t = np.trace(mat) + if t > 0: + s = np.sqrt(t + 1) * 2 + qw = 0.25 * s + qx = (mat[2, 1] - mat[1, 2]) / s + qy = (mat[0, 2] - mat[2, 0]) / s + qz = (mat[1, 0] - mat[0, 1]) / s + elif (mat[0, 0] > mat[1, 1]) and (mat[0, 0] > mat[2, 2]): + s = np.sqrt(1 + mat[0, 0] - mat[1, 1] - mat[2, 2]) * 2 + qx = 0.25 * s + qw = (mat[2, 1] - mat[1, 2]) / s + qy = (mat[0, 1] + mat[1, 0]) / s + qz = (mat[0, 2] + mat[2, 0]) / s + elif mat[1, 1] > mat[2, 2]: + s = np.sqrt(1 + mat[1, 1] - mat[0, 0] - mat[2, 2]) * 2 + qy = 0.25 * s + qw = (mat[0, 2] - mat[2, 0]) / s + qx = (mat[0, 1] + mat[1, 0]) / s + qz = (mat[1, 2] + mat[2, 1]) / s + else: + s = np.sqrt(1 + mat[2, 2] - mat[0, 0] - mat[1, 1]) * 2 + qz = 0.25 * s + qw = (mat[1, 0] - mat[0, 1]) / s + qx = (mat[0, 2] + mat[2, 0]) / s + qy = (mat[1, 2] + mat[2, 1]) / s + q = np.array([qx, qy, qz, qw]) + n = np.linalg.norm(q) + return q / n if n > 1e-8 else np.array([0.0, 0.0, 0.0, 1.0]) + + +def quaternion_from_rotvec(rotvec: ArrayLike) -> np.ndarray: + """Convert a rotation vector (axis × angle) to a quaternion. + + Args: + rotvec: Rotation vector of shape ``(3,)``. + + Returns: + Unit quaternion ``[x, y, z, w]``. + """ + return R.from_rotvec(np.asarray(rotvec, dtype=float)).as_quat() + + +def quaternion_to_rotvec(q: ArrayLike) -> np.ndarray: + """Convert a quaternion to a rotation vector (axis × angle). + + Args: + q: Quaternion ``[x, y, z, w]``. + + Returns: + Rotation vector of shape ``(3,)``. + """ + return R.from_quat(np.asarray(q, dtype=float)).as_rotvec() diff --git a/spinstep/math/core.py b/spinstep/math/core.py new file mode 100644 index 0000000..8dd0fa1 --- /dev/null +++ b/spinstep/math/core.py @@ -0,0 +1,90 @@ +# math/core.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Core quaternion operations: multiply, conjugate, normalize, inverse.""" + +from __future__ import annotations + +__all__ = [ + "quaternion_multiply", + "quaternion_conjugate", + "quaternion_normalize", + "quaternion_inverse", +] + +import numpy as np +from numpy.typing import ArrayLike + + +def quaternion_multiply(q1: ArrayLike, q2: ArrayLike) -> np.ndarray: + """Hamilton product of two quaternions in ``[x, y, z, w]`` order. + + Args: + q1: First quaternion ``[x, y, z, w]``. + q2: Second quaternion ``[x, y, z, w]``. + + Returns: + Product quaternion ``[x, y, z, w]`` as a NumPy array of shape ``(4,)``. + """ + a1 = np.asarray(q1, dtype=float) + a2 = np.asarray(q2, dtype=float) + x1, y1, z1, w1 = a1 + x2, y2, z2, w2 = a2 + return np.array([ + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + ]) + + +def quaternion_conjugate(q: ArrayLike) -> np.ndarray: + """Return the conjugate of quaternion *q* ``[x, y, z, w]``. + + The conjugate negates the vector part while keeping the scalar part. + + Args: + q: Quaternion ``[x, y, z, w]``. + + Returns: + Conjugate quaternion ``[-x, -y, -z, w]``. + """ + a = np.asarray(q, dtype=float) + return np.array([-a[0], -a[1], -a[2], a[3]]) + + +def quaternion_normalize(q: ArrayLike) -> np.ndarray: + """Normalize a quaternion to unit length. + + Args: + q: Quaternion ``[x, y, z, w]``. + + Returns: + Unit quaternion. Returns ``[0, 0, 0, 1]`` if the input has + near-zero norm. + """ + a = np.asarray(q, dtype=float) + n = np.linalg.norm(a) + if n < 1e-8: + return np.array([0.0, 0.0, 0.0, 1.0]) + return a / n + + +def quaternion_inverse(q: ArrayLike) -> np.ndarray: + """Return the inverse of a unit quaternion. + + For unit quaternions the inverse equals the conjugate. + + Args: + q: Unit quaternion ``[x, y, z, w]``. + + Returns: + Inverse quaternion ``[x, y, z, w]``. + """ + a = np.asarray(q, dtype=float) + norm_sq = np.dot(a, a) + if norm_sq < 1e-16: + return np.array([0.0, 0.0, 0.0, 1.0]) + conj = np.array([-a[0], -a[1], -a[2], a[3]]) + return conj / norm_sq diff --git a/spinstep/math/geometry.py b/spinstep/math/geometry.py new file mode 100644 index 0000000..ce2cb71 --- /dev/null +++ b/spinstep/math/geometry.py @@ -0,0 +1,125 @@ +# math/geometry.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Quaternion geometry: distance, angle, direction conversions.""" + +from __future__ import annotations + +__all__ = [ + "quaternion_distance", + "is_within_angle_threshold", + "forward_vector_from_quaternion", + "direction_to_quaternion", + "angle_between_directions", + "rotate_quaternion", +] + +import numpy as np +from numpy.typing import ArrayLike +from scipy.spatial.transform import Rotation as R + + +def quaternion_distance(q1: ArrayLike, q2: ArrayLike) -> float: + """Return the angular distance (radians) between two quaternions. + + Args: + q1: First quaternion ``[x, y, z, w]``. + q2: Second quaternion ``[x, y, z, w]``. + + Returns: + Angular distance in radians. + """ + r1 = R.from_quat(q1) + r2 = R.from_quat(q2) + return float((r1.inv() * r2).magnitude()) + + +def is_within_angle_threshold( + q_current: ArrayLike, + q_target: ArrayLike, + threshold_rad: float, +) -> bool: + """Check whether two quaternions are within *threshold_rad* of each other. + + Args: + q_current: Current quaternion ``[x, y, z, w]``. + q_target: Target quaternion ``[x, y, z, w]``. + threshold_rad: Maximum angular distance in radians. + + Returns: + ``True`` if the angular distance is less than *threshold_rad*. + """ + return quaternion_distance(q_current, q_target) < threshold_rad + + +def rotate_quaternion(q: ArrayLike, rotation_step: ArrayLike) -> np.ndarray: + """Apply *rotation_step* to quaternion *q* and return the result. + + Args: + q: Base quaternion ``[x, y, z, w]``. + rotation_step: Rotation to apply, as quaternion ``[x, y, z, w]``. + + Returns: + Composed quaternion ``[x, y, z, w]``. + """ + r1 = R.from_quat(q) + step = R.from_quat(rotation_step) + return (r1 * step).as_quat() + + +def forward_vector_from_quaternion(q: ArrayLike) -> np.ndarray: + """Extract the forward (look) direction from a quaternion. + + The forward direction is defined as ``[0, 0, -1]`` rotated by the + quaternion, following the convention where negative-Z is "forward". + + Args: + q: Quaternion ``[x, y, z, w]``. + + Returns: + Unit direction vector ``(3,)`` pointing forward. + """ + return R.from_quat(q).apply([0, 0, -1]) + + +def direction_to_quaternion(direction: ArrayLike) -> np.ndarray: + """Convert a 3D direction vector to an orientation quaternion. + + The returned quaternion represents the rotation that aligns the + default forward axis ``[0, 0, -1]`` with the given *direction*. + + Args: + direction: Target direction vector (does not need to be normalised). + + Returns: + Unit quaternion ``[x, y, z, w]``. + """ + d = np.asarray(direction, dtype=float) + norm = np.linalg.norm(d) + if norm < 1e-8: + return np.array([0.0, 0.0, 0.0, 1.0]) + d = d / norm + rot, _ = R.align_vectors([d], [[0, 0, -1]]) + return rot.as_quat() + + +def angle_between_directions(d1: ArrayLike, d2: ArrayLike) -> float: + """Compute the angular distance (radians) between two direction vectors. + + Args: + d1: First direction vector. + d2: Second direction vector. + + Returns: + Angle in radians in the range ``[0, π]``. + """ + v1 = np.asarray(d1, dtype=float) + v2 = np.asarray(d2, dtype=float) + n1 = np.linalg.norm(v1) + n2 = np.linalg.norm(v2) + if n1 < 1e-8 or n2 < 1e-8: + return 0.0 + cos_angle = np.dot(v1 / n1, v2 / n2) + cos_angle = np.clip(cos_angle, -1.0, 1.0) + return float(np.arccos(cos_angle)) diff --git a/spinstep/math/interpolation.py b/spinstep/math/interpolation.py new file mode 100644 index 0000000..db3b149 --- /dev/null +++ b/spinstep/math/interpolation.py @@ -0,0 +1,135 @@ +# math/interpolation.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Quaternion interpolation: SLERP and SQUAD.""" + +from __future__ import annotations + +__all__ = [ + "slerp", + "squad", +] + +from typing import Sequence + +import numpy as np +from numpy.typing import ArrayLike + + +def slerp(q0: ArrayLike, q1: ArrayLike, t: float) -> np.ndarray: + """Spherical linear interpolation between two quaternions. + + Interpolates along the shortest arc on the unit quaternion hypersphere. + + Args: + q0: Start quaternion ``[x, y, z, w]``. + q1: End quaternion ``[x, y, z, w]``. + t: Interpolation parameter in ``[0, 1]``. + + Returns: + Interpolated unit quaternion ``[x, y, z, w]``. + """ + a = np.asarray(q0, dtype=float) + b = np.asarray(q1, dtype=float) + + # Normalize inputs + a = a / np.linalg.norm(a) + b = b / np.linalg.norm(b) + + # Ensure shortest path + dot = np.dot(a, b) + if dot < 0.0: + b = -b + dot = -dot + + dot = np.clip(dot, -1.0, 1.0) + + # If quaternions are very close, use linear interpolation + if dot > 0.9995: + result = a + t * (b - a) + return result / np.linalg.norm(result) + + theta_0 = np.arccos(dot) + theta = theta_0 * t + sin_theta = np.sin(theta) + sin_theta_0 = np.sin(theta_0) + + s0 = np.cos(theta) - dot * sin_theta / sin_theta_0 + s1 = sin_theta / sin_theta_0 + + result = s0 * a + s1 * b + return result / np.linalg.norm(result) + + +def squad( + q0: ArrayLike, + q1: ArrayLike, + q2: ArrayLike, + q3: ArrayLike, + t: float, +) -> np.ndarray: + """Spherical cubic interpolation (SQUAD) between quaternion waypoints. + + Produces a smooth C¹-continuous curve through a sequence of orientations. + *q0* and *q3* are the neighboring control points; the interpolation is + between *q1* (at *t* = 0) and *q2* (at *t* = 1). + + Args: + q0: Control quaternion before *q1*. + q1: Start quaternion for this segment. + q2: End quaternion for this segment. + q3: Control quaternion after *q2*. + t: Interpolation parameter in ``[0, 1]``. + + Returns: + Interpolated unit quaternion ``[x, y, z, w]``. + """ + a1 = np.asarray(q1, dtype=float) + a2 = np.asarray(q2, dtype=float) + + s1 = _squad_intermediate(np.asarray(q0, dtype=float), a1, a2) + s2 = _squad_intermediate(a1, a2, np.asarray(q3, dtype=float)) + + slerp_q1_q2 = slerp(a1, a2, t) + slerp_s1_s2 = slerp(s1, s2, t) + return slerp(slerp_q1_q2, slerp_s1_s2, 2.0 * t * (1.0 - t)) + + +def _squad_intermediate( + q_prev: np.ndarray, q_curr: np.ndarray, q_next: np.ndarray +) -> np.ndarray: + """Compute the SQUAD intermediate control point for *q_curr*.""" + from .core import quaternion_conjugate, quaternion_multiply + + q_curr = q_curr / np.linalg.norm(q_curr) + inv_curr = quaternion_conjugate(q_curr) + + log_prev = _quat_log(quaternion_multiply(inv_curr, q_prev / np.linalg.norm(q_prev))) + log_next = _quat_log(quaternion_multiply(inv_curr, q_next / np.linalg.norm(q_next))) + + avg = -(log_prev + log_next) / 4.0 + result = quaternion_multiply(q_curr, _quat_exp(avg)) + return result / np.linalg.norm(result) + + +def _quat_log(q: np.ndarray) -> np.ndarray: + """Quaternion logarithm (returns pure-quaternion vector part).""" + q = q / np.linalg.norm(q) + vec = q[:3] + w = q[3] + vec_norm = np.linalg.norm(vec) + if vec_norm < 1e-10: + return np.array([0.0, 0.0, 0.0, 0.0]) + theta = np.arctan2(vec_norm, w) + return np.array([*(vec / vec_norm * theta), 0.0]) + + +def _quat_exp(v: np.ndarray) -> np.ndarray: + """Quaternion exponential (from pure-quaternion vector).""" + vec = v[:3] + theta = np.linalg.norm(vec) + if theta < 1e-10: + return np.array([0.0, 0.0, 0.0, 1.0]) + axis = vec / theta + return np.array([*(axis * np.sin(theta)), np.cos(theta)]) diff --git a/spinstep/traversal/__init__.py b/spinstep/traversal/__init__.py new file mode 100644 index 0000000..3617c00 --- /dev/null +++ b/spinstep/traversal/__init__.py @@ -0,0 +1,25 @@ +# traversal/__init__.py — MIT License +# Author: Eraldo B. Marques — Created: 2025-05-14 +# See LICENSE.txt for full terms. This header must be retained in redistributions. + +"""Tree traversal using quaternion orientation. + +This sub-package contains the original traversal classes: + +- :class:`Node` — tree node with quaternion orientation +- :class:`QuaternionDepthIterator` — continuous rotation-step depth-first traversal +- :class:`DiscreteOrientationSet` — queryable set of discrete orientations +- :class:`DiscreteQuaternionIterator` — discrete rotation-step depth-first traversal +""" + +__all__ = [ + "Node", + "QuaternionDepthIterator", + "DiscreteOrientationSet", + "DiscreteQuaternionIterator", +] + +from ..node import Node +from ..traversal import QuaternionDepthIterator +from ..discrete import DiscreteOrientationSet +from ..discrete_iterator import DiscreteQuaternionIterator From ec5c5e40865206a7bf63de69ac7a840112c78992 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:53:14 +0000 Subject: [PATCH 2/5] Implement full control library with observer-centered spherical model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - spinstep/math/: core, interpolation (slerp/squad), geometry, conversions, analysis, constraints — 6 modules with 20+ public functions - spinstep/control/: OrientationState (quaternion+distance), ControlCommand, P and PID controllers with angular+radial channels, trajectory tracking - spinstep/traversal/: relocated node, continuous, discrete, discrete_iterator - Updated __init__.py with new top-level API exports - 73 new tests (131 total pass, 5 skipped CUDA/healpy) - All ruff lint clean on new code Agent-Logs-Url: https://github.com/VoxleOne/SpinStep/sessions/dc0f5515-1df0-43c9-8b5e-28e75d27e046 Co-authored-by: VoxleOne <119956342+VoxleOne@users.noreply.github.com> --- pyproject.toml | 4 +- spinstep/__init__.py | 78 ++- spinstep/control/__init__.py | 16 +- spinstep/control/controllers.py | 301 +++++++--- spinstep/control/state.py | 193 +++++-- spinstep/control/trajectory.py | 120 ++-- spinstep/math/interpolation.py | 2 - spinstep/traversal/__init__.py | 10 +- .../{traversal.py => traversal/continuous.py} | 0 spinstep/{ => traversal}/discrete.py | 2 +- spinstep/{ => traversal}/discrete_iterator.py | 0 spinstep/{ => traversal}/node.py | 0 tests/test_control.py | 515 ++++++++++++++++++ tests/test_discrete_traversal.py | 4 +- tests/test_math.py | 191 +++++++ tests/test_spinstep.py | 6 +- tests/test_traversal.py | 4 +- tests/test_utils.py | 2 +- 18 files changed, 1231 insertions(+), 217 deletions(-) rename spinstep/{traversal.py => traversal/continuous.py} (100%) rename spinstep/{ => traversal}/discrete.py (99%) rename spinstep/{ => traversal}/discrete_iterator.py (100%) rename spinstep/{ => traversal}/node.py (100%) create mode 100644 tests/test_control.py create mode 100644 tests/test_math.py diff --git a/pyproject.toml b/pyproject.toml index 782bf98..87b50ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta" [project] name = "spinstep" -version = "0.3.1a1" -description = "Quaternion-based tree traversal for orientation-aware structures." +version = "0.4.0a0" +description = "Quaternion-based orientation control and traversal in observer-centered spherical coordinates." authors = [{ name = "Eraldo Marques", email = "eraldo.bernardo@gmail.com" }] readme = "README.md" license = "MIT" diff --git a/spinstep/__init__.py b/spinstep/__init__.py index 7b2d1e6..709e057 100644 --- a/spinstep/__init__.py +++ b/spinstep/__init__.py @@ -2,32 +2,78 @@ # Author: Eraldo B. Marques — Created: 2025-05-14 # See LICENSE.txt for full terms. This header must be retained in redistributions. -"""SpinStep: A quaternion-driven traversal framework. +"""SpinStep: Quaternion-based orientation control and traversal. -Provides quaternion-based tree traversal for orientation-aware structures, -supporting both continuous and discrete rotation stepping. +SpinStep uses an observer-centered spherical model where every guided +vehicle (node) is located by a quaternion (direction from the observer) +and a radial distance (layer). -Example usage:: +Control layer (primary API):: - from spinstep import Node, QuaternionDepthIterator + from spinstep import ( + OrientationState, ControlCommand, + ProportionalOrientationController, PIDOrientationController, + OrientationTrajectory, TrajectoryController, + integrate_state, compute_orientation_error, + slerp, + ) - root = Node("root", [0, 0, 0, 1]) - child = Node("child", [0, 0, 0.1, 0.995]) - root.children.append(child) +Traversal layer (original tree-walking API):: - step = [0, 0, 0.05, 0.9987] # small rotation about Z - for node in QuaternionDepthIterator(root, step): - print(node.name) + from spinstep.traversal import ( + Node, QuaternionDepthIterator, + DiscreteOrientationSet, DiscreteQuaternionIterator, + ) + +Math layer:: + + from spinstep.math import quaternion_multiply, quaternion_distance, slerp """ -__version__ = "0.3.0a0" +__version__ = "0.4.0a0" + +# --- control layer (primary API) --- +from .control.state import ( + ControlCommand, + OrientationState, + compute_orientation_error, + integrate_state, +) +from .control.controllers import ( + OrientationController, + PIDOrientationController, + ProportionalOrientationController, +) +from .control.trajectory import ( + OrientationTrajectory, + TrajectoryController, + TrajectoryInterpolator, +) + +# --- key math utilities at top level --- +from .math.interpolation import slerp -from .node import Node -from .traversal import QuaternionDepthIterator -from .discrete import DiscreteOrientationSet -from .discrete_iterator import DiscreteQuaternionIterator +# --- backward-compatible traversal re-exports --- +from .traversal.node import Node +from .traversal.continuous import QuaternionDepthIterator +from .traversal.discrete import DiscreteOrientationSet +from .traversal.discrete_iterator import DiscreteQuaternionIterator __all__ = [ + # control + "OrientationState", + "ControlCommand", + "integrate_state", + "compute_orientation_error", + "OrientationController", + "ProportionalOrientationController", + "PIDOrientationController", + "OrientationTrajectory", + "TrajectoryInterpolator", + "TrajectoryController", + # math + "slerp", + # traversal (backward compat) "Node", "QuaternionDepthIterator", "DiscreteOrientationSet", diff --git a/spinstep/control/__init__.py b/spinstep/control/__init__.py index bdc4c54..57ee8a7 100644 --- a/spinstep/control/__init__.py +++ b/spinstep/control/__init__.py @@ -2,11 +2,17 @@ # Author: Eraldo B. Marques — Created: 2025-05-14 # See LICENSE.txt for full terms. This header must be retained in redistributions. -"""Orientation control: state, controllers, and trajectory tracking. +"""Observer-centered orientation control: state, controllers, and trajectories. + +The SpinStep control model places the observer at the origin. Every +guided vehicle (node) is located by a quaternion (direction from observer) +and a radial distance (layer). Controllers produce commands with both +angular and radial velocity components. Sub-modules: -- :mod:`~.state` — :class:`OrientationState`, integration, error computation +- :mod:`~.state` — :class:`OrientationState`, :class:`ControlCommand`, + integration, error computation - :mod:`~.controllers` — proportional and PID orientation controllers - :mod:`~.trajectory` — waypoint trajectories and trajectory tracking """ @@ -14,7 +20,8 @@ __all__ = [ # state "OrientationState", - "integrate_orientation", + "ControlCommand", + "integrate_state", "compute_orientation_error", # controllers "OrientationController", @@ -27,9 +34,10 @@ ] from .state import ( + ControlCommand, OrientationState, compute_orientation_error, - integrate_orientation, + integrate_state, ) from .controllers import ( OrientationController, diff --git a/spinstep/control/controllers.py b/spinstep/control/controllers.py index 501c071..b1327dd 100644 --- a/spinstep/control/controllers.py +++ b/spinstep/control/controllers.py @@ -2,7 +2,12 @@ # Author: Eraldo B. Marques — Created: 2025-05-14 # See LICENSE.txt for full terms. This header must be retained in redistributions. -"""Orientation controllers: proportional and PID with rate limiting.""" +"""Orientation controllers: proportional and PID with rate limiting. + +All controllers operate in the observer-centered spherical model and +produce a :class:`~.state.ControlCommand` containing both angular +velocity and radial velocity components. +""" from __future__ import annotations @@ -18,37 +23,51 @@ import numpy as np from numpy.typing import ArrayLike -from .state import compute_orientation_error +from .state import ControlCommand, compute_orientation_error class OrientationController(ABC): """Abstract base class for orientation controllers. - All controllers share the ``update`` interface, which takes the current - and target quaternions plus a time step and returns an angular velocity - command vector. + All controllers share the :meth:`update` interface, which takes current + and target poses (quaternion + distance) plus a time step, and returns + a :class:`ControlCommand` with angular and radial velocity components. Args: max_angular_velocity: Maximum angular velocity magnitude (rad/s). ``None`` means unlimited. max_angular_acceleration: Maximum angular acceleration magnitude (rad/s²). ``None`` means unlimited. + max_radial_velocity: Maximum radial speed (units/s). + ``None`` means unlimited. + max_radial_acceleration: Maximum radial acceleration (units/s²). + ``None`` means unlimited. """ def __init__( self, max_angular_velocity: Optional[float] = None, max_angular_acceleration: Optional[float] = None, + max_radial_velocity: Optional[float] = None, + max_radial_acceleration: Optional[float] = None, ) -> None: self.max_angular_velocity = max_angular_velocity self.max_angular_acceleration = max_angular_acceleration - self._prev_command: Optional[np.ndarray] = None + self.max_radial_velocity = max_radial_velocity + self.max_radial_acceleration = max_radial_acceleration + self._prev_angular_cmd: Optional[np.ndarray] = None + self._prev_radial_cmd: Optional[float] = None @abstractmethod def compute_raw_command( - self, current_q: ArrayLike, target_q: ArrayLike, dt: float - ) -> np.ndarray: - """Compute the raw (unclamped) angular velocity command. + self, + current_q: ArrayLike, + target_q: ArrayLike, + dt: float, + current_distance: float = 0.0, + target_distance: float = 0.0, + ) -> ControlCommand: + """Compute the raw (unclamped) control command. Subclasses must implement this. @@ -56,28 +75,36 @@ def compute_raw_command( current_q: Current orientation ``[x, y, z, w]``. target_q: Target orientation ``[x, y, z, w]``. dt: Time step in seconds. + current_distance: Current radial distance from observer. + target_distance: Target radial distance from observer. Returns: - Raw angular velocity command ``(3,)`` in rad/s. + Raw :class:`ControlCommand`. """ ... def update( - self, current_q: ArrayLike, target_q: ArrayLike, dt: float - ) -> np.ndarray: - """Compute a rate-limited angular velocity command. + self, + current_q: ArrayLike, + target_q: ArrayLike, + dt: float, + current_distance: float = 0.0, + target_distance: float = 0.0, + ) -> ControlCommand: + """Compute a rate-limited control command. - Calls :meth:`compute_raw_command` and then applies - :attr:`max_angular_velocity` and :attr:`max_angular_acceleration` - limits. + Calls :meth:`compute_raw_command` and applies velocity and + acceleration limits to both angular and radial components. Args: current_q: Current orientation ``[x, y, z, w]``. target_q: Target orientation ``[x, y, z, w]``. dt: Time step in seconds. Must be positive. + current_distance: Current radial distance from observer. + target_distance: Target radial distance from observer. Returns: - Angular velocity command ``(3,)`` in rad/s. + Rate-limited :class:`ControlCommand`. Raises: ValueError: If *dt* is not positive. @@ -85,77 +112,127 @@ def update( if dt <= 0: raise ValueError(f"dt must be positive, got {dt}") - command = self.compute_raw_command(current_q, target_q, dt) + raw = self.compute_raw_command( + current_q, target_q, dt, current_distance, target_distance + ) + angular = raw.angular_velocity.copy() + radial = raw.radial_velocity - # Apply velocity limit + # --- angular velocity limit --- if self.max_angular_velocity is not None: - speed = np.linalg.norm(command) + speed = np.linalg.norm(angular) if speed > self.max_angular_velocity: - command = command * (self.max_angular_velocity / speed) - - # Apply acceleration limit - if self.max_angular_acceleration is not None and self._prev_command is not None: - delta = command - self._prev_command - accel = np.linalg.norm(delta) / dt - if accel > self.max_angular_acceleration: + angular = angular * (self.max_angular_velocity / speed) + + # --- angular acceleration limit --- + if ( + self.max_angular_acceleration is not None + and self._prev_angular_cmd is not None + ): + delta = angular - self._prev_angular_cmd + delta_norm = np.linalg.norm(delta) + if delta_norm / dt > self.max_angular_acceleration: max_delta = ( - delta / np.linalg.norm(delta) - * self.max_angular_acceleration - * dt + delta / delta_norm * self.max_angular_acceleration * dt ) - command = self._prev_command + max_delta - - self._prev_command = command.copy() - return command + angular = self._prev_angular_cmd + max_delta + + # --- radial velocity limit --- + if self.max_radial_velocity is not None: + if abs(radial) > self.max_radial_velocity: + radial = np.sign(radial) * self.max_radial_velocity + + # --- radial acceleration limit --- + if ( + self.max_radial_acceleration is not None + and self._prev_radial_cmd is not None + ): + delta_r = radial - self._prev_radial_cmd + if abs(delta_r) / dt > self.max_radial_acceleration: + max_delta_r = np.sign(delta_r) * self.max_radial_acceleration * dt + radial = self._prev_radial_cmd + max_delta_r + + self._prev_angular_cmd = angular.copy() + self._prev_radial_cmd = float(radial) + return ControlCommand(angular_velocity=angular, radial_velocity=radial) def reset(self) -> None: """Reset the controller's internal state.""" - self._prev_command = None + self._prev_angular_cmd = None + self._prev_radial_cmd = None class ProportionalOrientationController(OrientationController): """Proportional (P) orientation controller. - Computes the angular velocity command as ``kp × error``, where the - error is the rotation vector from the current to the target orientation. + Computes angular velocity as ``kp × angular_error`` and radial + velocity as ``kp_radial × radial_error``. Args: - kp: Proportional gain. Defaults to ``1.0``. - max_angular_velocity: Maximum angular velocity magnitude (rad/s). + kp: Proportional gain for angular error. Defaults to ``1.0``. + kp_radial: Proportional gain for radial error. Defaults to ``1.0``. + max_angular_velocity: Maximum angular velocity (rad/s). max_angular_acceleration: Maximum angular acceleration (rad/s²). + max_radial_velocity: Maximum radial speed (units/s). + max_radial_acceleration: Maximum radial acceleration (units/s²). Example:: from spinstep.control import ProportionalOrientationController - ctrl = ProportionalOrientationController(kp=2.0, max_angular_velocity=3.14) - cmd = ctrl.update([0, 0, 0, 1], [0, 0, 0.383, 0.924], dt=0.01) + ctrl = ProportionalOrientationController(kp=2.0, kp_radial=1.5) + cmd = ctrl.update( + [0, 0, 0, 1], [0, 0, 0.383, 0.924], dt=0.01, + current_distance=3.0, target_distance=5.0, + ) + print(cmd.angular_velocity, cmd.radial_velocity) """ def __init__( self, kp: float = 1.0, + kp_radial: float = 1.0, max_angular_velocity: Optional[float] = None, max_angular_acceleration: Optional[float] = None, + max_radial_velocity: Optional[float] = None, + max_radial_acceleration: Optional[float] = None, ) -> None: - super().__init__(max_angular_velocity, max_angular_acceleration) + super().__init__( + max_angular_velocity, + max_angular_acceleration, + max_radial_velocity, + max_radial_acceleration, + ) self.kp = kp + self.kp_radial = kp_radial def compute_raw_command( - self, current_q: ArrayLike, target_q: ArrayLike, dt: float - ) -> np.ndarray: - """Compute ``kp × orientation_error``. + self, + current_q: ArrayLike, + target_q: ArrayLike, + dt: float, + current_distance: float = 0.0, + target_distance: float = 0.0, + ) -> ControlCommand: + """Compute ``kp × error`` for angular and radial components. Args: current_q: Current orientation ``[x, y, z, w]``. target_q: Target orientation ``[x, y, z, w]``. dt: Time step in seconds (unused by P controller). + current_distance: Current radial distance. + target_distance: Target radial distance. Returns: - Angular velocity command ``(3,)`` in rad/s. + :class:`ControlCommand`. """ - error = compute_orientation_error(current_q, target_q) - return self.kp * error + ang_error, rad_error = compute_orientation_error( + current_q, target_q, current_distance, target_distance + ) + return ControlCommand( + angular_velocity=self.kp * ang_error, + radial_velocity=self.kp_radial * rad_error, + ) def reset(self) -> None: """Reset the controller's internal state.""" @@ -165,26 +242,34 @@ def reset(self) -> None: class PIDOrientationController(OrientationController): """PID orientation controller with anti-windup. - Computes the angular velocity command as - ``kp × error + ki × ∫error·dt + kd × d(error)/dt``. - Integral windup is prevented by clamping the integrated error magnitude - to *max_integral*. + Computes ``kp × e + ki × ∫e·dt + kd × de/dt`` for both the angular + and radial error channels independently. Integral windup is prevented + by clamping integrated error magnitudes. Args: - kp: Proportional gain. Defaults to ``1.0``. - ki: Integral gain. Defaults to ``0.0``. - kd: Derivative gain. Defaults to ``0.0``. - max_integral: Maximum magnitude of the integrated error vector. - Defaults to ``10.0``. - max_angular_velocity: Maximum angular velocity magnitude (rad/s). + kp: Proportional gain (angular). + ki: Integral gain (angular). + kd: Derivative gain (angular). + kp_radial: Proportional gain (radial). + ki_radial: Integral gain (radial). + kd_radial: Derivative gain (radial). + max_integral: Maximum angular integral magnitude. + max_integral_radial: Maximum radial integral magnitude. + max_angular_velocity: Maximum angular velocity (rad/s). max_angular_acceleration: Maximum angular acceleration (rad/s²). + max_radial_velocity: Maximum radial speed (units/s). + max_radial_acceleration: Maximum radial acceleration (units/s²). Example:: from spinstep.control import PIDOrientationController - ctrl = PIDOrientationController(kp=2.0, ki=0.1, kd=0.5) - cmd = ctrl.update([0, 0, 0, 1], [0, 0, 0.383, 0.924], dt=0.01) + ctrl = PIDOrientationController(kp=2.0, ki=0.1, kd=0.5, + kp_radial=1.0, ki_radial=0.05) + cmd = ctrl.update( + [0, 0, 0, 1], [0, 0, 0.383, 0.924], dt=0.01, + current_distance=3.0, target_distance=5.0, + ) """ def __init__( @@ -192,54 +277,102 @@ def __init__( kp: float = 1.0, ki: float = 0.0, kd: float = 0.0, + kp_radial: float = 1.0, + ki_radial: float = 0.0, + kd_radial: float = 0.0, max_integral: float = 10.0, + max_integral_radial: float = 10.0, max_angular_velocity: Optional[float] = None, max_angular_acceleration: Optional[float] = None, + max_radial_velocity: Optional[float] = None, + max_radial_acceleration: Optional[float] = None, ) -> None: - super().__init__(max_angular_velocity, max_angular_acceleration) + super().__init__( + max_angular_velocity, + max_angular_acceleration, + max_radial_velocity, + max_radial_acceleration, + ) self.kp = kp self.ki = ki self.kd = kd + self.kp_radial = kp_radial + self.ki_radial = ki_radial + self.kd_radial = kd_radial self.max_integral = max_integral - self._integral: np.ndarray = np.zeros(3) - self._prev_error: Optional[np.ndarray] = None + self.max_integral_radial = max_integral_radial + + self._ang_integral: np.ndarray = np.zeros(3) + self._rad_integral: float = 0.0 + self._prev_ang_error: Optional[np.ndarray] = None + self._prev_rad_error: Optional[float] = None def compute_raw_command( - self, current_q: ArrayLike, target_q: ArrayLike, dt: float - ) -> np.ndarray: - """Compute the PID angular velocity command. + self, + current_q: ArrayLike, + target_q: ArrayLike, + dt: float, + current_distance: float = 0.0, + target_distance: float = 0.0, + ) -> ControlCommand: + """Compute the PID command for angular and radial channels. Args: current_q: Current orientation ``[x, y, z, w]``. target_q: Target orientation ``[x, y, z, w]``. dt: Time step in seconds. + current_distance: Current radial distance. + target_distance: Target radial distance. Returns: - Angular velocity command ``(3,)`` in rad/s. + :class:`ControlCommand`. """ - error = compute_orientation_error(current_q, target_q) + ang_error, rad_error = compute_orientation_error( + current_q, target_q, current_distance, target_distance + ) - # Proportional - p_term = self.kp * error + # --- angular PID --- + p_ang = self.kp * ang_error - # Integral with anti-windup - self._integral += error * dt - integral_mag = np.linalg.norm(self._integral) - if integral_mag > self.max_integral: - self._integral = self._integral * (self.max_integral / integral_mag) - i_term = self.ki * self._integral + self._ang_integral += ang_error * dt + mag = np.linalg.norm(self._ang_integral) + if mag > self.max_integral: + self._ang_integral *= self.max_integral / mag + i_ang = self.ki * self._ang_integral - # Derivative - if self._prev_error is not None: - d_term = self.kd * (error - self._prev_error) / dt + if self._prev_ang_error is not None: + d_ang = self.kd * (ang_error - self._prev_ang_error) / dt else: - d_term = np.zeros(3) - self._prev_error = error.copy() + d_ang = np.zeros(3) + self._prev_ang_error = ang_error.copy() + + angular_cmd = p_ang + i_ang + d_ang + + # --- radial PID --- + p_rad = self.kp_radial * rad_error + + self._rad_integral += rad_error * dt + if abs(self._rad_integral) > self.max_integral_radial: + self._rad_integral = np.sign(self._rad_integral) * self.max_integral_radial + i_rad = self.ki_radial * self._rad_integral + + if self._prev_rad_error is not None: + d_rad = self.kd_radial * (rad_error - self._prev_rad_error) / dt + else: + d_rad = 0.0 + self._prev_rad_error = rad_error + + radial_cmd = p_rad + i_rad + d_rad - return p_term + i_term + d_term + return ControlCommand( + angular_velocity=angular_cmd, + radial_velocity=radial_cmd, + ) def reset(self) -> None: """Reset the controller's internal state (integral, derivative, etc.).""" super().reset() - self._integral = np.zeros(3) - self._prev_error = None + self._ang_integral = np.zeros(3) + self._rad_integral = 0.0 + self._prev_ang_error = None + self._prev_rad_error = None diff --git a/spinstep/control/state.py b/spinstep/control/state.py index d6c1fdc..4e06b41 100644 --- a/spinstep/control/state.py +++ b/spinstep/control/state.py @@ -2,13 +2,25 @@ # Author: Eraldo B. Marques — Created: 2025-05-14 # See LICENSE.txt for full terms. This header must be retained in redistributions. -"""Orientation state model: dataclass, integration, and error computation.""" +"""Observer-centered spherical state model. + +In the SpinStep control model the observer sits at the origin. +Every guided vehicle (node) is located by: + +- **orientation** — a unit quaternion giving the direction from the observer +- **distance** — the radial distance (layer) from the observer + +Velocities follow the same decomposition: angular velocity (rad/s) for +the tangential component and radial velocity (units/s) for the range +component. +""" from __future__ import annotations __all__ = [ "OrientationState", - "integrate_orientation", + "ControlCommand", + "integrate_state", "compute_orientation_error", ] @@ -23,38 +35,51 @@ @dataclass class OrientationState: - """Immutable orientation state: pose, angular velocity, and timestamp. + """Observer-centered state: direction, distance, velocities, timestamp. + + The state describes a guided vehicle (node) in the observer's spherical + frame. The quaternion gives the direction from the observer; the + distance gives the radial layer. All quaternions use ``[x, y, z, w]`` convention. Args: - quaternion: Unit quaternion ``[x, y, z, w]`` representing the - current orientation. - angular_velocity: Angular velocity vector ``[ωx, ωy, ωz]`` in - radians per second. Defaults to zero. - timestamp: Time in seconds. Defaults to ``0.0``. + quaternion: Unit quaternion ``[x, y, z, w]`` — direction from observer. + distance: Radial distance from observer. Defaults to ``0.0``. + angular_velocity: Angular velocity ``[ωx, ωy, ωz]`` in rad/s. + radial_velocity: Radial velocity (units/s). Positive = moving away + from observer. + timestamp: Time in seconds. Attributes: - quaternion: Normalised quaternion as a NumPy array of shape ``(4,)``. - angular_velocity: Angular velocity as a NumPy array of shape ``(3,)``. + quaternion: Normalised quaternion ``(4,)``. + distance: Radial distance (≥ 0). + angular_velocity: Angular velocity ``(3,)``. + radial_velocity: Radial velocity scalar. timestamp: Timestamp in seconds. Example:: from spinstep.control import OrientationState - state = OrientationState([0, 0, 0, 1]) - print(state.quaternion) # [0. 0. 0. 1.] + # A vehicle at distance 5.0 looking along +Z + state = OrientationState([0, 0, 0, 1], distance=5.0) """ - quaternion: np.ndarray = field(default_factory=lambda: np.array([0.0, 0.0, 0.0, 1.0])) + quaternion: np.ndarray = field( + default_factory=lambda: np.array([0.0, 0.0, 0.0, 1.0]) + ) + distance: float = 0.0 angular_velocity: np.ndarray = field(default_factory=lambda: np.zeros(3)) + radial_velocity: float = 0.0 timestamp: float = 0.0 def __init__( self, quaternion: ArrayLike = (0.0, 0.0, 0.0, 1.0), + distance: float = 0.0, angular_velocity: ArrayLike = (0.0, 0.0, 0.0), + radial_velocity: float = 0.0, timestamp: float = 0.0, ) -> None: q = np.asarray(quaternion, dtype=float) @@ -67,104 +92,170 @@ def __init__( raise ValueError("quaternion must be non-zero") self.quaternion = q / norm + if distance < 0: + raise ValueError(f"distance must be non-negative, got {distance}") + self.distance = float(distance) + omega = np.asarray(angular_velocity, dtype=float) if omega.shape != (3,): raise ValueError( f"angular_velocity must have shape (3,), got {omega.shape}" ) self.angular_velocity = omega + self.radial_velocity = float(radial_velocity) self.timestamp = float(timestamp) def __repr__(self) -> str: return ( f"OrientationState(" f"q={self.quaternion.tolist()}, " + f"d={self.distance}, " f"ω={self.angular_velocity.tolist()}, " + f"ṙ={self.radial_velocity}, " f"t={self.timestamp})" ) -def integrate_orientation(state: OrientationState, dt: float) -> OrientationState: - """Integrate orientation forward by *dt* seconds using current angular velocity. +@dataclass +class ControlCommand: + """Command output from a controller: angular + radial velocity. + + Separates the tangential (angular) and radial components of the + velocity command so they can be applied independently to actuators. + + Args: + angular_velocity: Desired angular velocity ``[ωx, ωy, ωz]`` in + rad/s. + radial_velocity: Desired radial velocity in units/s. Positive + means moving away from the observer. + + Example:: + + from spinstep.control.state import ControlCommand + + cmd = ControlCommand(angular_velocity=[0, 0, 1.0], radial_velocity=0.5) + """ + + angular_velocity: np.ndarray = field(default_factory=lambda: np.zeros(3)) + radial_velocity: float = 0.0 + + def __init__( + self, + angular_velocity: ArrayLike = (0.0, 0.0, 0.0), + radial_velocity: float = 0.0, + ) -> None: + omega = np.asarray(angular_velocity, dtype=float) + if omega.shape != (3,): + raise ValueError( + f"angular_velocity must have shape (3,), got {omega.shape}" + ) + self.angular_velocity = omega + self.radial_velocity = float(radial_velocity) + + def __repr__(self) -> str: + return ( + f"ControlCommand(" + f"ω={self.angular_velocity.tolist()}, " + f"ṙ={self.radial_velocity})" + ) + + +def integrate_state(state: OrientationState, dt: float) -> OrientationState: + """Integrate the full spherical state forward by *dt* seconds. - Uses the exponential map: ``q(t+dt) = q(t) * exp(ω · dt / 2)``, - which is the standard first-order quaternion integration. + Orientation is integrated via the exponential map: + ``q(t+dt) = q(t) * exp(ω · dt / 2)``. + Distance is integrated linearly: + ``d(t+dt) = max(0, d(t) + ṙ · dt)``. Args: - state: Current orientation state. + state: Current state. dt: Time step in seconds. Must be positive. Returns: - New :class:`OrientationState` with updated quaternion and timestamp. - Angular velocity is carried forward unchanged. + New :class:`OrientationState` with updated quaternion, distance, + and timestamp. Velocities are carried forward unchanged. Raises: ValueError: If *dt* is not positive. Example:: - from spinstep.control import OrientationState, integrate_orientation + from spinstep.control import OrientationState, integrate_state - state = OrientationState([0, 0, 0, 1], [0, 0, 1.0]) - new_state = integrate_orientation(state, dt=0.01) + state = OrientationState([0, 0, 0, 1], distance=5.0, + angular_velocity=[0, 0, 1.0], + radial_velocity=0.5) + new = integrate_state(state, dt=0.01) + # new.distance ≈ 5.005 """ if dt <= 0: raise ValueError(f"dt must be positive, got {dt}") + # --- angular integration --- omega = state.angular_velocity angle = np.linalg.norm(omega) if angle < 1e-10: - # No rotation — return state with updated timestamp - return OrientationState( - quaternion=state.quaternion.copy(), - angular_velocity=state.angular_velocity.copy(), - timestamp=state.timestamp + dt, - ) - - # Compute the incremental rotation quaternion: exp(ω·dt/2) - half_angle = angle * dt / 2.0 - axis = omega / angle - delta_q = np.array([ - *(axis * np.sin(half_angle)), - np.cos(half_angle), - ]) - - new_q = quaternion_multiply(state.quaternion, delta_q) - new_q = quaternion_normalize(new_q) + new_q = state.quaternion.copy() + else: + half_angle = angle * dt / 2.0 + axis = omega / angle + delta_q = np.array([ + *(axis * np.sin(half_angle)), + np.cos(half_angle), + ]) + new_q = quaternion_multiply(state.quaternion, delta_q) + new_q = quaternion_normalize(new_q) + + # --- radial integration --- + new_distance = max(0.0, state.distance + state.radial_velocity * dt) return OrientationState( quaternion=new_q, + distance=new_distance, angular_velocity=state.angular_velocity.copy(), + radial_velocity=state.radial_velocity, timestamp=state.timestamp + dt, ) def compute_orientation_error( - current_q: ArrayLike, target_q: ArrayLike -) -> np.ndarray: - """Compute the orientation error as an axis-angle vector from current to target. + current_q: ArrayLike, + target_q: ArrayLike, + current_distance: float = 0.0, + target_distance: float = 0.0, +) -> tuple[np.ndarray, float]: + """Compute the full spherical error: angular + radial. - The error is expressed in the body frame of *current_q*. Its direction is the - rotation axis and its magnitude is the rotation angle in radians. + The angular error is expressed in the body frame of *current_q* as a + rotation vector (axis × angle, in radians). + + The radial error is ``target_distance − current_distance`` (positive + means the target is farther from the observer). Args: current_q: Current orientation quaternion ``[x, y, z, w]``. target_q: Target orientation quaternion ``[x, y, z, w]``. + current_distance: Current radial distance. + target_distance: Target radial distance. Returns: - Error rotation vector ``(3,)`` in radians. Zero vector when - the orientations are identical. + A tuple ``(angular_error, radial_error)`` where *angular_error* + is a rotation vector ``(3,)`` and *radial_error* is a float. Example:: from spinstep.control import compute_orientation_error - error = compute_orientation_error([0, 0, 0, 1], [0, 0, 0.383, 0.924]) - print(error) # approximately [0, 0, 0.785] + ang_err, rad_err = compute_orientation_error( + [0, 0, 0, 1], [0, 0, 0.383, 0.924], + current_distance=3.0, target_distance=5.0, + ) """ r_current = R.from_quat(current_q) r_target = R.from_quat(target_q) - # Error rotation in the body frame of current r_error = r_current.inv() * r_target - return r_error.as_rotvec() + angular_error: np.ndarray = r_error.as_rotvec() + radial_error = float(target_distance - current_distance) + return angular_error, radial_error diff --git a/spinstep/control/trajectory.py b/spinstep/control/trajectory.py index 9c75bf6..7d009ee 100644 --- a/spinstep/control/trajectory.py +++ b/spinstep/control/trajectory.py @@ -2,7 +2,12 @@ # Author: Eraldo B. Marques — Created: 2025-05-14 # See LICENSE.txt for full terms. This header must be retained in redistributions. -"""Orientation trajectories: waypoints, interpolation, and tracking.""" +"""Orientation trajectories: waypoints, interpolation, and tracking. + +Trajectories in SpinStep are sequences of ``(quaternion, distance, time)`` +waypoints in the observer-centered spherical frame. The quaternion gives +the direction from the observer and the distance gives the radial layer. +""" from __future__ import annotations @@ -12,30 +17,36 @@ "TrajectoryController", ] -from typing import List, Optional, Sequence, Tuple +from typing import List, Sequence, Tuple, Union import numpy as np from numpy.typing import ArrayLike from ..math.interpolation import slerp from .controllers import OrientationController +from .state import ControlCommand class OrientationTrajectory: - """A sequence of quaternion waypoints with associated timestamps. + """A sequence of spherical waypoints with timestamps. + + Each waypoint is ``(quaternion, distance, time)`` or, for backward + compatibility, ``(quaternion, time)`` (distance defaults to ``0.0``). Waypoints must be in ascending time order. Args: - waypoints: Sequence of ``(quaternion, time)`` pairs where each - quaternion is ``[x, y, z, w]`` and time is in seconds. + waypoints: Sequence of waypoint tuples. Accepted forms: + - ``(quaternion, distance, time)`` + - ``(quaternion, time)`` — distance defaults to ``0.0`` Raises: - ValueError: If fewer than two waypoints are provided or times - are not strictly increasing. + ValueError: If fewer than two waypoints are provided, times are + not strictly increasing, or quaternions are invalid. Attributes: quaternions: Array of shape ``(N, 4)`` — waypoint quaternions. + distances: Array of shape ``(N,)`` — waypoint distances. times: Array of shape ``(N,)`` — waypoint times in seconds. Example:: @@ -43,18 +54,19 @@ class OrientationTrajectory: from spinstep.control import OrientationTrajectory traj = OrientationTrajectory([ - ([0, 0, 0, 1], 0.0), - ([0, 0, 0.383, 0.924], 1.0), - ([0, 0, 0.707, 0.707], 2.0), + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0.383, 0.924], 7.5, 1.0), + ([0, 0, 0.707, 0.707], 10.0, 2.0), ]) """ quaternions: np.ndarray + distances: np.ndarray times: np.ndarray def __init__( self, - waypoints: Sequence[Tuple[ArrayLike, float]], + waypoints: Sequence[Union[Tuple[ArrayLike, float, float], Tuple[ArrayLike, float]]], ) -> None: if len(waypoints) < 2: raise ValueError( @@ -62,9 +74,22 @@ def __init__( ) quats: List[np.ndarray] = [] + dists: List[float] = [] times: List[float] = [] - for q, t in waypoints: - arr = np.asarray(q, dtype=float) + + for wp in waypoints: + if len(wp) == 3: + q_raw, dist, t = wp # type: ignore[misc] + elif len(wp) == 2: + q_raw, t = wp # type: ignore[misc] + dist = 0.0 + else: + raise ValueError( + f"Each waypoint must be (quaternion, distance, time) " + f"or (quaternion, time), got tuple of length {len(wp)}" + ) + + arr = np.asarray(q_raw, dtype=float) if arr.shape != (4,): raise ValueError( f"Each waypoint quaternion must have shape (4,), got {arr.shape}" @@ -73,6 +98,7 @@ def __init__( if norm < 1e-8: raise ValueError("Waypoint quaternion must be non-zero") quats.append(arr / norm) + dists.append(float(dist)) times.append(float(t)) for i in range(1, len(times)): @@ -83,6 +109,7 @@ def __init__( ) self.quaternions = np.array(quats) + self.distances = np.array(dists) self.times = np.array(times) @property @@ -111,10 +138,10 @@ def __repr__(self) -> str: class TrajectoryInterpolator: - """SLERP-based interpolator for an :class:`OrientationTrajectory`. + """SLERP + linear interpolator for an :class:`OrientationTrajectory`. - Evaluates the orientation at any time within the trajectory's time span - using spherical linear interpolation between adjacent waypoints. + Orientation is interpolated via SLERP; distance is linearly + interpolated between adjacent waypoints. Args: trajectory: The trajectory to interpolate. @@ -124,34 +151,35 @@ class TrajectoryInterpolator: from spinstep.control import OrientationTrajectory, TrajectoryInterpolator traj = OrientationTrajectory([ - ([0, 0, 0, 1], 0.0), - ([0, 0, 0.383, 0.924], 1.0), + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0.383, 0.924], 10.0, 1.0), ]) interp = TrajectoryInterpolator(traj) - q = interp.evaluate(0.5) + q, d = interp.evaluate(0.5) + # q ≈ slerp midpoint, d ≈ 7.5 """ def __init__(self, trajectory: OrientationTrajectory) -> None: self.trajectory = trajectory - def evaluate(self, t: float) -> np.ndarray: - """Return the interpolated quaternion at time *t*. + def evaluate(self, t: float) -> Tuple[np.ndarray, float]: + """Return the interpolated quaternion and distance at time *t*. - Times before the first waypoint return the first quaternion. - Times after the last waypoint return the last quaternion. + Times before the first waypoint return the first pose; times + after the last return the last pose. Args: t: Query time in seconds. Returns: - Interpolated unit quaternion ``[x, y, z, w]``. + A tuple ``(quaternion, distance)``. """ traj = self.trajectory if t <= traj.times[0]: - return traj.quaternions[0].copy() + return traj.quaternions[0].copy(), float(traj.distances[0]) if t >= traj.times[-1]: - return traj.quaternions[-1].copy() + return traj.quaternions[-1].copy(), float(traj.distances[-1]) # Find the segment idx = int(np.searchsorted(traj.times, t, side="right") - 1) @@ -161,7 +189,9 @@ def evaluate(self, t: float) -> np.ndarray: t1 = traj.times[idx + 1] alpha = (t - t0) / (t1 - t0) - return slerp(traj.quaternions[idx], traj.quaternions[idx + 1], alpha) + q = slerp(traj.quaternions[idx], traj.quaternions[idx + 1], alpha) + d = float(traj.distances[idx] + alpha * (traj.distances[idx + 1] - traj.distances[idx])) + return q, d @property def duration(self) -> float: @@ -170,21 +200,19 @@ def duration(self) -> float: class TrajectoryController: - """Controller that tracks an orientation trajectory over time. + """Controller that tracks a spherical trajectory over time. Wraps a base :class:`OrientationController` and a - :class:`TrajectoryInterpolator`. At each time step the controller - queries the interpolator for the desired orientation and computes the - angular velocity command to drive the system towards it. + :class:`TrajectoryInterpolator`. At each step the controller queries + the interpolator for the desired pose and computes the + :class:`~.state.ControlCommand` to drive the vehicle towards it. Args: - controller: An :class:`OrientationController` instance (e.g. - :class:`ProportionalOrientationController` or - :class:`PIDOrientationController`). + controller: Base controller instance. trajectory: The trajectory to follow. Attributes: - interpolator: The :class:`TrajectoryInterpolator` used internally. + interpolator: Internal :class:`TrajectoryInterpolator`. controller: The wrapped base controller. is_complete: Whether the trajectory end time has been reached. @@ -197,12 +225,12 @@ class TrajectoryController: ) traj = OrientationTrajectory([ - ([0, 0, 0, 1], 0.0), - ([0, 0, 0.383, 0.924], 1.0), + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0.383, 0.924], 10.0, 1.0), ]) - ctrl = ProportionalOrientationController(kp=2.0) + ctrl = ProportionalOrientationController(kp=2.0, kp_radial=1.0) traj_ctrl = TrajectoryController(ctrl, traj) - cmd = traj_ctrl.update([0, 0, 0, 1], t=0.5, dt=0.01) + cmd = traj_ctrl.update([0, 0, 0, 1], current_distance=5.0, t=0.5, dt=0.01) """ def __init__( @@ -219,20 +247,24 @@ def update( current_q: ArrayLike, t: float, dt: float, - ) -> np.ndarray: - """Compute angular velocity command to track the trajectory at time *t*. + current_distance: float = 0.0, + ) -> ControlCommand: + """Compute the control command to track the trajectory at time *t*. Args: current_q: Current orientation ``[x, y, z, w]``. t: Current time in seconds. dt: Time step in seconds. + current_distance: Current radial distance from observer. Returns: - Angular velocity command ``(3,)`` in rad/s. + :class:`~.state.ControlCommand` with angular and radial components. """ - target_q = self.interpolator.evaluate(t) + target_q, target_distance = self.interpolator.evaluate(t) self.is_complete = t >= self.interpolator.trajectory.end_time - return self.controller.update(current_q, target_q, dt) + return self.controller.update( + current_q, target_q, dt, current_distance, target_distance + ) def reset(self) -> None: """Reset the controller state.""" diff --git a/spinstep/math/interpolation.py b/spinstep/math/interpolation.py index db3b149..22fba9d 100644 --- a/spinstep/math/interpolation.py +++ b/spinstep/math/interpolation.py @@ -11,8 +11,6 @@ "squad", ] -from typing import Sequence - import numpy as np from numpy.typing import ArrayLike diff --git a/spinstep/traversal/__init__.py b/spinstep/traversal/__init__.py index 3617c00..d480ffa 100644 --- a/spinstep/traversal/__init__.py +++ b/spinstep/traversal/__init__.py @@ -4,7 +4,7 @@ """Tree traversal using quaternion orientation. -This sub-package contains the original traversal classes: +This sub-package contains the traversal classes: - :class:`Node` — tree node with quaternion orientation - :class:`QuaternionDepthIterator` — continuous rotation-step depth-first traversal @@ -19,7 +19,7 @@ "DiscreteQuaternionIterator", ] -from ..node import Node -from ..traversal import QuaternionDepthIterator -from ..discrete import DiscreteOrientationSet -from ..discrete_iterator import DiscreteQuaternionIterator +from .node import Node +from .continuous import QuaternionDepthIterator +from .discrete import DiscreteOrientationSet +from .discrete_iterator import DiscreteQuaternionIterator diff --git a/spinstep/traversal.py b/spinstep/traversal/continuous.py similarity index 100% rename from spinstep/traversal.py rename to spinstep/traversal/continuous.py diff --git a/spinstep/discrete.py b/spinstep/traversal/discrete.py similarity index 99% rename from spinstep/discrete.py rename to spinstep/traversal/discrete.py index b4fae53..be04773 100644 --- a/spinstep/discrete.py +++ b/spinstep/traversal/discrete.py @@ -14,7 +14,7 @@ from numpy.typing import ArrayLike from scipy.spatial.transform import Rotation as R -from spinstep.utils.array_backend import get_array_module +from ..utils.array_backend import get_array_module class DiscreteOrientationSet: diff --git a/spinstep/discrete_iterator.py b/spinstep/traversal/discrete_iterator.py similarity index 100% rename from spinstep/discrete_iterator.py rename to spinstep/traversal/discrete_iterator.py diff --git a/spinstep/node.py b/spinstep/traversal/node.py similarity index 100% rename from spinstep/node.py rename to spinstep/traversal/node.py diff --git a/tests/test_control.py b/tests/test_control.py new file mode 100644 index 0000000..123e866 --- /dev/null +++ b/tests/test_control.py @@ -0,0 +1,515 @@ +# test_control.py — SpinStep Test Suite — MIT License +# Tests for spinstep.control subpackage (state, controllers, trajectory) + +import numpy as np +import pytest +from scipy.spatial.transform import Rotation as R + +from spinstep.control.state import ( + ControlCommand, + OrientationState, + compute_orientation_error, + integrate_state, +) +from spinstep.control.controllers import ( + PIDOrientationController, + ProportionalOrientationController, +) +from spinstep.control.trajectory import ( + OrientationTrajectory, + TrajectoryController, + TrajectoryInterpolator, +) + + +# ===== OrientationState ===== + + +class TestOrientationState: + def test_defaults(self): + state = OrientationState() + assert np.allclose(state.quaternion, [0, 0, 0, 1]) + assert state.distance == 0.0 + assert np.allclose(state.angular_velocity, [0, 0, 0]) + assert state.radial_velocity == 0.0 + assert state.timestamp == 0.0 + + def test_with_distance(self): + state = OrientationState([0, 0, 0, 1], distance=5.0) + assert state.distance == 5.0 + + def test_normalizes_quaternion(self): + state = OrientationState([0, 0, 0, 2]) + assert np.allclose(state.quaternion, [0, 0, 0, 1]) + + def test_full_state(self): + state = OrientationState( + [0, 0, 0, 1], + distance=10.0, + angular_velocity=[0, 0, 1.0], + radial_velocity=0.5, + timestamp=1.0, + ) + assert state.distance == 10.0 + assert state.radial_velocity == 0.5 + assert state.timestamp == 1.0 + + def test_invalid_quaternion_shape(self): + with pytest.raises(ValueError, match="shape"): + OrientationState([1, 0, 0]) + + def test_zero_quaternion(self): + with pytest.raises(ValueError, match="non-zero"): + OrientationState([0, 0, 0, 0]) + + def test_negative_distance(self): + with pytest.raises(ValueError, match="non-negative"): + OrientationState([0, 0, 0, 1], distance=-1.0) + + def test_repr(self): + state = OrientationState([0, 0, 0, 1], distance=5.0) + r = repr(state) + assert "OrientationState" in r + assert "d=5.0" in r + + +# ===== ControlCommand ===== + + +class TestControlCommand: + def test_defaults(self): + cmd = ControlCommand() + assert np.allclose(cmd.angular_velocity, [0, 0, 0]) + assert cmd.radial_velocity == 0.0 + + def test_custom(self): + cmd = ControlCommand([1, 0, 0], radial_velocity=2.5) + assert np.allclose(cmd.angular_velocity, [1, 0, 0]) + assert cmd.radial_velocity == 2.5 + + def test_repr(self): + cmd = ControlCommand([0, 0, 1.0], radial_velocity=0.5) + r = repr(cmd) + assert "ControlCommand" in r + + +# ===== integrate_state ===== + + +class TestIntegrateState: + def test_stationary(self): + state = OrientationState([0, 0, 0, 1], distance=5.0) + new = integrate_state(state, dt=0.1) + assert np.allclose(new.quaternion, [0, 0, 0, 1]) + assert new.distance == 5.0 + assert new.timestamp == pytest.approx(0.1) + + def test_radial_integration(self): + state = OrientationState( + [0, 0, 0, 1], distance=5.0, radial_velocity=2.0 + ) + new = integrate_state(state, dt=1.0) + assert new.distance == pytest.approx(7.0) + + def test_radial_clamp_to_zero(self): + """Distance cannot go negative.""" + state = OrientationState( + [0, 0, 0, 1], distance=1.0, radial_velocity=-5.0 + ) + new = integrate_state(state, dt=1.0) + assert new.distance == 0.0 + + def test_angular_integration(self): + """Rotating about Z axis.""" + omega = np.pi # rad/s → 180° per second + state = OrientationState( + [0, 0, 0, 1], angular_velocity=[0, 0, omega] + ) + new = integrate_state(state, dt=0.5) + # After 0.5s at π rad/s → 90° rotation about Z + angle = R.from_quat(new.quaternion).magnitude() + assert angle == pytest.approx(np.pi / 2, abs=0.01) + + def test_combined_integration(self): + state = OrientationState( + [0, 0, 0, 1], + distance=5.0, + angular_velocity=[0, 0, 1.0], + radial_velocity=0.5, + timestamp=1.0, + ) + new = integrate_state(state, dt=0.1) + assert new.timestamp == pytest.approx(1.1) + assert new.distance == pytest.approx(5.05) + assert not np.allclose(new.quaternion, [0, 0, 0, 1]) + + def test_invalid_dt(self): + state = OrientationState() + with pytest.raises(ValueError): + integrate_state(state, dt=0) + with pytest.raises(ValueError): + integrate_state(state, dt=-0.1) + + +# ===== compute_orientation_error ===== + + +class TestComputeOrientationError: + def test_no_error(self): + ang, rad = compute_orientation_error( + [0, 0, 0, 1], [0, 0, 0, 1], 5.0, 5.0 + ) + assert np.allclose(ang, [0, 0, 0], atol=1e-6) + assert rad == pytest.approx(0.0) + + def test_angular_error_only(self): + q_target = R.from_euler("z", 45, degrees=True).as_quat() + ang, rad = compute_orientation_error( + [0, 0, 0, 1], q_target, 5.0, 5.0 + ) + assert np.linalg.norm(ang) == pytest.approx(np.deg2rad(45), abs=0.01) + assert rad == pytest.approx(0.0) + + def test_radial_error_only(self): + ang, rad = compute_orientation_error( + [0, 0, 0, 1], [0, 0, 0, 1], 3.0, 7.0 + ) + assert np.allclose(ang, [0, 0, 0], atol=1e-6) + assert rad == pytest.approx(4.0) + + def test_combined_error(self): + q_target = R.from_euler("z", 90, degrees=True).as_quat() + ang, rad = compute_orientation_error( + [0, 0, 0, 1], q_target, 2.0, 8.0 + ) + assert np.linalg.norm(ang) == pytest.approx(np.pi / 2, abs=0.01) + assert rad == pytest.approx(6.0) + + def test_backward_compat_no_distance(self): + """Works without distance arguments (defaults to 0).""" + ang, rad = compute_orientation_error([0, 0, 0, 1], [0, 0, 0, 1]) + assert np.allclose(ang, [0, 0, 0], atol=1e-6) + assert rad == pytest.approx(0.0) + + +# ===== ProportionalOrientationController ===== + + +class TestProportionalController: + def test_zero_error(self): + ctrl = ProportionalOrientationController(kp=2.0) + cmd = ctrl.update([0, 0, 0, 1], [0, 0, 0, 1], dt=0.01) + assert np.allclose(cmd.angular_velocity, [0, 0, 0], atol=1e-6) + assert cmd.radial_velocity == pytest.approx(0.0, abs=1e-6) + + def test_angular_error(self): + ctrl = ProportionalOrientationController(kp=1.0) + q_target = R.from_euler("z", 45, degrees=True).as_quat() + cmd = ctrl.update([0, 0, 0, 1], q_target, dt=0.01) + assert np.linalg.norm(cmd.angular_velocity) > 0 + + def test_radial_error(self): + ctrl = ProportionalOrientationController(kp=1.0, kp_radial=2.0) + cmd = ctrl.update( + [0, 0, 0, 1], [0, 0, 0, 1], dt=0.01, + current_distance=3.0, target_distance=5.0, + ) + assert cmd.radial_velocity == pytest.approx(4.0) # 2.0 × 2.0 + + def test_velocity_limit(self): + ctrl = ProportionalOrientationController( + kp=100.0, max_angular_velocity=1.0 + ) + q_target = R.from_euler("z", 90, degrees=True).as_quat() + cmd = ctrl.update([0, 0, 0, 1], q_target, dt=0.01) + assert np.linalg.norm(cmd.angular_velocity) <= 1.0 + 1e-6 + + def test_radial_velocity_limit(self): + ctrl = ProportionalOrientationController( + kp_radial=100.0, max_radial_velocity=2.0 + ) + cmd = ctrl.update( + [0, 0, 0, 1], [0, 0, 0, 1], dt=0.01, + current_distance=0.0, target_distance=10.0, + ) + assert abs(cmd.radial_velocity) <= 2.0 + 1e-6 + + def test_invalid_dt(self): + ctrl = ProportionalOrientationController() + with pytest.raises(ValueError): + ctrl.update([0, 0, 0, 1], [0, 0, 0, 1], dt=0) + + def test_reset(self): + ctrl = ProportionalOrientationController(kp=1.0) + ctrl.update([0, 0, 0, 1], [0, 0, 0, 1], dt=0.01) + ctrl.reset() + assert ctrl._prev_angular_cmd is None + + +# ===== PIDOrientationController ===== + + +class TestPIDController: + def test_zero_error(self): + ctrl = PIDOrientationController(kp=1.0, ki=0.1, kd=0.5) + cmd = ctrl.update([0, 0, 0, 1], [0, 0, 0, 1], dt=0.01) + assert np.allclose(cmd.angular_velocity, [0, 0, 0], atol=1e-6) + + def test_integral_accumulates(self): + ctrl = PIDOrientationController(kp=0.0, ki=1.0, kd=0.0) + q_target = R.from_euler("z", 10, degrees=True).as_quat() + # First step — integral starts accumulating + ctrl.update([0, 0, 0, 1], q_target, dt=0.1) + # Second step — integral grows + cmd = ctrl.update([0, 0, 0, 1], q_target, dt=0.1) + assert np.linalg.norm(cmd.angular_velocity) > 0 + + def test_derivative_term(self): + ctrl = PIDOrientationController(kp=0.0, ki=0.0, kd=1.0) + q1 = R.from_euler("z", 10, degrees=True).as_quat() + q2 = R.from_euler("z", 20, degrees=True).as_quat() + # First call sets _prev_error + ctrl.update([0, 0, 0, 1], q1, dt=0.01) + # Second call with different target should produce d_term + cmd = ctrl.update([0, 0, 0, 1], q2, dt=0.01) + assert np.linalg.norm(cmd.angular_velocity) > 0 + + def test_radial_pid(self): + ctrl = PIDOrientationController( + kp=0.0, ki=0.0, kd=0.0, + kp_radial=2.0, ki_radial=0.5, kd_radial=0.0, + ) + cmd = ctrl.update( + [0, 0, 0, 1], [0, 0, 0, 1], dt=0.1, + current_distance=3.0, target_distance=5.0, + ) + assert cmd.radial_velocity == pytest.approx(2.0 * 2.0 + 0.5 * 2.0 * 0.1) + + def test_reset_clears_state(self): + ctrl = PIDOrientationController(kp=1.0, ki=1.0, kd=1.0) + ctrl.update([0, 0, 0, 1], [0, 0, 0, 1], dt=0.01) + ctrl.reset() + assert np.allclose(ctrl._ang_integral, [0, 0, 0]) + assert ctrl._rad_integral == 0.0 + assert ctrl._prev_ang_error is None + assert ctrl._prev_rad_error is None + + def test_anti_windup(self): + """Integral should be clamped.""" + ctrl = PIDOrientationController(kp=0.0, ki=1.0, max_integral=0.01) + q_target = R.from_euler("z", 90, degrees=True).as_quat() + for _ in range(100): + ctrl.update([0, 0, 0, 1], q_target, dt=1.0) + assert np.linalg.norm(ctrl._ang_integral) <= 0.01 + 1e-6 + + +# ===== OrientationTrajectory ===== + + +class TestOrientationTrajectory: + def test_basic_3_tuple(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0.383, 0.924], 10.0, 1.0), + ]) + assert len(traj) == 2 + assert traj.duration == pytest.approx(1.0) + assert np.allclose(traj.distances, [5.0, 10.0]) + + def test_basic_2_tuple_backward_compat(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 0.0), + ([0, 0, 0.383, 0.924], 1.0), + ]) + assert len(traj) == 2 + assert np.allclose(traj.distances, [0.0, 0.0]) + + def test_mixed_tuple_lengths(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0.383, 0.924], 1.0), # 2-tuple, distance=0 + ]) + assert traj.distances[0] == 5.0 + assert traj.distances[1] == 0.0 + + def test_too_few_waypoints(self): + with pytest.raises(ValueError, match="At least 2"): + OrientationTrajectory([([0, 0, 0, 1], 0.0)]) + + def test_non_increasing_times(self): + with pytest.raises(ValueError, match="strictly increasing"): + OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 1.0), + ([0, 0, 0, 1], 5.0, 0.5), + ]) + + def test_properties(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 1.0), + ([0, 0, 0, 1], 10.0, 3.0), + ]) + assert traj.start_time == 1.0 + assert traj.end_time == 3.0 + assert traj.duration == 2.0 + + def test_repr(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0, 1], 10.0, 1.0), + ]) + assert "OrientationTrajectory" in repr(traj) + + +# ===== TrajectoryInterpolator ===== + + +class TestTrajectoryInterpolator: + def test_at_waypoints(self): + q0 = [0, 0, 0, 1] + q1 = R.from_euler("z", 90, degrees=True).as_quat().tolist() + traj = OrientationTrajectory([ + (q0, 5.0, 0.0), + (q1, 10.0, 1.0), + ]) + interp = TrajectoryInterpolator(traj) + q_start, d_start = interp.evaluate(0.0) + q_end, d_end = interp.evaluate(1.0) + assert np.allclose(q_start, q0, atol=1e-6) + assert d_start == pytest.approx(5.0) + assert d_end == pytest.approx(10.0) + + def test_midpoint_distance(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 4.0, 0.0), + ([0, 0, 0, 1], 8.0, 1.0), + ]) + interp = TrajectoryInterpolator(traj) + _, d_mid = interp.evaluate(0.5) + assert d_mid == pytest.approx(6.0) + + def test_before_start(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 1.0), + ([0, 0, 0, 1], 10.0, 2.0), + ]) + interp = TrajectoryInterpolator(traj) + q, d = interp.evaluate(0.0) + assert d == pytest.approx(5.0) + + def test_after_end(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0, 1], 10.0, 1.0), + ]) + interp = TrajectoryInterpolator(traj) + q, d = interp.evaluate(99.0) + assert d == pytest.approx(10.0) + + +# ===== TrajectoryController ===== + + +class TestTrajectoryController: + def test_basic_tracking(self): + q0 = [0, 0, 0, 1] + q1 = R.from_euler("z", 45, degrees=True).as_quat().tolist() + traj = OrientationTrajectory([ + (q0, 5.0, 0.0), + (q1, 10.0, 1.0), + ]) + ctrl = ProportionalOrientationController(kp=1.0, kp_radial=1.0) + tc = TrajectoryController(ctrl, traj) + + cmd = tc.update(q0, t=0.5, dt=0.01, current_distance=7.0) + # At t=0.5 target distance ≈ 7.5, so radial_velocity > 0 + assert cmd.radial_velocity > 0 + assert not tc.is_complete + + def test_complete_flag(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0, 1], 10.0, 1.0), + ]) + ctrl = ProportionalOrientationController() + tc = TrajectoryController(ctrl, traj) + tc.update([0, 0, 0, 1], t=1.5, dt=0.01, current_distance=10.0) + assert tc.is_complete + + def test_reset(self): + traj = OrientationTrajectory([ + ([0, 0, 0, 1], 5.0, 0.0), + ([0, 0, 0, 1], 10.0, 1.0), + ]) + ctrl = ProportionalOrientationController() + tc = TrajectoryController(ctrl, traj) + tc.update([0, 0, 0, 1], t=2.0, dt=0.01) + assert tc.is_complete + tc.reset() + assert not tc.is_complete + + +# ===== Integration: full control loop ===== + + +class TestControlLoop: + def test_converges_to_target(self): + """P controller should drive orientation toward target over many steps.""" + ctrl = ProportionalOrientationController(kp=5.0, kp_radial=3.0) + current_q = np.array([0, 0, 0, 1.0]) + target_q = R.from_euler("z", 45, degrees=True).as_quat() + current_dist = 5.0 + target_dist = 10.0 + dt = 0.01 + + for _ in range(500): + cmd = ctrl.update( + current_q, target_q, dt, + current_distance=current_dist, target_distance=target_dist, + ) + state = OrientationState( + current_q, distance=current_dist, + angular_velocity=cmd.angular_velocity, + radial_velocity=cmd.radial_velocity, + ) + new_state = integrate_state(state, dt) + current_q = new_state.quaternion + current_dist = new_state.distance + + # Should be close to target + ang_err, rad_err = compute_orientation_error( + current_q, target_q, current_dist, target_dist + ) + assert np.linalg.norm(ang_err) < 0.01 + assert abs(rad_err) < 0.1 + + def test_trajectory_tracking_loop(self): + """TrajectoryController should follow a trajectory.""" + q0 = [0, 0, 0, 1] + q1 = R.from_euler("z", 90, degrees=True).as_quat().tolist() + traj = OrientationTrajectory([ + (q0, 5.0, 0.0), + (q1, 10.0, 2.0), + ]) + ctrl = ProportionalOrientationController(kp=5.0, kp_radial=3.0) + tc = TrajectoryController(ctrl, traj) + + current_q = np.array(q0, dtype=float) + current_dist = 5.0 + dt = 0.01 + t = 0.0 + + for _ in range(200): + cmd = tc.update(current_q, t=t, dt=dt, current_distance=current_dist) + state = OrientationState( + current_q, distance=current_dist, + angular_velocity=cmd.angular_velocity, + radial_velocity=cmd.radial_velocity, + ) + new_state = integrate_state(state, dt) + current_q = new_state.quaternion + current_dist = new_state.distance + t += dt + + # At t=2.0 should be near q1 / distance=10 + # At t=2.0 we're only at t=2.0 so check trajectory makes progress + assert current_dist > 5.0 # moved outward diff --git a/tests/test_discrete_traversal.py b/tests/test_discrete_traversal.py index f55df45..40a775a 100644 --- a/tests/test_discrete_traversal.py +++ b/tests/test_discrete_traversal.py @@ -14,8 +14,8 @@ from scipy.spatial.transform import Rotation as R # Import the modules under test -from spinstep.discrete import DiscreteOrientationSet -from spinstep.discrete_iterator import DiscreteQuaternionIterator +from spinstep.traversal.discrete import DiscreteOrientationSet +from spinstep.traversal.discrete_iterator import DiscreteQuaternionIterator # Simple node class for testing class Node: diff --git a/tests/test_math.py b/tests/test_math.py new file mode 100644 index 0000000..7a611dc --- /dev/null +++ b/tests/test_math.py @@ -0,0 +1,191 @@ +# test_math.py — SpinStep Test Suite — MIT License +# Tests for spinstep.math subpackage + +import numpy as np +import pytest +from scipy.spatial.transform import Rotation as R + +from spinstep.math.core import ( + quaternion_inverse, + quaternion_multiply, + quaternion_normalize, +) +from spinstep.math.interpolation import slerp, squad +from spinstep.math.geometry import quaternion_distance +from spinstep.math.conversions import ( + quaternion_from_rotvec, + quaternion_to_rotvec, +) +from spinstep.math.analysis import angular_velocity_from_quaternions +from spinstep.math.constraints import clamp_rotation_angle + + +# ===== core ===== + + +class TestQuaternionNormalize: + def test_unit_quaternion(self): + q = quaternion_normalize([0.5, 0.5, 0.5, 0.5]) + assert np.linalg.norm(q) == pytest.approx(1.0, abs=1e-10) + + def test_zero_quaternion(self): + q = quaternion_normalize([0, 0, 0, 0]) + assert np.allclose(q, [0, 0, 0, 1]) + + def test_non_unit(self): + q = quaternion_normalize([0, 0, 0, 2]) + assert np.allclose(q, [0, 0, 0, 1]) + + +class TestQuaternionInverse: + def test_unit_quaternion_inverse(self): + q = np.array([0.5, 0.5, 0.5, 0.5]) + inv = quaternion_inverse(q) + product = quaternion_multiply(q, inv) + # Should be close to identity [0, 0, 0, 1] + assert np.allclose(product, [0, 0, 0, 1], atol=1e-10) or np.allclose( + product, [0, 0, 0, -1], atol=1e-10 + ) + + def test_zero_quaternion(self): + inv = quaternion_inverse([0, 0, 0, 0]) + assert np.allclose(inv, [0, 0, 0, 1]) + + +# ===== interpolation ===== + + +class TestSlerp: + def test_endpoints(self): + q0 = np.array([0, 0, 0, 1.0]) + q1 = np.array([0, 0, np.sin(np.pi / 4), np.cos(np.pi / 4)]) + assert np.allclose(slerp(q0, q1, 0.0), q0, atol=1e-6) + assert np.allclose(slerp(q0, q1, 1.0), q1 / np.linalg.norm(q1), atol=1e-6) + + def test_midpoint_unit(self): + q0 = np.array([0, 0, 0, 1.0]) + q1 = np.array([0, 0, np.sin(np.pi / 4), np.cos(np.pi / 4)]) + mid = slerp(q0, q1, 0.5) + assert np.linalg.norm(mid) == pytest.approx(1.0, abs=1e-10) + + def test_same_quaternion(self): + q = np.array([0, 0, 0, 1.0]) + result = slerp(q, q, 0.5) + assert np.allclose(result, q, atol=1e-6) + + def test_shortest_path(self): + """SLERP takes the shortest arc even when quaternions are antipodal representatives.""" + q0 = np.array([0, 0, 0, 1.0]) + q1 = np.array([0, 0, 0, -1.0]) # same rotation, opposite sign + mid = slerp(q0, q1, 0.5) + assert np.linalg.norm(mid) == pytest.approx(1.0, abs=1e-6) + + def test_interpolation_angle(self): + """Midpoint of 0° and 90° should be ~45°.""" + q0 = np.array([0, 0, 0, 1.0]) + q1 = R.from_euler("z", 90, degrees=True).as_quat() + mid = slerp(q0, q1, 0.5) + angle = R.from_quat(mid).magnitude() + assert angle == pytest.approx(np.deg2rad(45), abs=0.01) + + +class TestSquad: + def test_returns_unit_quaternion(self): + q0 = R.from_euler("z", 0, degrees=True).as_quat() + q1 = R.from_euler("z", 30, degrees=True).as_quat() + q2 = R.from_euler("z", 60, degrees=True).as_quat() + q3 = R.from_euler("z", 90, degrees=True).as_quat() + result = squad(q0, q1, q2, q3, 0.5) + assert np.linalg.norm(result) == pytest.approx(1.0, abs=1e-6) + + def test_endpoints(self): + q0 = R.from_euler("z", 0, degrees=True).as_quat() + q1 = R.from_euler("z", 30, degrees=True).as_quat() + q2 = R.from_euler("z", 60, degrees=True).as_quat() + q3 = R.from_euler("z", 90, degrees=True).as_quat() + start = squad(q0, q1, q2, q3, 0.0) + end = squad(q0, q1, q2, q3, 1.0) + # At t=0 should be close to q1, at t=1 close to q2 + angle_start = quaternion_distance(start, q1) + angle_end = quaternion_distance(end, q2) + assert angle_start < 0.1 + assert angle_end < 0.1 + + +# ===== analysis ===== + + +class TestAngularVelocityFromQuaternions: + def test_no_rotation(self): + q = [0, 0, 0, 1] + omega = angular_velocity_from_quaternions(q, q, dt=0.1) + assert np.allclose(omega, [0, 0, 0], atol=1e-6) + + def test_known_rotation(self): + """90° about Z in 1 second → ω ≈ [0, 0, π/2].""" + q1 = [0, 0, 0, 1] + q2 = R.from_euler("z", 90, degrees=True).as_quat() + omega = angular_velocity_from_quaternions(q1, q2, dt=1.0) + assert np.allclose(omega, [0, 0, np.pi / 2], atol=1e-6) + + def test_invalid_dt(self): + with pytest.raises(ValueError): + angular_velocity_from_quaternions([0, 0, 0, 1], [0, 0, 0, 1], dt=0) + + +# ===== constraints ===== + + +class TestClampRotationAngle: + def test_within_limit(self): + """Small rotation should not be changed.""" + q = R.from_euler("z", 10, degrees=True).as_quat() + clamped = clamp_rotation_angle(q, max_angle=np.pi) + assert np.allclose(clamped, q, atol=1e-6) + + def test_clamped(self): + """Large rotation should be clamped to max_angle.""" + q = R.from_euler("z", 90, degrees=True).as_quat() + max_angle = np.deg2rad(45) + clamped = clamp_rotation_angle(q, max_angle=max_angle) + angle = R.from_quat(clamped).magnitude() + assert angle == pytest.approx(max_angle, abs=1e-6) + + def test_preserves_axis(self): + """Clamping preserves the rotation axis.""" + q = R.from_euler("z", 90, degrees=True).as_quat() + max_angle = np.deg2rad(45) + clamped = clamp_rotation_angle(q, max_angle=max_angle) + original_axis = R.from_quat(q).as_rotvec() + clamped_axis = R.from_quat(clamped).as_rotvec() + # Axes should be parallel + original_dir = original_axis / np.linalg.norm(original_axis) + clamped_dir = clamped_axis / np.linalg.norm(clamped_axis) + assert np.allclose(original_dir, clamped_dir, atol=1e-6) + + def test_negative_max_angle_raises(self): + with pytest.raises(ValueError): + clamp_rotation_angle([0, 0, 0, 1], max_angle=-1.0) + + +# ===== conversions ===== + + +class TestQuaternionFromRotvec: + def test_identity(self): + q = quaternion_from_rotvec([0, 0, 0]) + assert np.allclose(q, [0, 0, 0, 1], atol=1e-6) + + def test_90_about_z(self): + rotvec = [0, 0, np.pi / 2] + q = quaternion_from_rotvec(rotvec) + expected = R.from_rotvec(rotvec).as_quat() + assert np.allclose(q, expected, atol=1e-6) + + +class TestQuaternionToRotvec: + def test_roundtrip(self): + rotvec = [0.3, -0.5, 0.7] + q = quaternion_from_rotvec(rotvec) + result = quaternion_to_rotvec(q) + assert np.allclose(result, rotvec, atol=1e-6) diff --git a/tests/test_spinstep.py b/tests/test_spinstep.py index ee5b828..affc816 100644 --- a/tests/test_spinstep.py +++ b/tests/test_spinstep.py @@ -11,10 +11,10 @@ except ImportError: HAS_CUPY = False -from spinstep.discrete import DiscreteOrientationSet +from spinstep.traversal.discrete import DiscreteOrientationSet # If you have continuous traversal classes, import them here -# from spinstep.continuous import QuaternionDepthIterator -from spinstep.node import Node +# from spinstep.traversal.continuous import QuaternionDepthIterator +from spinstep.traversal.node import Node @pytest.fixture def simple_tree(): diff --git a/tests/test_traversal.py b/tests/test_traversal.py index db8053e..c7fd841 100644 --- a/tests/test_traversal.py +++ b/tests/test_traversal.py @@ -6,8 +6,8 @@ import numpy as np from scipy.spatial.transform import Rotation as R -from spinstep.node import Node -from spinstep.traversal import QuaternionDepthIterator +from spinstep.traversal.node import Node +from spinstep.traversal.continuous import QuaternionDepthIterator class TestQuaternionDepthIterator: diff --git a/tests/test_utils.py b/tests/test_utils.py index 6e6ce90..21effca 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,7 +6,7 @@ import pytest from scipy.spatial.transform import Rotation as R -from spinstep.node import Node +from spinstep.traversal.node import Node from spinstep.utils.array_backend import get_array_module from spinstep.utils.quaternion_math import batch_quaternion_angle from spinstep.utils.quaternion_utils import ( From 087d8b99357022381b86ca1a4fa89720a03b0a2e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Apr 2026 17:43:11 +0000 Subject: [PATCH 3/5] Quaternion library audit: resolve utils/math duplication, add API tests, improve packaging - Replace utils/quaternion_utils.py with re-exports from spinstep.math - Replace utils/quaternion_math.py with re-export from spinstep.math.analysis - Update utils/__init__.py with deprecation documentation - Add test_api.py with comprehensive API stability tests (92 new tests) - Bump version to 0.5.0a0 - Update pyproject.toml: add Typing::Typed classifier, Changelog/Docs URLs, py.typed package-data, mypy overrides for third-party stubs - Fix lint issues in test_discrete_traversal.py, test_spinstep.py, test_traversal.py Agent-Logs-Url: https://github.com/VoxleOne/SpinStep/sessions/378ceebc-cb82-487b-8c8a-abb3c5806f7c Co-authored-by: VoxleOne <119956342+VoxleOne@users.noreply.github.com> --- pyproject.toml | 17 +- spinstep/__init__.py | 2 +- spinstep/utils/__init__.py | 16 +- spinstep/utils/quaternion_math.py | 35 +--- spinstep/utils/quaternion_utils.py | 234 +++----------------------- tests/test_api.py | 253 +++++++++++++++++++++++++++++ tests/test_discrete_traversal.py | 7 +- tests/test_spinstep.py | 2 +- tests/test_traversal.py | 1 - 9 files changed, 320 insertions(+), 247 deletions(-) create mode 100644 tests/test_api.py diff --git a/pyproject.toml b/pyproject.toml index 87b50ba..b972696 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "spinstep" -version = "0.4.0a0" +version = "0.5.0a0" description = "Quaternion-based orientation control and traversal in observer-centered spherical coordinates." authors = [{ name = "Eraldo Marques", email = "eraldo.bernardo@gmail.com" }] readme = "README.md" @@ -20,6 +20,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Mathematics", + "Typing :: Typed", ] dependencies = [ "numpy>=1.22", @@ -43,11 +44,16 @@ dev = [ [project.urls] Repository = "https://github.com/VoxleOne/SpinStep" +Changelog = "https://github.com/VoxleOne/SpinStep/blob/main/CHANGELOG.md" +Documentation = "https://github.com/VoxleOne/SpinStep/tree/main/docs" [tool.setuptools.packages.find] include = ["spinstep*"] exclude = ["benchmark*", "demos*", "examples*", "tests*", "docs*"] +[tool.setuptools.package-data] +spinstep = ["py.typed"] + [tool.pytest.ini_options] testpaths = ["tests"] @@ -57,5 +63,14 @@ line-length = 88 [tool.mypy] strict = true +[[tool.mypy.overrides]] +module = [ + "scipy.*", + "sklearn.*", + "healpy.*", + "cupy.*", +] +ignore_missing_imports = true + [tool.ruff] line-length = 88 diff --git a/spinstep/__init__.py b/spinstep/__init__.py index 709e057..2936b73 100644 --- a/spinstep/__init__.py +++ b/spinstep/__init__.py @@ -30,7 +30,7 @@ from spinstep.math import quaternion_multiply, quaternion_distance, slerp """ -__version__ = "0.4.0a0" +__version__ = "0.5.0a0" # --- control layer (primary API) --- from .control.state import ( diff --git a/spinstep/utils/__init__.py b/spinstep/utils/__init__.py index 052def5..9ff0f8e 100644 --- a/spinstep/utils/__init__.py +++ b/spinstep/utils/__init__.py @@ -2,14 +2,18 @@ # Author: Eraldo B. Marques — Created: 2025-05-14 # See LICENSE.txt for full terms. This header must be retained in redistributions. -"""Utilities for quaternion math and array backend selection. +"""Utilities: array backend selection and backward-compatible quaternion re-exports. -This sub-package provides: +The ``get_array_module`` function is the primary utility provided here. +All quaternion math functions have moved to :mod:`spinstep.math` and are +re-exported here only for backward compatibility. -- :func:`~.array_backend.get_array_module` — NumPy / CuPy backend selection. -- :func:`~.quaternion_math.batch_quaternion_angle` — batch angular distances. -- Quaternion helpers in :mod:`~.quaternion_utils` (conversion, distance, - multiplication, etc.). +.. deprecated:: + For quaternion operations, import from :mod:`spinstep.math` instead. + +Example (preferred):: + + from spinstep.math import quaternion_multiply, quaternion_distance """ __all__ = [ diff --git a/spinstep/utils/quaternion_math.py b/spinstep/utils/quaternion_math.py index fbba4bf..347e42d 100644 --- a/spinstep/utils/quaternion_math.py +++ b/spinstep/utils/quaternion_math.py @@ -2,34 +2,17 @@ # Author: Eraldo B. Marques — Created: 2025-05-14 # See LICENSE.txt for full terms. This header must be retained in redistributions. -"""Batch quaternion angular distance computation.""" +"""Backward-compatible re-export of batch quaternion angle computation. -from __future__ import annotations - -from types import ModuleType -from typing import Any - -__all__ = ["batch_quaternion_angle"] +The implementation has moved to :mod:`spinstep.math.analysis`. +.. deprecated:: + Import from :mod:`spinstep.math` instead. +""" -def batch_quaternion_angle(qs1: Any, qs2: Any, xp: ModuleType) -> Any: - """Compute pairwise angular distances between two sets of quaternions. +from __future__ import annotations - Parameters - ---------- - qs1: - Array of shape ``(N, 4)`` — first set of quaternions. - qs2: - Array of shape ``(M, 4)`` — second set of quaternions. - xp: - Array module (:mod:`numpy` or :mod:`cupy`). +__all__ = ["batch_quaternion_angle"] - Returns - ------- - array - ``(N, M)`` array of angular distances in radians. - """ - dots = xp.abs(xp.dot(qs1, qs2.T)) - dots = xp.clip(dots, -1.0, 1.0) - angles = 2 * xp.arccos(dots) - return angles +# Re-export from canonical location +from spinstep.math.analysis import batch_quaternion_angle diff --git a/spinstep/utils/quaternion_utils.py b/spinstep/utils/quaternion_utils.py index 1d01aaa..d0bf24e 100644 --- a/spinstep/utils/quaternion_utils.py +++ b/spinstep/utils/quaternion_utils.py @@ -2,7 +2,15 @@ # Author: Eraldo Marques — Created: 2025-05-14 # See LICENSE.txt for full terms. This header must be retained in redistributions. -"""Quaternion conversion, distance, and manipulation utilities.""" +"""Backward-compatible re-exports of quaternion utilities. + +All quaternion functions have moved to :mod:`spinstep.math`. This module +re-exports them so that existing ``from spinstep.utils.quaternion_utils import …`` +statements continue to work. + +.. deprecated:: + Import from :mod:`spinstep.math` instead. +""" from __future__ import annotations @@ -21,209 +29,21 @@ "angle_between_directions", ] -from typing import List, Sequence - -import numpy as np -from numpy.typing import ArrayLike -from scipy.spatial.transform import Rotation as R - - -def quaternion_from_euler( - angles: Sequence[float], - order: str = "zyx", - degrees: bool = True, -) -> np.ndarray: - """Convert Euler angles to a quaternion ``[x, y, z, w]``.""" - return R.from_euler(order, angles, degrees=degrees).as_quat() - - -def quaternion_distance(q1: ArrayLike, q2: ArrayLike) -> float: - """Return the angular distance (radians) between two quaternions.""" - r1 = R.from_quat(q1) - r2 = R.from_quat(q2) - return float((r1.inv() * r2).magnitude()) - - -def rotate_quaternion(q: ArrayLike, rotation_step: ArrayLike) -> np.ndarray: - """Apply *rotation_step* to quaternion *q* and return the result.""" - r1 = R.from_quat(q) - step = R.from_quat(rotation_step) - return (r1 * step).as_quat() - - -def is_within_angle_threshold( - q_current: ArrayLike, - q_target: ArrayLike, - threshold_rad: float, -) -> bool: - """Check whether two quaternions are within *threshold_rad* of each other.""" - return quaternion_distance(q_current, q_target) < threshold_rad - - -def quaternion_conjugate(q: ArrayLike) -> np.ndarray: - """Return the conjugate of quaternion *q* ``[x, y, z, w]``.""" - return np.array([-q[0], -q[1], -q[2], q[3]]) - - -def quaternion_multiply(q1: ArrayLike, q2: ArrayLike) -> np.ndarray: - """Hamilton product of two quaternions in ``[x, y, z, w]`` order.""" - x1, y1, z1, w1 = q1 - x2, y2, z2, w2 = q2 - return np.array([ - w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, - w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, - w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, - w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, - ]) - - -def rotation_matrix_to_quaternion(m: ArrayLike) -> np.ndarray: - """Convert a 3×3 rotation matrix to a unit quaternion ``[x, y, z, w]``.""" - t = np.trace(m) - if t > 0: - s = np.sqrt(t + 1) * 2 - qw = 0.25 * s - qx = (m[2, 1] - m[1, 2]) / s - qy = (m[0, 2] - m[2, 0]) / s - qz = (m[1, 0] - m[0, 1]) / s - elif (m[0, 0] > m[1, 1]) and (m[0, 0] > m[2, 2]): - s = np.sqrt(1 + m[0, 0] - m[1, 1] - m[2, 2]) * 2 - qx = 0.25 * s - qw = (m[2, 1] - m[1, 2]) / s - qy = (m[0, 1] + m[1, 0]) / s - qz = (m[0, 2] + m[2, 0]) / s - elif m[1, 1] > m[2, 2]: - s = np.sqrt(1 + m[1, 1] - m[0, 0] - m[2, 2]) * 2 - qy = 0.25 * s - qw = (m[0, 2] - m[2, 0]) / s - qx = (m[0, 1] + m[1, 0]) / s - qz = (m[1, 2] + m[2, 1]) / s - else: - s = np.sqrt(1 + m[2, 2] - m[0, 0] - m[1, 1]) * 2 - qz = 0.25 * s - qw = (m[1, 0] - m[0, 1]) / s - qx = (m[0, 2] + m[2, 0]) / s - qy = (m[1, 2] + m[2, 1]) / s - q = np.array([qx, qy, qz, qw]) - n = np.linalg.norm(q) - return q / n if n > 1e-8 else np.array([0.0, 0.0, 0.0, 1.0]) - - -def forward_vector_from_quaternion(q: ArrayLike) -> np.ndarray: - """Extract the forward (look) direction from a quaternion. - - The forward direction is defined as ``[0, 0, -1]`` rotated by the - quaternion, following the convention where negative-Z is "forward". - - Args: - q: Quaternion ``[x, y, z, w]``. - - Returns: - Unit direction vector ``(3,)`` pointing forward. - """ - return R.from_quat(q).apply([0, 0, -1]) - - -def direction_to_quaternion(direction: ArrayLike) -> np.ndarray: - """Convert a 3D direction vector to an orientation quaternion. - - The returned quaternion represents the rotation that aligns the - default forward axis ``[0, 0, -1]`` with the given *direction*. - - Args: - direction: Target direction vector (does not need to be normalised). - - Returns: - Unit quaternion ``[x, y, z, w]``. - """ - d = np.asarray(direction, dtype=float) - norm = np.linalg.norm(d) - if norm < 1e-8: - return np.array([0.0, 0.0, 0.0, 1.0]) - d = d / norm - rot, _ = R.align_vectors([d], [[0, 0, -1]]) - return rot.as_quat() - - -def angle_between_directions(d1: ArrayLike, d2: ArrayLike) -> float: - """Compute the angular distance (radians) between two direction vectors. - - Args: - d1: First direction vector. - d2: Second direction vector. - - Returns: - Angle in radians in the range ``[0, π]``. - """ - v1 = np.asarray(d1, dtype=float) - v2 = np.asarray(d2, dtype=float) - n1 = np.linalg.norm(v1) - n2 = np.linalg.norm(v2) - if n1 < 1e-8 or n2 < 1e-8: - return 0.0 - cos_angle = np.dot(v1 / n1, v2 / n2) - cos_angle = np.clip(cos_angle, -1.0, 1.0) - return float(np.arccos(cos_angle)) - - -def get_relative_spin(nf: object, nt: object) -> np.ndarray: - """Return the relative quaternion rotation from node *nf* to node *nt*. - - Both nodes must have an ``.orientation`` attribute storing a quaternion - ``[x, y, z, w]``. - """ - qfc = quaternion_conjugate(nf.orientation) # type: ignore[union-attr] - qr = quaternion_multiply(qfc, nt.orientation) # type: ignore[union-attr] - n = np.linalg.norm(qr) - return qr / n if n > 1e-8 else np.array([0.0, 0.0, 0.0, 1.0]) - - -def get_unique_relative_spins( - nodes: Sequence[object], - nside: int, - nest: bool, - threshold: float = 1e-3, -) -> List[np.ndarray]: - """Compute unique relative rotations between HEALPix neighbours. - - Requires the ``healpy`` package. - - Parameters - ---------- - nodes: - Sequence of node objects with ``.orientation`` attributes. - nside: - HEALPix *nside* parameter. - nest: - Whether to use the NESTED pixel ordering. - threshold: - Angular threshold (radians) for considering two rotations identical. - """ - try: - import healpy as hp - except ImportError: - raise ImportError( - "healpy is required for get_unique_relative_spins(). " - "Install it with: pip install healpy" - ) - spins: List[np.ndarray] = [] - NPIX = hp.nside2npix(nside) - for i in range(NPIX): - nf = nodes[i] - nidx = hp.get_all_neighbours(nside, i, nest=nest) - for idx in nidx: - if idx != -1: - q = get_relative_spin(nf, nodes[idx]) - if q[3] < 0: - q = -q # Canonical form (w >= 0) - is_uniq = True - for s_q in spins: - dot = np.abs(np.dot(q, s_q)) - dot = np.clip(dot, -1, 1) - angle = 2 * np.arccos(dot) - if angle < threshold: - is_uniq = False - break - if is_uniq: - spins.append(q) - return spins +# Re-export from canonical locations in spinstep.math +from spinstep.math.core import quaternion_conjugate, quaternion_multiply +from spinstep.math.geometry import ( + angle_between_directions, + direction_to_quaternion, + forward_vector_from_quaternion, + is_within_angle_threshold, + quaternion_distance, + rotate_quaternion, +) +from spinstep.math.conversions import ( + quaternion_from_euler, + rotation_matrix_to_quaternion, +) +from spinstep.math.analysis import ( + get_relative_spin, + get_unique_relative_spins, +) diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..f02540c --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,253 @@ +# test_api.py — SpinStep Test Suite — MIT License +# Author: SpinStep Contributors +# See LICENSE.txt for full terms. + +"""Tests for public API stability, exports, and packaging markers.""" + +import importlib +import re + +import pytest + + +class TestPackageMetadata: + """Verify package-level metadata and markers.""" + + def test_version_format(self) -> None: + """Version string follows PEP 440.""" + import spinstep + + assert hasattr(spinstep, "__version__") + # PEP 440: N.N.NaN, N.N.N, etc. + assert re.match( + r"^\d+\.\d+\.\d+(a\d+|b\d+|rc\d+)?$", spinstep.__version__ + ), f"Invalid version format: {spinstep.__version__}" + + def test_py_typed_marker(self) -> None: + """PEP 561 py.typed marker exists.""" + import pathlib + + import spinstep + + pkg_dir = pathlib.Path(spinstep.__file__).parent + assert (pkg_dir / "py.typed").exists(), "Missing py.typed marker" + + +class TestTopLevelExports: + """Verify spinstep.__all__ contains expected symbols.""" + + EXPECTED_EXPORTS = [ + # control + "OrientationState", + "ControlCommand", + "integrate_state", + "compute_orientation_error", + "OrientationController", + "ProportionalOrientationController", + "PIDOrientationController", + "OrientationTrajectory", + "TrajectoryInterpolator", + "TrajectoryController", + # math + "slerp", + # traversal + "Node", + "QuaternionDepthIterator", + "DiscreteOrientationSet", + "DiscreteQuaternionIterator", + ] + + def test_all_defined(self) -> None: + import spinstep + + assert hasattr(spinstep, "__all__") + assert isinstance(spinstep.__all__, list) + + @pytest.mark.parametrize("name", EXPECTED_EXPORTS) + def test_export_importable(self, name: str) -> None: + """Every name in __all__ is importable from the top-level package.""" + import spinstep + + assert name in spinstep.__all__, f"{name} missing from __all__" + assert hasattr(spinstep, name), f"{name} not accessible on spinstep" + + +class TestMathSubpackageExports: + """Verify spinstep.math.__all__ contains expected symbols.""" + + EXPECTED_EXPORTS = [ + "quaternion_multiply", + "quaternion_conjugate", + "quaternion_normalize", + "quaternion_inverse", + "slerp", + "squad", + "quaternion_distance", + "is_within_angle_threshold", + "forward_vector_from_quaternion", + "direction_to_quaternion", + "angle_between_directions", + "rotate_quaternion", + "quaternion_from_euler", + "rotation_matrix_to_quaternion", + "quaternion_from_rotvec", + "quaternion_to_rotvec", + "batch_quaternion_angle", + "angular_velocity_from_quaternions", + "get_relative_spin", + "get_unique_relative_spins", + "clamp_rotation_angle", + ] + + def test_all_defined(self) -> None: + from spinstep import math + + assert hasattr(math, "__all__") + + @pytest.mark.parametrize("name", EXPECTED_EXPORTS) + def test_export_importable(self, name: str) -> None: + from spinstep import math + + assert name in math.__all__, f"{name} missing from math.__all__" + assert hasattr(math, name), f"{name} not accessible on spinstep.math" + + +class TestControlSubpackageExports: + """Verify spinstep.control.__all__ contains expected symbols.""" + + EXPECTED_EXPORTS = [ + "OrientationState", + "ControlCommand", + "integrate_state", + "compute_orientation_error", + "OrientationController", + "ProportionalOrientationController", + "PIDOrientationController", + "OrientationTrajectory", + "TrajectoryInterpolator", + "TrajectoryController", + ] + + def test_all_defined(self) -> None: + from spinstep import control + + assert hasattr(control, "__all__") + + @pytest.mark.parametrize("name", EXPECTED_EXPORTS) + def test_export_importable(self, name: str) -> None: + from spinstep import control + + assert name in control.__all__, f"{name} missing from control.__all__" + assert hasattr(control, name), f"{name} not accessible on spinstep.control" + + +class TestTraversalSubpackageExports: + """Verify spinstep.traversal.__all__ contains expected symbols.""" + + EXPECTED_EXPORTS = [ + "Node", + "QuaternionDepthIterator", + "DiscreteOrientationSet", + "DiscreteQuaternionIterator", + ] + + def test_all_defined(self) -> None: + from spinstep import traversal + + assert hasattr(traversal, "__all__") + + @pytest.mark.parametrize("name", EXPECTED_EXPORTS) + def test_export_importable(self, name: str) -> None: + from spinstep import traversal + + assert name in traversal.__all__, f"{name} missing from traversal.__all__" + assert hasattr( + traversal, name + ), f"{name} not accessible on spinstep.traversal" + + +class TestUtilsBackwardCompat: + """Verify that utils/ re-exports still work for backward compatibility.""" + + COMPAT_NAMES = [ + "get_array_module", + "batch_quaternion_angle", + "quaternion_from_euler", + "quaternion_distance", + "rotate_quaternion", + "is_within_angle_threshold", + "quaternion_conjugate", + "quaternion_multiply", + "rotation_matrix_to_quaternion", + "get_relative_spin", + "get_unique_relative_spins", + "forward_vector_from_quaternion", + "direction_to_quaternion", + "angle_between_directions", + ] + + @pytest.mark.parametrize("name", COMPAT_NAMES) + def test_utils_reexport(self, name: str) -> None: + """Functions remain importable from spinstep.utils.""" + from spinstep import utils + + assert hasattr(utils, name), f"{name} not accessible on spinstep.utils" + + def test_utils_quaternion_functions_are_math_functions(self) -> None: + """Verify that utils re-exports point to the same objects as math.""" + from spinstep import math as sp_math + from spinstep import utils as sp_utils + + shared = [ + "quaternion_multiply", + "quaternion_conjugate", + "quaternion_distance", + "quaternion_from_euler", + "rotation_matrix_to_quaternion", + "rotate_quaternion", + "is_within_angle_threshold", + "forward_vector_from_quaternion", + "direction_to_quaternion", + "angle_between_directions", + "batch_quaternion_angle", + "get_relative_spin", + "get_unique_relative_spins", + ] + for name in shared: + assert getattr(sp_utils, name) is getattr(sp_math, name), ( + f"utils.{name} is not the same object as math.{name}" + ) + + +class TestSubpackagesImportable: + """Verify all subpackages can be imported.""" + + @pytest.mark.parametrize( + "module", + [ + "spinstep", + "spinstep.math", + "spinstep.math.core", + "spinstep.math.geometry", + "spinstep.math.conversions", + "spinstep.math.interpolation", + "spinstep.math.analysis", + "spinstep.math.constraints", + "spinstep.control", + "spinstep.control.state", + "spinstep.control.controllers", + "spinstep.control.trajectory", + "spinstep.traversal", + "spinstep.traversal.node", + "spinstep.traversal.continuous", + "spinstep.traversal.discrete", + "spinstep.traversal.discrete_iterator", + "spinstep.utils", + "spinstep.utils.array_backend", + "spinstep.utils.quaternion_math", + "spinstep.utils.quaternion_utils", + ], + ) + def test_importable(self, module: str) -> None: + """Every listed module can be imported without error.""" + importlib.import_module(module) diff --git a/tests/test_discrete_traversal.py b/tests/test_discrete_traversal.py index 40a775a..9cd9872 100644 --- a/tests/test_discrete_traversal.py +++ b/tests/test_discrete_traversal.py @@ -11,7 +11,6 @@ cuda_available = False import numpy as np -from scipy.spatial.transform import Rotation as R # Import the modules under test from spinstep.traversal.discrete import DiscreteOrientationSet @@ -106,7 +105,7 @@ def test_cuda_support(self): try: cuda_set = DiscreteOrientationSet([[0, 0, 0, 1]], use_cuda=True) # If this succeeds, basic CUDA import worked - assert cuda_set.use_cuda == True + assert cuda_set.use_cuda is True # Test as_numpy() method for GPU->CPU transfer cpu_array = cuda_set.as_numpy() @@ -261,8 +260,8 @@ def test_full_pipeline(): # Create a simple tree root = Node("root", [0, 0, 0, 1]) - node_a = root.add_child(Node("A", [1, 0, 0, 0])) - node_b = root.add_child(Node("B", [0, 1, 0, 0])) + root.add_child(Node("A", [1, 0, 0, 0])) + root.add_child(Node("B", [0, 1, 0, 0])) # Create iterator and traverse iterator = DiscreteQuaternionIterator( diff --git a/tests/test_spinstep.py b/tests/test_spinstep.py index affc816..f2e61b5 100644 --- a/tests/test_spinstep.py +++ b/tests/test_spinstep.py @@ -6,7 +6,7 @@ import numpy as np try: - import cupy as cp + import cupy as cp # noqa: F401 HAS_CUPY = True except ImportError: HAS_CUPY = False diff --git a/tests/test_traversal.py b/tests/test_traversal.py index c7fd841..b8b42a5 100644 --- a/tests/test_traversal.py +++ b/tests/test_traversal.py @@ -4,7 +4,6 @@ import pytest import numpy as np -from scipy.spatial.transform import Rotation as R from spinstep.traversal.node import Node from spinstep.traversal.continuous import QuaternionDepthIterator From 15c9523e6b88074bd9b5361c6e0538db91838e2d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Apr 2026 18:22:17 +0000 Subject: [PATCH 4/5] Add NodeProtocol typing, Node.add_child(), update CI triggers and CHANGELOG - Add NodeProtocol (typing.Protocol) in spinstep.math.analysis, replacing `object` type hints in get_relative_spin() and get_unique_relative_spins() - Add Node.add_child() convenience method - Update CI workflow to trigger on feature/* branches - Update CHANGELOG.md with 0.5.0a0 release notes - Add tests for NodeProtocol and Node.add_child() - 228 tests pass, 5 skipped Agent-Logs-Url: https://github.com/VoxleOne/SpinStep/sessions/b6207769-cd62-43d2-83dd-af4b5a36fdba Co-authored-by: VoxleOne <119956342+VoxleOne@users.noreply.github.com> --- .github/workflows/ci.yml | 4 ++-- CHANGELOG.md | 49 ++++++++++++++++++++++++++++++++++++++ spinstep/math/__init__.py | 3 +++ spinstep/math/analysis.py | 21 ++++++++++++---- spinstep/traversal/node.py | 15 ++++++++++-- tests/test_api.py | 46 +++++++++++++++++++++++++++++++++++ 6 files changed, 129 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fca6f8f..1930765 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,9 +2,9 @@ name: SpinStep CI on: push: - branches: [main, dev, 'features/cuda'] + branches: [main, dev, 'features/cuda', 'feature/*'] pull_request: - branches: [main, dev] + branches: [main, dev, 'feature/*'] permissions: contents: read diff --git a/CHANGELOG.md b/CHANGELOG.md index a5b1689..1105c75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,55 @@ and uses [Semantic Versioning](https://semver.org/). --- +## [0.5.0a0] – 2026-04-10 + +### Added +- **Control subsystem** (`spinstep.control`): + - `OrientationState` — observer-centered spherical state (quaternion + distance) + - `ControlCommand` — angular + radial velocity commands + - `ProportionalOrientationController` — P-controller with rate limiting + - `PIDOrientationController` — PID controller with anti-windup + - `OrientationTrajectory`, `TrajectoryInterpolator`, `TrajectoryController` — + waypoint-based trajectory tracking with SLERP interpolation + - `integrate_state()`, `compute_orientation_error()` — state utilities +- **Math library** (`spinstep.math`): + - `core` — `quaternion_multiply`, `quaternion_conjugate`, `quaternion_normalize`, + `quaternion_inverse` + - `interpolation` — `slerp` (spherical linear), `squad` (spherical cubic) + - `geometry` — `quaternion_distance`, `rotate_quaternion`, + `forward_vector_from_quaternion`, `direction_to_quaternion`, + `angle_between_directions`, `is_within_angle_threshold` + - `conversions` — `quaternion_from_euler`, `rotation_matrix_to_quaternion`, + `quaternion_from_rotvec`, `quaternion_to_rotvec` + - `analysis` — `batch_quaternion_angle`, `angular_velocity_from_quaternions`, + `get_relative_spin`, `get_unique_relative_spins`, `NodeProtocol` + - `constraints` — `clamp_rotation_angle` +- `NodeProtocol` — `typing.Protocol` for any object with `.orientation` attribute, + replacing `object` type hints in `get_relative_spin()` and + `get_unique_relative_spins()` +- `Node.add_child()` method for ergonomic tree building +- Comprehensive API stability tests (`test_api.py`) — 92+ parametrized tests + covering all subpackage exports and backward compatibility +- `Typing :: Typed` classifier in `pyproject.toml` +- `Changelog` and `Documentation` project URLs +- `[[tool.mypy.overrides]]` for `scipy.*`, `sklearn.*`, `healpy.*`, `cupy.*` +- CI workflow now triggers on `feature/*` branches + +### Changed +- **Traversal classes moved** to `spinstep.traversal` subpackage + (`node.py`, `continuous.py`, `discrete.py`, `discrete_iterator.py`). + Top-level imports via `from spinstep import Node` remain backward-compatible. +- **`spinstep.utils` is now a backward-compatible re-export layer**. + All 13 quaternion functions now delegate to `spinstep.math` — the single + source of truth. Direct `from spinstep.utils import …` continues to work. +- Version bumped to `0.5.0a0` + +### Deprecated +- Direct imports from `spinstep.utils.quaternion_utils` and + `spinstep.utils.quaternion_math`. Use `spinstep.math` instead. + +--- + ## [0.3.0a0] – 2026-03-26 ### Added diff --git a/spinstep/math/__init__.py b/spinstep/math/__init__.py index a4b30f2..88661c5 100644 --- a/spinstep/math/__init__.py +++ b/spinstep/math/__init__.py @@ -42,6 +42,8 @@ "get_unique_relative_spins", # constraints "clamp_rotation_angle", + # protocols + "NodeProtocol", ] from .core import ( @@ -66,6 +68,7 @@ rotation_matrix_to_quaternion, ) from .analysis import ( + NodeProtocol, angular_velocity_from_quaternions, batch_quaternion_angle, get_relative_spin, diff --git a/spinstep/math/analysis.py b/spinstep/math/analysis.py index cd573db..05e7ea6 100644 --- a/spinstep/math/analysis.py +++ b/spinstep/math/analysis.py @@ -14,7 +14,7 @@ ] from types import ModuleType -from typing import Any, List, Sequence +from typing import Any, List, Protocol, Sequence, runtime_checkable import numpy as np from numpy.typing import ArrayLike @@ -22,6 +22,17 @@ from .core import quaternion_conjugate, quaternion_multiply +@runtime_checkable +class NodeProtocol(Protocol): + """Structural type for objects accepted by :func:`get_relative_spin`. + + Any object with an ``orientation`` attribute holding a quaternion + ``[x, y, z, w]`` satisfies this protocol. + """ + + orientation: np.ndarray + + def batch_quaternion_angle(qs1: Any, qs2: Any, xp: ModuleType) -> Any: """Compute pairwise angular distances between two sets of quaternions. @@ -69,7 +80,7 @@ def angular_velocity_from_quaternions( return rotvec / dt -def get_relative_spin(nf: object, nt: object) -> np.ndarray: +def get_relative_spin(nf: NodeProtocol, nt: NodeProtocol) -> np.ndarray: """Return the relative quaternion rotation from node *nf* to node *nt*. Both nodes must have an ``.orientation`` attribute storing a quaternion @@ -82,14 +93,14 @@ def get_relative_spin(nf: object, nt: object) -> np.ndarray: Returns: Unit quaternion representing the relative rotation. """ - qfc = quaternion_conjugate(nf.orientation) # type: ignore[union-attr] - qr = quaternion_multiply(qfc, nt.orientation) # type: ignore[union-attr] + qfc = quaternion_conjugate(nf.orientation) + qr = quaternion_multiply(qfc, nt.orientation) n = np.linalg.norm(qr) return qr / n if n > 1e-8 else np.array([0.0, 0.0, 0.0, 1.0]) def get_unique_relative_spins( - nodes: Sequence[object], + nodes: Sequence[NodeProtocol], nside: int, nest: bool, threshold: float = 1e-3, diff --git a/spinstep/traversal/node.py b/spinstep/traversal/node.py index d7d9e20..b7d1c78 100644 --- a/spinstep/traversal/node.py +++ b/spinstep/traversal/node.py @@ -40,8 +40,7 @@ class Node: from spinstep import Node root = Node("root", [0, 0, 0, 1]) - child = Node("child", [0.2588, 0, 0, 0.9659]) - root.children.append(child) + child = root.add_child(Node("child", [0.2588, 0, 0, 0.9659])) """ name: str @@ -64,5 +63,17 @@ def __init__( self.name = name self.children = list(children) if children else [] + def add_child(self, child: "Node") -> "Node": + """Append *child* to this node's children and return it. + + Args: + child: The child node to add. + + Returns: + The same *child* node, for convenience. + """ + self.children.append(child) + return child + def __repr__(self) -> str: return f"Node({self.name!r}, orientation={self.orientation.tolist()})" diff --git a/tests/test_api.py b/tests/test_api.py index f02540c..940ff16 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -97,6 +97,7 @@ class TestMathSubpackageExports: "get_relative_spin", "get_unique_relative_spins", "clamp_rotation_angle", + "NodeProtocol", ] def test_all_defined(self) -> None: @@ -251,3 +252,48 @@ class TestSubpackagesImportable: def test_importable(self, module: str) -> None: """Every listed module can be imported without error.""" importlib.import_module(module) + + +class TestNodeProtocol: + """Verify NodeProtocol structural typing works correctly.""" + + def test_node_satisfies_protocol(self) -> None: + """spinstep.traversal.Node satisfies NodeProtocol.""" + from spinstep.math.analysis import NodeProtocol + from spinstep.traversal.node import Node + + node = Node("test", [0, 0, 0, 1]) + assert isinstance(node, NodeProtocol) + + def test_custom_class_satisfies_protocol(self) -> None: + """Any class with .orientation satisfies NodeProtocol.""" + import numpy as np + + from spinstep.math.analysis import NodeProtocol + + class MyNode: + def __init__(self) -> None: + self.orientation = np.array([0.0, 0.0, 0.0, 1.0]) + + assert isinstance(MyNode(), NodeProtocol) + + +class TestNodeAddChild: + """Verify Node.add_child() convenience method.""" + + def test_add_child_appends(self) -> None: + from spinstep.traversal.node import Node + + root = Node("root", [0, 0, 0, 1]) + child = Node("child", [1, 0, 0, 0]) + result = root.add_child(child) + assert child in root.children + assert result is child + + def test_add_child_returns_child(self) -> None: + from spinstep.traversal.node import Node + + root = Node("root", [0, 0, 0, 1]) + child = root.add_child(Node("child", [0, 1, 0, 0])) + assert child.name == "child" + assert len(root.children) == 1 From c8c603389c72760686b8c18e7aaee0828025d1c2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Apr 2026 18:23:42 +0000 Subject: [PATCH 5/5] =?UTF-8?q?Fix=20CHANGELOG=20test=20count:=2092+=20?= =?UTF-8?q?=E2=86=92=2097=20(actual=20count)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Agent-Logs-Url: https://github.com/VoxleOne/SpinStep/sessions/b6207769-cd62-43d2-83dd-af4b5a36fdba Co-authored-by: VoxleOne <119956342+VoxleOne@users.noreply.github.com> --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1105c75..f4719ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,7 +38,7 @@ and uses [Semantic Versioning](https://semver.org/). replacing `object` type hints in `get_relative_spin()` and `get_unique_relative_spins()` - `Node.add_child()` method for ergonomic tree building -- Comprehensive API stability tests (`test_api.py`) — 92+ parametrized tests +- Comprehensive API stability tests (`test_api.py`) — 97 parametrized tests covering all subpackage exports and backward compatibility - `Typing :: Typed` classifier in `pyproject.toml` - `Changelog` and `Documentation` project URLs