Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ node_modules
.cache
!.vuepress
docs/.vuepress/dist

.venv
### STS ###
.apt_generated
.classpath
Expand Down
14 changes: 14 additions & 0 deletions adata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from adata.__version__ import __version__
from adata.bond import bond
from adata.common.utils.sunrequests import SunProxy
from adata.common.utils.sunrequests import RateLimiter
from adata.fund import fund
from adata.sentiment import sentiment
from adata.stock import stock
Expand All @@ -33,6 +34,19 @@ def proxy(is_proxy=False, ip: str = None, proxy_url: str = None):
return


def set_rate_limit(limit=30, domain=None):
"""
设置频率限制
:param limit: 每分钟请求次数,默认30次
:param domain: 特定域名,如果不指定则设置全局默认值
"""
if domain:
RateLimiter.set_domain_limit(domain, limit)
else:
RateLimiter.set_default_limit(limit)
return


# set up logging
logger = logging.getLogger("adata")

Expand Down
1 change: 1 addition & 0 deletions adata/common/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
"""
from .snowflake import worker
from .sunrequests import sun_requests as requests
from .sunrequests import RateLimiter


71 changes: 69 additions & 2 deletions adata/common/utils/sunrequests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import threading
import time
from urllib.parse import urlparse

import requests

Expand Down Expand Up @@ -41,6 +42,73 @@ def delete(cls, key):
del cls._data[key]


class RateLimiter(object):
_default_limit = 30
_domain_limits = {}
_request_records = {}
_lock = threading.Lock()

@classmethod
def set_default_limit(cls, limit):
"""
设置默认的频率限制(每分钟请求次数)
"""
cls._default_limit = limit

@classmethod
def set_domain_limit(cls, domain, limit):
"""
设置特定域名的频率限制(每分钟请求次数)
"""
cls._domain_limits[domain] = limit

@classmethod
def _get_domain_limit(cls, domain):
"""
获取域名的限制次数
"""
return cls._domain_limits.get(domain, cls._default_limit)

@classmethod
def _clean_old_records(cls, domain, current_time):
"""
清理超过60秒的请求记录
"""
if domain not in cls._request_records:
return
sixty_seconds_ago = current_time - 60
cls._request_records[domain] = [
t for t in cls._request_records[domain] if t > sixty_seconds_ago
]

@classmethod
def wait_if_needed(cls, url):
"""
检查频率限制,如果超过限制则等待
"""
parsed_url = urlparse(url)
domain = parsed_url.netloc

with cls._lock:
current_time = time.time()
cls._clean_old_records(domain, current_time)

if domain not in cls._request_records:
cls._request_records[domain] = []

limit = cls._get_domain_limit(domain)
current_count = len(cls._request_records[domain])

if current_count >= limit:
oldest_time = cls._request_records[domain][0]
wait_time = oldest_time + 60 - current_time
if wait_time > 0:
time.sleep(wait_time)
cls._clean_old_records(domain, time.time())

cls._request_records[domain].append(time.time())


class SunRequests(object):
def __init__(self, sun_proxy: SunProxy = None) -> None:
super().__init__()
Expand All @@ -58,9 +126,8 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies
:param kwargs: 其它 requests 参数,用法相同
:return: res
"""
# 1. 获取设置代理
RateLimiter.wait_if_needed(url)
proxies = self.__get_proxies(proxies)
# 2. 请求数据结果
res = None
for i in range(times):
if wait_time:
Expand Down
38 changes: 33 additions & 5 deletions adata/stock/market/stock_dividend.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,42 @@ def __dividend_baidu(self, stock_code):
return null_df

# 4. 封装数据
result_df = pd.DataFrame(data=body, columns=['公告日', '分红方案', '信息'])
result_df = pd.DataFrame(data=body)
result_df['stock_code'] = stock_code
rename_columns = {'公告日': 'report_date', '分红方案': 'dividend_plan'}

if len(result_df.columns) >= 4:
rename_columns = {
result_df.columns[0]: 'report_date',
result_df.columns[1]: 'dividend_plan',
result_df.columns[2]: 'ex_dividend_date',
result_df.columns[3]: 'info'
}
elif len(result_df.columns) >= 3:
rename_columns = {
result_df.columns[0]: 'report_date',
result_df.columns[1]: 'dividend_plan',
result_df.columns[2]: 'ex_dividend_date'
}
else:
return pd.DataFrame()

result_df = result_df.rename(columns=rename_columns)

# 5. 数据清洗
result_df = result_df[result_df.dividend_plan != '利润不分配']
result_df['ex_dividend_date'] = result_df['ex_dividend_date'].replace('--', np.nan)
return result_df[['stock_code', 'report_date', 'dividend_plan', 'ex_dividend_date']]
if 'dividend_plan' in result_df.columns:
result_df = result_df[result_df.dividend_plan != '利润不分配']
if 'ex_dividend_date' in result_df.columns:
result_df['ex_dividend_date'] = result_df['ex_dividend_date'].replace('--', np.nan)

return_columns = ['stock_code']
if 'report_date' in result_df.columns:
return_columns.append('report_date')
if 'dividend_plan' in result_df.columns:
return_columns.append('dividend_plan')
if 'ex_dividend_date' in result_df.columns:
return_columns.append('ex_dividend_date')

return result_df[return_columns]


if __name__ == '__main__':
Expand Down
54 changes: 54 additions & 0 deletions tests/adata_test/stress_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
import adata
import threading
import time
from concurrent.futures import ThreadPoolExecutor

# ================= 评测配置 =================
# 总请求数:设为 40,预期前 30 个成功,后 10 个被拦截或等待
TOTAL_TASKS = 40
# 并发数:同时开 10 个窗口发请求
MAX_WORKERS = 10
# ===========================================

class MultiThreadEvaluator:
def __init__(self):
self.success = 0
self.failure = 0
self.lock = threading.Lock()
self.start_time = time.time()

def task(self, task_id):
"""单个线程执行的任务"""
try:
# 执行你刚才测通的接口
df = adata.stock.market.get_dividend(stock_code='000001')

with self.lock:
self.success += 1
elapsed = time.time() - self.start_time
print(f"[{elapsed:6.2f}s] 任务 #{task_id:02d}: ✅ 成功 (Data rows: {len(df)})")
except Exception as e:
with self.lock:
self.failure += 1
elapsed = time.time() - self.start_time
# 如果是限流拦截,这里会打印出 AI 定义的错误信息
print(f"[{elapsed:6.2f}s] 任务 #{task_id:02d}: ❌ 拦截/报错: {e}")

def run(self):
print(f"开始 AData 多线程评测...")
print(f"配置:总任务 {TOTAL_TASKS} | 并发 {MAX_WORKERS}")
print("-" * 65)

# 核心:多线程执行
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
executor.map(self.task, range(1, TOTAL_TASKS + 1))

print("-" * 65)
print(f"评测完成! 成功: {self.success} | 拦截: {self.failure}")
print(f"总耗时: {time.time() - self.start_time:.2f} 秒")

if __name__ == '__main__':
adata.proxy(False)
evaluator = MultiThreadEvaluator()
evaluator.run()