Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
logging.getLogger("mesh_interface").setLevel(logging.WARNING)

from src.api.packet_serializer import PacketSerializer

# Now we can import the rest of our local files
from src.api.StorageAPI import StorageAPIWrapper
from src.bot import MeshflowBot
Expand Down Expand Up @@ -221,6 +220,9 @@ def main() -> None:
if RADIO_PROTOCOL == "meshcore"
else None
),
feeder_node_id_provider=(
(lambda: bot.my_nodenum) if RADIO_PROTOCOL == "meshtastic" else None
),
)

try:
Expand Down
18 changes: 14 additions & 4 deletions src/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
on_connect: Optional[Callable[[], None]] = None,
on_disconnect: Optional[Callable[[], None]] = None,
feeder_pubkey_prefix_provider: Optional[Callable[[], Optional[str]]] = None,
feeder_node_id_provider: Optional[Callable[[], Optional[int]]] = None,
):
"""
Args:
Expand All @@ -49,6 +50,7 @@ def __init__(
self.on_connect = on_connect
self.on_disconnect = on_disconnect
self._feeder_pubkey_prefix_provider = feeder_pubkey_prefix_provider
self._feeder_node_id_provider = feeder_node_id_provider

self._running = False
self._task: Optional[asyncio.Task] = None
Expand All @@ -61,6 +63,10 @@ def _get_ws_endpoint(self) -> str:
prefix = self._feeder_pubkey_prefix_provider()
if prefix:
url += f"&feeder_pubkey_prefix={quote(prefix, safe='')}"
if self._feeder_node_id_provider:
node_id = self._feeder_node_id_provider()
if node_id is not None:
url += f"&feeder_node_id={quote(str(node_id), safe='')}"
return url

def start(self):
Expand Down Expand Up @@ -148,9 +154,15 @@ async def _connect_and_receive(self):
if self._feeder_pubkey_prefix_provider
else None
)
node_id = (
self._feeder_node_id_provider()
if self._feeder_node_id_provider
else None
)
logger.info(
"MeshflowWSClient: connected (feeder_pubkey_prefix=%s)",
"MeshflowWSClient: connected (feeder_pubkey_prefix=%s, feeder_node_id=%s)",
prefix or "none",
node_id if node_id is not None else "none",
)
if self.on_connect:
try:
Expand Down Expand Up @@ -237,9 +249,7 @@ def _apply_done(t):
)
elif cmd_type == "refresh_feeder_config":
if self.on_refresh_feeder_config:
logger.info(
"MeshflowWSClient: received refresh_feeder_config"
)
logger.info("MeshflowWSClient: received refresh_feeder_config")
task = asyncio.create_task(
asyncio.to_thread(self.on_refresh_feeder_config)
)
Expand Down
Empty file added test/ws/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions test/ws/test_ws_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Tests for MeshflowWSClient WebSocket URL construction."""

from src.ws_client import MeshflowWSClient


def test_ws_endpoint_includes_feeder_node_id():
client = MeshflowWSClient(
ws_url="ws://localhost:8000",
api_key="test-key",
on_traceroute=lambda _target: None,
feeder_node_id_provider=lambda: 1127973616,
)
endpoint = client._get_ws_endpoint()
assert (
endpoint
== "ws://localhost:8000/ws/nodes/?api_key=test-key&feeder_node_id=1127973616"
)


def test_ws_endpoint_includes_feeder_pubkey_prefix():
client = MeshflowWSClient(
ws_url="ws://localhost:8000",
api_key="test-key",
on_traceroute=lambda _target: None,
feeder_pubkey_prefix_provider=lambda: "1a37f5aea4a1",
)
endpoint = client._get_ws_endpoint()
assert "feeder_pubkey_prefix=1a37f5aea4a1" in endpoint
assert "feeder_node_id=" not in endpoint


def test_ws_endpoint_api_key_only_when_no_providers():
client = MeshflowWSClient(
ws_url="ws://localhost:8000",
api_key="test-key",
on_traceroute=lambda _target: None,
)
assert client._get_ws_endpoint() == "ws://localhost:8000/ws/nodes/?api_key=test-key"