From 9cff16892b610309ee600e2a6d420cd31fff7361 Mon Sep 17 00:00:00 2001 From: SevenQuiches <2859601962@qq.com> Date: Wed, 18 Mar 2026 13:23:54 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=96=B0=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- adata/__init__.py | 41 ++++++++ adata/common/utils/rate_limiter.py | 146 +++++++++++++++++++++++++++++ adata/common/utils/sunrequests.py | 14 ++- tests/rate_limiter_test.py | 90 ++++++++++++++++++ 4 files changed, 288 insertions(+), 3 deletions(-) create mode 100644 adata/common/utils/rate_limiter.py create mode 100644 tests/rate_limiter_test.py diff --git a/adata/__init__.py b/adata/__init__.py index dee08e2..0f5e705 100644 --- a/adata/__init__.py +++ b/adata/__init__.py @@ -11,6 +11,7 @@ from adata.__version__ import __version__ from adata.bond import bond from adata.common.utils.sunrequests import SunProxy +from adata.common.utils.rate_limiter import rate_limiter from adata.fund import fund from adata.sentiment import sentiment from adata.stock import stock @@ -33,6 +34,46 @@ def proxy(is_proxy=False, ip: str = None, proxy_url: str = None): return +def set_rate_limit(domain: str = None, limit: int = 30): + """ + 设置请求频率限制 + :param domain: 域名,例如: 'push2his.eastmoney.com';为None时设置默认限制 + :param limit: 每分钟最大请求次数,默认30次 + + 使用示例: + # 设置默认频率限制为每分钟30次 + adata.set_rate_limit(limit=30) + + # 设置特定域名的频率限制为每分钟60次 + adata.set_rate_limit(domain='push2his.eastmoney.com', limit=60) + + # 查看当前频率限制状态 + status = adata.get_rate_limit_status() + print(status) + """ + if domain: + rate_limiter.set_limit(domain, limit) + else: + rate_limiter.set_default_limit(limit) + return + + +def get_rate_limit_status(url: str = None): + """ + 获取频率限制状态 + :param url: 可选,指定URL查看特定域名状态;为None时返回所有域名状态 + :return: 状态信息字典 + + 使用示例: + # 获取所有域名的频率限制状态 + status = adata.get_rate_limit_status() + + # 获取特定URL的域名频率限制状态 + status = adata.get_rate_limit_status('http://push2his.eastmoney.com/api/xxx') + """ + return rate_limiter.get_status(url) + + # set up logging logger = logging.getLogger("adata") diff --git a/adata/common/utils/rate_limiter.py b/adata/common/utils/rate_limiter.py new file mode 100644 index 0000000..418299d --- /dev/null +++ b/adata/common/utils/rate_limiter.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +""" +@desc: 请求频率限制器 +@author: 1nchaos +@time: 2026/3/18 +@log: 基于域名控制请求频率 +""" + +import threading +import time +from collections import defaultdict +from urllib.parse import urlparse + + +class RateLimiter: + """ + 请求频率限制器 + 基于域名控制请求频率,支持自定义每个域名的每分钟最大请求次数 + """ + _instance = None + _instance_lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + with cls._instance_lock: + if cls._instance is None: + cls._instance = object.__new__(cls) + return cls._instance + + def __init__(self, default_limit: int = 30): + """ + 初始化频率限制器 + :param default_limit: 默认每分钟最大请求次数 + """ + # 避免重复初始化 + if hasattr(self, '_initialized') and self._initialized: + return + + self._default_limit = default_limit + # 存储每个域名的限制配置: {domain: limit} + self._limits = {} + # 存储每个域名的请求记录: {domain: [(timestamp, count), ...]} + self._records = defaultdict(list) + self._lock = threading.Lock() + self._initialized = True + + def set_limit(self, domain: str, limit: int): + """ + 设置指定域名的频率限制 + :param domain: 域名,例如: 'push2his.eastmoney.com' + :param limit: 每分钟最大请求次数 + """ + with self._lock: + self._limits[domain] = limit + + def set_default_limit(self, limit: int): + """ + 设置默认的频率限制 + :param limit: 每分钟最大请求次数 + """ + with self._lock: + self._default_limit = limit + + def get_limit(self, domain: str) -> int: + """ + 获取指定域名的频率限制 + :param domain: 域名 + :return: 每分钟最大请求次数 + """ + return self._limits.get(domain, self._default_limit) + + def _extract_domain(self, url: str) -> str: + """ + 从URL中提取域名 + :param url: 完整的URL + :return: 域名 + """ + parsed = urlparse(url) + return parsed.netloc or parsed.path.split('/')[0] + + def acquire(self, url: str): + """ + 获取请求许可,如果超过频率限制则等待 + :param url: 请求的URL + """ + domain = self._extract_domain(url) + limit = self.get_limit(domain) + + with self._lock: + now = time.time() + # 清理60秒前的记录 + self._records[domain] = [ + (ts, cnt) for ts, cnt in self._records[domain] + if now - ts < 60 + ] + + # 计算当前分钟内的请求次数 + current_count = sum(cnt for ts, cnt in self._records[domain]) + + if current_count >= limit: + # 需要等待的时间 + oldest_ts = self._records[domain][0][0] + wait_time = 60 - (now - oldest_ts) + if wait_time > 0: + time.sleep(wait_time) + # 等待后重新获取 + return self.acquire(url) + + # 记录本次请求 + self._records[domain].append((now, 1)) + + def get_status(self, url: str = None) -> dict: + """ + 获取频率限制状态 + :param url: 可选,指定URL查看特定域名状态 + :return: 状态信息字典 + """ + with self._lock: + if url: + domain = self._extract_domain(url) + now = time.time() + records = [(ts, cnt) for ts, cnt in self._records[domain] if now - ts < 60] + current_count = sum(cnt for ts, cnt in records) + return { + 'domain': domain, + 'limit': self.get_limit(domain), + 'current_count': current_count, + 'remaining': max(0, self.get_limit(domain) - current_count) + } + else: + # 返回所有域名的状态 + result = {} + now = time.time() + for domain in self._records: + records = [(ts, cnt) for ts, cnt in self._records[domain] if now - ts < 60] + current_count = sum(cnt for ts, cnt in records) + result[domain] = { + 'limit': self.get_limit(domain), + 'current_count': current_count, + 'remaining': max(0, self.get_limit(domain) - current_count) + } + return result + + +# 全局单例实例 +rate_limiter = RateLimiter() diff --git a/adata/common/utils/sunrequests.py b/adata/common/utils/sunrequests.py index 9fec0ed..9894eee 100644 --- a/adata/common/utils/sunrequests.py +++ b/adata/common/utils/sunrequests.py @@ -13,6 +13,8 @@ import requests +from adata.common.utils.rate_limiter import rate_limiter + class SunProxy(object): _data = {} @@ -46,7 +48,8 @@ def __init__(self, sun_proxy: SunProxy = None) -> None: super().__init__() self.sun_proxy = sun_proxy - def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, **kwargs): + def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, + rate_limit=True, **kwargs): """ 简单封装的请求,参考requests,增加循环次数和次数之间的等待时间 :param proxies: 代理配置 @@ -55,12 +58,17 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies :param times: 次数,int :param retry_wait_time: 重试等待时间,毫秒 :param wait_time: 等待时间:毫秒;表示每个请求的间隔时间,在请求之前等待sleep,主要用于防止请求太频繁的限制。 + :param rate_limit: 是否启用频率限制,默认True :param kwargs: 其它 requests 参数,用法相同 :return: res """ - # 1. 获取设置代理 + # 1. 频率限制检查 + if rate_limit and url: + rate_limiter.acquire(url) + + # 2. 获取设置代理 proxies = self.__get_proxies(proxies) - # 2. 请求数据结果 + # 3. 请求数据结果 res = None for i in range(times): if wait_time: diff --git a/tests/rate_limiter_test.py b/tests/rate_limiter_test.py new file mode 100644 index 0000000..84a3d2c --- /dev/null +++ b/tests/rate_limiter_test.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +""" +@desc: 频率限制器测试 +@author: 1nchaos +@time: 2026/3/18 +""" + +import time +import sys +sys.path.insert(0, 'c:\\Users\\EDY\\adata') + +# 手动导入 rate_limiter 模块 +import importlib.util +spec = importlib.util.spec_from_file_location("rate_limiter", "c:/Users/EDY/adata/adata/common/utils/rate_limiter.py") +rate_limiter_module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(rate_limiter_module) + +# 获取全局单例实例 +rate_limiter = rate_limiter_module.rate_limiter + + +def test_rate_limiter(): + """测试频率限制功能""" + print("=" * 50) + print("测试频率限制器功能") + print("=" * 50) + + # 测试1: 设置默认频率限制 + print("\n1. 设置默认频率限制为每分钟5次(测试用)") + rate_limiter.set_default_limit(5) + print(" ✓ 默认频率限制设置成功") + + # 测试2: 设置特定域名频率限制 + print("\n2. 设置特定域名频率限制") + rate_limiter.set_limit('push2his.eastmoney.com', 10) + rate_limiter.set_limit('query.sse.com.cn', 20) + print(" ✓ 特定域名频率限制设置成功") + + # 测试3: 查看频率限制状态 + print("\n3. 查看频率限制状态") + status = rate_limiter.get_status() + print(f" 当前状态: {status}") + + # 测试4: 查看特定URL的频率限制状态 + print("\n4. 查看特定URL的频率限制状态") + url_status = rate_limiter.get_status('http://push2his.eastmoney.com/api/qt/stock/kline/get') + print(f" URL状态: {url_status}") + + print("\n" + "=" * 50) + print("频率限制器功能测试完成") + print("=" * 50) + + +def test_rate_limit_acquire(): + """测试频率限制 acquire 功能""" + print("\n" + "=" * 50) + print("测试频率限制 acquire 功能") + print("=" * 50) + + # 创建新的实例用于测试(避免影响全局实例) + test_limiter = rate_limiter_module.RateLimiter() + + # 设置较低的频率限制以便测试 + test_limiter.set_default_limit(3) + print("\n设置频率限制为每分钟3次") + + print("\n开始测试 acquire...") + start_time = time.time() + + # 发送4个请求,第4个应该会被限制 + for i in range(4): + req_start = time.time() + print(f"\n 请求 {i+1}:") + # 调用 acquire 进行频率限制检查 + test_limiter.acquire('https://httpbin.org/get') + req_time = time.time() - req_start + print(f" 通过频率检查,耗时: {req_time:.2f}s") + + total_time = time.time() - start_time + print(f"\n总耗时: {total_time:.2f}s (如果超过60s说明频率限制生效)") + print("\n" + "=" * 50) + print("频率限制 acquire 测试完成") + print("=" * 50) + + +if __name__ == '__main__': + # 运行测试 + test_rate_limiter() + # 注释掉实际请求测试,因为会等待60秒 + # test_rate_limit_acquire()