diff --git a/conftest.py b/conftest.py index 41b52042c..33640369e 100644 --- a/conftest.py +++ b/conftest.py @@ -27,12 +27,20 @@ ]: sys.modules[module] = Mock() + # use gpiozero fake pins environ["GPIOZERO_PIN_FACTORY"] = "mock" @pytest.fixture -def oled_mocks(): +def zmq_poller_mock(): + poller_mock = Mock() + poller_mock.poll.return_value = [] + sys.modules["zmq"].Poller.return_value = poller_mock + + +@pytest.fixture +def oled_mocks(zmq_poller_mock): SIZE = (128, 64) MODE = "1" SPI_BUS = 0 diff --git a/packages/common/pitop/common/ptdm.py b/packages/common/pitop/common/ptdm.py index 5cc448a66..11958731c 100644 --- a/packages/common/pitop/common/ptdm.py +++ b/packages/common/pitop/common/ptdm.py @@ -420,7 +420,7 @@ class PTDMSubscribeClient: def __init__(self): self.__thread = Thread(target=self.__thread_method, daemon=True) - self._callback_funcs = None + self._callback_funcs = {} self._zmq_context = None self._zmq_socket = None @@ -435,7 +435,9 @@ def __exit__(self, exc_type, exc_value, exc_traceback): def __connect_to_socket(self): self._zmq_context = zmq.Context() self._zmq_socket = self._zmq_context.socket(zmq.SUB) - self._zmq_socket.setsockopt_string(zmq.SUBSCRIBE, "") + + for message_id in self._callback_funcs.keys(): + self._zmq_socket.setsockopt(zmq.SUBSCRIBE, str(message_id).encode()) try: self._zmq_socket.connect(self.URI) @@ -463,35 +465,40 @@ def __thread_method(self): poller.register(self._zmq_socket, zmq.POLLIN) while self.__continue: events = poller.poll(_TIMEOUT_MS) - for _ in range(len(events)): message_string = self._zmq_socket.recv_string() message = Message.from_string(message_string) - id = message.message_id() - if id in self._callback_funcs: - self.invoke_callback_func_if_exists( - self._callback_funcs[id], message.parameters - ) + callback = self._callback_funcs.get(message.message_id()) + if callback: + self.invoke_callback(callback, message.parameters) - def invoke_callback_func_if_exists(self, func, params=list()): - if not callable(func): - return - - func_arg_no = len(signature(func).parameters) - if func_arg_no > 1: - logger.error( - "Invalid callback function - it should receive at most one argument." - ) - return "" + def invoke_callback(self, callback, params=None): + if params is None: + params = list() - if params == list() or func_arg_no == 0: - func() + func_arg_no = len(signature(callback).parameters) + if len(params) == 0 or func_arg_no == 0: + callback() else: - func(params) + callback(params) def initialise(self, callback_funcs): - self._callback_funcs = callback_funcs + for message_id, callback in callback_funcs.items(): + if not callable(callback): + logger.error( + f"Invalid callback function for message {message_id} - not callable. Skipping..." + ) + continue + + func_arg_no = len(signature(callback).parameters) + if func_arg_no > 1: + logger.error( + f"Invalid callback function for message {message_id} - it should receive at most one argument. Skipping..." + ) + continue + + self._callback_funcs.update({message_id: callback}) def start_listening(self): if not self.__connect_to_socket(): diff --git a/packages/display/pitop/display/display.py b/packages/display/pitop/display/display.py index 24dd8f296..791daa1f4 100644 --- a/packages/display/pitop/display/display.py +++ b/packages/display/pitop/display/display.py @@ -18,29 +18,21 @@ def __init__(self): def __setup_subscribe_client(self): def on_brightness_changed(parameters): - self.__ptdm_subscribe_client.invoke_callback_func_if_exists( + self.__ptdm_subscribe_client.invoke_callback( self.when_brightness_changed, parameters[0] ) def on_screen_blanked(): - self.__ptdm_subscribe_client.invoke_callback_func_if_exists( - self.when_screen_blanked - ) + self.__ptdm_subscribe_client.invoke_callback(self.when_screen_blanked) def on_screen_unblanked(): - self.__ptdm_subscribe_client.invoke_callback_func_if_exists( - self.when_screen_unblanked - ) + self.__ptdm_subscribe_client.invoke_callback(self.when_screen_unblanked) def on_lid_closed(): - self.__ptdm_subscribe_client.invoke_callback_func_if_exists( - self.when_lid_closed - ) + self.__ptdm_subscribe_client.invoke_callback(self.when_lid_closed) def on_lid_opened(): - self.__ptdm_subscribe_client.invoke_callback_func_if_exists( - self.when_lid_opened - ) + self.__ptdm_subscribe_client.invoke_callback(self.when_lid_opened) self.__ptdm_subscribe_client = PTDMSubscribeClient() self.__ptdm_subscribe_client.initialise( diff --git a/packages/miniscreen/pitop/miniscreen/miniscreen.py b/packages/miniscreen/pitop/miniscreen/miniscreen.py index 2609039dd..672643c6d 100644 --- a/packages/miniscreen/pitop/miniscreen/miniscreen.py +++ b/packages/miniscreen/pitop/miniscreen/miniscreen.py @@ -31,9 +31,11 @@ def __init__(self): def __setup_subscribe_client(self): def set_button_state(button, pressed): button.is_pressed = pressed - self.__ptdm_subscribe_client.invoke_callback_func_if_exists( + callback = ( button.when_pressed if button.is_pressed else button.when_released ) + if callable(callback): + callback() self.__ptdm_subscribe_client = PTDMSubscribeClient() self.__ptdm_subscribe_client.initialise( diff --git a/tests/test_ptdm.py b/tests/test_ptdm.py index 2278f2177..b5f3f8ddd 100644 --- a/tests/test_ptdm.py +++ b/tests/test_ptdm.py @@ -1,4 +1,4 @@ -from unittest import TestCase, skip +from unittest import TestCase from unittest.mock import Mock, patch from tests.utils import wait_until @@ -19,40 +19,52 @@ def setUp(self): self.poller_mock.poll.return_value = [] self.addCleanup(self.zmq_patch.stop) - @skip - def test_callback_called_when_message_is_published(self): + def test_correct_callback_called_when_message_is_published(self): from pitop.common.ptdm import Message, PTDMSubscribeClient - self.poller_mock.poll.side_effect = ( - lambda _: [1] if self.poller_mock.poll.call_count == 1 else [] - ) - self.socket_mock.recv_string.return_value = f"{Message.PUB_LOW_BATTERY_WARNING}" + self.poller_mock.poll.return_value = [2] - def callback(): - callback.counter += 1 + def callback_without_args(): + callback_without_args.counter += 1 - callback.counter = 0 + callback_without_args.counter = 0 + + def callback_with_args(): + callback_with_args.counter += 1 + + callback_with_args.counter = 0 client = PTDMSubscribeClient() client.initialise( { - Message.PUB_LOW_BATTERY_WARNING: callback, + Message.PUB_LOW_BATTERY_WARNING: callback_without_args, + Message.PUB_BRIGHTNESS_CHANGED: callback_with_args, } ) client.start_listening() - assert callback.counter == 0 - wait_until(lambda: callback.counter > 10, timeout=10) - assert callback.counter == 1 + + assert callback_without_args.counter == 0 + assert callback_with_args.counter == 0 + + # Emit event that doesn't use an argument + self.socket_mock.recv_string.return_value = f"{Message.PUB_LOW_BATTERY_WARNING}" + wait_until(lambda: callback_without_args.counter > 0, timeout=5) + assert callback_with_args.counter == 0 + + # Emit event that uses an argument + self.socket_mock.recv_string.return_value = ( + f"{Message.PUB_BRIGHTNESS_CHANGED}|1" + ) + wait_until(lambda: callback_with_args.counter > 0, timeout=5) + client.stop_listening() - def test_callback_not_called_if_it_has_wrong_signature(self): + def test_callback_not_included_if_has_wrong_signature(self): from pitop.common.ptdm import Message, PTDMSubscribeClient - self.poller_mock.poll.side_effect = ( - lambda _: [1] if self.poller_mock.poll.call_count == 1 else [] - ) self.socket_mock.recv_string.return_value = f"{Message.PUB_LOW_BATTERY_WARNING}" + # Callback should have only 1 argument def callback(x, y): callback.counter += 1 @@ -64,15 +76,15 @@ def callback(x, y): Message.PUB_LOW_BATTERY_WARNING: callback, } ) - client.start_listening() - assert callback.counter == 0 - wait_until(lambda: self.poller_mock.poll.call_count > 10, timeout=5) - assert callback.counter == 0 - client.stop_listening() + + # Callback wasn't saved + assert client._callback_funcs.get(Message.PUB_LOW_BATTERY_WARNING) is None def test_subscribe_client_cleanup_closes_socket(self): from pitop.common.ptdm import Message, PTDMSubscribeClient + self.socket_mock.recv_string.return_value = f"{Message.PUB_LOW_BATTERY_WARNING}" + client = PTDMSubscribeClient() client.initialise( { @@ -94,15 +106,12 @@ class PTDMRequestClientTestCase(TestCase): def setUp(self): self.zmq_patch = patch("pitop.common.ptdm.zmq") self.zmq_mock = self.zmq_patch.start() - self.poller_mock = Mock() self.context_mock = Mock() self.socket_mock = Mock() self.socket_mock.recv_string.return_value = "" self.context_mock.socket.return_value = self.socket_mock self.zmq_mock.Context.return_value = self.context_mock - self.zmq_mock.Poller.return_value = self.poller_mock - self.poller_mock.poll.return_value = [] self.addCleanup(self.zmq_patch.stop) def test_uri(self):