diff --git a/AGENTS.md b/AGENTS.md index 921cb72..9a9a531 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -2,18 +2,46 @@ Read README.md for the basic understanding of what this project is. -picows - this is the main package +picows - this is the main package. +picows.websockets - reimplements popular websockets library interface on top of picows +tests - Contains tests for picows +examples - Various examples for users on how to use picows + perf_test that could be used to build call-graph with perf -aiofastnet - Contains optimized versions of asyncio create_connection, create_server -I plan to make a separate python project for it, but I'm not there yet. It should be treated -as a separate python project. It will have its own tests, eventually its own description and docs. -The project contains very efficient repimplementation of SelectSocketTransport and SSLProtocol -using Cython and sometimes a pure C code. create_connection, create_server are defined in aiofastnet/api.py -sslproto.pyx - hack python SSLContext to get raw SSL_CTX*, it works with openssl api directly after that. -sslproto_stdlib.pyx - is just for reference, I will delete it soon, but now it's good for comparison between -stdlib ssl and whatever is in sslproto.pyx. - -tests - Contains tests for both picows and aiofastnet. Tests for aiofastnet will become a part of a separate project. +## Code style notes +- Max line width is 120 +- Do not write `del transport` or similar `del ` statements inside callbacks just to mark arguments as unused. + Leave unused callback parameters as-is or rename them with a leading underscore if that is clearer. + Using `del` in this situation is confusing and suggests reference-counting or lifetime management concerns. +- Prefer direct composition only when there is a real behavioral boundary. + Do not introduce adapter / holder / deferred-event plumbing just to preserve a conceptual separation. + If extra machinery exists only to work around the separation you introduced, the separation is probably wrong. +- Do not model impossible or non-normal internal states in the mainline code path without a concrete reason. + If an invariant is guaranteed by control flow, write the code around that invariant instead of adding repeated defensive checks. + Every extra "just in case" branch teaches the reader that the state is part of normal behavior. + Add such checks only for real risks like external misuse, concurrency races, partial failure, or invariants that are genuinely hard to guarantee. + If the only reason for the check is uncertainty in the design, fix the design first. +- When simplifying code, finish the simplification across all equivalent branches, not only at the first local site. + If the same conversion, check, or tiny code pattern appears in multiple sibling paths after a refactor, stop and normalize it before considering the work done. + Do not remove one layer of abstraction only to inline the same logic redundantly in several places. + After a refactor, scan for duplicated branch bodies and duplicated type-specific handling introduced by the change. +- `picows.websockets` aims for import-level compatibility with the official `websockets` package on the client side. + We can skip complicated areas such as the full server interface, but simple surface-area compatibility matters. + Type definitions, exception definitions, and other lightweight importable names should exist when upstream exposes them. + People switching from `websockets` to `picows.websockets` should notice as little difference as possible. +- In Cythonized Python modules, avoid `typing.cast(...)` in hot paths. + Cython may compile `cast(...)` into a real runtime global lookup and function call instead of erasing it like a type checker would. + Prefer control-flow narrowing, assertions, or narrowly scoped type-ignore comments when needed. +- If `picows` core exposes an inconsistent runtime shape or behavior that looks like a bug, do not silently normalize around it in wrapper code. + Stop and ask first, or at least clearly call out that it appears to be a core bug instead of assuming it is an intentional quirk. + Wrapper-level workarounds for such inconsistencies should be treated as temporary and explicit, not as the default resolution. + Legitimate intentional quirks can be documented in this file separately once confirmed. +- `WSUpgradeRequest` / `WSUpgradeResponse` expose a mixed bytes/str API and this is public API. + Request `method`, `path`, `version` and response `version` are low-level protocol bytes, while headers are decoded strings and response `status` is `HTTPStatus`. + Do not change this shape casually in core or silently normalize it away in wrappers; treat it as a stable compatibility constraint unless an intentional breaking change is agreed. +- In `picows` core, once a CLOSE frame has been sent, later send-side API calls are effectively no-ops. + This applies to `send_close()` as well as the other send methods. + Also, `disconnect()` and `wait_disconnected()` are safe to call multiple times. + Wrapper code should rely on these idempotency guarantees instead of adding its own state-based suppression around shutdown operations. ## Testing instructions - Run lint after updating code with: diff --git a/docs/source/reference.rst b/docs/source/reference.rst index c4ec2f7..f358434 100644 --- a/docs/source/reference.rst +++ b/docs/source/reference.rst @@ -146,8 +146,6 @@ Classes .. py:attribute:: payload_size :type: size_t - **Available only from Cython.** - Size of the payload. .. autoclass:: WSUpgradeRequest diff --git a/examples/echo_client_websockets.py b/examples/echo_client_websockets.py new file mode 100644 index 0000000..d43c6df --- /dev/null +++ b/examples/echo_client_websockets.py @@ -0,0 +1,14 @@ +import asyncio + +from picows import websockets + + +async def main(): + async with websockets.connect("ws://127.0.0.1:9001") as websocket: + await websocket.send("Hello world") + reply = await websocket.recv() + print(f"Echo reply: {reply}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/picows/picows.pxd b/picows/picows.pxd index 4034082..3836ec3 100644 --- a/picows/picows.pxd +++ b/picows/picows.pxd @@ -67,7 +67,7 @@ cdef class MemoryBuffer: cdef class WSFrame: cdef: char* payload_ptr - Py_ssize_t payload_size + readonly Py_ssize_t payload_size readonly Py_ssize_t tail_size readonly WSMsgType msg_type readonly uint8_t fin diff --git a/picows/picows.pyx b/picows/picows.pyx index dbe8653..6247df5 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -47,7 +47,7 @@ class _NotImplemented(Exception): pass -# "unlikely" works only gcc, but still nice to have +# "unlikely" works only for gcc, but still nice to have # https://github.com/cython/cython/issues/7667 cdef extern from *: cdef bint unlikely(bint val) noexcept @@ -668,6 +668,8 @@ cdef class WSTransport: if self.close_handshake is None: self.close_handshake = WSCloseHandshake.__new__(WSCloseHandshake) + self.close_handshake.recv = None + self.close_handshake.sent = None self.close_handshake.recv_then_sent = False self.close_handshake.sent = WSCloseInfo.__new__(WSCloseInfo) @@ -1648,7 +1650,8 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): self._f_payload_start_pos = self._f_curr_state_start_pos self._state = WSParserState.READ_PAYLOAD - if self._f_payload_length > self._max_frame_size: + if (self._f_payload_length > self._max_frame_size and + self._f_msg_type not in (WSMsgType.PING, WSMsgType.PONG, WSMsgType.CLOSE)): raise WSProtocolError( WSCloseCode.MESSAGE_TOO_BIG, f"Received frame with payload size exceeding max allowed size, " diff --git a/picows/websockets/__init__.py b/picows/websockets/__init__.py new file mode 100644 index 0000000..9cc5646 --- /dev/null +++ b/picows/websockets/__init__.py @@ -0,0 +1,100 @@ +from . import exceptions +from .asyncio.client import connect +from .asyncio.connection import ClientConnection, ServerConnection, process_exception +from .asyncio.router import route +from .asyncio.server import Server, ServerHandshakeConnection, basic_auth, broadcast, serve +from .compat import CloseCode, Request, Response, State +from .exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + DuplicateParameter, + InvalidHandshake, + InvalidHeader, + InvalidHeaderFormat, + InvalidHeaderValue, + InvalidMessage, + InvalidOrigin, + InvalidParameterName, + InvalidParameterValue, + InvalidProxy, + InvalidProxyMessage, + InvalidProxyStatus, + InvalidState, + InvalidStatus, + InvalidUpgrade, + InvalidURI, + NegotiationError, + PayloadTooBig, + ProtocolError, + ProxyError, + SecurityError, + WebSocketException, +) +from .typing import ( + BytesLike, + Data, + DataLike, + ExtensionName, + ExtensionParameter, + HeadersLike, + LoggerLike, + Origin, + StatusLike, + Subprotocol, +) + +__all__ = [ + "BytesLike", + "ClientConnection", + "CloseCode", + "Data", + "DataLike", + "Server", + "ServerHandshakeConnection", + "ServerConnection", + "ConcurrencyError", + "ConnectionClosed", + "ConnectionClosedError", + "ConnectionClosedOK", + "DuplicateParameter", + "ExtensionName", + "ExtensionParameter", + "HeadersLike", + "InvalidHandshake", + "InvalidHeader", + "InvalidHeaderFormat", + "InvalidHeaderValue", + "InvalidMessage", + "InvalidOrigin", + "InvalidParameterName", + "InvalidParameterValue", + "InvalidProxy", + "InvalidProxyMessage", + "InvalidProxyStatus", + "InvalidState", + "InvalidStatus", + "InvalidUpgrade", + "InvalidURI", + "LoggerLike", + "NegotiationError", + "Origin", + "PayloadTooBig", + "ProtocolError", + "ProxyError", + "Request", + "Response", + "SecurityError", + "State", + "StatusLike", + "Subprotocol", + "WebSocketException", + "basic_auth", + "broadcast", + "connect", + "exceptions", + "process_exception", + "route", + "serve", +] diff --git a/picows/websockets/asyncio/__init__.py b/picows/websockets/asyncio/__init__.py new file mode 100644 index 0000000..21c4a21 --- /dev/null +++ b/picows/websockets/asyncio/__init__.py @@ -0,0 +1,18 @@ +from .client import connect +from .connection import ClientConnection, ServerConnection, process_exception +from .router import route +from .server import Server, ServerHandshakeConnection, basic_auth, broadcast, serve +from ..compat import State + +__all__ = [ + "ClientConnection", + "Server", + "ServerHandshakeConnection", + "ServerConnection", + "basic_auth", + "broadcast", + "connect", + "process_exception", + "route", + "serve", +] diff --git a/picows/websockets/asyncio/client.py b/picows/websockets/asyncio/client.py new file mode 100644 index 0000000..1e6a536 --- /dev/null +++ b/picows/websockets/asyncio/client.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import asyncio +import socket +import sys +from collections.abc import Generator +from logging import getLogger +from ssl import SSLContext +from typing import Any, Callable, Optional, Sequence, Union + +import picows +from picows.url import parse_url + +from .connection import ( + ClientConnection, + process_exception, +) +from .negotiation import configure_permessage_deflate, resolve_subprotocol +from ..compat import Request, Response +from ..exceptions import ( + InvalidHandshake, + InvalidHeader, + InvalidMessage, + InvalidProxy, + InvalidStatus, + InvalidUpgrade, + InvalidURI, +) +from ..typing import HeadersLike, LoggerLike, Origin, Subprotocol + +__all__ = [ + "ClientConnection", + "connect", +] + + +_PERMESSAGE_DEFLATE_REQUEST = "permessage-deflate; client_max_window_bits" + + +def _default_user_agent() -> str: + return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" + + +def _header_items(headers: Any) -> list[tuple[str, str]]: + return [] if headers is None else list(headers.items()) + + +def _process_proxy(proxy: Union[str, bool, None], secure: bool) -> Optional[str]: + from urllib.request import getproxies + if proxy is None: + return None + if isinstance(proxy, str): + return proxy + if proxy is True: + proxies = getproxies() + return ( + proxies.get("wss" if secure else "ws") + or proxies.get("https" if secure else "http") + ) + raise InvalidProxy(str(proxy), "proxy must be None, True, or a proxy URL") + + +class _Connect: + def __init__( + self, + uri: str, + *, + origin: Optional[Origin] = None, + extensions: Optional[Sequence[Any]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + compression: Optional[str] = "deflate", + additional_headers: Optional[HeadersLike] = None, + user_agent_header: Optional[str] = _default_user_agent(), + proxy: Union[str, bool, None] = True, + process_exception: Callable[[Exception], Optional[Exception]] = process_exception, + open_timeout: Optional[float] = 10, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = 10, + max_size: Optional[int] = 1024 * 1024, + max_queue: Union[int, tuple[Optional[int], Optional[int]], None] = 16, + write_limit: Union[int, tuple[int, Optional[int]]] = 32768, + logger: LoggerLike = None, + **kwargs: Any, + ): + self.uri = uri + self.origin = origin + self.extensions = extensions + self.subprotocols = subprotocols + self.compression = compression + self.additional_headers = additional_headers + self.user_agent_header = user_agent_header + self.proxy = proxy + self.process_exception = process_exception + self.open_timeout = open_timeout + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.close_timeout = close_timeout + self.max_size = max_size + self.max_queue = max_queue + self.write_limit = write_limit + self.logger = logger + if "create_connection" in kwargs: + raise NotImplementedError("create_connection isn't supported by picows.websockets yet") + self.kwargs = kwargs + self._connection: Optional[ClientConnection] = None + self._backoff = 1.0 + + def __await__(self) -> Generator[Any, None, ClientConnection]: + return self._connect().__await__() + + async def __aenter__(self) -> ClientConnection: + self._connection = await self._connect() + return self._connection + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + if self._connection is not None: + await self._connection.close() + self._connection = None + + def __aiter__(self) -> _Connect: + return self + + async def __anext__(self) -> ClientConnection: + if self._connection is not None: + await self._connection.close() + self._connection = None + while True: + try: + connection = await self._connect() + except Exception as exc: + processed = self.process_exception(exc) + if processed is not None: + raise processed + await asyncio.sleep(self._backoff) + self._backoff = min(self._backoff * 2, 60.0) + continue + self._connection = connection + self._backoff = 1.0 + return connection + + async def _connect(self) -> ClientConnection: + parsed = parse_url(self.uri) + proxy = _process_proxy(self.proxy, parsed.is_secure) + extra_headers = self._build_headers() + max_message_size = 0 if self.max_size is None else self.max_size + max_frame_size = 2 ** 31 - 1 if not self.max_size else self.max_size + + if self.extensions is not None: + raise NotImplementedError("custom extensions aren't supported by picows.websockets") + if self.compression not in (None, "deflate"): + raise NotImplementedError("only compression=None or 'deflate' are accepted") + + conn_kwargs = dict(self.kwargs) + ssl_context = conn_kwargs.pop("ssl", None) + host_override = conn_kwargs.pop("host", None) + port_override = conn_kwargs.pop("port", None) + preexisting_sock = conn_kwargs.pop("sock", None) + + socket_factory = conn_kwargs.pop("socket_factory", None) + if preexisting_sock is not None: + if socket_factory is not None: + raise TypeError("cannot pass both sock and socket_factory") + if not isinstance(preexisting_sock, socket.socket): + raise TypeError("sock must be a socket.socket instance") + + provided_sock = preexisting_sock + + def provided_socket(_: Any) -> socket.socket: + return provided_sock + + socket_factory = provided_socket + elif host_override is not None or port_override is not None: + if socket_factory is not None: + raise TypeError("cannot pass both host/port override and socket_factory") + + async def connect_override(_: Any) -> socket.socket: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + await asyncio.get_running_loop().sock_connect( + sock, + (host_override or parsed.host, port_override or parsed.port), + ) + return sock + + socket_factory = connect_override + + def listener_factory( + request: picows.WSUpgradeRequest, + response: picows.WSUpgradeResponse, + ) -> ClientConnection: + wrapped_request = Request.from_picows(request) + wrapped_response = Response.from_picows(response) + subprotocol = resolve_subprotocol(self.subprotocols, wrapped_response) + permessage_deflate = configure_permessage_deflate(wrapped_response, self.compression) + return ClientConnection( + request=wrapped_request, + response=wrapped_response, + subprotocol=subprotocol, + permessage_deflate=permessage_deflate, + ping_interval=self.ping_interval, + ping_timeout=self.ping_timeout, + close_timeout=self.close_timeout, + max_queue=self.max_queue, + write_limit=self.write_limit, + max_message_size=max_message_size, + logger=self.logger, + ) + + try: + logger_name: Any = self.logger if self.logger is not None else getLogger("websockets.client") + _transport, listener = await picows.ws_connect( + listener_factory, + self.uri, + ssl_context=self._coerce_ssl_context(ssl_context), + websocket_handshake_timeout=self.open_timeout, + enable_auto_ping=False, + enable_auto_pong=True, + max_frame_size=max_frame_size, + extra_headers=extra_headers, + proxy=proxy, + socket_factory=socket_factory, + logger_name=logger_name, + **conn_kwargs, + ) + except picows.WSInvalidURL as exc: + raise InvalidURI(exc.args[0], exc.args[1] if len(exc.args) > 1 else str(exc)) from exc + except picows.WSInvalidStatusError as exc: + raise InvalidStatus(exc.response) from exc + except picows.WSInvalidUpgradeError as exc: + raise InvalidUpgrade(exc.name, exc.value) from exc + except picows.WSInvalidHeaderError as exc: + raise InvalidHeader(exc.name, exc.value) from exc + except picows.WSInvalidMessageError as exc: + raise InvalidMessage(str(exc)) from exc + except picows.WSHandshakeError as exc: + raise InvalidHandshake(str(exc)) from exc + assert isinstance(listener, ClientConnection) + return listener + + def _build_headers(self) -> list[tuple[str, str]]: + headers = _header_items(self.additional_headers) + if self.origin is not None: + headers.append(("Origin", self.origin)) + if self.user_agent_header is not None: + headers.append(("User-Agent", self.user_agent_header)) + if self.subprotocols: + headers.append(("Sec-WebSocket-Protocol", ", ".join(self.subprotocols))) + if self.compression == "deflate": + headers.append(("Sec-WebSocket-Extensions", _PERMESSAGE_DEFLATE_REQUEST)) + return headers + + def _coerce_ssl_context(self, value: Any) -> Optional[SSLContext]: + if value in (None, True): + return None + if value is False: + raise NotImplementedError("ssl=False isn't supported for wss:// URIs") + if not isinstance(value, SSLContext): + raise TypeError("ssl must be an SSLContext, True, False, or None") + return value + + +def connect(uri: str, **kwargs: Any) -> _Connect: + return _Connect(uri, **kwargs) diff --git a/picows/websockets/asyncio/connection.py b/picows/websockets/asyncio/connection.py new file mode 100644 index 0000000..8c47d7f --- /dev/null +++ b/picows/websockets/asyncio/connection.py @@ -0,0 +1,1095 @@ +from __future__ import annotations + +import asyncio +import logging +import os +import sys +import uuid +import zlib +from collections import deque +from collections.abc import AsyncIterable, Iterable +from time import monotonic +from typing import Any, AsyncIterator, Awaitable, Optional, Sequence, \ + Union, Dict, Tuple, Iterator, Mapping, NoReturn + +import cython + +if cython.compiled: + from cython.cimports.picows.picows import WSListener, WSTransport, WSFrame, \ + WSMsgType, WSCloseCode +else: + from picows import WSListener, WSTransport, WSFrame, WSMsgType, WSCloseCode + +from picows import WSProtocolError + +from ..compat import State, CloseCode, Request, Response +from ..exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + InvalidHandshake, + InvalidStatus, +) +from ..typing import BytesLike, Data, DataLike, LoggerLike, LoggerProtocol, Subprotocol + + +# cached for performance +_ok_close_codes = cython.declare(set, {0, 1000, 1001}) +_asyncio_shield = cython.declare(object, asyncio.shield) + +# zlib/compress/decompress utils, cached for performance +_empty_uncompressed_block = cython.declare(bytes, b"\x00\x00\xff\xff") +_zlib_compressobj = cython.declare(object, zlib.compressobj) +_zlib_decompressobj = cython.declare(object, zlib.decompressobj) +_zlib_z_sync_flush = cython.declare(object, zlib.Z_SYNC_FLUSH) + + +@cython.freelist(128) +@cython.no_gc +@cython.cclass +class _BufferedFrame: + msg_type: WSMsgType + payload: bytes + fin: cython.bint + + +@cython.cfunc +@cython.inline +def _make_buffered_frame(msg_type: WSMsgType, payload: bytes, fin: cython.bint) -> _BufferedFrame: + self: _BufferedFrame = _BufferedFrame.__new__(_BufferedFrame) + self.msg_type = msg_type + self.payload = payload + self.fin = fin + return self + + +@cython.no_gc +@cython.cclass +class _PerMessageDeflate: + remote_no_context_takeover: cython.bint + local_no_context_takeover: cython.bint + remote_max_window_bits: int + local_max_window_bits: int + _decoder: Any + _encoder: Any + _decode_cont_data: cython.bint + + @classmethod + def from_response_header(cls, header_value: str) -> _PerMessageDeflate: + extensions = [item.strip() for item in header_value.split(",") if item.strip()] + if len(extensions) != 1: + raise InvalidHandshake("unsupported websocket extension negotiation") + + parts = [item.strip() for item in extensions[0].split(";")] + if not parts or parts[0] != "permessage-deflate": + raise InvalidHandshake("unsupported websocket extension negotiation") + + server_no_context_takeover = False + client_no_context_takeover = False + server_max_window_bits = None + client_max_window_bits = None + seen = set() + + for raw_param in parts[1:]: + if not raw_param: + continue + if "=" in raw_param: + name, value = raw_param.split("=", 1) + name = name.strip() + value = value.strip() + else: + name = raw_param + value = None + + if name in seen: + raise InvalidHandshake( + f"unsupported websocket extension negotiation: {name}") + seen.add(name) + + if name == "server_no_context_takeover": + if value is not None: + raise InvalidHandshake("invalid server_no_context_takeover value") + server_no_context_takeover = True + elif name == "client_no_context_takeover": + if value is not None: + raise InvalidHandshake("invalid client_no_context_takeover value") + client_no_context_takeover = True + elif name == "server_max_window_bits": + if value is None or not value.isdigit(): + raise InvalidHandshake("invalid server_max_window_bits value") + server_max_window_bits = int(value) + if not 8 <= server_max_window_bits <= 15: + raise InvalidHandshake("invalid server_max_window_bits value") + elif name == "client_max_window_bits": + if value is None or not value.isdigit(): + raise InvalidHandshake("invalid client_max_window_bits value") + client_max_window_bits = int(value) + if not 8 <= client_max_window_bits <= 15: + raise InvalidHandshake("invalid client_max_window_bits value") + else: + raise InvalidHandshake(f"unsupported extension parameter: {name}") + + self: _PerMessageDeflate = _PerMessageDeflate.__new__(_PerMessageDeflate) + self.remote_no_context_takeover = server_no_context_takeover + self.local_no_context_takeover = client_no_context_takeover + self.remote_max_window_bits = -(server_max_window_bits or 15) + self.local_max_window_bits = -(client_max_window_bits or 15) + self._decoder = None + self._encoder = None + self._decode_cont_data = False + + # wbits: +9 to +15 + # The base-two logarithm of the window size, which therefore ranges between 512 and 32768. + # Larger values produce better compression at the expense of greater memory usage. + # The resulting output will include a zlib-specific header and trailer. + # Negative wbits: + # Uses the absolute value of wbits as the window size logarithm, + # while producing a raw output stream with no header or trailing checksum. + + if not self.remote_no_context_takeover: + self._decoder = _zlib_decompressobj(wbits=self.remote_max_window_bits) + if not self.local_no_context_takeover: + self._encoder = _zlib_compressobj(wbits=self.local_max_window_bits) + + return self + + @cython.cfunc + @cython.inline + def decode_frame(self, frame: WSFrame, max_length: cython.Py_ssize_t) -> bytes: + data: bytes + data2: bytes + + if frame.msg_type == WSMsgType.CONTINUATION: + if frame.rsv1: + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, "unexpected rsv1 on continuation frame") + if not self._decode_cont_data: + return frame.get_payload_as_bytes() # type: ignore[no-any-return] + if frame.fin: + self._decode_cont_data = False + else: + if not frame.rsv1: + return frame.get_payload_as_bytes() # type: ignore[no-any-return] + if not frame.fin: + self._decode_cont_data = True + if self.remote_no_context_takeover or self._decoder is None: + self._decoder = _zlib_decompressobj(wbits=self.remote_max_window_bits) + + try: + data = self._decoder.decompress(frame.get_payload_as_memoryview(), max_length) + if max_length > 0: + max_length -= len(data) + + if self._decoder.unconsumed_tail: + raise WSProtocolError(WSCloseCode.MESSAGE_TOO_BIG, + "message too big") + + if frame.fin: + data2 = self._decoder.decompress(_empty_uncompressed_block, max_length) + if data2: + data += data2 + except zlib.error as exc: + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, + "decompression failed") from exc + + if frame.fin and self.remote_no_context_takeover: + self._decoder = None + + return data + + @cython.cfunc + @cython.inline + @cython.wraparound(True) + def encode_frame(self, msg_type: WSMsgType, data: DataLike, fin: cython.bint) -> BytesLike: + if msg_type != WSMsgType.CONTINUATION and (self.local_no_context_takeover or self._encoder is None): + self._encoder = _zlib_compressobj(wbits=self.local_max_window_bits) + + if isinstance(data, str): + data = cython.cast(str, data).encode('utf-8') + + compressed_data: BytesLike = (self._encoder.compress(data) + + self._encoder.flush(_zlib_z_sync_flush)) + if fin: + mv = memoryview(compressed_data) + assert mv[-4:] == _empty_uncompressed_block + compressed_data = mv[:-4] + if self.local_no_context_takeover: + self._encoder = None + + return compressed_data + + +@cython.cfunc +@cython.inline +def _coerce_close_code(code: CloseCode) -> Optional[int]: + return None if code is None else code # type: ignore[return-value] + + +@cython.cfunc +@cython.inline +def _coerce_close_reason(reason: Optional[str]) -> Optional[str]: + return reason if reason is not None else None + +@cython.cfunc +@cython.inline +def _normalize_watermarks( + max_queue: Union[int, tuple[Optional[int], Optional[int]], None], +) -> tuple[cython.Py_ssize_t, cython.Py_ssize_t]: + if max_queue is None: + return 0, 0 + if isinstance(max_queue, tuple): + high, low = max_queue + if high is None: + return 0, 0 + return high, high // 4 if low is None else low + return max_queue, max_queue // 4 + + +@cython.ccall +def _resolve_logger(logger: LoggerLike) -> LoggerProtocol: + if logger is None: + return logging.getLogger("websockets.client") + if isinstance(logger, str): + return logging.getLogger(logger) + return logger + + +@cython.ccall +def process_exception(exc: Exception) -> Optional[Exception]: + if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)): + return None + if isinstance(exc, InvalidStatus): + status = exc.response.status_code + if int(status) in {500, 502, 503, 504}: + return None + return exc + + +@cython.cclass +class ConnectionBase(WSListener): # type: ignore[misc] + id: uuid.UUID + logger: LoggerProtocol + transport: WSTransport + _request: Request + _response: Response + _subprotocol: Optional[Subprotocol] + _permessage_deflate: Optional[_PerMessageDeflate] + _loop: asyncio.AbstractEventLoop + + # Send side + _send_in_progress: cython.bint + _send_waiters: deque[asyncio.Future[None]] + _write_ready: Optional[asyncio.Future[None]] + _write_limit: Union[int, tuple[int, Optional[int]]] + + # Recv side + _recv_in_progress: cython.bint + _recv_streaming_broken: cython.bint + _paused_reading: cython.bint + _recv_waiter: Optional[asyncio.Future[None]] + _recv_queue: deque[_BufferedFrame] + _max_message_size: cython.Py_ssize_t # 0 - no limit + _max_queue_high: cython.Py_ssize_t # 0 - no limit + _max_queue_low: cython.Py_ssize_t # 0 - no limit + _incoming_message_active: cython.bint + _incoming_message_size: cython.Py_ssize_t + + # Close logic + _close_timeout: Optional[float] + _close_fut: asyncio.Future[None] + _close_exc: Optional[ConnectionClosed] + + _pending_pings: Dict[bytes, Tuple[asyncio.Future[float], float]] + _ping_interval: Optional[float] + _ping_timeout: Optional[float] + _keepalive_task: Optional[asyncio.Task[None]] + _latency: cython.double + + def __init__( + self, + *, + request: Request, + response: Response, + subprotocol: Optional[Subprotocol], + permessage_deflate: Optional[_PerMessageDeflate], + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = 10, + max_queue: Union[int, tuple[Optional[int], Optional[int]], None] = 16, + write_limit: Union[int, tuple[int, Optional[int]]] = 32768, + max_message_size: Optional[int] = 1024 * 1024, + logger: LoggerLike = None, + ): + self.id = uuid.uuid4() + self.logger = _resolve_logger(logger) + self.transport = cython.cast(WSTransport, None) + self._request = request + self._response = response + self._subprotocol = subprotocol + self._permessage_deflate = permessage_deflate + self._loop = asyncio.get_running_loop() + + self._send_in_progress = False + self._send_waiters = deque() + self._write_ready: Optional[asyncio.Future[None]] = None + self._write_limit = write_limit + + self._recv_in_progress = False + self._recv_streaming_broken = False + self._paused_reading = False + self._recv_waiter = None + self._recv_queue = deque() + self._max_message_size = 0 if max_message_size is None else max_message_size + self._max_queue_high, self._max_queue_low = _normalize_watermarks(max_queue) + self._incoming_message_active = False + self._incoming_message_size = 0 + + self._close_timeout = close_timeout + self._close_fut = self._loop.create_future() + self._close_exc: Optional[ConnectionClosed] = None + + self._pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} + self._ping_interval = ping_interval + self._ping_timeout = ping_timeout + self._keepalive_task: Optional[asyncio.Task[None]] = None + self._latency = 0.0 + + @cython.ccall + def on_ws_connected(self, transport: WSTransport) -> None: + raise NotImplementedError + + @cython.ccall + def on_ws_disconnected(self, transport: WSTransport) -> None: + # Set _close_exc, _close_fut + self._set_close_exception() + + # Pacify type checker + assert self._close_exc is not None + + # Wake up potential waiter on _recv_queue + if self._recv_waiter is not None: + if not self._recv_waiter.done(): + self._recv_waiter.set_exception(self._close_exc) + self._recv_waiter = None + + # Cancel pinging loop + if self._keepalive_task is not None: + self._keepalive_task.cancel() + self._keepalive_task = None + + # If there is a waiter waiting for resume_writing wake it up with exception + if self._write_ready is not None: + if not self._write_ready.done(): + self._write_ready.set_exception(self._close_exc) + self._write_ready = None + + # Wake up all waiters waiting for ping replies + for ping_waiter, _ in self._pending_pings.values(): + if not ping_waiter.done(): + ping_waiter.set_exception(self._close_exc) + self._pending_pings.clear() + + # Wake up all waiters waiting for current send to complete + for send_waiter in self._send_waiters: + if not send_waiter.done(): + send_waiter.set_exception(self._close_exc) + self._send_waiters.clear() + + @cython.ccall + def on_ws_frame(self, transport: WSTransport, frame: WSFrame) -> None: + if frame.msg_type == WSMsgType.PONG: + self._process_pong_frame(frame) + return + + if frame.msg_type == WSMsgType.CLOSE: + self._process_close_frame(frame) + return + + if frame.msg_type not in (WSMsgType.TEXT, WSMsgType.BINARY, WSMsgType.CONTINUATION): + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, "unsupported frame opcode") + + if self._permessage_deflate is None and frame.rsv1: + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, "received compressed frame without negotiated permessage-deflate") + + if frame.msg_type == WSMsgType.CONTINUATION and not self._incoming_message_active: + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, "unexpected continuation frame") + + if frame.msg_type != WSMsgType.CONTINUATION and self._incoming_message_active: + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, "expected continuation frame") + + payload: bytes + if self._permessage_deflate is not None: + remaining = 0 if self._max_message_size == 0 else ( + max(self._max_message_size - self._incoming_message_size, 0)) + payload = self._permessage_deflate.decode_frame(frame, remaining) + else: + payload = frame.get_payload_as_bytes() + + if frame.msg_type == WSMsgType.CONTINUATION: + self._incoming_message_size += len(payload) + else: + self._incoming_message_size = len(payload) + + if self._max_message_size > 0 and self._incoming_message_size > self._max_message_size: + raise WSProtocolError(WSCloseCode.MESSAGE_TOO_BIG, "message too big") + + if frame.msg_type == WSMsgType.CONTINUATION: + if frame.fin: + self._incoming_message_active = False + self._incoming_message_size = 0 + else: + if frame.fin: + self._incoming_message_size = 0 + else: + self._incoming_message_active = True + + self._add_to_recv_queue(_make_buffered_frame(frame.msg_type, payload, frame.fin)) + self._pause_reading_if_needed() + + @cython.ccall + def pause_writing(self) -> None: + if self._write_ready is None: + self._write_ready = self._loop.create_future() + + @cython.ccall + def resume_writing(self) -> None: + if self._write_ready is not None: + if not self._write_ready.done(): + self._write_ready.set_result(None) + self._write_ready = None + + @cython.cfunc + @cython.inline + def _set_write_limits(self, write_limit: Union[int, tuple[int, Optional[int]]]) -> None: + if isinstance(write_limit, tuple): + high, low = write_limit + else: + high, low = write_limit, None + self.transport.underlying_transport.set_write_buffer_limits(high=high, low=low) + + @cython.cfunc + @cython.inline + def _process_pong_frame(self, frame: WSFrame) -> None: + ping = self._pending_pings.pop(frame.get_payload_as_bytes(), None) + if ping is not None: + waiter, sent_at = ping + self._latency = monotonic() - sent_at + if not waiter.done(): + waiter.set_result(self._latency) + + @cython.cfunc + @cython.inline + def _process_close_frame(self, frame: WSFrame) -> None: + close_code = frame.get_close_code() + close_message = frame.get_close_message() + self.transport.send_close(close_code, close_message) + self.transport.disconnect() + + @cython.cfunc + @cython.inline + def _pause_reading_if_needed(self) -> None: + if self._max_queue_high > 0 and not self._paused_reading and len(self._recv_queue) >= self._max_queue_high: + self.transport.underlying_transport.pause_reading() + self._paused_reading = True + + @cython.cfunc + @cython.inline + def _resume_reading_if_needed(self) -> None: + if not self._paused_reading: + return + if self._max_queue_low == 0 or len(self._recv_queue) <= self._max_queue_low: + self.transport.underlying_transport.resume_reading() + self._paused_reading = False + + @cython.cfunc + @cython.inline + def _add_to_recv_queue(self, frame: _BufferedFrame) -> None: + self._recv_queue.append(frame) + waiter = self._recv_waiter + if waiter is not None: + self._recv_waiter = None + if not waiter.done(): + waiter.set_result(None) + + @cython.cfunc + @cython.inline + def _wait_recv_queue_not_empty(self) -> asyncio.Future[None]: + assert self._recv_waiter is None + if self._close_exc is not None: + raise self._close_exc + + waiter: asyncio.Future[None] = self._loop.create_future() + self._recv_waiter = waiter + return waiter + + @cython.cfunc + @cython.inline + def _set_close_exception(self) -> None: + handshake = self.transport.close_handshake + if handshake is None: + self._close_exc = ConnectionClosedError(None, None, None) + self._close_fut.set_result(None) + return + rcvd = handshake.recv + sent = handshake.sent + rcvd_then_sent = handshake.recv_then_sent + rcvd_code = _coerce_close_code(rcvd.code) if rcvd is not None else None + sent_code = _coerce_close_code(sent.code) if sent is not None else None + ok = ( + (rcvd_code in _ok_close_codes or rcvd_code is None) + and (sent_code in _ok_close_codes or sent_code is None) + ) + exc_type = ConnectionClosedOK if ok else ConnectionClosedError + self._close_exc = exc_type(rcvd, sent, rcvd_then_sent) + self._close_fut.set_result(None) + + @cython.cfunc + @cython.inline + def _set_recv_in_progress(self) -> None: + if self._recv_in_progress: + raise ConcurrencyError("cannot call recv() or recv_streaming() concurrently") + if self._recv_streaming_broken: + raise ConcurrencyError("recv_streaming() wasn't fully consumed") + self._recv_in_progress = True + + @cython.cfunc + @cython.inline + def _decode_data(self, payload: bytes, msg_type: WSMsgType, decode: Optional[bool]) -> Data: + if decode is True or (msg_type == WSMsgType.TEXT and decode is None): + return payload.decode("utf-8") + else: + return payload + + @cython.cfunc + @cython.inline + def _check_frame(self, frame: _BufferedFrame) -> _BufferedFrame: + self._resume_reading_if_needed() + return frame + + @cython.cfunc + @cython.inline + def _fail_invalid_data(self, exc: UnicodeDecodeError) -> None: + self.transport.send_close( + WSCloseCode.INVALID_TEXT, + f"{exc.reason} at position {exc.start}", + ) + self.transport.disconnect(False) + + async def recv(self, decode: Optional[bool] = None) -> Data: + frame: _BufferedFrame + + self._set_recv_in_progress() + + try: + if not self._recv_queue: + await self._wait_recv_queue_not_empty() + frame = self._check_frame(self._recv_queue.popleft()) + + msg_type = frame.msg_type + if frame.fin: + data: Data = self._decode_data(frame.payload, msg_type, decode) + return data + + frames = [frame] + try: + payloads = [frame.payload] + while not frame.fin: + if not self._recv_queue: + await self._wait_recv_queue_not_empty() + frame = self._check_frame(self._recv_queue.popleft()) + + frames.append(frame) + payloads.append(frame.payload) + + payload = b"".join(payloads) + data = self._decode_data(payload, msg_type, decode) + return data + except asyncio.CancelledError: + self._recv_queue.extendleft(reversed(frames)) + raise + except UnicodeDecodeError as exc: + self._fail_invalid_data(exc) + await self._wait_close_and_raise(exc) + finally: + self._recv_in_progress = False + self._recv_waiter = None + + def recv_streaming(self, decode: Optional[bool] = None) -> AsyncIterator[Data]: + msg_started: cython.bint = False + msg_finished: cython.bint = False + frame: _BufferedFrame + msg_type: WSMsgType + + async def iterator() -> AsyncIterator[Data]: + nonlocal msg_started, msg_finished + + try: + self._set_recv_in_progress() + if not self._recv_queue: + await self._wait_recv_queue_not_empty() + frame = self._check_frame(self._recv_queue.popleft()) + + msg_started = True + msg_type = frame.msg_type + yield self._decode_data(frame.payload, msg_type, decode) + + while not frame.fin: + if not self._recv_queue: + await self._wait_recv_queue_not_empty() + frame = self._check_frame(self._recv_queue.popleft()) + + yield self._decode_data(frame.payload, msg_type, decode) + msg_finished = True + except UnicodeDecodeError as exc: + self._fail_invalid_data(exc) + await self._wait_close_and_raise(exc) + finally: + self._recv_in_progress = False + self._recv_waiter = None + if msg_started and not msg_finished: + self._recv_streaming_broken = True + elif msg_finished: + self._recv_streaming_broken = False + + return iterator() + + @cython.cfunc + @cython.inline + @cython.exceptval(check=False) + def _is_in_open_state(self) -> cython.bint: + # Before on_ws_connected, self.transport is None + # on_ws_frame immediately send CLOSE reply on incoming CLOSE frame, so receiving CLOSE == is_close_frame_sent + # transport.is_disconnect happens the last, asyncio Protocol got connection_lost event + + return (self.transport is not None + and not self.transport.is_disconnected + and not self.transport.is_close_frame_sent) + + @cython.cfunc + @cython.inline + def _encode_and_send(self, msg_type: WSMsgType, message: DataLike, fin: cython.bint) -> None: + if self._permessage_deflate is not None: + message = self._permessage_deflate.encode_frame(msg_type, message, fin) + self.transport.send(msg_type, message, fin, msg_type != WSMsgType.CONTINUATION) + else: + self.transport.send(msg_type, message, fin) + + async def _wait_close_and_raise(self, exc: Optional[BaseException]=None) -> NoReturn: + # CANCELLATION: + # _close_fut is supposed to be set only from on_ws_disconnected. + # It is intentionally shielded. + await _asyncio_shield(self._close_fut) + assert self._close_exc is not None # pacify type checker + + if exc is None: + raise self._close_exc + else: + raise self._close_exc from exc + + async def _wait_send_turn(self) -> None: + # DISCONNECT: the waiter will raise ConnectionClosed + # It can also successfully finish, but we may be in CLOSING state. + # In such case delegate waiting to _wait_close_and_raise + + # CANCELLATION: + # waiter future is not shielded intentionally, it turns into Cancelled + # state and removed from waiters by _release_send. + # _wait_close_and_raise shields _close_fut. + waiter: asyncio.Future[None] = self._loop.create_future() + self._send_waiters.append(waiter) + await waiter + if not self._is_in_open_state(): + await self._wait_close_and_raise() + + async def _wait_write_ready(self) -> None: + # DISCONNECT: the waiter will raise ConnectionClosed + # It can also successfully finish, but we may be in CLOSING state. + # In such case delegate waiting to _wait_close_and_raise + + # CANCELLATION: + # _write_ready future is shielded intentionally. It is only supposed to + # be set from resume_writing and on_ws_disconnected. + assert self._write_ready is not None + await _asyncio_shield(self._write_ready) + if not self._is_in_open_state(): + await self._wait_close_and_raise() + + async def _get_next_async_fragment(self, async_iterator: AsyncIterator[DataLike]) -> DataLike: + # DISCONNECT: raise ConnectionClosed if after user async interator + # returns we are not in OPEN state + + # CANCELLATION: + # User async iterator is also getting canceled. + + data: DataLike = await async_iterator.__anext__() + if not self._is_in_open_state(): + await self._wait_close_and_raise() + return data + + @cython.cfunc + @cython.inline + def _check_fragment_type(self, message: DataLike, first_is_str: cython.bint) -> None: + if first_is_str and isinstance(message, str): + return + elif not first_is_str and isinstance(message, (bytes, bytearray, memoryview)): + return + + raise TypeError("all fragments must be of the same category: str vs bytes-like") + + @cython.cfunc + @cython.inline + def _release_send(self) -> None: + self._send_in_progress = False + + waiter: asyncio.Future[None] + while self._send_waiters: + waiter = self._send_waiters.popleft() + # Some waiters may be canceled, that is why we have defensive check + # for waiter.done() + if not waiter.done(): + waiter.set_result(None) + return + + async def _send_fragments( + self, + messages: Union[AsyncIterable[DataLike], Iterable[DataLike]], + text: Optional[bool], + ) -> None: + is_async: cython.bint + async_iterator: AsyncIterator[DataLike] + iterator: Iterator[DataLike] + stop_exception_type: Union[type[StopAsyncIteration], type[StopIteration]] + + if isinstance(messages, AsyncIterable): + async_iterator = messages.__aiter__() + iterator = None # type: ignore[assignment] + stop_exception_type = StopAsyncIteration + is_async = True + else: + async_iterator = None # type: ignore[assignment] + iterator = iter(messages) + stop_exception_type = StopIteration + is_async = False + + try: + try: + if is_async: + current = await self._get_next_async_fragment(async_iterator) + else: + current = next(iterator) + except stop_exception_type: + return + + first_is_str: cython.bint + if isinstance(current, str): + msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT + first_is_str = True + elif isinstance(current, (bytes, bytearray, memoryview)): + msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY + first_is_str = False + else: + raise TypeError(f"message must contain str or bytes-like objects, got {type(current).__name__}") + + while True: + # Original websockets implementations always send one last empty + # frame with fin=True even if iterator returns only one fragment + # Perhaps this is useful for the users, just replicate this + # behavior. + self._encode_and_send(msg_type, current, False) + msg_type = WSMsgType.CONTINUATION + + try: + if is_async: + current = await self._get_next_async_fragment(async_iterator) + else: + current = next(iterator) + except stop_exception_type: + break + + self._check_fragment_type(current, first_is_str) + if self._write_ready is not None: + await self._wait_write_ready() + + # Send the last empty frame with fin=True + self._encode_and_send(msg_type, b"", True) + if self._write_ready is not None: + await self._wait_write_ready() + except BaseException: + self.transport.send_close(WSCloseCode.PROTOCOL_ERROR, "error in fragmented message") + self.transport.disconnect(False) + raise + + async def send( + self, + message: Union[DataLike, Iterable[DataLike], AsyncIterator[DataLike]], + text: Optional[bool] = None, + ) -> None: + # send doesn't directly wait on helper futures. It is very tricky to handle + # disconnects and cancellations properly. send delegates this to + # _wait_* helpers. + if not self._is_in_open_state(): + await self._wait_close_and_raise() + + if self._send_in_progress: + await self._wait_send_turn() + else: + self._send_in_progress = True + + try: + if isinstance(message, (str, bytes, bytearray, memoryview)): + if isinstance(message, str): + msg_type = WSMsgType.BINARY if text is False else WSMsgType.TEXT + else: + msg_type = WSMsgType.TEXT if text else WSMsgType.BINARY + + self._encode_and_send(msg_type, message, True) + + if self._write_ready is not None: + await self._wait_write_ready() + # Catch a common mistake -- passing a dict to send(). + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + elif isinstance(message, (AsyncIterable, Iterable)): + fragments: Union[AsyncIterable[DataLike], Iterable[DataLike]] = message # type: ignore[assignment] + await self._send_fragments(fragments, text) + else: + raise TypeError(f"message has unsupported type {type(message).__name__}") + finally: + self._release_send() + + async def close(self, code: int = 1000, reason: str = "") -> None: + if self._send_in_progress: + self.transport.send_close(WSCloseCode.INTERNAL_ERROR, "close during fragmented message") + self.transport.disconnect(False) + else: + self.transport.send_close(cython.cast(WSCloseCode, code), reason) + + try: + if self._close_timeout is None: + await self.wait_closed() + else: + await asyncio.wait_for(self.wait_closed(), self._close_timeout) + except asyncio.TimeoutError: + self.transport.disconnect(False) + await self.wait_closed() + + async def wait_closed(self) -> None: + try: + await self.transport.wait_disconnected() + except Exception: + pass + + async def ping(self, data: Optional[DataLike] = None) -> Awaitable[float]: + if not self._is_in_open_state(): + await self._wait_close_and_raise() + + if data is None: + while True: + payload = os.urandom(4) + if payload not in self._pending_pings: + break + elif isinstance(data, str): + payload = data.encode("utf-8") + elif isinstance(data, (bytes, bytearray, memoryview)): + payload = bytes(data) + else: + raise TypeError("ping payload must be str, bytes-like, or None") + + if payload in self._pending_pings: + raise ConcurrencyError("another ping was sent with the same data") + + waiter: asyncio.Future[float] = asyncio.get_running_loop().create_future() + self._pending_pings[payload] = (waiter, monotonic()) + self.transport.send_ping(payload) + return waiter + + async def pong(self, data: Union[str, bytes] = b"") -> None: + if not self._is_in_open_state(): + await self._wait_close_and_raise() + + self.transport.send_pong(data) + + async def _keepalive_loop(self) -> None: + try: + while True: + assert self._ping_interval is not None + await asyncio.sleep(self._ping_interval) + waiter = await self.ping() + if self._ping_timeout is None: + continue + await asyncio.wait_for(waiter, self._ping_timeout) + except asyncio.CancelledError: + raise + except Exception: + if self.state is not State.CLOSED: + await self.close(code=1011, reason="keepalive ping timeout") + + async def __aenter__(self): # type: ignore[no-untyped-def] + return self + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + await self.close() + + def __aiter__(self) -> AsyncIterator[Union[str, bytes]]: + return self._iterate_messages() + + async def _iterate_messages(self) -> AsyncIterator[Data]: + while True: + try: + yield await self.recv() + except ConnectionClosedOK: + return + + @property + def state(self) -> State: + if self.transport is None: + return State.CONNECTING + elif self.transport.is_disconnected: + return State.CLOSED + elif self.transport.is_close_frame_sent or self.transport.close_handshake is not None: + return State.CLOSING + else: + return State.OPEN + + @property + def request(self) -> Request: + return self._request + + @property + def response(self) -> Response: + return self._response + + @property + def local_address(self) -> Any: + return self.transport.underlying_transport.get_extra_info("sockname") + + @property + def remote_address(self) -> Any: + return self.transport.underlying_transport.get_extra_info("peername") + + @property + def latency(self) -> float: + return self._latency # type: ignore[no-any-return] + + @property + def subprotocol(self) -> Optional[Subprotocol]: + return self._subprotocol + + @property + def close_code(self) -> Optional[int]: + handshake = self.transport.close_handshake + if handshake is None: + return None + if handshake.recv is not None: + return _coerce_close_code(handshake.recv.code) # type: ignore[no-any-return] + if handshake.sent is not None: + return _coerce_close_code(handshake.sent.code) # type: ignore[no-any-return] + return None + + @property + def close_reason(self) -> Optional[str]: + handshake = self.transport.close_handshake + if handshake is None: + return None + if handshake.recv is not None: + return _coerce_close_reason(handshake.recv.reason) # type: ignore[no-any-return] + if handshake.sent is not None: + return _coerce_close_reason(handshake.sent.reason) # type: ignore[no-any-return] + return None + + +@cython.cclass +class ClientConnection(ConnectionBase): + def __init__( + self, + *, + request: Request, + response: Response, + subprotocol: Optional[Subprotocol], + permessage_deflate: Optional[_PerMessageDeflate], + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = 10, + max_queue: Union[int, tuple[Optional[int], Optional[int]], None] = 16, + write_limit: Union[int, tuple[int, Optional[int]]] = 32768, + max_message_size: Optional[int] = 1024 * 1024, + logger: LoggerLike = None, + ): + super().__init__( + request=request, + response=response, + subprotocol=subprotocol, + permessage_deflate=permessage_deflate, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, + max_message_size=max_message_size, + logger=logger, + ) + + @cython.ccall + def on_ws_connected(self, transport: WSTransport) -> None: + self.transport = transport + self._set_write_limits(self._write_limit) + if self._ping_interval is not None and self._keepalive_task is None: + self._keepalive_task = asyncio.create_task(self._keepalive_loop()) + + +def broadcast_message(connection: ConnectionBase, msg_type: WSMsgType, message: DataLike) -> bool: + if connection._send_in_progress: + return False + connection._encode_and_send(msg_type, message, True) + return True + + +@cython.cclass +class ServerConnection(ConnectionBase): + server: Any + _username: Optional[str] + + def __init__( + self, + server: Any, + *, + request: Request, + response: Response, + subprotocol: Optional[Subprotocol], + permessage_deflate: Optional[_PerMessageDeflate], + username: Optional[str] = None, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = 10, + max_queue: Union[int, tuple[Optional[int], Optional[int]], None] = 16, + write_limit: Union[int, tuple[int, Optional[int]]] = 32768, + max_message_size: Optional[int] = 1024 * 1024, + logger: LoggerLike = None, + ): + super().__init__( + request=request, + response=response, + subprotocol=subprotocol, + permessage_deflate=permessage_deflate, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, + max_message_size=max_message_size, + logger=logger, + ) + self.server = server + self._username = username + + @cython.ccall + def on_ws_connected(self, transport: WSTransport) -> None: + self.transport = transport + self._set_write_limits(self._write_limit) + if self._ping_interval is not None and self._keepalive_task is None: + self._keepalive_task = asyncio.create_task(self._keepalive_loop()) + self.server.loop.call_soon(self.server.start_connection_handler, self) + + @property + def username(self) -> Optional[str]: + return self._username diff --git a/picows/websockets/asyncio/negotiation.py b/picows/websockets/asyncio/negotiation.py new file mode 100644 index 0000000..e83f5bd --- /dev/null +++ b/picows/websockets/asyncio/negotiation.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import Optional, Sequence + +from ..compat import Response +from ..exceptions import InvalidHandshake +from ..typing import Subprotocol +from .connection import _PerMessageDeflate + + +def resolve_subprotocol( + subprotocols: Optional[Sequence[Subprotocol]], + response: Response, +) -> Optional[Subprotocol]: + value = response.headers.get("Sec-WebSocket-Protocol") + if value is None: + return None + if not isinstance(value, str): + raise InvalidHandshake("server returned non-string subprotocol") + if subprotocols is not None and value not in subprotocols: + raise InvalidHandshake(f"unsupported subprotocol negotiated by server: {value}") + return value + + +def configure_permessage_deflate( + response: Response, + compression: Optional[str], +) -> Optional[_PerMessageDeflate]: + header_value = response.headers.get("Sec-WebSocket-Extensions") + if header_value is None: + return None + if compression != "deflate": + raise InvalidHandshake("unexpected websocket extensions negotiated by server") + if not isinstance(header_value, str): + raise InvalidHandshake("invalid Sec-WebSocket-Extensions header") + return _PerMessageDeflate.from_response_header(header_value) diff --git a/picows/websockets/asyncio/router.py b/picows/websockets/asyncio/router.py new file mode 100644 index 0000000..d9a4c5a --- /dev/null +++ b/picows/websockets/asyncio/router.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from typing import Any + + +def route(*args: Any, **kwargs: Any) -> Any: + raise NotImplementedError("route() requires unsupported server process_request hooks") diff --git a/picows/websockets/asyncio/server.py b/picows/websockets/asyncio/server.py new file mode 100644 index 0000000..c018d45 --- /dev/null +++ b/picows/websockets/asyncio/server.py @@ -0,0 +1,551 @@ +from __future__ import annotations + +import asyncio +import binascii +import hmac +import http +import socket +import sys +from base64 import b64decode +from dataclasses import dataclass +from collections.abc import Awaitable, Callable, Iterable +from inspect import isawaitable +from logging import getLogger +from typing import Any, Optional, Pattern, Sequence + +import picows +from multidict import CIMultiDict + +from .connection import ( + ServerConnection, + _resolve_logger, + broadcast_message, +) +from .negotiation import configure_permessage_deflate +from ..compat import Request, Response, State +from ..exceptions import ConcurrencyError, InvalidHandshake, InvalidOrigin +from ..typing import DataLike, LoggerLike, Origin, Subprotocol + +if sys.version_info >= (3, 11): + from builtins import ExceptionGroup +else: + ExceptionGroup = Exception + +__all__ = [ + "ServerConnection", + "ServerHandshakeConnection", + "Server", + "serve", + "broadcast", + "basic_auth", +] + + +_PERMESSAGE_DEFLATE_REQUEST = "permessage-deflate" + + +def _default_server_header() -> str: + return f"Python/{sys.version_info.major}.{sys.version_info.minor} picows-websockets/0" + + +def _supports_permessage_deflate(request: Request) -> bool: + value = request.headers.get("Sec-WebSocket-Extensions") + return isinstance(value, str) and "permessage-deflate" in value + + +def _origin_allowed( + origin: str | None, + origins: Sequence[Origin | Pattern[str] | None] | None, +) -> bool: + if origins is None: + return True + for candidate in origins: + if candidate is None: + if origin is None: + return True + elif isinstance(candidate, str): + if origin == candidate: + return True + elif candidate.fullmatch(origin or "") is not None: + return True + return False + + +def _ensure_sync_result(value: Any, hook_name: str) -> Any: + if isawaitable(value): + close = getattr(value, "close", None) + if close is not None: + close() + raise NotImplementedError(f"async {hook_name} hooks aren't supported by picows.websockets server yet") + return value + + +def _make_error_response( + status: http.HTTPStatus, + body: bytes, +) -> Response: + return Response( + status_code=int(status), + reason_phrase=status.phrase, + headers=CIMultiDict({"Content-Type": "text/plain; charset=utf-8"}), + body=body, + ) + + +def _basic_auth_unauthorized_response(message: bytes, realm: str) -> Response: + response = _make_error_response(http.HTTPStatus.UNAUTHORIZED, message) + response.headers["WWW-Authenticate"] = f'Basic realm="{realm}"' + return response + + +def _parse_basic_authorization(header_value: str) -> tuple[str, str]: + scheme, _, token = header_value.partition(" ") + if scheme.lower() != "basic" or not token: + raise ValueError("unsupported authorization scheme") + try: + decoded = b64decode(token, validate=True).decode("utf-8") + except (ValueError, UnicodeDecodeError, binascii.Error) as exc: + raise ValueError("invalid basic authorization header") from exc + username, separator, password = decoded.partition(":") + if not separator: + raise ValueError("invalid basic authorization header") + return username, password + + +def _select_subprotocol( + connection: ServerHandshakeConnection, + request: Request, + subprotocols: Optional[Sequence[Subprotocol]], + select_subprotocol: Optional[Callable[[ServerHandshakeConnection, Sequence[Subprotocol]], Subprotocol | None]], +) -> Optional[Subprotocol]: + header_value = request.headers.get("Sec-WebSocket-Protocol") + if header_value is None: + return None + offered = [item.strip() for item in header_value.split(",") if item.strip()] + if not offered: + return None + if select_subprotocol is not None: + selected = select_subprotocol(connection, offered) + if selected is not None and selected not in offered: + raise InvalidHandshake(f"selected subprotocol isn't offered by client: {selected}") + return selected + if subprotocols is None: + return None + for subprotocol in subprotocols: + if subprotocol in offered: + return subprotocol + return None + + +def _resolve_response_subprotocol( + request: Request, + response: Response, +) -> Optional[Subprotocol]: + value = response.headers.get("Sec-WebSocket-Protocol") + if value is None: + return None + if not isinstance(value, str): + raise InvalidHandshake("invalid Sec-WebSocket-Protocol header") + header_value = request.headers.get("Sec-WebSocket-Protocol") + if header_value is None: + raise InvalidHandshake("server negotiated a subprotocol without a client offer") + offered = [item.strip() for item in header_value.split(",") if item.strip()] + if value not in offered: + raise InvalidHandshake(f"selected subprotocol isn't offered by client: {value}") + return value + + +def basic_auth( + realm: str = "", + credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None, + check_credentials: Callable[[str, str], bool] | None = None, +) -> Callable[[ServerHandshakeConnection, Request], Response | None]: + if (credentials is None) == (check_credentials is None): + raise ValueError("provide either credentials or check_credentials") + + if credentials is not None: + if ( + isinstance(credentials, tuple) + and len(credentials) == 2 + and all(isinstance(item, str) for item in credentials) + ): + username = credentials[0] + password = credentials[1] + assert isinstance(username, str) + assert isinstance(password, str) + credentials_dict: dict[str, str] = {username: password} + elif isinstance(credentials, Iterable): + credentials_list: list[tuple[str, str]] = [] + for item in credentials: + if ( + not isinstance(item, tuple) + or len(item) != 2 + or not isinstance(item[0], str) + or not isinstance(item[1], str) + ): + raise TypeError(f"invalid credentials argument: {credentials}") + credentials_list.append(item) + credentials_dict = dict(credentials_list) + else: + raise TypeError(f"invalid credentials argument: {credentials}") + + def check_credentials(username: str, password: str) -> bool: + expected_password: str | None = credentials_dict.get(username) + return ( + expected_password is not None + and hmac.compare_digest(expected_password, password) + ) + + assert check_credentials is not None + + def process_request( + connection: ServerHandshakeConnection, + request: Request, + ) -> Response | None: + authorization = request.headers.get("Authorization") + if authorization is None: + return _basic_auth_unauthorized_response(b"Missing credentials\n", realm) + + try: + username, password = _parse_basic_authorization(authorization) + except ValueError: + return _basic_auth_unauthorized_response(b"Unsupported credentials\n", realm) + + valid_credentials = check_credentials(username, password) + if isawaitable(valid_credentials): + close = getattr(valid_credentials, "close", None) + if close is not None: + close() + raise NotImplementedError("async basic_auth credential checks aren't supported yet") + if not valid_credentials: + return _basic_auth_unauthorized_response(b"Invalid credentials\n", realm) + + connection.username = username + return None + + return process_request + + +@dataclass +class ServerHandshakeConnection: + request: Request + username: Optional[str] = None + + @property + def state(self) -> State: + return State.CONNECTING + + +class Server: + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + *, + logger: LoggerLike | None = None, + ) -> None: + self.loop = asyncio.get_running_loop() + self.handler = handler + self.logger = _resolve_logger(logger if logger is not None else getLogger("websockets.server")) + self.handlers: dict[ServerConnection, asyncio.Task[None]] = {} + self.close_task: asyncio.Task[None] | None = None + self.closed_waiter: asyncio.Future[None] = self.loop.create_future() + self.server: asyncio.Server + + @property + def connections(self) -> set[ServerConnection]: + return {connection for connection in self.handlers if connection.state is State.OPEN} + + def wrap(self, server: asyncio.Server) -> None: + self.server = server + for sock in server.sockets: + if sock.family == socket.AF_INET: + name = "%s:%d" % sock.getsockname() + elif sock.family == socket.AF_INET6: + name = "[%s]:%d" % sock.getsockname()[:2] + else: + name = str(sock.getsockname()) + self.logger.info("server listening on %s", name) + + async def conn_handler(self, connection: ServerConnection) -> None: + try: + try: + await asyncio.sleep(0) + await self.handler(connection) + except Exception: + self.logger.error("connection handler failed", exc_info=True) + await connection.close(1011) + else: + await connection.close() + finally: + del self.handlers[connection] + + def start_connection_handler(self, connection: ServerConnection) -> None: + self.handlers[connection] = self.loop.create_task(self.conn_handler(connection)) + + def close( + self, + close_connections: bool = True, + code: int = 1001, + reason: str = "", + ) -> None: + if self.close_task is None: + self.close_task = self.loop.create_task(self._close(close_connections, code, reason)) + + async def _close( + self, + close_connections: bool = True, + code: int = 1001, + reason: str = "", + ) -> None: + self.logger.info("server closing") + self.server.close() + await asyncio.sleep(0) + if close_connections: + close_tasks = [ + asyncio.create_task(connection.close(code, reason)) + for connection in self.handlers + if connection.state is not State.CONNECTING + ] + if close_tasks: + await asyncio.wait(close_tasks) + await self.server.wait_closed() + if self.handlers: + await asyncio.wait(self.handlers.values()) + self.closed_waiter.set_result(None) + self.logger.info("server closed") + + async def wait_closed(self) -> None: + await asyncio.shield(self.closed_waiter) + + def get_loop(self) -> asyncio.AbstractEventLoop: + return self.server.get_loop() + + def is_serving(self) -> bool: + return self.server.is_serving() + + async def start_serving(self) -> None: + await self.server.start_serving() + + async def serve_forever(self) -> None: + await self.server.serve_forever() + + @property + def sockets(self) -> tuple[socket.socket, ...]: + return self.server.sockets + + async def __aenter__(self) -> Server: + return self + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + await self.wait_closed() + + +class serve: + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + host: str | None = None, + port: int | None = None, + *, + origins: Sequence[Origin | Pattern[str] | None] | None = None, + extensions: Sequence[Any] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: Callable[[ServerHandshakeConnection, Sequence[Subprotocol]], Subprotocol | None] | None = None, + compression: str | None = "deflate", + server_header: str | None = _default_server_header(), + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_size: int | None | tuple[int | None, int | None] = 1024 * 1024, + max_queue: int | None | tuple[int | None, int | None] = 16, + write_limit: int | tuple[int, int | None] = 32768, + logger: LoggerLike | None = None, + **kwargs: Any, + ): + self.handler = handler + self.host = host + self.port = port + self.origins = origins + self.extensions = extensions + self.subprotocols = subprotocols + self.select_subprotocol = select_subprotocol + self.compression = compression + self.server_header = server_header + self.open_timeout = open_timeout + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.close_timeout = close_timeout + self.max_size = max_size + self.max_queue = max_queue + self.write_limit = write_limit + self.logger = logger + if "create_connection" in kwargs: + raise NotImplementedError("create_connection isn't supported by picows.websockets server yet") + self.process_request = kwargs.pop("process_request", None) + self.process_response = kwargs.pop("process_response", None) + self.kwargs = kwargs + self._server: Server | None = None + + def __await__(self) -> Any: + return self._create().__await__() + + async def __aenter__(self) -> Server: + self._server = await self._create() + return self._server + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + assert self._server is not None + self._server.close() + await self._server.wait_closed() + self._server = None + + async def _create(self) -> Server: + if self.extensions is not None: + raise NotImplementedError("custom server extensions aren't supported by picows.websockets") + if self.compression not in (None, "deflate"): + raise NotImplementedError("only compression=None or 'deflate' are accepted") + + server = Server( + self.handler, + logger=self.logger, + ) + + max_message_size = self.max_size[0] if isinstance(self.max_size, tuple) else self.max_size + max_frame_size = 2 ** 31 - 1 if max_message_size is None else max_message_size + + def listener_factory( + upgrade_request: picows.WSUpgradeRequest, + ) -> picows.WSUpgradeResponseWithListener: + request = Request.from_picows(upgrade_request) + handshake_connection = ServerHandshakeConnection(request) + response: Response + + origin = request.headers.get("Origin") + if origin is not None and not isinstance(origin, str): + raise InvalidOrigin(None) + if not _origin_allowed(origin, self.origins): + response = _make_error_response( + http.HTTPStatus.FORBIDDEN, + b"Origin not allowed\n", + ) + elif server.close_task is not None: + response = _make_error_response( + http.HTTPStatus.SERVICE_UNAVAILABLE, + b"Server is shutting down.\n", + ) + else: + headers = {} + if self.server_header is not None: + headers["Server"] = self.server_header + subprotocol = _select_subprotocol( + handshake_connection, + request, + self.subprotocols, + self.select_subprotocol, + ) + if subprotocol is not None: + headers["Sec-WebSocket-Protocol"] = subprotocol + if self.compression == "deflate" and _supports_permessage_deflate(request): + headers["Sec-WebSocket-Extensions"] = _PERMESSAGE_DEFLATE_REQUEST + response = Response.from_picows(picows.WSUpgradeResponse.create_101_response(headers)) + + if self.process_request is not None: + response_or_none = _ensure_sync_result( + self.process_request(handshake_connection, request), + "process_request", + ) + if response_or_none is not None: + if not isinstance(response_or_none, Response): + raise TypeError("process_request must return a Response or None") + response = response_or_none + + response = _ensure_sync_result( + self.process_response(handshake_connection, request, response), + "process_response", + ) if self.process_response is not None else response + + if not isinstance(response, Response): + raise TypeError("process_response must return a Response") + + if response.status_code == int(http.HTTPStatus.SWITCHING_PROTOCOLS): + subprotocol = _resolve_response_subprotocol(request, response) + permessage_deflate = configure_permessage_deflate(response, self.compression) + connection = ServerConnection( + server, + request=request, + response=response, + subprotocol=subprotocol, + permessage_deflate=permessage_deflate, + username=handshake_connection.username, + ping_interval=self.ping_interval, + ping_timeout=self.ping_timeout, + close_timeout=self.close_timeout, + max_queue=self.max_queue, + write_limit=self.write_limit, + max_message_size=max_message_size, + logger=self.logger, + ) + return picows.WSUpgradeResponseWithListener(response.to_picows(), connection) + else: + return picows.WSUpgradeResponseWithListener(response.to_picows(), None) + + logger_name: Any = self.logger if self.logger is not None else getLogger("websockets.server") + raw_server = await picows.ws_create_server( + listener_factory, + self.host, + self.port, + websocket_handshake_timeout=self.open_timeout, + enable_auto_ping=False, + enable_auto_pong=True, + max_frame_size=max_frame_size, + logger_name=logger_name, + **self.kwargs, + ) + server.wrap(raw_server) + return server + + +def broadcast( + connections: Iterable[ServerConnection], + message: DataLike, + raise_exceptions: bool = False, +) -> None: + if isinstance(message, str): + msg_type = picows.WSMsgType.TEXT + elif isinstance(message, (bytes, bytearray, memoryview)): + msg_type = picows.WSMsgType.BINARY + else: + raise TypeError("data must be str or bytes") + + if raise_exceptions: + if sys.version_info[:2] < (3, 11): + raise ValueError("raise_exceptions requires at least Python 3.11") + exceptions: list[Exception] = [] + + for connection in connections: + if connection.state is not State.OPEN: + continue + try: + sent = broadcast_message(connection, msg_type, message) + except Exception as write_exception: + if raise_exceptions: + exception = RuntimeError("failed to write message") + exception.__cause__ = write_exception + exceptions.append(exception) + else: + getLogger("websockets.server").warning( + "skipped broadcast: failed to write message: %s", + write_exception, + ) + continue + + if not sent: + if raise_exceptions: + exceptions.append(ConcurrencyError("sending a fragmented message")) + else: + getLogger("websockets.server").warning("skipped broadcast: sending a fragmented message") + + if raise_exceptions and exceptions: + raise ExceptionGroup("skipped broadcast", exceptions) diff --git a/picows/websockets/compat.py b/picows/websockets/compat.py new file mode 100644 index 0000000..8c1b1a7 --- /dev/null +++ b/picows/websockets/compat.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import IntEnum +from http import HTTPStatus +from typing import Union + +import picows +from multidict import CIMultiDict + +CloseCode = picows.WSCloseCode + + +class State(IntEnum): + CONNECTING = 0 + OPEN = 1 + CLOSING = 2 + CLOSED = 3 + + +@dataclass +class Request: + path: str + headers: CIMultiDict[str] + + @classmethod + def from_picows(cls, request: picows.WSUpgradeRequest) -> Request: + return cls( + path=request.path.decode("ascii", "surrogateescape"), + headers=request.headers, + ) + + +@dataclass +class Response: + status_code: int + reason_phrase: str + headers: CIMultiDict[str] + body: Union[bytes, bytearray] + + @classmethod + def from_picows(cls, response: picows.WSUpgradeResponse) -> Response: + return cls( + status_code=int(response.status), + reason_phrase=response.status.phrase, + headers=response.headers, + body=b"" if response.body is None else response.body, + ) + + @property + def status(self) -> int: + return self.status_code + + def to_picows(self) -> picows.WSUpgradeResponse: + response = picows.WSUpgradeResponse() + response.version = b"HTTP/1.1" + response.status = HTTPStatus(self.status_code) + response.headers = self.headers.copy() + response.body = bytes(self.body) + return response + + +__all__ = [ + "State", + "CloseCode", + "Request", + "Response", +] diff --git a/picows/websockets/exceptions.py b/picows/websockets/exceptions.py new file mode 100644 index 0000000..c2e198d --- /dev/null +++ b/picows/websockets/exceptions.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +from typing import Any, Optional + +__all__ = [ + "WebSocketException", + "ConnectionClosed", + "ConnectionClosedOK", + "ConnectionClosedError", + "InvalidURI", + "InvalidProxy", + "InvalidHandshake", + "SecurityError", + "ProxyError", + "InvalidProxyMessage", + "InvalidProxyStatus", + "InvalidMessage", + "InvalidStatus", + "InvalidHeader", + "InvalidHeaderFormat", + "InvalidHeaderValue", + "InvalidOrigin", + "InvalidUpgrade", + "NegotiationError", + "DuplicateParameter", + "InvalidParameterName", + "InvalidParameterValue", + "ProtocolError", + "PayloadTooBig", + "InvalidState", + "ConcurrencyError", +] + + +class WebSocketException(Exception): + """Base class for exceptions defined by picows.websockets.""" + + +class ConnectionClosed(WebSocketException): + def __init__(self, rcvd: Any, sent: Any, rcvd_then_sent: Optional[bool] = None): + super().__init__() + self.rcvd = rcvd + self.sent = sent + self.rcvd_then_sent = rcvd_then_sent + + def __str__(self) -> str: + if self.rcvd is None and self.sent is None: + return "no close frame received or sent" + if self.rcvd is None: + return f"sent {self.sent.code} ({self.sent.reason})" + if self.sent is None: + return f"received {self.rcvd.code} ({self.rcvd.reason})" + order = "received then sent" if self.rcvd_then_sent else "sent then received" + return ( + f"{order} close frames: " + f"received {self.rcvd.code} ({self.rcvd.reason}), " + f"sent {self.sent.code} ({self.sent.reason})" + ) + + +class ConnectionClosedOK(ConnectionClosed): + pass + + +class ConnectionClosedError(ConnectionClosed): + pass + + +class InvalidURI(WebSocketException): + def __init__(self, uri: str, msg: str): + super().__init__(uri, msg) + self.uri = uri + self.msg = msg + + def __str__(self) -> str: + return f"{self.uri} isn't a valid WebSocket URI: {self.msg}" + + +class InvalidProxy(WebSocketException): + def __init__(self, proxy: str, msg: str): + super().__init__(proxy, msg) + self.proxy = proxy + self.msg = msg + + def __str__(self) -> str: + return f"{self.proxy} isn't a valid proxy: {self.msg}" + + +class InvalidHandshake(WebSocketException): + pass + + +class SecurityError(InvalidHandshake): + pass + + +class ProxyError(InvalidHandshake): + pass + + +class InvalidProxyMessage(ProxyError): + pass + + +class InvalidProxyStatus(ProxyError): + def __init__(self, response: Any): + super().__init__(response) + self.response = response + + def __str__(self) -> str: + status = getattr(self.response, "status_code", None) + if status is None: + return "proxy rejected connection" + return f"proxy rejected connection: HTTP {int(status):d}" + + +class InvalidMessage(InvalidHandshake): + pass + + +class InvalidStatus(InvalidHandshake): + def __init__(self, response: Any): + super().__init__(response) + self.response = response + + +class InvalidHeader(InvalidHandshake): + def __init__(self, name: str, value: Optional[str] = None): + super().__init__(name, value) + self.name = name + self.value = value + + +class InvalidUpgrade(InvalidHeader): + pass + + +class InvalidHeaderFormat(InvalidHeader): + def __init__(self, name: str, error: str, header: str, pos: int): + super().__init__(name, f"{error} at {pos} in {header}") + + +class InvalidHeaderValue(InvalidHeader): + pass + + +class InvalidOrigin(InvalidHeader): + def __init__(self, origin: Optional[str]): + super().__init__("Origin", origin) + + +class NegotiationError(InvalidHandshake): + pass + + +class DuplicateParameter(NegotiationError): + def __init__(self, name: str): + super().__init__(name) + self.name = name + + def __str__(self) -> str: + return f"duplicate parameter: {self.name}" + + +class InvalidParameterName(NegotiationError): + def __init__(self, name: str): + super().__init__(name) + self.name = name + + def __str__(self) -> str: + return f"invalid parameter name: {self.name}" + + +class InvalidParameterValue(NegotiationError): + def __init__(self, name: str, value: Optional[str]): + super().__init__(name, value) + self.name = name + self.value = value + + def __str__(self) -> str: + if self.value is None: + return f"missing value for parameter {self.name}" + if self.value == "": + return f"empty value for parameter {self.name}" + return f"invalid value for parameter {self.name}: {self.value}" + + +class ProtocolError(WebSocketException): + pass + + +class PayloadTooBig(WebSocketException): + pass + + +class InvalidState(WebSocketException, AssertionError): + pass + + +class ConcurrencyError(WebSocketException, RuntimeError): + pass diff --git a/picows/websockets/typing.py b/picows/websockets/typing.py new file mode 100644 index 0000000..c70f878 --- /dev/null +++ b/picows/websockets/typing.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import Any, Optional, Protocol, Tuple, Union + +from picows.types import WSHeadersLike + +BytesLike = Union[bytes, bytearray, memoryview] +Data = Union[str, bytes] +DataLike = Union[str, bytes, bytearray, memoryview] +HeadersLike = WSHeadersLike + + +class LoggerProtocol(Protocol): + @property + def debug(self) -> Any: + ... + + @property + def info(self) -> Any: + ... + + @property + def warning(self) -> Any: + ... + + @property + def error(self) -> Any: + ... + + +LoggerLike = Union[LoggerProtocol, str, None] +StatusLike = Union[HTTPStatus, int] +Origin = str +Subprotocol = str +ExtensionName = str +ExtensionParameter = Tuple[str, Optional[str]] + +__all__ = [ + "BytesLike", + "Data", + "DataLike", + "ExtensionName", + "ExtensionParameter", + "HeadersLike", + "LoggerLike", + "LoggerProtocol", + "Origin", + "StatusLike", + "Subprotocol", +] diff --git a/pyproject.toml b/pyproject.toml index dc3065f..7ebb5b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,8 +36,8 @@ Repository = "https://github.com/tarasko/picows" Issues = "https://github.com/tarasko/picows/issues" Documentation = "https://picows.readthedocs.io/en/latest" -[tool.setuptools] -packages = ["picows"] +[tool.setuptools.packages.find] +include = ["picows*"] [tool.setuptools.dynamic] dependencies = {file = ["requirements.txt"]} @@ -56,6 +56,10 @@ files = "picows" ignore_missing_imports = true strict = true +[[tool.mypy.overrides]] +module = ["picows.websockets.asyncio.connection"] +disable_error_code = ["untyped-decorator"] + [tool.coverage.run] source = ["picows"] branch = true diff --git a/requirements-test.txt b/requirements-test.txt index b6707b8..4bd372d 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -13,3 +13,4 @@ uvloop; sys_platform != 'win32' winloop; sys_platform == 'win32' tiny-proxy mypy +websockets \ No newline at end of file diff --git a/setup.py b/setup.py index f5996c4..b23c5b6 100644 --- a/setup.py +++ b/setup.py @@ -133,6 +133,11 @@ def build_extension(self, ext: Extension): depends=["picows/compat.h"], extra_compile_args=extra_compile_args, extra_link_args=extra_link_args), + Extension("picows.websockets.asyncio.connection", ["picows/websockets/asyncio/connection.py"], + libraries=libs, define_macros=macros, + depends=["picows/compat.h"], + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args), ] if with_examples: diff --git a/tests/test_websockets_client.py b/tests/test_websockets_client.py new file mode 100644 index 0000000..04e5b6d --- /dev/null +++ b/tests/test_websockets_client.py @@ -0,0 +1,168 @@ +import asyncio +import socket + +import pytest +from multidict import CIMultiDict + +import picows +from picows import websockets +from picows.websockets.asyncio.client import _process_proxy +from picows.websockets.asyncio.connection import _normalize_watermarks, _resolve_logger, process_exception +from tests.utils import ServerEchoListener, WSServer + + +def test_client_private_option_helpers(): + assert _process_proxy(None, False) is None + assert _process_proxy("http://127.0.0.1:8080", False) == "http://127.0.0.1:8080" + with pytest.raises(websockets.InvalidProxy): + _process_proxy(123, False) # type: ignore[arg-type] + + assert _normalize_watermarks(None) == (0, 0) + assert _normalize_watermarks((None, 1)) == (0, 0) + assert _normalize_watermarks((8, None)) == (8, 2) + assert _resolve_logger("picows.test").name == "picows.test" + + +async def test_client_connection_starts_in_connecting_state(): + connection = websockets.ClientConnection( + request=websockets.Request("/", CIMultiDict()), + response=websockets.Response(101, "Switching Protocols", CIMultiDict(), b""), + subprotocol=None, + permessage_deflate=None, + ) + + assert connection.state is websockets.State.CONNECTING + + +async def test_connect_await_style_and_socket_options(): + async with WSServer() as server: + ws = await websockets.connect(server.url, compression=None, ping_interval=None) + try: + assert ws.state is websockets.State.OPEN + assert ws.local_address[0] == "127.0.0.1" + assert ws.remote_address[0] == "127.0.0.1" + assert ws.latency == 0 + assert ws.subprotocol is None + assert ws.close_code is None + assert ws.close_reason is None + await ws.send("awaited") + assert await ws.recv() == "awaited" + finally: + await ws.close() + + sock = socket.create_connection((server.host, server.port)) + ws = await websockets.connect(server.url, compression=None, ping_interval=None, sock=sock) + try: + await ws.send(b"sock") + assert await ws.recv() == b"sock" + finally: + await ws.close() + + async with websockets.connect( + "ws://example.invalid/", + compression=None, + ping_interval=None, + proxy=None, + host=server.host, + port=server.port, + ) as ws: + await ws.send("override") + assert await ws.recv() == "override" + + +async def test_connect_rejects_conflicting_and_invalid_socket_options(): + async with WSServer() as server: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + with pytest.raises(TypeError, match="cannot pass both sock and socket_factory"): + await websockets.connect(server.url, compression=None, sock=sock, socket_factory=lambda _: None) + finally: + sock.close() + + with pytest.raises(TypeError, match="sock must be a socket.socket instance"): + await websockets.connect(server.url, compression=None, sock=object()) + + with pytest.raises(TypeError, match="cannot pass both host/port override and socket_factory"): + await websockets.connect( + server.url, + compression=None, + host=server.host, + socket_factory=lambda _: None, + ) + + +async def test_connect_rejects_invalid_ssl_options_before_network(): + with pytest.raises(NotImplementedError, match="ssl=False"): + await websockets.connect("wss://example.com/", compression=None, ssl=False) + with pytest.raises(TypeError, match="ssl must be"): + await websockets.connect("wss://example.com/", compression=None, ssl=object()) + + +def test_process_exception_retries_transient_failures(): + assert process_exception(EOFError()) is None + assert process_exception(OSError()) is None + assert process_exception(asyncio.TimeoutError()) is None + + response = websockets.Response(503, "Service Unavailable", CIMultiDict(), b"") + assert process_exception(websockets.InvalidStatus(response)) is None + + error = RuntimeError("boom") + assert process_exception(error) is error + + +async def test_connect_async_iterator_retries_then_succeeds(): + attempts = 0 + + def process_exception(exc: Exception) -> Exception | None: + nonlocal attempts + attempts += 1 + if attempts == 1: + return None + return exc + + connector = websockets.connect( + "ws://127.0.0.1:1/", + compression=None, + open_timeout=0.01, + process_exception=process_exception, + ) + connector._backoff = 0 + with pytest.raises(OSError): + async for _ws in connector: + pass + assert attempts == 2 + + +def test_connect_rejects_create_connection(): + with pytest.raises(NotImplementedError): + websockets.connect("ws://example.com", create_connection=websockets.ClientConnection) + + +async def test_connection_context_manager_and_close_timeout_none(): + async with WSServer() as server: + ws = await websockets.connect(server.url, compression=None, ping_interval=None, close_timeout=None) + async with ws: + await ws.send("context") + assert await ws.recv() == "context" + assert ws.state is websockets.State.CLOSED + + +class IgnoreCloseListener(ServerEchoListener): + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + if frame.msg_type == picows.WSMsgType.CLOSE: + return + super().on_ws_frame(transport, frame) + + +async def test_close_timeout_disconnects_when_peer_ignores_close(): + async with WSServer(lambda _: IgnoreCloseListener()) as server: + ws = await websockets.connect( + server.url, + compression=None, + ping_interval=None, + close_timeout=0.01, + ) + await ws.close() + assert ws.state is websockets.State.CLOSED + assert ws.close_code == 1000 + assert ws.close_reason == "" diff --git a/tests/test_websockets_compat_types.py b/tests/test_websockets_compat_types.py new file mode 100644 index 0000000..abdb948 --- /dev/null +++ b/tests/test_websockets_compat_types.py @@ -0,0 +1,20 @@ +from http import HTTPStatus + +from multidict import CIMultiDict + +from picows import websockets + + +def test_response_to_picows_supports_empty_body_and_status_alias(): + response = websockets.Response( + int(HTTPStatus.SWITCHING_PROTOCOLS), + HTTPStatus.SWITCHING_PROTOCOLS.phrase, + CIMultiDict({"X-Test": "yes"}), + bytearray(b"body"), + ) + + assert response.status == 101 + picows_response = response.to_picows() + assert picows_response.status is HTTPStatus.SWITCHING_PROTOCOLS + assert picows_response.headers["X-Test"] == "yes" + assert picows_response.body == b"body" diff --git a/tests/test_websockets_exceptions.py b/tests/test_websockets_exceptions.py new file mode 100644 index 0000000..20bc893 --- /dev/null +++ b/tests/test_websockets_exceptions.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass + +from multidict import CIMultiDict + +from picows import websockets + + +@dataclass +class Close: + code: int + reason: str + + +def test_connection_closed_string_variants(): + assert str(websockets.ConnectionClosed(None, None)) == "no close frame received or sent" + assert str(websockets.ConnectionClosed(None, Close(1000, "bye"))) == "sent 1000 (bye)" + assert str(websockets.ConnectionClosed(Close(1001, "away"), None)) == "received 1001 (away)" + assert str(websockets.ConnectionClosed(Close(1001, "away"), Close(1000, "bye"), True)) == ( + "received then sent close frames: received 1001 (away), sent 1000 (bye)" + ) + assert str(websockets.ConnectionClosed(Close(1001, "away"), Close(1000, "bye"), False)) == ( + "sent then received close frames: received 1001 (away), sent 1000 (bye)" + ) + + +def test_exception_attributes_and_strings(): + assert str(websockets.InvalidURI("http://example.com", "wrong scheme")) == ( + "http://example.com isn't a valid WebSocket URI: wrong scheme" + ) + assert str(websockets.InvalidProxy("ftp://proxy", "wrong scheme")) == ( + "ftp://proxy isn't a valid proxy: wrong scheme" + ) + assert str(websockets.InvalidProxyStatus(object())) == "proxy rejected connection" + assert str(websockets.InvalidProxyStatus(websockets.Response(502, "Bad Gateway", CIMultiDict(), b""))) == ( + "proxy rejected connection: HTTP 502" + ) + + invalid_header = websockets.InvalidHeader("X-Test", "bad") + assert invalid_header.name == "X-Test" + assert invalid_header.value == "bad" + assert websockets.InvalidOrigin("https://bad.example").name == "Origin" + assert websockets.InvalidHeaderFormat("X-Test", "bad syntax", "x:y", 1).value == "bad syntax at 1 in x:y" + + assert str(websockets.DuplicateParameter("server_max_window_bits")) == ( + "duplicate parameter: server_max_window_bits" + ) + assert str(websockets.InvalidParameterName("x")) == "invalid parameter name: x" + assert str(websockets.InvalidParameterValue("x", None)) == "missing value for parameter x" + assert str(websockets.InvalidParameterValue("x", "")) == "empty value for parameter x" + assert str(websockets.InvalidParameterValue("x", "bad")) == "invalid value for parameter x: bad" diff --git a/tests/test_websockets_negotiation.py b/tests/test_websockets_negotiation.py new file mode 100644 index 0000000..cf027c8 --- /dev/null +++ b/tests/test_websockets_negotiation.py @@ -0,0 +1,312 @@ +import asyncio +import base64 +import hashlib +from contextlib import asynccontextmanager + +import picows +import pytest +import websockets as upstream_websockets +from multidict import CIMultiDict + +from picows import websockets +from picows.websockets.asyncio.connection import _PerMessageDeflate +from picows.websockets.asyncio.negotiation import configure_permessage_deflate, resolve_subprotocol +from tests.utils import WSServer + + +@asynccontextmanager +async def upstream_server(handler): + server = await upstream_websockets.serve( + handler, + "127.0.0.1", + 0, + compression="deflate", + ) + port = server.sockets[0].getsockname()[1] + try: + yield f"ws://127.0.0.1:{port}/" + finally: + server.close() + await server.wait_closed() + + +@asynccontextmanager +async def malformed_compressed_server(): + async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + request = await reader.readuntil(b"\r\n\r\n") + headers = request.decode("ascii").split("\r\n") + key = None + for header in headers: + if header.lower().startswith("sec-websocket-key:"): + key = header.split(":", 1)[1].strip() + break + assert key is not None + + accept = base64.b64encode( + hashlib.sha1( + (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode("ascii") + ).digest() + ).decode("ascii") + + response = ( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + f"Sec-WebSocket-Accept: {accept}\r\n" + "Sec-WebSocket-Extensions: permessage-deflate\r\n" + "\r\n" + ) + writer.write(response.encode("ascii")) + + payload = b"not-a-valid-deflate-stream" + frame = bytes([0xC1, len(payload)]) + payload + writer.write(frame) + await writer.drain() + await asyncio.sleep(0.1) + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(handler, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + try: + yield f"ws://127.0.0.1:{port}/" + finally: + server.close() + await server.wait_closed() + + +class Frame: + def __init__( + self, + msg_type: picows.WSMsgType, + payload: bytes, + *, + fin: bool = True, + rsv1: bool = False, + ): + self.msg_type = msg_type + self.payload = payload + self.fin = fin + self.rsv1 = rsv1 + + def get_payload_as_bytes(self) -> bytes: + return self.payload + + def get_payload_as_memoryview(self) -> memoryview: + return memoryview(self.payload) + + +async def test_subprotocol_header_and_property(): + request_headers = {} + + def listener_factory(request): + request_headers["value"] = request.headers.get("Sec-WebSocket-Protocol") + return None + + async with WSServer(listener_factory) as server: + with pytest.raises(websockets.InvalidStatus): + async with websockets.connect(server.url, compression=None, subprotocols=["chat"]): + pass + + assert request_headers["value"] == "chat" + + +async def test_serve_negotiates_subprotocol(): + async def handler(ws: websockets.ServerConnection) -> None: + await ws.wait_closed() + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + subprotocols=["chat", "superchat"], + ) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + subprotocols=["superchat", "chat"], + ) as ws: + assert ws.subprotocol == "chat" + + +async def test_select_subprotocol_receives_handshake_connection(): + seen = {} + + async def handler(ws: websockets.ServerConnection) -> None: + await ws.wait_closed() + + def select_subprotocol( + ws: websockets.ServerHandshakeConnection, + offered: list[str], + ) -> str | None: + seen["type"] = type(ws) + seen["path"] = ws.request.path + seen["state"] = ws.state + seen["has_recv"] = hasattr(ws, "recv") + if "chat" in offered: + return "chat" + return None + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + select_subprotocol=select_subprotocol, + ) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect( + f"ws://127.0.0.1:{port}/room", + compression=None, + subprotocols=["chat"], + ) as ws: + assert ws.subprotocol == "chat" + + assert seen["type"] is websockets.ServerHandshakeConnection + assert seen["path"] == "/room" + assert seen["state"] is websockets.State.CONNECTING + assert seen["has_recv"] is False + + +def test_negotiation_rejects_invalid_subprotocol_and_extension_headers(): + response = websockets.Response(101, "Switching Protocols", CIMultiDict(), b"") + + response.headers["Sec-WebSocket-Protocol"] = 123 # type: ignore[assignment] + with pytest.raises(websockets.InvalidHandshake, match="non-string subprotocol"): + resolve_subprotocol(["chat"], response) + + response.headers["Sec-WebSocket-Protocol"] = "other" + with pytest.raises(websockets.InvalidHandshake, match="unsupported subprotocol"): + resolve_subprotocol(["chat"], response) + + response.headers.clear() + response.headers["Sec-WebSocket-Extensions"] = "permessage-deflate" + with pytest.raises(websockets.InvalidHandshake, match="unexpected websocket extensions"): + configure_permessage_deflate(response, None) + + response.headers["Sec-WebSocket-Extensions"] = 123 # type: ignore[assignment] + with pytest.raises(websockets.InvalidHandshake, match="invalid Sec-WebSocket-Extensions"): + configure_permessage_deflate(response, "deflate") + + +def test_permessage_deflate_rejects_invalid_parameters(): + response = websockets.Response(101, "Switching Protocols", CIMultiDict(), b"") + invalid_headers = [ + "x-webkit-deflate-frame", + "permessage-deflate, permessage-deflate", + "permessage-deflate; server_no_context_takeover=true", + "permessage-deflate; client_no_context_takeover=true", + "permessage-deflate; server_max_window_bits", + "permessage-deflate; server_max_window_bits=7", + "permessage-deflate; client_max_window_bits", + "permessage-deflate; client_max_window_bits=16", + "permessage-deflate; unknown=value", + "permessage-deflate; server_max_window_bits=15; server_max_window_bits=15", + ] + + for header in invalid_headers: + response.headers["Sec-WebSocket-Extensions"] = header + with pytest.raises(websockets.InvalidHandshake): + configure_permessage_deflate(response, "deflate") + + +def test_permessage_deflate_accepts_no_context_takeover_parameters(): + permessage_deflate = _PerMessageDeflate.from_response_header( + "permessage-deflate; server_no_context_takeover; client_no_context_takeover" + ) + + first = permessage_deflate.encode_frame(picows.WSMsgType.TEXT, "hello", True) + second = permessage_deflate.encode_frame(picows.WSMsgType.TEXT, b"hello", True) + + assert isinstance(first, memoryview) + assert isinstance(second, memoryview) + assert bytes(first) + assert bytes(second) + + +def test_permessage_deflate_decode_passthrough_and_protocol_error_branches(): + permessage_deflate = _PerMessageDeflate.from_response_header( + "permessage-deflate; server_no_context_takeover; client_no_context_takeover" + ) + + assert permessage_deflate.decode_frame( + Frame(picows.WSMsgType.TEXT, b"plain", rsv1=False), + 0, + ) == b"plain" + assert permessage_deflate.decode_frame( + Frame(picows.WSMsgType.CONTINUATION, b"continuation", rsv1=False), + 0, + ) == b"continuation" + + encoded = bytes(permessage_deflate.encode_frame(picows.WSMsgType.TEXT, b"compressed", True)) + assert permessage_deflate.decode_frame( + Frame(picows.WSMsgType.TEXT, encoded, rsv1=True), + 100, + ) == b"compressed" + + with pytest.raises(picows.WSProtocolError, match="unexpected rsv1"): + permessage_deflate.decode_frame( + Frame(picows.WSMsgType.CONTINUATION, b"bad", rsv1=True), + 0, + ) + + +async def test_permessage_deflate_echo_with_upstream_server(): + async def handler(ws): + async for message in ws: + await ws.send(message) + + async with upstream_server(handler) as url: + async with websockets.connect(url, ping_interval=None) as ws: + assert "permessage-deflate" in (ws.response.headers.get("Sec-WebSocket-Extensions") or "") + + message = "hello " * 1000 + await ws.send(message) + assert await ws.recv() == message + + +async def test_permessage_deflate_fragmented_send_with_upstream_server(): + async def handler(ws): + async for message in ws: + await ws.send(message) + + async with upstream_server(handler) as url: + async with websockets.connect(url, ping_interval=None) as ws: + await ws.send([b"a" * 300, b"b" * 300, b"c" * 300]) + assert await ws.recv() == (b"a" * 300 + b"b" * 300 + b"c" * 300) + + +async def test_permessage_deflate_recv_streaming_from_upstream_server(): + chunks = [b"ab" * 300, b"cd" * 300, b"ef" * 300] + + async def handler(ws): + await ws.send(chunks) + await ws.close() + + async with upstream_server(handler) as url: + async with websockets.connect(url, ping_interval=None) as ws: + fragments = [] + async for fragment in ws.recv_streaming(): + fragments.append(fragment) + assert fragments == chunks + [b""] + + +async def test_compressed_message_exceeding_max_size_closes_connection(): + async def handler(ws): + await ws.send("a" * 10000) + + async with upstream_server(handler) as url: + async with websockets.connect( + url, ping_interval=None, max_size=1000 + ) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + +async def test_malformed_compressed_message_closes_connection(): + async with malformed_compressed_server() as url: + async with websockets.connect(url, ping_interval=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() diff --git a/tests/test_websockets_ping_pong.py b/tests/test_websockets_ping_pong.py new file mode 100644 index 0000000..68e3bcb --- /dev/null +++ b/tests/test_websockets_ping_pong.py @@ -0,0 +1,80 @@ +import asyncio + +import pytest + +from picows import websockets +from tests.utils import WSServer + + +async def test_ping_returns_waiter(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + pong_waiter = await ws.ping(b"abcd") + latency = await asyncio.wait_for(pong_waiter, 1.0) + assert latency >= 0 + + +async def test_ping_accepts_byteslike_payloads(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + pong_waiter = await ws.ping(bytearray(b"abcd")) + await asyncio.wait_for(pong_waiter, 1.0) + pong_waiter = await ws.ping(memoryview(b"efgh")) + await asyncio.wait_for(pong_waiter, 1.0) + + +async def test_pong_accepts_byteslike_payloads(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.pong(bytearray(b"abcd")) + await ws.pong(memoryview(b"efgh")) + + +async def test_ping_default_duplicate_and_invalid_payloads(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + default_waiter = await ws.ping() + await asyncio.wait_for(default_waiter, 1.0) + + waiter = await ws.ping("same") + with pytest.raises(websockets.ConcurrencyError, match="same data"): + await ws.ping(b"same") + await asyncio.wait_for(waiter, 1.0) + + with pytest.raises(TypeError, match="ping payload"): + await ws.ping(object()) # type: ignore[arg-type] + + +async def test_keepalive_loop_without_ping_timeout_sends_default_pings(): + async with WSServer(enable_auto_pong=True) as server: + async with websockets.connect( + server.url, + compression=None, + ping_interval=0.01, + ping_timeout=None, + ) as ws: + await asyncio.sleep(0.03) + assert ws.latency >= 0 + + +async def test_keepalive_loop_with_ping_timeout_observes_pong(): + async with WSServer(enable_auto_pong=True) as server: + async with websockets.connect( + server.url, + compression=None, + ping_interval=0.01, + ping_timeout=1.0, + ) as ws: + await asyncio.sleep(0.03) + assert ws.latency >= 0 + + +async def test_ping_and_pong_after_close_raise_connection_closed(): + async with WSServer() as server: + ws = await websockets.connect(server.url, compression=None, ping_interval=None) + await ws.close() + + with pytest.raises(websockets.ConnectionClosedOK): + await ws.ping() + with pytest.raises(websockets.ConnectionClosedOK): + await ws.pong() diff --git a/tests/test_websockets_recv.py b/tests/test_websockets_recv.py new file mode 100644 index 0000000..5f00159 --- /dev/null +++ b/tests/test_websockets_recv.py @@ -0,0 +1,260 @@ +import asyncio + +import picows +import pytest + +from picows import websockets +from tests.utils import ServerEchoListener, WSServer + + +class SendTextOnConnect(ServerEchoListener): + def __init__(self, payload: bytes): + self._payload = payload + + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.TEXT, self._payload) + + +class SendBinaryOnConnect(ServerEchoListener): + def __init__(self, payload: bytes): + self._payload = payload + + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.BINARY, self._payload) + + +class FragmentedTextListener(ServerEchoListener): + def __init__(self, allow_first_fragment: asyncio.Event, allow_second_fragment: asyncio.Event): + self._allow_first_fragment = allow_first_fragment + self._allow_second_fragment = allow_second_fragment + + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + + async def send_fragments(): + await self._allow_first_fragment.wait() + transport.send(picows.WSMsgType.TEXT, b"he", fin=False) + await self._allow_second_fragment.wait() + transport.send(picows.WSMsgType.CONTINUATION, b"llo", fin=True) + + asyncio.create_task(send_fragments()) + + +class ContinuationOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.CONTINUATION, b"bad", fin=True) + + +class Rsv1OnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.TEXT, b"bad", fin=True, rsv1=True) + + +class BadContinuationSequenceOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.TEXT, b"first", fin=False) + transport.send(picows.WSMsgType.TEXT, b"second", fin=True) + + +class SendLargeTextOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.TEXT, b"large") + + +class DelayedFragmentedTextOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + + async def send_fragments(): + transport.send(picows.WSMsgType.TEXT, b"he", fin=False) + await asyncio.sleep(0.01) + transport.send(picows.WSMsgType.CONTINUATION, b"llo", fin=True) + + asyncio.create_task(send_fragments()) + + +class FragmentedTextOnConnect(ServerEchoListener): + def on_ws_connected(self, transport: picows.WSTransport): + super().on_ws_connected(transport) + transport.send(picows.WSMsgType.TEXT, b"he", fin=False) + transport.send(picows.WSMsgType.CONTINUATION, b"llo", fin=True) + + +async def test_async_iteration_closes_normally(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + await ws.send("hello") + assert await ws.recv() == "hello" + await ws.close() + + items = [] + async for item in ws: + items.append(item) + + assert items == [] + + +async def test_recv_streaming_fragmented_message(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + await ws.send([b"ab", b"cd"]) + fragments = [] + async for fragment in ws.recv_streaming(): + fragments.append(fragment) + assert fragments == [b"ab", b"cd", b""] + + +async def test_recv_cancellation_is_safe_for_fragmented_message(): + allow_first_fragment = asyncio.Event() + allow_second_fragment = asyncio.Event() + + async with WSServer(lambda _: FragmentedTextListener(allow_first_fragment, allow_second_fragment)) as server: + async with websockets.connect(server.url, compression=None) as ws: + recv_task = asyncio.create_task(ws.recv()) + allow_first_fragment.set() + await asyncio.sleep(0) + recv_task.cancel() + with pytest.raises(asyncio.CancelledError): + await recv_task + + allow_second_fragment.set() + assert await ws.recv() == "hello" + + +async def test_recv_streaming_cancellation_before_first_fragment_is_safe(): + allow_first_fragment = asyncio.Event() + allow_second_fragment = asyncio.Event() + + async with WSServer(lambda _: FragmentedTextListener(allow_first_fragment, allow_second_fragment)) as server: + async with websockets.connect(server.url, compression=None) as ws: + iterator = ws.recv_streaming() + recv_task = asyncio.create_task(anext(iterator)) + recv_task.cancel() + with pytest.raises(asyncio.CancelledError): + await recv_task + + allow_first_fragment.set() + allow_second_fragment.set() + assert await ws.recv() == "hello" + + +async def test_recv_streaming_partial_consumption_breaks_future_receives(): + allow_first_fragment = asyncio.Event() + allow_second_fragment = asyncio.Event() + + async with WSServer(lambda _: FragmentedTextListener(allow_first_fragment, allow_second_fragment)) as server: + async with websockets.connect(server.url, compression=None) as ws: + iterator = ws.recv_streaming() + allow_first_fragment.set() + assert await anext(iterator) == "he" + + with pytest.raises(websockets.ConcurrencyError): + await ws.recv() + + allow_second_fragment.set() + + with pytest.raises(websockets.ConcurrencyError): + await ws.recv() + + +async def test_recv_decode_false_returns_bytes_for_text_messages(): + async with WSServer(lambda _: SendTextOnConnect(b"hello")) as server: + async with websockets.connect(server.url, compression=None) as ws: + assert await ws.recv(decode=False) == b"hello" + + +async def test_recv_decode_true_returns_text_for_binary_messages(): + async with WSServer(lambda _: SendBinaryOnConnect(b"hello")) as server: + async with websockets.connect(server.url, compression=None) as ws: + assert await ws.recv(decode=True) == "hello" + + +async def test_recv_streaming_decode_false_returns_bytes_for_text_messages(): + async with WSServer(lambda _: SendTextOnConnect(b"hello")) as server: + async with websockets.connect(server.url, compression=None) as ws: + fragments = [] + async for fragment in ws.recv_streaming(decode=False): + fragments.append(fragment) + assert fragments == [b"hello"] + + +async def test_recv_streaming_decode_true_returns_text_for_binary_messages(): + async with WSServer(lambda _: SendBinaryOnConnect(b"hello")) as server: + async with websockets.connect(server.url, compression=None) as ws: + fragments = [] + async for fragment in ws.recv_streaming(decode=True): + fragments.append(fragment) + assert fragments == ["hello"] + + +async def test_recv_decode_true_invalid_utf8_closes_connection(): + async with WSServer(lambda _: SendBinaryOnConnect(b"\xff\xfe")) as server: + async with websockets.connect(server.url, compression=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv(decode=True) + assert ws.close_code == 1007 + + +async def test_recv_streaming_decode_true_invalid_utf8_closes_connection(): + async with WSServer(lambda _: SendBinaryOnConnect(b"\xff\xfe")) as server: + async with websockets.connect(server.url, compression=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + async for _fragment in ws.recv_streaming(decode=True): + pass + assert ws.close_code == 1007 + + +async def test_recv_rejects_unexpected_continuation_and_rsv1(): + async with WSServer(lambda _: ContinuationOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + async with WSServer(lambda _: Rsv1OnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + +async def test_recv_rejects_bad_continuation_and_too_large_message(): + async with WSServer(lambda _: BadContinuationSequenceOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + async with WSServer(lambda _: SendLargeTextOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None, max_size=2) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + +async def test_recv_streaming_waits_for_later_fragment(): + async with WSServer(lambda _: DelayedFragmentedTextOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + fragments = [] + async for fragment in ws.recv_streaming(): + fragments.append(fragment) + assert fragments == ["he", "llo"] + + +async def test_fragmented_message_exceeding_max_size_closes_connection(): + async with WSServer(lambda _: FragmentedTextOnConnect()) as server: + async with websockets.connect(server.url, compression=None, ping_interval=None, max_size=4) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + + +async def test_disconnect_without_close_frame_sets_error_close_state(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send("disconnect_me_without_close_frame") + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + assert ws.close_code is None + assert ws.close_reason is None diff --git a/tests/test_websockets_send.py b/tests/test_websockets_send.py new file mode 100644 index 0000000..da3397d --- /dev/null +++ b/tests/test_websockets_send.py @@ -0,0 +1,209 @@ +import asyncio +from collections.abc import AsyncIterator + +import pytest + +from picows import websockets +from tests.utils import WSServer + + +class FragmentError(Exception): + pass + + +async def test_connect_send_recv_text(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + await ws.send("hello") + assert await ws.recv() == "hello" + + +async def test_connect_send_recv_binary(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + await ws.send(b"hello") + assert await ws.recv() == b"hello" + + +async def test_send_empty_iterable_is_noop(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send([]) + pong_waiter = await ws.ping(b"noop") + await asyncio.wait_for(pong_waiter, 1.0) + + +async def test_send_empty_async_iterable_is_noop(): + async def fragments() -> AsyncIterator[bytes]: + if False: + yield b"never" + + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send(fragments()) + pong_waiter = await ws.ping(b"noop") + await asyncio.wait_for(pong_waiter, 1.0) + + +async def test_send_rejects_dict_like_objects(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None) as ws: + with pytest.raises(TypeError, match="dict-like object"): + await ws.send({"a": 1}) + + +async def test_send_text_overrides_and_concurrent_send_waits_for_turn(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send("as-binary", text=False) + assert await ws.recv(decode=False) == b"as-binary" + + await ws.send(b"as-text", text=True) + assert await ws.recv() == "as-text" + + first_sent = asyncio.Event() + release = asyncio.Event() + + async def fragments(): + yield b"first" + first_sent.set() + await release.wait() + yield b"second" + + first_send = asyncio.create_task(ws.send(fragments())) + await asyncio.wait_for(first_sent.wait(), 1.0) + + second_send = asyncio.create_task(ws.send(b"after")) + await asyncio.sleep(0) + assert not second_send.done() + + release.set() + await asyncio.wait_for(first_send, 1.0) + await asyncio.wait_for(second_send, 1.0) + assert await ws.recv() == b"firstsecond" + assert await ws.recv() == b"after" + + +async def test_send_string_fragments_and_write_pause_wait_paths(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + await ws.send(["he", "llo"]) + assert await ws.recv() == "hello" + + ws.pause_writing() + single_send = asyncio.create_task(ws.send(b"paused")) + await asyncio.sleep(0) + assert not single_send.done() + ws.resume_writing() + await asyncio.wait_for(single_send, 1.0) + assert await ws.recv() == b"paused" + + ws.pause_writing() + fragmented_send = asyncio.create_task(ws.send([b"frag", b"mented"])) + await asyncio.sleep(0) + assert not fragmented_send.done() + ws.resume_writing() + await asyncio.wait_for(fragmented_send, 1.0) + assert await ws.recv() == b"fragmented" + + +async def test_send_rejects_invalid_first_fragment_and_closes(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + with pytest.raises(TypeError, match="message must contain"): + await ws.send([object()]) # type: ignore[list-item] + await asyncio.wait_for(ws.wait_closed(), 1.0) + assert ws.state is websockets.State.CLOSED + + +async def test_send_rejects_unsupported_object_and_mixed_fragments(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + with pytest.raises(TypeError, match="unsupported type"): + await ws.send(object()) # type: ignore[arg-type] + with pytest.raises(TypeError, match="same category"): + await ws.send([b"bytes", "text"]) + + +async def test_send_sync_iterable_exception_closes_connection(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + def fragments(): + yield b"first" + raise FragmentError("boom") + + with pytest.raises(FragmentError, match="boom"): + await ws.send(fragments()) + + await asyncio.wait_for(ws.wait_closed(), 1.0) + with pytest.raises(websockets.ConnectionClosed): + await ws.send(b"x") + + +async def test_send_async_iterable_exception_closes_connection(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + async def fragments() -> AsyncIterator[bytes]: + yield b"first" + raise FragmentError("boom") + + with pytest.raises(FragmentError, match="boom"): + await ws.send(fragments()) + + await asyncio.wait_for(ws.wait_closed(), 1.0) + with pytest.raises(websockets.ConnectionClosed): + await ws.send(b"x") + + +async def test_send_async_iterable_cancellation_closes_connection(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + first_sent = asyncio.Event() + unblock = asyncio.Event() + + async def fragments() -> AsyncIterator[bytes]: + yield b"first" + first_sent.set() + await unblock.wait() + yield b"second" + + send_task = asyncio.create_task(ws.send(fragments())) + await asyncio.wait_for(first_sent.wait(), 1.0) + send_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await send_task + + await asyncio.wait_for(ws.wait_closed(), 1.0) + with pytest.raises(websockets.ConnectionClosed): + await ws.send(b"x") + + +async def test_close_during_fragmented_send_closes_connection(): + async with WSServer() as server: + async with websockets.connect(server.url, compression=None, ping_interval=None) as ws: + first_sent = asyncio.Event() + unblock = asyncio.Event() + + async def fragments() -> AsyncIterator[bytes]: + yield b"first" + first_sent.set() + await unblock.wait() + yield b"second" + + send_task = asyncio.create_task(ws.send(fragments())) + await asyncio.wait_for(first_sent.wait(), 1.0) + await ws.close() + unblock.set() + + with pytest.raises(websockets.ConnectionClosed): + await send_task + + +async def test_send_after_close_raises_connection_closed(): + async with WSServer() as server: + ws = await websockets.connect(server.url, compression=None, ping_interval=None) + await ws.close() + + with pytest.raises(websockets.ConnectionClosedOK): + await ws.send(b"closed") diff --git a/tests/test_websockets_server.py b/tests/test_websockets_server.py new file mode 100644 index 0000000..8199640 --- /dev/null +++ b/tests/test_websockets_server.py @@ -0,0 +1,149 @@ +import asyncio +import sys + +import pytest + +from picows import websockets + + +async def test_serve_echo_roundtrip(): + async def handler(ws: websockets.ServerConnection) -> None: + assert ws.request.path == "/" + assert ws.response.status_code == 101 + message = await ws.recv() + await ws.send(message) + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: + await ws.send("hello") + assert await ws.recv() == "hello" + + +async def test_serve_rejects_create_connection(): + async def handler(ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + with pytest.raises(NotImplementedError): + await websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + create_connection=websockets.ServerConnection, + ) + + +async def test_broadcast_sends_to_open_connections(): + connections: list[websockets.ServerConnection] = [] + + async def handler(ws: websockets.ServerConnection) -> None: + connections.append(ws) + await ws.wait_closed() + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws1: + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws2: + while len(connections) < 2: + await asyncio.sleep(0) + websockets.broadcast(connections, "hi") + assert await ws1.recv() == "hi" + assert await ws2.recv() == "hi" + + +def test_broadcast_validation_and_exception_group(): + with pytest.raises(TypeError, match="data must be str or bytes"): + websockets.broadcast([], object()) # type: ignore[arg-type] + + if sys.version_info[:2] < (3, 11): + with pytest.raises(ValueError, match="requires at least Python 3.11"): + websockets.broadcast([], "hello", raise_exceptions=True) + return + + class BrokenConnection: + state = websockets.State.OPEN + _send_in_progress = False + + def _encode_and_send(self, _msg_type, _message, _fin): + raise RuntimeError("broken") + + with pytest.raises(ExceptionGroup) as exc_info: + websockets.broadcast([BrokenConnection()], "hello", raise_exceptions=True) # type: ignore[list-item] + assert len(exc_info.value.exceptions) == 1 + + +async def test_server_connections_tracks_open_connections(): + connected = asyncio.Event() + + async def handler(ws: websockets.ServerConnection) -> None: + connected.set() + await ws.wait_closed() + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + assert server.connections == set() + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + await connected.wait() + assert len(server.connections) == 1 + await asyncio.sleep(0) + assert server.connections == set() + + +async def test_handler_exception_closes_connection_with_internal_error(): + async def handler(ws: websockets.ServerConnection) -> None: + raise RuntimeError("boom") + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: + with pytest.raises(websockets.ConnectionClosedError): + await ws.recv() + assert ws.close_code == 1011 + + +async def test_server_close_closes_existing_connections(): + started = asyncio.Event() + + async def handler(ws: websockets.ServerConnection) -> None: + started.set() + await ws.wait_closed() + + async with websockets.serve(handler, "127.0.0.1", 0, compression=None) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: + await started.wait() + server.close(reason="bye") + with pytest.raises(websockets.ConnectionClosedOK): + await ws.recv() + assert ws.close_code == 1001 + assert ws.close_reason == "bye" + await server.wait_closed() + + +async def test_wait_closed_waits_for_handler_completion(): + started = asyncio.Event() + finish = asyncio.Event() + finished = asyncio.Event() + + async def handler(ws: websockets.ServerConnection) -> None: + started.set() + await finish.wait() + finished.set() + + server = await websockets.serve(handler, "127.0.0.1", 0, compression=None) + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + await started.wait() + server.close(close_connections=False) + waiter = asyncio.create_task(server.wait_closed()) + await asyncio.sleep(0) + assert not waiter.done() + finish.set() + await waiter + assert finished.is_set() + + +def test_route_requires_werkzeug(): + with pytest.raises((ImportError, NotImplementedError)): + websockets.route(None) # type: ignore[arg-type] diff --git a/tests/test_websockets_server_handshake.py b/tests/test_websockets_server_handshake.py new file mode 100644 index 0000000..8553b7b --- /dev/null +++ b/tests/test_websockets_server_handshake.py @@ -0,0 +1,202 @@ +import base64 +import re + +import pytest +from multidict import CIMultiDict + +from picows import websockets +from picows.websockets.asyncio.server import _parse_basic_authorization + + +async def test_serve_process_request_can_reject_handshake(): + async def handler(ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + def process_request( + ws: websockets.ServerHandshakeConnection, + request: websockets.Request, + ) -> websockets.Response | None: + assert ws.request is request + return websockets.Response( + status_code=418, + reason_phrase="I'm a Teapot", + headers=CIMultiDict({"X-Test": "yes"}), + body=b"nope", + ) + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=process_request, + ) as server: + port = server.sockets[0].getsockname()[1] + with pytest.raises(websockets.InvalidStatus) as exc_info: + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + pass + assert int(exc_info.value.response.status) == 418 + + +async def test_serve_process_response_can_mutate_handshake_response(): + async def handler(ws: websockets.ServerConnection) -> None: + await ws.wait_closed() + + def process_response( + ws: websockets.ServerHandshakeConnection, + request: websockets.Request, + response: websockets.Response, + ) -> websockets.Response: + assert ws.request is request + response.headers["X-Handshake"] = "yes" + return response + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_response=process_response, + ) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None) as ws: + assert ws.response.headers["X-Handshake"] == "yes" + + +async def test_serve_rejects_async_process_request(): + async def handler(ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + async def process_request( + ws: websockets.ServerHandshakeConnection, + request: websockets.Request, + ) -> websockets.Response | None: + return None + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=process_request, + ) as server: + port = server.sockets[0].getsockname()[1] + with pytest.raises(websockets.InvalidStatus): + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + pass + + +async def test_serve_accepts_allowed_origin(): + async def handler(ws: websockets.ServerConnection) -> None: + await ws.send("ok") + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + origins=["https://example.com"], + ) as server: + port = server.sockets[0].getsockname()[1] + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + origin="https://example.com", + ) as ws: + assert await ws.recv() == "ok" + + +async def test_serve_rejects_disallowed_origin(): + async def handler(ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + origins=[re.compile(r"https://allowed\\.example\\.com")], + ) as server: + port = server.sockets[0].getsockname()[1] + with pytest.raises(websockets.InvalidStatus): + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + origin="https://denied.example.com", + ): + pass + + +async def test_basic_auth_rejects_missing_credentials_and_sets_username(): + async def handler(ws: websockets.ServerConnection) -> None: + assert ws.username == "hello" + await ws.send(ws.username) + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=websockets.basic_auth( + realm="test", + credentials=("hello", "secret"), + ), + ) as server: + port = server.sockets[0].getsockname()[1] + + with pytest.raises(websockets.InvalidStatus) as exc_info: + async with websockets.connect(f"ws://127.0.0.1:{port}/", compression=None): + pass + assert int(exc_info.value.response.status) == 401 + assert exc_info.value.response.headers["WWW-Authenticate"] == 'Basic realm="test"' + + token = base64.b64encode(b"hello:secret").decode() + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + additional_headers={"Authorization": f"Basic {token}"}, + ) as ws: + assert await ws.recv() == "hello" + + +def test_basic_auth_argument_validation_and_malformed_headers(): + with pytest.raises(ValueError, match="provide either credentials or check_credentials"): + websockets.basic_auth() + with pytest.raises(ValueError, match="provide either credentials or check_credentials"): + websockets.basic_auth(credentials=("a", "b"), check_credentials=lambda _u, _p: True) + with pytest.raises(TypeError, match="invalid credentials argument"): + websockets.basic_auth(credentials=("a", "b", "c")) # type: ignore[arg-type] + with pytest.raises(TypeError, match="invalid credentials argument"): + websockets.basic_auth(credentials=[("a", object())]) # type: ignore[list-item] + + with pytest.raises(ValueError, match="unsupported authorization scheme"): + _parse_basic_authorization("Bearer token") + with pytest.raises(ValueError, match="invalid basic authorization header"): + _parse_basic_authorization("Basic !!!") + with pytest.raises(ValueError, match="invalid basic authorization header"): + _parse_basic_authorization("Basic bm9jb2xvbg==") + + +async def test_basic_auth_rejects_bad_and_async_credentials(): + async def handler(_ws: websockets.ServerConnection) -> None: + raise AssertionError("handler must not be called") + + async def check_credentials(_username: str, _password: str) -> bool: + return True + + async with websockets.serve( + handler, + "127.0.0.1", + 0, + compression=None, + process_request=websockets.basic_auth(check_credentials=check_credentials), + ) as server: + port = server.sockets[0].getsockname()[1] + token = "Basic " + "aW52YWxpZDpjcmVkZW50aWFscw==" + with pytest.raises(websockets.InvalidStatus): + async with websockets.connect( + f"ws://127.0.0.1:{port}/", + compression=None, + additional_headers={"Authorization": token}, + ): + pass