Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Run lint after updating code, and fix all errors
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics

# Run mypy after updating code, and fix all errors. Disable errors that seems to be mypy quirks with #ignore comments
mypy picows
9 changes: 6 additions & 3 deletions docs/source/guides.rst
Original file line number Diff line number Diff line change
Expand Up @@ -236,17 +236,20 @@ Using proxies

:any:`ws_connect` supports HTTP, SOCKS4 and SOCKS5 proxies via
`python-socks <https://github.com/romis2012/python-socks>`_.
Use the ``proxy`` argument with a proxy URL. HTTPS proxy URLs (``https://...``)
are not currently supported:
Use the ``proxy`` argument with a proxy URL:

.. code-block:: python

transport, listener = await ws_connect(
ClientListener,
"ws://127.0.0.1:9000/",
proxy="socks5://user:password@127.0.0.1:1080",
proxy_ssl_context=ssl.create_default_context(),
proxy="https://user:password@127.0.0.1:1080",
)

When ``https://`` proxy URL scheme is used, TLS is established with the proxy
first. ``proxy_ssl_context`` can be used to customize certificate verification.

When connecting to ``wss://`` URLs through a proxy, **picows** establishes a tunnel
through the proxy and then performs the TLS handshake with the websocket server.

Expand Down
105 changes: 81 additions & 24 deletions picows/api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import asyncio
import ssl
import sys
import urllib.parse
from logging import getLogger
from ssl import SSLContext
from typing import Callable, Optional, Union
from typing import Any, Callable, Optional, Union, cast

from python_socks.async_.asyncio import Proxy
from python_socks.async_.asyncio.v2 import Proxy as ProxyV2

from .types import (WSHeadersLike, WSUpgradeRequest,
WSUpgradeResponseWithListener, WSError)
from .picows import (WSListener, WSTransport, WSAutoPingStrategy, # type: ignore [attr-defined]
WSProtocol)
from .url import parse_url, ParsedURL

from .url import parse_url, ParsedURL, WSInvalidURL

def _maybe_handle_redirect(exc: WSError, old_parsed_url: ParsedURL, max_redirects: int) -> ParsedURL:
if max_redirects <= 0:
Expand All @@ -38,6 +40,54 @@ def _maybe_handle_redirect(exc: WSError, old_parsed_url: ParsedURL, max_redirect
return parsed_url


class _DetachedWriterTransport:
def is_closing(self) -> bool:
return True

def close(self) -> None:
return


def _detach_stream_writer_transport(stream: Any) -> asyncio.Transport:
transport = cast(asyncio.Transport, stream.writer.transport)
# Prevent StreamWriter.__del__ from closing a transport we hand over to WSProtocol.
stream.writer._transport = _DetachedWriterTransport()
return transport


async def _connect_through_https_proxy(
ws_protocol_factory: Callable[[], WSProtocol],
ssl_context: Optional[SSLContext],
proxy: str,
proxy_ssl_context: Optional[SSLContext],
proxy_parsed_url: urllib.parse.SplitResult,
parsed_url: ParsedURL
) -> tuple[WSTransport, WSListener]:
loop = asyncio.get_running_loop()
is_asyncio_loop = loop.__class__.__module__.startswith("asyncio")
if sys.version_info < (3, 11) and is_asyncio_loop:
raise WSInvalidURL(
proxy,
"HTTPS proxy with asyncio requires Python 3.11+ (asyncio StreamWriter.start_tls support)"
)
proxy_ssl_context = proxy_ssl_context or ssl.create_default_context(
ssl.Purpose.SERVER_AUTH)

http_proxy_url = urllib.parse.urlunsplit(
("http", proxy_parsed_url.netloc, "", "", "")
)
stream = await ProxyV2.from_url(http_proxy_url, proxy_ssl=proxy_ssl_context).connect(
dest_host=parsed_url.host,
dest_port=parsed_url.port,
dest_ssl=ssl_context)
ws_protocol = ws_protocol_factory()
transport = _detach_stream_writer_transport(stream)
transport.set_protocol(ws_protocol)
ws_protocol.connection_made(transport)
await ws_protocol.wait_until_handshake_complete()
return ws_protocol.transport, ws_protocol.listener


async def ws_connect(ws_listener_factory: Callable[[], WSListener], # type: ignore [no-untyped-def]
url: str,
*,
Expand All @@ -54,6 +104,7 @@ async def ws_connect(ws_listener_factory: Callable[[], WSListener], # type: igno
extra_headers: Optional[WSHeadersLike] = None,
max_redirects: int = 5,
proxy: Optional[str] = None,
proxy_ssl_context: Optional[SSLContext] = None,
**kwargs
) -> tuple[WSTransport, WSListener]:
"""
Expand Down Expand Up @@ -104,8 +155,9 @@ async def ws_connect(ws_listener_factory: Callable[[], WSListener], # type: igno
* How many times we can follow HTTP redirects. Set to 0 in order to disable redirects.
:param proxy:
Optional proxy URL. Supported schemes are ``http://``, ``socks4://``
and ``socks5://`` (including authenticated variants).
HTTPS proxy scheme (``https://``) is currently not supported.
``https://`` and ``socks5://`` (including authenticated variants).
:param proxy_ssl_context: optional SSLContext to override default one when
https proxy scheme is used
:return: :any:`WSTransport` object and a user handler returned by `ws_listener_factory()`
"""

Expand All @@ -118,15 +170,12 @@ async def ws_connect(ws_listener_factory: Callable[[], WSListener], # type: igno
logger = getLogger(f"picows.{logger_name}")
parsed_url = parse_url(url)

# Loop in order to follow redirects
# Break loop if we were able to upgrade
while True:
if parsed_url.username is not None or parsed_url.password is not None:
logger.warning("Basic authentication was requested in URL, but it is not currently supported, ignore username and password")

if parsed_url.secure:
ssl = ssl_context if ssl_context is not None else True
else:
ssl = None

def ws_protocol_factory() -> WSProtocol:
return WSProtocol(
parsed_url.netloc,
Expand All @@ -144,28 +193,36 @@ def ws_protocol_factory() -> WSProtocol:
max_frame_size,
extra_headers)

current_ssl_context = ssl_context if parsed_url.secure else None

loop = asyncio.get_running_loop()
conn_kwargs = dict(kwargs)

proxy_socket = None
host = None
port = None

try:
loop = asyncio.get_running_loop()
conn_kwargs = dict(kwargs)
if proxy is not None and urllib.parse.urlsplit(proxy).scheme.lower() == "https":
raise ValueError("HTTPS proxy URL scheme is not supported, use http://, socks4:// or socks5://")

proxy_socket = None
host = None
port = None
if proxy is not None:
proxy_socket = await Proxy.from_url(proxy).connect(
dest_host=parsed_url.host,
dest_port=parsed_url.port)

if ssl is not None and "server_hostname" not in conn_kwargs:
proxy_url = urllib.parse.urlsplit(proxy)
proxy_scheme = proxy_url.scheme.lower()
if proxy_scheme == "https":
return await _connect_through_https_proxy(
ws_protocol_factory, current_ssl_context, proxy,
proxy_ssl_context, proxy_url, parsed_url)
else:
proxy_socket = await Proxy.from_url(proxy).connect(
dest_host=parsed_url.host,
dest_port=parsed_url.port)

if parsed_url.secure and "server_hostname" not in conn_kwargs:
conn_kwargs["server_hostname"] = parsed_url.host
else:
host = parsed_url.host
port = parsed_url.port

(_, ws_protocol) = await loop.create_connection(
ws_protocol_factory, host, port, ssl=ssl, sock=proxy_socket, **conn_kwargs) # type: ignore[arg-type]
ws_protocol_factory, host, port, ssl=current_ssl_context, sock=proxy_socket, **conn_kwargs) # type: ignore[arg-type]

await ws_protocol.wait_until_handshake_complete()
return ws_protocol.transport, ws_protocol.listener
Expand Down
1 change: 1 addition & 0 deletions picows/picows.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ async def ws_connect(
url: str,
*args: Any,
ssl_context: Union[SSLContext, None] = None,
proxy_ssl_context: Union[SSLContext, None] = None,
disconnect_on_exception: bool = True,
websocket_handshake_timeout: float = 5,
logger_name: str = "client",
Expand Down
50 changes: 40 additions & 10 deletions tests/test_redirects_and_proxies.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import asyncio
import ssl
import sys
from contextlib import asynccontextmanager
from http import HTTPStatus
from logging import getLogger

import anyio
import pytest
from anyio.streams.tls import TLSListener
from tiny_proxy import HttpProxyHandler, Socks4ProxyHandler, Socks5ProxyHandler

import picows
import picows.api as picows_api
from tests.utils import ClientAsyncContext, AsyncClient, \
create_client_ssl_context, echo_server, multiloop_event_loop_policy, \
ServerAsyncContext
ServerAsyncContext, create_server_ssl_context

event_loop_policy = multiloop_event_loop_policy()

Expand All @@ -21,6 +25,10 @@ def _create_proxy_handler(proxy_type: str):
return HttpProxyHandler()
if proxy_type == "http_auth":
return HttpProxyHandler(username="user", password="password")
if proxy_type == "https":
return HttpProxyHandler()
if proxy_type == "https_auth":
return HttpProxyHandler(username="user", password="password")
if proxy_type == "socks4":
return Socks4ProxyHandler()
if proxy_type == "socks5":
Expand All @@ -32,6 +40,8 @@ def _create_proxy_handler(proxy_type: str):
_proxy_url_templates = {
"http": "http://127.0.0.1:{port}",
"http_auth": "http://user:password@127.0.0.1:{port}",
"https": "https://127.0.0.1:{port}",
"https_auth": "https://user:password@127.0.0.1:{port}",
"socks4": "socks4://127.0.0.1:{port}",
"socks5": "socks5://user:password@127.0.0.1:{port}"
}
Expand All @@ -45,18 +55,19 @@ async def ProxyServer(proxy_type: str):
url_template = _proxy_url_templates[proxy_type]
handler = _create_proxy_handler(proxy_type)
listener = await anyio.create_tcp_listener(local_host="127.0.0.1")
proxy_listener = TLSListener(listener, create_server_ssl_context()) if proxy_type.startswith("https") else listener

task_group = anyio.create_task_group()
await task_group.__aenter__()
task_group.start_soon(listener.serve, handler.handle)
task_group.start_soon(proxy_listener.serve, handler.handle)

try:
proxy_port = listener.listeners[0].extra(anyio.abc.SocketAttribute.local_port)
yield url_template.format(port=proxy_port)
finally:
task_group.cancel_scope.cancel()
await task_group.__aexit__(None, None, None)
await listener.aclose()
await proxy_listener.aclose()


@pytest.fixture()
Expand Down Expand Up @@ -87,7 +98,7 @@ def listener_factory(r):
yield server_urls.tcp_url


@pytest.mark.parametrize("proxy_type", ["direct", "http", "http_auth", "socks4", "socks5"])
@pytest.mark.parametrize("proxy_type", ["direct", "socks4", "socks5", "http", "http_auth", "https", "https_auth"])
async def test_redirect_through_proxy(redirect_server_2, proxy_type: str):
# This is an absolute masterpiece! Best test I wrote ever!
#
Expand All @@ -97,31 +108,50 @@ async def test_redirect_through_proxy(redirect_server_2, proxy_type: str):
# echo server, send request and validate response.
#
# God bless pytest!

is_https = proxy_type in ("https", "https_auth")
is_asyncio_loop = isinstance(asyncio.get_event_loop_policy(), asyncio.DefaultEventLoopPolicy)

if sys.version_info < (3, 11) and is_asyncio_loop and is_https:
pytest.skip("HTTPS proxy using asyncio requires Python 3.11+")
return

client_ssl_ctx = create_client_ssl_context()
proxy_ssl_ctx = create_client_ssl_context() if is_https else None

async with ProxyServer(proxy_type) as proxy_url:
async with ClientAsyncContext(AsyncClient, redirect_server_2, ssl_context=client_ssl_ctx, proxy=proxy_url) as (transport, listener):
async with ClientAsyncContext(AsyncClient, redirect_server_2, ssl_context=client_ssl_ctx, proxy=proxy_url, proxy_ssl_context=proxy_ssl_ctx) as (transport, listener):
transport.send(picows.WSMsgType.BINARY, b"hello over proxy")
frame = await listener.get_message()
frame = await listener.get_message(1.0)
assert frame.msg_type == picows.WSMsgType.BINARY
assert frame.payload_as_bytes == b"hello over proxy"

with pytest.raises(picows.WSError, match="status 101"):
await picows.ws_connect(AsyncClient, redirect_server_2, max_redirects=0, proxy=proxy_url)
await picows.ws_connect(AsyncClient, redirect_server_2, max_redirects=0, proxy=proxy_url, proxy_ssl_context=proxy_ssl_ctx)

with pytest.raises(picows.WSError, match="status 101"):
await picows.ws_connect(AsyncClient, redirect_server_2, max_redirects=1, proxy=proxy_url)
await picows.ws_connect(AsyncClient, redirect_server_2, max_redirects=1, proxy=proxy_url, proxy_ssl_context=proxy_ssl_ctx)


@pytest.mark.parametrize("proxy_type", ["direct", "http", "http_auth", "socks4", "socks5"])
@pytest.mark.parametrize("proxy_type", ["socks4", "socks5", "http"])
@pytest.mark.skip(reason="echo server may respond with 429 (too many requests if we spam it a lot)")
async def test_proxy_dns_resolution(proxy_type):
is_https = proxy_type in ("https", "https_auth")
is_asyncio_loop = isinstance(asyncio.get_event_loop_policy(), asyncio.DefaultEventLoopPolicy)

if sys.version_info < (3, 11) and is_asyncio_loop and is_https:
pytest.skip("HTTPS proxy using asyncio requires Python 3.11+")
return

client_ssl_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
proxy_ssl_ctx = create_client_ssl_context() if is_https else None

async with ProxyServer(proxy_type) as proxy_url:
async with ClientAsyncContext(AsyncClient, "wss://echo.websocket.org", ssl_context=client_ssl_ctx, proxy=proxy_url) as (transport, listener):
async with ClientAsyncContext(AsyncClient, "wss://echo.websocket.org", ssl_context=client_ssl_ctx, proxy=proxy_url, proxy_ssl_context=proxy_ssl_ctx) as (transport, listener):
frame = await listener.get_message()
_logger.debug("Welcome frame from echo.websocket.org: %s", frame.payload_as_ascii_text)
transport.send(picows.WSMsgType.BINARY, b"hello over proxy")
frame = await listener.get_message()
assert frame.msg_type == picows.WSMsgType.BINARY
assert frame.payload_as_bytes == b"hello over proxy"

Loading