From b7908eba38c07b56d64f8c18d2961a1d7fa477af Mon Sep 17 00:00:00 2001 From: Shubh Date: Fri, 23 Jan 2026 14:15:16 +0530 Subject: [PATCH 1/2] enhance HyperStackClient --- python/hyperstack-sdk/examples/basic_usage.py | 41 -------- .../hyperstack-sdk/examples/custom_parser.py | 14 ++- python/hyperstack-sdk/examples/pumpfun.py | 77 +++++++++++++++ python/hyperstack-sdk/hyperstack/__init__.py | 17 +++- python/hyperstack-sdk/hyperstack/client.py | 97 +++++++++++++------ .../{websocket.py => connection.py} | 18 +++- python/hyperstack-sdk/hyperstack/store.py | 71 ++++++++++++++ python/hyperstack-sdk/hyperstack/types.py | 43 +++++++- python/hyperstack-sdk/hyperstack/views.py | 61 ++++++++++++ 9 files changed, 365 insertions(+), 74 deletions(-) delete mode 100644 python/hyperstack-sdk/examples/basic_usage.py create mode 100644 python/hyperstack-sdk/examples/pumpfun.py rename python/hyperstack-sdk/hyperstack/{websocket.py => connection.py} (95%) create mode 100644 python/hyperstack-sdk/hyperstack/views.py diff --git a/python/hyperstack-sdk/examples/basic_usage.py b/python/hyperstack-sdk/examples/basic_usage.py deleted file mode 100644 index 8d26b9a5..00000000 --- a/python/hyperstack-sdk/examples/basic_usage.py +++ /dev/null @@ -1,41 +0,0 @@ -import asyncio -from hyperstack import HyperStackClient - - -async def basic_subscribe(): - view = "SettlementGame/list" - - async with HyperStackClient("wss://flip.stack.hypertek.app") as client: - store = client.subscribe(view=view) - - print(f"Subscribed to {view}, waiting for updates...\n") - - async for update in store: - print(f"Update received for key '{update.key}':") - print(f" Data: {update.data}\n") - - -async def multiple_subscriptions(): - async with HyperStackClient("wss://flip.stack.hypertek.app") as client: - games_store = client.subscribe("SettlementGame/list") - games_store_state = client.subscribe("SettlementGame/state") - - print("Subscribed to multiple views\n") - - async def handle_games(): - async for update in games_store: - print(f"[GAME LIST] {update.key} updated") - - async def handle_games_state(): - async for update in games_store_state: - print(f"[GAME STATE] {update.key} updated") - - await asyncio.gather(handle_games(), handle_games_state()) - - -if __name__ == "__main__": - try: - asyncio.run(basic_subscribe()) - # asyncio.run(multiple_subscriptions()) - except KeyboardInterrupt: - print("Keyboard interrupt received. Exiting gracefully.") diff --git a/python/hyperstack-sdk/examples/custom_parser.py b/python/hyperstack-sdk/examples/custom_parser.py index 90802f6f..b614d05a 100644 --- a/python/hyperstack-sdk/examples/custom_parser.py +++ b/python/hyperstack-sdk/examples/custom_parser.py @@ -38,6 +38,18 @@ class Game: events: Optional[Dict[str, Any]] = None +class SettlementGame: + NAME = "SettlementGame" + + @staticmethod + def state_view() -> str: + return "SettlementGame/state" + + @staticmethod + def list_view() -> str: + return "SettlementGame/list" + + def parse_game(data: Dict[str, Any]) -> Game: id_data = data.get("id", {}) status_data = data.get("status", {}) @@ -70,7 +82,7 @@ def parse_game(data: Dict[str, Any]) -> Game: async def main(): async with HyperStackClient("ws://localhost:8080") as client: - game_store = client.subscribe(view="SettlementGame/list", parser=parse_game) + game_store = client.watch(SettlementGame, parser=parse_game) print(f"connected, watching {game_store.view}\n") diff --git a/python/hyperstack-sdk/examples/pumpfun.py b/python/hyperstack-sdk/examples/pumpfun.py new file mode 100644 index 00000000..c2fb08f3 --- /dev/null +++ b/python/hyperstack-sdk/examples/pumpfun.py @@ -0,0 +1,77 @@ +import asyncio +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from hyperstack import HyperStackClient + + +@dataclass +class PumpfunTokenData: + mint: str + name: str + symbol: str + creator: Optional[str] + timestamp: Optional[int] + + +def parse_token(payload: Dict[str, Any]) -> Optional[PumpfunTokenData]: + info = payload.get("info") if isinstance(payload.get("info"), dict) else {} + token_id = payload.get("id") if isinstance(payload.get("id"), dict) else {} + events = payload.get("events") if isinstance(payload.get("events"), dict) else {} + + name = info.get("name") + symbol = info.get("symbol") + mint = token_id.get("mint") + + creator = None + timestamp = None + create_event = events.get("create") + if isinstance(create_event, dict): + name = name or create_event.get("name") + symbol = symbol or create_event.get("symbol") + mint = mint or create_event.get("mint") + creator = create_event.get("creator") + timestamp = create_event.get("timestamp") + + if not mint or not name or not symbol: + return None + + return PumpfunTokenData( + mint=mint, + name=name, + symbol=symbol, + creator=creator, + timestamp=timestamp, + ) + + +async def main() -> None: + print("Connecting to Solana via Hyperstack...\n") + async with HyperStackClient( + "wss://pumpfun-token-rfx6zp.stack.usehyperstack.com" + ) as client: + print("Connected! Streaming live pump.fun tokens:\n") + async for update in client.subscribe("PumpfunToken/list"): + if not isinstance(update.data, dict): + continue + token = parse_token(update.data) + if not token: + continue + print(f"New token: {token.name} ({token.symbol})") + print(f" Mint: {token.mint}") + if token.creator: + creator_short = ( + f"{token.creator[:8]}..." + if len(token.creator) > 8 + else token.creator + ) + print(f" Creator: {creator_short}") + print("") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("Keyboard interrupt received. Exiting gracefully.") + diff --git a/python/hyperstack-sdk/hyperstack/__init__.py b/python/hyperstack-sdk/hyperstack/__init__.py index a99a0272..7b96d749 100644 --- a/python/hyperstack-sdk/hyperstack/__init__.py +++ b/python/hyperstack-sdk/hyperstack/__init__.py @@ -1,14 +1,21 @@ """HyperStack Python SDK - Real-time data synchronization with authentication support.""" from hyperstack.client import HyperStackClient -from hyperstack.store import Store, Update +from hyperstack.store import Store, Update, SharedStore from hyperstack.types import ( + Entity, + StackDefinition, + ViewDef, + ViewGroup, + state_view, + list_view, SortOrder, SortConfig, SubscribedFrame, try_parse_subscribed_frame, ConnectionState, ) +from hyperstack.views import TypedViews from hyperstack.auth import ( AuthConfig, AuthToken, @@ -32,12 +39,20 @@ "HyperStackClient", "Store", "Update", + "SharedStore", # Types + "Entity", + "StackDefinition", + "ViewDef", + "ViewGroup", + "state_view", + "list_view", "SortOrder", "SortConfig", "SubscribedFrame", "try_parse_subscribed_frame", "ConnectionState", + "TypedViews", # Auth "AuthConfig", "AuthToken", diff --git a/python/hyperstack-sdk/hyperstack/client.py b/python/hyperstack-sdk/hyperstack/client.py index a5d2b580..845ca00c 100644 --- a/python/hyperstack-sdk/hyperstack/client.py +++ b/python/hyperstack-sdk/hyperstack/client.py @@ -5,9 +5,17 @@ import logging from typing import Dict, List, Optional, Callable -from hyperstack.websocket import WebSocketManager -from hyperstack.store import Store, Mode -from hyperstack.types import Subscription, Unsubscription, Frame +from hyperstack.connection import ConnectionManager +from hyperstack.store import Store, Mode, SharedStore +from hyperstack.types import ( + Subscription, + Unsubscription, + Frame, + Entity, + StackDefinition, + ConnectionState, +) +from hyperstack.views import create_typed_views, TypedViews from hyperstack.auth import AuthConfig logger = logging.getLogger(__name__) @@ -76,11 +84,11 @@ def __init__( auth: Optional authentication configuration. Required for hosted Hyperstack URLs. """ self.url = url - self._stores: Dict[str, Store] = {} + self._store = SharedStore() self._pending_subs: List[Subscription] = [] + self._active_subs: Dict[str, Subscription] = {} self._user_on_connect = on_connect - - self.ws_manager = WebSocketManager( + self.ws_manager = ConnectionManager( url=url, reconnect_intervals=reconnect_intervals, ping_interval=ping_interval, @@ -125,23 +133,58 @@ def subscribe( raise ValueError(f"Invalid view '{view}'. Expected: Entity/mode") mode = parse_mode(view) - store = Store(mode=mode, parser=parser, view=view) - - store_key = f"{view}:{key or '*'}" - self._stores[store_key] = store + store = self._store.get_store(view, mode=mode, parser=parser) sub = Subscription(view=view, key=key) - if self.ws_manager.is_running: - asyncio.create_task(self._send_sub(sub)) - else: - self._pending_subs.append(sub) + sub_key = sub.sub_key() + if sub_key not in self._active_subs: + self._active_subs[sub_key] = sub + if self.ws_manager.is_running: + asyncio.create_task(self._send_sub(sub)) + else: + self._pending_subs.append(sub) return store + async def get( + self, + entity: Entity, + key: str, + parser: Optional[Callable] = None, + timeout: Optional[float] = None, + ) -> Optional[Dict]: + view = entity.state_view() + self.subscribe(view, key=key, parser=parser) + await self._store.wait_for_view_ready(view, timeout=timeout) + return await self._store.get(entity.state_view(), key) + + async def list( + self, + entity: Entity, + parser: Optional[Callable] = None, + timeout: Optional[float] = None, + ) -> List: + view = entity.list_view() + self.subscribe(view, parser=parser) + await self._store.wait_for_view_ready(view, timeout=timeout) + return await self._store.list(entity.list_view()) + + def watch(self, entity: Entity, parser: Optional[Callable] = None) -> Store: + return self.subscribe(entity.list_view(), parser=parser) + + def watch_key( + self, entity: Entity, key: str, parser: Optional[Callable] = None + ) -> Store: + return self.subscribe(entity.list_view(), key=key, parser=parser) + + def views(self, stack: StackDefinition) -> TypedViews: + return create_typed_views(stack, self) + async def _on_connect(self) -> None: """Send queued subscriptions on connect.""" - while self._pending_subs: - await self._send_sub(self._pending_subs.pop(0)) + for sub in self._active_subs.values(): + await self._send_sub(sub) + self._pending_subs.clear() if self._user_on_connect: await self._user_on_connect() @@ -159,8 +202,10 @@ async def _send_sub(self, sub: Subscription) -> None: async def unsubscribe(self, view: str, key: Optional[str] = None) -> None: """Unsubscribe from a view.""" - store_key = f"{view}:{key or '*'}" - self._stores.pop(store_key, None) + sub = Subscription(view=view, key=key) + sub_key = sub.sub_key() + self._active_subs.pop(sub_key, None) + self._pending_subs = [s for s in self._pending_subs if s.sub_key() != sub_key] if not self.ws_manager.ws or not self.ws_manager.is_running: return @@ -191,15 +236,13 @@ async def _on_message(self, message) -> None: logger.debug( f"Frame: entity={frame.entity}, op={frame.op}, key={frame.key}" ) - - view = frame.entity - store_keys = [f"{view}:{frame.key}", f"{view}:*"] - - for store_key in store_keys: - store = self._stores.get(store_key) - if store: - logger.debug(f"Routing to: {store_key}") - await store.handle_frame(frame) + await self._store.apply_frame(frame) except Exception as e: logger.error(f"Message error: {e}", exc_info=True) + + def store(self) -> SharedStore: + return self._store + + def connection_state(self) -> ConnectionState: + return self.ws_manager.state() diff --git a/python/hyperstack-sdk/hyperstack/websocket.py b/python/hyperstack-sdk/hyperstack/connection.py similarity index 95% rename from python/hyperstack-sdk/hyperstack/websocket.py rename to python/hyperstack-sdk/hyperstack/connection.py index 7f7a05ba..7b017d11 100644 --- a/python/hyperstack-sdk/hyperstack/websocket.py +++ b/python/hyperstack-sdk/hyperstack/connection.py @@ -20,11 +20,12 @@ should_retry_error, DEFAULT_QUERY_PARAMETER, ) +from hyperstack.types import ConnectionState logger = logging.getLogger(__name__) -class WebSocketManager: +class ConnectionManager: def __init__( self, url: str, @@ -68,6 +69,7 @@ def __init__( self.ping_task: Optional[asyncio.Task] = None self.refresh_task: Optional[asyncio.Task] = None self.message_handler: Optional[Callable] = None + self._state: ConnectionState = ConnectionState.DISCONNECTED # Track if we're reconnecting for token refresh self._force_token_refresh = False @@ -228,6 +230,7 @@ async def connect(self) -> None: logger.info("Already connected") return + self._state = ConnectionState.CONNECTING attempt = 0 while attempt < len(self.reconnect_intervals): try: @@ -248,6 +251,7 @@ async def connect(self) -> None: self.is_running = True self.reconnect_attempts = 0 self._immediate_reconnect = False + self._state = ConnectionState.CONNECTED logger.info("Connected") self.receive_task = asyncio.create_task(self.receive_messages()) @@ -264,7 +268,9 @@ async def connect(self) -> None: raise except Exception as e: attempt += 1 + self._state = ConnectionState.RECONNECTING if attempt >= len(self.reconnect_intervals): + self._state = ConnectionState.ERROR raise ConnectionError( f"Connection failed after {attempt} attempts: {e}" ) @@ -278,6 +284,7 @@ async def connect(self) -> None: async def disconnect(self) -> None: """Close WebSocket connection and cleanup resources.""" self.is_running = False + self._state = ConnectionState.DISCONNECTED self._stop_ping() self._stop_token_refresh() @@ -400,7 +407,7 @@ async def receive_messages(self) -> None: self.auth.clear_token() self._force_token_refresh = True self._immediate_reconnect = True - + self._state = ConnectionState.ERROR if self.on_error: await self.on_error(e) @@ -411,6 +418,7 @@ async def receive_messages(self) -> None: except Exception as e: logger.error(f"Receive error: {e}") + self._state = ConnectionState.ERROR if self.on_error: await self.on_error(e) @@ -420,6 +428,7 @@ async def handle_reconnect(self) -> None: if self.reconnect_attempts > len(self.reconnect_intervals): logger.error("Max reconnect attempts reached") + self._state = ConnectionState.ERROR return delay = ( @@ -433,5 +442,8 @@ async def handle_reconnect(self) -> None: await asyncio.sleep(delay) else: logger.info(f"Reconnecting immediately (attempt {self.reconnect_attempts})") - + self._state = ConnectionState.RECONNECTING await self.connect() + + def state(self) -> ConnectionState: + return self._state diff --git a/python/hyperstack-sdk/hyperstack/store.py b/python/hyperstack-sdk/hyperstack/store.py index d8300555..1fba1d58 100644 --- a/python/hyperstack-sdk/hyperstack/store.py +++ b/python/hyperstack-sdk/hyperstack/store.py @@ -87,6 +87,7 @@ def __init__( self._lock = asyncio.Lock() self._callbacks: List[Callable[[Update[T]], None]] = [] self._update_queue: asyncio.Queue[Update[T]] = asyncio.Queue() + self._ready_event = asyncio.Event() if mode in (Mode.LIST, Mode.STATE): self._data: Union[OrderedDict[str, T], List[T]] = OrderedDict() @@ -298,6 +299,7 @@ async def _notify_update(self, key: str, data: T) -> None: update = Update(key=key, data=data) await self._update_queue.put(update) + self._ready_event.set() # Call all registered callbacks for callback in self._callbacks: @@ -305,3 +307,72 @@ async def _notify_update(self, key: str, data: T) -> None: callback(update) except Exception as e: logger.error(f"Callback error: {e}") + + async def wait_ready(self, timeout: Optional[float] = None) -> bool: + if self._ready_event.is_set(): + return True + try: + if timeout is None: + await self._ready_event.wait() + else: + await asyncio.wait_for(self._ready_event.wait(), timeout=timeout) + return True + except asyncio.TimeoutError: + return False + + +class SharedStore: + def __init__(self, max_entries_per_view: Optional[int] = DEFAULT_MAX_ENTRIES_PER_VIEW): + self._stores: Dict[str, Store[Any]] = {} + self._max_entries_per_view = max_entries_per_view + + def get_store( + self, + view: str, + mode: Mode = Mode.LIST, + parser: Optional[Callable[[Dict[str, Any]], Any]] = None, + ) -> Store[Any]: + if view in self._stores: + return self._stores[view] + store = Store( + mode=mode, parser=parser, view=view, max_entries=self._max_entries_per_view + ) + self._stores[view] = store + return store + + async def apply_frame(self, frame) -> None: + mode = Mode(frame.mode) if frame.mode in Mode._value2member_map_ else Mode.LIST + view = self._frame_view(frame) + store = self.get_store(view, mode=mode) + await store.handle_frame(frame) + + @staticmethod + def _frame_view(frame) -> str: + entity = getattr(frame, "entity", "") + mode = getattr(frame, "mode", "") + if "/" in entity: + return entity + if mode: + return f"{entity}/{mode}" + return entity + + async def wait_for_view_ready(self, view: str, timeout: Optional[float] = None) -> bool: + store = self._stores.get(view) + if not store: + return False + return await store.wait_ready(timeout) + + async def get(self, view: str, key: Optional[str] = None) -> Optional[Any]: + store = self._stores.get(view) + if not store: + return None + return await store.get_async(key) + + async def list(self, view: str) -> List[Any]: + store = self._stores.get(view) + if not store: + return [] + async with store._lock: + if isinstance(store._data, OrderedDict): + return list(store._data.values()) + return list(store._data) diff --git a/python/hyperstack-sdk/hyperstack/types.py b/python/hyperstack-sdk/hyperstack/types.py index 76cf5a43..7bfa0021 100644 --- a/python/hyperstack-sdk/hyperstack/types.py +++ b/python/hyperstack-sdk/hyperstack/types.py @@ -1,7 +1,7 @@ import gzip import json from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Protocol, TypeVar from dataclasses import dataclass, field @@ -87,6 +87,47 @@ class ConnectionState(str, Enum): RECONNECTING = "reconnecting" +T = TypeVar("T") + + +class Entity(Protocol[T]): + NAME: str + + @staticmethod + def state_view() -> str: + ... + + @staticmethod + def list_view() -> str: + ... + + +@dataclass(frozen=True) +class ViewDef: + mode: str + view: str + + +@dataclass(frozen=True) +class ViewGroup: + state: Optional[ViewDef] = None + list: Optional[ViewDef] = None + + +@dataclass(frozen=True) +class StackDefinition: + name: str + views: Dict[str, ViewGroup] + + +def state_view(view: str) -> ViewDef: + return ViewDef(mode="state", view=view) + + +def list_view(view: str) -> ViewDef: + return ViewDef(mode="list", view=view) + + def is_gzip(data: bytes) -> bool: return len(data) >= 2 and data[:2] == GZIP_MAGIC diff --git a/python/hyperstack-sdk/hyperstack/views.py b/python/hyperstack-sdk/hyperstack/views.py new file mode 100644 index 00000000..6d420ed4 --- /dev/null +++ b/python/hyperstack-sdk/hyperstack/views.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from hyperstack.types import ViewDef, StackDefinition + + +@dataclass +class TypedStateView: + view_def: ViewDef + client: Any + + def get(self, key: str, parser=None): + return self.client.subscribe(self.view_def.view, key=key, parser=parser) + + def watch(self, key: str, parser=None): + return self.client.subscribe(self.view_def.view, key=key, parser=parser) + + +@dataclass +class TypedListView: + view_def: ViewDef + client: Any + + def get(self, parser=None): + return self.client.subscribe(self.view_def.view, parser=parser) + + def watch(self, parser=None): + return self.client.subscribe(self.view_def.view, parser=parser) + + +@dataclass +class TypedViewGroup: + state: Optional[TypedStateView] = None + list: Optional[TypedListView] = None + + +class TypedViews: + def __init__(self, groups: Dict[str, TypedViewGroup]): + self._groups = groups + + def __getattr__(self, name: str) -> TypedViewGroup: + if name in self._groups: + return self._groups[name] + raise AttributeError(name) + + def __getitem__(self, name: str) -> TypedViewGroup: + return self._groups[name] + + +def create_typed_views(stack: StackDefinition, client: Any) -> TypedViews: + groups: Dict[str, TypedViewGroup] = {} + + for name, group in stack.views.items(): + typed_group = TypedViewGroup() + if group.state: + typed_group.state = TypedStateView(group.state, client) + if group.list: + typed_group.list = TypedListView(group.list, client) + groups[name] = typed_group + + return TypedViews(groups) From 11ec0edd399db3cf1da41f09b482cc217c0c3646 Mon Sep 17 00:00:00 2001 From: shubh Date: Sun, 5 Apr 2026 13:16:44 +0530 Subject: [PATCH 2/2] update python sdk --- python/hyperstack-sdk/hyperstack/client.py | 24 +++++++++---------- .../hyperstack-sdk/hyperstack/connection.py | 21 +++++++++------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/python/hyperstack-sdk/hyperstack/client.py b/python/hyperstack-sdk/hyperstack/client.py index 845ca00c..3f660765 100644 --- a/python/hyperstack-sdk/hyperstack/client.py +++ b/python/hyperstack-sdk/hyperstack/client.py @@ -3,7 +3,7 @@ import asyncio import json import logging -from typing import Dict, List, Optional, Callable +from typing import Any, Awaitable, Callable, Dict, List, Optional from hyperstack.connection import ConnectionManager from hyperstack.store import Store, Mode, SharedStore @@ -63,10 +63,10 @@ def __init__( url: str, reconnect_intervals: Optional[List[int]] = None, ping_interval: int = 15, - on_connect: Optional[Callable] = None, - on_disconnect: Optional[Callable] = None, - on_error: Optional[Callable] = None, - on_socket_issue: Optional[Callable[[dict], None]] = None, + on_connect: Optional[Callable[[], Awaitable[None]]] = None, + on_disconnect: Optional[Callable[[], Awaitable[None]]] = None, + on_error: Optional[Callable[[Exception], Awaitable[None]]] = None, + on_socket_issue: Optional[Callable[[dict], Awaitable[None]]] = None, auth: Optional[AuthConfig] = None, ): """ @@ -108,15 +108,15 @@ async def disconnect(self) -> None: """Disconnect from server.""" await self.ws_manager.disconnect() - async def __aenter__(self): + async def __aenter__(self) -> "HyperStackClient": await self.connect() return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.disconnect() def subscribe( - self, view: str, key: Optional[str] = None, parser: Optional[Callable] = None + self, view: str, key: Optional[str] = None, parser: Optional[Callable[[Dict[str, Any]], Any]] = None ) -> Store: """ Subscribe to updates for the specified view (and optional key) on the HyperStack server. @@ -150,7 +150,7 @@ async def get( self, entity: Entity, key: str, - parser: Optional[Callable] = None, + parser: Optional[Callable[[Dict[str, Any]], Any]] = None, timeout: Optional[float] = None, ) -> Optional[Dict]: view = entity.state_view() @@ -161,7 +161,7 @@ async def get( async def list( self, entity: Entity, - parser: Optional[Callable] = None, + parser: Optional[Callable[[Dict[str, Any]], Any]] = None, timeout: Optional[float] = None, ) -> List: view = entity.list_view() @@ -169,11 +169,11 @@ async def list( await self._store.wait_for_view_ready(view, timeout=timeout) return await self._store.list(entity.list_view()) - def watch(self, entity: Entity, parser: Optional[Callable] = None) -> Store: + def watch(self, entity: Entity, parser: Optional[Callable[[Dict[str, Any]], Any]] = None) -> Store: return self.subscribe(entity.list_view(), parser=parser) def watch_key( - self, entity: Entity, key: str, parser: Optional[Callable] = None + self, entity: Entity, key: str, parser: Optional[Callable[[Dict[str, Any]], Any]] = None ) -> Store: return self.subscribe(entity.list_view(), key=key, parser=parser) diff --git a/python/hyperstack-sdk/hyperstack/connection.py b/python/hyperstack-sdk/hyperstack/connection.py index 7b017d11..d1425525 100644 --- a/python/hyperstack-sdk/hyperstack/connection.py +++ b/python/hyperstack-sdk/hyperstack/connection.py @@ -3,9 +3,9 @@ import asyncio import json import logging -from typing import Optional, Callable, List, Any +from typing import Any, Awaitable, Callable, List, Optional, Union from websockets import connect as ws_connect -from websockets.client import WebSocketClientProtocol +from websockets.asyncio.client import ClientConnection from websockets.exceptions import WebSocketException from hyperstack.errors import ConnectionError, AuthError @@ -31,10 +31,10 @@ def __init__( url: str, reconnect_intervals: List[int] = None, ping_interval: int = 15, - on_connect: Optional[Callable] = None, - on_disconnect: Optional[Callable] = None, - on_error: Optional[Callable] = None, - on_socket_issue: Optional[Callable[[dict], Any]] = None, + on_connect: Optional[Callable[[], Awaitable[None]]] = None, + on_disconnect: Optional[Callable[[], Awaitable[None]]] = None, + on_error: Optional[Callable[[Exception], Awaitable[None]]] = None, + on_socket_issue: Optional[Callable[[dict], Awaitable[Any]]] = None, auth: Optional[AuthConfig] = None, ): """ @@ -62,20 +62,23 @@ def __init__( self.auth = AuthState(url, auth) self._auth_config = auth - self.ws: Optional[WebSocketClientProtocol] = None + self.ws: Optional[ClientConnection] = None self.is_running = False self.reconnect_attempts = 0 self.receive_task: Optional[asyncio.Task] = None self.ping_task: Optional[asyncio.Task] = None self.refresh_task: Optional[asyncio.Task] = None - self.message_handler: Optional[Callable] = None + self.message_handler: Optional[Callable[[Union[bytes, str]], Awaitable[None]]] = None self._state: ConnectionState = ConnectionState.DISCONNECTED + def set_message_handler(self, handler: Callable[[Union[bytes, str]], Awaitable[None]]) -> None: # Track if we're reconnecting for token refresh self._force_token_refresh = False self._immediate_reconnect = False - def set_message_handler(self, handler: Callable) -> None: + def set_message_handler( + self, handler: Callable[[Union[bytes, str]], Awaitable[None]] + ) -> None: """ Set the callback function for handling incoming WebSocket messages.