diff --git a/.env.example b/.env.example index e88d6c2..455b7ad 100644 --- a/.env.example +++ b/.env.example @@ -5,6 +5,28 @@ ADMIN_NODES='!aae8900d' # The root URL of the Meshflow API STORAGE_API_ROOT='http://localhost:8000' STORAGE_API_TOKEN=... +# Features +ENABLE_TCP_PROXY=true +# Handshake Cache size for TCP proxy (how many historical packets to cache to allow new clients to catch up quickly) +PROXY_HANDSHAKE_CACHE_SIZE=100 +# Rolling Cache size for TCP proxy (how many recent packets to keep in the rolling buffer) +PROXY_ROLLING_CACHE_SIZE=100 +ENABLE_FEATURE_NODE_TOTALS=true +FREQUENCY_OF_NODE_REPORTS=3 +CHANNEL_FOR_NODE_TOTAL_BROADCAST=2 + +# Commands +ENABLE_COMMAND_PING=true +ENABLE_COMMAND_TR=true +ENABLE_COMMAND_HELLO=true +ENABLE_COMMAND_HELP=true +ENABLE_COMMAND_NODES=true +ENABLE_COMMAND_WHOAMI=true +ENABLE_COMMAND_PREFS=true +ENABLE_COMMAND_ADMIN=true +ENABLE_COMMAND_STATUS=true + +# API Version (usually 1 or 2) STORAGE_API_VERSION=2 # Use these if you want to upload to a second API (usually used during testing) @@ -15,10 +37,10 @@ STORAGE_API_VERSION=2 # Use this if you want to receive commands from the Meshflow server (e.g. traceroute) MESHFLOW_WS_URL=ws://localhost:8000 -# Comma-separated portnums to skip when submitting packets to the API (e.g. custom or rejected ports) +# Comma-separated portnums to skip when submitting packets to the API (e.g. custom or rejected portnums) IGNORE_PORTNUMS=345,ROUTING_APP -# Traceroute config +# Traceroute config (for WebSocket commands) TR_HOPS_LIMIT=5 # Min seconds between traceroutes (firmware enforces ~30s; we rate-limit client-side) TR_MIN_INTERVAL_SEC=30 diff --git a/README.md b/README.md index 6bbee32..00359c8 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # Meshtastic Bot +### Although this is based on https://github.com/pskillen/meshtastic-bot I have personalised it a lot to my own setup for commands, auto replies and automations etc + +I am now working on Node-totals which I hope to be able to report the number of nodes my remote nodes can see. + Meshtastic Bot is a Python-based bot for interacting with Meshtastic devices. It listens for messages, processes commands, and responds with appropriate actions. This guide is focused on helping you run the bot as-is, with minimal setup. ## Quick Start: Run with Docker @@ -12,34 +16,39 @@ The easiest way to run Meshtastic Bot is using Docker. This method requires mini - Create a `.env` file in your project directory with the required environment variables: ``` -MESHTASTIC_NODE_IP=your_meshtastic_node_ip +MESHTASTIC_IP=your_meshtastic_node_ip ADMIN_NODES=comma_separated_admin_node_ids -STORAGE_API_ROOT=your_storage_api_url -STORAGE_API_TOKEN=your_storage_api_token +STORAGE_API_ROOT=https://meshflow.pskillen.xyz +STORAGE_API_TOKEN=your_storage_api_token from above site # Optionally, you can upload to a second API as well STORAGE_API_2_ROOT=your_storage_api_2_url STORAGE_API_2_TOKEN=your_storage_api_2_token + +# Feature Toggles +ENABLE_TCP_PROXY=true + +# Command Toggles (set to false to disable) +ENABLE_COMMAND_PING=true +ENABLE_COMMAND_TR=true +ENABLE_COMMAND_HELLO=true +ENABLE_COMMAND_HELP=true +ENABLE_COMMAND_NODES=true +ENABLE_COMMAND_WHOAMI=true +ENABLE_COMMAND_PREFS=true +ENABLE_COMMAND_ADMIN=true +ENABLE_COMMAND_STATUS=true ``` ### 2. Use This `docker-compose.yaml` ```yaml -version: '3.8' - services: bot: image: ghcr.io/pskillen/meshtastic-bot:latest container_name: meshtastic-bot restart: unless-stopped - environment: - - MESHTASTIC_IP=${MESHTASTIC_NODE_IP} - - ADMIN_NODES=${ADMIN_NODES} - - STORAGE_API_ROOT=${STORAGE_API_ROOT} - - STORAGE_API_TOKEN=${STORAGE_API_TOKEN} - - STORAGE_API_VERSION=2 - - STORAGE_API_2_ROOT=${STORAGE_API_2_ROOT} - - STORAGE_API_2_TOKEN=${STORAGE_API_2_TOKEN} - - STORAGE_API_2_VERSION=2 + env_file: + - meshtastic-bot.env volumes: - mesh_bot_data:/app/data @@ -57,6 +66,46 @@ The bot will now run in the background. Data will be persisted locally in the `m --- +## Customization + +You can enable or disable specific features and commands using environment variables in your `.env` or `meshtastic-bot.env` file. All options default to `true` if not specified. + +### Feature Toggles +- `ENABLE_TCP_PROXY`: Set to `false` to disable the internal TCP proxy. The bot will connect directly to `MESHTASTIC_IP`. +- `PROXY_HANDSHAKE_CACHE_SIZE`: Number of initial packets to cache for connecting proxy clients (default `100`). +- `PROXY_ROLLING_CACHE_SIZE`: Number of recent packets to cache in a rolling queue for connecting proxy clients (default `100`). + +### Command Toggles +Set any of the following to `false` to disable the command and hide it from the `!help` menu: +- `ENABLE_COMMAND_PING` +- `ENABLE_COMMAND_TR` +- `ENABLE_COMMAND_HELLO` +- `ENABLE_COMMAND_HELP` +- `ENABLE_COMMAND_NODES` +- `ENABLE_COMMAND_WHOAMI` +- `ENABLE_COMMAND_PREFS` +- `ENABLE_COMMAND_ADMIN` +- `ENABLE_COMMAND_STATUS` + +--- + +## Docker Compose Options + +There are two primary ways to run the bot using Docker: + +### 1. Standard (`docker-compose.yaml`) - **Recommended for local builds** +- **Purpose**: Stable use with local source control. +- **How it works**: It builds the bot locally from the source files in the repository. +- **Includes**: Integrated **Watchtower** service which automatically checks for and applies updates to the `meshtastic-bot` container every hour. +- **Environment**: Configuration is pulled from your `.env` file. + +### 2. Remote/Pre-built (`docker-compose-remote.yaml`) +- **Purpose**: Quick deployment using the official container. +- **How it works**: Pulls the pre-built image from the **GitHub Container Registry** (`ghcr.io`). +- **Configuration**: Uses `meshtastic-bot.env` for environment variables and a named Docker volume (`mesh_bot_data`) for persistence. + +--- + ## Native Installation (Advanced/Development) If you prefer to run the bot natively (e.g., for development or customization): @@ -85,17 +134,39 @@ If you prefer to run the bot natively (e.g., for development or customization): ## Usage -The bot listens for messages and responds to commands. You can interact with it via supported Meshtastic channels. +The bot listens for messages and responds to commands as a direct message. You can interact with it via supported Meshtastic channels. ### Supported Commands -| Command | Description | -|-----------|------------------------------------------------| -| `!help` | Displays a list of available commands | -| `!hello` | Displays information about the bot | -| `!ping` | Responds with "Pong!" | -| `!nodes` | Displays a list of connected nodes, stats, etc | -| `!whoami` | Displays information about the sender | +| Command | Description | +|-----------|---------------------------------------------------------------| +| `!help` | Displays a list of available commands | +| `!hello` | Displays information about the bot | +| `!ping` | Responds with "Pong!" | +| `!nodes` | Displays a list of connected nodes, stats, etc | +| `!nodes totals` | Manually triggers a node count report | +| `!whoami` | Displays information about the sender | +| `!tr` | Performs a traceroute to the sender (outbound & inbound) | +| `!status` | Displays bot status and radio connection details | + +## Features + +### Node Count Reporting +The bot monitors mesh visibility and provides automated reporting: +- **Scheduled Reports:** Every 3 hours, a status update is sent to Channel 2 (GregPrivate) with the current online node count. +- **Immediate Alerts:** If the visible node count drops to zero, the bot sends an immediate warning. +- **Manual Check:** Use `!nodes totals` to get an instant report via DM. + +### Enhanced Connectivity (TCP Proxy) +The bot now includes a built-in TCP proxy to manage the connection to the Meshtastic node. This improves stability and allows for automatic reconnection if the radio connection is lost. + +### Improved Logging +Messages received on named Group Channels (e.g., 'LongRange', 'PrivateChat') are now logged with their specific channel name, making it easier to track conversations across different mesh networks. + +### Advanced Traceroute +The `!tr` command has been upgraded to show the full path: +- **Outbound:** The route from the bot to your node. +- **Inbound:** The route back from your node to the bot (if available). --- diff --git a/docker-compose-remote.yaml b/docker-compose-remote.yaml new file mode 100644 index 0000000..d20ae94 --- /dev/null +++ b/docker-compose-remote.yaml @@ -0,0 +1,16 @@ +#----- Docker Compose.yaml ------ +services: + bot: + image: ghcr.io/pskillen/meshtastic-bot:latest + container_name: meshtastic-bot + restart: unless-stopped + ports: + - "4403:4403" + env_file: + - meshtastic-bot.env + volumes: + - ./src:/app/src + - mesh_bot_data:/app/data + +volumes: + mesh_bot_data: diff --git a/docker-compose.yaml b/docker-compose.yaml index 762a72b..91b1ca6 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,19 +1,32 @@ --- services: - meshtastic-bot: - image: ghcr.io/pskillen/meshtastic-bot:latest - build: - context: "./" + bot: + build: . container_name: meshtastic-bot restart: unless-stopped + ports: + - "4403:4403" environment: - - MESHTASTIC_IP=meshtastic.local - - ADMIN_NODES='!aae8900d' # Change this, unless you want me to be the admin of your bot - volumes: - - ./data:/app/data - depends_on: - - watchtower + - MESHTASTIC_IP=${MESHTASTIC_IP:-192.168.4.210} # Your Meshtastic Node IP here + - STORAGE_API_ROOT=${STORAGE_API_ROOT} + - STORAGE_API_TOKEN=${STORAGE_API_TOKEN} + - STORAGE_API_VERSION=${STORAGE_API_VERSION} + - ENABLE_TCP_PROXY=${ENABLE_TCP_PROXY:-true} + - ENABLE_FEATURE_NODE_TOTALS=${ENABLE_FEATURE_NODE_TOTALS:-true} + - FREQUENCY_OF_NODE_REPORTS=${FREQUENCY_OF_NODE_REPORTS:-3} + - CHANNEL_FOR_NODE_TOTAL_BROADCAST=${CHANNEL_FOR_NODE_TOTAL_BROADCAST:-2} + - ENABLE_COMMAND_PING=${ENABLE_COMMAND_PING:-true} + - ENABLE_COMMAND_TR=${ENABLE_COMMAND_TR:-true} + - ENABLE_COMMAND_HELLO=${ENABLE_COMMAND_HELLO:-true} + - ENABLE_COMMAND_HELP=${ENABLE_COMMAND_HELP:-true} + - ENABLE_COMMAND_NODES=${ENABLE_COMMAND_NODES:-true} + - ENABLE_COMMAND_WHOAMI=${ENABLE_COMMAND_WHOAMI:-true} + - ENABLE_COMMAND_PREFS=${ENABLE_COMMAND_PREFS:-true} + - ENABLE_COMMAND_ADMIN=${ENABLE_COMMAND_ADMIN:-true} + - ENABLE_COMMAND_STATUS=${ENABLE_COMMAND_STATUS:-true} + labels: + - "com.centurylinklabs.watchtower.enable=false" watchtower: image: containrrr/watchtower @@ -21,4 +34,6 @@ services: restart: unless-stopped volumes: - /var/run/docker.sock:/var/run/docker.sock - command: --interval 3600 meshtastic-bot # Check for updates every hour + environment: + - DOCKER_API_VERSION=1.44 + command: --interval 3600 --label-enable diff --git a/find_node.py b/find_node.py new file mode 100644 index 0000000..1f8893d --- /dev/null +++ b/find_node.py @@ -0,0 +1,15 @@ +import sqlite3 +import sys + +try: + conn = sqlite3.connect('/app/data/node_db.sqlite') + cursor = conn.cursor() + cursor.execute("SELECT long_name, short_name, id FROM nodes WHERE short_name LIKE '%mte4%' OR long_name LIKE '%mte4%'") + rows = cursor.fetchall() + if not rows: + print("No node found with 'mte4' in name.") + for row in rows: + print(f"Found: {row}") + conn.close() +except Exception as e: + print(f"Error: {e}") diff --git a/src/api/StorageAPI.py b/src/api/StorageAPI.py index fb588ed..c05d2d6 100644 --- a/src/api/StorageAPI.py +++ b/src/api/StorageAPI.py @@ -4,6 +4,7 @@ import os import traceback from datetime import datetime +from json import JSONDecodeError from pathlib import Path from typing import Union @@ -28,6 +29,7 @@ def _get_url(self, path: str, args: dict = None): if args is None: args = {} + my_nodenum = self.bot.my_nodenum if self.api_version == 1: api_paths = { 'raw_packet': '/api/raw-packet/', @@ -35,7 +37,6 @@ def _get_url(self, path: str, args: dict = None): 'node_by_id': f'/api/nodes/{args.get("node_id", "")}', } else: - my_nodenum = self.bot.my_nodenum api_paths = { 'raw_packet': f'/api/packets/{my_nodenum}/ingest/', 'nodes': f'/api/packets/{my_nodenum}/nodes/', @@ -64,6 +65,35 @@ def store_raw_packet(self, packet: dict): """ Store a raw packet in the storage API """ + if self.api_version == 2 and (self.bot.my_nodenum is None or self.bot.my_nodenum <= 0): + logging.debug("Skipping store_raw_packet: Bot node number not yet initialized.") + return + + logging.info(f"store_raw_packet called for portnum: {packet.get('decoded', {}).get('portnum')}") + # Filter out packet types that the API doesn't support or we don't want to store + ignored_ports = [345, 'TRACEROUTE_APP', 'ADMIN_APP', 'NEIGHBORINFO_APP', 'ROUTING_APP'] + portnum = packet.get('decoded', {}).get('portnum') + if portnum in ignored_ports: + return + + # Additional filtering for Telemetry packets to avoid API errors + # The API requires either 'deviceMetrics' or 'localStats' + if portnum == 'ROUTING_APP': + from_id = packet.get('from') + logging.info(f"DEBUG: ROUTING_APP Packet from {from_id}: {packet}") + + # Log all text messages + if portnum == 'TEXT_MESSAGE_APP': + from_id = packet.get('from') + logging.info(f"DEBUG: TEXT_MESSAGE_APP Packet from {from_id}: {packet}") + + if portnum == 'TELEMETRY_APP': + telemetry = packet.get('decoded', {}).get('telemetry', {}) + if 'deviceMetrics' not in telemetry and 'localStats' not in telemetry: + # Log debug instead of error/warning so we know we skipped it but it's not a failure + logging.debug("Skipping unsupported TELEMETRY packet (missing deviceMetrics/localStats)") + return + # Convert bytes to Base64-encoded strings recursively raw_packet: MeshPacket = packet.get('raw') packet = StorageAPIWrapper._sanitise_raw_packet(packet) @@ -72,35 +102,44 @@ def store_raw_packet(self, packet: dict): if raw_packet: if 'channel' not in packet: packet['channel'] = raw_packet.channel + if 'id' not in packet: + packet['id'] = raw_packet.id + if 'from' not in packet: + packet['from'] = raw_packet.from_node - logging.debug(f"Storing packet: {packet}") + logging.info(f"Storing packet: {packet}") try: response = self._post(self._get_url('raw_packet'), json=packet) - response_json = response.json() - return response_json + try: + response_json = response.json() + logging.info(f"API Response ({response.status_code}): {response_json}") + return response_json + except JSONDecodeError: + logging.info(f"API Response ({response.status_code}, not JSON): {response.text}") + return {'text': response.text} + except HTTPError as ex: logging.error(f"HTTP error storing packet: {ex.response.text}") logging.error(f"Packet: {packet}") - - # Dump the packet to a .json file if self.failed_packets_dir: self._dump_failed_packet(packet, ex) return + except Exception as ex: logging.error(f"Error storing packet: {ex}") logging.error(f"Packet: {packet}") - - # Dump the packet to a .json file if self.failed_packets_dir: self._dump_failed_packet(packet, ex) - return def list_nodes(self) -> list[MeshNode]: """ Get a list of all nodes stored in the storage API. This list generally does not include position or metrics data. """ + if self.api_version == 2 and (self.bot.my_nodenum is None or self.bot.my_nodenum <= 0): + return [] + response = self._get(self._get_url('nodes')) response_json = response.json() @@ -112,6 +151,9 @@ def store_node(self, node: MeshNode): If the node contains position or metrics data, it will be stored as well """ + if self.api_version == 2 and (self.bot.my_nodenum is None or self.bot.my_nodenum <= 0): + logging.debug("Skipping store_node: Bot node number not yet initialized.") + return node_data = MeshNodeSerializer.to_api_dict(node) diff --git a/src/api/serializers.py b/src/api/serializers.py index a5e8426..7476704 100644 --- a/src/api/serializers.py +++ b/src/api/serializers.py @@ -27,22 +27,25 @@ class PositionSerializer(AbstractModelSerializer): def to_api_dict(cls, position: MeshNode.Position) -> dict: return { "logged_time": cls.date_to_api(position.logged_time), # api v1 compatibility + "loggedTime": cls.date_to_api(position.logged_time), "reported_time": cls.date_to_api(position.reported_time), # api v2 compatibility + "reportedTime": cls.date_to_api(position.reported_time), "latitude": position.latitude, "longitude": position.longitude, "altitude": position.altitude, "location_source": position.location_source or "LOC_UNKNOWN", + "locationSource": position.location_source or "LOC_UNKNOWN", } @classmethod def from_api_dict(cls, position_data: dict) -> MeshNode.Position: return MeshNode.Position( - logged_time=cls.date_from_api(position_data['logged_time']), - reported_time=cls.date_from_api(position_data['reported_time']), + logged_time=cls.date_from_api(position_data.get('logged_time') or position_data.get('loggedTime')), + reported_time=cls.date_from_api(position_data.get('reported_time') or position_data.get('reportedTime')), latitude=position_data['latitude'], longitude=position_data['longitude'], altitude=position_data['altitude'], - location_source=position_data['location_source'] + location_source=position_data.get('location_source') or position_data.get('locationSource') ) @@ -51,23 +54,29 @@ class DeviceMetricsSerializer(AbstractModelSerializer): def to_api_dict(cls, device_metrics: MeshNode.DeviceMetrics) -> dict: return { "logged_time": cls.date_to_api(device_metrics.logged_time), # api v1 compatibility + "loggedTime": cls.date_to_api(device_metrics.logged_time), "reported_time": cls.date_to_api(device_metrics.logged_time), # api v2 compatibility + "reportedTime": cls.date_to_api(device_metrics.logged_time), "battery_level": device_metrics.battery_level, + "batteryLevel": device_metrics.battery_level, "voltage": device_metrics.voltage, "channel_utilization": device_metrics.channel_utilization, + "channelUtilization": device_metrics.channel_utilization, "air_util_tx": device_metrics.air_util_tx, - "uptime_seconds": device_metrics.uptime_seconds + "airUtilTx": device_metrics.air_util_tx, + "uptime_seconds": device_metrics.uptime_seconds, + "uptimeSeconds": device_metrics.uptime_seconds } @classmethod def from_api_dict(cls, device_metrics_data: dict) -> MeshNode.DeviceMetrics: return MeshNode.DeviceMetrics( - logged_time=cls.date_from_api(device_metrics_data['logged_time']), - battery_level=device_metrics_data['battery_level'], + logged_time=cls.date_from_api(device_metrics_data.get('logged_time') or device_metrics_data.get('loggedTime') or device_metrics_data.get('reported_time') or device_metrics_data.get('reportedTime')), + battery_level=device_metrics_data.get('battery_level') or device_metrics_data.get('batteryLevel'), voltage=device_metrics_data['voltage'], - channel_utilization=device_metrics_data['channel_utilization'], - air_util_tx=device_metrics_data['air_util_tx'], - uptime_seconds=device_metrics_data['uptime_seconds'] + channel_utilization=device_metrics_data.get('channel_utilization') or device_metrics_data.get('channelUtilization'), + air_util_tx=device_metrics_data.get('air_util_tx') or device_metrics_data.get('airUtilTx'), + uptime_seconds=device_metrics_data.get('uptime_seconds') or device_metrics_data.get('uptimeSeconds') ) @@ -80,10 +89,14 @@ def to_api_dict(cls, node: MeshNode) -> dict: "id": node.user.id, "macaddr": node.user.macaddr, "hw_model": node.user.hw_model, + "hwModel": node.user.hw_model, "public_key": node.user.public_key, + "publicKey": node.user.public_key, 'user': { "long_name": node.user.long_name, - "short_name": node.user.short_name + "longName": node.user.long_name, + "short_name": node.user.short_name, + "shortName": node.user.short_name } } @@ -94,6 +107,7 @@ def to_api_dict(cls, node: MeshNode) -> dict: if node.device_metrics: node_data['device_metrics'] = DeviceMetricsSerializer.to_api_dict(node.device_metrics) + node_data['deviceMetrics'] = DeviceMetricsSerializer.to_api_dict(node.device_metrics) return node_data @@ -103,10 +117,10 @@ def from_api_dict(cls, node_data: dict) -> MeshNode: user = MeshNode.User( node_id=node_data['id'], macaddr=node_data['macaddr'], - hw_model=node_data['hw_model'], - public_key=node_data['public_key'], - long_name=user_data['long_name'], - short_name=user_data['short_name'] + hw_model=node_data.get('hw_model') or node_data.get('hwModel'), + public_key=node_data.get('public_key') or node_data.get('publicKey'), + long_name=user_data.get('long_name') or user_data.get('longName'), + short_name=user_data.get('short_name') or user_data.get('shortName') ) position_data = node_data.get('position') @@ -114,7 +128,7 @@ def from_api_dict(cls, node_data: dict) -> MeshNode: if position_data: position = PositionSerializer.from_api_dict(position_data) - device_metrics_data = node_data.get('device_metrics') + device_metrics_data = node_data.get('device_metrics') or node_data.get('deviceMetrics') device_metrics = None if device_metrics_data: device_metrics = DeviceMetricsSerializer.from_api_dict(device_metrics_data) diff --git a/src/base_feature.py b/src/base_feature.py index 27d6879..29adeb6 100644 --- a/src/base_feature.py +++ b/src/base_feature.py @@ -1,5 +1,6 @@ import logging import os +import time from abc import ABC from meshtastic.protobuf.mesh_pb2 import MeshPacket @@ -40,18 +41,19 @@ def message_in_channel(self, channel: int, message: str, want_ack=False) -> None message, channelIndex=channel, wantAck=want_ack, hopLimit=TEXT_MESSAGE_MAX_HOPS ) - def reply_in_dm(self, packet: MeshPacket, message: str, want_ack=False) -> None: + def reply_in_dm(self, packet: MeshPacket, message: str, want_ack=True) -> None: """ Reply in a direct message to a user """ destination_id = packet['fromId'] self.message_in_dm(destination_id, message, want_ack) - def message_in_dm(self, destination_id: str, message: str, want_ack=False) -> None: + def message_in_dm(self, destination_id: str, message: str, want_ack=True) -> None: """ Reply in a direct message to a user """ logging.debug(f"Sending DM: '{message}'") + time.sleep(1) # Wait a second to let the radio settle self.bot.interface.sendText( message, destinationId=destination_id, diff --git a/src/bot.py b/src/bot.py index 6b073de..f3ab458 100644 --- a/src/bot.py +++ b/src/bot.py @@ -1,6 +1,7 @@ import logging import sys import time +import threading from datetime import datetime, timezone import schedule @@ -10,9 +11,12 @@ from src.api.StorageAPI import StorageAPIWrapper from src.commands.factory import CommandFactory -from src.traceroute import on_traceroute_command +try: + from src.traceroute import on_traceroute_command +except ImportError: + on_traceroute_command = None from src.data_classes import MeshNode -from src.helpers import pretty_print_last_heard, safe_encode_node_name +from src.helpers import pretty_print_last_heard, safe_encode_node_name, get_env_bool, get_env_int from src.persistence.commands_logger import AbstractCommandLogger from src.persistence.node_db import AbstractNodeDB from src.persistence.node_info import AbstractNodeInfoStore @@ -42,6 +46,8 @@ class MeshtasticBot: def __init__(self, address: str): self.address = address + self.start_time = datetime.now(timezone.utc) + self.proxy = None self.admin_nodes = [] self.ignore_portnums = frozenset() @@ -57,8 +63,11 @@ def __init__(self, address: str): self.user_prefs_persistence = None self.storage_apis = [] self.ws_client = None + self.pending_traces = {} + self.last_report_zero = False pub.subscribe(self.on_receive, "meshtastic.receive") + pub.subscribe(self.on_traceroute, "meshtastic.traceroute") pub.subscribe(self.on_receive_text, "meshtastic.receive.text") pub.subscribe(self.on_node_updated, "meshtastic.node.updated") pub.subscribe(self.on_connection, "meshtastic.connection.established") @@ -113,21 +122,32 @@ def disconnect(self): def on_traceroute_command(self, target_node_id: int): """Handle traceroute command from WebSocket (e.g. from Meshflow API).""" - on_traceroute_command(self, target_node_id) + if on_traceroute_command: + on_traceroute_command(self, target_node_id) + else: + logging.warning("Traceroute handling via WebSocket is not available (import failed).") def on_connection(self, interface, topic=pub.AUTO_TOPIC): self.my_nodenum = interface.localNode.nodeNum # in dec - self.my_id = f"!{hex(self.my_nodenum)[2:]}" + self.my_id = f"!{self.my_nodenum:08x}" self.init_complete = True - logging.info('Connected to Meshtastic node') + logging.info(f'Connected to Meshtastic node as {self.my_id}') self.print_nodes() + + # Send an immediate node count report upon connection + # We use a timer to delay slightly to ensure everything settles + if get_env_bool('ENABLE_FEATURE_NODE_TOTALS', True): + threading.Timer(10.0, self.report_node_count).start() if self.ws_client: self.ws_client.start() def on_receive_text(self, packet: MeshPacket, interface): """Callback function triggered when a text message is received.""" + from_id = packet.get('fromId') + text = packet.get('decoded', {}).get('text', '') + logging.info(f"on_receive_text: Incoming text from {from_id}: {text}") to_id = packet['toId'] @@ -142,25 +162,69 @@ def handle_private_message(self, packet: MeshPacket): from_id = packet['fromId'] sender = self.node_db.get_by_id(from_id) - logging.info(f"Received private message: '{message}' from {sender.long_name if sender else from_id}") + logging.info(f"✉️ [PRIVATE MSG] '{message}' from {sender.long_name if sender else from_id}") words = message.split() command_name = words[0] command_instance = CommandFactory.create_command(command_name, self) if command_instance: self.command_logger.log_command(from_id, command_instance, message) - try: - command_instance.handle_packet(packet) - except Exception as e: - logging.error(f"Error handling message: {e}") + + def run_command(): + try: + logging.info(f"🤖 [BOT CMD] Running private command {command_name} in thread for {from_id}") + command_instance.handle_packet(packet) + logging.info(f"✅ [BOT CMD] Finished private command {command_name} for {from_id}") + except Exception as e: + logging.error(f"❌ [BOT CMD] Error handling private command {command_name}: {e}", exc_info=True) + + threading.Thread(target=run_command, daemon=True).start() else: self.command_logger.log_unknown_request(from_id, message) + def get_channel_name(self, packet: MeshPacket) -> str: + """Get the name of the channel for a packet.""" + channel_index = packet.get('channel', 0) + try: + if self.interface and self.interface.localNode: + channel = self.interface.localNode.channels[channel_index] + if channel and channel.settings and channel.settings.name: + return channel.settings.name + except (AttributeError, IndexError): + pass + return "Primary" if channel_index == 0 else f"Channel {channel_index}" + def handle_public_message(self, packet: MeshPacket): - """Handle public messages.""" + """Handle public (group channel) messages.""" message = packet['decoded']['text'] from_id = packet['fromId'] sender = self.node_db.get_by_id(from_id) + sender_name = sender.long_name if sender else from_id + channel_name = self.get_channel_name(packet) + + logging.info(f"📢 [GROUP MSG] Channel '{channel_name}' from {sender_name}: {message}") + + # Allow certain commands in public channels + words = message.split() + if words: + command_name = words[0].lower() + if command_name in ["!tr", "!ping", "!hello", "!nodes", "!status", "!whoami"]: + env_var_name = f"ENABLE_COMMAND_{command_name.lstrip('!').upper()}" + if get_env_bool(env_var_name, True): + logging.info(f"🤖 [BOT CMD] Received public {command_name} from {sender_name}") + command_instance = CommandFactory.create_command(command_name, self) + if command_instance: + def run_command(): + try: + logging.info(f"🤖 [BOT CMD] Running public command {command_name} in thread for {from_id}") + # Commands by default reply via DM (reply_in_dm). + command_instance.handle_packet(packet) + logging.info(f"✅ [BOT CMD] Finished public command {command_name} for {from_id}") + except Exception as e: + logging.error(f"❌ [BOT CMD] Error handling public command {command_name}: {e}", exc_info=True) + + threading.Thread(target=run_command, daemon=True).start() + return # Stop processing responders responder = ResponderFactory.match_responder(message, self) if responder: @@ -169,21 +233,96 @@ def handle_public_message(self, packet: MeshPacket): if outcome: logging.info( - f"Handled message from {sender.long_name if sender else from_id} with responder {responder.__class__.__name__}: {message}") + f"🤖 [RESPONDER] Handled message from {sender.long_name if sender else from_id} with responder {responder.__class__.__name__}: {message}") self.command_logger.log_responder_handled(from_id, responder, message) + except (KeyError, ValueError) as e: + logging.error(f"Packet format error handling message: {e}", exc_info=True) except Exception as e: - logging.error(f"Error handling message: {e}") + logging.error(f"Error handling message: {e}", exc_info=True) + + def on_traceroute(self, packet, route): + """Callback for when a traceroute response is received.""" + logging.info(f"on_traceroute: Received signal from {packet.get('fromId') if isinstance(packet, dict) else 'obj'}") + + def process_traceroute(): + try: + target_id = packet.get('fromId') + if target_id not in self.pending_traces: + return + + requesters = self.pending_traces.pop(target_id) + if not isinstance(requesters, list): + requesters = [requesters] + + if route is None: + for ctx in requesters: + r_id = ctx[0] if isinstance(ctx, tuple) else ctx + msg = f"Traceroute response received from {target_id}, but no route data was provided." + self.interface.sendText(msg, destinationId=r_id, wantAck=True) + return + + def get_route_hops(r, key='route'): + if isinstance(r, dict): + return r.get(key, []) + return getattr(r, key, []) + + # Format compact routes + target_node = self.node_db.get_by_id(target_id) + t_name = target_node.short_name if target_node else target_id[-4:] + + my_node = self.node_db.get_by_id(self.my_id) + m_name = my_node.short_name if my_node else self.my_id[-4:] + + # Outbound + route_ids = get_route_hops(route, 'route') + hops_to = [] + for nid in route_ids: + n = self.node_db.get_by_id(f"!{nid:08x}") + hops_to.append(n.short_name if n else f"{nid:08x}"[-4:]) + route_to_str = ">".join(hops_to) + (">" if hops_to else "") + t_name + + # Inbound + route_back_ids = get_route_hops(route, 'route_back') + hops_fr = [] + for nid in route_back_ids: + n = self.node_db.get_by_id(f"!{nid:08x}") + hops_fr.append(n.short_name if n else f"{nid:08x}"[-4:]) + route_fr_str = ">".join(hops_fr) + (">" if hops_fr else "") + m_name + + # Consolidate into a single message + combined_response = f"!tr {t_name}:\nTO({len(route_ids)}h): {route_to_str}\nFR({len(route_back_ids)}h): {route_fr_str}" + + # Longer wait for radio to settle + time.sleep(8) + + for ctx in requesters: + r_id, is_pub, to_id, c_idx = ctx if isinstance(ctx, tuple) else (ctx, False, ctx, 0) + dest_id = to_id if is_pub else r_id + self.interface.sendText(combined_response, destinationId=dest_id, channelIndex=c_idx, wantAck=True) + time.sleep(2) + except Exception as e: + logging.error(f"Error in on_traceroute thread: {e}", exc_info=True) + + threading.Thread(target=process_traceroute, daemon=True).start() def on_receive(self, packet: MeshPacket, interface): # dump the packet to disk (if enabled) dump_packet(packet) portnum = packet.get("decoded", {}).get("portnum", "unknown") + # Ensure we check against both the string name and the integer ID if available portnum_key = str(portnum).upper() + has_decoded = 'decoded' in packet or 'decrypted' in packet - if self.ignore_portnums and portnum_key in self.ignore_portnums: + is_ignored = False + if self.ignore_portnums: + if portnum_key in self.ignore_portnums: + is_ignored = True + elif isinstance(portnum, int) and str(portnum) in self.ignore_portnums: + is_ignored = True + + if is_ignored: logging.info(f"Skipping API submission for packet with portnum {portnum} (in IGNORE_PORTNUMS)") - # Continue with node_info etc. below, just skip storage API elif not has_decoded: pass # Skip API submission for packets with no decoded data else: @@ -192,67 +331,56 @@ def on_receive(self, packet: MeshPacket, interface): storage_api.store_raw_packet(packet) except HTTPError as ex: logging.warning(f"Error storing packet: {ex.response.text}") - pass except Exception as ex: logging.warning(f"Error storing packet in API: {ex}") - pass sender = packet['fromId'] node = self.node_db.get_by_id(sender) if not node: - # logging.warning(f"Received packet from unknown sender {sender}") return if node: portnum = packet['decoded']['portnum'] if 'decoded' in packet else 'unknown' if sender == self.my_id and portnum == 'TELEMETRY_APP': - # Ignore telemetry packets sent by self pass else: - # Increment packets_today for this node self.node_info.node_packet_received(sender, portnum) - if sender == self.my_id: - recipient_id = packet['toId'] - recipient = self.node_db.get_by_id(recipient_id) - portnum = packet['decoded']['portnum'] - - logging.debug( - f"Received packet from self: {recipient.long_name if recipient else recipient_id} (port {portnum})") - def on_node_updated(self, node, interface): if interface.localNode and self.my_nodenum is None: self.my_nodenum = interface.localNode.nodeNum - self.my_id = f"!{hex(self.my_nodenum)[2:]}" + self.my_id = f"!{self.my_nodenum:08x}" - # Check if the node is a new user if node['user'] is not None: mesh_node = MeshNode.from_dict(node) last_heard_int = node.get('lastHeard', 0) - last_heard = datetime.fromtimestamp(last_heard_int, tz=timezone.utc) - self.node_db.store_node(mesh_node) - self.node_info.update_last_heard(mesh_node.user.id, last_heard) - - for storage_api in self.storage_apis: - try: - storage_api.store_node(mesh_node) - except HTTPError as ex: - logging.warning(f"Error storing node: {ex.response.text}") - pass - except Exception as ex: - logging.warning(f"Error storing node: {ex}") - pass - - if self.init_complete: - last_heard_str = pretty_print_last_heard(last_heard) + + if last_heard_int > 0: + last_heard = datetime.fromtimestamp(last_heard_int, tz=timezone.utc) + existing_last_heard = self.node_info.get_last_heard(mesh_node.user.id) + if not existing_last_heard or last_heard > existing_last_heard: + self.node_info.update_last_heard(mesh_node.user.id, last_heard) + + existing_user = self.node_db.get_by_id(mesh_node.user.id) + is_new = existing_user is None + + if is_new or existing_user != mesh_node.user: + self.node_db.store_node(mesh_node) + for storage_api in self.storage_apis: + try: + storage_api.store_node(mesh_node) + except Exception as ex: + logging.warning(f"Error storing node: {ex}") + + if self.init_complete and is_new: + current_last_heard = self.node_info.get_last_heard(mesh_node.user.id) + last_heard_str = pretty_print_last_heard(current_last_heard) if current_last_heard else "unknown" logging.info(f"New user: {mesh_node.user.long_name} (last heard {last_heard_str})") def print_nodes(self): - # filter nodes where last heard is more than 2 hours ago online_nodes = self.node_info.get_online_nodes() offline_nodes = self.node_info.get_offline_nodes() - # print all nodes, sorted by last heard descending logging.info(f"Online nodes: ({len(online_nodes)})") sorted_nodes = sorted(online_nodes, key=lambda x: online_nodes[x], reverse=True) for node_id in sorted_nodes: @@ -266,6 +394,41 @@ def print_nodes(self): logging.info(f"- Plus {len(offline_nodes)} offline nodes") + def report_node_count(self, destination=None, channel_index=None): + if not self.init_complete or not self.interface: + return + + if channel_index is None: + channel_index = get_env_int('CHANNEL_FOR_NODE_TOTAL_BROADCAST', 2) + + online_nodes = self.node_info.get_online_nodes() + count = len(online_nodes) + + if count == 0: + message = "Warning MTEK cant see any nodes" + self.last_report_zero = True + else: + message = f"MTEK has a node count of {count}" + self.last_report_zero = False + + logging.info(f"Reporting node count: {message}") + try: + if destination: + self.interface.sendText(message, destinationId=destination, wantAck=True) + else: + self.interface.sendText(message, channelIndex=channel_index, wantAck=True) + except Exception as e: + logging.error(f"Error reporting node count: {e}") + + def check_for_zero_nodes(self): + if not self.init_complete or not self.interface: + return + online_nodes = self.node_info.get_online_nodes() + if len(online_nodes) == 0 and not self.last_report_zero: + self.report_node_count() + elif len(online_nodes) > 0: + self.last_report_zero = False + def get_global_context(self): return { 'nodes': self.node_db.list_nodes(), @@ -275,6 +438,10 @@ def get_global_context(self): def start_scheduler(self): schedule.every().day.at("00:00").do(self.node_info.reset_packets_today) + if get_env_bool('ENABLE_FEATURE_NODE_TOTALS', True): + report_frequency = get_env_int('FREQUENCY_OF_NODE_REPORTS', 3) + schedule.every(report_frequency).hours.do(self.report_node_count) + schedule.every(1).minutes.do(self.check_for_zero_nodes) while True: schedule.run_pending() try: diff --git a/src/commands/command.py b/src/commands/command.py index 2494c72..ba17e26 100644 --- a/src/commands/command.py +++ b/src/commands/command.py @@ -20,7 +20,7 @@ def handle_packet(self, packet: MeshPacket) -> None: pass @deprecated("use reply_in_dm instead") - def reply(self, packet: MeshPacket, message: str, want_ack=False) -> None: + def reply(self, packet: MeshPacket, message: str, want_ack=True) -> None: """ Reply to a message in the same channel This is a deprecated method, use reply_in_channel instead @@ -28,7 +28,7 @@ def reply(self, packet: MeshPacket, message: str, want_ack=False) -> None: self.reply_in_dm(packet, message, want_ack) @deprecated("use message_in_dm instead") - def reply_to(self, destination_id: str, message: str, want_ack=False) -> None: + def reply_to(self, destination_id: str, message: str, want_ack=True) -> None: """ Reply in a direct message to a user This is a deprecated method, use reply_in_dm instead diff --git a/src/commands/factory.py b/src/commands/factory.py index 0aa4437..122b693 100644 --- a/src/commands/factory.py +++ b/src/commands/factory.py @@ -1,4 +1,5 @@ import importlib +from src.helpers import get_env_bool class CommandFactory: @@ -7,6 +8,10 @@ class CommandFactory: "class": "src.commands.ping.PingCommand", "args": [] }, + "!tr": { + "class": "src.commands.tr.TracerouteCommand", + "args": [] + }, "!hello": { "class": "src.commands.hello.HelloCommand", "args": [] @@ -31,6 +36,10 @@ class CommandFactory: "class": "src.commands.admin.AdminCommand", "args": [] }, + "!status": { + "class": "src.commands.status.StatusCommand", + "args": [] + }, # "!enroll": { # "class": "src.commands.enroll.EnrollCommandHandler", # "args": ["enroll"] @@ -45,6 +54,12 @@ class CommandFactory: def create_command(command_name, bot): command_info = CommandFactory.commands.get(command_name) if command_info: + # Check if command is enabled via environment variable + # e.g., !ping -> ENABLE_COMMAND_PING + env_var_name = f"ENABLE_COMMAND_{command_name.lstrip('!').upper()}" + if not get_env_bool(env_var_name, True): + return None + module_name, class_name = command_info["class"].rsplit('.', 1) module = importlib.import_module(module_name) command_class = getattr(module, class_name) diff --git a/src/commands/hello.py b/src/commands/hello.py index 6f65435..737ab27 100644 --- a/src/commands/hello.py +++ b/src/commands/hello.py @@ -13,7 +13,7 @@ def handle_packet(self, packet: MeshPacket) -> None: sender = self.bot.node_db.get_by_id(sender_id) sender_name = sender.long_name if sender else sender_id - response = f"Hello, {sender_name}! How can I help you? (tip: try !help). I'm a bot maintained by PDY4 / pskillen@gmail.com" + response = f"Hello, {sender_name}! (tip: try !help). I'm a bot maintained by MTEK original PDY4 / https://github.com/pskillen/meshtastic-bot" self.reply_to(sender_id, response) def get_command_for_logging(self, message: str) -> (str, list[str] | None, str | None): diff --git a/src/commands/help.py b/src/commands/help.py index 63c9607..0f44568 100644 --- a/src/commands/help.py +++ b/src/commands/help.py @@ -2,25 +2,56 @@ from src.bot import MeshtasticBot from src.commands.command import AbstractCommandWithSubcommands +from src.helpers import get_env_bool class HelpCommand(AbstractCommandWithSubcommands): def __init__(self, bot: MeshtasticBot): super().__init__(bot, 'help') - self.sub_commands['hello'] = self.handle_hello - self.sub_commands['ping'] = self.handle_ping - self.sub_commands['nodes'] = self.handle_nodes - self.sub_commands['whoami'] = self.handle_whoami - self.sub_commands['prefs'] = self.handle_prefs - # self.sub_commands['enroll'] = self.handle_enroll - # self.sub_commands['leave'] = self.handle_leave + if get_env_bool('ENABLE_COMMAND_HELLO', True): + self.sub_commands['hello'] = self.handle_hello + if get_env_bool('ENABLE_COMMAND_PING', True): + self.sub_commands['ping'] = self.handle_ping + if get_env_bool('ENABLE_COMMAND_TR', True): + self.sub_commands['tr'] = self.handle_tr + if get_env_bool('ENABLE_COMMAND_NODES', True): + self.sub_commands['nodes'] = self.handle_nodes + if get_env_bool('ENABLE_COMMAND_WHOAMI', True): + self.sub_commands['whoami'] = self.handle_whoami + if get_env_bool('ENABLE_COMMAND_PREFS', True): + self.sub_commands['prefs'] = self.handle_prefs + if get_env_bool('ENABLE_COMMAND_STATUS', True): + self.sub_commands['status'] = self.handle_status + if get_env_bool('ENABLE_COMMAND_ADMIN', True): + self.sub_commands['admin'] = self.handle_admin + # if get_env_bool('ENABLE_COMMAND_ENROLL', True): + # self.sub_commands['enroll'] = self.handle_enroll + # if get_env_bool('ENABLE_COMMAND_LEAVE', True): + # self.sub_commands['leave'] = self.handle_leave def handle_base_command(self, packet: MeshPacket, args: list[str]) -> None: subcmds = self.sub_commands.keys() subcmds = filter(None, subcmds) # remove empty strings subcmds = [f"!{cmd}" for cmd in subcmds] - response = f"Valid commands are: {', '.join(subcmds)}" + public_cmds = [] + if get_env_bool('ENABLE_COMMAND_TR', True): + public_cmds.append("!tr") + if get_env_bool('ENABLE_COMMAND_PING', True): + public_cmds.append("!ping") + if get_env_bool('ENABLE_COMMAND_HELLO', True): + public_cmds.append("!hello") + if get_env_bool('ENABLE_COMMAND_NODES', True): + public_cmds.append("!nodes") + if get_env_bool('ENABLE_COMMAND_STATUS', True): + public_cmds.append("!status") + if get_env_bool('ENABLE_COMMAND_WHOAMI', True): + public_cmds.append("!whoami") + + response = f"Available via Direct Message: {', '.join(subcmds)}." + if public_cmds: + response += f"\nAvailable in Public Channels: {', '.join(public_cmds)} (replies via DM)." + self.reply(packet, response) def handle_hello(self, packet: MeshPacket, args: list[str]) -> None: @@ -31,6 +62,10 @@ def handle_ping(self, packet: MeshPacket, args: list[str]) -> None: response = "!ping (+ optional correlation message): responds with a pong" self.reply(packet, response) + def handle_tr(self, packet: MeshPacket, args: list[str]) -> None: + response = "!tr: responds with the number of hops and signal strength of your message" + self.reply(packet, response) + def handle_nodes(self, packet: MeshPacket, args: list[str]) -> None: response = "!nodes: details about the nodes this device has seen" self.reply(packet, response) @@ -55,5 +90,13 @@ def handle_leave(self, packet: MeshPacket, args: list[str]) -> None: response = "!leave: bot will not respond to you on public channels" self.reply(packet, response) + def handle_status(self, packet: MeshPacket, args: list[str]) -> None: + response = "!status: show current bot and proxy health status" + self.reply(packet, response) + + def handle_admin(self, packet: MeshPacket, args: list[str]) -> None: + response = "!admin: admin commands (restricted)" + self.reply(packet, response) + def get_command_for_logging(self, message: str) -> (str, list[str] | None, str | None): return self._gcfl_base_command_and_args(message) diff --git a/src/commands/nodes.py b/src/commands/nodes.py index dcd5872..a688f6a 100644 --- a/src/commands/nodes.py +++ b/src/commands/nodes.py @@ -1,3 +1,4 @@ +from datetime import datetime, timezone from meshtastic.protobuf.mesh_pb2 import MeshPacket from src.bot import MeshtasticBot @@ -13,6 +14,7 @@ class NodesCommand(AbstractCommandWithSubcommands): def __init__(self, bot: MeshtasticBot): super().__init__(bot, 'nodes') self.sub_commands['busy'] = self.handle_busy + self.sub_commands['totals'] = self.handle_totals def get_busy_nodes(self) -> list[MeshNode.User]: return sorted(self.bot.node_db.list_nodes(), @@ -24,8 +26,8 @@ def handle_base_command(self, packet: MeshPacket, args: list[str]) -> None: online_nodes = self.bot.node_info.get_online_nodes() offline_nodes = self.bot.node_info.get_offline_nodes() - # get nodes sorted by last_head - sorted_nodes = sorted(nodes, key=lambda n: self.bot.node_info.get_last_heard(n.id), reverse=True) + # get nodes sorted by last_head, handling None values (sort them to the bottom) + sorted_nodes = sorted(nodes, key=lambda n: self.bot.node_info.get_last_heard(n.id) or datetime.min.replace(tzinfo=timezone.utc), reverse=True) response = f"{len(online_nodes)} nodes online, {len(offline_nodes)} offline." # Add up to 10 nodes with the most packets received today @@ -93,10 +95,22 @@ def send_detailed_nodeinfo(self, sender: str, node_id: str): self.reply_to(sender, response) + def handle_totals(self, packet: MeshPacket, args: list[str]) -> None: + from_id = packet['fromId'] + # If the user provides a channel index, use it to send the report there + if args and args[0].isdigit(): + channel_index = int(args[0]) + self.bot.report_node_count(channel_index=channel_index) + self.reply(packet, f"Node count report sent to channel {channel_index}.") + else: + # By default, just reply to the user with the count in a DM + self.bot.report_node_count(destination=from_id) + def show_help(self, packet: MeshPacket, args: list[str]) -> None: help_text = "!nodes: details about nodes this device has seen\n" help_text += "!nodes busy: summary of busiest nodes\n" help_text += "!nodes busy detailed: detailed info about busiest nodes\n" + help_text += "!nodes totals: report current online node count\n" self.reply(packet, help_text) def get_command_for_logging(self, message: str) -> (str, list[str] | None, str | None): diff --git a/src/commands/ping.py b/src/commands/ping.py index 4d84317..54585a2 100644 --- a/src/commands/ping.py +++ b/src/commands/ping.py @@ -1,3 +1,4 @@ +import logging from meshtastic.protobuf.mesh_pb2 import MeshPacket from src.commands.command import AbstractCommand @@ -9,9 +10,12 @@ def __init__(self, bot): def handle_packet(self, packet: MeshPacket) -> None: message = packet['decoded']['text'] - hops_away = packet['hopStart'] - packet['hopLimit'] + + hop_start = packet.get('hopStart', 0) + hop_limit = packet.get('hopLimit', 0) + hops_away = hop_start - hop_limit - self.react_in_dm(packet, "🏓") + # self.react_in_dm(packet, "🏓") # trim off the '!ping' command from the message additional = message[5:].strip() @@ -21,7 +25,8 @@ def handle_packet(self, packet: MeshPacket) -> None: response = f"!pong: {additional}" response += f" (ping took {hops_away} hops)" + self.reply_in_dm(packet, response) def get_command_for_logging(self, message: str) -> (str, list[str] | None, str | None): - return self._gcfl_base_command_and_args(message) + return self._gcfl_base_command_and_args(message) \ No newline at end of file diff --git a/src/commands/status.py b/src/commands/status.py new file mode 100644 index 0000000..644c39a --- /dev/null +++ b/src/commands/status.py @@ -0,0 +1,47 @@ +import logging +from datetime import datetime, timezone +from src.commands.command import AbstractCommand + +class StatusCommand(AbstractCommand): + def __init__(self, bot): + super().__init__(bot, "!status") + + def handle_packet(self, packet): + from_id = packet.get('fromId') + + # Calculate Bot Uptime + uptime = datetime.now(timezone.utc) - self.bot.start_time + days = uptime.days + hours, remainder = divmod(uptime.seconds, 3600) + minutes, seconds = divmod(remainder, 60) + uptime_str = f"{days}d {hours}h {minutes}m" + + # Get Proxy Status + proxy_info = "Disabled" + if self.bot.proxy: + status = self.bot.proxy.get_status() + if isinstance(status, dict): + state = "Online" if status['connected'] else "Reconnecting" + proxy_info = f"{state}, {status['clients']} clients, {status['cached_packets']} pkts cached, last radio {status['silence_secs']}s ago" + else: + proxy_info = status + + # Get Storage API status + storage_info = "Not Configured" + if self.bot.storage_apis: + # We'll just report if at least one is configured + storage_info = f"{len(self.bot.storage_apis)} API(s) active" + + response = ( + f"🤖 Bot Status:\n" + f"⏱ Uptime: {uptime_str}\n" + f"🔌 Proxy: {proxy_info}\n" + f"☁️ Storage: {storage_info}" + ) + + logging.info(f"Sending status to {from_id}") + self.reply_in_dm(packet, response) + + def get_command_for_logging(self, message: str) -> (str, list[str] | None, str | None): + return self._gcfl_just_base_command(message) + diff --git a/src/commands/tr.py b/src/commands/tr.py new file mode 100644 index 0000000..c1b3623 --- /dev/null +++ b/src/commands/tr.py @@ -0,0 +1,115 @@ +import logging +import threading +import time +from meshtastic.protobuf.mesh_pb2 import MeshPacket + +from src.commands.command import AbstractCommand + + +class TracerouteCommand(AbstractCommand): + def __init__(self, bot): + super().__init__(bot, 'tr') + + def handle_packet(self, packet: MeshPacket) -> None: + message = packet['decoded']['text'] + words = message.split() + + is_public = packet.get('toId') == '^all' or 'channel' in packet + + def send_reply(msg): + # Always reply in DM + self.reply_in_dm(packet, msg, want_ack=True) + + # Add a reaction (thumbs up for public to acknowledge without spamming, hourglass for DM) + reaction_emoji = "👍" if is_public else "⌛" + reaction_dest = packet.get('toId') if is_public else packet.get('fromId') + logging.info(f"Adding reaction {reaction_emoji} for packet {packet.get('id')} to {reaction_dest}") + self.bot.interface.sendReaction(reaction_emoji, messageId=packet['id'], destinationId=reaction_dest) + + requester_id = packet.get('fromId') + requester = self.bot.node_db.get_by_id(requester_id) + requester_name = requester.long_name if requester else requester_id + + target_node = None + if len(words) > 1: + target_short = words[1] + target_node = self.bot.get_node_by_short_name(target_short) + if not target_node: + send_reply(f"Could not find node with short name '{target_short}'") + return + target_id = target_node.id + target_long_name = target_node.long_name + else: + target_id = requester_id + target_long_name = requester_name + + if target_id == self.bot.my_id: + send_reply("I am already here! No traceroute required.") + return + + # If tracing back to requester, we can show hops_away/SNR from the incoming packet + if target_id == requester_id: + hop_start = packet.get('hopStart', 0) + hop_limit = packet.get('hopLimit', 0) + hops_away = hop_start - hop_limit + snr = packet.get('rxSnr', 0.0) + + # We can log this, but no need to send it explicitly over the radio to save airtime + logging.info(f"Detected {hops_away} hops for {target_id}. SNR: {snr}dB.") + else: + # Tracing to a different node + logging.info(f"Starting traceroute to {target_long_name} ({target_id}) for you...") + + # Store for the callback + if target_id not in self.bot.pending_traces: + self.bot.pending_traces[target_id] = [] + + # Store context: force is_public=False so bot.py always replies via DM + to_id = packet.get('toId') + channel_index = packet.get('channel', 0) + context = (requester_id, False, to_id, channel_index) + + if context not in self.bot.pending_traces[target_id]: + self.bot.pending_traces[target_id].append(context) + + # Start a timeout timer (120 seconds) + def check_timeout(): + time.sleep(120) + if target_id in self.bot.pending_traces: + # Find and remove this specific context from the pending list + self.bot.pending_traces[target_id] = [c for c in self.bot.pending_traces[target_id] if c[0] != requester_id] + # If no more requesters for this target, clean up the key + if not self.bot.pending_traces[target_id]: + del self.bot.pending_traces[target_id] + + logging.info(f"Traceroute to {target_id} (requested by {requester_id}) timed out.") + timeout_msg = f"Traceroute to {target_long_name} ({target_id}) timed out (no response from mesh)." + + # Send the timeout message in a separate thread to avoid blocking the timer/interface + def send_timeout(): + self.message_in_dm(requester_id, timeout_msg, want_ack=True) + + threading.Thread(target=send_timeout, daemon=True).start() + + threading.Thread(target=check_timeout, daemon=True).start() + + try: + # Let the reaction settle before firing the trace + time.sleep(2) + logging.info(f"Initiating traceroute to {target_id} requested by {requester_id}") + # hopLimit=7 is standard max + p = self.bot.interface.sendTraceRoute(target_id, hopLimit=7) + if p: + logging.info(f"Sent traceroute packet to {target_id}. Packet ID: {p.id}") + else: + logging.warning(f"sendTraceRoute returned None for {target_id}") + except Exception as e: + logging.error(f"Failed to send traceroute to {target_id}: {e}") + if target_id in self.bot.pending_traces and requester_id in self.bot.pending_traces[target_id]: + self.bot.pending_traces[target_id].remove(requester_id) + if not self.bot.pending_traces[target_id]: + del self.bot.pending_traces[target_id] + send_reply(f"Error starting traceroute: {e}") + + def get_command_for_logging(self, message: str) -> (str, list[str] | None, str | None): + return self._gcfl_base_command_and_args(message) diff --git a/src/data_classes.py b/src/data_classes.py index 9ed28b3..4171c7e 100644 --- a/src/data_classes.py +++ b/src/data_classes.py @@ -19,6 +19,16 @@ def __init__(self, self.hw_model = hw_model self.public_key = public_key + def __eq__(self, other): + if not isinstance(other, MeshNode.User): + return False + return (self.id == other.id and + self.long_name == other.long_name and + self.short_name == other.short_name and + self.macaddr == other.macaddr and + self.hw_model == other.hw_model and + self.public_key == other.public_key) + id: str long_name: str short_name: str diff --git a/src/helpers.py b/src/helpers.py index d3d4c81..fd5935f 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -1,9 +1,30 @@ +import os import string import urllib.parse from datetime import datetime, timezone -def pretty_print_last_heard(last_heard_timestamp: int | datetime) -> str: +def get_env_bool(name: str, default: bool = True) -> bool: + value = os.getenv(name) + if value is None: + return default + return value.lower() in ('true', '1', 't', 'y', 'yes') + + +def get_env_int(name: str, default: int) -> int: + value = os.getenv(name) + if value is None: + return default + try: + return int(value) + except (ValueError, TypeError): + return default + + +def pretty_print_last_heard(last_heard_timestamp: int | datetime | None) -> str: + if not last_heard_timestamp: + return "never" + if not isinstance(last_heard_timestamp, datetime): last_heard = datetime.fromtimestamp(last_heard_timestamp, timezone.utc) else: diff --git a/src/main.py b/src/main.py index 69c3bb9..bfb6877 100644 --- a/src/main.py +++ b/src/main.py @@ -1,6 +1,7 @@ import logging import os import sys +import time from pathlib import Path from dotenv import load_dotenv @@ -21,15 +22,24 @@ # Now we can import the rest of our local files from src.api.StorageAPI import StorageAPIWrapper from src.bot import MeshtasticBot +from src.helpers import get_env_bool from src.ws_client import MeshflowWSClient from src.persistence.commands_logger import SqliteCommandLogger from src.persistence.node_info import InMemoryNodeInfoStore from src.persistence.node_db import SqliteNodeDB from src.persistence.user_prefs import SqliteUserPrefsPersistence +from src.tcp_proxy import TcpProxy # Get the IP address and admin nodes from environment variables MESHTASTIC_IP = os.getenv("MESHTASTIC_IP") -ADMIN_NODES = os.getenv("ADMIN_NODES").split(',') +# Safely handle missing or empty ADMIN_NODES +admin_nodes_raw = os.getenv("ADMIN_NODES") or "" +ADMIN_NODES = [node.strip() for node in admin_nodes_raw.split(',') if node.strip()] + +ENABLE_TCP_PROXY = get_env_bool("ENABLE_TCP_PROXY", True) +PROXY_HANDSHAKE_CACHE_SIZE = int(os.getenv("PROXY_HANDSHAKE_CACHE_SIZE", 100)) +PROXY_ROLLING_CACHE_SIZE = int(os.getenv("PROXY_ROLLING_CACHE_SIZE", 100)) + DATA_DIR = os.getenv("DATA_DIR", "data") STORAGE_API_ROOT = os.getenv("STORAGE_API_ROOT") STORAGE_API_TOKEN = os.getenv("STORAGE_API_TOKEN", None) @@ -57,8 +67,37 @@ def main(): node_info_file = data_dir / 'node_info.json' failed_packets_dir = data_dir / 'failed_packets' - # Connect to the Meshtastic node over WiFi - bot = MeshtasticBot(MESHTASTIC_IP) + logging.info(f"--- Configuration ---") + logging.info(f"MESHTASTIC_IP: {MESHTASTIC_IP}") + logging.info(f"ENABLE_TCP_PROXY: {ENABLE_TCP_PROXY}") + logging.info(f"PROXY_HANDSHAKE_CACHE_SIZE: {PROXY_HANDSHAKE_CACHE_SIZE}") + logging.info(f"PROXY_ROLLING_CACHE_SIZE: {PROXY_ROLLING_CACHE_SIZE}") + logging.info(f"ENABLE_FEATURE_NODE_TOTALS: {get_env_bool('ENABLE_FEATURE_NODE_TOTALS', True)}") + logging.info(f"FREQUENCY_OF_NODE_REPORTS: {os.getenv('FREQUENCY_OF_NODE_REPORTS', '3')} hours") + logging.info(f"CHANNEL_FOR_NODE_TOTAL_BROADCAST: {os.getenv('CHANNEL_FOR_NODE_TOTAL_BROADCAST', '2')}") + logging.info(f"ENABLE_COMMAND_PING: {get_env_bool('ENABLE_COMMAND_PING', True)}") + logging.info(f"ENABLE_COMMAND_TR: {get_env_bool('ENABLE_COMMAND_TR', True)}") + logging.info(f"IGNORE_PORTNUMS: {list(IGNORE_PORTNUMS)}") + logging.info(f"STORAGE_API_ROOT: {STORAGE_API_ROOT}") + if STORAGE_API_2_ROOT: + logging.info(f"STORAGE_API_2_ROOT: {STORAGE_API_2_ROOT}") + logging.info(f"---------------------") + + proxy = None + if ENABLE_TCP_PROXY: + # Start the TCP Proxy + # It listens on 0.0.0.0:4403 and forwards to MESHTASTIC_IP:4403 + proxy = TcpProxy(target_host=MESHTASTIC_IP, target_port=4403, listen_host='0.0.0.0', listen_port=4403, handshake_cache_size=PROXY_HANDSHAKE_CACHE_SIZE, rolling_cache_size=PROXY_ROLLING_CACHE_SIZE) + proxy.start() + + # Give the proxy a moment to bind to the port before the bot tries to connect + time.sleep(2) + + # Connect to the Meshtastic node + # Use 'localhost' if proxy is enabled, otherwise connect directly + connection_address = 'localhost' if ENABLE_TCP_PROXY else MESHTASTIC_IP + bot = MeshtasticBot(connection_address) + bot.proxy = proxy bot.ignore_portnums = IGNORE_PORTNUMS bot.admin_nodes = ADMIN_NODES bot.user_prefs_persistence = SqliteUserPrefsPersistence(str(user_prefs_file)) diff --git a/src/persistence/__init__.py b/src/persistence/__init__.py index a7d2976..2797e70 100644 --- a/src/persistence/__init__.py +++ b/src/persistence/__init__.py @@ -1,20 +1,33 @@ -import abc -import logging -from pathlib import Path - - -class BaseSqlitePersistenceStore(abc.ABC): - db_path: Path - - def __init__(self, db_path: str): - self.db_path = Path(db_path) - self._initialize_db() - if self.db_path.is_relative_to(Path.cwd()): - path_string = self.db_path.relative_to(Path.cwd()) - else: - path_string = self.db_path - logging.info(f"Connected to {self.__class__.__name__} DB at {path_string}") - - @abc.abstractmethod - def _initialize_db(self): - pass +import abc +import logging +import sqlite3 +import threading +from contextlib import contextmanager +from pathlib import Path + + +class BaseSqlitePersistenceStore(abc.ABC): + db_path: Path + + def __init__(self, db_path: str): + self.db_path = Path(db_path) + self._lock = threading.RLock() + self._initialize_db() + if self.db_path.is_relative_to(Path.cwd()): + path_string = self.db_path.relative_to(Path.cwd()) + else: + path_string = self.db_path + logging.info(f"Connected to {self.__class__.__name__} DB at {path_string}") + + @contextmanager + def _get_connection(self): + """Returns a thread-safe sqlite3 connection and ensures it is closed.""" + conn = sqlite3.connect(self.db_path, check_same_thread=False) + try: + yield conn + finally: + conn.close() + + @abc.abstractmethod + def _initialize_db(self): + pass diff --git a/src/persistence/commands_logger.py b/src/persistence/commands_logger.py index ca57d89..09edc5e 100644 --- a/src/persistence/commands_logger.py +++ b/src/persistence/commands_logger.py @@ -37,7 +37,7 @@ def get_responder_history(self, since: datetime, sender_id: str = None) -> pd.Da class SqliteCommandLogger(AbstractCommandLogger, BaseSqlitePersistenceStore): def _initialize_db(self): - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS command_log ( @@ -70,7 +70,7 @@ def log_command(self, sender_id: str, command_instance, message: str) -> None: base_cmd, subcommands, args = command_instance.get_command_for_logging(message) subcommands_str = ' '.join(subcommands) if subcommands else None - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO command_log (sender_id, base_command, sub_commands, args, timestamp, handler_class) @@ -80,7 +80,7 @@ def log_command(self, sender_id: str, command_instance, message: str) -> None: conn.commit() def log_responder_handled(self, sender_id: str, responder_instance, message_text: str) -> None: - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO responder_log (sender_id, message, timestamp, responder_class) @@ -89,7 +89,7 @@ def log_responder_handled(self, sender_id: str, responder_instance, message_text conn.commit() def log_unknown_request(self, sender_id: str, message: str) -> None: - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO unknown_requests (sender_id, message, timestamp) @@ -98,7 +98,7 @@ def log_unknown_request(self, sender_id: str, message: str) -> None: conn.commit() def get_command_history(self, since: datetime, sender_id: str = None) -> pd.DataFrame: - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() if sender_id: cursor.execute(''' @@ -114,7 +114,7 @@ def get_command_history(self, since: datetime, sender_id: str = None) -> pd.Data return pd.DataFrame(rows, columns=['sender_id', 'base_command', 'timestamp']) def get_unknown_command_history(self, since: datetime, sender_id: str = None) -> pd.DataFrame: - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() if sender_id: cursor.execute(''' @@ -130,7 +130,7 @@ def get_unknown_command_history(self, since: datetime, sender_id: str = None) -> return pd.DataFrame(rows, columns=['sender_id', 'message', 'timestamp']) def get_responder_history(self, since: datetime, sender_id: str = None) -> pd.DataFrame: - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() if sender_id: cursor.execute(''' diff --git a/src/persistence/node_db.py b/src/persistence/node_db.py index 955cb23..e6d3726 100644 --- a/src/persistence/node_db.py +++ b/src/persistence/node_db.py @@ -117,7 +117,7 @@ def get_device_metrics_log(self, node_id: str, start: datetime, end: datetime) - class SqliteNodeDB(BaseSqlitePersistenceStore, AbstractNodeDB): def _initialize_db(self): - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS nodes ( @@ -156,7 +156,7 @@ def _initialize_db(self): conn.commit() def store_user(self, node_user: MeshNode.User): - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT OR REPLACE INTO nodes (id, short_name, long_name, macaddr, hw_model, public_key) @@ -166,7 +166,7 @@ def store_user(self, node_user: MeshNode.User): conn.commit() def store_position(self, node_id: str, position: MeshNode.Position): - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO positions (node_id, logged_time, reported_time, latitude, longitude, altitude, location_source) @@ -176,7 +176,7 @@ def store_position(self, node_id: str, position: MeshNode.Position): conn.commit() def store_device_metrics(self, node_id: str, device_metrics: MeshNode.DeviceMetrics): - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO device_metrics (node_id, logged_time, battery_level, voltage, channel_utilization, air_util_tx, uptime_seconds) @@ -186,7 +186,7 @@ def store_device_metrics(self, node_id: str, device_metrics: MeshNode.DeviceMetr conn.commit() def get_by_id(self, node_id: str) -> MeshNode.User | None: - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute('SELECT id, short_name, long_name, macaddr, hw_model, public_key FROM nodes WHERE id = ?', (node_id,)) @@ -197,7 +197,7 @@ def get_by_id(self, node_id: str) -> MeshNode.User | None: return None def get_by_short_name(self, short_name: str) -> MeshNode.User | None: - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute( 'SELECT id, short_name, long_name, macaddr, hw_model, public_key FROM nodes WHERE short_name = ? COLLATE NOCASE', @@ -209,7 +209,7 @@ def get_by_short_name(self, short_name: str) -> MeshNode.User | None: return None def list_nodes(self) -> list[MeshNode.User]: - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute('SELECT id, short_name, long_name, macaddr, hw_model, public_key FROM nodes') rows = cursor.fetchall() @@ -217,7 +217,7 @@ def list_nodes(self) -> list[MeshNode.User]: hw_model=row[4], public_key=row[5]) for row in rows] def get_last_position(self, node_id: str) -> MeshNode.Position | None: - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' SELECT logged_time, reported_time, latitude, longitude, altitude, location_source @@ -234,7 +234,7 @@ def get_last_position(self, node_id: str) -> MeshNode.Position | None: def get_position_log(self, node_id: str, start: datetime, end: datetime) -> list[ MeshNode.Position]: - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' SELECT logged_time, reported_time, latitude, longitude, altitude, location_source @@ -247,7 +247,7 @@ def get_position_log(self, node_id: str, start: datetime, end: datetime) -> list altitude=row[4], location_source=row[5]) for row in rows] def get_last_device_metrics(self, node_id: str) -> MeshNode.DeviceMetrics | None: - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' SELECT logged_time, battery_level, voltage, channel_utilization, air_util_tx, uptime_seconds @@ -264,7 +264,7 @@ def get_last_device_metrics(self, node_id: str) -> MeshNode.DeviceMetrics | None def get_device_metrics_log(self, node_id: str, start: datetime, end: datetime) -> list[ MeshNode.DeviceMetrics]: - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' SELECT logged_time, battery_level, voltage, channel_utilization, air_util_tx, uptime_seconds diff --git a/src/persistence/node_info.py b/src/persistence/node_info.py index 960c7ce..9a4e2c0 100644 --- a/src/persistence/node_info.py +++ b/src/persistence/node_info.py @@ -100,11 +100,11 @@ def reset_packets_today(self) -> None: def get_online_nodes(self) -> dict[str, datetime]: return {node_id: last_heard for node_id, last_heard in self.nodes_last_heard.items() - if last_heard > datetime.now(timezone.utc) - timedelta(seconds=self.online_threshold_sec)} + if last_heard and last_heard > datetime.now(timezone.utc) - timedelta(seconds=self.online_threshold_sec)} def get_offline_nodes(self) -> dict[str, datetime]: return {node_id: last_heard for node_id, last_heard in self.nodes_last_heard.items() - if last_heard <= datetime.now(timezone.utc) - timedelta(seconds=self.online_threshold_sec)} + if not last_heard or last_heard <= datetime.now(timezone.utc) - timedelta(seconds=self.online_threshold_sec)} def get_all_nodes(self) -> dict[str, datetime]: return self.nodes_last_heard @@ -115,7 +115,7 @@ def load_from_file(self, node_info_file: str) -> None: with open(node_info_file, 'r') as file: data = json.load(file) - self.nodes_last_heard = {k: datetime.fromisoformat(v) for k, v in data['nodes_last_heard'].items()} + self.nodes_last_heard = {k: (datetime.fromisoformat(v) if v else None) for k, v in data['nodes_last_heard'].items()} self.node_packets_today = data['node_packets_today'] self.node_packets_today_breakdown = data['node_packets_today_breakdown'] diff --git a/src/persistence/user_prefs.py b/src/persistence/user_prefs.py index 5888fbd..af3bc70 100644 --- a/src/persistence/user_prefs.py +++ b/src/persistence/user_prefs.py @@ -51,7 +51,7 @@ def persist_user_prefs(self, user_id: str, user_prefs: UserPrefs): class SqliteUserPrefsPersistence(AbstractUserPrefsPersistence, BaseSqlitePersistenceStore): def _initialize_db(self): - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS user_prefs ( @@ -66,7 +66,7 @@ def _initialize_db(self): conn.commit() def get_user_prefs(self, user_id: str) -> UserPrefs: - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: # Fetch the data cursor = conn.cursor() cursor.execute(''' @@ -91,7 +91,7 @@ def get_user_prefs(self, user_id: str) -> UserPrefs: return user_prefs def persist_user_prefs(self, user_id: str, user_prefs: UserPrefs) -> UserPrefs: - with sqlite3.connect(self.db_path) as conn: + with self._lock, self._get_connection() as conn: cursor = conn.cursor() for key, preference in user_prefs.__dict__.items(): if key == 'user_id': diff --git a/src/tcp_interface.py b/src/tcp_interface.py index 8572d56..ab87eeb 100644 --- a/src/tcp_interface.py +++ b/src/tcp_interface.py @@ -1,5 +1,6 @@ import logging import sys +from pubsub import pub import time from queue import Queue from typing import Optional, Callable, Union @@ -29,7 +30,7 @@ def sendReaction( packet.decoded.portnum = portNum packet.decoded.payload = emoji_bytes packet.decoded.reply_id = messageId - packet.decoded.emoji = True + packet.decoded.emoji = ord(emoji) if isinstance(emoji, str) else 1 self._sendPacket(packet, destinationId, wantAck=wantAck, @@ -61,13 +62,42 @@ def __init__(self, *args, # Store packets in a queue and resend them after reconnecting # This will involve exposing the queue, and reloading the queue in bot.py since we create a new interface object + def onResponseTraceRoute(self, packet): + """ + Callback for when a traceroute response is received. + """ + try: + route_discovery = None + if isinstance(packet, dict): + decoded = packet.get('decoded', {}) + # It might be in 'routing', 'routing_app', or 'traceroute' + route_discovery = decoded.get('routing') or decoded.get('routing_app') or decoded.get('traceroute') + + if not route_discovery and 'payload' in decoded: + logging.debug(f"onResponseTraceRoute: Route not found in decoded, full packet: {packet}") + elif hasattr(packet, 'decoded'): + route_discovery = getattr(packet.decoded, 'routing', + getattr(packet.decoded, 'routing_app', + getattr(packet.decoded, 'traceroute', None))) + + logging.info(f"onResponseTraceRoute: Received traceroute response. Route data present: {route_discovery is not None}") + logging.info(f"DEBUG: Traceroute packet keys: {packet.keys() if isinstance(packet, dict) else 'not a dict'}") + + # Always call super to allow library internal processing (printing to stdout etc) + super().onResponseTraceRoute(packet) + + # Notify bot logic + pub.sendMessage("meshtastic.traceroute", packet=packet, route=route_discovery) + except Exception as e: + logging.error(f"Error in onResponseTraceRoute: {e}", exc_info=True) + def sendHeartbeat(self): try: super().sendHeartbeat() except (OSError, BrokenPipeError) as e: logging.error(f"Heartbeat failed: {e}") - # TODO: Decide if we want to handle the error on this thread - # self._reconnect_with_backoff() + # Shutdown and notify the error handler to trigger a clean restart from the main thread. + # This avoids nested reconnection attempts on the heartbeat thread. self._shutdown_and_call_error_handler() def _sendPacket( @@ -79,6 +109,8 @@ def _sendPacket( pkiEncrypted: Optional[bool] = False, publicKey: Optional[bytes] = None, ): + port_val = meshPacket.decoded.portnum + logging.info(f"_sendPacket: Attempting to send Port {port_val} to {destinationId} (wantAck={wantAck})") try: super()._sendPacket( meshPacket=meshPacket, @@ -88,11 +120,15 @@ def _sendPacket( pkiEncrypted=pkiEncrypted, publicKey=publicKey ) + logging.info(f"_sendPacket: Successfully handed Port {port_val} to {destinationId} to meshtastic library") except (OSError, BrokenPipeError) as e: - logging.error(f"sendPacket failed: {e}") + logging.error(f"_sendPacket failed (connection error): {e}") self.packet_queue.put((meshPacket, destinationId, wantAck, hopLimit, pkiEncrypted, publicKey)) - # self._reconnect_with_backoff() self._shutdown_and_call_error_handler(e) + except Exception as e: + logging.error(f"_sendPacket failed (unexpected error): {e}", exc_info=True) + # We still queue it just in case it's recoverable + self.packet_queue.put((meshPacket, destinationId, wantAck, hopLimit, pkiEncrypted, publicKey)) def _shutdown_and_call_error_handler(self, conn_error: Optional[Exception] = None): try: diff --git a/src/tcp_proxy.py b/src/tcp_proxy.py new file mode 100644 index 0000000..8a6e9ee --- /dev/null +++ b/src/tcp_proxy.py @@ -0,0 +1,270 @@ +import asyncio +import logging +import time +from collections import deque +import threading + +class TcpProxy: + def __init__(self, target_host, target_port=4403, listen_host='0.0.0.0', listen_port=4403, handshake_cache_size=100, rolling_cache_size=100): + self.target_host = target_host + self.target_port = int(target_port) + self.listen_host = listen_host + self.listen_port = int(listen_port) + + self.server = None + self.target_reader = None + self.target_writer = None + + self.clients = set() + + self.running = False + self.loop = None + self.thread = None + + self.handshake_packets = [] + self.handshake_max_count = handshake_cache_size + self.rolling_packets = deque(maxlen=rolling_cache_size) + + self.last_target_activity = time.time() + self.reconnecting = False + + def start(self): + self.running = True + self.thread = threading.Thread(target=self._run_loop, daemon=True) + self.thread.start() + + def stop(self): + self.running = False + if self.loop: + self.loop.call_soon_threadsafe(self._stop_loop) + + def _stop_loop(self): + if self.server: + self.server.close() + for writer in self.clients: + try: writer.close() + except: pass + if self.target_writer: + try: self.target_writer.close() + except: pass + + def get_status(self): + if not self.running: + return "Proxy: Offline" + + silence = time.time() - self.last_target_activity + client_count = len(self.clients) + cached_count = len(self.handshake_packets) + len(self.rolling_packets) + + state = "Reconnecting" if self.reconnecting else ("Online" if self.target_writer else "Offline") + + return { + "state": state, + "connected": self.target_writer is not None and not self.reconnecting, + "clients": client_count, + "silence_secs": int(silence), + "cached_packets": cached_count + } + + def _run_loop(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.loop.run_until_complete(self._async_run()) + + async def _async_run(self): + logging.info(f"Starting TCP Proxy on {self.listen_host}:{self.listen_port} -> {self.target_host}:{self.target_port}") + + try: + self.server = await asyncio.start_server( + self._handle_client, self.listen_host, self.listen_port) + except Exception as e: + logging.error(f"Failed to bind proxy port {self.listen_port}: {e}") + self.running = False + return + + asyncio.create_task(self._target_connection_manager()) + asyncio.create_task(self._watchdog()) + + try: + async with self.server: + while self.running: + await asyncio.sleep(1) + except asyncio.CancelledError: + pass + finally: + self._stop_loop() + + async def _watchdog(self): + last_heartbeat_log = time.time() + while self.running: + current_time = time.time() + if self.target_writer and not self.reconnecting: + if current_time - self.last_target_activity > 300.0: + logging.warning(f"Watchdog: No data from radio for 300s. Forcing reconnect...") + try: self.target_writer.close() + except: pass + self.target_reader = None + self.target_writer = None + + if current_time - last_heartbeat_log > 60.0: + client_count = len(self.clients) + status = "Connected" if self.target_writer and not self.reconnecting else "RECONNECTING" + silence = current_time - self.last_target_activity + logging.info(f"Proxy Heartbeat: {status}. Last radio data {silence:.1f}s ago. Clients: {client_count}") + last_heartbeat_log = current_time + + await asyncio.sleep(5) + + async def _target_connection_manager(self): + backoff_time = 5.0 + max_backoff_time = 60.0 + backoff_rate = 2.0 + + while self.running: + if self.target_writer is None or self.target_reader is None: + self.reconnecting = True + self._disconnect_all_clients() + self.handshake_packets.clear() + self.rolling_packets.clear() + + try: + logging.info(f"Proxy attempting to connect to target device at {self.target_host}:{self.target_port}...") + reader, writer = await asyncio.wait_for( + asyncio.open_connection(self.target_host, self.target_port), + timeout=5.0 + ) + self.target_reader = reader + self.target_writer = writer + self.last_target_activity = time.time() + self.reconnecting = False + backoff_time = 5.0 # Reset backoff on success + logging.info(f"Proxy successfully connected to target device at {self.target_host}:{self.target_port}") + asyncio.create_task(self._read_from_target()) + except (asyncio.TimeoutError, ConnectionError, OSError) as e: + logging.error(f"Failed to connect to target ({self.target_host}): {e}. Retrying in {backoff_time:.1f}s...") + await asyncio.sleep(backoff_time) + backoff_time = min(backoff_time * backoff_rate, max_backoff_time) + except Exception as e: + logging.error(f"Unexpected error in target connection manager: {e}", exc_info=True) + await asyncio.sleep(backoff_time) + backoff_time = min(backoff_time * backoff_rate, max_backoff_time) + else: + await asyncio.sleep(1) + + def _disconnect_all_clients(self): + for writer in list(self.clients): + try: writer.close() + except: pass + self.clients.clear() + logging.info("Disconnected all proxy clients to force re-sync.") + + async def _read_from_target(self): + reader = self.target_reader + writer = self.target_writer + + in_buffer = b'' + while self.running and self.target_reader == reader: + try: + data = await reader.read(16384) + if not data: + logging.warning("Radio closed connection. Triggering re-sync...") + break + self.last_target_activity = time.time() + + in_buffer += data + + while len(in_buffer) >= 4: + if in_buffer[0:2] != b'\x94\xc3': + idx = in_buffer.find(b'\x94\xc3') + if idx == -1: + in_buffer = b'' + break + in_buffer = in_buffer[idx:] + continue + + length = (in_buffer[2] << 8) | in_buffer[3] + total_len = length + 4 + + if len(in_buffer) < total_len: + break + + packet = in_buffer[:total_len] + in_buffer = in_buffer[total_len:] + + if len(self.handshake_packets) < self.handshake_max_count: + self.handshake_packets.append(packet) + self.rolling_packets.append(packet) + + for client_writer in list(self.clients): + try: + client_writer.write(packet) + await client_writer.drain() + except Exception as e: + logging.debug(f"Failed to forward packet to client: {e}") + self._remove_client(client_writer) + except Exception as e: + logging.error(f"Error reading from radio: {e}") + break + + if self.target_writer == writer: + try: writer.close() + except: pass + self.target_writer = None + self.target_reader = None + + async def _handle_client(self, reader, writer): + addr = writer.get_extra_info('peername') + logging.info(f"+++ PROXY: New connection accepted from {addr}") + self.clients.add(writer) + + h_snapshot = list(self.handshake_packets) + r_snapshot = list(self.rolling_packets) + + if addr[0] not in ('127.0.0.1', 'localhost'): + try: + await asyncio.sleep(2.0) + for p in h_snapshot: + writer.write(p) + await writer.drain() + await asyncio.sleep(0.05) + for p in r_snapshot: + writer.write(p) + await writer.drain() + await asyncio.sleep(0.01) + logging.info(f"Replayed {len(h_snapshot) + len(r_snapshot)} packets to {addr}") + except Exception as e: + self._remove_client(writer) + return + + while self.running: + try: + data = await reader.read(16384) + if not data: + break + if self.target_writer and not self.reconnecting: + try: + self.target_writer.write(data) + await self.target_writer.drain() + except Exception as e: + logging.error(f"Error sending to radio: {e}") + try: self.target_writer.close() + except: pass + self.target_writer = None + except Exception as e: + logging.debug(f"Error receiving from client: {e}") + break + + self._remove_client(writer) + + def _remove_client(self, writer): + addr = None + try: + addr = writer.get_extra_info('peername') + logging.info(f"--- PROXY: Removing client {addr}") + except: + logging.info("--- PROXY: Removing unknown client") + + if writer in self.clients: + self.clients.remove(writer) + try: writer.close() + except: pass diff --git a/test/commands/test_help.py b/test/commands/test_help.py index db04c36..0fac70d 100644 --- a/test/commands/test_help.py +++ b/test/commands/test_help.py @@ -19,7 +19,7 @@ def test_handle_packet_no_additional_message(self): response = self.mock_interface.sendText.call_args[0][0] - skipped_commands = ['!admin'] + skipped_commands = [] # Ensure every command in CommandFactory is mentioned in the response for command in CommandFactory.commands.keys(): diff --git a/test/commands/test_nodes.py b/test/commands/test_nodes.py index f5823d0..8f8982b 100644 --- a/test/commands/test_nodes.py +++ b/test/commands/test_nodes.py @@ -11,6 +11,7 @@ class TestNodesCommand(CommandWSCTestCase): def setUp(self): super().setUp() + self.bot.init_complete = True self.command = NodesCommand(self.bot) self.online_count = len(self.bot.node_info.get_online_nodes()) @@ -32,7 +33,7 @@ def test_handle_base_command(self): friendly_time = pretty_print_last_heard(last_heard) expected_response += f"- {node.user.short_name} ({friendly_time})\n" - self.assert_message_sent(expected_response, self.test_nodes[1]) + self.assert_message_sent(expected_response, self.test_nodes[1], want_ack=True) def test_handle_busy_command(self): packet = build_test_text_packet('!nodes busy', self.test_nodes[1].user.id, self.bot.my_id) @@ -52,7 +53,7 @@ def test_handle_busy_command(self): expected_response += f"(last reset at {last_reset_time})" - self.assert_message_sent(expected_response, self.test_nodes[1]) + self.assert_message_sent(expected_response, self.test_nodes[1], want_ack=True) def test_handle_busy_detailed_command(self): packet = build_test_text_packet('!nodes busy detailed', self.test_nodes[1].user.id, self.bot.my_id) @@ -81,7 +82,32 @@ def test_handle_busy_specific_node(self): for packet_type, count in sorted_breakdown: expected_response += f"- {packet_type}: {count}\n" - self.assert_message_sent(expected_response, self.test_nodes[1]) + self.assert_message_sent(expected_response, self.test_nodes[1], want_ack=True) + + def test_handle_totals_command(self): + packet = build_test_text_packet('!nodes totals', self.test_nodes[1].user.id, self.bot.my_id) + self.command.handle_packet(packet) + + # The command calls bot.report_node_count(destination=from_id) + # which sends "MTEK has a node count of X" + online_count = len(self.bot.node_info.get_online_nodes()) + expected_message = f"MTEK has a node count of {online_count}" + + self.assert_message_sent(expected_message, self.test_nodes[1], want_ack=True) + + def test_handle_totals_channel_command(self): + packet = build_test_text_packet('!nodes totals 3', self.test_nodes[1].user.id, self.bot.my_id) + self.command.handle_packet(packet) + + # The command calls bot.report_node_count(channel_index=3) + online_count = len(self.bot.node_info.get_online_nodes()) + expected_report = f"MTEK has a node count of {online_count}" + + # It also replies to the user + expected_reply = "Node count report sent to channel 3." + + self.mock_interface.sendText.assert_any_call(expected_report, channelIndex=3, wantAck=True) + self.mock_interface.sendText.assert_any_call(expected_reply, destinationId=self.test_nodes[1].user.id, wantAck=True) if __name__ == '__main__': diff --git a/test/commands/test_tr.py b/test/commands/test_tr.py new file mode 100644 index 0000000..2be4fb3 --- /dev/null +++ b/test/commands/test_tr.py @@ -0,0 +1,93 @@ +import unittest +from unittest.mock import MagicMock, call +from src.commands.tr import TracerouteCommand +from test.commands import CommandTestCase +from test.test_setup_data import build_test_text_packet + +class TestTracerouteCommand(CommandTestCase): + command: TracerouteCommand + + def setUp(self): + super().setUp() + self.command = TracerouteCommand(bot=self.bot) + # Mock sendTraceRoute since it's used in handle_packet + self.bot.interface.sendTraceRoute = MagicMock() + + def test_handle_packet_basic(self): + # !tr from node 1 + sender_id = self.test_nodes[1].user.id + packet = build_test_text_packet('!tr', sender_id, self.bot.my_id) + packet['hopStart'] = 3 + packet['hopLimit'] = 2 + # Ensure we know the SNR for the test + packet['rxSnr'] = 5.5 + + self.command.handle_packet(packet) + + # Check starting message sent to sender + expected_msg = f"{self.test_nodes[1].user.long_name} you are 1 hops away (Signal: 5.5 dB). Starting full traceroute..." + self.mock_interface.sendText.assert_any_call(expected_msg, destinationId=sender_id, wantAck=True) + + # Check sendTraceRoute called for sender + self.bot.interface.sendTraceRoute.assert_called_once_with(sender_id, hopLimit=7) + + # Check pending_traces entry + self.assertEqual(self.bot.pending_traces[sender_id], [sender_id]) + + def test_handle_packet_zero_hops(self): + sender_id = self.test_nodes[1].user.id + packet = build_test_text_packet('!tr', sender_id, self.bot.my_id) + packet['hopStart'] = 3 + packet['hopLimit'] = 3 + + self.command.handle_packet(packet) + + # Check zero hops message + expected_msg = f"{self.test_nodes[1].user.long_name} you are Zero Hops from me. No traceroute required!" + self.mock_interface.sendText.assert_any_call(expected_msg, destinationId=sender_id, wantAck=True) + self.bot.interface.sendTraceRoute.assert_not_called() + + def test_handle_packet_to_specific_node(self): + # Requester is node 1, Target is node 2 + requester_id = self.test_nodes[1].user.id + target_node = self.test_nodes[2] + target_short = target_node.user.short_name + + packet = build_test_text_packet(f'!tr {target_short}', requester_id, self.bot.my_id) + + self.command.handle_packet(packet) + + expected_msg = f"Starting traceroute to {target_node.user.long_name} ({target_node.user.id}) for you..." + self.mock_interface.sendText.assert_any_call(expected_msg, destinationId=requester_id, wantAck=True) + + self.bot.interface.sendTraceRoute.assert_called_once_with(target_node.user.id, hopLimit=7) + self.assertEqual(self.bot.pending_traces[target_node.user.id], [requester_id]) + + def test_handle_packet_unknown_shortname(self): + requester_id = self.test_nodes[1].user.id + packet = build_test_text_packet('!tr NONEXIST', requester_id, self.bot.my_id) + + self.command.handle_packet(packet) + + expected_msg = "Could not find node with short name 'NONEXIST'" + self.mock_interface.sendText.assert_any_call(expected_msg, destinationId=requester_id, wantAck=True) + self.bot.interface.sendTraceRoute.assert_not_called() + + def test_handle_packet_to_self(self): + # Bot's ID is typically !00000001 in test setup + requester_id = self.test_nodes[1].user.id + # We need the bot's short name if we want to test by shortname, + # but the command specifically checks against self.bot.my_id. + # Let's find a way to trigger the "I am already here" message. + + # Manually find/set a short name for the bot if needed, or just use words[1] + self.bot.get_node_by_short_name = MagicMock(return_value=MagicMock(id=self.bot.my_id, long_name="Bot")) + + packet = build_test_text_packet('!tr BOT', requester_id, self.bot.my_id) + self.command.handle_packet(packet) + + expected_msg = "I am already here! No traceroute required." + self.mock_interface.sendText.assert_any_call(expected_msg, destinationId=requester_id, wantAck=True) + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_tcp_proxy.py b/test/test_tcp_proxy.py new file mode 100644 index 0000000..d8a65a6 --- /dev/null +++ b/test/test_tcp_proxy.py @@ -0,0 +1,48 @@ +import unittest +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch +from src.tcp_proxy import TcpProxy + +class TestTcpProxy(unittest.TestCase): + def setUp(self): + self.proxy = TcpProxy("127.0.0.1", 4403, "127.0.0.1", 4404) + + def test_status_fields(self): + status = self.proxy.get_status() + self.assertIn("Offline", status) + + self.proxy.running = True + self.proxy.target_writer = MagicMock() + status = self.proxy.get_status() + self.assertEqual(status["state"], "Online") + self.assertEqual(status["clients"], 0) + + def test_remove_client(self): + mock_writer = MagicMock() + mock_writer.get_extra_info.return_value = ("127.0.0.1", 12345) + + self.proxy.clients.add(mock_writer) + self.proxy._remove_client(mock_writer) + + self.assertEqual(len(self.proxy.clients), 0) + mock_writer.close.assert_called_once() + + @patch('asyncio.start_server', new_callable=AsyncMock) + def test_async_run_binds_server(self, mock_start_server): + async def run_test(): + self.proxy.running = True + + # Cancel the watchdog and connection manager immediately to avoid hang + async def stop_soon(): + await asyncio.sleep(0.1) + self.proxy.running = False + + asyncio.create_task(stop_soon()) + await self.proxy._async_run() + + mock_start_server.assert_called_once() + + asyncio.run(run_test()) + +if __name__ == "__main__": + unittest.main()