diff --git a/.env.template b/.env.template
index c940f010..7ca02b5a 100644
--- a/.env.template
+++ b/.env.template
@@ -120,6 +120,17 @@
# fallback (default): use configured models only when upstream /models fails, is nil, or is empty.
# allowlist: expose only the configured models for providers that define a list, and skip their upstream /models calls.
# CONFIGURED_PROVIDER_MODELS_MODE=fallback
+
+# --- Intelligent routing ---
+# When a request does not pin a provider and multiple providers serve the same
+# model, the gateway scores candidates and routes to the best one. Single-provider
+# configs are unaffected (one candidate => no decision needed).
+# Strategy: balanced (default), cost_only, latency_only, first_fit
+# MODEL_ROUTING_STRATEGY=balanced
+# Balanced weights as "cost_weight,latency_weight" (only affects balanced).
+# MODEL_ROUTING_STRATEGY_WEIGHTS=0.6,0.4
+# Filter candidates at/above this smoothed error rate [0,1] (default: 0.5).
+# MODEL_ROUTING_MAX_ERROR_RATE=0.5
# Examples: OPENROUTER_MODELS=..., OPENROUTER_EU_MODELS=..., AZURE_MODELS=..., VLLM_MODELS=...
# Fallback & Workflow Configuration
diff --git a/.gitignore b/.gitignore
index 4fbc41a2..1625518e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -37,3 +37,4 @@
# Superpower design docs and plans (never commit)
/docs/superpowers/
+.superpowers/
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 00000000..f6906f2e
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,10 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
+# 基于编辑器的 HTTP 客户端请求
+/httpRequests/
+# 已忽略包含查询文件的默认文件夹
+/queries/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/GoModel.iml b/.idea/GoModel.iml
new file mode 100644
index 00000000..5e764c4f
--- /dev/null
+++ b/.idea/GoModel.iml
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/go.imports.xml b/.idea/go.imports.xml
new file mode 100644
index 00000000..644cdf0b
--- /dev/null
+++ b/.idea/go.imports.xml
@@ -0,0 +1,10 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/golinter.xml b/.idea/golinter.xml
new file mode 100644
index 00000000..1ccf3ec6
--- /dev/null
+++ b/.idea/golinter.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 00000000..b6bc1834
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 00000000..35eb1ddf
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/CLAUDE.md b/CLAUDE.md
index 1bbb778b..feb32791 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -131,3 +131,4 @@ Full reference: `.env.template` and `config/config.yaml`
- **Guardrails:** Configured via `config/config.yaml` only (except `GUARDRAILS_ENABLED` env var)
- **Providers:** `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `GEMINI_API_KEY`, `USE_GOOGLE_GEMINI_NATIVE_API` (true by default; false uses Gemini's OpenAI-compatible chat API), `XAI_API_KEY`, `GROQ_API_KEY`, `OPENROUTER_API_KEY`, `ZAI_API_KEY`, `ZAI_BASE_URL` (optional Z.ai endpoint override), `MINIMAX_API_KEY`, `MINIMAX_BASE_URL` (optional MiniMax endpoint override), `XIAOMI_API_KEY`, `XIAOMI_BASE_URL` (optional Xiaomi MiMo endpoint override), `OPENCODE_GO_API_KEY`, `OPENCODE_GO_BASE_URL` (optional OpenCode Go/Zen endpoint override; default `https://opencode.ai/zen/go/v1`), `OPENCODE_GO_MESSAGES_MODELS` (optional comma-separated model IDs routed to the Anthropic-native `/messages` endpoint instead of `/chat/completions`; default `qwen3.7-max`), `BAILIAN_API_KEY`, `BAILIAN_BASE_URL` (optional Bailian base URL for region switching; default `https://dashscope.aliyuncs.com/compatible-mode/v1`), `AZURE_API_KEY`, `AZURE_BASE_URL` (Azure OpenAI deployment base URL), `AZURE_API_VERSION` (optional Azure API version), `ORACLE_API_KEY` (Oracle API key), `ORACLE_BASE_URL` (Oracle OpenAI-compatible base URL), `[_SUFFIX]_MODELS` (comma-separated configured model list for any provider type), `OLLAMA_BASE_URL`, `VLLM_BASE_URL`, `VLLM_API_KEY` (optional upstream vLLM bearer token)
- **Provider model metadata:** `providers..models` accepts either model IDs (strings) or `{id, metadata}` objects. When `metadata` is supplied (`display_name`, `context_window`, `max_output_tokens`, `modes`, `capabilities`, `pricing`, …) it is merged onto the remote ai-model-list entry during enrichment, with operator values winning per-field. Primary use case: advertising context windows, capabilities, and pricing for local models (Ollama) and other custom endpoints whose IDs are not in the upstream registry.
+- **Intelligent routing:** When a request does not pin a provider and multiple configured providers serve the same model ID, the gateway scores candidates and routes to the best one (single-provider configs are unaffected — one candidate means no decision). `MODEL_ROUTING_STRATEGY` (`balanced` default; also `cost_only`, `latency_only`, `first_fit`), `MODEL_ROUTING_STRATEGY_WEIGHTS` (`0.6,0.4` = cost_weight,latency_weight; balanced only), `MODEL_ROUTING_MAX_ERROR_RATE` (`0.5`; candidates at/above this smoothed error rate are filtered before scoring). Latency is tracked per provider as in-memory EWMA (alpha 0.1 p50 / 0.05 p99 / 0.2 error-rate), never persisted. Per-request overrides via request headers `X-GoModel-Routing-Strategy` and `X-GoModel-Routing-Weights` (`cost,latency`); invalid header values are silently dropped (WARN log) and fall back to the global strategy. Explicit provider-qualified requests (`openai-east/gpt-4o`) always bypass intelligent routing.
diff --git a/c.out b/c.out
new file mode 100644
index 00000000..2be45a15
--- /dev/null
+++ b/c.out
@@ -0,0 +1,35 @@
+mode: set
+gomodel/internal/providers/bailian/bailian.go:39.86,48.2 2 1
+gomodel/internal/providers/bailian/bailian.go:52.97,60.2 1 1
+gomodel/internal/providers/bailian/bailian.go:62.51,64.2 1 1
+gomodel/internal/providers/bailian/bailian.go:67.43,69.2 1 1
+gomodel/internal/providers/bailian/bailian.go:73.107,75.2 1 1
+gomodel/internal/providers/bailian/bailian.go:78.108,80.2 1 1
+gomodel/internal/providers/bailian/bailian.go:83.82,85.2 1 1
+gomodel/internal/providers/bailian/bailian.go:88.112,90.2 1 1
+gomodel/internal/providers/bailian/bailian.go:93.108,95.2 1 1
+gomodel/internal/providers/bailian/bailian.go:101.113,103.2 1 1
+gomodel/internal/providers/bailian/bailian.go:106.118,108.2 1 1
+gomodel/internal/providers/bailian/bailian.go:111.106,113.2 1 1
+gomodel/internal/providers/bailian/bailian.go:116.90,118.2 1 1
+gomodel/internal/providers/bailian/bailian.go:121.111,123.2 1 1
+gomodel/internal/providers/bailian/bailian.go:126.93,128.2 1 1
+gomodel/internal/providers/bailian/bailian.go:131.104,133.2 1 1
+gomodel/internal/providers/bailian/bailian.go:136.107,138.16 2 1
+gomodel/internal/providers/bailian/bailian.go:138.16,140.3 1 0
+gomodel/internal/providers/bailian/bailian.go:141.2,142.18 2 1
+gomodel/internal/providers/bailian/bailian.go:146.124,148.16 2 1
+gomodel/internal/providers/bailian/bailian.go:148.16,150.3 1 0
+gomodel/internal/providers/bailian/bailian.go:151.2,151.27 1 1
+gomodel/internal/providers/bailian/bailian.go:151.27,153.3 1 0
+gomodel/internal/providers/bailian/bailian.go:154.2,154.18 1 1
+gomodel/internal/providers/bailian/bailian.go:158.86,160.16 2 1
+gomodel/internal/providers/bailian/bailian.go:160.16,162.3 1 0
+gomodel/internal/providers/bailian/bailian.go:163.2,164.18 2 1
+gomodel/internal/providers/bailian/bailian.go:168.97,170.2 1 1
+gomodel/internal/providers/bailian/bailian.go:173.102,175.2 1 1
+gomodel/internal/providers/bailian/bailian.go:189.67,190.40 1 1
+gomodel/internal/providers/bailian/bailian.go:190.40,192.3 1 1
+gomodel/internal/providers/bailian/bailian.go:194.2,194.82 1 1
+gomodel/internal/providers/bailian/bailian.go:194.82,198.3 3 1
+gomodel/internal/providers/bailian/bailian.go:199.2,206.16 6 1
diff --git a/config/config.example.yaml b/config/config.example.yaml
index ad5b619b..305b7243 100644
--- a/config/config.example.yaml
+++ b/config/config.example.yaml
@@ -356,3 +356,15 @@ providers:
# input_per_mtok: 0
# output_per_mtok: 0
# - Gemma4-31B
+
+# --- Intelligent routing ---
+# When a request does not pin a provider and multiple providers serve the same
+# model, the gateway scores candidates and routes to the best one. Single-provider
+# configs are unaffected (one candidate means no decision needed).
+# Default: strategy=balanced, weights={cost:0.6 latency:0.4}, max_error_rate=0.5
+# router:
+# strategy: balanced # balanced, cost_only, latency_only, first_fit
+# weights:
+# cost: 0.6
+# latency: 0.4
+# max_error_rate: 0.5
diff --git a/config/config.go b/config/config.go
index 59bc5b3a..09fdb1a3 100644
--- a/config/config.go
+++ b/config/config.go
@@ -27,6 +27,7 @@ type Config struct {
Fallback FallbackConfig `yaml:"fallback"`
Workflows WorkflowsConfig `yaml:"workflows"`
Resilience ResilienceConfig `yaml:"resilience"`
+ Router RouterConfig `yaml:"router"`
}
// LoadResult is returned by Load and bundles the application config with the raw
@@ -120,6 +121,7 @@ func buildDefaultConfig() *Config {
Retry: DefaultRetryConfig(),
CircuitBreaker: DefaultCircuitBreakerConfig(),
},
+ Router: DefaultRouterConfig(),
Admin: AdminConfig{
EndpointsEnabled: true,
UIEnabled: true,
@@ -194,6 +196,10 @@ func Load() (*LoadResult, error) {
return nil, err
}
+ if err := ValidateRouterConfig(&cfg.Router); err != nil {
+ return nil, err
+ }
+
return &LoadResult{
Config: cfg,
RawProviders: rawProviders,
diff --git a/config/router.go b/config/router.go
new file mode 100644
index 00000000..4dfdc189
--- /dev/null
+++ b/config/router.go
@@ -0,0 +1,127 @@
+package config
+
+import (
+ "fmt"
+ "math"
+ "strconv"
+ "strings"
+)
+
+// RouterConfig configures intelligent provider selection. When a request does
+// not pin a provider and multiple providers serve the same model, the gateway
+// scores candidates by cost and/or latency using the configured strategy.
+type RouterConfig struct {
+ // Strategy is the default strategy id: "balanced", "cost_only",
+ // "latency_only", or "first_fit". Empty defaults to "balanced".
+ Strategy string `yaml:"strategy" json:"strategy" env:"MODEL_ROUTING_STRATEGY"`
+
+ // Weights tunes the balanced strategy. CostWeight and LatencyWeight are
+ // ignored by other strategies.
+ Weights RouterWeights `yaml:"weights" json:"weights"`
+
+ // WeightsCSV is the env-only form of Weights as "cost,latency"
+ // (e.g. "0.6,0.4"). Parsed into Weights during validation.
+ WeightsCSV string `yaml:"-" json:"-" env:"MODEL_ROUTING_STRATEGY_WEIGHTS"`
+
+ // MaxErrorRate filters candidates at/above this smoothed error ratio before
+ // scoring. Zero falls back to the strategy default (0.5).
+ MaxErrorRate float64 `yaml:"max_error_rate" json:"max_error_rate" env:"MODEL_ROUTING_MAX_ERROR_RATE"`
+}
+
+// RouterWeights tunes the balanced strategy's cost/latency trade-off.
+type RouterWeights struct {
+ Cost float64 `yaml:"cost" json:"cost" env:"MODEL_ROUTING_COST_WEIGHT"`
+ Latency float64 `yaml:"latency" json:"latency" env:"MODEL_ROUTING_LATENCY_WEIGHT"`
+}
+
+// RouterStrategyBalanced and the other built-in strategy ids.
+const (
+ RouterStrategyBalanced = "balanced"
+ RouterStrategyCostOnly = "cost_only"
+ RouterStrategyLatencyOnly = "latency_only"
+ RouterStrategyFirstFit = "first_fit"
+)
+
+// DefaultRouterConfig returns the default router configuration: balanced strategy
+// with 0.6 cost / 0.4 latency weights and a 0.5 max error rate.
+func DefaultRouterConfig() RouterConfig {
+ return RouterConfig{
+ Strategy: RouterStrategyBalanced,
+ Weights: RouterWeights{
+ Cost: 0.6,
+ Latency: 0.4,
+ },
+ MaxErrorRate: 0.5,
+ }
+}
+
+// ValidateRouterConfig normalizes and validates the router config, applying
+// defaults for empty/invalid fields.
+func ValidateRouterConfig(cfg *RouterConfig) error {
+ if cfg.Strategy == "" {
+ cfg.Strategy = RouterStrategyBalanced
+ }
+ strategy := strings.ToLower(strings.TrimSpace(cfg.Strategy))
+ switch strategy {
+ case RouterStrategyBalanced, RouterStrategyCostOnly, RouterStrategyLatencyOnly, RouterStrategyFirstFit:
+ default:
+ return fmt.Errorf("invalid router.strategy %q: must be one of balanced, cost_only, latency_only, first_fit", cfg.Strategy)
+ }
+ cfg.Strategy = strategy
+
+ // MODEL_ROUTING_STRATEGY_WEIGHTS overrides YAML weights when set.
+ if strings.TrimSpace(cfg.WeightsCSV) != "" {
+ cost, lat, err := parseWeightsCSV(cfg.WeightsCSV)
+ if err != nil {
+ return fmt.Errorf("invalid MODEL_ROUTING_STRATEGY_WEIGHTS: %w", err)
+ }
+ cfg.Weights = RouterWeights{Cost: cost, Latency: lat}
+ }
+
+ if math.IsNaN(cfg.Weights.Cost) || math.IsNaN(cfg.Weights.Latency) ||
+ math.IsInf(cfg.Weights.Cost, 0) || math.IsInf(cfg.Weights.Latency, 0) {
+ return fmt.Errorf("router.weights must be finite numbers, got cost=%v latency=%v", cfg.Weights.Cost, cfg.Weights.Latency)
+ }
+ if cfg.Weights.Cost < 0 || cfg.Weights.Latency < 0 {
+ return fmt.Errorf("router.weights must be non-negative, got cost=%v latency=%v", cfg.Weights.Cost, cfg.Weights.Latency)
+ }
+ if cfg.Weights.Cost == 0 && cfg.Weights.Latency == 0 && strategy == RouterStrategyBalanced {
+ cfg.Weights = RouterWeights{Cost: 0.6, Latency: 0.4}
+ }
+ if math.IsNaN(cfg.MaxErrorRate) || math.IsInf(cfg.MaxErrorRate, 0) {
+ return fmt.Errorf("router.max_error_rate must be a finite number, got %v", cfg.MaxErrorRate)
+ }
+ if cfg.MaxErrorRate < 0 || cfg.MaxErrorRate > 1 {
+ return fmt.Errorf("router.max_error_rate must be in [0, 1], got %v", cfg.MaxErrorRate)
+ }
+ return nil
+}
+
+// parseWeightsCSV parses a "cost,latency" string into two non-negative floats.
+func parseWeightsCSV(s string) (cost, latency float64, err error) {
+ parts := strings.Split(s, ",")
+ if len(parts) != 2 {
+ return 0, 0, fmt.Errorf("expected two comma-separated weights, got %q", s)
+ }
+ if cost, err = parseFloatField(parts[0], "cost"); err != nil {
+ return 0, 0, err
+ }
+ if latency, err = parseFloatField(parts[1], "latency"); err != nil {
+ return 0, 0, err
+ }
+ if cost < 0 || latency < 0 {
+ return 0, 0, fmt.Errorf("weights must be non-negative, got %v,%v", cost, latency)
+ }
+ return cost, latency, nil
+}
+
+func parseFloatField(s, name string) (float64, error) {
+ v, err := strconv.ParseFloat(strings.TrimSpace(s), 64)
+ if err != nil {
+ return 0, fmt.Errorf("invalid %s weight %q: %w", name, s, err)
+ }
+ if math.IsNaN(v) || math.IsInf(v, 0) {
+ return 0, fmt.Errorf("invalid %s weight %q: must be a finite number", name, s)
+ }
+ return v, nil
+}
diff --git a/coverage-bailian b/coverage-bailian
new file mode 100644
index 00000000..484e7fa7
--- /dev/null
+++ b/coverage-bailian
@@ -0,0 +1,35 @@
+mode: set
+gomodel/internal/providers/bailian/bailian.go:39.86,48.2 2 1
+gomodel/internal/providers/bailian/bailian.go:52.97,60.2 1 1
+gomodel/internal/providers/bailian/bailian.go:62.51,64.2 1 1
+gomodel/internal/providers/bailian/bailian.go:67.43,69.2 1 1
+gomodel/internal/providers/bailian/bailian.go:73.107,75.2 1 1
+gomodel/internal/providers/bailian/bailian.go:78.108,80.2 1 1
+gomodel/internal/providers/bailian/bailian.go:83.82,85.2 1 1
+gomodel/internal/providers/bailian/bailian.go:88.112,90.2 1 1
+gomodel/internal/providers/bailian/bailian.go:93.108,95.2 1 1
+gomodel/internal/providers/bailian/bailian.go:101.113,103.2 1 1
+gomodel/internal/providers/bailian/bailian.go:106.118,108.2 1 1
+gomodel/internal/providers/bailian/bailian.go:111.106,113.2 1 0
+gomodel/internal/providers/bailian/bailian.go:116.90,118.2 1 0
+gomodel/internal/providers/bailian/bailian.go:121.111,123.2 1 0
+gomodel/internal/providers/bailian/bailian.go:126.93,128.2 1 0
+gomodel/internal/providers/bailian/bailian.go:131.104,133.2 1 0
+gomodel/internal/providers/bailian/bailian.go:136.107,138.16 2 0
+gomodel/internal/providers/bailian/bailian.go:138.16,140.3 1 0
+gomodel/internal/providers/bailian/bailian.go:141.2,142.18 2 0
+gomodel/internal/providers/bailian/bailian.go:146.124,148.16 2 0
+gomodel/internal/providers/bailian/bailian.go:148.16,150.3 1 0
+gomodel/internal/providers/bailian/bailian.go:151.2,151.27 1 0
+gomodel/internal/providers/bailian/bailian.go:151.27,153.3 1 0
+gomodel/internal/providers/bailian/bailian.go:154.2,154.18 1 0
+gomodel/internal/providers/bailian/bailian.go:158.86,160.16 2 0
+gomodel/internal/providers/bailian/bailian.go:160.16,162.3 1 0
+gomodel/internal/providers/bailian/bailian.go:163.2,164.18 2 0
+gomodel/internal/providers/bailian/bailian.go:168.97,170.2 1 0
+gomodel/internal/providers/bailian/bailian.go:173.102,175.2 1 0
+gomodel/internal/providers/bailian/bailian.go:189.67,190.40 1 1
+gomodel/internal/providers/bailian/bailian.go:190.40,192.3 1 1
+gomodel/internal/providers/bailian/bailian.go:194.2,194.82 1 1
+gomodel/internal/providers/bailian/bailian.go:194.82,198.3 3 1
+gomodel/internal/providers/bailian/bailian.go:199.2,206.16 6 1
diff --git a/coverage-bailian.out b/coverage-bailian.out
new file mode 100644
index 00000000..484e7fa7
--- /dev/null
+++ b/coverage-bailian.out
@@ -0,0 +1,35 @@
+mode: set
+gomodel/internal/providers/bailian/bailian.go:39.86,48.2 2 1
+gomodel/internal/providers/bailian/bailian.go:52.97,60.2 1 1
+gomodel/internal/providers/bailian/bailian.go:62.51,64.2 1 1
+gomodel/internal/providers/bailian/bailian.go:67.43,69.2 1 1
+gomodel/internal/providers/bailian/bailian.go:73.107,75.2 1 1
+gomodel/internal/providers/bailian/bailian.go:78.108,80.2 1 1
+gomodel/internal/providers/bailian/bailian.go:83.82,85.2 1 1
+gomodel/internal/providers/bailian/bailian.go:88.112,90.2 1 1
+gomodel/internal/providers/bailian/bailian.go:93.108,95.2 1 1
+gomodel/internal/providers/bailian/bailian.go:101.113,103.2 1 1
+gomodel/internal/providers/bailian/bailian.go:106.118,108.2 1 1
+gomodel/internal/providers/bailian/bailian.go:111.106,113.2 1 0
+gomodel/internal/providers/bailian/bailian.go:116.90,118.2 1 0
+gomodel/internal/providers/bailian/bailian.go:121.111,123.2 1 0
+gomodel/internal/providers/bailian/bailian.go:126.93,128.2 1 0
+gomodel/internal/providers/bailian/bailian.go:131.104,133.2 1 0
+gomodel/internal/providers/bailian/bailian.go:136.107,138.16 2 0
+gomodel/internal/providers/bailian/bailian.go:138.16,140.3 1 0
+gomodel/internal/providers/bailian/bailian.go:141.2,142.18 2 0
+gomodel/internal/providers/bailian/bailian.go:146.124,148.16 2 0
+gomodel/internal/providers/bailian/bailian.go:148.16,150.3 1 0
+gomodel/internal/providers/bailian/bailian.go:151.2,151.27 1 0
+gomodel/internal/providers/bailian/bailian.go:151.27,153.3 1 0
+gomodel/internal/providers/bailian/bailian.go:154.2,154.18 1 0
+gomodel/internal/providers/bailian/bailian.go:158.86,160.16 2 0
+gomodel/internal/providers/bailian/bailian.go:160.16,162.3 1 0
+gomodel/internal/providers/bailian/bailian.go:163.2,164.18 2 0
+gomodel/internal/providers/bailian/bailian.go:168.97,170.2 1 0
+gomodel/internal/providers/bailian/bailian.go:173.102,175.2 1 0
+gomodel/internal/providers/bailian/bailian.go:189.67,190.40 1 1
+gomodel/internal/providers/bailian/bailian.go:190.40,192.3 1 1
+gomodel/internal/providers/bailian/bailian.go:194.2,194.82 1 1
+gomodel/internal/providers/bailian/bailian.go:194.82,198.3 3 1
+gomodel/internal/providers/bailian/bailian.go:199.2,206.16 6 1
diff --git a/gomodel.exe b/gomodel.exe
new file mode 100644
index 00000000..c4524e69
Binary files /dev/null and b/gomodel.exe differ
diff --git a/gomodel.exe~ b/gomodel.exe~
new file mode 100644
index 00000000..cc19afec
Binary files /dev/null and b/gomodel.exe~ differ
diff --git a/internal/core/interfaces.go b/internal/core/interfaces.go
index a9ae55d5..2f030b91 100644
--- a/internal/core/interfaces.go
+++ b/internal/core/interfaces.go
@@ -4,6 +4,7 @@ package core
import (
"context"
"io"
+ "time"
)
// Provider defines the interface for LLM providers
@@ -184,6 +185,30 @@ type AvailabilityChecker interface {
CheckAvailability(ctx context.Context) error
}
+// ProviderStats is an optional interface implemented by provider clients that
+// expose in-memory runtime statistics for latency-aware routing. Routers obtain
+// it via a type assertion; providers without instrumentation simply omit it.
+//
+// Latency values are best-effort EWMA approximations (not true percentiles) and
+// are intended only for relative comparison across providers. All accessors
+// return the zero value (0 / "") before any request has been observed, which
+// strategies interpret as "unknown".
+type ProviderStats interface {
+ // P50Latency returns the approximate P50 round-trip latency, or 0 if unknown.
+ P50Latency() time.Duration
+
+ // P99Latency returns the approximate P99 round-trip latency, or 0 if unknown.
+ P99Latency() time.Duration
+
+ // ErrorRate returns the smoothed error ratio in [0, 1], or 0 if unknown.
+ ErrorRate() float64
+
+ // CircuitState returns the provider's circuit breaker state as a string:
+ // "closed", "open", or "half-open". Implementations without a circuit
+ // breaker return "closed".
+ CircuitState() string
+}
+
// ModelLookup defines the interface for looking up models and their providers.
// This abstraction allows the Router to be decoupled from the concrete ModelRegistry implementation.
type ModelLookup interface {
diff --git a/internal/gateway/inference_prepare.go b/internal/gateway/inference_prepare.go
index 6246201c..855a84de 100644
--- a/internal/gateway/inference_prepare.go
+++ b/internal/gateway/inference_prepare.go
@@ -5,6 +5,7 @@ import (
"strings"
"gomodel/internal/core"
+ router "gomodel/internal/router"
)
// PrepareChatRequest resolves workflow/model policy and applies translated request patching.
@@ -93,6 +94,18 @@ func prepareTranslatedRequest[Req any](
patchNilMessage string,
) (context.Context, Req, *core.Workflow, error) {
ctx = contextWithRequestID(ctx, meta.RequestID)
+
+ // When the caller did not pin a provider (neither via req.Provider nor via
+ // a "provider/model" model string), mark this request as eligible for
+ // intelligent routing. The prepare step resolves a concrete provider, but
+ // the execution path (resolveProvider) uses this flag to know that the
+ // provider was system-assigned and can be re-evaluated by strategy.
+ if provider != nil && strings.TrimSpace(*provider) == "" {
+ if !hasProviderInModel(model) {
+ ctx = router.WithRoutingEligible(ctx)
+ }
+ }
+
workflow, err := o.ensureTranslatedRequestWorkflow(ctx, meta.Workflow, meta.RequestID, meta.Endpoint, model, provider)
if err != nil {
var zero Req
@@ -215,6 +228,20 @@ func currentTranslatedWorkflow(workflow *core.Workflow, endpoint core.EndpointDe
return workflow
}
+// hasProviderInModel reports whether the model string contains a provider
+// prefix (e.g. "deepseek/deepseek-v4-flash"), which means the user explicitly
+// chose a provider via the model syntax even though req.Provider is empty.
+func hasProviderInModel(model *string) bool {
+ if model == nil {
+ return false
+ }
+ m := strings.TrimSpace(*model)
+ firstSlash := strings.IndexByte(m, '/')
+ // A provider prefix is a non-empty segment before a slash that is NOT itself
+ // a plain model-ID containing a slash (e.g. "meta-llama/Llama-3-70b").
+ return firstSlash > 0 && firstSlash < len(m)-1
+}
+
// ApplyResolvedSelector updates request model/provider fields to the resolved selector.
func ApplyResolvedSelector(model, providerHint *string, resolution *core.RequestModelResolution) {
if model == nil || providerHint == nil || resolution == nil {
diff --git a/internal/llmclient/client.go b/internal/llmclient/client.go
index 01394f9a..6021d7c5 100644
--- a/internal/llmclient/client.go
+++ b/internal/llmclient/client.go
@@ -93,14 +93,18 @@ type Client struct {
config Config
headerSetter HeaderSetter
circuitBreaker *circuitBreaker
+ latencyTracker *LatencyTracker
}
-// New creates a new LLM client with the given configuration
-func New(cfg Config, headerSetter HeaderSetter) *Client {
+// newLLMClient is the shared constructor core used by New and NewWithHTTPClient.
+// Every client gets a LatencyTracker so latency-aware routing always has a
+// stats source to assert, even before the first request.
+func newLLMClient(httpClient *http.Client, cfg Config, headerSetter HeaderSetter) *Client {
c := &Client{
- httpClient: httpclient.NewDefaultHTTPClient(),
- config: cfg,
- headerSetter: headerSetter,
+ httpClient: httpClient,
+ config: cfg,
+ headerSetter: headerSetter,
+ latencyTracker: NewLatencyTracker(),
}
if cfg.CircuitBreaker.FailureThreshold > 0 {
@@ -114,23 +118,14 @@ func New(cfg Config, headerSetter HeaderSetter) *Client {
return c
}
+// New creates a new LLM client with the given configuration
+func New(cfg Config, headerSetter HeaderSetter) *Client {
+ return newLLMClient(httpclient.NewDefaultHTTPClient(), cfg, headerSetter)
+}
+
// NewWithHTTPClient creates a new LLM client with a custom HTTP client
func NewWithHTTPClient(httpClient *http.Client, cfg Config, headerSetter HeaderSetter) *Client {
- c := &Client{
- httpClient: httpClient,
- config: cfg,
- headerSetter: headerSetter,
- }
-
- if cfg.CircuitBreaker.FailureThreshold > 0 {
- c.circuitBreaker = newCircuitBreaker(
- cfg.CircuitBreaker.FailureThreshold,
- cfg.CircuitBreaker.SuccessThreshold,
- cfg.CircuitBreaker.Timeout,
- )
- }
-
- return c
+ return newLLMClient(httpClient, cfg, headerSetter)
}
// SetBaseURL updates the base URL (thread-safe)
@@ -235,9 +230,37 @@ func (c *Client) finishRequest(scope requestScope, statusCode int, err error) {
// Use this whenever a code path returns from one of the public Do* methods.
func (c *Client) completeScope(scope requestScope, statusCode int, err, cbErr error) {
c.recordCircuitBreakerCompletion(statusCode, cbErr)
+ c.latencyTracker.Record(time.Since(scope.startedAt), err != nil)
c.finishRequest(scope, statusCode, err)
}
+// P50Latency returns the approximate P50 round-trip latency for this client's
+// provider, satisfying core.ProviderStats. Returns 0 before any request.
+func (c *Client) P50Latency() time.Duration {
+ return c.latencyTracker.P50()
+}
+
+// P99Latency returns the approximate P99 round-trip latency for this client's
+// provider, satisfying core.ProviderStats. Returns 0 before any request.
+func (c *Client) P99Latency() time.Duration {
+ return c.latencyTracker.P99()
+}
+
+// ErrorRate returns the smoothed error ratio in [0, 1], satisfying
+// core.ProviderStats. Returns 0 before any request.
+func (c *Client) ErrorRate() float64 {
+ return c.latencyTracker.ErrorRate()
+}
+
+// CircuitState returns the current circuit-breaker state, satisfying
+// core.ProviderStats. Returns "closed" when no breaker is configured.
+func (c *Client) CircuitState() string {
+ if c.circuitBreaker == nil {
+ return "closed"
+ }
+ return c.circuitBreaker.State()
+}
+
// failAfterRetries handles the "exhausted retries with no captured error"
// fallback shared by the retrying entry points (DoRaw, DoPassthrough). The
// returned error is also reported through the scope.
diff --git a/internal/llmclient/latency.go b/internal/llmclient/latency.go
new file mode 100644
index 00000000..45dd1e2b
--- /dev/null
+++ b/internal/llmclient/latency.go
@@ -0,0 +1,118 @@
+package llmclient
+
+import (
+ "sync"
+ "time"
+)
+
+// defaultLatencyAlpha is the EWMA smoothing factor used for p50 latency.
+// A smaller alpha makes the average more responsive to recent changes.
+const defaultLatencyAlpha = 0.1
+
+// ewma is a concurrency-safe exponentially weighted moving average.
+// The first sample initializes the value directly (tracked by the
+// initialized bool so that a genuine zero-valued sample is not
+// mistaken for the uninitialized state); subsequent samples are
+// smoothed using value = alpha*sample + (1-alpha)*value.
+type ewma struct {
+ mu sync.Mutex
+ value float64
+ alpha float64
+ initialized bool
+}
+
+// newEWMA returns an EWMA with the given smoothing factor alpha.
+// alpha must be in the range (0, 1]; values closer to 1 weight recent
+// samples more heavily.
+func newEWMA(alpha float64) *ewma {
+ return &ewma{alpha: alpha}
+}
+
+// Add incorporates a new sample into the moving average. The first sample
+// seeds the value directly so the average is not biased toward zero.
+func (e *ewma) Add(sample float64) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if !e.initialized {
+ e.value = sample
+ e.initialized = true
+ return
+ }
+ e.value = e.alpha*sample + (1-e.alpha)*e.value
+}
+
+// Value returns the current EWMA value. Returns zero before any sample
+// has been added.
+func (e *ewma) Value() float64 {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.value
+}
+
+// LatencyTracker maintains EWMA-based latency and error-rate statistics
+// for a provider or model. It is intended to feed latency-aware routing
+// decisions with cheap, in-memory, concurrency-safe signals.
+//
+// Statistics are kept in memory only: there is no persistence and no
+// periodic cleanup. Callers should call Record once per completed request.
+//
+// Smoothing factors:
+// - p50 uses alpha=0.1 for a responsive mid-latency signal
+// - p99 uses alpha=0.05 for a smoother tail-latency signal
+// - errorRate uses alpha=0.2 to react quickly to error bursts
+//
+// Note: the EWMA-based p50/p99 are approximate percentiles rather than
+// true percentile estimators; they are sufficient for relative comparison
+// across providers and not intended for SLA reporting.
+type LatencyTracker struct {
+ p50 *ewma
+ p99 *ewma
+ errorRate *ewma
+}
+
+// NewLatencyTracker returns a LatencyTracker with default smoothing factors
+// tuned for latency-aware routing.
+func NewLatencyTracker() *LatencyTracker {
+ return &LatencyTracker{
+ p50: newEWMA(defaultLatencyAlpha),
+ p99: newEWMA(defaultLatencyAlpha / 2),
+ errorRate: newEWMA(0.2),
+ }
+}
+
+// Record observes the duration and error status of a completed request.
+// duration is the total request latency; isError indicates whether the
+// request was considered failed for routing purposes.
+func (t *LatencyTracker) Record(duration time.Duration, isError bool) {
+ ms := float64(duration.Milliseconds())
+ t.p50.Add(ms)
+ t.p99.Add(ms)
+ t.errorRate.Add(boolToFloat64(isError))
+}
+
+// P50 returns the approximate p50 latency as a time.Duration.
+// Returns zero before any sample has been recorded.
+func (t *LatencyTracker) P50() time.Duration {
+ return time.Duration(t.p50.Value()) * time.Millisecond
+}
+
+// P99 returns the approximate p99 latency as a time.Duration.
+// Returns zero before any sample has been recorded.
+func (t *LatencyTracker) P99() time.Duration {
+ return time.Duration(t.p99.Value()) * time.Millisecond
+}
+
+// ErrorRate returns the EWMA-smoothed error rate in the range [0, 1].
+// Returns zero before any sample has been recorded.
+func (t *LatencyTracker) ErrorRate() float64 {
+ return t.errorRate.Value()
+}
+
+// boolToFloat64 converts a boolean to 1.0 (true) or 0.0 (false) for use
+// as an EWMA sample.
+func boolToFloat64(b bool) float64 {
+ if b {
+ return 1
+ }
+ return 0
+}
diff --git a/internal/llmclient/latency_test.go b/internal/llmclient/latency_test.go
new file mode 100644
index 00000000..1bc0a499
--- /dev/null
+++ b/internal/llmclient/latency_test.go
@@ -0,0 +1,65 @@
+package llmclient
+
+import (
+ "testing"
+ "time"
+)
+
+func TestLatencyTracker_FirstSampleSeedsValue(t *testing.T) {
+ tr := NewLatencyTracker()
+ tr.Record(100*time.Millisecond, false)
+ if got := tr.P50(); got != 100*time.Millisecond {
+ t.Fatalf("first sample should seed value, got %v", got)
+ }
+}
+
+func TestLatencyTracker_ConvergesTowardRecentSamples(t *testing.T) {
+ tr := NewLatencyTracker()
+ tr.Record(1000*time.Millisecond, false) // seed
+ // Feed many short samples; EWMA should move well below the seed.
+ for i := 0; i < 50; i++ {
+ tr.Record(10*time.Millisecond, false)
+ }
+ if got := tr.P50(); got > 200*time.Millisecond {
+ t.Fatalf("expected p50 to converge toward 10ms, got %v", got)
+ }
+}
+
+func TestLatencyTracker_ErrorRateReflectsFailures(t *testing.T) {
+ tr := NewLatencyTracker()
+ for i := 0; i < 5; i++ {
+ tr.Record(10*time.Millisecond, i%2 == 0) // 3 errors of 5
+ }
+ rate := tr.ErrorRate()
+ if rate < 0.3 || rate > 0.9 {
+ t.Fatalf("expected error rate near 0.4-0.6, got %v", rate)
+ }
+}
+
+func TestLatencyTracker_P50BeforeSamples(t *testing.T) {
+ tr := NewLatencyTracker()
+ if got := tr.P50(); got != 0 {
+ t.Fatalf("expected zero p50 before samples, got %v", got)
+ }
+ if got := tr.ErrorRate(); got != 0 {
+ t.Fatalf("expected zero error rate before samples, got %v", got)
+ }
+}
+
+func TestLatencyTracker_ConcurrentSafe(t *testing.T) {
+ tr := NewLatencyTracker()
+ done := make(chan struct{})
+ for g := 0; g < 10; g++ {
+ go func() {
+ for i := 0; i < 100; i++ {
+ tr.Record(time.Millisecond, false)
+ _ = tr.P50()
+ _ = tr.ErrorRate()
+ }
+ done <- struct{}{}
+ }()
+ }
+ for g := 0; g < 10; g++ {
+ <-done
+ }
+}
diff --git a/internal/providers/init.go b/internal/providers/init.go
index 2f9fe573..f4987014 100644
--- a/internal/providers/init.go
+++ b/internal/providers/init.go
@@ -147,7 +147,9 @@ func Init(ctx context.Context, result *config.LoadResult, factory *ProviderFacto
}
stopRefresh := registry.StartBackgroundRefresh(refreshInterval, modelListURL)
- router, err := NewRouter(registry)
+ strategyRegistry := buildStrategyRegistry(result.Config.Router)
+
+ router, err := NewRouter(registry, WithStrategyRegistry(strategyRegistry))
if err != nil {
stopRefresh()
modelCache.Close()
diff --git a/internal/providers/registry.go b/internal/providers/registry.go
index abb60c38..9307e9b0 100644
--- a/internal/providers/registry.go
+++ b/internal/providers/registry.go
@@ -66,6 +66,10 @@ type ModelRegistry struct {
// invalidateSortedCaches whenever the catalog changes. Protected by mu.
qualifiedByName map[string]core.ModelSelector
qualifiedByType map[string]core.ModelSelector
+
+ // providerRegistrationOrder maintains provider-name insertion order so
+ // ListProvidersForModel returns entries in registration (first-wins) order.
+ providerRegistrationOrder []string
}
type metadataEnrichmentStats struct {
@@ -92,6 +96,7 @@ func NewModelRegistry() *ModelRegistry {
providerRuntime: make(map[string]providerRuntimeState),
refreshCh: make(chan struct{}, 1),
configuredProviderModelsMode: config.ConfiguredProviderModelsModeFallback,
+ providerRegistrationOrder: make([]string, 0),
}
}
@@ -279,6 +284,11 @@ func (r *ModelRegistry) RegisterProviderWithNameAndType(provider core.Provider,
r.providerTypes[provider] = providerType
r.providerNames[provider] = providerName
+ // Track registration order for first-wins routing and ListProvidersForModel.
+ if !slices.Contains(r.providerRegistrationOrder, providerName) {
+ r.providerRegistrationOrder = append(r.providerRegistrationOrder, providerName)
+ }
+
state := r.providerRuntime[providerName]
state.registered = true
r.providerRuntime[providerName] = state
@@ -701,6 +711,39 @@ type ModelWithProvider struct {
Selector string `json:"selector"`
}
+// ListProvidersForModel returns every provider that serves modelID, in
+// registration (first-wins) order. Each entry includes the model metadata and
+// provider info so callers can inspect pricing, latency, and other signals.
+// Returns an empty slice when the model is not found.
+func (r *ModelRegistry) ListProvidersForModel(modelID string) []ModelWithProvider {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ modelID = strings.TrimSpace(modelID)
+ if modelID == "" || len(r.providerRegistrationOrder) == 0 {
+ return []ModelWithProvider{}
+ }
+
+ result := make([]ModelWithProvider, 0, len(r.providers))
+ for _, pName := range r.providerRegistrationOrder {
+ pModels, ok := r.modelsByProvider[pName]
+ if !ok {
+ continue
+ }
+ info, exists := providerModelInfo(pModels, modelID, modelID)
+ if !exists {
+ continue
+ }
+ result = append(result, ModelWithProvider{
+ Model: info.Model,
+ ProviderType: info.ProviderType,
+ ProviderName: pName,
+ Selector: qualifyPublicModelID(pName, modelID),
+ })
+ }
+ return result
+}
+
// ListModelsWithProvider returns all provider-backed models with provider metadata,
// sorted by public selector.
// The sorted slice is cached and rebuilt only when the underlying models change.
diff --git a/internal/providers/router.go b/internal/providers/router.go
index 84eaacb9..5109b237 100644
--- a/internal/providers/router.go
+++ b/internal/providers/router.go
@@ -12,16 +12,30 @@ import (
"strings"
"gomodel/internal/core"
+ router "gomodel/internal/router"
)
// ErrRegistryNotInitialized is returned when the router is used before the registry has any models.
var ErrRegistryNotInitialized = fmt.Errorf("model registry has no models: ensure Initialize() or LoadFromCache() is called before using the router")
+// RouterOption configures a Router at construction.
+type RouterOption func(*Router)
+
+// WithStrategyRegistry enables intelligent provider selection: when a request
+// does not pin a provider and multiple providers serve the same model, the
+// registry's configured strategy (or a per-request override on the context)
+// selects the best candidate by cost and/or latency. A nil registry keeps the
+// historical first-wins behaviour.
+func WithStrategyRegistry(registry *router.StrategyRegistry) RouterOption {
+ return func(r *Router) { r.strategyRegistry = registry }
+}
+
// Router routes requests to the appropriate provider based on the model lookup.
// It uses a dynamic model-to-provider mapping that is populated at startup
// by fetching available models from each provider's /models endpoint.
type Router struct {
- lookup core.ModelLookup
+ lookup core.ModelLookup
+ strategyRegistry *router.StrategyRegistry
}
type providerTypeRegistry interface {
@@ -52,6 +66,13 @@ type modelWithProviderLister interface {
ListModelsWithProvider() []ModelWithProvider
}
+// providerModelForModelLister is an optional O(k) interface for enumerating
+// providers that serve a specific model ID. Implementations with a registration
+// order can return candidates in first-wins order.
+type providerModelForModelLister interface {
+ ListProvidersForModel(modelID string) []ModelWithProvider
+}
+
// qualifiedSelectorResolver is an optional fast path for qualified selector
// resolution. Implementations resolve a "/" pair via an O(1)
// index instead of scanning the catalog. A false result means the caller should
@@ -70,14 +91,18 @@ func registryUnavailableError(err error) error {
// NewRouter creates a new provider router with a model lookup.
// The lookup must be initialized (via Initialize() or LoadFromCache()) before using the router.
-// Returns an error if the lookup is nil.
-func NewRouter(lookup core.ModelLookup) (*Router, error) {
+// Returns an error if the lookup is nil. Options may enable intelligent routing.
+func NewRouter(lookup core.ModelLookup, opts ...RouterOption) (*Router, error) {
if lookup == nil {
return nil, fmt.Errorf("lookup cannot be nil")
}
- return &Router{
- lookup: lookup,
- }, nil
+ r := &Router{lookup: lookup}
+ for _, opt := range opts {
+ if opt != nil {
+ opt(r)
+ }
+ }
+ return r, nil
}
// checkReady verifies the lookup has models available.
@@ -281,9 +306,141 @@ func (r *Router) resolveProvider(ctx context.Context, model, providerHint string
if p == nil {
return nil, core.ModelSelector{}, core.NewNotFoundError("model not found: " + lookupModel)
}
+
+ // Intelligent routing: when the caller did not pin a provider (or the provider
+ // was assigned by the system during the prepare phase for an unqualified request)
+ // and more than one provider serves this model, let the configured strategy
+ // choose the best candidate. This is a no-op when no strategy registry is
+ // configured (historical behaviour).
+ eligible := (strings.TrimSpace(providerHint) == "" || router.IsRoutingEligible(ctx)) && !strings.Contains(model, "/")
+ if eligible && r.strategyRegistry != nil {
+ if chosen, newSelector, ok := r.applyStrategy(ctx, selector); ok {
+ selector = newSelector
+ p = chosen
+ }
+ }
+
return p, selector, nil
}
+// applyStrategy runs the configured routing strategy over all providers that
+// serve the resolved model and returns the chosen provider and selector. It
+// returns ok=false when intelligent routing does not apply (single candidate,
+// no strategy, all filtered, or the chosen candidate equals the current one).
+func (r *Router) applyStrategy(ctx context.Context, selector core.ModelSelector) (core.Provider, core.ModelSelector, bool) {
+ candidates := r.collectCandidates(selector.Model)
+ if len(candidates) < 2 {
+ return nil, selector, false
+ }
+
+ strategy, valid := r.strategyRegistry.Resolve(ctx)
+ if !valid {
+ slog.WarnContext(ctx, "routing strategy override rejected, falling back to default strategy",
+ "model", selector.Model, "default", r.strategyRegistry.DefaultID())
+ strategy = r.strategyRegistry.New(r.strategyRegistry.DefaultID())
+ if strategy == nil {
+ return nil, selector, false
+ }
+ }
+ strategy = r.applyWeightsOverride(ctx, strategy)
+
+ chosen, err := strategy.Select(ctx, candidates)
+ if err != nil {
+ slog.WarnContext(ctx, "routing strategy selected no candidate, falling back to first candidate",
+ "model", selector.Model, "strategy", strategy.Name(), "error", err)
+ return nil, selector, false
+ }
+
+ newSelector := core.ModelSelector{Provider: chosen.ProviderName, Model: chosen.ModelID}
+ if newSelector.QualifiedModel() == selector.QualifiedModel() {
+ return nil, selector, false
+ }
+ chosenProvider := r.lookup.GetProvider(newSelector.QualifiedModel())
+ if chosenProvider == nil {
+ return nil, selector, false
+ }
+ slog.DebugContext(ctx, "intelligent routing selected provider",
+ "model", selector.Model, "provider", chosen.ProviderName, "strategy", strategy.Name())
+ return chosenProvider, newSelector, true
+}
+
+// applyWeightsOverride returns a balanced strategy with per-request cost/latency
+// weights applied when the context carries a weights override. Non-balanced
+// strategies pass through unchanged.
+func (r *Router) applyWeightsOverride(ctx context.Context, strategy router.RoutingStrategy) router.RoutingStrategy {
+ override, ok := router.WeightsOverrideFromContext(ctx)
+ if !ok {
+ return strategy
+ }
+ balanced, ok := strategy.(*router.BalancedStrategy)
+ if !ok {
+ return strategy
+ }
+ clone := *balanced
+ clone.CostWeight = override.Cost
+ clone.LatencyWeight = override.Latency
+ return &clone
+}
+
+// collectCandidates enumerates every provider that serves modelID, in
+// registration (first-wins) order, attaching pricing and runtime statistics
+// where the provider exposes them. Uses the O(k) ListProvidersForModel path
+// when available, falling back to an O(N) scan over the full catalog.
+func (r *Router) collectCandidates(modelID string) []router.ProviderCandidate {
+ modelID = strings.TrimSpace(modelID)
+ if modelID == "" {
+ return nil
+ }
+
+ // Fast path: O(k) per-model lookup.
+ if indexed, ok := r.lookup.(providerModelForModelLister); ok {
+ return r.buildCandidatesFromEntries(indexed.ListProvidersForModel(modelID))
+ }
+
+ // Fallback: O(N) scan over the full catalog.
+ models, ok := r.lookup.(modelWithProviderLister)
+ if !ok {
+ return nil
+ }
+ var entries []ModelWithProvider
+ for _, entry := range models.ListModelsWithProvider() {
+ if strings.TrimSpace(entry.Model.ID) == modelID {
+ entries = append(entries, entry)
+ }
+ }
+ return r.buildCandidatesFromEntries(entries)
+}
+
+// buildCandidatesFromEntries converts ModelWithProvider entries into
+// ProviderCandidate values, attaching runtime stats via type assertion.
+func (r *Router) buildCandidatesFromEntries(entries []ModelWithProvider) []router.ProviderCandidate {
+ candidates := make([]router.ProviderCandidate, 0, len(entries))
+ for _, entry := range entries {
+ providerName := strings.TrimSpace(entry.ProviderName)
+ qualified := core.ModelSelector{Provider: providerName, Model: entry.Model.ID}.QualifiedModel()
+ provider := r.lookup.GetProvider(qualified)
+ if provider == nil {
+ continue
+ }
+ candidate := router.ProviderCandidate{
+ Provider: provider,
+ ProviderName: providerName,
+ ProviderType: entry.ProviderType,
+ ModelID: entry.Model.ID,
+ }
+ if entry.Model.Metadata != nil {
+ candidate.Pricing = entry.Model.Metadata.Pricing
+ }
+ if stats, ok := provider.(core.ProviderStats); ok {
+ candidate.Latency = stats.P50Latency()
+ candidate.ErrorRate = stats.ErrorRate()
+ candidate.CircuitState = stats.CircuitState()
+ }
+ candidates = append(candidates, candidate)
+ }
+ return candidates
+}
+
func (r *Router) refreshProviderModelsForRequest(ctx context.Context, requested core.RequestedModelSelector) (bool, error) {
refresher, ok := r.lookup.(providerModelRefresher)
if !ok {
diff --git a/internal/providers/router_strategy_integration_test.go b/internal/providers/router_strategy_integration_test.go
new file mode 100644
index 00000000..33c600d2
--- /dev/null
+++ b/internal/providers/router_strategy_integration_test.go
@@ -0,0 +1,315 @@
+package providers
+
+import (
+ "context"
+ "net/http"
+ "testing"
+ "time"
+
+ router "gomodel/internal/router"
+
+ "gomodel/internal/core"
+)
+
+// statsMockProvider is a mockProvider that also implements core.ProviderStats,
+// letting intelligent routing observe latency/error/circuit signals.
+type statsMockProvider struct {
+ mockProvider
+ p50 time.Duration
+ p99 time.Duration
+ errorRate float64
+ circuitState string
+}
+
+func (s *statsMockProvider) P50Latency() time.Duration { return s.p50 }
+func (s *statsMockProvider) P99Latency() time.Duration { return s.p99 }
+func (s *statsMockProvider) ErrorRate() float64 { return s.errorRate }
+func (s *statsMockProvider) CircuitState() string {
+ if s.circuitState == "" {
+ return "closed"
+ }
+ return s.circuitState
+}
+
+func floatPtr(v float64) *float64 { return &v }
+
+// registerModelWithMetadata adds a model entry carrying pricing metadata so
+// cost-aware strategies have data to score.
+func registerModelWithMetadata(t *testing.T, registry *ModelRegistry, provider core.Provider, providerName, providerType, modelID string, pricing *core.ModelPricing) {
+ t.Helper()
+ registry.RegisterProviderWithNameAndType(provider, providerName, providerType)
+ info := &ModelInfo{
+ Model: core.Model{
+ ID: modelID,
+ Object: "model",
+ Metadata: &core.ModelMetadata{Pricing: pricing},
+ },
+ Provider: provider,
+ ProviderName: providerName,
+ ProviderType: providerType,
+ }
+ if registry.modelsByProvider[providerName] == nil {
+ registry.modelsByProvider[providerName] = make(map[string]*ModelInfo)
+ }
+ registry.modelsByProvider[providerName][modelID] = info
+ if _, exists := registry.models[modelID]; !exists {
+ registry.models[modelID] = info
+ }
+}
+
+func TestRouter_IntelligentRouting_PicksCheaperProvider(t *testing.T) {
+ registry := NewModelRegistry()
+ cheap := &statsMockProvider{
+ mockProvider: mockProvider{name: "openai-east", chatResponse: &core.ChatResponse{ID: "east-resp", Model: "gpt-4o"}},
+ p50: 100 * time.Millisecond,
+ }
+ pricey := &statsMockProvider{
+ mockProvider: mockProvider{name: "openai-west", chatResponse: &core.ChatResponse{ID: "west-resp", Model: "gpt-4o"}},
+ p50: 100 * time.Millisecond,
+ }
+ registerModelWithMetadata(t, registry, cheap, "openai-east", "openai", "gpt-4o",
+ &core.ModelPricing{InputPerMtok: floatPtr(1), OutputPerMtok: floatPtr(2)})
+ registerModelWithMetadata(t, registry, pricey, "openai-west", "openai", "gpt-4o",
+ &core.ModelPricing{InputPerMtok: floatPtr(10), OutputPerMtok: floatPtr(20)})
+
+ r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry()))
+
+ resp, err := r.ChatCompletion(context.Background(), &core.ChatRequest{Model: "gpt-4o"})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if resp.ID != "east-resp" {
+ t.Fatalf("expected cheaper east provider, got %s", resp.ID)
+ }
+}
+
+func TestRouter_IntelligentRouting_PicksFasterProvider(t *testing.T) {
+ registry := NewModelRegistry()
+ slow := &statsMockProvider{
+ mockProvider: mockProvider{name: "openai-slow", chatResponse: &core.ChatResponse{ID: "slow-resp", Model: "gpt-4o"}},
+ p50: 500 * time.Millisecond,
+ }
+ fast := &statsMockProvider{
+ mockProvider: mockProvider{name: "openai-fast", chatResponse: &core.ChatResponse{ID: "fast-resp", Model: "gpt-4o"}},
+ p50: 50 * time.Millisecond,
+ }
+ // Equal pricing so latency_only routing decides.
+ pricing := &core.ModelPricing{InputPerMtok: floatPtr(5), OutputPerMtok: floatPtr(5)}
+ registerModelWithMetadata(t, registry, slow, "openai-slow", "openai", "gpt-4o", pricing)
+ registerModelWithMetadata(t, registry, fast, "openai-fast", "openai", "gpt-4o", pricing)
+
+ r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry()))
+
+ resp, err := r.ChatCompletion(context.Background(), &core.ChatRequest{Model: "gpt-4o"})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if resp.ID != "fast-resp" {
+ t.Fatalf("expected faster provider, got %s", resp.ID)
+ }
+}
+
+func TestRouter_IntelligentRouting_HeaderStrategyOverride(t *testing.T) {
+ registry := NewModelRegistry()
+ // Two providers: cheaper-but-slower vs pricier-but-faster.
+ cheapSlow := &statsMockProvider{
+ mockProvider: mockProvider{name: "openai-cheap", chatResponse: &core.ChatResponse{ID: "cheap-resp", Model: "gpt-4o"}},
+ p50: 500 * time.Millisecond,
+ }
+ priceyFast := &statsMockProvider{
+ mockProvider: mockProvider{name: "openai-fast", chatResponse: &core.ChatResponse{ID: "fast-resp", Model: "gpt-4o"}},
+ p50: 50 * time.Millisecond,
+ }
+ registerModelWithMetadata(t, registry, cheapSlow, "openai-cheap", "openai", "gpt-4o",
+ &core.ModelPricing{InputPerMtok: floatPtr(1), OutputPerMtok: floatPtr(1)})
+ registerModelWithMetadata(t, registry, priceyFast, "openai-fast", "openai", "gpt-4o",
+ &core.ModelPricing{InputPerMtok: floatPtr(10), OutputPerMtok: floatPtr(10)})
+
+ r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry()))
+
+ // Default balanced should pick cheap-slow (cost dominates at 0.6/0.4).
+ defaultResp, err := r.ChatCompletion(context.Background(), &core.ChatRequest{Model: "gpt-4o"})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if defaultResp.ID != "cheap-resp" {
+ t.Fatalf("default balanced expected cheap, got %s", defaultResp.ID)
+ }
+
+ // Override to latency_only should pick the faster provider.
+ ctx := router.WithStrategyOverride(context.Background(), "latency_only")
+ latencyResp, err := r.ChatCompletion(ctx, &core.ChatRequest{Model: "gpt-4o"})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if latencyResp.ID != "fast-resp" {
+ t.Fatalf("latency_only override expected fast, got %s", latencyResp.ID)
+ }
+}
+
+func TestRouter_IntelligentRouting_InvalidOverrideFallsBack(t *testing.T) {
+ registry := NewModelRegistry()
+ cheap := &statsMockProvider{
+ mockProvider: mockProvider{name: "openai-east", chatResponse: &core.ChatResponse{ID: "east-resp", Model: "gpt-4o"}},
+ }
+ pricey := &statsMockProvider{
+ mockProvider: mockProvider{name: "openai-west", chatResponse: &core.ChatResponse{ID: "west-resp", Model: "gpt-4o"}},
+ }
+ registerModelWithMetadata(t, registry, cheap, "openai-east", "openai", "gpt-4o",
+ &core.ModelPricing{InputPerMtok: floatPtr(1), OutputPerMtok: floatPtr(2)})
+ registerModelWithMetadata(t, registry, pricey, "openai-west", "openai", "gpt-4o",
+ &core.ModelPricing{InputPerMtok: floatPtr(10), OutputPerMtok: floatPtr(20)})
+
+ r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry()))
+
+ ctx := router.WithStrategyOverride(context.Background(), "not-a-real-strategy")
+ resp, err := r.ChatCompletion(ctx, &core.ChatRequest{Model: "gpt-4o"})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ // Invalid override falls back to default balanced, which picks cheap.
+ if resp.ID != "east-resp" {
+ t.Fatalf("invalid override should fall back to default (cheap), got %s", resp.ID)
+ }
+}
+
+func TestRouter_IntelligentRouting_ExplicitProviderHintBypasses(t *testing.T) {
+ registry := NewModelRegistry()
+ cheap := &statsMockProvider{
+ mockProvider: mockProvider{name: "openai-east", chatResponse: &core.ChatResponse{ID: "east-resp", Model: "gpt-4o"}},
+ }
+ pricey := &statsMockProvider{
+ mockProvider: mockProvider{name: "openai-west", chatResponse: &core.ChatResponse{ID: "west-resp", Model: "gpt-4o"}},
+ }
+ registerModelWithMetadata(t, registry, cheap, "openai-east", "openai", "gpt-4o",
+ &core.ModelPricing{InputPerMtok: floatPtr(1), OutputPerMtok: floatPtr(2)})
+ registerModelWithMetadata(t, registry, pricey, "openai-west", "openai", "gpt-4o",
+ &core.ModelPricing{InputPerMtok: floatPtr(10), OutputPerMtok: floatPtr(20)})
+
+ r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry()))
+
+ // Explicit provider hint must route to the named provider regardless of cost.
+ resp, err := r.ChatCompletion(context.Background(), &core.ChatRequest{
+ Model: "gpt-4o",
+ Provider: "openai-west",
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if resp.ID != "west-resp" {
+ t.Fatalf("explicit provider hint should bypass routing, got %s", resp.ID)
+ }
+}
+
+func TestRouter_IntelligentRouting_SingleCandidateUnaffected(t *testing.T) {
+ registry := NewModelRegistry()
+ solo := &statsMockProvider{
+ mockProvider: mockProvider{name: "openai", chatResponse: &core.ChatResponse{ID: "solo-resp", Model: "gpt-4o"}},
+ }
+ registerModelWithMetadata(t, registry, solo, "openai", "openai", "gpt-4o", nil)
+
+ r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry()))
+
+ resp, err := r.ChatCompletion(context.Background(), &core.ChatRequest{Model: "gpt-4o"})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if resp.ID != "solo-resp" {
+ t.Fatalf("single candidate should route normally, got %s", resp.ID)
+ }
+}
+
+func TestRouter_IntelligentRouting_WeightsOverride(t *testing.T) {
+ registry := NewModelRegistry()
+ cheapSlow := &statsMockProvider{
+ mockProvider: mockProvider{name: "cheap-slow", chatResponse: &core.ChatResponse{ID: "cheap", Model: "gpt-4o"}},
+ p50: 500 * time.Millisecond,
+ }
+ priceyFast := &statsMockProvider{
+ mockProvider: mockProvider{name: "pricey-fast", chatResponse: &core.ChatResponse{ID: "fast", Model: "gpt-4o"}},
+ p50: 50 * time.Millisecond,
+ }
+ registerModelWithMetadata(t, registry, cheapSlow, "cheap-slow", "openai", "gpt-4o",
+ &core.ModelPricing{InputPerMtok: floatPtr(1), OutputPerMtok: floatPtr(1)})
+ registerModelWithMetadata(t, registry, priceyFast, "pricey-fast", "openai", "gpt-4o",
+ &core.ModelPricing{InputPerMtok: floatPtr(10), OutputPerMtok: floatPtr(10)})
+
+ r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry()))
+
+ t.Run("weighted cost=1 latency=0 picks cheap", func(t *testing.T) {
+ ctx := router.WithWeightsOverride(context.Background(), router.WeightsOverride{Cost: 1.0, Latency: 0.0})
+ resp, err := r.ChatCompletion(ctx, &core.ChatRequest{Model: "gpt-4o"})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if resp.ID != "cheap" {
+ t.Fatalf("cost=1 latency=0 expected cheap, got %s", resp.ID)
+ }
+ })
+
+ t.Run("weighted cost=0 latency=1 picks fast", func(t *testing.T) {
+ ctx := router.WithWeightsOverride(context.Background(), router.WeightsOverride{Cost: 0.0, Latency: 1.0})
+ resp, err := r.ChatCompletion(ctx, &core.ChatRequest{Model: "gpt-4o"})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if resp.ID != "fast" {
+ t.Fatalf("cost=0 latency=1 expected fast, got %s", resp.ID)
+ }
+ })
+}
+
+func TestRouter_IntelligentRouting_ModelSyntaxProviderNotOverridden(t *testing.T) {
+ registry := NewModelRegistry()
+ cheap := &statsMockProvider{
+ mockProvider: mockProvider{name: "cheap", chatResponse: &core.ChatResponse{ID: "cheap", Model: "gpt-4o"}},
+ }
+ pricey := &statsMockProvider{
+ mockProvider: mockProvider{name: "pricey", chatResponse: &core.ChatResponse{ID: "pricey", Model: "gpt-4o"}},
+ }
+ registerModelWithMetadata(t, registry, cheap, "cheap", "openai", "gpt-4o",
+ &core.ModelPricing{InputPerMtok: floatPtr(1), OutputPerMtok: floatPtr(1)})
+ registerModelWithMetadata(t, registry, pricey, "pricey", "openai", "gpt-4o",
+ &core.ModelPricing{InputPerMtok: floatPtr(10), OutputPerMtok: floatPtr(10)})
+
+ r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry()))
+
+ // Provider specified via model syntax (pricey/gpt-4o) should not be overridden.
+ resp, err := r.ChatCompletion(context.Background(), &core.ChatRequest{
+ Model: "pricey/gpt-4o",
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if resp.ID != "pricey" {
+ t.Fatalf("model-syntax provider (pricey/gpt-4o) should be respected, got %s", resp.ID)
+ }
+}
+
+func TestRouter_FirstFit_RespectsRegistrationOrder(t *testing.T) {
+ registry := NewModelRegistry()
+ first := &statsMockProvider{
+ mockProvider: mockProvider{name: "reg-first", chatResponse: &core.ChatResponse{ID: "first", Model: "gpt-4o"}},
+ }
+ second := &statsMockProvider{
+ mockProvider: mockProvider{name: "reg-second", chatResponse: &core.ChatResponse{ID: "second", Model: "gpt-4o"}},
+ }
+ // Register in order first, then second.
+ registerModelWithMetadata(t, registry, first, "reg-first", "openai", "gpt-4o", nil)
+ registerModelWithMetadata(t, registry, second, "reg-second", "openai", "gpt-4o", nil)
+
+ // Override to first_fit so the order is the only decision factor.
+ ctx := router.WithStrategyOverride(context.Background(), "first_fit")
+ r, _ := NewRouter(registry, WithStrategyRegistry(router.NewStrategyRegistry()))
+
+ resp, err := r.ChatCompletion(ctx, &core.ChatRequest{Model: "gpt-4o"})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if resp.ID != "first" {
+ t.Fatalf("first_fit should route to reg-first (registration order), got %s", resp.ID)
+ }
+}
+
+// sanity: http import referenced (prevents unused import churn if tests shrink).
+var _ = http.StatusOK
diff --git a/internal/providers/strategy_init.go b/internal/providers/strategy_init.go
new file mode 100644
index 00000000..6fec5878
--- /dev/null
+++ b/internal/providers/strategy_init.go
@@ -0,0 +1,36 @@
+package providers
+
+import (
+ router "gomodel/internal/router"
+
+ "gomodel/config"
+)
+
+// buildStrategyRegistry constructs the routing strategy registry from config.
+// Strategies with configured weights (balanced) are instantiated directly so
+// the configured cost/latency split applies to every request that does not
+// carry a per-request override.
+func buildStrategyRegistry(cfg config.RouterConfig) *router.StrategyRegistry {
+ registry := router.NewStrategyRegistry()
+
+ registry.Register("balanced", func() router.RoutingStrategy {
+ return &router.BalancedStrategy{
+ CostWeight: cfg.Weights.Cost,
+ LatencyWeight: cfg.Weights.Latency,
+ MaxErrorRate: cfg.MaxErrorRate,
+ }
+ })
+
+ registry.Register("cost_only", func() router.RoutingStrategy {
+ return &router.CostOnlyStrategy{MaxErrorRate: cfg.MaxErrorRate}
+ })
+ registry.Register("latency_only", func() router.RoutingStrategy {
+ return &router.LatencyOnlyStrategy{MaxErrorRate: cfg.MaxErrorRate}
+ })
+ registry.Register("first_fit", func() router.RoutingStrategy {
+ return &router.FirstFitStrategy{MaxErrorRate: cfg.MaxErrorRate}
+ })
+
+ _ = registry.SetDefault(cfg.Strategy)
+ return registry
+}
diff --git a/internal/router/context.go b/internal/router/context.go
new file mode 100644
index 00000000..ac3986c1
--- /dev/null
+++ b/internal/router/context.go
@@ -0,0 +1,67 @@
+package router
+
+import "context"
+
+// strategyCtxKey carries a per-request routing-strategy override extracted
+// from the X-GoModel-Routing-Strategy header. It is distinct from any other
+// context key so callers cannot collide with it.
+type strategyCtxKey struct{}
+
+// routingEligibleKey signals that the request's provider was assigned by the
+// system (prepare phase) rather than explicitly by the user. When set,
+// resolveProvider applies the routing strategy even though providerHint is
+// non-empty, because the provider is an artifact of first-wins resolution,
+// not a user choice.
+type routingEligibleKey struct{}
+
+// WithRoutingEligible marks a context as eligible for intelligent routing
+// despite having a non-empty provider hint. This is set by the prepare phase
+// when the original request did not pin a provider.
+func WithRoutingEligible(ctx context.Context) context.Context {
+ return context.WithValue(ctx, routingEligibleKey{}, true)
+}
+
+// IsRoutingEligible reports whether the context was marked as eligible for
+// intelligent routing by the prepare phase.
+func IsRoutingEligible(ctx context.Context) bool {
+ v, _ := ctx.Value(routingEligibleKey{}).(bool)
+ return v
+}
+
+// weightsCtxKey carries a per-request weights override parsed from the
+// X-GoModel-Routing-Weights header (only honored by the balanced strategy).
+type weightsCtxKey struct{}
+
+// StrategyOverrideFromContext returns a per-request strategy name override set
+// on the context, and whether one was present.
+func StrategyOverrideFromContext(ctx context.Context) (string, bool) {
+ s, ok := ctx.Value(strategyCtxKey{}).(string)
+ return s, ok
+}
+
+// WithStrategyOverride returns a context carrying a per-request strategy name
+// override. An empty name clears any prior override (returns ctx unchanged).
+func WithStrategyOverride(ctx context.Context, strategy string) context.Context {
+ if strategy == "" {
+ return ctx
+ }
+ return context.WithValue(ctx, strategyCtxKey{}, strategy)
+}
+
+// WeightsOverride holds a per-request cost/latency weight override for the
+// balanced strategy.
+type WeightsOverride struct {
+ Cost float64
+ Latency float64
+}
+
+// WeightsOverrideFromContext returns a per-request weights override, if set.
+func WeightsOverrideFromContext(ctx context.Context) (WeightsOverride, bool) {
+ w, ok := ctx.Value(weightsCtxKey{}).(WeightsOverride)
+ return w, ok
+}
+
+// WithWeightsOverride returns a context carrying a per-request weights override.
+func WithWeightsOverride(ctx context.Context, w WeightsOverride) context.Context {
+ return context.WithValue(ctx, weightsCtxKey{}, w)
+}
diff --git a/internal/router/registry.go b/internal/router/registry.go
new file mode 100644
index 00000000..bec46de2
--- /dev/null
+++ b/internal/router/registry.go
@@ -0,0 +1,119 @@
+package router
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "strings"
+)
+
+// StrategyRegistry maps strategy names to factories. A factory returns a fresh
+// strategy instance (so per-request weights can be applied without mutation
+// across requests). The default strategy is used when a request omits an
+// override or names an unknown strategy.
+type StrategyRegistry struct {
+ factories map[string]func() RoutingStrategy
+ defaultID string
+}
+
+// NewStrategyRegistry returns a registry pre-populated with the built-in
+// strategies (balanced, cost_only, latency_only, first_fit) and "balanced" as
+// the default.
+func NewStrategyRegistry() *StrategyRegistry {
+ r := &StrategyRegistry{
+ factories: map[string]func() RoutingStrategy{},
+ defaultID: "balanced",
+ }
+ r.Register("balanced", func() RoutingStrategy { return NewBalancedStrategy() })
+ r.Register("cost_only", func() RoutingStrategy { return NewCostOnlyStrategy() })
+ r.Register("latency_only", func() RoutingStrategy { return NewLatencyOnlyStrategy() })
+ r.Register("first_fit", func() RoutingStrategy { return NewFirstFitStrategy() })
+ return r
+}
+
+// Register adds or replaces a strategy factory under the given name.
+func (r *StrategyRegistry) Register(name string, factory func() RoutingStrategy) {
+ if factory == nil || name == "" {
+ return
+ }
+ r.factories[strings.ToLower(strings.TrimSpace(name))] = factory
+}
+
+// SetDefault sets the default strategy id used when no/invalid override is given.
+// Returns an error if the id is not registered.
+func (r *StrategyRegistry) SetDefault(id string) error {
+ if _, ok := r.factories[strings.ToLower(strings.TrimSpace(id))]; !ok {
+ return fmt.Errorf("unknown routing strategy %q", id)
+ }
+ r.defaultID = strings.ToLower(strings.TrimSpace(id))
+ return nil
+}
+
+// DefaultID returns the configured default strategy id.
+func (r *StrategyRegistry) DefaultID() string { return r.defaultID }
+
+// Names returns the registered strategy ids.
+func (r *StrategyRegistry) Names() []string {
+ names := make([]string, 0, len(r.factories))
+ for name := range r.factories {
+ names = append(names, name)
+ }
+ return names
+}
+
+// Resolve returns the strategy for a request. If ctx carries a valid override,
+// that strategy is used; otherwise the default is used. An invalid override id
+// yields (nil, false) so the caller can fall back to the default and log a warning.
+func (r *StrategyRegistry) Resolve(ctx context.Context) (RoutingStrategy, bool) {
+ if override, ok := StrategyOverrideFromContext(ctx); ok {
+ name := strings.ToLower(strings.TrimSpace(override))
+ if factory, ok := r.factories[name]; ok {
+ return factory(), true
+ }
+ return nil, false
+ }
+ if factory, ok := r.factories[r.defaultID]; ok {
+ return factory(), true
+ }
+ return nil, false
+}
+
+// New returns a fresh instance of the named strategy, or nil if unknown.
+func (r *StrategyRegistry) New(name string) RoutingStrategy {
+ factory, ok := r.factories[strings.ToLower(strings.TrimSpace(name))]
+ if !ok {
+ return nil
+ }
+ return factory()
+}
+
+// ParseWeights parses a "cost,latency" weights string (e.g. "0.6,0.4") into a
+// WeightsOverride. It returns an error if the string is not exactly two
+// non-negative finite numbers.
+func ParseWeights(s string) (WeightsOverride, error) {
+ parts := strings.Split(s, ",")
+ if len(parts) != 2 {
+ return WeightsOverride{}, fmt.Errorf("expected two comma-separated weights, got %q", s)
+ }
+ costStr := strings.TrimSpace(parts[0])
+ latStr := strings.TrimSpace(parts[1])
+ var cost, lat float64
+ var trailing string
+ n, err := fmt.Sscanf(costStr, "%f%s", &cost, &trailing)
+ if err != nil && !errors.Is(err, io.EOF) || n != 1 {
+ return WeightsOverride{}, fmt.Errorf("invalid cost weight %q: parse error", costStr)
+ }
+ n, err = fmt.Sscanf(latStr, "%f%s", &lat, &trailing)
+ if err != nil && !errors.Is(err, io.EOF) || n != 1 {
+ return WeightsOverride{}, fmt.Errorf("invalid latency weight %q: parse error", latStr)
+ }
+ if math.IsNaN(cost) || math.IsNaN(lat) || math.IsInf(cost, 0) || math.IsInf(lat, 0) {
+ return WeightsOverride{}, fmt.Errorf("weights must be finite numbers, got %v,%v", cost, lat)
+ }
+ if cost < 0 || lat < 0 {
+ return WeightsOverride{}, fmt.Errorf("weights must be non-negative, got %v,%v", cost, lat)
+ }
+ return WeightsOverride{Cost: cost, Latency: lat}, nil
+}
diff --git a/internal/router/registry_test.go b/internal/router/registry_test.go
new file mode 100644
index 00000000..fd569ff7
--- /dev/null
+++ b/internal/router/registry_test.go
@@ -0,0 +1,106 @@
+package router
+
+import (
+ "context"
+ "testing"
+)
+
+func TestStrategyRegistry_DefaultIsBalanced(t *testing.T) {
+ r := NewStrategyRegistry()
+ s, ok := r.Resolve(context.Background())
+ if !ok {
+ t.Fatal("expected default strategy to resolve")
+ }
+ if s.Name() != "balanced" {
+ t.Fatalf("expected default balanced, got %s", s.Name())
+ }
+}
+
+func TestStrategyRegistry_OverrideHonored(t *testing.T) {
+ r := NewStrategyRegistry()
+ ctx := WithStrategyOverride(context.Background(), "cost_only")
+ s, ok := r.Resolve(ctx)
+ if !ok {
+ t.Fatal("expected override to resolve")
+ }
+ if s.Name() != "cost_only" {
+ t.Fatalf("expected cost_only, got %s", s.Name())
+ }
+}
+
+func TestStrategyRegistry_InvalidOverrideRejected(t *testing.T) {
+ r := NewStrategyRegistry()
+ ctx := WithStrategyOverride(context.Background(), "nonsense")
+ s, ok := r.Resolve(ctx)
+ if ok {
+ t.Fatal("expected invalid override to be rejected")
+ }
+ if s != nil {
+ t.Fatal("expected nil strategy for invalid override")
+ }
+}
+
+func TestStrategyRegistry_Names(t *testing.T) {
+ r := NewStrategyRegistry()
+ names := r.Names()
+ if len(names) != 4 {
+ t.Fatalf("expected 4 strategies, got %d (%v)", len(names), names)
+ }
+}
+
+func TestStrategyRegistry_SetDefaultInvalid(t *testing.T) {
+ r := NewStrategyRegistry()
+ if err := r.SetDefault("bogus"); err == nil {
+ t.Fatal("expected error setting unknown default")
+ }
+}
+
+func TestParseWeights_Valid(t *testing.T) {
+ cases := []struct {
+ input string
+ wantCost float64
+ wantLat float64
+ }{
+ {"0.7,0.3", 0.7, 0.3},
+ {"0.60,0.40", 0.6, 0.4},
+ {"1,0", 1.0, 0.0},
+ {" 0.6 , 0.4 ", 0.6, 0.4},
+ {"0,1", 0.0, 1.0},
+ }
+ for _, tc := range cases {
+ w, err := ParseWeights(tc.input)
+ if err != nil {
+ t.Fatalf("ParseWeights(%q) unexpected error: %v", tc.input, err)
+ }
+ if w.Cost != tc.wantCost || w.Latency != tc.wantLat {
+ t.Fatalf("ParseWeights(%q) = %v,%v, want %v,%v", tc.input, w.Cost, w.Latency, tc.wantCost, tc.wantLat)
+ }
+ }
+}
+
+func TestParseWeights_Invalid(t *testing.T) {
+ cases := []string{"0.6", "0.6,0.4,0.2", "abc,0.4", "-1,0.4", "0.6x,0.4", "NaN,0.4", "Inf,0.4"}
+ for _, c := range cases {
+ if _, err := ParseWeights(c); err == nil {
+ t.Errorf("expected error for %q", c)
+ }
+ }
+}
+
+func TestWeightsOverride_RoundTrip(t *testing.T) {
+ ctx := context.Background()
+ if _, ok := WeightsOverrideFromContext(ctx); ok {
+ t.Fatal("expected no override on plain context")
+ }
+ ctx = WithWeightsOverride(ctx, WeightsOverride{Cost: 0.6, Latency: 0.4})
+ w, ok := WeightsOverrideFromContext(ctx)
+ if !ok || w.Cost != 0.6 || w.Latency != 0.4 {
+ t.Fatalf("round trip failed: %+v ok=%v", w, ok)
+ }
+}
+
+func TestParseWeights_ErrorsAreErrors(t *testing.T) {
+ if _, err := ParseWeights("x,y"); err == nil {
+ t.Fatal("expected error for non-numeric weights")
+ }
+}
diff --git a/internal/router/scoring.go b/internal/router/scoring.go
new file mode 100644
index 00000000..ae36947d
--- /dev/null
+++ b/internal/router/scoring.go
@@ -0,0 +1,244 @@
+package router
+
+import (
+ "errors"
+ "sort"
+ "time"
+)
+
+// ErrNoAcceptableCandidate indicates a strategy filtered out every candidate.
+// Callers fall back to the first candidate rather than failing the request.
+var ErrNoAcceptableCandidate = errors.New("no acceptable routing candidate")
+
+// candidateScore pairs a candidate with its computed score for sorting.
+type candidateScore struct {
+ candidate *ProviderCandidate
+ score float64
+}
+
+// maxErrorRate is the default upper bound above which a candidate is filtered.
+const maxErrorRate = 0.5
+
+// isCircuitOpen reports whether a candidate's breaker blocks requests.
+// An empty CircuitState is treated as healthy ("closed").
+func isCircuitOpen(c *ProviderCandidate) bool {
+ return c.CircuitState == "open" || c.CircuitState == "half-open"
+}
+
+// acceptable returns false when a candidate should be filtered out before
+// scoring: circuit open or error rate at/above the threshold. Candidates with
+// unknown error rate (0) are kept.
+func acceptable(c *ProviderCandidate, maxErrRate float64) bool {
+ if isCircuitOpen(c) {
+ return false
+ }
+ if c.ErrorRate >= maxErrRate {
+ return false
+ }
+ return true
+}
+
+// normalizeLatency maps latency values into [0, 1] relative to min and max.
+// Lower latency → lower (better) score. Unknown latencies (0) are scored 1.0
+// (worst) so complete-data providers win. When all latencies are unknown or
+// equal, every value scores 0.
+func normalizeLatency(values []time.Duration, min, max time.Duration) []float64 {
+ scores := make([]float64, len(values))
+ if max <= min || max == 0 {
+ for i := range values {
+ if values[i] == 0 {
+ scores[i] = 1.0
+ }
+ }
+ return scores
+ }
+ span := float64(max - min)
+ for i, v := range values {
+ if v == 0 {
+ scores[i] = 1.0
+ continue
+ }
+ scores[i] = float64(v-min) / span
+ }
+ return scores
+}
+
+// normalizeCost maps per-million-token costs into [0, 1] relative to min and
+// max. Lower cost → lower (better) score. Missing pricing (perMtokCost returns
+// -1) is scored 1.0 so complete-data providers win. Known zero-cost (free
+// models) is scored 0.0 (best), properly distinguished from unknown.
+// When all costs are unknown or equal, every value scores 0.
+func normalizeCost(values []float64, min, max float64) []float64 {
+ scores := make([]float64, len(values))
+ if max <= min {
+ for i, v := range values {
+ if v < 0 {
+ scores[i] = 1.0
+ }
+ }
+ return scores
+ }
+ span := max - min
+ for i, v := range values {
+ if v < 0 {
+ scores[i] = 1.0
+ continue
+ }
+ scores[i] = (v - min) / span
+ }
+ return scores
+}
+
+// pickBest returns the candidate with the lowest score, breaking ties by
+// ProviderName lexicographic order for determinism. Returns
+// ErrNoAcceptableCandidate when scored is empty.
+func pickBest(scored []candidateScore) (*ProviderCandidate, error) {
+ if len(scored) == 0 {
+ return nil, ErrNoAcceptableCandidate
+ }
+ sort.SliceStable(scored, func(i, j int) bool {
+ if scored[i].score != scored[j].score {
+ return scored[i].score < scored[j].score
+ }
+ return scored[i].candidate.ProviderName < scored[j].candidate.ProviderName
+ })
+ return scored[0].candidate, nil
+}
+
+// firstAcceptable returns the first acceptable candidate in order, or
+// ErrNoAcceptableCandidate when none pass the filter.
+func firstAcceptable(candidates []ProviderCandidate, maxErrRate float64) (*ProviderCandidate, error) {
+ for i := range candidates {
+ if acceptable(&candidates[i], maxErrRate) {
+ return &candidates[i], nil
+ }
+ }
+ return nil, ErrNoAcceptableCandidate
+}
+
+// minMaxLatency returns the min and max non-zero latency across candidates.
+func minMaxLatency(candidates []ProviderCandidate) (min, max time.Duration) {
+ for i := range candidates {
+ v := candidates[i].Latency
+ if v <= 0 {
+ continue
+ }
+ if min == 0 || v < min {
+ min = v
+ }
+ if v > max {
+ max = v
+ }
+ }
+ return min, max
+}
+
+// minMaxCost returns the min and max per-million-token cost across candidates.
+// Unknown costs (perMtokCost returns -1) are skipped. Known-zero costs
+// (free models) are included.
+func minMaxCost(candidates []ProviderCandidate) (min, max float64) {
+ var seen bool
+ for i := range candidates {
+ cost := perMtokCost(&candidates[i])
+ if cost < 0 {
+ continue
+ }
+ if !seen {
+ min = cost
+ max = cost
+ seen = true
+ continue
+ }
+ if cost < min {
+ min = cost
+ }
+ if cost > max {
+ max = cost
+ }
+ }
+ return min, max
+}
+
+// perMtokCost returns the sum of input and output per-million-token prices for
+// a candidate. Returns -1 when pricing is absent (so callers can distinguish
+// free models at cost 0 from unknown pricing).
+func perMtokCost(c *ProviderCandidate) float64 {
+ if c == nil || c.Pricing == nil {
+ return -1
+ }
+ var cost float64
+ if c.Pricing.InputPerMtok != nil {
+ cost += *c.Pricing.InputPerMtok
+ }
+ if c.Pricing.OutputPerMtok != nil {
+ cost += *c.Pricing.OutputPerMtok
+ }
+ return cost
+}
+
+// hasLatencyData reports whether at least one candidate carries latency data.
+func hasLatencyData(candidates []ProviderCandidate) bool {
+ for i := range candidates {
+ if candidates[i].Latency > 0 {
+ return true
+ }
+ }
+ return false
+}
+
+// latencyValues collects the latency field from each candidate preserving order.
+func latencyValues(candidates []ProviderCandidate) []time.Duration {
+ out := make([]time.Duration, len(candidates))
+ for i := range candidates {
+ out[i] = candidates[i].Latency
+ }
+ return out
+}
+
+// hasPricingData reports whether at least one candidate carries known pricing
+// (Pricing != nil with at least one non-nil field). This is distinct from
+// minMaxCost which can return max==0 even when a free-model (cost 0) exists.
+func hasPricingData(candidates []ProviderCandidate) bool {
+ for i := range candidates {
+ c := &candidates[i]
+ if c.Pricing == nil {
+ continue
+ }
+ if c.Pricing.InputPerMtok != nil || c.Pricing.OutputPerMtok != nil ||
+ c.Pricing.CachedInputPerMtok != nil || c.Pricing.CacheWritePerMtok != nil ||
+ c.Pricing.PerRequest != nil || c.Pricing.PerImage != nil ||
+ c.Pricing.InputPerImage != nil || c.Pricing.PerSecondInput != nil ||
+ c.Pricing.PerSecondOutput != nil || c.Pricing.PerCharacterInput != nil ||
+ c.Pricing.PerPage != nil || len(c.Pricing.Tiers) > 0 {
+ return true
+ }
+ }
+ return false
+}
+
+// costValues collects the per-million-token cost from each candidate preserving order.
+func costValues(candidates []ProviderCandidate) []float64 {
+ out := make([]float64, len(candidates))
+ for i := range candidates {
+ out[i] = perMtokCost(&candidates[i])
+ }
+ return out
+}
+
+// rankBy filters candidates and ranks them by a precomputed per-candidate
+// score, returning the best (lowest-scoring) one. It is the shared body of the
+// single-factor strategies; callers compute the score array (e.g. normalized
+// cost or latency) and pass it here.
+func rankBy(candidates []ProviderCandidate, scores []float64, maxErrRate float64) (*ProviderCandidate, error) {
+ scored := make([]candidateScore, 0, len(candidates))
+ for i := range candidates {
+ if !acceptable(&candidates[i], maxErrRate) {
+ continue
+ }
+ scored = append(scored, candidateScore{
+ candidate: &candidates[i],
+ score: scores[i],
+ })
+ }
+ return pickBest(scored)
+}
diff --git a/internal/router/strategy.go b/internal/router/strategy.go
new file mode 100644
index 00000000..71421421
--- /dev/null
+++ b/internal/router/strategy.go
@@ -0,0 +1,63 @@
+// Package router provides pluggable routing strategies that select the best
+// provider from a set of candidates serving the same model ID. Strategies are
+// cost-aware and/or latency-aware, fed by provider runtime statistics.
+package router
+
+import (
+ "context"
+ "time"
+
+ "gomodel/internal/core"
+)
+
+// ProviderCandidate describes one provider able to serve a given model, together
+// with the signals a strategy needs to score it. Fields with zero values
+// (nil Pricing, 0 Latency, "" CircuitState) mean "unknown"; strategies treat
+// unknown signals as worst-case so that providers with complete data are preferred.
+type ProviderCandidate struct {
+ // Provider is the underlying provider implementation (may be nil for
+ // strategy-only callers that just need the name/metadata).
+ Provider core.Provider
+
+ // ProviderName is the concrete configured instance name, e.g.
+ // "openai-primary". Used as a deterministic tiebreaker.
+ ProviderName string
+
+ // ProviderType is the provider type, e.g. "openai" or "azure".
+ ProviderType string
+
+ // ModelID is the model identifier this candidate serves.
+ ModelID string
+
+ // Pricing is the per-model pricing used by cost-aware strategies.
+ // nil means pricing is unknown.
+ Pricing *core.ModelPricing
+
+ // Latency is the provider's smoothed P50 latency. 0 means unknown.
+ Latency time.Duration
+
+ // CircuitState is "closed", "open", or "half-open". "" defaults to
+ // "closed" (healthy) when a provider has no breaker.
+ CircuitState string
+
+ // ErrorRate is the smoothed error ratio in [0, 1]. 0 means unknown
+ // (which strategies may treat conservatively or optimistically per policy).
+ ErrorRate float64
+}
+
+// RoutingStrategy selects the best provider candidate from a non-empty list.
+// Implementations must be deterministic: given identical candidates they must
+// return the same choice, breaking ties by ProviderName lexicographic order.
+//
+// Select returns an error only when no candidate is acceptable (e.g. every
+// candidate is filtered out); callers fall back to the first candidate in
+// that case rather than failing the request.
+type RoutingStrategy interface {
+ // Name returns the strategy identifier, e.g. "balanced", "cost_only".
+ Name() string
+
+ // Select picks one candidate. The candidates slice is non-empty; strategies
+ // that filter out every candidate return an error so the caller can fall
+ // back to the first candidate and log a warning.
+ Select(ctx context.Context, candidates []ProviderCandidate) (*ProviderCandidate, error)
+}
diff --git a/internal/router/strategy_balanced.go b/internal/router/strategy_balanced.go
new file mode 100644
index 00000000..79854847
--- /dev/null
+++ b/internal/router/strategy_balanced.go
@@ -0,0 +1,61 @@
+package router
+
+import (
+ "context"
+)
+
+// BalancedStrategy scores candidates by a weighted combination of cost and
+// latency. Lower total score wins; ties break by ProviderName. Candidates whose
+// circuit is open or whose error rate is at/above MaxErrorRate are filtered out.
+//
+// Missing pricing scores 1.0 (worst) on cost; missing latency scores 1.0
+// (worst) on latency — so providers with complete data are preferred. When
+// every candidate misses the same dimension, that dimension contributes equally
+// to all and ranking falls to the other dimension.
+type BalancedStrategy struct {
+ // CostWeight in [0, 1]. Defaults to 0.6 when zero.
+ CostWeight float64
+ // LatencyWeight in [0, 1]. Defaults to 0.4 when zero.
+ LatencyWeight float64
+ // MaxErrorRate filters candidates at/above this ratio. Defaults to 0.5.
+ MaxErrorRate float64
+}
+
+// NewBalancedStrategy returns a balanced strategy with default weights.
+func NewBalancedStrategy() *BalancedStrategy {
+ return &BalancedStrategy{
+ CostWeight: 0.6,
+ LatencyWeight: 0.4,
+ MaxErrorRate: maxErrorRate,
+ }
+}
+
+// Name returns "balanced".
+func (s *BalancedStrategy) Name() string { return "balanced" }
+
+// Select picks the lowest combined-score acceptable candidate.
+func (s *BalancedStrategy) Select(_ context.Context, candidates []ProviderCandidate) (*ProviderCandidate, error) {
+ costW := s.CostWeight
+ if costW == 0 && s.LatencyWeight == 0 {
+ costW = 0.6
+ }
+ latW := s.LatencyWeight
+ if s.CostWeight == 0 && latW == 0 {
+ latW = 0.4
+ }
+ maxErr := s.MaxErrorRate
+ if maxErr == 0 {
+ maxErr = maxErrorRate
+ }
+
+ costMin, costMax := minMaxCost(candidates)
+ latMin, latMax := minMaxLatency(candidates)
+ costScores := normalizeCost(costValues(candidates), costMin, costMax)
+ latScores := normalizeLatency(latencyValues(candidates), latMin, latMax)
+
+ scores := make([]float64, len(candidates))
+ for i := range candidates {
+ scores[i] = costW*costScores[i] + latW*latScores[i]
+ }
+ return rankBy(candidates, scores, maxErr)
+}
diff --git a/internal/router/strategy_cost.go b/internal/router/strategy_cost.go
new file mode 100644
index 00000000..e8e75e3d
--- /dev/null
+++ b/internal/router/strategy_cost.go
@@ -0,0 +1,42 @@
+//nolint:dupl // mirrors strategy_latency.go: same RoutingStrategy shape, different dimension (cost vs latency); divergence is in the scorer/normalize/degrade calls.
+package router
+
+import (
+ "context"
+)
+
+// CostOnlyStrategy selects the cheapest acceptable candidate. When every
+// candidate lacks pricing data, it degrades to latency-only ranking so that a
+// decision is still made on available signals.
+type CostOnlyStrategy struct {
+ // MaxErrorRate filters candidates at/above this ratio. Defaults to 0.5.
+ MaxErrorRate float64
+}
+
+// NewCostOnlyStrategy returns a cost-only strategy with the default error filter.
+func NewCostOnlyStrategy() *CostOnlyStrategy {
+ return &CostOnlyStrategy{MaxErrorRate: maxErrorRate}
+}
+
+// Name returns "cost_only".
+func (s *CostOnlyStrategy) Name() string { return "cost_only" }
+
+// Select picks the cheapest acceptable candidate. When no candidate carries
+// pricing it ranks by latency instead (without recursing into latency's own
+// cost-degradation path, which would loop when both dimensions are absent).
+func (s *CostOnlyStrategy) Select(ctx context.Context, candidates []ProviderCandidate) (*ProviderCandidate, error) {
+ maxErr := s.MaxErrorRate
+ if maxErr == 0 {
+ maxErr = maxErrorRate
+ }
+
+ if !hasPricingData(candidates) {
+ latMin, latMax := minMaxLatency(candidates)
+ scores := normalizeLatency(latencyValues(candidates), latMin, latMax)
+ return rankBy(candidates, scores, maxErr)
+ }
+
+ costMin, costMax := minMaxCost(candidates)
+ scores := normalizeCost(costValues(candidates), costMin, costMax)
+ return rankBy(candidates, scores, maxErr)
+}
diff --git a/internal/router/strategy_firstfit.go b/internal/router/strategy_firstfit.go
new file mode 100644
index 00000000..54c958c2
--- /dev/null
+++ b/internal/router/strategy_firstfit.go
@@ -0,0 +1,30 @@
+package router
+
+import (
+ "context"
+)
+
+// FirstFitStrategy returns the first acceptable candidate without scoring,
+// preserving the gateway's historical first-wins behaviour. It exists for
+// operators who want deterministic, config-order routing.
+type FirstFitStrategy struct {
+ // MaxErrorRate filters candidates at/above this ratio. Defaults to 0.5.
+ MaxErrorRate float64
+}
+
+// NewFirstFitStrategy returns a first-fit strategy with the default error filter.
+func NewFirstFitStrategy() *FirstFitStrategy {
+ return &FirstFitStrategy{MaxErrorRate: maxErrorRate}
+}
+
+// Name returns "first_fit".
+func (s *FirstFitStrategy) Name() string { return "first_fit" }
+
+// Select returns the first acceptable candidate in registration order.
+func (s *FirstFitStrategy) Select(_ context.Context, candidates []ProviderCandidate) (*ProviderCandidate, error) {
+ maxErr := s.MaxErrorRate
+ if maxErr == 0 {
+ maxErr = maxErrorRate
+ }
+ return firstAcceptable(candidates, maxErr)
+}
diff --git a/internal/router/strategy_latency.go b/internal/router/strategy_latency.go
new file mode 100644
index 00000000..78e1badf
--- /dev/null
+++ b/internal/router/strategy_latency.go
@@ -0,0 +1,42 @@
+//nolint:dupl // mirrors strategy_cost.go: same RoutingStrategy shape, different dimension (latency vs cost); divergence is in the scorer/normalize/degrade calls.
+package router
+
+import (
+ "context"
+)
+
+// LatencyOnlyStrategy selects the fastest acceptable candidate. When every
+// candidate lacks latency data, it degrades to cost-only ranking so that a
+// decision is still made on available signals.
+type LatencyOnlyStrategy struct {
+ // MaxErrorRate filters candidates at/above this ratio. Defaults to 0.5.
+ MaxErrorRate float64
+}
+
+// NewLatencyOnlyStrategy returns a latency-only strategy with the default error filter.
+func NewLatencyOnlyStrategy() *LatencyOnlyStrategy {
+ return &LatencyOnlyStrategy{MaxErrorRate: maxErrorRate}
+}
+
+// Name returns "latency_only".
+func (s *LatencyOnlyStrategy) Name() string { return "latency_only" }
+
+// Select picks the lowest-latency acceptable candidate. When no candidate
+// carries latency it ranks by cost instead (without recursing into cost's own
+// latency-degradation path, which would loop when both dimensions are absent).
+func (s *LatencyOnlyStrategy) Select(ctx context.Context, candidates []ProviderCandidate) (*ProviderCandidate, error) {
+ maxErr := s.MaxErrorRate
+ if maxErr == 0 {
+ maxErr = maxErrorRate
+ }
+
+ if !hasLatencyData(candidates) {
+ costMin, costMax := minMaxCost(candidates)
+ scores := normalizeCost(costValues(candidates), costMin, costMax)
+ return rankBy(candidates, scores, maxErr)
+ }
+
+ latMin, latMax := minMaxLatency(candidates)
+ scores := normalizeLatency(latencyValues(candidates), latMin, latMax)
+ return rankBy(candidates, scores, maxErr)
+}
diff --git a/internal/router/strategy_test.go b/internal/router/strategy_test.go
new file mode 100644
index 00000000..db01ba70
--- /dev/null
+++ b/internal/router/strategy_test.go
@@ -0,0 +1,294 @@
+package router
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "gomodel/internal/core"
+)
+
+func floatPtr(v float64) *float64 { return &v }
+
+func pricing(input, output float64) *core.ModelPricing {
+ return &core.ModelPricing{InputPerMtok: floatPtr(input), OutputPerMtok: floatPtr(output)}
+}
+
+func candidates(cs ...ProviderCandidate) []ProviderCandidate { return cs }
+
+func mustSelect(t *testing.T, s RoutingStrategy, cands []ProviderCandidate) *ProviderCandidate {
+ t.Helper()
+ got, err := s.Select(context.Background(), cands)
+ if err != nil {
+ t.Fatalf("%s.Select returned error: %v", s.Name(), err)
+ }
+ if got == nil {
+ t.Fatalf("%s.Select returned nil candidate", s.Name())
+ }
+ return got
+}
+
+func TestBalanced_PicksCheapestWithEqualLatency(t *testing.T) {
+ s := NewBalancedStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "alpha", Pricing: pricing(5, 15), Latency: 100 * time.Millisecond},
+ ProviderCandidate{ProviderName: "zeta", Pricing: pricing(1, 4), Latency: 100 * time.Millisecond},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "zeta" {
+ t.Fatalf("expected azure (cheaper), got %s", got.ProviderName)
+ }
+}
+
+func TestBalanced_PicksFastestWithEqualCost(t *testing.T) {
+ s := NewBalancedStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "alpha", Pricing: pricing(5, 5), Latency: 200 * time.Millisecond},
+ ProviderCandidate{ProviderName: "zeta", Pricing: pricing(5, 5), Latency: 50 * time.Millisecond},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "zeta" {
+ t.Fatalf("expected azure (faster), got %s", got.ProviderName)
+ }
+}
+
+func TestBalanced_TiebreaksByName(t *testing.T) {
+ s := NewBalancedStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "zeta", Pricing: pricing(5, 5), Latency: 100 * time.Millisecond},
+ ProviderCandidate{ProviderName: "alpha", Pricing: pricing(5, 5), Latency: 100 * time.Millisecond},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "alpha" {
+ t.Fatalf("expected alpha (lexicographic tiebreak), got %s", got.ProviderName)
+ }
+}
+
+func TestBalanced_FiltersOpenCircuit(t *testing.T) {
+ s := NewBalancedStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "open", Pricing: pricing(1, 1), Latency: 10 * time.Millisecond, CircuitState: "open"},
+ ProviderCandidate{ProviderName: "healthy", Pricing: pricing(10, 10), Latency: 500 * time.Millisecond},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "healthy" {
+ t.Fatalf("expected healthy (open circuit filtered), got %s", got.ProviderName)
+ }
+}
+
+func TestBalanced_FiltersHighErrorRate(t *testing.T) {
+ s := NewBalancedStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "flaky", Pricing: pricing(1, 1), Latency: 10 * time.Millisecond, ErrorRate: 0.9},
+ ProviderCandidate{ProviderName: "stable", Pricing: pricing(10, 10), Latency: 500 * time.Millisecond},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "stable" {
+ t.Fatalf("expected stable (high error rate filtered), got %s", got.ProviderName)
+ }
+}
+
+func TestBalanced_AllFilteredFallsBackWithError(t *testing.T) {
+ s := NewBalancedStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "open", Pricing: pricing(1, 1), CircuitState: "open"},
+ )
+ _, err := s.Select(context.Background(), cands)
+ if !errors.Is(err, ErrNoAcceptableCandidate) {
+ t.Fatalf("expected ErrNoAcceptableCandidate, got %v", err)
+ }
+}
+
+func TestCostOnly_PicksCheapest(t *testing.T) {
+ s := NewCostOnlyStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "alpha", Pricing: pricing(5, 15), Latency: 10 * time.Millisecond},
+ ProviderCandidate{ProviderName: "zeta", Pricing: pricing(1, 4), Latency: 500 * time.Millisecond},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "zeta" {
+ t.Fatalf("expected azure (cheapest), got %s", got.ProviderName)
+ }
+}
+
+func TestCostOnly_DegradesToLatencyWhenNoPricing(t *testing.T) {
+ s := NewCostOnlyStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "alpha", Latency: 200 * time.Millisecond},
+ ProviderCandidate{ProviderName: "zeta", Latency: 50 * time.Millisecond},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "zeta" {
+ t.Fatalf("expected azure (degraded to latency), got %s", got.ProviderName)
+ }
+}
+
+func TestLatencyOnly_PicksFastest(t *testing.T) {
+ s := NewLatencyOnlyStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "alpha", Pricing: pricing(1, 1), Latency: 300 * time.Millisecond},
+ ProviderCandidate{ProviderName: "zeta", Pricing: pricing(50, 50), Latency: 50 * time.Millisecond},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "zeta" {
+ t.Fatalf("expected azure (fastest), got %s", got.ProviderName)
+ }
+}
+
+func TestLatencyOnly_DegradesToCostWhenNoLatency(t *testing.T) {
+ s := NewLatencyOnlyStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "alpha", Pricing: pricing(5, 15)},
+ ProviderCandidate{ProviderName: "zeta", Pricing: pricing(1, 4)},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "zeta" {
+ t.Fatalf("expected azure (degraded to cost), got %s", got.ProviderName)
+ }
+}
+
+func TestFirstFit_ReturnsFirstAcceptable(t *testing.T) {
+ s := NewFirstFitStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "open", Pricing: pricing(50, 50), Latency: 500 * time.Millisecond, CircuitState: "open"},
+ ProviderCandidate{ProviderName: "second", Pricing: pricing(50, 50), Latency: 500 * time.Millisecond},
+ ProviderCandidate{ProviderName: "third", Pricing: pricing(1, 1), Latency: 10 * time.Millisecond},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "second" {
+ t.Fatalf("expected second (first acceptable), got %s", got.ProviderName)
+ }
+}
+
+func TestBalanced_FreeModelNotTreatedAsUnknown(t *testing.T) {
+ s := NewBalancedStrategy()
+ // freeModel has pricing = $0 (free). unknown has no pricing.
+ freeModel := ProviderCandidate{ProviderName: "free", Pricing: pricing(0, 0), Latency: 100 * time.Millisecond}
+ unknown := ProviderCandidate{ProviderName: "unknown", Latency: 100 * time.Millisecond}
+ got := mustSelect(t, s, candidates(freeModel, unknown))
+ if got.ProviderName != "free" {
+ t.Fatalf("expected free (known zero cost, score 0), got %s", got.ProviderName)
+ }
+}
+
+func TestBalanced_ZeroValuesDefaultToSixtyForty(t *testing.T) {
+ s := &BalancedStrategy{} // zero-value struct, should default to 0.6 cost / 0.4 latency
+ cands := candidates(
+ ProviderCandidate{ProviderName: "alpha", Pricing: pricing(5, 15), Latency: 100 * time.Millisecond},
+ ProviderCandidate{ProviderName: "zeta", Pricing: pricing(1, 4), Latency: 100 * time.Millisecond},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "zeta" {
+ t.Fatalf("expected zeta (cheaper, cost weight dominates at 0.6/0.4), got %s", got.ProviderName)
+ }
+}
+
+func TestBalanced_EqualCostKnownVsUnknown(t *testing.T) {
+ s := NewBalancedStrategy()
+ // same latency, both $5 — but "unknown" has no pricing at all, so it gets
+ // cost score 1.0 while the known-cost provider gets cost score 0.
+ known := ProviderCandidate{ProviderName: "known", Pricing: pricing(5, 5), Latency: 100 * time.Millisecond}
+ unknown := ProviderCandidate{ProviderName: "unknown", Latency: 100 * time.Millisecond}
+ got := mustSelect(t, s, candidates(known, unknown))
+ if got.ProviderName != "known" {
+ t.Fatalf("expected known (complete data preferred over unknown), got %s", got.ProviderName)
+ }
+}
+
+func TestStrategies_HandleEmpty(t *testing.T) {
+ strategies := []RoutingStrategy{
+ NewBalancedStrategy(),
+ NewCostOnlyStrategy(),
+ NewLatencyOnlyStrategy(),
+ NewFirstFitStrategy(),
+ }
+ for _, s := range strategies {
+ _, err := s.Select(context.Background(), nil)
+ if !errors.Is(err, ErrNoAcceptableCandidate) {
+ t.Errorf("%s on nil candidates: expected ErrNoAcceptableCandidate, got %v", s.Name(), err)
+ }
+ }
+}
+
+// --- Boundary tests ---
+
+func TestBalanced_ErrorRateAtBoundary(t *testing.T) {
+ s := NewBalancedStrategy()
+ // ErrorRate == 0.5 is at the default filter boundary — candidate is filtered.
+ atBoundary := ProviderCandidate{ProviderName: "at-boundary", Pricing: pricing(1, 1), Latency: 10 * time.Millisecond, ErrorRate: 0.5}
+ healthy := ProviderCandidate{ProviderName: "healthy", Pricing: pricing(10, 10), Latency: 500 * time.Millisecond}
+ got := mustSelect(t, s, candidates(atBoundary, healthy))
+ if got.ProviderName != "healthy" {
+ t.Fatalf("expected healthy (at-boundary filtered), got %s", got.ProviderName)
+ }
+}
+
+func TestBalanced_ErrorRateBelowBoundary(t *testing.T) {
+ s := NewBalancedStrategy()
+ below := ProviderCandidate{ProviderName: "below", Pricing: pricing(1, 1), Latency: 10 * time.Millisecond, ErrorRate: 0.499}
+ pricey := ProviderCandidate{ProviderName: "pricey", Pricing: pricing(10, 10), Latency: 500 * time.Millisecond}
+ got := mustSelect(t, s, candidates(below, pricey))
+ if got.ProviderName != "below" {
+ t.Fatalf("expected below (0.499 < 0.5 is acceptable), got %s", got.ProviderName)
+ }
+}
+
+func TestBalanced_EqualScoresTiebreak(t *testing.T) {
+ s := NewBalancedStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "zzz", Latency: 100 * time.Millisecond},
+ ProviderCandidate{ProviderName: "aaa", Latency: 100 * time.Millisecond},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "aaa" {
+ t.Fatalf("expected aaa (tiebreak), got %s", got.ProviderName)
+ }
+}
+
+func TestFirstFit_SingleCandidate(t *testing.T) {
+ s := NewFirstFitStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "solo", Pricing: pricing(1, 1)},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "solo" {
+ t.Fatalf("expected solo, got %s", got.ProviderName)
+ }
+}
+
+func TestFirstFit_AllFilteredReturnsError(t *testing.T) {
+ s := NewFirstFitStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "open", CircuitState: "open"},
+ ProviderCandidate{ProviderName: "half", CircuitState: "half-open"},
+ )
+ _, err := s.Select(context.Background(), cands)
+ if !errors.Is(err, ErrNoAcceptableCandidate) {
+ t.Fatalf("expected ErrNoAcceptableCandidate, got %v", err)
+ }
+}
+
+func TestCostOnly_AllAcceptableEqualCost(t *testing.T) {
+ s := NewCostOnlyStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "beta", Pricing: pricing(5, 5)},
+ ProviderCandidate{ProviderName: "alpha", Pricing: pricing(5, 5)},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "alpha" {
+ t.Fatalf("expected alpha (tiebreak equal cost), got %s", got.ProviderName)
+ }
+}
+
+func TestLatencyOnly_AllAcceptableEqualLatency(t *testing.T) {
+ s := NewLatencyOnlyStrategy()
+ cands := candidates(
+ ProviderCandidate{ProviderName: "zeta", Latency: 100 * time.Millisecond},
+ ProviderCandidate{ProviderName: "alpha", Latency: 100 * time.Millisecond},
+ )
+ got := mustSelect(t, s, cands)
+ if got.ProviderName != "alpha" {
+ t.Fatalf("expected alpha (tiebreak equal latency), got %s", got.ProviderName)
+ }
+}
diff --git a/internal/server/http.go b/internal/server/http.go
index 2603c1cc..6b5cb3aa 100644
--- a/internal/server/http.go
+++ b/internal/server/http.go
@@ -287,6 +287,10 @@ func New(provider core.RoutableProvider, cfg *Config) *Server {
e.Use(AuthMiddlewareWithAuthenticator(cfg.MasterKey, cfg.Authenticator, authSkipPaths, userPathHeaderName))
}
+ // Per-request intelligent-routing overrides (X-GoModel-Routing-Strategy /
+ // X-GoModel-Routing-Weights). A no-op header-lookup when the headers are absent.
+ e.Use(RoutingStrategyCapture())
+
// Workflow resolution resolves the request-scoped workflow after auth so
// managed auth key user-path overrides are visible to policy resolution while
// still keeping workflow resolution failures loggable through the audit middleware.
diff --git a/internal/server/routing_strategy.go b/internal/server/routing_strategy.go
new file mode 100644
index 00000000..1932b8fa
--- /dev/null
+++ b/internal/server/routing_strategy.go
@@ -0,0 +1,48 @@
+package server
+
+import (
+ "log/slog"
+ "strings"
+
+ router "gomodel/internal/router"
+
+ "github.com/labstack/echo/v5"
+)
+
+// RoutingStrategyHeader carries the per-request routing-strategy override.
+const RoutingStrategyHeader = "X-GoModel-Routing-Strategy"
+
+// RoutingWeightsHeader carries the per-request balanced-strategy weight
+// override as "cost,latency" (e.g. "0.6,0.4").
+const RoutingWeightsHeader = "X-GoModel-Routing-Weights"
+
+// RoutingStrategyCapture reads the optional routing-strategy/weights headers
+// and attaches them to the request context as overrides for the router. Invalid
+// values are silently dropped here with a warning; the router falls back to the
+// configured global strategy. The middleware is a no-op when the headers are
+// absent, so non-routing requests pay only a header-lookup cost.
+func RoutingStrategyCapture() echo.MiddlewareFunc {
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ ctx := c.Request().Context()
+
+ if strategy := strings.TrimSpace(c.Request().Header.Get(RoutingStrategyHeader)); strategy != "" {
+ ctx = router.WithStrategyOverride(ctx, strategy)
+ }
+
+ if weights := strings.TrimSpace(c.Request().Header.Get(RoutingWeightsHeader)); weights != "" {
+ if override, err := router.ParseWeights(weights); err != nil {
+ slog.WarnContext(ctx, "invalid routing weights header, ignoring",
+ "header", weights, "error", err)
+ } else {
+ ctx = router.WithWeightsOverride(ctx, override)
+ }
+ }
+
+ if ctx != c.Request().Context() {
+ c.SetRequest(c.Request().WithContext(ctx))
+ }
+ return next(c)
+ }
+ }
+}