Skip to content

Commit 6eb0a2a

Browse files
committed
feat: add REST client and WebSocket gateway
1 parent d688cc7 commit 6eb0a2a

3 files changed

Lines changed: 380 additions & 1 deletion

File tree

stackcoin/stackcoin/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""StackCoin Python SDK."""
22

3+
from .client import Client
34
from .errors import StackCoinError
5+
from .gateway import Event, Gateway
46

5-
__all__ = ["StackCoinError"]
7+
__all__ = ["Client", "Event", "Gateway", "StackCoinError"]

stackcoin/stackcoin/client.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""Async REST client for the StackCoin API."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any
6+
7+
import httpx
8+
9+
from .errors import StackCoinError
10+
from .models import (
11+
CreateRequestResponse,
12+
DiscordGuild,
13+
DiscordGuildsResponse,
14+
Request,
15+
RequestActionResponse,
16+
RequestsResponse,
17+
SendStkResponse,
18+
Transaction,
19+
TransactionsResponse,
20+
User,
21+
UsersResponse,
22+
)
23+
24+
25+
class Client:
26+
"""Async client for the StackCoin REST API.
27+
28+
Usage::
29+
30+
async with Client("https://stackcoin.example.com", token="sk-...") as client:
31+
me = await client.get_me()
32+
print(me.username, me.balance)
33+
"""
34+
35+
def __init__(
36+
self,
37+
base_url: str,
38+
token: str,
39+
*,
40+
timeout: float = 10.0,
41+
) -> None:
42+
self._http = httpx.AsyncClient(
43+
base_url=base_url,
44+
headers={
45+
"Authorization": f"Bearer {token}",
46+
"Accept": "application/json",
47+
},
48+
timeout=timeout,
49+
)
50+
51+
# -- context manager -------------------------------------------------- #
52+
53+
async def __aenter__(self) -> Client:
54+
return self
55+
56+
async def __aexit__(
57+
self,
58+
exc_type: type[BaseException] | None,
59+
exc_val: BaseException | None,
60+
exc_tb: Any,
61+
) -> None:
62+
await self.close()
63+
64+
async def close(self) -> None:
65+
"""Close the underlying HTTP connection pool."""
66+
await self._http.aclose()
67+
68+
# -- shared helpers --------------------------------------------------- #
69+
70+
@staticmethod
71+
def _raise_for_error(resp: httpx.Response) -> None:
72+
"""Raise :class:`StackCoinError` on any 4xx/5xx response."""
73+
if resp.status_code >= 400:
74+
try:
75+
body = resp.json()
76+
except Exception:
77+
body = {}
78+
error = body.get("error", f"http_{resp.status_code}")
79+
message = body.get("message")
80+
raise StackCoinError(resp.status_code, error, message)
81+
82+
# -- users ------------------------------------------------------------ #
83+
84+
async def get_me(self) -> User:
85+
"""Return the authenticated user's profile."""
86+
resp = await self._http.get("/api/user/me")
87+
self._raise_for_error(resp)
88+
return User.model_validate(resp.json())
89+
90+
async def get_user(self, user_id: int) -> User:
91+
"""Return a user by their ID."""
92+
resp = await self._http.get(f"/api/user/{user_id}")
93+
self._raise_for_error(resp)
94+
return User.model_validate(resp.json())
95+
96+
async def get_users(self, *, discord_id: str | None = None) -> list[User]:
97+
"""Return a list of users, optionally filtered by Discord ID."""
98+
params: dict[str, Any] = {}
99+
if discord_id is not None:
100+
params["discord_id"] = discord_id
101+
resp = await self._http.get("/api/users", params=params)
102+
self._raise_for_error(resp)
103+
wrapper = UsersResponse.model_validate(resp.json())
104+
return wrapper.users or []
105+
106+
# -- send ------------------------------------------------------------- #
107+
108+
async def send(
109+
self,
110+
to_user_id: int,
111+
amount: int,
112+
*,
113+
label: str | None = None,
114+
idempotency_key: str | None = None,
115+
) -> SendStkResponse:
116+
"""Send STK to another user."""
117+
body: dict[str, Any] = {"amount": amount}
118+
if label is not None:
119+
body["label"] = label
120+
headers: dict[str, str] = {}
121+
if idempotency_key is not None:
122+
headers["Idempotency-Key"] = idempotency_key
123+
resp = await self._http.post(
124+
f"/api/user/{to_user_id}/send",
125+
json=body,
126+
headers=headers,
127+
)
128+
self._raise_for_error(resp)
129+
return SendStkResponse.model_validate(resp.json())
130+
131+
# -- requests --------------------------------------------------------- #
132+
133+
async def create_request(
134+
self,
135+
to_user_id: int,
136+
amount: int,
137+
*,
138+
label: str | None = None,
139+
idempotency_key: str | None = None,
140+
) -> CreateRequestResponse:
141+
"""Create a STK request to another user."""
142+
body: dict[str, Any] = {"amount": amount}
143+
if label is not None:
144+
body["label"] = label
145+
headers: dict[str, str] = {}
146+
if idempotency_key is not None:
147+
headers["Idempotency-Key"] = idempotency_key
148+
resp = await self._http.post(
149+
f"/api/user/{to_user_id}/request",
150+
json=body,
151+
headers=headers,
152+
)
153+
self._raise_for_error(resp)
154+
return CreateRequestResponse.model_validate(resp.json())
155+
156+
async def get_request(self, request_id: int) -> Request:
157+
"""Return a single request by its ID."""
158+
resp = await self._http.get(f"/api/request/{request_id}")
159+
self._raise_for_error(resp)
160+
return Request.model_validate(resp.json())
161+
162+
async def get_requests(self, *, status: str | None = None) -> list[Request]:
163+
"""Return requests for the authenticated user, optionally filtered by status."""
164+
params: dict[str, Any] = {}
165+
if status is not None:
166+
params["status"] = status
167+
resp = await self._http.get("/api/requests", params=params)
168+
self._raise_for_error(resp)
169+
wrapper = RequestsResponse.model_validate(resp.json())
170+
return wrapper.requests or []
171+
172+
async def accept_request(self, request_id: int) -> RequestActionResponse:
173+
"""Accept a pending STK request."""
174+
resp = await self._http.post(f"/api/requests/{request_id}/accept")
175+
self._raise_for_error(resp)
176+
return RequestActionResponse.model_validate(resp.json())
177+
178+
async def deny_request(self, request_id: int) -> RequestActionResponse:
179+
"""Deny a pending STK request."""
180+
resp = await self._http.post(f"/api/requests/{request_id}/deny")
181+
self._raise_for_error(resp)
182+
return RequestActionResponse.model_validate(resp.json())
183+
184+
# -- transactions ----------------------------------------------------- #
185+
186+
async def get_transactions(self) -> list[Transaction]:
187+
"""Return transactions for the authenticated user."""
188+
resp = await self._http.get("/api/transactions")
189+
self._raise_for_error(resp)
190+
wrapper = TransactionsResponse.model_validate(resp.json())
191+
return wrapper.transactions or []
192+
193+
async def get_transaction(self, transaction_id: int) -> Transaction:
194+
"""Return a single transaction by its ID."""
195+
resp = await self._http.get(f"/api/transaction/{transaction_id}")
196+
self._raise_for_error(resp)
197+
return Transaction.model_validate(resp.json())
198+
199+
# -- events ----------------------------------------------------------- #
200+
201+
async def get_events(self, *, since_id: int = 0) -> list[dict[str, Any]]:
202+
"""Return events since the given ID.
203+
204+
Events are not yet in the OpenAPI spec, so this returns raw dicts.
205+
"""
206+
params: dict[str, Any] = {}
207+
if since_id:
208+
params["since_id"] = since_id
209+
resp = await self._http.get("/api/events", params=params)
210+
self._raise_for_error(resp)
211+
data = resp.json()
212+
return data.get("events", data) if isinstance(data, dict) else data
213+
214+
# -- discord guilds --------------------------------------------------- #
215+
216+
async def get_discord_guilds(self) -> list[DiscordGuild]:
217+
"""Return all Discord guilds."""
218+
resp = await self._http.get("/api/discord/guilds")
219+
self._raise_for_error(resp)
220+
wrapper = DiscordGuildsResponse.model_validate(resp.json())
221+
return wrapper.guilds or []
222+
223+
async def get_discord_guild(self, snowflake: str) -> DiscordGuild:
224+
"""Return a single Discord guild by its snowflake ID."""
225+
resp = await self._http.get(f"/api/discord/guild/{snowflake}")
226+
self._raise_for_error(resp)
227+
return DiscordGuild.model_validate(resp.json())

stackcoin/stackcoin/gateway.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""StackCoin WebSocket Gateway client."""
2+
3+
import asyncio
4+
import json
5+
from typing import Any, Callable, Awaitable
6+
7+
from pydantic import BaseModel
8+
9+
10+
EventHandler = Callable[["Event"], Awaitable[None]]
11+
12+
13+
class Event(BaseModel):
14+
"""A StackCoin event received via the gateway."""
15+
id: int
16+
type: str
17+
data: dict[str, Any]
18+
inserted_at: str
19+
20+
21+
class Gateway:
22+
"""WebSocket gateway for receiving real-time StackCoin events.
23+
24+
Usage::
25+
26+
gateway = stackcoin.Gateway(
27+
ws_url="ws://localhost:4000/bot/websocket",
28+
token="...",
29+
)
30+
31+
@gateway.on("request.accepted")
32+
async def handle_accepted(event: stackcoin.Event):
33+
print(event.data["request_id"])
34+
35+
await gateway.connect()
36+
"""
37+
38+
def __init__(
39+
self,
40+
ws_url: str,
41+
token: str,
42+
last_event_id: int = 0,
43+
on_event_id: Callable[[int], None] | None = None,
44+
):
45+
# ws_url should be the full websocket URL like "ws://localhost:4000/bot/websocket"
46+
self._ws_url = ws_url.rstrip("/")
47+
self._token = token
48+
self._handlers: dict[str, list[EventHandler]] = {}
49+
self._last_event_id = last_event_id
50+
self._on_event_id = on_event_id # callback to persist cursor position
51+
self._ws = None
52+
self._running = False
53+
self._ref_counter = 0
54+
55+
@property
56+
def last_event_id(self) -> int:
57+
return self._last_event_id
58+
59+
def on(self, event_type: str) -> Callable[[EventHandler], EventHandler]:
60+
"""Decorator to register an event handler."""
61+
def decorator(func: EventHandler) -> EventHandler:
62+
self.register_handler(event_type, func)
63+
return func
64+
return decorator
65+
66+
def register_handler(self, event_type: str, handler: EventHandler) -> None:
67+
"""Register an event handler programmatically."""
68+
if event_type not in self._handlers:
69+
self._handlers[event_type] = []
70+
self._handlers[event_type].append(handler)
71+
72+
async def connect(self) -> None:
73+
"""Connect and listen for events. Reconnects automatically on failure."""
74+
import websockets
75+
76+
self._running = True
77+
78+
while self._running:
79+
try:
80+
url = f"{self._ws_url}?token={self._token}&vsn=2.0.0"
81+
82+
async with websockets.connect(url) as ws:
83+
self._ws = ws
84+
await self._join_channel(ws)
85+
86+
heartbeat_task = asyncio.create_task(self._heartbeat(ws))
87+
try:
88+
async for raw_msg in ws:
89+
msg = json.loads(raw_msg)
90+
await self._handle_message(msg)
91+
finally:
92+
heartbeat_task.cancel()
93+
94+
except Exception:
95+
if self._running:
96+
await asyncio.sleep(5)
97+
98+
async def _join_channel(self, ws: Any) -> None:
99+
"""Join the user:self channel with event replay."""
100+
self._ref_counter += 1
101+
join_msg = json.dumps([
102+
None,
103+
str(self._ref_counter),
104+
"user:self",
105+
"phx_join",
106+
{"last_event_id": self._last_event_id},
107+
])
108+
await ws.send(join_msg)
109+
110+
reply = json.loads(await asyncio.wait_for(ws.recv(), timeout=10))
111+
if not (reply[3] == "phx_reply" and reply[4].get("status") == "ok"):
112+
raise ConnectionError(f"Failed to join channel: {reply}")
113+
114+
async def _heartbeat(self, ws: Any) -> None:
115+
"""Send periodic heartbeats."""
116+
while True:
117+
await asyncio.sleep(30)
118+
self._ref_counter += 1
119+
hb = json.dumps([None, str(self._ref_counter), "phoenix", "heartbeat", {}])
120+
await ws.send(hb)
121+
122+
async def _handle_message(self, msg: list[Any]) -> None:
123+
"""Dispatch incoming message to registered handlers."""
124+
if len(msg) < 5:
125+
return
126+
127+
event_name = msg[3]
128+
payload = msg[4]
129+
130+
if event_name == "event":
131+
event = Event.model_validate(payload)
132+
133+
if event.id > self._last_event_id:
134+
self._last_event_id = event.id
135+
136+
for handler in self._handlers.get(event.type, []):
137+
try:
138+
await handler(event)
139+
except Exception:
140+
pass
141+
142+
if event.id > 0 and self._on_event_id:
143+
try:
144+
self._on_event_id(event.id)
145+
except Exception:
146+
pass
147+
148+
def stop(self) -> None:
149+
"""Signal the gateway to stop."""
150+
self._running = False

0 commit comments

Comments
 (0)