diff --git a/HISTORY.rst b/HISTORY.rst index d642f8b..83cbc5e 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -5,6 +5,16 @@ picows Release History :depth: 1 :local: +1.20.0 () +------------------ + +* ws_connect/ws_create_server logger_name parameter can now accept a logger-like object +* ws_connect/ws_create_server websocket_handshake_timeout param can now accept None to disable handshake timeouts +* Introduce new exceptions: WSInvalidMessageError, WSInvalidStatusError, WSInvalidHeaderError, WSInvalidUpgradeError +* Allow sending close frames only using send_close to simplify logic +* Raise ValueError instead of assert on some invalid user input +* Added rsv2 and rsv3 to WSTransport send methods + 1.19.0 (2026-04-24) ------------------ diff --git a/docs/source/guides.rst b/docs/source/guides.rst index 50a12e1..83626ba 100644 --- a/docs/source/guides.rst +++ b/docs/source/guides.rst @@ -91,11 +91,24 @@ Additionally, websocket-specific failures are represented by :any:`WSError` and its subclasses: * :any:`WSHandshakeError` for HTTP upgrade negotiation failures (raised by :any:`ws_connect`). + More specific subclasses may be raised: + + * :any:`WSInvalidMessageError` for malformed HTTP upgrade responses. + * :any:`WSInvalidStatusError` when the HTTP response status isn't ``101 Switching Protocols``. + * :any:`WSInvalidHeaderError` for invalid handshake headers such as + ``Content-Length`` or ``Sec-WebSocket-Accept``. + * :any:`WSInvalidUpgradeError` for invalid ``Upgrade`` / ``Connection`` headers. + + Redirect-following failures in :any:`ws_connect` currently still raise the + base :any:`WSHandshakeError`. * :any:`WSProtocolError` for websocket parser/protocol violations (can be re-raised by :any:`WSTransport.wait_disconnected` on client side). * :any:`WSInvalidURL` for invalid websocket/proxy URL inputs. In general, :any:`WSError` is reserved for websocket-specific failures only. +Handshake timeouts are separate and currently raise `asyncio.TimeoutError`, +not :any:`WSError`. + There is also a special exception, `asyncio.CancelledError`, which any coroutine can raise when it is externally cancelled. Sometimes you need to handle this exception manually. For example, in a reconnection loop where you want to diff --git a/docs/source/reference.rst b/docs/source/reference.rst index 01cbc09..237c98a 100644 --- a/docs/source/reference.rst +++ b/docs/source/reference.rst @@ -20,6 +20,18 @@ Classes .. autoexception:: WSHandshakeError :show-inheritance: +.. autoexception:: WSInvalidMessageError + :show-inheritance: + +.. autoexception:: WSInvalidStatusError + :show-inheritance: + +.. autoexception:: WSInvalidHeaderError + :show-inheritance: + +.. autoexception:: WSInvalidUpgradeError + :show-inheritance: + .. autoexception:: WSProtocolError :show-inheritance: @@ -231,7 +243,7 @@ Classes Opening handshake response. - .. py:method:: send_reuse_external_buffer(WSMsgType msg_type, char* msg_ptr, size_t msg_size, bint fin=True, bint rsv1=False) + .. py:method:: send_reuse_external_buffer(WSMsgType msg_type, char* msg_ptr, size_t msg_size, bint fin=True, bint rsv1=False, bint rsv2=False, bint rsv3=False) **Available only from Cython.** @@ -251,6 +263,10 @@ Classes :param rsv1: first reserved bit in websocket frame. Some protocol extensions use it to indicate that payload is compressed. + :param rsv2: second reserved bit in websocket frame. + Protocol extensions can use this flag. + :param rsv3: third reserved bit in websocket frame. + Protocol extensions can use this flag. Enums ----- diff --git a/picows/__init__.py b/picows/__init__.py index e329d0b..e27785d 100644 --- a/picows/__init__.py +++ b/picows/__init__.py @@ -1,6 +1,10 @@ from .types import ( WSError, WSHandshakeError, + WSInvalidMessageError, + WSInvalidStatusError, + WSInvalidHeaderError, + WSInvalidUpgradeError, WSProtocolError, WSUpgradeRequest, WSUpgradeResponse, @@ -30,6 +34,10 @@ __all__ = [ 'WSError', 'WSHandshakeError', + 'WSInvalidMessageError', + 'WSInvalidStatusError', + 'WSInvalidHeaderError', + 'WSInvalidUpgradeError', 'WSProtocolError', 'WSUpgradeRequest', 'WSUpgradeResponse', diff --git a/picows/api.py b/picows/api.py index 204a7ca..f5cd14e 100644 --- a/picows/api.py +++ b/picows/api.py @@ -4,9 +4,9 @@ from dataclasses import dataclass from functools import partial from inspect import isawaitable -from logging import getLogger +from logging import Logger, LoggerAdapter, getLogger from ssl import SSLContext -from typing import Callable, Optional, Union, Dict, Any, Awaitable, cast +from typing import Callable, Optional, Union, Dict, Any, Awaitable, cast, TYPE_CHECKING from python_socks.async_.asyncio import Proxy @@ -21,6 +21,13 @@ WSServerListenerFactory = Callable[[WSUpgradeRequest], Union[WSListener, WSUpgradeResponseWithListener, None]] WSSocketFactory = Callable[[WSParsedURL], Union[Optional[socket.socket], Awaitable[Optional[socket.socket]]]] +if TYPE_CHECKING: + _WSLoggerAdapter = LoggerAdapter[Any] +else: + _WSLoggerAdapter = LoggerAdapter + +WSLoggerLike = Union[str, Logger, _WSLoggerAdapter, None] + _HAS_AIOFASTNET = False try: import aiofastnet @@ -61,6 +68,20 @@ def _is_connected(sock: socket.socket) -> bool: except OSError: return False + +def _resolve_logger( + logger_name: WSLoggerLike, + default_suffix: str, + prefix: str = "picows." +) -> Union[Logger, _WSLoggerAdapter]: + if logger_name is None: + return getLogger(f"{prefix}{default_suffix}") + + if isinstance(logger_name, str): + return getLogger(f"{prefix}{logger_name}") + + return logger_name + @dataclass class _ConnectedSocket: sock: Optional[socket.socket] @@ -171,8 +192,8 @@ async def ws_connect(ws_listener_factory: WSListenerFactory, # type: ignore [no- *, ssl_context: Optional[SSLContext] = None, disconnect_on_exception: bool = True, - websocket_handshake_timeout: float = 5, - logger_name: str = "client", + websocket_handshake_timeout: Optional[float] = 5, + logger_name: WSLoggerLike = None, enable_auto_ping: bool = False, auto_ping_idle_timeout: float = 10, auto_ping_reply_timeout: float = 10, @@ -205,8 +226,11 @@ async def ws_connect(ws_listener_factory: WSListenerFactory, # type: ignore [no- :param websocket_handshake_timeout: is the time in seconds to wait for the websocket client to receive websocket handshake response before aborting the connection. + Set to ``None`` to disable the timeout. :param logger_name: - picows will use `picows.` logger to do all the logging. + Logger name suffix or logger-like object used for logging. + If a string is provided, picows will use `picows.`. + If ``None`` is provided, picows will use ``picows.client``. :param enable_auto_ping: Enable detection of a stale connection by periodically pinging remote peer. @@ -273,7 +297,7 @@ async def ws_connect(ws_listener_factory: WSListenerFactory, # type: ignore [no- # May sure people who are passing old argument are not going to get an exception kwargs.pop('zero_copy_unsafe_ssl_write', None) - logger = getLogger(f"picows.{logger_name}") + logger = _resolve_logger(logger_name, "client") parsed_url = parse_url(url) parsed_proxy_url = parse_url(proxy, False) if proxy is not None else None loop = asyncio.get_running_loop() @@ -341,8 +365,8 @@ async def ws_create_server(ws_listener_factory: WSServerListenerFactory, port=None, *, disconnect_on_exception: bool = True, - websocket_handshake_timeout=5, - logger_name: str = "server", + websocket_handshake_timeout: Optional[float] = 5, + logger_name: WSLoggerLike = None, enable_auto_ping: bool = False, auto_ping_idle_timeout: float = 20, auto_ping_reply_timeout: float = 20, @@ -388,8 +412,11 @@ async def ws_create_server(ws_listener_factory: WSServerListenerFactory, thrown by WSListener.on_ws_frame callback :param websocket_handshake_timeout: is the time in seconds to wait for the websocket server to receive websocket handshake request before aborting the connection. + Set to ``None`` to disable the timeout. :param logger_name: - picows will use `picows.` logger to do all the logging. + Logger name suffix or logger-like object used for logging. + If a string is provided, picows will use `picows.`. + If ``None`` is provided, picows will use ``picows.server``. :param enable_auto_ping: Enable detection of a stale connection by periodically pinging remote peer. @@ -444,7 +471,7 @@ def ws_protocol_factory() -> WSProtocol: None, # ws_path False, # is_client_side ws_listener_factory, - getLogger(f"picows.{logger_name}"), + _resolve_logger(logger_name, "server"), disconnect_on_exception, websocket_handshake_timeout, enable_auto_ping, auto_ping_idle_timeout, auto_ping_reply_timeout, diff --git a/picows/picows.pxd b/picows/picows.pxd index b0e3ec2..93510f6 100644 --- a/picows/picows.pxd +++ b/picows/picows.pxd @@ -41,6 +41,19 @@ cpdef enum WSAutoPingStrategy: PING_PERIODICALLY = 2 +cdef class WSCloseInfo: + cdef: + readonly WSCloseCode code + readonly str reason + + +cdef class WSCloseHandshake: + cdef: + readonly WSCloseInfo recv + readonly WSCloseInfo sent + readonly bint recv_then_sent + + cdef class MemoryBuffer: cdef: Py_ssize_t size @@ -70,15 +83,17 @@ cdef class WSFrame: cpdef WSCloseCode get_close_code(self) cpdef bytes get_close_message(self) + cpdef str get_close_reason(self) cdef class WSTransport: cdef: object __weakref__ - readonly object underlying_transport #: asyncio.Transport - readonly object request #: WSUpgradeRequest - readonly object response #: WSUpgradeResponse + readonly object underlying_transport #: asyncio.Transport + readonly object request #: WSUpgradeRequest + readonly object response #: WSUpgradeResponse + readonly WSCloseHandshake close_handshake #: Optional[WSCloseHandshake] readonly bint is_client_side readonly bint is_secure readonly bint is_close_frame_sent @@ -88,6 +103,7 @@ cdef class WSTransport: object listener_proxy object disconnected_future #: asyncio.Future + object _loop object _logger #: Logger MemoryBuffer _write_buffer @@ -97,9 +113,9 @@ cdef class WSTransport: bint _is_aiofn_transport bint _log_debug_enabled - cdef inline send_reuse_external_buffer(self, WSMsgType msg_type, char* msg_ptr, Py_ssize_t msg_size, bint fin=*, bint rsv1=*) - cpdef send(self, WSMsgType msg_type, message, bint fin=*, bint rsv1=*) - cpdef send_reuse_external_bytearray(self, WSMsgType msg_type, bytearray buffer, Py_ssize_t msg_offset, bint fin=*, bint rsv1=*) + cdef inline send_reuse_external_buffer(self, WSMsgType msg_type, char* msg_ptr, Py_ssize_t msg_size, bint fin=*, bint rsv1=*, bint rsv2=*, bint rsv3=*) + cpdef send(self, WSMsgType msg_type, message, bint fin=*, bint rsv1=*, bint rsv2=*, bint rsv3=*) + cpdef send_reuse_external_bytearray(self, WSMsgType msg_type, bytearray buffer, Py_ssize_t msg_offset, bint fin=*, bint rsv1=*, bint rsv2=*, bint rsv3=*) cpdef send_ping(self, message=*) cpdef send_pong(self, message=*) cpdef send_close(self, WSCloseCode close_code=*, close_message=*) @@ -110,9 +126,9 @@ cdef class WSTransport: cdef inline Py_ssize_t _get_header_size(self, Py_ssize_t msg_size) noexcept cdef inline _send_buffer(self, WSMsgType msg_type, char* msg_ptr, Py_ssize_t msg_size, - bint fin, bint rsv1) - cdef inline _send(self, WSMsgType msg_type, message, bint fin, bint rsv1) - cdef inline uint32_t _prepare_header(self, uint8_t* header_ptr, WSMsgType msg_type, Py_ssize_t msg_size, bint fin, bint rsv1) noexcept + bint fin, bint rsv1, bint rsv2, bint rsv3) + cdef inline _send(self, WSMsgType msg_type, message, bint fin, bint rsv1, bint rsv2, bint rsv3) + cdef inline uint32_t _prepare_header(self, uint8_t* header_ptr, WSMsgType msg_type, Py_ssize_t msg_size, bint fin, bint rsv1, bint rsv2, bint rsv3) noexcept cdef inline _send_http_handshake(self, bytes ws_path, bytes host_port, bytes websocket_key_b64, object extra_headers) cdef inline _send_http_handshake_response(self, response, bytes accept_val) cdef inline _fast_write(self, char* ptr, Py_ssize_t sz) @@ -130,4 +146,3 @@ cdef class WSListener: cpdef pause_writing(self) cpdef resume_writing(self) - diff --git a/picows/picows.pyi b/picows/picows.pyi index eea441e..ddcd64d 100644 --- a/picows/picows.pyi +++ b/picows/picows.pyi @@ -36,6 +36,17 @@ class WSAutoPingStrategy(Enum): PING_PERIODICALLY = 2 +class WSCloseInfo: + code: WSCloseCode + reason: str + + +class WSCloseHandshake: + recv: Optional[WSCloseInfo] + sent: Optional[WSCloseInfo] + recv_then_sent: bool + + class WSFrame: @property def tail_size(self) -> int: ... @@ -49,6 +60,12 @@ class WSFrame: @property def rsv1(self) -> bool: ... + @property + def rsv2(self) -> bool: ... + + @property + def rsv3(self) -> bool: ... + @property def last_in_buffer(self) -> bool: ... @@ -58,6 +75,7 @@ class WSFrame: def get_payload_as_memoryview(self) -> memoryview: ... def get_close_code(self) -> WSCloseCode: ... def get_close_message(self) -> bytes: ... + def get_close_reason(self) -> str: ... def __str__(self) -> str: ... @@ -66,19 +84,22 @@ class WSTransport: def underlying_transport(self) -> asyncio.Transport: ... @property - def is_client_side(self) -> bool: ... + def request(self) -> WSUpgradeRequest: ... @property - def is_secure(self) -> bool: ... + def response(self) -> WSUpgradeResponse: ... @property - def is_close_frame_sent(self) -> bool: ... + def close_handshake(self) -> WSCloseHandshake: ... @property - def request(self) -> WSUpgradeRequest: ... + def is_client_side(self) -> bool: ... @property - def response(self) -> WSUpgradeResponse: ... + def is_secure(self) -> bool: ... + + @property + def is_close_frame_sent(self) -> bool: ... def send( self, @@ -86,6 +107,8 @@ class WSTransport: message: Optional[WSBuffer], fin: bool = True, rsv1: bool = False, + rsv2: bool = False, + rsv3: bool = False, ) -> None: ... def send_reuse_external_bytearray( self, @@ -93,7 +116,9 @@ class WSTransport: buffer: bytearray, msg_offset: int, fin: bool = True, - rsv1: bool = False + rsv1: bool = False, + rsv2: bool = False, + rsv3: bool = False, ) -> None: ... def send_ping(self, message: Optional[WSBuffer]=None) -> None: ... def send_pong(self, message: Optional[WSBuffer]=None) -> None: ... diff --git a/picows/picows.pyx b/picows/picows.pyx index 7f62e12..d2d5793 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -24,7 +24,9 @@ from libc.stdlib cimport rand from .types import (PICOWS_DEBUG_LL, WSUpgradeRequest, WSUpgradeResponse, WSUpgradeResponseWithListener, - WSHandshakeError, WSProtocolError, add_extra_headers) + WSHandshakeError, WSInvalidMessageError, WSInvalidStatusError, + WSInvalidHeaderError, WSInvalidUpgradeError, + WSProtocolError, add_extra_headers) cdef: @@ -116,6 +118,14 @@ cdef _is_aiofn_transport(transport): return False +cdef class WSCloseInfo: + pass + + +cdef class WSCloseHandshake: + pass + + @cython.no_gc @cython.freelist(64) cdef class WSFrame: @@ -197,6 +207,21 @@ cdef class WSFrame: else: return PyBytes_FromStringAndSize(self.payload_ptr + 2, self.payload_size - 2) + cpdef str get_close_reason(self): + """ + :return: a new str object with a close reason. If there is no close reason then returns None. + + This method is only valid for WSMsgType.CLOSE frames. + """ + + assert self.msg_type == WSMsgType.CLOSE, "get_close_message can be called only for CLOSE frames" + + if self.payload_size <= 2: + return None + else: + return PyUnicode_FromStringAndSize(self.payload_ptr + 2, + self.payload_size - 2) + def __str__(self): return (f"WSFrame({WSMsgType(self.msg_type).name}, fin={True if self.fin else False}, " f"rsv1={True if self.rsv1 else False}, " @@ -339,6 +364,7 @@ cdef class WSTransport: self.underlying_transport = underlying_transport self.request = None self.response = None + self.close_handshake = None self.is_client_side = is_client_side self.is_secure = underlying_transport.get_extra_info('ssl_object') is not None self.is_close_frame_sent = False @@ -378,7 +404,7 @@ cdef class WSTransport: cdef uint32_t _prepare_header(self, uint8_t* header_ptr, WSMsgType msg_type, Py_ssize_t msg_size, - bint fin, bint rsv1) noexcept: + bint fin, bint rsv1, bint rsv2, bint rsv3) noexcept: # Return mask or 0 for server side cdef: @@ -390,6 +416,10 @@ cdef class WSTransport: first_byte |= 0x80 if rsv1: first_byte |= 0x40 + if rsv2: + first_byte |= 0x20 + if rsv3: + first_byte |= 0x10 header_ptr[0] = first_byte @@ -418,21 +448,22 @@ cdef class WSTransport: cdef _send_buffer(self, WSMsgType msg_type, char* msg_ptr, Py_ssize_t msg_size, - bint fin, bint rsv1): + bint fin, bint rsv1, bint rsv2, bint rsv3): + if self.is_close_frame_sent: + self._logger.debug("Ignore attempt to send a message after WSMsgType.CLOSE has already been sent") + return + cdef: Py_ssize_t header_size = self._get_header_size(msg_size) char* header_ptr = msg_ptr - header_size - uint32_t mask = self._prepare_header(header_ptr, msg_type, msg_size, fin, rsv1) + uint32_t mask = self._prepare_header(header_ptr, msg_type, msg_size, fin, rsv1, rsv2, rsv3) if mask != 0: _mask_payload(msg_ptr, msg_size, mask, msg_ptr) self._fast_write(header_ptr, header_size + msg_size) - if msg_type == WSMsgType.CLOSE: - self.is_close_frame_sent = True - - cdef _send(self, WSMsgType msg_type, message, bint fin, bint rsv1): + cdef _send(self, WSMsgType msg_type, message, bint fin, bint rsv1, bint rsv2, bint rsv3): if self.is_close_frame_sent: self._logger.debug("Ignore attempt to send a message after WSMsgType.CLOSE has already been sent") return @@ -451,7 +482,7 @@ cdef class WSTransport: self._write_buffer.resize(header_size) mask = self._prepare_header(self._write_buffer.data, msg_type, - msg_size, fin, rsv1) + msg_size, fin, rsv1, rsv2, rsv3) if msg_size == 0: self._fast_write(self._write_buffer.data, header_size) @@ -468,7 +499,7 @@ cdef class WSTransport: self._write_buffer.resize(header_size) self._prepare_header( self._write_buffer.data, msg_type, - msg_size, fin, rsv1) + msg_size, fin, rsv1, rsv2, rsv3) header = PyMemoryView_FromMemory( self._write_buffer.data, header_size, PyBUF_READ ) @@ -491,30 +522,26 @@ cdef class WSTransport: (masked_msg_ptr - header_size), header_size + msg_size ) - if msg_type == WSMsgType.CLOSE: - self.is_close_frame_sent = True - cdef send_reuse_external_buffer(self, WSMsgType msg_type, char* msg_ptr, Py_ssize_t msg_size, - bint fin=True, bint rsv1=False): + bint fin=True, bint rsv1=False, bint rsv2=False, bint rsv3=False): self._check_thread("send_reuse_external_buffer") - if self.is_close_frame_sent: - self._logger.debug("Ignore attempt to send a message after WSMsgType.CLOSE has already been sent") - return + if msg_type == WSMsgType.CLOSE: + raise ValueError("attempt to send CLOSE frame using send_reuse_external_buffer, use send_close instead") - self._send_buffer(msg_type, msg_ptr, msg_size, fin, rsv1) + self._send_buffer(msg_type, msg_ptr, msg_size, fin, rsv1, rsv2, rsv3) cpdef send_reuse_external_bytearray(self, WSMsgType msg_type, bytearray buffer, Py_ssize_t msg_offset, - bint fin=True, bint rsv1=False): + bint fin=True, bint rsv1=False, bint rsv2=False, bint rsv3=False): """ Send a frame over websocket with a message as its payload. This function does not copy message to prepare websocket frames. It reuses bytearray's memory to write websocket frame header at the front. - :param msg_type: :any:`WSMsgType` enum value\n + :param msg_type: :any:`WSMsgType` enum value, except CLOSE. Use send_close to send close frames. :param msg_offset: specifies where message begins in the bytearray. Must be at least 14 to let picows to write websocket frame header in front of the message. :param buffer: bytearray that contains message and some extra space (at least 14 bytes) in the beginning. @@ -522,30 +549,37 @@ cdef class WSTransport: :param fin: fin bit in websocket frame. Indicate that the frame is the last one in the message. :param rsv1: first reserved bit in websocket frame. - Some protocol extensions use it to indicate that payload is compressed. + Some protocol extensions use it to indicate that payload is compressed. + :param rsv2: second reserved bit in websocket frame. + Protocol extensions can use this flag. + :param rsv3: third reserved bit in websocket frame. + Protocol extensions can use this flag. """ - assert buffer is not None, "buffer is None" - assert msg_offset >= 14, "buffer must have at least 14 bytes available before message starts, check msg_offset parameter" + if buffer is None: + raise ValueError("None is passed instead of buffer to send_reuse_external_bytearray") - self._check_thread("send_reuse_external_bytearray") + if msg_offset < 14: + raise ValueError("buffer must have at least 14 bytes available before message starts, check msg_offset parameter") - if self.is_close_frame_sent: - self._logger.debug("Ignore attempt to send a message after WSMsgType.CLOSE has already been sent") - return + if msg_type == WSMsgType.CLOSE: + raise ValueError("attempt to send CLOSE frame using send_reuse_external_bytearray, use send_close instead") + + self._check_thread("send_reuse_external_bytearray") cdef: char* buffer_ptr = PyByteArray_AS_STRING(buffer) Py_ssize_t buffer_size = PyByteArray_GET_SIZE(buffer) - assert buffer_size >= msg_offset, "msg_offset points beyond buffer end, msg_offset > len(buffer)" + if buffer_size < msg_offset: + raise ValueError("msg_offset points beyond buffer end, msg_offset > len(buffer)") cdef: char* msg_ptr = buffer_ptr + msg_offset Py_ssize_t msg_size = buffer_size - msg_offset - self._send_buffer(msg_type, msg_ptr, msg_size, fin, rsv1) + self._send_buffer(msg_type, msg_ptr, msg_size, fin, rsv1, rsv2, rsv3) - cpdef send(self, WSMsgType msg_type, message, bint fin=True, bint rsv1=False): + cpdef send(self, WSMsgType msg_type, message, bint fin=True, bint rsv1=False, bint rsv2=False, bint rsv3=False): """ Send a frame over websocket with a message as its payload. @@ -562,9 +596,13 @@ cdef class WSTransport: :param rsv1: first reserved bit in websocket frame. Some protocol extensions use it to indicate that payload is compressed. + :param rsv2: second reserved bit in websocket frame. + Protocol extensions can use this flag. + :param rsv3: third reserved bit in websocket frame. + Protocol extensions can use this flag. """ self._check_thread("send") - self._send(msg_type, message, fin, rsv1) + self._send(msg_type, message, fin, rsv1, rsv2, rsv3) cpdef send_ping(self, message=None): """ @@ -573,7 +611,7 @@ cdef class WSTransport: :param message: an optional bytes-like object """ self._check_thread("send_ping") - self._send(WSMsgType.PING, message, True, False) + self._send(WSMsgType.PING, message, True, False, False, False) cpdef send_pong(self, message=None): """ @@ -582,7 +620,7 @@ cdef class WSTransport: :param message: an optional bytes-like object """ self._check_thread("send_pong") - self._send(WSMsgType.PONG, message, True, False) + self._send(WSMsgType.PONG, message, True, False, False, False) cpdef send_close(self, WSCloseCode close_code=WSCloseCode.NO_INFO, close_message=None): """ @@ -604,11 +642,23 @@ cdef class WSTransport: cdef: bytes msg = PyBytes_FromStringAndSize(NULL, close_msg_length + 2) char* msg_ptr = PyBytes_AS_STRING(msg) + str reason = PyUnicode_FromStringAndSize(close_msg_ptr, close_msg_length) (msg_ptr)[0] = htons(close_code) memcpy(msg_ptr + 2, close_msg_ptr, close_msg_length) - self._send(WSMsgType.CLOSE, msg, True, False) + self._send(WSMsgType.CLOSE, msg, True, False, False, False) + + if not self.is_close_frame_sent: + self.is_close_frame_sent = True + + if self.close_handshake is None: + self.close_handshake = WSCloseHandshake.__new__(WSCloseHandshake) + self.close_handshake.recv_then_sent = False + + self.close_handshake.sent = WSCloseInfo.__new__(WSCloseInfo) + self.close_handshake.sent.code = close_code + self.close_handshake.sent.reason = reason cpdef disconnect(self, bint graceful=True): """ @@ -945,10 +995,11 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): # self._logger.getLogger adds child logger to the global loggers dict. # These child loggers never get deleted after connections are lost # Therefore do not use getLogger, create and setup child loggers manually - child_logger = logging.Logger(f"{self._logger.name}.{sock.fileno()}", logging.NOTSET) - child_logger.parent = self._logger - child_logger.propagate = True - self._logger = child_logger + if isinstance(self._logger, logging.Logger): + child_logger = logging.Logger(f"{self._logger.name}.{sock.fileno()}", logging.NOTSET) + child_logger.parent = self._logger + child_logger.propagate = True + self._logger = child_logger quickack = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_QUICKACK) if hasattr(socket, "TCP_QUICKACK") else False @@ -975,11 +1026,13 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): if self.is_client_side: self.transport._send_http_handshake(self._ws_path, self._host_port, self._websocket_key_b64, self._extra_headers) - self._handshake_timeout_handle = self._loop.call_later( - self._handshake_timeout, self._handshake_timeout_callback) + if self._handshake_timeout is not None: + self._handshake_timeout_handle = self._loop.call_later( + self._handshake_timeout, self._handshake_timeout_callback) else: - self._handshake_timeout_handle = self._loop.call_later( - self._handshake_timeout, self._handshake_timeout_callback) + if self._handshake_timeout is not None: + self._handshake_timeout_handle = self._loop.call_later( + self._handshake_timeout, self._handshake_timeout_callback) def connection_lost(self, exc): self._logger.info("Disconnected") @@ -1183,8 +1236,9 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): else: self.transport._send_http_handshake_response(response, accept_val) - self._handshake_timeout_handle.cancel() - self._handshake_timeout_handle = None + if self._handshake_timeout_handle is not None: + self._handshake_timeout_handle.cancel() + self._handshake_timeout_handle = None self._handshake_complete_future.set_result(None) self._invoke_on_ws_connected() self._last_data_time = picows_get_monotonic_time() @@ -1337,45 +1391,133 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): cdef list lines = raw_headers.split(b"\r\n") cdef bytes response_status_line = lines[0] + cdef str response_status_line_str + cdef bytes status_code + cdef bytes line, name, value + cdef str transfer_encoding + cdef object connection_value + cdef object upgrade_value + cdef object r_key + cdef Py_ssize_t content_length - cdef str response_status_line_str = response_status_line.decode().lower() + try: + response_status_line_str = response_status_line.decode().lower() + except UnicodeDecodeError: + raise WSInvalidMessageError( + "cannot upgrade, invalid HTTP status line in upgrade response", + raw_headers, + tail, + ) from None # check handshake if not response_status_line_str.startswith("http/1.1 " ): - raise WSHandshakeError(f"cannot upgrade, unknown protocol (expected HTTP/1.1) in upgrade response: {response_status_line_str}", raw_headers, tail) + raise WSInvalidMessageError( + f"cannot upgrade, unknown protocol (expected HTTP/1.1) in upgrade response: {response_status_line_str}", + raw_headers, + tail, + ) - cdef bytes status_code response = WSUpgradeResponse() - response.version, status_code, status_phrase = response_status_line.split(b" ", 2) - response.status = HTTPStatus(int(status_code.decode())) + try: + response.version, status_code, status_phrase = response_status_line.split(b" ", 2) + response.status = HTTPStatus(int(status_code.decode())) + except (ValueError, UnicodeDecodeError): + raise WSInvalidMessageError( + f"cannot upgrade, invalid HTTP status line in upgrade response: {response_status_line!r}", + raw_headers, + tail, + ) from None - cdef bytes line, name, value response.headers = CIMultiDict() for idx in range(1, len(lines)): line = lines[idx] - name, value = line.split(b":", 1) - response.headers.add((name.strip()).decode(), (value.strip()).decode()) + try: + name, value = line.split(b":", 1) + response.headers.add((name.strip()).decode(), (value.strip()).decode()) + except (ValueError, UnicodeDecodeError): + raise WSInvalidMessageError( + f"cannot upgrade, malformed header in upgrade response: {line!r}", + raw_headers, + tail, + response, + ) from None if response.status != HTTPStatus.SWITCHING_PROTOCOLS: - raise WSHandshakeError(f"expected upgrade response with status 101 Switching Protocols, but received {response.status}", raw_headers, tail, response) + raise WSInvalidStatusError( + f"expected upgrade response with status 101 Switching Protocols, but received {response.status}", + raw_headers, + tail, + response, + ) - if response.headers.get("transfer-encoding") == "chunked": - raise WSHandshakeError(f"101 response cannot have Transfer-Encoding but it has", raw_headers, tail, response) + transfer_encoding = response.headers.get("transfer-encoding") + if transfer_encoding == "chunked": + raise WSInvalidHeaderError( + "101 response cannot have Transfer-Encoding but it has", + "Transfer-Encoding", + transfer_encoding, + raw_headers, + tail, + response, + ) - cdef Py_ssize_t content_length = int(response.headers.get("content-length", "0")) + try: + content_length = int(response.headers.get("content-length", "0")) + except ValueError: + raise WSInvalidHeaderError( + "101 response has invalid Content-Length header", + "Content-Length", + response.headers.get("content-length"), + raw_headers, + tail, + response, + ) from None if content_length != 0: - raise WSHandshakeError(f"101 response has non-zero Content-Length, but it can't have body", raw_headers, tail, response) + raise WSInvalidHeaderError( + "101 response has non-zero Content-Length, but it can't have body", + "Content-Length", + response.headers.get("content-length"), + raw_headers, + tail, + response, + ) + + upgrade_value = response.headers.get("upgrade") + upgrade_value = upgrade_value if upgrade_value is None else upgrade_value.lower() + if upgrade_value != "websocket": + raise WSInvalidUpgradeError( + "cannot upgrade, invalid upgrade header", + "Upgrade", + response.headers.get("upgrade"), + raw_headers, + tail, + response, + ) connection_value = response.headers.get("connection") connection_value = connection_value if connection_value is None else connection_value.lower() if connection_value != "upgrade": - raise WSHandshakeError(f"cannot upgrade, invalid connection header: {response.headers['connection']}", raw_headers, tail, response) + raise WSInvalidUpgradeError( + "cannot upgrade, invalid connection header", + "Connection", + response.headers.get("connection"), + raw_headers, + tail, + response, + ) r_key = response.headers.get("sec-websocket-accept") match = b64encode(sha1(self._websocket_key_b64 + _WS_KEY).digest()).decode() if r_key != match: - raise WSHandshakeError(f"cannot upgrade, invalid sec-websocket-accept response", raw_headers, tail, response) + raise WSInvalidHeaderError( + "cannot upgrade, invalid sec-websocket-accept response", + "Sec-WebSocket-Accept", + response.headers.get("sec-websocket-accept"), + raw_headers, + tail, + response, + ) memmove(self._read_buffer.data, self._read_buffer.data + len(raw_headers) + 4, self._read_buffer.size - len(raw_headers) - 4) self._f_new_data_start_pos = len(tail) @@ -1386,7 +1528,6 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): return response cdef inline WSFrame _get_next_frame(self): - cdef WSFrame frame try: return self._get_next_frame_impl() except WSProtocolError as ex: @@ -1406,6 +1547,7 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): uint8_t first_byte uint8_t second_byte WSFrame frame + WSCloseInfo recv if self._state == WSParserState.READ_HEADER: if self._f_new_data_start_pos - self._f_curr_state_start_pos < 2: @@ -1526,7 +1668,8 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): self._state = WSParserState.READ_HEADER if frame.msg_type == WSMsgType.CLOSE: - if frame.get_close_code() < 3000 and frame.get_close_code() not in _ALLOWED_CLOSE_CODES: + close_code = frame.get_close_code() + if close_code < 3000 and close_code not in _ALLOWED_CLOSE_CODES: raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, f"Received CLOSE with invalid close code: {frame.get_close_code()}") @@ -1534,6 +1677,25 @@ cdef class WSProtocol(WSProtocolBase, asyncio.BufferedProtocol): raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, f"Received CLOSE with invalid close code size: {frame.fin} {frame.msg_type} {frame.get_payload_as_bytes()}") + recv = WSCloseInfo.__new__(WSCloseInfo) + recv.code = close_code + try: + recv.reason = frame.get_close_reason() + except UnicodeDecodeError: + raise WSProtocolError(WSCloseCode.INVALID_TEXT, + f"Received CLOSE with invalid UTF-8 reason") + + if self.transport.close_handshake is None: + self.transport.close_handshake = WSCloseHandshake.__new__(WSCloseHandshake) + self.transport.close_handshake.recv = recv + self.transport.close_handshake.sent = None + self.transport.close_handshake.recv_then_sent = True + elif self.transport.close_handshake.recv is None: + self.transport.close_handshake.recv = recv + else: + raise WSProtocolError(WSCloseCode.PROTOCOL_ERROR, + f"Received CLOSE for the second time: {frame.get_close_code()}") + return frame assert False, "we should never reach this state" diff --git a/picows/types.py b/picows/types.py index 9df12d8..6a6179d 100644 --- a/picows/types.py +++ b/picows/types.py @@ -165,6 +165,45 @@ def __init__(self, description: str, self.response = response +class WSInvalidMessageError(WSHandshakeError): + """ + Raised when the HTTP handshake request or response is malformed. + """ + pass + + +class WSInvalidStatusError(WSHandshakeError): + """ + Raised when the HTTP handshake response status rejects the WebSocket upgrade. + """ + pass + + +class WSInvalidHeaderError(WSHandshakeError): + """ + Raised when a HTTP header in the WebSocket handshake is invalid. + """ + name: str + value: Optional[str] + + def __init__(self, description: str, + name: str, + value: Optional[str] = None, + raw_header: Optional[bytes] = None, + raw_body: Optional[bytes] = None, + response: Optional[WSUpgradeResponse] = None): + super().__init__(description, raw_header, raw_body, response) + self.name = name + self.value = value + + +class WSInvalidUpgradeError(WSInvalidHeaderError): + """ + Raised when Upgrade / Connection headers are invalid in the WebSocket handshake. + """ + pass + + class WSProtocolError(WSError): """ Raised when receiving or sending frames that break the protocol or diff --git a/tests/test_basics.py b/tests/test_basics.py index 27ba9d0..30a5844 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -1,5 +1,6 @@ import asyncio import base64 +import logging import os import picows @@ -18,38 +19,47 @@ async def test_echo(use_aiofastnet, ssl_context, msg_size): async with WSClient(server, ssl_context=ssl_context.client, use_aiofastnet=use_aiofastnet) as client: msg = (b"ABCDEFGHIKLMNOPQ" * (int(msg_size / 16) + 1))[:msg_size] - client.transport.send(picows.WSMsgType.BINARY, msg, False, False) + client.transport.send(picows.WSMsgType.BINARY, msg, False, False, True, False) frame = await client.get_message() - assert frame.frame_str.startswith("WSFrame(BINARY, fin=False, rsv1=False") + assert frame.frame_str.startswith("WSFrame(BINARY, fin=False, rsv1=False, rsv2=True, rsv3=False") assert frame.msg_type == picows.WSMsgType.BINARY assert frame.payload_as_bytes == msg assert frame.payload_as_bytes_from_mv == msg assert not frame.fin assert not frame.rsv1 + assert frame.rsv2 + assert not frame.rsv3 ba = bytearray(b"1234567890123456") ba += msg - client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, ba, 16) + client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, ba, 16, True, False, False, True) frame = await client.get_message() - assert frame.frame_str.startswith("WSFrame(BINARY, fin=True, rsv1=False") + assert frame.frame_str.startswith("WSFrame(BINARY, fin=True, rsv1=False, rsv2=False, rsv3=True") assert frame.msg_type == picows.WSMsgType.BINARY assert frame.payload_as_bytes == msg + assert not frame.rsv1 + assert not frame.rsv2 + assert frame.rsv3 msg = base64.b64encode(msg) - client.transport.send(picows.WSMsgType.TEXT, msg, True, True) + client.transport.send(picows.WSMsgType.TEXT, msg, True, True, True, True) frame = await client.get_message() - assert frame.frame_str.startswith("WSFrame(TEXT, fin=True, rsv1=True") + assert frame.frame_str.startswith("WSFrame(TEXT, fin=True, rsv1=True, rsv2=True, rsv3=True") assert frame.msg_type == picows.WSMsgType.TEXT assert frame.payload_as_ascii_text == msg.decode("ascii") assert frame.payload_as_utf8_text == msg.decode("utf8") assert frame.fin assert frame.rsv1 + assert frame.rsv2 + assert frame.rsv3 # Check send defaults client.transport.send(picows.WSMsgType.BINARY, msg) frame = await client.get_message() assert frame.fin assert not frame.rsv1 + assert not frame.rsv2 + assert not frame.rsv3 # Test non-bytes like send with pytest.raises(TypeError): @@ -279,4 +289,4 @@ async def test_stress(use_aiofastnet, ssl_context): assert not client.is_paused -# \ No newline at end of file +# diff --git a/tests/test_ws_logic.py b/tests/test_ws_logic.py index 93fc2be..d92485a 100644 --- a/tests/test_ws_logic.py +++ b/tests/test_ws_logic.py @@ -1,28 +1,92 @@ import asyncio +import base64 +import logging import os import struct from concurrent.futures import ThreadPoolExecutor +from contextlib import asynccontextmanager +from hashlib import sha1 from http import HTTPStatus import async_timeout import pytest import picows -from tests.utils import WSServer, WSClient, TIMEOUT +from picows.api import _resolve_logger +from tests.utils import WSServer, WSClient, AsyncClient, TIMEOUT from tests.fixtures import use_aiofastnet, ssl_context +@asynccontextmanager +async def raw_handshake_server(response: bytes): + async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + await reader.readuntil(b"\r\n\r\n") + writer.write(response) + await writer.drain() + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(handle_client, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + try: + yield f"ws://127.0.0.1:{port}/" + finally: + server.close() + await server.wait_closed() + + +@asynccontextmanager +async def delayed_handshake_server(delay: float): + async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + request = await reader.readuntil(b"\r\n\r\n") + websocket_key = next( + line.split(b":", 1)[1].strip() + for line in request.split(b"\r\n") + if line.lower().startswith(b"sec-websocket-key:") + ) + accept = base64.b64encode( + sha1(websocket_key + b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11").digest() + ) + await asyncio.sleep(delay) + writer.write( + b"HTTP/1.1 101 Switching Protocols\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Accept: " + accept + b"\r\n" + b"\r\n" + ) + await writer.drain() + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(handle_client, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + try: + yield f"ws://127.0.0.1:{port}/" + finally: + server.close() + await server.wait_closed() + + async def test_send_external_bytearray_asserts(): async with WSServer() as server: async with WSClient(server) as client: - with pytest.raises(AssertionError): + with pytest.raises(ValueError): + # Check assertion for None buffer + client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, None, 16) + + with pytest.raises(ValueError): # Check assertion for msg_len >= 0 client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, bytearray(b"HELLO"), 16) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): # Check assertion for offset to be at least 14 client.transport.send_reuse_external_bytearray(picows.WSMsgType.BINARY, bytearray(b"1234567890123HELLO"), 13) + with pytest.raises(ValueError): + # Check CLOSE is not allowed + client.transport.send_reuse_external_bytearray(picows.WSMsgType.CLOSE, bytearray(b"1234567890123HELLO"), 16) + async def test_max_frame_size_violation_huge_frame_from_client(use_aiofastnet, ssl_context): msg = os.urandom(128 * 1024) @@ -163,6 +227,153 @@ def on_ws_connected(self, transport: picows.WSTransport): await client.transport.wait_disconnected() +async def test_close_frame_invalid_utf8_reason_from_client(): + async with WSServer() as server: + async with WSClient(server) as client: + mask = 0x12345678 + payload = struct.pack("!H", picows.WSCloseCode.OK) + b"\xff" + masked_payload = bytes( + b ^ mask.to_bytes(4, "big")[i % 4] + for i, b in enumerate(payload) + ) + invalid_close_frame = struct.pack("!BBI", 0x88, 0x80 | len(payload), mask) + masked_payload + + client.transport.underlying_transport.write(invalid_close_frame) + frame = await client.get_message() + assert frame.msg_type == picows.WSMsgType.CLOSE + assert frame.close_code == picows.WSCloseCode.INVALID_TEXT + assert b"Received CLOSE with invalid UTF-8 reason" in frame.close_message + await client.transport.wait_disconnected() + + assert client.transport.close_handshake.sent is None + assert client.transport.close_handshake.recv.code == picows.WSCloseCode.INVALID_TEXT + assert client.transport.close_handshake.recv.reason == "Received CLOSE with invalid UTF-8 reason" + assert client.transport.close_handshake.recv_then_sent is True + + +async def test_close_handshake_client_initiates_close(): + server_transport = None + + class ServerListener(picows.WSListener): + def on_ws_connected(self, transport: picows.WSTransport): + nonlocal server_transport + server_transport = transport + + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + if frame.msg_type == picows.WSMsgType.CLOSE: + transport.send_close(frame.get_close_code(), frame.get_close_message()) + transport.disconnect() + + async with WSServer(lambda _: ServerListener()) as server: + async with WSClient(server) as client: + client.transport.send_close(picows.WSCloseCode.OK, b"client says bye") + await client.transport.wait_disconnected() + + assert client.transport.close_handshake.sent.code == picows.WSCloseCode.OK + assert client.transport.close_handshake.sent.reason == "client says bye" + assert client.transport.close_handshake.recv.code == picows.WSCloseCode.OK + assert client.transport.close_handshake.recv.reason == "client says bye" + assert client.transport.close_handshake.recv_then_sent is False + + assert server_transport.close_handshake.sent.code == picows.WSCloseCode.OK + assert server_transport.close_handshake.sent.reason == "client says bye" + assert server_transport.close_handshake.recv.code == picows.WSCloseCode.OK + assert server_transport.close_handshake.recv.reason == "client says bye" + assert server_transport.close_handshake.recv_then_sent is True + + +async def test_close_handshake_server_initiates_close(): + server_transport = None + + class ServerListener(picows.WSListener): + def on_ws_connected(self, transport: picows.WSTransport): + nonlocal server_transport + server_transport = transport + transport.send_close(picows.WSCloseCode.GOING_AWAY, b"server shutdown") + asyncio.get_running_loop().call_later(0.05, transport.disconnect) + + class ClientListener(AsyncClient): + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + if frame.msg_type == picows.WSMsgType.CLOSE: + transport.send_close(frame.get_close_code(), frame.get_close_message()) + else: + super().on_ws_frame(transport, frame) + + async with WSServer(lambda _: ServerListener()) as server: + async with WSClient(server, ClientListener) as client: + await client.transport.wait_disconnected() + assert client.transport.close_handshake.recv.code == picows.WSCloseCode.GOING_AWAY + assert client.transport.close_handshake.recv.reason == "server shutdown" + assert client.transport.close_handshake.sent.code == picows.WSCloseCode.GOING_AWAY + assert client.transport.close_handshake.sent.reason == "server shutdown" + assert client.transport.close_handshake.recv_then_sent is True + + assert server_transport.close_handshake.recv.code == picows.WSCloseCode.GOING_AWAY + assert server_transport.close_handshake.recv.reason == "server shutdown" + assert server_transport.close_handshake.sent.code == picows.WSCloseCode.GOING_AWAY + assert server_transport.close_handshake.sent.reason == "server shutdown" + assert server_transport.close_handshake.recv_then_sent is False + + +async def test_close_handshake_client_initiates_close_server_disconnects_without_reply(): + server_transport = None + + class ServerListener(picows.WSListener): + def on_ws_connected(self, transport: picows.WSTransport): + nonlocal server_transport + server_transport = transport + + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + if frame.msg_type == picows.WSMsgType.CLOSE: + transport.disconnect(False) + + async with WSServer(lambda _: ServerListener()) as server: + async with WSClient(server) as client: + client.transport.send_close(picows.WSCloseCode.OK, b"client says bye") + await client.transport.wait_disconnected() + + assert client.transport.close_handshake.recv is None + assert client.transport.close_handshake.sent.code == picows.WSCloseCode.OK + assert client.transport.close_handshake.sent.reason == "client says bye" + assert client.transport.close_handshake.recv_then_sent is False + + assert server_transport.close_handshake.recv.code == picows.WSCloseCode.OK + assert server_transport.close_handshake.recv.reason == "client says bye" + assert server_transport.close_handshake.sent is None + assert server_transport.close_handshake.recv_then_sent is True + + +async def test_close_handshake_server_initiates_close_client_disconnects_without_reply(): + server_transport = None + + class ServerListener(picows.WSListener): + def on_ws_connected(self, transport: picows.WSTransport): + nonlocal server_transport + server_transport = transport + transport.send_close(picows.WSCloseCode.GOING_AWAY, b"server shutdown") + asyncio.get_running_loop().call_later(0.05, transport.disconnect) + + class ClientListener(AsyncClient): + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): + super().on_ws_frame(transport, frame) + if frame.msg_type == picows.WSMsgType.CLOSE: + transport.disconnect(False) + + async with WSServer(lambda _: ServerListener()) as server: + async with WSClient(server, ClientListener) as client: + await client.transport.wait_disconnected() + + assert client.transport.close_handshake.recv.code == picows.WSCloseCode.GOING_AWAY + assert client.transport.close_handshake.recv.reason == "server shutdown" + assert client.transport.close_handshake.sent is None + assert client.transport.close_handshake.recv_then_sent is True + + assert server_transport.close_handshake.recv is None + assert server_transport.close_handshake.sent.code == picows.WSCloseCode.GOING_AWAY + assert server_transport.close_handshake.sent.reason == "server shutdown" + assert server_transport.close_handshake.recv_then_sent is False + + async def test_wrong_thread_assert(): loop = asyncio.get_running_loop() with ThreadPoolExecutor(max_workers=1) as executor: @@ -188,3 +399,89 @@ async def test_wrong_thread_assert(): with pytest.raises(RuntimeError, match="WSTransport.disconnect called from a wrong thread"): await loop.run_in_executor(executor, client.transport.disconnect) + + +async def test_handshake_invalid_status_error(): + response = ( + b"HTTP/1.1 404 Not Found\r\n" + b"Connection: close\r\n" + b"Content-Length: 0\r\n" + b"\r\n" + ) + async with raw_handshake_server(response) as url: + with pytest.raises(picows.WSInvalidStatusError): + await picows.ws_connect(AsyncClient, url) + + +async def test_handshake_invalid_upgrade_error(): + response = ( + b"HTTP/1.1 101 Switching Protocols\r\n" + b"Upgrade: not-websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Accept: invalid\r\n" + b"\r\n" + ) + async with raw_handshake_server(response) as url: + with pytest.raises(picows.WSInvalidUpgradeError, match="invalid upgrade header"): + await picows.ws_connect(AsyncClient, url) + + +async def test_handshake_invalid_message_error(): + response = ( + b"NOT-HTTP\r\n" + b"Header: value\r\n" + b"\r\n" + ) + async with raw_handshake_server(response) as url: + with pytest.raises(picows.WSInvalidMessageError): + await picows.ws_connect(AsyncClient, url) + + +async def test_client_handshake_timeout_none(): + async with delayed_handshake_server(0.2) as url: + transport, _ = await picows.ws_connect( + AsyncClient, + url, + websocket_handshake_timeout=None, + ) + transport.disconnect(False) + await transport.wait_disconnected() + + +async def test_server_handshake_timeout_none(): + server = await picows.ws_create_server( + lambda _: picows.WSListener(), + "127.0.0.1", + 0, + websocket_handshake_timeout=None, + ) + port = server.sockets[0].getsockname()[1] + try: + reader, writer = await asyncio.open_connection("127.0.0.1", port) + await asyncio.sleep(0.2) + assert not reader.at_eof() + writer.write( + b"GET / HTTP/1.1\r\n" + b"Host: 127.0.0.1\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + b"Sec-WebSocket-Version: 13\r\n" + b"\r\n" + ) + response = await reader.readuntil(b"\r\n\r\n") + assert b"101 Switching Protocols" in response + writer.close() + await writer.wait_closed() + finally: + server.close() + await server.wait_closed() + + +def test_resolve_logger(): + logger = logging.getLogger("tests.picows.custom") + + assert _resolve_logger(None, "client") is logging.getLogger("picows.client") + assert _resolve_logger(None, "server") is logging.getLogger("picows.server") + assert _resolve_logger("custom", "client") is logging.getLogger("picows.custom") + assert _resolve_logger(logger, "client") is logger diff --git a/tests/utils.py b/tests/utils.py index 0193765..86600a5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,6 +24,8 @@ def __init__(self, frame: picows.WSFrame): self.payload_as_bytes_from_mv = bytes(frame.get_payload_as_memoryview()) self.fin = frame.fin self.rsv1 = frame.rsv1 + self.rsv2 = frame.rsv2 + self.rsv3 = frame.rsv3 class TextFrame: @@ -34,6 +36,8 @@ def __init__(self, frame: picows.WSFrame): self.payload_as_utf8_text = frame.get_payload_as_utf8_text() self.fin = frame.fin self.rsv1 = frame.rsv1 + self.rsv2 = frame.rsv2 + self.rsv3 = frame.rsv3 class CloseFrame: @@ -44,6 +48,8 @@ def __init__(self, frame: picows.WSFrame): self.close_message = frame.get_close_message() self.fin = frame.fin self.rsv1 = frame.rsv1 + self.rsv2 = frame.rsv2 + self.rsv3 = frame.rsv3 def materialize_frame(frame: picows.WSFrame) -> Union[TextFrame, CloseFrame, BinaryFrame]: @@ -116,7 +122,7 @@ def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): self._transport.send(picows.WSMsgType.BINARY, msg) return - self._transport.send(frame.msg_type, frame.get_payload_as_bytes(), frame.fin, frame.rsv1) + self._transport.send(frame.msg_type, frame.get_payload_as_bytes(), frame.fin, frame.rsv1, frame.rsv2, frame.rsv3) @dataclass @@ -172,4 +178,3 @@ async def WSClient(server, listener_factory=None, **kwargs): await transport.wait_disconnected() except (TestException, picows.WSError): pass -