Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 71 additions & 5 deletions adata/common/utils/sunrequests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import threading
import time
from urllib.parse import urlparse

import requests

Expand Down Expand Up @@ -41,10 +42,69 @@ def delete(cls, key):
del cls._data[key]


class RateLimiter:
_instance_lock = threading.Lock()
_instance = None
_default_limit = 30
_window_seconds = 60
_domain_limits = {}
_request_history = {}

def __new__(cls, *args, **kwargs):
if cls._instance is None:
with cls._instance_lock:
if cls._instance is None:
cls._instance = object.__new__(cls)
return cls._instance

@classmethod
def set_default_limit(cls, limit):
cls._default_limit = limit

@classmethod
def set_domain_limit(cls, domain, limit):
cls._domain_limits[domain] = limit

@classmethod
def get_domain_limit(cls, domain):
return cls._domain_limits.get(domain, cls._default_limit)

@classmethod
def _clean_old_requests(cls, domain, current_time):
if domain not in cls._request_history:
cls._request_history[domain] = []
cls._request_history[domain] = [
t for t in cls._request_history[domain]
if current_time - t < cls._window_seconds
]

@classmethod
def acquire(cls, url):
parsed = urlparse(url)
domain = parsed.netloc
current_time = time.time()

with cls._instance_lock:
cls._clean_old_requests(domain, current_time)
limit = cls.get_domain_limit(domain)

if len(cls._request_history[domain]) >= limit:
oldest_request = cls._request_history[domain][0]
wait_time = cls._window_seconds - (current_time - oldest_request)
if wait_time > 0:
time.sleep(wait_time)
current_time = time.time()
cls._clean_old_requests(domain, current_time)

cls._request_history[domain].append(current_time)
return True


class SunRequests(object):
def __init__(self, sun_proxy: SunProxy = None) -> None:
super().__init__()
self.sun_proxy = sun_proxy
self._rate_limiter = RateLimiter()

def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, **kwargs):
"""
Expand All @@ -58,9 +118,10 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies
:param kwargs: 其它 requests 参数,用法相同
:return: res
"""
# 1. 获取设置代理
if url:
self._rate_limiter.acquire(url)

proxies = self.__get_proxies(proxies)
# 2. 请求数据结果
res = None
for i in range(times):
if wait_time:
Expand All @@ -74,9 +135,6 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies
return res

def __get_proxies(self, proxies):
"""
获取代理配置
"""
if proxies is None:
proxies = {}
is_proxy = SunProxy.get('is_proxy')
Expand All @@ -90,4 +148,12 @@ def __get_proxies(self, proxies):
return proxies


def set_default_rate_limit(limit):
RateLimiter.set_default_limit(limit)


def set_domain_rate_limit(domain, limit):
RateLimiter.set_domain_limit(domain, limit)


sun_requests = SunRequests()