diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..b269e04 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,5 @@ +# Run lint after updating code, and fix all errors +flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + +# Run mypy after updating code, and fix all errors. Disable errors that seems to be mypy quirks with #ignore comments +mypy picows \ No newline at end of file diff --git a/docs/source/guides.rst b/docs/source/guides.rst index 5b2cdbb..2e8c6c0 100644 --- a/docs/source/guides.rst +++ b/docs/source/guides.rst @@ -236,17 +236,20 @@ Using proxies :any:`ws_connect` supports HTTP, SOCKS4 and SOCKS5 proxies via `python-socks `_. -Use the ``proxy`` argument with a proxy URL. HTTPS proxy URLs (``https://...``) -are not currently supported: +Use the ``proxy`` argument with a proxy URL: .. code-block:: python transport, listener = await ws_connect( ClientListener, "ws://127.0.0.1:9000/", - proxy="socks5://user:password@127.0.0.1:1080", + proxy_ssl_context=ssl.create_default_context(), + proxy="https://user:password@127.0.0.1:1080", ) +When ``https://`` proxy URL scheme is used, TLS is established with the proxy +first. ``proxy_ssl_context`` can be used to customize certificate verification. + When connecting to ``wss://`` URLs through a proxy, **picows** establishes a tunnel through the proxy and then performs the TLS handshake with the websocket server. diff --git a/picows/api.py b/picows/api.py index 3a2b9e3..1eecf77 100644 --- a/picows/api.py +++ b/picows/api.py @@ -1,17 +1,19 @@ import asyncio +import ssl +import sys import urllib.parse from logging import getLogger from ssl import SSLContext -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union, cast from python_socks.async_.asyncio import Proxy +from python_socks.async_.asyncio.v2 import Proxy as ProxyV2 from .types import (WSHeadersLike, WSUpgradeRequest, WSUpgradeResponseWithListener, WSError) from .picows import (WSListener, WSTransport, WSAutoPingStrategy, # type: ignore [attr-defined] WSProtocol) -from .url import parse_url, ParsedURL - +from .url import parse_url, ParsedURL, WSInvalidURL def _maybe_handle_redirect(exc: WSError, old_parsed_url: ParsedURL, max_redirects: int) -> ParsedURL: if max_redirects <= 0: @@ -38,6 +40,54 @@ def _maybe_handle_redirect(exc: WSError, old_parsed_url: ParsedURL, max_redirect return parsed_url +class _DetachedWriterTransport: + def is_closing(self) -> bool: + return True + + def close(self) -> None: + return + + +def _detach_stream_writer_transport(stream: Any) -> asyncio.Transport: + transport = cast(asyncio.Transport, stream.writer.transport) + # Prevent StreamWriter.__del__ from closing a transport we hand over to WSProtocol. + stream.writer._transport = _DetachedWriterTransport() + return transport + + +async def _connect_through_https_proxy( + ws_protocol_factory: Callable[[], WSProtocol], + ssl_context: Optional[SSLContext], + proxy: str, + proxy_ssl_context: Optional[SSLContext], + proxy_parsed_url: urllib.parse.SplitResult, + parsed_url: ParsedURL +) -> tuple[WSTransport, WSListener]: + loop = asyncio.get_running_loop() + is_asyncio_loop = loop.__class__.__module__.startswith("asyncio") + if sys.version_info < (3, 11) and is_asyncio_loop: + raise WSInvalidURL( + proxy, + "HTTPS proxy with asyncio requires Python 3.11+ (asyncio StreamWriter.start_tls support)" + ) + proxy_ssl_context = proxy_ssl_context or ssl.create_default_context( + ssl.Purpose.SERVER_AUTH) + + http_proxy_url = urllib.parse.urlunsplit( + ("http", proxy_parsed_url.netloc, "", "", "") + ) + stream = await ProxyV2.from_url(http_proxy_url, proxy_ssl=proxy_ssl_context).connect( + dest_host=parsed_url.host, + dest_port=parsed_url.port, + dest_ssl=ssl_context) + ws_protocol = ws_protocol_factory() + transport = _detach_stream_writer_transport(stream) + transport.set_protocol(ws_protocol) + ws_protocol.connection_made(transport) + await ws_protocol.wait_until_handshake_complete() + return ws_protocol.transport, ws_protocol.listener + + async def ws_connect(ws_listener_factory: Callable[[], WSListener], # type: ignore [no-untyped-def] url: str, *, @@ -54,6 +104,7 @@ async def ws_connect(ws_listener_factory: Callable[[], WSListener], # type: igno extra_headers: Optional[WSHeadersLike] = None, max_redirects: int = 5, proxy: Optional[str] = None, + proxy_ssl_context: Optional[SSLContext] = None, **kwargs ) -> tuple[WSTransport, WSListener]: """ @@ -104,8 +155,9 @@ async def ws_connect(ws_listener_factory: Callable[[], WSListener], # type: igno * How many times we can follow HTTP redirects. Set to 0 in order to disable redirects. :param proxy: Optional proxy URL. Supported schemes are ``http://``, ``socks4://`` - and ``socks5://`` (including authenticated variants). - HTTPS proxy scheme (``https://``) is currently not supported. + ``https://`` and ``socks5://`` (including authenticated variants). + :param proxy_ssl_context: optional SSLContext to override default one when + https proxy scheme is used :return: :any:`WSTransport` object and a user handler returned by `ws_listener_factory()` """ @@ -118,15 +170,12 @@ async def ws_connect(ws_listener_factory: Callable[[], WSListener], # type: igno logger = getLogger(f"picows.{logger_name}") parsed_url = parse_url(url) + # Loop in order to follow redirects + # Break loop if we were able to upgrade while True: if parsed_url.username is not None or parsed_url.password is not None: logger.warning("Basic authentication was requested in URL, but it is not currently supported, ignore username and password") - if parsed_url.secure: - ssl = ssl_context if ssl_context is not None else True - else: - ssl = None - def ws_protocol_factory() -> WSProtocol: return WSProtocol( parsed_url.netloc, @@ -144,28 +193,36 @@ def ws_protocol_factory() -> WSProtocol: max_frame_size, extra_headers) + current_ssl_context = ssl_context if parsed_url.secure else None + + loop = asyncio.get_running_loop() + conn_kwargs = dict(kwargs) + + proxy_socket = None + host = None + port = None + try: - loop = asyncio.get_running_loop() - conn_kwargs = dict(kwargs) - if proxy is not None and urllib.parse.urlsplit(proxy).scheme.lower() == "https": - raise ValueError("HTTPS proxy URL scheme is not supported, use http://, socks4:// or socks5://") - - proxy_socket = None - host = None - port = None if proxy is not None: - proxy_socket = await Proxy.from_url(proxy).connect( - dest_host=parsed_url.host, - dest_port=parsed_url.port) - - if ssl is not None and "server_hostname" not in conn_kwargs: + proxy_url = urllib.parse.urlsplit(proxy) + proxy_scheme = proxy_url.scheme.lower() + if proxy_scheme == "https": + return await _connect_through_https_proxy( + ws_protocol_factory, current_ssl_context, proxy, + proxy_ssl_context, proxy_url, parsed_url) + else: + proxy_socket = await Proxy.from_url(proxy).connect( + dest_host=parsed_url.host, + dest_port=parsed_url.port) + + if parsed_url.secure and "server_hostname" not in conn_kwargs: conn_kwargs["server_hostname"] = parsed_url.host else: host = parsed_url.host port = parsed_url.port (_, ws_protocol) = await loop.create_connection( - ws_protocol_factory, host, port, ssl=ssl, sock=proxy_socket, **conn_kwargs) # type: ignore[arg-type] + ws_protocol_factory, host, port, ssl=current_ssl_context, sock=proxy_socket, **conn_kwargs) # type: ignore[arg-type] await ws_protocol.wait_until_handshake_complete() return ws_protocol.transport, ws_protocol.listener diff --git a/picows/picows.pyi b/picows/picows.pyi index 74c4711..5a3716f 100644 --- a/picows/picows.pyi +++ b/picows/picows.pyi @@ -177,6 +177,7 @@ async def ws_connect( url: str, *args: Any, ssl_context: Union[SSLContext, None] = None, + proxy_ssl_context: Union[SSLContext, None] = None, disconnect_on_exception: bool = True, websocket_handshake_timeout: float = 5, logger_name: str = "client", diff --git a/tests/test_redirects_and_proxies.py b/tests/test_redirects_and_proxies.py index acc9b03..eae72b3 100644 --- a/tests/test_redirects_and_proxies.py +++ b/tests/test_redirects_and_proxies.py @@ -1,16 +1,20 @@ +import asyncio import ssl +import sys from contextlib import asynccontextmanager from http import HTTPStatus from logging import getLogger import anyio import pytest +from anyio.streams.tls import TLSListener from tiny_proxy import HttpProxyHandler, Socks4ProxyHandler, Socks5ProxyHandler import picows +import picows.api as picows_api from tests.utils import ClientAsyncContext, AsyncClient, \ create_client_ssl_context, echo_server, multiloop_event_loop_policy, \ - ServerAsyncContext + ServerAsyncContext, create_server_ssl_context event_loop_policy = multiloop_event_loop_policy() @@ -21,6 +25,10 @@ def _create_proxy_handler(proxy_type: str): return HttpProxyHandler() if proxy_type == "http_auth": return HttpProxyHandler(username="user", password="password") + if proxy_type == "https": + return HttpProxyHandler() + if proxy_type == "https_auth": + return HttpProxyHandler(username="user", password="password") if proxy_type == "socks4": return Socks4ProxyHandler() if proxy_type == "socks5": @@ -32,6 +40,8 @@ def _create_proxy_handler(proxy_type: str): _proxy_url_templates = { "http": "http://127.0.0.1:{port}", "http_auth": "http://user:password@127.0.0.1:{port}", + "https": "https://127.0.0.1:{port}", + "https_auth": "https://user:password@127.0.0.1:{port}", "socks4": "socks4://127.0.0.1:{port}", "socks5": "socks5://user:password@127.0.0.1:{port}" } @@ -45,10 +55,11 @@ async def ProxyServer(proxy_type: str): url_template = _proxy_url_templates[proxy_type] handler = _create_proxy_handler(proxy_type) listener = await anyio.create_tcp_listener(local_host="127.0.0.1") + proxy_listener = TLSListener(listener, create_server_ssl_context()) if proxy_type.startswith("https") else listener task_group = anyio.create_task_group() await task_group.__aenter__() - task_group.start_soon(listener.serve, handler.handle) + task_group.start_soon(proxy_listener.serve, handler.handle) try: proxy_port = listener.listeners[0].extra(anyio.abc.SocketAttribute.local_port) @@ -56,7 +67,7 @@ async def ProxyServer(proxy_type: str): finally: task_group.cancel_scope.cancel() await task_group.__aexit__(None, None, None) - await listener.aclose() + await proxy_listener.aclose() @pytest.fixture() @@ -87,7 +98,7 @@ def listener_factory(r): yield server_urls.tcp_url -@pytest.mark.parametrize("proxy_type", ["direct", "http", "http_auth", "socks4", "socks5"]) +@pytest.mark.parametrize("proxy_type", ["direct", "socks4", "socks5", "http", "http_auth", "https", "https_auth"]) async def test_redirect_through_proxy(redirect_server_2, proxy_type: str): # This is an absolute masterpiece! Best test I wrote ever! # @@ -97,31 +108,50 @@ async def test_redirect_through_proxy(redirect_server_2, proxy_type: str): # echo server, send request and validate response. # # God bless pytest! + + is_https = proxy_type in ("https", "https_auth") + is_asyncio_loop = isinstance(asyncio.get_event_loop_policy(), asyncio.DefaultEventLoopPolicy) + + if sys.version_info < (3, 11) and is_asyncio_loop and is_https: + pytest.skip("HTTPS proxy using asyncio requires Python 3.11+") + return + client_ssl_ctx = create_client_ssl_context() + proxy_ssl_ctx = create_client_ssl_context() if is_https else None async with ProxyServer(proxy_type) as proxy_url: - async with ClientAsyncContext(AsyncClient, redirect_server_2, ssl_context=client_ssl_ctx, proxy=proxy_url) as (transport, listener): + async with ClientAsyncContext(AsyncClient, redirect_server_2, ssl_context=client_ssl_ctx, proxy=proxy_url, proxy_ssl_context=proxy_ssl_ctx) as (transport, listener): transport.send(picows.WSMsgType.BINARY, b"hello over proxy") - frame = await listener.get_message() + frame = await listener.get_message(1.0) assert frame.msg_type == picows.WSMsgType.BINARY assert frame.payload_as_bytes == b"hello over proxy" with pytest.raises(picows.WSError, match="status 101"): - await picows.ws_connect(AsyncClient, redirect_server_2, max_redirects=0, proxy=proxy_url) + await picows.ws_connect(AsyncClient, redirect_server_2, max_redirects=0, proxy=proxy_url, proxy_ssl_context=proxy_ssl_ctx) with pytest.raises(picows.WSError, match="status 101"): - await picows.ws_connect(AsyncClient, redirect_server_2, max_redirects=1, proxy=proxy_url) + await picows.ws_connect(AsyncClient, redirect_server_2, max_redirects=1, proxy=proxy_url, proxy_ssl_context=proxy_ssl_ctx) -@pytest.mark.parametrize("proxy_type", ["direct", "http", "http_auth", "socks4", "socks5"]) +@pytest.mark.parametrize("proxy_type", ["socks4", "socks5", "http"]) +@pytest.mark.skip(reason="echo server may respond with 429 (too many requests if we spam it a lot)") async def test_proxy_dns_resolution(proxy_type): + is_https = proxy_type in ("https", "https_auth") + is_asyncio_loop = isinstance(asyncio.get_event_loop_policy(), asyncio.DefaultEventLoopPolicy) + + if sys.version_info < (3, 11) and is_asyncio_loop and is_https: + pytest.skip("HTTPS proxy using asyncio requires Python 3.11+") + return + client_ssl_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + proxy_ssl_ctx = create_client_ssl_context() if is_https else None async with ProxyServer(proxy_type) as proxy_url: - async with ClientAsyncContext(AsyncClient, "wss://echo.websocket.org", ssl_context=client_ssl_ctx, proxy=proxy_url) as (transport, listener): + async with ClientAsyncContext(AsyncClient, "wss://echo.websocket.org", ssl_context=client_ssl_ctx, proxy=proxy_url, proxy_ssl_context=proxy_ssl_ctx) as (transport, listener): frame = await listener.get_message() _logger.debug("Welcome frame from echo.websocket.org: %s", frame.payload_as_ascii_text) transport.send(picows.WSMsgType.BINARY, b"hello over proxy") frame = await listener.get_message() assert frame.msg_type == picows.WSMsgType.BINARY assert frame.payload_as_bytes == b"hello over proxy" +