-
Notifications
You must be signed in to change notification settings - Fork 3
feat: add circuit breaker for upstream provider overload protection #75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7700a8f
aad288c
47253f1
8cf2d18
8e44145
7af3bc1
521df9b
c85b836
e446954
1d2315e
6994f89
6a7d578
b0ff0eb
bee7a4d
98c7b7a
7733266
7c7c85b
8943ef0
7d2dcb1
e3438f4
a32f246
e929098
ab08de4
161db92
33ea4ae
dbfab23
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,9 +9,9 @@ import ( | |
|
|
||
| "cdr.dev/slog" | ||
| "github.com/coder/aibridge/mcp" | ||
| "go.opentelemetry.io/otel/trace" | ||
|
|
||
| "github.com/hashicorp/go-multierror" | ||
| "github.com/sony/gobreaker/v2" | ||
| "go.opentelemetry.io/otel/trace" | ||
| ) | ||
|
|
||
| // RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs; | ||
|
|
@@ -48,13 +48,33 @@ var _ http.Handler = &RequestBridge{} | |
| // A [Recorder] is also required to record prompt, tool, and token use. | ||
| // | ||
| // mcpProxy will be closed when the [RequestBridge] is closed. | ||
| // | ||
| // Circuit breaker configuration is obtained from each provider's CircuitBreakerConfig() method. | ||
| // Providers returning nil will not have circuit breaker protection. | ||
| func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) (*RequestBridge, error) { | ||
| mux := http.NewServeMux() | ||
|
|
||
| for _, provider := range providers { | ||
| // Create per-provider circuit breaker if configured | ||
| cfg := provider.CircuitBreakerConfig() | ||
| onChange := func(endpoint string, from, to gobreaker.State) {} | ||
|
|
||
| if cfg != nil && metrics != nil { | ||
| onChange = func(endpoint string, from, to gobreaker.State) { | ||
| metrics.CircuitBreakerState.WithLabelValues(provider.Name(), endpoint).Set(stateToGaugeValue(to)) | ||
| if to == gobreaker.StateOpen { | ||
| metrics.CircuitBreakerTrips.WithLabelValues(provider.Name(), endpoint).Inc() | ||
| } | ||
| } | ||
| } | ||
| cbs := NewProviderCircuitBreakers(provider.Name(), cfg, onChange) | ||
|
|
||
| // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). | ||
| for _, path := range provider.BridgedRoutes() { | ||
| mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer)) | ||
| handler := newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer) | ||
| // Wrap with circuit breaker middleware (nil cbs passes through) | ||
| wrapped := CircuitBreakerMiddleware(cbs, metrics)(handler) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The PR description is now stale, btw. |
||
| mux.Handle(path, wrapped) | ||
| } | ||
|
|
||
| // Any requests which passthrough to this will be reverse-proxied to the upstream. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| package aibridge | ||
kacpersaw marked this conversation as resolved.
Show resolved
Hide resolved
kacpersaw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| import ( | ||
| "errors" | ||
| "fmt" | ||
| "net/http" | ||
| "strings" | ||
| "sync" | ||
| "time" | ||
|
|
||
| "github.com/sony/gobreaker/v2" | ||
| ) | ||
|
|
||
| // CircuitBreakerConfig holds configuration for circuit breakers. | ||
| // Fields match gobreaker.Settings for clarity. | ||
| type CircuitBreakerConfig struct { | ||
| // MaxRequests is the maximum number of requests allowed in half-open state. | ||
| MaxRequests uint32 | ||
| // Interval is the cyclic period of the closed state for clearing internal counts. | ||
| Interval time.Duration | ||
| // Timeout is how long the circuit stays open before transitioning to half-open. | ||
| Timeout time.Duration | ||
| // FailureThreshold is the number of consecutive failures that triggers the circuit to open. | ||
| FailureThreshold uint32 | ||
| // IsFailure determines if a status code should count as a failure. | ||
| // If nil, defaults to 429, 503, and 529 (Anthropic overloaded). | ||
| IsFailure func(statusCode int) bool | ||
| } | ||
|
|
||
| // DefaultCircuitBreakerConfig returns sensible defaults for circuit breaker configuration. | ||
| func DefaultCircuitBreakerConfig() CircuitBreakerConfig { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not used, right? |
||
| return CircuitBreakerConfig{ | ||
| FailureThreshold: 5, | ||
| Interval: 10 * time.Second, | ||
| Timeout: 30 * time.Second, | ||
| MaxRequests: 3, | ||
| IsFailure: DefaultIsFailure, | ||
| } | ||
| } | ||
|
|
||
| // DefaultIsFailure returns true for status codes that typically indicate | ||
| // upstream overload: 429 (Too Many Requests), 503 (Service Unavailable), | ||
| // and 529 (Anthropic Overloaded). | ||
| func DefaultIsFailure(statusCode int) bool { | ||
| switch statusCode { | ||
| case http.StatusTooManyRequests, // 429 | ||
| http.StatusServiceUnavailable, // 503 | ||
| 529: // Anthropic "Overloaded" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had a comment here before about this not being provider-specific; not sure what happened to it. |
||
| return true | ||
| default: | ||
| return false | ||
| } | ||
| } | ||
|
|
||
| // ProviderCircuitBreakers manages per-endpoint circuit breakers for a single provider. | ||
| type ProviderCircuitBreakers struct { | ||
| provider string | ||
| config CircuitBreakerConfig | ||
| breakers sync.Map // endpoint -> *gobreaker.CircuitBreaker[struct{}] | ||
| onChange func(endpoint string, from, to gobreaker.State) | ||
| } | ||
|
|
||
| // NewProviderCircuitBreakers creates circuit breakers for a single provider. | ||
| // Returns nil if config is nil (no circuit breaker protection). | ||
| func NewProviderCircuitBreakers(provider string, config *CircuitBreakerConfig, onChange func(endpoint string, from, to gobreaker.State)) *ProviderCircuitBreakers { | ||
| if config == nil { | ||
| return nil | ||
| } | ||
| if config.IsFailure == nil { | ||
| config.IsFailure = DefaultIsFailure | ||
| } | ||
| return &ProviderCircuitBreakers{ | ||
| provider: provider, | ||
| config: *config, | ||
| onChange: onChange, | ||
| } | ||
| } | ||
|
|
||
| // Get returns the circuit breaker for an endpoint, creating it if needed. | ||
| func (p *ProviderCircuitBreakers) Get(endpoint string) *gobreaker.CircuitBreaker[struct{}] { | ||
| if v, ok := p.breakers.Load(endpoint); ok { | ||
| return v.(*gobreaker.CircuitBreaker[struct{}]) | ||
| } | ||
|
|
||
| settings := gobreaker.Settings{ | ||
| Name: p.provider + ":" + endpoint, | ||
| MaxRequests: p.config.MaxRequests, | ||
| Interval: p.config.Interval, | ||
| Timeout: p.config.Timeout, | ||
| ReadyToTrip: func(counts gobreaker.Counts) bool { | ||
| return counts.ConsecutiveFailures >= p.config.FailureThreshold | ||
| }, | ||
| OnStateChange: func(_ string, from, to gobreaker.State) { | ||
| if p.onChange != nil { | ||
| p.onChange(endpoint, from, to) | ||
| } | ||
| }, | ||
| } | ||
|
|
||
| cb := gobreaker.NewCircuitBreaker[struct{}](settings) | ||
| actual, _ := p.breakers.LoadOrStore(endpoint, cb) | ||
| return actual.(*gobreaker.CircuitBreaker[struct{}]) | ||
| } | ||
|
|
||
| // statusCapturingWriter wraps http.ResponseWriter to capture the status code. | ||
| // It also implements http.Flusher to support streaming responses. | ||
| type statusCapturingWriter struct { | ||
| http.ResponseWriter | ||
| statusCode int | ||
| headerWritten bool | ||
| } | ||
|
|
||
| func (w *statusCapturingWriter) WriteHeader(code int) { | ||
| if !w.headerWritten { | ||
| w.statusCode = code | ||
| w.headerWritten = true | ||
| } | ||
| w.ResponseWriter.WriteHeader(code) | ||
| } | ||
|
|
||
| func (w *statusCapturingWriter) Write(b []byte) (int, error) { | ||
| if !w.headerWritten { | ||
| w.statusCode = http.StatusOK | ||
| w.headerWritten = true | ||
| } | ||
| return w.ResponseWriter.Write(b) | ||
| } | ||
|
|
||
| func (w *statusCapturingWriter) Flush() { | ||
| if f, ok := w.ResponseWriter.(http.Flusher); ok { | ||
| f.Flush() | ||
| } | ||
| } | ||
|
|
||
| // Unwrap returns the underlying ResponseWriter for interface checks. | ||
| func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { | ||
| return w.ResponseWriter | ||
| } | ||
|
|
||
| // CircuitBreakerMiddleware returns middleware that wraps handlers with circuit breaker protection. | ||
| // It captures the response status code to determine success/failure without provider-specific logic. | ||
| // If cbs is nil, requests pass through without circuit breaker protection. | ||
| func CircuitBreakerMiddleware(cbs *ProviderCircuitBreakers, metrics *Metrics) func(http.Handler) http.Handler { | ||
| return func(next http.Handler) http.Handler { | ||
| // No circuit breaker configured - pass through | ||
| if cbs == nil { | ||
| return next | ||
| } | ||
|
|
||
| return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| endpoint := strings.TrimPrefix(r.URL.Path, "/"+cbs.provider) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If two providers supported the same endpoint, you'd get one provider influencing another.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was changed so |
||
| cb := cbs.Get(endpoint) | ||
|
|
||
| // Wrap response writer to capture status code | ||
| sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} | ||
|
|
||
| _, err := cb.Execute(func() (struct{}, error) { | ||
| next.ServeHTTP(sw, r) | ||
| if cbs.config.IsFailure(sw.statusCode) { | ||
| return struct{}{}, fmt.Errorf("upstream error: %d", sw.statusCode) | ||
| } | ||
| return struct{}{}, nil | ||
| }) | ||
|
|
||
| if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { | ||
| if metrics != nil { | ||
| metrics.CircuitBreakerRejects.WithLabelValues(cbs.provider, endpoint).Inc() | ||
| } | ||
| http.Error(w, "circuit breaker is open", http.StatusServiceUnavailable) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Requests expecting JSON responses will be broken by plaintext responses. |
||
| } | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| // stateToGaugeValue converts gobreaker.State to a gauge value. | ||
| // closed=0, half-open=0.5, open=1 | ||
| func stateToGaugeValue(s gobreaker.State) float64 { | ||
| switch s { | ||
| case gobreaker.StateClosed: | ||
| return 0 | ||
| case gobreaker.StateHalfOpen: | ||
| return 0.5 | ||
| case gobreaker.StateOpen: | ||
| return 1 | ||
| default: | ||
| return 0 | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| package aibridge_test | ||
|
|
||
| import ( | ||
| "context" | ||
| "io" | ||
| "net" | ||
| "net/http" | ||
| "net/http/httptest" | ||
| "strings" | ||
| "sync/atomic" | ||
| "testing" | ||
| "time" | ||
|
|
||
| "cdr.dev/slog" | ||
| "cdr.dev/slog/sloggers/slogtest" | ||
| "github.com/coder/aibridge" | ||
| "github.com/coder/aibridge/mcp" | ||
| "github.com/prometheus/client_golang/prometheus" | ||
| promtest "github.com/prometheus/client_golang/prometheus/testutil" | ||
| "github.com/stretchr/testify/assert" | ||
| "github.com/stretchr/testify/require" | ||
| "go.opentelemetry.io/otel" | ||
| ) | ||
|
|
||
| func TestCircuitBreaker_WithNewRequestBridge(t *testing.T) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just testing Anthropic; we should support all providers. |
||
| t.Parallel() | ||
|
|
||
| var upstreamCalls atomic.Int32 | ||
|
|
||
| // Mock upstream that returns 429 in Anthropic error format. | ||
| // x-should-retry: false is required to disable SDK automatic retries (default MaxRetries=2). | ||
| mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| upstreamCalls.Add(1) | ||
| w.Header().Set("Content-Type", "application/json") | ||
| w.Header().Set("x-should-retry", "false") | ||
| w.WriteHeader(http.StatusTooManyRequests) | ||
| _, _ = w.Write([]byte(`{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`)) | ||
| })) | ||
| defer mockUpstream.Close() | ||
|
|
||
| metrics := aibridge.NewMetrics(prometheus.NewRegistry()) | ||
|
|
||
| // Create provider with circuit breaker config | ||
| provider := aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{ | ||
| BaseURL: mockUpstream.URL, | ||
| Key: "test-key", | ||
| CircuitBreaker: &aibridge.CircuitBreakerConfig{ | ||
| FailureThreshold: 2, | ||
| Interval: time.Minute, | ||
| Timeout: 50 * time.Millisecond, | ||
| MaxRequests: 1, | ||
| }, | ||
| }, nil) | ||
|
|
||
| ctx := t.Context() | ||
| tracer := otel.Tracer("forTesting") | ||
| logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) | ||
| bridge, err := aibridge.NewRequestBridge(ctx, | ||
| []aibridge.Provider{provider}, | ||
| &mockRecorderClient{}, | ||
| mcp.NewServerProxyManager(nil, tracer), | ||
| logger, | ||
| metrics, | ||
| tracer, | ||
| ) | ||
| require.NoError(t, err) | ||
|
|
||
| mockSrv := httptest.NewUnstartedServer(bridge) | ||
| t.Cleanup(mockSrv.Close) | ||
| mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { | ||
| return aibridge.AsActor(ctx, "test-user-id", nil) | ||
| } | ||
| mockSrv.Start() | ||
|
|
||
| makeRequest := func() *http.Response { | ||
| body := `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}` | ||
| req, _ := http.NewRequest("POST", mockSrv.URL+"/anthropic/v1/messages", strings.NewReader(body)) | ||
| req.Header.Set("Content-Type", "application/json") | ||
| req.Header.Set("x-api-key", "test") | ||
| req.Header.Set("anthropic-version", "2023-06-01") | ||
| resp, err := http.DefaultClient.Do(req) | ||
| require.NoError(t, err) | ||
| _, _ = io.ReadAll(resp.Body) | ||
| resp.Body.Close() | ||
| return resp | ||
| } | ||
|
|
||
| // First 2 requests hit upstream, get 429 | ||
| for i := 0; i < 2; i++ { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: reference the config ( |
||
| resp := makeRequest() | ||
| assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) | ||
| } | ||
| assert.Equal(t, int32(2), upstreamCalls.Load()) | ||
|
|
||
| // Third request should be blocked by circuit breaker | ||
| resp := makeRequest() | ||
| assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) | ||
| assert.Equal(t, int32(2), upstreamCalls.Load()) // No new upstream call | ||
|
|
||
| // Verify metrics were recorded via NewRequestBridge's onChange callback | ||
| trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(aibridge.ProviderAnthropic, "/v1/messages")) | ||
| assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") | ||
|
|
||
| state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(aibridge.ProviderAnthropic, "/v1/messages")) | ||
| assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open)") | ||
|
|
||
| rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(aibridge.ProviderAnthropic, "/v1/messages")) | ||
| assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should be 1") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just testing that a circuit-breaker will close, but not open again, so this isn't a comprehensive integration test. |
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should always log when the CB changes state.
You can move the check for
cfg != nil && metrics != nilintoonChangeand only update metrics if needed.