diff --git a/CHANGELOG.md b/CHANGELOG.md index 57439d1..cff18fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,41 @@ # Change Log +## 0.13.0 + +### Changed + +* Replace gmqtt with aiomqtt for MQTT connectivity. The new client uses a + pure-asyncio architecture — no more `run_coroutine_threadsafe` cross-thread + calls. Reconnect now uses exponential backoff (5 s → 10 s → … → 5 min) + instead of a fixed interval (#457). + +### Fixed + +* Exit immediately on permanent MQTT connection failures (bad credentials, + unknown protocol version, identifier rejected, not authorised) instead of + looping forever (#457). + +* Restore MQTT subscriptions and trigger HA re-discovery on broker reconnect + after a transient disconnect (#457). + +* Suppress spurious `WARNING: Could not extract a valid SoC kWh` log on every + poll for non-EV vehicles. The warning is preserved for EVs where the BMS + returns invalid data (#460). + +* Publish command result topics (`Success` / `Failed: …`) with `retain=False` + so a stale result from a previous session does not persist on the broker + across restarts (#461). + +* Emit `retain` as a JSON boolean (`true`/`false`) in Home Assistant MQTT + discovery payloads instead of the string `"true"`/`"false"` that the HA + schema requires (#459). + +* Warn when TLS hostname verification is disabled regardless of whether a + custom CA certificate is configured. Self-signed cert users (who typically + have no CA file) now see the warning too (#462). + +**Full Changelog**: https://github.com/SAIC-iSmart-API/saic-python-mqtt-gateway/compare/0.12.0...0.13.0 + ## 0.12.0 ### Added diff --git a/poetry.lock b/poetry.lock index 04bd68a..669dfa2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,20 @@ # This file is automatically @generated by Poetry 2.4.1 and should not be changed by hand. +[[package]] +name = "aiomqtt" +version = "2.5.1" +description = "The idiomatic asyncio MQTT client" +optional = false +python-versions = "<4.0,>=3.8" +groups = ["main"] +files = [ + {file = "aiomqtt-2.5.1-py3-none-any.whl", hash = "sha256:fd58c3593160e4d475d90ce911cdfc4239cd64de96b0ba22edf6c86bd7afa278"}, + {file = "aiomqtt-2.5.1.tar.gz", hash = "sha256:25a0a47d157e8f158d2da1110ea4786c0615518751e94f7b04976c977a8ff20d"}, +] + +[package.dependencies] +paho-mqtt = ">=2.1.0,<3.0.0" + [[package]] name = "anyio" version = "4.12.1" @@ -270,21 +285,6 @@ files = [ {file = "filelock-3.29.0.tar.gz", hash = "sha256:69974355e960702e789734cb4871f884ea6fe50bd8404051a3530bc07809cf90"}, ] -[[package]] -name = "gmqtt" -version = "0.7.0" -description = "Client for MQTT protocol" -optional = false -python-versions = ">=3.5" -groups = ["main"] -files = [ - {file = "gmqtt-0.7.0-py3-none-any.whl", hash = "sha256:3e5571a20e9c115d83d600caa228b06f716087653e241035e29cec73277b52cc"}, - {file = "gmqtt-0.7.0.tar.gz", hash = "sha256:bedfec7bac26b6b4ce1f0c4c32cff3d663526a54c882d323d41560fc3b9b44a2"}, -] - -[package.extras] -test = ["atomicwrites (>=1.3.0)", "attrs (>=19.1.0)", "codecov (>=2.0.15)", "coverage (>=4.5.3)", "more-itertools (>=7.0.0)", "pluggy (>=0.11.0)", "py (>=1.8.0)", "pytest (>=5.4.0)", "pytest-asyncio (>=0.12.0)", "pytest-cov (>=2.7.1)", "six (>=1.12.0)", "uvloop (>=0.14.0)"] - [[package]] name = "h11" version = "0.16.0" @@ -641,6 +641,21 @@ files = [ {file = "packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4"}, ] +[[package]] +name = "paho-mqtt" +version = "2.1.0" +description = "MQTT version 5.0/3.1.1 client class" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "paho_mqtt-2.1.0-py3-none-any.whl", hash = "sha256:6db9ba9b34ed5bc6b6e3812718c7e06e2fd7444540df2455d2c51bd58808feee"}, + {file = "paho_mqtt-2.1.0.tar.gz", hash = "sha256:12d6e7511d4137555a3f6ea167ae846af2c7357b10bc6fa4f7c3968fc1723834"}, +] + +[package.extras] +proxy = ["pysocks"] + [[package]] name = "pathspec" version = "1.0.4" @@ -1134,4 +1149,4 @@ python-discovery = ">=1.2.2" [metadata] lock-version = "2.1" python-versions = ">=3.12,<4.0" -content-hash = "ba4f444f330a6cb58a06926a1c2701c544ffeaeda02107641361b7abaa61e4af" +content-hash = "2f857bf23052a67c88f8f78ebf3698a9327bfc3333cc22984bf8ed1bf127c4bf" diff --git a/pyproject.toml b/pyproject.toml index afda254..8f35e6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,10 +17,10 @@ requires-python = '>=3.12,<4.0' dependencies = [ "saic-ismart-client-ng (>=0.9.3,<0.10.0)", 'httpx (>=0.28.1,<0.29.0)', - 'gmqtt (>=0.7.0,<0.8.0)', 'inflection (>=0.5.1,<0.6.0)', 'apscheduler (>=3.11.0,<4.0.0)', 'python-dotenv (>=1.1.1,<2.0.0)', + "aiomqtt (>=2.4.0,<3.0.0)", ] [project.urls] diff --git a/src/configuration/__init__.py b/src/configuration/__init__.py index 7173f37..63ca693 100644 --- a/src/configuration/__init__.py +++ b/src/configuration/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: from zoneinfo import ZoneInfo @@ -9,8 +9,11 @@ from integrations.openwb.charging_station import ChargingStation +Transport = Literal["tcp", "websockets"] + + class TransportProtocol(Enum): - def __init__(self, transport_mechanism: str, with_tls: bool) -> None: + def __init__(self, transport_mechanism: Transport, with_tls: bool) -> None: self.transport_mechanism = transport_mechanism self.with_tls = with_tls diff --git a/src/configuration/parser.py b/src/configuration/parser.py index 99456ab..304c295 100644 --- a/src/configuration/parser.py +++ b/src/configuration/parser.py @@ -105,15 +105,18 @@ def __parse_mqtt_transport(args: Namespace, config: Configuration) -> None: args.tls_server_cert_check_hostname ) else: - msg = f"Invalid MQTT URI scheme: {parse_result.scheme}, use tcp or ws" + msg = f"Invalid MQTT URI scheme: {parse_result.scheme}, use tls, tcp or ws" raise SystemExit(msg) if parse_result.port: config.mqtt_port = parse_result.port - elif config.mqtt_transport_protocol == TransportProtocol.TCP: - config.mqtt_port = 1883 - else: + elif config.mqtt_transport_protocol == TransportProtocol.TLS: + config.mqtt_port = 8883 + elif config.mqtt_transport_protocol == TransportProtocol.WS: config.mqtt_port = 9001 + else: + # fallback to default mqtt port + config.mqtt_port = 1883 config.mqtt_host = str(parse_result.hostname) diff --git a/src/extractors/__init__.py b/src/extractors/__init__.py index 2ebd343..746a735 100644 --- a/src/extractors/__init__.py +++ b/src/extractors/__init__.py @@ -43,17 +43,17 @@ def extract_soc_kwh( charge_status: ChrgMgmtDataRespProcessingResult | None, soc: float | None, ) -> float | None: - if ( - charge_status is not None - and (raw_soc_kwh := charge_status.soc_kwh) is not None - and (soc_kwh := __validate_and_convert_soc_kwh(raw_soc_kwh)) is not None - ): + if charge_status is None: + return None + + if (raw_soc_kwh := charge_status.soc_kwh) is not None and ( + soc_kwh := __validate_and_convert_soc_kwh(raw_soc_kwh) + ) is not None: LOG.debug("SoC kWh derived from realtimePower") return soc_kwh if ( soc is not None - and charge_status is not None and ( capacity := __validate_and_convert_soc_kwh( charge_status.real_total_battery_capacity diff --git a/src/handlers/vehicle_command.py b/src/handlers/vehicle_command.py index e7f994c..e1840d1 100644 --- a/src/handlers/vehicle_command.py +++ b/src/handlers/vehicle_command.py @@ -70,7 +70,7 @@ def __report_command_failure( else: LOG.error("Command %s failed: %s", command, detail) try: - self.publisher.publish_str(result_topic, f"Failed: {detail}") + self.publisher.publish_str(result_topic, f"Failed: {detail}", retain=False) except Exception: LOG.warning( "Failed to publish failure result for command %s", @@ -138,7 +138,7 @@ async def __execute_mqtt_command_handler( try: execution_result = await handler.handle(payload, retained=retained) - self.publisher.publish_str(result_topic, "Success") + self.publisher.publish_str(result_topic, "Success", retain=False) if execution_result.force_refresh: self.vehicle_state.set_refresh_mode( RefreshMode.FORCE, f"after command execution on topic {topic}" @@ -165,7 +165,7 @@ async def __execute_mqtt_command_handler( return try: execution_result = await handler.handle(payload, retained=retained) - self.publisher.publish_str(result_topic, "Success") + self.publisher.publish_str(result_topic, "Success", retain=False) if execution_result.force_refresh: self.vehicle_state.set_refresh_mode( RefreshMode.FORCE, diff --git a/src/integrations/home_assistant/base.py b/src/integrations/home_assistant/base.py index 322fa60..8f12dbe 100644 --- a/src/integrations/home_assistant/base.py +++ b/src/integrations/home_assistant/base.py @@ -55,7 +55,7 @@ def _publish_select( "command_topic": self._get_command_topic(topic), "value_template": value_template, "command_template": command_template, - "retain": str(retain).lower(), + "retain": retain, "options": options, "enabled_by_default": enabled, } @@ -87,7 +87,7 @@ def _publish_text( "command_topic": self._get_command_topic(topic), "value_template": value_template, "command_template": command_template, - "retain": str(retain).lower(), + "retain": retain, "enabled_by_default": enabled, } if min_value is not None: @@ -153,7 +153,7 @@ def _publish_number( "state_topic": self._get_state_topic(topic), "command_topic": self._get_command_topic(topic), "value_template": value_template, - "retain": str(retain).lower(), + "retain": retain, "mode": mode, "min": min_value, "max": max_value, diff --git a/src/log_config.py b/src/log_config.py index 9c96cae..2a49e3a 100644 --- a/src/log_config.py +++ b/src/log_config.py @@ -7,7 +7,7 @@ MODULES_DEFAULT_LOG_LEVEL = { "asyncio": "WARNING", - "gmqtt": "WARNING", + "aiomqtt": "WARNING", "httpcore": "WARNING", "httpx": "WARNING", "saic_ismart_client_ng": "WARNING", diff --git a/src/main.py b/src/main.py index dc47e09..e2d90d9 100644 --- a/src/main.py +++ b/src/main.py @@ -27,4 +27,5 @@ configuration = process_command_line() mqtt_gateway = MqttGateway(configuration) + asyncio.run(mqtt_gateway.run(), debug=debug_log_enabled()) diff --git a/src/publisher/core.py b/src/publisher/core.py index 6e7e926..2ee2365 100644 --- a/src/publisher/core.py +++ b/src/publisher/core.py @@ -102,30 +102,55 @@ def publish_json( no_prefix: bool = False, *, retain: bool = True, + qos: int = 0, ) -> None: raise NotImplementedError @abstractmethod def publish_str( - self, key: str, value: str, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: str, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: raise NotImplementedError @abstractmethod def publish_int( - self, key: str, value: int, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: int, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: raise NotImplementedError @abstractmethod def publish_bool( - self, key: str, value: bool, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: bool, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: raise NotImplementedError @abstractmethod def publish_float( - self, key: str, value: float, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: float, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: raise NotImplementedError @@ -173,7 +198,7 @@ def publish( raise TypeError(msg) @abstractmethod - def clear_topic(self, key: str, no_prefix: bool = False) -> None: + def clear_topic(self, key: str, no_prefix: bool = False, qos: int = 0) -> None: raise NotImplementedError def get_mqtt_account_prefix(self) -> str: @@ -249,7 +274,9 @@ def __anonymize(self, data: T) -> T: return data def keepalive(self) -> None: - self.publish_str(mqtt_topics.INTERNAL_LWT, "online", False) + self.publish_str( + mqtt_topics.INTERNAL_LWT, "online", no_prefix=False, retain=True, qos=1 + ) @staticmethod def anonymize_str(value: str) -> str: diff --git a/src/publisher/log_publisher.py b/src/publisher/log_publisher.py index 7969c61..f05866a 100644 --- a/src/publisher/log_publisher.py +++ b/src/publisher/log_publisher.py @@ -30,36 +30,61 @@ def publish_json( no_prefix: bool = False, *, retain: bool = True, + qos: int = 0, ) -> None: anonymized_json = self.dict_to_anonymized_json(data) self.internal_publish(key, anonymized_json, retain=retain) @override def publish_str( - self, key: str, value: str, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: str, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.internal_publish(key, value, retain=retain) @override def publish_int( - self, key: str, value: int, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: int, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.internal_publish(key, value, retain=retain) @override def publish_bool( - self, key: str, value: bool, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: bool, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.internal_publish(key, value, retain=retain) @override def publish_float( - self, key: str, value: float, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: float, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.internal_publish(key, value, retain=retain) @override - def clear_topic(self, key: str, no_prefix: bool = False) -> None: + def clear_topic(self, key: str, no_prefix: bool = False, qos: int = 0) -> None: self.internal_publish(key, None) def internal_publish( diff --git a/src/publisher/mqtt_publisher.py b/src/publisher/mqtt_publisher.py index ed535d4..d7cc9b6 100644 --- a/src/publisher/mqtt_publisher.py +++ b/src/publisher/mqtt_publisher.py @@ -1,11 +1,13 @@ from __future__ import annotations +import asyncio import logging import math import ssl -from typing import TYPE_CHECKING, Any, Final, cast, override +from typing import TYPE_CHECKING, Any, override -import gmqtt +import aiomqtt +from aiomqtt.exceptions import MqttConnectError import mqtt_topics from publisher.core import Publisher @@ -17,9 +19,28 @@ LOG = logging.getLogger(__name__) +# Reconnect backoff: starts at 5 s, doubles on each failure, caps at 5 min +_RECONNECT_INTERVAL_MIN = 5 +_RECONNECT_INTERVAL_MAX = 300 + +# MQTT 3.1.1 spec section 3.2.2.3 — permanent connection refusal codes +_CONNACK_REFUSED_PROTOCOL_VERSION = 1 +_CONNACK_REFUSED_IDENTIFIER_REJECTED = 2 +_CONNACK_REFUSED_BAD_CREDENTIALS = 4 +_CONNACK_REFUSED_NOT_AUTHORIZED = 5 +_FATAL_CONNECT_RC = { + _CONNACK_REFUSED_PROTOCOL_VERSION, + _CONNACK_REFUSED_IDENTIFIER_REJECTED, + _CONNACK_REFUSED_BAD_CREDENTIALS, + _CONNACK_REFUSED_NOT_AUTHORIZED, +} + class MqttPublisher(Publisher): - def __init__(self, configuration: Configuration) -> None: + def __init__( + self, + configuration: Configuration, + ) -> None: super().__init__(configuration) self.publisher_id = configuration.mqtt_client_id self.host = self.configuration.mqtt_host @@ -30,129 +51,200 @@ def __init__(self, configuration: Configuration) -> None: self.vin_by_charger_connected_topic: dict[str, str] = {} self.vin_by_imported_energy_topic: dict[str, str] = {} self.first_connection = True + self.client: None | aiomqtt.Client = None + self.__running: asyncio.Task[None] | None = None + self.__connected = asyncio.Event() + self.__fatal_connect_error: SystemExit | None = None + + async def __run_loop(self) -> None: + if not self.host: + LOG.info("MQTT host is not configured") + return + ssl_context: ssl.SSLContext | None = None + if self.transport_protocol.with_tls: + ssl_context = ssl.create_default_context() + if self.configuration.tls_server_cert_path: + LOG.debug( + f"Using custom CA file {self.configuration.tls_server_cert_path}" + ) + ssl_context.load_verify_locations( + cafile=self.configuration.tls_server_cert_path + ) + if not self.configuration.tls_server_cert_check_hostname: + LOG.warning( + f"Skipping hostname check for TLS connection to {self.host}" + ) - mqtt_client = gmqtt.Client( - client_id=str(self.publisher_id), + client = aiomqtt.Client( + hostname=self.host, + port=self.port, + identifier=str(self.publisher_id), transport=self.transport_protocol.transport_mechanism, - will_message=gmqtt.Message( + username=self.configuration.mqtt_user or None, + password=self.configuration.mqtt_password or None, + clean_session=True, + tls_context=ssl_context, + tls_insecure=bool( + ssl_context and not self.configuration.tls_server_cert_check_hostname + ), + will=aiomqtt.Will( topic=self.get_topic(mqtt_topics.INTERNAL_LWT, False), payload="offline", retain=True, + qos=1, ), ) - mqtt_client.on_connect = self.__on_connect - mqtt_client.on_message = self.__on_message - self.client: Final[gmqtt.Client] = mqtt_client + client.pending_calls_threshold = 150 + reconnect_interval = _RECONNECT_INTERVAL_MIN + while True: + try: + LOG.debug( + "Connecting to %s:%s as %s", + self.host, + self.port, + self.publisher_id, + ) + async with client as client_context: + self.client = client_context + self.__connected.set() + await self.__on_connect() + reconnect_interval = _RECONNECT_INTERVAL_MIN + async for message in client_context.messages: + await self._on_message( + client_context, + str(message.topic), + message.payload, + message.qos, + message.retain, + ) + except MqttConnectError as e: + # Permanent rejections — retrying won't help. + # rc 3 (server unavailable) is transient and falls through to reconnect. + # ReasonCode.__eq__ supports int comparison, so no isinstance guard needed. + if e.rc in _FATAL_CONNECT_RC: + LOG.error("MQTT connection permanently refused: %s", e) + self.__fatal_connect_error = SystemExit(str(e)) + self.__connected.set() + return + LOG.warning( + "Connection to %s:%s refused (transient); Reconnecting in %d seconds ...", + self.host, + self.port, + reconnect_interval, + ) + await asyncio.sleep(reconnect_interval) + reconnect_interval = min( + reconnect_interval * 2, _RECONNECT_INTERVAL_MAX + ) + except aiomqtt.MqttError: + LOG.warning( + "Connection to %s:%s lost; Reconnecting in %d seconds ...", + self.host, + self.port, + reconnect_interval, + ) + await asyncio.sleep(reconnect_interval) + reconnect_interval = min( + reconnect_interval * 2, _RECONNECT_INTERVAL_MAX + ) + except asyncio.exceptions.CancelledError: + LOG.debug("MQTT publisher loop cancelled") + raise + finally: + self.__connected.clear() + LOG.info("MQTT client disconnected") @override async def connect(self) -> None: - if self.configuration.mqtt_user is not None: - if self.configuration.mqtt_password is not None: - self.client.set_auth_credentials( - username=self.configuration.mqtt_user, - password=self.configuration.mqtt_password, - ) - else: - self.client.set_auth_credentials(username=self.configuration.mqtt_user) - - if self.transport_protocol.with_tls: - ssl_context = ssl.create_default_context() - cert_uri = self.configuration.tls_server_cert_path - if cert_uri: - LOG.debug(f"Using custom CA file {cert_uri}") - ssl_context.load_verify_locations(cafile=cert_uri) - if not self.configuration.tls_server_cert_check_hostname: - LOG.warning( - f"Skipping hostname check for TLS connection to {self.host}" - ) - ssl_context.check_hostname = False - else: - ssl_context = None - await self.client.connect( - host=self.host, - port=self.port, - version=gmqtt.constants.MQTTv311, - ssl=ssl_context, - ) + if self.__running and not self.__running.done(): + LOG.warning("MQTT client is already running") + return + if not self.host: + LOG.info("MQTT host is not configured") + return + self.__running = asyncio.create_task(self.__run_loop()) + await self.__connected.wait() + if self.__fatal_connect_error is not None: + raise self.__fatal_connect_error - def __on_connect( - self, _client: Any, _flags: Any, rc: int, _properties: Any - ) -> None: - if rc == gmqtt.constants.CONNACK_ACCEPTED: - LOG.info("Connected to MQTT broker") - if not self.first_connection: - self.enable_commands() - if self.command_listener is not None: - self.command_listener.on_mqtt_reconnected() - self.first_connection = False - self.keepalive() - else: - if rc == gmqtt.constants.CONNACK_REFUSED_BAD_USERNAME_PASSWORD: - LOG.error( - f"MQTT connection error: bad username or password. Return code {rc}" - ) - elif rc == gmqtt.constants.CONNACK_REFUSED_PROTOCOL_VERSION: - LOG.error( - f"MQTT connection error: refused protocol version. Return code {rc}" - ) - else: - LOG.error(f"MQTT connection error.Return code {rc}") - msg = f"Unable to connect to MQTT broker. Return code: {rc}" - raise SystemExit(msg) + async def __on_connect(self) -> None: + LOG.info("Connected to MQTT broker") + if not self.first_connection: + await self.__enable_commands() + if self.command_listener is not None: + self.command_listener.on_mqtt_reconnected() + self.first_connection = False + self.keepalive() @override def enable_commands(self) -> None: - LOG.info("Subscribing to MQTT command topics") - mqtt_account_prefix = self.get_mqtt_account_prefix() - self.client.subscribe( - f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/+/+/{mqtt_topics.SET_SUFFIX}" - ) - self.client.subscribe( - f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/+/+/+/{mqtt_topics.SET_SUFFIX}" - ) - self.client.subscribe( - f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/{mqtt_topics.REFRESH_MODE}/{mqtt_topics.SET_SUFFIX}" - ) - self.client.subscribe( - f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/{mqtt_topics.REFRESH_PERIOD}/+/{mqtt_topics.SET_SUFFIX}" - ) - for charging_station in self.configuration.charging_stations_by_vin.values(): - LOG.debug( - f"Subscribing to MQTT topic {charging_station.charge_state_topic}" + task = asyncio.get_running_loop().create_task(self.__enable_commands()) + task.add_done_callback(self.__on_enable_commands_done) + + def __on_enable_commands_done(self, task: asyncio.Task[None]) -> None: + if not task.cancelled() and (exc := task.exception()): + LOG.error("Failed to enable MQTT command subscriptions: %s", exc) + + async def __enable_commands(self) -> None: + if not self.__connected.is_set() or not self.client: + LOG.error("Failed to enable commands: MQTT client is not connected") + return + try: + LOG.info("Subscribing to MQTT command topics") + mqtt_account_prefix = self.get_mqtt_account_prefix() + await self.client.subscribe( + f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/+/+/{mqtt_topics.SET_SUFFIX}" + ) + await self.client.subscribe( + f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/+/+/+/{mqtt_topics.SET_SUFFIX}" + ) + await self.client.subscribe( + f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/{mqtt_topics.REFRESH_MODE}/{mqtt_topics.SET_SUFFIX}" ) - self.vin_by_charge_state_topic[charging_station.charge_state_topic] = ( - charging_station.vin + await self.client.subscribe( + f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/{mqtt_topics.REFRESH_PERIOD}/+/{mqtt_topics.SET_SUFFIX}" ) - self.client.subscribe(charging_station.charge_state_topic) - if charging_station.connected_topic: + for ( + charging_station + ) in self.configuration.charging_stations_by_vin.values(): LOG.debug( - f"Subscribing to MQTT topic {charging_station.connected_topic}" + f"Subscribing to MQTT topic {charging_station.charge_state_topic}" ) - self.vin_by_charger_connected_topic[ - charging_station.connected_topic - ] = charging_station.vin - self.client.subscribe(charging_station.connected_topic) - if charging_station.imported_energy_topic: - LOG.debug( - f"Subscribing to MQTT topic {charging_station.imported_energy_topic}" + self.vin_by_charge_state_topic[charging_station.charge_state_topic] = ( + charging_station.vin ) - self.vin_by_imported_energy_topic[ - charging_station.imported_energy_topic - ] = charging_station.vin - self.client.subscribe(charging_station.imported_energy_topic) - if self.configuration.ha_discovery_enabled: - # enable dynamic discovery pushing in case ha reconnects - self.client.subscribe(self.configuration.ha_lwt_topic) + await self.client.subscribe(charging_station.charge_state_topic) + if charging_station.connected_topic: + LOG.debug( + f"Subscribing to MQTT topic {charging_station.connected_topic}" + ) + self.vin_by_charger_connected_topic[ + charging_station.connected_topic + ] = charging_station.vin + await self.client.subscribe(charging_station.connected_topic) + if charging_station.imported_energy_topic: + LOG.debug( + f"Subscribing to MQTT topic {charging_station.imported_energy_topic}" + ) + self.vin_by_imported_energy_topic[ + charging_station.imported_energy_topic + ] = charging_station.vin + await self.client.subscribe(charging_station.imported_energy_topic) + if self.configuration.ha_discovery_enabled: + # enable dynamic discovery pushing in case ha reconnects + await self.client.subscribe(self.configuration.ha_lwt_topic) + except aiomqtt.MqttError as e: + LOG.error(f"Failed to subscribe to MQTT command topics: {e}") + raise e - async def __on_message( - self, _client: Any, topic: str, payload: Any, _qos: Any, _properties: Any + async def _on_message( + self, _client: Any, topic: str, payload: Any, _qos: Any, retained: bool ) -> None: try: if isinstance(payload, bytes): payload = payload.decode("utf-8") else: payload = str(payload) - retained = bool(_properties.get("retain", 0)) await self.__on_message_real( topic=topic, payload=payload, retained=retained ) @@ -228,13 +320,34 @@ async def __handle_imported_energy(self, topic: str, payload: str) -> None: ) def __publish( - self, topic: str, payload: WirePayload | None, *, retain: bool = True + self, + topic: str, + payload: WirePayload | None, + *, + retain: bool = True, + qos: int = 0, + ) -> None: + LOG.debug("Publishing to MQTT topic %s with payload %s", topic, payload) + asyncio.get_running_loop().create_task( + self.__async_publish(topic, payload, retain=retain, qos=qos) + ) + + async def __async_publish( + self, topic: str, payload: Any, retain: bool, qos: int ) -> None: - self.client.publish(topic, payload, retain=retain) + if not (self.client and self.is_connected()): + LOG.error("Failed to publish: MQTT client is not connected") + return + try: + await self.client.publish(topic, payload, retain=retain, qos=qos) + except aiomqtt.MqttError as e: + LOG.error( + f"Failed to publish to MQTT topic {topic} with payload {payload}: {e}" + ) @override def is_connected(self) -> bool: - return cast("bool", self.client.is_connected) + return self.__connected.is_set() @override def publish_json( @@ -244,47 +357,75 @@ def publish_json( no_prefix: bool = False, *, retain: bool = True, + qos: int = 0, ) -> None: payload = self.dict_to_anonymized_json(data) self.__publish( - topic=self.get_topic(key, no_prefix), payload=payload, retain=retain + topic=self.get_topic(key, no_prefix), + payload=payload, + retain=retain, + qos=qos, ) @override def publish_str( - self, key: str, value: str, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: str, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.__publish( - topic=self.get_topic(key, no_prefix), payload=value, retain=retain + topic=self.get_topic(key, no_prefix), payload=value, retain=retain, qos=qos ) @override def publish_int( - self, key: str, value: int, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: int, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.__publish( - topic=self.get_topic(key, no_prefix), payload=value, retain=retain + topic=self.get_topic(key, no_prefix), payload=value, retain=retain, qos=qos ) @override def publish_bool( - self, key: str, value: bool, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: bool, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.__publish( - topic=self.get_topic(key, no_prefix), payload=value, retain=retain + topic=self.get_topic(key, no_prefix), payload=value, retain=retain, qos=qos ) @override def publish_float( - self, key: str, value: float, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: float, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: self.__publish( - topic=self.get_topic(key, no_prefix), payload=value, retain=retain + topic=self.get_topic(key, no_prefix), payload=value, retain=retain, qos=qos ) @override - def clear_topic(self, key: str, no_prefix: bool = False) -> None: - self.__publish(topic=self.get_topic(key, no_prefix), payload=None) + def clear_topic(self, key: str, no_prefix: bool = False, qos: int = 0) -> None: + self.__publish(topic=self.get_topic(key, no_prefix), payload=None, qos=qos) def get_vin_from_topic(self, topic: str) -> str: global_topic_removed = topic[len(self.configuration.mqtt_topic) + 1 :] diff --git a/src/status_publisher/charge/chrg_mgmt_data.py b/src/status_publisher/charge/chrg_mgmt_data.py index b0444fc..5825590 100644 --- a/src/status_publisher/charge/chrg_mgmt_data.py +++ b/src/status_publisher/charge/chrg_mgmt_data.py @@ -104,18 +104,18 @@ def publish(self, charge_mgmt_data: ChrgMgmtData) -> ChrgMgmtDataProcessingResul self._transform_and_publish( topic=mqtt_topics.BMS_CHARGE_STATUS, value=charge_mgmt_data.bms_charging_status, - transform=lambda x: f"UNKNOWN {charge_mgmt_data.bmsChrgSts}" - if x is None - else x.name, + transform=lambda x: ( + f"UNKNOWN {charge_mgmt_data.bmsChrgSts}" if x is None else x.name + ), ) self._transform_and_publish( topic=mqtt_topics.DRIVETRAIN_CHARGING_STOP_REASON, value=charge_mgmt_data.charging_stop_reason, validator=lambda x: x != ChargingStopReason.NO_REASON, - transform=lambda x: f"UNKNOWN {charge_mgmt_data.bmsChrgSpRsn}" - if x is None - else x.name, + transform=lambda x: ( + f"UNKNOWN {charge_mgmt_data.bmsChrgSpRsn}" if x is None else x.name + ), ) self._publish( @@ -153,9 +153,9 @@ def publish(self, charge_mgmt_data: ChrgMgmtData) -> ChrgMgmtDataProcessingResul topic=mqtt_topics.DRIVETRAIN_BATTERY_HEATING_STOP_REASON, value=charge_mgmt_data.heating_stop_reason, validator=lambda x: x != HeatingStopReason.NO_REASON, - transform=lambda x: f"UNKNOWN ({charge_mgmt_data.bmsPTCHeatResp})" - if x is None - else x.name, + transform=lambda x: ( + f"UNKNOWN ({charge_mgmt_data.bmsPTCHeatResp})" if x is None else x.name + ), ) self._transform_and_publish( diff --git a/tests/handlers/test_vehicle_command.py b/tests/handlers/test_vehicle_command.py index 4815b5e..57cfaaf 100644 --- a/tests/handlers/test_vehicle_command.py +++ b/tests/handlers/test_vehicle_command.py @@ -55,7 +55,7 @@ async def test_successful_command_publishes_success(self) -> None: await handler.handle_mqtt_command(topic=CHARGING_SET_TOPIC, payload="true") - pub.publish_str.assert_any_call(CHARGING_RESULT_TOPIC, "Success") + pub.publish_str.assert_any_call(CHARGING_RESULT_TOPIC, "Success", retain=False) pub.publish_json.assert_not_called() @@ -70,6 +70,7 @@ async def test_publishes_error_event(self) -> None: pub.publish_str.assert_any_call( result_topic, "Failed: No handler found for command topic nonexistent/topic/set", + retain=False, ) pub.publish_json.assert_called_once() event = pub.publish_json.call_args[0][1] @@ -99,6 +100,7 @@ async def test_publishes_error_event(self) -> None: CHARGING_RESULT_TOPIC, "Failed: Unsupported payload not_a_boolean for command " "DrivetrainChargingCommand", + retain=False, ) pub.publish_json.assert_called_once() event = pub.publish_json.call_args[0][1] @@ -119,6 +121,7 @@ async def test_publishes_error_event(self) -> None: pub.publish_str.assert_any_call( CHARGING_RESULT_TOPIC, "Failed: return code: 8, message: operation too frequent", + retain=False, ) pub.publish_json.assert_called_once() event = pub.publish_json.call_args[0][1] @@ -135,7 +138,7 @@ async def test_uses_safe_detail(self) -> None: await handler.handle_mqtt_command(topic=CHARGING_SET_TOPIC, payload="true") pub.publish_str.assert_any_call( - CHARGING_RESULT_TOPIC, "Failed: unexpected error" + CHARGING_RESULT_TOPIC, "Failed: unexpected error", retain=False ) event = pub.publish_json.call_args[0][1] assert event["detail"] == "unexpected error" @@ -157,7 +160,7 @@ async def test_relogin_success_retries_command(self) -> None: assert isinstance(relogin, AsyncMock) relogin.force_login.assert_awaited_once() assert saic_api.control_charging.await_count == 2 - pub.publish_str.assert_any_call(CHARGING_RESULT_TOPIC, "Success") + pub.publish_str.assert_any_call(CHARGING_RESULT_TOPIC, "Success", retain=False) pub.publish_json.assert_not_called() async def test_relogin_failure_publishes_error_event(self) -> None: @@ -170,7 +173,7 @@ async def test_relogin_failure_publishes_error_event(self) -> None: await handler.handle_mqtt_command(topic=CHARGING_SET_TOPIC, payload="true") pub.publish_str.assert_any_call( - CHARGING_RESULT_TOPIC, "Failed: relogin failed (login failed)" + CHARGING_RESULT_TOPIC, "Failed: relogin failed (login failed)", retain=False ) pub.publish_json.assert_called_once() event = pub.publish_json.call_args[0][1] @@ -186,7 +189,9 @@ async def test_retry_failure_publishes_error_event(self) -> None: await handler.handle_mqtt_command(topic=CHARGING_SET_TOPIC, payload="true") - pub.publish_str.assert_any_call(CHARGING_RESULT_TOPIC, "Failed: retry boom") + pub.publish_str.assert_any_call( + CHARGING_RESULT_TOPIC, "Failed: retry boom", retain=False + ) pub.publish_json.assert_called_once() event = pub.publish_json.call_args[0][1] assert event["detail"] == "retry boom" @@ -277,7 +282,7 @@ async def test_retained_force_refresh_mode_dropped(self) -> None: ) vehicle_state.set_refresh_mode.assert_not_called() - pub.publish_str.assert_any_call(REFRESH_MODE_RESULT_TOPIC, "Success") + pub.publish_str.assert_any_call(REFRESH_MODE_RESULT_TOPIC, "Success", retain=False) async def test_retained_charging_detection_refresh_mode_dropped(self) -> None: handler, pub = _build() @@ -290,7 +295,7 @@ async def test_retained_charging_detection_refresh_mode_dropped(self) -> None: ) vehicle_state.set_refresh_mode.assert_not_called() - pub.publish_str.assert_any_call(REFRESH_MODE_RESULT_TOPIC, "Success") + pub.publish_str.assert_any_call(REFRESH_MODE_RESULT_TOPIC, "Success", retain=False) async def test_retained_periodic_refresh_mode_applied(self) -> None: handler, _pub = _build() @@ -339,7 +344,9 @@ async def test_retained_battery_capacity_replays_to_vehicle_info(self) -> None: vehicle_state.update_battery_capacity.assert_called_once_with(50.0) pub.publish_float.assert_any_call(TOTAL_BATTERY_CAPACITY_STATE_TOPIC, 50.0) - pub.publish_str.assert_any_call(TOTAL_BATTERY_CAPACITY_RESULT_TOPIC, "Success") + pub.publish_str.assert_any_call( + TOTAL_BATTERY_CAPACITY_RESULT_TOPIC, "Success", retain=False + ) async def test_battery_capacity_zero_payload_publishes_model_default(self) -> None: """Payload `0` clears the override; the per-model default is republished.""" diff --git a/tests/integrations/home_assistant/test_discovery_retain.py b/tests/integrations/home_assistant/test_discovery_retain.py index 45fa789..f56731f 100644 --- a/tests/integrations/home_assistant/test_discovery_retain.py +++ b/tests/integrations/home_assistant/test_discovery_retain.py @@ -104,7 +104,7 @@ def test_required_entities_have_retain_true(self) -> None: assert payload is not None, ( f"No writable HA discovery payload found for topic {topic}" ) - assert payload.get("retain") == "true", ( + assert payload.get("retain") is True, ( f"Expected retain=true for {topic}, got {payload.get('retain')!r}" ) @@ -117,6 +117,6 @@ def test_non_retained_entities_keep_retain_false(self) -> None: payload = _payload_for_state_topic_suffix(payloads, topic) if payload is None: continue # entity not published for this vehicle config - assert payload.get("retain") in ("false", None), ( + assert payload.get("retain") in (False, None), ( f"Expected retain!=true for {topic}, got {payload.get('retain')!r}" ) diff --git a/tests/publisher/test_publish_dispatch.py b/tests/publisher/test_publish_dispatch.py index 7902771..192a574 100644 --- a/tests/publisher/test_publish_dispatch.py +++ b/tests/publisher/test_publish_dispatch.py @@ -299,35 +299,60 @@ def publish_json( no_prefix: bool = False, *, retain: bool = True, + qos: int = 0, ) -> None: pass @override def publish_str( - self, key: str, value: str, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: str, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: pass @override def publish_int( - self, key: str, value: int, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: int, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: pass @override def publish_bool( - self, key: str, value: bool, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: bool, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: pass @override def publish_float( - self, key: str, value: float, no_prefix: bool = False, *, retain: bool = True + self, + key: str, + value: float, + no_prefix: bool = False, + *, + retain: bool = True, + qos: int = 0, ) -> None: pass @override - def clear_topic(self, key: str, no_prefix: bool = False) -> None: + def clear_topic(self, key: str, no_prefix: bool = False, qos: int = 0) -> None: pass diff --git a/tests/test_mqtt_publisher.py b/tests/test_mqtt_publisher.py index 1fdeba3..d4bf19e 100644 --- a/tests/test_mqtt_publisher.py +++ b/tests/test_mqtt_publisher.py @@ -1,8 +1,10 @@ from __future__ import annotations +import asyncio +import json from typing import Any, override import unittest -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from configuration import Configuration, TransportProtocol from publisher.core import MqttCommandListener @@ -71,7 +73,7 @@ async def test_update_rear_window_heat_state(self) -> None: assert self.received_payload == REAR_WINDOW_HEAT_STATE async def send_message(self, topic: str, payload: Any) -> None: - await self.mqtt_client.client.on_message("client", topic, payload, 0, {}) + await self.mqtt_client._on_message("client", topic, payload, 0, False) async def test_get_vin_from_sanitized_topic(self) -> None: """Topics arrive with the sanitized prefix, not the raw username.""" @@ -100,63 +102,72 @@ async def on_charger_connection_state_changed( ) -> None: pass - def test_publish_str_default_is_retained(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_str("foo", "bar") - m_pub.assert_called_once_with("saic/foo", "bar", retain=True) - - def test_publish_str_forwards_retain_false(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_str("foo", "bar", retain=False) - m_pub.assert_called_once_with("saic/foo", "bar", retain=False) - - def test_publish_int_default_is_retained(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_int("foo", 42) - m_pub.assert_called_once_with("saic/foo", 42, retain=True) - - def test_publish_int_forwards_retain_false(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_int("foo", 42, retain=False) - m_pub.assert_called_once_with("saic/foo", 42, retain=False) - - def test_publish_bool_default_is_retained(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_bool("foo", True) - m_pub.assert_called_once_with("saic/foo", True, retain=True) - - def test_publish_bool_forwards_retain_false(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_bool("foo", True, retain=False) - m_pub.assert_called_once_with("saic/foo", True, retain=False) - - def test_publish_float_default_is_retained(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_float("foo", 1.5) - m_pub.assert_called_once_with("saic/foo", 1.5, retain=True) - - def test_publish_float_forwards_retain_false(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_float("foo", 1.5, retain=False) - m_pub.assert_called_once_with("saic/foo", 1.5, retain=False) - - def test_publish_json_default_is_retained(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_json("foo", {"a": 1}) - m_pub.assert_called_once() - args, kwargs = m_pub.call_args - assert args[0] == "saic/foo" - assert kwargs == {"retain": True} - - def test_publish_json_forwards_retain_false(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.publish_json("foo", {"a": 1}, retain=False) - m_pub.assert_called_once() - args, kwargs = m_pub.call_args - assert args[0] == "saic/foo" - assert kwargs == {"retain": False} - - def test_clear_topic_publishes_none_retained(self) -> None: - with patch.object(self.mqtt_client.client, "publish") as m_pub: - self.mqtt_client.clear_topic("foo") - m_pub.assert_called_once_with("saic/foo", None, retain=True) + async def _call_and_flush(self, fn: Any, *args: Any, **kwargs: Any) -> AsyncMock: + mock = AsyncMock() + with patch.object(self.mqtt_client, "_MqttPublisher__async_publish", mock): + self.mqtt_client._MqttPublisher__connected.set() # type: ignore[attr-defined] + fn(*args, **kwargs) + await asyncio.sleep(0) + return mock + + async def test_publish_str_default_is_retained(self) -> None: + m = await self._call_and_flush(self.mqtt_client.publish_str, "foo", "bar") + m.assert_called_once_with("saic/foo", "bar", retain=True, qos=0) + + async def test_publish_str_forwards_retain_false(self) -> None: + m = await self._call_and_flush( + self.mqtt_client.publish_str, "foo", "bar", retain=False + ) + m.assert_called_once_with("saic/foo", "bar", retain=False, qos=0) + + async def test_publish_int_default_is_retained(self) -> None: + m = await self._call_and_flush(self.mqtt_client.publish_int, "foo", 42) + m.assert_called_once_with("saic/foo", 42, retain=True, qos=0) + + async def test_publish_int_forwards_retain_false(self) -> None: + m = await self._call_and_flush( + self.mqtt_client.publish_int, "foo", 42, retain=False + ) + m.assert_called_once_with("saic/foo", 42, retain=False, qos=0) + + async def test_publish_bool_default_is_retained(self) -> None: + m = await self._call_and_flush(self.mqtt_client.publish_bool, "foo", True) + m.assert_called_once_with("saic/foo", True, retain=True, qos=0) + + async def test_publish_bool_forwards_retain_false(self) -> None: + m = await self._call_and_flush( + self.mqtt_client.publish_bool, "foo", True, retain=False + ) + m.assert_called_once_with("saic/foo", True, retain=False, qos=0) + + async def test_publish_float_default_is_retained(self) -> None: + m = await self._call_and_flush(self.mqtt_client.publish_float, "foo", 1.5) + m.assert_called_once_with("saic/foo", 1.5, retain=True, qos=0) + + async def test_publish_float_forwards_retain_false(self) -> None: + m = await self._call_and_flush( + self.mqtt_client.publish_float, "foo", 1.5, retain=False + ) + m.assert_called_once_with("saic/foo", 1.5, retain=False, qos=0) + + async def test_publish_json_default_is_retained(self) -> None: + m = await self._call_and_flush(self.mqtt_client.publish_json, "foo", {"a": 1}) + m.assert_called_once() + args, kwargs = m.call_args + assert args[0] == "saic/foo" + assert json.loads(args[1]) == {"a": 1} + assert kwargs == {"retain": True, "qos": 0} + + async def test_publish_json_forwards_retain_false(self) -> None: + m = await self._call_and_flush( + self.mqtt_client.publish_json, "foo", {"a": 1}, retain=False + ) + m.assert_called_once() + args, kwargs = m.call_args + assert args[0] == "saic/foo" + assert json.loads(args[1]) == {"a": 1} + assert kwargs == {"retain": False, "qos": 0} + + async def test_clear_topic_publishes_none_retained(self) -> None: + m = await self._call_and_flush(self.mqtt_client.clear_topic, "foo") + m.assert_called_once_with("saic/foo", None, retain=True, qos=0)