diff --git a/haoinvest/market/sources/_common.py b/haoinvest/market/sources/_common.py index 54a24ae..853704f 100644 --- a/haoinvest/market/sources/_common.py +++ b/haoinvest/market/sources/_common.py @@ -27,21 +27,23 @@ def bypass_proxy(): os.environ.update(saved) -def market_prefix(symbol: str) -> str: - """Return 'sh' or 'sz' based on A-share stock code convention.""" - if symbol.startswith(("6", "9")): - return "sh" - return "sz" +def _is_sh(symbol: str) -> bool: + """Check if symbol belongs to Shanghai Exchange. + + Shanghai codes: 6xxxxx (main/STAR), 9xxxxx, 5xxxxx (ETF/funds including + 51xxxx, 56xxxx cross-market ETFs that route via SH on quote APIs). + """ + return symbol.startswith(("5", "6", "9")) -def secid(symbol: str) -> str: - """Return eastmoney secid like '1.603618' (1=SH, 0=SZ).""" - return f"1.{symbol}" if symbol.startswith(("6", "9")) else f"0.{symbol}" +def market_prefix(symbol: str) -> str: + """Return 'sh' or 'sz' based on A-share stock code convention.""" + return "sh" if _is_sh(symbol) else "sz" def exchange_prefix(symbol: str) -> str: """Return 'SH' or 'SZ' for eastmoney web API code parameter.""" - return "SH" if symbol.startswith(("6", "9")) else "SZ" + return "SH" if _is_sh(symbol) else "SZ" def parse_float(value: Any) -> float | None: diff --git a/tests/test_market/test_sources/test_common.py b/tests/test_market/test_sources/test_common.py new file mode 100644 index 0000000..6107a89 --- /dev/null +++ b/tests/test_market/test_sources/test_common.py @@ -0,0 +1,61 @@ +"""Tests for market source common utilities — prefix routing.""" + +import pytest + +from haoinvest.market.sources._common import ( + exchange_prefix, + market_prefix, +) + + +class TestMarketPrefix: + """Verify symbol-to-exchange prefix routing.""" + + @pytest.mark.parametrize( + "symbol,expected", + [ + # Shanghai main board + ("600519", "sh"), + ("601877", "sh"), + # Shanghai STAR board + ("688001", "sh"), + # Shanghai ETF (51xxxx) + ("511360", "sh"), + ("513130", "sh"), + ("518880", "sh"), + # Cross-market ETF (56xxxx) — must route via sh + ("563020", "sh"), + ("560010", "sh"), + # Shenzhen main board + ("000001", "sz"), + ("000988", "sz"), + # Shenzhen SME + ("002001", "sz"), + ("002463", "sz"), + # Shenzhen ChiNext + ("300750", "sz"), + # Shanghai B-share + ("900001", "sh"), + # Shenzhen ETF (15xxxx) + ("159915", "sz"), + ], + ) + def test_market_prefix(self, symbol: str, expected: str) -> None: + assert market_prefix(symbol) == expected + + +class TestExchangePrefix: + """Verify eastmoney exchange prefix mapping.""" + + @pytest.mark.parametrize( + "symbol,expected", + [ + ("600519", "SH"), + ("563020", "SH"), + ("518880", "SH"), + ("000988", "SZ"), + ("002463", "SZ"), + ], + ) + def test_exchange_prefix(self, symbol: str, expected: str) -> None: + assert exchange_prefix(symbol) == expected