diff --git a/src/atv_player/app.py b/src/atv_player/app.py index 26242ff..f48cc15 100644 --- a/src/atv_player/app.py +++ b/src/atv_player/app.py @@ -3,7 +3,6 @@ from collections.abc import Mapping from dataclasses import replace import gc -import httpx import inspect import threading import time @@ -65,11 +64,13 @@ from atv_player.metadata.providers.tmdb import TMDBProvider, infer_tmdb_media_type from atv_player.metadata.providers.tmdb_client import TMDBClient from atv_player.models import AppConfig, LiveEpgConfig, PlayItem, VodItem -from atv_player.network_proxy import ProxyConfig, ProxyDecider, build_httpx_kwargs_for_url +from atv_player.network_client import NetworkClient +from atv_player.network_proxy import ProxyConfig, ProxyDecider from atv_player.paths import app_cache_dir, app_data_dir from atv_player.live_source_repository import LiveSourceRepository from atv_player.plugins import SpiderPluginLoader, SpiderPluginManager from atv_player.plugins.compat.base.spider import set_proxy_decider_loader as set_spider_proxy_decider_loader +from atv_player.plugins.compat.base.spider import set_session_loader as set_spider_session_loader from atv_player.plugins.repository import SpiderPluginRepository from atv_player.playback_parsers import BuiltInPlaybackParserService from atv_player.player.m3u8_ad_filter import M3U8AdFilter @@ -77,9 +78,15 @@ from atv_player.yt_dlp_service import YtdlpPlaybackService from atv_player.storage import SettingsRepository from atv_player.time_utils import is_refresh_stale -from atv_player.ui.poster_loader import set_proxy_decider_loader +from atv_player.ui.poster_loader import set_http_get_loader, set_proxy_decider_loader from atv_player.ui.login_window import LoginWindow -from atv_player.ui.main_window import MainWindow, load_direct_parse_detail +from atv_player.ui.main_window import ( + MainWindow, + load_direct_parse_detail, + set_main_window_http_get_loader, + set_main_window_http_post_loader, +) +from atv_player.ui.player_window import set_player_window_http_get_loader from atv_player.ui.icon_cache import load_icon POSTER_CACHE_MAX_AGE_SECONDS = 7 * 24 * 60 * 60 @@ -141,6 +148,27 @@ def decide_start_view(config: AppConfig) -> str: return "main" if config.token else "login" +def _proxy_signature(config: AppConfig) -> tuple[str, str, tuple[str, ...]]: + return ( + config.network_proxy_mode, + config.network_proxy_url, + tuple(config.network_proxy_bypass_rules), + ) + + +def _make_proxy_invalidation_wrapper(save_fn, config: AppConfig, network: NetworkClient): + last = [_proxy_signature(config)] + + def wrapped() -> None: + save_fn() + current = _proxy_signature(config) + if current != last[0]: + last[0] = current + network.invalidate_proxy() + + return wrapped + + def _app_icon_path() -> Path: return Path(__file__).resolve().parent / "icons" / "app.svg" @@ -306,26 +334,32 @@ def __init__(self, repo: SettingsRepository) -> None: self.login_window: LoginWindow | None = None self.main_window: MainWindow | None = None self._api_client: ApiClient | None = None - set_proxy_decider_loader(self._build_proxy_decider) - set_spider_proxy_decider_loader(self._build_proxy_decider) + self._network = NetworkClient(self._build_proxy_decider) + set_proxy_decider_loader(lambda: self._network.proxy_decider) + set_http_get_loader(lambda: self._network.get) + set_main_window_http_get_loader(lambda: self._network.get) + set_main_window_http_post_loader(lambda: self._network.post) + set_player_window_http_get_loader(lambda: self._network.get) + set_spider_proxy_decider_loader(lambda: self._network.proxy_decider) + set_spider_session_loader(lambda: self._network.requests_session()) self._m3u8_ad_filter = M3U8AdFilter( proxy_server=LocalHlsProxyServer( - get=self._proxy_http_get(), - stream=self._proxy_http_stream(), + get=self._network.get, + stream=self._network.stream, ), - get=self._proxy_http_get(), + get=self._network.get, ) self._playback_parser_service = BuiltInPlaybackParserService( - get=self._proxy_http_get(), - post=self._proxy_http_post(), + get=self._network.get, + post=self._network.post, ) self._yt_dlp_service = YtdlpPlaybackService( proxy_decider=self._build_proxy_decider(), config_loader=self.repo.load_config, ) self._danmaku_service = create_default_danmaku_service( - get=self._proxy_http_get(), - post=self._proxy_http_post(), + get=self._network.get, + post=self._network.post, ) if hasattr(repo, "database_path"): self._live_source_repository = LiveSourceRepository(repo.database_path) @@ -333,7 +367,7 @@ def __init__(self, repo: SettingsRepository) -> None: self._plugin_repository = SpiderPluginRepository(repo.database_path) self._playback_history_repository = LocalPlaybackHistoryRepository(repo.database_path) cache_dir = app_cache_dir() / "plugins" - self._plugin_loader = SpiderPluginLoader(cache_dir, get=self._proxy_http_get()) + self._plugin_loader = SpiderPluginLoader(cache_dir, get=self._network.get) self._plugin_manager = SpiderPluginManager( self._plugin_repository, self._plugin_loader, @@ -384,28 +418,13 @@ def _build_proxy_decider(self) -> ProxyDecider: ) def _proxy_http_get(self): - def run(url: str, **kwargs): - request_kwargs = dict(kwargs) - request_kwargs.update(build_httpx_kwargs_for_url(self._build_proxy_decider(), url)) - return httpx.get(url, **request_kwargs) - - return run + return self._network.get def _proxy_http_post(self): - def run(url: str, **kwargs): - request_kwargs = dict(kwargs) - request_kwargs.update(build_httpx_kwargs_for_url(self._build_proxy_decider(), url)) - return httpx.post(url, **request_kwargs) - - return run + return self._network.post def _proxy_http_stream(self): - def run(method: str, url: str, **kwargs): - request_kwargs = dict(kwargs) - request_kwargs.update(build_httpx_kwargs_for_url(self._build_proxy_decider(), url)) - return httpx.stream(method, url, **request_kwargs) - - return run + return self._network.stream def start(self) -> QWidget: config = self.repo.load_config() @@ -1211,7 +1230,9 @@ def plugin_loader_task(): history_controller=history_controller, player_controller=player_controller, config=config, - save_config=lambda: self.repo.save_config(config), + save_config=_make_proxy_invalidation_wrapper( + lambda: self.repo.save_config(config), config, self._network + ), douban_controller=douban_controller, telegram_controller=telegram_controller, bilibili_controller=bilibili_controller, @@ -1344,3 +1365,4 @@ def close(self) -> None: if callable(close_filter): close_filter() self._close_api_client() + self._network.close() diff --git a/src/atv_player/network_client.py b/src/atv_player/network_client.py new file mode 100644 index 0000000..c65f667 --- /dev/null +++ b/src/atv_player/network_client.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import threading +from collections.abc import Callable + +import httpx +import requests + +from atv_player.network_proxy import ProxyDecider + + +_DEFAULT_POOL_LIMIT = 10 +_MANUAL_PREFIX = "manual:" + + +def _client_key(decider: ProxyDecider | None, url: str) -> str: + if decider is None: + return "direct" + decision = decider.decide(url) + if decision.kind == "direct": + return "direct" + if decision.kind == "system": + return "system" + if decision.kind == "manual": + return f"{_MANUAL_PREFIX}{decision.proxy_url}" + return "direct" + + +def _default_client_factory(key: str, pool_limit: int) -> httpx.Client: + limits = httpx.Limits( + max_connections=pool_limit, + max_keepalive_connections=pool_limit, + ) + if key == "direct": + return httpx.Client(trust_env=False, limits=limits) + if key == "system": + return httpx.Client(trust_env=True, limits=limits) + if key.startswith(_MANUAL_PREFIX): + proxy_url = key[len(_MANUAL_PREFIX):] + return httpx.Client(proxy=proxy_url, trust_env=False, limits=limits) + raise ValueError(f"unknown client key: {key}") + + +def _default_session_factory(pool_limit: int) -> requests.Session: + session = requests.Session() + adapter = requests.adapters.HTTPAdapter( + pool_connections=pool_limit, + pool_maxsize=pool_limit, + ) + session.mount("http://", adapter) + session.mount("https://", adapter) + return session + + +class NetworkClient: + def __init__( + self, + decider_factory: Callable[[], ProxyDecider | None], + *, + pool_limit: int = _DEFAULT_POOL_LIMIT, + client_factory: Callable[[str, int], httpx.Client] = _default_client_factory, + session_factory: Callable[[int], requests.Session] = _default_session_factory, + ) -> None: + self._decider_factory = decider_factory + self._pool_limit = pool_limit + self._client_factory = client_factory + self._session_factory = session_factory + self._lock = threading.Lock() + self._decider: ProxyDecider | None = None + self._decider_loaded = False + self._clients: dict[str, httpx.Client] = {} + self._session: requests.Session | None = None + self._closed = False + + @property + def proxy_decider(self) -> ProxyDecider | None: + with self._lock: + return self._load_decider_locked() + + def _load_decider_locked(self) -> ProxyDecider | None: + if not self._decider_loaded: + self._decider = self._decider_factory() + self._decider_loaded = True + return self._decider + + def invalidate_proxy(self) -> None: + with self._lock: + if self._closed: + return + self._decider = None + self._decider_loaded = False + non_direct_keys = [k for k in self._clients if k != "direct"] + clients_to_close = [self._clients.pop(k) for k in non_direct_keys] + session_to_close = self._session + self._session = None + for client in clients_to_close: + try: + client.close() + except Exception: + pass + if session_to_close is not None: + try: + session_to_close.close() + except Exception: + pass + + def close(self) -> None: + with self._lock: + if self._closed: + return + self._closed = True + clients_to_close = list(self._clients.values()) + self._clients.clear() + session_to_close = self._session + self._session = None + self._decider = None + self._decider_loaded = False + for client in clients_to_close: + try: + client.close() + except Exception: + pass + if session_to_close is not None: + try: + session_to_close.close() + except Exception: + pass + + def _resolve(self, url: str) -> httpx.Client: + with self._lock: + if self._closed: + raise RuntimeError("NetworkClient is closed") + decider = self._load_decider_locked() + key = _client_key(decider, url) + client = self._clients.get(key) + if client is None: + client = self._client_factory(key, self._pool_limit) + self._clients[key] = client + return client + + def get(self, url: str, **kwargs): + return self._resolve(url).get(url, **kwargs) + + def post(self, url: str, **kwargs): + return self._resolve(url).post(url, **kwargs) + + def stream(self, method: str, url: str, **kwargs): + return self._resolve(url).stream(method, url, **kwargs) + + def requests_session(self) -> requests.Session: + with self._lock: + if self._closed: + raise RuntimeError("NetworkClient is closed") + if self._session is None: + self._session = self._session_factory(self._pool_limit) + return self._session diff --git a/src/atv_player/plugins/compat/base/spider.py b/src/atv_player/plugins/compat/base/spider.py index 5665527..fd0ada6 100644 --- a/src/atv_player/plugins/compat/base/spider.py +++ b/src/atv_player/plugins/compat/base/spider.py @@ -15,6 +15,7 @@ _CACHE_ROOT = Path.home() / ".cache" / "atv-player" / "plugins" / "spider-cache" _proxy_decider_loader: Callable[[], ProxyDecider | None] | None = None +_session_loader: Callable[[], requests.Session | None] | None = None def set_cache_root(path: Path | str) -> None: @@ -28,12 +29,23 @@ def set_proxy_decider_loader(loader: Callable[[], ProxyDecider | None] | None) - _proxy_decider_loader = loader +def set_session_loader(loader: Callable[[], requests.Session | None] | None) -> None: + global _session_loader + _session_loader = loader + + def _effective_proxy_decider() -> ProxyDecider | None: if _proxy_decider_loader is None: return None return _proxy_decider_loader() +def _effective_session() -> requests.Session | None: + if _session_loader is None: + return None + return _session_loader() + + def _cache_path(key: str) -> Path: _CACHE_ROOT.mkdir(parents=True, exist_ok=True) return _CACHE_ROOT / f"{sha256(key.encode('utf-8')).hexdigest()}.cache" @@ -98,8 +110,7 @@ def fetch( stream=False, allow_redirects=True, ): - response = requests.get( - url, + kwargs = dict( params=params, cookies=cookies, headers=headers, @@ -109,6 +120,8 @@ def fetch( allow_redirects=allow_redirects, proxies=build_requests_proxies_for_url(_effective_proxy_decider(), url), ) + session = _effective_session() + response = session.get(url, **kwargs) if session is not None else requests.get(url, **kwargs) response.encoding = "utf-8" return _buffer_and_close_response(response) @@ -125,8 +138,7 @@ def post( stream=False, allow_redirects=True, ): - response = requests.post( - url, + kwargs = dict( params=params, data=data, json=json, @@ -138,6 +150,8 @@ def post( allow_redirects=allow_redirects, proxies=build_requests_proxies_for_url(_effective_proxy_decider(), url), ) + session = _effective_session() + response = session.post(url, **kwargs) if session is not None else requests.post(url, **kwargs) response.encoding = "utf-8" return _buffer_and_close_response(response) diff --git a/src/atv_player/ui/main_window.py b/src/atv_player/ui/main_window.py index b767bf6..ea76db1 100644 --- a/src/atv_player/ui/main_window.py +++ b/src/atv_player/ui/main_window.py @@ -1,7 +1,7 @@ from __future__ import annotations import inspect import threading -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass from pathlib import Path from typing import Any, Protocol, cast @@ -223,8 +223,38 @@ def _looks_like_drive_share_link(value: str) -> bool: return any(hostname == domain or hostname.endswith(f".{domain}") for domain in _SUPPORTED_DRIVE_DOMAINS) +_http_get_loader: Callable[[], Callable[..., Any] | None] | None = None +_http_post_loader: Callable[[], Callable[..., Any] | None] | None = None + + +def set_main_window_http_get_loader(loader: Callable[[], Callable[..., Any] | None] | None) -> None: + global _http_get_loader + _http_get_loader = loader + + +def set_main_window_http_post_loader(loader: Callable[[], Callable[..., Any] | None] | None) -> None: + global _http_post_loader + _http_post_loader = loader + + +def _resolve_http_get() -> Callable[..., Any]: + if _http_get_loader is not None: + resolved = _http_get_loader() + if resolved is not None: + return resolved + return httpx.get + + +def _resolve_http_post() -> Callable[..., Any]: + if _http_post_loader is not None: + resolved = _http_post_loader() + if resolved is not None: + return resolved + return httpx.post + + def load_direct_parse_detail(url: str) -> dict[str, Any]: - response = httpx.get( + response = _resolve_http_get()( _DIRECT_PARSE_DETAIL_API, params={"ac": "list", "url": url}, timeout=10.0, @@ -236,7 +266,7 @@ def load_direct_parse_detail(url: str) -> dict[str, Any]: def load_360_hot_searches(hot_type: str = _DEFAULT_GLOBAL_SEARCH_HOT_TYPE) -> list[dict[str, str]]: - response = httpx.get( + response = _resolve_http_get()( _HOTKEY_360_API, params={"type": hot_type}, timeout=5.0, @@ -269,7 +299,7 @@ def load_tencent_hot_searches(hot_type: str = "hot") -> list[dict[str, str]]: def load_tencent_hot_search_sections() -> tuple[list[tuple[str, str]], dict[str, list[dict[str, str]]]]: - response = httpx.post( + response = _resolve_http_post()( _HOTKEY_TENCENT_API, headers={ "content-type": "application/json", @@ -320,7 +350,7 @@ def load_tencent_hot_search_sections() -> tuple[list[tuple[str, str]], dict[str, def load_iqiyi_hot_search_sections() -> tuple[list[tuple[str, str]], dict[str, list[dict[str, str]]]]: - response = httpx.get( + response = _resolve_http_get()( _HOTKEY_IQIYI_API, params={ "device_id": "7b16c55cfdf4edb1a33cd4fc07bc0f69", @@ -390,7 +420,7 @@ def load_global_search_hotkey_payload( def load_360_search_suggestions(keyword: str) -> list[str]: - response = httpx.get( + response = _resolve_http_get()( _SUGGESTION_360_API, params={"word": keyword, "encodein": "utf-8", "encodeout": "utf-8"}, timeout=5.0, diff --git a/src/atv_player/ui/player_window.py b/src/atv_player/ui/player_window.py index f21d3b5..36a336c 100644 --- a/src/atv_player/ui/player_window.py +++ b/src/atv_player/ui/player_window.py @@ -131,6 +131,22 @@ logger = logging.getLogger(__name__) +_http_get_loader: Callable[[], Callable[..., object] | None] | None = None + + +def set_player_window_http_get_loader(loader: Callable[[], Callable[..., object] | None] | None) -> None: + global _http_get_loader + _http_get_loader = loader + + +def _resolve_http_get() -> Callable[..., object]: + if _http_get_loader is not None: + resolved = _http_get_loader() + if resolved is not None: + return resolved + return httpx.get + + def _summarize_media_url(url: str) -> str: if url.startswith("data:application/dash+xml;base64,"): return "data:application/dash+xml;base64,..." @@ -2150,7 +2166,6 @@ def load() -> None: image_url, self._POSTER_SIZE, timeout=self._POSTER_REQUEST_TIMEOUT_SECONDS, - get=httpx.get, ) if self._is_window_alive(): if target == "video": @@ -6031,7 +6046,7 @@ def _fetch_external_subtitle_text(self, subtitle: ExternalSubtitleOption) -> str return subtitle_path.read_text(encoding="utf-8") current_item = self._current_play_item() headers = {} if current_item is None else dict(current_item.headers) - response = httpx.get(subtitle.url, headers=headers, timeout=10.0, follow_redirects=True) + response = _resolve_http_get()(subtitle.url, headers=headers, timeout=10.0, follow_redirects=True) return str(getattr(response, "text", "") or "") def _load_external_subtitle( diff --git a/src/atv_player/ui/poster_loader.py b/src/atv_player/ui/poster_loader.py index c17a2e7..0d86d6f 100644 --- a/src/atv_player/ui/poster_loader.py +++ b/src/atv_player/ui/poster_loader.py @@ -41,6 +41,7 @@ "youtu.be", } _proxy_decider_loader: Callable[[], ProxyDecider | None] | None = None +_http_get_loader: Callable[[], Callable[..., object] | None] | None = None def _looks_like_unsupported_page_url(source: str) -> bool: @@ -94,6 +95,21 @@ def set_proxy_decider_loader(loader: Callable[[], ProxyDecider | None] | None) - _proxy_decider_loader = loader +def set_http_get_loader(loader: Callable[[], Callable[..., object] | None] | None) -> None: + global _http_get_loader + _http_get_loader = loader + + +def _effective_http_get(get_override: Callable[..., object] | None) -> Callable[..., object]: + if get_override is not None: + return get_override + if _http_get_loader is not None: + resolved = _http_get_loader() + if resolved is not None: + return resolved + return httpx.get + + def _effective_proxy_decider(proxy_decider: ProxyDecider | None) -> ProxyDecider | None: if proxy_decider is not None: return proxy_decider @@ -181,7 +197,7 @@ def load_remote_poster_image( image_url: str, target_size: QSize, timeout: float = POSTER_REQUEST_TIMEOUT_SECONDS, - get=httpx.get, + get: Callable[..., object] | None = None, proxy_decider: ProxyDecider | None = None, ) -> QImage | None: normalized_url = normalize_poster_url(image_url) @@ -193,8 +209,9 @@ def load_remote_poster_image( if cached_image is not None: return cached_image + http_get = _effective_http_get(get) try: - response = get( + response = http_get( normalized_url, headers=build_poster_request_headers(normalized_url), timeout=timeout, diff --git a/tests/test_main_window_network_loaders.py b/tests/test_main_window_network_loaders.py new file mode 100644 index 0000000..1e27f6a --- /dev/null +++ b/tests/test_main_window_network_loaders.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import atv_player.ui.main_window as main_window_module +from atv_player.ui.main_window import ( + load_360_hot_searches, + load_360_search_suggestions, + load_direct_parse_detail, + load_iqiyi_hot_search_sections, + load_tencent_hot_search_sections, + set_main_window_http_get_loader, + set_main_window_http_post_loader, +) + + +class FakeResponse: + def __init__(self, payload: object) -> None: + self._payload = payload + + def json(self) -> object: + return self._payload + + def raise_for_status(self) -> None: + return None + + +@pytest.fixture(autouse=True) +def _reset_main_window_loaders(): + yield + set_main_window_http_get_loader(None) + set_main_window_http_post_loader(None) + + +def test_load_direct_parse_detail_uses_registered_loader() -> None: + calls: list[str] = [] + + def fake_get(url: str, **kwargs: Any) -> FakeResponse: + calls.append(url) + return FakeResponse({"list": []}) + + set_main_window_http_get_loader(lambda: fake_get) + + payload = load_direct_parse_detail("https://example.com/video") + + assert payload == {"list": []} + assert calls == [main_window_module._DIRECT_PARSE_DETAIL_API] + + +def test_load_360_hot_searches_uses_registered_loader() -> None: + calls: list[str] = [] + + def fake_get(url: str, **kwargs: Any) -> FakeResponse: + calls.append(url) + return FakeResponse({"data": [{"title": "热门"}]}) + + set_main_window_http_get_loader(lambda: fake_get) + + result = load_360_hot_searches() + + assert result == [{"title": "热门", "query": "热门"}] + assert calls == [main_window_module._HOTKEY_360_API] + + +def test_load_iqiyi_hot_search_sections_uses_registered_loader() -> None: + calls: list[str] = [] + + def fake_get(url: str, **kwargs: Any) -> FakeResponse: + calls.append(url) + return FakeResponse({}) + + set_main_window_http_get_loader(lambda: fake_get) + + load_iqiyi_hot_search_sections() + + assert calls == [main_window_module._HOTKEY_IQIYI_API] + + +def test_load_360_search_suggestions_uses_registered_loader() -> None: + calls: list[str] = [] + + def fake_get(url: str, **kwargs: Any) -> FakeResponse: + calls.append(url) + return FakeResponse({"result": []}) + + set_main_window_http_get_loader(lambda: fake_get) + + load_360_search_suggestions("keyword") + + assert calls == [main_window_module._SUGGESTION_360_API] + + +def test_load_tencent_hot_search_sections_uses_registered_post_loader() -> None: + calls: list[str] = [] + + def fake_post(url: str, **kwargs: Any) -> FakeResponse: + calls.append(url) + return FakeResponse({}) + + set_main_window_http_post_loader(lambda: fake_post) + + load_tencent_hot_search_sections() + + assert calls == [main_window_module._HOTKEY_TENCENT_API] + + +def test_get_loader_returning_none_falls_back_to_httpx(monkeypatch) -> None: + calls: list[str] = [] + + def fake_get(url: str, **kwargs: Any) -> FakeResponse: + calls.append(url) + return FakeResponse({}) + + import httpx as httpx_module + + monkeypatch.setattr(httpx_module, "get", fake_get) + set_main_window_http_get_loader(lambda: None) + + load_360_search_suggestions("kw") + + assert calls == [main_window_module._SUGGESTION_360_API] + + +def test_post_loader_returning_none_falls_back_to_httpx(monkeypatch) -> None: + calls: list[str] = [] + + def fake_post(url: str, **kwargs: Any) -> FakeResponse: + calls.append(url) + return FakeResponse({}) + + import httpx as httpx_module + + monkeypatch.setattr(httpx_module, "post", fake_post) + set_main_window_http_post_loader(lambda: None) + + load_tencent_hot_search_sections() + + assert calls == [main_window_module._HOTKEY_TENCENT_API] diff --git a/tests/test_network_client.py b/tests/test_network_client.py new file mode 100644 index 0000000..df7cb12 --- /dev/null +++ b/tests/test_network_client.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import threading +from collections.abc import Callable, Iterable +from unittest.mock import MagicMock + +import httpx +import pytest +import requests + +from atv_player.network_client import NetworkClient +from atv_player.network_proxy import ProxyConfig, ProxyDecider + + +def _make_decider( + mode: str = "direct", + proxy_url: str = "", + bypass: Iterable[str] = (), +) -> ProxyDecider: + return ProxyDecider( + ProxyConfig(mode=mode, proxy_url=proxy_url, bypass_rules=list(bypass)) + ) + + +def _client_factory() -> tuple[Callable[[str, int], MagicMock], dict[str, MagicMock]]: + clients: dict[str, MagicMock] = {} + + def factory(key: str, pool_limit: int) -> MagicMock: + clients[key] = MagicMock(name=f"client[{key}]") + return clients[key] + + return factory, clients + + +def _session_factory() -> tuple[Callable[[int], MagicMock], list[MagicMock]]: + sessions: list[MagicMock] = [] + + def factory(pool_limit: int) -> MagicMock: + sessions.append(MagicMock(name=f"session#{len(sessions)}")) + return sessions[-1] + + return factory, sessions + + +def test_get_with_none_decider_uses_direct_client() -> None: + factory, clients = _client_factory() + network = NetworkClient(lambda: None, client_factory=factory) + + network.get("https://example.com/") + + assert list(clients.keys()) == ["direct"] + clients["direct"].get.assert_called_once_with("https://example.com/") + + +def test_get_caches_client_by_url_decision() -> None: + factory, clients = _client_factory() + network = NetworkClient(lambda: None, client_factory=factory) + + network.get("https://example.com/a") + network.get("https://example.com/b") + network.get("https://example.org/c") + + assert list(clients.keys()) == ["direct"] + assert clients["direct"].get.call_count == 3 + + +def test_get_routes_bypass_and_proxy_to_separate_clients() -> None: + decider = _make_decider("http", "http://p:7890", bypass=["localhost"]) + factory, clients = _client_factory() + network = NetworkClient(lambda: decider, client_factory=factory) + + network.get("http://localhost/api") + network.get("https://api.themoviedb.org/") + + assert set(clients.keys()) == {"direct", "manual:http://p:7890"} + clients["direct"].get.assert_called_once_with("http://localhost/api") + clients["manual:http://p:7890"].get.assert_called_once_with( + "https://api.themoviedb.org/" + ) + + +def test_get_routes_system_mode_to_system_client() -> None: + decider = _make_decider("system") + factory, clients = _client_factory() + network = NetworkClient(lambda: decider, client_factory=factory) + + network.get("https://api.bgm.tv/") + + assert list(clients.keys()) == ["system"] + + +def test_get_with_non_http_url_uses_direct() -> None: + decider = _make_decider("http", "http://p:7890") + factory, clients = _client_factory() + network = NetworkClient(lambda: decider, client_factory=factory) + + network.get("file:///tmp/x") + + assert list(clients.keys()) == ["direct"] + + +def test_get_passes_kwargs_to_underlying_client() -> None: + factory, clients = _client_factory() + network = NetworkClient(lambda: None, client_factory=factory) + + network.get("https://example.com/", timeout=5, headers={"X": "1"}) + + clients["direct"].get.assert_called_once_with( + "https://example.com/", timeout=5, headers={"X": "1"} + ) + + +def test_post_uses_same_client_pool() -> None: + factory, clients = _client_factory() + network = NetworkClient(lambda: None, client_factory=factory) + + network.post("https://example.com/", json={"a": 1}) + + clients["direct"].post.assert_called_once_with("https://example.com/", json={"a": 1}) + + +def test_stream_uses_correct_client() -> None: + factory, clients = _client_factory() + network = NetworkClient(lambda: None, client_factory=factory) + + network.stream("GET", "https://example.com/") + + clients["direct"].stream.assert_called_once_with("GET", "https://example.com/") + + +def test_invalidate_proxy_keeps_direct_client_open() -> None: + decider_box = {"d": _make_decider("http", "http://p:7890", bypass=["localhost"])} + factory, clients = _client_factory() + network = NetworkClient(lambda: decider_box["d"], client_factory=factory) + network.get("http://localhost/") + network.get("https://other.com/") + direct_client = clients["direct"] + + network.invalidate_proxy() + + direct_client.close.assert_not_called() + + +def test_invalidate_proxy_closes_proxied_clients() -> None: + decider_box = {"d": _make_decider("http", "http://p:7890")} + factory, clients = _client_factory() + network = NetworkClient(lambda: decider_box["d"], client_factory=factory) + network.get("https://x.com/") + manual_client = clients["manual:http://p:7890"] + + network.invalidate_proxy() + + manual_client.close.assert_called_once() + + +def test_invalidate_proxy_rebuilds_manual_client_for_next_request() -> None: + decider_box = {"d": _make_decider("http", "http://a:7890")} + factory, clients = _client_factory() + network = NetworkClient(lambda: decider_box["d"], client_factory=factory) + network.get("https://x.com/") + + decider_box["d"] = _make_decider("http", "http://b:7890") + network.invalidate_proxy() + network.get("https://x.com/") + + assert set(clients.keys()) == {"manual:http://a:7890", "manual:http://b:7890"} + + +def test_invalidate_proxy_re_reads_decider_factory() -> None: + decider_calls = 0 + + def decider_factory() -> ProxyDecider | None: + nonlocal decider_calls + decider_calls += 1 + return _make_decider("direct") + + factory, _ = _client_factory() + network = NetworkClient(decider_factory, client_factory=factory) + + network.get("https://a/") + network.get("https://b/") + assert decider_calls == 1 + + network.invalidate_proxy() + network.get("https://c/") + assert decider_calls == 2 + + +def test_close_closes_all_clients_and_session() -> None: + decider = _make_decider("http", "http://p:7890", bypass=["localhost"]) + factory, clients = _client_factory() + session_factory_fn, sessions = _session_factory() + network = NetworkClient( + lambda: decider, + client_factory=factory, + session_factory=session_factory_fn, + ) + network.get("http://localhost/") + network.get("https://x.com/") + network.requests_session() + + network.close() + + clients["direct"].close.assert_called_once() + clients["manual:http://p:7890"].close.assert_called_once() + sessions[0].close.assert_called_once() + + +def test_close_is_idempotent() -> None: + factory, clients = _client_factory() + network = NetworkClient(lambda: None, client_factory=factory) + network.get("https://a/") + + network.close() + network.close() + + assert clients["direct"].close.call_count == 1 + + +def test_get_after_close_raises() -> None: + factory, _ = _client_factory() + network = NetworkClient(lambda: None, client_factory=factory) + network.close() + + with pytest.raises(RuntimeError): + network.get("https://a/") + + +def test_requests_session_reused() -> None: + factory, _ = _client_factory() + session_factory_fn, sessions = _session_factory() + network = NetworkClient( + lambda: None, + client_factory=factory, + session_factory=session_factory_fn, + ) + + s1 = network.requests_session() + s2 = network.requests_session() + + assert s1 is s2 + assert len(sessions) == 1 + + +def test_requests_session_rebuilt_after_invalidate() -> None: + factory, _ = _client_factory() + session_factory_fn, sessions = _session_factory() + network = NetworkClient( + lambda: None, + client_factory=factory, + session_factory=session_factory_fn, + ) + + s1 = network.requests_session() + network.invalidate_proxy() + s2 = network.requests_session() + + assert s1 is not s2 + s1.close.assert_called_once() + assert len(sessions) == 2 + + +def test_client_factory_receives_pool_limit() -> None: + received: list[int] = [] + + def factory(key: str, pool_limit: int) -> MagicMock: + received.append(pool_limit) + return MagicMock() + + network = NetworkClient(lambda: None, client_factory=factory, pool_limit=42) + network.get("https://x.com/") + + assert received == [42] + + +def test_session_factory_receives_pool_limit() -> None: + received: list[int] = [] + + def session_factory(pool_limit: int) -> MagicMock: + received.append(pool_limit) + return MagicMock() + + factory, _ = _client_factory() + network = NetworkClient( + lambda: None, + client_factory=factory, + session_factory=session_factory, + pool_limit=7, + ) + network.requests_session() + + assert received == [7] + + +def test_concurrent_get_creates_one_client_per_key() -> None: + factory, clients = _client_factory() + network = NetworkClient(lambda: None, client_factory=factory) + barrier = threading.Barrier(20) + + def worker() -> None: + barrier.wait() + for _ in range(5): + network.get("https://example.com/") + + threads = [threading.Thread(target=worker) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert list(clients.keys()) == ["direct"] + assert clients["direct"].get.call_count == 100 + + +def test_default_client_factory_returns_httpx_client() -> None: + from atv_player.network_client import _default_client_factory + + client = _default_client_factory("direct", 10) + try: + assert isinstance(client, httpx.Client) + finally: + client.close() + + +def test_default_client_factory_manual_accepts_proxy_url() -> None: + from atv_player.network_client import _default_client_factory + + client = _default_client_factory("manual:http://127.0.0.1:9999", 10) + try: + assert isinstance(client, httpx.Client) + finally: + client.close() + + +def test_default_client_factory_system() -> None: + from atv_player.network_client import _default_client_factory + + client = _default_client_factory("system", 10) + try: + assert isinstance(client, httpx.Client) + finally: + client.close() + + +def test_default_client_factory_unknown_key_raises() -> None: + from atv_player.network_client import _default_client_factory + + with pytest.raises(ValueError): + _default_client_factory("nonsense", 10) + + +def test_default_session_factory_pool_size_matches_argument() -> None: + from atv_player.network_client import _default_session_factory + + session = _default_session_factory(7) + try: + for prefix in ("http://", "https://"): + adapter = session.get_adapter(f"{prefix}example.com") + assert isinstance(adapter, requests.adapters.HTTPAdapter) + assert adapter._pool_connections == 7 + assert adapter._pool_maxsize == 7 + finally: + session.close() diff --git a/tests/test_player_window_network_loader.py b/tests/test_player_window_network_loader.py new file mode 100644 index 0000000..c55c1ab --- /dev/null +++ b/tests/test_player_window_network_loader.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import atv_player.ui.player_window as player_window_module +from atv_player.ui.player_window import set_player_window_http_get_loader + + +class FakeResponse: + def __init__(self, text: str = "stub") -> None: + self.text = text + + def raise_for_status(self) -> None: + return None + + +@pytest.fixture(autouse=True) +def _reset_loader(): + yield + set_player_window_http_get_loader(None) + + +def test_player_window_http_get_loader_default_falls_back_to_httpx(monkeypatch) -> None: + calls: list[str] = [] + + def fake_get(url: str, **kwargs: Any) -> FakeResponse: + calls.append(url) + return FakeResponse() + + import httpx as httpx_module + + monkeypatch.setattr(httpx_module, "get", fake_get) + set_player_window_http_get_loader(None) + + get = player_window_module._resolve_http_get() + get("https://example.com/") + + assert calls == ["https://example.com/"] + + +def test_player_window_http_get_loader_returns_registered() -> None: + calls: list[str] = [] + + def fake_get(url: str, **kwargs: Any) -> FakeResponse: + calls.append(url) + return FakeResponse() + + set_player_window_http_get_loader(lambda: fake_get) + + get = player_window_module._resolve_http_get() + get("https://example.com/") + + assert calls == ["https://example.com/"] + + +def test_player_window_http_get_loader_returning_none_falls_back(monkeypatch) -> None: + calls: list[str] = [] + + def fake_get(url: str, **kwargs: Any) -> FakeResponse: + calls.append(url) + return FakeResponse() + + import httpx as httpx_module + + monkeypatch.setattr(httpx_module, "get", fake_get) + set_player_window_http_get_loader(lambda: None) + + get = player_window_module._resolve_http_get() + get("https://example.com/") + + assert calls == ["https://example.com/"] diff --git a/tests/test_poster_loader_network.py b/tests/test_poster_loader_network.py new file mode 100644 index 0000000..9e60fea --- /dev/null +++ b/tests/test_poster_loader_network.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from typing import Any + +import pytest +from PySide6.QtCore import QSize +from PySide6.QtGui import QImage + +import atv_player.ui.poster_loader as poster_loader_module +from atv_player.ui.poster_loader import ( + load_remote_poster_image, + set_http_get_loader, +) + + +class FakeResponse: + def __init__(self, content: bytes) -> None: + self.content = content + + def raise_for_status(self) -> None: + return None + + +def _png_bytes(width: int = 20, height: int = 40) -> bytes: + from PySide6.QtCore import QBuffer, QByteArray, QIODeviceBase + + png = QImage(width, height, QImage.Format.Format_RGB32) + png.fill(0x00FF00) + data = QByteArray() + qbuffer = QBuffer(data) + qbuffer.open(QIODeviceBase.OpenModeFlag.WriteOnly) + png.save(qbuffer, "PNG") + return bytes(data) + + +@pytest.fixture(autouse=True) +def _reset_loader(): + yield + set_http_get_loader(None) + + +def test_set_http_get_loader_supplies_default_when_get_not_passed(monkeypatch, tmp_path) -> None: + cache_dir = tmp_path / "posters" + monkeypatch.setattr(poster_loader_module, "poster_cache_dir", lambda: cache_dir) + calls: list[str] = [] + + def fake_get(url: str, **kwargs: Any) -> FakeResponse: + calls.append(url) + return FakeResponse(_png_bytes()) + + set_http_get_loader(lambda: fake_get) + + loaded = load_remote_poster_image( + "https://img3.doubanio.com/view/photo/m/public/p123.jpg", + QSize(90, 130), + ) + + assert loaded is not None + assert calls == ["https://img3.doubanio.com/view/photo/m/public/p123.jpg"] + + +def test_explicit_get_argument_takes_precedence_over_loader(monkeypatch, tmp_path) -> None: + cache_dir = tmp_path / "posters" + monkeypatch.setattr(poster_loader_module, "poster_cache_dir", lambda: cache_dir) + loader_calls: list[str] = [] + explicit_calls: list[str] = [] + + def loader_get(url: str, **kwargs: Any) -> FakeResponse: + loader_calls.append(url) + return FakeResponse(_png_bytes()) + + def explicit_get(url: str, **kwargs: Any) -> FakeResponse: + explicit_calls.append(url) + return FakeResponse(_png_bytes()) + + set_http_get_loader(lambda: loader_get) + + loaded = load_remote_poster_image( + "https://img3.doubanio.com/view/photo/m/public/p123.jpg", + QSize(90, 130), + get=explicit_get, + ) + + assert loaded is not None + assert explicit_calls == ["https://img3.doubanio.com/view/photo/m/public/p123.jpg"] + assert loader_calls == [] + + +def test_falls_back_to_httpx_get_when_loader_unset(monkeypatch, tmp_path) -> None: + cache_dir = tmp_path / "posters" + monkeypatch.setattr(poster_loader_module, "poster_cache_dir", lambda: cache_dir) + monkeypatch.setattr(poster_loader_module, "_http_get_loader", None, raising=False) + calls: list[str] = [] + + def fake_get(url: str, **kwargs: Any) -> FakeResponse: + calls.append(url) + return FakeResponse(_png_bytes()) + + import httpx as httpx_module + + monkeypatch.setattr(httpx_module, "get", fake_get) + + loaded = load_remote_poster_image( + "https://img3.doubanio.com/view/photo/m/public/p123.jpg", + QSize(90, 130), + ) + + assert loaded is not None + assert calls == ["https://img3.doubanio.com/view/photo/m/public/p123.jpg"] + + +def test_falls_back_to_httpx_get_when_loader_returns_none(monkeypatch, tmp_path) -> None: + cache_dir = tmp_path / "posters" + monkeypatch.setattr(poster_loader_module, "poster_cache_dir", lambda: cache_dir) + calls: list[str] = [] + + def fake_get(url: str, **kwargs: Any) -> FakeResponse: + calls.append(url) + return FakeResponse(_png_bytes()) + + import httpx as httpx_module + + monkeypatch.setattr(httpx_module, "get", fake_get) + set_http_get_loader(lambda: None) + + loaded = load_remote_poster_image( + "https://img3.doubanio.com/view/photo/m/public/p123.jpg", + QSize(90, 130), + ) + + assert loaded is not None + assert calls == ["https://img3.doubanio.com/view/photo/m/public/p123.jpg"] diff --git a/tests/test_proxy_invalidation_wrapper.py b/tests/test_proxy_invalidation_wrapper.py new file mode 100644 index 0000000..7db8de3 --- /dev/null +++ b/tests/test_proxy_invalidation_wrapper.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from atv_player.app import _make_proxy_invalidation_wrapper, _proxy_signature +from atv_player.models import AppConfig + + +def _make_config(mode: str = "direct", proxy_url: str = "", bypass: list[str] | None = None) -> AppConfig: + return AppConfig( + network_proxy_mode=mode, + network_proxy_url=proxy_url, + network_proxy_bypass_rules=list(bypass) if bypass is not None else [], + ) + + +def _make_network_stub() -> tuple[SimpleNamespace, list[str]]: + calls: list[str] = [] + stub = SimpleNamespace(invalidate_proxy=lambda: calls.append("invalidate")) + return stub, calls + + +def test_save_runs_underlying_callback() -> None: + config = _make_config() + network, _ = _make_network_stub() + save_calls = [0] + + def save() -> None: + save_calls[0] += 1 + + wrapped = _make_proxy_invalidation_wrapper(save, config, network) + wrapped() + wrapped() + + assert save_calls[0] == 2 + + +def test_save_does_not_invalidate_when_proxy_unchanged() -> None: + config = _make_config("direct", "", ["localhost"]) + network, calls = _make_network_stub() + + wrapped = _make_proxy_invalidation_wrapper(lambda: None, config, network) + wrapped() + wrapped() + + assert calls == [] + + +def test_save_invalidates_after_proxy_mode_changes() -> None: + config = _make_config("direct") + network, calls = _make_network_stub() + wrapped = _make_proxy_invalidation_wrapper(lambda: None, config, network) + wrapped() + + config.network_proxy_mode = "http" + config.network_proxy_url = "http://127.0.0.1:7890" + wrapped() + + assert calls == ["invalidate"] + + +def test_save_invalidates_after_proxy_url_changes() -> None: + config = _make_config("http", "http://a:7890") + network, calls = _make_network_stub() + wrapped = _make_proxy_invalidation_wrapper(lambda: None, config, network) + wrapped() + + config.network_proxy_url = "http://b:7890" + wrapped() + + assert calls == ["invalidate"] + + +def test_save_invalidates_after_bypass_rules_change() -> None: + config = _make_config("http", "http://a:7890", ["localhost"]) + network, calls = _make_network_stub() + wrapped = _make_proxy_invalidation_wrapper(lambda: None, config, network) + wrapped() + + config.network_proxy_bypass_rules = ["localhost", "10.0.0.0/8"] + wrapped() + + assert calls == ["invalidate"] + + +def test_save_invalidates_once_per_change_then_stays_quiet() -> None: + config = _make_config("direct") + network, calls = _make_network_stub() + wrapped = _make_proxy_invalidation_wrapper(lambda: None, config, network) + wrapped() + + config.network_proxy_mode = "http" + config.network_proxy_url = "http://a:7890" + wrapped() + wrapped() + wrapped() + + assert calls == ["invalidate"] + + +def test_proxy_signature_is_tuple_with_bypass_rules_normalized() -> None: + config = _make_config("http", "http://a:7890", ["localhost", "10.0.0.0/8"]) + + sig = _proxy_signature(config) + + assert sig == ("http", "http://a:7890", ("localhost", "10.0.0.0/8")) diff --git a/tests/test_spider_session_loader.py b/tests/test_spider_session_loader.py new file mode 100644 index 0000000..fbc381b --- /dev/null +++ b/tests/test_spider_session_loader.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import types +from typing import Any + +import atv_player.plugins.compat.base.spider as compat_spider_module +from atv_player.plugins.compat.base.spider import ( + Spider, + set_proxy_decider_loader, + set_session_loader, +) +from atv_player.network_proxy import ProxyConfig, ProxyDecider + + +class _FakeResponse: + def __init__(self) -> None: + self.encoding = "" + self._content = b"" + + @property + def content(self) -> bytes: + return self._content + + def close(self) -> None: + pass + + +class _FakeSession: + def __init__(self) -> None: + self.get_calls: list[dict[str, Any]] = [] + self.post_calls: list[dict[str, Any]] = [] + self.responses: list[_FakeResponse] = [] + + def _record(self, calls: list[dict[str, Any]], url: str, kwargs: dict[str, Any]) -> _FakeResponse: + calls.append({"url": url, **kwargs}) + response = _FakeResponse() + self.responses.append(response) + return response + + def get(self, url: str, **kwargs: Any) -> _FakeResponse: + return self._record(self.get_calls, url, kwargs) + + def post(self, url: str, **kwargs: Any) -> _FakeResponse: + return self._record(self.post_calls, url, kwargs) + + +def _reset_loaders() -> None: + set_session_loader(None) + set_proxy_decider_loader(None) + + +def test_fetch_uses_injected_session_when_loader_is_set(monkeypatch) -> None: + session = _FakeSession() + set_session_loader(lambda: session) + monkeypatch.setattr( + compat_spider_module, + "requests", + types.SimpleNamespace(get=lambda *a, **k: pytest_fail("requests.get must not be called")), + raising=False, + ) + + try: + Spider().fetch("https://example.com/api", headers={"X": "1"}, timeout=5) + finally: + _reset_loaders() + + assert len(session.get_calls) == 1 + call = session.get_calls[0] + assert call["url"] == "https://example.com/api" + assert call["headers"] == {"X": "1"} + assert call["timeout"] == 5 + + +def test_post_uses_injected_session_when_loader_is_set(monkeypatch) -> None: + session = _FakeSession() + set_session_loader(lambda: session) + monkeypatch.setattr( + compat_spider_module, + "requests", + types.SimpleNamespace(post=lambda *a, **k: pytest_fail("requests.post must not be called")), + raising=False, + ) + + try: + Spider().post("https://example.com/api", json={"x": 1}, timeout=9) + finally: + _reset_loaders() + + assert len(session.post_calls) == 1 + call = session.post_calls[0] + assert call["url"] == "https://example.com/api" + assert call["json"] == {"x": 1} + assert call["timeout"] == 9 + + +def test_fetch_passes_per_request_proxies_from_decider(monkeypatch) -> None: + session = _FakeSession() + decider = ProxyDecider( + ProxyConfig( + mode="http", + proxy_url="http://127.0.0.1:7890", + bypass_rules=["localhost"], + ) + ) + set_session_loader(lambda: session) + set_proxy_decider_loader(lambda: decider) + + try: + Spider().fetch("https://api.example.com/") + Spider().fetch("http://localhost/api") + finally: + _reset_loaders() + + assert session.get_calls[0]["proxies"] == { + "http": "http://127.0.0.1:7890", + "https": "http://127.0.0.1:7890", + } + assert session.get_calls[1]["proxies"] == {"http": None, "https": None} + + +def test_fetch_falls_back_to_requests_when_session_loader_is_none(monkeypatch) -> None: + set_session_loader(None) + calls: list[str] = [] + + def fake_get(url, **kwargs): + calls.append(url) + return _FakeResponse() + + monkeypatch.setattr( + compat_spider_module, + "requests", + types.SimpleNamespace(get=fake_get), + raising=False, + ) + + try: + Spider().fetch("https://example.com/api") + finally: + _reset_loaders() + + assert calls == ["https://example.com/api"] + + +def test_fetch_falls_back_to_requests_when_loader_returns_none(monkeypatch) -> None: + set_session_loader(lambda: None) + calls: list[str] = [] + + def fake_get(url, **kwargs): + calls.append(url) + return _FakeResponse() + + monkeypatch.setattr( + compat_spider_module, + "requests", + types.SimpleNamespace(get=fake_get), + raising=False, + ) + + try: + Spider().fetch("https://example.com/api") + finally: + _reset_loaders() + + assert calls == ["https://example.com/api"] + + +def pytest_fail(message: str) -> None: + import pytest + + pytest.fail(message)