From 9471cabaa51b7a4880e8cd1a488c50d8d6d1ffc7 Mon Sep 17 00:00:00 2001 From: taras Date: Sat, 25 Apr 2026 23:39:41 +0200 Subject: [PATCH 01/16] logger_name can now accept a logger like object --- HISTORY.rst | 6 ++++++ picows/api.py | 33 ++++++++++++++++++++++++++------- tests/test_basics.py | 3 ++- tests/test_ws_logic.py | 13 +++++++++++++ 4 files changed, 47 insertions(+), 8 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index d642f8b..a79b6e1 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -5,6 +5,12 @@ picows Release History :depth: 1 :local: +1.20.0 () +------------------ + +* ws_connect/ws_create_server logger_name parameter can now accept a logger-like object + + 1.19.0 (2026-04-24) ------------------ diff --git a/picows/api.py b/picows/api.py index 204a7ca..3846c17 100644 --- a/picows/api.py +++ b/picows/api.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from functools import partial from inspect import isawaitable -from logging import getLogger +from logging import Logger, LoggerAdapter, getLogger from ssl import SSLContext from typing import Callable, Optional, Union, Dict, Any, Awaitable, cast @@ -20,6 +20,7 @@ WSListenerFactory = Callable[[], WSListener] WSServerListenerFactory = Callable[[WSUpgradeRequest], Union[WSListener, WSUpgradeResponseWithListener, None]] WSSocketFactory = Callable[[WSParsedURL], Union[Optional[socket.socket], Awaitable[Optional[socket.socket]]]] +WSLoggerLike = Union[str, Logger, LoggerAdapter[Any], None] _HAS_AIOFASTNET = False try: @@ -61,6 +62,20 @@ def _is_connected(sock: socket.socket) -> bool: except OSError: return False + +def _resolve_logger( + logger_name: WSLoggerLike, + default_suffix: str, + prefix: str = "picows." +) -> Union[Logger, LoggerAdapter[Any]]: + if logger_name is None: + return getLogger(f"{prefix}{default_suffix}") + + if isinstance(logger_name, str): + return getLogger(f"{prefix}{logger_name}") + + return logger_name + @dataclass class _ConnectedSocket: sock: Optional[socket.socket] @@ -172,7 +187,7 @@ async def ws_connect(ws_listener_factory: WSListenerFactory, # type: ignore [no- ssl_context: Optional[SSLContext] = None, disconnect_on_exception: bool = True, websocket_handshake_timeout: float = 5, - logger_name: str = "client", + logger_name: WSLoggerLike = None, enable_auto_ping: bool = False, auto_ping_idle_timeout: float = 10, auto_ping_reply_timeout: float = 10, @@ -206,7 +221,9 @@ async def ws_connect(ws_listener_factory: WSListenerFactory, # type: ignore [no- is the time in seconds to wait for the websocket client to receive websocket handshake response before aborting the connection. :param logger_name: - picows will use `picows.` logger to do all the logging. + Logger name suffix or logger-like object used for logging. + If a string is provided, picows will use `picows.`. + If ``None`` is provided, picows will use ``picows.client``. :param enable_auto_ping: Enable detection of a stale connection by periodically pinging remote peer. @@ -273,7 +290,7 @@ async def ws_connect(ws_listener_factory: WSListenerFactory, # type: ignore [no- # May sure people who are passing old argument are not going to get an exception kwargs.pop('zero_copy_unsafe_ssl_write', None) - logger = getLogger(f"picows.{logger_name}") + logger = _resolve_logger(logger_name, "client") parsed_url = parse_url(url) parsed_proxy_url = parse_url(proxy, False) if proxy is not None else None loop = asyncio.get_running_loop() @@ -342,7 +359,7 @@ async def ws_create_server(ws_listener_factory: WSServerListenerFactory, *, disconnect_on_exception: bool = True, websocket_handshake_timeout=5, - logger_name: str = "server", + logger_name: WSLoggerLike = None, enable_auto_ping: bool = False, auto_ping_idle_timeout: float = 20, auto_ping_reply_timeout: float = 20, @@ -389,7 +406,9 @@ async def ws_create_server(ws_listener_factory: WSServerListenerFactory, :param websocket_handshake_timeout: is the time in seconds to wait for the websocket server to receive websocket handshake request before aborting the connection. :param logger_name: - picows will use `picows.` logger to do all the logging. + Logger name suffix or logger-like object used for logging. + If a string is provided, picows will use `picows.`. + If ``None`` is provided, picows will use ``picows.server``. :param enable_auto_ping: Enable detection of a stale connection by periodically pinging remote peer. @@ -444,7 +463,7 @@ def ws_protocol_factory() -> WSProtocol: None, # ws_path False, # is_client_side ws_listener_factory, - getLogger(f"picows.{logger_name}"), + _resolve_logger(logger_name, "server"), disconnect_on_exception, websocket_handshake_timeout, enable_auto_ping, auto_ping_idle_timeout, auto_ping_reply_timeout, diff --git a/tests/test_basics.py b/tests/test_basics.py index 27ba9d0..4f35a9f 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -1,5 +1,6 @@ import asyncio import base64 +import logging import os import picows @@ -279,4 +280,4 @@ async def test_stress(use_aiofastnet, ssl_context): assert not client.is_paused -# \ No newline at end of file +# diff --git a/tests/test_ws_logic.py b/tests/test_ws_logic.py index 93fc2be..e9883f8 100644 --- a/tests/test_ws_logic.py +++ b/tests/test_ws_logic.py @@ -1,4 +1,5 @@ import asyncio +import logging import os import struct from concurrent.futures import ThreadPoolExecutor @@ -8,6 +9,7 @@ import pytest import picows +from picows.api import _resolve_logger from tests.utils import WSServer, WSClient, TIMEOUT from tests.fixtures import use_aiofastnet, ssl_context @@ -188,3 +190,14 @@ async def test_wrong_thread_assert(): with pytest.raises(RuntimeError, match="WSTransport.disconnect called from a wrong thread"): await loop.run_in_executor(executor, client.transport.disconnect) + + +def test_resolve_logger(): + logger = logging.getLogger("tests.picows.custom") + + assert _resolve_logger(None, "client") is logging.getLogger("picows.client") + assert _resolve_logger(None, "server") is logging.getLogger("picows.server") + assert _resolve_logger("custom", "client") is logging.getLogger("picows.custom") + assert _resolve_logger(logger, "client") is logger + + From 2f1e676ec89b7c8ff9c356899c92737b562e86b4 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 26 Apr 2026 00:22:56 +0200 Subject: [PATCH 02/16] Introduce new exception types for compatibility with websockets --- HISTORY.rst | 1 + docs/source/guides.rst | 13 ++++ docs/source/reference.rst | 12 ++++ picows/__init__.py | 8 +++ picows/picows.pyx | 122 +++++++++++++++++++++++++++++++++----- picows/types.py | 39 ++++++++++++ tests/test_ws_logic.py | 59 +++++++++++++++++- 7 files changed, 235 insertions(+), 19 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index a79b6e1..3f67534 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -9,6 +9,7 @@ picows Release History ------------------ * ws_connect/ws_create_server logger_name parameter can now accept a logger-like object +* Introduce new exceptions: WSInvalidMessageError, WSInvalidStatusError, WSInvalidHeaderError, WSInvalidUpgradeError 1.19.0 (2026-04-24) diff --git a/docs/source/guides.rst b/docs/source/guides.rst index 50a12e1..83626ba 100644 --- a/docs/source/guides.rst +++ b/docs/source/guides.rst @@ -91,11 +91,24 @@ Additionally, websocket-specific failures are represented by :any:`WSError` and its subclasses: * :any:`WSHandshakeError` for HTTP upgrade negotiation failures (raised by :any:`ws_connect`). + More specific subclasses may be raised: + + * :any:`WSInvalidMessageError` for malformed HTTP upgrade responses. + * :any:`WSInvalidStatusError` when the HTTP response status isn't ``101 Switching Protocols``. + * :any:`WSInvalidHeaderError` for invalid handshake headers such as + ``Content-Length`` or ``Sec-WebSocket-Accept``. + * :any:`WSInvalidUpgradeError` for invalid ``Upgrade`` / ``Connection`` headers. + + Redirect-following failures in :any:`ws_connect` currently still raise the + base :any:`WSHandshakeError`. * :any:`WSProtocolError` for websocket parser/protocol violations (can be re-raised by :any:`WSTransport.wait_disconnected` on client side). * :any:`WSInvalidURL` for invalid websocket/proxy URL inputs. In general, :any:`WSError` is reserved for websocket-specific failures only. +Handshake timeouts are separate and currently raise `asyncio.TimeoutError`, +not :any:`WSError`. + There is also a special exception, `asyncio.CancelledError`, which any coroutine can raise when it is externally cancelled. Sometimes you need to handle this exception manually. For example, in a reconnection loop where you want to diff --git a/docs/source/reference.rst b/docs/source/reference.rst index 01cbc09..3ac92b1 100644 --- a/docs/source/reference.rst +++ b/docs/source/reference.rst @@ -20,6 +20,18 @@ Classes .. autoexception:: WSHandshakeError :show-inheritance: +.. autoexception:: WSInvalidMessageError + :show-inheritance: + +.. autoexception:: WSInvalidStatusError + :show-inheritance: + +.. autoexception:: WSInvalidHeaderError + :show-inheritance: + +.. autoexception:: WSInvalidUpgradeError + :show-inheritance: + .. autoexception:: WSProtocolError :show-inheritance: diff --git a/picows/__init__.py b/picows/__init__.py index e329d0b..e27785d 100644 --- a/picows/__init__.py +++ b/picows/__init__.py @@ -1,6 +1,10 @@ from .types import ( WSError, WSHandshakeError, + WSInvalidMessageError, + WSInvalidStatusError, + WSInvalidHeaderError, + WSInvalidUpgradeError, WSProtocolError, WSUpgradeRequest, WSUpgradeResponse, @@ -30,6 +34,10 @@ __all__ = [ 'WSError', 'WSHandshakeError', + 'WSInvalidMessageError', + 'WSInvalidStatusError', + 'WSInvalidHeaderError', + 'WSInvalidUpgradeError', 'WSProtocolError', 'WSUpgradeRequest', 'WSUpgradeResponse', diff --git a/picows/picows.pyx b/picows/picows.pyx index 7f62e12..41f5e63 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -24,7 +24,9 @@ from libc.stdlib cimport rand from .types import (PICOWS_DEBUG_LL, WSUpgradeRequest, WSUpgradeResponse, WSUpgradeResponseWithListener, - WSHandshakeError, WSProtocolError, add_extra_headers) + WSHandshakeError, WSInvalidMessageError, WSInvalidStatusError, + WSInvalidHeaderError, WSInvalidUpgradeError, + WSProtocolError, add_extra_headers) cdef: @@ -1337,45 +1339,133 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): cdef list lines = raw_headers.split(b"\r\n") cdef bytes response_status_line = lines[0] + cdef str response_status_line_str + cdef bytes status_code + cdef bytes line, name, value + cdef str transfer_encoding + cdef object connection_value + cdef object upgrade_value + cdef object r_key + cdef Py_ssize_t content_length - cdef str response_status_line_str = response_status_line.decode().lower() + try: + response_status_line_str = response_status_line.decode().lower() + except UnicodeDecodeError: + raise WSInvalidMessageError( + "cannot upgrade, invalid HTTP status line in upgrade response", + raw_headers, + tail, + ) from None # check handshake if not response_status_line_str.startswith("http/1.1 " ): - raise WSHandshakeError(f"cannot upgrade, unknown protocol (expected HTTP/1.1) in upgrade response: {response_status_line_str}", raw_headers, tail) + raise WSInvalidMessageError( + f"cannot upgrade, unknown protocol (expected HTTP/1.1) in upgrade response: {response_status_line_str}", + raw_headers, + tail, + ) - cdef bytes status_code response = WSUpgradeResponse() - response.version, status_code, status_phrase = response_status_line.split(b" ", 2) - response.status = HTTPStatus(int(status_code.decode())) + try: + response.version, status_code, status_phrase = response_status_line.split(b" ", 2) + response.status = HTTPStatus(int(status_code.decode())) + except (ValueError, UnicodeDecodeError): + raise WSInvalidMessageError( + f"cannot upgrade, invalid HTTP status line in upgrade response: {response_status_line!r}", + raw_headers, + tail, + ) from None - cdef bytes line, name, value response.headers = CIMultiDict() for idx in range(1, len(lines)): line = lines[idx] - name, value = line.split(b":", 1) - response.headers.add((name.strip()).decode(), (value.strip()).decode()) + try: + name, value = line.split(b":", 1) + response.headers.add((name.strip()).decode(), (value.strip()).decode()) + except (ValueError, UnicodeDecodeError): + raise WSInvalidMessageError( + f"cannot upgrade, malformed header in upgrade response: {line!r}", + raw_headers, + tail, + response, + ) from None if response.status != HTTPStatus.SWITCHING_PROTOCOLS: - raise WSHandshakeError(f"expected upgrade response with status 101 Switching Protocols, but received {response.status}", raw_headers, tail, response) + raise WSInvalidStatusError( + f"expected upgrade response with status 101 Switching Protocols, but received {response.status}", + raw_headers, + tail, + response, + ) - if response.headers.get("transfer-encoding") == "chunked": - raise WSHandshakeError(f"101 response cannot have Transfer-Encoding but it has", raw_headers, tail, response) + transfer_encoding = response.headers.get("transfer-encoding") + if transfer_encoding == "chunked": + raise WSInvalidHeaderError( + "101 response cannot have Transfer-Encoding but it has", + "Transfer-Encoding", + transfer_encoding, + raw_headers, + tail, + response, + ) - cdef Py_ssize_t content_length = int(response.headers.get("content-length", "0")) + try: + content_length = int(response.headers.get("content-length", "0")) + except ValueError: + raise WSInvalidHeaderError( + "101 response has invalid Content-Length header", + "Content-Length", + response.headers.get("content-length"), + raw_headers, + tail, + response, + ) from None if content_length != 0: - raise WSHandshakeError(f"101 response has non-zero Content-Length, but it can't have body", raw_headers, tail, response) + raise WSInvalidHeaderError( + "101 response has non-zero Content-Length, but it can't have body", + "Content-Length", + response.headers.get("content-length"), + raw_headers, + tail, + response, + ) + + upgrade_value = response.headers.get("upgrade") + upgrade_value = upgrade_value if upgrade_value is None else upgrade_value.lower() + if upgrade_value != "websocket": + raise WSInvalidUpgradeError( + "cannot upgrade, invalid upgrade header", + "Upgrade", + response.headers.get("upgrade"), + raw_headers, + tail, + response, + ) connection_value = response.headers.get("connection") connection_value = connection_value if connection_value is None else connection_value.lower() if connection_value != "upgrade": - raise WSHandshakeError(f"cannot upgrade, invalid connection header: {response.headers['connection']}", raw_headers, tail, response) + raise WSInvalidUpgradeError( + "cannot upgrade, invalid connection header", + "Connection", + response.headers.get("connection"), + raw_headers, + tail, + response, + ) r_key = response.headers.get("sec-websocket-accept") match = b64encode(sha1(self._websocket_key_b64 + _WS_KEY).digest()).decode() if r_key != match: - raise WSHandshakeError(f"cannot upgrade, invalid sec-websocket-accept response", raw_headers, tail, response) + raise WSInvalidHeaderError( + "cannot upgrade, invalid sec-websocket-accept response", + "Sec-WebSocket-Accept", + response.headers.get("sec-websocket-accept"), + raw_headers, + tail, + response, + ) memmove(self._read_buffer.data, self._read_buffer.data + len(raw_headers) + 4, self._read_buffer.size - len(raw_headers) - 4) self._f_new_data_start_pos = len(tail) diff --git a/picows/types.py b/picows/types.py index 9df12d8..6a6179d 100644 --- a/picows/types.py +++ b/picows/types.py @@ -165,6 +165,45 @@ def __init__(self, description: str, self.response = response +class WSInvalidMessageError(WSHandshakeError): + """ + Raised when the HTTP handshake request or response is malformed. + """ + pass + + +class WSInvalidStatusError(WSHandshakeError): + """ + Raised when the HTTP handshake response status rejects the WebSocket upgrade. + """ + pass + + +class WSInvalidHeaderError(WSHandshakeError): + """ + Raised when a HTTP header in the WebSocket handshake is invalid. + """ + name: str + value: Optional[str] + + def __init__(self, description: str, + name: str, + value: Optional[str] = None, + raw_header: Optional[bytes] = None, + raw_body: Optional[bytes] = None, + response: Optional[WSUpgradeResponse] = None): + super().__init__(description, raw_header, raw_body, response) + self.name = name + self.value = value + + +class WSInvalidUpgradeError(WSInvalidHeaderError): + """ + Raised when Upgrade / Connection headers are invalid in the WebSocket handshake. + """ + pass + + class WSProtocolError(WSError): """ Raised when receiving or sending frames that break the protocol or diff --git a/tests/test_ws_logic.py b/tests/test_ws_logic.py index e9883f8..3f120af 100644 --- a/tests/test_ws_logic.py +++ b/tests/test_ws_logic.py @@ -3,6 +3,7 @@ import os import struct from concurrent.futures import ThreadPoolExecutor +from contextlib import asynccontextmanager from http import HTTPStatus import async_timeout @@ -10,10 +11,28 @@ import picows from picows.api import _resolve_logger -from tests.utils import WSServer, WSClient, TIMEOUT +from tests.utils import WSServer, WSClient, AsyncClient, TIMEOUT from tests.fixtures import use_aiofastnet, ssl_context +@asynccontextmanager +async def raw_handshake_server(response: bytes): + async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + await reader.readuntil(b"\r\n\r\n") + writer.write(response) + await writer.drain() + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(handle_client, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + try: + yield f"ws://127.0.0.1:{port}/" + finally: + server.close() + await server.wait_closed() + + async def test_send_external_bytearray_asserts(): async with WSServer() as server: async with WSClient(server) as client: @@ -192,6 +211,42 @@ async def test_wrong_thread_assert(): await loop.run_in_executor(executor, client.transport.disconnect) +async def test_handshake_invalid_status_error(): + response = ( + b"HTTP/1.1 404 Not Found\r\n" + b"Connection: close\r\n" + b"Content-Length: 0\r\n" + b"\r\n" + ) + async with raw_handshake_server(response) as url: + with pytest.raises(picows.WSInvalidStatusError): + await picows.ws_connect(AsyncClient, url) + + +async def test_handshake_invalid_upgrade_error(): + response = ( + b"HTTP/1.1 101 Switching Protocols\r\n" + b"Upgrade: not-websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Accept: invalid\r\n" + b"\r\n" + ) + async with raw_handshake_server(response) as url: + with pytest.raises(picows.WSInvalidUpgradeError, match="invalid upgrade header"): + await picows.ws_connect(AsyncClient, url) + + +async def test_handshake_invalid_message_error(): + response = ( + b"NOT-HTTP\r\n" + b"Header: value\r\n" + b"\r\n" + ) + async with raw_handshake_server(response) as url: + with pytest.raises(picows.WSInvalidMessageError): + await picows.ws_connect(AsyncClient, url) + + def test_resolve_logger(): logger = logging.getLogger("tests.picows.custom") @@ -199,5 +254,3 @@ def test_resolve_logger(): assert _resolve_logger(None, "server") is logging.getLogger("picows.server") assert _resolve_logger("custom", "client") is logging.getLogger("picows.custom") assert _resolve_logger(logger, "client") is logger - - From 31cb94211379eceff2f7af05faba5e76840fb56f Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 26 Apr 2026 00:53:10 +0200 Subject: [PATCH 03/16] Allow websocket_handshake_timeout=None to disable handshake timeouts --- HISTORY.rst | 2 +- picows/api.py | 6 ++-- picows/picows.pyx | 15 +++++---- tests/test_ws_logic.py | 76 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 9 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index 3f67534..246d12b 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -9,9 +9,9 @@ picows Release History ------------------ * ws_connect/ws_create_server logger_name parameter can now accept a logger-like object +* ws_connect/ws_create_server websocket_handshake_timeout param can now accept None to disable handshake timeouts * Introduce new exceptions: WSInvalidMessageError, WSInvalidStatusError, WSInvalidHeaderError, WSInvalidUpgradeError - 1.19.0 (2026-04-24) ------------------ diff --git a/picows/api.py b/picows/api.py index 3846c17..746474c 100644 --- a/picows/api.py +++ b/picows/api.py @@ -186,7 +186,7 @@ async def ws_connect(ws_listener_factory: WSListenerFactory, # type: ignore [no- *, ssl_context: Optional[SSLContext] = None, disconnect_on_exception: bool = True, - websocket_handshake_timeout: float = 5, + websocket_handshake_timeout: Optional[float] = 5, logger_name: WSLoggerLike = None, enable_auto_ping: bool = False, auto_ping_idle_timeout: float = 10, @@ -220,6 +220,7 @@ async def ws_connect(ws_listener_factory: WSListenerFactory, # type: ignore [no- :param websocket_handshake_timeout: is the time in seconds to wait for the websocket client to receive websocket handshake response before aborting the connection. + Set to ``None`` to disable the timeout. :param logger_name: Logger name suffix or logger-like object used for logging. If a string is provided, picows will use `picows.`. @@ -358,7 +359,7 @@ async def ws_create_server(ws_listener_factory: WSServerListenerFactory, port=None, *, disconnect_on_exception: bool = True, - websocket_handshake_timeout=5, + websocket_handshake_timeout: Optional[float] = 5, logger_name: WSLoggerLike = None, enable_auto_ping: bool = False, auto_ping_idle_timeout: float = 20, @@ -405,6 +406,7 @@ async def ws_create_server(ws_listener_factory: WSServerListenerFactory, thrown by WSListener.on_ws_frame callback :param websocket_handshake_timeout: is the time in seconds to wait for the websocket server to receive websocket handshake request before aborting the connection. + Set to ``None`` to disable the timeout. :param logger_name: Logger name suffix or logger-like object used for logging. If a string is provided, picows will use `picows.`. diff --git a/picows/picows.pyx b/picows/picows.pyx index 41f5e63..6c3f796 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -977,11 +977,13 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): if self.is_client_side: self.transport._send_http_handshake(self._ws_path, self._host_port, self._websocket_key_b64, self._extra_headers) - self._handshake_timeout_handle = self._loop.call_later( - self._handshake_timeout, self._handshake_timeout_callback) + if self._handshake_timeout is not None: + self._handshake_timeout_handle = self._loop.call_later( + self._handshake_timeout, self._handshake_timeout_callback) else: - self._handshake_timeout_handle = self._loop.call_later( - self._handshake_timeout, self._handshake_timeout_callback) + if self._handshake_timeout is not None: + self._handshake_timeout_handle = self._loop.call_later( + self._handshake_timeout, self._handshake_timeout_callback) def connection_lost(self, exc): self._logger.info("Disconnected") @@ -1185,8 +1187,9 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): else: self.transport._send_http_handshake_response(response, accept_val) - self._handshake_timeout_handle.cancel() - self._handshake_timeout_handle = None + if self._handshake_timeout_handle is not None: + self._handshake_timeout_handle.cancel() + self._handshake_timeout_handle = None self._handshake_complete_future.set_result(None) self._invoke_on_ws_connected() self._last_data_time = picows_get_monotonic_time() diff --git a/tests/test_ws_logic.py b/tests/test_ws_logic.py index 3f120af..833a666 100644 --- a/tests/test_ws_logic.py +++ b/tests/test_ws_logic.py @@ -1,9 +1,11 @@ import asyncio +import base64 import logging import os import struct from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager +from hashlib import sha1 from http import HTTPStatus import async_timeout @@ -33,6 +35,39 @@ async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWrit await server.wait_closed() +@asynccontextmanager +async def delayed_handshake_server(delay: float): + async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + request = await reader.readuntil(b"\r\n\r\n") + websocket_key = next( + line.split(b":", 1)[1].strip() + for line in request.split(b"\r\n") + if line.lower().startswith(b"sec-websocket-key:") + ) + accept = base64.b64encode( + sha1(websocket_key + b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11").digest() + ) + await asyncio.sleep(delay) + writer.write( + b"HTTP/1.1 101 Switching Protocols\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Accept: " + accept + b"\r\n" + b"\r\n" + ) + await writer.drain() + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(handle_client, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + try: + yield f"ws://127.0.0.1:{port}/" + finally: + server.close() + await server.wait_closed() + + async def test_send_external_bytearray_asserts(): async with WSServer() as server: async with WSClient(server) as client: @@ -247,6 +282,47 @@ async def test_handshake_invalid_message_error(): await picows.ws_connect(AsyncClient, url) +async def test_client_handshake_timeout_none(): + async with delayed_handshake_server(0.2) as url: + transport, _ = await picows.ws_connect( + AsyncClient, + url, + websocket_handshake_timeout=None, + ) + transport.disconnect(False) + await transport.wait_disconnected() + + +async def test_server_handshake_timeout_none(): + server = await picows.ws_create_server( + lambda _: picows.WSListener(), + "127.0.0.1", + 0, + websocket_handshake_timeout=None, + ) + port = server.sockets[0].getsockname()[1] + try: + reader, writer = await asyncio.open_connection("127.0.0.1", port) + await asyncio.sleep(0.2) + assert not reader.at_eof() + writer.write( + b"GET / HTTP/1.1\r\n" + b"Host: 127.0.0.1\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + b"Sec-WebSocket-Version: 13\r\n" + b"\r\n" + ) + response = await reader.readuntil(b"\r\n\r\n") + assert b"101 Switching Protocols" in response + writer.close() + await writer.wait_closed() + finally: + server.close() + await server.wait_closed() + + def test_resolve_logger(): logger = logging.getLogger("tests.picows.custom") From 9d6268b12e87d5d8e998f961a69b2e724d5836d6 Mon Sep 17 00:00:00 2001 From: taras Date: Sun, 26 Apr 2026 00:58:17 +0200 Subject: [PATCH 04/16] Fix for 3.9 and 3.10 --- picows/api.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/picows/api.py b/picows/api.py index 746474c..f5cd14e 100644 --- a/picows/api.py +++ b/picows/api.py @@ -6,7 +6,7 @@ from inspect import isawaitable from logging import Logger, LoggerAdapter, getLogger from ssl import SSLContext -from typing import Callable, Optional, Union, Dict, Any, Awaitable, cast +from typing import Callable, Optional, Union, Dict, Any, Awaitable, cast, TYPE_CHECKING from python_socks.async_.asyncio import Proxy @@ -20,7 +20,13 @@ WSListenerFactory = Callable[[], WSListener] WSServerListenerFactory = Callable[[WSUpgradeRequest], Union[WSListener, WSUpgradeResponseWithListener, None]] WSSocketFactory = Callable[[WSParsedURL], Union[Optional[socket.socket], Awaitable[Optional[socket.socket]]]] -WSLoggerLike = Union[str, Logger, LoggerAdapter[Any], None] + +if TYPE_CHECKING: + _WSLoggerAdapter = LoggerAdapter[Any] +else: + _WSLoggerAdapter = LoggerAdapter + +WSLoggerLike = Union[str, Logger, _WSLoggerAdapter, None] _HAS_AIOFASTNET = False try: @@ -67,7 +73,7 @@ def _resolve_logger( logger_name: WSLoggerLike, default_suffix: str, prefix: str = "picows." -) -> Union[Logger, LoggerAdapter[Any]]: +) -> Union[Logger, _WSLoggerAdapter]: if logger_name is None: return getLogger(f"{prefix}{default_suffix}") From 55facfc7bcb9e52b0c24bd1dcd548b3fcdd2a87c Mon Sep 17 00:00:00 2001 From: taras Date: Tue, 28 Apr 2026 22:57:13 +0200 Subject: [PATCH 05/16] Add missing fields --- picows/picows.pyi | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/picows/picows.pyi b/picows/picows.pyi index eea441e..75f765d 100644 --- a/picows/picows.pyi +++ b/picows/picows.pyi @@ -49,6 +49,12 @@ class WSFrame: @property def rsv1(self) -> bool: ... + @property + def rsv2(self) -> bool: ... + + @property + def rsv3(self) -> bool: ... + @property def last_in_buffer(self) -> bool: ... From fee8e417402545ae55c929cae03659380c6703d2 Mon Sep 17 00:00:00 2001 From: taras Date: Wed, 29 Apr 2026 01:33:47 +0200 Subject: [PATCH 06/16] Add WSCloseHandshake --- picows/picows.pxd | 22 +++++++++++++++++++--- picows/picows.pyi | 25 +++++++++++++++++++----- picows/picows.pyx | 48 +++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 85 insertions(+), 10 deletions(-) diff --git a/picows/picows.pxd b/picows/picows.pxd index b0e3ec2..bdb7142 100644 --- a/picows/picows.pxd +++ b/picows/picows.pxd @@ -41,6 +41,19 @@ cpdef enum WSAutoPingStrategy: PING_PERIODICALLY = 2 +cdef class WSCloseInfo: + cdef: + readonly WSCloseCode code + readonly str reason + + +cdef class WSCloseHandshake: + cdef: + readonly WSCloseInfo recv + readonly WSCloseInfo sent + readonly bint recv_then_sent + + cdef class MemoryBuffer: cdef: Py_ssize_t size @@ -70,15 +83,17 @@ cdef class WSFrame: cpdef WSCloseCode get_close_code(self) cpdef bytes get_close_message(self) + cpdef str get_close_reason(self) cdef class WSTransport: cdef: object __weakref__ - readonly object underlying_transport #: asyncio.Transport - readonly object request #: WSUpgradeRequest - readonly object response #: WSUpgradeResponse + readonly object underlying_transport #: asyncio.Transport + readonly object request #: WSUpgradeRequest + readonly object response #: WSUpgradeResponse + readonly WSCloseHandshake close_handshake #: Optional[WSCloseHandshake] readonly bint is_client_side readonly bint is_secure readonly bint is_close_frame_sent @@ -88,6 +103,7 @@ cdef class WSTransport: object listener_proxy object disconnected_future #: asyncio.Future + object _loop object _logger #: Logger MemoryBuffer _write_buffer diff --git a/picows/picows.pyi b/picows/picows.pyi index 75f765d..b8b4e27 100644 --- a/picows/picows.pyi +++ b/picows/picows.pyi @@ -36,6 +36,17 @@ class WSAutoPingStrategy(Enum): PING_PERIODICALLY = 2 +class WSCloseInfo: + code: WSCloseCode + reason: str + + +class WSCloseHandshake: + recv: Optional[WSCloseInfo] + sent: Optional[WSCloseInfo] + recv_then_sent: bool + + class WSFrame: @property def tail_size(self) -> int: ... @@ -64,6 +75,7 @@ class WSFrame: def get_payload_as_memoryview(self) -> memoryview: ... def get_close_code(self) -> WSCloseCode: ... def get_close_message(self) -> bytes: ... + def get_close_reason(self) -> str: ... def __str__(self) -> str: ... @@ -72,19 +84,22 @@ class WSTransport: def underlying_transport(self) -> asyncio.Transport: ... @property - def is_client_side(self) -> bool: ... + def request(self) -> WSUpgradeRequest: ... @property - def is_secure(self) -> bool: ... + def response(self) -> WSUpgradeResponse: ... @property - def is_close_frame_sent(self) -> bool: ... + def close_handshake(self) -> WSCloseHandshake: ... @property - def request(self) -> WSUpgradeRequest: ... + def is_client_side(self) -> bool: ... @property - def response(self) -> WSUpgradeResponse: ... + def is_secure(self) -> bool: ... + + @property + def is_close_frame_sent(self) -> bool: ... def send( self, diff --git a/picows/picows.pyx b/picows/picows.pyx index 6c3f796..e75f4a9 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -199,6 +199,21 @@ cdef class WSFrame: else: return PyBytes_FromStringAndSize(self.payload_ptr + 2, self.payload_size - 2) + cpdef str get_close_reason(self): + """ + :return: a new str object with a close reason. If there is no close reason then returns None. + + This method is only valid for WSMsgType.CLOSE frames. + """ + + assert self.msg_type == WSMsgType.CLOSE, "get_close_message can be called only for CLOSE frames" + + if self.payload_size <= 2: + return None + else: + return PyUnicode_FromStringAndSize(self.payload_ptr + 2, + self.payload_size - 2) + def __str__(self): return (f"WSFrame({WSMsgType(self.msg_type).name}, fin={True if self.fin else False}, " f"rsv1={True if self.rsv1 else False}, " @@ -341,6 +356,7 @@ cdef class WSTransport: self.underlying_transport = underlying_transport self.request = None self.response = None + self.close_handshake = None self.is_client_side = is_client_side self.is_secure = underlying_transport.get_extra_info('ssl_object') is not None self.is_close_frame_sent = False @@ -496,6 +512,14 @@ cdef class WSTransport: if msg_type == WSMsgType.CLOSE: self.is_close_frame_sent = True + if self.close_handshake is None: + self.close_handshake = WSCloseHandshake.__new__(WSCloseHandshake) + self.close_handshake.recv_then_sent = False + + self.close_handshake.sent = WSCloseInfo.__new__(WSCloseInfo) + self.close_handshake.sent.code = ntohs((msg_ptr)[0]) + self.close_handshake.sent.reason = PyUnicode_FromStringAndSize(msg_ptr + 2, msg_size - 2) + cdef send_reuse_external_buffer(self, WSMsgType msg_type, char* msg_ptr, Py_ssize_t msg_size, bint fin=True, bint rsv1=False): @@ -1479,7 +1503,6 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): return response cdef inline WSFrame _get_next_frame(self): - cdef WSFrame frame try: return self._get_next_frame_impl() except WSProtocolError as ex: @@ -1499,6 +1522,7 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): uint8_t first_byte uint8_t second_byte WSFrame frame + WSCloseInfo recv if self._state == WSParserState.READ_HEADER: if self._f_new_data_start_pos - self._f_curr_state_start_pos < 2: @@ -1619,7 +1643,8 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): self._state = WSParserState.READ_HEADER if frame.msg_type == WSMsgType.CLOSE: - if frame.get_close_code() < 3000 and frame.get_close_code() not in _ALLOWED_CLOSE_CODES: + close_code = frame.get_close_code() + if close_code < 3000 and close_code not in _ALLOWED_CLOSE_CODES: raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, f"Received CLOSE with invalid close code: {frame.get_close_code()}") @@ -1627,6 +1652,25 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, f"Received CLOSE with invalid close code size: {frame.fin} {frame.msg_type} {frame.get_payload_as_bytes()}") + recv = WSCloseInfo.__new__(WSCloseInfo) + recv.code = close_code + try: + recv.reason = frame.get_close_reason() + except UnicodeDecodeError: + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, + f"Received CLOSE with invalid UTF-8 reason") + + if self.transport.close_handshake is None: + self.transport.close_handshake = WSCloseHandshake.__new__(WSCloseHandshake) + self.transport.close_handshake.recv = recv + self.transport.close_handshake.sent = None + self.transport.close_handshake.recv_then_sent = True + elif self.transport.close_handshake.recv is None: + self.transport.close_handshake.recv = recv + else: + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, + f"Received CLOSE for the second time: {frame.get_close_code()}") + return frame assert False, "we should never reach this state" From 12c3c699f5497bff07de3d157a8f6ec728c0101f Mon Sep 17 00:00:00 2001 From: taras Date: Wed, 29 Apr 2026 01:42:41 +0200 Subject: [PATCH 07/16] Add basic tests --- picows/picows.pyx | 8 +++++ tests/test_ws_logic.py | 78 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/picows/picows.pyx b/picows/picows.pyx index e75f4a9..5a7269f 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -118,6 +118,14 @@ cdef _is_aiofn_transport(transport): return False +cdef class WSCloseInfo: + pass + + +cdef class WSCloseHandshake: + pass + + @cython.no_gc @cython.freelist(64) cdef class WSFrame: diff --git a/tests/test_ws_logic.py b/tests/test_ws_logic.py index 833a666..ed6546a 100644 --- a/tests/test_ws_logic.py +++ b/tests/test_ws_logic.py @@ -219,6 +219,84 @@ def on_ws_connected(self, transport: picows.WSTransport): await client.transport.wait_disconnected() +async def test_close_handshake_client_initiates_close(): + server_transport = None + + class ServerListener(picows.WSListener): + def on_ws_connected(self, transport: picows.WSTransport): + nonlocal server_transport + server_transport = transport + + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + if frame.msg_type == picows.WSMsgType.CLOSE: + transport.send_close(frame.get_close_code(), frame.get_close_message()) + transport.disconnect() + + async with WSServer(lambda _: ServerListener()) as server: + transport, _ = await picows.ws_connect(AsyncClient, server.url) + try: + transport.send_close(picows.WSCloseCode.OK, b"client says bye") + await transport.wait_disconnected() + finally: + try: + transport.disconnect(False) + await transport.wait_disconnected() + except Exception: + pass + + assert transport.close_handshake.sent.code == picows.WSCloseCode.OK + assert transport.close_handshake.sent.reason == "client says bye" + assert transport.close_handshake.recv.code == picows.WSCloseCode.OK + assert transport.close_handshake.recv.reason == "client says bye" + assert transport.close_handshake.recv_then_sent is False + + assert server_transport.close_handshake.sent.code == picows.WSCloseCode.OK + assert server_transport.close_handshake.sent.reason == "client says bye" + assert server_transport.close_handshake.recv.code == picows.WSCloseCode.OK + assert server_transport.close_handshake.recv.reason == "client says bye" + assert server_transport.close_handshake.recv_then_sent is True + + +async def test_close_handshake_server_initiates_close(): + server_transport = None + + class ServerListener(picows.WSListener): + def on_ws_connected(self, transport: picows.WSTransport): + nonlocal server_transport + server_transport = transport + transport.send_close(picows.WSCloseCode.GOING_AWAY, b"server shutdown") + asyncio.get_running_loop().call_later(0.05, transport.disconnect) + + class ClientListener(AsyncClient): + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + super().on_ws_frame(transport, frame) + if frame.msg_type == picows.WSMsgType.CLOSE: + transport.send_close(frame.get_close_code(), frame.get_close_message()) + + async with WSServer(lambda _: ServerListener()) as server: + transport, _ = await picows.ws_connect(ClientListener, server.url) + try: + await transport.wait_disconnected() + finally: + try: + transport.disconnect(False) + await transport.wait_disconnected() + except Exception: + pass + + assert transport.close_handshake.recv.code == picows.WSCloseCode.GOING_AWAY + assert transport.close_handshake.recv.reason == "server shutdown" + assert transport.close_handshake.sent.code == picows.WSCloseCode.GOING_AWAY + assert transport.close_handshake.sent.reason == "server shutdown" + assert transport.close_handshake.recv_then_sent is True + + assert server_transport.close_handshake.recv.code == picows.WSCloseCode.GOING_AWAY + assert server_transport.close_handshake.recv.reason == "server shutdown" + assert server_transport.close_handshake.sent.code == picows.WSCloseCode.GOING_AWAY + assert server_transport.close_handshake.sent.reason == "server shutdown" + assert server_transport.close_handshake.recv_then_sent is False + + async def test_wrong_thread_assert(): loop = asyncio.get_running_loop() with ThreadPoolExecutor(max_workers=1) as executor: From 36c32d5cb12a2be8f0b137d3d820eb025638ab19 Mon Sep 17 00:00:00 2001 From: taras Date: Wed, 29 Apr 2026 01:50:34 +0200 Subject: [PATCH 08/16] Add more tests --- tests/test_ws_logic.py | 97 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/tests/test_ws_logic.py b/tests/test_ws_logic.py index ed6546a..1a87bf2 100644 --- a/tests/test_ws_logic.py +++ b/tests/test_ws_logic.py @@ -219,6 +219,30 @@ def on_ws_connected(self, transport: picows.WSTransport): await client.transport.wait_disconnected() +async def test_close_frame_invalid_utf8_reason_from_client(): + async with WSServer() as server: + async with WSClient(server) as client: + mask = 0x12345678 + payload = struct.pack("!H", picows.WSCloseCode.OK) + b"\xff" + masked_payload = bytes( + b ^ mask.to_bytes(4, "big")[i % 4] + for i, b in enumerate(payload) + ) + invalid_close_frame = struct.pack("!BBI", 0x88, 0x80 | len(payload), mask) + masked_payload + + client.transport.underlying_transport.write(invalid_close_frame) + frame = await client.get_message() + assert frame.msg_type == picows.WSMsgType.CLOSE + assert frame.close_code == picows.WSCloseCode.PROTOCOL_ERROR + assert b"Received CLOSE with invalid UTF-8 reason" in frame.close_message + await client.transport.wait_disconnected() + + assert client.transport.close_handshake.sent is None + assert client.transport.close_handshake.recv.code == picows.WSCloseCode.PROTOCOL_ERROR + assert client.transport.close_handshake.recv.reason == "Received CLOSE with invalid UTF-8 reason" + assert client.transport.close_handshake.recv_then_sent is True + + async def test_close_handshake_client_initiates_close(): server_transport = None @@ -297,6 +321,79 @@ def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): assert server_transport.close_handshake.recv_then_sent is False +async def test_close_handshake_client_initiates_close_server_disconnects_without_reply(): + server_transport = None + + class ServerListener(picows.WSListener): + def on_ws_connected(self, transport: picows.WSTransport): + nonlocal server_transport + server_transport = transport + + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + if frame.msg_type == picows.WSMsgType.CLOSE: + transport.disconnect(False) + + async with WSServer(lambda _: ServerListener()) as server: + transport, _ = await picows.ws_connect(AsyncClient, server.url) + try: + transport.send_close(picows.WSCloseCode.OK, b"client says bye") + await transport.wait_disconnected() + finally: + try: + transport.disconnect(False) + await transport.wait_disconnected() + except Exception: + pass + + assert transport.close_handshake.recv is None + assert transport.close_handshake.sent.code == picows.WSCloseCode.OK + assert transport.close_handshake.sent.reason == "client says bye" + assert transport.close_handshake.recv_then_sent is False + + assert server_transport.close_handshake.recv.code == picows.WSCloseCode.OK + assert server_transport.close_handshake.recv.reason == "client says bye" + assert server_transport.close_handshake.sent is None + assert server_transport.close_handshake.recv_then_sent is True + + +async def test_close_handshake_server_initiates_close_client_disconnects_without_reply(): + server_transport = None + + class ServerListener(picows.WSListener): + def on_ws_connected(self, transport: picows.WSTransport): + nonlocal server_transport + server_transport = transport + transport.send_close(picows.WSCloseCode.GOING_AWAY, b"server shutdown") + asyncio.get_running_loop().call_later(0.05, transport.disconnect) + + class ClientListener(AsyncClient): + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + super().on_ws_frame(transport, frame) + if frame.msg_type == picows.WSMsgType.CLOSE: + transport.disconnect(False) + + async with WSServer(lambda _: ServerListener()) as server: + transport, _ = await picows.ws_connect(ClientListener, server.url) + try: + await transport.wait_disconnected() + finally: + try: + transport.disconnect(False) + await transport.wait_disconnected() + except Exception: + pass + + assert transport.close_handshake.recv.code == picows.WSCloseCode.GOING_AWAY + assert transport.close_handshake.recv.reason == "server shutdown" + assert transport.close_handshake.sent is None + assert transport.close_handshake.recv_then_sent is True + + assert server_transport.close_handshake.recv is None + assert server_transport.close_handshake.sent.code == picows.WSCloseCode.GOING_AWAY + assert server_transport.close_handshake.sent.reason == "server shutdown" + assert server_transport.close_handshake.recv_then_sent is False + + async def test_wrong_thread_assert(): loop = asyncio.get_running_loop() with ThreadPoolExecutor(max_workers=1) as executor: From 3c76fc131c8e231731c0e4f675ecfa5df42b444d Mon Sep 17 00:00:00 2001 From: taras Date: Wed, 29 Apr 2026 01:55:26 +0200 Subject: [PATCH 09/16] Update --- picows/picows.pyx | 2 +- tests/test_ws_logic.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/picows/picows.pyx b/picows/picows.pyx index 5a7269f..64a182c 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -1657,7 +1657,7 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): f"Received CLOSE with invalid close code: {frame.get_close_code()}") if frame.payload_size == 1: - raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, + raise WSProtocolError(WSCloseCode.INVALID_TEXT, f"Received CLOSE with invalid close code size: {frame.fin} {frame.msg_type} {frame.get_payload_as_bytes()}") recv = WSCloseInfo.__new__(WSCloseInfo) diff --git a/tests/test_ws_logic.py b/tests/test_ws_logic.py index 1a87bf2..02eb746 100644 --- a/tests/test_ws_logic.py +++ b/tests/test_ws_logic.py @@ -233,12 +233,12 @@ async def test_close_frame_invalid_utf8_reason_from_client(): client.transport.underlying_transport.write(invalid_close_frame) frame = await client.get_message() assert frame.msg_type == picows.WSMsgType.CLOSE - assert frame.close_code == picows.WSCloseCode.PROTOCOL_ERROR + assert frame.close_code == picows.WSCloseCode.INVALID_TEXT assert b"Received CLOSE with invalid UTF-8 reason" in frame.close_message await client.transport.wait_disconnected() assert client.transport.close_handshake.sent is None - assert client.transport.close_handshake.recv.code == picows.WSCloseCode.PROTOCOL_ERROR + assert client.transport.close_handshake.recv.code == picows.WSCloseCode.INVALID_TEXT assert client.transport.close_handshake.recv.reason == "Received CLOSE with invalid UTF-8 reason" assert client.transport.close_handshake.recv_then_sent is True From de5ae99c0bbae2aff5a81c0aea7facf6fb92e220 Mon Sep 17 00:00:00 2001 From: taras Date: Wed, 29 Apr 2026 22:18:26 +0200 Subject: [PATCH 10/16] Simplify tests --- tests/test_ws_logic.py | 124 ++++++++++++++++------------------------- 1 file changed, 48 insertions(+), 76 deletions(-) diff --git a/tests/test_ws_logic.py b/tests/test_ws_logic.py index 02eb746..3b2aa63 100644 --- a/tests/test_ws_logic.py +++ b/tests/test_ws_logic.py @@ -257,28 +257,21 @@ def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): transport.disconnect() async with WSServer(lambda _: ServerListener()) as server: - transport, _ = await picows.ws_connect(AsyncClient, server.url) - try: - transport.send_close(picows.WSCloseCode.OK, b"client says bye") - await transport.wait_disconnected() - finally: - try: - transport.disconnect(False) - await transport.wait_disconnected() - except Exception: - pass + async with WSClient(server) as client: + client.transport.send_close(picows.WSCloseCode.OK, b"client says bye") + await client.transport.wait_disconnected() - assert transport.close_handshake.sent.code == picows.WSCloseCode.OK - assert transport.close_handshake.sent.reason == "client says bye" - assert transport.close_handshake.recv.code == picows.WSCloseCode.OK - assert transport.close_handshake.recv.reason == "client says bye" - assert transport.close_handshake.recv_then_sent is False + assert client.transport.close_handshake.sent.code == picows.WSCloseCode.OK + assert client.transport.close_handshake.sent.reason == "client says bye" + assert client.transport.close_handshake.recv.code == picows.WSCloseCode.OK + assert client.transport.close_handshake.recv.reason == "client says bye" + assert client.transport.close_handshake.recv_then_sent is False - assert server_transport.close_handshake.sent.code == picows.WSCloseCode.OK - assert server_transport.close_handshake.sent.reason == "client says bye" - assert server_transport.close_handshake.recv.code == picows.WSCloseCode.OK - assert server_transport.close_handshake.recv.reason == "client says bye" - assert server_transport.close_handshake.recv_then_sent is True + assert server_transport.close_handshake.sent.code == picows.WSCloseCode.OK + assert server_transport.close_handshake.sent.reason == "client says bye" + assert server_transport.close_handshake.recv.code == picows.WSCloseCode.OK + assert server_transport.close_handshake.recv.reason == "client says bye" + assert server_transport.close_handshake.recv_then_sent is True async def test_close_handshake_server_initiates_close(): @@ -293,32 +286,25 @@ def on_ws_connected(self, transport: picows.WSTransport): class ClientListener(AsyncClient): def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): - super().on_ws_frame(transport, frame) if frame.msg_type == picows.WSMsgType.CLOSE: transport.send_close(frame.get_close_code(), frame.get_close_message()) + else: + super().on_ws_frame(transport, frame) async with WSServer(lambda _: ServerListener()) as server: - transport, _ = await picows.ws_connect(ClientListener, server.url) - try: - await transport.wait_disconnected() - finally: - try: - transport.disconnect(False) - await transport.wait_disconnected() - except Exception: - pass - - assert transport.close_handshake.recv.code == picows.WSCloseCode.GOING_AWAY - assert transport.close_handshake.recv.reason == "server shutdown" - assert transport.close_handshake.sent.code == picows.WSCloseCode.GOING_AWAY - assert transport.close_handshake.sent.reason == "server shutdown" - assert transport.close_handshake.recv_then_sent is True + async with WSClient(server, ClientListener) as client: + await client.transport.wait_disconnected() + assert client.transport.close_handshake.recv.code == picows.WSCloseCode.GOING_AWAY + assert client.transport.close_handshake.recv.reason == "server shutdown" + assert client.transport.close_handshake.sent.code == picows.WSCloseCode.GOING_AWAY + assert client.transport.close_handshake.sent.reason == "server shutdown" + assert client.transport.close_handshake.recv_then_sent is True - assert server_transport.close_handshake.recv.code == picows.WSCloseCode.GOING_AWAY - assert server_transport.close_handshake.recv.reason == "server shutdown" - assert server_transport.close_handshake.sent.code == picows.WSCloseCode.GOING_AWAY - assert server_transport.close_handshake.sent.reason == "server shutdown" - assert server_transport.close_handshake.recv_then_sent is False + assert server_transport.close_handshake.recv.code == picows.WSCloseCode.GOING_AWAY + assert server_transport.close_handshake.recv.reason == "server shutdown" + assert server_transport.close_handshake.sent.code == picows.WSCloseCode.GOING_AWAY + assert server_transport.close_handshake.sent.reason == "server shutdown" + assert server_transport.close_handshake.recv_then_sent is False async def test_close_handshake_client_initiates_close_server_disconnects_without_reply(): @@ -334,26 +320,19 @@ def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): transport.disconnect(False) async with WSServer(lambda _: ServerListener()) as server: - transport, _ = await picows.ws_connect(AsyncClient, server.url) - try: - transport.send_close(picows.WSCloseCode.OK, b"client says bye") - await transport.wait_disconnected() - finally: - try: - transport.disconnect(False) - await transport.wait_disconnected() - except Exception: - pass + async with WSClient(server) as client: + client.transport.send_close(picows.WSCloseCode.OK, b"client says bye") + await client.transport.wait_disconnected() - assert transport.close_handshake.recv is None - assert transport.close_handshake.sent.code == picows.WSCloseCode.OK - assert transport.close_handshake.sent.reason == "client says bye" - assert transport.close_handshake.recv_then_sent is False + assert client.transport.close_handshake.recv is None + assert client.transport.close_handshake.sent.code == picows.WSCloseCode.OK + assert client.transport.close_handshake.sent.reason == "client says bye" + assert client.transport.close_handshake.recv_then_sent is False - assert server_transport.close_handshake.recv.code == picows.WSCloseCode.OK - assert server_transport.close_handshake.recv.reason == "client says bye" - assert server_transport.close_handshake.sent is None - assert server_transport.close_handshake.recv_then_sent is True + assert server_transport.close_handshake.recv.code == picows.WSCloseCode.OK + assert server_transport.close_handshake.recv.reason == "client says bye" + assert server_transport.close_handshake.sent is None + assert server_transport.close_handshake.recv_then_sent is True async def test_close_handshake_server_initiates_close_client_disconnects_without_reply(): @@ -373,25 +352,18 @@ def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): transport.disconnect(False) async with WSServer(lambda _: ServerListener()) as server: - transport, _ = await picows.ws_connect(ClientListener, server.url) - try: - await transport.wait_disconnected() - finally: - try: - transport.disconnect(False) - await transport.wait_disconnected() - except Exception: - pass + async with WSClient(server, ClientListener) as client: + await client.transport.wait_disconnected() - assert transport.close_handshake.recv.code == picows.WSCloseCode.GOING_AWAY - assert transport.close_handshake.recv.reason == "server shutdown" - assert transport.close_handshake.sent is None - assert transport.close_handshake.recv_then_sent is True + assert client.transport.close_handshake.recv.code == picows.WSCloseCode.GOING_AWAY + assert client.transport.close_handshake.recv.reason == "server shutdown" + assert client.transport.close_handshake.sent is None + assert client.transport.close_handshake.recv_then_sent is True - assert server_transport.close_handshake.recv is None - assert server_transport.close_handshake.sent.code == picows.WSCloseCode.GOING_AWAY - assert server_transport.close_handshake.sent.reason == "server shutdown" - assert server_transport.close_handshake.recv_then_sent is False + assert server_transport.close_handshake.recv is None + assert server_transport.close_handshake.sent.code == picows.WSCloseCode.GOING_AWAY + assert server_transport.close_handshake.sent.reason == "server shutdown" + assert server_transport.close_handshake.recv_then_sent is False async def test_wrong_thread_assert(): From 2b38098086ee23625c186f2ef15072ca73026696 Mon Sep 17 00:00:00 2001 From: taras Date: Wed, 29 Apr 2026 22:23:31 +0200 Subject: [PATCH 11/16] Fix tests --- picows/picows.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/picows/picows.pyx b/picows/picows.pyx index 64a182c..0f54d8d 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -1657,7 +1657,7 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): f"Received CLOSE with invalid close code: {frame.get_close_code()}") if frame.payload_size == 1: - raise WSProtocolError(WSCloseCode.INVALID_TEXT, + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, f"Received CLOSE with invalid close code size: {frame.fin} {frame.msg_type} {frame.get_payload_as_bytes()}") recv = WSCloseInfo.__new__(WSCloseInfo) @@ -1665,7 +1665,7 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): try: recv.reason = frame.get_close_reason() except UnicodeDecodeError: - raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, + raise WSProtocolError(WSCloseCode.INVALID_TEXT, f"Received CLOSE with invalid UTF-8 reason") if self.transport.close_handshake is None: From 227760f0b89ef879804907f6ddd77c046e4dd4e8 Mon Sep 17 00:00:00 2001 From: taras Date: Fri, 1 May 2026 13:20:24 +0200 Subject: [PATCH 12/16] Cleanups --- HISTORY.rst | 2 ++ picows/picows.pyx | 54 +++++++++++++++++++++++------------------- tests/test_ws_logic.py | 4 ++-- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index 246d12b..aa01668 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -11,6 +11,8 @@ picows Release History * ws_connect/ws_create_server logger_name parameter can now accept a logger-like object * ws_connect/ws_create_server websocket_handshake_timeout param can now accept None to disable handshake timeouts * Introduce new exceptions: WSInvalidMessageError, WSInvalidStatusError, WSInvalidHeaderError, WSInvalidUpgradeError +* Allow sending close frames only using send_close to simplify logic +* Raise ValueError instead of assert on some invalid user input 1.19.0 (2026-04-24) ------------------ diff --git a/picows/picows.pyx b/picows/picows.pyx index 0f54d8d..89f7d24 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -445,6 +445,10 @@ cdef class WSTransport: cdef _send_buffer(self, WSMsgType msg_type, char* msg_ptr, Py_ssize_t msg_size, bint fin, bint rsv1): + if self.is_close_frame_sent: + self._logger.debug("Ignore attempt to send a message after WSMsgType.CLOSE has already been sent") + return + cdef: Py_ssize_t header_size = self._get_header_size(msg_size) char* header_ptr = msg_ptr - header_size @@ -455,9 +459,6 @@ cdef class WSTransport: self._fast_write(header_ptr, header_size + msg_size) - if msg_type == WSMsgType.CLOSE: - self.is_close_frame_sent = True - cdef _send(self, WSMsgType msg_type, message, bint fin, bint rsv1): if self.is_close_frame_sent: self._logger.debug("Ignore attempt to send a message after WSMsgType.CLOSE has already been sent") @@ -517,25 +518,13 @@ cdef class WSTransport: (masked_msg_ptr - header_size), header_size + msg_size ) - if msg_type == WSMsgType.CLOSE: - self.is_close_frame_sent = True - - if self.close_handshake is None: - self.close_handshake = WSCloseHandshake.__new__(WSCloseHandshake) - self.close_handshake.recv_then_sent = False - - self.close_handshake.sent = WSCloseInfo.__new__(WSCloseInfo) - self.close_handshake.sent.code = ntohs((msg_ptr)[0]) - self.close_handshake.sent.reason = PyUnicode_FromStringAndSize(msg_ptr + 2, msg_size - 2) - cdef send_reuse_external_buffer(self, WSMsgType msg_type, char* msg_ptr, Py_ssize_t msg_size, bint fin=True, bint rsv1=False): self._check_thread("send_reuse_external_buffer") - if self.is_close_frame_sent: - self._logger.debug("Ignore attempt to send a message after WSMsgType.CLOSE has already been sent") - return + if msg_type == WSMsgType.CLOSE: + raise ValueError("attempt to send CLOSE frame using send_reuse_external_buffer, use send_close instead") self._send_buffer(msg_type, msg_ptr, msg_size, fin, rsv1) @@ -548,7 +537,7 @@ cdef class WSTransport: This function does not copy message to prepare websocket frames. It reuses bytearray's memory to write websocket frame header at the front. - :param msg_type: :any:`WSMsgType` enum value\n + :param msg_type: :any:`WSMsgType` enum value, except CLOSE. Use send_close to send close frames. :param msg_offset: specifies where message begins in the bytearray. Must be at least 14 to let picows to write websocket frame header in front of the message. :param buffer: bytearray that contains message and some extra space (at least 14 bytes) in the beginning. @@ -558,20 +547,23 @@ cdef class WSTransport: :param rsv1: first reserved bit in websocket frame. Some protocol extensions use it to indicate that payload is compressed. """ - assert buffer is not None, "buffer is None" - assert msg_offset >= 14, "buffer must have at least 14 bytes available before message starts, check msg_offset parameter" + if buffer is None: + raise ValueError("None is passed instead of buffer to send_reuse_external_bytearray") - self._check_thread("send_reuse_external_bytearray") + if msg_offset < 14: + raise ValueError("buffer must have at least 14 bytes available before message starts, check msg_offset parameter") - if self.is_close_frame_sent: - self._logger.debug("Ignore attempt to send a message after WSMsgType.CLOSE has already been sent") - return + if msg_type == WSMsgType.CLOSE: + raise ValueError("attempt to send CLOSE frame using send_reuse_external_bytearray, use send_close instead") + + self._check_thread("send_reuse_external_bytearray") cdef: char* buffer_ptr = PyByteArray_AS_STRING(buffer) Py_ssize_t buffer_size = PyByteArray_GET_SIZE(buffer) - assert buffer_size >= msg_offset, "msg_offset points beyond buffer end, msg_offset > len(buffer)" + if buffer_size < msg_offset: + raise ValueError("msg_offset points beyond buffer end, msg_offset > len(buffer)") cdef: char* msg_ptr = buffer_ptr + msg_offset @@ -638,12 +630,24 @@ cdef class WSTransport: cdef: bytes msg = PyBytes_FromStringAndSize(NULL, close_msg_length + 2) char* msg_ptr = PyBytes_AS_STRING(msg) + str reason = PyUnicode_FromStringAndSize(close_msg_ptr, close_msg_length) (msg_ptr)[0] = htons(close_code) memcpy(msg_ptr + 2, close_msg_ptr, close_msg_length) self._send(WSMsgType.CLOSE, msg, True, False) + if not self.is_close_frame_sent: + self.is_close_frame_sent = True + + if self.close_handshake is None: + self.close_handshake = WSCloseHandshake.__new__(WSCloseHandshake) + self.close_handshake.recv_then_sent = False + + self.close_handshake.sent = WSCloseInfo.__new__(WSCloseInfo) + self.close_handshake.sent.code = close_code + self.close_handshake.sent.reason = reason + cpdef disconnect(self, bint graceful=True): """ Close the underlying transport. diff --git a/tests/test_ws_logic.py b/tests/test_ws_logic.py index 3b2aa63..80812bf 100644 --- a/tests/test_ws_logic.py +++ b/tests/test_ws_logic.py @@ -71,11 +71,11 @@ async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWrit async def test_send_external_bytearray_asserts(): async with WSServer() as server: async with WSClient(server) as client: - with pytest.raises(AssertionError): + with pytest.raises(ValueError): # Check assertion for msg_len >= 0 client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, bytearray(b"HELLO"), 16) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): # Check assertion for offset to be at least 14 client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, bytearray(b"1234567890123HELLO"), 13) From edafac0ce8ee9f60ccabbbd0d806f790a8816c81 Mon Sep 17 00:00:00 2001 From: taras Date: Fri, 1 May 2026 14:02:53 +0200 Subject: [PATCH 13/16] Add rsv2 and rsv3 to send --- HISTORY.rst | 1 + docs/source/reference.rst | 6 +++++- picows/picows.pxd | 13 ++++++------ picows/picows.pyi | 6 +++++- picows/picows.pyx | 44 +++++++++++++++++++++++++-------------- tests/test_basics.py | 21 +++++++++++++------ tests/utils.py | 9 ++++++-- 7 files changed, 67 insertions(+), 33 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index aa01668..83cbc5e 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -13,6 +13,7 @@ picows Release History * Introduce new exceptions: WSInvalidMessageError, WSInvalidStatusError, WSInvalidHeaderError, WSInvalidUpgradeError * Allow sending close frames only using send_close to simplify logic * Raise ValueError instead of assert on some invalid user input +* Added rsv2 and rsv3 to WSTransport send methods 1.19.0 (2026-04-24) ------------------ diff --git a/docs/source/reference.rst b/docs/source/reference.rst index 3ac92b1..237c98a 100644 --- a/docs/source/reference.rst +++ b/docs/source/reference.rst @@ -243,7 +243,7 @@ Classes Opening handshake response. - .. py:method:: send_reuse_external_buffer(WSMsgType msg_type, char* msg_ptr, size_t msg_size, bint fin=True, bint rsv1=False) + .. py:method:: send_reuse_external_buffer(WSMsgType msg_type, char* msg_ptr, size_t msg_size, bint fin=True, bint rsv1=False, bint rsv2=False, bint rsv3=False) **Available only from Cython.** @@ -263,6 +263,10 @@ Classes :param rsv1: first reserved bit in websocket frame. Some protocol extensions use it to indicate that payload is compressed. + :param rsv2: second reserved bit in websocket frame. + Protocol extensions can use this flag. + :param rsv3: third reserved bit in websocket frame. + Protocol extensions can use this flag. Enums ----- diff --git a/picows/picows.pxd b/picows/picows.pxd index bdb7142..93510f6 100644 --- a/picows/picows.pxd +++ b/picows/picows.pxd @@ -113,9 +113,9 @@ cdef class WSTransport: bint _is_aiofn_transport bint _log_debug_enabled - cdef inline send_reuse_external_buffer(self, WSMsgType msg_type, char* msg_ptr, Py_ssize_t msg_size, bint fin=*, bint rsv1=*) - cpdef send(self, WSMsgType msg_type, message, bint fin=*, bint rsv1=*) - cpdef send_reuse_external_bytearray(self, WSMsgType msg_type, bytearray buffer, Py_ssize_t msg_offset, bint fin=*, bint rsv1=*) + cdef inline send_reuse_external_buffer(self, WSMsgType msg_type, char* msg_ptr, Py_ssize_t msg_size, bint fin=*, bint rsv1=*, bint rsv2=*, bint rsv3=*) + cpdef send(self, WSMsgType msg_type, message, bint fin=*, bint rsv1=*, bint rsv2=*, bint rsv3=*) + cpdef send_reuse_external_bytearray(self, WSMsgType msg_type, bytearray buffer, Py_ssize_t msg_offset, bint fin=*, bint rsv1=*, bint rsv2=*, bint rsv3=*) cpdef send_ping(self, message=*) cpdef send_pong(self, message=*) cpdef send_close(self, WSCloseCode close_code=*, close_message=*) @@ -126,9 +126,9 @@ cdef class WSTransport: cdef inline Py_ssize_t _get_header_size(self, Py_ssize_t msg_size) noexcept cdef inline _send_buffer(self, WSMsgType msg_type, char* msg_ptr, Py_ssize_t msg_size, - bint fin, bint rsv1) - cdef inline _send(self, WSMsgType msg_type, message, bint fin, bint rsv1) - cdef inline uint32_t _prepare_header(self, uint8_t* header_ptr, WSMsgType msg_type, Py_ssize_t msg_size, bint fin, bint rsv1) noexcept + bint fin, bint rsv1, bint rsv2, bint rsv3) + cdef inline _send(self, WSMsgType msg_type, message, bint fin, bint rsv1, bint rsv2, bint rsv3) + cdef inline uint32_t _prepare_header(self, uint8_t* header_ptr, WSMsgType msg_type, Py_ssize_t msg_size, bint fin, bint rsv1, bint rsv2, bint rsv3) noexcept cdef inline _send_http_handshake(self, bytes ws_path, bytes host_port, bytes websocket_key_b64, object extra_headers) cdef inline _send_http_handshake_response(self, response, bytes accept_val) cdef inline _fast_write(self, char* ptr, Py_ssize_t sz) @@ -146,4 +146,3 @@ cdef class WSListener: cpdef pause_writing(self) cpdef resume_writing(self) - diff --git a/picows/picows.pyi b/picows/picows.pyi index b8b4e27..ddcd64d 100644 --- a/picows/picows.pyi +++ b/picows/picows.pyi @@ -107,6 +107,8 @@ class WSTransport: message: Optional[WSBuffer], fin: bool = True, rsv1: bool = False, + rsv2: bool = False, + rsv3: bool = False, ) -> None: ... def send_reuse_external_bytearray( self, @@ -114,7 +116,9 @@ class WSTransport: buffer: bytearray, msg_offset: int, fin: bool = True, - rsv1: bool = False + rsv1: bool = False, + rsv2: bool = False, + rsv3: bool = False, ) -> None: ... def send_ping(self, message: Optional[WSBuffer]=None) -> None: ... def send_pong(self, message: Optional[WSBuffer]=None) -> None: ... diff --git a/picows/picows.pyx b/picows/picows.pyx index 89f7d24..77446ed 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -404,7 +404,7 @@ cdef class WSTransport: cdef uint32_t _prepare_header(self, uint8_t* header_ptr, WSMsgType msg_type, Py_ssize_t msg_size, - bint fin, bint rsv1) noexcept: + bint fin, bint rsv1, bint rsv2, bint rsv3) noexcept: # Return mask or 0 for server side cdef: @@ -416,6 +416,10 @@ cdef class WSTransport: first_byte |= 0x80 if rsv1: first_byte |= 0x40 + if rsv2: + first_byte |= 0x20 + if rsv3: + first_byte |= 0x10 header_ptr[0] = first_byte @@ -444,7 +448,7 @@ cdef class WSTransport: cdef _send_buffer(self, WSMsgType msg_type, char* msg_ptr, Py_ssize_t msg_size, - bint fin, bint rsv1): + bint fin, bint rsv1, bint rsv2, bint rsv3): if self.is_close_frame_sent: self._logger.debug("Ignore attempt to send a message after WSMsgType.CLOSE has already been sent") return @@ -452,14 +456,14 @@ cdef class WSTransport: cdef: Py_ssize_t header_size = self._get_header_size(msg_size) char* header_ptr = msg_ptr - header_size - uint32_t mask = self._prepare_header(header_ptr, msg_type, msg_size, fin, rsv1) + uint32_t mask = self._prepare_header(header_ptr, msg_type, msg_size, fin, rsv1, rsv2, rsv3) if mask != 0: _mask_payload(msg_ptr, msg_size, mask, msg_ptr) self._fast_write(header_ptr, header_size + msg_size) - cdef _send(self, WSMsgType msg_type, message, bint fin, bint rsv1): + cdef _send(self, WSMsgType msg_type, message, bint fin, bint rsv1, bint rsv2, bint rsv3): if self.is_close_frame_sent: self._logger.debug("Ignore attempt to send a message after WSMsgType.CLOSE has already been sent") return @@ -478,7 +482,7 @@ cdef class WSTransport: self._write_buffer.resize(header_size) mask = self._prepare_header(self._write_buffer.data, msg_type, - msg_size, fin, rsv1) + msg_size, fin, rsv1, rsv2, rsv3) if msg_size == 0: self._fast_write(self._write_buffer.data, header_size) @@ -495,7 +499,7 @@ cdef class WSTransport: self._write_buffer.resize(header_size) self._prepare_header( self._write_buffer.data, msg_type, - msg_size, fin, rsv1) + msg_size, fin, rsv1, rsv2, rsv3) header = PyMemoryView_FromMemory( self._write_buffer.data, header_size, PyBUF_READ ) @@ -520,18 +524,18 @@ cdef class WSTransport: cdef send_reuse_external_buffer(self, WSMsgType msg_type, char* msg_ptr, Py_ssize_t msg_size, - bint fin=True, bint rsv1=False): + bint fin=True, bint rsv1=False, bint rsv2=False, bint rsv3=False): self._check_thread("send_reuse_external_buffer") if msg_type == WSMsgType.CLOSE: raise ValueError("attempt to send CLOSE frame using send_reuse_external_buffer, use send_close instead") - self._send_buffer(msg_type, msg_ptr, msg_size, fin, rsv1) + self._send_buffer(msg_type, msg_ptr, msg_size, fin, rsv1, rsv2, rsv3) cpdef send_reuse_external_bytearray(self, WSMsgType msg_type, bytearray buffer, Py_ssize_t msg_offset, - bint fin=True, bint rsv1=False): + bint fin=True, bint rsv1=False, bint rsv2=False, bint rsv3=False): """ Send a frame over websocket with a message as its payload. This function does not copy message to prepare websocket frames. @@ -545,7 +549,11 @@ cdef class WSTransport: :param fin: fin bit in websocket frame. Indicate that the frame is the last one in the message. :param rsv1: first reserved bit in websocket frame. - Some protocol extensions use it to indicate that payload is compressed. + Some protocol extensions use it to indicate that payload is compressed. + :param rsv2: second reserved bit in websocket frame. + Protocol extensions can use this flag. + :param rsv3: third reserved bit in websocket frame. + Protocol extensions can use this flag. """ if buffer is None: raise ValueError("None is passed instead of buffer to send_reuse_external_bytearray") @@ -569,9 +577,9 @@ cdef class WSTransport: char* msg_ptr = buffer_ptr + msg_offset Py_ssize_t msg_size = buffer_size - msg_offset - self._send_buffer(msg_type, msg_ptr, msg_size, fin, rsv1) + self._send_buffer(msg_type, msg_ptr, msg_size, fin, rsv1, rsv2, rsv3) - cpdef send(self, WSMsgType msg_type, message, bint fin=True, bint rsv1=False): + cpdef send(self, WSMsgType msg_type, message, bint fin=True, bint rsv1=False, bint rsv2=False, bint rsv3=False): """ Send a frame over websocket with a message as its payload. @@ -588,9 +596,13 @@ cdef class WSTransport: :param rsv1: first reserved bit in websocket frame. Some protocol extensions use it to indicate that payload is compressed. + :param rsv2: second reserved bit in websocket frame. + Protocol extensions can use this flag. + :param rsv3: third reserved bit in websocket frame. + Protocol extensions can use this flag. """ self._check_thread("send") - self._send(msg_type, message, fin, rsv1) + self._send(msg_type, message, fin, rsv1, rsv2, rsv3) cpdef send_ping(self, message=None): """ @@ -599,7 +611,7 @@ cdef class WSTransport: :param message: an optional bytes-like object """ self._check_thread("send_ping") - self._send(WSMsgType.PING, message, True, False) + self._send(WSMsgType.PING, message, True, False, False, False) cpdef send_pong(self, message=None): """ @@ -608,7 +620,7 @@ cdef class WSTransport: :param message: an optional bytes-like object """ self._check_thread("send_pong") - self._send(WSMsgType.PONG, message, True, False) + self._send(WSMsgType.PONG, message, True, False, False, False) cpdef send_close(self, WSCloseCode close_code=WSCloseCode.NO_INFO, close_message=None): """ @@ -635,7 +647,7 @@ cdef class WSTransport: (msg_ptr)[0] = htons(close_code) memcpy(msg_ptr + 2, close_msg_ptr, close_msg_length) - self._send(WSMsgType.CLOSE, msg, True, False) + self._send(WSMsgType.CLOSE, msg, True, False, False, False) if not self.is_close_frame_sent: self.is_close_frame_sent = True diff --git a/tests/test_basics.py b/tests/test_basics.py index 4f35a9f..30a5844 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -19,38 +19,47 @@ async def test_echo(use_aiofastnet, ssl_context, msg_size): async with WSClient(server, ssl_context=ssl_context.client, use_aiofastnet=use_aiofastnet) as client: msg = (b"ABCDEFGHIKLMNOPQ" * (int(msg_size / 16) + 1))[:msg_size] - client.transport.send(picows.WSMsgType.BINARY, msg, False, False) + client.transport.send(picows.WSMsgType.BINARY, msg, False, False, True, False) frame = await client.get_message() - assert frame.frame_str.startswith("WSFrame(BINARY, fin=False, rsv1=False") + assert frame.frame_str.startswith("WSFrame(BINARY, fin=False, rsv1=False, rsv2=True, rsv3=False") assert frame.msg_type == picows.WSMsgType.BINARY assert frame.payload_as_bytes == msg assert frame.payload_as_bytes_from_mv == msg assert not frame.fin assert not frame.rsv1 + assert frame.rsv2 + assert not frame.rsv3 ba = bytearray(b"1234567890123456") ba += msg - client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, ba, 16) + client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, ba, 16, True, False, False, True) frame = await client.get_message() - assert frame.frame_str.startswith("WSFrame(BINARY, fin=True, rsv1=False") + assert frame.frame_str.startswith("WSFrame(BINARY, fin=True, rsv1=False, rsv2=False, rsv3=True") assert frame.msg_type == picows.WSMsgType.BINARY assert frame.payload_as_bytes == msg + assert not frame.rsv1 + assert not frame.rsv2 + assert frame.rsv3 msg = base64.b64encode(msg) - client.transport.send(picows.WSMsgType.TEXT, msg, True, True) + client.transport.send(picows.WSMsgType.TEXT, msg, True, True, True, True) frame = await client.get_message() - assert frame.frame_str.startswith("WSFrame(TEXT, fin=True, rsv1=True") + assert frame.frame_str.startswith("WSFrame(TEXT, fin=True, rsv1=True, rsv2=True, rsv3=True") assert frame.msg_type == picows.WSMsgType.TEXT assert frame.payload_as_ascii_text == msg.decode("ascii") assert frame.payload_as_utf8_text == msg.decode("utf8") assert frame.fin assert frame.rsv1 + assert frame.rsv2 + assert frame.rsv3 # Check send defaults client.transport.send(picows.WSMsgType.BINARY, msg) frame = await client.get_message() assert frame.fin assert not frame.rsv1 + assert not frame.rsv2 + assert not frame.rsv3 # Test non-bytes like send with pytest.raises(TypeError): diff --git a/tests/utils.py b/tests/utils.py index 0193765..86600a5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,6 +24,8 @@ def __init__(self, frame: picows.WSFrame): self.payload_as_bytes_from_mv = bytes(frame.get_payload_as_memoryview()) self.fin = frame.fin self.rsv1 = frame.rsv1 + self.rsv2 = frame.rsv2 + self.rsv3 = frame.rsv3 class TextFrame: @@ -34,6 +36,8 @@ def __init__(self, frame: picows.WSFrame): self.payload_as_utf8_text = frame.get_payload_as_utf8_text() self.fin = frame.fin self.rsv1 = frame.rsv1 + self.rsv2 = frame.rsv2 + self.rsv3 = frame.rsv3 class CloseFrame: @@ -44,6 +48,8 @@ def __init__(self, frame: picows.WSFrame): self.close_message = frame.get_close_message() self.fin = frame.fin self.rsv1 = frame.rsv1 + self.rsv2 = frame.rsv2 + self.rsv3 = frame.rsv3 def materialize_frame(frame: picows.WSFrame) -> Union[TextFrame, CloseFrame, BinaryFrame]: @@ -116,7 +122,7 @@ def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): self._transport.send(picows.WSMsgType.BINARY, msg) return - self._transport.send(frame.msg_type, frame.get_payload_as_bytes(), frame.fin, frame.rsv1) + self._transport.send(frame.msg_type, frame.get_payload_as_bytes(), frame.fin, frame.rsv1, frame.rsv2, frame.rsv3) @dataclass @@ -172,4 +178,3 @@ async def WSClient(server, listener_factory=None, **kwargs): await transport.wait_disconnected() except (TestException, picows.WSError): pass - From 04b72708f82ade5a9a686487e9cf48ab94c79bef Mon Sep 17 00:00:00 2001 From: taras Date: Fri, 1 May 2026 14:10:54 +0200 Subject: [PATCH 14/16] Add tests --- tests/test_ws_logic.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_ws_logic.py b/tests/test_ws_logic.py index 80812bf..d92485a 100644 --- a/tests/test_ws_logic.py +++ b/tests/test_ws_logic.py @@ -71,6 +71,10 @@ async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWrit async def test_send_external_bytearray_asserts(): async with WSServer() as server: async with WSClient(server) as client: + with pytest.raises(ValueError): + # Check assertion for None buffer + client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, None, 16) + with pytest.raises(ValueError): # Check assertion for msg_len >= 0 client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, bytearray(b"HELLO"), 16) @@ -79,6 +83,10 @@ async def test_send_external_bytearray_asserts(): # Check assertion for offset to be at least 14 client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, bytearray(b"1234567890123HELLO"), 13) + with pytest.raises(ValueError): + # Check CLOSE is not allowed + client.transport.send_reuse_external_bytearray(picows.WSMsgType.CLOSE, bytearray(b"1234567890123HELLO"), 16) + async def test_max_frame_size_violation_huge_frame_from_client(use_aiofastnet, ssl_context): msg = os.urandom(128 * 1024) From d96e7bc570b53e031b47fd25b718edf607aa66dc Mon Sep 17 00:00:00 2001 From: taras Date: Fri, 1 May 2026 14:14:01 +0200 Subject: [PATCH 15/16] Improve logic --- picows/picows.pyx | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/picows/picows.pyx b/picows/picows.pyx index 77446ed..116b959 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -995,10 +995,11 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): # self._logger.getLogger adds child logger to the global loggers dict. # These child loggers never get deleted after connections are lost # Therefore do not use getLogger, create and setup child loggers manually - child_logger = logging.Logger(f"{self._logger.name}.{sock.fileno()}", logging.NOTSET) - child_logger.parent = self._logger - child_logger.propagate = True - self._logger = child_logger + if isinstance(self._logger, logging.Logger, logging.LoggerAdapter): + child_logger = logging.Logger(f"{self._logger.name}.{sock.fileno()}", logging.NOTSET) + child_logger.parent = self._logger + child_logger.propagate = True + self._logger = child_logger quickack = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_QUICKACK) if hasattr(socket, "TCP_QUICKACK") else False From b8340f765e4d16dd51ac4e7111b13d7a06ff6eca Mon Sep 17 00:00:00 2001 From: taras Date: Fri, 1 May 2026 14:14:46 +0200 Subject: [PATCH 16/16] Update --- picows/picows.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/picows/picows.pyx b/picows/picows.pyx index 116b959..d2d5793 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -995,7 +995,7 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): # self._logger.getLogger adds child logger to the global loggers dict. # These child loggers never get deleted after connections are lost # Therefore do not use getLogger, create and setup child loggers manually - if isinstance(self._logger, logging.Logger, logging.LoggerAdapter): + if isinstance(self._logger, logging.Logger): child_logger = logging.Logger(f"{self._logger.name}.{sock.fileno()}", logging.NOTSET) child_logger.parent = self._logger child_logger.propagate = True