diff --git a/adata/common/utils/sunrequests.py b/adata/common/utils/sunrequests.py index 9fec0ed..a320016 100644 --- a/adata/common/utils/sunrequests.py +++ b/adata/common/utils/sunrequests.py @@ -10,6 +10,7 @@ import threading import time +from urllib.parse import urlparse import requests @@ -41,10 +42,69 @@ def delete(cls, key): del cls._data[key] +class RateLimiter: + _instance_lock = threading.Lock() + _instance = None + _default_limit = 30 + _window_seconds = 60 + _domain_limits = {} + _request_history = {} + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + with cls._instance_lock: + if cls._instance is None: + cls._instance = object.__new__(cls) + return cls._instance + + @classmethod + def set_default_limit(cls, limit): + cls._default_limit = limit + + @classmethod + def set_domain_limit(cls, domain, limit): + cls._domain_limits[domain] = limit + + @classmethod + def get_domain_limit(cls, domain): + return cls._domain_limits.get(domain, cls._default_limit) + + @classmethod + def _clean_old_requests(cls, domain, current_time): + if domain not in cls._request_history: + cls._request_history[domain] = [] + cls._request_history[domain] = [ + t for t in cls._request_history[domain] + if current_time - t < cls._window_seconds + ] + + @classmethod + def acquire(cls, url): + parsed = urlparse(url) + domain = parsed.netloc + current_time = time.time() + + with cls._instance_lock: + cls._clean_old_requests(domain, current_time) + limit = cls.get_domain_limit(domain) + + if len(cls._request_history[domain]) >= limit: + oldest_request = cls._request_history[domain][0] + wait_time = cls._window_seconds - (current_time - oldest_request) + if wait_time > 0: + time.sleep(wait_time) + current_time = time.time() + cls._clean_old_requests(domain, current_time) + + cls._request_history[domain].append(current_time) + return True + + class SunRequests(object): def __init__(self, sun_proxy: SunProxy = None) -> None: super().__init__() self.sun_proxy = sun_proxy + self._rate_limiter = RateLimiter() def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, **kwargs): """ @@ -58,9 +118,10 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies :param kwargs: 其它 requests 参数,用法相同 :return: res """ - # 1. 获取设置代理 + if url: + self._rate_limiter.acquire(url) + proxies = self.__get_proxies(proxies) - # 2. 请求数据结果 res = None for i in range(times): if wait_time: @@ -74,9 +135,6 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies return res def __get_proxies(self, proxies): - """ - 获取代理配置 - """ if proxies is None: proxies = {} is_proxy = SunProxy.get('is_proxy') @@ -90,4 +148,12 @@ def __get_proxies(self, proxies): return proxies +def set_default_rate_limit(limit): + RateLimiter.set_default_limit(limit) + + +def set_domain_rate_limit(domain, limit): + RateLimiter.set_domain_limit(domain, limit) + + sun_requests = SunRequests()