From fe1156bc79211f6db22bf3ffb21dfb8050dad085 Mon Sep 17 00:00:00 2001 From: Shuhao Qing Date: Tue, 7 Apr 2026 19:45:10 +0800 Subject: [PATCH 1/5] refactor(cli): extract shared utilities to analysis/cache.py and cli/_shared.py Move _ensure_prices_cached (with gap-fill logic) to analysis/cache.py as a shared function. Extract _init_db and _fetch_current_prices to cli/_shared.py. Update all CLI modules and tests to use the shared imports. Co-Authored-By: Claude Opus 4.6 --- haoinvest/analysis/cache.py | 31 +++++++++++++++++ haoinvest/cli/_shared.py | 28 +++++++++++++++ haoinvest/cli/analyze.py | 61 +++++++++------------------------ haoinvest/cli/guardrails.py | 37 +++++--------------- haoinvest/cli/journal.py | 14 +++----- haoinvest/cli/portfolio.py | 14 +++----- haoinvest/cli/strategy.py | 28 +++------------ tests/test_analysis_report.py | 2 +- tests/test_cli/test_analyze.py | 4 +-- tests/test_cli/test_strategy.py | 4 +-- 10 files changed, 101 insertions(+), 122 deletions(-) create mode 100644 haoinvest/analysis/cache.py create mode 100644 haoinvest/cli/_shared.py diff --git a/haoinvest/analysis/cache.py b/haoinvest/analysis/cache.py new file mode 100644 index 0000000..a5e0e86 --- /dev/null +++ b/haoinvest/analysis/cache.py @@ -0,0 +1,31 @@ +"""Price data caching — ensure price history is available before analysis.""" + +from datetime import date, timedelta + +from ..db import Database +from ..market import get_provider +from ..models import MarketType + + +def ensure_prices_cached( + db: Database, symbol: str, market_type: MarketType, start: date, end: date +) -> None: + """Fetch and cache price history if not already present. + + Includes gap-fill: if cached data doesn't cover the requested start date, + fetches the missing earlier portion. + """ + existing = db.get_prices(symbol, market_type, start, end) + if len(existing) > 10: + earliest_cached = min(b.trade_date for b in existing) + if earliest_cached <= start + timedelta(days=7): + return + provider = get_provider(market_type) + bars = provider.get_price_history(symbol, start, earliest_cached) + if bars: + db.save_prices(bars) + return + provider = get_provider(market_type) + bars = provider.get_price_history(symbol, start, end) + if bars: + db.save_prices(bars) diff --git a/haoinvest/cli/_shared.py b/haoinvest/cli/_shared.py new file mode 100644 index 0000000..0853491 --- /dev/null +++ b/haoinvest/cli/_shared.py @@ -0,0 +1,28 @@ +"""Shared CLI utilities — DB init, price fetching.""" + +from ..db import Database +from ..market import get_provider +from ..models import MarketType +from .formatters import error_output + + +def init_db() -> Database: + """Initialize database with schema.""" + db = Database() + db.init_schema() + return db + + +def fetch_current_prices(db: Database) -> dict[tuple[str, MarketType], float]: + """Fetch current prices for all non-zero holdings.""" + positions = db.get_positions(include_zero=False) + prices: dict[tuple[str, MarketType], float] = {} + for pos in positions: + try: + provider = get_provider(pos.market_type) + prices[(pos.symbol, pos.market_type)] = provider.get_current_price( + pos.symbol + ) + except Exception as e: + error_output(f"Failed to get price for {pos.symbol}: {e}") + return prices diff --git a/haoinvest/cli/analyze.py b/haoinvest/cli/analyze.py index a1b39ad..dbf6add 100644 --- a/haoinvest/cli/analyze.py +++ b/haoinvest/cli/analyze.py @@ -5,13 +5,14 @@ import typer +from ..analysis.cache import ensure_prices_cached from ..analysis.fundamental import analyze_stock from ..analysis.risk import calculate_risk_metrics, portfolio_correlation from ..analysis.signals import aggregate_signals from ..analysis.technical import analyze_technical, analyze_technical_multi from ..analysis.volume import analyze_volume -from ..db import Database from ..models import MarketType +from ._shared import init_db from .formatters import ( error_output, json_output, @@ -24,36 +25,6 @@ app = typer.Typer(help="Analysis — fundamental, risk, technical, volume, signals.") -def _init_db() -> Database: - db = Database() - db.init_schema() - return db - - -def _ensure_prices_cached( - db: Database, symbol: str, market_type: MarketType, start: date, end: date -) -> None: - """Fetch and cache price history if not already present.""" - from ..market import get_provider - - existing = db.get_prices(symbol, market_type, start, end) - if len(existing) > 10: - # Check if cached data covers the requested start date. - # If not, fetch the missing earlier portion. - earliest_cached = min(b.trade_date for b in existing) - if earliest_cached <= start + timedelta(days=7): - return - provider = get_provider(market_type) - bars = provider.get_price_history(symbol, start, earliest_cached) - if bars: - db.save_prices(bars) - return - provider = get_provider(market_type) - bars = provider.get_price_history(symbol, start, end) - if bars: - db.save_prices(bars) - - @app.command() def fundamental( symbol: str = typer.Argument( @@ -178,13 +149,13 @@ def risk( use_json: bool = typer.Option(False, "--json", help="Output as JSON"), ) -> None: """Risk metrics — volatility, drawdown, Sharpe ratio, Sortino ratio.""" - db = _init_db() + db = init_db() end_date = date.fromisoformat(end) if end else date.today() start_date = date.fromisoformat(start) if start else end_date - timedelta(days=365) if symbol: mt = MarketType(market_type) if market_type else _detect_market_type(symbol) - _ensure_prices_cached(db, symbol, mt, start_date, end_date) + ensure_prices_cached(db, symbol, mt, start_date, end_date) result = calculate_risk_metrics(db, symbol, mt, start_date, end_date) output = {"symbol": symbol, **result.model_dump()} if use_json: @@ -199,7 +170,7 @@ def risk( return results = [] for pos in positions: - _ensure_prices_cached(db, pos.symbol, pos.market_type, start_date, end_date) + ensure_prices_cached(db, pos.symbol, pos.market_type, start_date, end_date) metrics = calculate_risk_metrics( db, pos.symbol, pos.market_type, start_date, end_date ) @@ -223,7 +194,7 @@ def correlation( use_json: bool = typer.Option(False, "--json", help="Output as JSON"), ) -> None: """Correlation matrix between assets.""" - db = _init_db() + db = init_db() end_date = date.fromisoformat(end) if end else date.today() start_date = date.fromisoformat(start) if start else end_date - timedelta(days=365) @@ -231,7 +202,7 @@ def correlation( pairs = [] for s in symbol_list: mt = MarketType(market_type) if market_type else _detect_market_type(s) - _ensure_prices_cached(db, s, mt, start_date, end_date) + ensure_prices_cached(db, s, mt, start_date, end_date) pairs.append((s, mt)) result = portfolio_correlation(db, pairs, start_date, end_date) @@ -270,7 +241,7 @@ def technical( use_json: bool = typer.Option(False, "--json", help="Output as JSON"), ) -> None: """Technical indicators — MA, MACD, RSI, Bollinger Bands (daily/weekly/monthly).""" - db = _init_db() + db = init_db() end_date = date.fromisoformat(end) if end else date.today() start_date = date.fromisoformat(start) if start else end_date - timedelta(days=1095) symbol_list = [s.strip() for s in symbol.split(",")] @@ -278,7 +249,7 @@ def technical( if len(symbol_list) == 1: # Single symbol — multi-timeframe output mt = MarketType(market_type) if market_type else _detect_market_type(symbol) - _ensure_prices_cached(db, symbol, mt, start_date, end_date) + ensure_prices_cached(db, symbol, mt, start_date, end_date) multi = analyze_technical_multi( db, symbol, mt, start_date, end_date, verbose=verbose ) @@ -350,7 +321,7 @@ def technical( rows = [] for s in symbol_list: mt = MarketType(market_type) if market_type else _detect_market_type(s) - _ensure_prices_cached(db, s, mt, start_date, end_date) + ensure_prices_cached(db, s, mt, start_date, end_date) result = analyze_technical(db, s, mt, start_date, end_date) if result.message: rows.append({"Symbol": s, "Trend": result.message}) @@ -401,12 +372,12 @@ def volume( use_json: bool = typer.Option(False, "--json", help="Output as JSON"), ) -> None: """Volume analysis — anomaly detection, turnover ratio.""" - db = _init_db() + db = init_db() mt = MarketType(market_type) if market_type else _detect_market_type(symbol) end_date = date.fromisoformat(end) if end else date.today() start_date = date.fromisoformat(start) if start else end_date - timedelta(days=365) - _ensure_prices_cached(db, symbol, mt, start_date, end_date) + ensure_prices_cached(db, symbol, mt, start_date, end_date) result = analyze_volume(db, symbol, mt, start_date, end_date, verbose=verbose) if use_json: @@ -448,12 +419,12 @@ def signals( use_json: bool = typer.Option(False, "--json", help="Output as JSON"), ) -> None: """Signal summary — aggregated technical view.""" - db = _init_db() + db = init_db() mt = MarketType(market_type) if market_type else _detect_market_type(symbol) end_date = date.fromisoformat(end) if end else date.today() start_date = date.fromisoformat(start) if start else end_date - timedelta(days=365) - _ensure_prices_cached(db, symbol, mt, start_date, end_date) + ensure_prices_cached(db, symbol, mt, start_date, end_date) result = aggregate_signals(db, symbol, mt, start_date, end_date, verbose=verbose) if use_json: @@ -524,12 +495,12 @@ def report( """综合分析报告 — full report with buy-readiness checklist.""" from ..analysis.report import full_stock_report - db = _init_db() + db = init_db() mt = MarketType(market_type) if market_type else _detect_market_type(symbol) end_date = date.fromisoformat(end) if end else date.today() start_date = date.fromisoformat(start) if start else end_date - timedelta(days=365) - _ensure_prices_cached(db, symbol, mt, start_date, end_date) + ensure_prices_cached(db, symbol, mt, start_date, end_date) try: r = full_stock_report( diff --git a/haoinvest/cli/guardrails.py b/haoinvest/cli/guardrails.py index 9de4aee..756f3e9 100644 --- a/haoinvest/cli/guardrails.py +++ b/haoinvest/cli/guardrails.py @@ -4,36 +4,15 @@ import typer -from ..db import Database from ..market import get_provider from ..models import MarketType +from ._shared import fetch_current_prices, init_db from .formatters import error_output, json_output, kv_output from .market import _detect_market_type app = typer.Typer(help="Guardrails — position rules, alerts, trade review.") -def _init_db() -> Database: - db = Database() - db.init_schema() - return db - - -def _fetch_current_prices(db: Database) -> dict[tuple[str, MarketType], float]: - """Fetch current prices for all holdings.""" - positions = db.get_positions(include_zero=False) - prices: dict[tuple[str, MarketType], float] = {} - for pos in positions: - try: - provider = get_provider(pos.market_type) - prices[(pos.symbol, pos.market_type)] = provider.get_current_price( - pos.symbol - ) - except Exception as e: - error_output(f"Failed to get price for {pos.symbol}: {e}") - return prices - - @app.command("health-check") def health_check_cmd( cash: float = typer.Option(0.0, "--cash", help="Current cash balance"), @@ -42,8 +21,8 @@ def health_check_cmd( """Check current portfolio against guardrail rules.""" from ..guardrails.rules import health_check - db = _init_db() - prices = _fetch_current_prices(db) + db = init_db() + prices = fetch_current_prices(db) result = health_check(db, prices, cash_balance=cash) if use_json: @@ -65,8 +44,8 @@ def alerts_cmd( """Scan all positions for threshold violations.""" from ..guardrails.alerts import scan_alerts - db = _init_db() - prices = _fetch_current_prices(db) + db = init_db() + prices = fetch_current_prices(db) alerts = scan_alerts(db, prices) if use_json: @@ -96,7 +75,7 @@ def config_cmd( """View or set guardrail configuration.""" from ..guardrails.rules import load_config - db = _init_db() + db = init_db() if set_value: if "=" not in set_value: @@ -132,7 +111,7 @@ def pre_trade_data_cmd( from ..guardrails.pre_trade_data import collect_pre_trade_data mt = MarketType(market_type) if market_type else _detect_market_type(symbol) - db = _init_db() + db = init_db() # Get current price if not specified trade_price = price @@ -144,7 +123,7 @@ def pre_trade_data_cmd( error_output(f"Failed to get price for {symbol}: {e}") raise typer.Exit(1) - prices = _fetch_current_prices(db) + prices = fetch_current_prices(db) prices[(symbol, mt)] = trade_price result = collect_pre_trade_data( diff --git a/haoinvest/cli/journal.py b/haoinvest/cli/journal.py index 2e45275..1f5e4a2 100644 --- a/haoinvest/cli/journal.py +++ b/haoinvest/cli/journal.py @@ -4,20 +4,14 @@ import typer -from ..db import Database from ..journal import JournalManager from ..models import DecisionType, Emotion +from ._shared import init_db from .formatters import error_output, json_output, kv_output, tsv_output app = typer.Typer(help="Journal — record decisions, review patterns.") -def _init_db() -> Database: - db = Database() - db.init_schema() - return db - - @app.command() def add( content: str = typer.Argument(help="Journal entry content"), @@ -40,7 +34,7 @@ def add( em = Emotion(emotion) if emotion else None related = [s.strip() for s in symbols.split(",")] if symbols else [] - db = _init_db() + db = init_db() jm = JournalManager(db) entry_id = jm.create_entry( content, decision_type=dt, emotion=em, related_symbols=related @@ -62,7 +56,7 @@ def list_entries( use_json: bool = typer.Option(False, "--json", help="Output as JSON"), ) -> None: """View recent journal entries.""" - db = _init_db() + db = init_db() jm = JournalManager(db) entries = jm.get_entries(symbol=symbol, limit=limit) @@ -100,7 +94,7 @@ def review( use_json: bool = typer.Option(False, "--json", help="Output as JSON"), ) -> None: """Get decision stats or retrospective context for AI analysis.""" - db = _init_db() + db = init_db() jm = JournalManager(db) if entry_id is not None: diff --git a/haoinvest/cli/portfolio.py b/haoinvest/cli/portfolio.py index c13ee26..4651e08 100644 --- a/haoinvest/cli/portfolio.py +++ b/haoinvest/cli/portfolio.py @@ -5,29 +5,23 @@ import typer -from ..db import Database from ..market import get_provider from ..models import MarketType, Transaction, TransactionAction from ..portfolio.manager import PortfolioManager from ..portfolio.returns import portfolio_returns_summary, realized_pnl, unrealized_pnl +from ._shared import init_db from .formatters import error_output, json_output, kv_output, tsv_output from .market import _detect_market_type app = typer.Typer(help="Portfolio — holdings, trades, returns.") -def _init_db() -> Database: - db = Database() - db.init_schema() - return db - - @app.command("list") def list_holdings( use_json: bool = typer.Option(False, "--json", help="Output as JSON"), ) -> None: """View all current holdings.""" - db = _init_db() + db = init_db() pm = PortfolioManager(db) summary = pm.get_portfolio_summary() @@ -98,7 +92,7 @@ def add_trade( note=note, ) - db = _init_db() + db = init_db() # Advisory guardrail check before trade try: @@ -148,7 +142,7 @@ def returns( use_json: bool = typer.Option(False, "--json", help="Output as JSON"), ) -> None: """View returns — unrealized P&L for holdings.""" - db = _init_db() + db = init_db() if symbol: mt = MarketType(market_type) if market_type else _detect_market_type(symbol) diff --git a/haoinvest/cli/strategy.py b/haoinvest/cli/strategy.py index 307eb3e..cafa2f8 100644 --- a/haoinvest/cli/strategy.py +++ b/haoinvest/cli/strategy.py @@ -6,36 +6,18 @@ import typer -from ..db import Database +from ..analysis.cache import ensure_prices_cached from ..market import get_provider from ..models import MarketType from ..strategy.optimizer import suggest_allocation from ..strategy.rebalance import calculate_rebalance +from ._shared import init_db from .formatters import error_output, json_output, kv_output, tsv_output from .market import _detect_market_type app = typer.Typer(help="Strategy — optimize allocation, rebalance.") -def _init_db() -> Database: - db = Database() - db.init_schema() - return db - - -def _ensure_prices_cached( - db: Database, symbol: str, market_type: MarketType, start: date, end: date -) -> None: - """Fetch and cache price history if not already present.""" - existing = db.get_prices(symbol, market_type, start, end) - if len(existing) > 10: - return - provider = get_provider(market_type) - bars = provider.get_price_history(symbol, start, end) - if bars: - db.save_prices(bars) - - @app.command() def optimize( method: str = typer.Option( @@ -52,7 +34,7 @@ def optimize( use_json: bool = typer.Option(False, "--json", help="Output as JSON"), ) -> None: """Suggest optimal portfolio allocation.""" - db = _init_db() + db = init_db() end_date = date.today() start_date = date.fromisoformat(start) if start else end_date - timedelta(days=365) @@ -68,7 +50,7 @@ def optimize( # Ensure price data cached for symbol, mt in pairs: - _ensure_prices_cached(db, symbol, mt, start_date, end_date) + ensure_prices_cached(db, symbol, mt, start_date, end_date) try: result = suggest_allocation( @@ -101,7 +83,7 @@ def rebalance( use_json: bool = typer.Option(False, "--json", help="Output as JSON"), ) -> None: """Calculate rebalance trades to reach target allocation.""" - db = _init_db() + db = init_db() if target is None: error_output("--target is required. Provide target weights as JSON.") diff --git a/tests/test_analysis_report.py b/tests/test_analysis_report.py index cf880d8..1cacb1e 100644 --- a/tests/test_analysis_report.py +++ b/tests/test_analysis_report.py @@ -262,7 +262,7 @@ def test_report_command(self, tmp_path, monkeypatch): ), ) with ( - patch("haoinvest.cli.analyze._ensure_prices_cached"), + patch("haoinvest.cli.analyze.ensure_prices_cached"), patch( "haoinvest.analysis.report.full_stock_report", return_value=mock_report, diff --git a/tests/test_cli/test_analyze.py b/tests/test_cli/test_analyze.py index 1510f2f..a12fe83 100644 --- a/tests/test_cli/test_analyze.py +++ b/tests/test_cli/test_analyze.py @@ -181,7 +181,7 @@ def test_technical_batch_two_symbols(self, tmp_path, monkeypatch): bollinger=BollingerBands(position="下轨附近"), ), ] - with patch("haoinvest.cli.analyze._ensure_prices_cached"): + with patch("haoinvest.cli.analyze.ensure_prices_cached"): with patch( "haoinvest.cli.analyze.analyze_technical", side_effect=mock_results ): @@ -209,7 +209,7 @@ def test_risk_single_symbol(self, tmp_path, monkeypatch): total_return_pct=12.0, num_days=252, ) - with patch("haoinvest.cli.analyze._ensure_prices_cached"): + with patch("haoinvest.cli.analyze.ensure_prices_cached"): with patch( "haoinvest.cli.analyze.calculate_risk_metrics", return_value=mock_metrics, diff --git a/tests/test_cli/test_strategy.py b/tests/test_cli/test_strategy.py index c2b7d7d..d221717 100644 --- a/tests/test_cli/test_strategy.py +++ b/tests/test_cli/test_strategy.py @@ -24,7 +24,7 @@ def test_optimize_with_symbols(self, tmp_path, monkeypatch): weights={"600519": 0.5, "000001": 0.5}, explanation="等权配置", ) - with patch("haoinvest.cli.strategy._ensure_prices_cached"): + with patch("haoinvest.cli.strategy.ensure_prices_cached"): with patch( "haoinvest.cli.strategy.suggest_allocation", return_value=mock_result ): @@ -45,7 +45,7 @@ def test_optimize_with_symbols(self, tmp_path, monkeypatch): def test_optimize_invalid_method(self, tmp_path, monkeypatch): monkeypatch.setenv("HAOINVEST_DATA_DIR", str(tmp_path)) - with patch("haoinvest.cli.strategy._ensure_prices_cached"): + with patch("haoinvest.cli.strategy.ensure_prices_cached"): with patch( "haoinvest.cli.strategy.suggest_allocation", side_effect=ValueError("Unknown method"), From f65c9f4275196c7da47eeefbaaae2e93df746820 Mon Sep 17 00:00:00 2001 From: Shuhao Qing Date: Tue, 7 Apr 2026 19:49:00 +0800 Subject: [PATCH 2/5] feat(cli): add composable `analyze run` command with --modules flag New `analyze run --modules fundamental,technical,risk,...` command that composes analysis modules in a single CLI call. Supports: - Module selection via --modules (default: all) - Batch mode with comma-separated symbols - JSON output with --json - Checklist post-processing from fundamental + risk + signals Also adds analysis/registry.py (module registry), section_header() formatter, and compute_checklist_from_parts() in report.py. Co-Authored-By: Claude Opus 4.6 --- haoinvest/analysis/registry.py | 303 +++++++++++++++++++++++++++++++++ haoinvest/analysis/report.py | 82 +++++++++ haoinvest/cli/analyze.py | 174 +++++++++++++++++++ haoinvest/cli/formatters.py | 8 + haoinvest/cli/strategy.py | 1 - 5 files changed, 567 insertions(+), 1 deletion(-) create mode 100644 haoinvest/analysis/registry.py diff --git a/haoinvest/analysis/registry.py b/haoinvest/analysis/registry.py new file mode 100644 index 0000000..9df5030 --- /dev/null +++ b/haoinvest/analysis/registry.py @@ -0,0 +1,303 @@ +"""Analysis module registry — maps module names to runners and formatters. + +Each module defines how to execute an analysis and how to format its output +for the CLI. Used by the composable `analyze run` command. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import date +from typing import Any, Callable + +from ..db import Database +from ..models import MarketType + + +@dataclass +class AnalysisModule: + """A single analysis module that can be composed into a run command.""" + + name: str + runner: Callable[..., Any] + formatter: Callable[..., tuple[str, Any]] + needs_prices: bool = True + default_lookback_days: int = 365 + # Extra kwargs this module accepts (e.g., top_n for peer) + extra_kwargs: list[str] = field(default_factory=list) + + +# -- Runners: thin wrappers to normalize call signatures -- + + +def _run_fundamental( + db: Database, + symbol: str, + market_type: MarketType, + start: date, + end: date, + **kwargs: Any, +) -> Any: + from .fundamental import analyze_stock + + return analyze_stock(symbol, market_type) + + +def _run_technical( + db: Database, + symbol: str, + market_type: MarketType, + start: date, + end: date, + **kwargs: Any, +) -> Any: + from .technical import analyze_technical_multi + + verbose = kwargs.get("verbose", False) + return analyze_technical_multi(db, symbol, market_type, start, end, verbose=verbose) + + +def _run_risk( + db: Database, + symbol: str, + market_type: MarketType, + start: date, + end: date, + **kwargs: Any, +) -> Any: + from .risk import calculate_risk_metrics + + return calculate_risk_metrics(db, symbol, market_type, start, end) + + +def _run_volume( + db: Database, + symbol: str, + market_type: MarketType, + start: date, + end: date, + **kwargs: Any, +) -> Any: + from .volume import analyze_volume + + verbose = kwargs.get("verbose", False) + return analyze_volume(db, symbol, market_type, start, end, verbose=verbose) + + +def _run_signals( + db: Database, + symbol: str, + market_type: MarketType, + start: date, + end: date, + **kwargs: Any, +) -> Any: + from .signals import aggregate_signals + + verbose = kwargs.get("verbose", False) + return aggregate_signals(db, symbol, market_type, start, end, verbose=verbose) + + +def _run_peer( + db: Database, + symbol: str, + market_type: MarketType, + start: date, + end: date, + **kwargs: Any, +) -> Any: + from .peer import find_peers + + top_n = kwargs.get("top_n", 10) + return find_peers(symbol, market_type, top_n=top_n) + + +# -- Formatters: convert results to (output_type, data) -- +# "kv" -> dict for kv_output +# "tsv" -> (list[dict], columns) for tsv_output +# "technical" -> special multi-timeframe format + + +def _format_fundamental(result: Any, verbose: bool = False) -> tuple[str, Any]: + output: dict[str, Any] = { + "Symbol": result.symbol, + "Name": result.name, + "Price": result.current_price, + "PE(TTM)": result.pe_ratio, + "PB": result.pb_ratio, + "Sector": result.sector, + "Industry": result.industry, + "MarketCap": result.total_market_cap, + "ROE(%)": result.roe, + "ROA(%)": result.roa, + "DebtToEquity": result.debt_to_equity, + "RevenueGrowth": result.revenue_growth, + "ProfitMargin": result.profit_margin, + "GrossMargin": result.gross_margin, + "CurrentRatio": result.current_ratio, + "FreeCashFlow": result.free_cash_flow, + "PEG": result.peg_ratio, + "PE_Assessment": result.valuation.pe_assessment, + "PB_Assessment": result.valuation.pb_assessment, + "Overall_Valuation": result.valuation.overall, + } + if verbose and result.financial_health: + fh = result.financial_health + output.update( + { + "Profitability": fh.profitability, + "Growth": fh.growth, + "Leverage": fh.leverage, + "CashFlow": fh.cash_flow, + "FinancialHealth": fh.overall, + } + ) + return ("kv", output) + + +def _format_technical(result: Any, verbose: bool = False) -> tuple[str, Any]: + return ("technical", result) + + +def _format_risk(result: Any, verbose: bool = False) -> tuple[str, Any]: + return ( + "kv", + { + "Volatility": result.annualized_volatility, + "MaxDrawdown%": result.max_drawdown_pct, + "Sharpe": result.sharpe_ratio, + "Sortino": result.sortino_ratio, + "TotalReturn%": result.total_return_pct, + }, + ) + + +def _format_volume(result: Any, verbose: bool = False) -> tuple[str, Any]: + output: dict[str, Any] = {} + if result.message: + output["message"] = result.message + else: + output.update( + { + "LatestVolume": result.latest_volume, + "AvgVolume20d": result.avg_volume_20d, + "VolumeRatio": result.volume_ratio, + "IsAnomaly": result.is_anomaly, + "Assessment": result.assessment, + } + ) + if verbose and result.explanation: + output["Explanation"] = result.explanation + return ("kv", output) + + +def _format_signals(result: Any, verbose: bool = False) -> tuple[str, Any]: + output: dict[str, Any] = { + "OverallSignal": result.overall_signal, + "Confidence": result.confidence, + "Bullish": result.bullish_count, + "Bearish": result.bearish_count, + "Neutral": result.neutral_count, + } + for i, detail in enumerate(result.details): + output[f"Indicator_{i + 1}"] = detail + if verbose and result.explanation: + output["Explanation"] = result.explanation + return ("kv", output) + + +def _format_peer(result: Any, verbose: bool = False) -> tuple[str, Any]: + if result and "message" in result[0]: + return ("kv", {"message": result[0]["message"]}) + columns = ["Symbol", "Name", "Price", "Change%", "PE", "PB", "MarketCap"] + return ("tsv", (result, columns)) + + +def _format_checklist(result: Any, verbose: bool = False) -> tuple[str, Any]: + output: dict[str, Any] = {} + for item in result.items: + output[item.dimension] = f"{item.score}/5 — {item.assessment}" + output["TotalScore"] = f"{result.total_score}/{result.max_score}" + output["Recommendation"] = result.recommendation + return ("kv", output) + + +# -- Module resolution -- + + +def parse_modules(modules_str: str) -> list[str]: + """Parse --modules flag. 'all' expands to all modules in order.""" + if modules_str.strip().lower() == "all": + return list(MODULES.keys()) + names = [m.strip().lower() for m in modules_str.split(",") if m.strip()] + unknown = [n for n in names if n not in MODULES] + if unknown: + valid = ", ".join(MODULES.keys()) + raise ValueError(f"Unknown module(s): {', '.join(unknown)}. Valid: {valid}") + return names + + +def max_lookback_days(module_names: list[str]) -> int: + """Return the max default_lookback_days across requested modules.""" + return max( + (MODULES[n].default_lookback_days for n in module_names if n in MODULES), + default=365, + ) + + +def any_needs_prices(module_names: list[str]) -> bool: + """Return True if any requested module needs cached price data.""" + return any(MODULES[n].needs_prices for n in module_names if n in MODULES) + + +# -- Registry -- + +MODULES: dict[str, AnalysisModule] = { + "fundamental": AnalysisModule( + name="fundamental", + runner=_run_fundamental, + formatter=_format_fundamental, + needs_prices=False, + default_lookback_days=0, + ), + "technical": AnalysisModule( + name="technical", + runner=_run_technical, + formatter=_format_technical, + needs_prices=True, + default_lookback_days=1095, + ), + "risk": AnalysisModule( + name="risk", + runner=_run_risk, + formatter=_format_risk, + needs_prices=True, + ), + "volume": AnalysisModule( + name="volume", + runner=_run_volume, + formatter=_format_volume, + needs_prices=True, + ), + "signals": AnalysisModule( + name="signals", + runner=_run_signals, + formatter=_format_signals, + needs_prices=True, + ), + "peer": AnalysisModule( + name="peer", + runner=_run_peer, + formatter=_format_peer, + needs_prices=False, + default_lookback_days=0, + extra_kwargs=["top_n"], + ), + "checklist": AnalysisModule( + name="checklist", + runner=lambda *a, **kw: None, # placeholder — computed as post-processing + formatter=_format_checklist, + needs_prices=False, + default_lookback_days=0, + ), +} diff --git a/haoinvest/analysis/report.py b/haoinvest/analysis/report.py index 479da4f..927464d 100644 --- a/haoinvest/analysis/report.py +++ b/haoinvest/analysis/report.py @@ -1,12 +1,17 @@ """Analysis report assembly — combines fundamental and risk data.""" +from __future__ import annotations + from datetime import date from ..db import Database from ..models import ( BuyReadinessChecklist, ChecklistItem, + FundamentalAnalysis, MarketType, + RiskMetrics, + SignalSummary, StockReport, ) from .fundamental import analyze_stock @@ -89,6 +94,83 @@ def full_stock_report( return report +def compute_checklist_from_parts( + fundamental: "FundamentalAnalysis", + risk: "RiskMetrics", + signals: "SignalSummary | None" = None, +) -> BuyReadinessChecklist: + """Compute buy-readiness checklist from separate module results. + + Used by the composable `analyze run` command which doesn't build + a full StockReport. + """ + items: list[ChecklistItem] = [] + + val_score = _score_valuation(fundamental.valuation.overall) + items.append( + ChecklistItem( + dimension="估值", score=val_score, assessment=fundamental.valuation.overall + ) + ) + + prof_score = _score_profitability(fundamental.roe, fundamental.profit_margin) + items.append( + ChecklistItem( + dimension="盈利能力", + score=prof_score, + assessment=fundamental.financial_health.profitability + if fundamental.financial_health + else "N/A", + ) + ) + + growth_score = _score_growth(fundamental.revenue_growth) + items.append( + ChecklistItem( + dimension="成长性", + score=growth_score, + assessment=fundamental.financial_health.growth + if fundamental.financial_health + else "N/A", + ) + ) + + risk_score = _score_risk(risk.max_drawdown_pct, risk.sharpe_ratio) + risk_text = ( + f"最大回撤 {risk.max_drawdown_pct:.1f}%" if risk.max_drawdown_pct else "N/A" + ) + items.append( + ChecklistItem(dimension="风险", score=risk_score, assessment=risk_text) + ) + + if signals: + tech_score = _score_technical(signals.overall_signal, signals.confidence) + tech_text = f"{signals.overall_signal} (置信度: {signals.confidence})" + else: + tech_score = 3 + tech_text = "无技术面数据" + items.append( + ChecklistItem(dimension="技术面", score=tech_score, assessment=tech_text) + ) + + total = sum(item.score for item in items) + max_score = len(items) * 5 + + if total >= max_score * 0.75: + recommendation = "建议关注" + elif total >= max_score * 0.5: + recommendation = "谨慎观望" + else: + recommendation = "建议回避" + + return BuyReadinessChecklist( + items=items, + total_score=total, + max_score=max_score, + recommendation=recommendation, + ) + + def _compute_checklist(report: StockReport) -> BuyReadinessChecklist: """Score each dimension 1-5 and produce a recommendation.""" items: list[ChecklistItem] = [] diff --git a/haoinvest/cli/analyze.py b/haoinvest/cli/analyze.py index dbf6add..330add8 100644 --- a/haoinvest/cli/analyze.py +++ b/haoinvest/cli/analyze.py @@ -7,6 +7,12 @@ from ..analysis.cache import ensure_prices_cached from ..analysis.fundamental import analyze_stock +from ..analysis.registry import ( + MODULES, + any_needs_prices, + max_lookback_days, + parse_modules, +) from ..analysis.risk import calculate_risk_metrics, portfolio_correlation from ..analysis.signals import aggregate_signals from ..analysis.technical import analyze_technical, analyze_technical_multi @@ -17,6 +23,7 @@ error_output, json_output, kv_output, + section_header, timeframe_section, tsv_output, ) @@ -591,3 +598,170 @@ def report( print(f" {item.dimension}: {item.score}/5 — {item.assessment}") print(f" 总分: {r.checklist.total_score}/{r.checklist.max_score}") print(f" 建议: {r.checklist.recommendation}") + + +@app.command() +def run( + symbol: str = typer.Argument(help="Symbol(s), comma-separated for batch"), + modules: str = typer.Option( + "all", + "--modules", + help="Comma-separated: fundamental,technical,risk,volume,signals,peer,checklist", + ), + market_type: Optional[str] = typer.Option( + None, "--market-type", "-m", help="Override: a_share, crypto, us" + ), + start: Optional[str] = typer.Option(None, "--start", help="Start date YYYY-MM-DD"), + end: Optional[str] = typer.Option(None, "--end", help="End date YYYY-MM-DD"), + top_n: int = typer.Option(10, "--top-n", help="Number of peers (peer module)"), + verbose: bool = typer.Option( + False, "--verbose", "-v", help="Add explanations for learning" + ), + use_json: bool = typer.Option(False, "--json", help="Output as JSON"), +) -> None: + """Composable analysis — choose which modules to run in a single call. + + Examples: + analyze run 600519 # all modules + analyze run 600519 --modules fundamental,risk # selective + analyze run 600519,000858 --modules fundamental # batch + """ + try: + module_names = parse_modules(modules) + except ValueError as e: + error_output(str(e)) + raise typer.Exit(1) + + symbol_list = [s.strip() for s in symbol.split(",")] + end_date = date.fromisoformat(end) if end else date.today() + lookback = max_lookback_days(module_names) + start_date = ( + date.fromisoformat(start) if start else end_date - timedelta(days=lookback) + ) + + db = init_db() + needs_prices = any_needs_prices(module_names) + is_batch = len(symbol_list) > 1 + + # JSON accumulator for batch mode + json_results: dict = {} + + for sym in symbol_list: + mt = MarketType(market_type) if market_type else _detect_market_type(sym) + + if needs_prices: + ensure_prices_cached(db, sym, mt, start_date, end_date) + + results: dict = {} + for name in module_names: + if name == "checklist": + continue # post-processing, handled below + mod = MODULES[name] + try: + results[name] = mod.runner( + db, + sym, + mt, + start_date, + end_date, + verbose=verbose, + top_n=top_n, + ) + except (ValueError, RuntimeError) as e: + results[name] = {"error": str(e)} + + # Checklist post-processing + if "checklist" in module_names: + fund = results.get("fundamental") + risk_r = results.get("risk") + sig = results.get("signals") + if ( + fund + and risk_r + and not isinstance(fund, dict) + and not isinstance(risk_r, dict) + ): + from ..analysis.report import compute_checklist_from_parts + + results["checklist"] = compute_checklist_from_parts( + fund, risk_r, sig if sig and not isinstance(sig, dict) else None + ) + else: + results["checklist"] = { + "error": "checklist requires fundamental + risk modules" + } + + if use_json: + # Build JSON-serializable dict + sym_data: dict = {} + for name in module_names: + r = results.get(name) + if r is None: + continue + if isinstance(r, dict): + sym_data[name] = r + elif isinstance(r, list): + sym_data[name] = r + else: + sym_data[name] = r.model_dump() if hasattr(r, "model_dump") else r + if is_batch: + json_results[sym] = sym_data + else: + json_results = sym_data + else: + # Text output with section headers + for name in module_names: + r = results.get(name) + if r is None: + continue + + section_header(name, sym if is_batch else None) + + if isinstance(r, dict): + # Error or simple dict + kv_output(r) + continue + + mod = MODULES[name] + fmt_type, fmt_data = mod.formatter(r, verbose) + + if fmt_type == "kv": + kv_output(fmt_data) + elif fmt_type == "tsv": + rows, columns = fmt_data + tsv_output(rows, columns=columns) + elif fmt_type == "technical": + # Multi-timeframe technical output + multi = fmt_data + if multi.monthly: + timeframe_section("月线 (Monthly)", multi.monthly, verbose) + if multi.weekly: + timeframe_section("周线 (Weekly)", multi.weekly, verbose) + if multi.daily: + daily = multi.daily + if daily.message: + print(f" {daily.message}") + else: + daily_kv: dict = { + "Close": daily.latest_close, + "Trend": daily.moving_averages.trend, + "MACD_Signal": daily.macd.signal, + "RSI": daily.rsi.rsi, + "RSI_Zone": daily.rsi.assessment, + "BB_Position": daily.bollinger.position, + } + if verbose: + if daily.moving_averages.explanation: + daily_kv["MA_Explain"] = ( + daily.moving_averages.explanation + ) + if daily.macd.explanation: + daily_kv["MACD_Explain"] = daily.macd.explanation + if daily.rsi.explanation: + daily_kv["RSI_Explain"] = daily.rsi.explanation + if daily.bollinger.explanation: + daily_kv["BB_Explain"] = daily.bollinger.explanation + kv_output(daily_kv) + + if use_json: + json_output(json_results) diff --git a/haoinvest/cli/formatters.py b/haoinvest/cli/formatters.py index 07e87e1..d22da62 100644 --- a/haoinvest/cli/formatters.py +++ b/haoinvest/cli/formatters.py @@ -84,6 +84,14 @@ def timeframe_section( print(f" RSI说明: {result.rsi.explanation}") +def section_header(name: str, symbol: str | None = None) -> None: + """Print a section header like === risk === or === risk: 600519 ===.""" + if symbol: + print(f"\n=== {name}: {symbol} ===") + else: + print(f"\n=== {name} ===") + + def error_output(message: str) -> None: """Print error message to stderr.""" print(f"Error: {message}", file=sys.stderr) diff --git a/haoinvest/cli/strategy.py b/haoinvest/cli/strategy.py index cafa2f8..3403e24 100644 --- a/haoinvest/cli/strategy.py +++ b/haoinvest/cli/strategy.py @@ -8,7 +8,6 @@ from ..analysis.cache import ensure_prices_cached from ..market import get_provider -from ..models import MarketType from ..strategy.optimizer import suggest_allocation from ..strategy.rebalance import calculate_rebalance from ._shared import init_db From 755ad6342e9b2d3a50cb9451b42b82d51f67639c Mon Sep 17 00:00:00 2001 From: Shuhao Qing Date: Tue, 7 Apr 2026 19:51:35 +0800 Subject: [PATCH 3/5] test: add tests for cache, registry, and analyze run command - test_analysis_cache: gap-fill, skip-when-cached, fetch-when-empty - test_analysis_registry: parse_modules, max_lookback, needs_prices, module count - test_cli/test_analyze_run: module selection, JSON output, batch mode, checklist with/without deps, technical module Co-Authored-By: Claude Opus 4.6 --- tests/test_analysis_cache.py | 67 +++++++ tests/test_analysis_registry.py | 88 +++++++++ tests/test_cli/test_analyze_run.py | 305 +++++++++++++++++++++++++++++ 3 files changed, 460 insertions(+) create mode 100644 tests/test_analysis_cache.py create mode 100644 tests/test_analysis_registry.py create mode 100644 tests/test_cli/test_analyze_run.py diff --git a/tests/test_analysis_cache.py b/tests/test_analysis_cache.py new file mode 100644 index 0000000..1b9c743 --- /dev/null +++ b/tests/test_analysis_cache.py @@ -0,0 +1,67 @@ +"""Tests for haoinvest.analysis.cache — price caching logic.""" + +from datetime import date, timedelta +from unittest.mock import MagicMock, patch + +from haoinvest.analysis.cache import ensure_prices_cached +from haoinvest.models import MarketType, PriceBar + + +def _make_bar(d: date) -> PriceBar: + return PriceBar( + symbol="600519", + market_type=MarketType.A_SHARE, + trade_date=d, + open=100.0, + high=105.0, + low=95.0, + close=102.0, + volume=10000.0, + ) + + +class TestEnsurePricesCached: + def test_fetches_when_no_data(self, db): + """Should fetch full range when no cached data exists.""" + start = date(2024, 1, 1) + end = date(2024, 6, 1) + mock_bars = [_make_bar(start + timedelta(days=i)) for i in range(5)] + mock_provider = MagicMock() + mock_provider.get_price_history.return_value = mock_bars + + with patch("haoinvest.analysis.cache.get_provider", return_value=mock_provider): + ensure_prices_cached(db, "600519", MarketType.A_SHARE, start, end) + + mock_provider.get_price_history.assert_called_once_with("600519", start, end) + + def test_skips_when_sufficient_data(self, db): + """Should not fetch when >10 bars exist and cover start date.""" + start = date(2024, 1, 1) + end = date(2024, 6, 1) + # Pre-populate with 15 bars starting from start date + bars = [_make_bar(start + timedelta(days=i)) for i in range(15)] + db.save_prices(bars) + + mock_provider = MagicMock() + with patch("haoinvest.analysis.cache.get_provider", return_value=mock_provider): + ensure_prices_cached(db, "600519", MarketType.A_SHARE, start, end) + + mock_provider.get_price_history.assert_not_called() + + def test_gap_fills_earlier_data(self, db): + """Should fetch missing earlier portion when cached data starts too late.""" + start = date(2024, 1, 1) + end = date(2024, 6, 1) + # Cached data starts 30 days after requested start + cached_start = start + timedelta(days=30) + bars = [_make_bar(cached_start + timedelta(days=i)) for i in range(15)] + db.save_prices(bars) + + mock_provider = MagicMock() + mock_provider.get_price_history.return_value = [] + with patch("haoinvest.analysis.cache.get_provider", return_value=mock_provider): + ensure_prices_cached(db, "600519", MarketType.A_SHARE, start, end) + + mock_provider.get_price_history.assert_called_once_with( + "600519", start, cached_start + ) diff --git a/tests/test_analysis_registry.py b/tests/test_analysis_registry.py new file mode 100644 index 0000000..5e12159 --- /dev/null +++ b/tests/test_analysis_registry.py @@ -0,0 +1,88 @@ +"""Tests for haoinvest.analysis.registry — module registry logic.""" + +import pytest + +from haoinvest.analysis.registry import ( + MODULES, + any_needs_prices, + max_lookback_days, + parse_modules, +) + + +class TestParseModules: + def test_all_expands_to_all_modules(self): + result = parse_modules("all") + assert result == list(MODULES.keys()) + + def test_all_case_insensitive(self): + assert parse_modules("ALL") == list(MODULES.keys()) + assert parse_modules(" All ") == list(MODULES.keys()) + + def test_single_module(self): + assert parse_modules("fundamental") == ["fundamental"] + + def test_multiple_modules(self): + assert parse_modules("fundamental,risk,peer") == [ + "fundamental", + "risk", + "peer", + ] + + def test_strips_whitespace(self): + assert parse_modules(" fundamental , risk ") == ["fundamental", "risk"] + + def test_unknown_module_raises(self): + with pytest.raises(ValueError, match="Unknown module"): + parse_modules("fundamental,nonexistent") + + def test_empty_after_strip_ignored(self): + assert parse_modules("fundamental,,risk") == ["fundamental", "risk"] + + +class TestMaxLookbackDays: + def test_technical_dominates(self): + assert max_lookback_days(["fundamental", "technical", "risk"]) == 1095 + + def test_risk_only(self): + assert max_lookback_days(["risk"]) == 365 + + def test_fundamental_only(self): + assert max_lookback_days(["fundamental"]) == 0 + + +class TestAnyNeedsPrices: + def test_fundamental_no_prices(self): + assert any_needs_prices(["fundamental"]) is False + + def test_risk_needs_prices(self): + assert any_needs_prices(["risk"]) is True + + def test_mixed(self): + assert any_needs_prices(["fundamental", "peer"]) is False + assert any_needs_prices(["fundamental", "risk"]) is True + + +class TestModuleRegistry: + def test_all_modules_have_required_fields(self): + for name, mod in MODULES.items(): + assert mod.name == name + assert callable(mod.runner) + assert callable(mod.formatter) + assert isinstance(mod.needs_prices, bool) + assert isinstance(mod.default_lookback_days, int) + + def test_module_count(self): + assert len(MODULES) == 7 + + def test_module_names(self): + expected = { + "fundamental", + "technical", + "risk", + "volume", + "signals", + "peer", + "checklist", + } + assert set(MODULES.keys()) == expected diff --git a/tests/test_cli/test_analyze_run.py b/tests/test_cli/test_analyze_run.py new file mode 100644 index 0000000..7f5b2bc --- /dev/null +++ b/tests/test_cli/test_analyze_run.py @@ -0,0 +1,305 @@ +"""Tests for the composable `analyze run` CLI command.""" + +import json +from unittest.mock import patch + +from typer.testing import CliRunner + +from haoinvest.cli import app +from haoinvest.models import ( + BollingerBands, + FundamentalAnalysis, + MACDResult, + MovingAverages, + MultiTimeframeTechnical, + RiskMetrics, + RSIResult, + SignalSummary, + TechnicalIndicators, + ValuationAssessment, + VolumeAnalysis, +) + +runner = CliRunner() + +# Patch targets: the underlying analysis functions that registry runners call +_P_FUND = "haoinvest.analysis.fundamental.analyze_stock" +_P_TECH = "haoinvest.analysis.technical.analyze_technical_multi" +_P_RISK = "haoinvest.analysis.risk.calculate_risk_metrics" +_P_VOL = "haoinvest.analysis.volume.analyze_volume" +_P_SIG = "haoinvest.analysis.signals.aggregate_signals" +_P_PEER = "haoinvest.analysis.peer.find_peers" +_P_CACHE = "haoinvest.cli.analyze.ensure_prices_cached" + + +def _mock_fundamental(): + return FundamentalAnalysis( + symbol="600519", + name="贵州茅台", + sector="白酒", + market_type="a_share", + current_price=1800.0, + currency="CNY", + pe_ratio=35.2, + pb_ratio=12.1, + roe=25.0, + revenue_growth=15.0, + profit_margin=40.0, + valuation=ValuationAssessment( + pe_assessment="偏高", + pb_assessment="高估", + overall="偏高估", + ), + ) + + +def _mock_risk(): + return RiskMetrics( + annualized_volatility=25.5, + max_drawdown_pct=-15.3, + sharpe_ratio=0.85, + total_return_pct=12.0, + num_days=252, + ) + + +def _mock_signals(): + return SignalSummary( + symbol="600519", + market_type="a_share", + overall_signal="偏多", + confidence="中", + bullish_count=3, + bearish_count=1, + neutral_count=0, + details=["MA: 偏多", "MACD: 偏多"], + ) + + +def _mock_technical(): + daily = TechnicalIndicators( + symbol="600519", + market_type="a_share", + timeframe="daily", + latest_close=1800.0, + moving_averages=MovingAverages( + sma_5=1790.0, sma_10=1780.0, sma_20=1770.0, trend="上升趋势" + ), + macd=MACDResult(macd_line=5.0, signal_line=3.0, histogram=2.0, signal="金叉"), + rsi=RSIResult(rsi=55.0, assessment="中性"), + bollinger=BollingerBands(position="中轨附近"), + ) + return MultiTimeframeTechnical( + symbol="600519", market_type="a_share", daily=daily + ) + + +def _mock_volume(): + return VolumeAnalysis( + symbol="600519", + market_type="a_share", + latest_volume=50000.0, + avg_volume_20d=40000.0, + volume_ratio=1.25, + is_anomaly=False, + assessment="正常", + ) + + +def _mock_peer(): + return [ + {"Symbol": "600519", "Name": "贵州茅台", "Price": 1800, "PE": 35.2}, + {"Symbol": "000858", "Name": "五粮液", "Price": 160, "PE": 22.3}, + ] + + +class TestAnalyzeRunModuleSelection: + def test_invalid_module_name(self, tmp_path, monkeypatch): + monkeypatch.setenv("HAOINVEST_DATA_DIR", str(tmp_path)) + result = runner.invoke(app, ["analyze", "run", "600519", "--modules", "fake"]) + assert result.exit_code == 1 + + def test_single_module_fundamental(self, tmp_path, monkeypatch): + monkeypatch.setenv("HAOINVEST_DATA_DIR", str(tmp_path)) + with patch(_P_FUND, return_value=_mock_fundamental()): + result = runner.invoke( + app, ["analyze", "run", "600519", "--modules", "fundamental"] + ) + assert result.exit_code == 0 + assert "=== fundamental ===" in result.output + assert "贵州茅台" in result.output + assert "=== risk ===" not in result.output + + def test_two_modules(self, tmp_path, monkeypatch): + monkeypatch.setenv("HAOINVEST_DATA_DIR", str(tmp_path)) + with ( + patch(_P_FUND, return_value=_mock_fundamental()), + patch(_P_RISK, return_value=_mock_risk()), + patch(_P_CACHE), + ): + result = runner.invoke( + app, ["analyze", "run", "600519", "--modules", "fundamental,risk"] + ) + assert result.exit_code == 0 + assert "=== fundamental ===" in result.output + assert "=== risk ===" in result.output + assert "25.5" in result.output + + def test_peer_module(self, tmp_path, monkeypatch): + monkeypatch.setenv("HAOINVEST_DATA_DIR", str(tmp_path)) + with patch(_P_PEER, return_value=_mock_peer()): + result = runner.invoke( + app, ["analyze", "run", "600519", "--modules", "peer"] + ) + assert result.exit_code == 0 + assert "=== peer ===" in result.output + assert "五粮液" in result.output + + +class TestAnalyzeRunJsonOutput: + def test_json_single_symbol(self, tmp_path, monkeypatch): + monkeypatch.setenv("HAOINVEST_DATA_DIR", str(tmp_path)) + with ( + patch(_P_FUND, return_value=_mock_fundamental()), + patch(_P_RISK, return_value=_mock_risk()), + patch(_P_CACHE), + ): + result = runner.invoke( + app, + [ + "analyze", + "run", + "600519", + "--modules", + "fundamental,risk", + "--json", + ], + ) + assert result.exit_code == 0 + data = json.loads(result.output) + assert "fundamental" in data + assert "risk" in data + assert data["fundamental"]["symbol"] == "600519" + + def test_json_batch(self, tmp_path, monkeypatch): + monkeypatch.setenv("HAOINVEST_DATA_DIR", str(tmp_path)) + mock_a = _mock_fundamental() + mock_b = FundamentalAnalysis( + symbol="000858", + name="五粮液", + sector="白酒", + market_type="a_share", + current_price=160.0, + currency="CNY", + pe_ratio=22.3, + pb_ratio=6.8, + valuation=ValuationAssessment( + pe_assessment="合理", + pb_assessment="偏高", + overall="估值合理", + ), + ) + calls = iter([mock_a, mock_b]) + with patch(_P_FUND, side_effect=lambda *a, **k: next(calls)): + result = runner.invoke( + app, + [ + "analyze", + "run", + "600519,000858", + "--modules", + "fundamental", + "--json", + ], + ) + assert result.exit_code == 0 + data = json.loads(result.output) + assert "600519" in data + assert "000858" in data + + +class TestAnalyzeRunBatch: + def test_batch_fundamental_text(self, tmp_path, monkeypatch): + monkeypatch.setenv("HAOINVEST_DATA_DIR", str(tmp_path)) + mock_a = _mock_fundamental() + mock_b = FundamentalAnalysis( + symbol="000858", + name="五粮液", + sector="白酒", + market_type="a_share", + current_price=160.0, + currency="CNY", + pe_ratio=22.3, + pb_ratio=6.8, + valuation=ValuationAssessment( + pe_assessment="合理", + pb_assessment="偏高", + overall="估值合理", + ), + ) + calls = iter([mock_a, mock_b]) + with patch(_P_FUND, side_effect=lambda *a, **k: next(calls)): + result = runner.invoke( + app, + ["analyze", "run", "600519,000858", "--modules", "fundamental"], + ) + assert result.exit_code == 0 + assert "=== fundamental: 600519 ===" in result.output + assert "=== fundamental: 000858 ===" in result.output + assert "贵州茅台" in result.output + assert "五粮液" in result.output + + +class TestAnalyzeRunChecklist: + def test_checklist_with_dependencies(self, tmp_path, monkeypatch): + monkeypatch.setenv("HAOINVEST_DATA_DIR", str(tmp_path)) + with ( + patch(_P_FUND, return_value=_mock_fundamental()), + patch(_P_RISK, return_value=_mock_risk()), + patch(_P_SIG, return_value=_mock_signals()), + patch(_P_CACHE), + ): + result = runner.invoke( + app, + [ + "analyze", + "run", + "600519", + "--modules", + "fundamental,risk,signals,checklist", + ], + ) + assert result.exit_code == 0 + assert "=== checklist ===" in result.output + assert "Recommendation" in result.output + + def test_checklist_without_deps_shows_error(self, tmp_path, monkeypatch): + monkeypatch.setenv("HAOINVEST_DATA_DIR", str(tmp_path)) + with patch(_P_FUND, return_value=_mock_fundamental()): + result = runner.invoke( + app, + [ + "analyze", + "run", + "600519", + "--modules", + "fundamental,checklist", + ], + ) + assert result.exit_code == 0 + assert "checklist requires fundamental + risk" in result.output + + +class TestAnalyzeRunTechnical: + def test_technical_module(self, tmp_path, monkeypatch): + monkeypatch.setenv("HAOINVEST_DATA_DIR", str(tmp_path)) + with ( + patch(_P_TECH, return_value=_mock_technical()), + patch(_P_CACHE), + ): + result = runner.invoke( + app, ["analyze", "run", "600519", "--modules", "technical"] + ) + assert result.exit_code == 0 + assert "=== technical ===" in result.output + assert "上升趋势" in result.output From b49f01388b7181ae5e9163190c57f40a39b1539f Mon Sep 17 00:00:00 2001 From: Shuhao Qing Date: Tue, 7 Apr 2026 20:26:11 +0800 Subject: [PATCH 4/5] docs(skill): update haoinvest skill to use analyze run command Update workflows 1, 3, 4 to use composable `analyze run` instead of multiple individual commands. Add `analyze run` to command reference. Co-Authored-By: Claude Opus 4.6 --- .claude/skills/haoinvest/SKILL.md | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/.claude/skills/haoinvest/SKILL.md b/.claude/skills/haoinvest/SKILL.md index 559aa89..4b52eef 100644 --- a/.claude/skills/haoinvest/SKILL.md +++ b/.claude/skills/haoinvest/SKILL.md @@ -12,15 +12,15 @@ All-in-one investment management via CLI + Claude Code agent. CLI does data + co ### Workflow 1: "帮我分析 XXX" — Analyze a stock -1. Run comprehensive report: +1. Run composable analysis (single call, all modules including peer): ```bash - uv run haoinvest analyze report + uv run haoinvest analyze run ``` -2. For A-shares, also run peer comparison: + Or select specific modules: ```bash - uv run haoinvest analyze peer + uv run haoinvest analyze run --modules fundamental,risk,peer ``` -3. Interpret ALL sections in Chinese: +2. Interpret ALL sections in Chinese: - 估值: Is it cheap or expensive? Compare PE/PB to peers. - 财务健康: Is the company profitable and growing? - 风险: How volatile is it? What's the worst drawdown? @@ -51,9 +51,9 @@ All-in-one investment management via CLI + Claude Code agent. CLI does data + co **触发条件**: 用户表达买入/卖出意图时。 -1. Run comprehensive report + guardrails pre-trade data (2 calls): +1. Run composable analysis + guardrails pre-trade data (2 calls): ```bash - uv run haoinvest analyze report --json + uv run haoinvest analyze run --json uv run haoinvest guardrails pre-trade-data -m --json ``` If user hasn't specified quantity, ask first. If price not known, the command auto-fetches. @@ -129,15 +129,11 @@ All-in-one investment management via CLI + Claude Code agent. CLI does data + co ### Workflow 4: "对比 A 和 B" — Compare stocks -1. Batch fundamental comparison: +1. Batch composable analysis (single call): ```bash - uv run haoinvest analyze fundamental , --verbose + uv run haoinvest analyze run , --modules fundamental,risk,signals ``` -2. Batch technical comparison: - ```bash - uv run haoinvest analyze technical , - ``` -3. Summarize: who's better on what dimension, and overall recommendation +2. Summarize: who's better on what dimension, and overall recommendation ### Workflow 5: "定期体检" — Portfolio checkup @@ -179,6 +175,13 @@ uv run haoinvest market sector # Sector constitu ### Analysis ```bash +# Composable analysis (preferred — single call, choose modules) +uv run haoinvest analyze run # All modules (fundamental,technical,risk,volume,signals,peer,checklist) +uv run haoinvest analyze run --modules fundamental,risk,peer # Selective modules +uv run haoinvest analyze run --json # JSON output for structured parsing +uv run haoinvest analyze run , --modules fundamental # Batch comparison + +# Individual commands (still available) uv run haoinvest analyze report # Full report + buy-readiness checklist uv run haoinvest analyze fundamental [--verbose] # Valuation + financial health (batch OK) uv run haoinvest analyze technical # MA/MACD/RSI/BB (batch OK) From d3a79565c79413ae3ce2bdca678b0fca3334854a Mon Sep 17 00:00:00 2001 From: Shuhao Qing Date: Tue, 7 Apr 2026 20:27:20 +0800 Subject: [PATCH 5/5] docs: update CLAUDE.md and README.md for composable analyze run - CLAUDE.md: add cache.py, registry.py, _shared.py, guardrails.py to structure - README.md: add composable analysis feature, analyze run examples Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 8 ++++++-- README.md | 8 +++++++- tests/test_cli/test_analyze_run.py | 4 +--- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 2d3448b..be86ca0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -35,8 +35,10 @@ haoinvest/ ├── journal.py # Investment journal with emotion/decision tagging ├── cli/ # Typer CLI — entry point: `uv run haoinvest` │ ├── __init__.py # App + subcommand registration -│ ├── formatters.py # Output formatting (text/JSON) -│ ├── analyze.py # analyze subcommand +│ ├── _shared.py # Shared CLI utilities (init_db, fetch_current_prices) +│ ├── formatters.py # Output formatting (text/JSON/section headers) +│ ├── analyze.py # analyze subcommand (includes composable `run` command) +│ ├── guardrails.py # guardrails subcommand │ ├── journal.py # journal subcommand │ ├── market.py # market subcommand │ ├── portfolio.py # portfolio subcommand @@ -48,6 +50,8 @@ haoinvest/ │ └── optimization_engine.py # Portfolio optimization (HRP, min vol, max Sharpe) ├── portfolio/ # Trade recording, position tracking, returns (TWR) ├── analysis/ # Thin adapters over engine +│ ├── cache.py # Price data caching (ensure_prices_cached) +│ ├── registry.py # Module registry for composable `analyze run` │ ├── fundamental.py # Valuation assessment (PE/PB/ROE), financial health │ ├── technical.py # Technical indicator adapter │ ├── risk.py # Risk metrics adapter diff --git a/README.md b/README.md index 611809a..df4e703 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ Built for a beginner investor in China covering A-shares, US stocks, HK stocks, - **Fundamental Analysis** — PE/PB/ROE valuation assessment with financial health scoring; batch support for multi-symbol comparison - **Peer Comparison** — Find and compare same-sector stocks by valuation and performance - **Sector Browsing** — Browse A-share industry sectors and their constituent stocks +- **Composable Analysis** — `analyze run` command with `--modules` flag to compose any combination of fundamental, technical, risk, volume, signals, peer, and checklist in a single call - **Comprehensive Report** — Full stock report with buy-readiness checklist combining fundamental, technical, and risk analysis - **Risk Metrics** — Annualized volatility, max drawdown, Sharpe ratio, Sortino ratio (powered by QuantStats) - **Technical Analysis** — MA, MACD, RSI, Bollinger Bands with Chinese explanations (powered by pandas-ta) @@ -47,7 +48,12 @@ uv run haoinvest portfolio list # View holdings uv run haoinvest portfolio add-trade 600519 buy 100 1800.50 uv run haoinvest portfolio returns # P&L summary -# Analysis +# Composable analysis (preferred — single call, choose modules) +uv run haoinvest analyze run 600519 # All modules +uv run haoinvest analyze run 600519 --modules fundamental,risk,peer # Selective +uv run haoinvest analyze run 600519,000858 --modules fundamental # Batch + +# Individual analysis commands uv run haoinvest analyze fundamental 600519 # PE/PB valuation uv run haoinvest analyze fundamental 600519,000858 # Batch comparison uv run haoinvest analyze risk --symbol NVDA # Volatility, Sharpe, drawdown diff --git a/tests/test_cli/test_analyze_run.py b/tests/test_cli/test_analyze_run.py index 7f5b2bc..5c8050c 100644 --- a/tests/test_cli/test_analyze_run.py +++ b/tests/test_cli/test_analyze_run.py @@ -89,9 +89,7 @@ def _mock_technical(): rsi=RSIResult(rsi=55.0, assessment="中性"), bollinger=BollingerBands(position="中轨附近"), ) - return MultiTimeframeTechnical( - symbol="600519", market_type="a_share", daily=daily - ) + return MultiTimeframeTechnical(symbol="600519", market_type="a_share", daily=daily) def _mock_volume():