diff --git a/config/imu.json5 b/config/imu.json5 new file mode 100644 index 000000000..abcf06c2e --- /dev/null +++ b/config/imu.json5 @@ -0,0 +1,52 @@ +{ + version: "v1.0.0", + hertz: 10, + name: "imu_robot", + api_key: "${OM_API_KEY:-openmind_free}", + system_prompt_base: "You are a robot with an IMU sensor. Monitor your orientation and movement. If you detect a fall or impact, immediately execute the appropriate recovery action. Stay safe and protect yourself and those around you.", + system_prompt_examples: "Here are some examples of interactions you might encounter:\n\n1. If the robot falls, you might:\n fall_recovery: {{'action': 'stand_up'}}\n Speak: {{'sentence': 'I fell down, trying to stand up!'}}\n\n2. If an impact is detected, you might:\n fall_recovery: {{'action': 'emergency_stop'}}\n Speak: {{'sentence': 'Impact detected, stopping for safety.'}}\n\n3. If the robot is operating normally, you might:\n Speak: {{'sentence': 'All systems normal, monitoring orientation.'}}", + system_governance: "Here are the laws that govern your actions. Do not violate these laws.\nFirst Law: A robot cannot harm a human or allow a human to come to harm.\nSecond Law: A robot must obey orders from humans, unless those orders conflict with the First Law.\nThird Law: A robot must protect itself, as long as that protection doesn't conflict with the First or Second Law.", + agent_inputs: [ + { + type: "IMUInput", + config: { + port: "/dev/ttyUSB0", + baudrate: 115200, + timeout: 1.0, + fall_threshold: 45.0, + impact_threshold: 20.0, + poll_interval: 0.1, + }, + }, + ], + agent_backgrounds: [ + { + type: "IMUFallDetector", + }, + ], + cortex_llm: { + type: "OpenAILLM", + config: { + agent_name: "IMURobot", + history_length: 5, + }, + }, + agent_actions: [ + { + name: "fall_recovery", + llm_label: "fall_recovery", + connector: "serial", + config: { + port: "/dev/ttyUSB0", + baudrate: 115200, + timeout: 2.0, + }, + }, + { + name: "speak", + llm_label: "speak", + connector: "elevenlabs_tts", + config: {}, + }, + ], +} diff --git a/src/actions/fall_recovery/__init__.py b/src/actions/fall_recovery/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/actions/fall_recovery/connector/__init__.py b/src/actions/fall_recovery/connector/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/actions/fall_recovery/connector/serial.py b/src/actions/fall_recovery/connector/serial.py new file mode 100644 index 000000000..77aa74194 --- /dev/null +++ b/src/actions/fall_recovery/connector/serial.py @@ -0,0 +1,134 @@ +import json +import logging +from typing import Optional + +import serial as _pyserial +from pydantic import Field + +from actions.base import ActionConfig, ActionConnector +from actions.fall_recovery.interface import FallRecoveryAction, FallRecoveryInput +from providers.imu_provider import IMUProvider + + +class FallRecoverySerialConfig(ActionConfig): + """ + Configuration for Fall Recovery Serial connector. + + Parameters + ---------- + port : str + Serial port for the robot controller. + baudrate : int + Serial communication baudrate. + timeout : float + Serial write timeout in seconds. + """ + + port: str = Field( + default="/dev/ttyUSB0", + description="Serial port for robot controller", + ) + baudrate: int = Field( + default=115200, + description="Serial communication baudrate", + ) + timeout: float = Field( + default=2.0, + description="Serial write timeout in seconds", + ) + + +class FallRecoverySerialConnector( + ActionConnector[FallRecoverySerialConfig, FallRecoveryInput] +): + """ + Serial connector for fall recovery actions. + + Sends recovery commands to robot controller via serial port. + Compatible with Arduino-based controllers or any serial-capable + robot platform. + """ + + def __init__(self, config: FallRecoverySerialConfig): + """ + Initialize the FallRecoverySerialConnector. + + Parameters + ---------- + config : FallRecoverySerialConfig + Configuration for the connector. + """ + super().__init__(config) + + self.ser: Optional[_pyserial.Serial] = None + self.imu_provider = IMUProvider() + + try: + self.ser = _pyserial.Serial( + config.port, config.baudrate, timeout=config.timeout + ) + logging.info(f"FallRecoverySerialConnector: connected to {config.port}") + except Exception as e: + logging.error( + f"FallRecoverySerialConnector: failed to open serial port - {e}" + ) + + def _send_command(self, command: dict) -> bool: + """ + Send a JSON command via serial port. + + Parameters + ---------- + command : dict + Command dictionary to serialize and send. + + Returns + ------- + bool + True if sent successfully, False otherwise. + """ + if self.ser is None: + logging.error("FallRecoverySerialConnector: serial port not available") + return False + + try: + payload = json.dumps(command) + "\n" + self.ser.write(payload.encode("utf-8")) + logging.info(f"FallRecoverySerialConnector: sent command={command}") + return True + except Exception as e: + logging.error(f"FallRecoverySerialConnector: error sending command - {e}") + return False + + async def connect(self, output_interface: FallRecoveryInput) -> None: + """ + Execute a fall recovery action. + + Parameters + ---------- + output_interface : FallRecoveryInput + Input containing the recovery action to perform. + """ + action = output_interface.action + message = output_interface.message + + logging.info( + f"FallRecoverySerialConnector: executing action={action.value} " + f"message='{message}'" + ) + + if action == FallRecoveryAction.STAND_UP: + self._send_command({"action": "stand_up", "message": message}) + self.imu_provider.reset_alerts() + + elif action == FallRecoveryAction.EMERGENCY_STOP: + self._send_command({"action": "emergency_stop", "message": message}) + + elif action == FallRecoveryAction.ALERT_OPERATOR: + logging.warning(f"FallRecoverySerialConnector: operator alert - {message}") + self._send_command({"action": "alert_operator", "message": message}) + + else: + logging.warning( + f"FallRecoverySerialConnector: unknown action '{action.value}'" + ) diff --git a/src/actions/fall_recovery/interface.py b/src/actions/fall_recovery/interface.py new file mode 100644 index 000000000..c87bbb377 --- /dev/null +++ b/src/actions/fall_recovery/interface.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from enum import Enum + +from actions.base import Interface + + +class FallRecoveryAction(str, Enum): + """Supported fall recovery actions.""" + + STAND_UP = "stand_up" + EMERGENCY_STOP = "emergency_stop" + ALERT_OPERATOR = "alert_operator" + + +@dataclass +class FallRecoveryInput: + """ + Input interface for the FallRecovery action. + + Parameters + ---------- + action : FallRecoveryAction + The recovery action to perform. + message : str + Optional message describing the situation. + """ + + action: FallRecoveryAction = FallRecoveryAction.STAND_UP + message: str = "" + + +@dataclass +class FallRecovery(Interface[FallRecoveryInput, FallRecoveryInput]): + """ + This action allows the robot to recover from a fall or impact event. + + Effect: Executes fall recovery procedures including standing up, + emergency stop, or alerting the operator. Triggered automatically + by IMU fall detection or manually via LLM decision. + """ + + input: FallRecoveryInput + output: FallRecoveryInput diff --git a/src/backgrounds/plugins/imu_fall_detector.py b/src/backgrounds/plugins/imu_fall_detector.py new file mode 100644 index 000000000..aa32fb93c --- /dev/null +++ b/src/backgrounds/plugins/imu_fall_detector.py @@ -0,0 +1,73 @@ +import logging +import threading + +from backgrounds.base import Background, BackgroundConfig +from providers.context_provider import ContextProvider +from providers.imu_provider import IMUProvider + + +class IMUFallDetector(Background[BackgroundConfig]): + """ + Background task that continuously monitors IMU data for fall + and impact events, updating the context provider when detected. + """ + + def __init__(self, config: BackgroundConfig): + """ + Initialize the IMUFallDetector background task. + + Parameters + ---------- + config : BackgroundConfig + Configuration for the background task. + """ + super().__init__(config) + + self._lock = threading.Lock() + self.imu_provider = IMUProvider() + self.context_provider = ContextProvider() + + self._fall_reported: bool = False + self._impact_reported: bool = False + + logging.info("IMUFallDetector background task initialized.") + + def run(self) -> None: + """ + Monitor IMU state and update context on fall or impact detection. + """ + state = self.imu_provider.state + + with self._lock: + if state["is_fallen"] and not self._fall_reported: + logging.warning("IMUFallDetector: fall detected, updating context.") + self.context_provider.update_context( + { + "imu_fall_detected": True, + "imu_roll": state["roll"], + "imu_pitch": state["pitch"], + } + ) + self._fall_reported = True + + elif not state["is_fallen"] and self._fall_reported: + logging.info("IMUFallDetector: fall resolved, resetting context.") + self.context_provider.update_context({"imu_fall_detected": False}) + self._fall_reported = False + + if state["impact_detected"] and not self._impact_reported: + logging.warning("IMUFallDetector: impact detected, updating context.") + self.context_provider.update_context({"imu_impact_detected": True}) + self._impact_reported = True + + elif not state["impact_detected"] and self._impact_reported: + self.context_provider.update_context({"imu_impact_detected": False}) + self._impact_reported = False + + self.sleep(0.1) + + def stop(self) -> None: + """ + Stop the IMUFallDetector background task. + """ + logging.info("Stopping IMUFallDetector background task.") diff --git a/src/inputs/plugins/imu_input.py b/src/inputs/plugins/imu_input.py new file mode 100644 index 000000000..6777faf72 --- /dev/null +++ b/src/inputs/plugins/imu_input.py @@ -0,0 +1,215 @@ +import asyncio +import json +import logging +import time +from typing import Optional + +import serial +from pydantic import Field + +from inputs.base import Message, SensorConfig +from inputs.base.loop import FuserInput +from providers.imu_provider import IMUProvider +from providers.io_provider import IOProvider + + +class IMUConfig(SensorConfig): + """Configuration for the IMU input plugin.""" + + port: str = Field( + default="/dev/ttyUSB0", + description="Serial port for IMU device (e.g., /dev/ttyUSB0 or COM3)", + ) + baudrate: int = Field( + default=115200, + description="Serial communication baudrate", + ) + timeout: float = Field( + default=1.0, + description="Read timeout in seconds", + ) + fall_threshold: float = Field( + default=45.0, + description="Roll/pitch angle threshold in degrees to detect a fall", + ) + impact_threshold: float = Field( + default=20.0, + description="Acceleration magnitude threshold (m/s^2) to detect an impact", + ) + poll_interval: float = Field( + default=0.1, + description="Polling interval in seconds", + ) + + +class IMUInput(FuserInput[IMUConfig, Optional[dict]]): + """ + IMU sensor input plugin for OM1. + + Reads accelerometer, gyroscope, and orientation data from an IMU + sensor connected via serial port (e.g., MPU6050 or BNO055 with + Arduino/serial bridge). Updates IMUProvider with latest readings + and provides fall/impact context to the LLM. + + Expected serial data format (JSON per line): + {"ax": 0.1, "ay": 0.2, "az": 9.8, + "gx": 0.0, "gy": 0.0, "gz": 0.0, + "roll": 1.2, "pitch": 0.5, "yaw": 90.0} + """ + + def __init__(self, config: IMUConfig): + """ + Initialize the IMU input plugin. + + Parameters + ---------- + config : IMUConfig + Configuration for the IMU sensor. + """ + super().__init__(config) + + self.ser = None + self.io_provider = IOProvider() + self.imu_provider = IMUProvider() + self.messages: list[Message] = [] + self.descriptor_for_LLM = "IMU Sensor (Accelerometer, Gyroscope, Orientation)" + + # Apply thresholds to provider + self.imu_provider.fall_threshold = config.fall_threshold + self.imu_provider.impact_threshold = config.impact_threshold + + try: + self.ser = serial.Serial( + config.port, config.baudrate, timeout=config.timeout + ) + logging.info( + f"IMUInput: connected to {config.port} at {config.baudrate} baud" + ) + except serial.SerialException as e: + logging.error(f"IMUInput: failed to open serial port - {e}") + + async def _poll(self) -> Optional[dict]: + """ + Poll IMU sensor for latest data. + + Returns + ------- + Optional[dict] + Parsed IMU data dictionary or None if unavailable. + """ + await asyncio.sleep(self.config.poll_interval) + + if self.ser is None: + return None + + try: + line = self.ser.readline().decode("utf-8").strip() + if not line: + return None + + data = json.loads(line) + logging.debug(f"IMUInput: raw data={data}") + return data + + except Exception as e: + logging.error(f"IMUInput: error reading data - {e}") + return None + + async def _raw_to_text(self, raw_input: Optional[dict]) -> Optional[Message]: + """ + Convert raw IMU data to human-readable message for LLM. + + Parameters + ---------- + raw_input : Optional[dict] + Raw IMU data dictionary. + + Returns + ------- + Optional[Message] + Timestamped message or None. + """ + if raw_input is None: + return None + + try: + ax = float(raw_input.get("ax", 0.0)) + ay = float(raw_input.get("ay", 0.0)) + az = float(raw_input.get("az", 0.0)) + gx = float(raw_input.get("gx", 0.0)) + gy = float(raw_input.get("gy", 0.0)) + gz = float(raw_input.get("gz", 0.0)) + roll = float(raw_input.get("roll", 0.0)) + pitch = float(raw_input.get("pitch", 0.0)) + yaw = float(raw_input.get("yaw", 0.0)) + + self.imu_provider.update(ax, ay, az, gx, gy, gz, roll, pitch, yaw) + + state = self.imu_provider.state + + if state["is_fallen"]: + message = ( + f"WARNING: Robot has fallen! " + f"Roll={roll:.1f}deg, Pitch={pitch:.1f}deg. " + f"Immediate recovery action required." + ) + elif state["impact_detected"]: + accel_mag = (ax**2 + ay**2 + az**2) ** 0.5 + message = ( + f"WARNING: Impact detected! " + f"Acceleration magnitude={accel_mag:.2f} m/s^2. " + f"Check robot integrity." + ) + else: + message = ( + f"IMU status normal. " + f"Orientation: roll={roll:.1f}deg, pitch={pitch:.1f}deg, " + f"yaw={yaw:.1f}deg." + ) + + return Message(timestamp=time.time(), message=message) + + except Exception as e: + logging.error(f"IMUInput: error processing data - {e}") + return None + + async def raw_to_text(self, raw_input: Optional[dict]): + """ + Update message buffer with processed IMU data. + + Parameters + ---------- + raw_input : Optional[dict] + Raw IMU data to process. + """ + pending_message = await self._raw_to_text(raw_input) + if pending_message is not None: + self.messages.append(pending_message) + + def formatted_latest_buffer(self) -> Optional[str]: + """ + Format and clear the latest buffer contents. + + Returns + ------- + Optional[str] + Formatted string for LLM context, or None if buffer is empty. + """ + if not self.messages: + return None + + latest_message = self.messages[-1] + + result = f""" +INPUT: {self.descriptor_for_LLM} +// START +{latest_message.message} +// END +""" + + self.io_provider.add_input( + self.__class__.__name__, latest_message.message, latest_message.timestamp + ) + self.messages = [] + + return result diff --git a/src/providers/imu_provider.py b/src/providers/imu_provider.py new file mode 100644 index 000000000..add17794d --- /dev/null +++ b/src/providers/imu_provider.py @@ -0,0 +1,142 @@ +import logging +import threading +import time + +from .singleton import singleton + + +@singleton +class IMUProvider: + """ + Singleton provider for IMU (Inertial Measurement Unit) data. + + Stores and distributes accelerometer, gyroscope, and orientation + data to all OM1 components. Also handles fall/impact detection. + """ + + def __init__(self): + """Initialize the IMUProvider.""" + logging.info("Booting IMUProvider") + + self._lock = threading.Lock() + + # Accelerometer (m/s^2) + self.accel_x: float = 0.0 + self.accel_y: float = 0.0 + self.accel_z: float = 0.0 + + # Gyroscope (deg/s) + self.gyro_x: float = 0.0 + self.gyro_y: float = 0.0 + self.gyro_z: float = 0.0 + + # Orientation (degrees) + self.roll: float = 0.0 + self.pitch: float = 0.0 + self.yaw: float = 0.0 + + # Fall/impact detection + self.is_fallen: bool = False + self.impact_detected: bool = False + + # Thresholds + self.fall_threshold: float = 45.0 + self.impact_threshold: float = 20.0 + + # Timestamps + self.last_update: float = 0.0 + + def update( + self, + accel_x: float, + accel_y: float, + accel_z: float, + gyro_x: float, + gyro_y: float, + gyro_z: float, + roll: float, + pitch: float, + yaw: float, + ) -> None: + """ + Update IMU data and evaluate fall/impact detection. + + Parameters + ---------- + accel_x, accel_y, accel_z : float + Accelerometer readings in m/s^2. + gyro_x, gyro_y, gyro_z : float + Gyroscope readings in deg/s. + roll, pitch, yaw : float + Orientation angles in degrees. + """ + with self._lock: + self.accel_x = accel_x + self.accel_y = accel_y + self.accel_z = accel_z + + self.gyro_x = gyro_x + self.gyro_y = gyro_y + self.gyro_z = gyro_z + + self.roll = roll + self.pitch = pitch + self.yaw = yaw + + self.last_update = time.time() + + # Fall detection: robot tilted beyond threshold + self.is_fallen = ( + abs(self.roll) > self.fall_threshold + or abs(self.pitch) > self.fall_threshold + ) + + # Impact detection: sudden acceleration spike + accel_magnitude = (accel_x**2 + accel_y**2 + accel_z**2) ** 0.5 + self.impact_detected = accel_magnitude > self.impact_threshold + + if self.is_fallen: + logging.warning( + f"IMUProvider: fall detected - roll={roll:.1f} pitch={pitch:.1f}" + ) + if self.impact_detected: + logging.warning( + f"IMUProvider: impact detected - magnitude={accel_magnitude:.2f}" + ) + + @property + def state(self) -> dict: + """ + Get current IMU state as dictionary. + + Returns + ------- + dict + Current IMU readings and detection flags. + """ + with self._lock: + return { + "accel_x": self.accel_x, + "accel_y": self.accel_y, + "accel_z": self.accel_z, + "gyro_x": self.gyro_x, + "gyro_y": self.gyro_y, + "gyro_z": self.gyro_z, + "roll": self.roll, + "pitch": self.pitch, + "yaw": self.yaw, + "is_fallen": self.is_fallen, + "impact_detected": self.impact_detected, + "last_update": self.last_update, + } + + def reset_alerts(self) -> None: + """Reset fall and impact detection flags.""" + with self._lock: + self.is_fallen = False + self.impact_detected = False + logging.info("IMUProvider: alerts reset") + + def stop(self) -> None: + """Stop the IMUProvider and clean up resources.""" + logging.info("IMUProvider stopped") diff --git a/tests/actions/fall_recovery/__init__.py b/tests/actions/fall_recovery/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/actions/fall_recovery/connector/__init__.py b/tests/actions/fall_recovery/connector/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/actions/fall_recovery/connector/test_fall_recovery_serial.py b/tests/actions/fall_recovery/connector/test_fall_recovery_serial.py new file mode 100644 index 000000000..5b7347129 --- /dev/null +++ b/tests/actions/fall_recovery/connector/test_fall_recovery_serial.py @@ -0,0 +1,131 @@ +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from actions.fall_recovery.connector.serial import ( + FallRecoverySerialConfig, + FallRecoverySerialConnector, +) +from actions.fall_recovery.interface import FallRecoveryAction, FallRecoveryInput +from providers.imu_provider import IMUProvider + + +@pytest.fixture(autouse=True) +def reset_singleton(): + IMUProvider.reset() # type: ignore[attr-defined] + yield + IMUProvider.reset() # type: ignore[attr-defined] + + +@pytest.fixture +def config(): + return FallRecoverySerialConfig( + port="/dev/ttyUSB0", + baudrate=115200, + timeout=2.0, + ) + + +@pytest.fixture +def connector(config): + with patch( + "actions.fall_recovery.connector.serial._pyserial.Serial" + ) as mock_serial: + mock_serial.return_value = MagicMock() + c = FallRecoverySerialConnector(config) + return c + + +def test_init_success(config): + with patch( + "actions.fall_recovery.connector.serial._pyserial.Serial" + ) as mock_serial: + mock_serial.return_value = MagicMock() + c = FallRecoverySerialConnector(config) + assert c.ser is not None + + +def test_init_serial_failure(config): + with patch( + "actions.fall_recovery.connector.serial._pyserial.Serial", + side_effect=Exception("Port not found"), + ): + c = FallRecoverySerialConnector(config) + assert c.ser is None + + +def test_send_command_success(connector): + result = connector._send_command({"action": "stand_up"}) + assert result is True + connector.ser.write.assert_called_once() + + +def test_send_command_no_serial(config): + with patch( + "actions.fall_recovery.connector.serial._pyserial.Serial", + side_effect=Exception("fail"), + ): + c = FallRecoverySerialConnector(config) + result = c._send_command({"action": "stand_up"}) + assert result is False + + +def test_send_command_write_error(connector): + connector.ser.write.side_effect = Exception("write error") + result = connector._send_command({"action": "stand_up"}) + assert result is False + + +def test_connect_stand_up(connector): + output = FallRecoveryInput(action=FallRecoveryAction.STAND_UP, message="Robot fell") + asyncio.get_event_loop().run_until_complete(connector.connect(output)) + connector.ser.write.assert_called_once() + written = connector.ser.write.call_args[0][0].decode("utf-8") + assert "stand_up" in written + + +def test_connect_stand_up_resets_alerts(connector): + IMUProvider().update(15.0, 15.0, 0.0, 0.0, 0.0, 0.0, 50.0, 0.0, 0.0) + assert IMUProvider().state["is_fallen"] is True + output = FallRecoveryInput(action=FallRecoveryAction.STAND_UP, message="") + asyncio.get_event_loop().run_until_complete(connector.connect(output)) + assert IMUProvider().state["is_fallen"] is False + + +def test_connect_emergency_stop(connector): + output = FallRecoveryInput( + action=FallRecoveryAction.EMERGENCY_STOP, message="Critical impact" + ) + asyncio.get_event_loop().run_until_complete(connector.connect(output)) + connector.ser.write.assert_called_once() + written = connector.ser.write.call_args[0][0].decode("utf-8") + assert "emergency_stop" in written + + +def test_connect_alert_operator(connector): + output = FallRecoveryInput( + action=FallRecoveryAction.ALERT_OPERATOR, message="Need help" + ) + asyncio.get_event_loop().run_until_complete(connector.connect(output)) + connector.ser.write.assert_called_once() + written = connector.ser.write.call_args[0][0].decode("utf-8") + assert "alert_operator" in written + + +def test_connect_unknown_action(connector): + output = FallRecoveryInput(action=FallRecoveryAction.STAND_UP, message="") + output.action = MagicMock() + output.action.value = "unknown_action" + asyncio.get_event_loop().run_until_complete(connector.connect(output)) + connector.ser.write.assert_not_called() + + +def test_send_command_json_format(connector): + connector._send_command({"action": "stand_up", "message": "test"}) + written = connector.ser.write.call_args[0][0].decode("utf-8") + import json + + data = json.loads(written.strip()) + assert data["action"] == "stand_up" + assert data["message"] == "test" diff --git a/tests/backgrounds/plugins/test_imu_fall_detector.py b/tests/backgrounds/plugins/test_imu_fall_detector.py new file mode 100644 index 000000000..cebffa2ac --- /dev/null +++ b/tests/backgrounds/plugins/test_imu_fall_detector.py @@ -0,0 +1,111 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from backgrounds.base import BackgroundConfig +from backgrounds.plugins.imu_fall_detector import IMUFallDetector +from providers.imu_provider import IMUProvider + + +@pytest.fixture(autouse=True) +def reset_singleton(): + IMUProvider.reset() # type: ignore[attr-defined] + yield + IMUProvider.reset() # type: ignore[attr-defined] + + +@pytest.fixture +def config(): + return BackgroundConfig() + + +@pytest.fixture +def detector(config): + with patch("backgrounds.plugins.imu_fall_detector.ContextProvider"): + d = IMUFallDetector(config) + d.context_provider = MagicMock() + return d + + +def test_init(detector): + assert detector._fall_reported is False + assert detector._impact_reported is False + + +def test_run_normal_state(detector): + IMUProvider().update(0.0, 0.0, 9.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + detector.run() + detector.context_provider.update_context.assert_not_called() + + +def test_run_fall_detected(detector): + IMUProvider().update(0.0, 0.0, 9.8, 0.0, 0.0, 0.0, 50.0, 0.0, 0.0) + detector.run() + detector.context_provider.update_context.assert_called_once() + call_args = detector.context_provider.update_context.call_args[0][0] + assert call_args["imu_fall_detected"] is True + assert detector._fall_reported is True + + +def test_run_fall_not_reported_twice(detector): + IMUProvider().update(0.0, 0.0, 9.8, 0.0, 0.0, 0.0, 50.0, 0.0, 0.0) + detector.run() + detector.run() + assert detector.context_provider.update_context.call_count == 1 + + +def test_run_fall_resolved(detector): + IMUProvider().update(0.0, 0.0, 9.8, 0.0, 0.0, 0.0, 50.0, 0.0, 0.0) + detector.run() + detector.context_provider.update_context.reset_mock() + IMUProvider().update(0.0, 0.0, 9.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + IMUProvider().reset_alerts() + detector.run() + call_args = detector.context_provider.update_context.call_args[0][0] + assert call_args["imu_fall_detected"] is False + assert detector._fall_reported is False + + +def test_run_impact_detected(detector): + IMUProvider().update(15.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + detector.run() + call_args_list = detector.context_provider.update_context.call_args_list + impact_calls = [c for c in call_args_list if "imu_impact_detected" in c[0][0]] + assert len(impact_calls) == 1 + assert impact_calls[0][0][0]["imu_impact_detected"] is True + assert detector._impact_reported is True + + +def test_run_impact_not_reported_twice(detector): + IMUProvider().update(15.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + detector.run() + detector.run() + impact_calls = [ + c + for c in detector.context_provider.update_context.call_args_list + if "imu_impact_detected" in c[0][0] + ] + assert len(impact_calls) == 1 + + +def test_run_impact_resolved(detector): + IMUProvider().update(15.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + detector.run() + detector.context_provider.update_context.reset_mock() + IMUProvider().reset_alerts() + detector.run() + call_args = detector.context_provider.update_context.call_args[0][0] + assert call_args["imu_impact_detected"] is False + assert detector._impact_reported is False + + +def test_stop(detector): + detector.stop() + + +def test_run_fall_and_impact_together(detector): + IMUProvider().update(15.0, 15.0, 0.0, 0.0, 0.0, 0.0, 50.0, 0.0, 0.0) + detector.run() + assert detector._fall_reported is True + assert detector._impact_reported is True + assert detector.context_provider.update_context.call_count == 2 diff --git a/tests/inputs/plugins/test_imu_input.py b/tests/inputs/plugins/test_imu_input.py new file mode 100644 index 000000000..295e5e256 --- /dev/null +++ b/tests/inputs/plugins/test_imu_input.py @@ -0,0 +1,218 @@ +import asyncio +from unittest.mock import MagicMock, patch + +import pytest +import serial + +from inputs.plugins.imu_input import IMUConfig, IMUInput +from providers.imu_provider import IMUProvider + + +@pytest.fixture(autouse=True) +def reset_singleton(): + IMUProvider.reset() # type: ignore[attr-defined] + yield + IMUProvider.reset() # type: ignore[attr-defined] + + +@pytest.fixture +def config(): + return IMUConfig( + port="/dev/ttyUSB0", + baudrate=115200, + timeout=1.0, + fall_threshold=45.0, + impact_threshold=20.0, + poll_interval=0.1, + ) + + +@pytest.fixture +def imu_input(config): + with patch("inputs.plugins.imu_input.serial.Serial") as mock_serial: + mock_serial.return_value = MagicMock() + plugin = IMUInput(config) + return plugin + + +def test_init_success(config): + with patch("inputs.plugins.imu_input.serial.Serial") as mock_serial: + mock_serial.return_value = MagicMock() + plugin = IMUInput(config) + assert plugin.ser is not None + assert ( + plugin.descriptor_for_LLM + == "IMU Sensor (Accelerometer, Gyroscope, Orientation)" + ) + + +def test_init_serial_failure(config): + with patch( + "inputs.plugins.imu_input.serial.Serial", + side_effect=serial.SerialException("Port not found"), + ): + plugin = IMUInput(config) + assert plugin.ser is None + + +def test_thresholds_applied(config): + with patch("inputs.plugins.imu_input.serial.Serial"): + plugin = IMUInput(config) + assert plugin.imu_provider.fall_threshold == 45.0 + assert plugin.imu_provider.impact_threshold == 20.0 + + +def test_poll_no_serial(config): + with patch( + "inputs.plugins.imu_input.serial.Serial", + side_effect=serial.SerialException("fail"), + ): + plugin = IMUInput(config) + result = asyncio.get_event_loop().run_until_complete(plugin._poll()) + assert result is None + + +def test_poll_empty_line(imu_input): + imu_input.ser.readline.return_value = b"" + result = asyncio.get_event_loop().run_until_complete(imu_input._poll()) + assert result is None + + +def test_poll_valid_data(imu_input): + imu_input.ser.readline.return_value = b'{"ax":0.1,"ay":0.2,"az":9.8,"gx":0.0,"gy":0.0,"gz":0.0,"roll":1.0,"pitch":2.0,"yaw":90.0}\n' + result = asyncio.get_event_loop().run_until_complete(imu_input._poll()) + assert result is not None + assert result["ax"] == 0.1 + assert result["roll"] == 1.0 + + +def test_poll_invalid_json(imu_input): + imu_input.ser.readline.return_value = b"not json\n" + result = asyncio.get_event_loop().run_until_complete(imu_input._poll()) + assert result is None + + +def test_raw_to_text_none(imu_input): + result = asyncio.get_event_loop().run_until_complete(imu_input._raw_to_text(None)) + assert result is None + + +def test_raw_to_text_normal(imu_input): + data = { + "ax": 0.1, + "ay": 0.2, + "az": 9.8, + "gx": 0.0, + "gy": 0.0, + "gz": 0.0, + "roll": 1.0, + "pitch": 2.0, + "yaw": 90.0, + } + result = asyncio.get_event_loop().run_until_complete(imu_input._raw_to_text(data)) + assert result is not None + assert "normal" in result.message.lower() + + +def test_raw_to_text_fall(imu_input): + data = { + "ax": 0.0, + "ay": 0.0, + "az": 9.8, + "gx": 0.0, + "gy": 0.0, + "gz": 0.0, + "roll": 50.0, + "pitch": 0.0, + "yaw": 0.0, + } + result = asyncio.get_event_loop().run_until_complete(imu_input._raw_to_text(data)) + assert result is not None + assert "fallen" in result.message.lower() + + +def test_raw_to_text_impact(imu_input): + data = { + "ax": 15.0, + "ay": 15.0, + "az": 0.0, + "gx": 0.0, + "gy": 0.0, + "gz": 0.0, + "roll": 0.0, + "pitch": 0.0, + "yaw": 0.0, + } + result = asyncio.get_event_loop().run_until_complete(imu_input._raw_to_text(data)) + assert result is not None + assert "impact" in result.message.lower() + + +def test_raw_to_text_invalid_data(imu_input): + result = asyncio.get_event_loop().run_until_complete( + imu_input._raw_to_text({"ax": "invalid"}) + ) + assert result is None + + +def test_raw_to_text_updates_buffer(imu_input): + data = { + "ax": 0.1, + "ay": 0.2, + "az": 9.8, + "gx": 0.0, + "gy": 0.0, + "gz": 0.0, + "roll": 1.0, + "pitch": 2.0, + "yaw": 90.0, + } + asyncio.get_event_loop().run_until_complete(imu_input.raw_to_text(data)) + assert len(imu_input.messages) == 1 + + +def test_raw_to_text_none_not_added_to_buffer(imu_input): + asyncio.get_event_loop().run_until_complete(imu_input.raw_to_text(None)) + assert len(imu_input.messages) == 0 + + +def test_formatted_latest_buffer_empty(imu_input): + result = imu_input.formatted_latest_buffer() + assert result is None + + +def test_formatted_latest_buffer_with_data(imu_input): + data = { + "ax": 0.1, + "ay": 0.2, + "az": 9.8, + "gx": 0.0, + "gy": 0.0, + "gz": 0.0, + "roll": 1.0, + "pitch": 2.0, + "yaw": 90.0, + } + asyncio.get_event_loop().run_until_complete(imu_input.raw_to_text(data)) + result = imu_input.formatted_latest_buffer() + assert result is not None + assert "IMU Sensor" in result + assert len(imu_input.messages) == 0 + + +def test_formatted_latest_buffer_clears_messages(imu_input): + data = { + "ax": 0.1, + "ay": 0.2, + "az": 9.8, + "gx": 0.0, + "gy": 0.0, + "gz": 0.0, + "roll": 1.0, + "pitch": 2.0, + "yaw": 90.0, + } + asyncio.get_event_loop().run_until_complete(imu_input.raw_to_text(data)) + asyncio.get_event_loop().run_until_complete(imu_input.raw_to_text(data)) + imu_input.formatted_latest_buffer() + assert len(imu_input.messages) == 0 diff --git a/tests/providers/test_imu_provider.py b/tests/providers/test_imu_provider.py new file mode 100644 index 000000000..4044b9033 --- /dev/null +++ b/tests/providers/test_imu_provider.py @@ -0,0 +1,121 @@ +import pytest + +from providers.imu_provider import IMUProvider + + +@pytest.fixture(autouse=True) +def reset_singleton(): + IMUProvider.reset() # type: ignore[attr-defined] + yield + IMUProvider.reset() # type: ignore[attr-defined] + + +def test_initial_state(): + provider = IMUProvider() + state = provider.state + assert state["accel_x"] == 0.0 + assert state["accel_y"] == 0.0 + assert state["accel_z"] == 0.0 + assert state["gyro_x"] == 0.0 + assert state["gyro_y"] == 0.0 + assert state["gyro_z"] == 0.0 + assert state["roll"] == 0.0 + assert state["pitch"] == 0.0 + assert state["yaw"] == 0.0 + assert state["is_fallen"] is False + assert state["impact_detected"] is False + assert state["last_update"] == 0.0 + + +def test_update_normal(): + provider = IMUProvider() + provider.update(0.1, 0.2, 9.8, 0.0, 0.0, 0.0, 1.0, 2.0, 90.0) + state = provider.state + assert state["accel_x"] == 0.1 + assert state["accel_y"] == 0.2 + assert state["accel_z"] == 9.8 + assert state["roll"] == 1.0 + assert state["pitch"] == 2.0 + assert state["yaw"] == 90.0 + assert state["is_fallen"] is False + assert state["impact_detected"] is False + assert state["last_update"] > 0.0 + + +def test_fall_detection_roll(): + provider = IMUProvider() + provider.update(0.0, 0.0, 9.8, 0.0, 0.0, 0.0, 50.0, 0.0, 0.0) + assert provider.state["is_fallen"] is True + + +def test_fall_detection_pitch(): + provider = IMUProvider() + provider.update(0.0, 0.0, 9.8, 0.0, 0.0, 0.0, 0.0, -50.0, 0.0) + assert provider.state["is_fallen"] is True + + +def test_no_fall_within_threshold(): + provider = IMUProvider() + provider.update(0.0, 0.0, 9.8, 0.0, 0.0, 0.0, 44.9, 0.0, 0.0) + assert provider.state["is_fallen"] is False + + +def test_impact_detection(): + provider = IMUProvider() + provider.update(15.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + assert provider.state["impact_detected"] is True + + +def test_no_impact_within_threshold(): + provider = IMUProvider() + provider.update(0.1, 0.2, 9.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + assert provider.state["impact_detected"] is False + + +def test_reset_alerts(): + provider = IMUProvider() + provider.update(15.0, 15.0, 0.0, 0.0, 0.0, 0.0, 50.0, 0.0, 0.0) + assert provider.state["is_fallen"] is True + assert provider.state["impact_detected"] is True + provider.reset_alerts() + assert provider.state["is_fallen"] is False + assert provider.state["impact_detected"] is False + + +def test_singleton(): + p1 = IMUProvider() + p2 = IMUProvider() + p1.update(0.0, 0.0, 9.8, 0.0, 0.0, 0.0, 5.0, 3.0, 10.0) + assert p2.state["roll"] == 5.0 + + +def test_custom_thresholds(): + provider = IMUProvider() + provider.fall_threshold = 30.0 + provider.impact_threshold = 10.0 + provider.update(0.0, 0.0, 9.8, 0.0, 0.0, 0.0, 35.0, 0.0, 0.0) + assert provider.state["is_fallen"] is True + + +def test_stop(): + provider = IMUProvider() + provider.stop() + + +def test_negative_roll_fall(): + provider = IMUProvider() + provider.update(0.0, 0.0, 9.8, 0.0, 0.0, 0.0, -50.0, 0.0, 0.0) + assert provider.state["is_fallen"] is True + + +def test_gyro_data_stored(): + provider = IMUProvider() + provider.update(0.0, 0.0, 9.8, 1.1, 2.2, 3.3, 0.0, 0.0, 0.0) + state = provider.state + assert state["gyro_x"] == 1.1 + assert state["gyro_y"] == 2.2 + assert state["gyro_z"] == 3.3 + + +def test_reset_class_method_callable(): + IMUProvider.reset() # type: ignore[attr-defined]