diff --git a/.env.template b/.env.template index c940f010..7ca02b5a 100644 --- a/.env.template +++ b/.env.template @@ -120,6 +120,17 @@ # fallback (default): use configured models only when upstream /models fails, is nil, or is empty. # allowlist: expose only the configured models for providers that define a list, and skip their upstream /models calls. # CONFIGURED_PROVIDER_MODELS_MODE=fallback + +# --- Intelligent routing --- +# When a request does not pin a provider and multiple providers serve the same +# model, the gateway scores candidates and routes to the best one. Single-provider +# configs are unaffected (one candidate => no decision needed). +# Strategy: balanced (default), cost_only, latency_only, first_fit +# MODEL_ROUTING_STRATEGY=balanced +# Balanced weights as "cost_weight,latency_weight" (only affects balanced). +# MODEL_ROUTING_STRATEGY_WEIGHTS=0.6,0.4 +# Filter candidates at/above this smoothed error rate [0,1] (default: 0.5). +# MODEL_ROUTING_MAX_ERROR_RATE=0.5 # Examples: OPENROUTER_MODELS=..., OPENROUTER_EU_MODELS=..., AZURE_MODELS=..., VLLM_MODELS=... # Fallback & Workflow Configuration diff --git a/.gitignore b/.gitignore index 4fbc41a2..1625518e 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,4 @@ # Superpower design docs and plans (never commit) /docs/superpowers/ +.superpowers/ diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 00000000..f6906f2e --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,10 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# 已忽略包含查询文件的默认文件夹 +/queries/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/GoModel.iml b/.idea/GoModel.iml new file mode 100644 index 00000000..5e764c4f --- /dev/null +++ b/.idea/GoModel.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/go.imports.xml b/.idea/go.imports.xml new file mode 100644 index 00000000..644cdf0b --- /dev/null +++ b/.idea/go.imports.xml @@ -0,0 +1,10 @@ + + + + + + \ No newline at end of file diff --git a/.idea/golinter.xml b/.idea/golinter.xml new file mode 100644 index 00000000..1ccf3ec6 --- /dev/null +++ b/.idea/golinter.xml @@ -0,0 +1,7 @@ + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 00000000..b6bc1834 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 00000000..35eb1ddf --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 1bbb778b..feb32791 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -131,3 +131,4 @@ Full reference: `.env.template` and `config/config.yaml` - **Guardrails:** Configured via `config/config.yaml` only (except `GUARDRAILS_ENABLED` env var) - **Providers:** `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `GEMINI_API_KEY`, `USE_GOOGLE_GEMINI_NATIVE_API` (true by default; false uses Gemini's OpenAI-compatible chat API), `XAI_API_KEY`, `GROQ_API_KEY`, `OPENROUTER_API_KEY`, `ZAI_API_KEY`, `ZAI_BASE_URL` (optional Z.ai endpoint override), `MINIMAX_API_KEY`, `MINIMAX_BASE_URL` (optional MiniMax endpoint override), `XIAOMI_API_KEY`, `XIAOMI_BASE_URL` (optional Xiaomi MiMo endpoint override), `OPENCODE_GO_API_KEY`, `OPENCODE_GO_BASE_URL` (optional OpenCode Go/Zen endpoint override; default `https://opencode.ai/zen/go/v1`), `OPENCODE_GO_MESSAGES_MODELS` (optional comma-separated model IDs routed to the Anthropic-native `/messages` endpoint instead of `/chat/completions`; default `qwen3.7-max`), `BAILIAN_API_KEY`, `BAILIAN_BASE_URL` (optional Bailian base URL for region switching; default `https://dashscope.aliyuncs.com/compatible-mode/v1`), `AZURE_API_KEY`, `AZURE_BASE_URL` (Azure OpenAI deployment base URL), `AZURE_API_VERSION` (optional Azure API version), `ORACLE_API_KEY` (Oracle API key), `ORACLE_BASE_URL` (Oracle OpenAI-compatible base URL), `[_SUFFIX]_MODELS` (comma-separated configured model list for any provider type), `OLLAMA_BASE_URL`, `VLLM_BASE_URL`, `VLLM_API_KEY` (optional upstream vLLM bearer token) - **Provider model metadata:** `providers..models` accepts either model IDs (strings) or `{id, metadata}` objects. When `metadata` is supplied (`display_name`, `context_window`, `max_output_tokens`, `modes`, `capabilities`, `pricing`, …) it is merged onto the remote ai-model-list entry during enrichment, with operator values winning per-field. Primary use case: advertising context windows, capabilities, and pricing for local models (Ollama) and other custom endpoints whose IDs are not in the upstream registry. +- **Intelligent routing:** When a request does not pin a provider and multiple configured providers serve the same model ID, the gateway scores candidates and routes to the best one (single-provider configs are unaffected — one candidate means no decision). `MODEL_ROUTING_STRATEGY` (`balanced` default; also `cost_only`, `latency_only`, `first_fit`), `MODEL_ROUTING_STRATEGY_WEIGHTS` (`0.6,0.4` = cost_weight,latency_weight; balanced only), `MODEL_ROUTING_MAX_ERROR_RATE` (`0.5`; candidates at/above this smoothed error rate are filtered before scoring). Latency is tracked per provider as in-memory EWMA (alpha 0.1 p50 / 0.05 p99 / 0.2 error-rate), never persisted. Per-request overrides via request headers `X-GoModel-Routing-Strategy` and `X-GoModel-Routing-Weights` (`cost,latency`); invalid header values are silently dropped (WARN log) and fall back to the global strategy. Explicit provider-qualified requests (`openai-east/gpt-4o`) always bypass intelligent routing. diff --git a/c.out b/c.out new file mode 100644 index 00000000..2be45a15 --- /dev/null +++ b/c.out @@ -0,0 +1,35 @@ +mode: set +gomodel/internal/providers/bailian/bailian.go:39.86,48.2 2 1 +gomodel/internal/providers/bailian/bailian.go:52.97,60.2 1 1 +gomodel/internal/providers/bailian/bailian.go:62.51,64.2 1 1 +gomodel/internal/providers/bailian/bailian.go:67.43,69.2 1 1 +gomodel/internal/providers/bailian/bailian.go:73.107,75.2 1 1 +gomodel/internal/providers/bailian/bailian.go:78.108,80.2 1 1 +gomodel/internal/providers/bailian/bailian.go:83.82,85.2 1 1 +gomodel/internal/providers/bailian/bailian.go:88.112,90.2 1 1 +gomodel/internal/providers/bailian/bailian.go:93.108,95.2 1 1 +gomodel/internal/providers/bailian/bailian.go:101.113,103.2 1 1 +gomodel/internal/providers/bailian/bailian.go:106.118,108.2 1 1 +gomodel/internal/providers/bailian/bailian.go:111.106,113.2 1 1 +gomodel/internal/providers/bailian/bailian.go:116.90,118.2 1 1 +gomodel/internal/providers/bailian/bailian.go:121.111,123.2 1 1 +gomodel/internal/providers/bailian/bailian.go:126.93,128.2 1 1 +gomodel/internal/providers/bailian/bailian.go:131.104,133.2 1 1 +gomodel/internal/providers/bailian/bailian.go:136.107,138.16 2 1 +gomodel/internal/providers/bailian/bailian.go:138.16,140.3 1 0 +gomodel/internal/providers/bailian/bailian.go:141.2,142.18 2 1 +gomodel/internal/providers/bailian/bailian.go:146.124,148.16 2 1 +gomodel/internal/providers/bailian/bailian.go:148.16,150.3 1 0 +gomodel/internal/providers/bailian/bailian.go:151.2,151.27 1 1 +gomodel/internal/providers/bailian/bailian.go:151.27,153.3 1 0 +gomodel/internal/providers/bailian/bailian.go:154.2,154.18 1 1 +gomodel/internal/providers/bailian/bailian.go:158.86,160.16 2 1 +gomodel/internal/providers/bailian/bailian.go:160.16,162.3 1 0 +gomodel/internal/providers/bailian/bailian.go:163.2,164.18 2 1 +gomodel/internal/providers/bailian/bailian.go:168.97,170.2 1 1 +gomodel/internal/providers/bailian/bailian.go:173.102,175.2 1 1 +gomodel/internal/providers/bailian/bailian.go:189.67,190.40 1 1 +gomodel/internal/providers/bailian/bailian.go:190.40,192.3 1 1 +gomodel/internal/providers/bailian/bailian.go:194.2,194.82 1 1 +gomodel/internal/providers/bailian/bailian.go:194.82,198.3 3 1 +gomodel/internal/providers/bailian/bailian.go:199.2,206.16 6 1 diff --git a/config/config.example.yaml b/config/config.example.yaml index ad5b619b..305b7243 100644 --- a/config/config.example.yaml +++ b/config/config.example.yaml @@ -356,3 +356,15 @@ providers: # input_per_mtok: 0 # output_per_mtok: 0 # - Gemma4-31B + +# --- Intelligent routing --- +# When a request does not pin a provider and multiple providers serve the same +# model, the gateway scores candidates and routes to the best one. Single-provider +# configs are unaffected (one candidate means no decision needed). +# Default: strategy=balanced, weights={cost:0.6 latency:0.4}, max_error_rate=0.5 +# router: +# strategy: balanced # balanced, cost_only, latency_only, first_fit +# weights: +# cost: 0.6 +# latency: 0.4 +# max_error_rate: 0.5 diff --git a/config/config.go b/config/config.go index 59bc5b3a..09fdb1a3 100644 --- a/config/config.go +++ b/config/config.go @@ -27,6 +27,7 @@ type Config struct { Fallback FallbackConfig `yaml:"fallback"` Workflows WorkflowsConfig `yaml:"workflows"` Resilience ResilienceConfig `yaml:"resilience"` + Router RouterConfig `yaml:"router"` } // LoadResult is returned by Load and bundles the application config with the raw @@ -120,6 +121,7 @@ func buildDefaultConfig() *Config { Retry: DefaultRetryConfig(), CircuitBreaker: DefaultCircuitBreakerConfig(), }, + Router: DefaultRouterConfig(), Admin: AdminConfig{ EndpointsEnabled: true, UIEnabled: true, @@ -194,6 +196,10 @@ func Load() (*LoadResult, error) { return nil, err } + if err := ValidateRouterConfig(&cfg.Router); err != nil { + return nil, err + } + return &LoadResult{ Config: cfg, RawProviders: rawProviders, diff --git a/config/router.go b/config/router.go new file mode 100644 index 00000000..4dfdc189 --- /dev/null +++ b/config/router.go @@ -0,0 +1,127 @@ +package config + +import ( + "fmt" + "math" + "strconv" + "strings" +) + +// RouterConfig configures intelligent provider selection. When a request does +// not pin a provider and multiple providers serve the same model, the gateway +// scores candidates by cost and/or latency using the configured strategy. +type RouterConfig struct { + // Strategy is the default strategy id: "balanced", "cost_only", + // "latency_only", or "first_fit". Empty defaults to "balanced". + Strategy string `yaml:"strategy" json:"strategy" env:"MODEL_ROUTING_STRATEGY"` + + // Weights tunes the balanced strategy. CostWeight and LatencyWeight are + // ignored by other strategies. + Weights RouterWeights `yaml:"weights" json:"weights"` + + // WeightsCSV is the env-only form of Weights as "cost,latency" + // (e.g. "0.6,0.4"). Parsed into Weights during validation. + WeightsCSV string `yaml:"-" json:"-" env:"MODEL_ROUTING_STRATEGY_WEIGHTS"` + + // MaxErrorRate filters candidates at/above this smoothed error ratio before + // scoring. Zero falls back to the strategy default (0.5). + MaxErrorRate float64 `yaml:"max_error_rate" json:"max_error_rate" env:"MODEL_ROUTING_MAX_ERROR_RATE"` +} + +// RouterWeights tunes the balanced strategy's cost/latency trade-off. +type RouterWeights struct { + Cost float64 `yaml:"cost" json:"cost" env:"MODEL_ROUTING_COST_WEIGHT"` + Latency float64 `yaml:"latency" json:"latency" env:"MODEL_ROUTING_LATENCY_WEIGHT"` +} + +// RouterStrategyBalanced and the other built-in strategy ids. +const ( + RouterStrategyBalanced = "balanced" + RouterStrategyCostOnly = "cost_only" + RouterStrategyLatencyOnly = "latency_only" + RouterStrategyFirstFit = "first_fit" +) + +// DefaultRouterConfig returns the default router configuration: balanced strategy +// with 0.6 cost / 0.4 latency weights and a 0.5 max error rate. +func DefaultRouterConfig() RouterConfig { + return RouterConfig{ + Strategy: RouterStrategyBalanced, + Weights: RouterWeights{ + Cost: 0.6, + Latency: 0.4, + }, + MaxErrorRate: 0.5, + } +} + +// ValidateRouterConfig normalizes and validates the router config, applying +// defaults for empty/invalid fields. +func ValidateRouterConfig(cfg *RouterConfig) error { + if cfg.Strategy == "" { + cfg.Strategy = RouterStrategyBalanced + } + strategy := strings.ToLower(strings.TrimSpace(cfg.Strategy)) + switch strategy { + case RouterStrategyBalanced, RouterStrategyCostOnly, RouterStrategyLatencyOnly, RouterStrategyFirstFit: + default: + return fmt.Errorf("invalid router.strategy %q: must be one of balanced, cost_only, latency_only, first_fit", cfg.Strategy) + } + cfg.Strategy = strategy + + // MODEL_ROUTING_STRATEGY_WEIGHTS overrides YAML weights when set. + if strings.TrimSpace(cfg.WeightsCSV) != "" { + cost, lat, err := parseWeightsCSV(cfg.WeightsCSV) + if err != nil { + return fmt.Errorf("invalid MODEL_ROUTING_STRATEGY_WEIGHTS: %w", err) + } + cfg.Weights = RouterWeights{Cost: cost, Latency: lat} + } + + if math.IsNaN(cfg.Weights.Cost) || math.IsNaN(cfg.Weights.Latency) || + math.IsInf(cfg.Weights.Cost, 0) || math.IsInf(cfg.Weights.Latency, 0) { + return fmt.Errorf("router.weights must be finite numbers, got cost=%v latency=%v", cfg.Weights.Cost, cfg.Weights.Latency) + } + if cfg.Weights.Cost < 0 || cfg.Weights.Latency < 0 { + return fmt.Errorf("router.weights must be non-negative, got cost=%v latency=%v", cfg.Weights.Cost, cfg.Weights.Latency) + } + if cfg.Weights.Cost == 0 && cfg.Weights.Latency == 0 && strategy == RouterStrategyBalanced { + cfg.Weights = RouterWeights{Cost: 0.6, Latency: 0.4} + } + if math.IsNaN(cfg.MaxErrorRate) || math.IsInf(cfg.MaxErrorRate, 0) { + return fmt.Errorf("router.max_error_rate must be a finite number, got %v", cfg.MaxErrorRate) + } + if cfg.MaxErrorRate < 0 || cfg.MaxErrorRate > 1 { + return fmt.Errorf("router.max_error_rate must be in [0, 1], got %v", cfg.MaxErrorRate) + } + return nil +} + +// parseWeightsCSV parses a "cost,latency" string into two non-negative floats. +func parseWeightsCSV(s string) (cost, latency float64, err error) { + parts := strings.Split(s, ",") + if len(parts) != 2 { + return 0, 0, fmt.Errorf("expected two comma-separated weights, got %q", s) + } + if cost, err = parseFloatField(parts[0], "cost"); err != nil { + return 0, 0, err + } + if latency, err = parseFloatField(parts[1], "latency"); err != nil { + return 0, 0, err + } + if cost < 0 || latency < 0 { + return 0, 0, fmt.Errorf("weights must be non-negative, got %v,%v", cost, latency) + } + return cost, latency, nil +} + +func parseFloatField(s, name string) (float64, error) { + v, err := strconv.ParseFloat(strings.TrimSpace(s), 64) + if err != nil { + return 0, fmt.Errorf("invalid %s weight %q: %w", name, s, err) + } + if math.IsNaN(v) || math.IsInf(v, 0) { + return 0, fmt.Errorf("invalid %s weight %q: must be a finite number", name, s) + } + return v, nil +} diff --git a/coverage-bailian b/coverage-bailian new file mode 100644 index 00000000..484e7fa7 --- /dev/null +++ b/coverage-bailian @@ -0,0 +1,35 @@ +mode: set +gomodel/internal/providers/bailian/bailian.go:39.86,48.2 2 1 +gomodel/internal/providers/bailian/bailian.go:52.97,60.2 1 1 +gomodel/internal/providers/bailian/bailian.go:62.51,64.2 1 1 +gomodel/internal/providers/bailian/bailian.go:67.43,69.2 1 1 +gomodel/internal/providers/bailian/bailian.go:73.107,75.2 1 1 +gomodel/internal/providers/bailian/bailian.go:78.108,80.2 1 1 +gomodel/internal/providers/bailian/bailian.go:83.82,85.2 1 1 +gomodel/internal/providers/bailian/bailian.go:88.112,90.2 1 1 +gomodel/internal/providers/bailian/bailian.go:93.108,95.2 1 1 +gomodel/internal/providers/bailian/bailian.go:101.113,103.2 1 1 +gomodel/internal/providers/bailian/bailian.go:106.118,108.2 1 1 +gomodel/internal/providers/bailian/bailian.go:111.106,113.2 1 0 +gomodel/internal/providers/bailian/bailian.go:116.90,118.2 1 0 +gomodel/internal/providers/bailian/bailian.go:121.111,123.2 1 0 +gomodel/internal/providers/bailian/bailian.go:126.93,128.2 1 0 +gomodel/internal/providers/bailian/bailian.go:131.104,133.2 1 0 +gomodel/internal/providers/bailian/bailian.go:136.107,138.16 2 0 +gomodel/internal/providers/bailian/bailian.go:138.16,140.3 1 0 +gomodel/internal/providers/bailian/bailian.go:141.2,142.18 2 0 +gomodel/internal/providers/bailian/bailian.go:146.124,148.16 2 0 +gomodel/internal/providers/bailian/bailian.go:148.16,150.3 1 0 +gomodel/internal/providers/bailian/bailian.go:151.2,151.27 1 0 +gomodel/internal/providers/bailian/bailian.go:151.27,153.3 1 0 +gomodel/internal/providers/bailian/bailian.go:154.2,154.18 1 0 +gomodel/internal/providers/bailian/bailian.go:158.86,160.16 2 0 +gomodel/internal/providers/bailian/bailian.go:160.16,162.3 1 0 +gomodel/internal/providers/bailian/bailian.go:163.2,164.18 2 0 +gomodel/internal/providers/bailian/bailian.go:168.97,170.2 1 0 +gomodel/internal/providers/bailian/bailian.go:173.102,175.2 1 0 +gomodel/internal/providers/bailian/bailian.go:189.67,190.40 1 1 +gomodel/internal/providers/bailian/bailian.go:190.40,192.3 1 1 +gomodel/internal/providers/bailian/bailian.go:194.2,194.82 1 1 +gomodel/internal/providers/bailian/bailian.go:194.82,198.3 3 1 +gomodel/internal/providers/bailian/bailian.go:199.2,206.16 6 1 diff --git a/coverage-bailian.out b/coverage-bailian.out new file mode 100644 index 00000000..484e7fa7 --- /dev/null +++ b/coverage-bailian.out @@ -0,0 +1,35 @@ +mode: set +gomodel/internal/providers/bailian/bailian.go:39.86,48.2 2 1 +gomodel/internal/providers/bailian/bailian.go:52.97,60.2 1 1 +gomodel/internal/providers/bailian/bailian.go:62.51,64.2 1 1 +gomodel/internal/providers/bailian/bailian.go:67.43,69.2 1 1 +gomodel/internal/providers/bailian/bailian.go:73.107,75.2 1 1 +gomodel/internal/providers/bailian/bailian.go:78.108,80.2 1 1 +gomodel/internal/providers/bailian/bailian.go:83.82,85.2 1 1 +gomodel/internal/providers/bailian/bailian.go:88.112,90.2 1 1 +gomodel/internal/providers/bailian/bailian.go:93.108,95.2 1 1 +gomodel/internal/providers/bailian/bailian.go:101.113,103.2 1 1 +gomodel/internal/providers/bailian/bailian.go:106.118,108.2 1 1 +gomodel/internal/providers/bailian/bailian.go:111.106,113.2 1 0 +gomodel/internal/providers/bailian/bailian.go:116.90,118.2 1 0 +gomodel/internal/providers/bailian/bailian.go:121.111,123.2 1 0 +gomodel/internal/providers/bailian/bailian.go:126.93,128.2 1 0 +gomodel/internal/providers/bailian/bailian.go:131.104,133.2 1 0 +gomodel/internal/providers/bailian/bailian.go:136.107,138.16 2 0 +gomodel/internal/providers/bailian/bailian.go:138.16,140.3 1 0 +gomodel/internal/providers/bailian/bailian.go:141.2,142.18 2 0 +gomodel/internal/providers/bailian/bailian.go:146.124,148.16 2 0 +gomodel/internal/providers/bailian/bailian.go:148.16,150.3 1 0 +gomodel/internal/providers/bailian/bailian.go:151.2,151.27 1 0 +gomodel/internal/providers/bailian/bailian.go:151.27,153.3 1 0 +gomodel/internal/providers/bailian/bailian.go:154.2,154.18 1 0 +gomodel/internal/providers/bailian/bailian.go:158.86,160.16 2 0 +gomodel/internal/providers/bailian/bailian.go:160.16,162.3 1 0 +gomodel/internal/providers/bailian/bailian.go:163.2,164.18 2 0 +gomodel/internal/providers/bailian/bailian.go:168.97,170.2 1 0 +gomodel/internal/providers/bailian/bailian.go:173.102,175.2 1 0 +gomodel/internal/providers/bailian/bailian.go:189.67,190.40 1 1 +gomodel/internal/providers/bailian/bailian.go:190.40,192.3 1 1 +gomodel/internal/providers/bailian/bailian.go:194.2,194.82 1 1 +gomodel/internal/providers/bailian/bailian.go:194.82,198.3 3 1 +gomodel/internal/providers/bailian/bailian.go:199.2,206.16 6 1 diff --git a/gomodel.exe b/gomodel.exe new file mode 100644 index 00000000..c4524e69 Binary files /dev/null and b/gomodel.exe differ diff --git a/gomodel.exe~ b/gomodel.exe~ new file mode 100644 index 00000000..cc19afec Binary files /dev/null and b/gomodel.exe~ differ diff --git a/internal/core/interfaces.go b/internal/core/interfaces.go index a9ae55d5..2f030b91 100644 --- a/internal/core/interfaces.go +++ b/internal/core/interfaces.go @@ -4,6 +4,7 @@ package core import ( "context" "io" + "time" ) // Provider defines the interface for LLM providers @@ -184,6 +185,30 @@ type AvailabilityChecker interface { CheckAvailability(ctx context.Context) error } +// ProviderStats is an optional interface implemented by provider clients that +// expose in-memory runtime statistics for latency-aware routing. Routers obtain +// it via a type assertion; providers without instrumentation simply omit it. +// +// Latency values are best-effort EWMA approximations (not true percentiles) and +// are intended only for relative comparison across providers. All accessors +// return the zero value (0 / "") before any request has been observed, which +// strategies interpret as "unknown". +type ProviderStats interface { + // P50Latency returns the approximate P50 round-trip latency, or 0 if unknown. + P50Latency() time.Duration + + // P99Latency returns the approximate P99 round-trip latency, or 0 if unknown. + P99Latency() time.Duration + + // ErrorRate returns the smoothed error ratio in [0, 1], or 0 if unknown. + ErrorRate() float64 + + // CircuitState returns the provider's circuit breaker state as a string: + // "closed", "open", or "half-open". Implementations without a circuit + // breaker return "closed". + CircuitState() string +} + // ModelLookup defines the interface for looking up models and their providers. // This abstraction allows the Router to be decoupled from the concrete ModelRegistry implementation. type ModelLookup interface { diff --git a/internal/gateway/inference_prepare.go b/internal/gateway/inference_prepare.go index 6246201c..855a84de 100644 --- a/internal/gateway/inference_prepare.go +++ b/internal/gateway/inference_prepare.go @@ -5,6 +5,7 @@ import ( "strings" "gomodel/internal/core" + router "gomodel/internal/router" ) // PrepareChatRequest resolves workflow/model policy and applies translated request patching. @@ -93,6 +94,18 @@ func prepareTranslatedRequest[Req any]( patchNilMessage string, ) (context.Context, Req, *core.Workflow, error) { ctx = contextWithRequestID(ctx, meta.RequestID) + + // When the caller did not pin a provider (neither via req.Provider nor via + // a "provider/model" model string), mark this request as eligible for + // intelligent routing. The prepare step resolves a concrete provider, but + // the execution path (resolveProvider) uses this flag to know that the + // provider was system-assigned and can be re-evaluated by strategy. + if provider != nil && strings.TrimSpace(*provider) == "" { + if !hasProviderInModel(model) { + ctx = router.WithRoutingEligible(ctx) + } + } + workflow, err := o.ensureTranslatedRequestWorkflow(ctx, meta.Workflow, meta.RequestID, meta.Endpoint, model, provider) if err != nil { var zero Req @@ -215,6 +228,20 @@ func currentTranslatedWorkflow(workflow *core.Workflow, endpoint core.EndpointDe return workflow } +// hasProviderInModel reports whether the model string contains a provider +// prefix (e.g. "deepseek/deepseek-v4-flash"), which means the user explicitly +// chose a provider via the model syntax even though req.Provider is empty. +func hasProviderInModel(model *string) bool { + if model == nil { + return false + } + m := strings.TrimSpace(*model) + firstSlash := strings.IndexByte(m, '/') + // A provider prefix is a non-empty segment before a slash that is NOT itself + // a plain model-ID containing a slash (e.g. "meta-llama/Llama-3-70b"). + return firstSlash > 0 && firstSlash < len(m)-1 +} + // ApplyResolvedSelector updates request model/provider fields to the resolved selector. func ApplyResolvedSelector(model, providerHint *string, resolution *core.RequestModelResolution) { if model == nil || providerHint == nil || resolution == nil { diff --git a/internal/llmclient/client.go b/internal/llmclient/client.go index 01394f9a..6021d7c5 100644 --- a/internal/llmclient/client.go +++ b/internal/llmclient/client.go @@ -93,14 +93,18 @@ type Client struct { config Config headerSetter HeaderSetter circuitBreaker *circuitBreaker + latencyTracker *LatencyTracker } -// New creates a new LLM client with the given configuration -func New(cfg Config, headerSetter HeaderSetter) *Client { +// newLLMClient is the shared constructor core used by New and NewWithHTTPClient. +// Every client gets a LatencyTracker so latency-aware routing always has a +// stats source to assert, even before the first request. +func newLLMClient(httpClient *http.Client, cfg Config, headerSetter HeaderSetter) *Client { c := &Client{ - httpClient: httpclient.NewDefaultHTTPClient(), - config: cfg, - headerSetter: headerSetter, + httpClient: httpClient, + config: cfg, + headerSetter: headerSetter, + latencyTracker: NewLatencyTracker(), } if cfg.CircuitBreaker.FailureThreshold > 0 { @@ -114,23 +118,14 @@ func New(cfg Config, headerSetter HeaderSetter) *Client { return c } +// New creates a new LLM client with the given configuration +func New(cfg Config, headerSetter HeaderSetter) *Client { + return newLLMClient(httpclient.NewDefaultHTTPClient(), cfg, headerSetter) +} + // NewWithHTTPClient creates a new LLM client with a custom HTTP client func NewWithHTTPClient(httpClient *http.Client, cfg Config, headerSetter HeaderSetter) *Client { - c := &Client{ - httpClient: httpClient, - config: cfg, - headerSetter: headerSetter, - } - - if cfg.CircuitBreaker.FailureThreshold > 0 { - c.circuitBreaker = newCircuitBreaker( - cfg.CircuitBreaker.FailureThreshold, - cfg.CircuitBreaker.SuccessThreshold, - cfg.CircuitBreaker.Timeout, - ) - } - - return c + return newLLMClient(httpClient, cfg, headerSetter) } // SetBaseURL updates the base URL (thread-safe) @@ -235,9 +230,37 @@ func (c *Client) finishRequest(scope requestScope, statusCode int, err error) { // Use this whenever a code path returns from one of the public Do* methods. func (c *Client) completeScope(scope requestScope, statusCode int, err, cbErr error) { c.recordCircuitBreakerCompletion(statusCode, cbErr) + c.latencyTracker.Record(time.Since(scope.startedAt), err != nil) c.finishRequest(scope, statusCode, err) } +// P50Latency returns the approximate P50 round-trip latency for this client's +// provider, satisfying core.ProviderStats. Returns 0 before any request. +func (c *Client) P50Latency() time.Duration { + return c.latencyTracker.P50() +} + +// P99Latency returns the approximate P99 round-trip latency for this client's +// provider, satisfying core.ProviderStats. Returns 0 before any request. +func (c *Client) P99Latency() time.Duration { + return c.latencyTracker.P99() +} + +// ErrorRate returns the smoothed error ratio in [0, 1], satisfying +// core.ProviderStats. Returns 0 before any request. +func (c *Client) ErrorRate() float64 { + return c.latencyTracker.ErrorRate() +} + +// CircuitState returns the current circuit-breaker state, satisfying +// core.ProviderStats. Returns "closed" when no breaker is configured. +func (c *Client) CircuitState() string { + if c.circuitBreaker == nil { + return "closed" + } + return c.circuitBreaker.State() +} + // failAfterRetries handles the "exhausted retries with no captured error" // fallback shared by the retrying entry points (DoRaw, DoPassthrough). The // returned error is also reported through the scope. diff --git a/internal/llmclient/latency.go b/internal/llmclient/latency.go new file mode 100644 index 00000000..45dd1e2b --- /dev/null +++ b/internal/llmclient/latency.go @@ -0,0 +1,118 @@ +package llmclient + +import ( + "sync" + "time" +) + +// defaultLatencyAlpha is the EWMA smoothing factor used for p50 latency. +// A smaller alpha makes the average more responsive to recent changes. +const defaultLatencyAlpha = 0.1 + +// ewma is a concurrency-safe exponentially weighted moving average. +// The first sample initializes the value directly (tracked by the +// initialized bool so that a genuine zero-valued sample is not +// mistaken for the uninitialized state); subsequent samples are +// smoothed using value = alpha*sample + (1-alpha)*value. +type ewma struct { + mu sync.Mutex + value float64 + alpha float64 + initialized bool +} + +// newEWMA returns an EWMA with the given smoothing factor alpha. +// alpha must be in the range (0, 1]; values closer to 1 weight recent +// samples more heavily. +func newEWMA(alpha float64) *ewma { + return &ewma{alpha: alpha} +} + +// Add incorporates a new sample into the moving average. The first sample +// seeds the value directly so the average is not biased toward zero. +func (e *ewma) Add(sample float64) { + e.mu.Lock() + defer e.mu.Unlock() + if !e.initialized { + e.value = sample + e.initialized = true + return + } + e.value = e.alpha*sample + (1-e.alpha)*e.value +} + +// Value returns the current EWMA value. Returns zero before any sample +// has been added. +func (e *ewma) Value() float64 { + e.mu.Lock() + defer e.mu.Unlock() + return e.value +} + +// LatencyTracker maintains EWMA-based latency and error-rate statistics +// for a provider or model. It is intended to feed latency-aware routing +// decisions with cheap, in-memory, concurrency-safe signals. +// +// Statistics are kept in memory only: there is no persistence and no +// periodic cleanup. Callers should call Record once per completed request. +// +// Smoothing factors: +// - p50 uses alpha=0.1 for a responsive mid-latency signal +// - p99 uses alpha=0.05 for a smoother tail-latency signal +// - errorRate uses alpha=0.2 to react quickly to error bursts +// +// Note: the EWMA-based p50/p99 are approximate percentiles rather than +// true percentile estimators; they are sufficient for relative comparison +// across providers and not intended for SLA reporting. +type LatencyTracker struct { + p50 *ewma + p99 *ewma + errorRate *ewma +} + +// NewLatencyTracker returns a LatencyTracker with default smoothing factors +// tuned for latency-aware routing. +func NewLatencyTracker() *LatencyTracker { + return &LatencyTracker{ + p50: newEWMA(defaultLatencyAlpha), + p99: newEWMA(defaultLatencyAlpha / 2), + errorRate: newEWMA(0.2), + } +} + +// Record observes the duration and error status of a completed request. +// duration is the total request latency; isError indicates whether the +// request was considered failed for routing purposes. +func (t *LatencyTracker) Record(duration time.Duration, isError bool) { + ms := float64(duration.Milliseconds()) + t.p50.Add(ms) + t.p99.Add(ms) + t.errorRate.Add(boolToFloat64(isError)) +} + +// P50 returns the approximate p50 latency as a time.Duration. +// Returns zero before any sample has been recorded. +func (t *LatencyTracker) P50() time.Duration { + return time.Duration(t.p50.Value()) * time.Millisecond +} + +// P99 returns the approximate p99 latency as a time.Duration. +// Returns zero before any sample has been recorded. +func (t *LatencyTracker) P99() time.Duration { + return time.Duration(t.p99.Value()) * time.Millisecond +} + +// ErrorRate returns the EWMA-smoothed error rate in the range [0, 1]. +// Returns zero before any sample has been recorded. +func (t *LatencyTracker) ErrorRate() float64 { + return t.errorRate.Value() +} + +// boolToFloat64 converts a boolean to 1.0 (true) or 0.0 (false) for use +// as an EWMA sample. +func boolToFloat64(b bool) float64 { + if b { + return 1 + } + return 0 +} diff --git a/internal/llmclient/latency_test.go b/internal/llmclient/latency_test.go new file mode 100644 index 00000000..1bc0a499 --- /dev/null +++ b/internal/llmclient/latency_test.go @@ -0,0 +1,65 @@ +package llmclient + +import ( + "testing" + "time" +) + +func TestLatencyTracker_FirstSampleSeedsValue(t *testing.T) { + tr := NewLatencyTracker() + tr.Record(100*time.Millisecond, false) + if got := tr.P50(); got != 100*time.Millisecond { + t.Fatalf("first sample should seed value, got %v", got) + } +} + +func TestLatencyTracker_ConvergesTowardRecentSamples(t *testing.T) { + tr := NewLatencyTracker() + tr.Record(1000*time.Millisecond, false) // seed + // Feed many short samples; EWMA should move well below the seed. + for i := 0; i < 50; i++ { + tr.Record(10*time.Millisecond, false) + } + if got := tr.P50(); got > 200*time.Millisecond { + t.Fatalf("expected p50 to converge toward 10ms, got %v", got) + } +} + +func TestLatencyTracker_ErrorRateReflectsFailures(t *testing.T) { + tr := NewLatencyTracker() + for i := 0; i < 5; i++ { + tr.Record(10*time.Millisecond, i%2 == 0) // 3 errors of 5 + } + rate := tr.ErrorRate() + if rate < 0.3 || rate > 0.9 { + t.Fatalf("expected error rate near 0.4-0.6, got %v", rate) + } +} + +func TestLatencyTracker_P50BeforeSamples(t *testing.T) { + tr := NewLatencyTracker() + if got := tr.P50(); got != 0 { + t.Fatalf("expected zero p50 before samples, got %v", got) + } + if got := tr.ErrorRate(); got != 0 { + t.Fatalf("expected zero error rate before samples, got %v", got) + } +} + +func TestLatencyTracker_ConcurrentSafe(t *testing.T) { + tr := NewLatencyTracker() + done := make(chan struct{}) + for g := 0; g < 10; g++ { + go func() { + for i := 0; i < 100; i++ { + tr.Record(time.Millisecond, false) + _ = tr.P50() + _ = tr.ErrorRate() + } + done <- struct{}{} + }() + } + for g := 0; g < 10; g++ { + <-done + } +} diff --git a/internal/providers/init.go b/internal/providers/init.go index 2f9fe573..f4987014 100644 --- a/internal/providers/init.go +++ b/internal/providers/init.go @@ -147,7 +147,9 @@ func Init(ctx context.Context, result *config.LoadResult, factory *ProviderFacto } stopRefresh := registry.StartBackgroundRefresh(refreshInterval, modelListURL) - router, err := NewRouter(registry) + strategyRegistry := buildStrategyRegistry(result.Config.Router) + + router, err := NewRouter(registry, WithStrategyRegistry(strategyRegistry)) if err != nil { stopRefresh() modelCache.Close() diff --git a/internal/providers/registry.go b/internal/providers/registry.go index abb60c38..9307e9b0 100644 --- a/internal/providers/registry.go +++ b/internal/providers/registry.go @@ -66,6 +66,10 @@ type ModelRegistry struct { // invalidateSortedCaches whenever the catalog changes. Protected by mu. qualifiedByName map[string]core.ModelSelector qualifiedByType map[string]core.ModelSelector + + // providerRegistrationOrder maintains provider-name insertion order so + // ListProvidersForModel returns entries in registration (first-wins) order. + providerRegistrationOrder []string } type metadataEnrichmentStats struct { @@ -92,6 +96,7 @@ func NewModelRegistry() *ModelRegistry { providerRuntime: make(map[string]providerRuntimeState), refreshCh: make(chan struct{}, 1), configuredProviderModelsMode: config.ConfiguredProviderModelsModeFallback, + providerRegistrationOrder: make([]string, 0), } } @@ -279,6 +284,11 @@ func (r *ModelRegistry) RegisterProviderWithNameAndType(provider core.Provider, r.providerTypes[provider] = providerType r.providerNames[provider] = providerName + // Track registration order for first-wins routing and ListProvidersForModel. + if !slices.Contains(r.providerRegistrationOrder, providerName) { + r.providerRegistrationOrder = append(r.providerRegistrationOrder, providerName) + } + state := r.providerRuntime[providerName] state.registered = true r.providerRuntime[providerName] = state @@ -701,6 +711,39 @@ type ModelWithProvider struct { Selector string `json:"selector"` } +// ListProvidersForModel returns every provider that serves modelID, in +// registration (first-wins) order. Each entry includes the model metadata and +// provider info so callers can inspect pricing, latency, and other signals. +// Returns an empty slice when the model is not found. +func (r *ModelRegistry) ListProvidersForModel(modelID string) []ModelWithProvider { + r.mu.RLock() + defer r.mu.RUnlock() + + modelID = strings.TrimSpace(modelID) + if modelID == "" || len(r.providerRegistrationOrder) == 0 { + return []ModelWithProvider{} + } + + result := make([]ModelWithProvider, 0, len(r.providers)) + for _, pName := range r.providerRegistrationOrder { + pModels, ok := r.modelsByProvider[pName] + if !ok { + continue + } + info, exists := providerModelInfo(pModels, modelID, modelID) + if !exists { + continue + } + result = append(result, ModelWithProvider{ + Model: info.Model, + ProviderType: info.ProviderType, + ProviderName: pName, + Selector: qualifyPublicModelID(pName, modelID), + }) + } + return result +} + // ListModelsWithProvider returns all provider-backed models with provider metadata, // sorted by public selector. // The sorted slice is cached and rebuilt only when the underlying models change. diff --git a/internal/providers/router.go b/internal/providers/router.go index 84eaacb9..5109b237 100644 --- a/internal/providers/router.go +++ b/internal/providers/router.go @@ -12,16 +12,30 @@ import ( "strings" "gomodel/internal/core" + router "gomodel/internal/router" ) // ErrRegistryNotInitialized is returned when the router is used before the registry has any models. var ErrRegistryNotInitialized = fmt.Errorf("model registry has no models: ensure Initialize() or LoadFromCache() is called before using the router") +// RouterOption configures a Router at construction. +type RouterOption func(*Router) + +// WithStrategyRegistry enables intelligent provider selection: when a request +// does not pin a provider and multiple providers serve the same model, the +// registry's configured strategy (or a per-request override on the context) +// selects the best candidate by cost and/or latency. A nil registry keeps the +// historical first-wins behaviour. +func WithStrategyRegistry(registry *router.StrategyRegistry) RouterOption { + return func(r *Router) { r.strategyRegistry = registry } +} + // Router routes requests to the appropriate provider based on the model lookup. // It uses a dynamic model-to-provider mapping that is populated at startup // by fetching available models from each provider's /models endpoint. type Router struct { - lookup core.ModelLookup + lookup core.ModelLookup + strategyRegistry *router.StrategyRegistry } type providerTypeRegistry interface { @@ -52,6 +66,13 @@ type modelWithProviderLister interface { ListModelsWithProvider() []ModelWithProvider } +// providerModelForModelLister is an optional O(k) interface for enumerating +// providers that serve a specific model ID. Implementations with a registration +// order can return candidates in first-wins order. +type providerModelForModelLister interface { + ListProvidersForModel(modelID string) []ModelWithProvider +} + // qualifiedSelectorResolver is an optional fast path for qualified selector // resolution. Implementations resolve a "/" pair via an O(1) // index instead of scanning the catalog. A false result means the caller should @@ -70,14 +91,18 @@ func registryUnavailableError(err error) error { // NewRouter creates a new provider router with a model lookup. // The lookup must be initialized (via Initialize() or LoadFromCache()) before using the router. -// Returns an error if the lookup is nil. -func NewRouter(lookup core.ModelLookup) (*Router, error) { +// Returns an error if the lookup is nil. Options may enable intelligent routing. +func NewRouter(lookup core.ModelLookup, opts ...RouterOption) (*Router, error) { if lookup == nil { return nil, fmt.Errorf("lookup cannot be nil") } - return &Router{ - lookup: lookup, - }, nil + r := &Router{lookup: lookup} + for _, opt := range opts { + if opt != nil { + opt(r) + } + } + return r, nil } // checkReady verifies the lookup has models available. @@ -281,9 +306,141 @@ func (r *Router) resolveProvider(ctx context.Context, model, providerHint string if p == nil { return nil, core.ModelSelector{}, core.NewNotFoundError("model not found: " + lookupModel) } + + // Intelligent routing: when the caller did not pin a provider (or the provider + // was assigned by the system during the prepare phase for an unqualified request) + // and more than one provider serves this model, let the configured strategy + // choose the best candidate. This is a no-op when no strategy registry is + // configured (historical behaviour). + eligible := (strings.TrimSpace(providerHint) == "" || router.IsRoutingEligible(ctx)) && !strings.Contains(model, "/") + if eligible && r.strategyRegistry != nil { + if chosen, newSelector, ok := r.applyStrategy(ctx, selector); ok { + selector = newSelector + p = chosen + } + } + return p, selector, nil } +// applyStrategy runs the configured routing strategy over all providers that +// serve the resolved model and returns the chosen provider and selector. It +// returns ok=false when intelligent routing does not apply (single candidate, +// no strategy, all filtered, or the chosen candidate equals the current one). +func (r *Router) applyStrategy(ctx context.Context, selector core.ModelSelector) (core.Provider, core.ModelSelector, bool) { + candidates := r.collectCandidates(selector.Model) + if len(candidates) < 2 { + return nil, selector, false + } + + strategy, valid := r.strategyRegistry.Resolve(ctx) + if !valid { + slog.WarnContext(ctx, "routing strategy override rejected, falling back to default strategy", + "model", selector.Model, "default", r.strategyRegistry.DefaultID()) + strategy = r.strategyRegistry.New(r.strategyRegistry.DefaultID()) + if strategy == nil { + return nil, selector, false + } + } + strategy = r.applyWeightsOverride(ctx, strategy) + + chosen, err := strategy.Select(ctx, candidates) + if err != nil { + slog.WarnContext(ctx, "routing strategy selected no candidate, falling back to first candidate", + "model", selector.Model, "strategy", strategy.Name(), "error", err) + return nil, selector, false + } + + newSelector := core.ModelSelector{Provider: chosen.ProviderName, Model: chosen.ModelID} + if newSelector.QualifiedModel() == selector.QualifiedModel() { + return nil, selector, false + } + chosenProvider := r.lookup.GetProvider(newSelector.QualifiedModel()) + if chosenProvider == nil { + return nil, selector, false + } + slog.DebugContext(ctx, "intelligent routing selected provider", + "model", selector.Model, "provider", chosen.ProviderName, "strategy", strategy.Name()) + return chosenProvider, newSelector, true +} + +// applyWeightsOverride returns a balanced strategy with per-request cost/latency +// weights applied when the context carries a weights override. Non-balanced +// strategies pass through unchanged. +func (r *Router) applyWeightsOverride(ctx context.Context, strategy router.RoutingStrategy) router.RoutingStrategy { + override, ok := router.WeightsOverrideFromContext(ctx) + if !ok { + return strategy + } + balanced, ok := strategy.(*router.BalancedStrategy) + if !ok { + return strategy + } + clone := *balanced + clone.CostWeight = override.Cost + clone.LatencyWeight = override.Latency + return &clone +} + +// collectCandidates enumerates every provider that serves modelID, in +// registration (first-wins) order, attaching pricing and runtime statistics +// where the provider exposes them. Uses the O(k) ListProvidersForModel path +// when available, falling back to an O(N) scan over the full catalog. +func (r *Router) collectCandidates(modelID string) []router.ProviderCandidate { + modelID = strings.TrimSpace(modelID) + if modelID == "" { + return nil + } + + // Fast path: O(k) per-model lookup. + if indexed, ok := r.lookup.(providerModelForModelLister); ok { + return r.buildCandidatesFromEntries(indexed.ListProvidersForModel(modelID)) + } + + // Fallback: O(N) scan over the full catalog. + models, ok := r.lookup.(modelWithProviderLister) + if !ok { + return nil + } + var entries []ModelWithProvider + for _, entry := range models.ListModelsWithProvider() { + if strings.TrimSpace(entry.Model.ID) == modelID { + entries = append(entries, entry) + } + } + return r.buildCandidatesFromEntries(entries) +} + +// buildCandidatesFromEntries converts ModelWithProvider entries into +// ProviderCandidate values, attaching runtime stats via type assertion. +func (r *Router) buildCandidatesFromEntries(entries []ModelWithProvider) []router.ProviderCandidate { + candidates := make([]router.ProviderCandidate, 0, len(entries)) + for _, entry := range entries { + providerName := strings.TrimSpace(entry.ProviderName) + qualified := core.ModelSelector{Provider: providerName, Model: entry.Model.ID}.QualifiedModel() + provider := r.lookup.GetProvider(qualified) + if provider == nil { + continue + } + candidate := router.ProviderCandidate{ + Provider: provider, + ProviderName: providerName, + ProviderType: entry.ProviderType, + ModelID: entry.Model.ID, + } + if entry.Model.Metadata != nil { + candidate.Pricing = entry.Model.Metadata.Pricing + } + if stats, ok := provider.(core.ProviderStats); ok { + candidate.Latency = stats.P50Latency() + candidate.ErrorRate = stats.ErrorRate() + candidate.CircuitState = stats.CircuitState() + } + candidates = append(candidates, candidate) + } + return candidates +} + func (r *Router) refreshProviderModelsForRequest(ctx context.Context, requested core.RequestedModelSelector) (bool, error) { refresher, ok := r.lookup.(providerModelRefresher) if !ok { diff --git a/internal/providers/router_strategy_integration_test.go b/internal/providers/router_strategy_integration_test.go new file mode 100644 index 00000000..33c600d2 --- /dev/null +++ b/internal/providers/router_strategy_integration_test.go @@ -0,0 +1,315 @@ +package providers + +import ( + "context" + "net/http" + "testing" + "time" + + router "gomodel/internal/router" + + "gomodel/internal/core" +) + +// statsMockProvider is a mockProvider that also implements core.ProviderStats, +// letting intelligent routing observe latency/error/circuit signals. +type statsMockProvider struct { + mockProvider + p50 time.Duration + p99 time.Duration + errorRate float64 + circuitState string +} + +func (s *statsMockProvider) P50Latency() time.Duration { return s.p50 } +func (s *statsMockProvider) P99Latency() time.Duration { return s.p99 } +func (s *statsMockProvider) ErrorRate() float64 { return s.errorRate } +func (s *statsMockProvider) CircuitState() string { + if s.circuitState == "" { + return "closed" + } + return s.circuitState +} + +func floatPtr(v float64) *float64 { return &v } + +// registerModelWithMetadata adds a model entry carrying pricing metadata so +// cost-aware strategies have data to score. +func registerModelWithMetadata(t *testing.T, registry *ModelRegistry, provider core.Provider, providerName, providerType, modelID string, pricing *core.ModelPricing) { + t.Helper() + registry.RegisterProviderWithNameAndType(provider, providerName, providerType) + info := &ModelInfo{ + Model: core.Model{ + ID: modelID, + Object: "model", + Metadata: &core.ModelMetadata{Pricing: pricing}, + }, + Provider: provider, + ProviderName: providerName, + ProviderType: providerType, + } + if registry.modelsByProvider[providerName] == nil { + registry.modelsByProvider[providerName] = make(map[string]*ModelInfo) + } + registry.modelsByProvider[providerName][modelID] = info + if _, exists := registry.models[modelID]; !exists { + registry.models[modelID] = info + } +} + +func TestRouter_IntelligentRouting_PicksCheaperProvider(t *testing.T) { + registry := NewModelRegistry() + cheap := &statsMockProvider{ + mockProvider: mockProvider{name: "openai-east", chatResponse: &core.ChatResponse{ID: "east-resp", Model: "gpt-4o"}}, + p50: 100 * time.Millisecond, + } + pricey := &statsMockProvider{ + mockProvider: mockProvider{name: "openai-west", chatResponse: &core.ChatResponse{ID: "west-resp", Model: "gpt-4o"}}, + p50: 100 * time.Millisecond, + } + registerModelWithMetadata(t, registry, cheap, "openai-east", "openai", "gpt-4o", + &core.ModelPricing{InputPerMtok: floatPtr(1), OutputPerMtok: floatPtr(2)}) + registerModelWithMetadata(t, registry, pricey, "openai-west", "openai", "gpt-4o", + &core.ModelPricing{InputPerMtok: floatPtr(10), OutputPerMtok: floatPtr(20)}) + + r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry())) + + resp, err := r.ChatCompletion(context.Background(), &core.ChatRequest{Model: "gpt-4o"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.ID != "east-resp" { + t.Fatalf("expected cheaper east provider, got %s", resp.ID) + } +} + +func TestRouter_IntelligentRouting_PicksFasterProvider(t *testing.T) { + registry := NewModelRegistry() + slow := &statsMockProvider{ + mockProvider: mockProvider{name: "openai-slow", chatResponse: &core.ChatResponse{ID: "slow-resp", Model: "gpt-4o"}}, + p50: 500 * time.Millisecond, + } + fast := &statsMockProvider{ + mockProvider: mockProvider{name: "openai-fast", chatResponse: &core.ChatResponse{ID: "fast-resp", Model: "gpt-4o"}}, + p50: 50 * time.Millisecond, + } + // Equal pricing so latency_only routing decides. + pricing := &core.ModelPricing{InputPerMtok: floatPtr(5), OutputPerMtok: floatPtr(5)} + registerModelWithMetadata(t, registry, slow, "openai-slow", "openai", "gpt-4o", pricing) + registerModelWithMetadata(t, registry, fast, "openai-fast", "openai", "gpt-4o", pricing) + + r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry())) + + resp, err := r.ChatCompletion(context.Background(), &core.ChatRequest{Model: "gpt-4o"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.ID != "fast-resp" { + t.Fatalf("expected faster provider, got %s", resp.ID) + } +} + +func TestRouter_IntelligentRouting_HeaderStrategyOverride(t *testing.T) { + registry := NewModelRegistry() + // Two providers: cheaper-but-slower vs pricier-but-faster. + cheapSlow := &statsMockProvider{ + mockProvider: mockProvider{name: "openai-cheap", chatResponse: &core.ChatResponse{ID: "cheap-resp", Model: "gpt-4o"}}, + p50: 500 * time.Millisecond, + } + priceyFast := &statsMockProvider{ + mockProvider: mockProvider{name: "openai-fast", chatResponse: &core.ChatResponse{ID: "fast-resp", Model: "gpt-4o"}}, + p50: 50 * time.Millisecond, + } + registerModelWithMetadata(t, registry, cheapSlow, "openai-cheap", "openai", "gpt-4o", + &core.ModelPricing{InputPerMtok: floatPtr(1), OutputPerMtok: floatPtr(1)}) + registerModelWithMetadata(t, registry, priceyFast, "openai-fast", "openai", "gpt-4o", + &core.ModelPricing{InputPerMtok: floatPtr(10), OutputPerMtok: floatPtr(10)}) + + r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry())) + + // Default balanced should pick cheap-slow (cost dominates at 0.6/0.4). + defaultResp, err := r.ChatCompletion(context.Background(), &core.ChatRequest{Model: "gpt-4o"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if defaultResp.ID != "cheap-resp" { + t.Fatalf("default balanced expected cheap, got %s", defaultResp.ID) + } + + // Override to latency_only should pick the faster provider. + ctx := router.WithStrategyOverride(context.Background(), "latency_only") + latencyResp, err := r.ChatCompletion(ctx, &core.ChatRequest{Model: "gpt-4o"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if latencyResp.ID != "fast-resp" { + t.Fatalf("latency_only override expected fast, got %s", latencyResp.ID) + } +} + +func TestRouter_IntelligentRouting_InvalidOverrideFallsBack(t *testing.T) { + registry := NewModelRegistry() + cheap := &statsMockProvider{ + mockProvider: mockProvider{name: "openai-east", chatResponse: &core.ChatResponse{ID: "east-resp", Model: "gpt-4o"}}, + } + pricey := &statsMockProvider{ + mockProvider: mockProvider{name: "openai-west", chatResponse: &core.ChatResponse{ID: "west-resp", Model: "gpt-4o"}}, + } + registerModelWithMetadata(t, registry, cheap, "openai-east", "openai", "gpt-4o", + &core.ModelPricing{InputPerMtok: floatPtr(1), OutputPerMtok: floatPtr(2)}) + registerModelWithMetadata(t, registry, pricey, "openai-west", "openai", "gpt-4o", + &core.ModelPricing{InputPerMtok: floatPtr(10), OutputPerMtok: floatPtr(20)}) + + r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry())) + + ctx := router.WithStrategyOverride(context.Background(), "not-a-real-strategy") + resp, err := r.ChatCompletion(ctx, &core.ChatRequest{Model: "gpt-4o"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Invalid override falls back to default balanced, which picks cheap. + if resp.ID != "east-resp" { + t.Fatalf("invalid override should fall back to default (cheap), got %s", resp.ID) + } +} + +func TestRouter_IntelligentRouting_ExplicitProviderHintBypasses(t *testing.T) { + registry := NewModelRegistry() + cheap := &statsMockProvider{ + mockProvider: mockProvider{name: "openai-east", chatResponse: &core.ChatResponse{ID: "east-resp", Model: "gpt-4o"}}, + } + pricey := &statsMockProvider{ + mockProvider: mockProvider{name: "openai-west", chatResponse: &core.ChatResponse{ID: "west-resp", Model: "gpt-4o"}}, + } + registerModelWithMetadata(t, registry, cheap, "openai-east", "openai", "gpt-4o", + &core.ModelPricing{InputPerMtok: floatPtr(1), OutputPerMtok: floatPtr(2)}) + registerModelWithMetadata(t, registry, pricey, "openai-west", "openai", "gpt-4o", + &core.ModelPricing{InputPerMtok: floatPtr(10), OutputPerMtok: floatPtr(20)}) + + r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry())) + + // Explicit provider hint must route to the named provider regardless of cost. + resp, err := r.ChatCompletion(context.Background(), &core.ChatRequest{ + Model: "gpt-4o", + Provider: "openai-west", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.ID != "west-resp" { + t.Fatalf("explicit provider hint should bypass routing, got %s", resp.ID) + } +} + +func TestRouter_IntelligentRouting_SingleCandidateUnaffected(t *testing.T) { + registry := NewModelRegistry() + solo := &statsMockProvider{ + mockProvider: mockProvider{name: "openai", chatResponse: &core.ChatResponse{ID: "solo-resp", Model: "gpt-4o"}}, + } + registerModelWithMetadata(t, registry, solo, "openai", "openai", "gpt-4o", nil) + + r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry())) + + resp, err := r.ChatCompletion(context.Background(), &core.ChatRequest{Model: "gpt-4o"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.ID != "solo-resp" { + t.Fatalf("single candidate should route normally, got %s", resp.ID) + } +} + +func TestRouter_IntelligentRouting_WeightsOverride(t *testing.T) { + registry := NewModelRegistry() + cheapSlow := &statsMockProvider{ + mockProvider: mockProvider{name: "cheap-slow", chatResponse: &core.ChatResponse{ID: "cheap", Model: "gpt-4o"}}, + p50: 500 * time.Millisecond, + } + priceyFast := &statsMockProvider{ + mockProvider: mockProvider{name: "pricey-fast", chatResponse: &core.ChatResponse{ID: "fast", Model: "gpt-4o"}}, + p50: 50 * time.Millisecond, + } + registerModelWithMetadata(t, registry, cheapSlow, "cheap-slow", "openai", "gpt-4o", + &core.ModelPricing{InputPerMtok: floatPtr(1), OutputPerMtok: floatPtr(1)}) + registerModelWithMetadata(t, registry, priceyFast, "pricey-fast", "openai", "gpt-4o", + &core.ModelPricing{InputPerMtok: floatPtr(10), OutputPerMtok: floatPtr(10)}) + + r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry())) + + t.Run("weighted cost=1 latency=0 picks cheap", func(t *testing.T) { + ctx := router.WithWeightsOverride(context.Background(), router.WeightsOverride{Cost: 1.0, Latency: 0.0}) + resp, err := r.ChatCompletion(ctx, &core.ChatRequest{Model: "gpt-4o"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.ID != "cheap" { + t.Fatalf("cost=1 latency=0 expected cheap, got %s", resp.ID) + } + }) + + t.Run("weighted cost=0 latency=1 picks fast", func(t *testing.T) { + ctx := router.WithWeightsOverride(context.Background(), router.WeightsOverride{Cost: 0.0, Latency: 1.0}) + resp, err := r.ChatCompletion(ctx, &core.ChatRequest{Model: "gpt-4o"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.ID != "fast" { + t.Fatalf("cost=0 latency=1 expected fast, got %s", resp.ID) + } + }) +} + +func TestRouter_IntelligentRouting_ModelSyntaxProviderNotOverridden(t *testing.T) { + registry := NewModelRegistry() + cheap := &statsMockProvider{ + mockProvider: mockProvider{name: "cheap", chatResponse: &core.ChatResponse{ID: "cheap", Model: "gpt-4o"}}, + } + pricey := &statsMockProvider{ + mockProvider: mockProvider{name: "pricey", chatResponse: &core.ChatResponse{ID: "pricey", Model: "gpt-4o"}}, + } + registerModelWithMetadata(t, registry, cheap, "cheap", "openai", "gpt-4o", + &core.ModelPricing{InputPerMtok: floatPtr(1), OutputPerMtok: floatPtr(1)}) + registerModelWithMetadata(t, registry, pricey, "pricey", "openai", "gpt-4o", + &core.ModelPricing{InputPerMtok: floatPtr(10), OutputPerMtok: floatPtr(10)}) + + r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry())) + + // Provider specified via model syntax (pricey/gpt-4o) should not be overridden. + resp, err := r.ChatCompletion(context.Background(), &core.ChatRequest{ + Model: "pricey/gpt-4o", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.ID != "pricey" { + t.Fatalf("model-syntax provider (pricey/gpt-4o) should be respected, got %s", resp.ID) + } +} + +func TestRouter_FirstFit_RespectsRegistrationOrder(t *testing.T) { + registry := NewModelRegistry() + first := &statsMockProvider{ + mockProvider: mockProvider{name: "reg-first", chatResponse: &core.ChatResponse{ID: "first", Model: "gpt-4o"}}, + } + second := &statsMockProvider{ + mockProvider: mockProvider{name: "reg-second", chatResponse: &core.ChatResponse{ID: "second", Model: "gpt-4o"}}, + } + // Register in order first, then second. + registerModelWithMetadata(t, registry, first, "reg-first", "openai", "gpt-4o", nil) + registerModelWithMetadata(t, registry, second, "reg-second", "openai", "gpt-4o", nil) + + // Override to first_fit so the order is the only decision factor. + ctx := router.WithStrategyOverride(context.Background(), "first_fit") + r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry())) + + resp, err := r.ChatCompletion(ctx, &core.ChatRequest{Model: "gpt-4o"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.ID != "first" { + t.Fatalf("first_fit should route to reg-first (registration order), got %s", resp.ID) + } +} + +// sanity: http import referenced (prevents unused import churn if tests shrink). +var _ = http.StatusOK diff --git a/internal/providers/strategy_init.go b/internal/providers/strategy_init.go new file mode 100644 index 00000000..6fec5878 --- /dev/null +++ b/internal/providers/strategy_init.go @@ -0,0 +1,36 @@ +package providers + +import ( + router "gomodel/internal/router" + + "gomodel/config" +) + +// buildStrategyRegistry constructs the routing strategy registry from config. +// Strategies with configured weights (balanced) are instantiated directly so +// the configured cost/latency split applies to every request that does not +// carry a per-request override. +func buildStrategyRegistry(cfg config.RouterConfig) *router.StrategyRegistry { + registry := router.NewStrategyRegistry() + + registry.Register("balanced", func() router.RoutingStrategy { + return &router.BalancedStrategy{ + CostWeight: cfg.Weights.Cost, + LatencyWeight: cfg.Weights.Latency, + MaxErrorRate: cfg.MaxErrorRate, + } + }) + + registry.Register("cost_only", func() router.RoutingStrategy { + return &router.CostOnlyStrategy{MaxErrorRate: cfg.MaxErrorRate} + }) + registry.Register("latency_only", func() router.RoutingStrategy { + return &router.LatencyOnlyStrategy{MaxErrorRate: cfg.MaxErrorRate} + }) + registry.Register("first_fit", func() router.RoutingStrategy { + return &router.FirstFitStrategy{MaxErrorRate: cfg.MaxErrorRate} + }) + + _ = registry.SetDefault(cfg.Strategy) + return registry +} diff --git a/internal/router/context.go b/internal/router/context.go new file mode 100644 index 00000000..ac3986c1 --- /dev/null +++ b/internal/router/context.go @@ -0,0 +1,67 @@ +package router + +import "context" + +// strategyCtxKey carries a per-request routing-strategy override extracted +// from the X-GoModel-Routing-Strategy header. It is distinct from any other +// context key so callers cannot collide with it. +type strategyCtxKey struct{} + +// routingEligibleKey signals that the request's provider was assigned by the +// system (prepare phase) rather than explicitly by the user. When set, +// resolveProvider applies the routing strategy even though providerHint is +// non-empty, because the provider is an artifact of first-wins resolution, +// not a user choice. +type routingEligibleKey struct{} + +// WithRoutingEligible marks a context as eligible for intelligent routing +// despite having a non-empty provider hint. This is set by the prepare phase +// when the original request did not pin a provider. +func WithRoutingEligible(ctx context.Context) context.Context { + return context.WithValue(ctx, routingEligibleKey{}, true) +} + +// IsRoutingEligible reports whether the context was marked as eligible for +// intelligent routing by the prepare phase. +func IsRoutingEligible(ctx context.Context) bool { + v, _ := ctx.Value(routingEligibleKey{}).(bool) + return v +} + +// weightsCtxKey carries a per-request weights override parsed from the +// X-GoModel-Routing-Weights header (only honored by the balanced strategy). +type weightsCtxKey struct{} + +// StrategyOverrideFromContext returns a per-request strategy name override set +// on the context, and whether one was present. +func StrategyOverrideFromContext(ctx context.Context) (string, bool) { + s, ok := ctx.Value(strategyCtxKey{}).(string) + return s, ok +} + +// WithStrategyOverride returns a context carrying a per-request strategy name +// override. An empty name clears any prior override (returns ctx unchanged). +func WithStrategyOverride(ctx context.Context, strategy string) context.Context { + if strategy == "" { + return ctx + } + return context.WithValue(ctx, strategyCtxKey{}, strategy) +} + +// WeightsOverride holds a per-request cost/latency weight override for the +// balanced strategy. +type WeightsOverride struct { + Cost float64 + Latency float64 +} + +// WeightsOverrideFromContext returns a per-request weights override, if set. +func WeightsOverrideFromContext(ctx context.Context) (WeightsOverride, bool) { + w, ok := ctx.Value(weightsCtxKey{}).(WeightsOverride) + return w, ok +} + +// WithWeightsOverride returns a context carrying a per-request weights override. +func WithWeightsOverride(ctx context.Context, w WeightsOverride) context.Context { + return context.WithValue(ctx, weightsCtxKey{}, w) +} diff --git a/internal/router/registry.go b/internal/router/registry.go new file mode 100644 index 00000000..bec46de2 --- /dev/null +++ b/internal/router/registry.go @@ -0,0 +1,119 @@ +package router + +import ( + "context" + "errors" + "fmt" + "io" + "math" + "strings" +) + +// StrategyRegistry maps strategy names to factories. A factory returns a fresh +// strategy instance (so per-request weights can be applied without mutation +// across requests). The default strategy is used when a request omits an +// override or names an unknown strategy. +type StrategyRegistry struct { + factories map[string]func() RoutingStrategy + defaultID string +} + +// NewStrategyRegistry returns a registry pre-populated with the built-in +// strategies (balanced, cost_only, latency_only, first_fit) and "balanced" as +// the default. +func NewStrategyRegistry() *StrategyRegistry { + r := &StrategyRegistry{ + factories: map[string]func() RoutingStrategy{}, + defaultID: "balanced", + } + r.Register("balanced", func() RoutingStrategy { return NewBalancedStrategy() }) + r.Register("cost_only", func() RoutingStrategy { return NewCostOnlyStrategy() }) + r.Register("latency_only", func() RoutingStrategy { return NewLatencyOnlyStrategy() }) + r.Register("first_fit", func() RoutingStrategy { return NewFirstFitStrategy() }) + return r +} + +// Register adds or replaces a strategy factory under the given name. +func (r *StrategyRegistry) Register(name string, factory func() RoutingStrategy) { + if factory == nil || name == "" { + return + } + r.factories[strings.ToLower(strings.TrimSpace(name))] = factory +} + +// SetDefault sets the default strategy id used when no/invalid override is given. +// Returns an error if the id is not registered. +func (r *StrategyRegistry) SetDefault(id string) error { + if _, ok := r.factories[strings.ToLower(strings.TrimSpace(id))]; !ok { + return fmt.Errorf("unknown routing strategy %q", id) + } + r.defaultID = strings.ToLower(strings.TrimSpace(id)) + return nil +} + +// DefaultID returns the configured default strategy id. +func (r *StrategyRegistry) DefaultID() string { return r.defaultID } + +// Names returns the registered strategy ids. +func (r *StrategyRegistry) Names() []string { + names := make([]string, 0, len(r.factories)) + for name := range r.factories { + names = append(names, name) + } + return names +} + +// Resolve returns the strategy for a request. If ctx carries a valid override, +// that strategy is used; otherwise the default is used. An invalid override id +// yields (nil, false) so the caller can fall back to the default and log a warning. +func (r *StrategyRegistry) Resolve(ctx context.Context) (RoutingStrategy, bool) { + if override, ok := StrategyOverrideFromContext(ctx); ok { + name := strings.ToLower(strings.TrimSpace(override)) + if factory, ok := r.factories[name]; ok { + return factory(), true + } + return nil, false + } + if factory, ok := r.factories[r.defaultID]; ok { + return factory(), true + } + return nil, false +} + +// New returns a fresh instance of the named strategy, or nil if unknown. +func (r *StrategyRegistry) New(name string) RoutingStrategy { + factory, ok := r.factories[strings.ToLower(strings.TrimSpace(name))] + if !ok { + return nil + } + return factory() +} + +// ParseWeights parses a "cost,latency" weights string (e.g. "0.6,0.4") into a +// WeightsOverride. It returns an error if the string is not exactly two +// non-negative finite numbers. +func ParseWeights(s string) (WeightsOverride, error) { + parts := strings.Split(s, ",") + if len(parts) != 2 { + return WeightsOverride{}, fmt.Errorf("expected two comma-separated weights, got %q", s) + } + costStr := strings.TrimSpace(parts[0]) + latStr := strings.TrimSpace(parts[1]) + var cost, lat float64 + var trailing string + n, err := fmt.Sscanf(costStr, "%f%s", &cost, &trailing) + if err != nil && !errors.Is(err, io.EOF) || n != 1 { + return WeightsOverride{}, fmt.Errorf("invalid cost weight %q: parse error", costStr) + } + n, err = fmt.Sscanf(latStr, "%f%s", &lat, &trailing) + if err != nil && !errors.Is(err, io.EOF) || n != 1 { + return WeightsOverride{}, fmt.Errorf("invalid latency weight %q: parse error", latStr) + } + if math.IsNaN(cost) || math.IsNaN(lat) || math.IsInf(cost, 0) || math.IsInf(lat, 0) { + return WeightsOverride{}, fmt.Errorf("weights must be finite numbers, got %v,%v", cost, lat) + } + if cost < 0 || lat < 0 { + return WeightsOverride{}, fmt.Errorf("weights must be non-negative, got %v,%v", cost, lat) + } + return WeightsOverride{Cost: cost, Latency: lat}, nil +} diff --git a/internal/router/registry_test.go b/internal/router/registry_test.go new file mode 100644 index 00000000..fd569ff7 --- /dev/null +++ b/internal/router/registry_test.go @@ -0,0 +1,106 @@ +package router + +import ( + "context" + "testing" +) + +func TestStrategyRegistry_DefaultIsBalanced(t *testing.T) { + r := NewStrategyRegistry() + s, ok := r.Resolve(context.Background()) + if !ok { + t.Fatal("expected default strategy to resolve") + } + if s.Name() != "balanced" { + t.Fatalf("expected default balanced, got %s", s.Name()) + } +} + +func TestStrategyRegistry_OverrideHonored(t *testing.T) { + r := NewStrategyRegistry() + ctx := WithStrategyOverride(context.Background(), "cost_only") + s, ok := r.Resolve(ctx) + if !ok { + t.Fatal("expected override to resolve") + } + if s.Name() != "cost_only" { + t.Fatalf("expected cost_only, got %s", s.Name()) + } +} + +func TestStrategyRegistry_InvalidOverrideRejected(t *testing.T) { + r := NewStrategyRegistry() + ctx := WithStrategyOverride(context.Background(), "nonsense") + s, ok := r.Resolve(ctx) + if ok { + t.Fatal("expected invalid override to be rejected") + } + if s != nil { + t.Fatal("expected nil strategy for invalid override") + } +} + +func TestStrategyRegistry_Names(t *testing.T) { + r := NewStrategyRegistry() + names := r.Names() + if len(names) != 4 { + t.Fatalf("expected 4 strategies, got %d (%v)", len(names), names) + } +} + +func TestStrategyRegistry_SetDefaultInvalid(t *testing.T) { + r := NewStrategyRegistry() + if err := r.SetDefault("bogus"); err == nil { + t.Fatal("expected error setting unknown default") + } +} + +func TestParseWeights_Valid(t *testing.T) { + cases := []struct { + input string + wantCost float64 + wantLat float64 + }{ + {"0.7,0.3", 0.7, 0.3}, + {"0.60,0.40", 0.6, 0.4}, + {"1,0", 1.0, 0.0}, + {" 0.6 , 0.4 ", 0.6, 0.4}, + {"0,1", 0.0, 1.0}, + } + for _, tc := range cases { + w, err := ParseWeights(tc.input) + if err != nil { + t.Fatalf("ParseWeights(%q) unexpected error: %v", tc.input, err) + } + if w.Cost != tc.wantCost || w.Latency != tc.wantLat { + t.Fatalf("ParseWeights(%q) = %v,%v, want %v,%v", tc.input, w.Cost, w.Latency, tc.wantCost, tc.wantLat) + } + } +} + +func TestParseWeights_Invalid(t *testing.T) { + cases := []string{"0.6", "0.6,0.4,0.2", "abc,0.4", "-1,0.4", "0.6x,0.4", "NaN,0.4", "Inf,0.4"} + for _, c := range cases { + if _, err := ParseWeights(c); err == nil { + t.Errorf("expected error for %q", c) + } + } +} + +func TestWeightsOverride_RoundTrip(t *testing.T) { + ctx := context.Background() + if _, ok := WeightsOverrideFromContext(ctx); ok { + t.Fatal("expected no override on plain context") + } + ctx = WithWeightsOverride(ctx, WeightsOverride{Cost: 0.6, Latency: 0.4}) + w, ok := WeightsOverrideFromContext(ctx) + if !ok || w.Cost != 0.6 || w.Latency != 0.4 { + t.Fatalf("round trip failed: %+v ok=%v", w, ok) + } +} + +func TestParseWeights_ErrorsAreErrors(t *testing.T) { + if _, err := ParseWeights("x,y"); err == nil { + t.Fatal("expected error for non-numeric weights") + } +} diff --git a/internal/router/scoring.go b/internal/router/scoring.go new file mode 100644 index 00000000..ae36947d --- /dev/null +++ b/internal/router/scoring.go @@ -0,0 +1,244 @@ +package router + +import ( + "errors" + "sort" + "time" +) + +// ErrNoAcceptableCandidate indicates a strategy filtered out every candidate. +// Callers fall back to the first candidate rather than failing the request. +var ErrNoAcceptableCandidate = errors.New("no acceptable routing candidate") + +// candidateScore pairs a candidate with its computed score for sorting. +type candidateScore struct { + candidate *ProviderCandidate + score float64 +} + +// maxErrorRate is the default upper bound above which a candidate is filtered. +const maxErrorRate = 0.5 + +// isCircuitOpen reports whether a candidate's breaker blocks requests. +// An empty CircuitState is treated as healthy ("closed"). +func isCircuitOpen(c *ProviderCandidate) bool { + return c.CircuitState == "open" || c.CircuitState == "half-open" +} + +// acceptable returns false when a candidate should be filtered out before +// scoring: circuit open or error rate at/above the threshold. Candidates with +// unknown error rate (0) are kept. +func acceptable(c *ProviderCandidate, maxErrRate float64) bool { + if isCircuitOpen(c) { + return false + } + if c.ErrorRate >= maxErrRate { + return false + } + return true +} + +// normalizeLatency maps latency values into [0, 1] relative to min and max. +// Lower latency → lower (better) score. Unknown latencies (0) are scored 1.0 +// (worst) so complete-data providers win. When all latencies are unknown or +// equal, every value scores 0. +func normalizeLatency(values []time.Duration, min, max time.Duration) []float64 { + scores := make([]float64, len(values)) + if max <= min || max == 0 { + for i := range values { + if values[i] == 0 { + scores[i] = 1.0 + } + } + return scores + } + span := float64(max - min) + for i, v := range values { + if v == 0 { + scores[i] = 1.0 + continue + } + scores[i] = float64(v-min) / span + } + return scores +} + +// normalizeCost maps per-million-token costs into [0, 1] relative to min and +// max. Lower cost → lower (better) score. Missing pricing (perMtokCost returns +// -1) is scored 1.0 so complete-data providers win. Known zero-cost (free +// models) is scored 0.0 (best), properly distinguished from unknown. +// When all costs are unknown or equal, every value scores 0. +func normalizeCost(values []float64, min, max float64) []float64 { + scores := make([]float64, len(values)) + if max <= min { + for i, v := range values { + if v < 0 { + scores[i] = 1.0 + } + } + return scores + } + span := max - min + for i, v := range values { + if v < 0 { + scores[i] = 1.0 + continue + } + scores[i] = (v - min) / span + } + return scores +} + +// pickBest returns the candidate with the lowest score, breaking ties by +// ProviderName lexicographic order for determinism. Returns +// ErrNoAcceptableCandidate when scored is empty. +func pickBest(scored []candidateScore) (*ProviderCandidate, error) { + if len(scored) == 0 { + return nil, ErrNoAcceptableCandidate + } + sort.SliceStable(scored, func(i, j int) bool { + if scored[i].score != scored[j].score { + return scored[i].score < scored[j].score + } + return scored[i].candidate.ProviderName < scored[j].candidate.ProviderName + }) + return scored[0].candidate, nil +} + +// firstAcceptable returns the first acceptable candidate in order, or +// ErrNoAcceptableCandidate when none pass the filter. +func firstAcceptable(candidates []ProviderCandidate, maxErrRate float64) (*ProviderCandidate, error) { + for i := range candidates { + if acceptable(&candidates[i], maxErrRate) { + return &candidates[i], nil + } + } + return nil, ErrNoAcceptableCandidate +} + +// minMaxLatency returns the min and max non-zero latency across candidates. +func minMaxLatency(candidates []ProviderCandidate) (min, max time.Duration) { + for i := range candidates { + v := candidates[i].Latency + if v <= 0 { + continue + } + if min == 0 || v < min { + min = v + } + if v > max { + max = v + } + } + return min, max +} + +// minMaxCost returns the min and max per-million-token cost across candidates. +// Unknown costs (perMtokCost returns -1) are skipped. Known-zero costs +// (free models) are included. +func minMaxCost(candidates []ProviderCandidate) (min, max float64) { + var seen bool + for i := range candidates { + cost := perMtokCost(&candidates[i]) + if cost < 0 { + continue + } + if !seen { + min = cost + max = cost + seen = true + continue + } + if cost < min { + min = cost + } + if cost > max { + max = cost + } + } + return min, max +} + +// perMtokCost returns the sum of input and output per-million-token prices for +// a candidate. Returns -1 when pricing is absent (so callers can distinguish +// free models at cost 0 from unknown pricing). +func perMtokCost(c *ProviderCandidate) float64 { + if c == nil || c.Pricing == nil { + return -1 + } + var cost float64 + if c.Pricing.InputPerMtok != nil { + cost += *c.Pricing.InputPerMtok + } + if c.Pricing.OutputPerMtok != nil { + cost += *c.Pricing.OutputPerMtok + } + return cost +} + +// hasLatencyData reports whether at least one candidate carries latency data. +func hasLatencyData(candidates []ProviderCandidate) bool { + for i := range candidates { + if candidates[i].Latency > 0 { + return true + } + } + return false +} + +// latencyValues collects the latency field from each candidate preserving order. +func latencyValues(candidates []ProviderCandidate) []time.Duration { + out := make([]time.Duration, len(candidates)) + for i := range candidates { + out[i] = candidates[i].Latency + } + return out +} + +// hasPricingData reports whether at least one candidate carries known pricing +// (Pricing != nil with at least one non-nil field). This is distinct from +// minMaxCost which can return max==0 even when a free-model (cost 0) exists. +func hasPricingData(candidates []ProviderCandidate) bool { + for i := range candidates { + c := &candidates[i] + if c.Pricing == nil { + continue + } + if c.Pricing.InputPerMtok != nil || c.Pricing.OutputPerMtok != nil || + c.Pricing.CachedInputPerMtok != nil || c.Pricing.CacheWritePerMtok != nil || + c.Pricing.PerRequest != nil || c.Pricing.PerImage != nil || + c.Pricing.InputPerImage != nil || c.Pricing.PerSecondInput != nil || + c.Pricing.PerSecondOutput != nil || c.Pricing.PerCharacterInput != nil || + c.Pricing.PerPage != nil || len(c.Pricing.Tiers) > 0 { + return true + } + } + return false +} + +// costValues collects the per-million-token cost from each candidate preserving order. +func costValues(candidates []ProviderCandidate) []float64 { + out := make([]float64, len(candidates)) + for i := range candidates { + out[i] = perMtokCost(&candidates[i]) + } + return out +} + +// rankBy filters candidates and ranks them by a precomputed per-candidate +// score, returning the best (lowest-scoring) one. It is the shared body of the +// single-factor strategies; callers compute the score array (e.g. normalized +// cost or latency) and pass it here. +func rankBy(candidates []ProviderCandidate, scores []float64, maxErrRate float64) (*ProviderCandidate, error) { + scored := make([]candidateScore, 0, len(candidates)) + for i := range candidates { + if !acceptable(&candidates[i], maxErrRate) { + continue + } + scored = append(scored, candidateScore{ + candidate: &candidates[i], + score: scores[i], + }) + } + return pickBest(scored) +} diff --git a/internal/router/strategy.go b/internal/router/strategy.go new file mode 100644 index 00000000..71421421 --- /dev/null +++ b/internal/router/strategy.go @@ -0,0 +1,63 @@ +// Package router provides pluggable routing strategies that select the best +// provider from a set of candidates serving the same model ID. Strategies are +// cost-aware and/or latency-aware, fed by provider runtime statistics. +package router + +import ( + "context" + "time" + + "gomodel/internal/core" +) + +// ProviderCandidate describes one provider able to serve a given model, together +// with the signals a strategy needs to score it. Fields with zero values +// (nil Pricing, 0 Latency, "" CircuitState) mean "unknown"; strategies treat +// unknown signals as worst-case so that providers with complete data are preferred. +type ProviderCandidate struct { + // Provider is the underlying provider implementation (may be nil for + // strategy-only callers that just need the name/metadata). + Provider core.Provider + + // ProviderName is the concrete configured instance name, e.g. + // "openai-primary". Used as a deterministic tiebreaker. + ProviderName string + + // ProviderType is the provider type, e.g. "openai" or "azure". + ProviderType string + + // ModelID is the model identifier this candidate serves. + ModelID string + + // Pricing is the per-model pricing used by cost-aware strategies. + // nil means pricing is unknown. + Pricing *core.ModelPricing + + // Latency is the provider's smoothed P50 latency. 0 means unknown. + Latency time.Duration + + // CircuitState is "closed", "open", or "half-open". "" defaults to + // "closed" (healthy) when a provider has no breaker. + CircuitState string + + // ErrorRate is the smoothed error ratio in [0, 1]. 0 means unknown + // (which strategies may treat conservatively or optimistically per policy). + ErrorRate float64 +} + +// RoutingStrategy selects the best provider candidate from a non-empty list. +// Implementations must be deterministic: given identical candidates they must +// return the same choice, breaking ties by ProviderName lexicographic order. +// +// Select returns an error only when no candidate is acceptable (e.g. every +// candidate is filtered out); callers fall back to the first candidate in +// that case rather than failing the request. +type RoutingStrategy interface { + // Name returns the strategy identifier, e.g. "balanced", "cost_only". + Name() string + + // Select picks one candidate. The candidates slice is non-empty; strategies + // that filter out every candidate return an error so the caller can fall + // back to the first candidate and log a warning. + Select(ctx context.Context, candidates []ProviderCandidate) (*ProviderCandidate, error) +} diff --git a/internal/router/strategy_balanced.go b/internal/router/strategy_balanced.go new file mode 100644 index 00000000..79854847 --- /dev/null +++ b/internal/router/strategy_balanced.go @@ -0,0 +1,61 @@ +package router + +import ( + "context" +) + +// BalancedStrategy scores candidates by a weighted combination of cost and +// latency. Lower total score wins; ties break by ProviderName. Candidates whose +// circuit is open or whose error rate is at/above MaxErrorRate are filtered out. +// +// Missing pricing scores 1.0 (worst) on cost; missing latency scores 1.0 +// (worst) on latency — so providers with complete data are preferred. When +// every candidate misses the same dimension, that dimension contributes equally +// to all and ranking falls to the other dimension. +type BalancedStrategy struct { + // CostWeight in [0, 1]. Defaults to 0.6 when zero. + CostWeight float64 + // LatencyWeight in [0, 1]. Defaults to 0.4 when zero. + LatencyWeight float64 + // MaxErrorRate filters candidates at/above this ratio. Defaults to 0.5. + MaxErrorRate float64 +} + +// NewBalancedStrategy returns a balanced strategy with default weights. +func NewBalancedStrategy() *BalancedStrategy { + return &BalancedStrategy{ + CostWeight: 0.6, + LatencyWeight: 0.4, + MaxErrorRate: maxErrorRate, + } +} + +// Name returns "balanced". +func (s *BalancedStrategy) Name() string { return "balanced" } + +// Select picks the lowest combined-score acceptable candidate. +func (s *BalancedStrategy) Select(_ context.Context, candidates []ProviderCandidate) (*ProviderCandidate, error) { + costW := s.CostWeight + if costW == 0 && s.LatencyWeight == 0 { + costW = 0.6 + } + latW := s.LatencyWeight + if s.CostWeight == 0 && latW == 0 { + latW = 0.4 + } + maxErr := s.MaxErrorRate + if maxErr == 0 { + maxErr = maxErrorRate + } + + costMin, costMax := minMaxCost(candidates) + latMin, latMax := minMaxLatency(candidates) + costScores := normalizeCost(costValues(candidates), costMin, costMax) + latScores := normalizeLatency(latencyValues(candidates), latMin, latMax) + + scores := make([]float64, len(candidates)) + for i := range candidates { + scores[i] = costW*costScores[i] + latW*latScores[i] + } + return rankBy(candidates, scores, maxErr) +} diff --git a/internal/router/strategy_cost.go b/internal/router/strategy_cost.go new file mode 100644 index 00000000..e8e75e3d --- /dev/null +++ b/internal/router/strategy_cost.go @@ -0,0 +1,42 @@ +//nolint:dupl // mirrors strategy_latency.go: same RoutingStrategy shape, different dimension (cost vs latency); divergence is in the scorer/normalize/degrade calls. +package router + +import ( + "context" +) + +// CostOnlyStrategy selects the cheapest acceptable candidate. When every +// candidate lacks pricing data, it degrades to latency-only ranking so that a +// decision is still made on available signals. +type CostOnlyStrategy struct { + // MaxErrorRate filters candidates at/above this ratio. Defaults to 0.5. + MaxErrorRate float64 +} + +// NewCostOnlyStrategy returns a cost-only strategy with the default error filter. +func NewCostOnlyStrategy() *CostOnlyStrategy { + return &CostOnlyStrategy{MaxErrorRate: maxErrorRate} +} + +// Name returns "cost_only". +func (s *CostOnlyStrategy) Name() string { return "cost_only" } + +// Select picks the cheapest acceptable candidate. When no candidate carries +// pricing it ranks by latency instead (without recursing into latency's own +// cost-degradation path, which would loop when both dimensions are absent). +func (s *CostOnlyStrategy) Select(ctx context.Context, candidates []ProviderCandidate) (*ProviderCandidate, error) { + maxErr := s.MaxErrorRate + if maxErr == 0 { + maxErr = maxErrorRate + } + + if !hasPricingData(candidates) { + latMin, latMax := minMaxLatency(candidates) + scores := normalizeLatency(latencyValues(candidates), latMin, latMax) + return rankBy(candidates, scores, maxErr) + } + + costMin, costMax := minMaxCost(candidates) + scores := normalizeCost(costValues(candidates), costMin, costMax) + return rankBy(candidates, scores, maxErr) +} diff --git a/internal/router/strategy_firstfit.go b/internal/router/strategy_firstfit.go new file mode 100644 index 00000000..54c958c2 --- /dev/null +++ b/internal/router/strategy_firstfit.go @@ -0,0 +1,30 @@ +package router + +import ( + "context" +) + +// FirstFitStrategy returns the first acceptable candidate without scoring, +// preserving the gateway's historical first-wins behaviour. It exists for +// operators who want deterministic, config-order routing. +type FirstFitStrategy struct { + // MaxErrorRate filters candidates at/above this ratio. Defaults to 0.5. + MaxErrorRate float64 +} + +// NewFirstFitStrategy returns a first-fit strategy with the default error filter. +func NewFirstFitStrategy() *FirstFitStrategy { + return &FirstFitStrategy{MaxErrorRate: maxErrorRate} +} + +// Name returns "first_fit". +func (s *FirstFitStrategy) Name() string { return "first_fit" } + +// Select returns the first acceptable candidate in registration order. +func (s *FirstFitStrategy) Select(_ context.Context, candidates []ProviderCandidate) (*ProviderCandidate, error) { + maxErr := s.MaxErrorRate + if maxErr == 0 { + maxErr = maxErrorRate + } + return firstAcceptable(candidates, maxErr) +} diff --git a/internal/router/strategy_latency.go b/internal/router/strategy_latency.go new file mode 100644 index 00000000..78e1badf --- /dev/null +++ b/internal/router/strategy_latency.go @@ -0,0 +1,42 @@ +//nolint:dupl // mirrors strategy_cost.go: same RoutingStrategy shape, different dimension (latency vs cost); divergence is in the scorer/normalize/degrade calls. +package router + +import ( + "context" +) + +// LatencyOnlyStrategy selects the fastest acceptable candidate. When every +// candidate lacks latency data, it degrades to cost-only ranking so that a +// decision is still made on available signals. +type LatencyOnlyStrategy struct { + // MaxErrorRate filters candidates at/above this ratio. Defaults to 0.5. + MaxErrorRate float64 +} + +// NewLatencyOnlyStrategy returns a latency-only strategy with the default error filter. +func NewLatencyOnlyStrategy() *LatencyOnlyStrategy { + return &LatencyOnlyStrategy{MaxErrorRate: maxErrorRate} +} + +// Name returns "latency_only". +func (s *LatencyOnlyStrategy) Name() string { return "latency_only" } + +// Select picks the lowest-latency acceptable candidate. When no candidate +// carries latency it ranks by cost instead (without recursing into cost's own +// latency-degradation path, which would loop when both dimensions are absent). +func (s *LatencyOnlyStrategy) Select(ctx context.Context, candidates []ProviderCandidate) (*ProviderCandidate, error) { + maxErr := s.MaxErrorRate + if maxErr == 0 { + maxErr = maxErrorRate + } + + if !hasLatencyData(candidates) { + costMin, costMax := minMaxCost(candidates) + scores := normalizeCost(costValues(candidates), costMin, costMax) + return rankBy(candidates, scores, maxErr) + } + + latMin, latMax := minMaxLatency(candidates) + scores := normalizeLatency(latencyValues(candidates), latMin, latMax) + return rankBy(candidates, scores, maxErr) +} diff --git a/internal/router/strategy_test.go b/internal/router/strategy_test.go new file mode 100644 index 00000000..db01ba70 --- /dev/null +++ b/internal/router/strategy_test.go @@ -0,0 +1,294 @@ +package router + +import ( + "context" + "errors" + "testing" + "time" + + "gomodel/internal/core" +) + +func floatPtr(v float64) *float64 { return &v } + +func pricing(input, output float64) *core.ModelPricing { + return &core.ModelPricing{InputPerMtok: floatPtr(input), OutputPerMtok: floatPtr(output)} +} + +func candidates(cs ...ProviderCandidate) []ProviderCandidate { return cs } + +func mustSelect(t *testing.T, s RoutingStrategy, cands []ProviderCandidate) *ProviderCandidate { + t.Helper() + got, err := s.Select(context.Background(), cands) + if err != nil { + t.Fatalf("%s.Select returned error: %v", s.Name(), err) + } + if got == nil { + t.Fatalf("%s.Select returned nil candidate", s.Name()) + } + return got +} + +func TestBalanced_PicksCheapestWithEqualLatency(t *testing.T) { + s := NewBalancedStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "alpha", Pricing: pricing(5, 15), Latency: 100 * time.Millisecond}, + ProviderCandidate{ProviderName: "zeta", Pricing: pricing(1, 4), Latency: 100 * time.Millisecond}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "zeta" { + t.Fatalf("expected azure (cheaper), got %s", got.ProviderName) + } +} + +func TestBalanced_PicksFastestWithEqualCost(t *testing.T) { + s := NewBalancedStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "alpha", Pricing: pricing(5, 5), Latency: 200 * time.Millisecond}, + ProviderCandidate{ProviderName: "zeta", Pricing: pricing(5, 5), Latency: 50 * time.Millisecond}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "zeta" { + t.Fatalf("expected azure (faster), got %s", got.ProviderName) + } +} + +func TestBalanced_TiebreaksByName(t *testing.T) { + s := NewBalancedStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "zeta", Pricing: pricing(5, 5), Latency: 100 * time.Millisecond}, + ProviderCandidate{ProviderName: "alpha", Pricing: pricing(5, 5), Latency: 100 * time.Millisecond}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "alpha" { + t.Fatalf("expected alpha (lexicographic tiebreak), got %s", got.ProviderName) + } +} + +func TestBalanced_FiltersOpenCircuit(t *testing.T) { + s := NewBalancedStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "open", Pricing: pricing(1, 1), Latency: 10 * time.Millisecond, CircuitState: "open"}, + ProviderCandidate{ProviderName: "healthy", Pricing: pricing(10, 10), Latency: 500 * time.Millisecond}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "healthy" { + t.Fatalf("expected healthy (open circuit filtered), got %s", got.ProviderName) + } +} + +func TestBalanced_FiltersHighErrorRate(t *testing.T) { + s := NewBalancedStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "flaky", Pricing: pricing(1, 1), Latency: 10 * time.Millisecond, ErrorRate: 0.9}, + ProviderCandidate{ProviderName: "stable", Pricing: pricing(10, 10), Latency: 500 * time.Millisecond}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "stable" { + t.Fatalf("expected stable (high error rate filtered), got %s", got.ProviderName) + } +} + +func TestBalanced_AllFilteredFallsBackWithError(t *testing.T) { + s := NewBalancedStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "open", Pricing: pricing(1, 1), CircuitState: "open"}, + ) + _, err := s.Select(context.Background(), cands) + if !errors.Is(err, ErrNoAcceptableCandidate) { + t.Fatalf("expected ErrNoAcceptableCandidate, got %v", err) + } +} + +func TestCostOnly_PicksCheapest(t *testing.T) { + s := NewCostOnlyStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "alpha", Pricing: pricing(5, 15), Latency: 10 * time.Millisecond}, + ProviderCandidate{ProviderName: "zeta", Pricing: pricing(1, 4), Latency: 500 * time.Millisecond}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "zeta" { + t.Fatalf("expected azure (cheapest), got %s", got.ProviderName) + } +} + +func TestCostOnly_DegradesToLatencyWhenNoPricing(t *testing.T) { + s := NewCostOnlyStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "alpha", Latency: 200 * time.Millisecond}, + ProviderCandidate{ProviderName: "zeta", Latency: 50 * time.Millisecond}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "zeta" { + t.Fatalf("expected azure (degraded to latency), got %s", got.ProviderName) + } +} + +func TestLatencyOnly_PicksFastest(t *testing.T) { + s := NewLatencyOnlyStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "alpha", Pricing: pricing(1, 1), Latency: 300 * time.Millisecond}, + ProviderCandidate{ProviderName: "zeta", Pricing: pricing(50, 50), Latency: 50 * time.Millisecond}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "zeta" { + t.Fatalf("expected azure (fastest), got %s", got.ProviderName) + } +} + +func TestLatencyOnly_DegradesToCostWhenNoLatency(t *testing.T) { + s := NewLatencyOnlyStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "alpha", Pricing: pricing(5, 15)}, + ProviderCandidate{ProviderName: "zeta", Pricing: pricing(1, 4)}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "zeta" { + t.Fatalf("expected azure (degraded to cost), got %s", got.ProviderName) + } +} + +func TestFirstFit_ReturnsFirstAcceptable(t *testing.T) { + s := NewFirstFitStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "open", Pricing: pricing(50, 50), Latency: 500 * time.Millisecond, CircuitState: "open"}, + ProviderCandidate{ProviderName: "second", Pricing: pricing(50, 50), Latency: 500 * time.Millisecond}, + ProviderCandidate{ProviderName: "third", Pricing: pricing(1, 1), Latency: 10 * time.Millisecond}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "second" { + t.Fatalf("expected second (first acceptable), got %s", got.ProviderName) + } +} + +func TestBalanced_FreeModelNotTreatedAsUnknown(t *testing.T) { + s := NewBalancedStrategy() + // freeModel has pricing = $0 (free). unknown has no pricing. + freeModel := ProviderCandidate{ProviderName: "free", Pricing: pricing(0, 0), Latency: 100 * time.Millisecond} + unknown := ProviderCandidate{ProviderName: "unknown", Latency: 100 * time.Millisecond} + got := mustSelect(t, s, candidates(freeModel, unknown)) + if got.ProviderName != "free" { + t.Fatalf("expected free (known zero cost, score 0), got %s", got.ProviderName) + } +} + +func TestBalanced_ZeroValuesDefaultToSixtyForty(t *testing.T) { + s := &BalancedStrategy{} // zero-value struct, should default to 0.6 cost / 0.4 latency + cands := candidates( + ProviderCandidate{ProviderName: "alpha", Pricing: pricing(5, 15), Latency: 100 * time.Millisecond}, + ProviderCandidate{ProviderName: "zeta", Pricing: pricing(1, 4), Latency: 100 * time.Millisecond}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "zeta" { + t.Fatalf("expected zeta (cheaper, cost weight dominates at 0.6/0.4), got %s", got.ProviderName) + } +} + +func TestBalanced_EqualCostKnownVsUnknown(t *testing.T) { + s := NewBalancedStrategy() + // same latency, both $5 — but "unknown" has no pricing at all, so it gets + // cost score 1.0 while the known-cost provider gets cost score 0. + known := ProviderCandidate{ProviderName: "known", Pricing: pricing(5, 5), Latency: 100 * time.Millisecond} + unknown := ProviderCandidate{ProviderName: "unknown", Latency: 100 * time.Millisecond} + got := mustSelect(t, s, candidates(known, unknown)) + if got.ProviderName != "known" { + t.Fatalf("expected known (complete data preferred over unknown), got %s", got.ProviderName) + } +} + +func TestStrategies_HandleEmpty(t *testing.T) { + strategies := []RoutingStrategy{ + NewBalancedStrategy(), + NewCostOnlyStrategy(), + NewLatencyOnlyStrategy(), + NewFirstFitStrategy(), + } + for _, s := range strategies { + _, err := s.Select(context.Background(), nil) + if !errors.Is(err, ErrNoAcceptableCandidate) { + t.Errorf("%s on nil candidates: expected ErrNoAcceptableCandidate, got %v", s.Name(), err) + } + } +} + +// --- Boundary tests --- + +func TestBalanced_ErrorRateAtBoundary(t *testing.T) { + s := NewBalancedStrategy() + // ErrorRate == 0.5 is at the default filter boundary — candidate is filtered. + atBoundary := ProviderCandidate{ProviderName: "at-boundary", Pricing: pricing(1, 1), Latency: 10 * time.Millisecond, ErrorRate: 0.5} + healthy := ProviderCandidate{ProviderName: "healthy", Pricing: pricing(10, 10), Latency: 500 * time.Millisecond} + got := mustSelect(t, s, candidates(atBoundary, healthy)) + if got.ProviderName != "healthy" { + t.Fatalf("expected healthy (at-boundary filtered), got %s", got.ProviderName) + } +} + +func TestBalanced_ErrorRateBelowBoundary(t *testing.T) { + s := NewBalancedStrategy() + below := ProviderCandidate{ProviderName: "below", Pricing: pricing(1, 1), Latency: 10 * time.Millisecond, ErrorRate: 0.499} + pricey := ProviderCandidate{ProviderName: "pricey", Pricing: pricing(10, 10), Latency: 500 * time.Millisecond} + got := mustSelect(t, s, candidates(below, pricey)) + if got.ProviderName != "below" { + t.Fatalf("expected below (0.499 < 0.5 is acceptable), got %s", got.ProviderName) + } +} + +func TestBalanced_EqualScoresTiebreak(t *testing.T) { + s := NewBalancedStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "zzz", Latency: 100 * time.Millisecond}, + ProviderCandidate{ProviderName: "aaa", Latency: 100 * time.Millisecond}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "aaa" { + t.Fatalf("expected aaa (tiebreak), got %s", got.ProviderName) + } +} + +func TestFirstFit_SingleCandidate(t *testing.T) { + s := NewFirstFitStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "solo", Pricing: pricing(1, 1)}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "solo" { + t.Fatalf("expected solo, got %s", got.ProviderName) + } +} + +func TestFirstFit_AllFilteredReturnsError(t *testing.T) { + s := NewFirstFitStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "open", CircuitState: "open"}, + ProviderCandidate{ProviderName: "half", CircuitState: "half-open"}, + ) + _, err := s.Select(context.Background(), cands) + if !errors.Is(err, ErrNoAcceptableCandidate) { + t.Fatalf("expected ErrNoAcceptableCandidate, got %v", err) + } +} + +func TestCostOnly_AllAcceptableEqualCost(t *testing.T) { + s := NewCostOnlyStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "beta", Pricing: pricing(5, 5)}, + ProviderCandidate{ProviderName: "alpha", Pricing: pricing(5, 5)}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "alpha" { + t.Fatalf("expected alpha (tiebreak equal cost), got %s", got.ProviderName) + } +} + +func TestLatencyOnly_AllAcceptableEqualLatency(t *testing.T) { + s := NewLatencyOnlyStrategy() + cands := candidates( + ProviderCandidate{ProviderName: "zeta", Latency: 100 * time.Millisecond}, + ProviderCandidate{ProviderName: "alpha", Latency: 100 * time.Millisecond}, + ) + got := mustSelect(t, s, cands) + if got.ProviderName != "alpha" { + t.Fatalf("expected alpha (tiebreak equal latency), got %s", got.ProviderName) + } +} diff --git a/internal/server/http.go b/internal/server/http.go index 2603c1cc..6b5cb3aa 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -287,6 +287,10 @@ func New(provider core.RoutableProvider, cfg *Config) *Server { e.Use(AuthMiddlewareWithAuthenticator(cfg.MasterKey, cfg.Authenticator, authSkipPaths, userPathHeaderName)) } + // Per-request intelligent-routing overrides (X-GoModel-Routing-Strategy / + // X-GoModel-Routing-Weights). A no-op header-lookup when the headers are absent. + e.Use(RoutingStrategyCapture()) + // Workflow resolution resolves the request-scoped workflow after auth so // managed auth key user-path overrides are visible to policy resolution while // still keeping workflow resolution failures loggable through the audit middleware. diff --git a/internal/server/routing_strategy.go b/internal/server/routing_strategy.go new file mode 100644 index 00000000..1932b8fa --- /dev/null +++ b/internal/server/routing_strategy.go @@ -0,0 +1,48 @@ +package server + +import ( + "log/slog" + "strings" + + router "gomodel/internal/router" + + "github.com/labstack/echo/v5" +) + +// RoutingStrategyHeader carries the per-request routing-strategy override. +const RoutingStrategyHeader = "X-GoModel-Routing-Strategy" + +// RoutingWeightsHeader carries the per-request balanced-strategy weight +// override as "cost,latency" (e.g. "0.6,0.4"). +const RoutingWeightsHeader = "X-GoModel-Routing-Weights" + +// RoutingStrategyCapture reads the optional routing-strategy/weights headers +// and attaches them to the request context as overrides for the router. Invalid +// values are silently dropped here with a warning; the router falls back to the +// configured global strategy. The middleware is a no-op when the headers are +// absent, so non-routing requests pay only a header-lookup cost. +func RoutingStrategyCapture() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c *echo.Context) error { + ctx := c.Request().Context() + + if strategy := strings.TrimSpace(c.Request().Header.Get(RoutingStrategyHeader)); strategy != "" { + ctx = router.WithStrategyOverride(ctx, strategy) + } + + if weights := strings.TrimSpace(c.Request().Header.Get(RoutingWeightsHeader)); weights != "" { + if override, err := router.ParseWeights(weights); err != nil { + slog.WarnContext(ctx, "invalid routing weights header, ignoring", + "header", weights, "error", err) + } else { + ctx = router.WithWeightsOverride(ctx, override) + } + } + + if ctx != c.Request().Context() { + c.SetRequest(c.Request().WithContext(ctx)) + } + return next(c) + } + } +}