diff --git a/src/lib/websocket/websocketService.ts b/src/lib/websocket/websocketService.ts index ec1b386..bc48e24 100644 --- a/src/lib/websocket/websocketService.ts +++ b/src/lib/websocket/websocketService.ts @@ -27,6 +27,7 @@ class WebSocketService { private reconnectDelay = 2000; // Start with 2 seconds private connectionState: ConnectionState = ConnectionState.DISCONNECTED; private baseUrl: string = ''; + private authListenersRegistered = false; /** * Initialize the WebSocket service with the base URL @@ -36,6 +37,11 @@ class WebSocketService { // Convert http/https to ws/wss this.baseUrl = baseUrl.replace(/^http/, 'ws'); + if (this.authListenersRegistered) { + return; + } + this.authListenersRegistered = true; + // Listen for auth events to reconnect when token changes eventService.subscribe(AuthEventType.AUTH_TOKEN_REFRESHED, () => { if (this.connectionState === ConnectionState.CONNECTED) { diff --git a/src/providers/WebSocketProvider.test.tsx b/src/providers/WebSocketProvider.test.tsx new file mode 100644 index 0000000..2cc2130 --- /dev/null +++ b/src/providers/WebSocketProvider.test.tsx @@ -0,0 +1,131 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { render, act } from '@testing-library/react'; +import { MemoryRouter, Routes, Route, useNavigate } from 'react-router-dom'; +import { WebSocketProvider } from './WebSocketProvider'; +import { eventService } from '@/lib/events/eventService'; +import { WebSocketEventType } from '@/lib/websocket/websocketService'; +import type { TextMessage } from '@/lib/models'; + +const mockConfig = { apis: { meshBot: { baseUrl: 'http://127.0.0.1:8000' } } }; + +const { connect, disconnect, initialize } = vi.hoisted(() => ({ + connect: vi.fn(), + disconnect: vi.fn(), + initialize: vi.fn(), +})); + +vi.mock('@/providers/ConfigProvider', () => ({ + useConfig: () => mockConfig, +})); + +vi.mock('@/lib/websocket/websocketService', () => ({ + websocketService: { + initialize, + connect, + disconnect, + }, + WebSocketEventType: { + CONNECTED: 'websocket:connected', + DISCONNECTED: 'websocket:disconnected', + MESSAGE_RECEIVED: 'websocket:message_received', + ERROR: 'websocket:error', + }, + ConnectionState: { + CONNECTING: 'connecting', + CONNECTED: 'connected', + DISCONNECTED: 'disconnected', + ERROR: 'error', + }, +})); + +const { toastMock } = vi.hoisted(() => ({ + toastMock: vi.fn(), +})); + +vi.mock('@/hooks/use-toast', () => ({ + toast: (...args: unknown[]) => toastMock(...args), + useToast: () => ({ toast: toastMock }), +})); + +function NavigationHarness({ onNavigate }: { onNavigate: (navigate: ReturnType) => void }) { + const navigate = useNavigate(); + onNavigate(navigate); + return null; +} + +function renderWithRoutes(initialPath: string) { + let navigateFn: ReturnType | null = null; + + const utils = render( + + + + (navigateFn = n)} />} /> + (navigateFn = n)} />} /> + (navigateFn = n)} />} /> + (navigateFn = n)} />} /> + + + + ); + + return { + ...utils, + navigate: (path: string) => { + if (!navigateFn) throw new Error('navigate not ready'); + act(() => navigateFn!(path)); + }, + }; +} + +const sampleMessage = { + id: 1, + message_text: 'hello', + protocol: 'meshtastic', + channel: 1, + sender: { node_id_str: '!aabbccdd', short_name: 'AB' }, +} as unknown as TextMessage; + +describe('WebSocketProvider', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('connects once on mount and does not disconnect on route changes', () => { + const { navigate } = renderWithRoutes('/'); + + expect(initialize).toHaveBeenCalledTimes(1); + expect(initialize).toHaveBeenCalledWith('http://127.0.0.1:8000'); + expect(connect).toHaveBeenCalledTimes(1); + expect(disconnect).not.toHaveBeenCalled(); + + navigate('/nodes'); + navigate('/messages'); + navigate('/meshcore/messages'); + navigate('/'); + + expect(connect).toHaveBeenCalledTimes(1); + expect(disconnect).not.toHaveBeenCalled(); + expect(initialize).toHaveBeenCalledTimes(1); + }); + + it('shows toast for messages when not on the matching messages page', () => { + renderWithRoutes('/'); + + act(() => { + eventService.emit(WebSocketEventType.MESSAGE_RECEIVED, sampleMessage); + }); + + expect(toastMock).toHaveBeenCalled(); + }); + + it('suppresses toast when on the matching messages page', () => { + renderWithRoutes('/messages'); + + act(() => { + eventService.emit(WebSocketEventType.MESSAGE_RECEIVED, sampleMessage); + }); + + expect(toastMock).not.toHaveBeenCalled(); + }); +}); diff --git a/src/providers/WebSocketProvider.tsx b/src/providers/WebSocketProvider.tsx index 01aa644..5a9ebbe 100644 --- a/src/providers/WebSocketProvider.tsx +++ b/src/providers/WebSocketProvider.tsx @@ -1,7 +1,7 @@ -import React, { createContext, useContext, useEffect, useState, useCallback, useMemo } from 'react'; +import React, { createContext, useContext, useEffect, useRef, useState, useCallback, useMemo } from 'react'; import { useConfig } from './ConfigProvider'; import { useLocation } from 'react-router-dom'; -import { useToast } from '@/hooks/use-toast'; +import { toast } from '@/hooks/use-toast'; import { websocketService, WebSocketEventType, ConnectionState } from '@/lib/websocket/websocketService'; import { TextMessage } from '@/lib/models'; import { eventService } from '@/lib/events/eventService'; @@ -36,11 +36,15 @@ export function useWebSocket() { export function WebSocketProvider({ children }: { children: React.ReactNode }) { const config = useConfig(); const location = useLocation(); - const { toast } = useToast(); + const pathnameRef = useRef(location.pathname); const [connectionState, setConnectionState] = useState(ConnectionState.DISCONNECTED); const [unreadMessages, setUnreadMessages] = useState([]); + useEffect(() => { + pathnameRef.current = location.pathname; + }, [location.pathname]); + const markAsReadForProtocol = useCallback((protocol: MessageProtocolSlug) => { setUnreadMessages((prev) => prev.filter((m) => messageProtocol(m) !== protocol)); }, []); @@ -77,7 +81,7 @@ export function WebSocketProvider({ children }: { children: React.ReactNode }) { const messageHandler = (message: TextMessage) => { const proto = messageProtocol(message); - if (isOnMessagesPage(location.pathname, proto)) { + if (isOnMessagesPage(pathnameRef.current, proto)) { return; } @@ -104,7 +108,7 @@ export function WebSocketProvider({ children }: { children: React.ReactNode }) { eventService.unsubscribe(WebSocketEventType.MESSAGE_RECEIVED, messageHandler); websocketService.disconnect(); }; - }, [config.apis.meshBot.baseUrl, toast, location.pathname]); + }, [config.apis.meshBot.baseUrl]); useEffect(() => { if (isOnMessagesPage(location.pathname, 'meshtastic')) {