diff --git a/redis_watcher/watcher.py b/redis_watcher/watcher.py index 4c049b4..b4b519b 100644 --- a/redis_watcher/watcher.py +++ b/redis_watcher/watcher.py @@ -15,24 +15,39 @@ import json import logging from threading import Thread, Lock, Event +import time from casbin.model import Model from redis.client import Redis, PubSub +from redis.backoff import ExponentialBackoff +from redis.retry import Retry as RedisRetry from redis_watcher.options import WatcherOptions class RedisWatcher: - def __init__(self): + def __init__(self, logger=None): self.mutex: Lock = Lock() self.sub_client: PubSub = None self.pub_client: Redis = None self.options: WatcherOptions = None self.close = None + self.sleep = 0 + self.execute_update = False self.callback: callable = None self.subscribe_thread: Thread = Thread(target=self.subscribe, daemon=True) self.subscribe_event = Event() - self.logger = logging.getLogger(__name__) + + self.logger = logger if logger else logging.getLogger(__name__) + + def recreate_thread(self): + self.sleep = 10 + self.execute_update = True + self.subscribe_thread: Thread = Thread(target=self.subscribe, daemon=True) + self.subscribe_event = Event() + self.close = False + self.subscribe_thread.start() + self.subscribe_event.wait(timeout=1) def init_config(self, option: WatcherOptions): if option.optional_update_callback: @@ -47,6 +62,51 @@ def set_update_callback(self, callback: callable): with self.mutex: self.callback = callback + def _get_redis_conn(self): + """ + Creates a new redis connection instance + """ + rds = Redis( + host=self.options.host, + port=self.options.port, + password=self.options.password, + ssl=self.options.ssl, + retry=RedisRetry(ExponentialBackoff(), 3), + ) + return rds + + def init_publisher_subscriber(self, init_pub=True, init_sub=True): + """ + Initialize the publisher and subscriber subscribers + NOTE: A new Redis connection is created for the publisher and subscriber because since Redis5 + the connection needs to be created by thread + Args: + init_pub (bool, optional): Whether to initialize the publisher subscriber. Defaults to True. + init_sub (bool, optional): Whether to initialize the publisher subscriber. Defaults to True. + """ + try: + if init_pub: + rds = self._get_redis_conn() + if not rds.ping(): + raise Exception("Redis not responding.") + self.pub_client = rds.client() + + if init_sub: + rds = self._get_redis_conn() + if not rds.ping(): + raise Exception("Redis not responding.") + self.sub_client = rds.client().pubsub() + except Exception as e: + if self.pub_client: + self.pub_client.close() + if self.sub_client: + self.sub_client.close() + self.pub_client = None + self.sub_client = None + print( + f"Casbin Redis Watcher error: {e}. Publisher/Subscriber failed to be initialized {self.options.local_ID}" + ) + def update(self): def func(): with self.mutex: @@ -103,12 +163,16 @@ def func(): def default_callback_func(msg: str): print("callback: " + msg) - @staticmethod - def log_record(f: callable): + def log_record(self, f: callable): try: + if not self.pub_client: + rds = self._get_redis_conn() + self.pub_client = rds.client() result = f() except Exception as e: - print(f"Casbin Redis Watcher error: {e}") + if self.pub_client: + self.pub_client.close() + print(f"Casbin Redis Watcher error: {e}. Publisher failure on the worker {self.options.local_ID}") else: return result @@ -117,13 +181,64 @@ def unsubscribe(psc: PubSub): return psc.unsubscribe() def subscribe(self): - self.sub_client.subscribe(self.options.channel) - for item in self.sub_client.listen(): - if not self.subscribe_event.is_set(): - self.subscribe_event.set() - if item is not None and item["type"] == "message": - with self.mutex: - self.callback(str(item)) + time.sleep(self.sleep) + try: + if not self.sub_client: + rds = self._get_redis_conn() + self.sub_client = rds.client().pubsub() + self.sub_client.subscribe(self.options.channel) + print(f"Waiting for casbin updates... in the worker: {self.options.local_ID}") + if self.execute_update: + self.update() + try: + for item in self.sub_client.listen(): + if not self.subscribe_event.is_set(): + self.subscribe_event.set() + if item is not None and item["type"] == "message": + try: + with self.mutex: + self.callback(str(item)) + except Exception as listen_exc: + print( + "Casbin Redis watcher failed sending update to teh callback function " + " process due to: {}".format(str(listen_exc)) + ) + if self.sub_client: + self.sub_client.close() + break + except Exception as sub_exc: + print("Casbin Redis watcher failed to get message from redis due to {}".format(str(sub_exc))) + if self.sub_client: + self.sub_client.close() + except Exception as redis_exc: + print("Casbin Redis watcher failed to subscribe due to: {}".format(str(redis_exc))) + finally: + if self.sub_client: + self.sub_client.close() + + def should_reload(self, recreate=True): + """ + Checks is the thread and event are still alive, if they are not they are recreated. + If they were recreated the watcher should reload the policies. + Args: + recreate(bool): recreates the thread if it's dead for redis timeouts + """ + try: + if self.subscribe_thread.is_alive() and self.subscribe_event.is_set(): + return False + else: + if recreate and not self.subscribe_thread.is_alive(): + print(f"Casbin Redis Watcher will be recreated for the worker {self.options.local_ID} in 10 secs.") + self.recreate_thread() + return True + except Exception: + return True + + def update_callback(self): + """ + This method was created to cover the function that flask_authz calls + """ + self.update() class MSG: @@ -140,18 +255,15 @@ def marshal_binary(self): @staticmethod def unmarshal_binary(data: bytes): loaded = json.loads(data) + loaded.pop("params", None) return MSG(**loaded) -def new_watcher(option: WatcherOptions): +def new_watcher(option: WatcherOptions, logger=None): option.init_config() - w = RedisWatcher() - rds = Redis(host=option.host, port=option.port, password=option.password, ssl=option.ssl) - if rds.ping() is False: - raise Exception("Redis server is not available.") - w.sub_client = rds.client().pubsub() - w.pub_client = rds.client() + w = RedisWatcher(logger) w.init_config(option) + w.init_publisher_subscriber() w.close = False w.subscribe_thread.start() w.subscribe_event.wait(timeout=5) @@ -161,10 +273,7 @@ def new_watcher(option: WatcherOptions): def new_publish_watcher(option: WatcherOptions): option.init_config() w = RedisWatcher() - rds = Redis(host=option.host, port=option.port, password=option.password, ssl=option.ssl) - if rds.ping() is False: - raise Exception("Redis server is not available.") - w.pub_client = rds.client() w.init_config(option) + w.init_publisher_subscriber(init_sub=False) w.close = False return w diff --git a/requirements.txt b/requirements.txt index 3c34b5b..7551798 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ casbin~=1.18 -redis==4.5.2 +redis