From 3595c73fb6e9fc222569d56ab08989ee54065340 Mon Sep 17 00:00:00 2001 From: taras Date: Fri, 1 May 2026 22:07:16 +0200 Subject: [PATCH 01/57] Experimental --- examples/echo_client_websockets.py | 14 + picows/websockets/__init__.py | 40 ++ picows/websockets/asyncio/__init__.py | 8 + picows/websockets/asyncio/client.py | 832 ++++++++++++++++++++++++++ picows/websockets/exceptions.py | 88 +++ tests/test_websockets_compat.py | 69 +++ 6 files changed, 1051 insertions(+) create mode 100644 examples/echo_client_websockets.py create mode 100644 picows/websockets/__init__.py create mode 100644 picows/websockets/asyncio/__init__.py create mode 100644 picows/websockets/asyncio/client.py create mode 100644 picows/websockets/exceptions.py create mode 100644 tests/test_websockets_compat.py diff --git a/examples/echo_client_websockets.py b/examples/echo_client_websockets.py new file mode 100644 index 0000000..d43c6df --- /dev/null +++ b/examples/echo_client_websockets.py @@ -0,0 +1,14 @@ +import asyncio + +from picows import websockets + + +async def main(): + async with websockets.connect("ws://127.0.0.1:9001") as websocket: + await websocket.send("Hello world") + reply = await websocket.recv() + print(f"Echo reply: {reply}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/picows/websockets/__init__.py b/picows/websockets/__init__.py new file mode 100644 index 0000000..aeb99e0 --- /dev/null +++ b/picows/websockets/__init__.py @@ -0,0 +1,40 @@ +from . import exceptions +from .asyncio.client import ClientConnection, State, connect, process_exception +from .exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + InvalidHandshake, + InvalidHeader, + InvalidMessage, + InvalidState, + InvalidStatus, + InvalidUpgrade, + InvalidURI, + PayloadTooBig, + ProtocolError, + WebSocketException, +) + +__all__ = [ + "ClientConnection", + "ConcurrencyError", + "ConnectionClosed", + "ConnectionClosedError", + "ConnectionClosedOK", + "InvalidHandshake", + "InvalidHeader", + "InvalidMessage", + "InvalidState", + "InvalidStatus", + "InvalidUpgrade", + "InvalidURI", + "PayloadTooBig", + "ProtocolError", + "State", + "WebSocketException", + "connect", + "exceptions", + "process_exception", +] diff --git a/picows/websockets/asyncio/__init__.py b/picows/websockets/asyncio/__init__.py new file mode 100644 index 0000000..9869987 --- /dev/null +++ b/picows/websockets/asyncio/__init__.py @@ -0,0 +1,8 @@ +from .client import ClientConnection, State, connect, process_exception + +__all__ = [ + "ClientConnection", + "State", + "connect", + "process_exception", +] diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py new file mode 100644 index 0000000..bfa3df7 --- /dev/null +++ b/picows/websockets/asyncio/client.py @@ -0,0 +1,832 @@ +from __future__ import annotations + +import asyncio +import sys +import logging +import os +import socket +import uuid +import warnings +from collections.abc import AsyncIterable, Generator, Iterable +from dataclasses import dataclass +from enum import IntEnum +from ssl import SSLContext +from time import monotonic +from typing import Any, AsyncIterator, Awaitable, Callable, Optional, Sequence, Union, cast +from urllib.request import getproxies + +import picows +from picows.types import WSHeadersLike +from picows.url import parse_url + +from ..exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + InvalidHandshake, + InvalidHeader, + InvalidMessage, + InvalidState, + InvalidStatus, + InvalidUpgrade, + InvalidURI, + PayloadTooBig, + ProtocolError, +) + + +Data = Union[str, bytes, bytearray, memoryview] +HeadersLike = WSHeadersLike +CloseCodeT = Union[int, picows.WSCloseCode] +LoggerLike = Union[str, logging.Logger, logging.LoggerAdapter[Any], None] + + +OK_CLOSE_CODES = {0, 1000, 1001} + + +class State(IntEnum): + CONNECTING = 0 + OPEN = 1 + CLOSING = 2 + CLOSED = 3 + + +@dataclass(slots=True) +class _BufferedFrame: + msg_type: picows.WSMsgType + payload: bytes + fin: bool + + +def _coerce_close_code(code: Optional[picows.WSCloseCode]) -> Optional[int]: + return None if code is None else int(code.value) + + +def _coerce_close_reason(reason: Optional[str]) -> Optional[str]: + return reason if reason is not None else None + + +def _header_items(headers: Any) -> list[tuple[str, str]]: + return [] if headers is None else list(headers.items()) + + +def _resolve_subprotocol(subprotocols: Optional[Sequence[str]], response: Any) -> Optional[str]: + if response is None: + return None + value = response.headers.get("Sec-WebSocket-Protocol") + if value is None: + return None + if subprotocols is not None and value not in subprotocols: + raise InvalidHandshake(f"unsupported subprotocol negotiated by server: {value}") + return cast(str, value) + + +def _default_user_agent() -> str: + return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" + + +def _process_proxy(proxy: Union[str, bool, None], secure: bool) -> Optional[str]: + if proxy is None: + return None + if isinstance(proxy, str): + return proxy + if proxy is True: + proxies = getproxies() + return ( + proxies.get("wss" if secure else "ws") + or proxies.get("https" if secure else "http") + ) + raise InvalidURI(str(proxy), "proxy must be None, True, or a proxy URL") + + +def process_exception(exc: Exception) -> Optional[Exception]: + if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)): + return None + if isinstance(exc, InvalidStatus): + status = getattr(getattr(exc, "response", None), "status", None) + if status is not None and int(status) in {500, 502, 503, 504}: + return None + return exc + + +class _ConnectionListener(picows.WSListener): + def __init__(self, holder: dict[str, Any]): + self.holder = holder + + def on_ws_connected(self, transport: picows.WSTransport) -> None: + connection = self.holder.get("connection") + if connection is None: + self.holder["pending"].append(("connected", transport)) + else: + connection._on_connected(transport) + + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame) -> None: + del transport + event = _BufferedFrame(frame.msg_type, frame.get_payload_as_bytes(), frame.fin) + connection = self.holder.get("connection") + if connection is None: + self.holder["pending"].append(("frame", event)) + else: + connection._on_frame(event) + + def on_ws_disconnected(self, transport: picows.WSTransport) -> None: + del transport + connection = self.holder.get("connection") + if connection is None: + self.holder["pending"].append(("disconnected", None)) + else: + connection._on_disconnected() + + def pause_writing(self) -> None: + connection = self.holder.get("connection") + if connection is not None: + connection._pause_writing() + + def resume_writing(self) -> None: + connection = self.holder.get("connection") + if connection is not None: + connection._resume_writing() + + +class ClientConnection: + def __init__( + self, + transport: picows.WSTransport, + *, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = 10, + max_queue: Union[int, tuple[Optional[int], Optional[int]], None] = 16, + write_limit: Union[int, tuple[int, Optional[int]]] = 32768, + max_message_size: Optional[int] = 1024 * 1024, + max_fragment_size: Optional[int] = 1024 * 1024, + logger: LoggerLike = None, + subprotocols: Optional[Sequence[str]] = None, + ): + self.transport = transport + self.request = transport.request + self.response = transport.response + self.id = uuid.uuid4() + self.logger = self._resolve_logger(logger) + self._subprotocol = _resolve_subprotocol(subprotocols, self.response) + self._state = State.OPEN + self._closed_event = asyncio.Event() + self._frames: asyncio.Queue[Optional[_BufferedFrame]] = asyncio.Queue() + self._close_exc: Optional[ConnectionClosed] = None + self._disconnect_waiter = asyncio.create_task(self._watch_disconnect()) + self._recv_lock = asyncio.Lock() + self._send_lock = asyncio.Lock() + self._read_closed = False + self._write_paused = False + self._recv_streaming_in_progress = False + self._recv_streaming_broken = False + self._pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} + self._ping_interval = ping_interval + self._ping_timeout = ping_timeout + self._close_timeout = close_timeout + self._keepalive_task: Optional[asyncio.Task[None]] = None + self._latency = 0.0 + self._max_message_size = max_message_size + self._max_fragment_size = max_fragment_size + self._max_queue_high, self._max_queue_low = self._normalize_watermarks(max_queue) + self._set_write_limits(write_limit) + self._paused_reading = False + if ping_interval is not None: + self._keepalive_task = asyncio.create_task(self._keepalive_loop()) + + def _resolve_logger(self, logger: LoggerLike) -> Union[logging.Logger, logging.LoggerAdapter[Any]]: + if logger is None: + return logging.getLogger("websockets.client") + if isinstance(logger, str): + return logging.getLogger(logger) + return logger + + def _normalize_watermarks( + self, + max_queue: Union[int, tuple[Optional[int], Optional[int]], None], + ) -> tuple[Optional[int], Optional[int]]: + if max_queue is None: + return None, None + if isinstance(max_queue, tuple): + high, low = max_queue + if high is None: + return None, None + return high, high // 4 if low is None else low + return max_queue, max_queue // 4 + + def _set_write_limits(self, write_limit: Union[int, tuple[int, Optional[int]]]) -> None: + if isinstance(write_limit, tuple): + high, low = write_limit + else: + high, low = write_limit, None + self.transport.underlying_transport.set_write_buffer_limits(high=high, low=low) + + def _on_connected(self, transport: picows.WSTransport) -> None: + del transport + + def _pause_writing(self) -> None: + self._write_paused = True + + def _resume_writing(self) -> None: + self._write_paused = False + + def _pause_reading_if_needed(self) -> None: + if self._max_queue_high is None: + return + if not self._paused_reading and self._frames.qsize() >= self._max_queue_high: + self.transport.underlying_transport.pause_reading() + self._paused_reading = True + + def _resume_reading_if_needed(self) -> None: + if not self._paused_reading: + return + if self._max_queue_low is None or self._frames.qsize() <= self._max_queue_low: + self.transport.underlying_transport.resume_reading() + self._paused_reading = False + + def _set_close_exception(self) -> None: + handshake = self.transport.close_handshake + rcvd = getattr(handshake, "recv", None) if handshake is not None else None + sent = getattr(handshake, "sent", None) if handshake is not None else None + rcvd_then_sent = getattr(handshake, "recv_then_sent", None) if handshake is not None else None + rcvd_code = _coerce_close_code(getattr(rcvd, "code", None)) + sent_code = _coerce_close_code(getattr(sent, "code", None)) + ok = ( + (rcvd_code in OK_CLOSE_CODES or rcvd_code is None) + and (sent_code in OK_CLOSE_CODES or sent_code is None) + and handshake is not None + ) + exc_type = ConnectionClosedOK if ok else ConnectionClosedError + self._close_exc = exc_type(rcvd, sent, rcvd_then_sent) + + async def _watch_disconnect(self) -> None: + try: + await self.transport.wait_disconnected() + except Exception: + self._state = State.CLOSED + self._set_close_exception() + self._frames.put_nowait(None) + self._closed_event.set() + else: + self._state = State.CLOSED + self._set_close_exception() + self._frames.put_nowait(None) + self._closed_event.set() + finally: + if self._keepalive_task is not None: + self._keepalive_task.cancel() + for waiter, _ in self._pending_pings.values(): + if not waiter.done(): + waiter.set_exception(self._close_exc or ConnectionClosedError(None, None, None)) + self._pending_pings.clear() + + def _on_disconnected(self) -> None: + self._state = State.CLOSED + + def _fail_message_too_big(self, message: str) -> None: + if self._state is State.CLOSED: + return + self.transport.send_close(picows.WSCloseCode.MESSAGE_TOO_BIG, message.encode("utf-8")) + self.transport.disconnect(False) + + def _on_frame(self, frame: _BufferedFrame) -> None: + if frame.msg_type == picows.WSMsgType.PING: + self.transport.send_pong(frame.payload) + return + + if frame.msg_type == picows.WSMsgType.PONG: + payload = frame.payload + ping = self._pending_pings.pop(payload, None) + if ping is not None: + waiter, sent_at = ping + self._latency = monotonic() - sent_at + if not waiter.done(): + waiter.set_result(self._latency) + return + + if frame.msg_type == picows.WSMsgType.CLOSE: + close_code: CloseCodeT = picows.WSCloseCode.NO_INFO + close_message = b"" + if len(frame.payload) >= 2: + close_code = int.from_bytes(frame.payload[:2], "big") + close_message = frame.payload[2:] + if not self.transport.is_close_frame_sent: + self.transport.send_close(cast(picows.WSCloseCode, close_code), close_message) + self._state = State.CLOSING + self.transport.disconnect() + return + + payload = frame.payload + if self._max_fragment_size is not None and len(payload) > self._max_fragment_size: + self._fail_message_too_big("fragment too big") + return + + self._frames.put_nowait(frame) + self._pause_reading_if_needed() + + async def _next_frame(self) -> _BufferedFrame: + frame = await self._frames.get() + self._resume_reading_if_needed() + if frame is None: + raise self._connection_closed() + return frame + + def _connection_closed(self) -> ConnectionClosed: + if self._close_exc is None: + self._set_close_exception() + return self._close_exc or ConnectionClosedError(None, None, None) + + def _ensure_recv_available(self) -> None: + if self._recv_streaming_broken: + raise ConcurrencyError("recv_streaming() wasn't fully consumed") + if self._recv_streaming_in_progress: + raise ConcurrencyError("cannot call recv() while recv_streaming() is active") + + async def recv(self, decode: Optional[bool] = None) -> Union[str, bytes]: + self._ensure_recv_available() + if self._recv_lock.locked(): + raise ConcurrencyError("cannot call recv() concurrently") + async with self._recv_lock: + first = await self._next_frame() + if first.msg_type not in (picows.WSMsgType.TEXT, picows.WSMsgType.BINARY): + raise ProtocolError(f"unexpected opcode while receiving message: {first.msg_type}") + msg_type = first.msg_type + + chunks = [first.payload] + total = len(first.payload) + while not first.fin: + first = await self._next_frame() + if first.msg_type != picows.WSMsgType.CONTINUATION: + raise ProtocolError("expected continuation frame") + chunks.append(first.payload) + total += len(first.payload) + if self._max_message_size is not None and total > self._max_message_size: + self._fail_message_too_big("message too big") + raise PayloadTooBig("message too big") + + payload = b"".join(chunks) + return self._decode_payload(payload, msg_type, decode) + + def _decode_payload( + self, + payload: bytes, + msg_type: picows.WSMsgType, + decode: Optional[bool], + ) -> Union[str, bytes]: + if msg_type == picows.WSMsgType.TEXT: + if decode is False: + return payload + return payload.decode("utf-8") + if decode is True: + return payload.decode("utf-8") + return payload + + def recv_streaming(self, decode: Optional[bool] = None) -> AsyncIterator[Union[str, bytes]]: + self._ensure_recv_available() + if self._recv_lock.locked(): + raise ConcurrencyError("cannot call recv_streaming() concurrently") + self._recv_streaming_in_progress = True + started = False + finished = False + + async def iterator() -> AsyncIterator[Union[str, bytes]]: + nonlocal started, finished + try: + async with self._recv_lock: + first = await self._next_frame() + if first.msg_type not in (picows.WSMsgType.TEXT, picows.WSMsgType.BINARY): + raise ProtocolError(f"unexpected opcode while receiving message: {first.msg_type}") + msg_type = first.msg_type + started = True + yield self._decode_fragment(first.payload, msg_type, decode) + total = len(first.payload) + frame = first + while not frame.fin: + frame = await self._next_frame() + if frame.msg_type != picows.WSMsgType.CONTINUATION: + raise ProtocolError("expected continuation frame") + total += len(frame.payload) + if self._max_message_size is not None and total > self._max_message_size: + self._fail_message_too_big("message too big") + raise PayloadTooBig("message too big") + yield self._decode_fragment(frame.payload, msg_type, decode) + finished = True + finally: + if started and not finished: + self._recv_streaming_broken = True + elif finished: + self._recv_streaming_broken = False + self._recv_streaming_in_progress = False + + return iterator() + + def _decode_fragment( + self, + payload: bytes, + msg_type: picows.WSMsgType, + decode: Optional[bool], + ) -> Union[str, bytes]: + if msg_type == picows.WSMsgType.TEXT: + if decode is False: + return payload + return payload.decode("utf-8") + if decode is True: + return payload.decode("utf-8") + return payload + + async def send( + self, + message: Union[Data, Iterable[Data], AsyncIterator[Data]], + text: Optional[bool] = None, + ) -> None: + if self.state is State.CLOSED: + raise self._connection_closed() + + async with self._send_lock: + fragments = await self._collect_fragments(message) + if not fragments: + raise TypeError("message iterable cannot be empty") + + first = fragments[0] + if isinstance(first, str): + msg_type = picows.WSMsgType.TEXT + def encode(item: Data) -> bytes: + if not isinstance(item, str): + raise TypeError("all fragments must be of the same type") + return item.encode("utf-8") + elif isinstance(first, (bytes, bytearray, memoryview)): + msg_type = picows.WSMsgType.TEXT if text else picows.WSMsgType.BINARY + def encode(item: Data) -> bytes: + if not isinstance(item, (bytes, bytearray, memoryview)): + raise TypeError("all fragments must be of the same type") + return bytes(item) + else: + raise TypeError(f"message must contain str or bytes-like objects, got {type(first).__name__}") + + if len(fragments) == 1: + payload = encode(first) + self.transport.send(msg_type, payload) + return + + for index, fragment in enumerate(fragments): + if isinstance(first, str) and not isinstance(fragment, str): + raise TypeError("all fragments must be of the same type") + if not isinstance(first, str) and not isinstance(fragment, (bytes, bytearray, memoryview)): + raise TypeError("all fragments must be of the same type") + opcode = msg_type if index == 0 else picows.WSMsgType.CONTINUATION + self.transport.send(opcode, encode(fragment), fin=index == len(fragments) - 1) + + async def _collect_fragments( + self, + message: Union[Data, Iterable[Data], AsyncIterator[Data]], + ) -> list[Data]: + if isinstance(message, (str, bytes, bytearray, memoryview)): + return [message] + if isinstance(message, AsyncIterable): + result: list[Data] = [] + async for item in message: + result.append(item) + return result + if isinstance(message, Iterable): + return list(message) + raise TypeError(f"message has unsupported type {type(message).__name__}") + + async def close(self, code: CloseCodeT = 1000, reason: str = "") -> None: + if self.state is State.CLOSED: + return + if self.state is State.OPEN: + self._state = State.CLOSING + close_code = code if isinstance(code, picows.WSCloseCode) else picows.WSCloseCode(code) + self.transport.send_close(close_code, reason.encode("utf-8")) + try: + if self._close_timeout is None: + await self.wait_closed() + else: + await asyncio.wait_for(self.wait_closed(), self._close_timeout) + except asyncio.TimeoutError: + self.transport.disconnect(False) + await self.wait_closed() + + async def wait_closed(self) -> None: + await self._closed_event.wait() + + async def ping(self, data: Optional[Union[str, bytes]] = None) -> Awaitable[float]: + if self.state is State.CLOSED: + raise self._connection_closed() + if data is None: + while True: + payload = os.urandom(4) + if payload not in self._pending_pings: + break + elif isinstance(data, str): + payload = data.encode("utf-8") + elif isinstance(data, bytes): + payload = data + else: + raise TypeError("ping payload must be str, bytes, or None") + + if payload in self._pending_pings: + raise ConcurrencyError("another ping was sent with the same data") + + waiter: asyncio.Future[float] = asyncio.get_running_loop().create_future() + self._pending_pings[payload] = (waiter, monotonic()) + self.transport.send_ping(payload) + return waiter + + async def pong(self, data: Union[str, bytes] = b"") -> None: + if self.state is State.CLOSED: + raise self._connection_closed() + payload = data.encode("utf-8") if isinstance(data, str) else data + self.transport.send_pong(payload) + + async def _keepalive_loop(self) -> None: + try: + while True: + assert self._ping_interval is not None + await asyncio.sleep(self._ping_interval) + waiter = await self.ping() + if self._ping_timeout is None: + continue + await asyncio.wait_for(waiter, self._ping_timeout) + except asyncio.CancelledError: + raise + except Exception: + if self.state is not State.CLOSED: + await self.close(code=1011, reason="keepalive ping timeout") + + async def __aenter__(self) -> "ClientConnection": + return self + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + del exc_type, exc, tb + await self.close() + + def __aiter__(self) -> AsyncIterator[Union[str, bytes]]: + return self._iterate_messages() + + async def _iterate_messages(self) -> AsyncIterator[Union[str, bytes]]: + while True: + try: + yield await self.recv() + except ConnectionClosedOK: + return + + @property + def state(self) -> State: + return self._state + + @property + def local_address(self) -> Any: + return self.transport.underlying_transport.get_extra_info("sockname") + + @property + def remote_address(self) -> Any: + return self.transport.underlying_transport.get_extra_info("peername") + + @property + def latency(self) -> float: + return self._latency + + @property + def subprotocol(self) -> Optional[str]: + return self._subprotocol + + @property + def close_code(self) -> Optional[int]: + handshake = self.transport.close_handshake + if handshake is None: + return None + if handshake.recv is not None: + return _coerce_close_code(handshake.recv.code) + if handshake.sent is not None: + return _coerce_close_code(handshake.sent.code) + return None + + @property + def close_reason(self) -> Optional[str]: + handshake = self.transport.close_handshake + if handshake is None: + return None + if handshake.recv is not None: + return _coerce_close_reason(handshake.recv.reason) + if handshake.sent is not None: + return _coerce_close_reason(handshake.sent.reason) + return None + + +class _Connect: + def __init__( + self, + uri: str, + *, + origin: Optional[str] = None, + extensions: Optional[Sequence[Any]] = None, + subprotocols: Optional[Sequence[str]] = None, + compression: Optional[str] = "deflate", + additional_headers: Optional[HeadersLike] = None, + user_agent_header: Optional[str] = _default_user_agent(), + proxy: Union[str, bool, None] = True, + process_exception: Callable[[Exception], Optional[Exception]] = process_exception, + open_timeout: Optional[float] = 10, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = 10, + max_size: Union[int, tuple[Optional[int], Optional[int]], None] = 1024 * 1024, + max_queue: Union[int, tuple[Optional[int], Optional[int]], None] = 16, + write_limit: Union[int, tuple[int, Optional[int]]] = 32768, + logger: LoggerLike = None, + create_connection: Optional[type[ClientConnection]] = None, + **kwargs: Any, + ): + self.uri = uri + self.origin = origin + self.extensions = extensions + self.subprotocols = subprotocols + self.compression = compression + self.additional_headers = additional_headers + self.user_agent_header = user_agent_header + self.proxy = proxy + self.process_exception = process_exception + self.open_timeout = open_timeout + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.close_timeout = close_timeout + self.max_size = max_size + self.max_queue = max_queue + self.write_limit = write_limit + self.logger = logger + self.connection_factory = create_connection or ClientConnection + self.kwargs = kwargs + self._connection: Optional[ClientConnection] = None + self._backoff = 1.0 + + def __await__(self) -> Generator[Any, None, ClientConnection]: + return self._connect().__await__() + + async def __aenter__(self) -> ClientConnection: + self._connection = await self._connect() + return self._connection + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + del exc_type, exc, tb + if self._connection is not None: + await self._connection.close() + self._connection = None + + def __aiter__(self) -> "_Connect": + return self + + async def __anext__(self) -> ClientConnection: + if self._connection is not None: + await self._connection.close() + self._connection = None + while True: + try: + connection = await self._connect() + except Exception as exc: + processed = self.process_exception(exc) + if processed is not None: + raise processed + await asyncio.sleep(self._backoff) + self._backoff = min(self._backoff * 2, 60.0) + continue + self._connection = connection + self._backoff = 1.0 + return connection + + async def _connect(self) -> ClientConnection: + parsed = parse_url(self.uri) + proxy = _process_proxy(self.proxy, parsed.is_secure) + extra_headers = self._build_headers() + max_message_size, max_fragment_size = self._normalize_max_size(self.max_size) + + if self.extensions is not None: + raise NotImplementedError("custom extensions aren't supported by picows.websockets") + if self.compression not in (None, "deflate"): + raise NotImplementedError("only compression=None or 'deflate' are accepted") + if self.compression == "deflate": + warnings.warn( + "picows.websockets doesn't implement permessage-deflate; connecting without compression", + RuntimeWarning, + stacklevel=2, + ) + + conn_kwargs = dict(self.kwargs) + ssl_context = conn_kwargs.pop("ssl", None) + host_override = conn_kwargs.pop("host", None) + port_override = conn_kwargs.pop("port", None) + preexisting_sock = conn_kwargs.pop("sock", None) + + socket_factory = conn_kwargs.pop("socket_factory", None) + if preexisting_sock is not None: + if socket_factory is not None: + raise TypeError("cannot pass both sock and socket_factory") + + provided_sock = cast(socket.socket, preexisting_sock) + + def provided_socket(_: Any) -> socket.socket: + return provided_sock + + socket_factory = provided_socket + elif host_override is not None or port_override is not None: + if socket_factory is not None: + raise TypeError("cannot pass both host/port override and socket_factory") + + async def connect_override(_: Any) -> socket.socket: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + await asyncio.get_running_loop().sock_connect( + sock, + (host_override or parsed.host, port_override or parsed.port), + ) + return sock + + socket_factory = connect_override + holder: dict[str, Any] = {"pending": []} + + def listener_factory() -> _ConnectionListener: + return _ConnectionListener(holder) + + try: + transport, _listener = await picows.ws_connect( + listener_factory, + self.uri, + ssl_context=self._coerce_ssl_context(ssl_context), + websocket_handshake_timeout=self.open_timeout, + enable_auto_ping=False, + enable_auto_pong=False, + max_frame_size=max_fragment_size if max_fragment_size is not None else 2 ** 31 - 1, + extra_headers=extra_headers, + proxy=proxy, + socket_factory=socket_factory, + logger_name=self.logger if self.logger is not None else "websockets.client", + **conn_kwargs, + ) + except picows.WSInvalidURL as exc: + raise InvalidURI(exc.args[0], exc.args[1] if len(exc.args) > 1 else str(exc)) from exc + except picows.WSInvalidStatusError as exc: + raise InvalidStatus(exc.response) from exc + except picows.WSInvalidUpgradeError as exc: + raise InvalidUpgrade(exc.name, exc.value) from exc + except picows.WSInvalidHeaderError as exc: + raise InvalidHeader(exc.name, exc.value) from exc + except picows.WSInvalidMessageError as exc: + raise InvalidMessage(str(exc)) from exc + except picows.WSHandshakeError as exc: + raise InvalidHandshake(str(exc)) from exc + + connection = self.connection_factory( + transport, + ping_interval=self.ping_interval, + ping_timeout=self.ping_timeout, + close_timeout=self.close_timeout, + max_queue=self.max_queue, + write_limit=self.write_limit, + max_message_size=max_message_size, + max_fragment_size=max_fragment_size, + logger=self.logger, + subprotocols=self.subprotocols, + ) + holder["connection"] = connection + for kind, event in holder["pending"]: + if kind == "connected": + connection._on_connected(event) + elif kind == "frame": + connection._on_frame(event) + else: + connection._on_disconnected() + return connection + + def _normalize_max_size( + self, + max_size: Union[int, tuple[Optional[int], Optional[int]], None], + ) -> tuple[Optional[int], Optional[int]]: + if max_size is None: + return None, None + if isinstance(max_size, tuple): + return max_size + return max_size, max_size + + def _build_headers(self) -> list[tuple[str, str]]: + headers = _header_items(self.additional_headers) + if self.origin is not None: + headers.append(("Origin", self.origin)) + if self.user_agent_header is not None: + headers.append(("User-Agent", self.user_agent_header)) + if self.subprotocols: + headers.append(("Sec-WebSocket-Protocol", ", ".join(self.subprotocols))) + return headers + + def _coerce_ssl_context(self, value: Any) -> Optional[SSLContext]: + if value in (None, True): + return None + if value is False: + raise NotImplementedError("ssl=False isn't supported for wss:// URIs") + if not isinstance(value, SSLContext): + raise TypeError("ssl must be an SSLContext, True, False, or None") + return value + + +def connect(uri: str, **kwargs: Any) -> _Connect: + return _Connect(uri, **kwargs) diff --git a/picows/websockets/exceptions.py b/picows/websockets/exceptions.py new file mode 100644 index 0000000..3ae7794 --- /dev/null +++ b/picows/websockets/exceptions.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import Any, Optional + + +class WebSocketException(Exception): + """Base class for exceptions defined by picows.websockets.""" + + +class ConnectionClosed(WebSocketException): + def __init__(self, rcvd: Any, sent: Any, rcvd_then_sent: Optional[bool] = None): + super().__init__() + self.rcvd = rcvd + self.sent = sent + self.rcvd_then_sent = rcvd_then_sent + + def __str__(self) -> str: + if self.rcvd is None and self.sent is None: + return "no close frame received or sent" + if self.rcvd is None: + return f"sent {self.sent.code} ({self.sent.reason})" + if self.sent is None: + return f"received {self.rcvd.code} ({self.rcvd.reason})" + order = "received then sent" if self.rcvd_then_sent else "sent then received" + return ( + f"{order} close frames: " + f"received {self.rcvd.code} ({self.rcvd.reason}), " + f"sent {self.sent.code} ({self.sent.reason})" + ) + + +class ConnectionClosedOK(ConnectionClosed): + pass + + +class ConnectionClosedError(ConnectionClosed): + pass + + +class InvalidURI(WebSocketException): + def __init__(self, uri: str, msg: str): + super().__init__(uri, msg) + self.uri = uri + self.msg = msg + + def __str__(self) -> str: + return f"{self.uri} isn't a valid WebSocket URI: {self.msg}" + + +class InvalidHandshake(WebSocketException): + pass + + +class InvalidMessage(InvalidHandshake): + pass + + +class InvalidStatus(InvalidHandshake): + def __init__(self, response: Any): + super().__init__(response) + self.response = response + + +class InvalidHeader(InvalidHandshake): + def __init__(self, name: str, value: Optional[str] = None): + super().__init__(name, value) + self.name = name + self.value = value + + +class InvalidUpgrade(InvalidHeader): + pass + + +class ProtocolError(WebSocketException): + pass + + +class PayloadTooBig(WebSocketException): + pass + + +class InvalidState(WebSocketException): + pass + + +class ConcurrencyError(WebSocketException): + pass diff --git a/tests/test_websockets_compat.py b/tests/test_websockets_compat.py new file mode 100644 index 0000000..eca88c8 --- /dev/null +++ b/tests/test_websockets_compat.py @@ -0,0 +1,69 @@ +import asyncio + +import pytest + +from picows import websockets +from tests.utils import WSServer + + +async def test_connect_send_recv_text(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + await ws.send("hello") + reply = await ws.recv() + assert reply == "hello" + + +async def test_connect_send_recv_binary(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + await ws.send(b"hello") + reply = await ws.recv() + assert reply == b"hello" + + +async def test_async_iteration_closes_normally(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + await ws.send("hello") + assert await ws.recv() == "hello" + await ws.close() + + items = [] + async for item in ws: + items.append(item) + + assert items == [] + + +async def test_ping_returns_waiter(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + pong_waiter = await ws.ping(b"abcd") + latency = await asyncio.wait_for(pong_waiter, 1.0) + assert latency >= 0 + + +async def test_recv_streaming_fragmented_message(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + await ws.send([b"ab", b"cd"]) + fragments = [] + async for fragment in ws.recv_streaming(): + fragments.append(fragment) + assert fragments == [b"ab", b"cd"] + + +async def test_subprotocol_header_and_property(): + request_headers = {} + + def listener_factory(request): + request_headers["value"] = request.headers.get("Sec-WebSocket-Protocol") + return None + + async with WSServer(listener_factory) as server: + with pytest.raises(websockets.InvalidStatus): + async with websockets.connect(server.url, compression=None, subprotocols=["chat"]): + pass + + assert request_headers["value"] == "chat" From f369b00b49f8570bb2b4e168cbd4576a2768b71a Mon Sep 17 00:00:00 2001 From: taras Date: Sat, 2 May 2026 00:37:58 +0200 Subject: [PATCH 02/57] Refactoring --- AGENTS.md | 27 +-- picows/websockets/asyncio/client.py | 291 ++++++++++++++++------------ pyproject.toml | 4 +- tests/test_websockets_compat.py | 46 +++++ 4 files changed, 227 insertions(+), 141 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 921cb72..355fc8b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -2,18 +2,23 @@ Read README.md for the basic understanding of what this project is. -picows - this is the main package +picows - this is the main package. +picows.websockets - reimplements popular websockets library interface on top of picows +tests - Contains tests for picows +examples - Various examples for users on how to use picows + perf_test that could be used to build call-graph with perf -aiofastnet - Contains optimized versions of asyncio create_connection, create_server -I plan to make a separate python project for it, but I'm not there yet. It should be treated -as a separate python project. It will have its own tests, eventually its own description and docs. -The project contains very efficient repimplementation of SelectSocketTransport and SSLProtocol -using Cython and sometimes a pure C code. create_connection, create_server are defined in aiofastnet/api.py -sslproto.pyx - hack python SSLContext to get raw SSL_CTX*, it works with openssl api directly after that. -sslproto_stdlib.pyx - is just for reference, I will delete it soon, but now it's good for comparison between -stdlib ssl and whatever is in sslproto.pyx. - -tests - Contains tests for both picows and aiofastnet. Tests for aiofastnet will become a part of a separate project. +## Code style notes +- Do not write `del transport` or similar `del ` statements inside callbacks just to mark arguments as unused. + Leave unused callback parameters as-is or rename them with a leading underscore if that is clearer. + Using `del` in this situation is confusing and suggests reference-counting or lifetime management concerns. +- Prefer direct composition only when there is a real behavioral boundary. + Do not introduce adapter / holder / deferred-event plumbing just to preserve a conceptual separation. + If extra machinery exists only to work around the separation you introduced, the separation is probably wrong. +- Do not model impossible or non-normal internal states in the mainline code path without a concrete reason. + If an invariant is guaranteed by control flow, write the code around that invariant instead of adding repeated defensive checks. + Every extra "just in case" branch teaches the reader that the state is part of normal behavior. + Add such checks only for real risks like external misuse, concurrency races, partial failure, or invariants that are genuinely hard to guarantee. + If the only reason for the check is uncertainty in the design, fix the design first. ## Testing instructions - Run lint after updating code with: diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index bfa3df7..763b892 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -27,7 +27,6 @@ InvalidHandshake, InvalidHeader, InvalidMessage, - InvalidState, InvalidStatus, InvalidUpgrade, InvalidURI, @@ -110,49 +109,9 @@ def process_exception(exc: Exception) -> Optional[Exception]: return exc -class _ConnectionListener(picows.WSListener): - def __init__(self, holder: dict[str, Any]): - self.holder = holder - - def on_ws_connected(self, transport: picows.WSTransport) -> None: - connection = self.holder.get("connection") - if connection is None: - self.holder["pending"].append(("connected", transport)) - else: - connection._on_connected(transport) - - def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame) -> None: - del transport - event = _BufferedFrame(frame.msg_type, frame.get_payload_as_bytes(), frame.fin) - connection = self.holder.get("connection") - if connection is None: - self.holder["pending"].append(("frame", event)) - else: - connection._on_frame(event) - - def on_ws_disconnected(self, transport: picows.WSTransport) -> None: - del transport - connection = self.holder.get("connection") - if connection is None: - self.holder["pending"].append(("disconnected", None)) - else: - connection._on_disconnected() - - def pause_writing(self) -> None: - connection = self.holder.get("connection") - if connection is not None: - connection._pause_writing() - - def resume_writing(self) -> None: - connection = self.holder.get("connection") - if connection is not None: - connection._resume_writing() - - -class ClientConnection: +class ClientConnection(picows.WSListener): def __init__( self, - transport: picows.WSTransport, *, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, @@ -164,21 +123,23 @@ def __init__( logger: LoggerLike = None, subprotocols: Optional[Sequence[str]] = None, ): - self.transport = transport - self.request = transport.request - self.response = transport.response self.id = uuid.uuid4() self.logger = self._resolve_logger(logger) - self._subprotocol = _resolve_subprotocol(subprotocols, self.response) - self._state = State.OPEN + self.transport = cast(picows.WSTransport, None) + self.request = cast(picows.WSUpgradeRequest, None) + self.response = cast(picows.WSUpgradeResponse, None) + self._subprotocols = subprotocols + self._subprotocol = cast(Optional[str], None) + self._state = State.CONNECTING self._closed_event = asyncio.Event() self._frames: asyncio.Queue[Optional[_BufferedFrame]] = asyncio.Queue() self._close_exc: Optional[ConnectionClosed] = None - self._disconnect_waiter = asyncio.create_task(self._watch_disconnect()) + self._disconnect_waiter: Optional[asyncio.Task[None]] = None + self._loop = asyncio.get_running_loop() self._recv_lock = asyncio.Lock() self._send_lock = asyncio.Lock() self._read_closed = False - self._write_paused = False + self._write_ready: Optional[asyncio.Future[None]] = None self._recv_streaming_in_progress = False self._recv_streaming_broken = False self._pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} @@ -190,10 +151,8 @@ def __init__( self._max_message_size = max_message_size self._max_fragment_size = max_fragment_size self._max_queue_high, self._max_queue_low = self._normalize_watermarks(max_queue) - self._set_write_limits(write_limit) + self._write_limit = write_limit self._paused_reading = False - if ping_interval is not None: - self._keepalive_task = asyncio.create_task(self._keepalive_loop()) def _resolve_logger(self, logger: LoggerLike) -> Union[logging.Logger, logging.LoggerAdapter[Any]]: if logger is None: @@ -222,14 +181,37 @@ def _set_write_limits(self, write_limit: Union[int, tuple[int, Optional[int]]]) high, low = write_limit, None self.transport.underlying_transport.set_write_buffer_limits(high=high, low=low) - def _on_connected(self, transport: picows.WSTransport) -> None: - del transport + def on_ws_connected(self, transport: picows.WSTransport) -> None: + self.transport = transport + self.request = transport.request + self.response = transport.response + self._subprotocol = _resolve_subprotocol(self._subprotocols, self.response) + self._state = State.OPEN + self._set_write_limits(self._write_limit) + self._disconnect_waiter = asyncio.create_task(self._watch_disconnect()) + if self._ping_interval is not None and self._keepalive_task is None: + self._keepalive_task = asyncio.create_task(self._keepalive_loop()) + + def pause_writing(self) -> None: + if self._write_ready is None: + self._write_ready = self._loop.create_future() - def _pause_writing(self) -> None: - self._write_paused = True + def resume_writing(self) -> None: + if self._write_ready is not None: + if not self._write_ready.done(): + self._write_ready.set_result(None) + self._write_ready = None - def _resume_writing(self) -> None: - self._write_paused = False + async def _send_and_wait_write_ready( + self, + msg_type: picows.WSMsgType, + payload: bytes, + *, + fin: bool = True, + ) -> None: + self.transport.send(msg_type, payload, fin=fin) + if self._write_ready is not None: + await self._write_ready def _pause_reading_if_needed(self) -> None: if self._max_queue_high is None: @@ -276,12 +258,18 @@ async def _watch_disconnect(self) -> None: finally: if self._keepalive_task is not None: self._keepalive_task.cancel() + if self._write_ready is not None: + if not self._write_ready.done(): + self._write_ready.set_exception( + self._close_exc or ConnectionClosedError(None, None, None) + ) + self._write_ready = None for waiter, _ in self._pending_pings.values(): if not waiter.done(): waiter.set_exception(self._close_exc or ConnectionClosedError(None, None, None)) self._pending_pings.clear() - def _on_disconnected(self) -> None: + def on_ws_disconnected(self, transport: picows.WSTransport) -> None: self._state = State.CLOSED def _fail_message_too_big(self, message: str) -> None: @@ -290,7 +278,10 @@ def _fail_message_too_big(self, message: str) -> None: self.transport.send_close(picows.WSCloseCode.MESSAGE_TOO_BIG, message.encode("utf-8")) self.transport.disconnect(False) - def _on_frame(self, frame: _BufferedFrame) -> None: + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame) -> None: + self._on_frame_buffered(_BufferedFrame(frame.msg_type, frame.get_payload_as_bytes(), frame.fin)) + + def _on_frame_buffered(self, frame: _BufferedFrame) -> None: if frame.msg_type == picows.WSMsgType.PING: self.transport.send_pong(frame.payload) return @@ -444,53 +435,109 @@ async def send( raise self._connection_closed() async with self._send_lock: - fragments = await self._collect_fragments(message) - if not fragments: - raise TypeError("message iterable cannot be empty") - - first = fragments[0] - if isinstance(first, str): - msg_type = picows.WSMsgType.TEXT - def encode(item: Data) -> bytes: - if not isinstance(item, str): - raise TypeError("all fragments must be of the same type") - return item.encode("utf-8") - elif isinstance(first, (bytes, bytearray, memoryview)): - msg_type = picows.WSMsgType.TEXT if text else picows.WSMsgType.BINARY - def encode(item: Data) -> bytes: - if not isinstance(item, (bytes, bytearray, memoryview)): - raise TypeError("all fragments must be of the same type") - return bytes(item) - else: - raise TypeError(f"message must contain str or bytes-like objects, got {type(first).__name__}") - - if len(fragments) == 1: - payload = encode(first) - self.transport.send(msg_type, payload) + if isinstance(message, (str, bytes, bytearray, memoryview)): + await self._send_single_message(message, text) + return + if isinstance(message, AsyncIterable): + await self._send_async_fragments(message, text) return + if isinstance(message, Iterable): + await self._send_sync_fragments(message, text) + return + raise TypeError(f"message has unsupported type {type(message).__name__}") + + async def _send_single_message(self, message: Data, text: Optional[bool]) -> None: + if isinstance(message, str): + msg_type = picows.WSMsgType.TEXT + payload = message.encode("utf-8") + else: + msg_type = picows.WSMsgType.TEXT if text else picows.WSMsgType.BINARY + payload = bytes(message) - for index, fragment in enumerate(fragments): - if isinstance(first, str) and not isinstance(fragment, str): + await self._send_and_wait_write_ready(msg_type, payload) + + async def _send_sync_fragments(self, message: Iterable[Data], text: Optional[bool]) -> None: + iterator = iter(message) + try: + first = next(iterator) + except StopIteration: + raise TypeError("message iterable cannot be empty") from None + + msg_type, encode = self._get_fragment_codec(first, text) + + try: + second = next(iterator) + except StopIteration: + await self._send_and_wait_write_ready(msg_type, encode(first)) + return + + await self._send_and_wait_write_ready(msg_type, encode(first), fin=False) + previous = second + for fragment in iterator: + await self._send_and_wait_write_ready( + picows.WSMsgType.CONTINUATION, + encode(previous), + fin=False, + ) + previous = fragment + await self._send_and_wait_write_ready( + picows.WSMsgType.CONTINUATION, + encode(previous), + fin=True, + ) + + async def _send_async_fragments(self, message: AsyncIterable[Data], text: Optional[bool]) -> None: + iterator = message.__aiter__() + try: + first = await anext(iterator) + except StopAsyncIteration: + raise TypeError("message iterable cannot be empty") from None + + msg_type, encode = self._get_fragment_codec(first, text) + + try: + second = await anext(iterator) + except StopAsyncIteration: + await self._send_and_wait_write_ready(msg_type, encode(first)) + return + + await self._send_and_wait_write_ready(msg_type, encode(first), fin=False) + previous = second + async for fragment in iterator: + await self._send_and_wait_write_ready( + picows.WSMsgType.CONTINUATION, + encode(previous), + fin=False, + ) + previous = fragment + await self._send_and_wait_write_ready( + picows.WSMsgType.CONTINUATION, + encode(previous), + fin=True, + ) + + def _get_fragment_codec( + self, + first: Data, + text: Optional[bool], + ) -> tuple[picows.WSMsgType, Callable[[Data], bytes]]: + if isinstance(first, str): + def encode(item: Data) -> bytes: + if not isinstance(item, str): raise TypeError("all fragments must be of the same type") - if not isinstance(first, str) and not isinstance(fragment, (bytes, bytearray, memoryview)): + return item.encode("utf-8") + + return picows.WSMsgType.TEXT, encode + + if isinstance(first, (bytes, bytearray, memoryview)): + def encode(item: Data) -> bytes: + if not isinstance(item, (bytes, bytearray, memoryview)): raise TypeError("all fragments must be of the same type") - opcode = msg_type if index == 0 else picows.WSMsgType.CONTINUATION - self.transport.send(opcode, encode(fragment), fin=index == len(fragments) - 1) + return bytes(item) - async def _collect_fragments( - self, - message: Union[Data, Iterable[Data], AsyncIterator[Data]], - ) -> list[Data]: - if isinstance(message, (str, bytes, bytearray, memoryview)): - return [message] - if isinstance(message, AsyncIterable): - result: list[Data] = [] - async for item in message: - result.append(item) - return result - if isinstance(message, Iterable): - return list(message) - raise TypeError(f"message has unsupported type {type(message).__name__}") + return (picows.WSMsgType.TEXT if text else picows.WSMsgType.BINARY), encode + + raise TypeError(f"message must contain str or bytes-like objects, got {type(first).__name__}") async def close(self, code: CloseCodeT = 1000, reason: str = "") -> None: if self.state is State.CLOSED: @@ -743,13 +790,21 @@ async def connect_override(_: Any) -> socket.socket: return sock socket_factory = connect_override - holder: dict[str, Any] = {"pending": []} - - def listener_factory() -> _ConnectionListener: - return _ConnectionListener(holder) + def listener_factory() -> ClientConnection: + return self.connection_factory( + ping_interval=self.ping_interval, + ping_timeout=self.ping_timeout, + close_timeout=self.close_timeout, + max_queue=self.max_queue, + write_limit=self.write_limit, + max_message_size=max_message_size, + max_fragment_size=max_fragment_size, + logger=self.logger, + subprotocols=self.subprotocols, + ) try: - transport, _listener = await picows.ws_connect( + _transport, listener = await picows.ws_connect( listener_factory, self.uri, ssl_context=self._coerce_ssl_context(ssl_context), @@ -776,27 +831,7 @@ def listener_factory() -> _ConnectionListener: except picows.WSHandshakeError as exc: raise InvalidHandshake(str(exc)) from exc - connection = self.connection_factory( - transport, - ping_interval=self.ping_interval, - ping_timeout=self.ping_timeout, - close_timeout=self.close_timeout, - max_queue=self.max_queue, - write_limit=self.write_limit, - max_message_size=max_message_size, - max_fragment_size=max_fragment_size, - logger=self.logger, - subprotocols=self.subprotocols, - ) - holder["connection"] = connection - for kind, event in holder["pending"]: - if kind == "connected": - connection._on_connected(event) - elif kind == "frame": - connection._on_frame(event) - else: - connection._on_disconnected() - return connection + return cast(ClientConnection, listener) def _normalize_max_size( self, diff --git a/pyproject.toml b/pyproject.toml index dc3065f..4d121ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,8 +36,8 @@ Repository = "https://github.com/tarasko/picows" Issues = "https://github.com/tarasko/picows/issues" Documentation = "https://picows.readthedocs.io/en/latest" -[tool.setuptools] -packages = ["picows"] +[tool.setuptools.packages.find] +include = ["picows*"] [tool.setuptools.dynamic] dependencies = {file = ["requirements.txt"]} diff --git a/tests/test_websockets_compat.py b/tests/test_websockets_compat.py index eca88c8..cad5ab3 100644 --- a/tests/test_websockets_compat.py +++ b/tests/test_websockets_compat.py @@ -2,6 +2,7 @@ import pytest +import picows from picows import websockets from tests.utils import WSServer @@ -67,3 +68,48 @@ def listener_factory(request): pass assert request_headers["value"] == "chat" + + +async def test_send_waits_for_resume_writing(): + class TrackingConnection(websockets.ClientConnection): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.pause_event = asyncio.Event() + + def pause_writing(self) -> None: + super().pause_writing() + self.pause_event.set() + + async with WSServer() as server: + async with websockets.connect( + server.url, + compression=None, + create_connection=TrackingConnection, + ) as ws: + third_requested = asyncio.Event() + allow_resume = asyncio.Event() + + async def fragments(): + ws.pause_writing() + yield b"first" + yield b"second" + third_requested.set() + yield b"third" + + async def resume_later(): + await allow_resume.wait() + ws.resume_writing() + + asyncio.create_task(resume_later()) + + send_task = asyncio.create_task(ws.send(fragments())) + await asyncio.wait_for(ws.pause_event.wait(), 1.0) + await asyncio.sleep(0) + assert not third_requested.is_set() + + allow_resume.set() + await asyncio.wait_for(send_task, 1.0) + assert third_requested.is_set() + + reply = await ws.recv() + assert reply == b"firstsecondthird" From d2cbcc0de17b8a3cf2ec593e29dc3f893dd28357 Mon Sep 17 00:00:00 2001 From: taras Date: Sat, 2 May 2026 01:37:49 +0200 Subject: [PATCH 03/57] Cleanups --- picows/picows.pyx | 2 + picows/websockets/asyncio/client.py | 82 +++++++++++------------------ 2 files changed, 34 insertions(+), 50 deletions(-) diff --git a/picows/picows.pyx b/picows/picows.pyx index d2d5793..44307a0 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -654,6 +654,8 @@ cdef class WSTransport: if self.close_handshake is None: self.close_handshake = WSCloseHandshake.__new__(WSCloseHandshake) + self.close_handshake.recv = None + self.close_handshake.sent = None self.close_handshake.recv_then_sent = False self.close_handshake.sent = WSCloseInfo.__new__(WSCloseInfo) diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index 763b892..08e7100 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -103,8 +103,8 @@ def process_exception(exc: Exception) -> Optional[Exception]: if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)): return None if isinstance(exc, InvalidStatus): - status = getattr(getattr(exc, "response", None), "status", None) - if status is not None and int(status) in {500, 502, 503, 504}: + status = exc.response.status + if int(status) in {500, 502, 503, 504}: return None return exc @@ -134,11 +134,9 @@ def __init__( self._closed_event = asyncio.Event() self._frames: asyncio.Queue[Optional[_BufferedFrame]] = asyncio.Queue() self._close_exc: Optional[ConnectionClosed] = None - self._disconnect_waiter: Optional[asyncio.Task[None]] = None self._loop = asyncio.get_running_loop() self._recv_lock = asyncio.Lock() self._send_lock = asyncio.Lock() - self._read_closed = False self._write_ready: Optional[asyncio.Future[None]] = None self._recv_streaming_in_progress = False self._recv_streaming_broken = False @@ -188,7 +186,6 @@ def on_ws_connected(self, transport: picows.WSTransport) -> None: self._subprotocol = _resolve_subprotocol(self._subprotocols, self.response) self._state = State.OPEN self._set_write_limits(self._write_limit) - self._disconnect_waiter = asyncio.create_task(self._watch_disconnect()) if self._ping_interval is not None and self._keepalive_task is None: self._keepalive_task = asyncio.create_task(self._keepalive_loop()) @@ -229,48 +226,39 @@ def _resume_reading_if_needed(self) -> None: def _set_close_exception(self) -> None: handshake = self.transport.close_handshake - rcvd = getattr(handshake, "recv", None) if handshake is not None else None - sent = getattr(handshake, "sent", None) if handshake is not None else None - rcvd_then_sent = getattr(handshake, "recv_then_sent", None) if handshake is not None else None - rcvd_code = _coerce_close_code(getattr(rcvd, "code", None)) - sent_code = _coerce_close_code(getattr(sent, "code", None)) + if handshake is None: + self._close_exc = ConnectionClosedError(None, None, None) + return + rcvd = handshake.recv + sent = handshake.sent + rcvd_then_sent = handshake.recv_then_sent + rcvd_code = _coerce_close_code(rcvd.code) if rcvd is not None else None + sent_code = _coerce_close_code(sent.code) if sent is not None else None ok = ( (rcvd_code in OK_CLOSE_CODES or rcvd_code is None) and (sent_code in OK_CLOSE_CODES or sent_code is None) - and handshake is not None ) exc_type = ConnectionClosedOK if ok else ConnectionClosedError self._close_exc = exc_type(rcvd, sent, rcvd_then_sent) - async def _watch_disconnect(self) -> None: - try: - await self.transport.wait_disconnected() - except Exception: - self._state = State.CLOSED - self._set_close_exception() - self._frames.put_nowait(None) - self._closed_event.set() - else: - self._state = State.CLOSED - self._set_close_exception() - self._frames.put_nowait(None) - self._closed_event.set() - finally: - if self._keepalive_task is not None: - self._keepalive_task.cancel() - if self._write_ready is not None: - if not self._write_ready.done(): - self._write_ready.set_exception( - self._close_exc or ConnectionClosedError(None, None, None) - ) - self._write_ready = None - for waiter, _ in self._pending_pings.values(): - if not waiter.done(): - waiter.set_exception(self._close_exc or ConnectionClosedError(None, None, None)) - self._pending_pings.clear() - def on_ws_disconnected(self, transport: picows.WSTransport) -> None: self._state = State.CLOSED + self._set_close_exception() + self._frames.put_nowait(None) + self._closed_event.set() + if self._keepalive_task is not None: + self._keepalive_task.cancel() + self._keepalive_task = None + if self._write_ready is not None: + if not self._write_ready.done(): + self._write_ready.set_exception( + self._close_exc or ConnectionClosedError(None, None, None) + ) + self._write_ready = None + for waiter, _ in self._pending_pings.values(): + if not waiter.done(): + waiter.set_exception(self._close_exc or ConnectionClosedError(None, None, None)) + self._pending_pings.clear() def _fail_message_too_big(self, message: str) -> None: if self._state is State.CLOSED: @@ -279,15 +267,12 @@ def _fail_message_too_big(self, message: str) -> None: self.transport.disconnect(False) def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame) -> None: - self._on_frame_buffered(_BufferedFrame(frame.msg_type, frame.get_payload_as_bytes(), frame.fin)) - - def _on_frame_buffered(self, frame: _BufferedFrame) -> None: + payload = frame.get_payload_as_bytes() if frame.msg_type == picows.WSMsgType.PING: - self.transport.send_pong(frame.payload) + self.transport.send_pong(payload) return if frame.msg_type == picows.WSMsgType.PONG: - payload = frame.payload ping = self._pending_pings.pop(payload, None) if ping is not None: waiter, sent_at = ping @@ -299,21 +284,20 @@ def _on_frame_buffered(self, frame: _BufferedFrame) -> None: if frame.msg_type == picows.WSMsgType.CLOSE: close_code: CloseCodeT = picows.WSCloseCode.NO_INFO close_message = b"" - if len(frame.payload) >= 2: - close_code = int.from_bytes(frame.payload[:2], "big") - close_message = frame.payload[2:] + if len(payload) >= 2: + close_code = int.from_bytes(payload[:2], "big") + close_message = payload[2:] if not self.transport.is_close_frame_sent: self.transport.send_close(cast(picows.WSCloseCode, close_code), close_message) self._state = State.CLOSING self.transport.disconnect() return - payload = frame.payload if self._max_fragment_size is not None and len(payload) > self._max_fragment_size: self._fail_message_too_big("fragment too big") return - self._frames.put_nowait(frame) + self._frames.put_nowait(_BufferedFrame(frame.msg_type, payload, frame.fin)) self._pause_reading_if_needed() async def _next_frame(self) -> _BufferedFrame: @@ -606,7 +590,6 @@ async def __aenter__(self) -> "ClientConnection": return self async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - del exc_type, exc, tb await self.close() def __aiter__(self) -> AsyncIterator[Union[str, bytes]]: @@ -716,7 +699,6 @@ async def __aenter__(self) -> ClientConnection: return self._connection async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - del exc_type, exc, tb if self._connection is not None: await self._connection.close() self._connection = None From 115246ee7fac4b456025aa3db93b7b272d8005a9 Mon Sep 17 00:00:00 2001 From: taras Date: Sat, 2 May 2026 21:52:05 +0200 Subject: [PATCH 04/57] Cythonize client --- picows/websockets/asyncio/client.py | 281 +++++++++++++++++----------- setup.py | 5 + 2 files changed, 173 insertions(+), 113 deletions(-) diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index 08e7100..f434319 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -8,13 +8,22 @@ import uuid import warnings from collections.abc import AsyncIterable, Generator, Iterable -from dataclasses import dataclass from enum import IntEnum from ssl import SSLContext from time import monotonic -from typing import Any, AsyncIterator, Awaitable, Callable, Optional, Sequence, Union, cast +from typing import Any, AsyncIterator, Awaitable, Callable, Optional, Sequence, \ + Union, cast, Dict, Tuple from urllib.request import getproxies +import cython + +if cython.compiled: + from cython.cimports.picows.picows import WSListener, WSTransport, WSFrame, \ + WSMsgType, WSCloseCode +else: + from picows import WSListener, WSTransport, WSFrame, WSMsgType, WSCloseCode + + import picows from picows.types import WSHeadersLike from picows.url import parse_url @@ -37,7 +46,7 @@ Data = Union[str, bytes, bytearray, memoryview] HeadersLike = WSHeadersLike -CloseCodeT = Union[int, picows.WSCloseCode] +CloseCodeT = int LoggerLike = Union[str, logging.Logger, logging.LoggerAdapter[Any], None] @@ -51,25 +60,34 @@ class State(IntEnum): CLOSED = 3 -@dataclass(slots=True) +@cython.cclass class _BufferedFrame: - msg_type: picows.WSMsgType + msg_type: WSMsgType payload: bytes fin: bool + def __init__(self, msg_type: WSMsgType, payload: bytes, fin: bool): + self.msg_type = msg_type + self.payload = payload + self.fin = fin -def _coerce_close_code(code: Optional[picows.WSCloseCode]) -> Optional[int]: - return None if code is None else int(code.value) +@cython.cfunc +def _coerce_close_code(code: WSCloseCode) -> Optional[int]: + return None if code is None else int(code) + +@cython.cfunc def _coerce_close_reason(reason: Optional[str]) -> Optional[str]: return reason if reason is not None else None +@cython.cfunc def _header_items(headers: Any) -> list[tuple[str, str]]: return [] if headers is None else list(headers.items()) +@cython.cfunc def _resolve_subprotocol(subprotocols: Optional[Sequence[str]], response: Any) -> Optional[str]: if response is None: return None @@ -81,10 +99,12 @@ def _resolve_subprotocol(subprotocols: Optional[Sequence[str]], response: Any) - return cast(str, value) +@cython.cfunc def _default_user_agent() -> str: return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" +@cython.cfunc def _process_proxy(proxy: Union[str, bool, None], secure: bool) -> Optional[str]: if proxy is None: return None @@ -99,6 +119,12 @@ def _process_proxy(proxy: Union[str, bool, None], secure: bool) -> Optional[str] raise InvalidURI(str(proxy), "proxy must be None, True, or a proxy URL") +@cython.cfunc +def _normalize_size_limit(limit: Optional[int]) -> cython.Py_ssize_t: + return 0 if limit is None else limit + + +@cython.ccall def process_exception(exc: Exception) -> Optional[Exception]: if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)): return None @@ -109,7 +135,38 @@ def process_exception(exc: Exception) -> Optional[Exception]: return exc -class ClientConnection(picows.WSListener): +@cython.cclass +class ClientConnection(WSListener): + id: uuid.UUID + logger: Union[logging.Logger, logging.LoggerAdapter[Any]] + transport: WSTransport + request: picows.WSUpgradeRequest + response: picows.WSUpgradeResponse + _subprotocols: Optional[Sequence[str]] + _subprotocol: Optional[str] + _state: State + _closed_event: asyncio.Event + _frames: asyncio.Queue[Optional[_BufferedFrame]] + _close_exc: Optional[ConnectionClosed] + _loop: asyncio.AbstractEventLoop + _recv_lock: asyncio.Lock + _send_lock: asyncio.Lock + _write_ready: Optional[asyncio.Future[None]] + _recv_streaming_in_progress: cython.bint + _recv_streaming_broken: cython.bint + _pending_pings: Dict[bytes, Tuple[asyncio.Future[float], float]] + _ping_interval: Optional[float] + _ping_timeout: Optional[float] + _close_timeout: Optional[float] + _keepalive_task: Optional[asyncio.Task[None]] + _latency: cython.double + _max_message_size: cython.Py_ssize_t + _max_fragment_size: cython.Py_ssize_t + _max_queue_high: cython.Py_ssize_t + _max_queue_low: cython.Py_ssize_t + _write_limit: Union[int, tuple[int, Optional[int]]] + _paused_reading: cython.bint + def __init__( self, *, @@ -125,7 +182,7 @@ def __init__( ): self.id = uuid.uuid4() self.logger = self._resolve_logger(logger) - self.transport = cast(picows.WSTransport, None) + self.transport = cython.cast(WSTransport, None) self.request = cast(picows.WSUpgradeRequest, None) self.response = cast(picows.WSUpgradeResponse, None) self._subprotocols = subprotocols @@ -146,12 +203,13 @@ def __init__( self._close_timeout = close_timeout self._keepalive_task: Optional[asyncio.Task[None]] = None self._latency = 0.0 - self._max_message_size = max_message_size - self._max_fragment_size = max_fragment_size + self._max_message_size = _normalize_size_limit(max_message_size) + self._max_fragment_size = _normalize_size_limit(max_fragment_size) self._max_queue_high, self._max_queue_low = self._normalize_watermarks(max_queue) self._write_limit = write_limit self._paused_reading = False + @cython.cfunc def _resolve_logger(self, logger: LoggerLike) -> Union[logging.Logger, logging.LoggerAdapter[Any]]: if logger is None: return logging.getLogger("websockets.client") @@ -159,19 +217,21 @@ def _resolve_logger(self, logger: LoggerLike) -> Union[logging.Logger, logging.L return logging.getLogger(logger) return logger + @cython.cfunc def _normalize_watermarks( self, max_queue: Union[int, tuple[Optional[int], Optional[int]], None], - ) -> tuple[Optional[int], Optional[int]]: + ) -> tuple[cython.Py_ssize_t, cython.Py_ssize_t]: if max_queue is None: - return None, None + return 0, 0 if isinstance(max_queue, tuple): high, low = max_queue if high is None: - return None, None + return 0, 0 return high, high // 4 if low is None else low return max_queue, max_queue // 4 + @cython.cfunc def _set_write_limits(self, write_limit: Union[int, tuple[int, Optional[int]]]) -> None: if isinstance(write_limit, tuple): high, low = write_limit @@ -179,7 +239,8 @@ def _set_write_limits(self, write_limit: Union[int, tuple[int, Optional[int]]]) high, low = write_limit, None self.transport.underlying_transport.set_write_buffer_limits(high=high, low=low) - def on_ws_connected(self, transport: picows.WSTransport) -> None: + @cython.ccall + def on_ws_connected(self, transport: WSTransport) -> None: self.transport = transport self.request = transport.request self.response = transport.response @@ -189,41 +250,33 @@ def on_ws_connected(self, transport: picows.WSTransport) -> None: if self._ping_interval is not None and self._keepalive_task is None: self._keepalive_task = asyncio.create_task(self._keepalive_loop()) + @cython.ccall def pause_writing(self) -> None: if self._write_ready is None: self._write_ready = self._loop.create_future() + @cython.ccall def resume_writing(self) -> None: if self._write_ready is not None: if not self._write_ready.done(): self._write_ready.set_result(None) self._write_ready = None - async def _send_and_wait_write_ready( - self, - msg_type: picows.WSMsgType, - payload: bytes, - *, - fin: bool = True, - ) -> None: - self.transport.send(msg_type, payload, fin=fin) - if self._write_ready is not None: - await self._write_ready - + @cython.cfunc def _pause_reading_if_needed(self) -> None: - if self._max_queue_high is None: - return - if not self._paused_reading and self._frames.qsize() >= self._max_queue_high: + if self._max_queue_high > 0 and not self._paused_reading and self._frames.qsize() >= self._max_queue_high: self.transport.underlying_transport.pause_reading() self._paused_reading = True + @cython.cfunc def _resume_reading_if_needed(self) -> None: if not self._paused_reading: return - if self._max_queue_low is None or self._frames.qsize() <= self._max_queue_low: + if self._max_queue_low == 0 or self._frames.qsize() <= self._max_queue_low: self.transport.underlying_transport.resume_reading() self._paused_reading = False + @cython.cfunc def _set_close_exception(self) -> None: handshake = self.transport.close_handshake if handshake is None: @@ -241,7 +294,8 @@ def _set_close_exception(self) -> None: exc_type = ConnectionClosedOK if ok else ConnectionClosedError self._close_exc = exc_type(rcvd, sent, rcvd_then_sent) - def on_ws_disconnected(self, transport: picows.WSTransport) -> None: + @cython.ccall + def on_ws_disconnected(self, transport: WSTransport) -> None: self._state = State.CLOSED self._set_close_exception() self._frames.put_nowait(None) @@ -260,19 +314,19 @@ def on_ws_disconnected(self, transport: picows.WSTransport) -> None: waiter.set_exception(self._close_exc or ConnectionClosedError(None, None, None)) self._pending_pings.clear() + @cython.cfunc def _fail_message_too_big(self, message: str) -> None: - if self._state is State.CLOSED: - return - self.transport.send_close(picows.WSCloseCode.MESSAGE_TOO_BIG, message.encode("utf-8")) + self.transport.send_close(WSCloseCode.MESSAGE_TOO_BIG, message.encode("utf-8")) self.transport.disconnect(False) - def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame) -> None: + @cython.ccall + def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: payload = frame.get_payload_as_bytes() - if frame.msg_type == picows.WSMsgType.PING: + if frame.msg_type == WSMsgType.PING: self.transport.send_pong(payload) return - if frame.msg_type == picows.WSMsgType.PONG: + if frame.msg_type == WSMsgType.PONG: ping = self._pending_pings.pop(payload, None) if ping is not None: waiter, sent_at = ping @@ -281,19 +335,15 @@ def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame) -> N waiter.set_result(self._latency) return - if frame.msg_type == picows.WSMsgType.CLOSE: - close_code: CloseCodeT = picows.WSCloseCode.NO_INFO - close_message = b"" - if len(payload) >= 2: - close_code = int.from_bytes(payload[:2], "big") - close_message = payload[2:] - if not self.transport.is_close_frame_sent: - self.transport.send_close(cast(picows.WSCloseCode, close_code), close_message) - self._state = State.CLOSING + if frame.msg_type == WSMsgType.CLOSE: + close_code = frame.get_close_code() + close_message = frame.get_close_message() + self.transport.send_close(close_code, close_message) self.transport.disconnect() + self._state = State.CLOSING return - if self._max_fragment_size is not None and len(payload) > self._max_fragment_size: + if self._max_fragment_size > 0 and len(payload) > self._max_fragment_size: self._fail_message_too_big("fragment too big") return @@ -301,17 +351,19 @@ def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame) -> N self._pause_reading_if_needed() async def _next_frame(self) -> _BufferedFrame: - frame = await self._frames.get() + frame: _BufferedFrame = cython.cast(_BufferedFrame, await self._frames.get()) self._resume_reading_if_needed() if frame is None: raise self._connection_closed() return frame + @cython.cfunc def _connection_closed(self) -> ConnectionClosed: if self._close_exc is None: self._set_close_exception() return self._close_exc or ConnectionClosedError(None, None, None) + @cython.cfunc def _ensure_recv_available(self) -> None: if self._recv_streaming_broken: raise ConcurrencyError("recv_streaming() wasn't fully consumed") @@ -323,33 +375,36 @@ async def recv(self, decode: Optional[bool] = None) -> Union[str, bytes]: if self._recv_lock.locked(): raise ConcurrencyError("cannot call recv() concurrently") async with self._recv_lock: - first = await self._next_frame() - if first.msg_type not in (picows.WSMsgType.TEXT, picows.WSMsgType.BINARY): + first: _BufferedFrame = await self._next_frame() + if first.msg_type not in (WSMsgType.TEXT, WSMsgType.BINARY): raise ProtocolError(f"unexpected opcode while receiving message: {first.msg_type}") msg_type = first.msg_type + if first.fin: + return self._decode_payload(first.payload, msg_type, decode) chunks = [first.payload] - total = len(first.payload) + total: cython.Py_ssize_t = len(first.payload) while not first.fin: first = await self._next_frame() - if first.msg_type != picows.WSMsgType.CONTINUATION: + if first.msg_type != WSMsgType.CONTINUATION: raise ProtocolError("expected continuation frame") chunks.append(first.payload) total += len(first.payload) - if self._max_message_size is not None and total > self._max_message_size: + if self._max_message_size > 0 and total > self._max_message_size: self._fail_message_too_big("message too big") raise PayloadTooBig("message too big") payload = b"".join(chunks) return self._decode_payload(payload, msg_type, decode) + @cython.cfunc def _decode_payload( self, payload: bytes, - msg_type: picows.WSMsgType, + msg_type: WSMsgType, decode: Optional[bool], ) -> Union[str, bytes]: - if msg_type == picows.WSMsgType.TEXT: + if msg_type == WSMsgType.TEXT: if decode is False: return payload return payload.decode("utf-8") @@ -369,20 +424,20 @@ async def iterator() -> AsyncIterator[Union[str, bytes]]: nonlocal started, finished try: async with self._recv_lock: - first = await self._next_frame() - if first.msg_type not in (picows.WSMsgType.TEXT, picows.WSMsgType.BINARY): + first: _BufferedFrame = await self._next_frame() + if first.msg_type not in (WSMsgType.TEXT, WSMsgType.BINARY): raise ProtocolError(f"unexpected opcode while receiving message: {first.msg_type}") msg_type = first.msg_type started = True yield self._decode_fragment(first.payload, msg_type, decode) - total = len(first.payload) + total: cython.Py_ssize_t = len(first.payload) frame = first while not frame.fin: - frame = await self._next_frame() - if frame.msg_type != picows.WSMsgType.CONTINUATION: + frame: _BufferedFrame = await self._next_frame() + if frame.msg_type != WSMsgType.CONTINUATION: raise ProtocolError("expected continuation frame") total += len(frame.payload) - if self._max_message_size is not None and total > self._max_message_size: + if self._max_message_size > 0 and total > self._max_message_size: self._fail_message_too_big("message too big") raise PayloadTooBig("message too big") yield self._decode_fragment(frame.payload, msg_type, decode) @@ -396,13 +451,14 @@ async def iterator() -> AsyncIterator[Union[str, bytes]]: return iterator() + @cython.cfunc def _decode_fragment( self, payload: bytes, - msg_type: picows.WSMsgType, + msg_type: WSMsgType, decode: Optional[bool], ) -> Union[str, bytes]: - if msg_type == picows.WSMsgType.TEXT: + if msg_type == WSMsgType.TEXT: if decode is False: return payload return payload.decode("utf-8") @@ -420,7 +476,15 @@ async def send( async with self._send_lock: if isinstance(message, (str, bytes, bytearray, memoryview)): - await self._send_single_message(message, text) + if isinstance(message, str): + self.transport.send(WSMsgType.TEXT, cython.cast(str, message).encode("utf-8")) + else: + self.transport.send( + WSMsgType.TEXT if text else WSMsgType.BINARY, + message, + ) + if self._write_ready is not None: + await self._write_ready return if isinstance(message, AsyncIterable): await self._send_async_fragments(message, text) @@ -430,16 +494,6 @@ async def send( return raise TypeError(f"message has unsupported type {type(message).__name__}") - async def _send_single_message(self, message: Data, text: Optional[bool]) -> None: - if isinstance(message, str): - msg_type = picows.WSMsgType.TEXT - payload = message.encode("utf-8") - else: - msg_type = picows.WSMsgType.TEXT if text else picows.WSMsgType.BINARY - payload = bytes(message) - - await self._send_and_wait_write_ready(msg_type, payload) - async def _send_sync_fragments(self, message: Iterable[Data], text: Optional[bool]) -> None: iterator = iter(message) try: @@ -452,23 +506,23 @@ async def _send_sync_fragments(self, message: Iterable[Data], text: Optional[boo try: second = next(iterator) except StopIteration: - await self._send_and_wait_write_ready(msg_type, encode(first)) + self.transport.send(msg_type, encode(first)) + if self._write_ready is not None: + await self._write_ready return - await self._send_and_wait_write_ready(msg_type, encode(first), fin=False) + self.transport.send(msg_type, encode(first), fin=False) + if self._write_ready is not None: + await self._write_ready previous = second for fragment in iterator: - await self._send_and_wait_write_ready( - picows.WSMsgType.CONTINUATION, - encode(previous), - fin=False, - ) + self.transport.send(WSMsgType.CONTINUATION, encode(previous), fin=False) + if self._write_ready is not None: + await self._write_ready previous = fragment - await self._send_and_wait_write_ready( - picows.WSMsgType.CONTINUATION, - encode(previous), - fin=True, - ) + self.transport.send(WSMsgType.CONTINUATION, encode(previous), fin=True) + if self._write_ready is not None: + await self._write_ready async def _send_async_fragments(self, message: AsyncIterable[Data], text: Optional[bool]) -> None: iterator = message.__aiter__() @@ -482,54 +536,54 @@ async def _send_async_fragments(self, message: AsyncIterable[Data], text: Option try: second = await anext(iterator) except StopAsyncIteration: - await self._send_and_wait_write_ready(msg_type, encode(first)) + self.transport.send(msg_type, encode(first)) + if self._write_ready is not None: + await self._write_ready return - await self._send_and_wait_write_ready(msg_type, encode(first), fin=False) + self.transport.send(msg_type, encode(first), fin=False) + if self._write_ready is not None: + await self._write_ready previous = second async for fragment in iterator: - await self._send_and_wait_write_ready( - picows.WSMsgType.CONTINUATION, - encode(previous), - fin=False, - ) + self.transport.send(WSMsgType.CONTINUATION, encode(previous), fin=False) + if self._write_ready is not None: + await self._write_ready previous = fragment - await self._send_and_wait_write_ready( - picows.WSMsgType.CONTINUATION, - encode(previous), - fin=True, - ) + self.transport.send(WSMsgType.CONTINUATION, encode(previous), fin=True) + if self._write_ready is not None: + await self._write_ready + @cython.cfunc def _get_fragment_codec( self, first: Data, text: Optional[bool], - ) -> tuple[picows.WSMsgType, Callable[[Data], bytes]]: + ) -> tuple[WSMsgType, Callable[[Data], bytes]]: if isinstance(first, str): def encode(item: Data) -> bytes: if not isinstance(item, str): raise TypeError("all fragments must be of the same type") return item.encode("utf-8") - return picows.WSMsgType.TEXT, encode + return WSMsgType.TEXT, encode if isinstance(first, (bytes, bytearray, memoryview)): - def encode(item: Data) -> bytes: + def encode(item: Data) -> Data: if not isinstance(item, (bytes, bytearray, memoryview)): raise TypeError("all fragments must be of the same type") - return bytes(item) + return item - return (picows.WSMsgType.TEXT if text else picows.WSMsgType.BINARY), encode + return (WSMsgType.TEXT if text else WSMsgType.BINARY), encode raise TypeError(f"message must contain str or bytes-like objects, got {type(first).__name__}") - async def close(self, code: CloseCodeT = 1000, reason: str = "") -> None: + async def close(self, code: int = 1000, reason: str = "") -> None: if self.state is State.CLOSED: return if self.state is State.OPEN: self._state = State.CLOSING - close_code = code if isinstance(code, picows.WSCloseCode) else picows.WSCloseCode(code) - self.transport.send_close(close_code, reason.encode("utf-8")) + self.transport.send_close(code, reason.encode("utf-8")) try: if self._close_timeout is None: await self.wait_closed() @@ -586,7 +640,7 @@ async def _keepalive_loop(self) -> None: if self.state is not State.CLOSED: await self.close(code=1011, reason="keepalive ping timeout") - async def __aenter__(self) -> "ClientConnection": + async def __aenter__(self) -> ClientConnection: return self async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: @@ -703,7 +757,7 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: await self._connection.close() self._connection = None - def __aiter__(self) -> "_Connect": + def __aiter__(self) -> _Connect: return self async def __anext__(self) -> ClientConnection: @@ -793,7 +847,7 @@ def listener_factory() -> ClientConnection: websocket_handshake_timeout=self.open_timeout, enable_auto_ping=False, enable_auto_pong=False, - max_frame_size=max_fragment_size if max_fragment_size is not None else 2 ** 31 - 1, + max_frame_size=max_fragment_size if max_fragment_size > 0 else 2 ** 31 - 1, extra_headers=extra_headers, proxy=proxy, socket_factory=socket_factory, @@ -813,16 +867,17 @@ def listener_factory() -> ClientConnection: except picows.WSHandshakeError as exc: raise InvalidHandshake(str(exc)) from exc - return cast(ClientConnection, listener) + return cython.cast(ClientConnection, listener) def _normalize_max_size( self, max_size: Union[int, tuple[Optional[int], Optional[int]], None], - ) -> tuple[Optional[int], Optional[int]]: + ) -> tuple[cython.Py_ssize_t, cython.Py_ssize_t]: if max_size is None: - return None, None + return 0, 0 if isinstance(max_size, tuple): - return max_size + max_message_size, max_fragment_size = max_size + return _normalize_size_limit(max_message_size), _normalize_size_limit(max_fragment_size) return max_size, max_size def _build_headers(self) -> list[tuple[str, str]]: diff --git a/setup.py b/setup.py index f5996c4..8b7ec3c 100644 --- a/setup.py +++ b/setup.py @@ -133,6 +133,11 @@ def build_extension(self, ext: Extension): depends=["picows/compat.h"], extra_compile_args=extra_compile_args, extra_link_args=extra_link_args), + Extension("picows.websockets.asyncio.client", ["picows/websockets/asyncio/client.py"], + libraries=libs, define_macros=macros, + depends=["picows/compat.h"], + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args), ] if with_examples: From 66684d0ddd9df18294ebbe7302b6a8951a269d1a Mon Sep 17 00:00:00 2001 From: taras Date: Sat, 2 May 2026 22:25:26 +0200 Subject: [PATCH 05/57] Optimize --- picows/websockets/asyncio/client.py | 98 +++++++++++++++++++++++++---- 1 file changed, 86 insertions(+), 12 deletions(-) diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index f434319..572ebfa 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -7,6 +7,7 @@ import socket import uuid import warnings +from collections import deque from collections.abc import AsyncIterable, Generator, Iterable from enum import IntEnum from ssl import SSLContext @@ -72,22 +73,70 @@ def __init__(self, msg_type: WSMsgType, payload: bytes, fin: bool): self.fin = fin +@cython.cclass +class _AsyncLock: + _loop: asyncio.AbstractEventLoop + _locked: cython.bint + _waiters: Any + + def __init__(self, loop: asyncio.AbstractEventLoop): + self._loop = loop + self._locked = False + self._waiters = deque() + + @cython.cfunc + @cython.inline + def locked(self) -> cython.bint: + return self._locked + + @cython.cfunc + @cython.inline + def acquire(self) -> None: + self._locked = True + + async def wait_and_acquire(self) -> None: + waiter = self._loop.create_future() + self._waiters.append(waiter) + try: + await waiter + except Exception: + try: + self._waiters.remove(waiter) + except ValueError: + pass + raise + + @cython.cfunc + @cython.inline + def release(self) -> None: + while self._waiters: + waiter = self._waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + return + self._locked = False + + @cython.cfunc +@cython.inline def _coerce_close_code(code: WSCloseCode) -> Optional[int]: return None if code is None else int(code) @cython.cfunc +@cython.inline def _coerce_close_reason(reason: Optional[str]) -> Optional[str]: return reason if reason is not None else None @cython.cfunc +@cython.inline def _header_items(headers: Any) -> list[tuple[str, str]]: return [] if headers is None else list(headers.items()) @cython.cfunc +@cython.inline def _resolve_subprotocol(subprotocols: Optional[Sequence[str]], response: Any) -> Optional[str]: if response is None: return None @@ -100,11 +149,13 @@ def _resolve_subprotocol(subprotocols: Optional[Sequence[str]], response: Any) - @cython.cfunc +@cython.inline def _default_user_agent() -> str: return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" @cython.cfunc +@cython.inline def _process_proxy(proxy: Union[str, bool, None], secure: bool) -> Optional[str]: if proxy is None: return None @@ -120,6 +171,7 @@ def _process_proxy(proxy: Union[str, bool, None], secure: bool) -> Optional[str] @cython.cfunc +@cython.inline def _normalize_size_limit(limit: Optional[int]) -> cython.Py_ssize_t: return 0 if limit is None else limit @@ -149,8 +201,8 @@ class ClientConnection(WSListener): _frames: asyncio.Queue[Optional[_BufferedFrame]] _close_exc: Optional[ConnectionClosed] _loop: asyncio.AbstractEventLoop - _recv_lock: asyncio.Lock - _send_lock: asyncio.Lock + _recv_lock: _AsyncLock + _send_lock: _AsyncLock _write_ready: Optional[asyncio.Future[None]] _recv_streaming_in_progress: cython.bint _recv_streaming_broken: cython.bint @@ -192,8 +244,8 @@ def __init__( self._frames: asyncio.Queue[Optional[_BufferedFrame]] = asyncio.Queue() self._close_exc: Optional[ConnectionClosed] = None self._loop = asyncio.get_running_loop() - self._recv_lock = asyncio.Lock() - self._send_lock = asyncio.Lock() + self._recv_lock = _AsyncLock(self._loop) + self._send_lock = _AsyncLock(self._loop) self._write_ready: Optional[asyncio.Future[None]] = None self._recv_streaming_in_progress = False self._recv_streaming_broken = False @@ -210,6 +262,7 @@ def __init__( self._paused_reading = False @cython.cfunc + @cython.inline def _resolve_logger(self, logger: LoggerLike) -> Union[logging.Logger, logging.LoggerAdapter[Any]]: if logger is None: return logging.getLogger("websockets.client") @@ -218,6 +271,7 @@ def _resolve_logger(self, logger: LoggerLike) -> Union[logging.Logger, logging.L return logger @cython.cfunc + @cython.inline def _normalize_watermarks( self, max_queue: Union[int, tuple[Optional[int], Optional[int]], None], @@ -232,6 +286,7 @@ def _normalize_watermarks( return max_queue, max_queue // 4 @cython.cfunc + @cython.inline def _set_write_limits(self, write_limit: Union[int, tuple[int, Optional[int]]]) -> None: if isinstance(write_limit, tuple): high, low = write_limit @@ -263,12 +318,14 @@ def resume_writing(self) -> None: self._write_ready = None @cython.cfunc + @cython.inline def _pause_reading_if_needed(self) -> None: if self._max_queue_high > 0 and not self._paused_reading and self._frames.qsize() >= self._max_queue_high: self.transport.underlying_transport.pause_reading() self._paused_reading = True @cython.cfunc + @cython.inline def _resume_reading_if_needed(self) -> None: if not self._paused_reading: return @@ -277,6 +334,7 @@ def _resume_reading_if_needed(self) -> None: self._paused_reading = False @cython.cfunc + @cython.inline def _set_close_exception(self) -> None: handshake = self.transport.close_handshake if handshake is None: @@ -315,6 +373,7 @@ def on_ws_disconnected(self, transport: WSTransport) -> None: self._pending_pings.clear() @cython.cfunc + @cython.inline def _fail_message_too_big(self, message: str) -> None: self.transport.send_close(WSCloseCode.MESSAGE_TOO_BIG, message.encode("utf-8")) self.transport.disconnect(False) @@ -358,12 +417,14 @@ async def _next_frame(self) -> _BufferedFrame: return frame @cython.cfunc + @cython.inline def _connection_closed(self) -> ConnectionClosed: if self._close_exc is None: self._set_close_exception() return self._close_exc or ConnectionClosedError(None, None, None) @cython.cfunc + @cython.inline def _ensure_recv_available(self) -> None: if self._recv_streaming_broken: raise ConcurrencyError("recv_streaming() wasn't fully consumed") @@ -374,7 +435,8 @@ async def recv(self, decode: Optional[bool] = None) -> Union[str, bytes]: self._ensure_recv_available() if self._recv_lock.locked(): raise ConcurrencyError("cannot call recv() concurrently") - async with self._recv_lock: + self._recv_lock.acquire() + try: first: _BufferedFrame = await self._next_frame() if first.msg_type not in (WSMsgType.TEXT, WSMsgType.BINARY): raise ProtocolError(f"unexpected opcode while receiving message: {first.msg_type}") @@ -396,8 +458,11 @@ async def recv(self, decode: Optional[bool] = None) -> Union[str, bytes]: payload = b"".join(chunks) return self._decode_payload(payload, msg_type, decode) + finally: + self._recv_lock.release() @cython.cfunc + @cython.inline def _decode_payload( self, payload: bytes, @@ -423,7 +488,8 @@ def recv_streaming(self, decode: Optional[bool] = None) -> AsyncIterator[Union[s async def iterator() -> AsyncIterator[Union[str, bytes]]: nonlocal started, finished try: - async with self._recv_lock: + self._recv_lock.acquire() + try: first: _BufferedFrame = await self._next_frame() if first.msg_type not in (WSMsgType.TEXT, WSMsgType.BINARY): raise ProtocolError(f"unexpected opcode while receiving message: {first.msg_type}") @@ -441,7 +507,9 @@ async def iterator() -> AsyncIterator[Union[str, bytes]]: self._fail_message_too_big("message too big") raise PayloadTooBig("message too big") yield self._decode_fragment(frame.payload, msg_type, decode) - finished = True + finished = True + finally: + self._recv_lock.release() finally: if started and not finished: self._recv_streaming_broken = True @@ -452,6 +520,7 @@ async def iterator() -> AsyncIterator[Union[str, bytes]]: return iterator() @cython.cfunc + @cython.inline def _decode_fragment( self, payload: bytes, @@ -474,15 +543,17 @@ async def send( if self.state is State.CLOSED: raise self._connection_closed() - async with self._send_lock: + if self._send_lock.locked(): + await self._send_lock.wait_and_acquire() + else: + self._send_lock.acquire() + + try: if isinstance(message, (str, bytes, bytearray, memoryview)): if isinstance(message, str): self.transport.send(WSMsgType.TEXT, cython.cast(str, message).encode("utf-8")) else: - self.transport.send( - WSMsgType.TEXT if text else WSMsgType.BINARY, - message, - ) + self.transport.send(WSMsgType.TEXT if text else WSMsgType.BINARY, message) if self._write_ready is not None: await self._write_ready return @@ -493,6 +564,8 @@ async def send( await self._send_sync_fragments(message, text) return raise TypeError(f"message has unsupported type {type(message).__name__}") + finally: + self._send_lock.release() async def _send_sync_fragments(self, message: Iterable[Data], text: Optional[bool]) -> None: iterator = iter(message) @@ -555,6 +628,7 @@ async def _send_async_fragments(self, message: AsyncIterable[Data], text: Option await self._write_ready @cython.cfunc + @cython.inline def _get_fragment_codec( self, first: Data, From 240e0a68862692362d39d1bb7938ce9f61349ce2 Mon Sep 17 00:00:00 2001 From: taras Date: Sat, 2 May 2026 23:42:42 +0200 Subject: [PATCH 06/57] Simplify --- AGENTS.md | 4 ++ picows/websockets/asyncio/client.py | 106 +++++++++++----------------- 2 files changed, 47 insertions(+), 63 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 355fc8b..82a02fe 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -19,6 +19,10 @@ examples - Various examples for users on how to use picows + perf_test that coul Every extra "just in case" branch teaches the reader that the state is part of normal behavior. Add such checks only for real risks like external misuse, concurrency races, partial failure, or invariants that are genuinely hard to guarantee. If the only reason for the check is uncertainty in the design, fix the design first. +- When simplifying code, finish the simplification across all equivalent branches, not only at the first local site. + If the same conversion, check, or tiny code pattern appears in multiple sibling paths after a refactor, stop and normalize it before considering the work done. + Do not remove one layer of abstraction only to inline the same logic redundantly in several places. + After a refactor, scan for duplicated branch bodies and duplicated type-specific handling introduced by the change. ## Testing instructions - Run lint after updating code with: diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index 572ebfa..1fb7cfc 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -442,7 +442,7 @@ async def recv(self, decode: Optional[bool] = None) -> Union[str, bytes]: raise ProtocolError(f"unexpected opcode while receiving message: {first.msg_type}") msg_type = first.msg_type if first.fin: - return self._decode_payload(first.payload, msg_type, decode) + return self._decode_data(first.payload, msg_type, decode) chunks = [first.payload] total: cython.Py_ssize_t = len(first.payload) @@ -457,13 +457,13 @@ async def recv(self, decode: Optional[bool] = None) -> Union[str, bytes]: raise PayloadTooBig("message too big") payload = b"".join(chunks) - return self._decode_payload(payload, msg_type, decode) + return self._decode_data(payload, msg_type, decode) finally: self._recv_lock.release() @cython.cfunc @cython.inline - def _decode_payload( + def _decode_data( self, payload: bytes, msg_type: WSMsgType, @@ -495,7 +495,7 @@ async def iterator() -> AsyncIterator[Union[str, bytes]]: raise ProtocolError(f"unexpected opcode while receiving message: {first.msg_type}") msg_type = first.msg_type started = True - yield self._decode_fragment(first.payload, msg_type, decode) + yield self._decode_data(first.payload, msg_type, decode) total: cython.Py_ssize_t = len(first.payload) frame = first while not frame.fin: @@ -506,7 +506,7 @@ async def iterator() -> AsyncIterator[Union[str, bytes]]: if self._max_message_size > 0 and total > self._max_message_size: self._fail_message_too_big("message too big") raise PayloadTooBig("message too big") - yield self._decode_fragment(frame.payload, msg_type, decode) + yield self._decode_data(frame.payload, msg_type, decode) finished = True finally: self._recv_lock.release() @@ -519,28 +519,12 @@ async def iterator() -> AsyncIterator[Union[str, bytes]]: return iterator() - @cython.cfunc - @cython.inline - def _decode_fragment( - self, - payload: bytes, - msg_type: WSMsgType, - decode: Optional[bool], - ) -> Union[str, bytes]: - if msg_type == WSMsgType.TEXT: - if decode is False: - return payload - return payload.decode("utf-8") - if decode is True: - return payload.decode("utf-8") - return payload - async def send( self, message: Union[Data, Iterable[Data], AsyncIterator[Data]], text: Optional[bool] = None, ) -> None: - if self.state is State.CLOSED: + if self._state is State.CLOSED: raise self._connection_closed() if self._send_lock.locked(): @@ -551,9 +535,10 @@ async def send( try: if isinstance(message, (str, bytes, bytearray, memoryview)): if isinstance(message, str): - self.transport.send(WSMsgType.TEXT, cython.cast(str, message).encode("utf-8")) + msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT else: - self.transport.send(WSMsgType.TEXT if text else WSMsgType.BINARY, message) + msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY + self.transport.send(msg_type, message) if self._write_ready is not None: await self._write_ready return @@ -574,26 +559,36 @@ async def _send_sync_fragments(self, message: Iterable[Data], text: Optional[boo except StopIteration: raise TypeError("message iterable cannot be empty") from None - msg_type, encode = self._get_fragment_codec(first, text) + expected_type = type(first) + if expected_type is str: + msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT + elif expected_type in (bytes, bytearray, memoryview): + msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY + else: + raise TypeError(f"message must contain str or bytes-like objects, got {type(first).__name__}") try: second = next(iterator) except StopIteration: - self.transport.send(msg_type, encode(first)) + self.transport.send(msg_type, first) if self._write_ready is not None: await self._write_ready return - self.transport.send(msg_type, encode(first), fin=False) + self.transport.send(msg_type, first, fin=False) if self._write_ready is not None: await self._write_ready previous = second for fragment in iterator: - self.transport.send(WSMsgType.CONTINUATION, encode(previous), fin=False) + if type(previous) is not expected_type: + raise TypeError("all fragments must be of the same type") + self.transport.send(WSMsgType.CONTINUATION, previous, fin=False) if self._write_ready is not None: await self._write_ready previous = fragment - self.transport.send(WSMsgType.CONTINUATION, encode(previous), fin=True) + if type(previous) is not expected_type: + raise TypeError("all fragments must be of the same type") + self.transport.send(WSMsgType.CONTINUATION, previous, fin=True) if self._write_ready is not None: await self._write_ready @@ -604,58 +599,43 @@ async def _send_async_fragments(self, message: AsyncIterable[Data], text: Option except StopAsyncIteration: raise TypeError("message iterable cannot be empty") from None - msg_type, encode = self._get_fragment_codec(first, text) + expected_type = type(first) + if expected_type is str: + msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT + elif expected_type in (bytes, bytearray, memoryview): + msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY + else: + raise TypeError(f"message must contain str or bytes-like objects, got {type(first).__name__}") try: second = await anext(iterator) except StopAsyncIteration: - self.transport.send(msg_type, encode(first)) + self.transport.send(msg_type, first) if self._write_ready is not None: await self._write_ready return - self.transport.send(msg_type, encode(first), fin=False) + self.transport.send(msg_type, first, fin=False) if self._write_ready is not None: await self._write_ready previous = second async for fragment in iterator: - self.transport.send(WSMsgType.CONTINUATION, encode(previous), fin=False) + if type(previous) is not expected_type: + raise TypeError("all fragments must be of the same type") + self.transport.send(WSMsgType.CONTINUATION, previous, fin=False) if self._write_ready is not None: await self._write_ready previous = fragment - self.transport.send(WSMsgType.CONTINUATION, encode(previous), fin=True) + if type(previous) is not expected_type: + raise TypeError("all fragments must be of the same type") + self.transport.send(WSMsgType.CONTINUATION, previous, fin=True) if self._write_ready is not None: await self._write_ready - @cython.cfunc - @cython.inline - def _get_fragment_codec( - self, - first: Data, - text: Optional[bool], - ) -> tuple[WSMsgType, Callable[[Data], bytes]]: - if isinstance(first, str): - def encode(item: Data) -> bytes: - if not isinstance(item, str): - raise TypeError("all fragments must be of the same type") - return item.encode("utf-8") - - return WSMsgType.TEXT, encode - - if isinstance(first, (bytes, bytearray, memoryview)): - def encode(item: Data) -> Data: - if not isinstance(item, (bytes, bytearray, memoryview)): - raise TypeError("all fragments must be of the same type") - return item - - return (WSMsgType.TEXT if text else WSMsgType.BINARY), encode - - raise TypeError(f"message must contain str or bytes-like objects, got {type(first).__name__}") - async def close(self, code: int = 1000, reason: str = "") -> None: - if self.state is State.CLOSED: + if self._state is State.CLOSED: return - if self.state is State.OPEN: + if self._state is State.OPEN: self._state = State.CLOSING self.transport.send_close(code, reason.encode("utf-8")) try: @@ -671,7 +651,7 @@ async def wait_closed(self) -> None: await self._closed_event.wait() async def ping(self, data: Optional[Union[str, bytes]] = None) -> Awaitable[float]: - if self.state is State.CLOSED: + if self._state is State.CLOSED: raise self._connection_closed() if data is None: while True: @@ -694,7 +674,7 @@ async def ping(self, data: Optional[Union[str, bytes]] = None) -> Awaitable[floa return waiter async def pong(self, data: Union[str, bytes] = b"") -> None: - if self.state is State.CLOSED: + if self._state is State.CLOSED: raise self._connection_closed() payload = data.encode("utf-8") if isinstance(data, str) else data self.transport.send_pong(payload) From ed2f717a5d94c3e370e423ccc927babf93c6d09f Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 01:25:47 +0200 Subject: [PATCH 07/57] Simplify --- picows/websockets/asyncio/client.py | 120 +++++++++++----------------- 1 file changed, 48 insertions(+), 72 deletions(-) diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index 1fb7cfc..423c6ce 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -13,7 +13,7 @@ from ssl import SSLContext from time import monotonic from typing import Any, AsyncIterator, Awaitable, Callable, Optional, Sequence, \ - Union, cast, Dict, Tuple + Union, cast, Dict, Tuple, Iterator from urllib.request import getproxies import cython @@ -482,15 +482,18 @@ def recv_streaming(self, decode: Optional[bool] = None) -> AsyncIterator[Union[s if self._recv_lock.locked(): raise ConcurrencyError("cannot call recv_streaming() concurrently") self._recv_streaming_in_progress = True - started = False - finished = False + started: cython.bint = False + finished: cython.bint = False async def iterator() -> AsyncIterator[Union[str, bytes]]: nonlocal started, finished + msg_type: WSMsgType + first: _BufferedFrame + frame: _BufferedFrame try: self._recv_lock.acquire() try: - first: _BufferedFrame = await self._next_frame() + first = await self._next_frame() if first.msg_type not in (WSMsgType.TEXT, WSMsgType.BINARY): raise ProtocolError(f"unexpected opcode while receiving message: {first.msg_type}") msg_type = first.msg_type @@ -499,7 +502,7 @@ async def iterator() -> AsyncIterator[Union[str, bytes]]: total: cython.Py_ssize_t = len(first.payload) frame = first while not frame.fin: - frame: _BufferedFrame = await self._next_frame() + frame = await self._next_frame() if frame.msg_type != WSMsgType.CONTINUATION: raise ProtocolError("expected continuation frame") total += len(frame.payload) @@ -542,93 +545,66 @@ async def send( if self._write_ready is not None: await self._write_ready return - if isinstance(message, AsyncIterable): - await self._send_async_fragments(message, text) + elif isinstance(message, AsyncIterable): + await self._send_fragments(True, message.__aiter__(), text) return - if isinstance(message, Iterable): - await self._send_sync_fragments(message, text) + elif isinstance(message, Iterable): + await self._send_fragments(False, iter(message), text) return raise TypeError(f"message has unsupported type {type(message).__name__}") finally: self._send_lock.release() - async def _send_sync_fragments(self, message: Iterable[Data], text: Optional[bool]) -> None: - iterator = iter(message) - try: - first = next(iterator) - except StopIteration: - raise TypeError("message iterable cannot be empty") from None - - expected_type = type(first) - if expected_type is str: - msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT - elif expected_type in (bytes, bytearray, memoryview): - msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY - else: - raise TypeError(f"message must contain str or bytes-like objects, got {type(first).__name__}") - - try: - second = next(iterator) - except StopIteration: - self.transport.send(msg_type, first) - if self._write_ready is not None: - await self._write_ready + @cython.cfunc + @cython.inline + def _check_fragment_type(self, message, first_is_str: cython.bint) -> None: + if first_is_str and isinstance(message, str): + return + elif not first_is_str and isinstance(message, (bytes, bytearray, memoryview)): return - self.transport.send(msg_type, first, fin=False) - if self._write_ready is not None: - await self._write_ready - previous = second - for fragment in iterator: - if type(previous) is not expected_type: - raise TypeError("all fragments must be of the same type") - self.transport.send(WSMsgType.CONTINUATION, previous, fin=False) - if self._write_ready is not None: - await self._write_ready - previous = fragment - if type(previous) is not expected_type: - raise TypeError("all fragments must be of the same type") - self.transport.send(WSMsgType.CONTINUATION, previous, fin=True) - if self._write_ready is not None: - await self._write_ready + raise TypeError("all fragments must be of the same category: str vs bytes-like") - async def _send_async_fragments(self, message: AsyncIterable[Data], text: Optional[bool]) -> None: - iterator = message.__aiter__() + async def _send_fragments(self, is_async: cython.bint, iterator: Union[Iterator[Data], AsyncIterator[Data]], text: Optional[bool]) -> None: + stop_exception_type = StopAsyncIteration if is_async else StopIteration try: - first = await anext(iterator) - except StopAsyncIteration: + if is_async: + first = await anext(iterator) + else: + first = next(iterator) + except stop_exception_type: raise TypeError("message iterable cannot be empty") from None - expected_type = type(first) - if expected_type is str: + first_is_str: cython.bint + if isinstance(first, str): msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT - elif expected_type in (bytes, bytearray, memoryview): + first_is_str = True + elif isinstance(first, (bytes, bytearray, memoryview)): msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY + first_is_str = False else: raise TypeError(f"message must contain str or bytes-like objects, got {type(first).__name__}") - try: - second = await anext(iterator) - except StopAsyncIteration: - self.transport.send(msg_type, first) - if self._write_ready is not None: - await self._write_ready - return + previous = first + while True: + try: + if is_async: + current = await anext(iterator) + else: + current = next(iterator) + except stop_exception_type: + break - self.transport.send(msg_type, first, fin=False) - if self._write_ready is not None: - await self._write_ready - previous = second - async for fragment in iterator: - if type(previous) is not expected_type: - raise TypeError("all fragments must be of the same type") - self.transport.send(WSMsgType.CONTINUATION, previous, fin=False) + self._check_fragment_type(current, first_is_str) + + self.transport.send(msg_type, previous, fin=False) + msg_type = WSMsgType.CONTINUATION if self._write_ready is not None: await self._write_ready - previous = fragment - if type(previous) is not expected_type: - raise TypeError("all fragments must be of the same type") - self.transport.send(WSMsgType.CONTINUATION, previous, fin=True) + + previous = current + + self.transport.send(msg_type, previous, fin=True) if self._write_ready is not None: await self._write_ready From 763e36b2268db3ee99f52ec2287b441f725fce9e Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 02:21:01 +0200 Subject: [PATCH 08/57] Simplify and optimize --- AGENTS.md | 1 + picows/websockets/__init__.py | 3 +- picows/websockets/asyncio/__init__.py | 3 +- picows/websockets/asyncio/client.py | 708 +---------------------- picows/websockets/asyncio/connection.py | 722 ++++++++++++++++++++++++ setup.py | 2 +- 6 files changed, 745 insertions(+), 694 deletions(-) create mode 100644 picows/websockets/asyncio/connection.py diff --git a/AGENTS.md b/AGENTS.md index 82a02fe..1f07284 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -8,6 +8,7 @@ tests - Contains tests for picows examples - Various examples for users on how to use picows + perf_test that could be used to build call-graph with perf ## Code style notes +- Max line width is 120 - Do not write `del transport` or similar `del ` statements inside callbacks just to mark arguments as unused. Leave unused callback parameters as-is or rename them with a leading underscore if that is clearer. Using `del` in this situation is confusing and suggests reference-counting or lifetime management concerns. diff --git a/picows/websockets/__init__.py b/picows/websockets/__init__.py index aeb99e0..7ece134 100644 --- a/picows/websockets/__init__.py +++ b/picows/websockets/__init__.py @@ -1,5 +1,6 @@ from . import exceptions -from .asyncio.client import ClientConnection, State, connect, process_exception +from .asyncio.client import connect +from .asyncio.connection import ClientConnection, State, process_exception from .exceptions import ( ConcurrencyError, ConnectionClosed, diff --git a/picows/websockets/asyncio/__init__.py b/picows/websockets/asyncio/__init__.py index 9869987..ed18ffe 100644 --- a/picows/websockets/asyncio/__init__.py +++ b/picows/websockets/asyncio/__init__.py @@ -1,4 +1,5 @@ -from .client import ClientConnection, State, connect, process_exception +from .client import connect +from .connection import ClientConnection, State, process_exception __all__ = [ "ClientConnection", diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index 423c6ce..edb5f1a 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -1,162 +1,42 @@ from __future__ import annotations import asyncio -import sys -import logging -import os import socket -import uuid import warnings -from collections import deque -from collections.abc import AsyncIterable, Generator, Iterable -from enum import IntEnum +from collections.abc import Generator from ssl import SSLContext -from time import monotonic -from typing import Any, AsyncIterator, Awaitable, Callable, Optional, Sequence, \ - Union, cast, Dict, Tuple, Iterator -from urllib.request import getproxies - -import cython - -if cython.compiled: - from cython.cimports.picows.picows import WSListener, WSTransport, WSFrame, \ - WSMsgType, WSCloseCode -else: - from picows import WSListener, WSTransport, WSFrame, WSMsgType, WSCloseCode - +from typing import Any, Optional, Sequence, Union, cast import picows -from picows.types import WSHeadersLike from picows.url import parse_url +from .connection import ( + ClientConnection, + HeadersLike, + LoggerLike, + process_exception, +) from ..exceptions import ( - ConcurrencyError, - ConnectionClosed, - ConnectionClosedError, - ConnectionClosedOK, InvalidHandshake, InvalidHeader, InvalidMessage, InvalidStatus, InvalidUpgrade, InvalidURI, - PayloadTooBig, - ProtocolError, ) -Data = Union[str, bytes, bytearray, memoryview] -HeadersLike = WSHeadersLike -CloseCodeT = int -LoggerLike = Union[str, logging.Logger, logging.LoggerAdapter[Any], None] - - -OK_CLOSE_CODES = {0, 1000, 1001} - - -class State(IntEnum): - CONNECTING = 0 - OPEN = 1 - CLOSING = 2 - CLOSED = 3 - - -@cython.cclass -class _BufferedFrame: - msg_type: WSMsgType - payload: bytes - fin: bool - - def __init__(self, msg_type: WSMsgType, payload: bytes, fin: bool): - self.msg_type = msg_type - self.payload = payload - self.fin = fin - - -@cython.cclass -class _AsyncLock: - _loop: asyncio.AbstractEventLoop - _locked: cython.bint - _waiters: Any - - def __init__(self, loop: asyncio.AbstractEventLoop): - self._loop = loop - self._locked = False - self._waiters = deque() - - @cython.cfunc - @cython.inline - def locked(self) -> cython.bint: - return self._locked - - @cython.cfunc - @cython.inline - def acquire(self) -> None: - self._locked = True - - async def wait_and_acquire(self) -> None: - waiter = self._loop.create_future() - self._waiters.append(waiter) - try: - await waiter - except Exception: - try: - self._waiters.remove(waiter) - except ValueError: - pass - raise - - @cython.cfunc - @cython.inline - def release(self) -> None: - while self._waiters: - waiter = self._waiters.popleft() - if not waiter.done(): - waiter.set_result(None) - return - self._locked = False - - -@cython.cfunc -@cython.inline -def _coerce_close_code(code: WSCloseCode) -> Optional[int]: - return None if code is None else int(code) - - -@cython.cfunc -@cython.inline -def _coerce_close_reason(reason: Optional[str]) -> Optional[str]: - return reason if reason is not None else None +def _default_user_agent() -> str: + import sys + return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" -@cython.cfunc -@cython.inline def _header_items(headers: Any) -> list[tuple[str, str]]: return [] if headers is None else list(headers.items()) -@cython.cfunc -@cython.inline -def _resolve_subprotocol(subprotocols: Optional[Sequence[str]], response: Any) -> Optional[str]: - if response is None: - return None - value = response.headers.get("Sec-WebSocket-Protocol") - if value is None: - return None - if subprotocols is not None and value not in subprotocols: - raise InvalidHandshake(f"unsupported subprotocol negotiated by server: {value}") - return cast(str, value) - - -@cython.cfunc -@cython.inline -def _default_user_agent() -> str: - return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" - - -@cython.cfunc -@cython.inline def _process_proxy(proxy: Union[str, bool, None], secure: bool) -> Optional[str]: + from urllib.request import getproxies if proxy is None: return None if isinstance(proxy, str): @@ -170,565 +50,10 @@ def _process_proxy(proxy: Union[str, bool, None], secure: bool) -> Optional[str] raise InvalidURI(str(proxy), "proxy must be None, True, or a proxy URL") -@cython.cfunc -@cython.inline -def _normalize_size_limit(limit: Optional[int]) -> cython.Py_ssize_t: +def _normalize_size_limit(limit: Optional[int]) -> int: return 0 if limit is None else limit -@cython.ccall -def process_exception(exc: Exception) -> Optional[Exception]: - if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)): - return None - if isinstance(exc, InvalidStatus): - status = exc.response.status - if int(status) in {500, 502, 503, 504}: - return None - return exc - - -@cython.cclass -class ClientConnection(WSListener): - id: uuid.UUID - logger: Union[logging.Logger, logging.LoggerAdapter[Any]] - transport: WSTransport - request: picows.WSUpgradeRequest - response: picows.WSUpgradeResponse - _subprotocols: Optional[Sequence[str]] - _subprotocol: Optional[str] - _state: State - _closed_event: asyncio.Event - _frames: asyncio.Queue[Optional[_BufferedFrame]] - _close_exc: Optional[ConnectionClosed] - _loop: asyncio.AbstractEventLoop - _recv_lock: _AsyncLock - _send_lock: _AsyncLock - _write_ready: Optional[asyncio.Future[None]] - _recv_streaming_in_progress: cython.bint - _recv_streaming_broken: cython.bint - _pending_pings: Dict[bytes, Tuple[asyncio.Future[float], float]] - _ping_interval: Optional[float] - _ping_timeout: Optional[float] - _close_timeout: Optional[float] - _keepalive_task: Optional[asyncio.Task[None]] - _latency: cython.double - _max_message_size: cython.Py_ssize_t - _max_fragment_size: cython.Py_ssize_t - _max_queue_high: cython.Py_ssize_t - _max_queue_low: cython.Py_ssize_t - _write_limit: Union[int, tuple[int, Optional[int]]] - _paused_reading: cython.bint - - def __init__( - self, - *, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = 10, - max_queue: Union[int, tuple[Optional[int], Optional[int]], None] = 16, - write_limit: Union[int, tuple[int, Optional[int]]] = 32768, - max_message_size: Optional[int] = 1024 * 1024, - max_fragment_size: Optional[int] = 1024 * 1024, - logger: LoggerLike = None, - subprotocols: Optional[Sequence[str]] = None, - ): - self.id = uuid.uuid4() - self.logger = self._resolve_logger(logger) - self.transport = cython.cast(WSTransport, None) - self.request = cast(picows.WSUpgradeRequest, None) - self.response = cast(picows.WSUpgradeResponse, None) - self._subprotocols = subprotocols - self._subprotocol = cast(Optional[str], None) - self._state = State.CONNECTING - self._closed_event = asyncio.Event() - self._frames: asyncio.Queue[Optional[_BufferedFrame]] = asyncio.Queue() - self._close_exc: Optional[ConnectionClosed] = None - self._loop = asyncio.get_running_loop() - self._recv_lock = _AsyncLock(self._loop) - self._send_lock = _AsyncLock(self._loop) - self._write_ready: Optional[asyncio.Future[None]] = None - self._recv_streaming_in_progress = False - self._recv_streaming_broken = False - self._pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} - self._ping_interval = ping_interval - self._ping_timeout = ping_timeout - self._close_timeout = close_timeout - self._keepalive_task: Optional[asyncio.Task[None]] = None - self._latency = 0.0 - self._max_message_size = _normalize_size_limit(max_message_size) - self._max_fragment_size = _normalize_size_limit(max_fragment_size) - self._max_queue_high, self._max_queue_low = self._normalize_watermarks(max_queue) - self._write_limit = write_limit - self._paused_reading = False - - @cython.cfunc - @cython.inline - def _resolve_logger(self, logger: LoggerLike) -> Union[logging.Logger, logging.LoggerAdapter[Any]]: - if logger is None: - return logging.getLogger("websockets.client") - if isinstance(logger, str): - return logging.getLogger(logger) - return logger - - @cython.cfunc - @cython.inline - def _normalize_watermarks( - self, - max_queue: Union[int, tuple[Optional[int], Optional[int]], None], - ) -> tuple[cython.Py_ssize_t, cython.Py_ssize_t]: - if max_queue is None: - return 0, 0 - if isinstance(max_queue, tuple): - high, low = max_queue - if high is None: - return 0, 0 - return high, high // 4 if low is None else low - return max_queue, max_queue // 4 - - @cython.cfunc - @cython.inline - def _set_write_limits(self, write_limit: Union[int, tuple[int, Optional[int]]]) -> None: - if isinstance(write_limit, tuple): - high, low = write_limit - else: - high, low = write_limit, None - self.transport.underlying_transport.set_write_buffer_limits(high=high, low=low) - - @cython.ccall - def on_ws_connected(self, transport: WSTransport) -> None: - self.transport = transport - self.request = transport.request - self.response = transport.response - self._subprotocol = _resolve_subprotocol(self._subprotocols, self.response) - self._state = State.OPEN - self._set_write_limits(self._write_limit) - if self._ping_interval is not None and self._keepalive_task is None: - self._keepalive_task = asyncio.create_task(self._keepalive_loop()) - - @cython.ccall - def pause_writing(self) -> None: - if self._write_ready is None: - self._write_ready = self._loop.create_future() - - @cython.ccall - def resume_writing(self) -> None: - if self._write_ready is not None: - if not self._write_ready.done(): - self._write_ready.set_result(None) - self._write_ready = None - - @cython.cfunc - @cython.inline - def _pause_reading_if_needed(self) -> None: - if self._max_queue_high > 0 and not self._paused_reading and self._frames.qsize() >= self._max_queue_high: - self.transport.underlying_transport.pause_reading() - self._paused_reading = True - - @cython.cfunc - @cython.inline - def _resume_reading_if_needed(self) -> None: - if not self._paused_reading: - return - if self._max_queue_low == 0 or self._frames.qsize() <= self._max_queue_low: - self.transport.underlying_transport.resume_reading() - self._paused_reading = False - - @cython.cfunc - @cython.inline - def _set_close_exception(self) -> None: - handshake = self.transport.close_handshake - if handshake is None: - self._close_exc = ConnectionClosedError(None, None, None) - return - rcvd = handshake.recv - sent = handshake.sent - rcvd_then_sent = handshake.recv_then_sent - rcvd_code = _coerce_close_code(rcvd.code) if rcvd is not None else None - sent_code = _coerce_close_code(sent.code) if sent is not None else None - ok = ( - (rcvd_code in OK_CLOSE_CODES or rcvd_code is None) - and (sent_code in OK_CLOSE_CODES or sent_code is None) - ) - exc_type = ConnectionClosedOK if ok else ConnectionClosedError - self._close_exc = exc_type(rcvd, sent, rcvd_then_sent) - - @cython.ccall - def on_ws_disconnected(self, transport: WSTransport) -> None: - self._state = State.CLOSED - self._set_close_exception() - self._frames.put_nowait(None) - self._closed_event.set() - if self._keepalive_task is not None: - self._keepalive_task.cancel() - self._keepalive_task = None - if self._write_ready is not None: - if not self._write_ready.done(): - self._write_ready.set_exception( - self._close_exc or ConnectionClosedError(None, None, None) - ) - self._write_ready = None - for waiter, _ in self._pending_pings.values(): - if not waiter.done(): - waiter.set_exception(self._close_exc or ConnectionClosedError(None, None, None)) - self._pending_pings.clear() - - @cython.cfunc - @cython.inline - def _fail_message_too_big(self, message: str) -> None: - self.transport.send_close(WSCloseCode.MESSAGE_TOO_BIG, message.encode("utf-8")) - self.transport.disconnect(False) - - @cython.ccall - def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: - payload = frame.get_payload_as_bytes() - if frame.msg_type == WSMsgType.PING: - self.transport.send_pong(payload) - return - - if frame.msg_type == WSMsgType.PONG: - ping = self._pending_pings.pop(payload, None) - if ping is not None: - waiter, sent_at = ping - self._latency = monotonic() - sent_at - if not waiter.done(): - waiter.set_result(self._latency) - return - - if frame.msg_type == WSMsgType.CLOSE: - close_code = frame.get_close_code() - close_message = frame.get_close_message() - self.transport.send_close(close_code, close_message) - self.transport.disconnect() - self._state = State.CLOSING - return - - if self._max_fragment_size > 0 and len(payload) > self._max_fragment_size: - self._fail_message_too_big("fragment too big") - return - - self._frames.put_nowait(_BufferedFrame(frame.msg_type, payload, frame.fin)) - self._pause_reading_if_needed() - - async def _next_frame(self) -> _BufferedFrame: - frame: _BufferedFrame = cython.cast(_BufferedFrame, await self._frames.get()) - self._resume_reading_if_needed() - if frame is None: - raise self._connection_closed() - return frame - - @cython.cfunc - @cython.inline - def _connection_closed(self) -> ConnectionClosed: - if self._close_exc is None: - self._set_close_exception() - return self._close_exc or ConnectionClosedError(None, None, None) - - @cython.cfunc - @cython.inline - def _ensure_recv_available(self) -> None: - if self._recv_streaming_broken: - raise ConcurrencyError("recv_streaming() wasn't fully consumed") - if self._recv_streaming_in_progress: - raise ConcurrencyError("cannot call recv() while recv_streaming() is active") - - async def recv(self, decode: Optional[bool] = None) -> Union[str, bytes]: - self._ensure_recv_available() - if self._recv_lock.locked(): - raise ConcurrencyError("cannot call recv() concurrently") - self._recv_lock.acquire() - try: - first: _BufferedFrame = await self._next_frame() - if first.msg_type not in (WSMsgType.TEXT, WSMsgType.BINARY): - raise ProtocolError(f"unexpected opcode while receiving message: {first.msg_type}") - msg_type = first.msg_type - if first.fin: - return self._decode_data(first.payload, msg_type, decode) - - chunks = [first.payload] - total: cython.Py_ssize_t = len(first.payload) - while not first.fin: - first = await self._next_frame() - if first.msg_type != WSMsgType.CONTINUATION: - raise ProtocolError("expected continuation frame") - chunks.append(first.payload) - total += len(first.payload) - if self._max_message_size > 0 and total > self._max_message_size: - self._fail_message_too_big("message too big") - raise PayloadTooBig("message too big") - - payload = b"".join(chunks) - return self._decode_data(payload, msg_type, decode) - finally: - self._recv_lock.release() - - @cython.cfunc - @cython.inline - def _decode_data( - self, - payload: bytes, - msg_type: WSMsgType, - decode: Optional[bool], - ) -> Union[str, bytes]: - if msg_type == WSMsgType.TEXT: - if decode is False: - return payload - return payload.decode("utf-8") - if decode is True: - return payload.decode("utf-8") - return payload - - def recv_streaming(self, decode: Optional[bool] = None) -> AsyncIterator[Union[str, bytes]]: - self._ensure_recv_available() - if self._recv_lock.locked(): - raise ConcurrencyError("cannot call recv_streaming() concurrently") - self._recv_streaming_in_progress = True - started: cython.bint = False - finished: cython.bint = False - - async def iterator() -> AsyncIterator[Union[str, bytes]]: - nonlocal started, finished - msg_type: WSMsgType - first: _BufferedFrame - frame: _BufferedFrame - try: - self._recv_lock.acquire() - try: - first = await self._next_frame() - if first.msg_type not in (WSMsgType.TEXT, WSMsgType.BINARY): - raise ProtocolError(f"unexpected opcode while receiving message: {first.msg_type}") - msg_type = first.msg_type - started = True - yield self._decode_data(first.payload, msg_type, decode) - total: cython.Py_ssize_t = len(first.payload) - frame = first - while not frame.fin: - frame = await self._next_frame() - if frame.msg_type != WSMsgType.CONTINUATION: - raise ProtocolError("expected continuation frame") - total += len(frame.payload) - if self._max_message_size > 0 and total > self._max_message_size: - self._fail_message_too_big("message too big") - raise PayloadTooBig("message too big") - yield self._decode_data(frame.payload, msg_type, decode) - finished = True - finally: - self._recv_lock.release() - finally: - if started and not finished: - self._recv_streaming_broken = True - elif finished: - self._recv_streaming_broken = False - self._recv_streaming_in_progress = False - - return iterator() - - async def send( - self, - message: Union[Data, Iterable[Data], AsyncIterator[Data]], - text: Optional[bool] = None, - ) -> None: - if self._state is State.CLOSED: - raise self._connection_closed() - - if self._send_lock.locked(): - await self._send_lock.wait_and_acquire() - else: - self._send_lock.acquire() - - try: - if isinstance(message, (str, bytes, bytearray, memoryview)): - if isinstance(message, str): - msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT - else: - msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY - self.transport.send(msg_type, message) - if self._write_ready is not None: - await self._write_ready - return - elif isinstance(message, AsyncIterable): - await self._send_fragments(True, message.__aiter__(), text) - return - elif isinstance(message, Iterable): - await self._send_fragments(False, iter(message), text) - return - raise TypeError(f"message has unsupported type {type(message).__name__}") - finally: - self._send_lock.release() - - @cython.cfunc - @cython.inline - def _check_fragment_type(self, message, first_is_str: cython.bint) -> None: - if first_is_str and isinstance(message, str): - return - elif not first_is_str and isinstance(message, (bytes, bytearray, memoryview)): - return - - raise TypeError("all fragments must be of the same category: str vs bytes-like") - - async def _send_fragments(self, is_async: cython.bint, iterator: Union[Iterator[Data], AsyncIterator[Data]], text: Optional[bool]) -> None: - stop_exception_type = StopAsyncIteration if is_async else StopIteration - try: - if is_async: - first = await anext(iterator) - else: - first = next(iterator) - except stop_exception_type: - raise TypeError("message iterable cannot be empty") from None - - first_is_str: cython.bint - if isinstance(first, str): - msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT - first_is_str = True - elif isinstance(first, (bytes, bytearray, memoryview)): - msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY - first_is_str = False - else: - raise TypeError(f"message must contain str or bytes-like objects, got {type(first).__name__}") - - previous = first - while True: - try: - if is_async: - current = await anext(iterator) - else: - current = next(iterator) - except stop_exception_type: - break - - self._check_fragment_type(current, first_is_str) - - self.transport.send(msg_type, previous, fin=False) - msg_type = WSMsgType.CONTINUATION - if self._write_ready is not None: - await self._write_ready - - previous = current - - self.transport.send(msg_type, previous, fin=True) - if self._write_ready is not None: - await self._write_ready - - async def close(self, code: int = 1000, reason: str = "") -> None: - if self._state is State.CLOSED: - return - if self._state is State.OPEN: - self._state = State.CLOSING - self.transport.send_close(code, reason.encode("utf-8")) - try: - if self._close_timeout is None: - await self.wait_closed() - else: - await asyncio.wait_for(self.wait_closed(), self._close_timeout) - except asyncio.TimeoutError: - self.transport.disconnect(False) - await self.wait_closed() - - async def wait_closed(self) -> None: - await self._closed_event.wait() - - async def ping(self, data: Optional[Union[str, bytes]] = None) -> Awaitable[float]: - if self._state is State.CLOSED: - raise self._connection_closed() - if data is None: - while True: - payload = os.urandom(4) - if payload not in self._pending_pings: - break - elif isinstance(data, str): - payload = data.encode("utf-8") - elif isinstance(data, bytes): - payload = data - else: - raise TypeError("ping payload must be str, bytes, or None") - - if payload in self._pending_pings: - raise ConcurrencyError("another ping was sent with the same data") - - waiter: asyncio.Future[float] = asyncio.get_running_loop().create_future() - self._pending_pings[payload] = (waiter, monotonic()) - self.transport.send_ping(payload) - return waiter - - async def pong(self, data: Union[str, bytes] = b"") -> None: - if self._state is State.CLOSED: - raise self._connection_closed() - payload = data.encode("utf-8") if isinstance(data, str) else data - self.transport.send_pong(payload) - - async def _keepalive_loop(self) -> None: - try: - while True: - assert self._ping_interval is not None - await asyncio.sleep(self._ping_interval) - waiter = await self.ping() - if self._ping_timeout is None: - continue - await asyncio.wait_for(waiter, self._ping_timeout) - except asyncio.CancelledError: - raise - except Exception: - if self.state is not State.CLOSED: - await self.close(code=1011, reason="keepalive ping timeout") - - async def __aenter__(self) -> ClientConnection: - return self - - async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - await self.close() - - def __aiter__(self) -> AsyncIterator[Union[str, bytes]]: - return self._iterate_messages() - - async def _iterate_messages(self) -> AsyncIterator[Union[str, bytes]]: - while True: - try: - yield await self.recv() - except ConnectionClosedOK: - return - - @property - def state(self) -> State: - return self._state - - @property - def local_address(self) -> Any: - return self.transport.underlying_transport.get_extra_info("sockname") - - @property - def remote_address(self) -> Any: - return self.transport.underlying_transport.get_extra_info("peername") - - @property - def latency(self) -> float: - return self._latency - - @property - def subprotocol(self) -> Optional[str]: - return self._subprotocol - - @property - def close_code(self) -> Optional[int]: - handshake = self.transport.close_handshake - if handshake is None: - return None - if handshake.recv is not None: - return _coerce_close_code(handshake.recv.code) - if handshake.sent is not None: - return _coerce_close_code(handshake.sent.code) - return None - - @property - def close_reason(self) -> Optional[str]: - handshake = self.transport.close_handshake - if handshake is None: - return None - if handshake.recv is not None: - return _coerce_close_reason(handshake.recv.reason) - if handshake.sent is not None: - return _coerce_close_reason(handshake.sent.reason) - return None - - class _Connect: def __init__( self, @@ -741,7 +66,7 @@ def __init__( additional_headers: Optional[HeadersLike] = None, user_agent_header: Optional[str] = _default_user_agent(), proxy: Union[str, bool, None] = True, - process_exception: Callable[[Exception], Optional[Exception]] = process_exception, + process_exception=process_exception, open_timeout: Optional[float] = 10, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, @@ -856,6 +181,7 @@ async def connect_override(_: Any) -> socket.socket: return sock socket_factory = connect_override + def listener_factory() -> ClientConnection: return self.connection_factory( ping_interval=self.ping_interval, @@ -897,12 +223,12 @@ def listener_factory() -> ClientConnection: except picows.WSHandshakeError as exc: raise InvalidHandshake(str(exc)) from exc - return cython.cast(ClientConnection, listener) + return cast(ClientConnection, listener) def _normalize_max_size( self, max_size: Union[int, tuple[Optional[int], Optional[int]], None], - ) -> tuple[cython.Py_ssize_t, cython.Py_ssize_t]: + ) -> tuple[int, int]: if max_size is None: return 0, 0 if isinstance(max_size, tuple): diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py new file mode 100644 index 0000000..30a1c9d --- /dev/null +++ b/picows/websockets/asyncio/connection.py @@ -0,0 +1,722 @@ +from __future__ import annotations + +import asyncio +import sys +import logging +import os +import socket +import uuid +import warnings +from collections import deque +from collections.abc import AsyncIterable, Generator, Iterable +from enum import IntEnum +from ssl import SSLContext +from time import monotonic +from typing import Any, AsyncIterator, Awaitable, Callable, Optional, Sequence, \ + Union, cast, Dict, Tuple, Iterator +from urllib.request import getproxies + +import cython + +if cython.compiled: + from cython.cimports.picows.picows import WSListener, WSTransport, WSFrame, \ + WSMsgType, WSCloseCode +else: + from picows import WSListener, WSTransport, WSFrame, WSMsgType, WSCloseCode + + +import picows +from picows.types import WSHeadersLike + +from ..exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + InvalidHandshake, + InvalidStatus, + InvalidURI, + PayloadTooBig, + ProtocolError, +) + + +DataLike = Union[str, bytes, bytearray, memoryview] +HeadersLike = WSHeadersLike +CloseCodeT = int +LoggerLike = Union[str, logging.Logger, logging.LoggerAdapter[Any], None] + + +OK_CLOSE_CODES = {0, 1000, 1001} + + +class State(IntEnum): + CONNECTING = 0 + OPEN = 1 + CLOSING = 2 + CLOSED = 3 + + +@cython.cclass +class _BufferedFrame: + msg_type: WSMsgType + payload: bytes + fin: bool + + def __init__(self, msg_type: WSMsgType, payload: bytes, fin: bool): + self.msg_type = msg_type + self.payload = payload + self.fin = fin + + +@cython.cclass +class _AsyncLock: + _loop: asyncio.AbstractEventLoop + _locked: cython.bint + _waiters: Any + + def __init__(self, loop: asyncio.AbstractEventLoop): + self._loop = loop + self._locked = False + self._waiters = deque() + + @cython.cfunc + @cython.inline + def locked(self) -> cython.bint: + return self._locked + + @cython.cfunc + @cython.inline + def acquire(self) -> None: + self._locked = True + + async def wait_and_acquire(self) -> None: + waiter = self._loop.create_future() + self._waiters.append(waiter) + try: + await waiter + except Exception: + try: + self._waiters.remove(waiter) + except ValueError: + pass + raise + + @cython.cfunc + @cython.inline + def release(self) -> None: + while self._waiters: + waiter = self._waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + return + self._locked = False + + +@cython.cfunc +@cython.inline +def _coerce_close_code(code: WSCloseCode) -> Optional[int]: + return None if code is None else int(code) + + +@cython.cfunc +@cython.inline +def _coerce_close_reason(reason: Optional[str]) -> Optional[str]: + return reason if reason is not None else None + + +@cython.cfunc +@cython.inline +def _header_items(headers: Any) -> list[tuple[str, str]]: + return [] if headers is None else list(headers.items()) + + +@cython.cfunc +@cython.inline +def _resolve_subprotocol(subprotocols: Optional[Sequence[str]], response: Any) -> Optional[str]: + if response is None: + return None + value = response.headers.get("Sec-WebSocket-Protocol") + if value is None: + return None + if subprotocols is not None and value not in subprotocols: + raise InvalidHandshake(f"unsupported subprotocol negotiated by server: {value}") + return cast(str, value) + + +@cython.cfunc +@cython.inline +def _default_user_agent() -> str: + return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" + + +@cython.cfunc +@cython.inline +def _process_proxy(proxy: Union[str, bool, None], secure: bool) -> Optional[str]: + if proxy is None: + return None + if isinstance(proxy, str): + return proxy + if proxy is True: + proxies = getproxies() + return ( + proxies.get("wss" if secure else "ws") + or proxies.get("https" if secure else "http") + ) + raise InvalidURI(str(proxy), "proxy must be None, True, or a proxy URL") + + +@cython.cfunc +@cython.inline +def _normalize_size_limit(limit: Optional[int]) -> cython.Py_ssize_t: + return 0 if limit is None else limit + + +@cython.cfunc +@cython.inline +def _normalize_watermarks( + max_queue: Union[int, tuple[Optional[int], Optional[int]], None], +) -> tuple[cython.Py_ssize_t, cython.Py_ssize_t]: + if max_queue is None: + return 0, 0 + if isinstance(max_queue, tuple): + high, low = max_queue + if high is None: + return 0, 0 + return high, high // 4 if low is None else low + return max_queue, max_queue // 4 + + +@cython.cfunc +@cython.inline +def _resolve_logger(logger: LoggerLike) -> Union[logging.Logger, logging.LoggerAdapter[Any]]: + if logger is None: + return logging.getLogger("websockets.client") + if isinstance(logger, str): + return logging.getLogger(logger) + return logger + + +@cython.ccall +def process_exception(exc: Exception) -> Optional[Exception]: + if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)): + return None + if isinstance(exc, InvalidStatus): + status = exc.response.status + if int(status) in {500, 502, 503, 504}: + return None + return exc + + +@cython.cclass +class ClientConnection(WSListener): + id: uuid.UUID + logger: Union[logging.Logger, logging.LoggerAdapter[Any]] + transport: WSTransport + request: picows.WSUpgradeRequest + response: picows.WSUpgradeResponse + _subprotocols: Optional[Sequence[str]] + _subprotocol: Optional[str] + _state: State + _closed_event: asyncio.Event + _frames: asyncio.Queue[Optional[_BufferedFrame]] + _close_exc: Optional[ConnectionClosed] + _loop: asyncio.AbstractEventLoop + _recv_in_progress: cython.bint + _send_lock: _AsyncLock + _write_ready: Optional[asyncio.Future[None]] + _recv_streaming_broken: cython.bint + _pending_pings: Dict[bytes, Tuple[asyncio.Future[float], float]] + _ping_interval: Optional[float] + _ping_timeout: Optional[float] + _close_timeout: Optional[float] + _keepalive_task: Optional[asyncio.Task[None]] + _latency: cython.double + _max_message_size: cython.Py_ssize_t + _max_fragment_size: cython.Py_ssize_t + _max_queue_high: cython.Py_ssize_t + _max_queue_low: cython.Py_ssize_t + _write_limit: Union[int, tuple[int, Optional[int]]] + _paused_reading: cython.bint + + def __init__( + self, + *, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = 10, + max_queue: Union[int, tuple[Optional[int], Optional[int]], None] = 16, + write_limit: Union[int, tuple[int, Optional[int]]] = 32768, + max_message_size: Optional[int] = 1024 * 1024, + max_fragment_size: Optional[int] = 1024 * 1024, + logger: LoggerLike = None, + subprotocols: Optional[Sequence[str]] = None, + ): + self.id = uuid.uuid4() + self.logger = _resolve_logger(logger) + self.transport = cython.cast(WSTransport, None) + self.request = cast(picows.WSUpgradeRequest, None) + self.response = cast(picows.WSUpgradeResponse, None) + self._subprotocols = subprotocols + self._subprotocol = cast(Optional[str], None) + self._state = State.CONNECTING + self._closed_event = asyncio.Event() + self._frames: asyncio.Queue[Optional[_BufferedFrame]] = asyncio.Queue() + self._close_exc: Optional[ConnectionClosed] = None + self._loop = asyncio.get_running_loop() + self._recv_in_progress = False + self._send_lock = _AsyncLock(self._loop) + self._write_ready: Optional[asyncio.Future[None]] = None + self._recv_streaming_broken = False + self._pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} + self._ping_interval = ping_interval + self._ping_timeout = ping_timeout + self._close_timeout = close_timeout + self._keepalive_task: Optional[asyncio.Task[None]] = None + self._latency = 0.0 + self._max_message_size = _normalize_size_limit(max_message_size) + self._max_fragment_size = _normalize_size_limit(max_fragment_size) + self._max_queue_high, self._max_queue_low = _normalize_watermarks(max_queue) + self._write_limit = write_limit + self._paused_reading = False + + @cython.ccall + def on_ws_connected(self, transport: WSTransport) -> None: + self.transport = transport + self.request = transport.request + self.response = transport.response + self._subprotocol = _resolve_subprotocol(self._subprotocols, self.response) + self._state = State.OPEN + self._set_write_limits(self._write_limit) + if self._ping_interval is not None and self._keepalive_task is None: + self._keepalive_task = asyncio.create_task(self._keepalive_loop()) + + @cython.ccall + def on_ws_disconnected(self, transport: WSTransport) -> None: + self._state = State.CLOSED + self._set_close_exception() + self._frames.put_nowait(None) + self._closed_event.set() + if self._keepalive_task is not None: + self._keepalive_task.cancel() + self._keepalive_task = None + if self._write_ready is not None: + if not self._write_ready.done(): + self._write_ready.set_exception( + self._close_exc or ConnectionClosedError(None, None, None) + ) + self._write_ready = None + for waiter, _ in self._pending_pings.values(): + if not waiter.done(): + waiter.set_exception(self._close_exc or ConnectionClosedError(None, None, None)) + self._pending_pings.clear() + + @cython.ccall + def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: + payload = frame.get_payload_as_bytes() + if frame.msg_type == WSMsgType.PING: + self.transport.send_pong(payload) + return + + if frame.msg_type == WSMsgType.PONG: + ping = self._pending_pings.pop(payload, None) + if ping is not None: + waiter, sent_at = ping + self._latency = monotonic() - sent_at + if not waiter.done(): + waiter.set_result(self._latency) + return + + if frame.msg_type == WSMsgType.CLOSE: + close_code = frame.get_close_code() + close_message = frame.get_close_message() + self.transport.send_close(close_code, close_message) + self.transport.disconnect() + self._state = State.CLOSING + return + + if self._max_fragment_size > 0 and len(payload) > self._max_fragment_size: + self._fail_message_too_big("fragment too big") + return + + self._frames.put_nowait(_BufferedFrame(frame.msg_type, payload, frame.fin)) + self._pause_reading_if_needed() + + @cython.ccall + def pause_writing(self) -> None: + if self._write_ready is None: + self._write_ready = self._loop.create_future() + + @cython.ccall + def resume_writing(self) -> None: + if self._write_ready is not None: + if not self._write_ready.done(): + self._write_ready.set_result(None) + self._write_ready = None + + @cython.cfunc + @cython.inline + def _set_write_limits(self, write_limit: Union[int, tuple[int, Optional[int]]]) -> None: + if isinstance(write_limit, tuple): + high, low = write_limit + else: + high, low = write_limit, None + self.transport.underlying_transport.set_write_buffer_limits(high=high, low=low) + + @cython.cfunc + @cython.inline + def _pause_reading_if_needed(self) -> None: + if self._max_queue_high > 0 and not self._paused_reading and self._frames.qsize() >= self._max_queue_high: + self.transport.underlying_transport.pause_reading() + self._paused_reading = True + + @cython.cfunc + @cython.inline + def _resume_reading_if_needed(self) -> None: + if not self._paused_reading: + return + if self._max_queue_low == 0 or self._frames.qsize() <= self._max_queue_low: + self.transport.underlying_transport.resume_reading() + self._paused_reading = False + + @cython.cfunc + @cython.inline + def _set_close_exception(self) -> None: + handshake = self.transport.close_handshake + if handshake is None: + self._close_exc = ConnectionClosedError(None, None, None) + return + rcvd = handshake.recv + sent = handshake.sent + rcvd_then_sent = handshake.recv_then_sent + rcvd_code = _coerce_close_code(rcvd.code) if rcvd is not None else None + sent_code = _coerce_close_code(sent.code) if sent is not None else None + ok = ( + (rcvd_code in OK_CLOSE_CODES or rcvd_code is None) + and (sent_code in OK_CLOSE_CODES or sent_code is None) + ) + exc_type = ConnectionClosedOK if ok else ConnectionClosedError + self._close_exc = exc_type(rcvd, sent, rcvd_then_sent) + + @cython.cfunc + @cython.inline + def _fail_message_too_big(self, message: str) -> None: + self.transport.send_close(WSCloseCode.MESSAGE_TOO_BIG, message.encode("utf-8")) + self.transport.disconnect(False) + + @cython.cfunc + @cython.inline + def _connection_closed(self) -> ConnectionClosed: + if self._close_exc is None: + self._set_close_exception() + return self._close_exc or ConnectionClosedError(None, None, None) + + @cython.cfunc + @cython.inline + def _set_recv_in_progress(self) -> None: + if self._recv_in_progress: + raise ConcurrencyError("cannot call recv() or recv_streaming() concurrently") + if self._recv_streaming_broken: + raise ConcurrencyError("recv_streaming() wasn't fully consumed") + self._recv_in_progress = True + + @cython.cfunc + @cython.inline + def _decode_data(self, payload: bytes, msg_type: WSMsgType, decode: Optional[bool]) -> Union[str, bytes]: + if decode is True or (msg_type == WSMsgType.TEXT and decode is None): + return payload.decode("utf-8") + else: + return payload + + @cython.cfunc + @cython.inline + def _check_frame(self, frame: Optional[_BufferedFrame], is_first: cython.bint) -> None: + self._resume_reading_if_needed() + + if frame is None: + raise self._connection_closed() + + if is_first: + if frame.msg_type not in (WSMsgType.TEXT, WSMsgType.BINARY): + raise ProtocolError(f"unexpected opcode while receiving message: {frame.msg_type}") + else: + if frame.msg_type != WSMsgType.CONTINUATION: + raise ProtocolError("expected continuation frame") + + async def recv(self, decode: Optional[bool] = None) -> Union[str, bytes]: + frame: Optional[_BufferedFrame] + total: cython.Py_ssize_t + + self._set_recv_in_progress() + + try: + frame = await self._frames.get() + self._check_frame(frame, True) + + msg_type = frame.msg_type + if frame.fin: + return self._decode_data(frame.payload, msg_type, decode) + + chunks = [frame.payload] + total = len(frame.payload) + while not frame.fin: + frame = await self._frames.get() + self._check_frame(frame, False) + + chunks.append(frame.payload) + total += len(frame.payload) + if self._max_message_size > 0 and total > self._max_message_size: + self._fail_message_too_big("message too big") + raise PayloadTooBig("message too big") + + payload = b"".join(chunks) + return self._decode_data(payload, msg_type, decode) + finally: + self._recv_in_progress = False + + def recv_streaming(self, decode: Optional[bool] = None) -> AsyncIterator[Union[str, bytes]]: + self._set_recv_in_progress() + + started: cython.bint = False + finished: cython.bint = False + + async def iterator() -> AsyncIterator[Union[str, bytes]]: + nonlocal started, finished + msg_type: WSMsgType + first: Optional[_BufferedFrame] + frame: Optional[_BufferedFrame] + total: cython.Py_ssize_t + + try: + first = await self._frames.get() + self._check_frame(first, True) + + msg_type = first.msg_type + started = True + yield self._decode_data(first.payload, msg_type, decode) + total = len(first.payload) + frame = first + while not frame.fin: + frame = await self._frames.get() + self._check_frame(frame, False) + + total += len(frame.payload) + if self._max_message_size > 0 and total > self._max_message_size: + self._fail_message_too_big("message too big") + raise PayloadTooBig("message too big") + yield self._decode_data(frame.payload, msg_type, decode) + finished = True + finally: + self._recv_in_progress = False + if started and not finished: + self._recv_streaming_broken = True + elif finished: + self._recv_streaming_broken = False + + return iterator() + + async def send( + self, + message: Union[DataLike, Iterable[DataLike], AsyncIterator[DataLike]], + text: Optional[bool] = None, + ) -> None: + if self._state is State.CLOSED: + raise self._connection_closed() + + if self._send_lock.locked(): + await self._send_lock.wait_and_acquire() + else: + self._send_lock.acquire() + + try: + if isinstance(message, (str, bytes, bytearray, memoryview)): + if isinstance(message, str): + msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT + else: + msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY + self.transport.send(msg_type, message) + if self._write_ready is not None: + await self._write_ready + return + elif isinstance(message, AsyncIterable): + await self._send_fragments(True, message.__aiter__(), text) + return + elif isinstance(message, Iterable): + await self._send_fragments(False, iter(message), text) + return + raise TypeError(f"message has unsupported type {type(message).__name__}") + finally: + self._send_lock.release() + + @cython.cfunc + @cython.inline + def _check_fragment_type(self, message, first_is_str: cython.bint) -> None: + if first_is_str and isinstance(message, str): + return + elif not first_is_str and isinstance(message, (bytes, bytearray, memoryview)): + return + + raise TypeError("all fragments must be of the same category: str vs bytes-like") + + async def _send_fragments(self, is_async: cython.bint, iterator: Union[Iterator[DataLike], AsyncIterator[DataLike]], text: Optional[bool]) -> None: + stop_exception_type = StopAsyncIteration if is_async else StopIteration + try: + if is_async: + first = await anext(iterator) + else: + first = next(iterator) + except stop_exception_type: + raise TypeError("message iterable cannot be empty") from None + + first_is_str: cython.bint + if isinstance(first, str): + msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT + first_is_str = True + elif isinstance(first, (bytes, bytearray, memoryview)): + msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY + first_is_str = False + else: + raise TypeError(f"message must contain str or bytes-like objects, got {type(first).__name__}") + + previous = first + while True: + try: + if is_async: + current = await anext(iterator) + else: + current = next(iterator) + except stop_exception_type: + break + + self._check_fragment_type(current, first_is_str) + + self.transport.send(msg_type, previous, fin=False) + msg_type = WSMsgType.CONTINUATION + if self._write_ready is not None: + await self._write_ready + + previous = current + + self.transport.send(msg_type, previous, fin=True) + if self._write_ready is not None: + await self._write_ready + + async def close(self, code: int = 1000, reason: str = "") -> None: + if self._state is State.CLOSED: + return + if self._state is State.OPEN: + self._state = State.CLOSING + self.transport.send_close(code, reason.encode("utf-8")) + try: + if self._close_timeout is None: + await self.wait_closed() + else: + await asyncio.wait_for(self.wait_closed(), self._close_timeout) + except asyncio.TimeoutError: + self.transport.disconnect(False) + await self.wait_closed() + + async def wait_closed(self) -> None: + await self._closed_event.wait() + + async def ping(self, data: Optional[Union[str, bytes]] = None) -> Awaitable[float]: + if self._state is State.CLOSED: + raise self._connection_closed() + if data is None: + while True: + payload = os.urandom(4) + if payload not in self._pending_pings: + break + elif isinstance(data, str): + payload = data.encode("utf-8") + elif isinstance(data, bytes): + payload = data + else: + raise TypeError("ping payload must be str, bytes, or None") + + if payload in self._pending_pings: + raise ConcurrencyError("another ping was sent with the same data") + + waiter: asyncio.Future[float] = asyncio.get_running_loop().create_future() + self._pending_pings[payload] = (waiter, monotonic()) + self.transport.send_ping(payload) + return waiter + + async def pong(self, data: Union[str, bytes] = b"") -> None: + if self._state is State.CLOSED: + raise self._connection_closed() + payload = data.encode("utf-8") if isinstance(data, str) else data + self.transport.send_pong(payload) + + async def _keepalive_loop(self) -> None: + try: + while True: + assert self._ping_interval is not None + await asyncio.sleep(self._ping_interval) + waiter = await self.ping() + if self._ping_timeout is None: + continue + await asyncio.wait_for(waiter, self._ping_timeout) + except asyncio.CancelledError: + raise + except Exception: + if self.state is not State.CLOSED: + await self.close(code=1011, reason="keepalive ping timeout") + + async def __aenter__(self) -> ClientConnection: + return self + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + await self.close() + + def __aiter__(self) -> AsyncIterator[Union[str, bytes]]: + return self._iterate_messages() + + async def _iterate_messages(self) -> AsyncIterator[Union[str, bytes]]: + while True: + try: + yield await self.recv() + except ConnectionClosedOK: + return + + @property + def state(self) -> State: + return self._state + + @property + def local_address(self) -> Any: + return self.transport.underlying_transport.get_extra_info("sockname") + + @property + def remote_address(self) -> Any: + return self.transport.underlying_transport.get_extra_info("peername") + + @property + def latency(self) -> float: + return self._latency + + @property + def subprotocol(self) -> Optional[str]: + return self._subprotocol + + @property + def close_code(self) -> Optional[int]: + handshake = self.transport.close_handshake + if handshake is None: + return None + if handshake.recv is not None: + return _coerce_close_code(handshake.recv.code) + if handshake.sent is not None: + return _coerce_close_code(handshake.sent.code) + return None + + @property + def close_reason(self) -> Optional[str]: + handshake = self.transport.close_handshake + if handshake is None: + return None + if handshake.recv is not None: + return _coerce_close_reason(handshake.recv.reason) + if handshake.sent is not None: + return _coerce_close_reason(handshake.sent.reason) + return None diff --git a/setup.py b/setup.py index 8b7ec3c..b23c5b6 100644 --- a/setup.py +++ b/setup.py @@ -133,7 +133,7 @@ def build_extension(self, ext: Extension): depends=["picows/compat.h"], extra_compile_args=extra_compile_args, extra_link_args=extra_link_args), - Extension("picows.websockets.asyncio.client", ["picows/websockets/asyncio/client.py"], + Extension("picows.websockets.asyncio.connection", ["picows/websockets/asyncio/connection.py"], libraries=libs, define_macros=macros, depends=["picows/compat.h"], extra_compile_args=extra_compile_args, From a9de29700a88d2656b64d7742bdacf79274cee83 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 02:36:50 +0200 Subject: [PATCH 09/57] Cleanup --- picows/websockets/asyncio/connection.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 30a1c9d..dba1fa6 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -537,14 +537,12 @@ async def send( self.transport.send(msg_type, message) if self._write_ready is not None: await self._write_ready - return elif isinstance(message, AsyncIterable): await self._send_fragments(True, message.__aiter__(), text) - return elif isinstance(message, Iterable): await self._send_fragments(False, iter(message), text) - return - raise TypeError(f"message has unsupported type {type(message).__name__}") + else: + raise TypeError(f"message has unsupported type {type(message).__name__}") finally: self._send_lock.release() From 7d0d2a9144a7ce54888ed858dba7f72b38bd6357 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 03:08:30 +0200 Subject: [PATCH 10/57] Refactoring --- picows/websockets/asyncio/client.py | 19 ++-- picows/websockets/asyncio/connection.py | 113 +++++++++++++----------- 2 files changed, 66 insertions(+), 66 deletions(-) diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index edb5f1a..54b02a8 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -71,7 +71,7 @@ def __init__( ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = 10, - max_size: Union[int, tuple[Optional[int], Optional[int]], None] = 1024 * 1024, + max_size: Optional[int] = 1024 * 1024, max_queue: Union[int, tuple[Optional[int], Optional[int]], None] = 16, write_limit: Union[int, tuple[int, Optional[int]]] = 32768, logger: LoggerLike = None, @@ -137,7 +137,7 @@ async def _connect(self) -> ClientConnection: parsed = parse_url(self.uri) proxy = _process_proxy(self.proxy, parsed.is_secure) extra_headers = self._build_headers() - max_message_size, max_fragment_size = self._normalize_max_size(self.max_size) + max_message_size = self._normalize_max_size(self.max_size) if self.extensions is not None: raise NotImplementedError("custom extensions aren't supported by picows.websockets") @@ -190,7 +190,6 @@ def listener_factory() -> ClientConnection: max_queue=self.max_queue, write_limit=self.write_limit, max_message_size=max_message_size, - max_fragment_size=max_fragment_size, logger=self.logger, subprotocols=self.subprotocols, ) @@ -203,7 +202,7 @@ def listener_factory() -> ClientConnection: websocket_handshake_timeout=self.open_timeout, enable_auto_ping=False, enable_auto_pong=False, - max_frame_size=max_fragment_size if max_fragment_size > 0 else 2 ** 31 - 1, + max_frame_size=2 ** 31 - 1, extra_headers=extra_headers, proxy=proxy, socket_factory=socket_factory, @@ -225,16 +224,8 @@ def listener_factory() -> ClientConnection: return cast(ClientConnection, listener) - def _normalize_max_size( - self, - max_size: Union[int, tuple[Optional[int], Optional[int]], None], - ) -> tuple[int, int]: - if max_size is None: - return 0, 0 - if isinstance(max_size, tuple): - max_message_size, max_fragment_size = max_size - return _normalize_size_limit(max_message_size), _normalize_size_limit(max_fragment_size) - return max_size, max_size + def _normalize_max_size(self, max_size: Optional[int]) -> int: + return _normalize_size_limit(max_size) def _build_headers(self) -> list[tuple[str, str]]: headers = _header_items(self.additional_headers) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index dba1fa6..ea288cc 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -36,8 +36,6 @@ InvalidHandshake, InvalidStatus, InvalidURI, - PayloadTooBig, - ProtocolError, ) @@ -233,9 +231,10 @@ class ClientConnection(WSListener): _keepalive_task: Optional[asyncio.Task[None]] _latency: cython.double _max_message_size: cython.Py_ssize_t - _max_fragment_size: cython.Py_ssize_t _max_queue_high: cython.Py_ssize_t _max_queue_low: cython.Py_ssize_t + _incoming_message_active: cython.bint + _incoming_message_size: cython.Py_ssize_t _write_limit: Union[int, tuple[int, Optional[int]]] _paused_reading: cython.bint @@ -248,7 +247,6 @@ def __init__( max_queue: Union[int, tuple[Optional[int], Optional[int]], None] = 16, write_limit: Union[int, tuple[int, Optional[int]]] = 32768, max_message_size: Optional[int] = 1024 * 1024, - max_fragment_size: Optional[int] = 1024 * 1024, logger: LoggerLike = None, subprotocols: Optional[Sequence[str]] = None, ): @@ -275,8 +273,9 @@ def __init__( self._keepalive_task: Optional[asyncio.Task[None]] = None self._latency = 0.0 self._max_message_size = _normalize_size_limit(max_message_size) - self._max_fragment_size = _normalize_size_limit(max_fragment_size) self._max_queue_high, self._max_queue_low = _normalize_watermarks(max_queue) + self._incoming_message_active = False + self._incoming_message_size = 0 self._write_limit = write_limit self._paused_reading = False @@ -335,8 +334,31 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: self._state = State.CLOSING return - if self._max_fragment_size > 0 and len(payload) > self._max_fragment_size: - self._fail_message_too_big("fragment too big") + if frame.msg_type == WSMsgType.CONTINUATION: + if not self._incoming_message_active: + self._fail_protocol_error("unexpected continuation frame") + return + self._incoming_message_size += len(payload) + if self._max_message_size > 0 and self._incoming_message_size > self._max_message_size: + self._fail_message_too_big("message too big") + return + if frame.fin: + self._incoming_message_active = False + self._incoming_message_size = 0 + elif frame.msg_type in (WSMsgType.TEXT, WSMsgType.BINARY): + if self._incoming_message_active: + self._fail_protocol_error("expected continuation frame") + return + self._incoming_message_size = len(payload) + if self._max_message_size > 0 and self._incoming_message_size > self._max_message_size: + self._fail_message_too_big("message too big") + return + if frame.fin: + self._incoming_message_size = 0 + else: + self._incoming_message_active = True + else: + self._fail_protocol_error(f"unexpected opcode while receiving message: {frame.msg_type}") return self._frames.put_nowait(_BufferedFrame(frame.msg_type, payload, frame.fin)) @@ -398,12 +420,6 @@ def _set_close_exception(self) -> None: exc_type = ConnectionClosedOK if ok else ConnectionClosedError self._close_exc = exc_type(rcvd, sent, rcvd_then_sent) - @cython.cfunc - @cython.inline - def _fail_message_too_big(self, message: str) -> None: - self.transport.send_close(WSCloseCode.MESSAGE_TOO_BIG, message.encode("utf-8")) - self.transport.disconnect(False) - @cython.cfunc @cython.inline def _connection_closed(self) -> ConnectionClosed: @@ -411,6 +427,18 @@ def _connection_closed(self) -> ConnectionClosed: self._set_close_exception() return self._close_exc or ConnectionClosedError(None, None, None) + @cython.cfunc + @cython.inline + def _fail_protocol_error(self, message: str) -> None: + self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, message.encode("utf-8")) + self.transport.disconnect(False) + + @cython.cfunc + @cython.inline + def _fail_message_too_big(self, message: str) -> None: + self.transport.send_close(WSCloseCode.MESSAGE_TOO_BIG, message.encode("utf-8")) + self.transport.disconnect(False) + @cython.cfunc @cython.inline def _set_recv_in_progress(self) -> None: @@ -430,44 +458,32 @@ def _decode_data(self, payload: bytes, msg_type: WSMsgType, decode: Optional[boo @cython.cfunc @cython.inline - def _check_frame(self, frame: Optional[_BufferedFrame], is_first: cython.bint) -> None: + def _check_frame(self, frame: Optional[_BufferedFrame]) -> None: self._resume_reading_if_needed() if frame is None: raise self._connection_closed() - if is_first: - if frame.msg_type not in (WSMsgType.TEXT, WSMsgType.BINARY): - raise ProtocolError(f"unexpected opcode while receiving message: {frame.msg_type}") - else: - if frame.msg_type != WSMsgType.CONTINUATION: - raise ProtocolError("expected continuation frame") - async def recv(self, decode: Optional[bool] = None) -> Union[str, bytes]: frame: Optional[_BufferedFrame] - total: cython.Py_ssize_t self._set_recv_in_progress() try: frame = await self._frames.get() - self._check_frame(frame, True) + self._check_frame(frame) + frame = cast(_BufferedFrame, frame) msg_type = frame.msg_type if frame.fin: return self._decode_data(frame.payload, msg_type, decode) chunks = [frame.payload] - total = len(frame.payload) while not frame.fin: frame = await self._frames.get() - self._check_frame(frame, False) + self._check_frame(frame) chunks.append(frame.payload) - total += len(frame.payload) - if self._max_message_size > 0 and total > self._max_message_size: - self._fail_message_too_big("message too big") - raise PayloadTooBig("message too big") payload = b"".join(chunks) return self._decode_data(payload, msg_type, decode) @@ -477,40 +493,33 @@ async def recv(self, decode: Optional[bool] = None) -> Union[str, bytes]: def recv_streaming(self, decode: Optional[bool] = None) -> AsyncIterator[Union[str, bytes]]: self._set_recv_in_progress() - started: cython.bint = False - finished: cython.bint = False + msg_started: cython.bint = False + msg_finished: cython.bint = False async def iterator() -> AsyncIterator[Union[str, bytes]]: - nonlocal started, finished - msg_type: WSMsgType - first: Optional[_BufferedFrame] + nonlocal msg_started, msg_finished frame: Optional[_BufferedFrame] - total: cython.Py_ssize_t + msg_type: WSMsgType try: - first = await self._frames.get() - self._check_frame(first, True) - - msg_type = first.msg_type - started = True - yield self._decode_data(first.payload, msg_type, decode) - total = len(first.payload) - frame = first + frame = await self._frames.get() + self._check_frame(frame) + frame = cast(_BufferedFrame, frame) + msg_started = True + msg_type = frame.msg_type + yield self._decode_data(frame.payload, msg_type, decode) + while not frame.fin: frame = await self._frames.get() - self._check_frame(frame, False) - - total += len(frame.payload) - if self._max_message_size > 0 and total > self._max_message_size: - self._fail_message_too_big("message too big") - raise PayloadTooBig("message too big") + self._check_frame(frame) + frame = cast(_BufferedFrame, frame) yield self._decode_data(frame.payload, msg_type, decode) - finished = True + msg_finished = True finally: self._recv_in_progress = False - if started and not finished: + if msg_started and not msg_finished: self._recv_streaming_broken = True - elif finished: + elif msg_finished: self._recv_streaming_broken = False return iterator() From 5a47d899720d0117ef7c897ba9b9d96d8e11bcb0 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 03:26:50 +0200 Subject: [PATCH 11/57] More refactoring --- picows/picows.pyx | 3 ++- picows/websockets/asyncio/client.py | 8 +++----- picows/websockets/asyncio/connection.py | 8 +------- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/picows/picows.pyx b/picows/picows.pyx index 980b7db..2efcb80 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -1632,7 +1632,8 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): self._f_payload_start_pos = self._f_curr_state_start_pos self._state = WSParserState.READ_PAYLOAD - if self._f_payload_length > self._max_frame_size: + if (self._f_payload_length > self._max_frame_size and + self._f_msg_type not in (WSMsgType.PING, WSMsgType.PONG, WSMsgType.CLOSE)): raise WSProtocolError( WSCloseCode.MESSAGE_TOO_BIG, f"Received frame with payload size exceeding max allowed size, " diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index 54b02a8..8b3e2b5 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -137,7 +137,8 @@ async def _connect(self) -> ClientConnection: parsed = parse_url(self.uri) proxy = _process_proxy(self.proxy, parsed.is_secure) extra_headers = self._build_headers() - max_message_size = self._normalize_max_size(self.max_size) + max_message_size = 0 if self.max_size is None else self.max_size + max_frame_size = 2 ** 31 - 1 if not self.max_size else self.max_size if self.extensions is not None: raise NotImplementedError("custom extensions aren't supported by picows.websockets") @@ -202,7 +203,7 @@ def listener_factory() -> ClientConnection: websocket_handshake_timeout=self.open_timeout, enable_auto_ping=False, enable_auto_pong=False, - max_frame_size=2 ** 31 - 1, + max_frame_size=max_frame_size, extra_headers=extra_headers, proxy=proxy, socket_factory=socket_factory, @@ -224,9 +225,6 @@ def listener_factory() -> ClientConnection: return cast(ClientConnection, listener) - def _normalize_max_size(self, max_size: Optional[int]) -> int: - return _normalize_size_limit(max_size) - def _build_headers(self) -> list[tuple[str, str]]: headers = _header_items(self.additional_headers) if self.origin is not None: diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index ea288cc..7109726 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -164,12 +164,6 @@ def _process_proxy(proxy: Union[str, bool, None], secure: bool) -> Optional[str] raise InvalidURI(str(proxy), "proxy must be None, True, or a proxy URL") -@cython.cfunc -@cython.inline -def _normalize_size_limit(limit: Optional[int]) -> cython.Py_ssize_t: - return 0 if limit is None else limit - - @cython.cfunc @cython.inline def _normalize_watermarks( @@ -272,7 +266,7 @@ def __init__( self._close_timeout = close_timeout self._keepalive_task: Optional[asyncio.Task[None]] = None self._latency = 0.0 - self._max_message_size = _normalize_size_limit(max_message_size) + self._max_message_size = 0 if max_message_size is None else max_message_size self._max_queue_high, self._max_queue_low = _normalize_watermarks(max_queue) self._incoming_message_active = False self._incoming_message_size = 0 From 9c772562ee8a165149afdcd69e5c1c26b9b14304 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 03:33:10 +0200 Subject: [PATCH 12/57] Cleanup --- picows/websockets/asyncio/client.py | 2 +- picows/websockets/asyncio/connection.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index 8b3e2b5..ec2b92b 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -202,7 +202,7 @@ def listener_factory() -> ClientConnection: ssl_context=self._coerce_ssl_context(ssl_context), websocket_handshake_timeout=self.open_timeout, enable_auto_ping=False, - enable_auto_pong=False, + enable_auto_pong=True, max_frame_size=max_frame_size, extra_headers=extra_headers, proxy=proxy, diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 7109726..80e181d 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -307,10 +307,6 @@ def on_ws_disconnected(self, transport: WSTransport) -> None: @cython.ccall def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: payload = frame.get_payload_as_bytes() - if frame.msg_type == WSMsgType.PING: - self.transport.send_pong(payload) - return - if frame.msg_type == WSMsgType.PONG: ping = self._pending_pings.pop(payload, None) if ping is not None: From 540ce33002eb5a1593983c29b92bb3803a62d3a7 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 03:55:55 +0200 Subject: [PATCH 13/57] Simplify --- docs/source/reference.rst | 2 -- picows/picows.pxd | 2 +- picows/websockets/asyncio/connection.py | 15 ++++++--------- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/docs/source/reference.rst b/docs/source/reference.rst index 237c98a..e11d85f 100644 --- a/docs/source/reference.rst +++ b/docs/source/reference.rst @@ -146,8 +146,6 @@ Classes .. py:attribute:: payload_size :type: size_t - **Available only from Cython.** - Size of the payload. .. autoclass:: WSUpgradeRequest diff --git a/picows/picows.pxd b/picows/picows.pxd index 93510f6..dbed30b 100644 --- a/picows/picows.pxd +++ b/picows/picows.pxd @@ -67,7 +67,7 @@ cdef class MemoryBuffer: cdef class WSFrame: cdef: char* payload_ptr - Py_ssize_t payload_size + readonly Py_ssize_t payload_size readonly Py_ssize_t tail_size readonly WSMsgType msg_type readonly uint8_t fin diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 80e181d..23e8b11 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -210,7 +210,6 @@ class ClientConnection(WSListener): _subprotocols: Optional[Sequence[str]] _subprotocol: Optional[str] _state: State - _closed_event: asyncio.Event _frames: asyncio.Queue[Optional[_BufferedFrame]] _close_exc: Optional[ConnectionClosed] _loop: asyncio.AbstractEventLoop @@ -252,7 +251,6 @@ def __init__( self._subprotocols = subprotocols self._subprotocol = cast(Optional[str], None) self._state = State.CONNECTING - self._closed_event = asyncio.Event() self._frames: asyncio.Queue[Optional[_BufferedFrame]] = asyncio.Queue() self._close_exc: Optional[ConnectionClosed] = None self._loop = asyncio.get_running_loop() @@ -289,7 +287,6 @@ def on_ws_disconnected(self, transport: WSTransport) -> None: self._state = State.CLOSED self._set_close_exception() self._frames.put_nowait(None) - self._closed_event.set() if self._keepalive_task is not None: self._keepalive_task.cancel() self._keepalive_task = None @@ -306,9 +303,8 @@ def on_ws_disconnected(self, transport: WSTransport) -> None: @cython.ccall def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: - payload = frame.get_payload_as_bytes() if frame.msg_type == WSMsgType.PONG: - ping = self._pending_pings.pop(payload, None) + ping = self._pending_pings.pop(frame.get_payload_as_bytes(), None) if ping is not None: waiter, sent_at = ping self._latency = monotonic() - sent_at @@ -328,7 +324,8 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: if not self._incoming_message_active: self._fail_protocol_error("unexpected continuation frame") return - self._incoming_message_size += len(payload) + + self._incoming_message_size += frame.payload_size if self._max_message_size > 0 and self._incoming_message_size > self._max_message_size: self._fail_message_too_big("message too big") return @@ -339,7 +336,7 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: if self._incoming_message_active: self._fail_protocol_error("expected continuation frame") return - self._incoming_message_size = len(payload) + self._incoming_message_size = frame.payload_size if self._max_message_size > 0 and self._incoming_message_size > self._max_message_size: self._fail_message_too_big("message too big") return @@ -351,7 +348,7 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: self._fail_protocol_error(f"unexpected opcode while receiving message: {frame.msg_type}") return - self._frames.put_nowait(_BufferedFrame(frame.msg_type, payload, frame.fin)) + self._frames.put_nowait(_BufferedFrame(frame.msg_type, frame.get_payload_as_bytes(), frame.fin)) self._pause_reading_if_needed() @cython.ccall @@ -614,7 +611,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: await self.wait_closed() async def wait_closed(self) -> None: - await self._closed_event.wait() + await self.transport.wait_disconnected() async def ping(self, data: Optional[Union[str, bytes]] = None) -> Awaitable[float]: if self._state is State.CLOSED: From 6af773a1fcb6a180bc4b97f770f4cafbd2d55bd6 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 16:46:49 +0200 Subject: [PATCH 14/57] Better compatibility with websockets --- AGENTS.md | 4 + picows/websockets/__init__.py | 50 +++++++ picows/websockets/asyncio/client.py | 15 ++- picows/websockets/asyncio/connection.py | 165 ++++++++++++++---------- picows/websockets/compat.py | 13 ++ picows/websockets/exceptions.py | 117 ++++++++++++++++- picows/websockets/typing.py | 31 +++++ 7 files changed, 320 insertions(+), 75 deletions(-) create mode 100644 picows/websockets/compat.py create mode 100644 picows/websockets/typing.py diff --git a/AGENTS.md b/AGENTS.md index 1f07284..11908d4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -24,6 +24,10 @@ examples - Various examples for users on how to use picows + perf_test that coul If the same conversion, check, or tiny code pattern appears in multiple sibling paths after a refactor, stop and normalize it before considering the work done. Do not remove one layer of abstraction only to inline the same logic redundantly in several places. After a refactor, scan for duplicated branch bodies and duplicated type-specific handling introduced by the change. +- `picows.websockets` aims for import-level compatibility with the official `websockets` package on the client side. + We can skip complicated areas such as the full server interface, but simple surface-area compatibility matters. + Type definitions, exception definitions, and other lightweight importable names should exist when upstream exposes them. + People switching from `websockets` to `picows.websockets` should notice as little difference as possible. ## Testing instructions - Run lint after updating code with: diff --git a/picows/websockets/__init__.py b/picows/websockets/__init__.py index 7ece134..0fb5b48 100644 --- a/picows/websockets/__init__.py +++ b/picows/websockets/__init__.py @@ -1,39 +1,89 @@ from . import exceptions from .asyncio.client import connect from .asyncio.connection import ClientConnection, State, process_exception +from .compat import CloseCode, Request, Response from .exceptions import ( ConcurrencyError, ConnectionClosed, ConnectionClosedError, ConnectionClosedOK, + DuplicateParameter, InvalidHandshake, InvalidHeader, + InvalidHeaderFormat, + InvalidHeaderValue, InvalidMessage, + InvalidOrigin, + InvalidParameterName, + InvalidParameterValue, + InvalidProxy, + InvalidProxyMessage, + InvalidProxyStatus, InvalidState, InvalidStatus, InvalidUpgrade, InvalidURI, + NegotiationError, PayloadTooBig, ProtocolError, + ProxyError, + SecurityError, WebSocketException, ) +from .typing import ( + BytesLike, + Data, + DataLike, + ExtensionName, + ExtensionParameter, + HeadersLike, + LoggerLike, + Origin, + StatusLike, + Subprotocol, +) __all__ = [ + "BytesLike", "ClientConnection", + "CloseCode", + "Data", + "DataLike", "ConcurrencyError", "ConnectionClosed", "ConnectionClosedError", "ConnectionClosedOK", + "DuplicateParameter", + "ExtensionName", + "ExtensionParameter", + "HeadersLike", "InvalidHandshake", "InvalidHeader", + "InvalidHeaderFormat", + "InvalidHeaderValue", "InvalidMessage", + "InvalidOrigin", + "InvalidParameterName", + "InvalidParameterValue", + "InvalidProxy", + "InvalidProxyMessage", + "InvalidProxyStatus", "InvalidState", "InvalidStatus", "InvalidUpgrade", "InvalidURI", + "LoggerLike", + "NegotiationError", + "Origin", "PayloadTooBig", "ProtocolError", + "ProxyError", + "Request", + "Response", + "SecurityError", "State", + "StatusLike", + "Subprotocol", "WebSocketException", "connect", "exceptions", diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index ec2b92b..fc08061 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -12,18 +12,23 @@ from .connection import ( ClientConnection, - HeadersLike, - LoggerLike, process_exception, ) from ..exceptions import ( InvalidHandshake, InvalidHeader, InvalidMessage, + InvalidProxy, InvalidStatus, InvalidUpgrade, InvalidURI, ) +from ..typing import HeadersLike, LoggerLike, Origin, Subprotocol + +__all__ = [ + "ClientConnection", + "connect", +] def _default_user_agent() -> str: @@ -47,7 +52,7 @@ def _process_proxy(proxy: Union[str, bool, None], secure: bool) -> Optional[str] proxies.get("wss" if secure else "ws") or proxies.get("https" if secure else "http") ) - raise InvalidURI(str(proxy), "proxy must be None, True, or a proxy URL") + raise InvalidProxy(str(proxy), "proxy must be None, True, or a proxy URL") def _normalize_size_limit(limit: Optional[int]) -> int: @@ -59,9 +64,9 @@ def __init__( self, uri: str, *, - origin: Optional[str] = None, + origin: Optional[Origin] = None, extensions: Optional[Sequence[Any]] = None, - subprotocols: Optional[Sequence[str]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, compression: Optional[str] = "deflate", additional_headers: Optional[HeadersLike] = None, user_agent_header: Optional[str] = _default_user_agent(), diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 23e8b11..5b7616d 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -1,20 +1,15 @@ from __future__ import annotations import asyncio -import sys import logging import os -import socket import uuid -import warnings from collections import deque from collections.abc import AsyncIterable, Generator, Iterable from enum import IntEnum -from ssl import SSLContext from time import monotonic -from typing import Any, AsyncIterator, Awaitable, Callable, Optional, Sequence, \ +from typing import Any, AsyncIterator, Awaitable, Optional, Sequence, \ Union, cast, Dict, Tuple, Iterator -from urllib.request import getproxies import cython @@ -26,8 +21,8 @@ import picows -from picows.types import WSHeadersLike +from ..compat import CloseCode, Request, Response from ..exceptions import ( ConcurrencyError, ConnectionClosed, @@ -35,17 +30,12 @@ ConnectionClosedOK, InvalidHandshake, InvalidStatus, - InvalidURI, ) - - -DataLike = Union[str, bytes, bytearray, memoryview] -HeadersLike = WSHeadersLike -CloseCodeT = int -LoggerLike = Union[str, logging.Logger, logging.LoggerAdapter[Any], None] +from ..typing import Data, DataLike, LoggerLike, Subprotocol OK_CLOSE_CODES = {0, 1000, 1001} +_QUEUE_EMPTY = object() class State(IntEnum): @@ -111,9 +101,58 @@ def release(self) -> None: self._locked = False +@cython.cclass +class _SingleConsumerQueue: + _loop: asyncio.AbstractEventLoop + _items: deque + _waiter: Optional[asyncio.Future[Optional[_BufferedFrame]]] + + def __init__(self, loop: asyncio.AbstractEventLoop): + self._loop = loop + self._items = deque() + self._waiter = None + + @cython.cfunc + @cython.inline + def put(self, item: Optional[_BufferedFrame]) -> None: + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.done(): + waiter.set_result(item) + return + self._items.append(item) + + @cython.cfunc + @cython.inline + def get_nowait(self) -> object: + if self._items: + return self._items.popleft() + return _QUEUE_EMPTY + + async def get(self) -> Optional[_BufferedFrame]: + item = self.get_nowait() + if item is not _QUEUE_EMPTY: + return item + + waiter: asyncio.Future[Optional[_BufferedFrame]] = self._loop.create_future() + self._waiter = waiter + try: + return await waiter + except Exception: + if self._waiter is waiter: + self._waiter = None + raise + + @cython.cfunc + @cython.inline + def qsize(self) -> cython.Py_ssize_t: + return len(self._items) + + @cython.cfunc @cython.inline -def _coerce_close_code(code: WSCloseCode) -> Optional[int]: +def _coerce_close_code(code: CloseCode) -> Optional[int]: return None if code is None else int(code) @@ -125,13 +164,10 @@ def _coerce_close_reason(reason: Optional[str]) -> Optional[str]: @cython.cfunc @cython.inline -def _header_items(headers: Any) -> list[tuple[str, str]]: - return [] if headers is None else list(headers.items()) - - -@cython.cfunc -@cython.inline -def _resolve_subprotocol(subprotocols: Optional[Sequence[str]], response: Any) -> Optional[str]: +def _resolve_subprotocol( + subprotocols: Optional[Sequence[Subprotocol]], + response: Any, +) -> Optional[Subprotocol]: if response is None: return None value = response.headers.get("Sec-WebSocket-Protocol") @@ -139,29 +175,7 @@ def _resolve_subprotocol(subprotocols: Optional[Sequence[str]], response: Any) - return None if subprotocols is not None and value not in subprotocols: raise InvalidHandshake(f"unsupported subprotocol negotiated by server: {value}") - return cast(str, value) - - -@cython.cfunc -@cython.inline -def _default_user_agent() -> str: - return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" - - -@cython.cfunc -@cython.inline -def _process_proxy(proxy: Union[str, bool, None], secure: bool) -> Optional[str]: - if proxy is None: - return None - if isinstance(proxy, str): - return proxy - if proxy is True: - proxies = getproxies() - return ( - proxies.get("wss" if secure else "ws") - or proxies.get("https" if secure else "http") - ) - raise InvalidURI(str(proxy), "proxy must be None, True, or a proxy URL") + return cast(Subprotocol, value) @cython.cfunc @@ -205,12 +219,12 @@ class ClientConnection(WSListener): id: uuid.UUID logger: Union[logging.Logger, logging.LoggerAdapter[Any]] transport: WSTransport - request: picows.WSUpgradeRequest - response: picows.WSUpgradeResponse - _subprotocols: Optional[Sequence[str]] - _subprotocol: Optional[str] + request: Request + response: Response + _subprotocols: Optional[Sequence[Subprotocol]] + _subprotocol: Optional[Subprotocol] _state: State - _frames: asyncio.Queue[Optional[_BufferedFrame]] + _frames: _SingleConsumerQueue _close_exc: Optional[ConnectionClosed] _loop: asyncio.AbstractEventLoop _recv_in_progress: cython.bint @@ -241,19 +255,19 @@ def __init__( write_limit: Union[int, tuple[int, Optional[int]]] = 32768, max_message_size: Optional[int] = 1024 * 1024, logger: LoggerLike = None, - subprotocols: Optional[Sequence[str]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, ): self.id = uuid.uuid4() self.logger = _resolve_logger(logger) self.transport = cython.cast(WSTransport, None) - self.request = cast(picows.WSUpgradeRequest, None) - self.response = cast(picows.WSUpgradeResponse, None) + self.request = cast(Request, None) + self.response = cast(Response, None) self._subprotocols = subprotocols - self._subprotocol = cast(Optional[str], None) + self._subprotocol = cast(Optional[Subprotocol], None) self._state = State.CONNECTING - self._frames: asyncio.Queue[Optional[_BufferedFrame]] = asyncio.Queue() self._close_exc: Optional[ConnectionClosed] = None self._loop = asyncio.get_running_loop() + self._frames = _SingleConsumerQueue(self._loop) self._recv_in_progress = False self._send_lock = _AsyncLock(self._loop) self._write_ready: Optional[asyncio.Future[None]] = None @@ -286,7 +300,7 @@ def on_ws_connected(self, transport: WSTransport) -> None: def on_ws_disconnected(self, transport: WSTransport) -> None: self._state = State.CLOSED self._set_close_exception() - self._frames.put_nowait(None) + self._frames.put(None) if self._keepalive_task is not None: self._keepalive_task.cancel() self._keepalive_task = None @@ -348,7 +362,7 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: self._fail_protocol_error(f"unexpected opcode while receiving message: {frame.msg_type}") return - self._frames.put_nowait(_BufferedFrame(frame.msg_type, frame.get_payload_as_bytes(), frame.fin)) + self._frames.put(_BufferedFrame(frame.msg_type, frame.get_payload_as_bytes(), frame.fin)) self._pause_reading_if_needed() @cython.ccall @@ -437,7 +451,7 @@ def _set_recv_in_progress(self) -> None: @cython.cfunc @cython.inline - def _decode_data(self, payload: bytes, msg_type: WSMsgType, decode: Optional[bool]) -> Union[str, bytes]: + def _decode_data(self, payload: bytes, msg_type: WSMsgType, decode: Optional[bool]) -> Data: if decode is True or (msg_type == WSMsgType.TEXT and decode is None): return payload.decode("utf-8") else: @@ -451,15 +465,23 @@ def _check_frame(self, frame: Optional[_BufferedFrame]) -> None: if frame is None: raise self._connection_closed() - async def recv(self, decode: Optional[bool] = None) -> Union[str, bytes]: + @cython.cfunc + @cython.inline + def _get_frame_nowait(self) -> object: + return self._frames.get_nowait() + + async def recv(self, decode: Optional[bool] = None) -> Data: frame: Optional[_BufferedFrame] self._set_recv_in_progress() try: - frame = await self._frames.get() + item = self._get_frame_nowait() + if item is _QUEUE_EMPTY: + item = await self._frames.get() + + frame = cython.cast(_BufferedFrame, item) self._check_frame(frame) - frame = cast(_BufferedFrame, frame) msg_type = frame.msg_type if frame.fin: @@ -467,8 +489,11 @@ async def recv(self, decode: Optional[bool] = None) -> Union[str, bytes]: chunks = [frame.payload] while not frame.fin: - frame = await self._frames.get() + frame = self._get_frame_nowait() + if frame is _QUEUE_EMPTY: + frame = await self._frames.get() self._check_frame(frame) + frame = cast(_BufferedFrame, frame) chunks.append(frame.payload) @@ -477,19 +502,21 @@ async def recv(self, decode: Optional[bool] = None) -> Union[str, bytes]: finally: self._recv_in_progress = False - def recv_streaming(self, decode: Optional[bool] = None) -> AsyncIterator[Union[str, bytes]]: + def recv_streaming(self, decode: Optional[bool] = None) -> AsyncIterator[Data]: self._set_recv_in_progress() msg_started: cython.bint = False msg_finished: cython.bint = False - async def iterator() -> AsyncIterator[Union[str, bytes]]: + async def iterator() -> AsyncIterator[Data]: nonlocal msg_started, msg_finished frame: Optional[_BufferedFrame] msg_type: WSMsgType try: - frame = await self._frames.get() + frame = self._get_frame_nowait() + if frame is _QUEUE_EMPTY: + frame = await self._frames.get() self._check_frame(frame) frame = cast(_BufferedFrame, frame) msg_started = True @@ -497,7 +524,9 @@ async def iterator() -> AsyncIterator[Union[str, bytes]]: yield self._decode_data(frame.payload, msg_type, decode) while not frame.fin: - frame = await self._frames.get() + frame = self._get_frame_nowait() + if frame is _QUEUE_EMPTY: + frame = await self._frames.get() self._check_frame(frame) frame = cast(_BufferedFrame, frame) yield self._decode_data(frame.payload, msg_type, decode) @@ -666,7 +695,7 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: def __aiter__(self) -> AsyncIterator[Union[str, bytes]]: return self._iterate_messages() - async def _iterate_messages(self) -> AsyncIterator[Union[str, bytes]]: + async def _iterate_messages(self) -> AsyncIterator[Data]: while True: try: yield await self.recv() @@ -690,7 +719,7 @@ def latency(self) -> float: return self._latency @property - def subprotocol(self) -> Optional[str]: + def subprotocol(self) -> Optional[Subprotocol]: return self._subprotocol @property diff --git a/picows/websockets/compat.py b/picows/websockets/compat.py new file mode 100644 index 0000000..58c381c --- /dev/null +++ b/picows/websockets/compat.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import picows + +CloseCode = picows.WSCloseCode +Request = picows.WSUpgradeRequest +Response = picows.WSUpgradeResponse + +__all__ = [ + "CloseCode", + "Request", + "Response", +] diff --git a/picows/websockets/exceptions.py b/picows/websockets/exceptions.py index 3ae7794..1d8a970 100644 --- a/picows/websockets/exceptions.py +++ b/picows/websockets/exceptions.py @@ -2,6 +2,35 @@ from typing import Any, Optional +__all__ = [ + "WebSocketException", + "ConnectionClosed", + "ConnectionClosedOK", + "ConnectionClosedError", + "InvalidURI", + "InvalidProxy", + "InvalidHandshake", + "SecurityError", + "ProxyError", + "InvalidProxyMessage", + "InvalidProxyStatus", + "InvalidMessage", + "InvalidStatus", + "InvalidHeader", + "InvalidHeaderFormat", + "InvalidHeaderValue", + "InvalidOrigin", + "InvalidUpgrade", + "NegotiationError", + "DuplicateParameter", + "InvalidParameterName", + "InvalidParameterValue", + "ProtocolError", + "PayloadTooBig", + "InvalidState", + "ConcurrencyError", +] + class WebSocketException(Exception): """Base class for exceptions defined by picows.websockets.""" @@ -47,10 +76,44 @@ def __str__(self) -> str: return f"{self.uri} isn't a valid WebSocket URI: {self.msg}" +class InvalidProxy(WebSocketException): + def __init__(self, proxy: str, msg: str): + super().__init__(proxy, msg) + self.proxy = proxy + self.msg = msg + + def __str__(self) -> str: + return f"{self.proxy} isn't a valid proxy: {self.msg}" + + class InvalidHandshake(WebSocketException): pass +class SecurityError(InvalidHandshake): + pass + + +class ProxyError(InvalidHandshake): + pass + + +class InvalidProxyMessage(ProxyError): + pass + + +class InvalidProxyStatus(ProxyError): + def __init__(self, response: Any): + super().__init__(response) + self.response = response + + def __str__(self) -> str: + status = getattr(self.response, "status", None) + if status is None: + return "proxy rejected connection" + return f"proxy rejected connection: HTTP {int(status):d}" + + class InvalidMessage(InvalidHandshake): pass @@ -72,6 +135,56 @@ class InvalidUpgrade(InvalidHeader): pass +class InvalidHeaderFormat(InvalidHeader): + def __init__(self, name: str, error: str, header: str, pos: int): + super().__init__(name, f"{error} at {pos} in {header}") + + +class InvalidHeaderValue(InvalidHeader): + pass + + +class InvalidOrigin(InvalidHeader): + def __init__(self, origin: Optional[str]): + super().__init__("Origin", origin) + + +class NegotiationError(InvalidHandshake): + pass + + +class DuplicateParameter(NegotiationError): + def __init__(self, name: str): + super().__init__(name) + self.name = name + + def __str__(self) -> str: + return f"duplicate parameter: {self.name}" + + +class InvalidParameterName(NegotiationError): + def __init__(self, name: str): + super().__init__(name) + self.name = name + + def __str__(self) -> str: + return f"invalid parameter name: {self.name}" + + +class InvalidParameterValue(NegotiationError): + def __init__(self, name: str, value: Optional[str]): + super().__init__(name, value) + self.name = name + self.value = value + + def __str__(self) -> str: + if self.value is None: + return f"missing value for parameter {self.name}" + if self.value == "": + return f"empty value for parameter {self.name}" + return f"invalid value for parameter {self.name}: {self.value}" + + class ProtocolError(WebSocketException): pass @@ -80,9 +193,9 @@ class PayloadTooBig(WebSocketException): pass -class InvalidState(WebSocketException): +class InvalidState(WebSocketException, AssertionError): pass -class ConcurrencyError(WebSocketException): +class ConcurrencyError(WebSocketException, RuntimeError): pass diff --git a/picows/websockets/typing.py b/picows/websockets/typing.py new file mode 100644 index 0000000..8c0b5d4 --- /dev/null +++ b/picows/websockets/typing.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import logging +from http import HTTPStatus +from typing import Any + +from picows.types import WSHeadersLike + +BytesLike = bytes | bytearray | memoryview +Data = str | bytes +DataLike = str | bytes | bytearray | memoryview +HeadersLike = WSHeadersLike +LoggerLike = logging.Logger | logging.LoggerAdapter[Any] | str | None +StatusLike = HTTPStatus | int +Origin = str +Subprotocol = str +ExtensionName = str +ExtensionParameter = tuple[str, str | None] + +__all__ = [ + "BytesLike", + "Data", + "DataLike", + "ExtensionName", + "ExtensionParameter", + "HeadersLike", + "LoggerLike", + "Origin", + "StatusLike", + "Subprotocol", +] From 5ac97e3c29d448080c5b6ad11060394efa72a82f Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 16:56:40 +0200 Subject: [PATCH 15/57] Fix mypy issues --- picows/websockets/asyncio/client.py | 4 +-- picows/websockets/asyncio/connection.py | 45 ++++++++++++++----------- pyproject.toml | 4 +++ 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index fc08061..d82e48c 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -5,7 +5,7 @@ import warnings from collections.abc import Generator from ssl import SSLContext -from typing import Any, Optional, Sequence, Union, cast +from typing import Any, Callable, Optional, Sequence, Union, cast import picows from picows.url import parse_url @@ -71,7 +71,7 @@ def __init__( additional_headers: Optional[HeadersLike] = None, user_agent_header: Optional[str] = _default_user_agent(), proxy: Union[str, bool, None] = True, - process_exception=process_exception, + process_exception: Callable[[Exception], Optional[Exception]] = process_exception, open_timeout: Optional[float] = 10, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 5b7616d..1cdefa8 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -104,7 +104,7 @@ def release(self) -> None: @cython.cclass class _SingleConsumerQueue: _loop: asyncio.AbstractEventLoop - _items: deque + _items: deque[Optional[_BufferedFrame]] _waiter: Optional[asyncio.Future[Optional[_BufferedFrame]]] def __init__(self, loop: asyncio.AbstractEventLoop): @@ -133,7 +133,7 @@ def get_nowait(self) -> object: async def get(self) -> Optional[_BufferedFrame]: item = self.get_nowait() if item is not _QUEUE_EMPTY: - return item + return cast(Optional[_BufferedFrame], item) waiter: asyncio.Future[Optional[_BufferedFrame]] = self._loop.create_future() self._waiter = waiter @@ -153,7 +153,7 @@ def qsize(self) -> cython.Py_ssize_t: @cython.cfunc @cython.inline def _coerce_close_code(code: CloseCode) -> Optional[int]: - return None if code is None else int(code) + return None if code is None else cast(int, code) @cython.cfunc @@ -215,7 +215,7 @@ def process_exception(exc: Exception) -> Optional[Exception]: @cython.cclass -class ClientConnection(WSListener): +class ClientConnection(WSListener): # type: ignore[misc] id: uuid.UUID logger: Union[logging.Logger, logging.LoggerAdapter[Any]] transport: WSTransport @@ -453,9 +453,9 @@ def _set_recv_in_progress(self) -> None: @cython.inline def _decode_data(self, payload: bytes, msg_type: WSMsgType, decode: Optional[bool]) -> Data: if decode is True or (msg_type == WSMsgType.TEXT and decode is None): - return payload.decode("utf-8") + return cast(Data, payload.decode("utf-8")) else: - return payload + return cast(Data, payload) @cython.cfunc @cython.inline @@ -485,7 +485,7 @@ async def recv(self, decode: Optional[bool] = None) -> Data: msg_type = frame.msg_type if frame.fin: - return self._decode_data(frame.payload, msg_type, decode) + return cast(Data, self._decode_data(frame.payload, msg_type, decode)) chunks = [frame.payload] while not frame.fin: @@ -498,7 +498,7 @@ async def recv(self, decode: Optional[bool] = None) -> Data: chunks.append(frame.payload) payload = b"".join(chunks) - return self._decode_data(payload, msg_type, decode) + return cast(Data, self._decode_data(payload, msg_type, decode)) finally: self._recv_in_progress = False @@ -529,7 +529,7 @@ async def iterator() -> AsyncIterator[Data]: frame = await self._frames.get() self._check_frame(frame) frame = cast(_BufferedFrame, frame) - yield self._decode_data(frame.payload, msg_type, decode) + yield cast(Data, self._decode_data(frame.payload, msg_type, decode)) msg_finished = True finally: self._recv_in_progress = False @@ -573,7 +573,7 @@ async def send( @cython.cfunc @cython.inline - def _check_fragment_type(self, message, first_is_str: cython.bint) -> None: + def _check_fragment_type(self, message: DataLike, first_is_str: cython.bint) -> None: if first_is_str and isinstance(message, str): return elif not first_is_str and isinstance(message, (bytes, bytearray, memoryview)): @@ -581,13 +581,18 @@ def _check_fragment_type(self, message, first_is_str: cython.bint) -> None: raise TypeError("all fragments must be of the same category: str vs bytes-like") - async def _send_fragments(self, is_async: cython.bint, iterator: Union[Iterator[DataLike], AsyncIterator[DataLike]], text: Optional[bool]) -> None: + async def _send_fragments( + self, + is_async: cython.bint, + iterator: Union[Iterator[DataLike], AsyncIterator[DataLike]], + text: Optional[bool], + ) -> None: stop_exception_type = StopAsyncIteration if is_async else StopIteration try: if is_async: - first = await anext(iterator) + first = await anext(cast(AsyncIterator[DataLike], iterator)) else: - first = next(iterator) + first = next(cast(Iterator[DataLike], iterator)) except stop_exception_type: raise TypeError("message iterable cannot be empty") from None @@ -605,9 +610,9 @@ async def _send_fragments(self, is_async: cython.bint, iterator: Union[Iterator[ while True: try: if is_async: - current = await anext(iterator) + current = await anext(cast(AsyncIterator[DataLike], iterator)) else: - current = next(iterator) + current = next(cast(Iterator[DataLike], iterator)) except stop_exception_type: break @@ -716,7 +721,7 @@ def remote_address(self) -> Any: @property def latency(self) -> float: - return self._latency + return cast(float, self._latency) @property def subprotocol(self) -> Optional[Subprotocol]: @@ -728,9 +733,9 @@ def close_code(self) -> Optional[int]: if handshake is None: return None if handshake.recv is not None: - return _coerce_close_code(handshake.recv.code) + return cast(Optional[int], _coerce_close_code(handshake.recv.code)) if handshake.sent is not None: - return _coerce_close_code(handshake.sent.code) + return cast(Optional[int], _coerce_close_code(handshake.sent.code)) return None @property @@ -739,7 +744,7 @@ def close_reason(self) -> Optional[str]: if handshake is None: return None if handshake.recv is not None: - return _coerce_close_reason(handshake.recv.reason) + return cast(Optional[str], _coerce_close_reason(handshake.recv.reason)) if handshake.sent is not None: - return _coerce_close_reason(handshake.sent.reason) + return cast(Optional[str], _coerce_close_reason(handshake.sent.reason)) return None diff --git a/pyproject.toml b/pyproject.toml index 4d121ca..7ebb5b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,10 @@ files = "picows" ignore_missing_imports = true strict = true +[[tool.mypy.overrides]] +module = ["picows.websockets.asyncio.connection"] +disable_error_code = ["untyped-decorator"] + [tool.coverage.run] source = ["picows"] branch = true From ef6017037db1e18fcfdc97c8a7ca7a26d056b24c Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 17:03:22 +0200 Subject: [PATCH 16/57] Inline AsyncLock --- picows/websockets/asyncio/connection.py | 83 ++++++++++--------------- 1 file changed, 33 insertions(+), 50 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 1cdefa8..8c60209 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -57,50 +57,6 @@ def __init__(self, msg_type: WSMsgType, payload: bytes, fin: bool): self.fin = fin -@cython.cclass -class _AsyncLock: - _loop: asyncio.AbstractEventLoop - _locked: cython.bint - _waiters: Any - - def __init__(self, loop: asyncio.AbstractEventLoop): - self._loop = loop - self._locked = False - self._waiters = deque() - - @cython.cfunc - @cython.inline - def locked(self) -> cython.bint: - return self._locked - - @cython.cfunc - @cython.inline - def acquire(self) -> None: - self._locked = True - - async def wait_and_acquire(self) -> None: - waiter = self._loop.create_future() - self._waiters.append(waiter) - try: - await waiter - except Exception: - try: - self._waiters.remove(waiter) - except ValueError: - pass - raise - - @cython.cfunc - @cython.inline - def release(self) -> None: - while self._waiters: - waiter = self._waiters.popleft() - if not waiter.done(): - waiter.set_result(None) - return - self._locked = False - - @cython.cclass class _SingleConsumerQueue: _loop: asyncio.AbstractEventLoop @@ -228,7 +184,8 @@ class ClientConnection(WSListener): # type: ignore[misc] _close_exc: Optional[ConnectionClosed] _loop: asyncio.AbstractEventLoop _recv_in_progress: cython.bint - _send_lock: _AsyncLock + _send_in_progress: cython.bint + _send_waiters: deque[asyncio.Future[None]] _write_ready: Optional[asyncio.Future[None]] _recv_streaming_broken: cython.bint _pending_pings: Dict[bytes, Tuple[asyncio.Future[float], float]] @@ -269,7 +226,8 @@ def __init__( self._loop = asyncio.get_running_loop() self._frames = _SingleConsumerQueue(self._loop) self._recv_in_progress = False - self._send_lock = _AsyncLock(self._loop) + self._send_in_progress = False + self._send_waiters = deque() self._write_ready: Optional[asyncio.Future[None]] = None self._recv_streaming_broken = False self._pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} @@ -449,6 +407,31 @@ def _set_recv_in_progress(self) -> None: raise ConcurrencyError("recv_streaming() wasn't fully consumed") self._recv_in_progress = True + async def _wait_send_turn(self) -> None: + waiter: asyncio.Future[None] = self._loop.create_future() + self._send_waiters.append(waiter) + try: + await waiter + except Exception: + try: + self._send_waiters.remove(waiter) + except ValueError: + pass + raise + + @cython.cfunc + @cython.inline + def _release_send(self) -> None: + waiter: asyncio.Future[None] + + while self._send_waiters: + waiter = self._send_waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + return + + self._send_in_progress = False + @cython.cfunc @cython.inline def _decode_data(self, payload: bytes, msg_type: WSMsgType, decode: Optional[bool]) -> Data: @@ -548,10 +531,10 @@ async def send( if self._state is State.CLOSED: raise self._connection_closed() - if self._send_lock.locked(): - await self._send_lock.wait_and_acquire() + if self._send_in_progress: + await self._wait_send_turn() else: - self._send_lock.acquire() + self._send_in_progress = True try: if isinstance(message, (str, bytes, bytearray, memoryview)): @@ -569,7 +552,7 @@ async def send( else: raise TypeError(f"message has unsupported type {type(message).__name__}") finally: - self._send_lock.release() + self._release_send() @cython.cfunc @cython.inline From d38746c611424c2217af6b4aaa04c50e37a1074e Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 17:10:17 +0200 Subject: [PATCH 17/57] Simplify --- picows/websockets/asyncio/connection.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 8c60209..4452bea 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -436,9 +436,9 @@ def _release_send(self) -> None: @cython.inline def _decode_data(self, payload: bytes, msg_type: WSMsgType, decode: Optional[bool]) -> Data: if decode is True or (msg_type == WSMsgType.TEXT and decode is None): - return cast(Data, payload.decode("utf-8")) + return payload.decode("utf-8") else: - return cast(Data, payload) + return payload @cython.cfunc @cython.inline @@ -468,7 +468,7 @@ async def recv(self, decode: Optional[bool] = None) -> Data: msg_type = frame.msg_type if frame.fin: - return cast(Data, self._decode_data(frame.payload, msg_type, decode)) + return self._decode_data(frame.payload, msg_type, decode) chunks = [frame.payload] while not frame.fin: @@ -481,7 +481,7 @@ async def recv(self, decode: Optional[bool] = None) -> Data: chunks.append(frame.payload) payload = b"".join(chunks) - return cast(Data, self._decode_data(payload, msg_type, decode)) + return self._decode_data(payload, msg_type, decode) finally: self._recv_in_progress = False From 4c88a41cabee981fedabb7aec41ba8ce12fb9f53 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 17:13:00 +0200 Subject: [PATCH 18/57] Fix mypy --- picows/websockets/asyncio/connection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 4452bea..1f3c3db 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -468,7 +468,7 @@ async def recv(self, decode: Optional[bool] = None) -> Data: msg_type = frame.msg_type if frame.fin: - return self._decode_data(frame.payload, msg_type, decode) + return self._decode_data(frame.payload, msg_type, decode) # type: ignore[no-any-return] chunks = [frame.payload] while not frame.fin: @@ -481,7 +481,7 @@ async def recv(self, decode: Optional[bool] = None) -> Data: chunks.append(frame.payload) payload = b"".join(chunks) - return self._decode_data(payload, msg_type, decode) + return self._decode_data(payload, msg_type, decode) # type: ignore[no-any-return] finally: self._recv_in_progress = False From 832863f5c794a38bf6dbddb635825e15fc746205 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 17:21:31 +0200 Subject: [PATCH 19/57] Cleanup --- picows/websockets/asyncio/connection.py | 33 ++++++++++++++----------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 1f3c3db..6d60367 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -472,11 +472,12 @@ async def recv(self, decode: Optional[bool] = None) -> Data: chunks = [frame.payload] while not frame.fin: - frame = self._get_frame_nowait() - if frame is _QUEUE_EMPTY: - frame = await self._frames.get() + item = self._get_frame_nowait() + if item is _QUEUE_EMPTY: + item = await self._frames.get() + + frame = cython.cast(_BufferedFrame, item) self._check_frame(frame) - frame = cast(_BufferedFrame, frame) chunks.append(frame.payload) @@ -490,29 +491,31 @@ def recv_streaming(self, decode: Optional[bool] = None) -> AsyncIterator[Data]: msg_started: cython.bint = False msg_finished: cython.bint = False + frame: Optional[_BufferedFrame] + msg_type: WSMsgType async def iterator() -> AsyncIterator[Data]: nonlocal msg_started, msg_finished - frame: Optional[_BufferedFrame] - msg_type: WSMsgType try: - frame = self._get_frame_nowait() - if frame is _QUEUE_EMPTY: - frame = await self._frames.get() + item = self._get_frame_nowait() + if item is _QUEUE_EMPTY: + item = await self._frames.get() + frame = cython.cast(_BufferedFrame, item) self._check_frame(frame) - frame = cast(_BufferedFrame, frame) + msg_started = True msg_type = frame.msg_type yield self._decode_data(frame.payload, msg_type, decode) while not frame.fin: - frame = self._get_frame_nowait() - if frame is _QUEUE_EMPTY: - frame = await self._frames.get() + item = self._get_frame_nowait() + if item is _QUEUE_EMPTY: + item = await self._frames.get() + frame = cython.cast(_BufferedFrame, item) self._check_frame(frame) - frame = cast(_BufferedFrame, frame) - yield cast(Data, self._decode_data(frame.payload, msg_type, decode)) + + yield self._decode_data(frame.payload, msg_type, decode) msg_finished = True finally: self._recv_in_progress = False From 060560fe05a63b0ebf2a78b058551529bcb80400 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 18:22:51 +0200 Subject: [PATCH 20/57] Simplify logic --- picows/websockets/asyncio/connection.py | 171 ++++++++++-------------- 1 file changed, 67 insertions(+), 104 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 6d60367..fa9e646 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -35,7 +35,6 @@ OK_CLOSE_CODES = {0, 1000, 1001} -_QUEUE_EMPTY = object() class State(IntEnum): @@ -57,55 +56,6 @@ def __init__(self, msg_type: WSMsgType, payload: bytes, fin: bool): self.fin = fin -@cython.cclass -class _SingleConsumerQueue: - _loop: asyncio.AbstractEventLoop - _items: deque[Optional[_BufferedFrame]] - _waiter: Optional[asyncio.Future[Optional[_BufferedFrame]]] - - def __init__(self, loop: asyncio.AbstractEventLoop): - self._loop = loop - self._items = deque() - self._waiter = None - - @cython.cfunc - @cython.inline - def put(self, item: Optional[_BufferedFrame]) -> None: - waiter = self._waiter - if waiter is not None: - self._waiter = None - if not waiter.done(): - waiter.set_result(item) - return - self._items.append(item) - - @cython.cfunc - @cython.inline - def get_nowait(self) -> object: - if self._items: - return self._items.popleft() - return _QUEUE_EMPTY - - async def get(self) -> Optional[_BufferedFrame]: - item = self.get_nowait() - if item is not _QUEUE_EMPTY: - return cast(Optional[_BufferedFrame], item) - - waiter: asyncio.Future[Optional[_BufferedFrame]] = self._loop.create_future() - self._waiter = waiter - try: - return await waiter - except Exception: - if self._waiter is waiter: - self._waiter = None - raise - - @cython.cfunc - @cython.inline - def qsize(self) -> cython.Py_ssize_t: - return len(self._items) - - @cython.cfunc @cython.inline def _coerce_close_code(code: CloseCode) -> Optional[int]: @@ -180,27 +130,33 @@ class ClientConnection(WSListener): # type: ignore[misc] _subprotocols: Optional[Sequence[Subprotocol]] _subprotocol: Optional[Subprotocol] _state: State - _frames: _SingleConsumerQueue _close_exc: Optional[ConnectionClosed] _loop: asyncio.AbstractEventLoop - _recv_in_progress: cython.bint + + # Send side _send_in_progress: cython.bint _send_waiters: deque[asyncio.Future[None]] _write_ready: Optional[asyncio.Future[None]] + _write_limit: Union[int, tuple[int, Optional[int]]] + + # Recv side + _recv_in_progress: cython.bint _recv_streaming_broken: cython.bint + _paused_reading: cython.bint + _recv_waiter: Optional[asyncio.Future[None]] + _recv_queue: deque[Optional[_BufferedFrame]] + _max_message_size: cython.Py_ssize_t + _max_queue_high: cython.Py_ssize_t + _max_queue_low: cython.Py_ssize_t + _incoming_message_active: cython.bint + _incoming_message_size: cython.Py_ssize_t + _pending_pings: Dict[bytes, Tuple[asyncio.Future[float], float]] _ping_interval: Optional[float] _ping_timeout: Optional[float] _close_timeout: Optional[float] _keepalive_task: Optional[asyncio.Task[None]] _latency: cython.double - _max_message_size: cython.Py_ssize_t - _max_queue_high: cython.Py_ssize_t - _max_queue_low: cython.Py_ssize_t - _incoming_message_active: cython.bint - _incoming_message_size: cython.Py_ssize_t - _write_limit: Union[int, tuple[int, Optional[int]]] - _paused_reading: cython.bint def __init__( self, @@ -224,24 +180,28 @@ def __init__( self._state = State.CONNECTING self._close_exc: Optional[ConnectionClosed] = None self._loop = asyncio.get_running_loop() - self._frames = _SingleConsumerQueue(self._loop) - self._recv_in_progress = False + self._send_in_progress = False self._send_waiters = deque() self._write_ready: Optional[asyncio.Future[None]] = None + self._write_limit = write_limit + + self._recv_in_progress = False self._recv_streaming_broken = False + self._paused_reading = False + self._recv_waiter = None + self._recv_queue = deque() + self._max_message_size = 0 if max_message_size is None else max_message_size + self._max_queue_high, self._max_queue_low = _normalize_watermarks(max_queue) + self._incoming_message_active = False + self._incoming_message_size = 0 + self._pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} self._ping_interval = ping_interval self._ping_timeout = ping_timeout self._close_timeout = close_timeout self._keepalive_task: Optional[asyncio.Task[None]] = None self._latency = 0.0 - self._max_message_size = 0 if max_message_size is None else max_message_size - self._max_queue_high, self._max_queue_low = _normalize_watermarks(max_queue) - self._incoming_message_active = False - self._incoming_message_size = 0 - self._write_limit = write_limit - self._paused_reading = False @cython.ccall def on_ws_connected(self, transport: WSTransport) -> None: @@ -258,7 +218,7 @@ def on_ws_connected(self, transport: WSTransport) -> None: def on_ws_disconnected(self, transport: WSTransport) -> None: self._state = State.CLOSED self._set_close_exception() - self._frames.put(None) + self._add_to_recv_queue(None) if self._keepalive_task is not None: self._keepalive_task.cancel() self._keepalive_task = None @@ -320,7 +280,7 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: self._fail_protocol_error(f"unexpected opcode while receiving message: {frame.msg_type}") return - self._frames.put(_BufferedFrame(frame.msg_type, frame.get_payload_as_bytes(), frame.fin)) + self._add_to_recv_queue(_BufferedFrame(frame.msg_type, frame.get_payload_as_bytes(), frame.fin)) self._pause_reading_if_needed() @cython.ccall @@ -347,7 +307,7 @@ def _set_write_limits(self, write_limit: Union[int, tuple[int, Optional[int]]]) @cython.cfunc @cython.inline def _pause_reading_if_needed(self) -> None: - if self._max_queue_high > 0 and not self._paused_reading and self._frames.qsize() >= self._max_queue_high: + if self._max_queue_high > 0 and not self._paused_reading and len(self._recv_queue) >= self._max_queue_high: self.transport.underlying_transport.pause_reading() self._paused_reading = True @@ -356,10 +316,28 @@ def _pause_reading_if_needed(self) -> None: def _resume_reading_if_needed(self) -> None: if not self._paused_reading: return - if self._max_queue_low == 0 or self._frames.qsize() <= self._max_queue_low: + if self._max_queue_low == 0 or len(self._recv_queue) <= self._max_queue_low: self.transport.underlying_transport.resume_reading() self._paused_reading = False + @cython.cfunc + @cython.inline + def _add_to_recv_queue(self, frame: Optional[_BufferedFrame]) -> None: + self._recv_queue.append(frame) + waiter = self._recv_waiter + if waiter is not None: + self._recv_waiter = None + if not waiter.done(): + waiter.set_result(None) + + @cython.cfunc + @cython.inline + def _wait_recv_queue_not_empty(self) -> asyncio.Future[None]: + assert self._recv_waiter is None + waiter: asyncio.Future[None] = self._loop.create_future() + self._recv_waiter = waiter + return waiter + @cython.cfunc @cython.inline def _set_close_exception(self) -> None: @@ -442,29 +420,21 @@ def _decode_data(self, payload: bytes, msg_type: WSMsgType, decode: Optional[boo @cython.cfunc @cython.inline - def _check_frame(self, frame: Optional[_BufferedFrame]) -> None: + def _check_frame(self, frame: Optional[_BufferedFrame]) -> _BufferedFrame: self._resume_reading_if_needed() - if frame is None: raise self._connection_closed() - - @cython.cfunc - @cython.inline - def _get_frame_nowait(self) -> object: - return self._frames.get_nowait() + return frame async def recv(self, decode: Optional[bool] = None) -> Data: - frame: Optional[_BufferedFrame] + frame: _BufferedFrame self._set_recv_in_progress() try: - item = self._get_frame_nowait() - if item is _QUEUE_EMPTY: - item = await self._frames.get() - - frame = cython.cast(_BufferedFrame, item) - self._check_frame(frame) + if not self._recv_queue: + await self._wait_recv_queue_not_empty() + frame = self._check_frame(self._recv_queue.popleft()) msg_type = frame.msg_type if frame.fin: @@ -472,17 +442,14 @@ async def recv(self, decode: Optional[bool] = None) -> Data: chunks = [frame.payload] while not frame.fin: - item = self._get_frame_nowait() - if item is _QUEUE_EMPTY: - item = await self._frames.get() - - frame = cython.cast(_BufferedFrame, item) - self._check_frame(frame) + if not self._recv_queue: + await self._wait_recv_queue_not_empty() + frame = self._check_frame(self._recv_queue.popleft()) chunks.append(frame.payload) payload = b"".join(chunks) - return self._decode_data(payload, msg_type, decode) # type: ignore[no-any-return] + return self._decode_data(payload, msg_type, decode) finally: self._recv_in_progress = False @@ -491,29 +458,25 @@ def recv_streaming(self, decode: Optional[bool] = None) -> AsyncIterator[Data]: msg_started: cython.bint = False msg_finished: cython.bint = False - frame: Optional[_BufferedFrame] + frame: _BufferedFrame msg_type: WSMsgType async def iterator() -> AsyncIterator[Data]: nonlocal msg_started, msg_finished try: - item = self._get_frame_nowait() - if item is _QUEUE_EMPTY: - item = await self._frames.get() - frame = cython.cast(_BufferedFrame, item) - self._check_frame(frame) + if not self._recv_queue: + await self._wait_recv_queue_not_empty() + frame = self._check_frame(self._recv_queue.popleft()) msg_started = True msg_type = frame.msg_type yield self._decode_data(frame.payload, msg_type, decode) while not frame.fin: - item = self._get_frame_nowait() - if item is _QUEUE_EMPTY: - item = await self._frames.get() - frame = cython.cast(_BufferedFrame, item) - self._check_frame(frame) + if not self._recv_queue: + await self._wait_recv_queue_not_empty() + frame = self._check_frame(self._recv_queue.popleft()) yield self._decode_data(frame.payload, msg_type, decode) msg_finished = True @@ -707,7 +670,7 @@ def remote_address(self) -> Any: @property def latency(self) -> float: - return cast(float, self._latency) + return self._latency # type: ignore[no-any-return] @property def subprotocol(self) -> Optional[Subprotocol]: From 02ac50f8debf2d97fd1ffa0a599c58fd0e4f7cc5 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 18:53:12 +0200 Subject: [PATCH 21/57] Better cancellation logic for recv and recv_streaming --- picows/websockets/asyncio/connection.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index fa9e646..89fc7d4 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -440,18 +440,25 @@ async def recv(self, decode: Optional[bool] = None) -> Data: if frame.fin: return self._decode_data(frame.payload, msg_type, decode) # type: ignore[no-any-return] - chunks = [frame.payload] - while not frame.fin: - if not self._recv_queue: - await self._wait_recv_queue_not_empty() - frame = self._check_frame(self._recv_queue.popleft()) + try: + frames = [frame] + payloads = [frame.payload] + while not frame.fin: + if not self._recv_queue: + await self._wait_recv_queue_not_empty() + frame = self._check_frame(self._recv_queue.popleft()) - chunks.append(frame.payload) + frames.append(frame) + payloads.append(frame.payload) + except asyncio.CancelledError: + self._recv_queue.extendleft(reversed(frames)) + raise - payload = b"".join(chunks) + payload = b"".join(payloads) return self._decode_data(payload, msg_type, decode) finally: self._recv_in_progress = False + self._recv_waiter = None def recv_streaming(self, decode: Optional[bool] = None) -> AsyncIterator[Data]: self._set_recv_in_progress() @@ -482,6 +489,7 @@ async def iterator() -> AsyncIterator[Data]: msg_finished = True finally: self._recv_in_progress = False + self._recv_waiter = None if msg_started and not msg_finished: self._recv_streaming_broken = True elif msg_finished: From 2fc8e231b2ca05a2527166be0831732eea648168 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 18:55:49 +0200 Subject: [PATCH 22/57] Better exception safety --- picows/websockets/asyncio/connection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 89fc7d4..7090ae2 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -450,12 +450,12 @@ async def recv(self, decode: Optional[bool] = None) -> Data: frames.append(frame) payloads.append(frame.payload) + + payload = b"".join(payloads) + return self._decode_data(payload, msg_type, decode) except asyncio.CancelledError: self._recv_queue.extendleft(reversed(frames)) raise - - payload = b"".join(payloads) - return self._decode_data(payload, msg_type, decode) finally: self._recv_in_progress = False self._recv_waiter = None From 41b19dc328eb9f7cb9a6cb7f3b892bb51267e109 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 18:56:42 +0200 Subject: [PATCH 23/57] Better exception safety --- picows/websockets/asyncio/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 7090ae2..2422848 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -440,8 +440,8 @@ async def recv(self, decode: Optional[bool] = None) -> Data: if frame.fin: return self._decode_data(frame.payload, msg_type, decode) # type: ignore[no-any-return] + frames = [frame] try: - frames = [frame] payloads = [frame.payload] while not frame.fin: if not self._recv_queue: From e00eab872afba89fc7cf4ca63c8c14d6fe26b995 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 3 May 2026 19:44:51 +0200 Subject: [PATCH 24/57] Various fixes --- AGENTS.md | 3 + picows/websockets/asyncio/client.py | 9 +- picows/websockets/asyncio/connection.py | 120 +++++++++++++----------- 3 files changed, 76 insertions(+), 56 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 11908d4..64f59d1 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -28,6 +28,9 @@ examples - Various examples for users on how to use picows + perf_test that coul We can skip complicated areas such as the full server interface, but simple surface-area compatibility matters. Type definitions, exception definitions, and other lightweight importable names should exist when upstream exposes them. People switching from `websockets` to `picows.websockets` should notice as little difference as possible. +- In Cythonized Python modules, avoid `typing.cast(...)` in hot paths. + Cython may compile `cast(...)` into a real runtime global lookup and function call instead of erasing it like a type checker would. + Prefer control-flow narrowing, assertions, or narrowly scoped type-ignore comments when needed. ## Testing instructions - Run lint after updating code with: diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index d82e48c..2cfdd8a 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -5,7 +5,7 @@ import warnings from collections.abc import Generator from ssl import SSLContext -from typing import Any, Callable, Optional, Sequence, Union, cast +from typing import Any, Callable, Optional, Sequence, Union import picows from picows.url import parse_url @@ -166,8 +166,10 @@ async def _connect(self) -> ClientConnection: if preexisting_sock is not None: if socket_factory is not None: raise TypeError("cannot pass both sock and socket_factory") + if not isinstance(preexisting_sock, socket.socket): + raise TypeError("sock must be a socket.socket instance") - provided_sock = cast(socket.socket, preexisting_sock) + provided_sock = preexisting_sock def provided_socket(_: Any) -> socket.socket: return provided_sock @@ -228,7 +230,8 @@ def listener_factory() -> ClientConnection: except picows.WSHandshakeError as exc: raise InvalidHandshake(str(exc)) from exc - return cast(ClientConnection, listener) + assert isinstance(listener, ClientConnection) + return listener def _build_headers(self) -> list[tuple[str, str]]: headers = _header_items(self.additional_headers) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 2422848..169e8f1 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -9,7 +9,7 @@ from enum import IntEnum from time import monotonic from typing import Any, AsyncIterator, Awaitable, Optional, Sequence, \ - Union, cast, Dict, Tuple, Iterator + Union, Dict, Tuple, Iterator import cython @@ -19,9 +19,6 @@ else: from picows import WSListener, WSTransport, WSFrame, WSMsgType, WSCloseCode - -import picows - from ..compat import CloseCode, Request, Response from ..exceptions import ( ConcurrencyError, @@ -59,7 +56,7 @@ def __init__(self, msg_type: WSMsgType, payload: bytes, fin: bool): @cython.cfunc @cython.inline def _coerce_close_code(code: CloseCode) -> Optional[int]: - return None if code is None else cast(int, code) + return None if code is None else code # type: ignore[return-value] @cython.cfunc @@ -79,9 +76,11 @@ def _resolve_subprotocol( value = response.headers.get("Sec-WebSocket-Protocol") if value is None: return None + if not isinstance(value, str): + raise InvalidHandshake("server returned non-string subprotocol") if subprotocols is not None and value not in subprotocols: raise InvalidHandshake(f"unsupported subprotocol negotiated by server: {value}") - return cast(Subprotocol, value) + return value @cython.cfunc @@ -173,10 +172,10 @@ def __init__( self.id = uuid.uuid4() self.logger = _resolve_logger(logger) self.transport = cython.cast(WSTransport, None) - self.request = cast(Request, None) - self.response = cast(Response, None) + self.request = None # type: ignore[assignment] + self.response = None # type: ignore[assignment] self._subprotocols = subprotocols - self._subprotocol = cast(Optional[Subprotocol], None) + self._subprotocol = None self._state = State.CONNECTING self._close_exc: Optional[ConnectionClosed] = None self._loop = asyncio.get_running_loop() @@ -519,10 +518,8 @@ async def send( self.transport.send(msg_type, message) if self._write_ready is not None: await self._write_ready - elif isinstance(message, AsyncIterable): - await self._send_fragments(True, message.__aiter__(), text) - elif isinstance(message, Iterable): - await self._send_fragments(False, iter(message), text) + elif isinstance(message, (AsyncIterable, Iterable)): + await self._send_fragments(message, text) # type: ignore[arg-type] else: raise TypeError(f"message has unsupported type {type(message).__name__}") finally: @@ -540,51 +537,69 @@ def _check_fragment_type(self, message: DataLike, first_is_str: cython.bint) -> async def _send_fragments( self, - is_async: cython.bint, - iterator: Union[Iterator[DataLike], AsyncIterator[DataLike]], + messages: Union[AsyncIterable[DataLike], Iterable[DataLike]], text: Optional[bool], ) -> None: - stop_exception_type = StopAsyncIteration if is_async else StopIteration - try: - if is_async: - first = await anext(cast(AsyncIterator[DataLike], iterator)) - else: - first = next(cast(Iterator[DataLike], iterator)) - except stop_exception_type: - raise TypeError("message iterable cannot be empty") from None - - first_is_str: cython.bint - if isinstance(first, str): - msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT - first_is_str = True - elif isinstance(first, (bytes, bytearray, memoryview)): - msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY - first_is_str = False + is_async: cython.bint + async_iterator: AsyncIterator[DataLike] + iterator: Iterator[DataLike] + stop_exception_type: Union[type[StopAsyncIteration], type[StopIteration]] + + if isinstance(messages, AsyncIterable): + async_iterator = messages.__aiter__() + iterator = None # type: ignore[assignment] + stop_exception_type = StopAsyncIteration + is_async = True else: - raise TypeError(f"message must contain str or bytes-like objects, got {type(first).__name__}") + async_iterator = None # type: ignore[assignment] + iterator = iter(messages) + stop_exception_type = StopIteration + is_async = False - previous = first - while True: + try: try: if is_async: - current = await anext(cast(AsyncIterator[DataLike], iterator)) + first = await anext(async_iterator) else: - current = next(cast(Iterator[DataLike], iterator)) + first = next(iterator) except stop_exception_type: - break + raise TypeError("message iterable cannot be empty") from None + + first_is_str: cython.bint + if isinstance(first, str): + msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT + first_is_str = True + elif isinstance(first, (bytes, bytearray, memoryview)): + msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY + first_is_str = False + else: + raise TypeError(f"message must contain str or bytes-like objects, got {type(first).__name__}") + + previous = first + while True: + try: + if is_async: + current = await anext(async_iterator) + else: + current = next(iterator) + except stop_exception_type: + break - self._check_fragment_type(current, first_is_str) + self._check_fragment_type(current, first_is_str) - self.transport.send(msg_type, previous, fin=False) - msg_type = WSMsgType.CONTINUATION - if self._write_ready is not None: - await self._write_ready + self.transport.send(msg_type, previous, fin=False) + msg_type = WSMsgType.CONTINUATION + if self._write_ready is not None: + await self._write_ready - previous = current + previous = current - self.transport.send(msg_type, previous, fin=True) - if self._write_ready is not None: - await self._write_ready + self.transport.send(msg_type, previous, fin=True) + if self._write_ready is not None: + await self._write_ready + except: + self._fail_protocol_error("error in fragmented message") + raise async def close(self, code: int = 1000, reason: str = "") -> None: if self._state is State.CLOSED: @@ -630,8 +645,7 @@ async def ping(self, data: Optional[Union[str, bytes]] = None) -> Awaitable[floa async def pong(self, data: Union[str, bytes] = b"") -> None: if self._state is State.CLOSED: raise self._connection_closed() - payload = data.encode("utf-8") if isinstance(data, str) else data - self.transport.send_pong(payload) + self.transport.send_pong(data) async def _keepalive_loop(self) -> None: try: @@ -648,7 +662,7 @@ async def _keepalive_loop(self) -> None: if self.state is not State.CLOSED: await self.close(code=1011, reason="keepalive ping timeout") - async def __aenter__(self) -> ClientConnection: + async def __aenter__(self): # type: ignore[no-untyped-def] return self async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: @@ -690,9 +704,9 @@ def close_code(self) -> Optional[int]: if handshake is None: return None if handshake.recv is not None: - return cast(Optional[int], _coerce_close_code(handshake.recv.code)) + return _coerce_close_code(handshake.recv.code) # type: ignore[no-any-return] if handshake.sent is not None: - return cast(Optional[int], _coerce_close_code(handshake.sent.code)) + return _coerce_close_code(handshake.sent.code) # type: ignore[no-any-return] return None @property @@ -701,7 +715,7 @@ def close_reason(self) -> Optional[str]: if handshake is None: return None if handshake.recv is not None: - return cast(Optional[str], _coerce_close_reason(handshake.recv.reason)) + return _coerce_close_reason(handshake.recv.reason) # type: ignore[no-any-return] if handshake.sent is not None: - return cast(Optional[str], _coerce_close_reason(handshake.sent.reason)) + return _coerce_close_reason(handshake.sent.reason) # type: ignore[no-any-return] return None From d612c85e9451a87a916396d9d758e5f9751a8bb0 Mon Sep 17 00:00:00 2001 From: taras Date: Mon, 4 May 2026 14:49:10 +0200 Subject: [PATCH 25/57] Add permessage-deflate support --- picows/websockets/asyncio/client.py | 15 +- picows/websockets/asyncio/connection.py | 363 +++++++++++++++++--- tests/test_websockets_compat.py | 2 +- tests/test_websockets_compression_compat.py | 61 ++++ 4 files changed, 380 insertions(+), 61 deletions(-) create mode 100644 tests/test_websockets_compression_compat.py diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index 2cfdd8a..9375bd1 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -2,7 +2,6 @@ import asyncio import socket -import warnings from collections.abc import Generator from ssl import SSLContext from typing import Any, Callable, Optional, Sequence, Union @@ -31,6 +30,9 @@ ] +_PERMESSAGE_DEFLATE_REQUEST = "permessage-deflate; client_max_window_bits" + + def _default_user_agent() -> str: import sys return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" @@ -149,12 +151,6 @@ async def _connect(self) -> ClientConnection: raise NotImplementedError("custom extensions aren't supported by picows.websockets") if self.compression not in (None, "deflate"): raise NotImplementedError("only compression=None or 'deflate' are accepted") - if self.compression == "deflate": - warnings.warn( - "picows.websockets doesn't implement permessage-deflate; connecting without compression", - RuntimeWarning, - stacklevel=2, - ) conn_kwargs = dict(self.kwargs) ssl_context = conn_kwargs.pop("ssl", None) @@ -200,6 +196,7 @@ def listener_factory() -> ClientConnection: max_message_size=max_message_size, logger=self.logger, subprotocols=self.subprotocols, + compression=self.compression, ) try: @@ -231,6 +228,8 @@ def listener_factory() -> ClientConnection: raise InvalidHandshake(str(exc)) from exc assert isinstance(listener, ClientConnection) + if listener.connect_exception is not None: + raise listener.connect_exception return listener def _build_headers(self) -> list[tuple[str, str]]: @@ -241,6 +240,8 @@ def _build_headers(self) -> list[tuple[str, str]]: headers.append(("User-Agent", self.user_agent_header)) if self.subprotocols: headers.append(("Sec-WebSocket-Protocol", ", ".join(self.subprotocols))) + if self.compression == "deflate": + headers.append(("Sec-WebSocket-Extensions", _PERMESSAGE_DEFLATE_REQUEST)) return headers def _coerce_ssl_context(self, value: Any) -> Optional[SSLContext]: diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 169e8f1..2d8b9b3 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -4,12 +4,13 @@ import logging import os import uuid +import zlib from collections import deque from collections.abc import AsyncIterable, Generator, Iterable from enum import IntEnum from time import monotonic from typing import Any, AsyncIterator, Awaitable, Optional, Sequence, \ - Union, Dict, Tuple, Iterator + Union, Dict, Tuple, Iterator, Mapping import cython @@ -28,10 +29,11 @@ InvalidHandshake, InvalidStatus, ) -from ..typing import Data, DataLike, LoggerLike, Subprotocol +from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol OK_CLOSE_CODES = {0, 1000, 1001} +_EMPTY_UNCOMPRESSED_BLOCK = b"\x00\x00\xff\xff" class State(IntEnum): @@ -53,6 +55,168 @@ def __init__(self, msg_type: WSMsgType, payload: bytes, fin: bool): self.fin = fin +class _CompressionError(Exception): + pass + + +@cython.cclass +class _PerMessageDeflate: + remote_no_context_takeover: cython.bint + local_no_context_takeover: cython.bint + remote_max_window_bits: cython.int + local_max_window_bits: cython.int + _decoder: Any + _encoder: Any + _decode_cont_data: cython.int + + def __init__( + self, + *, + remote_no_context_takeover: bool, + local_no_context_takeover: bool, + remote_max_window_bits: int, + local_max_window_bits: int, + ): + self.remote_no_context_takeover = remote_no_context_takeover + self.local_no_context_takeover = local_no_context_takeover + self.remote_max_window_bits = remote_max_window_bits + self.local_max_window_bits = local_max_window_bits + self._decoder = None + self._encoder = None + self._decode_cont_data = False + + # wbits: +9 to +15 + # The base-two logarithm of the window size, which therefore ranges between 512 and 32768. + # Larger values produce better compression at the expense of greater memory usage. + # The resulting output will include a zlib-specific header and trailer. + # Negative wbits: + # Uses the absolute value of wbits as the window size logarithm, + # while producing a raw output stream with no header or trailing checksum. + + if not self.remote_no_context_takeover: + self._decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) + if not self.local_no_context_takeover: + self._encoder = zlib.compressobj(wbits=-self.local_max_window_bits) + + @classmethod + def from_response_header(cls, header_value: str) -> _PerMessageDeflate: + extensions = [item.strip() for item in header_value.split(",") if item.strip()] + if len(extensions) != 1: + raise _CompressionError("unsupported websocket extension negotiation") + + parts = [item.strip() for item in extensions[0].split(";")] + if not parts or parts[0] != "permessage-deflate": + raise _CompressionError("unsupported websocket extension negotiation") + + server_no_context_takeover = False + client_no_context_takeover = False + server_max_window_bits = None + client_max_window_bits = None + seen = set() + + for raw_param in parts[1:]: + if not raw_param: + continue + if "=" in raw_param: + name, value = raw_param.split("=", 1) + name = name.strip() + value = value.strip() + else: + name = raw_param + value = None + + if name in seen: + raise _CompressionError(f"duplicate extension parameter: {name}") + seen.add(name) + + if name == "server_no_context_takeover": + if value is not None: + raise _CompressionError("invalid server_no_context_takeover value") + server_no_context_takeover = True + elif name == "client_no_context_takeover": + if value is not None: + raise _CompressionError("invalid client_no_context_takeover value") + client_no_context_takeover = True + elif name == "server_max_window_bits": + if value is None or not value.isdigit(): + raise _CompressionError("invalid server_max_window_bits value") + server_max_window_bits = int(value) + if not 8 <= server_max_window_bits <= 15: + raise _CompressionError("invalid server_max_window_bits value") + elif name == "client_max_window_bits": + if value is None or not value.isdigit(): + raise _CompressionError("invalid client_max_window_bits value") + client_max_window_bits = int(value) + if not 8 <= client_max_window_bits <= 15: + raise _CompressionError("invalid client_max_window_bits value") + else: + raise _CompressionError(f"unsupported extension parameter: {name}") + + return cls( + remote_no_context_takeover=server_no_context_takeover, + local_no_context_takeover=client_no_context_takeover, + remote_max_window_bits=server_max_window_bits or 15, + local_max_window_bits=client_max_window_bits or 15, + ) + + @cython.cfunc + @cython.inline + def decode_frame(self, frame: WSFrame, max_length: cython.Py_ssize_t) -> bytes: + data: bytes + data2: bytes + + if frame.msg_type == WSMsgType.CONTINUATION: + if frame.rsv1: + raise _CompressionError("unexpected rsv1 on continuation frame") + if not self._decode_cont_data: + return frame.get_payload_as_bytes() + if frame.fin: + self._decode_cont_data = False + else: + if not frame.rsv1: + return frame.get_payload_as_bytes() + if not frame.fin: + self._decode_cont_data = True + if self.remote_no_context_takeover or self._decoder is None: + self._decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) + + assert self._decoder is not None + try: + data = self._decoder.decompress(frame.get_payload_as_memoryview(), max_length) + + if self._decoder.unconsumed_tail: + raise _CompressionError("message too big") + + if frame.fin: + data2 = self._decoder.decompress(_EMPTY_UNCOMPRESSED_BLOCK, max_length) + if data2: + data += data2 + except zlib.error as exc: + raise _CompressionError("decompression failed") from exc + + if frame.fin and self.remote_no_context_takeover: + self._decoder = None + + return data + + @cython.cfunc + @cython.inline + def encode_frame(self, msg_type: WSMsgType, payload: BytesLike, fin: cython.bint) -> tuple[BytesLike, cython.bint]: + if msg_type != WSMsgType.CONTINUATION and (self.local_no_context_takeover or self._encoder is None): + self._encoder = zlib.compressobj(wbits=-self.local_max_window_bits) + + data: BytesLike = (self._encoder.compress(payload) + + self._encoder.flush(zlib.Z_SYNC_FLUSH)) + if fin: + data_mv = memoryview(data) + assert data_mv[-4:] == _EMPTY_UNCOMPRESSED_BLOCK + data = data_mv[:-4] + if self.local_no_context_takeover: + self._encoder = None + + return data, msg_type != WSMsgType.CONTINUATION + + @cython.cfunc @cython.inline def _coerce_close_code(code: CloseCode) -> Optional[int]: @@ -124,10 +288,13 @@ class ClientConnection(WSListener): # type: ignore[misc] id: uuid.UUID logger: Union[logging.Logger, logging.LoggerAdapter[Any]] transport: WSTransport - request: Request - response: Response + _request: Request + _response: Response + _connect_exception: Optional[Exception] _subprotocols: Optional[Sequence[Subprotocol]] _subprotocol: Optional[Subprotocol] + _compression: Optional[str] + _permessage_deflate: Optional[_PerMessageDeflate] _state: State _close_exc: Optional[ConnectionClosed] _loop: asyncio.AbstractEventLoop @@ -168,14 +335,18 @@ def __init__( max_message_size: Optional[int] = 1024 * 1024, logger: LoggerLike = None, subprotocols: Optional[Sequence[Subprotocol]] = None, + compression: Optional[str] = None, ): self.id = uuid.uuid4() self.logger = _resolve_logger(logger) self.transport = cython.cast(WSTransport, None) - self.request = None # type: ignore[assignment] - self.response = None # type: ignore[assignment] + self._request = None # type: ignore[assignment] + self._response = None # type: ignore[assignment] + self._connect_exception = None self._subprotocols = subprotocols self._subprotocol = None + self._compression = compression + self._permessage_deflate = None self._state = State.CONNECTING self._close_exc: Optional[ConnectionClosed] = None self._loop = asyncio.get_running_loop() @@ -205,9 +376,16 @@ def __init__( @cython.ccall def on_ws_connected(self, transport: WSTransport) -> None: self.transport = transport - self.request = transport.request - self.response = transport.response - self._subprotocol = _resolve_subprotocol(self._subprotocols, self.response) + self._request = transport.request + self._response = transport.response + try: + self._subprotocol = _resolve_subprotocol(self._subprotocols, self._response) + self._configure_extensions() + except InvalidHandshake as exc: + self._connect_exception = exc + self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, str(exc).encode("utf-8")) + self.transport.disconnect(False) + return self._state = State.OPEN self._set_write_limits(self._write_limit) if self._ping_interval is not None and self._keepalive_task is None: @@ -232,54 +410,82 @@ def on_ws_disconnected(self, transport: WSTransport) -> None: waiter.set_exception(self._close_exc or ConnectionClosedError(None, None, None)) self._pending_pings.clear() + @cython.cfunc + @cython.inline + def _process_pong_frame(self, frame: WSFrame) -> None: + ping = self._pending_pings.pop(frame.get_payload_as_bytes(), None) + if ping is not None: + waiter, sent_at = ping + self._latency = monotonic() - sent_at + if not waiter.done(): + waiter.set_result(self._latency) + + @cython.cfunc + @cython.inline + def _process_close_frame(self, frame: WSFrame) -> None: + close_code = frame.get_close_code() + close_message = frame.get_close_message() + self.transport.send_close(close_code, close_message) + self.transport.disconnect() + self._state = State.CLOSING + @cython.ccall def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: if frame.msg_type == WSMsgType.PONG: - ping = self._pending_pings.pop(frame.get_payload_as_bytes(), None) - if ping is not None: - waiter, sent_at = ping - self._latency = monotonic() - sent_at - if not waiter.done(): - waiter.set_result(self._latency) + self._process_pong_frame(frame) return if frame.msg_type == WSMsgType.CLOSE: - close_code = frame.get_close_code() - close_message = frame.get_close_message() - self.transport.send_close(close_code, close_message) - self.transport.disconnect() - self._state = State.CLOSING + self._process_close_frame(frame) return - if frame.msg_type == WSMsgType.CONTINUATION: - if not self._incoming_message_active: - self._fail_protocol_error("unexpected continuation frame") - return + if frame.msg_type not in (WSMsgType.TEXT, WSMsgType.BINARY, WSMsgType.CONTINUATION): + self._fail_protocol_error("unsupported frame opcode") + return + + if self._permessage_deflate is None and frame.rsv1: + self._fail_protocol_error("received compressed frame without negotiated permessage-deflate") + return - self._incoming_message_size += frame.payload_size - if self._max_message_size > 0 and self._incoming_message_size > self._max_message_size: - self._fail_message_too_big("message too big") + if frame.msg_type == WSMsgType.CONTINUATION and not self._incoming_message_active: + self._fail_protocol_error("unexpected continuation frame") + return + + if frame.msg_type != WSMsgType.CONTINUATION and self._incoming_message_active: + self._fail_protocol_error("expected continuation frame") + return + + payload: bytes + if self._permessage_deflate is not None: + remaining = 0 if self._max_message_size == 0 else ( + max(self._max_message_size - self._incoming_message_size, 0)) + try: + payload = self._permessage_deflate.decode_frame(frame, remaining) + except _CompressionError as exc: + if str(exc) == "message too big": + self._fail_message_too_big("message too big") + else: + self._fail_protocol_error(str(exc)) return + else: + payload = frame.get_payload_as_bytes() + + self._incoming_message_size = len(payload) + if self._max_message_size > 0 and self._incoming_message_size > self._max_message_size: + self._fail_message_too_big("message too big") + return + + if frame.msg_type == WSMsgType.CONTINUATION: if frame.fin: self._incoming_message_active = False self._incoming_message_size = 0 - elif frame.msg_type in (WSMsgType.TEXT, WSMsgType.BINARY): - if self._incoming_message_active: - self._fail_protocol_error("expected continuation frame") - return - self._incoming_message_size = frame.payload_size - if self._max_message_size > 0 and self._incoming_message_size > self._max_message_size: - self._fail_message_too_big("message too big") - return + else: if frame.fin: self._incoming_message_size = 0 else: self._incoming_message_active = True - else: - self._fail_protocol_error(f"unexpected opcode while receiving message: {frame.msg_type}") - return - self._add_to_recv_queue(_BufferedFrame(frame.msg_type, frame.get_payload_as_bytes(), frame.fin)) + self._add_to_recv_queue(_BufferedFrame(frame.msg_type, payload, frame.fin)) self._pause_reading_if_needed() @cython.ccall @@ -337,6 +543,21 @@ def _wait_recv_queue_not_empty(self) -> asyncio.Future[None]: self._recv_waiter = waiter return waiter + @cython.cfunc + @cython.inline + def _configure_extensions(self) -> None: + header_value = self._response.headers.get("Sec-WebSocket-Extensions") + if header_value is None: + return + if self._compression != "deflate": + raise InvalidHandshake("unexpected websocket extensions negotiated by server") + if not isinstance(header_value, str): + raise InvalidHandshake("invalid Sec-WebSocket-Extensions header") + try: + self._permessage_deflate = _PerMessageDeflate.from_response_header(header_value) + except _CompressionError as exc: + raise InvalidHandshake(str(exc)) from exc + @cython.cfunc @cython.inline def _set_close_exception(self) -> None: @@ -496,11 +717,24 @@ async def iterator() -> AsyncIterator[Data]: return iterator() + def _encode_and_send(self, msg_type: WSMsgType, message: Data, fin: cython.bint) -> None: + rsv1: cython.bint = False + if self._permessage_deflate is not None: + message, rsv1 = self._permessage_deflate.encode_frame( + msg_type, self._compression_payload(message), fin + ) + + self.transport.send(msg_type, message, fin, rsv1) + async def send( self, message: Union[DataLike, Iterable[DataLike], AsyncIterator[DataLike]], text: Optional[bool] = None, ) -> None: + # Catch a common mistake -- passing a dict to send(). + if isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + if self._state is State.CLOSED: raise self._connection_closed() @@ -515,7 +749,9 @@ async def send( msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT else: msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY - self.transport.send(msg_type, message) + + self._encode_and_send(msg_type, message, True) + if self._write_ready is not None: await self._write_ready elif isinstance(message, (AsyncIterable, Iterable)): @@ -535,6 +771,13 @@ def _check_fragment_type(self, message: DataLike, first_is_str: cython.bint) -> raise TypeError("all fragments must be of the same category: str vs bytes-like") + @cython.cfunc + @cython.inline + def _compression_payload(self, message: DataLike) -> BytesLike: + if isinstance(message, str): + return message.encode("utf-8") + return message + async def _send_fragments( self, messages: Union[AsyncIterable[DataLike], Iterable[DataLike]], @@ -559,24 +802,30 @@ async def _send_fragments( try: try: if is_async: - first = await anext(async_iterator) + current = await anext(async_iterator) else: - first = next(iterator) + current = next(iterator) except stop_exception_type: raise TypeError("message iterable cannot be empty") from None first_is_str: cython.bint - if isinstance(first, str): + if isinstance(current, str): msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT first_is_str = True - elif isinstance(first, (bytes, bytearray, memoryview)): + elif isinstance(current, (bytes, bytearray, memoryview)): msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY first_is_str = False else: - raise TypeError(f"message must contain str or bytes-like objects, got {type(first).__name__}") + raise TypeError(f"message must contain str or bytes-like objects, got {type(current).__name__}") - previous = first while True: + # Original websockets implementations always send one last empty + # frame with fin=True even if iterator returns only one fragment + # Perhaps this is useful for the users, just replicate this + # behavior. + self._encode_and_send(msg_type, current, False) + msg_type = WSMsgType.CONTINUATION + try: if is_async: current = await anext(async_iterator) @@ -586,18 +835,14 @@ async def _send_fragments( break self._check_fragment_type(current, first_is_str) - - self.transport.send(msg_type, previous, fin=False) - msg_type = WSMsgType.CONTINUATION if self._write_ready is not None: await self._write_ready - previous = current - - self.transport.send(msg_type, previous, fin=True) + # Send the last empty frame with fin=True + self._encode_and_send(msg_type, b"", True) if self._write_ready is not None: await self._write_ready - except: + except Exception: self._fail_protocol_error("error in fragmented message") raise @@ -682,6 +927,18 @@ async def _iterate_messages(self) -> AsyncIterator[Data]: def state(self) -> State: return self._state + @property + def request(self) -> Request: + return self._request + + @property + def response(self) -> Response: + return self._response + + @property + def connect_exception(self) -> Optional[Exception]: + return self._connect_exception + @property def local_address(self) -> Any: return self.transport.underlying_transport.get_extra_info("sockname") diff --git a/tests/test_websockets_compat.py b/tests/test_websockets_compat.py index cad5ab3..010ee03 100644 --- a/tests/test_websockets_compat.py +++ b/tests/test_websockets_compat.py @@ -52,7 +52,7 @@ async def test_recv_streaming_fragmented_message(): fragments = [] async for fragment in ws.recv_streaming(): fragments.append(fragment) - assert fragments == [b"ab", b"cd"] + assert fragments == [b"ab", b"cd", b""] async def test_subprotocol_header_and_property(): diff --git a/tests/test_websockets_compression_compat.py b/tests/test_websockets_compression_compat.py new file mode 100644 index 0000000..954e278 --- /dev/null +++ b/tests/test_websockets_compression_compat.py @@ -0,0 +1,61 @@ +from contextlib import asynccontextmanager + +import websockets as upstream_websockets + +from picows import websockets + + +@asynccontextmanager +async def upstream_server(handler): + server = await upstream_websockets.serve( + handler, + "127.0.0.1", + 0, + compression="deflate", + ) + port = server.sockets[0].getsockname()[1] + try: + yield f"ws://127.0.0.1:{port}/" + finally: + server.close() + await server.wait_closed() + + +async def test_permessage_deflate_echo_with_upstream_server(): + async def handler(ws): + async for message in ws: + await ws.send(message) + + async with upstream_server(handler) as url: + async with websockets.connect(url, ping_interval=None) as ws: + assert "permessage-deflate" in (ws.response.headers.get("Sec-WebSocket-Extensions") or "") + + message = "hello " * 1000 + await ws.send(message) + assert await ws.recv() == message + + +async def test_permessage_deflate_fragmented_send_with_upstream_server(): + async def handler(ws): + async for message in ws: + await ws.send(message) + + async with upstream_server(handler) as url: + async with websockets.connect(url, ping_interval=None) as ws: + await ws.send([b"a" * 300, b"b" * 300, b"c" * 300]) + assert await ws.recv() == (b"a" * 300 + b"b" * 300 + b"c" * 300) + + +async def test_permessage_deflate_recv_streaming_from_upstream_server(): + chunks = [b"ab" * 300, b"cd" * 300, b"ef" * 300] + + async def handler(ws): + await ws.send(chunks) + await ws.close() + + async with upstream_server(handler) as url: + async with websockets.connect(url, ping_interval=None) as ws: + fragments = [] + async for fragment in ws.recv_streaming(): + fragments.append(fragment) + assert fragments == chunks + [b""] From f008c78a29faf50e41960a8451211f64058df4bd Mon Sep 17 00:00:00 2001 From: taras Date: Mon, 4 May 2026 15:30:08 +0200 Subject: [PATCH 26/57] Optimizations --- picows/websockets/asyncio/connection.py | 106 ++++++++++++------------ 1 file changed, 54 insertions(+), 52 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 2d8b9b3..6fe1458 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -33,7 +33,7 @@ OK_CLOSE_CODES = {0, 1000, 1001} -_EMPTY_UNCOMPRESSED_BLOCK = b"\x00\x00\xff\xff" +_EMPTY_UNCOMPRESSED_BLOCK = cython.declare(bytes, b"\x00\x00\xff\xff") class State(IntEnum): @@ -43,61 +43,43 @@ class State(IntEnum): CLOSED = 3 +@cython.freelist(128) +@cython.no_gc @cython.cclass class _BufferedFrame: msg_type: WSMsgType payload: bytes - fin: bool + fin: cython.bint - def __init__(self, msg_type: WSMsgType, payload: bytes, fin: bool): - self.msg_type = msg_type - self.payload = payload - self.fin = fin + +@cython.cfunc +@cython.inline +def _make_buffered_frame(msg_type: WSMsgType, payload: bytes, fin: cython.bint) -> _BufferedFrame: + self: _BufferedFrame = _BufferedFrame.__new__(_BufferedFrame) + self.msg_type = msg_type + self.payload = payload + self.fin = fin + return self class _CompressionError(Exception): pass +@cython.no_gc @cython.cclass class _PerMessageDeflate: remote_no_context_takeover: cython.bint local_no_context_takeover: cython.bint - remote_max_window_bits: cython.int - local_max_window_bits: cython.int + remote_max_window_bits: int + local_max_window_bits: int + _zlib_compressobj: Any + _zlib_decompressobj: Any + _zlib_z_sync_flush: Any _decoder: Any _encoder: Any _decode_cont_data: cython.int - def __init__( - self, - *, - remote_no_context_takeover: bool, - local_no_context_takeover: bool, - remote_max_window_bits: int, - local_max_window_bits: int, - ): - self.remote_no_context_takeover = remote_no_context_takeover - self.local_no_context_takeover = local_no_context_takeover - self.remote_max_window_bits = remote_max_window_bits - self.local_max_window_bits = local_max_window_bits - self._decoder = None - self._encoder = None - self._decode_cont_data = False - - # wbits: +9 to +15 - # The base-two logarithm of the window size, which therefore ranges between 512 and 32768. - # Larger values produce better compression at the expense of greater memory usage. - # The resulting output will include a zlib-specific header and trailer. - # Negative wbits: - # Uses the absolute value of wbits as the window size logarithm, - # while producing a raw output stream with no header or trailing checksum. - - if not self.remote_no_context_takeover: - self._decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) - if not self.local_no_context_takeover: - self._encoder = zlib.compressobj(wbits=-self.local_max_window_bits) - @classmethod def from_response_header(cls, header_value: str) -> _PerMessageDeflate: extensions = [item.strip() for item in header_value.split(",") if item.strip()] @@ -152,12 +134,32 @@ def from_response_header(cls, header_value: str) -> _PerMessageDeflate: else: raise _CompressionError(f"unsupported extension parameter: {name}") - return cls( - remote_no_context_takeover=server_no_context_takeover, - local_no_context_takeover=client_no_context_takeover, - remote_max_window_bits=server_max_window_bits or 15, - local_max_window_bits=client_max_window_bits or 15, - ) + self: _PerMessageDeflate = _PerMessageDeflate.__new__(_PerMessageDeflate) + self.remote_no_context_takeover = server_no_context_takeover + self.local_no_context_takeover = client_no_context_takeover + self.remote_max_window_bits = -(server_max_window_bits or 15) + self.local_max_window_bits = -(client_max_window_bits or 15) + self._zlib_compressobj = zlib.compressobj + self._zlib_decompressobj = zlib.decompressobj + self._zlib_z_sync_flush = zlib.Z_SYNC_FLUSH + self._decoder = None + self._encoder = None + self._decode_cont_data = False + + # wbits: +9 to +15 + # The base-two logarithm of the window size, which therefore ranges between 512 and 32768. + # Larger values produce better compression at the expense of greater memory usage. + # The resulting output will include a zlib-specific header and trailer. + # Negative wbits: + # Uses the absolute value of wbits as the window size logarithm, + # while producing a raw output stream with no header or trailing checksum. + + if not self.remote_no_context_takeover: + self._decoder = self._zlib_decompressobj(wbits=self.remote_max_window_bits) + if not self.local_no_context_takeover: + self._encoder = self._zlib_compressobj(wbits=self.local_max_window_bits) + + return self @cython.cfunc @cython.inline @@ -178,7 +180,7 @@ def decode_frame(self, frame: WSFrame, max_length: cython.Py_ssize_t) -> bytes: if not frame.fin: self._decode_cont_data = True if self.remote_no_context_takeover or self._decoder is None: - self._decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) + self._decoder = self._zlib_decompressobj(wbits=self.remote_max_window_bits) assert self._decoder is not None try: @@ -201,12 +203,13 @@ def decode_frame(self, frame: WSFrame, max_length: cython.Py_ssize_t) -> bytes: @cython.cfunc @cython.inline - def encode_frame(self, msg_type: WSMsgType, payload: BytesLike, fin: cython.bint) -> tuple[BytesLike, cython.bint]: + @cython.wraparound(True) + def encode_frame(self, msg_type: WSMsgType, payload: BytesLike, fin: cython.bint) -> BytesLike: if msg_type != WSMsgType.CONTINUATION and (self.local_no_context_takeover or self._encoder is None): - self._encoder = zlib.compressobj(wbits=-self.local_max_window_bits) + self._encoder = self._zlib_compressobj(wbits=self.local_max_window_bits) data: BytesLike = (self._encoder.compress(payload) + - self._encoder.flush(zlib.Z_SYNC_FLUSH)) + self._encoder.flush(self._zlib_z_sync_flush)) if fin: data_mv = memoryview(data) assert data_mv[-4:] == _EMPTY_UNCOMPRESSED_BLOCK @@ -214,7 +217,7 @@ def encode_frame(self, msg_type: WSMsgType, payload: BytesLike, fin: cython.bint if self.local_no_context_takeover: self._encoder = None - return data, msg_type != WSMsgType.CONTINUATION + return data @cython.cfunc @@ -485,7 +488,7 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: else: self._incoming_message_active = True - self._add_to_recv_queue(_BufferedFrame(frame.msg_type, payload, frame.fin)) + self._add_to_recv_queue(_make_buffered_frame(frame.msg_type, payload, frame.fin)) self._pause_reading_if_needed() @cython.ccall @@ -718,13 +721,12 @@ async def iterator() -> AsyncIterator[Data]: return iterator() def _encode_and_send(self, msg_type: WSMsgType, message: Data, fin: cython.bint) -> None: - rsv1: cython.bint = False if self._permessage_deflate is not None: - message, rsv1 = self._permessage_deflate.encode_frame( + message = self._permessage_deflate.encode_frame( msg_type, self._compression_payload(message), fin ) - self.transport.send(msg_type, message, fin, rsv1) + self.transport.send(msg_type, message, fin, msg_type != WSMsgType.CONTINUATION) async def send( self, From 27c5f2f4d766c2b9fbc88753cfc4e6ecce34cf23 Mon Sep 17 00:00:00 2001 From: taras Date: Mon, 4 May 2026 17:02:40 +0200 Subject: [PATCH 27/57] Fix mypy --- picows/websockets/asyncio/connection.py | 38 ++++++++++++------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 6fe1458..09674b4 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -33,7 +33,6 @@ OK_CLOSE_CODES = {0, 1000, 1001} -_EMPTY_UNCOMPRESSED_BLOCK = cython.declare(bytes, b"\x00\x00\xff\xff") class State(IntEnum): @@ -66,6 +65,13 @@ class _CompressionError(Exception): pass +# zlib/compress/decompress utils, cached for performance +_empty_uncompressed_block = cython.declare(bytes, b"\x00\x00\xff\xff") +_zlib_compressobj = cython.declare(object, zlib.compressobj) +_zlib_decompressobj = cython.declare(object, zlib.decompressobj) +_zlib_z_sync_flush = cython.declare(object, zlib.Z_SYNC_FLUSH) + + @cython.no_gc @cython.cclass class _PerMessageDeflate: @@ -73,12 +79,9 @@ class _PerMessageDeflate: local_no_context_takeover: cython.bint remote_max_window_bits: int local_max_window_bits: int - _zlib_compressobj: Any - _zlib_decompressobj: Any - _zlib_z_sync_flush: Any _decoder: Any _encoder: Any - _decode_cont_data: cython.int + _decode_cont_data: cython.bint @classmethod def from_response_header(cls, header_value: str) -> _PerMessageDeflate: @@ -139,9 +142,6 @@ def from_response_header(cls, header_value: str) -> _PerMessageDeflate: self.local_no_context_takeover = client_no_context_takeover self.remote_max_window_bits = -(server_max_window_bits or 15) self.local_max_window_bits = -(client_max_window_bits or 15) - self._zlib_compressobj = zlib.compressobj - self._zlib_decompressobj = zlib.decompressobj - self._zlib_z_sync_flush = zlib.Z_SYNC_FLUSH self._decoder = None self._encoder = None self._decode_cont_data = False @@ -155,9 +155,9 @@ def from_response_header(cls, header_value: str) -> _PerMessageDeflate: # while producing a raw output stream with no header or trailing checksum. if not self.remote_no_context_takeover: - self._decoder = self._zlib_decompressobj(wbits=self.remote_max_window_bits) + self._decoder = _zlib_decompressobj(wbits=self.remote_max_window_bits) if not self.local_no_context_takeover: - self._encoder = self._zlib_compressobj(wbits=self.local_max_window_bits) + self._encoder = _zlib_compressobj(wbits=self.local_max_window_bits) return self @@ -171,16 +171,16 @@ def decode_frame(self, frame: WSFrame, max_length: cython.Py_ssize_t) -> bytes: if frame.rsv1: raise _CompressionError("unexpected rsv1 on continuation frame") if not self._decode_cont_data: - return frame.get_payload_as_bytes() + return frame.get_payload_as_bytes() # type: ignore[no-any-return] if frame.fin: self._decode_cont_data = False else: if not frame.rsv1: - return frame.get_payload_as_bytes() + return frame.get_payload_as_bytes() # type: ignore[no-any-return] if not frame.fin: self._decode_cont_data = True if self.remote_no_context_takeover or self._decoder is None: - self._decoder = self._zlib_decompressobj(wbits=self.remote_max_window_bits) + self._decoder = _zlib_decompressobj(wbits=self.remote_max_window_bits) assert self._decoder is not None try: @@ -190,7 +190,7 @@ def decode_frame(self, frame: WSFrame, max_length: cython.Py_ssize_t) -> bytes: raise _CompressionError("message too big") if frame.fin: - data2 = self._decoder.decompress(_EMPTY_UNCOMPRESSED_BLOCK, max_length) + data2 = self._decoder.decompress(_empty_uncompressed_block, max_length) if data2: data += data2 except zlib.error as exc: @@ -206,13 +206,13 @@ def decode_frame(self, frame: WSFrame, max_length: cython.Py_ssize_t) -> bytes: @cython.wraparound(True) def encode_frame(self, msg_type: WSMsgType, payload: BytesLike, fin: cython.bint) -> BytesLike: if msg_type != WSMsgType.CONTINUATION and (self.local_no_context_takeover or self._encoder is None): - self._encoder = self._zlib_compressobj(wbits=self.local_max_window_bits) + self._encoder = _zlib_compressobj(wbits=self.local_max_window_bits) data: BytesLike = (self._encoder.compress(payload) + - self._encoder.flush(self._zlib_z_sync_flush)) + self._encoder.flush(_zlib_z_sync_flush)) if fin: data_mv = memoryview(data) - assert data_mv[-4:] == _EMPTY_UNCOMPRESSED_BLOCK + assert data_mv[-4:] == _empty_uncompressed_block data = data_mv[:-4] if self.local_no_context_takeover: self._encoder = None @@ -675,7 +675,7 @@ async def recv(self, decode: Optional[bool] = None) -> Data: payloads.append(frame.payload) payload = b"".join(payloads) - return self._decode_data(payload, msg_type, decode) + return self._decode_data(payload, msg_type, decode) # type: ignore[no-any-return] except asyncio.CancelledError: self._recv_queue.extendleft(reversed(frames)) raise @@ -720,7 +720,7 @@ async def iterator() -> AsyncIterator[Data]: return iterator() - def _encode_and_send(self, msg_type: WSMsgType, message: Data, fin: cython.bint) -> None: + def _encode_and_send(self, msg_type: WSMsgType, message: DataLike, fin: cython.bint) -> None: if self._permessage_deflate is not None: message = self._permessage_deflate.encode_frame( msg_type, self._compression_payload(message), fin From caf454f22ac70c470cf2a38e0148f2d5c3917658 Mon Sep 17 00:00:00 2001 From: taras Date: Mon, 4 May 2026 17:35:48 +0200 Subject: [PATCH 28/57] Optimizations --- picows/websockets/asyncio/connection.py | 78 ++++++++++--------------- 1 file changed, 30 insertions(+), 48 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 09674b4..8110d30 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -20,6 +20,8 @@ else: from picows import WSListener, WSTransport, WSFrame, WSMsgType, WSCloseCode +from picows import WSProtocolError + from ..compat import CloseCode, Request, Response from ..exceptions import ( ConcurrencyError, @@ -61,10 +63,6 @@ def _make_buffered_frame(msg_type: WSMsgType, payload: bytes, fin: cython.bint) return self -class _CompressionError(Exception): - pass - - # zlib/compress/decompress utils, cached for performance _empty_uncompressed_block = cython.declare(bytes, b"\x00\x00\xff\xff") _zlib_compressobj = cython.declare(object, zlib.compressobj) @@ -87,11 +85,11 @@ class _PerMessageDeflate: def from_response_header(cls, header_value: str) -> _PerMessageDeflate: extensions = [item.strip() for item in header_value.split(",") if item.strip()] if len(extensions) != 1: - raise _CompressionError("unsupported websocket extension negotiation") + raise InvalidHandshake("unsupported websocket extension negotiation") parts = [item.strip() for item in extensions[0].split(";")] if not parts or parts[0] != "permessage-deflate": - raise _CompressionError("unsupported websocket extension negotiation") + raise InvalidHandshake("unsupported websocket extension negotiation") server_no_context_takeover = False client_no_context_takeover = False @@ -111,31 +109,32 @@ def from_response_header(cls, header_value: str) -> _PerMessageDeflate: value = None if name in seen: - raise _CompressionError(f"duplicate extension parameter: {name}") + raise InvalidHandshake( + f"unsupported websocket extension negotiation: {name}") seen.add(name) if name == "server_no_context_takeover": if value is not None: - raise _CompressionError("invalid server_no_context_takeover value") + raise InvalidHandshake("invalid server_no_context_takeover value") server_no_context_takeover = True elif name == "client_no_context_takeover": if value is not None: - raise _CompressionError("invalid client_no_context_takeover value") + raise InvalidHandshake("invalid client_no_context_takeover value") client_no_context_takeover = True elif name == "server_max_window_bits": if value is None or not value.isdigit(): - raise _CompressionError("invalid server_max_window_bits value") + raise InvalidHandshake("invalid server_max_window_bits value") server_max_window_bits = int(value) if not 8 <= server_max_window_bits <= 15: - raise _CompressionError("invalid server_max_window_bits value") + raise InvalidHandshake("invalid server_max_window_bits value") elif name == "client_max_window_bits": if value is None or not value.isdigit(): - raise _CompressionError("invalid client_max_window_bits value") + raise InvalidHandshake("invalid client_max_window_bits value") client_max_window_bits = int(value) if not 8 <= client_max_window_bits <= 15: - raise _CompressionError("invalid client_max_window_bits value") + raise InvalidHandshake("invalid client_max_window_bits value") else: - raise _CompressionError(f"unsupported extension parameter: {name}") + raise InvalidHandshake(f"unsupported extension parameter: {name}") self: _PerMessageDeflate = _PerMessageDeflate.__new__(_PerMessageDeflate) self.remote_no_context_takeover = server_no_context_takeover @@ -169,7 +168,7 @@ def decode_frame(self, frame: WSFrame, max_length: cython.Py_ssize_t) -> bytes: if frame.msg_type == WSMsgType.CONTINUATION: if frame.rsv1: - raise _CompressionError("unexpected rsv1 on continuation frame") + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, "unexpected rsv1 on continuation frame") if not self._decode_cont_data: return frame.get_payload_as_bytes() # type: ignore[no-any-return] if frame.fin: @@ -185,16 +184,19 @@ def decode_frame(self, frame: WSFrame, max_length: cython.Py_ssize_t) -> bytes: assert self._decoder is not None try: data = self._decoder.decompress(frame.get_payload_as_memoryview(), max_length) + max_length -= len(data) if self._decoder.unconsumed_tail: - raise _CompressionError("message too big") + raise WSProtocolError(WSCloseCode.MESSAGE_TOO_BIG, + "message too big") if frame.fin: data2 = self._decoder.decompress(_empty_uncompressed_block, max_length) if data2: data += data2 except zlib.error as exc: - raise _CompressionError("decompression failed") from exc + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, + "decompression failed") from exc if frame.fin and self.remote_no_context_takeover: self._decoder = None @@ -386,7 +388,7 @@ def on_ws_connected(self, transport: WSTransport) -> None: self._configure_extensions() except InvalidHandshake as exc: self._connect_exception = exc - self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, str(exc).encode("utf-8")) + self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, str(exc)) self.transport.disconnect(False) return self._state = State.OPEN @@ -443,40 +445,28 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: return if frame.msg_type not in (WSMsgType.TEXT, WSMsgType.BINARY, WSMsgType.CONTINUATION): - self._fail_protocol_error("unsupported frame opcode") - return + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, "unsupported frame opcode") if self._permessage_deflate is None and frame.rsv1: - self._fail_protocol_error("received compressed frame without negotiated permessage-deflate") - return + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, "received compressed frame without negotiated permessage-deflate") if frame.msg_type == WSMsgType.CONTINUATION and not self._incoming_message_active: - self._fail_protocol_error("unexpected continuation frame") - return + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, "unexpected continuation frame") if frame.msg_type != WSMsgType.CONTINUATION and self._incoming_message_active: - self._fail_protocol_error("expected continuation frame") - return + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, "expected continuation frame") payload: bytes if self._permessage_deflate is not None: remaining = 0 if self._max_message_size == 0 else ( max(self._max_message_size - self._incoming_message_size, 0)) - try: - payload = self._permessage_deflate.decode_frame(frame, remaining) - except _CompressionError as exc: - if str(exc) == "message too big": - self._fail_message_too_big("message too big") - else: - self._fail_protocol_error(str(exc)) - return + payload = self._permessage_deflate.decode_frame(frame, remaining) else: payload = frame.get_payload_as_bytes() self._incoming_message_size = len(payload) if self._max_message_size > 0 and self._incoming_message_size > self._max_message_size: - self._fail_message_too_big("message too big") - return + raise WSProtocolError(WSCloseCode.MESSAGE_TOO_BIG, "message too big") if frame.msg_type == WSMsgType.CONTINUATION: if frame.fin: @@ -556,10 +546,8 @@ def _configure_extensions(self) -> None: raise InvalidHandshake("unexpected websocket extensions negotiated by server") if not isinstance(header_value, str): raise InvalidHandshake("invalid Sec-WebSocket-Extensions header") - try: - self._permessage_deflate = _PerMessageDeflate.from_response_header(header_value) - except _CompressionError as exc: - raise InvalidHandshake(str(exc)) from exc + + self._permessage_deflate = _PerMessageDeflate.from_response_header(header_value) @cython.cfunc @cython.inline @@ -590,13 +578,7 @@ def _connection_closed(self) -> ConnectionClosed: @cython.cfunc @cython.inline def _fail_protocol_error(self, message: str) -> None: - self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, message.encode("utf-8")) - self.transport.disconnect(False) - - @cython.cfunc - @cython.inline - def _fail_message_too_big(self, message: str) -> None: - self.transport.send_close(WSCloseCode.MESSAGE_TOO_BIG, message.encode("utf-8")) + self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, message) self.transport.disconnect(False) @cython.cfunc @@ -853,7 +835,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: return if self._state is State.OPEN: self._state = State.CLOSING - self.transport.send_close(code, reason.encode("utf-8")) + self.transport.send_close(code, reason) try: if self._close_timeout is None: await self.wait_closed() From fe03d7fb1113593a143d22a6ff7a65a6f01f3506 Mon Sep 17 00:00:00 2001 From: taras Date: Mon, 4 May 2026 17:55:41 +0200 Subject: [PATCH 29/57] Fix tests --- picows/websockets/asyncio/connection.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 8110d30..39b6c49 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -707,8 +707,9 @@ def _encode_and_send(self, msg_type: WSMsgType, message: DataLike, fin: cython.b message = self._permessage_deflate.encode_frame( msg_type, self._compression_payload(message), fin ) - - self.transport.send(msg_type, message, fin, msg_type != WSMsgType.CONTINUATION) + self.transport.send(msg_type, message, fin, msg_type != WSMsgType.CONTINUATION) + else: + self.transport.send(msg_type, message, fin) async def send( self, From c5690c6f5afe6baf01b20a4381fc13e09410bab4 Mon Sep 17 00:00:00 2001 From: taras Date: Mon, 4 May 2026 18:05:48 +0200 Subject: [PATCH 30/57] Clean up logic --- picows/websockets/asyncio/connection.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 39b6c49..1c44791 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -181,14 +181,14 @@ def decode_frame(self, frame: WSFrame, max_length: cython.Py_ssize_t) -> bytes: if self.remote_no_context_takeover or self._decoder is None: self._decoder = _zlib_decompressobj(wbits=self.remote_max_window_bits) - assert self._decoder is not None try: data = self._decoder.decompress(frame.get_payload_as_memoryview(), max_length) - max_length -= len(data) + if max_length > 0: + max_length -= len(data) - if self._decoder.unconsumed_tail: - raise WSProtocolError(WSCloseCode.MESSAGE_TOO_BIG, - "message too big") + if self._decoder.unconsumed_tail: + raise WSProtocolError(WSCloseCode.MESSAGE_TOO_BIG, + "message too big") if frame.fin: data2 = self._decoder.decompress(_empty_uncompressed_block, max_length) From 9ef22c6bf27fc935f151e760e46344a832e59fc8 Mon Sep 17 00:00:00 2001 From: taras Date: Mon, 4 May 2026 18:15:56 +0200 Subject: [PATCH 31/57] Optimize --- picows/websockets/asyncio/connection.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 1c44791..be94998 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -716,10 +716,6 @@ async def send( message: Union[DataLike, Iterable[DataLike], AsyncIterator[DataLike]], text: Optional[bool] = None, ) -> None: - # Catch a common mistake -- passing a dict to send(). - if isinstance(message, Mapping): - raise TypeError("data is a dict-like object") - if self._state is State.CLOSED: raise self._connection_closed() @@ -739,6 +735,9 @@ async def send( if self._write_ready is not None: await self._write_ready + # Catch a common mistake -- passing a dict to send(). + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") elif isinstance(message, (AsyncIterable, Iterable)): await self._send_fragments(message, text) # type: ignore[arg-type] else: From 448c1a305b662f91b50e801c91a2565332ffacc1 Mon Sep 17 00:00:00 2001 From: taras Date: Mon, 4 May 2026 18:25:27 +0200 Subject: [PATCH 32/57] Cleanup --- picows/websockets/asyncio/connection.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index be94998..e59ce3a 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -34,9 +34,6 @@ from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol -OK_CLOSE_CODES = {0, 1000, 1001} - - class State(IntEnum): CONNECTING = 0 OPEN = 1 @@ -63,6 +60,10 @@ def _make_buffered_frame(msg_type: WSMsgType, payload: bytes, fin: cython.bint) return self +# cached for performance +_ok_close_codes = cython.declare(set, {0, 1000, 1001}) + + # zlib/compress/decompress utils, cached for performance _empty_uncompressed_block = cython.declare(bytes, b"\x00\x00\xff\xff") _zlib_compressobj = cython.declare(object, zlib.compressobj) @@ -562,8 +563,8 @@ def _set_close_exception(self) -> None: rcvd_code = _coerce_close_code(rcvd.code) if rcvd is not None else None sent_code = _coerce_close_code(sent.code) if sent is not None else None ok = ( - (rcvd_code in OK_CLOSE_CODES or rcvd_code is None) - and (sent_code in OK_CLOSE_CODES or sent_code is None) + (rcvd_code in _ok_close_codes or rcvd_code is None) + and (sent_code in _ok_close_codes or sent_code is None) ) exc_type = ConnectionClosedOK if ok else ConnectionClosedError self._close_exc = exc_type(rcvd, sent, rcvd_then_sent) From b40b54a2464c4ef0eed24f0775e17e314ff07c18 Mon Sep 17 00:00:00 2001 From: taras Date: Mon, 4 May 2026 18:52:19 +0200 Subject: [PATCH 33/57] Test some edge cases --- picows/websockets/asyncio/connection.py | 13 ++-- tests/test_websockets_decode_edge_cases.py | 50 ++++++++++++++ tests/test_websockets_recv_edge_cases.py | 78 ++++++++++++++++++++++ tests/test_websockets_send_edge_cases.py | 50 ++++++++++++++ 4 files changed, 184 insertions(+), 7 deletions(-) create mode 100644 tests/test_websockets_decode_edge_cases.py create mode 100644 tests/test_websockets_recv_edge_cases.py create mode 100644 tests/test_websockets_send_edge_cases.py diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index e59ce3a..13ea423 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -667,8 +667,6 @@ async def recv(self, decode: Optional[bool] = None) -> Data: self._recv_waiter = None def recv_streaming(self, decode: Optional[bool] = None) -> AsyncIterator[Data]: - self._set_recv_in_progress() - msg_started: cython.bint = False msg_finished: cython.bint = False frame: _BufferedFrame @@ -678,6 +676,7 @@ async def iterator() -> AsyncIterator[Data]: nonlocal msg_started, msg_finished try: + self._set_recv_in_progress() if not self._recv_queue: await self._wait_recv_queue_not_empty() frame = self._check_frame(self._recv_queue.popleft()) @@ -791,7 +790,7 @@ async def _send_fragments( else: current = next(iterator) except stop_exception_type: - raise TypeError("message iterable cannot be empty") from None + return first_is_str: cython.bint if isinstance(current, str): @@ -849,7 +848,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: async def wait_closed(self) -> None: await self.transport.wait_disconnected() - async def ping(self, data: Optional[Union[str, bytes]] = None) -> Awaitable[float]: + async def ping(self, data: Optional[DataLike] = None) -> Awaitable[float]: if self._state is State.CLOSED: raise self._connection_closed() if data is None: @@ -859,10 +858,10 @@ async def ping(self, data: Optional[Union[str, bytes]] = None) -> Awaitable[floa break elif isinstance(data, str): payload = data.encode("utf-8") - elif isinstance(data, bytes): - payload = data + elif isinstance(data, (bytes, bytearray, memoryview)): + payload = bytes(data) else: - raise TypeError("ping payload must be str, bytes, or None") + raise TypeError("ping payload must be str, bytes-like, or None") if payload in self._pending_pings: raise ConcurrencyError("another ping was sent with the same data") diff --git a/tests/test_websockets_decode_edge_cases.py b/tests/test_websockets_decode_edge_cases.py new file mode 100644 index 0000000..13cf5a7 --- /dev/null +++ b/tests/test_websockets_decode_edge_cases.py @@ -0,0 +1,50 @@ +import picows + +from picows import websockets +from tests.utils import WSServer + + +class SendTextOnConnect(picows.WSListener): + def __init__(self, payload: bytes): + self._payload = payload + + def on_ws_connected(self, transport: picows.WSTransport): + transport.send(picows.WSMsgType.TEXT, self._payload) + + +class SendBinaryOnConnect(picows.WSListener): + def __init__(self, payload: bytes): + self._payload = payload + + def on_ws_connected(self, transport: picows.WSTransport): + transport.send(picows.WSMsgType.BINARY, self._payload) + + +async def test_recv_decode_false_returns_bytes_for_text_messages(): + async with WSServer(lambda _: SendTextOnConnect(b"hello")) as server: + async with websockets.connect(server.url, compression=None) as ws: + assert await ws.recv(decode=False) == b"hello" + + +async def test_recv_decode_true_returns_text_for_binary_messages(): + async with WSServer(lambda _: SendBinaryOnConnect(b"hello")) as server: + async with websockets.connect(server.url, compression=None) as ws: + assert await ws.recv(decode=True) == "hello" + + +async def test_recv_streaming_decode_false_returns_bytes_for_text_messages(): + async with WSServer(lambda _: SendTextOnConnect(b"hello")) as server: + async with websockets.connect(server.url, compression=None) as ws: + fragments = [] + async for fragment in ws.recv_streaming(decode=False): + fragments.append(fragment) + assert fragments == [b"hello"] + + +async def test_recv_streaming_decode_true_returns_text_for_binary_messages(): + async with WSServer(lambda _: SendBinaryOnConnect(b"hello")) as server: + async with websockets.connect(server.url, compression=None) as ws: + fragments = [] + async for fragment in ws.recv_streaming(decode=True): + fragments.append(fragment) + assert fragments == ["hello"] diff --git a/tests/test_websockets_recv_edge_cases.py b/tests/test_websockets_recv_edge_cases.py new file mode 100644 index 0000000..6c22342 --- /dev/null +++ b/tests/test_websockets_recv_edge_cases.py @@ -0,0 +1,78 @@ +import asyncio + +import picows +import pytest + +from picows import websockets +from tests.utils import WSServer + + +class FragmentedTextListener(picows.WSListener): + def __init__(self, allow_first_fragment: asyncio.Event, allow_second_fragment: asyncio.Event): + self._allow_first_fragment = allow_first_fragment + self._allow_second_fragment = allow_second_fragment + self.transport = None + + def on_ws_connected(self, transport: picows.WSTransport): + self.transport = transport + + async def send_fragments(): + await self._allow_first_fragment.wait() + transport.send(picows.WSMsgType.TEXT, b"he", fin=False) + await self._allow_second_fragment.wait() + transport.send(picows.WSMsgType.CONTINUATION, b"llo", fin=True) + + asyncio.create_task(send_fragments()) + + +async def test_recv_cancellation_is_safe_for_fragmented_message(): + allow_first_fragment = asyncio.Event() + allow_second_fragment = asyncio.Event() + + async with WSServer(lambda _: FragmentedTextListener(allow_first_fragment, allow_second_fragment)) as server: + async with websockets.connect(server.url, compression=None) as ws: + recv_task = asyncio.create_task(ws.recv()) + allow_first_fragment.set() + await asyncio.sleep(0) + recv_task.cancel() + with pytest.raises(asyncio.CancelledError): + await recv_task + + allow_second_fragment.set() + assert await ws.recv() == "hello" + + +async def test_recv_streaming_cancellation_before_first_fragment_is_safe(): + allow_first_fragment = asyncio.Event() + allow_second_fragment = asyncio.Event() + + async with WSServer(lambda _: FragmentedTextListener(allow_first_fragment, allow_second_fragment)) as server: + async with websockets.connect(server.url, compression=None) as ws: + iterator = ws.recv_streaming() + recv_task = asyncio.create_task(anext(iterator)) + recv_task.cancel() + with pytest.raises(asyncio.CancelledError): + await recv_task + + allow_first_fragment.set() + allow_second_fragment.set() + assert await ws.recv() == "hello" + + +async def test_recv_streaming_partial_consumption_breaks_future_receives(): + allow_first_fragment = asyncio.Event() + allow_second_fragment = asyncio.Event() + + async with WSServer(lambda _: FragmentedTextListener(allow_first_fragment, allow_second_fragment)) as server: + async with websockets.connect(server.url, compression=None) as ws: + iterator = ws.recv_streaming() + allow_first_fragment.set() + assert await anext(iterator) == "he" + + with pytest.raises(websockets.ConcurrencyError): + await ws.recv() + + allow_second_fragment.set() + + with pytest.raises(websockets.ConcurrencyError): + await ws.recv() diff --git a/tests/test_websockets_send_edge_cases.py b/tests/test_websockets_send_edge_cases.py new file mode 100644 index 0000000..4a6080b --- /dev/null +++ b/tests/test_websockets_send_edge_cases.py @@ -0,0 +1,50 @@ +import asyncio +from collections.abc import AsyncIterator + +import pytest + +from picows import websockets +from tests.utils import WSServer + + +async def test_send_empty_iterable_is_noop(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send([]) + pong_waiter = await ws.ping(b"noop") + await asyncio.wait_for(pong_waiter, 1.0) + + +async def test_send_empty_async_iterable_is_noop(): + async def fragments() -> AsyncIterator[bytes]: + if False: + yield b"never" + + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send(fragments()) + pong_waiter = await ws.ping(b"noop") + await asyncio.wait_for(pong_waiter, 1.0) + + +async def test_send_rejects_dict_like_objects(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + with pytest.raises(TypeError, match="dict-like object"): + await ws.send({"a": 1}) + + +async def test_ping_accepts_byteslike_payloads(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + pong_waiter = await ws.ping(bytearray(b"abcd")) + await asyncio.wait_for(pong_waiter, 1.0) + pong_waiter = await ws.ping(memoryview(b"efgh")) + await asyncio.wait_for(pong_waiter, 1.0) + + +async def test_pong_accepts_byteslike_payloads(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.pong(bytearray(b"abcd")) + await ws.pong(memoryview(b"efgh")) From d9ecf9b395fdd75abb7f826d412f8959f1c14b7a Mon Sep 17 00:00:00 2001 From: taras Date: Tue, 5 May 2026 00:36:14 +0200 Subject: [PATCH 34/57] Add tests for more edge cases --- ...bsockets_compression_failure_edge_cases.py | 89 +++++++++++++++++++ tests/test_websockets_decode_edge_cases.py | 8 +- ...st_websockets_decode_failure_edge_cases.py | 31 +++++++ tests/test_websockets_recv_edge_cases.py | 7 +- ...test_websockets_send_failure_edge_cases.py | 86 ++++++++++++++++++ 5 files changed, 214 insertions(+), 7 deletions(-) create mode 100644 tests/test_websockets_compression_failure_edge_cases.py create mode 100644 tests/test_websockets_decode_failure_edge_cases.py create mode 100644 tests/test_websockets_send_failure_edge_cases.py diff --git a/tests/test_websockets_compression_failure_edge_cases.py b/tests/test_websockets_compression_failure_edge_cases.py new file mode 100644 index 0000000..2d4f474 --- /dev/null +++ b/tests/test_websockets_compression_failure_edge_cases.py @@ -0,0 +1,89 @@ +import asyncio +import base64 +import hashlib +from contextlib import asynccontextmanager + +import pytest +import websockets as upstream_websockets + +from picows import websockets + + +@asynccontextmanager +async def upstream_server(handler): + server = await upstream_websockets.serve( + handler, + "127.0.0.1", + 0, + compression="deflate", + ) + port = server.sockets[0].getsockname()[1] + try: + yield f"ws://127.0.0.1:{port}/" + finally: + server.close() + await server.wait_closed() + + +@asynccontextmanager +async def malformed_compressed_server(): + async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + request = await reader.readuntil(b"\r\n\r\n") + headers = request.decode("ascii").split("\r\n") + key = None + for header in headers: + if header.lower().startswith("sec-websocket-key:"): + key = header.split(":", 1)[1].strip() + break + assert key is not None + + accept = base64.b64encode( + hashlib.sha1( + (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode("ascii") + ).digest() + ).decode("ascii") + + response = ( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + f"Sec-WebSocket-Accept: {accept}\r\n" + "Sec-WebSocket-Extensions: permessage-deflate\r\n" + "\r\n" + ) + writer.write(response.encode("ascii")) + + payload = b"not-a-valid-deflate-stream" + frame = bytes([0xC1, len(payload)]) + payload + writer.write(frame) + await writer.drain() + await asyncio.sleep(0.1) + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(handler, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + try: + yield f"ws://127.0.0.1:{port}/" + finally: + server.close() + await server.wait_closed() + + +async def test_compressed_message_exceeding_max_size_closes_connection(): + async def handler(ws): + await ws.send("a" * 10000) + + async with upstream_server(handler) as url: + async with websockets.connect( + url, ping_interval=None, max_size=1000 + ) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + +async def test_malformed_compressed_message_closes_connection(): + async with malformed_compressed_server() as url: + async with websockets.connect(url, ping_interval=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() diff --git a/tests/test_websockets_decode_edge_cases.py b/tests/test_websockets_decode_edge_cases.py index 13cf5a7..a7c17c3 100644 --- a/tests/test_websockets_decode_edge_cases.py +++ b/tests/test_websockets_decode_edge_cases.py @@ -1,22 +1,24 @@ import picows from picows import websockets -from tests.utils import WSServer +from tests.utils import ServerEchoListener, WSServer -class SendTextOnConnect(picows.WSListener): +class SendTextOnConnect(ServerEchoListener): def __init__(self, payload: bytes): self._payload = payload def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) transport.send(picows.WSMsgType.TEXT, self._payload) -class SendBinaryOnConnect(picows.WSListener): +class SendBinaryOnConnect(ServerEchoListener): def __init__(self, payload: bytes): self._payload = payload def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) transport.send(picows.WSMsgType.BINARY, self._payload) diff --git a/tests/test_websockets_decode_failure_edge_cases.py b/tests/test_websockets_decode_failure_edge_cases.py new file mode 100644 index 0000000..33025ac --- /dev/null +++ b/tests/test_websockets_decode_failure_edge_cases.py @@ -0,0 +1,31 @@ +import pytest + +import picows +from picows import websockets +from tests.utils import ServerEchoListener, WSServer + + +class SendBinaryOnConnect(ServerEchoListener): + def __init__(self, payload: bytes): + self._payload = payload + + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.BINARY, self._payload) + + +async def test_recv_decode_true_invalid_utf8_closes_connection(): + async with WSServer(lambda _: SendBinaryOnConnect(b"\xff\xfe")) as server: + async with websockets.connect(server.url, compression=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv(decode=True) + assert ws.close_code == 1007 + + +async def test_recv_streaming_decode_true_invalid_utf8_closes_connection(): + async with WSServer(lambda _: SendBinaryOnConnect(b"\xff\xfe")) as server: + async with websockets.connect(server.url, compression=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + async for _fragment in ws.recv_streaming(decode=True): + pass + assert ws.close_code == 1007 diff --git a/tests/test_websockets_recv_edge_cases.py b/tests/test_websockets_recv_edge_cases.py index 6c22342..81cc8e6 100644 --- a/tests/test_websockets_recv_edge_cases.py +++ b/tests/test_websockets_recv_edge_cases.py @@ -4,17 +4,16 @@ import pytest from picows import websockets -from tests.utils import WSServer +from tests.utils import ServerEchoListener, WSServer -class FragmentedTextListener(picows.WSListener): +class FragmentedTextListener(ServerEchoListener): def __init__(self, allow_first_fragment: asyncio.Event, allow_second_fragment: asyncio.Event): self._allow_first_fragment = allow_first_fragment self._allow_second_fragment = allow_second_fragment - self.transport = None def on_ws_connected(self, transport: picows.WSTransport): - self.transport = transport + super().on_ws_connected(transport) async def send_fragments(): await self._allow_first_fragment.wait() diff --git a/tests/test_websockets_send_failure_edge_cases.py b/tests/test_websockets_send_failure_edge_cases.py new file mode 100644 index 0000000..81629e6 --- /dev/null +++ b/tests/test_websockets_send_failure_edge_cases.py @@ -0,0 +1,86 @@ +import asyncio +from collections.abc import AsyncIterator + +import pytest + +from picows import websockets +from tests.utils import WSServer + + +class FragmentError(Exception): + pass + + +async def test_send_sync_iterable_exception_closes_connection(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + def fragments(): + yield b"first" + raise FragmentError("boom") + + with pytest.raises(FragmentError, match="boom"): + await ws.send(fragments()) + + await asyncio.wait_for(ws.wait_closed(), 1.0) + with pytest.raises(websockets.ConnectionClosed): + await ws.send(b"x") + + +async def test_send_async_iterable_exception_closes_connection(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + async def fragments() -> AsyncIterator[bytes]: + yield b"first" + raise FragmentError("boom") + + with pytest.raises(FragmentError, match="boom"): + await ws.send(fragments()) + + await asyncio.wait_for(ws.wait_closed(), 1.0) + with pytest.raises(websockets.ConnectionClosed): + await ws.send(b"x") + + +async def test_send_async_iterable_cancellation_closes_connection(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + first_sent = asyncio.Event() + unblock = asyncio.Event() + + async def fragments() -> AsyncIterator[bytes]: + yield b"first" + first_sent.set() + await unblock.wait() + yield b"second" + + send_task = asyncio.create_task(ws.send(fragments())) + await asyncio.wait_for(first_sent.wait(), 1.0) + send_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await send_task + + await asyncio.wait_for(ws.wait_closed(), 1.0) + with pytest.raises(websockets.ConnectionClosed): + await ws.send(b"x") + + +async def test_close_during_fragmented_send_closes_connection(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + first_sent = asyncio.Event() + unblock = asyncio.Event() + + async def fragments() -> AsyncIterator[bytes]: + yield b"first" + first_sent.set() + await unblock.wait() + yield b"second" + + send_task = asyncio.create_task(ws.send(fragments())) + await asyncio.wait_for(first_sent.wait(), 1.0) + await ws.close() + unblock.set() + + with pytest.raises(websockets.ConnectionClosed): + await send_task From ad1bfc99313808684e02c3b1925b260b2feef1ea Mon Sep 17 00:00:00 2001 From: taras Date: Tue, 5 May 2026 07:03:28 +0200 Subject: [PATCH 35/57] Fix missing WSUpgradeResponse.body property on success path --- picows/picows.pyx | 1 + 1 file changed, 1 insertion(+) diff --git a/picows/picows.pyx b/picows/picows.pyx index 5a35285..946aa92 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -1437,6 +1437,7 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): ) from None response.headers = CIMultiDict() + response.body = None for idx in range(1, len(lines)): line = lines[idx] try: From 807807f138a7980b8a1f6b128ecb937e5b263fa8 Mon Sep 17 00:00:00 2001 From: taras Date: Tue, 5 May 2026 07:09:17 +0200 Subject: [PATCH 36/57] Update AGENTS.md --- AGENTS.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index 64f59d1..554bb79 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -31,6 +31,13 @@ examples - Various examples for users on how to use picows + perf_test that coul - In Cythonized Python modules, avoid `typing.cast(...)` in hot paths. Cython may compile `cast(...)` into a real runtime global lookup and function call instead of erasing it like a type checker would. Prefer control-flow narrowing, assertions, or narrowly scoped type-ignore comments when needed. +- If `picows` core exposes an inconsistent runtime shape or behavior that looks like a bug, do not silently normalize around it in wrapper code. + Stop and ask first, or at least clearly call out that it appears to be a core bug instead of assuming it is an intentional quirk. + Wrapper-level workarounds for such inconsistencies should be treated as temporary and explicit, not as the default resolution. + Legitimate intentional quirks can be documented in this file separately once confirmed. +- `WSUpgradeRequest` / `WSUpgradeResponse` expose a mixed bytes/str API and this is public API. + Request `method`, `path`, `version` and response `version` are low-level protocol bytes, while headers are decoded strings and response `status` is `HTTPStatus`. + Do not change this shape casually in core or silently normalize it away in wrappers; treat it as a stable compatibility constraint unless an intentional breaking change is agreed. ## Testing instructions - Run lint after updating code with: From 493fc46249437e056bfcff217292d2af579b40b3 Mon Sep 17 00:00:00 2001 From: taras Date: Tue, 5 May 2026 08:00:06 +0200 Subject: [PATCH 37/57] Expose WSTransport.is_disconnected attribute --- HISTORY.rst | 3 +++ docs/source/reference.rst | 11 +++++++++++ picows/picows.pxd | 5 ++++- picows/picows.pyx | 7 +++++-- tests/test_basics.py | 23 ++++++++++++++++++++++- 5 files changed, 45 insertions(+), 4 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index 00ece56..fa2a75c 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -16,6 +16,9 @@ picows Release History * Added rsv2 and rsv3 to WSTransport send methods * WSTransport send, send_ping, send_pong, send_close can now accept `str` type as message. The message will be encoded as utf-8 before sending * User on_ws_connect and on_ws_frame implementation can now signalize protocol errors by raising WSProtocolError +* Add is_disconnected property to WSTransport. +* Fix send_* methods raising exceptions when attempting to send after connection abort and without prior CLOSE frame. +* Add missing body attribute in WSUpgradeResponse at the client side. 1.19.0 (2026-04-24) ------------------ diff --git a/docs/source/reference.rst b/docs/source/reference.rst index e11d85f..f358434 100644 --- a/docs/source/reference.rst +++ b/docs/source/reference.rst @@ -231,6 +231,17 @@ Classes :any:`WSTransport.send_pong`, and :any:`WSTransport.send_close`) are no-ops and do nothing. + .. py:attribute:: is_disconnected + :type: bool + + Indicates whether the underlying asyncio transport has reported `connection_lost` event. + :any:`WSTransport.send_close`. + + When this flag is ``True``, subsequent calls to send methods + (:any:`WSTransport.send`, :any:`WSTransport.send_ping`, + :any:`WSTransport.send_pong`, and :any:`WSTransport.send_close`) + are no-ops and do nothing. + .. py:attribute:: request :type: WSUpgradeRequest diff --git a/picows/picows.pxd b/picows/picows.pxd index dbed30b..3836ec3 100644 --- a/picows/picows.pxd +++ b/picows/picows.pxd @@ -97,13 +97,16 @@ cdef class WSTransport: readonly bint is_client_side readonly bint is_secure readonly bint is_close_frame_sent + readonly bint is_disconnected + # These are not public API, but accessed directly from WSProtocol + # Therefore without _. Perhaps I should add underscore, since accessing + # from Cython is possible for the user bint auto_ping_expect_pong object pong_received_at_future object listener_proxy object disconnected_future #: asyncio.Future - object _loop object _logger #: Logger MemoryBuffer _write_buffer diff --git a/picows/picows.pyx b/picows/picows.pyx index 946aa92..5a8baf3 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -368,6 +368,7 @@ cdef class WSTransport: self.is_client_side = is_client_side self.is_secure = underlying_transport.get_extra_info('ssl_object') is not None self.is_close_frame_sent = False + self.is_disconnected = False self.auto_ping_expect_pong = False self.pong_received_at_future = None self.listener_proxy = None @@ -449,7 +450,7 @@ cdef class WSTransport: cdef _send_buffer(self, WSMsgType msg_type, char* msg_ptr, Py_ssize_t msg_size, bint fin, bint rsv1, bint rsv2, bint rsv3): - if self.is_close_frame_sent: + if self.is_close_frame_sent or self.is_disconnected: self._logger.debug("Ignore attempt to send a message after WSMsgType.CLOSE has already been sent") return @@ -464,7 +465,7 @@ cdef class WSTransport: self._fast_write(header_ptr, header_size + msg_size) cdef _send(self, WSMsgType msg_type, message, bint fin, bint rsv1, bint rsv2, bint rsv3): - if self.is_close_frame_sent: + if self.is_close_frame_sent or self.is_disconnected: self._logger.debug("Ignore attempt to send a message after WSMsgType.CLOSE has already been sent") return @@ -1043,6 +1044,8 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): self._handshake_timeout, self._handshake_timeout_callback) def connection_lost(self, exc): + self.transport.is_disconnected = True + self._logger.info("Disconnected") if self._handshake_complete_future.done(): diff --git a/tests/test_basics.py b/tests/test_basics.py index afaaa51..66e4a29 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -293,4 +293,25 @@ async def test_stress(use_aiofastnet, ssl_context): assert not client.is_paused -# + +async def test_send_after_abort(use_aiofastnet, ssl_context): + ba = bytearray(b"1234567890123456") + + # Test that any attempt to send after abort is no-op. No exception is raised + + async with WSServer(ssl=ssl_context.server, use_aiofastnet=use_aiofastnet) as server: + async with WSClient(server, ssl_context=ssl_context.client, use_aiofastnet=use_aiofastnet) as client: + client.transport.disconnect(False) + client.transport.send(picows.WSMsgType.BINARY, b"halo") + client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, ba, 16) + assert client.transport.is_disconnected == False + + await client.transport.wait_disconnected() + assert client.transport.is_disconnected == True + + client.transport.send(picows.WSMsgType.BINARY, b"halo") + client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, ba, 16) + + await asyncio.sleep(0.05) + client.transport.send(picows.WSMsgType.BINARY, b"halo") + client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, ba, 16) From af503e337f6cc8f60b6b65d6f0d347c364f2b6e2 Mon Sep 17 00:00:00 2001 From: taras Date: Tue, 5 May 2026 08:01:32 +0200 Subject: [PATCH 38/57] Update AGENTS.md --- AGENTS.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index 554bb79..9a9a531 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -38,6 +38,10 @@ examples - Various examples for users on how to use picows + perf_test that coul - `WSUpgradeRequest` / `WSUpgradeResponse` expose a mixed bytes/str API and this is public API. Request `method`, `path`, `version` and response `version` are low-level protocol bytes, while headers are decoded strings and response `status` is `HTTPStatus`. Do not change this shape casually in core or silently normalize it away in wrappers; treat it as a stable compatibility constraint unless an intentional breaking change is agreed. +- In `picows` core, once a CLOSE frame has been sent, later send-side API calls are effectively no-ops. + This applies to `send_close()` as well as the other send methods. + Also, `disconnect()` and `wait_disconnected()` are safe to call multiple times. + Wrapper code should rely on these idempotency guarantees instead of adding its own state-based suppression around shutdown operations. ## Testing instructions - Run lint after updating code with: From 15f5cf038f237c62b5e0cdaa9a82216f77088ea2 Mon Sep 17 00:00:00 2001 From: taras Date: Tue, 5 May 2026 09:09:29 +0200 Subject: [PATCH 39/57] Optimizations and cleanups --- picows/websockets/asyncio/client.py | 3 +- picows/websockets/asyncio/connection.py | 121 +++++++++++++++--------- picows/websockets/compat.py | 40 +++++++- picows/websockets/exceptions.py | 2 +- 4 files changed, 117 insertions(+), 49 deletions(-) diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index 9375bd1..b6d1d74 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -3,6 +3,7 @@ import asyncio import socket from collections.abc import Generator +from logging import getLogger from ssl import SSLContext from typing import Any, Callable, Optional, Sequence, Union @@ -211,7 +212,7 @@ def listener_factory() -> ClientConnection: extra_headers=extra_headers, proxy=proxy, socket_factory=socket_factory, - logger_name=self.logger if self.logger is not None else "websockets.client", + logger_name=self.logger if self.logger is not None else getLogger("websockets.client"), **conn_kwargs, ) except picows.WSInvalidURL as exc: diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 13ea423..3632cb1 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -283,7 +283,7 @@ def process_exception(exc: Exception) -> Optional[Exception]: if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)): return None if isinstance(exc, InvalidStatus): - status = exc.response.status + status = exc.response.status_code if int(status) in {500, 502, 503, 504}: return None return exc @@ -301,7 +301,6 @@ class ClientConnection(WSListener): # type: ignore[misc] _subprotocol: Optional[Subprotocol] _compression: Optional[str] _permessage_deflate: Optional[_PerMessageDeflate] - _state: State _close_exc: Optional[ConnectionClosed] _loop: asyncio.AbstractEventLoop @@ -353,7 +352,6 @@ def __init__( self._subprotocol = None self._compression = compression self._permessage_deflate = None - self._state = State.CONNECTING self._close_exc: Optional[ConnectionClosed] = None self._loop = asyncio.get_running_loop() @@ -382,8 +380,8 @@ def __init__( @cython.ccall def on_ws_connected(self, transport: WSTransport) -> None: self.transport = transport - self._request = transport.request - self._response = transport.response + self._request = Request.from_picows(transport.request) + self._response = Response.from_picows(transport.response) try: self._subprotocol = _resolve_subprotocol(self._subprotocols, self._response) self._configure_extensions() @@ -392,14 +390,12 @@ def on_ws_connected(self, transport: WSTransport) -> None: self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, str(exc)) self.transport.disconnect(False) return - self._state = State.OPEN self._set_write_limits(self._write_limit) if self._ping_interval is not None and self._keepalive_task is None: self._keepalive_task = asyncio.create_task(self._keepalive_loop()) @cython.ccall def on_ws_disconnected(self, transport: WSTransport) -> None: - self._state = State.CLOSED self._set_close_exception() self._add_to_recv_queue(None) if self._keepalive_task is not None: @@ -416,25 +412,6 @@ def on_ws_disconnected(self, transport: WSTransport) -> None: waiter.set_exception(self._close_exc or ConnectionClosedError(None, None, None)) self._pending_pings.clear() - @cython.cfunc - @cython.inline - def _process_pong_frame(self, frame: WSFrame) -> None: - ping = self._pending_pings.pop(frame.get_payload_as_bytes(), None) - if ping is not None: - waiter, sent_at = ping - self._latency = monotonic() - sent_at - if not waiter.done(): - waiter.set_result(self._latency) - - @cython.cfunc - @cython.inline - def _process_close_frame(self, frame: WSFrame) -> None: - close_code = frame.get_close_code() - close_message = frame.get_close_message() - self.transport.send_close(close_code, close_message) - self.transport.disconnect() - self._state = State.CLOSING - @cython.ccall def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: if frame.msg_type == WSMsgType.PONG: @@ -494,6 +471,24 @@ def resume_writing(self) -> None: self._write_ready.set_result(None) self._write_ready = None + @cython.cfunc + @cython.inline + def _process_pong_frame(self, frame: WSFrame) -> None: + ping = self._pending_pings.pop(frame.get_payload_as_bytes(), None) + if ping is not None: + waiter, sent_at = ping + self._latency = monotonic() - sent_at + if not waiter.done(): + waiter.set_result(self._latency) + + @cython.cfunc + @cython.inline + def _process_close_frame(self, frame: WSFrame) -> None: + close_code = frame.get_close_code() + close_message = frame.get_close_message() + self.transport.send_close(close_code, close_message) + self.transport.disconnect() + @cython.cfunc @cython.inline def _set_write_limits(self, write_limit: Union[int, tuple[int, Optional[int]]]) -> None: @@ -576,12 +571,6 @@ def _connection_closed(self) -> ConnectionClosed: self._set_close_exception() return self._close_exc or ConnectionClosedError(None, None, None) - @cython.cfunc - @cython.inline - def _fail_protocol_error(self, message: str) -> None: - self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, message) - self.transport.disconnect(False) - @cython.cfunc @cython.inline def _set_recv_in_progress(self) -> None: @@ -632,6 +621,14 @@ def _check_frame(self, frame: Optional[_BufferedFrame]) -> _BufferedFrame: raise self._connection_closed() return frame + async def _fail_invalid_data(self, exc: UnicodeDecodeError) -> None: + self.transport.send_close( + WSCloseCode.INVALID_TEXT, + f"{exc.reason} at position {exc.start}", + ) + self.transport.disconnect(False) + await self.wait_closed() + async def recv(self, decode: Optional[bool] = None) -> Data: frame: _BufferedFrame @@ -662,6 +659,9 @@ async def recv(self, decode: Optional[bool] = None) -> Data: except asyncio.CancelledError: self._recv_queue.extendleft(reversed(frames)) raise + except UnicodeDecodeError as exc: + await self._fail_invalid_data(exc) + raise self._connection_closed() from exc finally: self._recv_in_progress = False self._recv_waiter = None @@ -692,6 +692,9 @@ async def iterator() -> AsyncIterator[Data]: yield self._decode_data(frame.payload, msg_type, decode) msg_finished = True + except UnicodeDecodeError as exc: + await self._fail_invalid_data(exc) + raise self._connection_closed() from exc finally: self._recv_in_progress = False self._recv_waiter = None @@ -702,7 +705,21 @@ async def iterator() -> AsyncIterator[Data]: return iterator() + @cython.cfunc + @cython.inline + def _is_in_open_state(self) -> cython.bint: + # Before on_ws_connected, self.transport is None + # on_ws_frame immediately send CLOSE reply on incoming CLOSE frame, so receiving CLOSE == is_close_frame_sent + # transport.is_disconnect happens the last, asyncio Protocol got connection_lost event + + return (self.transport is not None + and not self.transport.is_disconnected + and not self.transport.is_close_frame_sent) + def _encode_and_send(self, msg_type: WSMsgType, message: DataLike, fin: cython.bint) -> None: + if not self._is_in_open_state(): + raise self._connection_closed() + if self._permessage_deflate is not None: message = self._permessage_deflate.encode_frame( msg_type, self._compression_payload(message), fin @@ -716,7 +733,7 @@ async def send( message: Union[DataLike, Iterable[DataLike], AsyncIterator[DataLike]], text: Optional[bool] = None, ) -> None: - if self._state is State.CLOSED: + if not self._is_in_open_state(): raise self._connection_closed() if self._send_in_progress: @@ -826,16 +843,18 @@ async def _send_fragments( self._encode_and_send(msg_type, b"", True) if self._write_ready is not None: await self._write_ready - except Exception: - self._fail_protocol_error("error in fragmented message") + except BaseException: + self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, "error in fragmented message") + self.transport.disconnect(False) raise async def close(self, code: int = 1000, reason: str = "") -> None: - if self._state is State.CLOSED: - return - if self._state is State.OPEN: - self._state = State.CLOSING - self.transport.send_close(code, reason) + if self._send_in_progress: + self.transport.send_close(WSCloseCode.INTERNAL_ERROR, "close during fragmented message") + self.transport.disconnect(False) + else: + self.transport.send_close(cython.cast(WSCloseCode, code), reason) + try: if self._close_timeout is None: await self.wait_closed() @@ -846,11 +865,15 @@ async def close(self, code: int = 1000, reason: str = "") -> None: await self.wait_closed() async def wait_closed(self) -> None: - await self.transport.wait_disconnected() + try: + await self.transport.wait_disconnected() + except Exception: + pass async def ping(self, data: Optional[DataLike] = None) -> Awaitable[float]: - if self._state is State.CLOSED: + if not self._is_in_open_state(): raise self._connection_closed() + if data is None: while True: payload = os.urandom(4) @@ -872,11 +895,12 @@ async def ping(self, data: Optional[DataLike] = None) -> Awaitable[float]: return waiter async def pong(self, data: Union[str, bytes] = b"") -> None: - if self._state is State.CLOSED: + if not self._is_in_open_state(): raise self._connection_closed() + self.transport.send_pong(data) - async def _keepalive_loop(self) -> None: + async def _keepalive_loop(self) -> None: try: while True: assert self._ping_interval is not None @@ -909,7 +933,14 @@ async def _iterate_messages(self) -> AsyncIterator[Data]: @property def state(self) -> State: - return self._state + if self.transport is None: + return State.CONNECTING + elif self.transport.is_disconnected: + return State.CLOSED + elif self.transport.is_close_frame_sent or self.transport.close_handshake is not None: + return State.CLOSING + else: + return State.OPEN @property def request(self) -> Request: diff --git a/picows/websockets/compat.py b/picows/websockets/compat.py index 58c381c..c627442 100644 --- a/picows/websockets/compat.py +++ b/picows/websockets/compat.py @@ -1,10 +1,46 @@ from __future__ import annotations +from dataclasses import dataclass + import picows +from multidict import CIMultiDict CloseCode = picows.WSCloseCode -Request = picows.WSUpgradeRequest -Response = picows.WSUpgradeResponse + + +@dataclass(slots=True) +class Request: + path: str + headers: CIMultiDict[str] + + @classmethod + def from_picows(cls, request: picows.WSUpgradeRequest) -> Request: + return cls( + path=request.path.decode("ascii", "surrogateescape"), + headers=request.headers, + ) + + +@dataclass(slots=True) +class Response: + status_code: int + reason_phrase: str + headers: CIMultiDict[str] + body: bytes | bytearray + + @classmethod + def from_picows(cls, response: picows.WSUpgradeResponse) -> Response: + return cls( + status_code=int(response.status), + reason_phrase=response.status.phrase, + headers=response.headers, + body=b"" if response.body is None else response.body, + ) + + @property + def status(self) -> int: + return self.status_code + __all__ = [ "CloseCode", diff --git a/picows/websockets/exceptions.py b/picows/websockets/exceptions.py index 1d8a970..c2e198d 100644 --- a/picows/websockets/exceptions.py +++ b/picows/websockets/exceptions.py @@ -108,7 +108,7 @@ def __init__(self, response: Any): self.response = response def __str__(self) -> str: - status = getattr(self.response, "status", None) + status = getattr(self.response, "status_code", None) if status is None: return "proxy rejected connection" return f"proxy rejected connection: HTTP {int(status):d}" From 2b4eb1b953c7464176e16ce5e6715cde39eb4a46 Mon Sep 17 00:00:00 2001 From: taras Date: Tue, 5 May 2026 19:00:25 +0200 Subject: [PATCH 40/57] Better cancellation and disconnect logic --- picows/websockets/asyncio/connection.py | 198 +++++++++++++++--------- 1 file changed, 124 insertions(+), 74 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 3632cb1..7bbba7b 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -207,20 +207,23 @@ def decode_frame(self, frame: WSFrame, max_length: cython.Py_ssize_t) -> bytes: @cython.cfunc @cython.inline @cython.wraparound(True) - def encode_frame(self, msg_type: WSMsgType, payload: BytesLike, fin: cython.bint) -> BytesLike: + def encode_frame(self, msg_type: WSMsgType, data: DataLike, fin: cython.bint) -> BytesLike: if msg_type != WSMsgType.CONTINUATION and (self.local_no_context_takeover or self._encoder is None): self._encoder = _zlib_compressobj(wbits=self.local_max_window_bits) - data: BytesLike = (self._encoder.compress(payload) + - self._encoder.flush(_zlib_z_sync_flush)) + if isinstance(data, str): + data = cython.cast(str, data).encode('utf-8') + + compressed_data: BytesLike = (self._encoder.compress(data) + + self._encoder.flush(_zlib_z_sync_flush)) if fin: - data_mv = memoryview(data) - assert data_mv[-4:] == _empty_uncompressed_block - data = data_mv[:-4] + mv = memoryview(compressed_data) + assert mv[-4:] == _empty_uncompressed_block + compressed_data = mv[:-4] if self.local_no_context_takeover: self._encoder = None - return data + return compressed_data @cython.cfunc @@ -301,7 +304,6 @@ class ClientConnection(WSListener): # type: ignore[misc] _subprotocol: Optional[Subprotocol] _compression: Optional[str] _permessage_deflate: Optional[_PerMessageDeflate] - _close_exc: Optional[ConnectionClosed] _loop: asyncio.AbstractEventLoop # Send side @@ -322,6 +324,10 @@ class ClientConnection(WSListener): # type: ignore[misc] _incoming_message_active: cython.bint _incoming_message_size: cython.Py_ssize_t + # Close logic + _close_fut: asyncio.Future[None] + _close_exc: Optional[ConnectionClosed] + _pending_pings: Dict[bytes, Tuple[asyncio.Future[float], float]] _ping_interval: Optional[float] _ping_timeout: Optional[float] @@ -352,7 +358,6 @@ def __init__( self._subprotocol = None self._compression = compression self._permessage_deflate = None - self._close_exc: Optional[ConnectionClosed] = None self._loop = asyncio.get_running_loop() self._send_in_progress = False @@ -370,6 +375,9 @@ def __init__( self._incoming_message_active = False self._incoming_message_size = 0 + self._close_fut = self._loop.create_future() + self._close_exc: Optional[ConnectionClosed] = None + self._pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} self._ping_interval = ping_interval self._ping_timeout = ping_timeout @@ -396,22 +404,35 @@ def on_ws_connected(self, transport: WSTransport) -> None: @cython.ccall def on_ws_disconnected(self, transport: WSTransport) -> None: + # Set _close_exc, _close_fut self._set_close_exception() + + # Wake up potential waiter on _recv_queue self._add_to_recv_queue(None) + + # Cancel pinging loop if self._keepalive_task is not None: self._keepalive_task.cancel() self._keepalive_task = None + + # If there is a waiter waiting for resume_writing wake it up with exception if self._write_ready is not None: if not self._write_ready.done(): - self._write_ready.set_exception( - self._close_exc or ConnectionClosedError(None, None, None) - ) + self._write_ready.set_exception(self._close_exc) # type: ignore[arg-type] self._write_ready = None + + # Wake up all waiters waiting for ping replies for waiter, _ in self._pending_pings.values(): if not waiter.done(): - waiter.set_exception(self._close_exc or ConnectionClosedError(None, None, None)) + waiter.set_exception(self._close_exc) # type: ignore[arg-type] self._pending_pings.clear() + # Wake up all waiters waiting for current send to complete + for waiter in self._send_waiters: + if not waiter.done(): + waiter.set_exception(self._close_exc) + self._send_waiters.clear() + @cython.ccall def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: if frame.msg_type == WSMsgType.PONG: @@ -548,6 +569,8 @@ def _configure_extensions(self) -> None: @cython.cfunc @cython.inline def _set_close_exception(self) -> None: + self._close_fut.set_result(None) + handshake = self.transport.close_handshake if handshake is None: self._close_exc = ConnectionClosedError(None, None, None) @@ -564,13 +587,6 @@ def _set_close_exception(self) -> None: exc_type = ConnectionClosedOK if ok else ConnectionClosedError self._close_exc = exc_type(rcvd, sent, rcvd_then_sent) - @cython.cfunc - @cython.inline - def _connection_closed(self) -> ConnectionClosed: - if self._close_exc is None: - self._set_close_exception() - return self._close_exc or ConnectionClosedError(None, None, None) - @cython.cfunc @cython.inline def _set_recv_in_progress(self) -> None: @@ -580,31 +596,6 @@ def _set_recv_in_progress(self) -> None: raise ConcurrencyError("recv_streaming() wasn't fully consumed") self._recv_in_progress = True - async def _wait_send_turn(self) -> None: - waiter: asyncio.Future[None] = self._loop.create_future() - self._send_waiters.append(waiter) - try: - await waiter - except Exception: - try: - self._send_waiters.remove(waiter) - except ValueError: - pass - raise - - @cython.cfunc - @cython.inline - def _release_send(self) -> None: - waiter: asyncio.Future[None] - - while self._send_waiters: - waiter = self._send_waiters.popleft() - if not waiter.done(): - waiter.set_result(None) - return - - self._send_in_progress = False - @cython.cfunc @cython.inline def _decode_data(self, payload: bytes, msg_type: WSMsgType, decode: Optional[bool]) -> Data: @@ -618,16 +609,17 @@ def _decode_data(self, payload: bytes, msg_type: WSMsgType, decode: Optional[boo def _check_frame(self, frame: Optional[_BufferedFrame]) -> _BufferedFrame: self._resume_reading_if_needed() if frame is None: - raise self._connection_closed() + raise self._close_exc return frame - async def _fail_invalid_data(self, exc: UnicodeDecodeError) -> None: + @cython.cfunc + @cython.inline + def _fail_invalid_data(self, exc: UnicodeDecodeError) -> None: self.transport.send_close( WSCloseCode.INVALID_TEXT, f"{exc.reason} at position {exc.start}", ) self.transport.disconnect(False) - await self.wait_closed() async def recv(self, decode: Optional[bool] = None) -> Data: frame: _BufferedFrame @@ -660,8 +652,8 @@ async def recv(self, decode: Optional[bool] = None) -> Data: self._recv_queue.extendleft(reversed(frames)) raise except UnicodeDecodeError as exc: - await self._fail_invalid_data(exc) - raise self._connection_closed() from exc + self._fail_invalid_data(exc) + await self._wait_close_and_raise(exc) finally: self._recv_in_progress = False self._recv_waiter = None @@ -693,8 +685,8 @@ async def iterator() -> AsyncIterator[Data]: yield self._decode_data(frame.payload, msg_type, decode) msg_finished = True except UnicodeDecodeError as exc: - await self._fail_invalid_data(exc) - raise self._connection_closed() from exc + self._fail_invalid_data(exc) + await self._wait_close_and_raise(exc) finally: self._recv_in_progress = False self._recv_waiter = None @@ -716,25 +708,90 @@ def _is_in_open_state(self) -> cython.bint: and not self.transport.is_disconnected and not self.transport.is_close_frame_sent) + @cython.cfunc + @cython.inline def _encode_and_send(self, msg_type: WSMsgType, message: DataLike, fin: cython.bint) -> None: - if not self._is_in_open_state(): - raise self._connection_closed() - if self._permessage_deflate is not None: - message = self._permessage_deflate.encode_frame( - msg_type, self._compression_payload(message), fin - ) + message = self._permessage_deflate.encode_frame(msg_type, message, fin) self.transport.send(msg_type, message, fin, msg_type != WSMsgType.CONTINUATION) else: self.transport.send(msg_type, message, fin) + async def _wait_close_and_raise(self, exc=None) -> None: + # CANCELLATION: + # _close_fut is supposed to be set only from on_ws_disconnected. + # It is intentionally shielded. + await asyncio.shield(self._close_fut) + if exc is None: + raise self._close_exc + else: + raise self._close_exc from exc + + async def _wait_send_turn(self) -> None: + # DISCONNECT: the waiter will raise ConnectionClosed + # It can also successfully finish, but we may be in CLOSING state. + # In such case delegate waiting to _wait_close_and_raise + + # CANCELLATION: + # waiter future is not shielded intentionally, it turns into Cancelled + # state and removed from waiters by _release_send. + # _wait_close_and_raise shields _close_fut. + waiter: asyncio.Future[None] = self._loop.create_future() + self._send_waiters.append(waiter) + await waiter + if not self._is_in_open_state(): + await self._wait_close_and_raise() + + async def _wait_write_ready(self) -> None: + # DISCONNECT: the waiter will raise ConnectionClosed + # It can also successfully finish, but we may be in CLOSING state. + # In such case delegate waiting to _wait_close_and_raise + + # CANCELLATION: + # _write_ready future is shielded intentionally. It is only supposed to + # be set from resume_writing and on_ws_disconnected. + assert self._write_ready is not None + await asyncio.shield(self._write_ready) + if not self._is_in_open_state(): + await self._wait_close_and_raise() + + async def _get_next_async_fragment(self, async_iterator: AsyncIterator[DataLike]) -> DataLike: + # DISCONNECT: raise ConnectionClosed if after user async interator + # returns we are not in OPEN state + + # CANCELLATION: + # User async iterator is also getting canceled. + + data: DataLike = await anext(async_iterator) + if not self._is_in_open_state(): + await self._wait_close_and_raise() + return data + + @cython.cfunc + @cython.inline + def _release_send(self) -> None: + waiter: asyncio.Future[None] + + while self._send_waiters: + waiter = self._send_waiters.popleft() + # Some waiters may be canceled, that is why we have defensive check + # for waiter.done() + if not waiter.done(): + waiter.set_result(None) + return + + self._send_in_progress = False + async def send( self, message: Union[DataLike, Iterable[DataLike], AsyncIterator[DataLike]], text: Optional[bool] = None, ) -> None: + # send doesn't directly wait on helper futures. It is very tricky to handle + # disconnects and cancellations properly. send delegates this to + # _wait_* helpers. if not self._is_in_open_state(): - raise self._connection_closed() + await self._wait_close_and_raise() if self._send_in_progress: await self._wait_send_turn() @@ -751,7 +808,7 @@ async def send( self._encode_and_send(msg_type, message, True) if self._write_ready is not None: - await self._write_ready + await self._wait_write_ready() # Catch a common mistake -- passing a dict to send(). elif isinstance(message, Mapping): raise TypeError("data is a dict-like object") @@ -772,13 +829,6 @@ def _check_fragment_type(self, message: DataLike, first_is_str: cython.bint) -> raise TypeError("all fragments must be of the same category: str vs bytes-like") - @cython.cfunc - @cython.inline - def _compression_payload(self, message: DataLike) -> BytesLike: - if isinstance(message, str): - return message.encode("utf-8") - return message - async def _send_fragments( self, messages: Union[AsyncIterable[DataLike], Iterable[DataLike]], @@ -803,7 +853,7 @@ async def _send_fragments( try: try: if is_async: - current = await anext(async_iterator) + current = await self._get_next_async_fragment(async_iterator) else: current = next(iterator) except stop_exception_type: @@ -829,7 +879,7 @@ async def _send_fragments( try: if is_async: - current = await anext(async_iterator) + current = await self._get_next_async_fragment(async_iterator) else: current = next(iterator) except stop_exception_type: @@ -837,12 +887,12 @@ async def _send_fragments( self._check_fragment_type(current, first_is_str) if self._write_ready is not None: - await self._write_ready + await self._wait_write_ready() # Send the last empty frame with fin=True self._encode_and_send(msg_type, b"", True) if self._write_ready is not None: - await self._write_ready + await self._wait_write_ready() except BaseException: self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, "error in fragmented message") self.transport.disconnect(False) @@ -872,7 +922,7 @@ async def wait_closed(self) -> None: async def ping(self, data: Optional[DataLike] = None) -> Awaitable[float]: if not self._is_in_open_state(): - raise self._connection_closed() + await self._wait_close_and_raise() if data is None: while True: @@ -896,11 +946,11 @@ async def ping(self, data: Optional[DataLike] = None) -> Awaitable[float]: async def pong(self, data: Union[str, bytes] = b"") -> None: if not self._is_in_open_state(): - raise self._connection_closed() + await self._wait_close_and_raise() self.transport.send_pong(data) - async def _keepalive_loop(self) -> None: + async def _keepalive_loop(self) -> None: try: while True: assert self._ping_interval is not None From b9d9dde29d3a2139e33fb0fdba24a8c5f330ab94 Mon Sep 17 00:00:00 2001 From: taras Date: Tue, 5 May 2026 19:22:51 +0200 Subject: [PATCH 41/57] Simplify logic --- picows/websockets/asyncio/connection.py | 29 +++++++++++++++---------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 7bbba7b..d842cf2 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -10,7 +10,7 @@ from enum import IntEnum from time import monotonic from typing import Any, AsyncIterator, Awaitable, Optional, Sequence, \ - Union, Dict, Tuple, Iterator, Mapping + Union, Dict, Tuple, Iterator, Mapping, NoReturn import cython @@ -317,7 +317,7 @@ class ClientConnection(WSListener): # type: ignore[misc] _recv_streaming_broken: cython.bint _paused_reading: cython.bint _recv_waiter: Optional[asyncio.Future[None]] - _recv_queue: deque[Optional[_BufferedFrame]] + _recv_queue: deque[_BufferedFrame] _max_message_size: cython.Py_ssize_t _max_queue_high: cython.Py_ssize_t _max_queue_low: cython.Py_ssize_t @@ -407,8 +407,14 @@ def on_ws_disconnected(self, transport: WSTransport) -> None: # Set _close_exc, _close_fut self._set_close_exception() + # Pacify type checker + assert self._close_exc is not None + # Wake up potential waiter on _recv_queue - self._add_to_recv_queue(None) + if self._recv_waiter is not None: + if not self._recv_waiter.done(): + self._recv_waiter.set_exception(self._close_exc) + self._recv_waiter = None # Cancel pinging loop if self._keepalive_task is not None: @@ -418,7 +424,7 @@ def on_ws_disconnected(self, transport: WSTransport) -> None: # If there is a waiter waiting for resume_writing wake it up with exception if self._write_ready is not None: if not self._write_ready.done(): - self._write_ready.set_exception(self._close_exc) # type: ignore[arg-type] + self._write_ready.set_exception(self._close_exc) self._write_ready = None # Wake up all waiters waiting for ping replies @@ -537,7 +543,7 @@ def _resume_reading_if_needed(self) -> None: @cython.cfunc @cython.inline - def _add_to_recv_queue(self, frame: Optional[_BufferedFrame]) -> None: + def _add_to_recv_queue(self, frame: _BufferedFrame) -> None: self._recv_queue.append(frame) waiter = self._recv_waiter if waiter is not None: @@ -549,6 +555,9 @@ def _add_to_recv_queue(self, frame: Optional[_BufferedFrame]) -> None: @cython.inline def _wait_recv_queue_not_empty(self) -> asyncio.Future[None]: assert self._recv_waiter is None + if self._close_exc is not None: + raise self._close_exc + waiter: asyncio.Future[None] = self._loop.create_future() self._recv_waiter = waiter return waiter @@ -569,11 +578,10 @@ def _configure_extensions(self) -> None: @cython.cfunc @cython.inline def _set_close_exception(self) -> None: - self._close_fut.set_result(None) - handshake = self.transport.close_handshake if handshake is None: self._close_exc = ConnectionClosedError(None, None, None) + self._close_fut.set_result(None) return rcvd = handshake.recv sent = handshake.sent @@ -586,6 +594,7 @@ def _set_close_exception(self) -> None: ) exc_type = ConnectionClosedOK if ok else ConnectionClosedError self._close_exc = exc_type(rcvd, sent, rcvd_then_sent) + self._close_fut.set_result(None) @cython.cfunc @cython.inline @@ -606,10 +615,8 @@ def _decode_data(self, payload: bytes, msg_type: WSMsgType, decode: Optional[boo @cython.cfunc @cython.inline - def _check_frame(self, frame: Optional[_BufferedFrame]) -> _BufferedFrame: + def _check_frame(self, frame: _BufferedFrame) -> _BufferedFrame: self._resume_reading_if_needed() - if frame is None: - raise self._close_exc return frame @cython.cfunc @@ -717,7 +724,7 @@ def _encode_and_send(self, msg_type: WSMsgType, message: DataLike, fin: cython.b else: self.transport.send(msg_type, message, fin) - async def _wait_close_and_raise(self, exc=None) -> None: + async def _wait_close_and_raise(self, exc=None) -> NoReturn: # CANCELLATION: # _close_fut is supposed to be set only from on_ws_disconnected. # It is intentionally shielded. From 7170205e423ebacfd8a79e1eceb5dfc1ab8b077a Mon Sep 17 00:00:00 2001 From: taras Date: Tue, 5 May 2026 19:28:34 +0200 Subject: [PATCH 42/57] Simplify logic --- picows/websockets/asyncio/connection.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index d842cf2..7cee571 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -428,15 +428,15 @@ def on_ws_disconnected(self, transport: WSTransport) -> None: self._write_ready = None # Wake up all waiters waiting for ping replies - for waiter, _ in self._pending_pings.values(): - if not waiter.done(): - waiter.set_exception(self._close_exc) # type: ignore[arg-type] + for ping_waiter, _ in self._pending_pings.values(): + if not ping_waiter.done(): + ping_waiter.set_exception(self._close_exc) self._pending_pings.clear() # Wake up all waiters waiting for current send to complete - for waiter in self._send_waiters: - if not waiter.done(): - waiter.set_exception(self._close_exc) + for send_waiter in self._send_waiters: + if not send_waiter.done(): + send_waiter.set_exception(self._close_exc) self._send_waiters.clear() @cython.ccall @@ -724,11 +724,13 @@ def _encode_and_send(self, msg_type: WSMsgType, message: DataLike, fin: cython.b else: self.transport.send(msg_type, message, fin) - async def _wait_close_and_raise(self, exc=None) -> NoReturn: + async def _wait_close_and_raise(self, exc: Optional[BaseException]=None) -> NoReturn: # CANCELLATION: # _close_fut is supposed to be set only from on_ws_disconnected. # It is intentionally shielded. await asyncio.shield(self._close_fut) + assert self._close_exc is not None # pacify type checker + if exc is None: raise self._close_exc else: From b6189d75165e66f0cebd237411106ac605ef7e8a Mon Sep 17 00:00:00 2001 From: taras Date: Wed, 6 May 2026 04:42:36 +0200 Subject: [PATCH 43/57] Cleanup --- picows/websockets/__init__.py | 4 +- picows/websockets/asyncio/__init__.py | 4 +- picows/websockets/asyncio/connection.py | 185 ++++++++++++------------ picows/websockets/compat.py | 9 ++ 4 files changed, 102 insertions(+), 100 deletions(-) diff --git a/picows/websockets/__init__.py b/picows/websockets/__init__.py index 0fb5b48..b8e2d41 100644 --- a/picows/websockets/__init__.py +++ b/picows/websockets/__init__.py @@ -1,7 +1,7 @@ from . import exceptions from .asyncio.client import connect -from .asyncio.connection import ClientConnection, State, process_exception -from .compat import CloseCode, Request, Response +from .asyncio.connection import ClientConnection, process_exception +from .compat import CloseCode, Request, Response, State from .exceptions import ( ConcurrencyError, ConnectionClosed, diff --git a/picows/websockets/asyncio/__init__.py b/picows/websockets/asyncio/__init__.py index ed18ffe..e160d29 100644 --- a/picows/websockets/asyncio/__init__.py +++ b/picows/websockets/asyncio/__init__.py @@ -1,9 +1,9 @@ from .client import connect -from .connection import ClientConnection, State, process_exception +from .connection import ClientConnection, process_exception +from ..compat import State __all__ = [ "ClientConnection", - "State", "connect", "process_exception", ] diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 7cee571..a1861f8 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -6,8 +6,7 @@ import uuid import zlib from collections import deque -from collections.abc import AsyncIterable, Generator, Iterable -from enum import IntEnum +from collections.abc import AsyncIterable, Iterable from time import monotonic from typing import Any, AsyncIterator, Awaitable, Optional, Sequence, \ Union, Dict, Tuple, Iterator, Mapping, NoReturn @@ -22,7 +21,7 @@ from picows import WSProtocolError -from ..compat import CloseCode, Request, Response +from ..compat import State, CloseCode, Request, Response from ..exceptions import ( ConcurrencyError, ConnectionClosed, @@ -34,11 +33,15 @@ from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol -class State(IntEnum): - CONNECTING = 0 - OPEN = 1 - CLOSING = 2 - CLOSED = 3 +# cached for performance +_ok_close_codes = cython.declare(set, {0, 1000, 1001}) +_asyncio_shield = cython.declare(object, asyncio.shield) + +# zlib/compress/decompress utils, cached for performance +_empty_uncompressed_block = cython.declare(bytes, b"\x00\x00\xff\xff") +_zlib_compressobj = cython.declare(object, zlib.compressobj) +_zlib_decompressobj = cython.declare(object, zlib.decompressobj) +_zlib_z_sync_flush = cython.declare(object, zlib.Z_SYNC_FLUSH) @cython.freelist(128) @@ -60,17 +63,6 @@ def _make_buffered_frame(msg_type: WSMsgType, payload: bytes, fin: cython.bint) return self -# cached for performance -_ok_close_codes = cython.declare(set, {0, 1000, 1001}) - - -# zlib/compress/decompress utils, cached for performance -_empty_uncompressed_block = cython.declare(bytes, b"\x00\x00\xff\xff") -_zlib_compressobj = cython.declare(object, zlib.compressobj) -_zlib_decompressobj = cython.declare(object, zlib.decompressobj) -_zlib_z_sync_flush = cython.declare(object, zlib.Z_SYNC_FLUSH) - - @cython.no_gc @cython.cclass class _PerMessageDeflate: @@ -318,20 +310,20 @@ class ClientConnection(WSListener): # type: ignore[misc] _paused_reading: cython.bint _recv_waiter: Optional[asyncio.Future[None]] _recv_queue: deque[_BufferedFrame] - _max_message_size: cython.Py_ssize_t - _max_queue_high: cython.Py_ssize_t - _max_queue_low: cython.Py_ssize_t + _max_message_size: cython.Py_ssize_t # 0 - no limit + _max_queue_high: cython.Py_ssize_t # 0 - no limit + _max_queue_low: cython.Py_ssize_t # 0 - no limit _incoming_message_active: cython.bint _incoming_message_size: cython.Py_ssize_t # Close logic + _close_timeout: Optional[float] _close_fut: asyncio.Future[None] _close_exc: Optional[ConnectionClosed] _pending_pings: Dict[bytes, Tuple[asyncio.Future[float], float]] _ping_interval: Optional[float] _ping_timeout: Optional[float] - _close_timeout: Optional[float] _keepalive_task: Optional[asyncio.Task[None]] _latency: cython.double @@ -375,13 +367,13 @@ def __init__( self._incoming_message_active = False self._incoming_message_size = 0 + self._close_timeout = close_timeout self._close_fut = self._loop.create_future() self._close_exc: Optional[ConnectionClosed] = None self._pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} self._ping_interval = ping_interval self._ping_timeout = ping_timeout - self._close_timeout = close_timeout self._keepalive_task: Optional[asyncio.Task[None]] = None self._latency = 0.0 @@ -498,6 +490,28 @@ def resume_writing(self) -> None: self._write_ready.set_result(None) self._write_ready = None + @cython.cfunc + @cython.inline + def _set_write_limits(self, write_limit: Union[int, tuple[int, Optional[int]]]) -> None: + if isinstance(write_limit, tuple): + high, low = write_limit + else: + high, low = write_limit, None + self.transport.underlying_transport.set_write_buffer_limits(high=high, low=low) + + @cython.cfunc + @cython.inline + def _configure_extensions(self) -> None: + header_value = self._response.headers.get("Sec-WebSocket-Extensions") + if header_value is None: + return + if self._compression != "deflate": + raise InvalidHandshake("unexpected websocket extensions negotiated by server") + if not isinstance(header_value, str): + raise InvalidHandshake("invalid Sec-WebSocket-Extensions header") + + self._permessage_deflate = _PerMessageDeflate.from_response_header(header_value) + @cython.cfunc @cython.inline def _process_pong_frame(self, frame: WSFrame) -> None: @@ -516,15 +530,6 @@ def _process_close_frame(self, frame: WSFrame) -> None: self.transport.send_close(close_code, close_message) self.transport.disconnect() - @cython.cfunc - @cython.inline - def _set_write_limits(self, write_limit: Union[int, tuple[int, Optional[int]]]) -> None: - if isinstance(write_limit, tuple): - high, low = write_limit - else: - high, low = write_limit, None - self.transport.underlying_transport.set_write_buffer_limits(high=high, low=low) - @cython.cfunc @cython.inline def _pause_reading_if_needed(self) -> None: @@ -562,19 +567,6 @@ def _wait_recv_queue_not_empty(self) -> asyncio.Future[None]: self._recv_waiter = waiter return waiter - @cython.cfunc - @cython.inline - def _configure_extensions(self) -> None: - header_value = self._response.headers.get("Sec-WebSocket-Extensions") - if header_value is None: - return - if self._compression != "deflate": - raise InvalidHandshake("unexpected websocket extensions negotiated by server") - if not isinstance(header_value, str): - raise InvalidHandshake("invalid Sec-WebSocket-Extensions header") - - self._permessage_deflate = _PerMessageDeflate.from_response_header(header_value) - @cython.cfunc @cython.inline def _set_close_exception(self) -> None: @@ -706,6 +698,7 @@ async def iterator() -> AsyncIterator[Data]: @cython.cfunc @cython.inline + @cython.exceptval(check=False) def _is_in_open_state(self) -> cython.bint: # Before on_ws_connected, self.transport is None # on_ws_frame immediately send CLOSE reply on incoming CLOSE frame, so receiving CLOSE == is_close_frame_sent @@ -728,9 +721,9 @@ async def _wait_close_and_raise(self, exc: Optional[BaseException]=None) -> NoRe # CANCELLATION: # _close_fut is supposed to be set only from on_ws_disconnected. # It is intentionally shielded. - await asyncio.shield(self._close_fut) + await _asyncio_shield(self._close_fut) assert self._close_exc is not None # pacify type checker - + if exc is None: raise self._close_exc else: @@ -760,7 +753,7 @@ async def _wait_write_ready(self) -> None: # _write_ready future is shielded intentionally. It is only supposed to # be set from resume_writing and on_ws_disconnected. assert self._write_ready is not None - await asyncio.shield(self._write_ready) + await _asyncio_shield(self._write_ready) if not self._is_in_open_state(): await self._wait_close_and_raise() @@ -776,6 +769,16 @@ async def _get_next_async_fragment(self, async_iterator: AsyncIterator[DataLike] await self._wait_close_and_raise() return data + @cython.cfunc + @cython.inline + def _check_fragment_type(self, message: DataLike, first_is_str: cython.bint) -> None: + if first_is_str and isinstance(message, str): + return + elif not first_is_str and isinstance(message, (bytes, bytearray, memoryview)): + return + + raise TypeError("all fragments must be of the same category: str vs bytes-like") + @cython.cfunc @cython.inline def _release_send(self) -> None: @@ -791,53 +794,6 @@ def _release_send(self) -> None: self._send_in_progress = False - async def send( - self, - message: Union[DataLike, Iterable[DataLike], AsyncIterator[DataLike]], - text: Optional[bool] = None, - ) -> None: - # send doesn't directly wait on helper futures. It is very tricky to handle - # disconnects and cancellations properly. send delegates this to - # _wait_* helpers. - if not self._is_in_open_state(): - await self._wait_close_and_raise() - - if self._send_in_progress: - await self._wait_send_turn() - else: - self._send_in_progress = True - - try: - if isinstance(message, (str, bytes, bytearray, memoryview)): - if isinstance(message, str): - msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT - else: - msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY - - self._encode_and_send(msg_type, message, True) - - if self._write_ready is not None: - await self._wait_write_ready() - # Catch a common mistake -- passing a dict to send(). - elif isinstance(message, Mapping): - raise TypeError("data is a dict-like object") - elif isinstance(message, (AsyncIterable, Iterable)): - await self._send_fragments(message, text) # type: ignore[arg-type] - else: - raise TypeError(f"message has unsupported type {type(message).__name__}") - finally: - self._release_send() - - @cython.cfunc - @cython.inline - def _check_fragment_type(self, message: DataLike, first_is_str: cython.bint) -> None: - if first_is_str and isinstance(message, str): - return - elif not first_is_str and isinstance(message, (bytes, bytearray, memoryview)): - return - - raise TypeError("all fragments must be of the same category: str vs bytes-like") - async def _send_fragments( self, messages: Union[AsyncIterable[DataLike], Iterable[DataLike]], @@ -907,6 +863,43 @@ async def _send_fragments( self.transport.disconnect(False) raise + async def send( + self, + message: Union[DataLike, Iterable[DataLike], AsyncIterator[DataLike]], + text: Optional[bool] = None, + ) -> None: + # send doesn't directly wait on helper futures. It is very tricky to handle + # disconnects and cancellations properly. send delegates this to + # _wait_* helpers. + if not self._is_in_open_state(): + await self._wait_close_and_raise() + + if self._send_in_progress: + await self._wait_send_turn() + else: + self._send_in_progress = True + + try: + if isinstance(message, (str, bytes, bytearray, memoryview)): + if isinstance(message, str): + msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT + else: + msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY + + self._encode_and_send(msg_type, message, True) + + if self._write_ready is not None: + await self._wait_write_ready() + # Catch a common mistake -- passing a dict to send(). + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + elif isinstance(message, (AsyncIterable, Iterable)): + await self._send_fragments(message, text) # type: ignore[arg-type] + else: + raise TypeError(f"message has unsupported type {type(message).__name__}") + finally: + self._release_send() + async def close(self, code: int = 1000, reason: str = "") -> None: if self._send_in_progress: self.transport.send_close(WSCloseCode.INTERNAL_ERROR, "close during fragmented message") diff --git a/picows/websockets/compat.py b/picows/websockets/compat.py index c627442..7b414fb 100644 --- a/picows/websockets/compat.py +++ b/picows/websockets/compat.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +from enum import IntEnum import picows from multidict import CIMultiDict @@ -8,6 +9,13 @@ CloseCode = picows.WSCloseCode +class State(IntEnum): + CONNECTING = 0 + OPEN = 1 + CLOSING = 2 + CLOSED = 3 + + @dataclass(slots=True) class Request: path: str @@ -43,6 +51,7 @@ def status(self) -> int: __all__ = [ + "State", "CloseCode", "Request", "Response", From a49336bc532563c8a630bbd08bd1a22b53e2fbaa Mon Sep 17 00:00:00 2001 From: taras Date: Wed, 6 May 2026 20:24:35 +0200 Subject: [PATCH 44/57] Add server implementation --- picows/websockets/__init__.py | 11 +- picows/websockets/asyncio/__init__.py | 11 +- picows/websockets/asyncio/client.py | 2 +- picows/websockets/asyncio/connection.py | 137 ++++++- picows/websockets/asyncio/router.py | 81 ++++ picows/websockets/asyncio/server.py | 498 ++++++++++++++++++++++++ picows/websockets/compat.py | 9 + tests/test_websockets_server_compat.py | 259 ++++++++++++ 8 files changed, 999 insertions(+), 9 deletions(-) create mode 100644 picows/websockets/asyncio/router.py create mode 100644 picows/websockets/asyncio/server.py create mode 100644 tests/test_websockets_server_compat.py diff --git a/picows/websockets/__init__.py b/picows/websockets/__init__.py index b8e2d41..ff475f9 100644 --- a/picows/websockets/__init__.py +++ b/picows/websockets/__init__.py @@ -1,6 +1,8 @@ from . import exceptions from .asyncio.client import connect -from .asyncio.connection import ClientConnection, process_exception +from .asyncio.connection import ClientConnection, ServerConnection, process_exception +from .asyncio.router import Router, route +from .asyncio.server import Server, basic_auth, broadcast, serve from .compat import CloseCode, Request, Response, State from .exceptions import ( ConcurrencyError, @@ -49,6 +51,9 @@ "CloseCode", "Data", "DataLike", + "Router", + "Server", + "ServerConnection", "ConcurrencyError", "ConnectionClosed", "ConnectionClosedError", @@ -85,7 +90,11 @@ "StatusLike", "Subprotocol", "WebSocketException", + "basic_auth", + "broadcast", "connect", "exceptions", "process_exception", + "route", + "serve", ] diff --git a/picows/websockets/asyncio/__init__.py b/picows/websockets/asyncio/__init__.py index e160d29..6bc3d1f 100644 --- a/picows/websockets/asyncio/__init__.py +++ b/picows/websockets/asyncio/__init__.py @@ -1,9 +1,18 @@ from .client import connect -from .connection import ClientConnection, process_exception +from .connection import ClientConnection, ServerConnection, process_exception +from .router import Router, route +from .server import Server, basic_auth, broadcast, serve from ..compat import State __all__ = [ "ClientConnection", + "Router", + "Server", + "ServerConnection", + "basic_auth", + "broadcast", "connect", "process_exception", + "route", + "serve", ] diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index b6d1d74..3cec097 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -2,6 +2,7 @@ import asyncio import socket +import sys from collections.abc import Generator from logging import getLogger from ssl import SSLContext @@ -35,7 +36,6 @@ def _default_user_agent() -> str: - import sys return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index a1861f8..b312cd4 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -1,9 +1,12 @@ from __future__ import annotations import asyncio +import http import logging import os +import sys import uuid +import weakref import zlib from collections import deque from collections.abc import AsyncIterable, Iterable @@ -12,6 +15,7 @@ Union, Dict, Tuple, Iterator, Mapping, NoReturn import cython +from multidict import CIMultiDict if cython.compiled: from cython.cimports.picows.picows import WSListener, WSTransport, WSFrame, \ @@ -30,7 +34,7 @@ InvalidHandshake, InvalidStatus, ) -from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol +from ..typing import BytesLike, Data, DataLike, LoggerLike, StatusLike, Subprotocol # cached for performance @@ -153,6 +157,18 @@ def from_response_header(cls, header_value: str) -> _PerMessageDeflate: return self + @classmethod + def enabled(cls) -> _PerMessageDeflate: + self: _PerMessageDeflate = _PerMessageDeflate.__new__(_PerMessageDeflate) + self.remote_no_context_takeover = False + self.local_no_context_takeover = False + self.remote_max_window_bits = -15 + self.local_max_window_bits = -15 + self._decoder = _zlib_decompressobj(wbits=self.remote_max_window_bits) + self._encoder = _zlib_compressobj(wbits=self.local_max_window_bits) + self._decode_cont_data = False + return self + @cython.cfunc @cython.inline def decode_frame(self, frame: WSFrame, max_length: cython.Py_ssize_t) -> bytes: @@ -263,8 +279,7 @@ def _normalize_watermarks( return max_queue, max_queue // 4 -@cython.cfunc -@cython.inline +@cython.ccall def _resolve_logger(logger: LoggerLike) -> Union[logging.Logger, logging.LoggerAdapter[Any]]: if logger is None: return logging.getLogger("websockets.client") @@ -284,6 +299,27 @@ def process_exception(exc: Exception) -> Optional[Exception]: return exc +def _default_server_header() -> str: + return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" + + +_pending_server_requests: weakref.WeakKeyDictionary[Any, Request] = weakref.WeakKeyDictionary() +_pending_server_responses: weakref.WeakKeyDictionary[Any, Response] = weakref.WeakKeyDictionary() +_pending_server_usernames: weakref.WeakKeyDictionary[Any, str] = weakref.WeakKeyDictionary() + + +def stash_server_request(connection: Any, request: Request) -> None: + _pending_server_requests[connection] = request + + +def stash_server_response(connection: Any, response: Response) -> None: + _pending_server_responses[connection] = response + + +def stash_server_username(connection: Any, username: str) -> None: + _pending_server_usernames[connection] = username + + @cython.cclass class ClientConnection(WSListener): # type: ignore[misc] id: uuid.UUID @@ -782,8 +818,9 @@ def _check_fragment_type(self, message: DataLike, first_is_str: cython.bint) -> @cython.cfunc @cython.inline def _release_send(self) -> None: - waiter: asyncio.Future[None] + self._send_in_progress = False + waiter: asyncio.Future[None] while self._send_waiters: waiter = self._send_waiters.popleft() # Some waiters may be canceled, that is why we have defensive check @@ -792,8 +829,6 @@ def _release_send(self) -> None: waiter.set_result(None) return - self._send_in_progress = False - async def _send_fragments( self, messages: Union[AsyncIterable[DataLike], Iterable[DataLike]], @@ -1043,3 +1078,93 @@ def close_reason(self) -> Optional[str]: if handshake.sent is not None: return _coerce_close_reason(handshake.sent.reason) # type: ignore[no-any-return] return None + + +def broadcast_message(connection: ClientConnection, msg_type: WSMsgType, message: DataLike) -> bool: + if connection._send_in_progress: + return False + connection._encode_and_send(msg_type, message, True) + return True + + +@cython.cclass +class ServerConnection(ClientConnection): + server: Any + handler: Any + handler_kwargs: Mapping[str, Any] + + def __init__( + self, + server: Any, + *, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = 10, + max_queue: Union[int, tuple[Optional[int], Optional[int]], None] = 16, + write_limit: Union[int, tuple[int, Optional[int]]] = 32768, + max_message_size: Optional[int] = 1024 * 1024, + logger: LoggerLike = None, + compression: Optional[str] = None, + ): + super().__init__( + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, + max_message_size=max_message_size, + logger=logger, + subprotocols=None, + compression=compression, + ) + self.server = server + self.handler = None + self.handler_kwargs = {} + + @cython.ccall + def on_ws_connected(self, transport: WSTransport) -> None: + self.transport = transport + request = _pending_server_requests.pop(self, None) + response = _pending_server_responses.pop(self, None) + self._request = request if request is not None else Request.from_picows(transport.request) + self._response = response if response is not None else Response.from_picows(transport.response) + try: + self._subprotocol = _resolve_subprotocol(None, self._response) + self._configure_extensions() + except InvalidHandshake as exc: + self._connect_exception = exc + self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, str(exc)) + self.transport.disconnect(False) + return + self._set_write_limits(self._write_limit) + if self._ping_interval is not None and self._keepalive_task is None: + self._keepalive_task = asyncio.create_task(self._keepalive_loop()) + self.server.loop.call_soon(self.server.start_connection_handler, self) + + @property + def username(self) -> str: + try: + return _pending_server_usernames[self] + except KeyError as exc: + raise AttributeError("username") from exc + + def respond(self, status: StatusLike, text: str) -> Response: + body = text.encode("utf-8") + request = _pending_server_requests.get(self) + headers = ( + type(request.headers)({ + "Content-Type": "text/plain; charset=utf-8", + "Content-Length": str(len(body)), + }) + if request is not None + else CIMultiDict({ + "Content-Type": "text/plain; charset=utf-8", + "Content-Length": str(len(body)), + }) + ) + return Response( + status_code=int(status), + reason_phrase=http.HTTPStatus(status).phrase, + headers=headers, + body=body, + ) diff --git a/picows/websockets/asyncio/router.py b/picows/websockets/asyncio/router.py new file mode 100644 index 0000000..d445923 --- /dev/null +++ b/picows/websockets/asyncio/router.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import http +import urllib.parse +from typing import Any + +from .server import Server, ServerConnection, serve +from ..compat import Request, Response + +try: + from werkzeug.routing import Map, RequestRedirect + from werkzeug.exceptions import NotFound +except ImportError: # pragma: no cover + Map = None + RequestRedirect = None + NotFound = None + + +class Router: + def __init__( + self, + url_map: Map, + server_name: str | None = None, + url_scheme: str = "ws", + ) -> None: + self.url_map = url_map + self.server_name = server_name + self.url_scheme = url_scheme + for rule in self.url_map.iter_rules(): + rule.websocket = True + + def get_server_name(self, connection: ServerConnection, request: Request) -> str: + if self.server_name is None: + return request.headers["Host"] + return self.server_name + + def redirect(self, connection: ServerConnection, url: str) -> Response: + response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}") + response.headers["Location"] = url + return response + + def not_found(self, connection: ServerConnection) -> Response: + return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found") + + def route_request(self, connection: ServerConnection, request: Request) -> Response | None: + url_map_adapter = self.url_map.bind( + server_name=self.get_server_name(connection, request), + url_scheme=self.url_scheme, + ) + try: + parsed = urllib.parse.urlparse(request.path) + handler, kwargs = url_map_adapter.match( + path_info=parsed.path, + query_args=parsed.query, + ) + except RequestRedirect as redirect: + return self.redirect(connection, redirect.new_url) + except NotFound: + return self.not_found(connection) + connection.handler = handler + connection.handler_kwargs = kwargs + return None + + async def handler(self, connection: ServerConnection) -> None: + handler = connection.handler + assert handler is not None + await handler(connection, **connection.handler_kwargs) + + +def route( + url_map: Map, + *args: Any, + server_name: str | None = None, + create_router: type[Router] | None = None, + **kwargs: Any, +) -> Any: + if Map is None: + raise ImportError("route() requires werkzeug") + router_cls = create_router or Router + router = router_cls(url_map, server_name=server_name) + return serve(router.handler, *args, process_request=router.route_request, **kwargs) diff --git a/picows/websockets/asyncio/server.py b/picows/websockets/asyncio/server.py new file mode 100644 index 0000000..de24d4f --- /dev/null +++ b/picows/websockets/asyncio/server.py @@ -0,0 +1,498 @@ +from __future__ import annotations + +import asyncio +import base64 +import binascii +import hmac +import http +import inspect +import re +import socket +import sys +from collections.abc import Awaitable, Callable, Iterable +from logging import getLogger +from typing import Any, Optional, Pattern, Sequence, cast + +import picows + +from .connection import ( + ServerConnection, + _default_server_header, + _resolve_logger, + broadcast_message, + stash_server_request, + stash_server_response, + stash_server_username, +) +from ..compat import Request, Response, State +from ..exceptions import ConcurrencyError, InvalidHandshake, InvalidOrigin +from ..typing import DataLike, HeadersLike, LoggerLike, Origin, StatusLike, Subprotocol + +__all__ = [ + "ServerConnection", + "Server", + "serve", + "broadcast", + "basic_auth", +] + + +_PERMESSAGE_DEFLATE_REQUEST = "permessage-deflate" + + +def _header_items(headers: Any) -> list[tuple[str, str]]: + return [] if headers is None else list(headers.items()) + + +def _supports_permessage_deflate(request: Request) -> bool: + value = request.headers.get("Sec-WebSocket-Extensions") + return isinstance(value, str) and "permessage-deflate" in value + + +def _origin_allowed( + origin: str | None, + origins: Sequence[Origin | Pattern[str] | None] | None, +) -> bool: + if origins is None: + return True + for candidate in origins: + if candidate is None: + if origin is None: + return True + elif isinstance(candidate, str): + if origin == candidate: + return True + elif candidate.fullmatch(origin or "") is not None: + return True + return False + + +def _select_subprotocol( + connection: ServerConnection, + request: Request, + subprotocols: Optional[Sequence[Subprotocol]], + select_subprotocol: Optional[Callable[[ServerConnection, Sequence[Subprotocol]], Subprotocol | None]], +) -> Optional[Subprotocol]: + header_value = request.headers.get("Sec-WebSocket-Protocol") + if header_value is None: + return None + offered = [item.strip() for item in header_value.split(",") if item.strip()] + if not offered: + return None + if select_subprotocol is not None: + selected = select_subprotocol(connection, offered) + if selected is not None and selected not in offered: + raise InvalidHandshake(f"selected subprotocol isn't offered by client: {selected}") + return selected + if subprotocols is None: + return None + for subprotocol in subprotocols: + if subprotocol in offered: + return subprotocol + return None + + +def _build_www_authenticate_basic(realm: str) -> str: + realm.encode("ascii") + return f'Basic realm="{realm}"' + + +def _parse_authorization_basic(header: str) -> tuple[str, str]: + scheme, _, token = header.partition(" ") + if scheme.lower() != "basic" or not token: + raise InvalidHandshake("unsupported authorization header") + try: + decoded = base64.b64decode(token.encode("ascii"), validate=True).decode("utf-8") + except (UnicodeDecodeError, ValueError, binascii.Error) as exc: + raise InvalidHandshake("invalid basic authorization header") from exc + username, sep, password = decoded.partition(":") + if not sep: + raise InvalidHandshake("invalid basic authorization header") + return username, password + + +def _is_credentials(value: object) -> bool: + return ( + isinstance(value, tuple) + and len(value) == 2 + and isinstance(value[0], str) + and isinstance(value[1], str) + ) + + +def basic_auth( + realm: str = "", + credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None, + check_credentials: Callable[[str, str], bool] | None = None, +) -> Callable[[ServerConnection, Request], Response | None]: + if (credentials is None) == (check_credentials is None): + raise ValueError("provide either credentials or check_credentials") + + if check_credentials is not None and inspect.iscoroutinefunction(check_credentials): + raise NotImplementedError("async check_credentials isn't supported by picows core yet") + + if credentials is not None: + if _is_credentials(credentials): + credentials_list = [cast(tuple[str, str], credentials)] + else: + if not isinstance(credentials, Iterable): + raise TypeError(f"invalid credentials argument: {credentials}") + credentials_iterable = cast(Iterable[tuple[str, str]], credentials) + credentials_list = list(credentials_iterable) + if not all(_is_credentials(item) for item in credentials_list): + raise TypeError(f"invalid credentials argument: {credentials}") + + credentials_dict = dict(credentials_list) + + def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + + assert check_credentials is not None + + def process_request( + connection: ServerConnection, + request: Request, + ) -> Response | None: + authorization = request.headers.get("Authorization") + if authorization is None: + response = connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing credentials\n") + response.headers["WWW-Authenticate"] = _build_www_authenticate_basic(realm) + return response + + try: + username, password = _parse_authorization_basic(authorization) + except InvalidHandshake: + response = connection.respond(http.HTTPStatus.UNAUTHORIZED, "Unsupported credentials\n") + response.headers["WWW-Authenticate"] = _build_www_authenticate_basic(realm) + return response + + if not check_credentials(username, password): + response = connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid credentials\n") + response.headers["WWW-Authenticate"] = _build_www_authenticate_basic(realm) + return response + + stash_server_username(connection, username) + return None + + return process_request + + +class Server: + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + *, + process_request: Callable[[ServerConnection, Request], Response | None] | None = None, + process_response: Callable[[ServerConnection, Request, Response], Response | None] | None = None, + server_header: str | None = _default_server_header(), + open_timeout: float | None = 10, + logger: LoggerLike | None = None, + ) -> None: + self.loop = asyncio.get_running_loop() + self.handler = handler + self.process_request = process_request + self.process_response = process_response + self.server_header = server_header + self.open_timeout = open_timeout + self.logger = _resolve_logger(logger if logger is not None else getLogger("websockets.server")) + self.handlers: dict[ServerConnection, asyncio.Task[None]] = {} + self.close_task: asyncio.Task[None] | None = None + self.closed_waiter: asyncio.Future[None] = self.loop.create_future() + self.server: asyncio.Server + + @property + def connections(self) -> set[ServerConnection]: + return {connection for connection in self.handlers if connection.state is State.OPEN} + + def wrap(self, server: asyncio.Server) -> None: + self.server = server + for sock in server.sockets: + if sock.family == socket.AF_INET: + name = "%s:%d" % sock.getsockname() + elif sock.family == socket.AF_INET6: + name = "[%s]:%d" % sock.getsockname()[:2] + else: + name = str(sock.getsockname()) + self.logger.info("server listening on %s", name) + + async def conn_handler(self, connection: ServerConnection) -> None: + try: + try: + await asyncio.sleep(0) + await self.handler(connection) + except Exception: + self.logger.error("connection handler failed", exc_info=True) + await connection.close(1011) + else: + await connection.close() + finally: + del self.handlers[connection] + + def start_connection_handler(self, connection: ServerConnection) -> None: + self.handlers[connection] = self.loop.create_task(self.conn_handler(connection)) + + def close( + self, + close_connections: bool = True, + code: int = 1001, + reason: str = "", + ) -> None: + if self.close_task is None: + self.close_task = self.loop.create_task(self._close(close_connections, code, reason)) + + async def _close( + self, + close_connections: bool = True, + code: int = 1001, + reason: str = "", + ) -> None: + self.logger.info("server closing") + self.server.close() + await asyncio.sleep(0) + if close_connections: + close_tasks = [ + asyncio.create_task(connection.close(code, reason)) + for connection in self.handlers + if connection.state is not State.CONNECTING + ] + if close_tasks: + await asyncio.wait(close_tasks) + await self.server.wait_closed() + if self.handlers: + await asyncio.wait(self.handlers.values()) + self.closed_waiter.set_result(None) + self.logger.info("server closed") + + async def wait_closed(self) -> None: + await asyncio.shield(self.closed_waiter) + + def get_loop(self) -> asyncio.AbstractEventLoop: + return self.server.get_loop() + + def is_serving(self) -> bool: + return self.server.is_serving() + + async def start_serving(self) -> None: + await self.server.start_serving() + + async def serve_forever(self) -> None: + await self.server.serve_forever() + + @property + def sockets(self) -> tuple[socket.socket, ...]: + return self.server.sockets + + async def __aenter__(self) -> Server: + return self + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + await self.wait_closed() + + +class serve: + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + host: str | None = None, + port: int | None = None, + *, + origins: Sequence[Origin | Pattern[str] | None] | None = None, + extensions: Sequence[Any] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: Callable[[ServerConnection, Sequence[Subprotocol]], Subprotocol | None] | None = None, + compression: str | None = "deflate", + process_request: Callable[[ServerConnection, Request], Response | None] | None = None, + process_response: Callable[[ServerConnection, Request, Response], Response | None] | None = None, + server_header: str | None = _default_server_header(), + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_size: int | None | tuple[int | None, int | None] = 1024 * 1024, + max_queue: int | None | tuple[int | None, int | None] = 16, + write_limit: int | tuple[int, int | None] = 32768, + logger: LoggerLike | None = None, + create_connection: type[ServerConnection] | None = None, + **kwargs: Any, + ): + self.handler = handler + self.host = host + self.port = port + self.origins = origins + self.extensions = extensions + self.subprotocols = subprotocols + self.select_subprotocol = select_subprotocol + self.compression = compression + self.process_request = process_request + self.process_response = process_response + self.server_header = server_header + self.open_timeout = open_timeout + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.close_timeout = close_timeout + self.max_size = max_size + self.max_queue = max_queue + self.write_limit = write_limit + self.logger = logger + self.connection_factory = create_connection or ServerConnection + self.kwargs = kwargs + self._server: Server | None = None + + def __await__(self) -> Any: + return self._create().__await__() + + async def __aenter__(self) -> Server: + self._server = await self._create() + return self._server + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + assert self._server is not None + self._server.close() + await self._server.wait_closed() + self._server = None + + async def _create(self) -> Server: + if self.extensions is not None: + raise NotImplementedError("custom server extensions aren't supported by picows.websockets") + if self.compression not in (None, "deflate"): + raise NotImplementedError("only compression=None or 'deflate' are accepted") + if self.process_request is not None and inspect.iscoroutinefunction(self.process_request): + raise NotImplementedError("async process_request isn't supported by picows core yet") + if self.process_response is not None and inspect.iscoroutinefunction(self.process_response): + raise NotImplementedError("async process_response isn't supported by picows core yet") + + server = Server( + self.handler, + process_request=self.process_request, + process_response=self.process_response, + server_header=self.server_header, + open_timeout=self.open_timeout, + logger=self.logger, + ) + + max_message_size = self.max_size[0] if isinstance(self.max_size, tuple) else self.max_size + max_frame_size = 2 ** 31 - 1 if max_message_size is None else max_message_size + + def listener_factory( + upgrade_request: picows.WSUpgradeRequest, + ) -> picows.WSUpgradeResponseWithListener: + request = Request.from_picows(upgrade_request) + connection = self.connection_factory( + server, + ping_interval=self.ping_interval, + ping_timeout=self.ping_timeout, + close_timeout=self.close_timeout, + max_queue=self.max_queue, + write_limit=self.write_limit, + max_message_size=max_message_size, + logger=self.logger, + compression=self.compression, + ) + stash_server_request(connection, request) + + response: Response | None = None + + origin = request.headers.get("Origin") + if origin is not None and not isinstance(origin, str): + raise InvalidOrigin(None) + if not _origin_allowed(origin, self.origins): + response = connection.respond(http.HTTPStatus.FORBIDDEN, "Origin not allowed\n") + + if self.process_request is not None and response is None: + response = self.process_request(connection, request) + + if response is None: + if server.close_task is not None: + response = connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "Server is shutting down.\n") + else: + headers = {} + if self.server_header is not None: + headers["Server"] = self.server_header + subprotocol = _select_subprotocol(connection, request, self.subprotocols, self.select_subprotocol) + if subprotocol is not None: + headers["Sec-WebSocket-Protocol"] = subprotocol + if self.compression == "deflate" and _supports_permessage_deflate(request): + headers["Sec-WebSocket-Extensions"] = _PERMESSAGE_DEFLATE_REQUEST + response = Response( + status_code=int(http.HTTPStatus.SWITCHING_PROTOCOLS), + reason_phrase=http.HTTPStatus.SWITCHING_PROTOCOLS.phrase, + headers=type(request.headers)(headers), + body=b"", + ) + + if self.process_response is not None: + updated = self.process_response(connection, request, response) + if updated is not None: + response = updated + + assert response is not None + stash_server_response(connection, response) + listener = connection if response.status_code == int(http.HTTPStatus.SWITCHING_PROTOCOLS) else None + if listener is not None: + raw_response = picows.WSUpgradeResponse.create_101_response(response.headers) + else: + raw_response = response.to_picows() + return picows.WSUpgradeResponseWithListener(raw_response, listener) + + raw_server = await picows.ws_create_server( + listener_factory, + self.host, + self.port, + websocket_handshake_timeout=self.open_timeout, + enable_auto_ping=False, + enable_auto_pong=True, + max_frame_size=max_frame_size, + logger_name=self.logger if self.logger is not None else getLogger("websockets.server"), + **self.kwargs, + ) + server.wrap(raw_server) + return server + + +def broadcast( + connections: Iterable[ServerConnection], + message: DataLike, + raise_exceptions: bool = False, +) -> None: + if isinstance(message, str): + msg_type = picows.WSMsgType.TEXT + elif isinstance(message, (bytes, bytearray, memoryview)): + msg_type = picows.WSMsgType.BINARY + else: + raise TypeError("data must be str or bytes") + + if raise_exceptions: + if sys.version_info[:2] < (3, 11): + raise ValueError("raise_exceptions requires at least Python 3.11") + exceptions: list[Exception] = [] + + for connection in connections: + if connection.state is not State.OPEN: + continue + try: + sent = broadcast_message(connection, msg_type, message) + except Exception as write_exception: + if raise_exceptions: + exception = RuntimeError("failed to write message") + exception.__cause__ = write_exception + exceptions.append(exception) + else: + getLogger("websockets.server").warning( + "skipped broadcast: failed to write message: %s", + write_exception, + ) + continue + + if not sent: + if raise_exceptions: + exceptions.append(ConcurrencyError("sending a fragmented message")) + else: + getLogger("websockets.server").warning("skipped broadcast: sending a fragmented message") + + if raise_exceptions and exceptions: + raise ExceptionGroup("skipped broadcast", exceptions) diff --git a/picows/websockets/compat.py b/picows/websockets/compat.py index 7b414fb..55b2e0a 100644 --- a/picows/websockets/compat.py +++ b/picows/websockets/compat.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from enum import IntEnum +from http import HTTPStatus import picows from multidict import CIMultiDict @@ -49,6 +50,14 @@ def from_picows(cls, response: picows.WSUpgradeResponse) -> Response: def status(self) -> int: return self.status_code + def to_picows(self) -> picows.WSUpgradeResponse: + response = picows.WSUpgradeResponse() + response.version = b"HTTP/1.1" + response.status = HTTPStatus(self.status_code) + response.headers = self.headers.copy() + response.body = bytes(self.body) + return response + __all__ = [ "State", diff --git a/tests/test_websockets_server_compat.py b/tests/test_websockets_server_compat.py new file mode 100644 index 0000000..a48752d --- /dev/null +++ b/tests/test_websockets_server_compat.py @@ -0,0 +1,259 @@ +import asyncio +import base64 +import http +import re + +import pytest + +from picows import websockets + + +async def test_serve_echo_roundtrip(): + async def handler(ws: websockets.ServerConnection) -> None: + message = await ws.recv() + await ws.send(message) + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: + await ws.send("hello") + assert await ws.recv() == "hello" + + +async def test_serve_process_request_can_reject_handshake(): + async def handler(ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + def process_request(ws: websockets.ServerConnection, request: websockets.Request) -> websockets.Response: + return ws.respond(http.HTTPStatus.FORBIDDEN, "nope") + + async with websockets.serve( + handler, "127.0.0.1", 0, compression=None, process_request=process_request + ) as server: + port = server.sockets[0].getsockname()[1] + with pytest.raises(websockets.InvalidStatus): + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + pass + + +async def test_serve_process_response_can_mutate_headers(): + async def handler(ws: websockets.ServerConnection) -> None: + await ws.wait_closed() + + def process_response( + ws: websockets.ServerConnection, + request: websockets.Request, + response: websockets.Response, + ) -> websockets.Response | None: + response.headers["X-Test"] = "1" + return response + + async with websockets.serve( + handler, "127.0.0.1", 0, compression=None, process_response=process_response + ) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: + assert ws.response.headers["X-Test"] == "1" + + +async def test_serve_accepts_allowed_origin(): + async def handler(ws: websockets.ServerConnection) -> None: + await ws.send("ok") + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + origins=["https://example.com"], + ) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + origin="https://example.com", + ) as ws: + assert await ws.recv() == "ok" + + +async def test_serve_rejects_disallowed_origin(): + async def handler(ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + origins=[re.compile(r"https://allowed\\.example\\.com")], + ) as server: + port = server.sockets[0].getsockname()[1] + with pytest.raises(websockets.InvalidStatus): + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + origin="https://denied.example.com", + ): + pass + + +async def test_basic_auth_accepts_valid_credentials_and_sets_username(): + seen_usernames: list[str] = [] + + async def handler(ws: websockets.ServerConnection) -> None: + seen_usernames.append(ws.username) + await ws.send(ws.username) + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=websockets.basic_auth( + realm="test", + credentials=("hello", "secret"), + ), + ) as server: + port = server.sockets[0].getsockname()[1] + token = base64.b64encode(b"hello:secret").decode() + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + additional_headers={"Authorization": f"Basic {token}"}, + ) as ws: + assert await ws.recv() == "hello" + assert seen_usernames == ["hello"] + + +async def test_basic_auth_rejects_missing_credentials(): + async def handler(ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=websockets.basic_auth( + realm="test", + credentials=("hello", "secret"), + ), + ) as server: + port = server.sockets[0].getsockname()[1] + with pytest.raises(websockets.InvalidStatus): + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + pass + + +async def test_serve_negotiates_subprotocol(): + async def handler(ws: websockets.ServerConnection) -> None: + await ws.wait_closed() + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + subprotocols=["chat", "superchat"], + ) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + subprotocols=["superchat", "chat"], + ) as ws: + assert ws.subprotocol == "chat" + + +async def test_broadcast_sends_to_open_connections(): + connections: list[websockets.ServerConnection] = [] + + async def handler(ws: websockets.ServerConnection) -> None: + connections.append(ws) + await ws.wait_closed() + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws1: + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws2: + while len(connections) < 2: + await asyncio.sleep(0) + websockets.broadcast(connections, "hi") + assert await ws1.recv() == "hi" + assert await ws2.recv() == "hi" + + +async def test_server_connections_tracks_open_connections(): + connected = asyncio.Event() + + async def handler(ws: websockets.ServerConnection) -> None: + connected.set() + await ws.wait_closed() + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + assert server.connections == set() + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + await connected.wait() + assert len(server.connections) == 1 + await asyncio.sleep(0) + assert server.connections == set() + + +async def test_handler_exception_closes_connection_with_internal_error(): + async def handler(ws: websockets.ServerConnection) -> None: + raise RuntimeError("boom") + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + assert ws.close_code == 1011 + + +async def test_server_close_closes_existing_connections(): + started = asyncio.Event() + + async def handler(ws: websockets.ServerConnection) -> None: + started.set() + await ws.wait_closed() + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: + await started.wait() + server.close(reason="bye") + with pytest.raises(websockets.ConnectionClosedOK): + await ws.recv() + assert ws.close_code == 1001 + assert ws.close_reason == "bye" + await server.wait_closed() + + +async def test_wait_closed_waits_for_handler_completion(): + started = asyncio.Event() + finish = asyncio.Event() + finished = asyncio.Event() + + async def handler(ws: websockets.ServerConnection) -> None: + started.set() + await finish.wait() + finished.set() + + server = await websockets.serve(handler, "127.0.0.1", 0, compression=None) + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + await started.wait() + server.close(close_connections=False) + waiter = asyncio.create_task(server.wait_closed()) + await asyncio.sleep(0) + assert not waiter.done() + finish.set() + await waiter + assert finished.is_set() + + +def test_route_requires_werkzeug(): + with pytest.raises(ImportError): + websockets.route(None) # type: ignore[arg-type] From d9c85231f1f15a55bf30a582f2f5a875fdf487f5 Mon Sep 17 00:00:00 2001 From: taras Date: Thu, 7 May 2026 05:12:01 +0200 Subject: [PATCH 45/57] Simplify implmentation --- picows/websockets/__init__.py | 3 +- picows/websockets/asyncio/__init__.py | 3 +- picows/websockets/asyncio/client.py | 6 +- picows/websockets/asyncio/connection.py | 136 +++++++-------- picows/websockets/asyncio/router.py | 78 +-------- picows/websockets/asyncio/server.py | 223 +++++++----------------- tests/test_websockets_compat.py | 46 +---- tests/test_websockets_server_compat.py | 111 ++++-------- 8 files changed, 173 insertions(+), 433 deletions(-) diff --git a/picows/websockets/__init__.py b/picows/websockets/__init__.py index ff475f9..e97f1b8 100644 --- a/picows/websockets/__init__.py +++ b/picows/websockets/__init__.py @@ -1,7 +1,7 @@ from . import exceptions from .asyncio.client import connect from .asyncio.connection import ClientConnection, ServerConnection, process_exception -from .asyncio.router import Router, route +from .asyncio.router import route from .asyncio.server import Server, basic_auth, broadcast, serve from .compat import CloseCode, Request, Response, State from .exceptions import ( @@ -51,7 +51,6 @@ "CloseCode", "Data", "DataLike", - "Router", "Server", "ServerConnection", "ConcurrencyError", diff --git a/picows/websockets/asyncio/__init__.py b/picows/websockets/asyncio/__init__.py index 6bc3d1f..da25ffb 100644 --- a/picows/websockets/asyncio/__init__.py +++ b/picows/websockets/asyncio/__init__.py @@ -1,12 +1,11 @@ from .client import connect from .connection import ClientConnection, ServerConnection, process_exception -from .router import Router, route +from .router import route from .server import Server, basic_auth, broadcast, serve from ..compat import State __all__ = [ "ClientConnection", - "Router", "Server", "ServerConnection", "basic_auth", diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index 3cec097..6b0ffc2 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -83,7 +83,6 @@ def __init__( max_queue: Union[int, tuple[Optional[int], Optional[int]], None] = 16, write_limit: Union[int, tuple[int, Optional[int]]] = 32768, logger: LoggerLike = None, - create_connection: Optional[type[ClientConnection]] = None, **kwargs: Any, ): self.uri = uri @@ -103,7 +102,8 @@ def __init__( self.max_queue = max_queue self.write_limit = write_limit self.logger = logger - self.connection_factory = create_connection or ClientConnection + if "create_connection" in kwargs: + raise NotImplementedError("create_connection isn't supported by picows.websockets yet") self.kwargs = kwargs self._connection: Optional[ClientConnection] = None self._backoff = 1.0 @@ -188,7 +188,7 @@ async def connect_override(_: Any) -> socket.socket: socket_factory = connect_override def listener_factory() -> ClientConnection: - return self.connection_factory( + return ClientConnection( ping_interval=self.ping_interval, ping_timeout=self.ping_timeout, close_timeout=self.close_timeout, diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index b312cd4..17b567c 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -1,12 +1,10 @@ from __future__ import annotations import asyncio -import http import logging import os import sys import uuid -import weakref import zlib from collections import deque from collections.abc import AsyncIterable, Iterable @@ -15,7 +13,6 @@ Union, Dict, Tuple, Iterator, Mapping, NoReturn import cython -from multidict import CIMultiDict if cython.compiled: from cython.cimports.picows.picows import WSListener, WSTransport, WSFrame, \ @@ -34,7 +31,7 @@ InvalidHandshake, InvalidStatus, ) -from ..typing import BytesLike, Data, DataLike, LoggerLike, StatusLike, Subprotocol +from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol # cached for performance @@ -303,25 +300,8 @@ def _default_server_header() -> str: return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" -_pending_server_requests: weakref.WeakKeyDictionary[Any, Request] = weakref.WeakKeyDictionary() -_pending_server_responses: weakref.WeakKeyDictionary[Any, Response] = weakref.WeakKeyDictionary() -_pending_server_usernames: weakref.WeakKeyDictionary[Any, str] = weakref.WeakKeyDictionary() - - -def stash_server_request(connection: Any, request: Request) -> None: - _pending_server_requests[connection] = request - - -def stash_server_response(connection: Any, response: Response) -> None: - _pending_server_responses[connection] = response - - -def stash_server_username(connection: Any, username: str) -> None: - _pending_server_usernames[connection] = username - - @cython.cclass -class ClientConnection(WSListener): # type: ignore[misc] +class ConnectionBase(WSListener): # type: ignore[misc] id: uuid.UUID logger: Union[logging.Logger, logging.LoggerAdapter[Any]] transport: WSTransport @@ -415,20 +395,7 @@ def __init__( @cython.ccall def on_ws_connected(self, transport: WSTransport) -> None: - self.transport = transport - self._request = Request.from_picows(transport.request) - self._response = Response.from_picows(transport.response) - try: - self._subprotocol = _resolve_subprotocol(self._subprotocols, self._response) - self._configure_extensions() - except InvalidHandshake as exc: - self._connect_exception = exc - self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, str(exc)) - self.transport.disconnect(False) - return - self._set_write_limits(self._write_limit) - if self._ping_interval is not None and self._keepalive_task is None: - self._keepalive_task = asyncio.create_task(self._keepalive_loop()) + raise NotImplementedError @cython.ccall def on_ws_disconnected(self, transport: WSTransport) -> None: @@ -1080,7 +1047,52 @@ def close_reason(self) -> Optional[str]: return None -def broadcast_message(connection: ClientConnection, msg_type: WSMsgType, message: DataLike) -> bool: +@cython.cclass +class ClientConnection(ConnectionBase): + def __init__( + self, + *, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = 10, + max_queue: Union[int, tuple[Optional[int], Optional[int]], None] = 16, + write_limit: Union[int, tuple[int, Optional[int]]] = 32768, + max_message_size: Optional[int] = 1024 * 1024, + logger: LoggerLike = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + compression: Optional[str] = None, + ): + super().__init__( + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, + max_message_size=max_message_size, + logger=logger, + subprotocols=subprotocols, + compression=compression, + ) + + @cython.ccall + def on_ws_connected(self, transport: WSTransport) -> None: + self.transport = transport + self._request = Request.from_picows(transport.request) + self._response = Response.from_picows(transport.response) + try: + self._subprotocol = _resolve_subprotocol(self._subprotocols, self._response) + self._configure_extensions() + except InvalidHandshake as exc: + self._connect_exception = exc + self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, str(exc)) + self.transport.disconnect(False) + return + self._set_write_limits(self._write_limit) + if self._ping_interval is not None and self._keepalive_task is None: + self._keepalive_task = asyncio.create_task(self._keepalive_loop()) + + +def broadcast_message(connection: ConnectionBase, msg_type: WSMsgType, message: DataLike) -> bool: if connection._send_in_progress: return False connection._encode_and_send(msg_type, message, True) @@ -1088,14 +1100,16 @@ def broadcast_message(connection: ClientConnection, msg_type: WSMsgType, message @cython.cclass -class ServerConnection(ClientConnection): +class ServerConnection(ConnectionBase): server: Any - handler: Any - handler_kwargs: Mapping[str, Any] + _initial_request: Optional[Request] + _initial_response: Optional[Response] def __init__( self, server: Any, + request: Request, + response: Response, *, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, @@ -1118,16 +1132,18 @@ def __init__( compression=compression, ) self.server = server - self.handler = None - self.handler_kwargs = {} + self._initial_request = request + self._initial_response = response @cython.ccall def on_ws_connected(self, transport: WSTransport) -> None: self.transport = transport - request = _pending_server_requests.pop(self, None) - response = _pending_server_responses.pop(self, None) - self._request = request if request is not None else Request.from_picows(transport.request) - self._response = response if response is not None else Response.from_picows(transport.response) + assert self._initial_request is not None + assert self._initial_response is not None + self._request = self._initial_request + self._response = self._initial_response + self._initial_request = None + self._initial_response = None try: self._subprotocol = _resolve_subprotocol(None, self._response) self._configure_extensions() @@ -1140,31 +1156,3 @@ def on_ws_connected(self, transport: WSTransport) -> None: if self._ping_interval is not None and self._keepalive_task is None: self._keepalive_task = asyncio.create_task(self._keepalive_loop()) self.server.loop.call_soon(self.server.start_connection_handler, self) - - @property - def username(self) -> str: - try: - return _pending_server_usernames[self] - except KeyError as exc: - raise AttributeError("username") from exc - - def respond(self, status: StatusLike, text: str) -> Response: - body = text.encode("utf-8") - request = _pending_server_requests.get(self) - headers = ( - type(request.headers)({ - "Content-Type": "text/plain; charset=utf-8", - "Content-Length": str(len(body)), - }) - if request is not None - else CIMultiDict({ - "Content-Type": "text/plain; charset=utf-8", - "Content-Length": str(len(body)), - }) - ) - return Response( - status_code=int(status), - reason_phrase=http.HTTPStatus(status).phrase, - headers=headers, - body=body, - ) diff --git a/picows/websockets/asyncio/router.py b/picows/websockets/asyncio/router.py index d445923..d9a4c5a 100644 --- a/picows/websockets/asyncio/router.py +++ b/picows/websockets/asyncio/router.py @@ -1,81 +1,7 @@ from __future__ import annotations -import http -import urllib.parse from typing import Any -from .server import Server, ServerConnection, serve -from ..compat import Request, Response -try: - from werkzeug.routing import Map, RequestRedirect - from werkzeug.exceptions import NotFound -except ImportError: # pragma: no cover - Map = None - RequestRedirect = None - NotFound = None - - -class Router: - def __init__( - self, - url_map: Map, - server_name: str | None = None, - url_scheme: str = "ws", - ) -> None: - self.url_map = url_map - self.server_name = server_name - self.url_scheme = url_scheme - for rule in self.url_map.iter_rules(): - rule.websocket = True - - def get_server_name(self, connection: ServerConnection, request: Request) -> str: - if self.server_name is None: - return request.headers["Host"] - return self.server_name - - def redirect(self, connection: ServerConnection, url: str) -> Response: - response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}") - response.headers["Location"] = url - return response - - def not_found(self, connection: ServerConnection) -> Response: - return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found") - - def route_request(self, connection: ServerConnection, request: Request) -> Response | None: - url_map_adapter = self.url_map.bind( - server_name=self.get_server_name(connection, request), - url_scheme=self.url_scheme, - ) - try: - parsed = urllib.parse.urlparse(request.path) - handler, kwargs = url_map_adapter.match( - path_info=parsed.path, - query_args=parsed.query, - ) - except RequestRedirect as redirect: - return self.redirect(connection, redirect.new_url) - except NotFound: - return self.not_found(connection) - connection.handler = handler - connection.handler_kwargs = kwargs - return None - - async def handler(self, connection: ServerConnection) -> None: - handler = connection.handler - assert handler is not None - await handler(connection, **connection.handler_kwargs) - - -def route( - url_map: Map, - *args: Any, - server_name: str | None = None, - create_router: type[Router] | None = None, - **kwargs: Any, -) -> Any: - if Map is None: - raise ImportError("route() requires werkzeug") - router_cls = create_router or Router - router = router_cls(url_map, server_name=server_name) - return serve(router.handler, *args, process_request=router.route_request, **kwargs) +def route(*args: Any, **kwargs: Any) -> Any: + raise NotImplementedError("route() requires unsupported server process_request hooks") diff --git a/picows/websockets/asyncio/server.py b/picows/websockets/asyncio/server.py index de24d4f..755943f 100644 --- a/picows/websockets/asyncio/server.py +++ b/picows/websockets/asyncio/server.py @@ -1,17 +1,13 @@ from __future__ import annotations import asyncio -import base64 -import binascii -import hmac import http -import inspect import re import socket import sys from collections.abc import Awaitable, Callable, Iterable from logging import getLogger -from typing import Any, Optional, Pattern, Sequence, cast +from typing import Any, Optional, Pattern, Sequence import picows @@ -20,13 +16,10 @@ _default_server_header, _resolve_logger, broadcast_message, - stash_server_request, - stash_server_response, - stash_server_username, ) from ..compat import Request, Response, State from ..exceptions import ConcurrencyError, InvalidHandshake, InvalidOrigin -from ..typing import DataLike, HeadersLike, LoggerLike, Origin, StatusLike, Subprotocol +from ..typing import DataLike, LoggerLike, Origin, Subprotocol __all__ = [ "ServerConnection", @@ -92,93 +85,8 @@ def _select_subprotocol( return None -def _build_www_authenticate_basic(realm: str) -> str: - realm.encode("ascii") - return f'Basic realm="{realm}"' - - -def _parse_authorization_basic(header: str) -> tuple[str, str]: - scheme, _, token = header.partition(" ") - if scheme.lower() != "basic" or not token: - raise InvalidHandshake("unsupported authorization header") - try: - decoded = base64.b64decode(token.encode("ascii"), validate=True).decode("utf-8") - except (UnicodeDecodeError, ValueError, binascii.Error) as exc: - raise InvalidHandshake("invalid basic authorization header") from exc - username, sep, password = decoded.partition(":") - if not sep: - raise InvalidHandshake("invalid basic authorization header") - return username, password - - -def _is_credentials(value: object) -> bool: - return ( - isinstance(value, tuple) - and len(value) == 2 - and isinstance(value[0], str) - and isinstance(value[1], str) - ) - - -def basic_auth( - realm: str = "", - credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None, - check_credentials: Callable[[str, str], bool] | None = None, -) -> Callable[[ServerConnection, Request], Response | None]: - if (credentials is None) == (check_credentials is None): - raise ValueError("provide either credentials or check_credentials") - - if check_credentials is not None and inspect.iscoroutinefunction(check_credentials): - raise NotImplementedError("async check_credentials isn't supported by picows core yet") - - if credentials is not None: - if _is_credentials(credentials): - credentials_list = [cast(tuple[str, str], credentials)] - else: - if not isinstance(credentials, Iterable): - raise TypeError(f"invalid credentials argument: {credentials}") - credentials_iterable = cast(Iterable[tuple[str, str]], credentials) - credentials_list = list(credentials_iterable) - if not all(_is_credentials(item) for item in credentials_list): - raise TypeError(f"invalid credentials argument: {credentials}") - - credentials_dict = dict(credentials_list) - - def check_credentials(username: str, password: str) -> bool: - try: - expected_password = credentials_dict[username] - except KeyError: - return False - return hmac.compare_digest(expected_password, password) - - assert check_credentials is not None - - def process_request( - connection: ServerConnection, - request: Request, - ) -> Response | None: - authorization = request.headers.get("Authorization") - if authorization is None: - response = connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing credentials\n") - response.headers["WWW-Authenticate"] = _build_www_authenticate_basic(realm) - return response - - try: - username, password = _parse_authorization_basic(authorization) - except InvalidHandshake: - response = connection.respond(http.HTTPStatus.UNAUTHORIZED, "Unsupported credentials\n") - response.headers["WWW-Authenticate"] = _build_www_authenticate_basic(realm) - return response - - if not check_credentials(username, password): - response = connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid credentials\n") - response.headers["WWW-Authenticate"] = _build_www_authenticate_basic(realm) - return response - - stash_server_username(connection, username) - return None - - return process_request +def basic_auth(*args: Any, **kwargs: Any) -> Any: + raise NotImplementedError("basic_auth() requires unsupported server process_request hooks") class Server: @@ -186,16 +94,12 @@ def __init__( self, handler: Callable[[ServerConnection], Awaitable[None]], *, - process_request: Callable[[ServerConnection, Request], Response | None] | None = None, - process_response: Callable[[ServerConnection, Request, Response], Response | None] | None = None, server_header: str | None = _default_server_header(), open_timeout: float | None = 10, logger: LoggerLike | None = None, ) -> None: self.loop = asyncio.get_running_loop() self.handler = handler - self.process_request = process_request - self.process_response = process_response self.server_header = server_header self.open_timeout = open_timeout self.logger = _resolve_logger(logger if logger is not None else getLogger("websockets.server")) @@ -306,8 +210,6 @@ def __init__( subprotocols: Sequence[Subprotocol] | None = None, select_subprotocol: Callable[[ServerConnection, Sequence[Subprotocol]], Subprotocol | None] | None = None, compression: str | None = "deflate", - process_request: Callable[[ServerConnection, Request], Response | None] | None = None, - process_response: Callable[[ServerConnection, Request, Response], Response | None] | None = None, server_header: str | None = _default_server_header(), open_timeout: float | None = 10, ping_interval: float | None = 20, @@ -317,7 +219,6 @@ def __init__( max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 32768, logger: LoggerLike | None = None, - create_connection: type[ServerConnection] | None = None, **kwargs: Any, ): self.handler = handler @@ -328,8 +229,6 @@ def __init__( self.subprotocols = subprotocols self.select_subprotocol = select_subprotocol self.compression = compression - self.process_request = process_request - self.process_response = process_response self.server_header = server_header self.open_timeout = open_timeout self.ping_interval = ping_interval @@ -339,7 +238,10 @@ def __init__( self.max_queue = max_queue self.write_limit = write_limit self.logger = logger - self.connection_factory = create_connection or ServerConnection + if "create_connection" in kwargs: + raise NotImplementedError("create_connection isn't supported by picows.websockets server yet") + self.process_request = kwargs.pop("process_request", None) + self.process_response = kwargs.pop("process_response", None) self.kwargs = kwargs self._server: Server | None = None @@ -361,15 +263,18 @@ async def _create(self) -> Server: raise NotImplementedError("custom server extensions aren't supported by picows.websockets") if self.compression not in (None, "deflate"): raise NotImplementedError("only compression=None or 'deflate' are accepted") - if self.process_request is not None and inspect.iscoroutinefunction(self.process_request): - raise NotImplementedError("async process_request isn't supported by picows core yet") - if self.process_response is not None and inspect.iscoroutinefunction(self.process_response): - raise NotImplementedError("async process_response isn't supported by picows core yet") + unsupported = [] + if self.process_request is not None: + unsupported.append("process_request") + if self.process_response is not None: + unsupported.append("process_response") + if unsupported: + raise NotImplementedError( + f"{', '.join(unsupported)} isn't supported by picows.websockets server yet" + ) server = Server( self.handler, - process_request=self.process_request, - process_response=self.process_response, server_header=self.server_header, open_timeout=self.open_timeout, logger=self.logger, @@ -382,8 +287,41 @@ def listener_factory( upgrade_request: picows.WSUpgradeRequest, ) -> picows.WSUpgradeResponseWithListener: request = Request.from_picows(upgrade_request) - connection = self.connection_factory( + origin = request.headers.get("Origin") + if origin is not None and not isinstance(origin, str): + raise InvalidOrigin(None) + if not _origin_allowed(origin, self.origins): + return picows.WSUpgradeResponseWithListener( + picows.WSUpgradeResponse.create_error_response( + http.HTTPStatus.FORBIDDEN, + b"Origin not allowed\n", + {"Content-Type": "text/plain; charset=utf-8"}, + ), + None, + ) + + if server.close_task is not None: + return picows.WSUpgradeResponseWithListener( + picows.WSUpgradeResponse.create_error_response( + http.HTTPStatus.SERVICE_UNAVAILABLE, + b"Server is shutting down.\n", + {"Content-Type": "text/plain; charset=utf-8"}, + ), + None, + ) + + headers = {} + if self.server_header is not None: + headers["Server"] = self.server_header + connection = ServerConnection( server, + request=request, + response=Response( + status_code=int(http.HTTPStatus.SWITCHING_PROTOCOLS), + reason_phrase=http.HTTPStatus.SWITCHING_PROTOCOLS.phrase, + headers=type(request.headers)({}), + body=b"", + ), ping_interval=self.ping_interval, ping_timeout=self.ping_timeout, close_timeout=self.close_timeout, @@ -393,51 +331,20 @@ def listener_factory( logger=self.logger, compression=self.compression, ) - stash_server_request(connection, request) - - response: Response | None = None - - origin = request.headers.get("Origin") - if origin is not None and not isinstance(origin, str): - raise InvalidOrigin(None) - if not _origin_allowed(origin, self.origins): - response = connection.respond(http.HTTPStatus.FORBIDDEN, "Origin not allowed\n") - - if self.process_request is not None and response is None: - response = self.process_request(connection, request) - - if response is None: - if server.close_task is not None: - response = connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "Server is shutting down.\n") - else: - headers = {} - if self.server_header is not None: - headers["Server"] = self.server_header - subprotocol = _select_subprotocol(connection, request, self.subprotocols, self.select_subprotocol) - if subprotocol is not None: - headers["Sec-WebSocket-Protocol"] = subprotocol - if self.compression == "deflate" and _supports_permessage_deflate(request): - headers["Sec-WebSocket-Extensions"] = _PERMESSAGE_DEFLATE_REQUEST - response = Response( - status_code=int(http.HTTPStatus.SWITCHING_PROTOCOLS), - reason_phrase=http.HTTPStatus.SWITCHING_PROTOCOLS.phrase, - headers=type(request.headers)(headers), - body=b"", - ) - - if self.process_response is not None: - updated = self.process_response(connection, request, response) - if updated is not None: - response = updated - - assert response is not None - stash_server_response(connection, response) - listener = connection if response.status_code == int(http.HTTPStatus.SWITCHING_PROTOCOLS) else None - if listener is not None: - raw_response = picows.WSUpgradeResponse.create_101_response(response.headers) - else: - raw_response = response.to_picows() - return picows.WSUpgradeResponseWithListener(raw_response, listener) + subprotocol = _select_subprotocol(connection, request, self.subprotocols, self.select_subprotocol) + if subprotocol is not None: + headers["Sec-WebSocket-Protocol"] = subprotocol + if self.compression == "deflate" and _supports_permessage_deflate(request): + headers["Sec-WebSocket-Extensions"] = _PERMESSAGE_DEFLATE_REQUEST + response = Response( + status_code=int(http.HTTPStatus.SWITCHING_PROTOCOLS), + reason_phrase=http.HTTPStatus.SWITCHING_PROTOCOLS.phrase, + headers=type(request.headers)(headers), + body=b"", + ) + connection._initial_response = response + raw_response = picows.WSUpgradeResponse.create_101_response(headers) + return picows.WSUpgradeResponseWithListener(raw_response, connection) raw_server = await picows.ws_create_server( listener_factory, diff --git a/tests/test_websockets_compat.py b/tests/test_websockets_compat.py index 010ee03..2c1cb05 100644 --- a/tests/test_websockets_compat.py +++ b/tests/test_websockets_compat.py @@ -70,46 +70,6 @@ def listener_factory(request): assert request_headers["value"] == "chat" -async def test_send_waits_for_resume_writing(): - class TrackingConnection(websockets.ClientConnection): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.pause_event = asyncio.Event() - - def pause_writing(self) -> None: - super().pause_writing() - self.pause_event.set() - - async with WSServer() as server: - async with websockets.connect( - server.url, - compression=None, - create_connection=TrackingConnection, - ) as ws: - third_requested = asyncio.Event() - allow_resume = asyncio.Event() - - async def fragments(): - ws.pause_writing() - yield b"first" - yield b"second" - third_requested.set() - yield b"third" - - async def resume_later(): - await allow_resume.wait() - ws.resume_writing() - - asyncio.create_task(resume_later()) - - send_task = asyncio.create_task(ws.send(fragments())) - await asyncio.wait_for(ws.pause_event.wait(), 1.0) - await asyncio.sleep(0) - assert not third_requested.is_set() - - allow_resume.set() - await asyncio.wait_for(send_task, 1.0) - assert third_requested.is_set() - - reply = await ws.recv() - assert reply == b"firstsecondthird" +def test_connect_rejects_create_connection(): + with pytest.raises(NotImplementedError): + websockets.connect("ws://example.com", create_connection=websockets.ClientConnection) diff --git a/tests/test_websockets_server_compat.py b/tests/test_websockets_server_compat.py index a48752d..48de100 100644 --- a/tests/test_websockets_server_compat.py +++ b/tests/test_websockets_server_compat.py @@ -1,6 +1,4 @@ import asyncio -import base64 -import http import re import pytest @@ -20,40 +18,46 @@ async def handler(ws: websockets.ServerConnection) -> None: assert await ws.recv() == "hello" -async def test_serve_process_request_can_reject_handshake(): +async def test_serve_rejects_process_request(): async def handler(ws: websockets.ServerConnection) -> None: raise AssertionError("handler must not be called") - def process_request(ws: websockets.ServerConnection, request: websockets.Request) -> websockets.Response: - return ws.respond(http.HTTPStatus.FORBIDDEN, "nope") - - async with websockets.serve( - handler, "127.0.0.1", 0, compression=None, process_request=process_request - ) as server: - port = server.sockets[0].getsockname()[1] - with pytest.raises(websockets.InvalidStatus): - async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): - pass + with pytest.raises(NotImplementedError): + await websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=lambda ws, request: None, + ) -async def test_serve_process_response_can_mutate_headers(): +async def test_serve_rejects_process_response(): async def handler(ws: websockets.ServerConnection) -> None: - await ws.wait_closed() + raise AssertionError("handler must not be called") - def process_response( - ws: websockets.ServerConnection, - request: websockets.Request, - response: websockets.Response, - ) -> websockets.Response | None: - response.headers["X-Test"] = "1" - return response + with pytest.raises(NotImplementedError): + await websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_response=lambda ws, request, response: response, + ) - async with websockets.serve( - handler, "127.0.0.1", 0, compression=None, process_response=process_response - ) as server: - port = server.sockets[0].getsockname()[1] - async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: - assert ws.response.headers["X-Test"] == "1" + +async def test_serve_rejects_create_connection(): + async def handler(ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + with pytest.raises(NotImplementedError): + await websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + create_connection=websockets.ServerConnection, + ) async def test_serve_accepts_allowed_origin(): @@ -97,52 +101,9 @@ async def handler(ws: websockets.ServerConnection) -> None: pass -async def test_basic_auth_accepts_valid_credentials_and_sets_username(): - seen_usernames: list[str] = [] - - async def handler(ws: websockets.ServerConnection) -> None: - seen_usernames.append(ws.username) - await ws.send(ws.username) - - async with websockets.serve( - handler, - "127.0.0.1", - 0, - compression=None, - process_request=websockets.basic_auth( - realm="test", - credentials=("hello", "secret"), - ), - ) as server: - port = server.sockets[0].getsockname()[1] - token = base64.b64encode(b"hello:secret").decode() - async with websockets.connect( - f"ws://127.0.0.1:{port}/", - compression=None, - additional_headers={"Authorization": f"Basic {token}"}, - ) as ws: - assert await ws.recv() == "hello" - assert seen_usernames == ["hello"] - - -async def test_basic_auth_rejects_missing_credentials(): - async def handler(ws: websockets.ServerConnection) -> None: - raise AssertionError("handler must not be called") - - async with websockets.serve( - handler, - "127.0.0.1", - 0, - compression=None, - process_request=websockets.basic_auth( - realm="test", - credentials=("hello", "secret"), - ), - ) as server: - port = server.sockets[0].getsockname()[1] - with pytest.raises(websockets.InvalidStatus): - async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): - pass +def test_basic_auth_is_not_supported_yet(): + with pytest.raises(NotImplementedError): + websockets.basic_auth(realm="test", credentials=("hello", "secret")) async def test_serve_negotiates_subprotocol(): @@ -255,5 +216,5 @@ async def handler(ws: websockets.ServerConnection) -> None: def test_route_requires_werkzeug(): - with pytest.raises(ImportError): + with pytest.raises((ImportError, NotImplementedError)): websockets.route(None) # type: ignore[arg-type] From 08135c9fd1406f548cc29b4fe92d69c19764f691 Mon Sep 17 00:00:00 2001 From: taras Date: Thu, 7 May 2026 05:25:50 +0200 Subject: [PATCH 46/57] WSTransport.request was missing for the server side. --- HISTORY.rst | 1 + picows/picows.pyx | 1 + 2 files changed, 2 insertions(+) diff --git a/HISTORY.rst b/HISTORY.rst index fa2a75c..588acfc 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -19,6 +19,7 @@ picows Release History * Add is_disconnected property to WSTransport. * Fix send_* methods raising exceptions when attempting to send after connection abort and without prior CLOSE frame. * Add missing body attribute in WSUpgradeResponse at the client side. +* WSTransport.request was missing for the server side. 1.19.0 (2026-04-24) ------------------ diff --git a/picows/picows.pyx b/picows/picows.pyx index 5a8baf3..fba28f2 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -1216,6 +1216,7 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): listener_factory = self._listener_factory self._listener_factory = None try: + self.transport.request = upgrade_request listener_or_response_with_listener = listener_factory(upgrade_request) if isinstance(listener_or_response_with_listener, WSUpgradeResponseWithListener): self.listener = listener_or_response_with_listener.listener From f4178af25f92d3699f31af2da6989a86ed0e8e62 Mon Sep 17 00:00:00 2001 From: taras Date: Thu, 7 May 2026 05:26:22 +0200 Subject: [PATCH 47/57] Simplify logic --- picows/websockets/asyncio/connection.py | 14 ++------------ picows/websockets/asyncio/server.py | 16 +--------------- tests/test_websockets_server_compat.py | 2 ++ 3 files changed, 5 insertions(+), 27 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 17b567c..da00161 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -1102,14 +1102,10 @@ def broadcast_message(connection: ConnectionBase, msg_type: WSMsgType, message: @cython.cclass class ServerConnection(ConnectionBase): server: Any - _initial_request: Optional[Request] - _initial_response: Optional[Response] def __init__( self, server: Any, - request: Request, - response: Response, *, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, @@ -1132,18 +1128,12 @@ def __init__( compression=compression, ) self.server = server - self._initial_request = request - self._initial_response = response @cython.ccall def on_ws_connected(self, transport: WSTransport) -> None: self.transport = transport - assert self._initial_request is not None - assert self._initial_response is not None - self._request = self._initial_request - self._response = self._initial_response - self._initial_request = None - self._initial_response = None + self._request = Request.from_picows(transport.request) + self._response = Response.from_picows(transport.response) try: self._subprotocol = _resolve_subprotocol(None, self._response) self._configure_extensions() diff --git a/picows/websockets/asyncio/server.py b/picows/websockets/asyncio/server.py index 755943f..068513d 100644 --- a/picows/websockets/asyncio/server.py +++ b/picows/websockets/asyncio/server.py @@ -17,7 +17,7 @@ _resolve_logger, broadcast_message, ) -from ..compat import Request, Response, State +from ..compat import Request, State from ..exceptions import ConcurrencyError, InvalidHandshake, InvalidOrigin from ..typing import DataLike, LoggerLike, Origin, Subprotocol @@ -315,13 +315,6 @@ def listener_factory( headers["Server"] = self.server_header connection = ServerConnection( server, - request=request, - response=Response( - status_code=int(http.HTTPStatus.SWITCHING_PROTOCOLS), - reason_phrase=http.HTTPStatus.SWITCHING_PROTOCOLS.phrase, - headers=type(request.headers)({}), - body=b"", - ), ping_interval=self.ping_interval, ping_timeout=self.ping_timeout, close_timeout=self.close_timeout, @@ -336,13 +329,6 @@ def listener_factory( headers["Sec-WebSocket-Protocol"] = subprotocol if self.compression == "deflate" and _supports_permessage_deflate(request): headers["Sec-WebSocket-Extensions"] = _PERMESSAGE_DEFLATE_REQUEST - response = Response( - status_code=int(http.HTTPStatus.SWITCHING_PROTOCOLS), - reason_phrase=http.HTTPStatus.SWITCHING_PROTOCOLS.phrase, - headers=type(request.headers)(headers), - body=b"", - ) - connection._initial_response = response raw_response = picows.WSUpgradeResponse.create_101_response(headers) return picows.WSUpgradeResponseWithListener(raw_response, connection) diff --git a/tests/test_websockets_server_compat.py b/tests/test_websockets_server_compat.py index 48de100..4e1f58a 100644 --- a/tests/test_websockets_server_compat.py +++ b/tests/test_websockets_server_compat.py @@ -8,6 +8,8 @@ async def test_serve_echo_roundtrip(): async def handler(ws: websockets.ServerConnection) -> None: + assert ws.request.path == "/" + assert ws.response.status_code == 101 message = await ws.recv() await ws.send(message) From 8c7a56e24e252780dd5718f7bb449272d035a484 Mon Sep 17 00:00:00 2001 From: taras Date: Thu, 7 May 2026 05:31:39 +0200 Subject: [PATCH 48/57] Clenaup --- picows/websockets/asyncio/connection.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index da00161..3a89361 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -154,18 +154,6 @@ def from_response_header(cls, header_value: str) -> _PerMessageDeflate: return self - @classmethod - def enabled(cls) -> _PerMessageDeflate: - self: _PerMessageDeflate = _PerMessageDeflate.__new__(_PerMessageDeflate) - self.remote_no_context_takeover = False - self.local_no_context_takeover = False - self.remote_max_window_bits = -15 - self.local_max_window_bits = -15 - self._decoder = _zlib_decompressobj(wbits=self.remote_max_window_bits) - self._encoder = _zlib_compressobj(wbits=self.local_max_window_bits) - self._decode_cont_data = False - return self - @cython.cfunc @cython.inline def decode_frame(self, frame: WSFrame, max_length: cython.Py_ssize_t) -> bytes: From e8e1a65d901c93030ea34fe155e920a8d0cc0a12 Mon Sep 17 00:00:00 2001 From: taras Date: Thu, 7 May 2026 06:12:48 +0200 Subject: [PATCH 49/57] ws_connect can now accept listener_factory that takes WSUpgradeRequest, WSUpgradeResponse as arguments. Old argument-less client_factory also works. --- HISTORY.rst | 5 +++-- README.md | 6 +++++- picows/api.py | 19 ++++++++++++++----- picows/picows.pyx | 8 +++++++- tests/test_ws_logic.py | 21 +++++++++++++++++++++ 5 files changed, 50 insertions(+), 9 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index 588acfc..3456051 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -18,8 +18,9 @@ picows Release History * User on_ws_connect and on_ws_frame implementation can now signalize protocol errors by raising WSProtocolError * Add is_disconnected property to WSTransport. * Fix send_* methods raising exceptions when attempting to send after connection abort and without prior CLOSE frame. -* Add missing body attribute in WSUpgradeResponse at the client side. -* WSTransport.request was missing for the server side. +* Add missing WSUpgradeResponse.body attribute at the client side. +* Add missing WSTransport.request attribute for the server side. +* ws_connect can now accept listener_factory that takes WSUpgradeRequest, WSUpgradeResponse as arguments. Old argument-less client_factory also works. 1.19.0 (2026-04-24) ------------------ diff --git a/README.md b/README.md index b72ab04..f930836 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,11 @@ This prints: Echo reply: Hello world ``` +`ws_connect()` accepts either a zero-argument client listener factory such as +`ClientListener`, or a two-argument factory receiving +`(request: WSUpgradeRequest, response: WSUpgradeResponse)` when the caller +needs access to the negotiated handshake metadata before `on_ws_connected()`. + ### Echo server ```python @@ -198,4 +203,3 @@ pytest -s -v --cov=picows --cov-report=html pip install -r docs/requirements.txt make -C docs clean html ``` - diff --git a/picows/api.py b/picows/api.py index f5cd14e..6d20c76 100644 --- a/picows/api.py +++ b/picows/api.py @@ -10,14 +10,17 @@ from python_socks.async_.asyncio import Proxy -from .types import (WSHeadersLike, WSUpgradeRequest, WSHost, WSPort, +from .types import (WSHeadersLike, WSUpgradeRequest, WSUpgradeResponse, WSHost, WSPort, WSUpgradeResponseWithListener, WSHandshakeError) from .picows import (WSListener, WSTransport, WSAutoPingStrategy, # type: ignore [attr-defined] WSProtocol) from .url import parse_url, WSInvalidURL, WSParsedURL -WSListenerFactory = Callable[[], WSListener] +WSListenerFactory = Union[ + Callable[[], WSListener], + Callable[[WSUpgradeRequest, WSUpgradeResponse], WSListener], +] WSServerListenerFactory = Callable[[WSUpgradeRequest], Union[WSListener, WSUpgradeResponseWithListener, None]] WSSocketFactory = Callable[[WSParsedURL], Union[Optional[socket.socket], Awaitable[Optional[socket.socket]]]] @@ -215,8 +218,10 @@ async def ws_connect(ws_listener_factory: WSListenerFactory, # type: ignore [no- `asyncio.loop.create_connection `_ :param ws_listener_factory: - A parameterless factory function that returns a user handler. - User handler has to derive from :any:`WSListener`. + A factory function that returns a user handler. + The factory may either accept no arguments, or accept the negotiated + :any:`WSUpgradeRequest` and :any:`WSUpgradeResponse`. + The returned handler has to derive from :any:`WSListener`. :param url: Destination URL :param ssl_context: optional SSLContext to override default one when the wss scheme is used @@ -281,7 +286,11 @@ async def ws_connect(ws_listener_factory: WSListenerFactory, # type: ignore [no- instead of ``loop.create_server``, ``loop.create_connection`` native method. **picows** will use **aiofastnet** by default if it is installed. You can override default behavior by using this argument. - :return: :any:`WSTransport` object and a user handler returned by `ws_listener_factory()` + :return: + :any:`WSTransport` object and a user handler returned by + `ws_listener_factory()`, or by + `ws_listener_factory(request, response)` when using the two-argument + client factory form. """ assert "ssl" not in kwargs, "explicit 'ssl' argument for loop.create_connection is not supported" diff --git a/picows/picows.pyx b/picows/picows.pyx index fba28f2..7331d84 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -1190,7 +1190,13 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): if self._state == WSParserState.WAIT_UPGRADE_RESPONSE: # Upgrade response hasn't fully arrived yet return False - self.listener = self._listener_factory() + try: + self.listener = self._listener_factory(self.transport.request, response) + except TypeError as ex: + try: + self.listener = self._listener_factory() + except TypeError: + raise ex self.transport.listener_proxy = weakref.proxy(self.listener) self.transport.response = response self._listener_factory = None diff --git a/tests/test_ws_logic.py b/tests/test_ws_logic.py index 8e764c1..ddb5d4e 100644 --- a/tests/test_ws_logic.py +++ b/tests/test_ws_logic.py @@ -7,6 +7,7 @@ from contextlib import asynccontextmanager from hashlib import sha1 from http import HTTPStatus +from typing import Optional import async_timeout import pytest @@ -183,6 +184,26 @@ def listener_factory(request: picows.WSUpgradeRequest): assert client.transport.response.status == HTTPStatus.SWITCHING_PROTOCOLS +async def test_client_factory_with_2_args(): + request: Optional[picows.WSUpgradeRequest] = None + response: Optional[picows.WSUpgradeResponse] = None + + def listener_factory(client_request, server_response): + nonlocal request, response + request = client_request + response = server_response + return AsyncClient() + + async with WSServer() as server: + async with WSClient(server, listener_factory) as client: + assert request.method == b"GET" + assert response.status == HTTPStatus.SWITCHING_PROTOCOLS + + client.transport.send(picows.WSMsgType.BINARY, b"test") + frame = await client.get_message() + assert frame.payload_as_bytes == b"test" + + async def test_route_not_found(): def exc_check(exc): return exc.response.status == 404 From 5b52bda558cb9b1b081b8c0a194cd6bdbb100c5b Mon Sep 17 00:00:00 2001 From: taras Date: Thu, 7 May 2026 06:29:56 +0200 Subject: [PATCH 50/57] Simplify --- picows/websockets/__init__.py | 3 +- picows/websockets/asyncio/__init__.py | 3 +- picows/websockets/asyncio/client.py | 20 +++-- picows/websockets/asyncio/connection.py | 102 ++++++----------------- picows/websockets/asyncio/negotiation.py | 36 ++++++++ picows/websockets/asyncio/server.py | 45 +++++++--- tests/test_websockets_server_compat.py | 39 +++++++++ 7 files changed, 153 insertions(+), 95 deletions(-) create mode 100644 picows/websockets/asyncio/negotiation.py diff --git a/picows/websockets/__init__.py b/picows/websockets/__init__.py index e97f1b8..9cc5646 100644 --- a/picows/websockets/__init__.py +++ b/picows/websockets/__init__.py @@ -2,7 +2,7 @@ from .asyncio.client import connect from .asyncio.connection import ClientConnection, ServerConnection, process_exception from .asyncio.router import route -from .asyncio.server import Server, basic_auth, broadcast, serve +from .asyncio.server import Server, ServerHandshakeConnection, basic_auth, broadcast, serve from .compat import CloseCode, Request, Response, State from .exceptions import ( ConcurrencyError, @@ -52,6 +52,7 @@ "Data", "DataLike", "Server", + "ServerHandshakeConnection", "ServerConnection", "ConcurrencyError", "ConnectionClosed", diff --git a/picows/websockets/asyncio/__init__.py b/picows/websockets/asyncio/__init__.py index da25ffb..21c4a21 100644 --- a/picows/websockets/asyncio/__init__.py +++ b/picows/websockets/asyncio/__init__.py @@ -1,12 +1,13 @@ from .client import connect from .connection import ClientConnection, ServerConnection, process_exception from .router import route -from .server import Server, basic_auth, broadcast, serve +from .server import Server, ServerHandshakeConnection, basic_auth, broadcast, serve from ..compat import State __all__ = [ "ClientConnection", "Server", + "ServerHandshakeConnection", "ServerConnection", "basic_auth", "broadcast", diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index 6b0ffc2..68b6d43 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -15,6 +15,8 @@ ClientConnection, process_exception, ) +from .negotiation import configure_permessage_deflate, resolve_subprotocol +from ..compat import Request, Response from ..exceptions import ( InvalidHandshake, InvalidHeader, @@ -187,8 +189,19 @@ async def connect_override(_: Any) -> socket.socket: socket_factory = connect_override - def listener_factory() -> ClientConnection: + def listener_factory( + request: picows.WSUpgradeRequest, + response: picows.WSUpgradeResponse, + ) -> ClientConnection: + wrapped_request = Request.from_picows(request) + wrapped_response = Response.from_picows(response) + subprotocol = resolve_subprotocol(self.subprotocols, wrapped_response) + permessage_deflate = configure_permessage_deflate(wrapped_response, self.compression) return ClientConnection( + request=wrapped_request, + response=wrapped_response, + subprotocol=subprotocol, + permessage_deflate=permessage_deflate, ping_interval=self.ping_interval, ping_timeout=self.ping_timeout, close_timeout=self.close_timeout, @@ -196,8 +209,6 @@ def listener_factory() -> ClientConnection: write_limit=self.write_limit, max_message_size=max_message_size, logger=self.logger, - subprotocols=self.subprotocols, - compression=self.compression, ) try: @@ -227,10 +238,7 @@ def listener_factory() -> ClientConnection: raise InvalidMessage(str(exc)) from exc except picows.WSHandshakeError as exc: raise InvalidHandshake(str(exc)) from exc - assert isinstance(listener, ClientConnection) - if listener.connect_exception is not None: - raise listener.connect_exception return listener def _build_headers(self) -> list[tuple[str, str]]: diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 3a89361..64d7547 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -230,25 +230,6 @@ def _coerce_close_code(code: CloseCode) -> Optional[int]: def _coerce_close_reason(reason: Optional[str]) -> Optional[str]: return reason if reason is not None else None - -@cython.cfunc -@cython.inline -def _resolve_subprotocol( - subprotocols: Optional[Sequence[Subprotocol]], - response: Any, -) -> Optional[Subprotocol]: - if response is None: - return None - value = response.headers.get("Sec-WebSocket-Protocol") - if value is None: - return None - if not isinstance(value, str): - raise InvalidHandshake("server returned non-string subprotocol") - if subprotocols is not None and value not in subprotocols: - raise InvalidHandshake(f"unsupported subprotocol negotiated by server: {value}") - return value - - @cython.cfunc @cython.inline def _normalize_watermarks( @@ -295,10 +276,7 @@ class ConnectionBase(WSListener): # type: ignore[misc] transport: WSTransport _request: Request _response: Response - _connect_exception: Optional[Exception] - _subprotocols: Optional[Sequence[Subprotocol]] _subprotocol: Optional[Subprotocol] - _compression: Optional[str] _permessage_deflate: Optional[_PerMessageDeflate] _loop: asyncio.AbstractEventLoop @@ -334,6 +312,10 @@ class ConnectionBase(WSListener): # type: ignore[misc] def __init__( self, *, + request: Request, + response: Response, + subprotocol: Optional[Subprotocol], + permessage_deflate: Optional[_PerMessageDeflate], ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = 10, @@ -341,19 +323,14 @@ def __init__( write_limit: Union[int, tuple[int, Optional[int]]] = 32768, max_message_size: Optional[int] = 1024 * 1024, logger: LoggerLike = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - compression: Optional[str] = None, ): self.id = uuid.uuid4() self.logger = _resolve_logger(logger) self.transport = cython.cast(WSTransport, None) - self._request = None # type: ignore[assignment] - self._response = None # type: ignore[assignment] - self._connect_exception = None - self._subprotocols = subprotocols - self._subprotocol = None - self._compression = compression - self._permessage_deflate = None + self._request = request + self._response = response + self._subprotocol = subprotocol + self._permessage_deflate = permessage_deflate self._loop = asyncio.get_running_loop() self._send_in_progress = False @@ -490,19 +467,6 @@ def _set_write_limits(self, write_limit: Union[int, tuple[int, Optional[int]]]) high, low = write_limit, None self.transport.underlying_transport.set_write_buffer_limits(high=high, low=low) - @cython.cfunc - @cython.inline - def _configure_extensions(self) -> None: - header_value = self._response.headers.get("Sec-WebSocket-Extensions") - if header_value is None: - return - if self._compression != "deflate": - raise InvalidHandshake("unexpected websocket extensions negotiated by server") - if not isinstance(header_value, str): - raise InvalidHandshake("invalid Sec-WebSocket-Extensions header") - - self._permessage_deflate = _PerMessageDeflate.from_response_header(header_value) - @cython.cfunc @cython.inline def _process_pong_frame(self, frame: WSFrame) -> None: @@ -992,10 +956,6 @@ def request(self) -> Request: def response(self) -> Response: return self._response - @property - def connect_exception(self) -> Optional[Exception]: - return self._connect_exception - @property def local_address(self) -> Any: return self.transport.underlying_transport.get_extra_info("sockname") @@ -1040,6 +1000,10 @@ class ClientConnection(ConnectionBase): def __init__( self, *, + request: Request, + response: Response, + subprotocol: Optional[Subprotocol], + permessage_deflate: Optional[_PerMessageDeflate], ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = 10, @@ -1047,10 +1011,12 @@ def __init__( write_limit: Union[int, tuple[int, Optional[int]]] = 32768, max_message_size: Optional[int] = 1024 * 1024, logger: LoggerLike = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - compression: Optional[str] = None, ): super().__init__( + request=request, + response=response, + subprotocol=subprotocol, + permessage_deflate=permessage_deflate, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, @@ -1058,23 +1024,11 @@ def __init__( write_limit=write_limit, max_message_size=max_message_size, logger=logger, - subprotocols=subprotocols, - compression=compression, ) @cython.ccall def on_ws_connected(self, transport: WSTransport) -> None: self.transport = transport - self._request = Request.from_picows(transport.request) - self._response = Response.from_picows(transport.response) - try: - self._subprotocol = _resolve_subprotocol(self._subprotocols, self._response) - self._configure_extensions() - except InvalidHandshake as exc: - self._connect_exception = exc - self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, str(exc)) - self.transport.disconnect(False) - return self._set_write_limits(self._write_limit) if self._ping_interval is not None and self._keepalive_task is None: self._keepalive_task = asyncio.create_task(self._keepalive_loop()) @@ -1095,6 +1049,10 @@ def __init__( self, server: Any, *, + request: Request, + response: Response, + subprotocol: Optional[Subprotocol], + permessage_deflate: Optional[_PerMessageDeflate], ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = 10, @@ -1102,9 +1060,12 @@ def __init__( write_limit: Union[int, tuple[int, Optional[int]]] = 32768, max_message_size: Optional[int] = 1024 * 1024, logger: LoggerLike = None, - compression: Optional[str] = None, ): super().__init__( + request=request, + response=response, + subprotocol=subprotocol, + permessage_deflate=permessage_deflate, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, @@ -1112,25 +1073,14 @@ def __init__( write_limit=write_limit, max_message_size=max_message_size, logger=logger, - subprotocols=None, - compression=compression, ) self.server = server @cython.ccall def on_ws_connected(self, transport: WSTransport) -> None: self.transport = transport - self._request = Request.from_picows(transport.request) - self._response = Response.from_picows(transport.response) - try: - self._subprotocol = _resolve_subprotocol(None, self._response) - self._configure_extensions() - except InvalidHandshake as exc: - self._connect_exception = exc - self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, str(exc)) - self.transport.disconnect(False) - return self._set_write_limits(self._write_limit) if self._ping_interval is not None and self._keepalive_task is None: self._keepalive_task = asyncio.create_task(self._keepalive_loop()) - self.server.loop.call_soon(self.server.start_connection_handler, self) + self.server.loop. + (self.server.start_connection_handler, self) diff --git a/picows/websockets/asyncio/negotiation.py b/picows/websockets/asyncio/negotiation.py new file mode 100644 index 0000000..e83f5bd --- /dev/null +++ b/picows/websockets/asyncio/negotiation.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import Optional, Sequence + +from ..compat import Response +from ..exceptions import InvalidHandshake +from ..typing import Subprotocol +from .connection import _PerMessageDeflate + + +def resolve_subprotocol( + subprotocols: Optional[Sequence[Subprotocol]], + response: Response, +) -> Optional[Subprotocol]: + value = response.headers.get("Sec-WebSocket-Protocol") + if value is None: + return None + if not isinstance(value, str): + raise InvalidHandshake("server returned non-string subprotocol") + if subprotocols is not None and value not in subprotocols: + raise InvalidHandshake(f"unsupported subprotocol negotiated by server: {value}") + return value + + +def configure_permessage_deflate( + response: Response, + compression: Optional[str], +) -> Optional[_PerMessageDeflate]: + header_value = response.headers.get("Sec-WebSocket-Extensions") + if header_value is None: + return None + if compression != "deflate": + raise InvalidHandshake("unexpected websocket extensions negotiated by server") + if not isinstance(header_value, str): + raise InvalidHandshake("invalid Sec-WebSocket-Extensions header") + return _PerMessageDeflate.from_response_header(header_value) diff --git a/picows/websockets/asyncio/server.py b/picows/websockets/asyncio/server.py index 068513d..5357474 100644 --- a/picows/websockets/asyncio/server.py +++ b/picows/websockets/asyncio/server.py @@ -5,6 +5,7 @@ import re import socket import sys +from dataclasses import dataclass from collections.abc import Awaitable, Callable, Iterable from logging import getLogger from typing import Any, Optional, Pattern, Sequence @@ -17,12 +18,14 @@ _resolve_logger, broadcast_message, ) -from ..compat import Request, State +from .negotiation import configure_permessage_deflate +from ..compat import Request, Response, State from ..exceptions import ConcurrencyError, InvalidHandshake, InvalidOrigin from ..typing import DataLike, LoggerLike, Origin, Subprotocol __all__ = [ "ServerConnection", + "ServerHandshakeConnection", "Server", "serve", "broadcast", @@ -61,10 +64,10 @@ def _origin_allowed( def _select_subprotocol( - connection: ServerConnection, + connection: ServerHandshakeConnection, request: Request, subprotocols: Optional[Sequence[Subprotocol]], - select_subprotocol: Optional[Callable[[ServerConnection, Sequence[Subprotocol]], Subprotocol | None]], + select_subprotocol: Optional[Callable[[ServerHandshakeConnection, Sequence[Subprotocol]], Subprotocol | None]], ) -> Optional[Subprotocol]: header_value = request.headers.get("Sec-WebSocket-Protocol") if header_value is None: @@ -89,6 +92,15 @@ def basic_auth(*args: Any, **kwargs: Any) -> Any: raise NotImplementedError("basic_auth() requires unsupported server process_request hooks") +@dataclass(slots=True) +class ServerHandshakeConnection: + request: Request + + @property + def state(self) -> State: + return State.CONNECTING + + class Server: def __init__( self, @@ -208,7 +220,7 @@ def __init__( origins: Sequence[Origin | Pattern[str] | None] | None = None, extensions: Sequence[Any] | None = None, subprotocols: Sequence[Subprotocol] | None = None, - select_subprotocol: Callable[[ServerConnection, Sequence[Subprotocol]], Subprotocol | None] | None = None, + select_subprotocol: Callable[[ServerHandshakeConnection, Sequence[Subprotocol]], Subprotocol | None] | None = None, compression: str | None = "deflate", server_header: str | None = _default_server_header(), open_timeout: float | None = 10, @@ -313,8 +325,26 @@ def listener_factory( headers = {} if self.server_header is not None: headers["Server"] = self.server_header + handshake_connection = ServerHandshakeConnection(request) + subprotocol = _select_subprotocol( + handshake_connection, + request, + self.subprotocols, + self.select_subprotocol, + ) + if subprotocol is not None: + headers["Sec-WebSocket-Protocol"] = subprotocol + if self.compression == "deflate" and _supports_permessage_deflate(request): + headers["Sec-WebSocket-Extensions"] = _PERMESSAGE_DEFLATE_REQUEST + raw_response = picows.WSUpgradeResponse.create_101_response(headers) + response = Response.from_picows(raw_response) + permessage_deflate = configure_permessage_deflate(response, self.compression) connection = ServerConnection( server, + request=request, + response=response, + subprotocol=subprotocol, + permessage_deflate=permessage_deflate, ping_interval=self.ping_interval, ping_timeout=self.ping_timeout, close_timeout=self.close_timeout, @@ -322,14 +352,7 @@ def listener_factory( write_limit=self.write_limit, max_message_size=max_message_size, logger=self.logger, - compression=self.compression, ) - subprotocol = _select_subprotocol(connection, request, self.subprotocols, self.select_subprotocol) - if subprotocol is not None: - headers["Sec-WebSocket-Protocol"] = subprotocol - if self.compression == "deflate" and _supports_permessage_deflate(request): - headers["Sec-WebSocket-Extensions"] = _PERMESSAGE_DEFLATE_REQUEST - raw_response = picows.WSUpgradeResponse.create_101_response(headers) return picows.WSUpgradeResponseWithListener(raw_response, connection) raw_server = await picows.ws_create_server( diff --git a/tests/test_websockets_server_compat.py b/tests/test_websockets_server_compat.py index 4e1f58a..4a16767 100644 --- a/tests/test_websockets_server_compat.py +++ b/tests/test_websockets_server_compat.py @@ -128,6 +128,45 @@ async def handler(ws: websockets.ServerConnection) -> None: assert ws.subprotocol == "chat" +async def test_select_subprotocol_receives_handshake_connection(): + seen = {} + + async def handler(ws: websockets.ServerConnection) -> None: + await ws.wait_closed() + + def select_subprotocol( + ws: websockets.ServerHandshakeConnection, + offered: list[str], + ) -> str | None: + seen["type"] = type(ws) + seen["path"] = ws.request.path + seen["state"] = ws.state + seen["has_recv"] = hasattr(ws, "recv") + if "chat" in offered: + return "chat" + return None + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + select_subprotocol=select_subprotocol, + ) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect( + f"ws://127.0.0.1:{port}/room", + compression=None, + subprotocols=["chat"], + ) as ws: + assert ws.subprotocol == "chat" + + assert seen["type"] is websockets.ServerHandshakeConnection + assert seen["path"] == "/room" + assert seen["state"] is websockets.State.CONNECTING + assert seen["has_recv"] is False + + async def test_broadcast_sends_to_open_connections(): connections: list[websockets.ServerConnection] = [] From a8ce4fae822e9ac52423a4e542164db66f519a31 Mon Sep 17 00:00:00 2001 From: taras Date: Thu, 7 May 2026 06:30:17 +0200 Subject: [PATCH 51/57] Fix --- picows/websockets/asyncio/connection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 64d7547..a89df02 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -1082,5 +1082,4 @@ def on_ws_connected(self, transport: WSTransport) -> None: self._set_write_limits(self._write_limit) if self._ping_interval is not None and self._keepalive_task is None: self._keepalive_task = asyncio.create_task(self._keepalive_loop()) - self.server.loop. - (self.server.start_connection_handler, self) + self.server.loop.call_soon(self.server.start_connection_handler, self) From bfdb6cf85c608c25a2248eeb4c2ca0fc7fa27732 Mon Sep 17 00:00:00 2001 From: taras Date: Thu, 7 May 2026 06:32:52 +0200 Subject: [PATCH 52/57] Clenaup --- picows/websockets/asyncio/connection.py | 4 ---- picows/websockets/asyncio/server.py | 5 ++++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index a89df02..db68de0 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -265,10 +265,6 @@ def process_exception(exc: Exception) -> Optional[Exception]: return exc -def _default_server_header() -> str: - return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" - - @cython.cclass class ConnectionBase(WSListener): # type: ignore[misc] id: uuid.UUID diff --git a/picows/websockets/asyncio/server.py b/picows/websockets/asyncio/server.py index 5357474..1ee4928 100644 --- a/picows/websockets/asyncio/server.py +++ b/picows/websockets/asyncio/server.py @@ -14,7 +14,6 @@ from .connection import ( ServerConnection, - _default_server_header, _resolve_logger, broadcast_message, ) @@ -36,6 +35,10 @@ _PERMESSAGE_DEFLATE_REQUEST = "permessage-deflate" +def _default_server_header() -> str: + return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" + + def _header_items(headers: Any) -> list[tuple[str, str]]: return [] if headers is None else list(headers.items()) From 8252c4085f3b18f15c27f3769a285b27327c2a31 Mon Sep 17 00:00:00 2001 From: taras Date: Thu, 7 May 2026 06:44:42 +0200 Subject: [PATCH 53/57] Re-introduce process_request, process_response on the server side --- picows/websockets/asyncio/connection.py | 7 + picows/websockets/asyncio/server.py | 260 ++++++++++++++++++------ tests/test_websockets_server_compat.py | 119 +++++++++-- 3 files changed, 305 insertions(+), 81 deletions(-) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index db68de0..6b0e939 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -1040,6 +1040,7 @@ def broadcast_message(connection: ConnectionBase, msg_type: WSMsgType, message: @cython.cclass class ServerConnection(ConnectionBase): server: Any + _username: Optional[str] def __init__( self, @@ -1049,6 +1050,7 @@ def __init__( response: Response, subprotocol: Optional[Subprotocol], permessage_deflate: Optional[_PerMessageDeflate], + username: Optional[str] = None, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = 10, @@ -1071,6 +1073,7 @@ def __init__( logger=logger, ) self.server = server + self._username = username @cython.ccall def on_ws_connected(self, transport: WSTransport) -> None: @@ -1079,3 +1082,7 @@ def on_ws_connected(self, transport: WSTransport) -> None: if self._ping_interval is not None and self._keepalive_task is None: self._keepalive_task = asyncio.create_task(self._keepalive_loop()) self.server.loop.call_soon(self.server.start_connection_handler, self) + + @property + def username(self) -> Optional[str]: + return self._username diff --git a/picows/websockets/asyncio/server.py b/picows/websockets/asyncio/server.py index 1ee4928..5c86aa4 100644 --- a/picows/websockets/asyncio/server.py +++ b/picows/websockets/asyncio/server.py @@ -1,16 +1,20 @@ from __future__ import annotations import asyncio +import binascii +import hmac import http -import re import socket import sys +from base64 import b64decode from dataclasses import dataclass from collections.abc import Awaitable, Callable, Iterable +from inspect import isawaitable from logging import getLogger from typing import Any, Optional, Pattern, Sequence import picows +from multidict import CIMultiDict from .connection import ( ServerConnection, @@ -66,6 +70,47 @@ def _origin_allowed( return False +def _ensure_sync_result(value: Any, hook_name: str) -> Any: + if isawaitable(value): + close = getattr(value, "close", None) + if close is not None: + close() + raise NotImplementedError(f"async {hook_name} hooks aren't supported by picows.websockets server yet") + return value + + +def _make_error_response( + status: http.HTTPStatus, + body: bytes, +) -> Response: + return Response( + status_code=int(status), + reason_phrase=status.phrase, + headers=CIMultiDict({"Content-Type": "text/plain; charset=utf-8"}), + body=body, + ) + + +def _basic_auth_unauthorized_response(message: bytes, realm: str) -> Response: + response = _make_error_response(http.HTTPStatus.UNAUTHORIZED, message) + response.headers["WWW-Authenticate"] = f'Basic realm="{realm}"' + return response + + +def _parse_basic_authorization(header_value: str) -> tuple[str, str]: + scheme, _, token = header_value.partition(" ") + if scheme.lower() != "basic" or not token: + raise ValueError("unsupported authorization scheme") + try: + decoded = b64decode(token, validate=True).decode("utf-8") + except (ValueError, UnicodeDecodeError, binascii.Error) as exc: + raise ValueError("invalid basic authorization header") from exc + username, separator, password = decoded.partition(":") + if not separator: + raise ValueError("invalid basic authorization header") + return username, password + + def _select_subprotocol( connection: ServerHandshakeConnection, request: Request, @@ -91,13 +136,99 @@ def _select_subprotocol( return None -def basic_auth(*args: Any, **kwargs: Any) -> Any: - raise NotImplementedError("basic_auth() requires unsupported server process_request hooks") +def _resolve_response_subprotocol( + request: Request, + response: Response, +) -> Optional[Subprotocol]: + value = response.headers.get("Sec-WebSocket-Protocol") + if value is None: + return None + if not isinstance(value, str): + raise InvalidHandshake("invalid Sec-WebSocket-Protocol header") + header_value = request.headers.get("Sec-WebSocket-Protocol") + if header_value is None: + raise InvalidHandshake("server negotiated a subprotocol without a client offer") + offered = [item.strip() for item in header_value.split(",") if item.strip()] + if value not in offered: + raise InvalidHandshake(f"selected subprotocol isn't offered by client: {value}") + return value + + +def basic_auth( + realm: str = "", + credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None, + check_credentials: Callable[[str, str], bool] | None = None, +) -> Callable[[ServerHandshakeConnection, Request], Response | None]: + if (credentials is None) == (check_credentials is None): + raise ValueError("provide either credentials or check_credentials") + + if credentials is not None: + if ( + isinstance(credentials, tuple) + and len(credentials) == 2 + and all(isinstance(item, str) for item in credentials) + ): + username = credentials[0] + password = credentials[1] + assert isinstance(username, str) + assert isinstance(password, str) + credentials_dict: dict[str, str] = {username: password} + elif isinstance(credentials, Iterable): + credentials_list: list[tuple[str, str]] = [] + for item in credentials: + if ( + not isinstance(item, tuple) + or len(item) != 2 + or not isinstance(item[0], str) + or not isinstance(item[1], str) + ): + raise TypeError(f"invalid credentials argument: {credentials}") + credentials_list.append(item) + credentials_dict = dict(credentials_list) + else: + raise TypeError(f"invalid credentials argument: {credentials}") + + def check_credentials(username: str, password: str) -> bool: + expected_password: str | None = credentials_dict.get(username) + return ( + expected_password is not None + and hmac.compare_digest(expected_password, password) + ) + + assert check_credentials is not None + + def process_request( + connection: ServerHandshakeConnection, + request: Request, + ) -> Response | None: + authorization = request.headers.get("Authorization") + if authorization is None: + return _basic_auth_unauthorized_response(b"Missing credentials\n", realm) + + try: + username, password = _parse_basic_authorization(authorization) + except ValueError: + return _basic_auth_unauthorized_response(b"Unsupported credentials\n", realm) + + valid_credentials = check_credentials(username, password) + if isawaitable(valid_credentials): + close = getattr(valid_credentials, "close", None) + if close is not None: + close() + raise NotImplementedError("async basic_auth credential checks aren't supported yet") + if not valid_credentials: + return _basic_auth_unauthorized_response(b"Invalid credentials\n", realm) + + connection.username = username + return None + + return process_request @dataclass(slots=True) class ServerHandshakeConnection: request: Request + username: Optional[str] = None @property def state(self) -> State: @@ -278,15 +409,6 @@ async def _create(self) -> Server: raise NotImplementedError("custom server extensions aren't supported by picows.websockets") if self.compression not in (None, "deflate"): raise NotImplementedError("only compression=None or 'deflate' are accepted") - unsupported = [] - if self.process_request is not None: - unsupported.append("process_request") - if self.process_response is not None: - unsupported.append("process_response") - if unsupported: - raise NotImplementedError( - f"{', '.join(unsupported)} isn't supported by picows.websockets server yet" - ) server = Server( self.handler, @@ -302,61 +424,77 @@ def listener_factory( upgrade_request: picows.WSUpgradeRequest, ) -> picows.WSUpgradeResponseWithListener: request = Request.from_picows(upgrade_request) + handshake_connection = ServerHandshakeConnection(request) + response: Response + origin = request.headers.get("Origin") if origin is not None and not isinstance(origin, str): raise InvalidOrigin(None) if not _origin_allowed(origin, self.origins): - return picows.WSUpgradeResponseWithListener( - picows.WSUpgradeResponse.create_error_response( - http.HTTPStatus.FORBIDDEN, - b"Origin not allowed\n", - {"Content-Type": "text/plain; charset=utf-8"}, - ), - None, + response = _make_error_response( + http.HTTPStatus.FORBIDDEN, + b"Origin not allowed\n", ) - - if server.close_task is not None: - return picows.WSUpgradeResponseWithListener( - picows.WSUpgradeResponse.create_error_response( - http.HTTPStatus.SERVICE_UNAVAILABLE, - b"Server is shutting down.\n", - {"Content-Type": "text/plain; charset=utf-8"}, - ), - None, + elif server.close_task is not None: + response = _make_error_response( + http.HTTPStatus.SERVICE_UNAVAILABLE, + b"Server is shutting down.\n", ) - - headers = {} - if self.server_header is not None: - headers["Server"] = self.server_header - handshake_connection = ServerHandshakeConnection(request) - subprotocol = _select_subprotocol( - handshake_connection, - request, - self.subprotocols, - self.select_subprotocol, - ) - if subprotocol is not None: - headers["Sec-WebSocket-Protocol"] = subprotocol - if self.compression == "deflate" and _supports_permessage_deflate(request): - headers["Sec-WebSocket-Extensions"] = _PERMESSAGE_DEFLATE_REQUEST - raw_response = picows.WSUpgradeResponse.create_101_response(headers) - response = Response.from_picows(raw_response) - permessage_deflate = configure_permessage_deflate(response, self.compression) - connection = ServerConnection( - server, - request=request, - response=response, - subprotocol=subprotocol, - permessage_deflate=permessage_deflate, - ping_interval=self.ping_interval, - ping_timeout=self.ping_timeout, - close_timeout=self.close_timeout, - max_queue=self.max_queue, - write_limit=self.write_limit, - max_message_size=max_message_size, - logger=self.logger, - ) - return picows.WSUpgradeResponseWithListener(raw_response, connection) + else: + headers = {} + if self.server_header is not None: + headers["Server"] = self.server_header + subprotocol = _select_subprotocol( + handshake_connection, + request, + self.subprotocols, + self.select_subprotocol, + ) + if subprotocol is not None: + headers["Sec-WebSocket-Protocol"] = subprotocol + if self.compression == "deflate" and _supports_permessage_deflate(request): + headers["Sec-WebSocket-Extensions"] = _PERMESSAGE_DEFLATE_REQUEST + response = Response.from_picows(picows.WSUpgradeResponse.create_101_response(headers)) + + if self.process_request is not None: + response_or_none = _ensure_sync_result( + self.process_request(handshake_connection, request), + "process_request", + ) + if response_or_none is not None: + if not isinstance(response_or_none, Response): + raise TypeError("process_request must return a Response or None") + response = response_or_none + + response = _ensure_sync_result( + self.process_response(handshake_connection, request, response), + "process_response", + ) if self.process_response is not None else response + + if not isinstance(response, Response): + raise TypeError("process_response must return a Response") + + if response.status_code == int(http.HTTPStatus.SWITCHING_PROTOCOLS): + subprotocol = _resolve_response_subprotocol(request, response) + permessage_deflate = configure_permessage_deflate(response, self.compression) + connection = ServerConnection( + server, + request=request, + response=response, + subprotocol=subprotocol, + permessage_deflate=permessage_deflate, + username=handshake_connection.username, + ping_interval=self.ping_interval, + ping_timeout=self.ping_timeout, + close_timeout=self.close_timeout, + max_queue=self.max_queue, + write_limit=self.write_limit, + max_message_size=max_message_size, + logger=self.logger, + ) + return picows.WSUpgradeResponseWithListener(response.to_picows(), connection) + else: + return picows.WSUpgradeResponseWithListener(response.to_picows(), None) raw_server = await picows.ws_create_server( listener_factory, diff --git a/tests/test_websockets_server_compat.py b/tests/test_websockets_server_compat.py index 4a16767..ed1756f 100644 --- a/tests/test_websockets_server_compat.py +++ b/tests/test_websockets_server_compat.py @@ -1,7 +1,9 @@ import asyncio +import base64 import re import pytest +from multidict import CIMultiDict from picows import websockets @@ -20,32 +22,82 @@ async def handler(ws: websockets.ServerConnection) -> None: assert await ws.recv() == "hello" -async def test_serve_rejects_process_request(): +async def test_serve_process_request_can_reject_handshake(): async def handler(ws: websockets.ServerConnection) -> None: raise AssertionError("handler must not be called") - with pytest.raises(NotImplementedError): - await websockets.serve( - handler, - "127.0.0.1", - 0, - compression=None, - process_request=lambda ws, request: None, + def process_request( + ws: websockets.ServerHandshakeConnection, + request: websockets.Request, + ) -> websockets.Response | None: + assert ws.request is request + return websockets.Response( + status_code=418, + reason_phrase="I'm a Teapot", + headers=CIMultiDict({"X-Test": "yes"}), + body=b"nope", ) + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=process_request, + ) as server: + port = server.sockets[0].getsockname()[1] + with pytest.raises(websockets.InvalidStatus) as exc_info: + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + pass + assert int(exc_info.value.response.status) == 418 + + +async def test_serve_process_response_can_mutate_handshake_response(): + async def handler(ws: websockets.ServerConnection) -> None: + await ws.wait_closed() -async def test_serve_rejects_process_response(): + def process_response( + ws: websockets.ServerHandshakeConnection, + request: websockets.Request, + response: websockets.Response, + ) -> websockets.Response: + assert ws.request is request + response.headers["X-Handshake"] = "yes" + return response + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_response=process_response, + ) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: + assert ws.response.headers["X-Handshake"] == "yes" + + +async def test_serve_rejects_async_process_request(): async def handler(ws: websockets.ServerConnection) -> None: raise AssertionError("handler must not be called") - with pytest.raises(NotImplementedError): - await websockets.serve( - handler, - "127.0.0.1", - 0, - compression=None, - process_response=lambda ws, request, response: response, - ) + async def process_request( + ws: websockets.ServerHandshakeConnection, + request: websockets.Request, + ) -> websockets.Response | None: + return None + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=process_request, + ) as server: + port = server.sockets[0].getsockname()[1] + with pytest.raises(websockets.InvalidStatus): + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + pass async def test_serve_rejects_create_connection(): @@ -103,9 +155,36 @@ async def handler(ws: websockets.ServerConnection) -> None: pass -def test_basic_auth_is_not_supported_yet(): - with pytest.raises(NotImplementedError): - websockets.basic_auth(realm="test", credentials=("hello", "secret")) +async def test_basic_auth_rejects_missing_credentials_and_sets_username(): + async def handler(ws: websockets.ServerConnection) -> None: + assert ws.username == "hello" + await ws.send(ws.username) + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=websockets.basic_auth( + realm="test", + credentials=("hello", "secret"), + ), + ) as server: + port = server.sockets[0].getsockname()[1] + + with pytest.raises(websockets.InvalidStatus) as exc_info: + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + pass + assert int(exc_info.value.response.status) == 401 + assert exc_info.value.response.headers["WWW-Authenticate"] == 'Basic realm="test"' + + token = base64.b64encode(b"hello:secret").decode() + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + additional_headers={"Authorization": f"Basic {token}"}, + ) as ws: + assert await ws.recv() == "hello" async def test_serve_negotiates_subprotocol(): From 67898d8bcb67f22386872a39f17005861d48ef9b Mon Sep 17 00:00:00 2001 From: taras Date: Fri, 8 May 2026 18:31:29 +0200 Subject: [PATCH 54/57] Cleanups --- picows/websockets/asyncio/client.py | 4 - picows/websockets/asyncio/connection.py | 6 +- picows/websockets/asyncio/server.py | 10 - tests/test_websockets_unit_coverage.py | 618 ++++++++++++++++++++++++ 4 files changed, 623 insertions(+), 15 deletions(-) create mode 100644 tests/test_websockets_unit_coverage.py diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index 68b6d43..e35afba 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -60,10 +60,6 @@ def _process_proxy(proxy: Union[str, bool, None], secure: bool) -> Optional[str] raise InvalidProxy(str(proxy), "proxy must be None, True, or a proxy URL") -def _normalize_size_limit(limit: Optional[int]) -> int: - return 0 if limit is None else limit - - class _Connect: def __init__( self, diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index 6b0e939..d080d38 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -425,7 +425,11 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: else: payload = frame.get_payload_as_bytes() - self._incoming_message_size = len(payload) + if frame.msg_type == WSMsgType.CONTINUATION: + self._incoming_message_size += len(payload) + else: + self._incoming_message_size = len(payload) + if self._max_message_size > 0 and self._incoming_message_size > self._max_message_size: raise WSProtocolError(WSCloseCode.MESSAGE_TOO_BIG, "message too big") diff --git a/picows/websockets/asyncio/server.py b/picows/websockets/asyncio/server.py index 5c86aa4..6728f18 100644 --- a/picows/websockets/asyncio/server.py +++ b/picows/websockets/asyncio/server.py @@ -43,10 +43,6 @@ def _default_server_header() -> str: return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" -def _header_items(headers: Any) -> list[tuple[str, str]]: - return [] if headers is None else list(headers.items()) - - def _supports_permessage_deflate(request: Request) -> bool: value = request.headers.get("Sec-WebSocket-Extensions") return isinstance(value, str) and "permessage-deflate" in value @@ -240,14 +236,10 @@ def __init__( self, handler: Callable[[ServerConnection], Awaitable[None]], *, - server_header: str | None = _default_server_header(), - open_timeout: float | None = 10, logger: LoggerLike | None = None, ) -> None: self.loop = asyncio.get_running_loop() self.handler = handler - self.server_header = server_header - self.open_timeout = open_timeout self.logger = _resolve_logger(logger if logger is not None else getLogger("websockets.server")) self.handlers: dict[ServerConnection, asyncio.Task[None]] = {} self.close_task: asyncio.Task[None] | None = None @@ -412,8 +404,6 @@ async def _create(self) -> Server: server = Server( self.handler, - server_header=self.server_header, - open_timeout=self.open_timeout, logger=self.logger, ) diff --git a/tests/test_websockets_unit_coverage.py b/tests/test_websockets_unit_coverage.py new file mode 100644 index 0000000..1a2a2e1 --- /dev/null +++ b/tests/test_websockets_unit_coverage.py @@ -0,0 +1,618 @@ +from __future__ import annotations + +import asyncio +import socket +import sys +from dataclasses import dataclass +from http import HTTPStatus + +import pytest +from multidict import CIMultiDict + +import picows +from picows import websockets +from picows.websockets.asyncio.client import _process_proxy +from picows.websockets.asyncio.connection import ( + _PerMessageDeflate, + _normalize_watermarks, + _resolve_logger, + process_exception, +) +from picows.websockets.asyncio.negotiation import configure_permessage_deflate, resolve_subprotocol +from picows.websockets.asyncio.server import _parse_basic_authorization +from tests.utils import ServerEchoListener, WSServer + + +@dataclass +class Close: + code: int + reason: str + + +def test_connection_closed_string_variants(): + assert str(websockets.ConnectionClosed(None, None)) == "no close frame received or sent" + assert str(websockets.ConnectionClosed(None, Close(1000, "bye"))) == "sent 1000 (bye)" + assert str(websockets.ConnectionClosed(Close(1001, "away"), None)) == "received 1001 (away)" + assert str(websockets.ConnectionClosed(Close(1001, "away"), Close(1000, "bye"), True)) == ( + "received then sent close frames: received 1001 (away), sent 1000 (bye)" + ) + assert str(websockets.ConnectionClosed(Close(1001, "away"), Close(1000, "bye"), False)) == ( + "sent then received close frames: received 1001 (away), sent 1000 (bye)" + ) + + +def test_exception_attributes_and_strings(): + assert str(websockets.InvalidURI("http://example.com", "wrong scheme")) == ( + "http://example.com isn't a valid WebSocket URI: wrong scheme" + ) + assert str(websockets.InvalidProxy("ftp://proxy", "wrong scheme")) == ( + "ftp://proxy isn't a valid proxy: wrong scheme" + ) + assert str(websockets.InvalidProxyStatus(object())) == "proxy rejected connection" + assert str(websockets.InvalidProxyStatus(websockets.Response(502, "Bad Gateway", CIMultiDict(), b""))) == ( + "proxy rejected connection: HTTP 502" + ) + + invalid_header = websockets.InvalidHeader("X-Test", "bad") + assert invalid_header.name == "X-Test" + assert invalid_header.value == "bad" + assert websockets.InvalidOrigin("https://bad.example").name == "Origin" + assert websockets.InvalidHeaderFormat("X-Test", "bad syntax", "x:y", 1).value == "bad syntax at 1 in x:y" + + assert str(websockets.DuplicateParameter("server_max_window_bits")) == ( + "duplicate parameter: server_max_window_bits" + ) + assert str(websockets.InvalidParameterName("x")) == "invalid parameter name: x" + assert str(websockets.InvalidParameterValue("x", None)) == "missing value for parameter x" + assert str(websockets.InvalidParameterValue("x", "")) == "empty value for parameter x" + assert str(websockets.InvalidParameterValue("x", "bad")) == "invalid value for parameter x: bad" + + +def test_client_private_option_helpers(): + assert _process_proxy(None, False) is None + assert _process_proxy("http://127.0.0.1:8080", False) == "http://127.0.0.1:8080" + with pytest.raises(websockets.InvalidProxy): + _process_proxy(123, False) # type: ignore[arg-type] + + assert _normalize_watermarks(None) == (0, 0) + assert _normalize_watermarks((None, 1)) == (0, 0) + assert _normalize_watermarks((8, None)) == (8, 2) + assert _resolve_logger("picows.test").name == "picows.test" + + +async def test_client_connection_starts_in_connecting_state(): + connection = websockets.ClientConnection( + request=websockets.Request("/", CIMultiDict()), + response=websockets.Response(101, "Switching Protocols", CIMultiDict(), b""), + subprotocol=None, + permessage_deflate=None, + ) + + assert connection.state is websockets.State.CONNECTING + + +async def test_connect_await_style_and_socket_options(): + async with WSServer() as server: + ws = await websockets.connect(server.url, compression=None, ping_interval=None) + try: + await ws.send("awaited") + assert await ws.recv() == "awaited" + finally: + await ws.close() + + sock = socket.create_connection((server.host, server.port)) + ws = await websockets.connect(server.url, compression=None, ping_interval=None, sock=sock) + try: + await ws.send(b"sock") + assert await ws.recv() == b"sock" + finally: + await ws.close() + + async with websockets.connect( + "ws://example.invalid/", + compression=None, + ping_interval=None, + proxy=None, + host=server.host, + port=server.port, + ) as ws: + await ws.send("override") + assert await ws.recv() == "override" + + +async def test_connect_rejects_conflicting_and_invalid_socket_options(): + async with WSServer() as server: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + with pytest.raises(TypeError, match="cannot pass both sock and socket_factory"): + await websockets.connect(server.url, compression=None, sock=sock, socket_factory=lambda _: None) + finally: + sock.close() + + with pytest.raises(TypeError, match="sock must be a socket.socket instance"): + await websockets.connect(server.url, compression=None, sock=object()) + + with pytest.raises(TypeError, match="cannot pass both host/port override and socket_factory"): + await websockets.connect( + server.url, + compression=None, + host=server.host, + socket_factory=lambda _: None, + ) + + +async def test_connect_rejects_invalid_ssl_options_before_network(): + with pytest.raises(NotImplementedError, match="ssl=False"): + await websockets.connect("wss://example.com/", compression=None, ssl=False) + with pytest.raises(TypeError, match="ssl must be"): + await websockets.connect("wss://example.com/", compression=None, ssl=object()) + + +def test_process_exception_retries_transient_failures(): + assert process_exception(EOFError()) is None + assert process_exception(OSError()) is None + assert process_exception(asyncio.TimeoutError()) is None + + response = websockets.Response(503, "Service Unavailable", CIMultiDict(), b"") + assert process_exception(websockets.InvalidStatus(response)) is None + + error = RuntimeError("boom") + assert process_exception(error) is error + + +def test_negotiation_rejects_invalid_subprotocol_and_extension_headers(): + response = websockets.Response(101, "Switching Protocols", CIMultiDict(), b"") + + response.headers["Sec-WebSocket-Protocol"] = 123 # type: ignore[assignment] + with pytest.raises(websockets.InvalidHandshake, match="non-string subprotocol"): + resolve_subprotocol(["chat"], response) + + response.headers["Sec-WebSocket-Protocol"] = "other" + with pytest.raises(websockets.InvalidHandshake, match="unsupported subprotocol"): + resolve_subprotocol(["chat"], response) + + response.headers.clear() + response.headers["Sec-WebSocket-Extensions"] = "permessage-deflate" + with pytest.raises(websockets.InvalidHandshake, match="unexpected websocket extensions"): + configure_permessage_deflate(response, None) + + response.headers["Sec-WebSocket-Extensions"] = 123 # type: ignore[assignment] + with pytest.raises(websockets.InvalidHandshake, match="invalid Sec-WebSocket-Extensions"): + configure_permessage_deflate(response, "deflate") + + +def test_permessage_deflate_rejects_invalid_parameters(): + response = websockets.Response(101, "Switching Protocols", CIMultiDict(), b"") + invalid_headers = [ + "x-webkit-deflate-frame", + "permessage-deflate, permessage-deflate", + "permessage-deflate; server_no_context_takeover=true", + "permessage-deflate; client_no_context_takeover=true", + "permessage-deflate; server_max_window_bits", + "permessage-deflate; server_max_window_bits=7", + "permessage-deflate; client_max_window_bits", + "permessage-deflate; client_max_window_bits=16", + "permessage-deflate; unknown=value", + "permessage-deflate; server_max_window_bits=15; server_max_window_bits=15", + ] + + for header in invalid_headers: + response.headers["Sec-WebSocket-Extensions"] = header + with pytest.raises(websockets.InvalidHandshake): + configure_permessage_deflate(response, "deflate") + + +def test_permessage_deflate_accepts_no_context_takeover_parameters(): + permessage_deflate = _PerMessageDeflate.from_response_header( + "permessage-deflate; server_no_context_takeover; client_no_context_takeover" + ) + + first = permessage_deflate.encode_frame(picows.WSMsgType.TEXT, "hello", True) + second = permessage_deflate.encode_frame(picows.WSMsgType.TEXT, b"hello", True) + + assert isinstance(first, memoryview) + assert isinstance(second, memoryview) + assert bytes(first) + assert bytes(second) + + +class Frame: + def __init__( + self, + msg_type: picows.WSMsgType, + payload: bytes, + *, + fin: bool = True, + rsv1: bool = False, + ): + self.msg_type = msg_type + self.payload = payload + self.fin = fin + self.rsv1 = rsv1 + + def get_payload_as_bytes(self) -> bytes: + return self.payload + + def get_payload_as_memoryview(self) -> memoryview: + return memoryview(self.payload) + + +def test_permessage_deflate_decode_passthrough_and_protocol_error_branches(): + permessage_deflate = _PerMessageDeflate.from_response_header( + "permessage-deflate; server_no_context_takeover; client_no_context_takeover" + ) + + assert permessage_deflate.decode_frame( + Frame(picows.WSMsgType.TEXT, b"plain", rsv1=False), + 0, + ) == b"plain" + assert permessage_deflate.decode_frame( + Frame(picows.WSMsgType.CONTINUATION, b"continuation", rsv1=False), + 0, + ) == b"continuation" + + encoded = bytes(permessage_deflate.encode_frame(picows.WSMsgType.TEXT, b"compressed", True)) + assert permessage_deflate.decode_frame( + Frame(picows.WSMsgType.TEXT, encoded, rsv1=True), + 100, + ) == b"compressed" + + with pytest.raises(picows.WSProtocolError, match="unexpected rsv1"): + permessage_deflate.decode_frame( + Frame(picows.WSMsgType.CONTINUATION, b"bad", rsv1=True), + 0, + ) + + +def test_basic_auth_argument_validation_and_malformed_headers(): + with pytest.raises(ValueError, match="provide either credentials or check_credentials"): + websockets.basic_auth() + with pytest.raises(ValueError, match="provide either credentials or check_credentials"): + websockets.basic_auth(credentials=("a", "b"), check_credentials=lambda _u, _p: True) + with pytest.raises(TypeError, match="invalid credentials argument"): + websockets.basic_auth(credentials=("a", "b", "c")) # type: ignore[arg-type] + with pytest.raises(TypeError, match="invalid credentials argument"): + websockets.basic_auth(credentials=[("a", object())]) # type: ignore[list-item] + + with pytest.raises(ValueError, match="unsupported authorization scheme"): + _parse_basic_authorization("Bearer token") + with pytest.raises(ValueError, match="invalid basic authorization header"): + _parse_basic_authorization("Basic !!!") + with pytest.raises(ValueError, match="invalid basic authorization header"): + _parse_basic_authorization("Basic bm9jb2xvbg==") + + +async def test_basic_auth_rejects_bad_and_async_credentials(): + async def handler(_ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + async def check_credentials(_username: str, _password: str) -> bool: + return True + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=websockets.basic_auth(check_credentials=check_credentials), + ) as server: + port = server.sockets[0].getsockname()[1] + token = "Basic " + "aW52YWxpZDpjcmVkZW50aWFscw==" + with pytest.raises(websockets.InvalidStatus): + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + additional_headers={"Authorization": token}, + ): + pass + + +async def test_connect_async_iterator_retries_then_succeeds(): + attempts = 0 + + def process_exception(exc: Exception) -> Exception | None: + nonlocal attempts + attempts += 1 + if attempts == 1: + return None + return exc + + connector = websockets.connect( + "ws://127.0.0.1:1/", + compression=None, + open_timeout=0.01, + process_exception=process_exception, + ) + connector._backoff = 0 + with pytest.raises(OSError): + async for _ws in connector: + pass + assert attempts == 2 + + +class ContinuationOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.CONTINUATION, b"bad", fin=True) + + +class Rsv1OnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.TEXT, b"bad", fin=True, rsv1=True) + + +class BadContinuationSequenceOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.TEXT, b"first", fin=False) + transport.send(picows.WSMsgType.TEXT, b"second", fin=True) + + +class SendLargeTextOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.TEXT, b"large") + + +class DelayedFragmentedTextOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + + async def send_fragments(): + transport.send(picows.WSMsgType.TEXT, b"he", fin=False) + await asyncio.sleep(0.01) + transport.send(picows.WSMsgType.CONTINUATION, b"llo", fin=True) + + asyncio.create_task(send_fragments()) + + +class FragmentedTextOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.TEXT, b"he", fin=False) + transport.send(picows.WSMsgType.CONTINUATION, b"llo", fin=True) + + +class IgnoreCloseListener(ServerEchoListener): + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + if frame.msg_type == picows.WSMsgType.CLOSE: + return + super().on_ws_frame(transport, frame) + + +async def test_recv_rejects_unexpected_continuation_and_rsv1(): + async with WSServer(lambda _: ContinuationOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + async with WSServer(lambda _: Rsv1OnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + +async def test_recv_rejects_bad_continuation_and_too_large_message(): + async with WSServer(lambda _: BadContinuationSequenceOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + async with WSServer(lambda _: SendLargeTextOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None, max_size=2) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + +async def test_recv_streaming_waits_for_later_fragment(): + async with WSServer(lambda _: DelayedFragmentedTextOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + fragments = [] + async for fragment in ws.recv_streaming(): + fragments.append(fragment) + assert fragments == ["he", "llo"] + + +async def test_fragmented_message_exceeding_max_size_closes_connection(): + async with WSServer(lambda _: FragmentedTextOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None, max_size=4) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + +async def test_send_and_ping_validation_branches(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + assert ws.state is websockets.State.OPEN + assert ws.local_address[0] == "127.0.0.1" + assert ws.remote_address[0] == "127.0.0.1" + assert ws.latency == 0 + assert ws.subprotocol is None + assert ws.close_code is None + assert ws.close_reason is None + + default_waiter = await ws.ping() + await asyncio.wait_for(default_waiter, 1.0) + + waiter = await ws.ping("same") + with pytest.raises(websockets.ConcurrencyError, match="same data"): + await ws.ping(b"same") + await asyncio.wait_for(waiter, 1.0) + + with pytest.raises(TypeError, match="ping payload"): + await ws.ping(object()) # type: ignore[arg-type] + with pytest.raises(TypeError, match="unsupported type"): + await ws.send(object()) # type: ignore[arg-type] + with pytest.raises(TypeError, match="same category"): + await ws.send([b"bytes", "text"]) + + +async def test_send_text_overrides_and_concurrent_send_waits_for_turn(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send("as-binary", text=False) + assert await ws.recv(decode=False) == b"as-binary" + + await ws.send(b"as-text", text=True) + assert await ws.recv() == "as-text" + + first_sent = asyncio.Event() + release = asyncio.Event() + + async def fragments(): + yield b"first" + first_sent.set() + await release.wait() + yield b"second" + + first_send = asyncio.create_task(ws.send(fragments())) + await asyncio.wait_for(first_sent.wait(), 1.0) + + second_send = asyncio.create_task(ws.send(b"after")) + await asyncio.sleep(0) + assert not second_send.done() + + release.set() + await asyncio.wait_for(first_send, 1.0) + await asyncio.wait_for(second_send, 1.0) + assert await ws.recv() == b"firstsecond" + assert await ws.recv() == b"after" + + +async def test_send_string_fragments_and_write_pause_wait_paths(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send(["he", "llo"]) + assert await ws.recv() == "hello" + + ws.pause_writing() + single_send = asyncio.create_task(ws.send(b"paused")) + await asyncio.sleep(0) + assert not single_send.done() + ws.resume_writing() + await asyncio.wait_for(single_send, 1.0) + assert await ws.recv() == b"paused" + + ws.pause_writing() + fragmented_send = asyncio.create_task(ws.send([b"frag", b"mented"])) + await asyncio.sleep(0) + assert not fragmented_send.done() + ws.resume_writing() + await asyncio.wait_for(fragmented_send, 1.0) + assert await ws.recv() == b"fragmented" + + +async def test_send_rejects_invalid_first_fragment_and_closes(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + with pytest.raises(TypeError, match="message must contain"): + await ws.send([object()]) # type: ignore[list-item] + await asyncio.wait_for(ws.wait_closed(), 1.0) + assert ws.state is websockets.State.CLOSED + + +async def test_connection_context_manager_and_close_timeout_none(): + async with WSServer() as server: + ws = await websockets.connect(server.url, compression=None, ping_interval=None, close_timeout=None) + async with ws: + await ws.send("context") + assert await ws.recv() == "context" + assert ws.state is websockets.State.CLOSED + + +async def test_close_timeout_disconnects_when_peer_ignores_close(): + async with WSServer(lambda _: IgnoreCloseListener()) as server: + ws = await websockets.connect( + server.url, + compression=None, + ping_interval=None, + close_timeout=0.01, + ) + await ws.close() + assert ws.state is websockets.State.CLOSED + assert ws.close_code == 1000 + assert ws.close_reason == "" + + +async def test_keepalive_loop_without_ping_timeout_sends_default_pings(): + async with WSServer(enable_auto_pong=True) as server: + async with websockets.connect( + server.url, + compression=None, + ping_interval=0.01, + ping_timeout=None, + ) as ws: + await asyncio.sleep(0.03) + assert ws.latency >= 0 + + +async def test_keepalive_loop_with_ping_timeout_observes_pong(): + async with WSServer(enable_auto_pong=True) as server: + async with websockets.connect( + server.url, + compression=None, + ping_interval=0.01, + ping_timeout=1.0, + ) as ws: + await asyncio.sleep(0.03) + assert ws.latency >= 0 + + +async def test_disconnect_without_close_frame_sets_error_close_state(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send("disconnect_me_without_close_frame") + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + assert ws.close_code is None + assert ws.close_reason is None + + +async def test_send_ping_and_pong_after_close_raise_connection_closed(): + async with WSServer() as server: + ws = await websockets.connect(server.url, compression=None, ping_interval=None) + await ws.close() + + with pytest.raises(websockets.ConnectionClosedOK): + await ws.send(b"closed") + with pytest.raises(websockets.ConnectionClosedOK): + await ws.ping() + with pytest.raises(websockets.ConnectionClosedOK): + await ws.pong() + + +def test_broadcast_validation_and_exception_group(): + with pytest.raises(TypeError, match="data must be str or bytes"): + websockets.broadcast([], object()) # type: ignore[arg-type] + + if sys.version_info[:2] < (3, 11): + with pytest.raises(ValueError, match="requires at least Python 3.11"): + websockets.broadcast([], "hello", raise_exceptions=True) + return + + class BrokenConnection: + state = websockets.State.OPEN + _send_in_progress = False + + def _encode_and_send(self, _msg_type, _message, _fin): + raise RuntimeError("broken") + + with pytest.raises(ExceptionGroup) as exc_info: + websockets.broadcast([BrokenConnection()], "hello", raise_exceptions=True) # type: ignore[list-item] + assert len(exc_info.value.exceptions) == 1 + + +def test_response_to_picows_supports_empty_body_and_status_alias(): + response = websockets.Response( + int(HTTPStatus.SWITCHING_PROTOCOLS), + HTTPStatus.SWITCHING_PROTOCOLS.phrase, + CIMultiDict({"X-Test": "yes"}), + bytearray(b"body"), + ) + + assert response.status == 101 + picows_response = response.to_picows() + assert picows_response.status is HTTPStatus.SWITCHING_PROTOCOLS + assert picows_response.headers["X-Test"] == "yes" + assert picows_response.body == b"body" From d121fc2b728faa9a1625bba0bd215c7f0e3582e9 Mon Sep 17 00:00:00 2001 From: taras Date: Fri, 8 May 2026 19:57:57 +0200 Subject: [PATCH 55/57] Refactor tests --- tests/test_websockets_client.py | 168 +++++ tests/test_websockets_compat.py | 75 --- tests/test_websockets_compat_types.py | 20 + tests/test_websockets_compression_compat.py | 61 -- ...bsockets_compression_failure_edge_cases.py | 89 --- tests/test_websockets_decode_edge_cases.py | 52 -- ...st_websockets_decode_failure_edge_cases.py | 31 - tests/test_websockets_exceptions.py | 50 ++ tests/test_websockets_negotiation.py | 312 +++++++++ tests/test_websockets_ping_pong.py | 80 +++ tests/test_websockets_recv.py | 260 ++++++++ tests/test_websockets_recv_edge_cases.py | 77 --- tests/test_websockets_send.py | 209 ++++++ tests/test_websockets_send_edge_cases.py | 50 -- ...test_websockets_send_failure_edge_cases.py | 86 --- tests/test_websockets_server.py | 149 +++++ tests/test_websockets_server_compat.py | 340 ---------- tests/test_websockets_server_handshake.py | 202 ++++++ tests/test_websockets_unit_coverage.py | 618 ------------------ 19 files changed, 1450 insertions(+), 1479 deletions(-) create mode 100644 tests/test_websockets_client.py delete mode 100644 tests/test_websockets_compat.py create mode 100644 tests/test_websockets_compat_types.py delete mode 100644 tests/test_websockets_compression_compat.py delete mode 100644 tests/test_websockets_compression_failure_edge_cases.py delete mode 100644 tests/test_websockets_decode_edge_cases.py delete mode 100644 tests/test_websockets_decode_failure_edge_cases.py create mode 100644 tests/test_websockets_exceptions.py create mode 100644 tests/test_websockets_negotiation.py create mode 100644 tests/test_websockets_ping_pong.py create mode 100644 tests/test_websockets_recv.py delete mode 100644 tests/test_websockets_recv_edge_cases.py create mode 100644 tests/test_websockets_send.py delete mode 100644 tests/test_websockets_send_edge_cases.py delete mode 100644 tests/test_websockets_send_failure_edge_cases.py create mode 100644 tests/test_websockets_server.py delete mode 100644 tests/test_websockets_server_compat.py create mode 100644 tests/test_websockets_server_handshake.py delete mode 100644 tests/test_websockets_unit_coverage.py diff --git a/tests/test_websockets_client.py b/tests/test_websockets_client.py new file mode 100644 index 0000000..04e5b6d --- /dev/null +++ b/tests/test_websockets_client.py @@ -0,0 +1,168 @@ +import asyncio +import socket + +import pytest +from multidict import CIMultiDict + +import picows +from picows import websockets +from picows.websockets.asyncio.client import _process_proxy +from picows.websockets.asyncio.connection import _normalize_watermarks, _resolve_logger, process_exception +from tests.utils import ServerEchoListener, WSServer + + +def test_client_private_option_helpers(): + assert _process_proxy(None, False) is None + assert _process_proxy("http://127.0.0.1:8080", False) == "http://127.0.0.1:8080" + with pytest.raises(websockets.InvalidProxy): + _process_proxy(123, False) # type: ignore[arg-type] + + assert _normalize_watermarks(None) == (0, 0) + assert _normalize_watermarks((None, 1)) == (0, 0) + assert _normalize_watermarks((8, None)) == (8, 2) + assert _resolve_logger("picows.test").name == "picows.test" + + +async def test_client_connection_starts_in_connecting_state(): + connection = websockets.ClientConnection( + request=websockets.Request("/", CIMultiDict()), + response=websockets.Response(101, "Switching Protocols", CIMultiDict(), b""), + subprotocol=None, + permessage_deflate=None, + ) + + assert connection.state is websockets.State.CONNECTING + + +async def test_connect_await_style_and_socket_options(): + async with WSServer() as server: + ws = await websockets.connect(server.url, compression=None, ping_interval=None) + try: + assert ws.state is websockets.State.OPEN + assert ws.local_address[0] == "127.0.0.1" + assert ws.remote_address[0] == "127.0.0.1" + assert ws.latency == 0 + assert ws.subprotocol is None + assert ws.close_code is None + assert ws.close_reason is None + await ws.send("awaited") + assert await ws.recv() == "awaited" + finally: + await ws.close() + + sock = socket.create_connection((server.host, server.port)) + ws = await websockets.connect(server.url, compression=None, ping_interval=None, sock=sock) + try: + await ws.send(b"sock") + assert await ws.recv() == b"sock" + finally: + await ws.close() + + async with websockets.connect( + "ws://example.invalid/", + compression=None, + ping_interval=None, + proxy=None, + host=server.host, + port=server.port, + ) as ws: + await ws.send("override") + assert await ws.recv() == "override" + + +async def test_connect_rejects_conflicting_and_invalid_socket_options(): + async with WSServer() as server: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + with pytest.raises(TypeError, match="cannot pass both sock and socket_factory"): + await websockets.connect(server.url, compression=None, sock=sock, socket_factory=lambda _: None) + finally: + sock.close() + + with pytest.raises(TypeError, match="sock must be a socket.socket instance"): + await websockets.connect(server.url, compression=None, sock=object()) + + with pytest.raises(TypeError, match="cannot pass both host/port override and socket_factory"): + await websockets.connect( + server.url, + compression=None, + host=server.host, + socket_factory=lambda _: None, + ) + + +async def test_connect_rejects_invalid_ssl_options_before_network(): + with pytest.raises(NotImplementedError, match="ssl=False"): + await websockets.connect("wss://example.com/", compression=None, ssl=False) + with pytest.raises(TypeError, match="ssl must be"): + await websockets.connect("wss://example.com/", compression=None, ssl=object()) + + +def test_process_exception_retries_transient_failures(): + assert process_exception(EOFError()) is None + assert process_exception(OSError()) is None + assert process_exception(asyncio.TimeoutError()) is None + + response = websockets.Response(503, "Service Unavailable", CIMultiDict(), b"") + assert process_exception(websockets.InvalidStatus(response)) is None + + error = RuntimeError("boom") + assert process_exception(error) is error + + +async def test_connect_async_iterator_retries_then_succeeds(): + attempts = 0 + + def process_exception(exc: Exception) -> Exception | None: + nonlocal attempts + attempts += 1 + if attempts == 1: + return None + return exc + + connector = websockets.connect( + "ws://127.0.0.1:1/", + compression=None, + open_timeout=0.01, + process_exception=process_exception, + ) + connector._backoff = 0 + with pytest.raises(OSError): + async for _ws in connector: + pass + assert attempts == 2 + + +def test_connect_rejects_create_connection(): + with pytest.raises(NotImplementedError): + websockets.connect("ws://example.com", create_connection=websockets.ClientConnection) + + +async def test_connection_context_manager_and_close_timeout_none(): + async with WSServer() as server: + ws = await websockets.connect(server.url, compression=None, ping_interval=None, close_timeout=None) + async with ws: + await ws.send("context") + assert await ws.recv() == "context" + assert ws.state is websockets.State.CLOSED + + +class IgnoreCloseListener(ServerEchoListener): + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + if frame.msg_type == picows.WSMsgType.CLOSE: + return + super().on_ws_frame(transport, frame) + + +async def test_close_timeout_disconnects_when_peer_ignores_close(): + async with WSServer(lambda _: IgnoreCloseListener()) as server: + ws = await websockets.connect( + server.url, + compression=None, + ping_interval=None, + close_timeout=0.01, + ) + await ws.close() + assert ws.state is websockets.State.CLOSED + assert ws.close_code == 1000 + assert ws.close_reason == "" diff --git a/tests/test_websockets_compat.py b/tests/test_websockets_compat.py deleted file mode 100644 index 2c1cb05..0000000 --- a/tests/test_websockets_compat.py +++ /dev/null @@ -1,75 +0,0 @@ -import asyncio - -import pytest - -import picows -from picows import websockets -from tests.utils import WSServer - - -async def test_connect_send_recv_text(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None) as ws: - await ws.send("hello") - reply = await ws.recv() - assert reply == "hello" - - -async def test_connect_send_recv_binary(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None) as ws: - await ws.send(b"hello") - reply = await ws.recv() - assert reply == b"hello" - - -async def test_async_iteration_closes_normally(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None) as ws: - await ws.send("hello") - assert await ws.recv() == "hello" - await ws.close() - - items = [] - async for item in ws: - items.append(item) - - assert items == [] - - -async def test_ping_returns_waiter(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - pong_waiter = await ws.ping(b"abcd") - latency = await asyncio.wait_for(pong_waiter, 1.0) - assert latency >= 0 - - -async def test_recv_streaming_fragmented_message(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None) as ws: - await ws.send([b"ab", b"cd"]) - fragments = [] - async for fragment in ws.recv_streaming(): - fragments.append(fragment) - assert fragments == [b"ab", b"cd", b""] - - -async def test_subprotocol_header_and_property(): - request_headers = {} - - def listener_factory(request): - request_headers["value"] = request.headers.get("Sec-WebSocket-Protocol") - return None - - async with WSServer(listener_factory) as server: - with pytest.raises(websockets.InvalidStatus): - async with websockets.connect(server.url, compression=None, subprotocols=["chat"]): - pass - - assert request_headers["value"] == "chat" - - -def test_connect_rejects_create_connection(): - with pytest.raises(NotImplementedError): - websockets.connect("ws://example.com", create_connection=websockets.ClientConnection) diff --git a/tests/test_websockets_compat_types.py b/tests/test_websockets_compat_types.py new file mode 100644 index 0000000..abdb948 --- /dev/null +++ b/tests/test_websockets_compat_types.py @@ -0,0 +1,20 @@ +from http import HTTPStatus + +from multidict import CIMultiDict + +from picows import websockets + + +def test_response_to_picows_supports_empty_body_and_status_alias(): + response = websockets.Response( + int(HTTPStatus.SWITCHING_PROTOCOLS), + HTTPStatus.SWITCHING_PROTOCOLS.phrase, + CIMultiDict({"X-Test": "yes"}), + bytearray(b"body"), + ) + + assert response.status == 101 + picows_response = response.to_picows() + assert picows_response.status is HTTPStatus.SWITCHING_PROTOCOLS + assert picows_response.headers["X-Test"] == "yes" + assert picows_response.body == b"body" diff --git a/tests/test_websockets_compression_compat.py b/tests/test_websockets_compression_compat.py deleted file mode 100644 index 954e278..0000000 --- a/tests/test_websockets_compression_compat.py +++ /dev/null @@ -1,61 +0,0 @@ -from contextlib import asynccontextmanager - -import websockets as upstream_websockets - -from picows import websockets - - -@asynccontextmanager -async def upstream_server(handler): - server = await upstream_websockets.serve( - handler, - "127.0.0.1", - 0, - compression="deflate", - ) - port = server.sockets[0].getsockname()[1] - try: - yield f"ws://127.0.0.1:{port}/" - finally: - server.close() - await server.wait_closed() - - -async def test_permessage_deflate_echo_with_upstream_server(): - async def handler(ws): - async for message in ws: - await ws.send(message) - - async with upstream_server(handler) as url: - async with websockets.connect(url, ping_interval=None) as ws: - assert "permessage-deflate" in (ws.response.headers.get("Sec-WebSocket-Extensions") or "") - - message = "hello " * 1000 - await ws.send(message) - assert await ws.recv() == message - - -async def test_permessage_deflate_fragmented_send_with_upstream_server(): - async def handler(ws): - async for message in ws: - await ws.send(message) - - async with upstream_server(handler) as url: - async with websockets.connect(url, ping_interval=None) as ws: - await ws.send([b"a" * 300, b"b" * 300, b"c" * 300]) - assert await ws.recv() == (b"a" * 300 + b"b" * 300 + b"c" * 300) - - -async def test_permessage_deflate_recv_streaming_from_upstream_server(): - chunks = [b"ab" * 300, b"cd" * 300, b"ef" * 300] - - async def handler(ws): - await ws.send(chunks) - await ws.close() - - async with upstream_server(handler) as url: - async with websockets.connect(url, ping_interval=None) as ws: - fragments = [] - async for fragment in ws.recv_streaming(): - fragments.append(fragment) - assert fragments == chunks + [b""] diff --git a/tests/test_websockets_compression_failure_edge_cases.py b/tests/test_websockets_compression_failure_edge_cases.py deleted file mode 100644 index 2d4f474..0000000 --- a/tests/test_websockets_compression_failure_edge_cases.py +++ /dev/null @@ -1,89 +0,0 @@ -import asyncio -import base64 -import hashlib -from contextlib import asynccontextmanager - -import pytest -import websockets as upstream_websockets - -from picows import websockets - - -@asynccontextmanager -async def upstream_server(handler): - server = await upstream_websockets.serve( - handler, - "127.0.0.1", - 0, - compression="deflate", - ) - port = server.sockets[0].getsockname()[1] - try: - yield f"ws://127.0.0.1:{port}/" - finally: - server.close() - await server.wait_closed() - - -@asynccontextmanager -async def malformed_compressed_server(): - async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): - request = await reader.readuntil(b"\r\n\r\n") - headers = request.decode("ascii").split("\r\n") - key = None - for header in headers: - if header.lower().startswith("sec-websocket-key:"): - key = header.split(":", 1)[1].strip() - break - assert key is not None - - accept = base64.b64encode( - hashlib.sha1( - (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode("ascii") - ).digest() - ).decode("ascii") - - response = ( - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: websocket\r\n" - "Connection: Upgrade\r\n" - f"Sec-WebSocket-Accept: {accept}\r\n" - "Sec-WebSocket-Extensions: permessage-deflate\r\n" - "\r\n" - ) - writer.write(response.encode("ascii")) - - payload = b"not-a-valid-deflate-stream" - frame = bytes([0xC1, len(payload)]) + payload - writer.write(frame) - await writer.drain() - await asyncio.sleep(0.1) - writer.close() - await writer.wait_closed() - - server = await asyncio.start_server(handler, "127.0.0.1", 0) - port = server.sockets[0].getsockname()[1] - try: - yield f"ws://127.0.0.1:{port}/" - finally: - server.close() - await server.wait_closed() - - -async def test_compressed_message_exceeding_max_size_closes_connection(): - async def handler(ws): - await ws.send("a" * 10000) - - async with upstream_server(handler) as url: - async with websockets.connect( - url, ping_interval=None, max_size=1000 - ) as ws: - with pytest.raises(websockets.ConnectionClosedError): - await ws.recv() - - -async def test_malformed_compressed_message_closes_connection(): - async with malformed_compressed_server() as url: - async with websockets.connect(url, ping_interval=None) as ws: - with pytest.raises(websockets.ConnectionClosedError): - await ws.recv() diff --git a/tests/test_websockets_decode_edge_cases.py b/tests/test_websockets_decode_edge_cases.py deleted file mode 100644 index a7c17c3..0000000 --- a/tests/test_websockets_decode_edge_cases.py +++ /dev/null @@ -1,52 +0,0 @@ -import picows - -from picows import websockets -from tests.utils import ServerEchoListener, WSServer - - -class SendTextOnConnect(ServerEchoListener): - def __init__(self, payload: bytes): - self._payload = payload - - def on_ws_connected(self, transport: picows.WSTransport): - super().on_ws_connected(transport) - transport.send(picows.WSMsgType.TEXT, self._payload) - - -class SendBinaryOnConnect(ServerEchoListener): - def __init__(self, payload: bytes): - self._payload = payload - - def on_ws_connected(self, transport: picows.WSTransport): - super().on_ws_connected(transport) - transport.send(picows.WSMsgType.BINARY, self._payload) - - -async def test_recv_decode_false_returns_bytes_for_text_messages(): - async with WSServer(lambda _: SendTextOnConnect(b"hello")) as server: - async with websockets.connect(server.url, compression=None) as ws: - assert await ws.recv(decode=False) == b"hello" - - -async def test_recv_decode_true_returns_text_for_binary_messages(): - async with WSServer(lambda _: SendBinaryOnConnect(b"hello")) as server: - async with websockets.connect(server.url, compression=None) as ws: - assert await ws.recv(decode=True) == "hello" - - -async def test_recv_streaming_decode_false_returns_bytes_for_text_messages(): - async with WSServer(lambda _: SendTextOnConnect(b"hello")) as server: - async with websockets.connect(server.url, compression=None) as ws: - fragments = [] - async for fragment in ws.recv_streaming(decode=False): - fragments.append(fragment) - assert fragments == [b"hello"] - - -async def test_recv_streaming_decode_true_returns_text_for_binary_messages(): - async with WSServer(lambda _: SendBinaryOnConnect(b"hello")) as server: - async with websockets.connect(server.url, compression=None) as ws: - fragments = [] - async for fragment in ws.recv_streaming(decode=True): - fragments.append(fragment) - assert fragments == ["hello"] diff --git a/tests/test_websockets_decode_failure_edge_cases.py b/tests/test_websockets_decode_failure_edge_cases.py deleted file mode 100644 index 33025ac..0000000 --- a/tests/test_websockets_decode_failure_edge_cases.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest - -import picows -from picows import websockets -from tests.utils import ServerEchoListener, WSServer - - -class SendBinaryOnConnect(ServerEchoListener): - def __init__(self, payload: bytes): - self._payload = payload - - def on_ws_connected(self, transport: picows.WSTransport): - super().on_ws_connected(transport) - transport.send(picows.WSMsgType.BINARY, self._payload) - - -async def test_recv_decode_true_invalid_utf8_closes_connection(): - async with WSServer(lambda _: SendBinaryOnConnect(b"\xff\xfe")) as server: - async with websockets.connect(server.url, compression=None) as ws: - with pytest.raises(websockets.ConnectionClosedError): - await ws.recv(decode=True) - assert ws.close_code == 1007 - - -async def test_recv_streaming_decode_true_invalid_utf8_closes_connection(): - async with WSServer(lambda _: SendBinaryOnConnect(b"\xff\xfe")) as server: - async with websockets.connect(server.url, compression=None) as ws: - with pytest.raises(websockets.ConnectionClosedError): - async for _fragment in ws.recv_streaming(decode=True): - pass - assert ws.close_code == 1007 diff --git a/tests/test_websockets_exceptions.py b/tests/test_websockets_exceptions.py new file mode 100644 index 0000000..20bc893 --- /dev/null +++ b/tests/test_websockets_exceptions.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass + +from multidict import CIMultiDict + +from picows import websockets + + +@dataclass +class Close: + code: int + reason: str + + +def test_connection_closed_string_variants(): + assert str(websockets.ConnectionClosed(None, None)) == "no close frame received or sent" + assert str(websockets.ConnectionClosed(None, Close(1000, "bye"))) == "sent 1000 (bye)" + assert str(websockets.ConnectionClosed(Close(1001, "away"), None)) == "received 1001 (away)" + assert str(websockets.ConnectionClosed(Close(1001, "away"), Close(1000, "bye"), True)) == ( + "received then sent close frames: received 1001 (away), sent 1000 (bye)" + ) + assert str(websockets.ConnectionClosed(Close(1001, "away"), Close(1000, "bye"), False)) == ( + "sent then received close frames: received 1001 (away), sent 1000 (bye)" + ) + + +def test_exception_attributes_and_strings(): + assert str(websockets.InvalidURI("http://example.com", "wrong scheme")) == ( + "http://example.com isn't a valid WebSocket URI: wrong scheme" + ) + assert str(websockets.InvalidProxy("ftp://proxy", "wrong scheme")) == ( + "ftp://proxy isn't a valid proxy: wrong scheme" + ) + assert str(websockets.InvalidProxyStatus(object())) == "proxy rejected connection" + assert str(websockets.InvalidProxyStatus(websockets.Response(502, "Bad Gateway", CIMultiDict(), b""))) == ( + "proxy rejected connection: HTTP 502" + ) + + invalid_header = websockets.InvalidHeader("X-Test", "bad") + assert invalid_header.name == "X-Test" + assert invalid_header.value == "bad" + assert websockets.InvalidOrigin("https://bad.example").name == "Origin" + assert websockets.InvalidHeaderFormat("X-Test", "bad syntax", "x:y", 1).value == "bad syntax at 1 in x:y" + + assert str(websockets.DuplicateParameter("server_max_window_bits")) == ( + "duplicate parameter: server_max_window_bits" + ) + assert str(websockets.InvalidParameterName("x")) == "invalid parameter name: x" + assert str(websockets.InvalidParameterValue("x", None)) == "missing value for parameter x" + assert str(websockets.InvalidParameterValue("x", "")) == "empty value for parameter x" + assert str(websockets.InvalidParameterValue("x", "bad")) == "invalid value for parameter x: bad" diff --git a/tests/test_websockets_negotiation.py b/tests/test_websockets_negotiation.py new file mode 100644 index 0000000..cf027c8 --- /dev/null +++ b/tests/test_websockets_negotiation.py @@ -0,0 +1,312 @@ +import asyncio +import base64 +import hashlib +from contextlib import asynccontextmanager + +import picows +import pytest +import websockets as upstream_websockets +from multidict import CIMultiDict + +from picows import websockets +from picows.websockets.asyncio.connection import _PerMessageDeflate +from picows.websockets.asyncio.negotiation import configure_permessage_deflate, resolve_subprotocol +from tests.utils import WSServer + + +@asynccontextmanager +async def upstream_server(handler): + server = await upstream_websockets.serve( + handler, + "127.0.0.1", + 0, + compression="deflate", + ) + port = server.sockets[0].getsockname()[1] + try: + yield f"ws://127.0.0.1:{port}/" + finally: + server.close() + await server.wait_closed() + + +@asynccontextmanager +async def malformed_compressed_server(): + async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + request = await reader.readuntil(b"\r\n\r\n") + headers = request.decode("ascii").split("\r\n") + key = None + for header in headers: + if header.lower().startswith("sec-websocket-key:"): + key = header.split(":", 1)[1].strip() + break + assert key is not None + + accept = base64.b64encode( + hashlib.sha1( + (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode("ascii") + ).digest() + ).decode("ascii") + + response = ( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + f"Sec-WebSocket-Accept: {accept}\r\n" + "Sec-WebSocket-Extensions: permessage-deflate\r\n" + "\r\n" + ) + writer.write(response.encode("ascii")) + + payload = b"not-a-valid-deflate-stream" + frame = bytes([0xC1, len(payload)]) + payload + writer.write(frame) + await writer.drain() + await asyncio.sleep(0.1) + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(handler, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + try: + yield f"ws://127.0.0.1:{port}/" + finally: + server.close() + await server.wait_closed() + + +class Frame: + def __init__( + self, + msg_type: picows.WSMsgType, + payload: bytes, + *, + fin: bool = True, + rsv1: bool = False, + ): + self.msg_type = msg_type + self.payload = payload + self.fin = fin + self.rsv1 = rsv1 + + def get_payload_as_bytes(self) -> bytes: + return self.payload + + def get_payload_as_memoryview(self) -> memoryview: + return memoryview(self.payload) + + +async def test_subprotocol_header_and_property(): + request_headers = {} + + def listener_factory(request): + request_headers["value"] = request.headers.get("Sec-WebSocket-Protocol") + return None + + async with WSServer(listener_factory) as server: + with pytest.raises(websockets.InvalidStatus): + async with websockets.connect(server.url, compression=None, subprotocols=["chat"]): + pass + + assert request_headers["value"] == "chat" + + +async def test_serve_negotiates_subprotocol(): + async def handler(ws: websockets.ServerConnection) -> None: + await ws.wait_closed() + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + subprotocols=["chat", "superchat"], + ) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + subprotocols=["superchat", "chat"], + ) as ws: + assert ws.subprotocol == "chat" + + +async def test_select_subprotocol_receives_handshake_connection(): + seen = {} + + async def handler(ws: websockets.ServerConnection) -> None: + await ws.wait_closed() + + def select_subprotocol( + ws: websockets.ServerHandshakeConnection, + offered: list[str], + ) -> str | None: + seen["type"] = type(ws) + seen["path"] = ws.request.path + seen["state"] = ws.state + seen["has_recv"] = hasattr(ws, "recv") + if "chat" in offered: + return "chat" + return None + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + select_subprotocol=select_subprotocol, + ) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect( + f"ws://127.0.0.1:{port}/room", + compression=None, + subprotocols=["chat"], + ) as ws: + assert ws.subprotocol == "chat" + + assert seen["type"] is websockets.ServerHandshakeConnection + assert seen["path"] == "/room" + assert seen["state"] is websockets.State.CONNECTING + assert seen["has_recv"] is False + + +def test_negotiation_rejects_invalid_subprotocol_and_extension_headers(): + response = websockets.Response(101, "Switching Protocols", CIMultiDict(), b"") + + response.headers["Sec-WebSocket-Protocol"] = 123 # type: ignore[assignment] + with pytest.raises(websockets.InvalidHandshake, match="non-string subprotocol"): + resolve_subprotocol(["chat"], response) + + response.headers["Sec-WebSocket-Protocol"] = "other" + with pytest.raises(websockets.InvalidHandshake, match="unsupported subprotocol"): + resolve_subprotocol(["chat"], response) + + response.headers.clear() + response.headers["Sec-WebSocket-Extensions"] = "permessage-deflate" + with pytest.raises(websockets.InvalidHandshake, match="unexpected websocket extensions"): + configure_permessage_deflate(response, None) + + response.headers["Sec-WebSocket-Extensions"] = 123 # type: ignore[assignment] + with pytest.raises(websockets.InvalidHandshake, match="invalid Sec-WebSocket-Extensions"): + configure_permessage_deflate(response, "deflate") + + +def test_permessage_deflate_rejects_invalid_parameters(): + response = websockets.Response(101, "Switching Protocols", CIMultiDict(), b"") + invalid_headers = [ + "x-webkit-deflate-frame", + "permessage-deflate, permessage-deflate", + "permessage-deflate; server_no_context_takeover=true", + "permessage-deflate; client_no_context_takeover=true", + "permessage-deflate; server_max_window_bits", + "permessage-deflate; server_max_window_bits=7", + "permessage-deflate; client_max_window_bits", + "permessage-deflate; client_max_window_bits=16", + "permessage-deflate; unknown=value", + "permessage-deflate; server_max_window_bits=15; server_max_window_bits=15", + ] + + for header in invalid_headers: + response.headers["Sec-WebSocket-Extensions"] = header + with pytest.raises(websockets.InvalidHandshake): + configure_permessage_deflate(response, "deflate") + + +def test_permessage_deflate_accepts_no_context_takeover_parameters(): + permessage_deflate = _PerMessageDeflate.from_response_header( + "permessage-deflate; server_no_context_takeover; client_no_context_takeover" + ) + + first = permessage_deflate.encode_frame(picows.WSMsgType.TEXT, "hello", True) + second = permessage_deflate.encode_frame(picows.WSMsgType.TEXT, b"hello", True) + + assert isinstance(first, memoryview) + assert isinstance(second, memoryview) + assert bytes(first) + assert bytes(second) + + +def test_permessage_deflate_decode_passthrough_and_protocol_error_branches(): + permessage_deflate = _PerMessageDeflate.from_response_header( + "permessage-deflate; server_no_context_takeover; client_no_context_takeover" + ) + + assert permessage_deflate.decode_frame( + Frame(picows.WSMsgType.TEXT, b"plain", rsv1=False), + 0, + ) == b"plain" + assert permessage_deflate.decode_frame( + Frame(picows.WSMsgType.CONTINUATION, b"continuation", rsv1=False), + 0, + ) == b"continuation" + + encoded = bytes(permessage_deflate.encode_frame(picows.WSMsgType.TEXT, b"compressed", True)) + assert permessage_deflate.decode_frame( + Frame(picows.WSMsgType.TEXT, encoded, rsv1=True), + 100, + ) == b"compressed" + + with pytest.raises(picows.WSProtocolError, match="unexpected rsv1"): + permessage_deflate.decode_frame( + Frame(picows.WSMsgType.CONTINUATION, b"bad", rsv1=True), + 0, + ) + + +async def test_permessage_deflate_echo_with_upstream_server(): + async def handler(ws): + async for message in ws: + await ws.send(message) + + async with upstream_server(handler) as url: + async with websockets.connect(url, ping_interval=None) as ws: + assert "permessage-deflate" in (ws.response.headers.get("Sec-WebSocket-Extensions") or "") + + message = "hello " * 1000 + await ws.send(message) + assert await ws.recv() == message + + +async def test_permessage_deflate_fragmented_send_with_upstream_server(): + async def handler(ws): + async for message in ws: + await ws.send(message) + + async with upstream_server(handler) as url: + async with websockets.connect(url, ping_interval=None) as ws: + await ws.send([b"a" * 300, b"b" * 300, b"c" * 300]) + assert await ws.recv() == (b"a" * 300 + b"b" * 300 + b"c" * 300) + + +async def test_permessage_deflate_recv_streaming_from_upstream_server(): + chunks = [b"ab" * 300, b"cd" * 300, b"ef" * 300] + + async def handler(ws): + await ws.send(chunks) + await ws.close() + + async with upstream_server(handler) as url: + async with websockets.connect(url, ping_interval=None) as ws: + fragments = [] + async for fragment in ws.recv_streaming(): + fragments.append(fragment) + assert fragments == chunks + [b""] + + +async def test_compressed_message_exceeding_max_size_closes_connection(): + async def handler(ws): + await ws.send("a" * 10000) + + async with upstream_server(handler) as url: + async with websockets.connect( + url, ping_interval=None, max_size=1000 + ) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + +async def test_malformed_compressed_message_closes_connection(): + async with malformed_compressed_server() as url: + async with websockets.connect(url, ping_interval=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() diff --git a/tests/test_websockets_ping_pong.py b/tests/test_websockets_ping_pong.py new file mode 100644 index 0000000..68e3bcb --- /dev/null +++ b/tests/test_websockets_ping_pong.py @@ -0,0 +1,80 @@ +import asyncio + +import pytest + +from picows import websockets +from tests.utils import WSServer + + +async def test_ping_returns_waiter(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + pong_waiter = await ws.ping(b"abcd") + latency = await asyncio.wait_for(pong_waiter, 1.0) + assert latency >= 0 + + +async def test_ping_accepts_byteslike_payloads(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + pong_waiter = await ws.ping(bytearray(b"abcd")) + await asyncio.wait_for(pong_waiter, 1.0) + pong_waiter = await ws.ping(memoryview(b"efgh")) + await asyncio.wait_for(pong_waiter, 1.0) + + +async def test_pong_accepts_byteslike_payloads(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.pong(bytearray(b"abcd")) + await ws.pong(memoryview(b"efgh")) + + +async def test_ping_default_duplicate_and_invalid_payloads(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + default_waiter = await ws.ping() + await asyncio.wait_for(default_waiter, 1.0) + + waiter = await ws.ping("same") + with pytest.raises(websockets.ConcurrencyError, match="same data"): + await ws.ping(b"same") + await asyncio.wait_for(waiter, 1.0) + + with pytest.raises(TypeError, match="ping payload"): + await ws.ping(object()) # type: ignore[arg-type] + + +async def test_keepalive_loop_without_ping_timeout_sends_default_pings(): + async with WSServer(enable_auto_pong=True) as server: + async with websockets.connect( + server.url, + compression=None, + ping_interval=0.01, + ping_timeout=None, + ) as ws: + await asyncio.sleep(0.03) + assert ws.latency >= 0 + + +async def test_keepalive_loop_with_ping_timeout_observes_pong(): + async with WSServer(enable_auto_pong=True) as server: + async with websockets.connect( + server.url, + compression=None, + ping_interval=0.01, + ping_timeout=1.0, + ) as ws: + await asyncio.sleep(0.03) + assert ws.latency >= 0 + + +async def test_ping_and_pong_after_close_raise_connection_closed(): + async with WSServer() as server: + ws = await websockets.connect(server.url, compression=None, ping_interval=None) + await ws.close() + + with pytest.raises(websockets.ConnectionClosedOK): + await ws.ping() + with pytest.raises(websockets.ConnectionClosedOK): + await ws.pong() diff --git a/tests/test_websockets_recv.py b/tests/test_websockets_recv.py new file mode 100644 index 0000000..5f00159 --- /dev/null +++ b/tests/test_websockets_recv.py @@ -0,0 +1,260 @@ +import asyncio + +import picows +import pytest + +from picows import websockets +from tests.utils import ServerEchoListener, WSServer + + +class SendTextOnConnect(ServerEchoListener): + def __init__(self, payload: bytes): + self._payload = payload + + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.TEXT, self._payload) + + +class SendBinaryOnConnect(ServerEchoListener): + def __init__(self, payload: bytes): + self._payload = payload + + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.BINARY, self._payload) + + +class FragmentedTextListener(ServerEchoListener): + def __init__(self, allow_first_fragment: asyncio.Event, allow_second_fragment: asyncio.Event): + self._allow_first_fragment = allow_first_fragment + self._allow_second_fragment = allow_second_fragment + + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + + async def send_fragments(): + await self._allow_first_fragment.wait() + transport.send(picows.WSMsgType.TEXT, b"he", fin=False) + await self._allow_second_fragment.wait() + transport.send(picows.WSMsgType.CONTINUATION, b"llo", fin=True) + + asyncio.create_task(send_fragments()) + + +class ContinuationOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.CONTINUATION, b"bad", fin=True) + + +class Rsv1OnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.TEXT, b"bad", fin=True, rsv1=True) + + +class BadContinuationSequenceOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.TEXT, b"first", fin=False) + transport.send(picows.WSMsgType.TEXT, b"second", fin=True) + + +class SendLargeTextOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.TEXT, b"large") + + +class DelayedFragmentedTextOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + + async def send_fragments(): + transport.send(picows.WSMsgType.TEXT, b"he", fin=False) + await asyncio.sleep(0.01) + transport.send(picows.WSMsgType.CONTINUATION, b"llo", fin=True) + + asyncio.create_task(send_fragments()) + + +class FragmentedTextOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.TEXT, b"he", fin=False) + transport.send(picows.WSMsgType.CONTINUATION, b"llo", fin=True) + + +async def test_async_iteration_closes_normally(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + await ws.send("hello") + assert await ws.recv() == "hello" + await ws.close() + + items = [] + async for item in ws: + items.append(item) + + assert items == [] + + +async def test_recv_streaming_fragmented_message(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + await ws.send([b"ab", b"cd"]) + fragments = [] + async for fragment in ws.recv_streaming(): + fragments.append(fragment) + assert fragments == [b"ab", b"cd", b""] + + +async def test_recv_cancellation_is_safe_for_fragmented_message(): + allow_first_fragment = asyncio.Event() + allow_second_fragment = asyncio.Event() + + async with WSServer(lambda _: FragmentedTextListener(allow_first_fragment, allow_second_fragment)) as server: + async with websockets.connect(server.url, compression=None) as ws: + recv_task = asyncio.create_task(ws.recv()) + allow_first_fragment.set() + await asyncio.sleep(0) + recv_task.cancel() + with pytest.raises(asyncio.CancelledError): + await recv_task + + allow_second_fragment.set() + assert await ws.recv() == "hello" + + +async def test_recv_streaming_cancellation_before_first_fragment_is_safe(): + allow_first_fragment = asyncio.Event() + allow_second_fragment = asyncio.Event() + + async with WSServer(lambda _: FragmentedTextListener(allow_first_fragment, allow_second_fragment)) as server: + async with websockets.connect(server.url, compression=None) as ws: + iterator = ws.recv_streaming() + recv_task = asyncio.create_task(anext(iterator)) + recv_task.cancel() + with pytest.raises(asyncio.CancelledError): + await recv_task + + allow_first_fragment.set() + allow_second_fragment.set() + assert await ws.recv() == "hello" + + +async def test_recv_streaming_partial_consumption_breaks_future_receives(): + allow_first_fragment = asyncio.Event() + allow_second_fragment = asyncio.Event() + + async with WSServer(lambda _: FragmentedTextListener(allow_first_fragment, allow_second_fragment)) as server: + async with websockets.connect(server.url, compression=None) as ws: + iterator = ws.recv_streaming() + allow_first_fragment.set() + assert await anext(iterator) == "he" + + with pytest.raises(websockets.ConcurrencyError): + await ws.recv() + + allow_second_fragment.set() + + with pytest.raises(websockets.ConcurrencyError): + await ws.recv() + + +async def test_recv_decode_false_returns_bytes_for_text_messages(): + async with WSServer(lambda _: SendTextOnConnect(b"hello")) as server: + async with websockets.connect(server.url, compression=None) as ws: + assert await ws.recv(decode=False) == b"hello" + + +async def test_recv_decode_true_returns_text_for_binary_messages(): + async with WSServer(lambda _: SendBinaryOnConnect(b"hello")) as server: + async with websockets.connect(server.url, compression=None) as ws: + assert await ws.recv(decode=True) == "hello" + + +async def test_recv_streaming_decode_false_returns_bytes_for_text_messages(): + async with WSServer(lambda _: SendTextOnConnect(b"hello")) as server: + async with websockets.connect(server.url, compression=None) as ws: + fragments = [] + async for fragment in ws.recv_streaming(decode=False): + fragments.append(fragment) + assert fragments == [b"hello"] + + +async def test_recv_streaming_decode_true_returns_text_for_binary_messages(): + async with WSServer(lambda _: SendBinaryOnConnect(b"hello")) as server: + async with websockets.connect(server.url, compression=None) as ws: + fragments = [] + async for fragment in ws.recv_streaming(decode=True): + fragments.append(fragment) + assert fragments == ["hello"] + + +async def test_recv_decode_true_invalid_utf8_closes_connection(): + async with WSServer(lambda _: SendBinaryOnConnect(b"\xff\xfe")) as server: + async with websockets.connect(server.url, compression=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv(decode=True) + assert ws.close_code == 1007 + + +async def test_recv_streaming_decode_true_invalid_utf8_closes_connection(): + async with WSServer(lambda _: SendBinaryOnConnect(b"\xff\xfe")) as server: + async with websockets.connect(server.url, compression=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + async for _fragment in ws.recv_streaming(decode=True): + pass + assert ws.close_code == 1007 + + +async def test_recv_rejects_unexpected_continuation_and_rsv1(): + async with WSServer(lambda _: ContinuationOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + async with WSServer(lambda _: Rsv1OnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + +async def test_recv_rejects_bad_continuation_and_too_large_message(): + async with WSServer(lambda _: BadContinuationSequenceOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + async with WSServer(lambda _: SendLargeTextOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None, max_size=2) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + +async def test_recv_streaming_waits_for_later_fragment(): + async with WSServer(lambda _: DelayedFragmentedTextOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + fragments = [] + async for fragment in ws.recv_streaming(): + fragments.append(fragment) + assert fragments == ["he", "llo"] + + +async def test_fragmented_message_exceeding_max_size_closes_connection(): + async with WSServer(lambda _: FragmentedTextOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None, max_size=4) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + +async def test_disconnect_without_close_frame_sets_error_close_state(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send("disconnect_me_without_close_frame") + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + assert ws.close_code is None + assert ws.close_reason is None diff --git a/tests/test_websockets_recv_edge_cases.py b/tests/test_websockets_recv_edge_cases.py deleted file mode 100644 index 81cc8e6..0000000 --- a/tests/test_websockets_recv_edge_cases.py +++ /dev/null @@ -1,77 +0,0 @@ -import asyncio - -import picows -import pytest - -from picows import websockets -from tests.utils import ServerEchoListener, WSServer - - -class FragmentedTextListener(ServerEchoListener): - def __init__(self, allow_first_fragment: asyncio.Event, allow_second_fragment: asyncio.Event): - self._allow_first_fragment = allow_first_fragment - self._allow_second_fragment = allow_second_fragment - - def on_ws_connected(self, transport: picows.WSTransport): - super().on_ws_connected(transport) - - async def send_fragments(): - await self._allow_first_fragment.wait() - transport.send(picows.WSMsgType.TEXT, b"he", fin=False) - await self._allow_second_fragment.wait() - transport.send(picows.WSMsgType.CONTINUATION, b"llo", fin=True) - - asyncio.create_task(send_fragments()) - - -async def test_recv_cancellation_is_safe_for_fragmented_message(): - allow_first_fragment = asyncio.Event() - allow_second_fragment = asyncio.Event() - - async with WSServer(lambda _: FragmentedTextListener(allow_first_fragment, allow_second_fragment)) as server: - async with websockets.connect(server.url, compression=None) as ws: - recv_task = asyncio.create_task(ws.recv()) - allow_first_fragment.set() - await asyncio.sleep(0) - recv_task.cancel() - with pytest.raises(asyncio.CancelledError): - await recv_task - - allow_second_fragment.set() - assert await ws.recv() == "hello" - - -async def test_recv_streaming_cancellation_before_first_fragment_is_safe(): - allow_first_fragment = asyncio.Event() - allow_second_fragment = asyncio.Event() - - async with WSServer(lambda _: FragmentedTextListener(allow_first_fragment, allow_second_fragment)) as server: - async with websockets.connect(server.url, compression=None) as ws: - iterator = ws.recv_streaming() - recv_task = asyncio.create_task(anext(iterator)) - recv_task.cancel() - with pytest.raises(asyncio.CancelledError): - await recv_task - - allow_first_fragment.set() - allow_second_fragment.set() - assert await ws.recv() == "hello" - - -async def test_recv_streaming_partial_consumption_breaks_future_receives(): - allow_first_fragment = asyncio.Event() - allow_second_fragment = asyncio.Event() - - async with WSServer(lambda _: FragmentedTextListener(allow_first_fragment, allow_second_fragment)) as server: - async with websockets.connect(server.url, compression=None) as ws: - iterator = ws.recv_streaming() - allow_first_fragment.set() - assert await anext(iterator) == "he" - - with pytest.raises(websockets.ConcurrencyError): - await ws.recv() - - allow_second_fragment.set() - - with pytest.raises(websockets.ConcurrencyError): - await ws.recv() diff --git a/tests/test_websockets_send.py b/tests/test_websockets_send.py new file mode 100644 index 0000000..da3397d --- /dev/null +++ b/tests/test_websockets_send.py @@ -0,0 +1,209 @@ +import asyncio +from collections.abc import AsyncIterator + +import pytest + +from picows import websockets +from tests.utils import WSServer + + +class FragmentError(Exception): + pass + + +async def test_connect_send_recv_text(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + await ws.send("hello") + assert await ws.recv() == "hello" + + +async def test_connect_send_recv_binary(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + await ws.send(b"hello") + assert await ws.recv() == b"hello" + + +async def test_send_empty_iterable_is_noop(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send([]) + pong_waiter = await ws.ping(b"noop") + await asyncio.wait_for(pong_waiter, 1.0) + + +async def test_send_empty_async_iterable_is_noop(): + async def fragments() -> AsyncIterator[bytes]: + if False: + yield b"never" + + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send(fragments()) + pong_waiter = await ws.ping(b"noop") + await asyncio.wait_for(pong_waiter, 1.0) + + +async def test_send_rejects_dict_like_objects(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + with pytest.raises(TypeError, match="dict-like object"): + await ws.send({"a": 1}) + + +async def test_send_text_overrides_and_concurrent_send_waits_for_turn(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send("as-binary", text=False) + assert await ws.recv(decode=False) == b"as-binary" + + await ws.send(b"as-text", text=True) + assert await ws.recv() == "as-text" + + first_sent = asyncio.Event() + release = asyncio.Event() + + async def fragments(): + yield b"first" + first_sent.set() + await release.wait() + yield b"second" + + first_send = asyncio.create_task(ws.send(fragments())) + await asyncio.wait_for(first_sent.wait(), 1.0) + + second_send = asyncio.create_task(ws.send(b"after")) + await asyncio.sleep(0) + assert not second_send.done() + + release.set() + await asyncio.wait_for(first_send, 1.0) + await asyncio.wait_for(second_send, 1.0) + assert await ws.recv() == b"firstsecond" + assert await ws.recv() == b"after" + + +async def test_send_string_fragments_and_write_pause_wait_paths(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send(["he", "llo"]) + assert await ws.recv() == "hello" + + ws.pause_writing() + single_send = asyncio.create_task(ws.send(b"paused")) + await asyncio.sleep(0) + assert not single_send.done() + ws.resume_writing() + await asyncio.wait_for(single_send, 1.0) + assert await ws.recv() == b"paused" + + ws.pause_writing() + fragmented_send = asyncio.create_task(ws.send([b"frag", b"mented"])) + await asyncio.sleep(0) + assert not fragmented_send.done() + ws.resume_writing() + await asyncio.wait_for(fragmented_send, 1.0) + assert await ws.recv() == b"fragmented" + + +async def test_send_rejects_invalid_first_fragment_and_closes(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + with pytest.raises(TypeError, match="message must contain"): + await ws.send([object()]) # type: ignore[list-item] + await asyncio.wait_for(ws.wait_closed(), 1.0) + assert ws.state is websockets.State.CLOSED + + +async def test_send_rejects_unsupported_object_and_mixed_fragments(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + with pytest.raises(TypeError, match="unsupported type"): + await ws.send(object()) # type: ignore[arg-type] + with pytest.raises(TypeError, match="same category"): + await ws.send([b"bytes", "text"]) + + +async def test_send_sync_iterable_exception_closes_connection(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + def fragments(): + yield b"first" + raise FragmentError("boom") + + with pytest.raises(FragmentError, match="boom"): + await ws.send(fragments()) + + await asyncio.wait_for(ws.wait_closed(), 1.0) + with pytest.raises(websockets.ConnectionClosed): + await ws.send(b"x") + + +async def test_send_async_iterable_exception_closes_connection(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + async def fragments() -> AsyncIterator[bytes]: + yield b"first" + raise FragmentError("boom") + + with pytest.raises(FragmentError, match="boom"): + await ws.send(fragments()) + + await asyncio.wait_for(ws.wait_closed(), 1.0) + with pytest.raises(websockets.ConnectionClosed): + await ws.send(b"x") + + +async def test_send_async_iterable_cancellation_closes_connection(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + first_sent = asyncio.Event() + unblock = asyncio.Event() + + async def fragments() -> AsyncIterator[bytes]: + yield b"first" + first_sent.set() + await unblock.wait() + yield b"second" + + send_task = asyncio.create_task(ws.send(fragments())) + await asyncio.wait_for(first_sent.wait(), 1.0) + send_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await send_task + + await asyncio.wait_for(ws.wait_closed(), 1.0) + with pytest.raises(websockets.ConnectionClosed): + await ws.send(b"x") + + +async def test_close_during_fragmented_send_closes_connection(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + first_sent = asyncio.Event() + unblock = asyncio.Event() + + async def fragments() -> AsyncIterator[bytes]: + yield b"first" + first_sent.set() + await unblock.wait() + yield b"second" + + send_task = asyncio.create_task(ws.send(fragments())) + await asyncio.wait_for(first_sent.wait(), 1.0) + await ws.close() + unblock.set() + + with pytest.raises(websockets.ConnectionClosed): + await send_task + + +async def test_send_after_close_raises_connection_closed(): + async with WSServer() as server: + ws = await websockets.connect(server.url, compression=None, ping_interval=None) + await ws.close() + + with pytest.raises(websockets.ConnectionClosedOK): + await ws.send(b"closed") diff --git a/tests/test_websockets_send_edge_cases.py b/tests/test_websockets_send_edge_cases.py deleted file mode 100644 index 4a6080b..0000000 --- a/tests/test_websockets_send_edge_cases.py +++ /dev/null @@ -1,50 +0,0 @@ -import asyncio -from collections.abc import AsyncIterator - -import pytest - -from picows import websockets -from tests.utils import WSServer - - -async def test_send_empty_iterable_is_noop(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - await ws.send([]) - pong_waiter = await ws.ping(b"noop") - await asyncio.wait_for(pong_waiter, 1.0) - - -async def test_send_empty_async_iterable_is_noop(): - async def fragments() -> AsyncIterator[bytes]: - if False: - yield b"never" - - async with WSServer() as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - await ws.send(fragments()) - pong_waiter = await ws.ping(b"noop") - await asyncio.wait_for(pong_waiter, 1.0) - - -async def test_send_rejects_dict_like_objects(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None) as ws: - with pytest.raises(TypeError, match="dict-like object"): - await ws.send({"a": 1}) - - -async def test_ping_accepts_byteslike_payloads(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - pong_waiter = await ws.ping(bytearray(b"abcd")) - await asyncio.wait_for(pong_waiter, 1.0) - pong_waiter = await ws.ping(memoryview(b"efgh")) - await asyncio.wait_for(pong_waiter, 1.0) - - -async def test_pong_accepts_byteslike_payloads(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - await ws.pong(bytearray(b"abcd")) - await ws.pong(memoryview(b"efgh")) diff --git a/tests/test_websockets_send_failure_edge_cases.py b/tests/test_websockets_send_failure_edge_cases.py deleted file mode 100644 index 81629e6..0000000 --- a/tests/test_websockets_send_failure_edge_cases.py +++ /dev/null @@ -1,86 +0,0 @@ -import asyncio -from collections.abc import AsyncIterator - -import pytest - -from picows import websockets -from tests.utils import WSServer - - -class FragmentError(Exception): - pass - - -async def test_send_sync_iterable_exception_closes_connection(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - def fragments(): - yield b"first" - raise FragmentError("boom") - - with pytest.raises(FragmentError, match="boom"): - await ws.send(fragments()) - - await asyncio.wait_for(ws.wait_closed(), 1.0) - with pytest.raises(websockets.ConnectionClosed): - await ws.send(b"x") - - -async def test_send_async_iterable_exception_closes_connection(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - async def fragments() -> AsyncIterator[bytes]: - yield b"first" - raise FragmentError("boom") - - with pytest.raises(FragmentError, match="boom"): - await ws.send(fragments()) - - await asyncio.wait_for(ws.wait_closed(), 1.0) - with pytest.raises(websockets.ConnectionClosed): - await ws.send(b"x") - - -async def test_send_async_iterable_cancellation_closes_connection(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - first_sent = asyncio.Event() - unblock = asyncio.Event() - - async def fragments() -> AsyncIterator[bytes]: - yield b"first" - first_sent.set() - await unblock.wait() - yield b"second" - - send_task = asyncio.create_task(ws.send(fragments())) - await asyncio.wait_for(first_sent.wait(), 1.0) - send_task.cancel() - - with pytest.raises(asyncio.CancelledError): - await send_task - - await asyncio.wait_for(ws.wait_closed(), 1.0) - with pytest.raises(websockets.ConnectionClosed): - await ws.send(b"x") - - -async def test_close_during_fragmented_send_closes_connection(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - first_sent = asyncio.Event() - unblock = asyncio.Event() - - async def fragments() -> AsyncIterator[bytes]: - yield b"first" - first_sent.set() - await unblock.wait() - yield b"second" - - send_task = asyncio.create_task(ws.send(fragments())) - await asyncio.wait_for(first_sent.wait(), 1.0) - await ws.close() - unblock.set() - - with pytest.raises(websockets.ConnectionClosed): - await send_task diff --git a/tests/test_websockets_server.py b/tests/test_websockets_server.py new file mode 100644 index 0000000..8199640 --- /dev/null +++ b/tests/test_websockets_server.py @@ -0,0 +1,149 @@ +import asyncio +import sys + +import pytest + +from picows import websockets + + +async def test_serve_echo_roundtrip(): + async def handler(ws: websockets.ServerConnection) -> None: + assert ws.request.path == "/" + assert ws.response.status_code == 101 + message = await ws.recv() + await ws.send(message) + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: + await ws.send("hello") + assert await ws.recv() == "hello" + + +async def test_serve_rejects_create_connection(): + async def handler(ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + with pytest.raises(NotImplementedError): + await websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + create_connection=websockets.ServerConnection, + ) + + +async def test_broadcast_sends_to_open_connections(): + connections: list[websockets.ServerConnection] = [] + + async def handler(ws: websockets.ServerConnection) -> None: + connections.append(ws) + await ws.wait_closed() + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws1: + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws2: + while len(connections) < 2: + await asyncio.sleep(0) + websockets.broadcast(connections, "hi") + assert await ws1.recv() == "hi" + assert await ws2.recv() == "hi" + + +def test_broadcast_validation_and_exception_group(): + with pytest.raises(TypeError, match="data must be str or bytes"): + websockets.broadcast([], object()) # type: ignore[arg-type] + + if sys.version_info[:2] < (3, 11): + with pytest.raises(ValueError, match="requires at least Python 3.11"): + websockets.broadcast([], "hello", raise_exceptions=True) + return + + class BrokenConnection: + state = websockets.State.OPEN + _send_in_progress = False + + def _encode_and_send(self, _msg_type, _message, _fin): + raise RuntimeError("broken") + + with pytest.raises(ExceptionGroup) as exc_info: + websockets.broadcast([BrokenConnection()], "hello", raise_exceptions=True) # type: ignore[list-item] + assert len(exc_info.value.exceptions) == 1 + + +async def test_server_connections_tracks_open_connections(): + connected = asyncio.Event() + + async def handler(ws: websockets.ServerConnection) -> None: + connected.set() + await ws.wait_closed() + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + assert server.connections == set() + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + await connected.wait() + assert len(server.connections) == 1 + await asyncio.sleep(0) + assert server.connections == set() + + +async def test_handler_exception_closes_connection_with_internal_error(): + async def handler(ws: websockets.ServerConnection) -> None: + raise RuntimeError("boom") + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + assert ws.close_code == 1011 + + +async def test_server_close_closes_existing_connections(): + started = asyncio.Event() + + async def handler(ws: websockets.ServerConnection) -> None: + started.set() + await ws.wait_closed() + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: + await started.wait() + server.close(reason="bye") + with pytest.raises(websockets.ConnectionClosedOK): + await ws.recv() + assert ws.close_code == 1001 + assert ws.close_reason == "bye" + await server.wait_closed() + + +async def test_wait_closed_waits_for_handler_completion(): + started = asyncio.Event() + finish = asyncio.Event() + finished = asyncio.Event() + + async def handler(ws: websockets.ServerConnection) -> None: + started.set() + await finish.wait() + finished.set() + + server = await websockets.serve(handler, "127.0.0.1", 0, compression=None) + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + await started.wait() + server.close(close_connections=False) + waiter = asyncio.create_task(server.wait_closed()) + await asyncio.sleep(0) + assert not waiter.done() + finish.set() + await waiter + assert finished.is_set() + + +def test_route_requires_werkzeug(): + with pytest.raises((ImportError, NotImplementedError)): + websockets.route(None) # type: ignore[arg-type] diff --git a/tests/test_websockets_server_compat.py b/tests/test_websockets_server_compat.py deleted file mode 100644 index ed1756f..0000000 --- a/tests/test_websockets_server_compat.py +++ /dev/null @@ -1,340 +0,0 @@ -import asyncio -import base64 -import re - -import pytest -from multidict import CIMultiDict - -from picows import websockets - - -async def test_serve_echo_roundtrip(): - async def handler(ws: websockets.ServerConnection) -> None: - assert ws.request.path == "/" - assert ws.response.status_code == 101 - message = await ws.recv() - await ws.send(message) - - async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: - port = server.sockets[0].getsockname()[1] - async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: - await ws.send("hello") - assert await ws.recv() == "hello" - - -async def test_serve_process_request_can_reject_handshake(): - async def handler(ws: websockets.ServerConnection) -> None: - raise AssertionError("handler must not be called") - - def process_request( - ws: websockets.ServerHandshakeConnection, - request: websockets.Request, - ) -> websockets.Response | None: - assert ws.request is request - return websockets.Response( - status_code=418, - reason_phrase="I'm a Teapot", - headers=CIMultiDict({"X-Test": "yes"}), - body=b"nope", - ) - - async with websockets.serve( - handler, - "127.0.0.1", - 0, - compression=None, - process_request=process_request, - ) as server: - port = server.sockets[0].getsockname()[1] - with pytest.raises(websockets.InvalidStatus) as exc_info: - async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): - pass - assert int(exc_info.value.response.status) == 418 - - -async def test_serve_process_response_can_mutate_handshake_response(): - async def handler(ws: websockets.ServerConnection) -> None: - await ws.wait_closed() - - def process_response( - ws: websockets.ServerHandshakeConnection, - request: websockets.Request, - response: websockets.Response, - ) -> websockets.Response: - assert ws.request is request - response.headers["X-Handshake"] = "yes" - return response - - async with websockets.serve( - handler, - "127.0.0.1", - 0, - compression=None, - process_response=process_response, - ) as server: - port = server.sockets[0].getsockname()[1] - async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: - assert ws.response.headers["X-Handshake"] == "yes" - - -async def test_serve_rejects_async_process_request(): - async def handler(ws: websockets.ServerConnection) -> None: - raise AssertionError("handler must not be called") - - async def process_request( - ws: websockets.ServerHandshakeConnection, - request: websockets.Request, - ) -> websockets.Response | None: - return None - - async with websockets.serve( - handler, - "127.0.0.1", - 0, - compression=None, - process_request=process_request, - ) as server: - port = server.sockets[0].getsockname()[1] - with pytest.raises(websockets.InvalidStatus): - async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): - pass - - -async def test_serve_rejects_create_connection(): - async def handler(ws: websockets.ServerConnection) -> None: - raise AssertionError("handler must not be called") - - with pytest.raises(NotImplementedError): - await websockets.serve( - handler, - "127.0.0.1", - 0, - compression=None, - create_connection=websockets.ServerConnection, - ) - - -async def test_serve_accepts_allowed_origin(): - async def handler(ws: websockets.ServerConnection) -> None: - await ws.send("ok") - - async with websockets.serve( - handler, - "127.0.0.1", - 0, - compression=None, - origins=["https://example.com"], - ) as server: - port = server.sockets[0].getsockname()[1] - async with websockets.connect( - f"ws://127.0.0.1:{port}/", - compression=None, - origin="https://example.com", - ) as ws: - assert await ws.recv() == "ok" - - -async def test_serve_rejects_disallowed_origin(): - async def handler(ws: websockets.ServerConnection) -> None: - raise AssertionError("handler must not be called") - - async with websockets.serve( - handler, - "127.0.0.1", - 0, - compression=None, - origins=[re.compile(r"https://allowed\\.example\\.com")], - ) as server: - port = server.sockets[0].getsockname()[1] - with pytest.raises(websockets.InvalidStatus): - async with websockets.connect( - f"ws://127.0.0.1:{port}/", - compression=None, - origin="https://denied.example.com", - ): - pass - - -async def test_basic_auth_rejects_missing_credentials_and_sets_username(): - async def handler(ws: websockets.ServerConnection) -> None: - assert ws.username == "hello" - await ws.send(ws.username) - - async with websockets.serve( - handler, - "127.0.0.1", - 0, - compression=None, - process_request=websockets.basic_auth( - realm="test", - credentials=("hello", "secret"), - ), - ) as server: - port = server.sockets[0].getsockname()[1] - - with pytest.raises(websockets.InvalidStatus) as exc_info: - async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): - pass - assert int(exc_info.value.response.status) == 401 - assert exc_info.value.response.headers["WWW-Authenticate"] == 'Basic realm="test"' - - token = base64.b64encode(b"hello:secret").decode() - async with websockets.connect( - f"ws://127.0.0.1:{port}/", - compression=None, - additional_headers={"Authorization": f"Basic {token}"}, - ) as ws: - assert await ws.recv() == "hello" - - -async def test_serve_negotiates_subprotocol(): - async def handler(ws: websockets.ServerConnection) -> None: - await ws.wait_closed() - - async with websockets.serve( - handler, - "127.0.0.1", - 0, - compression=None, - subprotocols=["chat", "superchat"], - ) as server: - port = server.sockets[0].getsockname()[1] - async with websockets.connect( - f"ws://127.0.0.1:{port}/", - compression=None, - subprotocols=["superchat", "chat"], - ) as ws: - assert ws.subprotocol == "chat" - - -async def test_select_subprotocol_receives_handshake_connection(): - seen = {} - - async def handler(ws: websockets.ServerConnection) -> None: - await ws.wait_closed() - - def select_subprotocol( - ws: websockets.ServerHandshakeConnection, - offered: list[str], - ) -> str | None: - seen["type"] = type(ws) - seen["path"] = ws.request.path - seen["state"] = ws.state - seen["has_recv"] = hasattr(ws, "recv") - if "chat" in offered: - return "chat" - return None - - async with websockets.serve( - handler, - "127.0.0.1", - 0, - compression=None, - select_subprotocol=select_subprotocol, - ) as server: - port = server.sockets[0].getsockname()[1] - async with websockets.connect( - f"ws://127.0.0.1:{port}/room", - compression=None, - subprotocols=["chat"], - ) as ws: - assert ws.subprotocol == "chat" - - assert seen["type"] is websockets.ServerHandshakeConnection - assert seen["path"] == "/room" - assert seen["state"] is websockets.State.CONNECTING - assert seen["has_recv"] is False - - -async def test_broadcast_sends_to_open_connections(): - connections: list[websockets.ServerConnection] = [] - - async def handler(ws: websockets.ServerConnection) -> None: - connections.append(ws) - await ws.wait_closed() - - async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: - port = server.sockets[0].getsockname()[1] - async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws1: - async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws2: - while len(connections) < 2: - await asyncio.sleep(0) - websockets.broadcast(connections, "hi") - assert await ws1.recv() == "hi" - assert await ws2.recv() == "hi" - - -async def test_server_connections_tracks_open_connections(): - connected = asyncio.Event() - - async def handler(ws: websockets.ServerConnection) -> None: - connected.set() - await ws.wait_closed() - - async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: - port = server.sockets[0].getsockname()[1] - assert server.connections == set() - async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): - await connected.wait() - assert len(server.connections) == 1 - await asyncio.sleep(0) - assert server.connections == set() - - -async def test_handler_exception_closes_connection_with_internal_error(): - async def handler(ws: websockets.ServerConnection) -> None: - raise RuntimeError("boom") - - async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: - port = server.sockets[0].getsockname()[1] - async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: - with pytest.raises(websockets.ConnectionClosedError): - await ws.recv() - assert ws.close_code == 1011 - - -async def test_server_close_closes_existing_connections(): - started = asyncio.Event() - - async def handler(ws: websockets.ServerConnection) -> None: - started.set() - await ws.wait_closed() - - async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: - port = server.sockets[0].getsockname()[1] - async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: - await started.wait() - server.close(reason="bye") - with pytest.raises(websockets.ConnectionClosedOK): - await ws.recv() - assert ws.close_code == 1001 - assert ws.close_reason == "bye" - await server.wait_closed() - - -async def test_wait_closed_waits_for_handler_completion(): - started = asyncio.Event() - finish = asyncio.Event() - finished = asyncio.Event() - - async def handler(ws: websockets.ServerConnection) -> None: - started.set() - await finish.wait() - finished.set() - - server = await websockets.serve(handler, "127.0.0.1", 0, compression=None) - port = server.sockets[0].getsockname()[1] - async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): - await started.wait() - server.close(close_connections=False) - waiter = asyncio.create_task(server.wait_closed()) - await asyncio.sleep(0) - assert not waiter.done() - finish.set() - await waiter - assert finished.is_set() - - -def test_route_requires_werkzeug(): - with pytest.raises((ImportError, NotImplementedError)): - websockets.route(None) # type: ignore[arg-type] diff --git a/tests/test_websockets_server_handshake.py b/tests/test_websockets_server_handshake.py new file mode 100644 index 0000000..8553b7b --- /dev/null +++ b/tests/test_websockets_server_handshake.py @@ -0,0 +1,202 @@ +import base64 +import re + +import pytest +from multidict import CIMultiDict + +from picows import websockets +from picows.websockets.asyncio.server import _parse_basic_authorization + + +async def test_serve_process_request_can_reject_handshake(): + async def handler(ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + def process_request( + ws: websockets.ServerHandshakeConnection, + request: websockets.Request, + ) -> websockets.Response | None: + assert ws.request is request + return websockets.Response( + status_code=418, + reason_phrase="I'm a Teapot", + headers=CIMultiDict({"X-Test": "yes"}), + body=b"nope", + ) + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=process_request, + ) as server: + port = server.sockets[0].getsockname()[1] + with pytest.raises(websockets.InvalidStatus) as exc_info: + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + pass + assert int(exc_info.value.response.status) == 418 + + +async def test_serve_process_response_can_mutate_handshake_response(): + async def handler(ws: websockets.ServerConnection) -> None: + await ws.wait_closed() + + def process_response( + ws: websockets.ServerHandshakeConnection, + request: websockets.Request, + response: websockets.Response, + ) -> websockets.Response: + assert ws.request is request + response.headers["X-Handshake"] = "yes" + return response + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_response=process_response, + ) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: + assert ws.response.headers["X-Handshake"] == "yes" + + +async def test_serve_rejects_async_process_request(): + async def handler(ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + async def process_request( + ws: websockets.ServerHandshakeConnection, + request: websockets.Request, + ) -> websockets.Response | None: + return None + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=process_request, + ) as server: + port = server.sockets[0].getsockname()[1] + with pytest.raises(websockets.InvalidStatus): + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + pass + + +async def test_serve_accepts_allowed_origin(): + async def handler(ws: websockets.ServerConnection) -> None: + await ws.send("ok") + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + origins=["https://example.com"], + ) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + origin="https://example.com", + ) as ws: + assert await ws.recv() == "ok" + + +async def test_serve_rejects_disallowed_origin(): + async def handler(ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + origins=[re.compile(r"https://allowed\\.example\\.com")], + ) as server: + port = server.sockets[0].getsockname()[1] + with pytest.raises(websockets.InvalidStatus): + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + origin="https://denied.example.com", + ): + pass + + +async def test_basic_auth_rejects_missing_credentials_and_sets_username(): + async def handler(ws: websockets.ServerConnection) -> None: + assert ws.username == "hello" + await ws.send(ws.username) + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=websockets.basic_auth( + realm="test", + credentials=("hello", "secret"), + ), + ) as server: + port = server.sockets[0].getsockname()[1] + + with pytest.raises(websockets.InvalidStatus) as exc_info: + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + pass + assert int(exc_info.value.response.status) == 401 + assert exc_info.value.response.headers["WWW-Authenticate"] == 'Basic realm="test"' + + token = base64.b64encode(b"hello:secret").decode() + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + additional_headers={"Authorization": f"Basic {token}"}, + ) as ws: + assert await ws.recv() == "hello" + + +def test_basic_auth_argument_validation_and_malformed_headers(): + with pytest.raises(ValueError, match="provide either credentials or check_credentials"): + websockets.basic_auth() + with pytest.raises(ValueError, match="provide either credentials or check_credentials"): + websockets.basic_auth(credentials=("a", "b"), check_credentials=lambda _u, _p: True) + with pytest.raises(TypeError, match="invalid credentials argument"): + websockets.basic_auth(credentials=("a", "b", "c")) # type: ignore[arg-type] + with pytest.raises(TypeError, match="invalid credentials argument"): + websockets.basic_auth(credentials=[("a", object())]) # type: ignore[list-item] + + with pytest.raises(ValueError, match="unsupported authorization scheme"): + _parse_basic_authorization("Bearer token") + with pytest.raises(ValueError, match="invalid basic authorization header"): + _parse_basic_authorization("Basic !!!") + with pytest.raises(ValueError, match="invalid basic authorization header"): + _parse_basic_authorization("Basic bm9jb2xvbg==") + + +async def test_basic_auth_rejects_bad_and_async_credentials(): + async def handler(_ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + async def check_credentials(_username: str, _password: str) -> bool: + return True + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=websockets.basic_auth(check_credentials=check_credentials), + ) as server: + port = server.sockets[0].getsockname()[1] + token = "Basic " + "aW52YWxpZDpjcmVkZW50aWFscw==" + with pytest.raises(websockets.InvalidStatus): + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + additional_headers={"Authorization": token}, + ): + pass diff --git a/tests/test_websockets_unit_coverage.py b/tests/test_websockets_unit_coverage.py deleted file mode 100644 index 1a2a2e1..0000000 --- a/tests/test_websockets_unit_coverage.py +++ /dev/null @@ -1,618 +0,0 @@ -from __future__ import annotations - -import asyncio -import socket -import sys -from dataclasses import dataclass -from http import HTTPStatus - -import pytest -from multidict import CIMultiDict - -import picows -from picows import websockets -from picows.websockets.asyncio.client import _process_proxy -from picows.websockets.asyncio.connection import ( - _PerMessageDeflate, - _normalize_watermarks, - _resolve_logger, - process_exception, -) -from picows.websockets.asyncio.negotiation import configure_permessage_deflate, resolve_subprotocol -from picows.websockets.asyncio.server import _parse_basic_authorization -from tests.utils import ServerEchoListener, WSServer - - -@dataclass -class Close: - code: int - reason: str - - -def test_connection_closed_string_variants(): - assert str(websockets.ConnectionClosed(None, None)) == "no close frame received or sent" - assert str(websockets.ConnectionClosed(None, Close(1000, "bye"))) == "sent 1000 (bye)" - assert str(websockets.ConnectionClosed(Close(1001, "away"), None)) == "received 1001 (away)" - assert str(websockets.ConnectionClosed(Close(1001, "away"), Close(1000, "bye"), True)) == ( - "received then sent close frames: received 1001 (away), sent 1000 (bye)" - ) - assert str(websockets.ConnectionClosed(Close(1001, "away"), Close(1000, "bye"), False)) == ( - "sent then received close frames: received 1001 (away), sent 1000 (bye)" - ) - - -def test_exception_attributes_and_strings(): - assert str(websockets.InvalidURI("http://example.com", "wrong scheme")) == ( - "http://example.com isn't a valid WebSocket URI: wrong scheme" - ) - assert str(websockets.InvalidProxy("ftp://proxy", "wrong scheme")) == ( - "ftp://proxy isn't a valid proxy: wrong scheme" - ) - assert str(websockets.InvalidProxyStatus(object())) == "proxy rejected connection" - assert str(websockets.InvalidProxyStatus(websockets.Response(502, "Bad Gateway", CIMultiDict(), b""))) == ( - "proxy rejected connection: HTTP 502" - ) - - invalid_header = websockets.InvalidHeader("X-Test", "bad") - assert invalid_header.name == "X-Test" - assert invalid_header.value == "bad" - assert websockets.InvalidOrigin("https://bad.example").name == "Origin" - assert websockets.InvalidHeaderFormat("X-Test", "bad syntax", "x:y", 1).value == "bad syntax at 1 in x:y" - - assert str(websockets.DuplicateParameter("server_max_window_bits")) == ( - "duplicate parameter: server_max_window_bits" - ) - assert str(websockets.InvalidParameterName("x")) == "invalid parameter name: x" - assert str(websockets.InvalidParameterValue("x", None)) == "missing value for parameter x" - assert str(websockets.InvalidParameterValue("x", "")) == "empty value for parameter x" - assert str(websockets.InvalidParameterValue("x", "bad")) == "invalid value for parameter x: bad" - - -def test_client_private_option_helpers(): - assert _process_proxy(None, False) is None - assert _process_proxy("http://127.0.0.1:8080", False) == "http://127.0.0.1:8080" - with pytest.raises(websockets.InvalidProxy): - _process_proxy(123, False) # type: ignore[arg-type] - - assert _normalize_watermarks(None) == (0, 0) - assert _normalize_watermarks((None, 1)) == (0, 0) - assert _normalize_watermarks((8, None)) == (8, 2) - assert _resolve_logger("picows.test").name == "picows.test" - - -async def test_client_connection_starts_in_connecting_state(): - connection = websockets.ClientConnection( - request=websockets.Request("/", CIMultiDict()), - response=websockets.Response(101, "Switching Protocols", CIMultiDict(), b""), - subprotocol=None, - permessage_deflate=None, - ) - - assert connection.state is websockets.State.CONNECTING - - -async def test_connect_await_style_and_socket_options(): - async with WSServer() as server: - ws = await websockets.connect(server.url, compression=None, ping_interval=None) - try: - await ws.send("awaited") - assert await ws.recv() == "awaited" - finally: - await ws.close() - - sock = socket.create_connection((server.host, server.port)) - ws = await websockets.connect(server.url, compression=None, ping_interval=None, sock=sock) - try: - await ws.send(b"sock") - assert await ws.recv() == b"sock" - finally: - await ws.close() - - async with websockets.connect( - "ws://example.invalid/", - compression=None, - ping_interval=None, - proxy=None, - host=server.host, - port=server.port, - ) as ws: - await ws.send("override") - assert await ws.recv() == "override" - - -async def test_connect_rejects_conflicting_and_invalid_socket_options(): - async with WSServer() as server: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - with pytest.raises(TypeError, match="cannot pass both sock and socket_factory"): - await websockets.connect(server.url, compression=None, sock=sock, socket_factory=lambda _: None) - finally: - sock.close() - - with pytest.raises(TypeError, match="sock must be a socket.socket instance"): - await websockets.connect(server.url, compression=None, sock=object()) - - with pytest.raises(TypeError, match="cannot pass both host/port override and socket_factory"): - await websockets.connect( - server.url, - compression=None, - host=server.host, - socket_factory=lambda _: None, - ) - - -async def test_connect_rejects_invalid_ssl_options_before_network(): - with pytest.raises(NotImplementedError, match="ssl=False"): - await websockets.connect("wss://example.com/", compression=None, ssl=False) - with pytest.raises(TypeError, match="ssl must be"): - await websockets.connect("wss://example.com/", compression=None, ssl=object()) - - -def test_process_exception_retries_transient_failures(): - assert process_exception(EOFError()) is None - assert process_exception(OSError()) is None - assert process_exception(asyncio.TimeoutError()) is None - - response = websockets.Response(503, "Service Unavailable", CIMultiDict(), b"") - assert process_exception(websockets.InvalidStatus(response)) is None - - error = RuntimeError("boom") - assert process_exception(error) is error - - -def test_negotiation_rejects_invalid_subprotocol_and_extension_headers(): - response = websockets.Response(101, "Switching Protocols", CIMultiDict(), b"") - - response.headers["Sec-WebSocket-Protocol"] = 123 # type: ignore[assignment] - with pytest.raises(websockets.InvalidHandshake, match="non-string subprotocol"): - resolve_subprotocol(["chat"], response) - - response.headers["Sec-WebSocket-Protocol"] = "other" - with pytest.raises(websockets.InvalidHandshake, match="unsupported subprotocol"): - resolve_subprotocol(["chat"], response) - - response.headers.clear() - response.headers["Sec-WebSocket-Extensions"] = "permessage-deflate" - with pytest.raises(websockets.InvalidHandshake, match="unexpected websocket extensions"): - configure_permessage_deflate(response, None) - - response.headers["Sec-WebSocket-Extensions"] = 123 # type: ignore[assignment] - with pytest.raises(websockets.InvalidHandshake, match="invalid Sec-WebSocket-Extensions"): - configure_permessage_deflate(response, "deflate") - - -def test_permessage_deflate_rejects_invalid_parameters(): - response = websockets.Response(101, "Switching Protocols", CIMultiDict(), b"") - invalid_headers = [ - "x-webkit-deflate-frame", - "permessage-deflate, permessage-deflate", - "permessage-deflate; server_no_context_takeover=true", - "permessage-deflate; client_no_context_takeover=true", - "permessage-deflate; server_max_window_bits", - "permessage-deflate; server_max_window_bits=7", - "permessage-deflate; client_max_window_bits", - "permessage-deflate; client_max_window_bits=16", - "permessage-deflate; unknown=value", - "permessage-deflate; server_max_window_bits=15; server_max_window_bits=15", - ] - - for header in invalid_headers: - response.headers["Sec-WebSocket-Extensions"] = header - with pytest.raises(websockets.InvalidHandshake): - configure_permessage_deflate(response, "deflate") - - -def test_permessage_deflate_accepts_no_context_takeover_parameters(): - permessage_deflate = _PerMessageDeflate.from_response_header( - "permessage-deflate; server_no_context_takeover; client_no_context_takeover" - ) - - first = permessage_deflate.encode_frame(picows.WSMsgType.TEXT, "hello", True) - second = permessage_deflate.encode_frame(picows.WSMsgType.TEXT, b"hello", True) - - assert isinstance(first, memoryview) - assert isinstance(second, memoryview) - assert bytes(first) - assert bytes(second) - - -class Frame: - def __init__( - self, - msg_type: picows.WSMsgType, - payload: bytes, - *, - fin: bool = True, - rsv1: bool = False, - ): - self.msg_type = msg_type - self.payload = payload - self.fin = fin - self.rsv1 = rsv1 - - def get_payload_as_bytes(self) -> bytes: - return self.payload - - def get_payload_as_memoryview(self) -> memoryview: - return memoryview(self.payload) - - -def test_permessage_deflate_decode_passthrough_and_protocol_error_branches(): - permessage_deflate = _PerMessageDeflate.from_response_header( - "permessage-deflate; server_no_context_takeover; client_no_context_takeover" - ) - - assert permessage_deflate.decode_frame( - Frame(picows.WSMsgType.TEXT, b"plain", rsv1=False), - 0, - ) == b"plain" - assert permessage_deflate.decode_frame( - Frame(picows.WSMsgType.CONTINUATION, b"continuation", rsv1=False), - 0, - ) == b"continuation" - - encoded = bytes(permessage_deflate.encode_frame(picows.WSMsgType.TEXT, b"compressed", True)) - assert permessage_deflate.decode_frame( - Frame(picows.WSMsgType.TEXT, encoded, rsv1=True), - 100, - ) == b"compressed" - - with pytest.raises(picows.WSProtocolError, match="unexpected rsv1"): - permessage_deflate.decode_frame( - Frame(picows.WSMsgType.CONTINUATION, b"bad", rsv1=True), - 0, - ) - - -def test_basic_auth_argument_validation_and_malformed_headers(): - with pytest.raises(ValueError, match="provide either credentials or check_credentials"): - websockets.basic_auth() - with pytest.raises(ValueError, match="provide either credentials or check_credentials"): - websockets.basic_auth(credentials=("a", "b"), check_credentials=lambda _u, _p: True) - with pytest.raises(TypeError, match="invalid credentials argument"): - websockets.basic_auth(credentials=("a", "b", "c")) # type: ignore[arg-type] - with pytest.raises(TypeError, match="invalid credentials argument"): - websockets.basic_auth(credentials=[("a", object())]) # type: ignore[list-item] - - with pytest.raises(ValueError, match="unsupported authorization scheme"): - _parse_basic_authorization("Bearer token") - with pytest.raises(ValueError, match="invalid basic authorization header"): - _parse_basic_authorization("Basic !!!") - with pytest.raises(ValueError, match="invalid basic authorization header"): - _parse_basic_authorization("Basic bm9jb2xvbg==") - - -async def test_basic_auth_rejects_bad_and_async_credentials(): - async def handler(_ws: websockets.ServerConnection) -> None: - raise AssertionError("handler must not be called") - - async def check_credentials(_username: str, _password: str) -> bool: - return True - - async with websockets.serve( - handler, - "127.0.0.1", - 0, - compression=None, - process_request=websockets.basic_auth(check_credentials=check_credentials), - ) as server: - port = server.sockets[0].getsockname()[1] - token = "Basic " + "aW52YWxpZDpjcmVkZW50aWFscw==" - with pytest.raises(websockets.InvalidStatus): - async with websockets.connect( - f"ws://127.0.0.1:{port}/", - compression=None, - additional_headers={"Authorization": token}, - ): - pass - - -async def test_connect_async_iterator_retries_then_succeeds(): - attempts = 0 - - def process_exception(exc: Exception) -> Exception | None: - nonlocal attempts - attempts += 1 - if attempts == 1: - return None - return exc - - connector = websockets.connect( - "ws://127.0.0.1:1/", - compression=None, - open_timeout=0.01, - process_exception=process_exception, - ) - connector._backoff = 0 - with pytest.raises(OSError): - async for _ws in connector: - pass - assert attempts == 2 - - -class ContinuationOnConnect(ServerEchoListener): - def on_ws_connected(self, transport: picows.WSTransport): - super().on_ws_connected(transport) - transport.send(picows.WSMsgType.CONTINUATION, b"bad", fin=True) - - -class Rsv1OnConnect(ServerEchoListener): - def on_ws_connected(self, transport: picows.WSTransport): - super().on_ws_connected(transport) - transport.send(picows.WSMsgType.TEXT, b"bad", fin=True, rsv1=True) - - -class BadContinuationSequenceOnConnect(ServerEchoListener): - def on_ws_connected(self, transport: picows.WSTransport): - super().on_ws_connected(transport) - transport.send(picows.WSMsgType.TEXT, b"first", fin=False) - transport.send(picows.WSMsgType.TEXT, b"second", fin=True) - - -class SendLargeTextOnConnect(ServerEchoListener): - def on_ws_connected(self, transport: picows.WSTransport): - super().on_ws_connected(transport) - transport.send(picows.WSMsgType.TEXT, b"large") - - -class DelayedFragmentedTextOnConnect(ServerEchoListener): - def on_ws_connected(self, transport: picows.WSTransport): - super().on_ws_connected(transport) - - async def send_fragments(): - transport.send(picows.WSMsgType.TEXT, b"he", fin=False) - await asyncio.sleep(0.01) - transport.send(picows.WSMsgType.CONTINUATION, b"llo", fin=True) - - asyncio.create_task(send_fragments()) - - -class FragmentedTextOnConnect(ServerEchoListener): - def on_ws_connected(self, transport: picows.WSTransport): - super().on_ws_connected(transport) - transport.send(picows.WSMsgType.TEXT, b"he", fin=False) - transport.send(picows.WSMsgType.CONTINUATION, b"llo", fin=True) - - -class IgnoreCloseListener(ServerEchoListener): - def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): - if frame.msg_type == picows.WSMsgType.CLOSE: - return - super().on_ws_frame(transport, frame) - - -async def test_recv_rejects_unexpected_continuation_and_rsv1(): - async with WSServer(lambda _: ContinuationOnConnect()) as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - with pytest.raises(websockets.ConnectionClosedError): - await ws.recv() - - async with WSServer(lambda _: Rsv1OnConnect()) as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - with pytest.raises(websockets.ConnectionClosedError): - await ws.recv() - - -async def test_recv_rejects_bad_continuation_and_too_large_message(): - async with WSServer(lambda _: BadContinuationSequenceOnConnect()) as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - with pytest.raises(websockets.ConnectionClosedError): - await ws.recv() - - async with WSServer(lambda _: SendLargeTextOnConnect()) as server: - async with websockets.connect(server.url, compression=None, ping_interval=None, max_size=2) as ws: - with pytest.raises(websockets.ConnectionClosedError): - await ws.recv() - - -async def test_recv_streaming_waits_for_later_fragment(): - async with WSServer(lambda _: DelayedFragmentedTextOnConnect()) as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - fragments = [] - async for fragment in ws.recv_streaming(): - fragments.append(fragment) - assert fragments == ["he", "llo"] - - -async def test_fragmented_message_exceeding_max_size_closes_connection(): - async with WSServer(lambda _: FragmentedTextOnConnect()) as server: - async with websockets.connect(server.url, compression=None, ping_interval=None, max_size=4) as ws: - with pytest.raises(websockets.ConnectionClosedError): - await ws.recv() - - -async def test_send_and_ping_validation_branches(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - assert ws.state is websockets.State.OPEN - assert ws.local_address[0] == "127.0.0.1" - assert ws.remote_address[0] == "127.0.0.1" - assert ws.latency == 0 - assert ws.subprotocol is None - assert ws.close_code is None - assert ws.close_reason is None - - default_waiter = await ws.ping() - await asyncio.wait_for(default_waiter, 1.0) - - waiter = await ws.ping("same") - with pytest.raises(websockets.ConcurrencyError, match="same data"): - await ws.ping(b"same") - await asyncio.wait_for(waiter, 1.0) - - with pytest.raises(TypeError, match="ping payload"): - await ws.ping(object()) # type: ignore[arg-type] - with pytest.raises(TypeError, match="unsupported type"): - await ws.send(object()) # type: ignore[arg-type] - with pytest.raises(TypeError, match="same category"): - await ws.send([b"bytes", "text"]) - - -async def test_send_text_overrides_and_concurrent_send_waits_for_turn(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - await ws.send("as-binary", text=False) - assert await ws.recv(decode=False) == b"as-binary" - - await ws.send(b"as-text", text=True) - assert await ws.recv() == "as-text" - - first_sent = asyncio.Event() - release = asyncio.Event() - - async def fragments(): - yield b"first" - first_sent.set() - await release.wait() - yield b"second" - - first_send = asyncio.create_task(ws.send(fragments())) - await asyncio.wait_for(first_sent.wait(), 1.0) - - second_send = asyncio.create_task(ws.send(b"after")) - await asyncio.sleep(0) - assert not second_send.done() - - release.set() - await asyncio.wait_for(first_send, 1.0) - await asyncio.wait_for(second_send, 1.0) - assert await ws.recv() == b"firstsecond" - assert await ws.recv() == b"after" - - -async def test_send_string_fragments_and_write_pause_wait_paths(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - await ws.send(["he", "llo"]) - assert await ws.recv() == "hello" - - ws.pause_writing() - single_send = asyncio.create_task(ws.send(b"paused")) - await asyncio.sleep(0) - assert not single_send.done() - ws.resume_writing() - await asyncio.wait_for(single_send, 1.0) - assert await ws.recv() == b"paused" - - ws.pause_writing() - fragmented_send = asyncio.create_task(ws.send([b"frag", b"mented"])) - await asyncio.sleep(0) - assert not fragmented_send.done() - ws.resume_writing() - await asyncio.wait_for(fragmented_send, 1.0) - assert await ws.recv() == b"fragmented" - - -async def test_send_rejects_invalid_first_fragment_and_closes(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - with pytest.raises(TypeError, match="message must contain"): - await ws.send([object()]) # type: ignore[list-item] - await asyncio.wait_for(ws.wait_closed(), 1.0) - assert ws.state is websockets.State.CLOSED - - -async def test_connection_context_manager_and_close_timeout_none(): - async with WSServer() as server: - ws = await websockets.connect(server.url, compression=None, ping_interval=None, close_timeout=None) - async with ws: - await ws.send("context") - assert await ws.recv() == "context" - assert ws.state is websockets.State.CLOSED - - -async def test_close_timeout_disconnects_when_peer_ignores_close(): - async with WSServer(lambda _: IgnoreCloseListener()) as server: - ws = await websockets.connect( - server.url, - compression=None, - ping_interval=None, - close_timeout=0.01, - ) - await ws.close() - assert ws.state is websockets.State.CLOSED - assert ws.close_code == 1000 - assert ws.close_reason == "" - - -async def test_keepalive_loop_without_ping_timeout_sends_default_pings(): - async with WSServer(enable_auto_pong=True) as server: - async with websockets.connect( - server.url, - compression=None, - ping_interval=0.01, - ping_timeout=None, - ) as ws: - await asyncio.sleep(0.03) - assert ws.latency >= 0 - - -async def test_keepalive_loop_with_ping_timeout_observes_pong(): - async with WSServer(enable_auto_pong=True) as server: - async with websockets.connect( - server.url, - compression=None, - ping_interval=0.01, - ping_timeout=1.0, - ) as ws: - await asyncio.sleep(0.03) - assert ws.latency >= 0 - - -async def test_disconnect_without_close_frame_sets_error_close_state(): - async with WSServer() as server: - async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: - await ws.send("disconnect_me_without_close_frame") - with pytest.raises(websockets.ConnectionClosedError): - await ws.recv() - assert ws.close_code is None - assert ws.close_reason is None - - -async def test_send_ping_and_pong_after_close_raise_connection_closed(): - async with WSServer() as server: - ws = await websockets.connect(server.url, compression=None, ping_interval=None) - await ws.close() - - with pytest.raises(websockets.ConnectionClosedOK): - await ws.send(b"closed") - with pytest.raises(websockets.ConnectionClosedOK): - await ws.ping() - with pytest.raises(websockets.ConnectionClosedOK): - await ws.pong() - - -def test_broadcast_validation_and_exception_group(): - with pytest.raises(TypeError, match="data must be str or bytes"): - websockets.broadcast([], object()) # type: ignore[arg-type] - - if sys.version_info[:2] < (3, 11): - with pytest.raises(ValueError, match="requires at least Python 3.11"): - websockets.broadcast([], "hello", raise_exceptions=True) - return - - class BrokenConnection: - state = websockets.State.OPEN - _send_in_progress = False - - def _encode_and_send(self, _msg_type, _message, _fin): - raise RuntimeError("broken") - - with pytest.raises(ExceptionGroup) as exc_info: - websockets.broadcast([BrokenConnection()], "hello", raise_exceptions=True) # type: ignore[list-item] - assert len(exc_info.value.exceptions) == 1 - - -def test_response_to_picows_supports_empty_body_and_status_alias(): - response = websockets.Response( - int(HTTPStatus.SWITCHING_PROTOCOLS), - HTTPStatus.SWITCHING_PROTOCOLS.phrase, - CIMultiDict({"X-Test": "yes"}), - bytearray(b"body"), - ) - - assert response.status == 101 - picows_response = response.to_picows() - assert picows_response.status is HTTPStatus.SWITCHING_PROTOCOLS - assert picows_response.headers["X-Test"] == "yes" - assert picows_response.body == b"body" From e544120a5722c7cc7fb30473fdc35b54687b4787 Mon Sep 17 00:00:00 2001 From: taras Date: Fri, 8 May 2026 20:36:56 +0200 Subject: [PATCH 56/57] Fix deps --- requirements-test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-test.txt b/requirements-test.txt index b6707b8..4bd372d 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -13,3 +13,4 @@ uvloop; sys_platform != 'win32' winloop; sys_platform == 'win32' tiny-proxy mypy +websockets \ No newline at end of file From 6f518ccb4574852343b396ca4073f24ceb48a059 Mon Sep 17 00:00:00 2001 From: taras Date: Fri, 8 May 2026 20:45:32 +0200 Subject: [PATCH 57/57] Address mypy issues --- picows/websockets/asyncio/client.py | 3 ++- picows/websockets/asyncio/connection.py | 17 +++++++----- picows/websockets/asyncio/server.py | 10 +++++-- picows/websockets/compat.py | 7 ++--- picows/websockets/typing.py | 36 +++++++++++++++++++------ 5 files changed, 52 insertions(+), 21 deletions(-) diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py index e35afba..1e6a536 100644 --- a/picows/websockets/asyncio/client.py +++ b/picows/websockets/asyncio/client.py @@ -208,6 +208,7 @@ def listener_factory( ) try: + logger_name: Any = self.logger if self.logger is not None else getLogger("websockets.client") _transport, listener = await picows.ws_connect( listener_factory, self.uri, @@ -219,7 +220,7 @@ def listener_factory( extra_headers=extra_headers, proxy=proxy, socket_factory=socket_factory, - logger_name=self.logger if self.logger is not None else getLogger("websockets.client"), + logger_name=logger_name, **conn_kwargs, ) except picows.WSInvalidURL as exc: diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py index d080d38..8c47d7f 100644 --- a/picows/websockets/asyncio/connection.py +++ b/picows/websockets/asyncio/connection.py @@ -31,7 +31,7 @@ InvalidHandshake, InvalidStatus, ) -from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol +from ..typing import BytesLike, Data, DataLike, LoggerLike, LoggerProtocol, Subprotocol # cached for performance @@ -246,7 +246,7 @@ def _normalize_watermarks( @cython.ccall -def _resolve_logger(logger: LoggerLike) -> Union[logging.Logger, logging.LoggerAdapter[Any]]: +def _resolve_logger(logger: LoggerLike) -> LoggerProtocol: if logger is None: return logging.getLogger("websockets.client") if isinstance(logger, str): @@ -268,7 +268,7 @@ def process_exception(exc: Exception) -> Optional[Exception]: @cython.cclass class ConnectionBase(WSListener): # type: ignore[misc] id: uuid.UUID - logger: Union[logging.Logger, logging.LoggerAdapter[Any]] + logger: LoggerProtocol transport: WSTransport _request: Request _response: Response @@ -587,7 +587,8 @@ async def recv(self, decode: Optional[bool] = None) -> Data: msg_type = frame.msg_type if frame.fin: - return self._decode_data(frame.payload, msg_type, decode) # type: ignore[no-any-return] + data: Data = self._decode_data(frame.payload, msg_type, decode) + return data frames = [frame] try: @@ -601,7 +602,8 @@ async def recv(self, decode: Optional[bool] = None) -> Data: payloads.append(frame.payload) payload = b"".join(payloads) - return self._decode_data(payload, msg_type, decode) # type: ignore[no-any-return] + data = self._decode_data(payload, msg_type, decode) + return data except asyncio.CancelledError: self._recv_queue.extendleft(reversed(frames)) raise @@ -719,7 +721,7 @@ async def _get_next_async_fragment(self, async_iterator: AsyncIterator[DataLike] # CANCELLATION: # User async iterator is also getting canceled. - data: DataLike = await anext(async_iterator) + data: DataLike = await async_iterator.__anext__() if not self._is_in_open_state(): await self._wait_close_and_raise() return data @@ -848,7 +850,8 @@ async def send( elif isinstance(message, Mapping): raise TypeError("data is a dict-like object") elif isinstance(message, (AsyncIterable, Iterable)): - await self._send_fragments(message, text) # type: ignore[arg-type] + fragments: Union[AsyncIterable[DataLike], Iterable[DataLike]] = message # type: ignore[assignment] + await self._send_fragments(fragments, text) else: raise TypeError(f"message has unsupported type {type(message).__name__}") finally: diff --git a/picows/websockets/asyncio/server.py b/picows/websockets/asyncio/server.py index 6728f18..c018d45 100644 --- a/picows/websockets/asyncio/server.py +++ b/picows/websockets/asyncio/server.py @@ -26,6 +26,11 @@ from ..exceptions import ConcurrencyError, InvalidHandshake, InvalidOrigin from ..typing import DataLike, LoggerLike, Origin, Subprotocol +if sys.version_info >= (3, 11): + from builtins import ExceptionGroup +else: + ExceptionGroup = Exception + __all__ = [ "ServerConnection", "ServerHandshakeConnection", @@ -221,7 +226,7 @@ def process_request( return process_request -@dataclass(slots=True) +@dataclass class ServerHandshakeConnection: request: Request username: Optional[str] = None @@ -486,6 +491,7 @@ def listener_factory( else: return picows.WSUpgradeResponseWithListener(response.to_picows(), None) + logger_name: Any = self.logger if self.logger is not None else getLogger("websockets.server") raw_server = await picows.ws_create_server( listener_factory, self.host, @@ -494,7 +500,7 @@ def listener_factory( enable_auto_ping=False, enable_auto_pong=True, max_frame_size=max_frame_size, - logger_name=self.logger if self.logger is not None else getLogger("websockets.server"), + logger_name=logger_name, **self.kwargs, ) server.wrap(raw_server) diff --git a/picows/websockets/compat.py b/picows/websockets/compat.py index 55b2e0a..8c1b1a7 100644 --- a/picows/websockets/compat.py +++ b/picows/websockets/compat.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from enum import IntEnum from http import HTTPStatus +from typing import Union import picows from multidict import CIMultiDict @@ -17,7 +18,7 @@ class State(IntEnum): CLOSED = 3 -@dataclass(slots=True) +@dataclass class Request: path: str headers: CIMultiDict[str] @@ -30,12 +31,12 @@ def from_picows(cls, request: picows.WSUpgradeRequest) -> Request: ) -@dataclass(slots=True) +@dataclass class Response: status_code: int reason_phrase: str headers: CIMultiDict[str] - body: bytes | bytearray + body: Union[bytes, bytearray] @classmethod def from_picows(cls, response: picows.WSUpgradeResponse) -> Response: diff --git a/picows/websockets/typing.py b/picows/websockets/typing.py index 8c0b5d4..c70f878 100644 --- a/picows/websockets/typing.py +++ b/picows/websockets/typing.py @@ -1,21 +1,40 @@ from __future__ import annotations -import logging from http import HTTPStatus -from typing import Any +from typing import Any, Optional, Protocol, Tuple, Union from picows.types import WSHeadersLike -BytesLike = bytes | bytearray | memoryview -Data = str | bytes -DataLike = str | bytes | bytearray | memoryview +BytesLike = Union[bytes, bytearray, memoryview] +Data = Union[str, bytes] +DataLike = Union[str, bytes, bytearray, memoryview] HeadersLike = WSHeadersLike -LoggerLike = logging.Logger | logging.LoggerAdapter[Any] | str | None -StatusLike = HTTPStatus | int + + +class LoggerProtocol(Protocol): + @property + def debug(self) -> Any: + ... + + @property + def info(self) -> Any: + ... + + @property + def warning(self) -> Any: + ... + + @property + def error(self) -> Any: + ... + + +LoggerLike = Union[LoggerProtocol, str, None] +StatusLike = Union[HTTPStatus, int] Origin = str Subprotocol = str ExtensionName = str -ExtensionParameter = tuple[str, str | None] +ExtensionParameter = Tuple[str, Optional[str]] __all__ = [ "BytesLike", @@ -25,6 +44,7 @@ "ExtensionParameter", "HeadersLike", "LoggerLike", + "LoggerProtocol", "Origin", "StatusLike", "Subprotocol",