diff --git a/src/main.py b/src/main.py index e52743b..7667994 100644 --- a/src/main.py +++ b/src/main.py @@ -28,7 +28,6 @@ logging.getLogger("mesh_interface").setLevel(logging.WARNING) from src.api.packet_serializer import PacketSerializer - # Now we can import the rest of our local files from src.api.StorageAPI import StorageAPIWrapper from src.bot import MeshflowBot @@ -221,6 +220,9 @@ def main() -> None: if RADIO_PROTOCOL == "meshcore" else None ), + feeder_node_id_provider=( + (lambda: bot.my_nodenum) if RADIO_PROTOCOL == "meshtastic" else None + ), ) try: diff --git a/src/ws_client.py b/src/ws_client.py index cefea35..08760f0 100644 --- a/src/ws_client.py +++ b/src/ws_client.py @@ -32,6 +32,7 @@ def __init__( on_connect: Optional[Callable[[], None]] = None, on_disconnect: Optional[Callable[[], None]] = None, feeder_pubkey_prefix_provider: Optional[Callable[[], Optional[str]]] = None, + feeder_node_id_provider: Optional[Callable[[], Optional[int]]] = None, ): """ Args: @@ -49,6 +50,7 @@ def __init__( self.on_connect = on_connect self.on_disconnect = on_disconnect self._feeder_pubkey_prefix_provider = feeder_pubkey_prefix_provider + self._feeder_node_id_provider = feeder_node_id_provider self._running = False self._task: Optional[asyncio.Task] = None @@ -61,6 +63,10 @@ def _get_ws_endpoint(self) -> str: prefix = self._feeder_pubkey_prefix_provider() if prefix: url += f"&feeder_pubkey_prefix={quote(prefix, safe='')}" + if self._feeder_node_id_provider: + node_id = self._feeder_node_id_provider() + if node_id is not None: + url += f"&feeder_node_id={quote(str(node_id), safe='')}" return url def start(self): @@ -148,9 +154,15 @@ async def _connect_and_receive(self): if self._feeder_pubkey_prefix_provider else None ) + node_id = ( + self._feeder_node_id_provider() + if self._feeder_node_id_provider + else None + ) logger.info( - "MeshflowWSClient: connected (feeder_pubkey_prefix=%s)", + "MeshflowWSClient: connected (feeder_pubkey_prefix=%s, feeder_node_id=%s)", prefix or "none", + node_id if node_id is not None else "none", ) if self.on_connect: try: @@ -237,9 +249,7 @@ def _apply_done(t): ) elif cmd_type == "refresh_feeder_config": if self.on_refresh_feeder_config: - logger.info( - "MeshflowWSClient: received refresh_feeder_config" - ) + logger.info("MeshflowWSClient: received refresh_feeder_config") task = asyncio.create_task( asyncio.to_thread(self.on_refresh_feeder_config) ) diff --git a/test/ws/__init__.py b/test/ws/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/ws/test_ws_client.py b/test/ws/test_ws_client.py new file mode 100644 index 0000000..063fb1d --- /dev/null +++ b/test/ws/test_ws_client.py @@ -0,0 +1,38 @@ +"""Tests for MeshflowWSClient WebSocket URL construction.""" + +from src.ws_client import MeshflowWSClient + + +def test_ws_endpoint_includes_feeder_node_id(): + client = MeshflowWSClient( + ws_url="ws://localhost:8000", + api_key="test-key", + on_traceroute=lambda _target: None, + feeder_node_id_provider=lambda: 1127973616, + ) + endpoint = client._get_ws_endpoint() + assert ( + endpoint + == "ws://localhost:8000/ws/nodes/?api_key=test-key&feeder_node_id=1127973616" + ) + + +def test_ws_endpoint_includes_feeder_pubkey_prefix(): + client = MeshflowWSClient( + ws_url="ws://localhost:8000", + api_key="test-key", + on_traceroute=lambda _target: None, + feeder_pubkey_prefix_provider=lambda: "1a37f5aea4a1", + ) + endpoint = client._get_ws_endpoint() + assert "feeder_pubkey_prefix=1a37f5aea4a1" in endpoint + assert "feeder_node_id=" not in endpoint + + +def test_ws_endpoint_api_key_only_when_no_providers(): + client = MeshflowWSClient( + ws_url="ws://localhost:8000", + api_key="test-key", + on_traceroute=lambda _target: None, + ) + assert client._get_ws_endpoint() == "ws://localhost:8000/ws/nodes/?api_key=test-key"