Skip to content

Commit e8585f1

Browse files
kacpersawpawbanadannykopping
authored
feat: add circuit breaker for upstream provider overload protection (#75)
* feat: add circuit breaker for upstream provider overload protection Implement per-provider circuit breakers that detect upstream rate limiting (429/503/529 status codes) and temporarily stop sending requests when providers are overloaded. Key features: - Per-provider circuit breakers (Anthropic, OpenAI) - Configurable failure threshold, time window, and cooldown period - Half-open state allows gradual recovery testing - Prometheus metrics for monitoring (state gauge, trips counter, rejects counter) - Thread-safe implementation with proper state machine transitions - Disabled by default for backward compatibility Circuit breaker states: - Closed: normal operation, tracking failures within sliding window - Open: all requests rejected with 503, waiting for cooldown - Half-Open: limited requests allowed to test if upstream recovered Status codes that trigger circuit breaker: - 429 Too Many Requests - 503 Service Unavailable - 529 Anthropic Overloaded Relates to: coder/internal#1153 * chore: apply make fmt * refactor: use sony/gobreaker for circuit breakers with per-endpoint isolation - Replace custom circuit breaker implementation with sony/gobreaker - Change from per-provider to per-endpoint circuit breakers (e.g., OpenAI chat completions failing won't block responses API) - Simplify API: CircuitBreakers manages all breakers internally - Update metrics to include endpoint label - Simplify tests to focus on key behaviors Based on PR review feedback suggesting use of established library and per-endpoint granularity for better fault isolation. * refactor: align CircuitBreakerConfig fields with gobreaker.Settings Rename fields to match gobreaker naming convention: - Window -> Interval - Cooldown -> Timeout - HalfOpenMaxRequests -> MaxRequests - FailureThreshold type int64 -> uint32 * refactor: remove CircuitState, use gobreaker.State directly * refactor: implement circuit breaker as middleware with per-provider configs Address PR review feedback: 1. Middleware pattern - Circuit breaker is now HTTP middleware that wraps handlers, capturing response status codes directly instead of extracting from provider-specific error types. 2. Per-provider configs - NewCircuitBreakers takes map[string]CircuitBreakerConfig keyed by provider name. Providers not in the map have no circuit breaker. 3. Remove provider overfitting - Deleted extractStatusCodeFromError() which hardcoded AnthropicErrorResponse and OpenAIErrorResponse types. Middleware now uses statusCapturingWriter to inspect actual HTTP response codes. 4. Configurable failure detection - IsFailure func in config allows providers to define custom status codes as failures. Defaults to 429/503/529. 5. Fix gauge values - State gauge now uses 0 (closed), 0.5 (half-open), 1 (open) 6. Integration tests - Replaced unit tests with httptest-based integration tests that verify actual behavior: upstream errors trip circuit, requests get blocked, recovery after timeout, per-endpoint isolation. 7. Error message - Changed from 'upstream rate limiting' to 'circuit breaker is open' * docs: clarify noop behavior when provider not configured * Update go.mod * fix: update metrics help text to reflect 0/0.5/1 gauge values * refactor: add CircuitBreaker interface with NoopCircuitBreaker - Add CircuitBreaker interface with Allow(), RecordSuccess(), RecordFailure() - Add NoopCircuitBreaker struct for providers without circuit breaker config - Add gobreakerCircuitBreaker wrapping sony/gobreaker implementation - CircuitBreakers.Get() returns NoopCircuitBreaker when provider not configured - Add http.Flusher support to statusCapturingWriter for SSE streaming - Add Unwrap() for ResponseWriter interface detection * refactor: use gobreaker Execute for proper half-open rejection handling - Changed CircuitBreaker interface to Execute(fn func() int) (statusCode, rejected) - Use gobreaker.Execute() to properly handle both ErrOpenState and ErrTooManyRequests - NoopCircuitBreaker.Execute simply runs the function and returns not rejected - Simplified middleware by removing separate Allow/Record pattern * refactor: remove unused circuitBreakers field and getter from RequestBridge * use per-provider maps for endpoints * make fmt * use mux.Handle for cb middleware * Move CircuitBreakerConfig to the Provider struct * Update tests * default noop func for onChange * create CircuitBreakers per Provider instead of a global one and remove gobreakerCircuitBraker along with the interface and noop struct * Update bridge.go Co-authored-by: Paweł Banaszewski <pawel@coder.com> * fix format * Apply review suggestions * Apply review suggestions and add proper integration tests * Add test to check circuit breaker config * Remove test * Remove TestCircuitBreaker_HalfOpenAndRecovery * Apply review suggestions * Apply review suggestions * Fix test * Add TestCircuitBreaker_HalfOpenMaxRequests test and add Retry-After header * Apply review suggestions * Apply review suggestions * Update provider/anthropic.go Co-authored-by: Danny Kopping <danny@coder.com> * Fix tests * Fix fmt --------- Co-authored-by: Paweł Banaszewski <pawel@coder.com> Co-authored-by: Danny Kopping <danny@coder.com>
1 parent 61a792b commit e8585f1

12 files changed

Lines changed: 1061 additions & 17 deletions

File tree

bridge.go

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,17 @@ import (
1010
"time"
1111

1212
"cdr.dev/slog/v3"
13+
"github.com/coder/aibridge/circuitbreaker"
1314
aibcontext "github.com/coder/aibridge/context"
1415
"github.com/coder/aibridge/mcp"
1516
"github.com/coder/aibridge/metrics"
1617
"github.com/coder/aibridge/provider"
1718
"github.com/coder/aibridge/recorder"
1819
"github.com/coder/aibridge/tracing"
20+
"github.com/hashicorp/go-multierror"
21+
"github.com/sony/gobreaker/v2"
1922
"go.opentelemetry.io/otel/codes"
2023
"go.opentelemetry.io/otel/trace"
21-
22-
"github.com/hashicorp/go-multierror"
2324
)
2425

2526
// RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs;
@@ -59,22 +60,53 @@ const recordingTimeout = time.Second * 5
5960
// A [intercept.Recorder] is also required to record prompt, tool, and token use.
6061
//
6162
// mcpProxy will be closed when the [RequestBridge] is closed.
63+
//
64+
// Circuit breaker configuration is obtained from each provider's CircuitBreakerConfig() method.
65+
// Providers returning nil will not have circuit breaker protection.
6266
func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec recorder.Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) (*RequestBridge, error) {
6367
mux := http.NewServeMux()
6468

65-
for _, provider := range providers {
69+
for _, prov := range providers {
70+
// Create per-provider circuit breaker if configured
71+
cfg := prov.CircuitBreakerConfig()
72+
providerName := prov.Name()
73+
onChange := func(endpoint string, from, to gobreaker.State) {
74+
logger.Info(context.Background(), "circuit breaker state change",
75+
slog.F("provider", providerName),
76+
slog.F("endpoint", endpoint),
77+
slog.F("from", from.String()),
78+
slog.F("to", to.String()),
79+
)
80+
if m != nil {
81+
m.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(circuitbreaker.StateToGaugeValue(to))
82+
if to == gobreaker.StateOpen {
83+
m.CircuitBreakerTrips.WithLabelValues(providerName, endpoint).Inc()
84+
}
85+
}
86+
}
87+
cbs := circuitbreaker.NewProviderCircuitBreakers(providerName, cfg, onChange)
88+
6689
// Add the known provider-specific routes which are bridged (i.e. intercepted and augmented).
67-
for _, path := range provider.BridgedRoutes() {
68-
mux.HandleFunc(path, newInterceptionProcessor(provider, rec, mcpProxy, logger, m, tracer))
90+
for _, path := range prov.BridgedRoutes() {
91+
// Initialize circuit breaker state metric to closed (0) for known routes
92+
if m != nil && cbs != nil {
93+
endpoint := strings.TrimPrefix(path, "/"+providerName)
94+
m.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(0)
95+
}
96+
97+
handler := newInterceptionProcessor(prov, rec, mcpProxy, logger, m, tracer)
98+
// Wrap with circuit breaker middleware (nil cbs passes through)
99+
wrapped := circuitbreaker.Middleware(cbs, m, logger)(handler)
100+
mux.Handle(path, wrapped)
69101
}
70102

71103
// Any requests which passthrough to this will be reverse-proxied to the upstream.
72104
//
73105
// We have to whitelist the known-safe routes because an API key with elevated privileges (i.e. admin) might be
74106
// configured, so we should just reverse-proxy known-safe routes.
75-
ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), m, tracer)
76-
for _, path := range provider.PassthroughRoutes() {
77-
prefix := fmt.Sprintf("/%s", provider.Name())
107+
ftr := newPassthroughRouter(prov, logger.Named(fmt.Sprintf("passthrough.%s", prov.Name())), m, tracer)
108+
for _, path := range prov.PassthroughRoutes() {
109+
prefix := fmt.Sprintf("/%s", prov.Name())
78110
route := fmt.Sprintf("%s%s", prefix, path)
79111
mux.HandleFunc(route, http.StripPrefix(prefix, ftr).ServeHTTP)
80112
}
@@ -100,7 +132,7 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re
100132

101133
// newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request
102134
// using [Provider] p, recording all usage events using [Recorder] rec.
103-
func newInterceptionProcessor(p Provider, rec recorder.Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc {
135+
func newInterceptionProcessor(p provider.Provider, rec recorder.Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc {
104136
return func(w http.ResponseWriter, r *http.Request) {
105137
ctx, span := tracer.Start(r.Context(), "Intercept")
106138
defer span.End()

0 commit comments

Comments
 (0)