diff --git a/adata/common/utils/sunrequests.py b/adata/common/utils/sunrequests.py index 9fec0ed..69c09fb 100644 --- a/adata/common/utils/sunrequests.py +++ b/adata/common/utils/sunrequests.py @@ -10,6 +10,8 @@ import threading import time +from collections import defaultdict +from urllib.parse import urlparse import requests @@ -41,12 +43,140 @@ def delete(cls, key): del cls._data[key] +class RateLimiter: + """ + 基于域名的频率限制器 + 默认每个域名每分钟30次请求 + """ + + def __init__(self, default_rate=30, default_period=60): + """ + :param default_rate: 默认请求次数限制(每分钟) + :param default_period: 时间窗口(秒),默认60秒 + """ + self.default_rate = default_rate + self.default_period = default_period + self._domain_rates = {} # 自定义域名频率限制 + self._request_history = defaultdict(list) # 记录每个域名的请求时间 + self._lock = threading.Lock() + + def set_rate_limit(self, domain, rate, period=60): + """ + 设置指定域名的频率限制 + :param domain: 域名,如 'eastmoney.com' + :param rate: 每分钟请求次数限制 + :param period: 时间窗口(秒),默认60秒 + """ + self._domain_rates[domain] = (rate, period) + + def get_rate_limit(self, domain): + """ + 获取指定域名的频率限制 + :param domain: 域名 + :return: (rate, period) + """ + # 检查是否有精确匹配 + if domain in self._domain_rates: + return self._domain_rates[domain] + # 检查是否有部分匹配(如子域名) + for d, rate_info in self._domain_rates.items(): + if domain.endswith(d) or d.endswith(domain): + return rate_info + return self.default_rate, self.default_period + + def _extract_domain(self, url): + """从URL中提取域名""" + try: + parsed = urlparse(url) + return parsed.netloc.lower() + except Exception: + return url.lower() + + def acquire(self, url): + """ + 获取请求许可,如果超出频率限制则等待 + :param url: 请求的URL + """ + domain = self._extract_domain(url) + rate, period = self.get_rate_limit(domain) + + with self._lock: + now = time.time() + history = self._request_history[domain] + + # 清理过期记录 + cutoff = now - period + self._request_history[domain] = [t for t in history if t > cutoff] + + # 检查是否超出限制 + if len(self._request_history[domain]) >= rate: + # 计算需要等待的时间 + oldest = min(self._request_history[domain]) + wait_time = period - (now - oldest) + if wait_time > 0: + time.sleep(wait_time) + # 重新获取当前时间并清理 + now = time.time() + self._request_history[domain] = [ + t for t in self._request_history[domain] if t > now - period + ] + + # 记录本次请求 + self._request_history[domain].append(time.time()) + + def reset(self, domain=None): + """ + 重置频率限制记录 + :param domain: 指定域名,如果为None则重置所有 + """ + with self._lock: + if domain: + self._request_history.pop(domain, None) + else: + self._request_history.clear() + + class SunRequests(object): def __init__(self, sun_proxy: SunProxy = None) -> None: super().__init__() self.sun_proxy = sun_proxy + self._rate_limiter = RateLimiter(default_rate=30, default_period=60) + + def set_rate_limit(self, domain, rate, period=60): + """ + 设置指定域名的频率限制 + :param domain: 域名,如 'eastmoney.com' 或 'push2.eastmoney.com' + :param rate: 每分钟请求次数限制 + :param period: 时间窗口(秒),默认60秒 + """ + self._rate_limiter.set_rate_limit(domain, rate, period) - def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, **kwargs): + def set_default_rate_limit(self, rate, period=60): + """ + 设置默认的频率限制 + :param rate: 每分钟请求次数限制 + :param period: 时间窗口(秒),默认60秒 + """ + self._rate_limiter.default_rate = rate + self._rate_limiter.default_period = period + + def reset_rate_limit(self, domain=None): + """ + 重置频率限制记录 + :param domain: 指定域名,如果为None则重置所有 + """ + self._rate_limiter.reset(domain) + + def request( + self, + method="get", + url=None, + times=3, + retry_wait_time=1588, + proxies=None, + wait_time=None, + **kwargs, + ): """ 简单封装的请求,参考requests,增加循环次数和次数之间的等待时间 :param proxies: 代理配置 @@ -58,9 +188,13 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies :param kwargs: 其它 requests 参数,用法相同 :return: res """ - # 1. 获取设置代理 + # 1. 频率限制检查 + if url: + self._rate_limiter.acquire(url) + + # 2. 获取设置代理 proxies = self.__get_proxies(proxies) - # 2. 请求数据结果 + # 3. 请求数据结果 res = None for i in range(times): if wait_time: @@ -79,14 +213,19 @@ def __get_proxies(self, proxies): """ if proxies is None: proxies = {} - is_proxy = SunProxy.get('is_proxy') - ip = SunProxy.get('ip') - proxy_url = SunProxy.get('proxy_url') + is_proxy = SunProxy.get("is_proxy") + ip = SunProxy.get("ip") + proxy_url = SunProxy.get("proxy_url") if not ip and is_proxy and proxy_url: - ip = requests.get(url=proxy_url).text.replace('\r\n', '') \ - .replace('\r', '').replace('\n', '').replace('\t', '') + ip = ( + requests.get(url=proxy_url) + .text.replace("\r\n", "") + .replace("\r", "") + .replace("\n", "") + .replace("\t", "") + ) if is_proxy and ip: - proxies = {'https': f"http://{ip}", 'http': f"http://{ip}"} + proxies = {"https": f"http://{ip}", "http": f"http://{ip}"} return proxies