Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions serialx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
open_serial_connection,
)
from .common import (
AllConnectKwargs,
BaseSerial,
BaseSerialTransport,
ConnectKwargs,
ModemPins,
Parity,
PinState,
Expand Down Expand Up @@ -57,8 +59,10 @@
"Parity",
"PinState",
"Platform",
"AllConnectKwargs",
"BaseSerial",
"BaseSerialTransport",
"ConnectKwargs",
"Serial",
"SerialException",
"UnsupportedSetting",
Expand Down
11 changes: 7 additions & 4 deletions serialx/async_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@
import logging
from typing import Any, Generic, TypeVar, cast

from typing_extensions import Self
from typing_extensions import Self, Unpack

from .common import (
AllConnectKwargs,
BaseSerialTransport,
ModemPins,
Parity,
PinState,
SerialException,
StopBits,
get_uri_handler,
route_backend_kwargs,
)

LOGGER = logging.getLogger(__name__)
Expand All @@ -41,7 +43,7 @@ def __init__(
url: str | None,
*,
transport_cls: type[BaseSerialTransport] | None = None,
**kwargs: Any,
**kwargs: Unpack[AllConnectKwargs],
) -> None:
"""Initialize an unopened serial port.

Expand All @@ -50,7 +52,7 @@ def __init__(
"""

self._url = url
self._connect_kwargs: dict[str, Any] = kwargs
self._connect_kwargs: dict[str, Any] = dict(kwargs)
self._transport_cls = transport_cls

self._reader: asyncio.StreamReader | None = None
Expand Down Expand Up @@ -282,6 +284,7 @@ async def create_serial_connection(
None, get_uri_handler, url
)
resolved_cls = handler.async_transport_cls
kwargs = route_backend_kwargs(handler, kwargs) # pylint: disable=serialx-reassigned-parameter

protocol = protocol_factory()
transport = resolved_cls(loop=loop, protocol=protocol)
Expand Down Expand Up @@ -322,7 +325,7 @@ def async_serial_for_url(
url: str | None,
*,
transport_cls: type[BaseSerialTransport] | None = None,
**kwargs: Any,
**kwargs: Unpack[AllConnectKwargs],
) -> AsyncSerial:
"""Build an unopened AsyncSerial. Use `async with` or `await serial.open()`."""
return AsyncSerial(url, transport_cls=transport_cls, **kwargs)
129 changes: 118 additions & 11 deletions serialx/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,22 @@
from enum import Enum
import functools
import io
import logging
import os.path
from pathlib import Path
import time
from types import TracebackType
from typing import Any, Concatenate, NamedTuple, ParamSpec, TypeVar, cast
from typing import TYPE_CHECKING, Any, Concatenate, NamedTuple, ParamSpec, TypeVar, cast
import urllib.parse
import warnings

from typing_extensions import Buffer, Self, TypedDict, Unpack

if TYPE_CHECKING:
from aioesphomeapi.client import APIClient

LOGGER = logging.getLogger(__name__)


class Platform(str, Enum):
"""Built-in platform name."""
Expand Down Expand Up @@ -52,6 +58,7 @@ class RegisteredUriHandler:
list_serial_ports_func: Callable[..., list[SerialPortInfo]]
async_list_serial_ports_func: Callable[..., Awaitable[list[SerialPortInfo]]]
strip_uri_scheme: bool
connect_kwargs: frozenset[str] = frozenset()


class _RegistryEntry(NamedTuple):
Expand Down Expand Up @@ -87,6 +94,7 @@ def register_uri_handler(
] = async_empty_port_list,
weight: int = 1,
strip_uri_scheme: bool = False,
connect_kwargs: frozenset[str] = frozenset(),
) -> Callable[[], None]:
"""Register a URI handler.

Expand All @@ -110,6 +118,7 @@ def register_uri_handler(
strip_uri_scheme: If ``True``, the leading ``scheme`` / ``unique_scheme``
is removed before the URL is passed to the sync class. Set this when
the underlying class expects a bare device path rather than a URL.
connect_kwargs: Names of backend-specific connect kwargs this handler accepts.

Returns:
A callable that unregisters the handler.
Expand Down Expand Up @@ -140,6 +149,7 @@ def register_uri_handler(
list_serial_ports_func=list_serial_ports_func,
async_list_serial_ports_func=async_list_serial_ports_func,
strip_uri_scheme=strip_uri_scheme,
connect_kwargs=connect_kwargs,
),
)
bisect.insort_right(_REGISTERED_URI_HANDLERS[scheme], item)
Expand All @@ -165,6 +175,36 @@ def get_uri_handler(uri: str) -> RegisteredUriHandler:
return handlers[-1].handler


def route_backend_kwargs(
handler: RegisteredUriHandler, kwargs: dict[str, Any]
) -> dict[str, Any]:
"""Drop kwargs that belong to a different backend before dispatch."""
all_backend_specific_kwargs: set[str] = set()

for extras in BACKEND_CONNECT_KWARGS.values():
all_backend_specific_kwargs |= extras

for entries in _REGISTERED_URI_HANDLERS.values():
for entry in entries:
all_backend_specific_kwargs |= entry.handler.connect_kwargs

backend_specific_kwargs = (
BACKEND_CONNECT_KWARGS.get(handler.unique_scheme, set())
| handler.connect_kwargs
)

other_kwargs = all_backend_specific_kwargs - backend_specific_kwargs
dropped = other_kwargs & kwargs.keys()

if not dropped:
return dict(kwargs)

LOGGER.debug(
"Ignoring kwarg not accepted by %r backend: %s", handler.unique_scheme, dropped
)
return {key: value for key, value in kwargs.items() if key not in dropped}


class SerialException(Exception):
"""Base serial exception."""

Expand Down Expand Up @@ -195,18 +235,80 @@ class Parity(str, Enum):
SPACE = "S"


class ConnectKwargs( # type: ignore[call-arg] # PEP 728 not in mypy yet
TypedDict, total=False, extra_items=Any
):
"""Kwargs forwarded to BaseSerialTransport.connect / _connect."""
class _CommonConnectKwargs(TypedDict, total=False):
"""Connect kwargs accepted by every backend (see `BaseSerial.__init__`)."""

baudrate: int
parity: Parity
stopbits: StopBits
parity: Parity | str | None
stopbits: StopBits | int | float
xonxoff: bool
rtscts: bool
exclusive: bool
dsrdtr: bool
byte_size: int
read_timeout: float | None
write_timeout: float | None
rtsdtr_on_open: PinState
rtsdtr_on_close: PinState
exclusive: bool

# pyserial compatibility kwargs
port: str | None
timeout: float | None
bytesize: int | None
writeTimeout: float | None
do_not_open: bool | None


class ConnectKwargs( # type: ignore[call-arg] # PEP 728 not in mypy yet
_CommonConnectKwargs, total=False, extra_items=Any
):
"""Connect kwargs plumbed internally to `BaseSerialTransport.connect`."""


class AllConnectKwargs(_CommonConnectKwargs, total=False):
"""Every connect kwarg any built-in backend accepts, for typing."""

# linux://
low_latency: bool

# socket:// + tcp:// + rfc2217:// + esphome://
connect_timeout: float | None

# rfc2217://
receive_buffer_size: int

# windows://
read_buffer_size: int
write_buffer_size: int

# esphome://
api: APIClient | None
port_name: str | None
port_instance: int | None
key: str | None
password: str | None
noise_psk: str | None


# Backend-specific connect kwargs per unique URI scheme
BACKEND_CONNECT_KWARGS: dict[str, frozenset[str]] = {
"linux://": frozenset({"low_latency"}),
"windows://": frozenset({"read_buffer_size", "write_buffer_size"}),
"socket://": frozenset({"connect_timeout"}),
"tcp://": frozenset({"connect_timeout"}),
"rfc2217://": frozenset({"connect_timeout", "receive_buffer_size"}),
"esphome://": frozenset(
{
"api",
"connect_timeout",
"port_name",
"port_instance",
"key",
"password",
"noise_psk",
}
),
}


class PinState(Enum):
Expand Down Expand Up @@ -397,15 +499,18 @@ def _check_broken(self) -> None:
raise self._broken

@classmethod
def from_url(cls, url: str, *args: Any, **kwargs: Any) -> BaseSerial:
def from_url(
cls, url: str, *args: Any, **kwargs: Unpack[AllConnectKwargs]
) -> BaseSerial:
"""Create the appropriate serial port subclass for the given URL."""
handler = get_uri_handler(url)
target = url
if handler.strip_uri_scheme:
target = url.removeprefix(handler.scheme).removeprefix(
handler.unique_scheme
)
return handler.sync_cls(target, *args, **kwargs)
routed = route_backend_kwargs(handler, dict(kwargs))
return handler.sync_cls(target, *args, **routed)

@maybe_wrap_exceptions
def open(self) -> None:
Expand Down Expand Up @@ -1165,6 +1270,8 @@ async def async_list_serial_ports(
return await handler.async_list_serial_ports_func(**kwargs)


def serial_for_url(url: str, *args: Any, **kwargs: Any) -> BaseSerial:
def serial_for_url(
url: str, *args: Any, **kwargs: Unpack[AllConnectKwargs]
) -> BaseSerial:
"""Create the appropriate serial port subclass for the given URL."""
return BaseSerial.from_url(url, *args, **kwargs)
2 changes: 1 addition & 1 deletion serialx_compat/serial/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
@classmethod
def from_url(cls, url: str, *args: Any, **kwargs: Any) -> BaseSerial:
"""Create the appropriate serial port subclass for the given URL."""
return super().from_url(url, *args, _wrap_exceptions=True, **kwargs)
return super().from_url(url, *args, _wrap_exceptions=True, **kwargs) # type:ignore[call-arg]


Serial = CompatSerial
Expand Down
Loading
Loading