diff --git a/.gitignore b/.gitignore index a2cf47e..7a62cb7 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,7 @@ node_modules .cache !.vuepress docs/.vuepress/dist - +.venv ### STS ### .apt_generated .classpath diff --git a/adata/__init__.py b/adata/__init__.py index dee08e2..fae727b 100644 --- a/adata/__init__.py +++ b/adata/__init__.py @@ -11,6 +11,7 @@ from adata.__version__ import __version__ from adata.bond import bond from adata.common.utils.sunrequests import SunProxy +from adata.common.utils.sunrequests import RateLimiter from adata.fund import fund from adata.sentiment import sentiment from adata.stock import stock @@ -33,6 +34,19 @@ def proxy(is_proxy=False, ip: str = None, proxy_url: str = None): return +def set_rate_limit(limit=30, domain=None): + """ + 设置频率限制 + :param limit: 每分钟请求次数,默认30次 + :param domain: 特定域名,如果不指定则设置全局默认值 + """ + if domain: + RateLimiter.set_domain_limit(domain, limit) + else: + RateLimiter.set_default_limit(limit) + return + + # set up logging logger = logging.getLogger("adata") diff --git a/adata/common/utils/__init__.py b/adata/common/utils/__init__.py index 9b4eeb9..b40c01c 100644 --- a/adata/common/utils/__init__.py +++ b/adata/common/utils/__init__.py @@ -7,5 +7,6 @@ """ from .snowflake import worker from .sunrequests import sun_requests as requests +from .sunrequests import RateLimiter diff --git a/adata/common/utils/sunrequests.py b/adata/common/utils/sunrequests.py index 9fec0ed..8854d03 100644 --- a/adata/common/utils/sunrequests.py +++ b/adata/common/utils/sunrequests.py @@ -10,6 +10,7 @@ import threading import time +from urllib.parse import urlparse import requests @@ -41,6 +42,73 @@ def delete(cls, key): del cls._data[key] +class RateLimiter(object): + _default_limit = 30 + _domain_limits = {} + _request_records = {} + _lock = threading.Lock() + + @classmethod + def set_default_limit(cls, limit): + """ + 设置默认的频率限制(每分钟请求次数) + """ + cls._default_limit = limit + + @classmethod + def set_domain_limit(cls, domain, limit): + """ + 设置特定域名的频率限制(每分钟请求次数) + """ + cls._domain_limits[domain] = limit + + @classmethod + def _get_domain_limit(cls, domain): + """ + 获取域名的限制次数 + """ + return cls._domain_limits.get(domain, cls._default_limit) + + @classmethod + def _clean_old_records(cls, domain, current_time): + """ + 清理超过60秒的请求记录 + """ + if domain not in cls._request_records: + return + sixty_seconds_ago = current_time - 60 + cls._request_records[domain] = [ + t for t in cls._request_records[domain] if t > sixty_seconds_ago + ] + + @classmethod + def wait_if_needed(cls, url): + """ + 检查频率限制,如果超过限制则等待 + """ + parsed_url = urlparse(url) + domain = parsed_url.netloc + + with cls._lock: + current_time = time.time() + cls._clean_old_records(domain, current_time) + + if domain not in cls._request_records: + cls._request_records[domain] = [] + + limit = cls._get_domain_limit(domain) + current_count = len(cls._request_records[domain]) + + if current_count >= limit: + oldest_time = cls._request_records[domain][0] + wait_time = oldest_time + 60 - current_time + if wait_time > 0: + time.sleep(wait_time) + cls._clean_old_records(domain, time.time()) + + cls._request_records[domain].append(time.time()) + + class SunRequests(object): def __init__(self, sun_proxy: SunProxy = None) -> None: super().__init__() @@ -58,9 +126,8 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies :param kwargs: 其它 requests 参数,用法相同 :return: res """ - # 1. 获取设置代理 + RateLimiter.wait_if_needed(url) proxies = self.__get_proxies(proxies) - # 2. 请求数据结果 res = None for i in range(times): if wait_time: diff --git a/adata/stock/market/stock_dividend.py b/adata/stock/market/stock_dividend.py index 386cecb..ad2bdab 100644 --- a/adata/stock/market/stock_dividend.py +++ b/adata/stock/market/stock_dividend.py @@ -69,14 +69,42 @@ def __dividend_baidu(self, stock_code): return null_df # 4. 封装数据 - result_df = pd.DataFrame(data=body, columns=['公告日', '分红方案', '信息']) + result_df = pd.DataFrame(data=body) result_df['stock_code'] = stock_code - rename_columns = {'公告日': 'report_date', '分红方案': 'dividend_plan'} + + if len(result_df.columns) >= 4: + rename_columns = { + result_df.columns[0]: 'report_date', + result_df.columns[1]: 'dividend_plan', + result_df.columns[2]: 'ex_dividend_date', + result_df.columns[3]: 'info' + } + elif len(result_df.columns) >= 3: + rename_columns = { + result_df.columns[0]: 'report_date', + result_df.columns[1]: 'dividend_plan', + result_df.columns[2]: 'ex_dividend_date' + } + else: + return pd.DataFrame() + result_df = result_df.rename(columns=rename_columns) + # 5. 数据清洗 - result_df = result_df[result_df.dividend_plan != '利润不分配'] - result_df['ex_dividend_date'] = result_df['ex_dividend_date'].replace('--', np.nan) - return result_df[['stock_code', 'report_date', 'dividend_plan', 'ex_dividend_date']] + if 'dividend_plan' in result_df.columns: + result_df = result_df[result_df.dividend_plan != '利润不分配'] + if 'ex_dividend_date' in result_df.columns: + result_df['ex_dividend_date'] = result_df['ex_dividend_date'].replace('--', np.nan) + + return_columns = ['stock_code'] + if 'report_date' in result_df.columns: + return_columns.append('report_date') + if 'dividend_plan' in result_df.columns: + return_columns.append('dividend_plan') + if 'ex_dividend_date' in result_df.columns: + return_columns.append('ex_dividend_date') + + return result_df[return_columns] if __name__ == '__main__': diff --git a/tests/adata_test/stress_test.py b/tests/adata_test/stress_test.py new file mode 100644 index 0000000..25f36ad --- /dev/null +++ b/tests/adata_test/stress_test.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +import adata +import threading +import time +from concurrent.futures import ThreadPoolExecutor + +# ================= 评测配置 ================= +# 总请求数:设为 40,预期前 30 个成功,后 10 个被拦截或等待 +TOTAL_TASKS = 40 +# 并发数:同时开 10 个窗口发请求 +MAX_WORKERS = 10 +# =========================================== + +class MultiThreadEvaluator: + def __init__(self): + self.success = 0 + self.failure = 0 + self.lock = threading.Lock() + self.start_time = time.time() + + def task(self, task_id): + """单个线程执行的任务""" + try: + # 执行你刚才测通的接口 + df = adata.stock.market.get_dividend(stock_code='000001') + + with self.lock: + self.success += 1 + elapsed = time.time() - self.start_time + print(f"[{elapsed:6.2f}s] 任务 #{task_id:02d}: ✅ 成功 (Data rows: {len(df)})") + except Exception as e: + with self.lock: + self.failure += 1 + elapsed = time.time() - self.start_time + # 如果是限流拦截,这里会打印出 AI 定义的错误信息 + print(f"[{elapsed:6.2f}s] 任务 #{task_id:02d}: ❌ 拦截/报错: {e}") + + def run(self): + print(f"开始 AData 多线程评测...") + print(f"配置:总任务 {TOTAL_TASKS} | 并发 {MAX_WORKERS}") + print("-" * 65) + + # 核心:多线程执行 + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + executor.map(self.task, range(1, TOTAL_TASKS + 1)) + + print("-" * 65) + print(f"评测完成! 成功: {self.success} | 拦截: {self.failure}") + print(f"总耗时: {time.time() - self.start_time:.2f} 秒") + +if __name__ == '__main__': + adata.proxy(False) + evaluator = MultiThreadEvaluator() + evaluator.run() \ No newline at end of file