From d4ae32b01052da4020b88a57fb2bc97db712f25d Mon Sep 17 00:00:00 2001 From: Shuhao Qing Date: Tue, 7 Apr 2026 11:51:35 +0800 Subject: [PATCH] feat(market): add retry with exponential backoff to all API calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit External API calls (Sina, Tencent, Eastmoney, CoinGecko, exchangerate) had no retry logic, causing failures on transient network errors. Add tenacity-based retry decorator with exponential backoff + jitter (3 attempts, 1s→2s→4s) for connection errors, timeouts, and 5xx. 4xx client errors are not retried. Co-Authored-By: Claude Opus 4.6 --- haoinvest/fx.py | 3 + haoinvest/http_retry.py | 53 ++++++++++++ haoinvest/market/crypto_provider.py | 4 + haoinvest/market/sources/eastmoney.py | 38 +++++---- haoinvest/market/sources/sina.py | 38 +++++---- haoinvest/market/sources/tencent.py | 23 ++++-- pyproject.toml | 1 + tests/test_http_retry.py | 114 ++++++++++++++++++++++++++ uv.lock | 11 +++ 9 files changed, 249 insertions(+), 36 deletions(-) create mode 100644 haoinvest/http_retry.py create mode 100644 tests/test_http_retry.py diff --git a/haoinvest/fx.py b/haoinvest/fx.py index 61ae05a..604ab6e 100644 --- a/haoinvest/fx.py +++ b/haoinvest/fx.py @@ -2,6 +2,8 @@ import httpx +from .http_retry import api_retry + # Hardcoded fallback rates (updated manually when needed) _FALLBACK_RATES = { ("USD", "CNY"): 7.25, @@ -57,6 +59,7 @@ def _get_rate(from_ccy: str, to_ccy: str) -> float: raise ValueError(f"No exchange rate available for {from_ccy} → {to_ccy}") +@api_retry def _fetch_live_rate(from_ccy: str, to_ccy: str) -> float: """Fetch live rate from a free API. Raises on failure.""" # exchangerate-api.com free tier diff --git a/haoinvest/http_retry.py b/haoinvest/http_retry.py new file mode 100644 index 0000000..48624a3 --- /dev/null +++ b/haoinvest/http_retry.py @@ -0,0 +1,53 @@ +"""Shared retry decorator for external API calls. + +Provides exponential backoff with jitter for transient network errors +and server-side failures (5xx). Does NOT retry client errors (4xx). +""" + +import logging + +import httpx +import requests +from tenacity import ( + retry, + retry_if_exception, + stop_after_attempt, + wait_exponential_jitter, +) + +logger = logging.getLogger(__name__) + + +def _is_retryable(exc: BaseException) -> bool: + """Return True for transient errors worth retrying.""" + # Network-level errors (connection refused, DNS failure, timeout) + if isinstance(exc, (requests.ConnectionError, requests.Timeout)): + return True + if isinstance(exc, (httpx.ConnectError, httpx.TimeoutException)): + return True + + # HTTP 5xx server errors + if isinstance(exc, requests.HTTPError) and exc.response is not None: + return exc.response.status_code >= 500 + if isinstance(exc, httpx.HTTPStatusError): + return exc.response.status_code >= 500 + + return False + + +def _log_retry(retry_state) -> None: + logger.debug( + "Retrying %s (attempt %d) after %s", + retry_state.fn.__name__ if retry_state.fn else "unknown", + retry_state.attempt_number, + retry_state.outcome.exception() if retry_state.outcome else "unknown", + ) + + +api_retry = retry( + stop=stop_after_attempt(3), + wait=wait_exponential_jitter(initial=1, max=10, jitter=0.5), + retry=retry_if_exception(_is_retryable), + reraise=True, + before_sleep=_log_retry, +) diff --git a/haoinvest/market/crypto_provider.py b/haoinvest/market/crypto_provider.py index a9e87b5..09ca32f 100644 --- a/haoinvest/market/crypto_provider.py +++ b/haoinvest/market/crypto_provider.py @@ -4,6 +4,7 @@ import httpx +from ..http_retry import api_retry from ..models import BasicInfo, MarketType, PriceBar from .provider import MarketProvider @@ -50,6 +51,7 @@ def client(self) -> httpx.Client: self._client = httpx.Client(timeout=10.0) return self._client + @api_retry def get_current_price(self, symbol: str) -> float: """Get latest USD price for a crypto asset. @@ -67,6 +69,7 @@ def get_current_price(self, symbol: str) -> float: raise ValueError(f"Crypto asset {symbol} (id={coin_id}) not found") return float(data[coin_id]["usd"]) + @api_retry def get_price_history(self, symbol: str, start: date, end: date) -> list[PriceBar]: """Get daily OHLC data for a crypto asset (close prices only from CoinGecko free tier).""" coin_id = _to_coingecko_id(symbol) @@ -93,6 +96,7 @@ def get_price_history(self, symbol: str, start: date, end: date) -> list[PriceBa ) return bars + @api_retry def get_basic_info(self, symbol: str) -> BasicInfo: """Get basic info for a crypto asset.""" coin_id = _to_coingecko_id(symbol) diff --git a/haoinvest/market/sources/eastmoney.py b/haoinvest/market/sources/eastmoney.py index 24794b9..a457213 100644 --- a/haoinvest/market/sources/eastmoney.py +++ b/haoinvest/market/sources/eastmoney.py @@ -9,6 +9,7 @@ import requests +from ...http_retry import api_retry from ...models import BasicInfo from ._common import exchange_prefix, parse_float @@ -22,6 +23,7 @@ ) +@api_retry def get_basic_info(symbol: str) -> BasicInfo: """Get basic company info from eastmoney emweb CompanySurvey API.""" code = f"{exchange_prefix(symbol)}{symbol}" @@ -47,21 +49,7 @@ def get_financial_indicators(symbol: str) -> dict: Gracefully returns empty dict on any failure. """ try: - r = requests.get( - _DATACENTER_URL, - params={ - "reportName": "RPT_LICO_FN_CPD", - "columns": _FIN_COLUMNS, - "filter": f'(SECURITY_CODE="{symbol}")', - "pageNumber": "1", - "pageSize": "1", - "sortColumns": "NOTICE_DATE", - "sortTypes": "-1", - }, - timeout=10, - ) - r.raise_for_status() - body = r.json() + body = _fetch_financial_data(symbol) if not body.get("success") or not body.get("result"): return {} @@ -86,3 +74,23 @@ def get_financial_indicators(symbol: str) -> dict: except Exception as e: logger.debug("Eastmoney financial indicators failed for %s: %s", symbol, e) return {} + + +@api_retry +def _fetch_financial_data(symbol: str) -> dict: + """Fetch financial report data from eastmoney datacenter (with retry).""" + r = requests.get( + _DATACENTER_URL, + params={ + "reportName": "RPT_LICO_FN_CPD", + "columns": _FIN_COLUMNS, + "filter": f'(SECURITY_CODE="{symbol}")', + "pageNumber": "1", + "pageSize": "1", + "sortColumns": "NOTICE_DATE", + "sortTypes": "-1", + }, + timeout=10, + ) + r.raise_for_status() + return r.json() diff --git a/haoinvest/market/sources/sina.py b/haoinvest/market/sources/sina.py index 7425635..1c16516 100644 --- a/haoinvest/market/sources/sina.py +++ b/haoinvest/market/sources/sina.py @@ -10,9 +10,11 @@ import requests +from ...http_retry import api_retry from ._common import market_prefix, parse_float +@api_retry def get_current_price(symbol: str) -> float: """Get current price from Sina Finance API.""" prefix = market_prefix(symbol) @@ -81,20 +83,7 @@ def get_sector_constituents(sector_name: str) -> list[dict]: "Use 'market sector-list' to see available sectors." ) - r = requests.get( - "https://vip.stock.finance.sina.com.cn/quotes_service/api/json_v2.php" - "/Market_Center.getHQNodeData", - params={ - "page": 1, - "num": 200, - "sort": "changepercent", - "asc": 0, - "node": node_code, - }, - headers={"Referer": "https://finance.sina.com.cn"}, - timeout=10, - ) - r.encoding = "gbk" + r = _fetch_sector_constituents(node_code) stocks = r.json() rows = [] @@ -115,6 +104,27 @@ def get_sector_constituents(sector_name: str) -> list[dict]: return rows +@api_retry +def _fetch_sector_constituents(node_code: str) -> requests.Response: + """Fetch sector constituent data from Sina API (with retry).""" + r = requests.get( + "https://vip.stock.finance.sina.com.cn/quotes_service/api/json_v2.php" + "/Market_Center.getHQNodeData", + params={ + "page": 1, + "num": 200, + "sort": "changepercent", + "asc": 0, + "node": node_code, + }, + headers={"Referer": "https://finance.sina.com.cn"}, + timeout=10, + ) + r.encoding = "gbk" + return r + + +@api_retry def _fetch_sector_data() -> dict[str, list[str]]: """Fetch Sina industry board data. Returns {node_code: [fields...]}.""" r = requests.get( diff --git a/haoinvest/market/sources/tencent.py b/haoinvest/market/sources/tencent.py index 9113101..9cc2418 100644 --- a/haoinvest/market/sources/tencent.py +++ b/haoinvest/market/sources/tencent.py @@ -10,6 +10,7 @@ import requests +from ...http_retry import api_retry from ...models import MarketType, PriceBar from ._common import market_prefix, parse_float @@ -21,6 +22,7 @@ _PB = 46 +@api_retry def get_current_price(symbol: str) -> float: """Get current price from Tencent Finance quote API.""" prefix = market_prefix(symbol) @@ -38,6 +40,7 @@ def get_current_price(symbol: str) -> float: return price +@api_retry def get_price_history(symbol: str, start: date, end: date) -> list[PriceBar]: """Get forward-adjusted daily klines from Tencent Finance API.""" prefix = market_prefix(symbol) @@ -90,13 +93,7 @@ def get_valuation(symbol: str) -> dict: "total_market_cap": None, } try: - prefix = market_prefix(symbol) - r = requests.get( - f"https://qt.gtimg.cn/q={prefix}{symbol}", - timeout=10, - ) - r.raise_for_status() - fields = r.text.strip().split("~") + fields = _fetch_quote_fields(symbol) if len(fields) <= _PB: logger.debug( "Tencent response too short for %s: %d fields", symbol, len(fields) @@ -115,3 +112,15 @@ def get_valuation(symbol: str) -> dict: logger.debug("Tencent valuation failed for %s: %s", symbol, e) return result + + +@api_retry +def _fetch_quote_fields(symbol: str) -> list[str]: + """Fetch and parse quote fields from Tencent API (with retry).""" + prefix = market_prefix(symbol) + r = requests.get( + f"https://qt.gtimg.cn/q={prefix}{symbol}", + timeout=10, + ) + r.raise_for_status() + return r.text.strip().split("~") diff --git a/pyproject.toml b/pyproject.toml index 397b7db..f7b4106 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "pandas-ta-classic==0.4.47", "quantstats==0.0.81", "pyportfolioopt==1.6.0", + "tenacity==9.1.4", ] [project.scripts] diff --git a/tests/test_http_retry.py b/tests/test_http_retry.py new file mode 100644 index 0000000..6bfd572 --- /dev/null +++ b/tests/test_http_retry.py @@ -0,0 +1,114 @@ +"""Tests for the http_retry module.""" + +from unittest.mock import MagicMock, patch + +import httpx +import pytest +import requests + +from haoinvest.http_retry import _is_retryable, api_retry + + +class TestIsRetryable: + """Test the _is_retryable predicate.""" + + def test_requests_connection_error(self): + assert _is_retryable(requests.ConnectionError()) is True + + def test_requests_timeout(self): + assert _is_retryable(requests.Timeout()) is True + + def test_httpx_connect_error(self): + assert _is_retryable(httpx.ConnectError("fail")) is True + + def test_httpx_timeout(self): + assert _is_retryable(httpx.ReadTimeout("timeout")) is True + + def test_requests_500_error(self): + resp = MagicMock() + resp.status_code = 500 + exc = requests.HTTPError(response=resp) + assert _is_retryable(exc) is True + + def test_requests_503_error(self): + resp = MagicMock() + resp.status_code = 503 + exc = requests.HTTPError(response=resp) + assert _is_retryable(exc) is True + + def test_requests_400_not_retryable(self): + resp = MagicMock() + resp.status_code = 400 + exc = requests.HTTPError(response=resp) + assert _is_retryable(exc) is False + + def test_requests_404_not_retryable(self): + resp = MagicMock() + resp.status_code = 404 + exc = requests.HTTPError(response=resp) + assert _is_retryable(exc) is False + + def test_httpx_500_error(self): + request = httpx.Request("GET", "https://example.com") + response = httpx.Response(500, request=request) + exc = httpx.HTTPStatusError("500", request=request, response=response) + assert _is_retryable(exc) is True + + def test_httpx_404_not_retryable(self): + request = httpx.Request("GET", "https://example.com") + response = httpx.Response(404, request=request) + exc = httpx.HTTPStatusError("404", request=request, response=response) + assert _is_retryable(exc) is False + + def test_value_error_not_retryable(self): + assert _is_retryable(ValueError("bad data")) is False + + def test_runtime_error_not_retryable(self): + assert _is_retryable(RuntimeError("unexpected")) is False + + +class TestApiRetry: + """Test the api_retry decorator behavior.""" + + @patch("haoinvest.http_retry._log_retry") + def test_retries_on_connection_error_then_succeeds(self, mock_log): + call_count = 0 + + @api_retry + def flaky_call(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise requests.ConnectionError("connection refused") + return "success" + + result = flaky_call() + assert result == "success" + assert call_count == 3 + + def test_no_retry_on_value_error(self): + call_count = 0 + + @api_retry + def bad_call(): + nonlocal call_count + call_count += 1 + raise ValueError("not found") + + with pytest.raises(ValueError, match="not found"): + bad_call() + assert call_count == 1 + + @patch("haoinvest.http_retry._log_retry") + def test_gives_up_after_max_attempts(self, mock_log): + call_count = 0 + + @api_retry + def always_fails(): + nonlocal call_count + call_count += 1 + raise requests.ConnectionError("always down") + + with pytest.raises(requests.ConnectionError): + always_fails() + assert call_count == 3 diff --git a/uv.lock b/uv.lock index 7522183..996f77b 100644 --- a/uv.lock +++ b/uv.lock @@ -494,6 +494,7 @@ dependencies = [ { name = "pydantic" }, { name = "pyportfolioopt" }, { name = "quantstats" }, + { name = "tenacity" }, { name = "typer" }, { name = "yfinance" }, ] @@ -513,6 +514,7 @@ requires-dist = [ { name = "pydantic", specifier = "==2.12.5" }, { name = "pyportfolioopt", specifier = "==1.6.0" }, { name = "quantstats", specifier = "==0.0.81" }, + { name = "tenacity", specifier = "==9.1.4" }, { name = "typer", specifier = "==0.24.1" }, { name = "yfinance", specifier = "==1.2.0" }, ] @@ -1722,6 +1724,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/55/db07de81b5c630da5cbf5c7df646580ca26dfaefa593667fc6f2fe016d2e/tabulate-0.10.0-py3-none-any.whl", hash = "sha256:f0b0622e567335c8fabaaa659f1b33bcb6ddfe2e496071b743aa113f8774f2d3", size = 39814, upload-time = "2026-03-04T18:55:31.284Z" }, ] +[[package]] +name = "tenacity" +version = "9.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/c6/ee486fd809e357697ee8a44d3d69222b344920433d3b6666ccd9b374630c/tenacity-9.1.4.tar.gz", hash = "sha256:adb31d4c263f2bd041081ab33b498309a57c77f9acf2db65aadf0898179cf93a", size = 49413, upload-time = "2026-02-07T10:45:33.841Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/c1/eb8f9debc45d3b7918a32ab756658a0904732f75e555402972246b0b8e71/tenacity-9.1.4-py3-none-any.whl", hash = "sha256:6095a360c919085f28c6527de529e76a06ad89b23659fa881ae0649b867a9d55", size = 28926, upload-time = "2026-02-07T10:45:32.24Z" }, +] + [[package]] name = "threadpoolctl" version = "3.6.0"