Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions adata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from adata.__version__ import __version__
from adata.bond import bond
from adata.common.utils.sunrequests import SunProxy
from adata.common.utils.rate_limiter import rate_limiter
from adata.fund import fund
from adata.sentiment import sentiment
from adata.stock import stock
Expand All @@ -33,6 +34,46 @@ def proxy(is_proxy=False, ip: str = None, proxy_url: str = None):
return


def set_rate_limit(domain: str = None, limit: int = 30):
"""
设置请求频率限制
:param domain: 域名,例如: 'push2his.eastmoney.com';为None时设置默认限制
:param limit: 每分钟最大请求次数,默认30次

使用示例:
# 设置默认频率限制为每分钟30次
adata.set_rate_limit(limit=30)

# 设置特定域名的频率限制为每分钟60次
adata.set_rate_limit(domain='push2his.eastmoney.com', limit=60)

# 查看当前频率限制状态
status = adata.get_rate_limit_status()
print(status)
"""
if domain:
rate_limiter.set_limit(domain, limit)
else:
rate_limiter.set_default_limit(limit)
return


def get_rate_limit_status(url: str = None):
"""
获取频率限制状态
:param url: 可选,指定URL查看特定域名状态;为None时返回所有域名状态
:return: 状态信息字典

使用示例:
# 获取所有域名的频率限制状态
status = adata.get_rate_limit_status()

# 获取特定URL的域名频率限制状态
status = adata.get_rate_limit_status('http://push2his.eastmoney.com/api/xxx')
"""
return rate_limiter.get_status(url)


# set up logging
logger = logging.getLogger("adata")

Expand Down
146 changes: 146 additions & 0 deletions adata/common/utils/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# -*- coding: utf-8 -*-
"""
@desc: 请求频率限制器
@author: 1nchaos
@time: 2026/3/18
@log: 基于域名控制请求频率
"""

import threading
import time
from collections import defaultdict
from urllib.parse import urlparse


class RateLimiter:
"""
请求频率限制器
基于域名控制请求频率,支持自定义每个域名的每分钟最大请求次数
"""
_instance = None
_instance_lock = threading.Lock()

def __new__(cls, *args, **kwargs):
if cls._instance is None:
with cls._instance_lock:
if cls._instance is None:
cls._instance = object.__new__(cls)
return cls._instance

def __init__(self, default_limit: int = 30):
"""
初始化频率限制器
:param default_limit: 默认每分钟最大请求次数
"""
# 避免重复初始化
if hasattr(self, '_initialized') and self._initialized:
return

self._default_limit = default_limit
# 存储每个域名的限制配置: {domain: limit}
self._limits = {}
# 存储每个域名的请求记录: {domain: [(timestamp, count), ...]}
self._records = defaultdict(list)
self._lock = threading.Lock()
self._initialized = True

def set_limit(self, domain: str, limit: int):
"""
设置指定域名的频率限制
:param domain: 域名,例如: 'push2his.eastmoney.com'
:param limit: 每分钟最大请求次数
"""
with self._lock:
self._limits[domain] = limit

def set_default_limit(self, limit: int):
"""
设置默认的频率限制
:param limit: 每分钟最大请求次数
"""
with self._lock:
self._default_limit = limit

def get_limit(self, domain: str) -> int:
"""
获取指定域名的频率限制
:param domain: 域名
:return: 每分钟最大请求次数
"""
return self._limits.get(domain, self._default_limit)

def _extract_domain(self, url: str) -> str:
"""
从URL中提取域名
:param url: 完整的URL
:return: 域名
"""
parsed = urlparse(url)
return parsed.netloc or parsed.path.split('/')[0]

def acquire(self, url: str):
"""
获取请求许可,如果超过频率限制则等待
:param url: 请求的URL
"""
domain = self._extract_domain(url)
limit = self.get_limit(domain)

with self._lock:
now = time.time()
# 清理60秒前的记录
self._records[domain] = [
(ts, cnt) for ts, cnt in self._records[domain]
if now - ts < 60
]

# 计算当前分钟内的请求次数
current_count = sum(cnt for ts, cnt in self._records[domain])

if current_count >= limit:
# 需要等待的时间
oldest_ts = self._records[domain][0][0]
wait_time = 60 - (now - oldest_ts)
if wait_time > 0:
time.sleep(wait_time)
# 等待后重新获取
return self.acquire(url)

# 记录本次请求
self._records[domain].append((now, 1))

def get_status(self, url: str = None) -> dict:
"""
获取频率限制状态
:param url: 可选,指定URL查看特定域名状态
:return: 状态信息字典
"""
with self._lock:
if url:
domain = self._extract_domain(url)
now = time.time()
records = [(ts, cnt) for ts, cnt in self._records[domain] if now - ts < 60]
current_count = sum(cnt for ts, cnt in records)
return {
'domain': domain,
'limit': self.get_limit(domain),
'current_count': current_count,
'remaining': max(0, self.get_limit(domain) - current_count)
}
else:
# 返回所有域名的状态
result = {}
now = time.time()
for domain in self._records:
records = [(ts, cnt) for ts, cnt in self._records[domain] if now - ts < 60]
current_count = sum(cnt for ts, cnt in records)
result[domain] = {
'limit': self.get_limit(domain),
'current_count': current_count,
'remaining': max(0, self.get_limit(domain) - current_count)
}
return result


# 全局单例实例
rate_limiter = RateLimiter()
14 changes: 11 additions & 3 deletions adata/common/utils/sunrequests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import requests

from adata.common.utils.rate_limiter import rate_limiter


class SunProxy(object):
_data = {}
Expand Down Expand Up @@ -46,7 +48,8 @@ def __init__(self, sun_proxy: SunProxy = None) -> None:
super().__init__()
self.sun_proxy = sun_proxy

def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, **kwargs):
def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None,
rate_limit=True, **kwargs):
"""
简单封装的请求,参考requests,增加循环次数和次数之间的等待时间
:param proxies: 代理配置
Expand All @@ -55,12 +58,17 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies
:param times: 次数,int
:param retry_wait_time: 重试等待时间,毫秒
:param wait_time: 等待时间:毫秒;表示每个请求的间隔时间,在请求之前等待sleep,主要用于防止请求太频繁的限制。
:param rate_limit: 是否启用频率限制,默认True
:param kwargs: 其它 requests 参数,用法相同
:return: res
"""
# 1. 获取设置代理
# 1. 频率限制检查
if rate_limit and url:
rate_limiter.acquire(url)

# 2. 获取设置代理
proxies = self.__get_proxies(proxies)
# 2. 请求数据结果
# 3. 请求数据结果
res = None
for i in range(times):
if wait_time:
Expand Down
90 changes: 90 additions & 0 deletions tests/rate_limiter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
"""
@desc: 频率限制器测试
@author: 1nchaos
@time: 2026/3/18
"""

import time
import sys
sys.path.insert(0, 'c:\\Users\\EDY\\adata')

# 手动导入 rate_limiter 模块
import importlib.util
spec = importlib.util.spec_from_file_location("rate_limiter", "c:/Users/EDY/adata/adata/common/utils/rate_limiter.py")
rate_limiter_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(rate_limiter_module)

# 获取全局单例实例
rate_limiter = rate_limiter_module.rate_limiter


def test_rate_limiter():
"""测试频率限制功能"""
print("=" * 50)
print("测试频率限制器功能")
print("=" * 50)

# 测试1: 设置默认频率限制
print("\n1. 设置默认频率限制为每分钟5次(测试用)")
rate_limiter.set_default_limit(5)
print(" ✓ 默认频率限制设置成功")

# 测试2: 设置特定域名频率限制
print("\n2. 设置特定域名频率限制")
rate_limiter.set_limit('push2his.eastmoney.com', 10)
rate_limiter.set_limit('query.sse.com.cn', 20)
print(" ✓ 特定域名频率限制设置成功")

# 测试3: 查看频率限制状态
print("\n3. 查看频率限制状态")
status = rate_limiter.get_status()
print(f" 当前状态: {status}")

# 测试4: 查看特定URL的频率限制状态
print("\n4. 查看特定URL的频率限制状态")
url_status = rate_limiter.get_status('http://push2his.eastmoney.com/api/qt/stock/kline/get')
print(f" URL状态: {url_status}")

print("\n" + "=" * 50)
print("频率限制器功能测试完成")
print("=" * 50)


def test_rate_limit_acquire():
"""测试频率限制 acquire 功能"""
print("\n" + "=" * 50)
print("测试频率限制 acquire 功能")
print("=" * 50)

# 创建新的实例用于测试(避免影响全局实例)
test_limiter = rate_limiter_module.RateLimiter()

# 设置较低的频率限制以便测试
test_limiter.set_default_limit(3)
print("\n设置频率限制为每分钟3次")

print("\n开始测试 acquire...")
start_time = time.time()

# 发送4个请求,第4个应该会被限制
for i in range(4):
req_start = time.time()
print(f"\n 请求 {i+1}:")
# 调用 acquire 进行频率限制检查
test_limiter.acquire('https://httpbin.org/get')
req_time = time.time() - req_start
print(f" 通过频率检查,耗时: {req_time:.2f}s")

total_time = time.time() - start_time
print(f"\n总耗时: {total_time:.2f}s (如果超过60s说明频率限制生效)")
print("\n" + "=" * 50)
print("频率限制 acquire 测试完成")
print("=" * 50)


if __name__ == '__main__':
# 运行测试
test_rate_limiter()
# 注释掉实际请求测试,因为会等待60秒
# test_rate_limit_acquire()