diff --git a/poetry.lock b/poetry.lock index 04bd68a..669dfa2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,20 @@ # This file is automatically @generated by Poetry 2.4.1 and should not be changed by hand. +[[package]] +name = "aiomqtt" +version = "2.5.1" +description = "The idiomatic asyncio MQTT client" +optional = false +python-versions = "<4.0,>=3.8" +groups = ["main"] +files = [ + {file = "aiomqtt-2.5.1-py3-none-any.whl", hash = "sha256:fd58c3593160e4d475d90ce911cdfc4239cd64de96b0ba22edf6c86bd7afa278"}, + {file = "aiomqtt-2.5.1.tar.gz", hash = "sha256:25a0a47d157e8f158d2da1110ea4786c0615518751e94f7b04976c977a8ff20d"}, +] + +[package.dependencies] +paho-mqtt = ">=2.1.0,<3.0.0" + [[package]] name = "anyio" version = "4.12.1" @@ -270,21 +285,6 @@ files = [ {file = "filelock-3.29.0.tar.gz", hash = "sha256:69974355e960702e789734cb4871f884ea6fe50bd8404051a3530bc07809cf90"}, ] -[[package]] -name = "gmqtt" -version = "0.7.0" -description = "Client for MQTT protocol" -optional = false -python-versions = ">=3.5" -groups = ["main"] -files = [ - {file = "gmqtt-0.7.0-py3-none-any.whl", hash = "sha256:3e5571a20e9c115d83d600caa228b06f716087653e241035e29cec73277b52cc"}, - {file = "gmqtt-0.7.0.tar.gz", hash = "sha256:bedfec7bac26b6b4ce1f0c4c32cff3d663526a54c882d323d41560fc3b9b44a2"}, -] - -[package.extras] -test = ["atomicwrites (>=1.3.0)", "attrs (>=19.1.0)", "codecov (>=2.0.15)", "coverage (>=4.5.3)", "more-itertools (>=7.0.0)", "pluggy (>=0.11.0)", "py (>=1.8.0)", "pytest (>=5.4.0)", "pytest-asyncio (>=0.12.0)", "pytest-cov (>=2.7.1)", "six (>=1.12.0)", "uvloop (>=0.14.0)"] - [[package]] name = "h11" version = "0.16.0" @@ -641,6 +641,21 @@ files = [ {file = "packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4"}, ] +[[package]] +name = "paho-mqtt" +version = "2.1.0" +description = "MQTT version 5.0/3.1.1 client class" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "paho_mqtt-2.1.0-py3-none-any.whl", hash = "sha256:6db9ba9b34ed5bc6b6e3812718c7e06e2fd7444540df2455d2c51bd58808feee"}, + {file = "paho_mqtt-2.1.0.tar.gz", hash = "sha256:12d6e7511d4137555a3f6ea167ae846af2c7357b10bc6fa4f7c3968fc1723834"}, +] + +[package.extras] +proxy = ["pysocks"] + [[package]] name = "pathspec" version = "1.0.4" @@ -1134,4 +1149,4 @@ python-discovery = ">=1.2.2" [metadata] lock-version = "2.1" python-versions = ">=3.12,<4.0" -content-hash = "ba4f444f330a6cb58a06926a1c2701c544ffeaeda02107641361b7abaa61e4af" +content-hash = "2f857bf23052a67c88f8f78ebf3698a9327bfc3333cc22984bf8ed1bf127c4bf" diff --git a/pyproject.toml b/pyproject.toml index afda254..8f35e6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,10 +17,10 @@ requires-python = '>=3.12,<4.0' dependencies = [ "saic-ismart-client-ng (>=0.9.3,<0.10.0)", 'httpx (>=0.28.1,<0.29.0)', - 'gmqtt (>=0.7.0,<0.8.0)', 'inflection (>=0.5.1,<0.6.0)', 'apscheduler (>=3.11.0,<4.0.0)', 'python-dotenv (>=1.1.1,<2.0.0)', + "aiomqtt (>=2.4.0,<3.0.0)", ] [project.urls] diff --git a/src/configuration/__init__.py b/src/configuration/__init__.py index 7173f37..63ca693 100644 --- a/src/configuration/__init__.py +++ b/src/configuration/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: from zoneinfo import ZoneInfo @@ -9,8 +9,11 @@ from integrations.openwb.charging_station import ChargingStation +Transport = Literal["tcp", "websockets"] + + class TransportProtocol(Enum): - def __init__(self, transport_mechanism: str, with_tls: bool) -> None: + def __init__(self, transport_mechanism: Transport, with_tls: bool) -> None: self.transport_mechanism = transport_mechanism self.with_tls = with_tls diff --git a/src/configuration/parser.py b/src/configuration/parser.py index 99456ab..304c295 100644 --- a/src/configuration/parser.py +++ b/src/configuration/parser.py @@ -105,15 +105,18 @@ def __parse_mqtt_transport(args: Namespace, config: Configuration) -> None: args.tls_server_cert_check_hostname ) else: - msg = f"Invalid MQTT URI scheme: {parse_result.scheme}, use tcp or ws" + msg = f"Invalid MQTT URI scheme: {parse_result.scheme}, use tls, tcp or ws" raise SystemExit(msg) if parse_result.port: config.mqtt_port = parse_result.port - elif config.mqtt_transport_protocol == TransportProtocol.TCP: - config.mqtt_port = 1883 - else: + elif config.mqtt_transport_protocol == TransportProtocol.TLS: + config.mqtt_port = 8883 + elif config.mqtt_transport_protocol == TransportProtocol.WS: config.mqtt_port = 9001 + else: + # fallback to default mqtt port + config.mqtt_port = 1883 config.mqtt_host = str(parse_result.hostname) diff --git a/src/log_config.py b/src/log_config.py index 9c96cae..2a49e3a 100644 --- a/src/log_config.py +++ b/src/log_config.py @@ -7,7 +7,7 @@ MODULES_DEFAULT_LOG_LEVEL = { "asyncio": "WARNING", - "gmqtt": "WARNING", + "aiomqtt": "WARNING", "httpcore": "WARNING", "httpx": "WARNING", "saic_ismart_client_ng": "WARNING", diff --git a/src/main.py b/src/main.py index dc47e09..e2d90d9 100644 --- a/src/main.py +++ b/src/main.py @@ -27,4 +27,5 @@ configuration = process_command_line() mqtt_gateway = MqttGateway(configuration) + asyncio.run(mqtt_gateway.run(), debug=debug_log_enabled()) diff --git a/src/publisher/core.py b/src/publisher/core.py index 6e7e926..2ee2365 100644 --- a/src/publisher/core.py +++ b/src/publisher/core.py @@ -102,30 +102,55 @@ def publish_json( no_prefix: bool = False, *, retain: bool = True, + qos: int = 0, ) -> None: raise NotImplementedError @abstractmethod def publish_str( - self, key: str, value: str, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: str, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: raise NotImplementedError @abstractmethod def publish_int( - self, key: str, value: int, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: int, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: raise NotImplementedError @abstractmethod def publish_bool( - self, key: str, value: bool, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: bool, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: raise NotImplementedError @abstractmethod def publish_float( - self, key: str, value: float, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: float, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: raise NotImplementedError @@ -173,7 +198,7 @@ def publish( raise TypeError(msg) @abstractmethod - def clear_topic(self, key: str, no_prefix: bool = False) -> None: + def clear_topic(self, key: str, no_prefix: bool = False, qos: int = 0) -> None: raise NotImplementedError def get_mqtt_account_prefix(self) -> str: @@ -249,7 +274,9 @@ def __anonymize(self, data: T) -> T: return data def keepalive(self) -> None: - self.publish_str(mqtt_topics.INTERNAL_LWT, "online", False) + self.publish_str( + mqtt_topics.INTERNAL_LWT, "online", no_prefix=False, retain=True, qos=1 + ) @staticmethod def anonymize_str(value: str) -> str: diff --git a/src/publisher/log_publisher.py b/src/publisher/log_publisher.py index 7969c61..f05866a 100644 --- a/src/publisher/log_publisher.py +++ b/src/publisher/log_publisher.py @@ -30,36 +30,61 @@ def publish_json( no_prefix: bool = False, *, retain: bool = True, + qos: int = 0, ) -> None: anonymized_json = self.dict_to_anonymized_json(data) self.internal_publish(key, anonymized_json, retain=retain) @override def publish_str( - self, key: str, value: str, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: str, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.internal_publish(key, value, retain=retain) @override def publish_int( - self, key: str, value: int, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: int, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.internal_publish(key, value, retain=retain) @override def publish_bool( - self, key: str, value: bool, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: bool, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.internal_publish(key, value, retain=retain) @override def publish_float( - self, key: str, value: float, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: float, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.internal_publish(key, value, retain=retain) @override - def clear_topic(self, key: str, no_prefix: bool = False) -> None: + def clear_topic(self, key: str, no_prefix: bool = False, qos: int = 0) -> None: self.internal_publish(key, None) def internal_publish( diff --git a/src/publisher/mqtt_publisher.py b/src/publisher/mqtt_publisher.py index ed535d4..7f4b71a 100644 --- a/src/publisher/mqtt_publisher.py +++ b/src/publisher/mqtt_publisher.py @@ -1,11 +1,13 @@ from __future__ import annotations +import asyncio import logging import math import ssl -from typing import TYPE_CHECKING, Any, Final, cast, override +from typing import TYPE_CHECKING, Any, override -import gmqtt +import aiomqtt +from aiomqtt.exceptions import MqttConnectError import mqtt_topics from publisher.core import Publisher @@ -17,9 +19,28 @@ LOG = logging.getLogger(__name__) +# Reconnect backoff: starts at 5 s, doubles on each failure, caps at 5 min +_RECONNECT_INTERVAL_MIN = 5 +_RECONNECT_INTERVAL_MAX = 300 + +# MQTT 3.1.1 spec section 3.2.2.3 — permanent connection refusal codes +_CONNACK_REFUSED_PROTOCOL_VERSION = 1 +_CONNACK_REFUSED_IDENTIFIER_REJECTED = 2 +_CONNACK_REFUSED_BAD_CREDENTIALS = 4 +_CONNACK_REFUSED_NOT_AUTHORIZED = 5 +_FATAL_CONNECT_RC = { + _CONNACK_REFUSED_PROTOCOL_VERSION, + _CONNACK_REFUSED_IDENTIFIER_REJECTED, + _CONNACK_REFUSED_BAD_CREDENTIALS, + _CONNACK_REFUSED_NOT_AUTHORIZED, +} + class MqttPublisher(Publisher): - def __init__(self, configuration: Configuration) -> None: + def __init__( + self, + configuration: Configuration, + ) -> None: super().__init__(configuration) self.publisher_id = configuration.mqtt_client_id self.host = self.configuration.mqtt_host @@ -30,129 +51,196 @@ def __init__(self, configuration: Configuration) -> None: self.vin_by_charger_connected_topic: dict[str, str] = {} self.vin_by_imported_energy_topic: dict[str, str] = {} self.first_connection = True + self.client: None | aiomqtt.Client = None + self.__running: asyncio.Task[None] | None = None + self.__connected = asyncio.Event() + self.__fatal_connect_error: SystemExit | None = None - mqtt_client = gmqtt.Client( - client_id=str(self.publisher_id), - transport=self.transport_protocol.transport_mechanism, - will_message=gmqtt.Message( - topic=self.get_topic(mqtt_topics.INTERNAL_LWT, False), - payload="offline", - retain=True, - ), - ) - mqtt_client.on_connect = self.__on_connect - mqtt_client.on_message = self.__on_message - self.client: Final[gmqtt.Client] = mqtt_client - - @override - async def connect(self) -> None: - if self.configuration.mqtt_user is not None: - if self.configuration.mqtt_password is not None: - self.client.set_auth_credentials( - username=self.configuration.mqtt_user, - password=self.configuration.mqtt_password, - ) - else: - self.client.set_auth_credentials(username=self.configuration.mqtt_user) - + async def __run_loop(self) -> None: + if not self.host: + LOG.info("MQTT host is not configured") + return + ssl_context: ssl.SSLContext | None = None if self.transport_protocol.with_tls: ssl_context = ssl.create_default_context() - cert_uri = self.configuration.tls_server_cert_path - if cert_uri: - LOG.debug(f"Using custom CA file {cert_uri}") - ssl_context.load_verify_locations(cafile=cert_uri) + if self.configuration.tls_server_cert_path: + LOG.debug( + f"Using custom CA file {self.configuration.tls_server_cert_path}" + ) + ssl_context.load_verify_locations( + cafile=self.configuration.tls_server_cert_path + ) if not self.configuration.tls_server_cert_check_hostname: LOG.warning( f"Skipping hostname check for TLS connection to {self.host}" ) - ssl_context.check_hostname = False - else: - ssl_context = None - await self.client.connect( - host=self.host, + + client = aiomqtt.Client( + hostname=self.host, port=self.port, - version=gmqtt.constants.MQTTv311, - ssl=ssl_context, + identifier=str(self.publisher_id), + transport=self.transport_protocol.transport_mechanism, + username=self.configuration.mqtt_user or None, + password=self.configuration.mqtt_password or None, + clean_session=True, + tls_context=ssl_context, + tls_insecure=bool( + ssl_context and not self.configuration.tls_server_cert_check_hostname + ), + will=aiomqtt.Will( + topic=self.get_topic(mqtt_topics.INTERNAL_LWT, False), + payload="offline", + retain=True, + qos=1, + ), ) - - def __on_connect( - self, _client: Any, _flags: Any, rc: int, _properties: Any - ) -> None: - if rc == gmqtt.constants.CONNACK_ACCEPTED: - LOG.info("Connected to MQTT broker") - if not self.first_connection: - self.enable_commands() - if self.command_listener is not None: - self.command_listener.on_mqtt_reconnected() - self.first_connection = False - self.keepalive() - else: - if rc == gmqtt.constants.CONNACK_REFUSED_BAD_USERNAME_PASSWORD: - LOG.error( - f"MQTT connection error: bad username or password. Return code {rc}" + client.pending_calls_threshold = 150 + reconnect_interval = _RECONNECT_INTERVAL_MIN + while True: + try: + LOG.debug( + "Connecting to %s:%s as %s", + self.host, + self.port, + self.publisher_id, ) - elif rc == gmqtt.constants.CONNACK_REFUSED_PROTOCOL_VERSION: - LOG.error( - f"MQTT connection error: refused protocol version. Return code {rc}" + async with client as client_context: + self.client = client_context + self.__connected.set() + await self.__on_connect() + reconnect_interval = _RECONNECT_INTERVAL_MIN + async for message in client_context.messages: + await self._on_message( + client_context, + str(message.topic), + message.payload, + message.qos, + message.retain, + ) + except MqttConnectError as e: + # Permanent rejections — retrying won't help. + # rc 3 (server unavailable) is transient and falls through to reconnect. + # ReasonCode.__eq__ supports int comparison, so no isinstance guard needed. + if e.rc in _FATAL_CONNECT_RC: + LOG.error("MQTT connection permanently refused: %s", e) + self.__fatal_connect_error = SystemExit(str(e)) + self.__connected.set() + return + LOG.warning( + "Connection to %s:%s refused (transient); Reconnecting in %d seconds ...", + self.host, + self.port, + reconnect_interval, ) - else: - LOG.error(f"MQTT connection error.Return code {rc}") - msg = f"Unable to connect to MQTT broker. Return code: {rc}" - raise SystemExit(msg) + await asyncio.sleep(reconnect_interval) + reconnect_interval = min(reconnect_interval * 2, _RECONNECT_INTERVAL_MAX) + except aiomqtt.MqttError: + LOG.warning( + "Connection to %s:%s lost; Reconnecting in %d seconds ...", + self.host, + self.port, + reconnect_interval, + ) + await asyncio.sleep(reconnect_interval) + reconnect_interval = min(reconnect_interval * 2, _RECONNECT_INTERVAL_MAX) + except asyncio.exceptions.CancelledError: + LOG.debug("MQTT publisher loop cancelled") + raise + finally: + self.__connected.clear() + LOG.info("MQTT client disconnected") + + @override + async def connect(self) -> None: + if self.__running and not self.__running.done(): + LOG.warning("MQTT client is already running") + return + if not self.host: + LOG.info("MQTT host is not configured") + return + self.__running = asyncio.create_task(self.__run_loop()) + await self.__connected.wait() + if self.__fatal_connect_error is not None: + raise self.__fatal_connect_error + + async def __on_connect(self) -> None: + LOG.info("Connected to MQTT broker") + if not self.first_connection: + await self.__enable_commands() + if self.command_listener is not None: + self.command_listener.on_mqtt_reconnected() + self.first_connection = False + self.keepalive() @override def enable_commands(self) -> None: - LOG.info("Subscribing to MQTT command topics") - mqtt_account_prefix = self.get_mqtt_account_prefix() - self.client.subscribe( - f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/+/+/{mqtt_topics.SET_SUFFIX}" - ) - self.client.subscribe( - f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/+/+/+/{mqtt_topics.SET_SUFFIX}" - ) - self.client.subscribe( - f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/{mqtt_topics.REFRESH_MODE}/{mqtt_topics.SET_SUFFIX}" - ) - self.client.subscribe( - f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/{mqtt_topics.REFRESH_PERIOD}/+/{mqtt_topics.SET_SUFFIX}" - ) - for charging_station in self.configuration.charging_stations_by_vin.values(): - LOG.debug( - f"Subscribing to MQTT topic {charging_station.charge_state_topic}" + task = asyncio.get_running_loop().create_task(self.__enable_commands()) + task.add_done_callback(self.__on_enable_commands_done) + + def __on_enable_commands_done(self, task: asyncio.Task[None]) -> None: + if not task.cancelled() and (exc := task.exception()): + LOG.error("Failed to enable MQTT command subscriptions: %s", exc) + + async def __enable_commands(self) -> None: + if not self.__connected.is_set() or not self.client: + LOG.error("Failed to enable commands: MQTT client is not connected") + return + try: + LOG.info("Subscribing to MQTT command topics") + mqtt_account_prefix = self.get_mqtt_account_prefix() + await self.client.subscribe( + f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/+/+/{mqtt_topics.SET_SUFFIX}" + ) + await self.client.subscribe( + f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/+/+/+/{mqtt_topics.SET_SUFFIX}" ) - self.vin_by_charge_state_topic[charging_station.charge_state_topic] = ( - charging_station.vin + await self.client.subscribe( + f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/{mqtt_topics.REFRESH_MODE}/{mqtt_topics.SET_SUFFIX}" ) - self.client.subscribe(charging_station.charge_state_topic) - if charging_station.connected_topic: + await self.client.subscribe( + f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/{mqtt_topics.REFRESH_PERIOD}/+/{mqtt_topics.SET_SUFFIX}" + ) + for ( + charging_station + ) in self.configuration.charging_stations_by_vin.values(): LOG.debug( - f"Subscribing to MQTT topic {charging_station.connected_topic}" + f"Subscribing to MQTT topic {charging_station.charge_state_topic}" ) - self.vin_by_charger_connected_topic[ - charging_station.connected_topic - ] = charging_station.vin - self.client.subscribe(charging_station.connected_topic) - if charging_station.imported_energy_topic: - LOG.debug( - f"Subscribing to MQTT topic {charging_station.imported_energy_topic}" + self.vin_by_charge_state_topic[charging_station.charge_state_topic] = ( + charging_station.vin ) - self.vin_by_imported_energy_topic[ - charging_station.imported_energy_topic - ] = charging_station.vin - self.client.subscribe(charging_station.imported_energy_topic) - if self.configuration.ha_discovery_enabled: - # enable dynamic discovery pushing in case ha reconnects - self.client.subscribe(self.configuration.ha_lwt_topic) + await self.client.subscribe(charging_station.charge_state_topic) + if charging_station.connected_topic: + LOG.debug( + f"Subscribing to MQTT topic {charging_station.connected_topic}" + ) + self.vin_by_charger_connected_topic[ + charging_station.connected_topic + ] = charging_station.vin + await self.client.subscribe(charging_station.connected_topic) + if charging_station.imported_energy_topic: + LOG.debug( + f"Subscribing to MQTT topic {charging_station.imported_energy_topic}" + ) + self.vin_by_imported_energy_topic[ + charging_station.imported_energy_topic + ] = charging_station.vin + await self.client.subscribe(charging_station.imported_energy_topic) + if self.configuration.ha_discovery_enabled: + # enable dynamic discovery pushing in case ha reconnects + await self.client.subscribe(self.configuration.ha_lwt_topic) + except aiomqtt.MqttError as e: + LOG.error(f"Failed to subscribe to MQTT command topics: {e}") + raise e - async def __on_message( - self, _client: Any, topic: str, payload: Any, _qos: Any, _properties: Any + async def _on_message( + self, _client: Any, topic: str, payload: Any, _qos: Any, retained: bool ) -> None: try: if isinstance(payload, bytes): payload = payload.decode("utf-8") else: payload = str(payload) - retained = bool(_properties.get("retain", 0)) await self.__on_message_real( topic=topic, payload=payload, retained=retained ) @@ -228,13 +316,34 @@ async def __handle_imported_energy(self, topic: str, payload: str) -> None: ) def __publish( - self, topic: str, payload: WirePayload | None, *, retain: bool = True + self, + topic: str, + payload: WirePayload | None, + *, + retain: bool = True, + qos: int = 0, ) -> None: - self.client.publish(topic, payload, retain=retain) + LOG.debug("Publishing to MQTT topic %s with payload %s", topic, payload) + asyncio.get_running_loop().create_task( + self.__async_publish(topic, payload, retain=retain, qos=qos) + ) + + async def __async_publish( + self, topic: str, payload: Any, retain: bool, qos: int + ) -> None: + if not (self.client and self.is_connected()): + LOG.error("Failed to publish: MQTT client is not connected") + return + try: + await self.client.publish(topic, payload, retain=retain, qos=qos) + except aiomqtt.MqttError as e: + LOG.error( + f"Failed to publish to MQTT topic {topic} with payload {payload}: {e}" + ) @override def is_connected(self) -> bool: - return cast("bool", self.client.is_connected) + return self.__connected.is_set() @override def publish_json( @@ -244,47 +353,75 @@ def publish_json( no_prefix: bool = False, *, retain: bool = True, + qos: int = 0, ) -> None: payload = self.dict_to_anonymized_json(data) self.__publish( - topic=self.get_topic(key, no_prefix), payload=payload, retain=retain + topic=self.get_topic(key, no_prefix), + payload=payload, + retain=retain, + qos=qos, ) @override def publish_str( - self, key: str, value: str, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: str, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.__publish( - topic=self.get_topic(key, no_prefix), payload=value, retain=retain + topic=self.get_topic(key, no_prefix), payload=value, retain=retain, qos=qos ) @override def publish_int( - self, key: str, value: int, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: int, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.__publish( - topic=self.get_topic(key, no_prefix), payload=value, retain=retain + topic=self.get_topic(key, no_prefix), payload=value, retain=retain, qos=qos ) @override def publish_bool( - self, key: str, value: bool, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: bool, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.__publish( - topic=self.get_topic(key, no_prefix), payload=value, retain=retain + topic=self.get_topic(key, no_prefix), payload=value, retain=retain, qos=qos ) @override def publish_float( - self, key: str, value: float, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: float, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.__publish( - topic=self.get_topic(key, no_prefix), payload=value, retain=retain + topic=self.get_topic(key, no_prefix), payload=value, retain=retain, qos=qos ) @override - def clear_topic(self, key: str, no_prefix: bool = False) -> None: - self.__publish(topic=self.get_topic(key, no_prefix), payload=None) + def clear_topic(self, key: str, no_prefix: bool = False, qos: int = 0) -> None: + self.__publish(topic=self.get_topic(key, no_prefix), payload=None, qos=qos) def get_vin_from_topic(self, topic: str) -> str: global_topic_removed = topic[len(self.configuration.mqtt_topic) + 1 :] diff --git a/src/status_publisher/charge/chrg_mgmt_data.py b/src/status_publisher/charge/chrg_mgmt_data.py index b0444fc..5825590 100644 --- a/src/status_publisher/charge/chrg_mgmt_data.py +++ b/src/status_publisher/charge/chrg_mgmt_data.py @@ -104,18 +104,18 @@ def publish(self, charge_mgmt_data: ChrgMgmtData) -> ChrgMgmtDataProcessingResul self._transform_and_publish( topic=mqtt_topics.BMS_CHARGE_STATUS, value=charge_mgmt_data.bms_charging_status, - transform=lambda x: f"UNKNOWN {charge_mgmt_data.bmsChrgSts}" - if x is None - else x.name, + transform=lambda x: ( + f"UNKNOWN {charge_mgmt_data.bmsChrgSts}" if x is None else x.name + ), ) self._transform_and_publish( topic=mqtt_topics.DRIVETRAIN_CHARGING_STOP_REASON, value=charge_mgmt_data.charging_stop_reason, validator=lambda x: x != ChargingStopReason.NO_REASON, - transform=lambda x: f"UNKNOWN {charge_mgmt_data.bmsChrgSpRsn}" - if x is None - else x.name, + transform=lambda x: ( + f"UNKNOWN {charge_mgmt_data.bmsChrgSpRsn}" if x is None else x.name + ), ) self._publish( @@ -153,9 +153,9 @@ def publish(self, charge_mgmt_data: ChrgMgmtData) -> ChrgMgmtDataProcessingResul topic=mqtt_topics.DRIVETRAIN_BATTERY_HEATING_STOP_REASON, value=charge_mgmt_data.heating_stop_reason, validator=lambda x: x != HeatingStopReason.NO_REASON, - transform=lambda x: f"UNKNOWN ({charge_mgmt_data.bmsPTCHeatResp})" - if x is None - else x.name, + transform=lambda x: ( + f"UNKNOWN ({charge_mgmt_data.bmsPTCHeatResp})" if x is None else x.name + ), ) self._transform_and_publish( diff --git a/tests/publisher/test_publish_dispatch.py b/tests/publisher/test_publish_dispatch.py index 7902771..192a574 100644 --- a/tests/publisher/test_publish_dispatch.py +++ b/tests/publisher/test_publish_dispatch.py @@ -299,35 +299,60 @@ def publish_json( no_prefix: bool = False, *, retain: bool = True, + qos: int = 0, ) -> None: pass @override def publish_str( - self, key: str, value: str, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: str, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: pass @override def publish_int( - self, key: str, value: int, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: int, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: pass @override def publish_bool( - self, key: str, value: bool, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: bool, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: pass @override def publish_float( - self, key: str, value: float, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: float, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: pass @override - def clear_topic(self, key: str, no_prefix: bool = False) -> None: + def clear_topic(self, key: str, no_prefix: bool = False, qos: int = 0) -> None: pass diff --git a/tests/test_mqtt_publisher.py b/tests/test_mqtt_publisher.py index 1fdeba3..d4bf19e 100644 --- a/tests/test_mqtt_publisher.py +++ b/tests/test_mqtt_publisher.py @@ -1,8 +1,10 @@ from __future__ import annotations +import asyncio +import json from typing import Any, override import unittest -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from configuration import Configuration, TransportProtocol from publisher.core import MqttCommandListener @@ -71,7 +73,7 @@ async def test_update_rear_window_heat_state(self) -> None: assert self.received_payload == REAR_WINDOW_HEAT_STATE async def send_message(self, topic: str, payload: Any) -> None: - await self.mqtt_client.client.on_message("client", topic, payload, 0, {}) + await self.mqtt_client._on_message("client", topic, payload, 0, False) async def test_get_vin_from_sanitized_topic(self) -> None: """Topics arrive with the sanitized prefix, not the raw username.""" @@ -100,63 +102,72 @@ async def on_charger_connection_state_changed( ) -> None: pass - def test_publish_str_default_is_retained(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_str("foo", "bar") - m_pub.assert_called_once_with("saic/foo", "bar", retain=True) - - def test_publish_str_forwards_retain_false(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_str("foo", "bar", retain=False) - m_pub.assert_called_once_with("saic/foo", "bar", retain=False) - - def test_publish_int_default_is_retained(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_int("foo", 42) - m_pub.assert_called_once_with("saic/foo", 42, retain=True) - - def test_publish_int_forwards_retain_false(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_int("foo", 42, retain=False) - m_pub.assert_called_once_with("saic/foo", 42, retain=False) - - def test_publish_bool_default_is_retained(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_bool("foo", True) - m_pub.assert_called_once_with("saic/foo", True, retain=True) - - def test_publish_bool_forwards_retain_false(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_bool("foo", True, retain=False) - m_pub.assert_called_once_with("saic/foo", True, retain=False) - - def test_publish_float_default_is_retained(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_float("foo", 1.5) - m_pub.assert_called_once_with("saic/foo", 1.5, retain=True) - - def test_publish_float_forwards_retain_false(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_float("foo", 1.5, retain=False) - m_pub.assert_called_once_with("saic/foo", 1.5, retain=False) - - def test_publish_json_default_is_retained(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_json("foo", {"a": 1}) - m_pub.assert_called_once() - args, kwargs = m_pub.call_args - assert args[0] == "saic/foo" - assert kwargs == {"retain": True} - - def test_publish_json_forwards_retain_false(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_json("foo", {"a": 1}, retain=False) - m_pub.assert_called_once() - args, kwargs = m_pub.call_args - assert args[0] == "saic/foo" - assert kwargs == {"retain": False} - - def test_clear_topic_publishes_none_retained(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.clear_topic("foo") - m_pub.assert_called_once_with("saic/foo", None, retain=True) + async def _call_and_flush(self, fn: Any, *args: Any, **kwargs: Any) -> AsyncMock: + mock = AsyncMock() + with patch.object(self.mqtt_client, "_MqttPublisher__async_publish", mock): + self.mqtt_client._MqttPublisher__connected.set() # type: ignore[attr-defined] + fn(*args, **kwargs) + await asyncio.sleep(0) + return mock + + async def test_publish_str_default_is_retained(self) -> None: + m = await self._call_and_flush(self.mqtt_client.publish_str, "foo", "bar") + m.assert_called_once_with("saic/foo", "bar", retain=True, qos=0) + + async def test_publish_str_forwards_retain_false(self) -> None: + m = await self._call_and_flush( + self.mqtt_client.publish_str, "foo", "bar", retain=False + ) + m.assert_called_once_with("saic/foo", "bar", retain=False, qos=0) + + async def test_publish_int_default_is_retained(self) -> None: + m = await self._call_and_flush(self.mqtt_client.publish_int, "foo", 42) + m.assert_called_once_with("saic/foo", 42, retain=True, qos=0) + + async def test_publish_int_forwards_retain_false(self) -> None: + m = await self._call_and_flush( + self.mqtt_client.publish_int, "foo", 42, retain=False + ) + m.assert_called_once_with("saic/foo", 42, retain=False, qos=0) + + async def test_publish_bool_default_is_retained(self) -> None: + m = await self._call_and_flush(self.mqtt_client.publish_bool, "foo", True) + m.assert_called_once_with("saic/foo", True, retain=True, qos=0) + + async def test_publish_bool_forwards_retain_false(self) -> None: + m = await self._call_and_flush( + self.mqtt_client.publish_bool, "foo", True, retain=False + ) + m.assert_called_once_with("saic/foo", True, retain=False, qos=0) + + async def test_publish_float_default_is_retained(self) -> None: + m = await self._call_and_flush(self.mqtt_client.publish_float, "foo", 1.5) + m.assert_called_once_with("saic/foo", 1.5, retain=True, qos=0) + + async def test_publish_float_forwards_retain_false(self) -> None: + m = await self._call_and_flush( + self.mqtt_client.publish_float, "foo", 1.5, retain=False + ) + m.assert_called_once_with("saic/foo", 1.5, retain=False, qos=0) + + async def test_publish_json_default_is_retained(self) -> None: + m = await self._call_and_flush(self.mqtt_client.publish_json, "foo", {"a": 1}) + m.assert_called_once() + args, kwargs = m.call_args + assert args[0] == "saic/foo" + assert json.loads(args[1]) == {"a": 1} + assert kwargs == {"retain": True, "qos": 0} + + async def test_publish_json_forwards_retain_false(self) -> None: + m = await self._call_and_flush( + self.mqtt_client.publish_json, "foo", {"a": 1}, retain=False + ) + m.assert_called_once() + args, kwargs = m.call_args + assert args[0] == "saic/foo" + assert json.loads(args[1]) == {"a": 1} + assert kwargs == {"retain": False, "qos": 0} + + async def test_clear_topic_publishes_none_retained(self) -> None: + m = await self._call_and_flush(self.mqtt_client.clear_topic, "foo") + m.assert_called_once_with("saic/foo", None, retain=True, qos=0)