Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ target/
!.mvn/wrapper/maven-wrapper.jar
!**/src/main/**
!**/src/test/**
.venv
### node js ###
node_modules
.temp
Expand Down
2 changes: 1 addition & 1 deletion adata/common/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
@log: change log
"""
from .snowflake import worker
from .sunrequests import sun_requests as requests
from .sunrequests import sun_requests as requests, SunRequests


91 changes: 88 additions & 3 deletions adata/common/utils/sunrequests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,72 @@

import threading
import time
from urllib.parse import urlparse

import requests
import requests as _requests


class RateLimiter:
"""频率限制器,按域名控制请求频率"""

def __init__(self, default_max_requests=30, window_seconds=60):
self._default_max_requests = default_max_requests
self._window_seconds = window_seconds
self._domain_limits = {}
self._domain_requests = {}
self._lock = threading.Lock()

def set_limit(self, domain, max_requests):
"""设置指定域名的频率限制
:param domain: 域名,如 '10jqka.com.cn'
:param max_requests: 每分钟最大请求数
"""
with self._lock:
self._domain_limits[domain] = max_requests

def get_limit(self, domain):
"""获取指定域名的频率限制"""
with self._lock:
for d, limit in self._domain_limits.items():
if d in domain:
return limit
return self._default_max_requests

def acquire(self, url):
"""获取请求许可,如果超过限制则等待
:param url: 请求的URL
:return: 实际等待的时间(秒)
"""
domain = self._extract_domain(url)
max_requests = self.get_limit(domain)

with self._lock:
now = time.time()
if domain not in self._domain_requests:
self._domain_requests[domain] = []

requests = self._domain_requests[domain]
requests = [t for t in requests if now - t < self._window_seconds]

if len(requests) >= max_requests:
sleep_time = self._window_seconds - (now - requests[0])
if sleep_time > 0:
time.sleep(sleep_time)
now = time.time()
requests = [t for t in requests if now - t < self._window_seconds]

requests.append(now)
self._domain_requests[domain] = requests

return 0

def _extract_domain(self, url):
"""从URL中提取域名"""
try:
parsed = urlparse(url)
return parsed.netloc.lower()
except:
return url


class SunProxy(object):
Expand Down Expand Up @@ -42,10 +106,27 @@ def delete(cls, key):


class SunRequests(object):
_rate_limiter = RateLimiter(default_max_requests=30, window_seconds=60)

def __init__(self, sun_proxy: SunProxy = None) -> None:
super().__init__()
self.sun_proxy = sun_proxy

@classmethod
def set_rate_limit(cls, domain, max_requests):
"""设置指定域名的频率限制
:param domain: 域名,如 '10jqka.com.cn'
:param max_requests: 每分钟最大请求数
"""
cls._rate_limiter.set_limit(domain, max_requests)

@classmethod
def set_default_rate_limit(cls, max_requests):
"""设置默认的频率限制
:param max_requests: 每分钟最大请求数
"""
cls._rate_limiter._default_max_requests = max_requests

def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, **kwargs):
"""
简单封装的请求,参考requests,增加循环次数和次数之间的等待时间
Expand All @@ -58,14 +139,18 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies
:param kwargs: 其它 requests 参数,用法相同
:return: res
"""
# 0. 频率限制检查
if url:
self._rate_limiter.acquire(url)

# 1. 获取设置代理
proxies = self.__get_proxies(proxies)
# 2. 请求数据结果
res = None
for i in range(times):
if wait_time:
time.sleep(wait_time / 1000)
res = requests.request(method=method, url=url, proxies=proxies, **kwargs)
res = _requests.request(method=method, url=url, proxies=proxies, **kwargs)
if res.status_code in (200, 404):
return res
time.sleep(retry_wait_time / 1000)
Expand All @@ -83,7 +168,7 @@ def __get_proxies(self, proxies):
ip = SunProxy.get('ip')
proxy_url = SunProxy.get('proxy_url')
if not ip and is_proxy and proxy_url:
ip = requests.get(url=proxy_url).text.replace('\r\n', '') \
ip = _requests.get(url=proxy_url).text.replace('\r\n', '') \
.replace('\r', '').replace('\n', '').replace('\t', '')
if is_proxy and ip:
proxies = {'https': f"http://{ip}", 'http': f"http://{ip}"}
Expand Down
18 changes: 12 additions & 6 deletions adata/stock/market/stock_dividend.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,20 @@ def __dividend_baidu(self, stock_code):
return null_df

# 4. 封装数据
result_df = pd.DataFrame(data=body, columns=['公告日', '分红方案', '信息'])
result_df = pd.DataFrame(data=body)
# 处理列名(适配不同的数据源返回格式)
if len(result_df.columns) >= 3:
result_df.columns = ['report_date', 'dividend_plan', 'ex_dividend_date'] + list(result_df.columns[3:])
result_df['stock_code'] = stock_code
rename_columns = {'公告日': 'report_date', '分红方案': 'dividend_plan'}
result_df = result_df.rename(columns=rename_columns)
# 5. 数据清洗
result_df = result_df[result_df.dividend_plan != '利润不分配']
result_df['ex_dividend_date'] = result_df['ex_dividend_date'].replace('--', np.nan)
return result_df[['stock_code', 'report_date', 'dividend_plan', 'ex_dividend_date']]
if 'dividend_plan' in result_df.columns:
result_df = result_df[result_df.dividend_plan != '利润不分配']
if 'ex_dividend_date' in result_df.columns:
result_df['ex_dividend_date'] = result_df['ex_dividend_date'].replace('--', np.nan)
# 返回标准列
return_columns = ['stock_code', 'report_date', 'dividend_plan', 'ex_dividend_date']
return_columns = [col for col in return_columns if col in result_df.columns]
return result_df[return_columns]


if __name__ == '__main__':
Expand Down
70 changes: 70 additions & 0 deletions tests/adata_test/stress_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
import sys
import os

# 添加项目根目录到 Python 路径(确保使用开发版本)
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if project_root not in sys.path:
sys.path.insert(0, project_root)

import threading
import time
from concurrent.futures import ThreadPoolExecutor

# 先导入 SunRequests 设置频率限制,再导入 adata
from adata.common.utils.sunrequests import SunRequests
print(f"SunRequests: {SunRequests}")
print(f"dir(SunRequests): {dir(SunRequests)}")
SunRequests.set_default_rate_limit(30)

import adata
from adata.common.utils import requests

# ================= 评测配置 =================
# 总请求数:设为 40,预期前 30 个成功,后 10 个被拦截或等待
TOTAL_TASKS = 40
# 并发数:同时开 10 个窗口发请求
MAX_WORKERS = 10
# ===========================================

class MultiThreadEvaluator:
def __init__(self):
self.success = 0
self.failure = 0
self.lock = threading.Lock()
self.start_time = time.time()

def task(self, task_id):
"""单个线程执行的任务"""
try:
# 执行你刚才测通的接口
df = adata.stock.market.get_dividend(stock_code='000001')

with self.lock:
self.success += 1
elapsed = time.time() - self.start_time
print(f"[{elapsed:6.2f}s] 任务 #{task_id:02d}: ✅ 成功 (Data rows: {len(df)})")
except Exception as e:
with self.lock:
self.failure += 1
elapsed = time.time() - self.start_time
# 如果是限流拦截,这里会打印出 AI 定义的错误信息
print(f"[{elapsed:6.2f}s] 任务 #{task_id:02d}: ❌ 拦截/报错: {e}")

def run(self):
print(f"开始 AData 多线程评测...")
print(f"配置:总任务 {TOTAL_TASKS} | 并发 {MAX_WORKERS}")
print("-" * 65)

# 核心:多线程执行
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
executor.map(self.task, range(1, TOTAL_TASKS + 1))

print("-" * 65)
print(f"评测完成! 成功: {self.success} | 拦截: {self.failure}")
print(f"总耗时: {time.time() - self.start_time:.2f} 秒")

if __name__ == '__main__':
adata.proxy(False)
evaluator = MultiThreadEvaluator()
evaluator.run()