diff --git a/.gitignore b/.gitignore index a2cf47e..bc14e3c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ target/ !.mvn/wrapper/maven-wrapper.jar !**/src/main/** !**/src/test/** +.venv ### node js ### node_modules .temp diff --git a/adata/common/utils/__init__.py b/adata/common/utils/__init__.py index 9b4eeb9..89bc0f8 100644 --- a/adata/common/utils/__init__.py +++ b/adata/common/utils/__init__.py @@ -6,6 +6,6 @@ @log: change log """ from .snowflake import worker -from .sunrequests import sun_requests as requests +from .sunrequests import sun_requests as requests, SunRequests diff --git a/adata/common/utils/sunrequests.py b/adata/common/utils/sunrequests.py index 9fec0ed..909d16c 100644 --- a/adata/common/utils/sunrequests.py +++ b/adata/common/utils/sunrequests.py @@ -10,8 +10,72 @@ import threading import time +from urllib.parse import urlparse -import requests +import requests as _requests + + +class RateLimiter: + """频率限制器,按域名控制请求频率""" + + def __init__(self, default_max_requests=30, window_seconds=60): + self._default_max_requests = default_max_requests + self._window_seconds = window_seconds + self._domain_limits = {} + self._domain_requests = {} + self._lock = threading.Lock() + + def set_limit(self, domain, max_requests): + """设置指定域名的频率限制 + :param domain: 域名,如 '10jqka.com.cn' + :param max_requests: 每分钟最大请求数 + """ + with self._lock: + self._domain_limits[domain] = max_requests + + def get_limit(self, domain): + """获取指定域名的频率限制""" + with self._lock: + for d, limit in self._domain_limits.items(): + if d in domain: + return limit + return self._default_max_requests + + def acquire(self, url): + """获取请求许可,如果超过限制则等待 + :param url: 请求的URL + :return: 实际等待的时间(秒) + """ + domain = self._extract_domain(url) + max_requests = self.get_limit(domain) + + with self._lock: + now = time.time() + if domain not in self._domain_requests: + self._domain_requests[domain] = [] + + requests = self._domain_requests[domain] + requests = [t for t in requests if now - t < self._window_seconds] + + if len(requests) >= max_requests: + sleep_time = self._window_seconds - (now - requests[0]) + if sleep_time > 0: + time.sleep(sleep_time) + now = time.time() + requests = [t for t in requests if now - t < self._window_seconds] + + requests.append(now) + self._domain_requests[domain] = requests + + return 0 + + def _extract_domain(self, url): + """从URL中提取域名""" + try: + parsed = urlparse(url) + return parsed.netloc.lower() + except: + return url class SunProxy(object): @@ -42,10 +106,27 @@ def delete(cls, key): class SunRequests(object): + _rate_limiter = RateLimiter(default_max_requests=30, window_seconds=60) + def __init__(self, sun_proxy: SunProxy = None) -> None: super().__init__() self.sun_proxy = sun_proxy + @classmethod + def set_rate_limit(cls, domain, max_requests): + """设置指定域名的频率限制 + :param domain: 域名,如 '10jqka.com.cn' + :param max_requests: 每分钟最大请求数 + """ + cls._rate_limiter.set_limit(domain, max_requests) + + @classmethod + def set_default_rate_limit(cls, max_requests): + """设置默认的频率限制 + :param max_requests: 每分钟最大请求数 + """ + cls._rate_limiter._default_max_requests = max_requests + def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, **kwargs): """ 简单封装的请求,参考requests,增加循环次数和次数之间的等待时间 @@ -58,6 +139,10 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies :param kwargs: 其它 requests 参数,用法相同 :return: res """ + # 0. 频率限制检查 + if url: + self._rate_limiter.acquire(url) + # 1. 获取设置代理 proxies = self.__get_proxies(proxies) # 2. 请求数据结果 @@ -65,7 +150,7 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies for i in range(times): if wait_time: time.sleep(wait_time / 1000) - res = requests.request(method=method, url=url, proxies=proxies, **kwargs) + res = _requests.request(method=method, url=url, proxies=proxies, **kwargs) if res.status_code in (200, 404): return res time.sleep(retry_wait_time / 1000) @@ -83,7 +168,7 @@ def __get_proxies(self, proxies): ip = SunProxy.get('ip') proxy_url = SunProxy.get('proxy_url') if not ip and is_proxy and proxy_url: - ip = requests.get(url=proxy_url).text.replace('\r\n', '') \ + ip = _requests.get(url=proxy_url).text.replace('\r\n', '') \ .replace('\r', '').replace('\n', '').replace('\t', '') if is_proxy and ip: proxies = {'https': f"http://{ip}", 'http': f"http://{ip}"} diff --git a/adata/stock/market/stock_dividend.py b/adata/stock/market/stock_dividend.py index 386cecb..be0e628 100644 --- a/adata/stock/market/stock_dividend.py +++ b/adata/stock/market/stock_dividend.py @@ -69,14 +69,20 @@ def __dividend_baidu(self, stock_code): return null_df # 4. 封装数据 - result_df = pd.DataFrame(data=body, columns=['公告日', '分红方案', '信息']) + result_df = pd.DataFrame(data=body) + # 处理列名(适配不同的数据源返回格式) + if len(result_df.columns) >= 3: + result_df.columns = ['report_date', 'dividend_plan', 'ex_dividend_date'] + list(result_df.columns[3:]) result_df['stock_code'] = stock_code - rename_columns = {'公告日': 'report_date', '分红方案': 'dividend_plan'} - result_df = result_df.rename(columns=rename_columns) # 5. 数据清洗 - result_df = result_df[result_df.dividend_plan != '利润不分配'] - result_df['ex_dividend_date'] = result_df['ex_dividend_date'].replace('--', np.nan) - return result_df[['stock_code', 'report_date', 'dividend_plan', 'ex_dividend_date']] + if 'dividend_plan' in result_df.columns: + result_df = result_df[result_df.dividend_plan != '利润不分配'] + if 'ex_dividend_date' in result_df.columns: + result_df['ex_dividend_date'] = result_df['ex_dividend_date'].replace('--', np.nan) + # 返回标准列 + return_columns = ['stock_code', 'report_date', 'dividend_plan', 'ex_dividend_date'] + return_columns = [col for col in return_columns if col in result_df.columns] + return result_df[return_columns] if __name__ == '__main__': diff --git a/tests/adata_test/stress_test.py b/tests/adata_test/stress_test.py new file mode 100644 index 0000000..d1d2a1f --- /dev/null +++ b/tests/adata_test/stress_test.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +import sys +import os + +# 添加项目根目录到 Python 路径(确保使用开发版本) +project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +import threading +import time +from concurrent.futures import ThreadPoolExecutor + +# 先导入 SunRequests 设置频率限制,再导入 adata +from adata.common.utils.sunrequests import SunRequests +print(f"SunRequests: {SunRequests}") +print(f"dir(SunRequests): {dir(SunRequests)}") +SunRequests.set_default_rate_limit(30) + +import adata +from adata.common.utils import requests + +# ================= 评测配置 ================= +# 总请求数:设为 40,预期前 30 个成功,后 10 个被拦截或等待 +TOTAL_TASKS = 40 +# 并发数:同时开 10 个窗口发请求 +MAX_WORKERS = 10 +# =========================================== + +class MultiThreadEvaluator: + def __init__(self): + self.success = 0 + self.failure = 0 + self.lock = threading.Lock() + self.start_time = time.time() + + def task(self, task_id): + """单个线程执行的任务""" + try: + # 执行你刚才测通的接口 + df = adata.stock.market.get_dividend(stock_code='000001') + + with self.lock: + self.success += 1 + elapsed = time.time() - self.start_time + print(f"[{elapsed:6.2f}s] 任务 #{task_id:02d}: ✅ 成功 (Data rows: {len(df)})") + except Exception as e: + with self.lock: + self.failure += 1 + elapsed = time.time() - self.start_time + # 如果是限流拦截,这里会打印出 AI 定义的错误信息 + print(f"[{elapsed:6.2f}s] 任务 #{task_id:02d}: ❌ 拦截/报错: {e}") + + def run(self): + print(f"开始 AData 多线程评测...") + print(f"配置:总任务 {TOTAL_TASKS} | 并发 {MAX_WORKERS}") + print("-" * 65) + + # 核心:多线程执行 + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + executor.map(self.task, range(1, TOTAL_TASKS + 1)) + + print("-" * 65) + print(f"评测完成! 成功: {self.success} | 拦截: {self.failure}") + print(f"总耗时: {time.time() - self.start_time:.2f} 秒") + +if __name__ == '__main__': + adata.proxy(False) + evaluator = MultiThreadEvaluator() + evaluator.run() \ No newline at end of file