Skip to content

Commit 24784c7

Browse files
committed
Apply review suggestions
1 parent 6d98675 commit 24784c7

5 files changed

Lines changed: 42 additions & 19 deletions

File tree

bridge.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ import (
1010
"time"
1111

1212
"cdr.dev/slog/v3"
13-
aibcontext "github.com/coder/aibridge/context"
1413
"github.com/coder/aibridge/circuitbreaker"
14+
aibcontext "github.com/coder/aibridge/context"
1515
"github.com/coder/aibridge/mcp"
1616
"github.com/coder/aibridge/metrics"
1717
"github.com/coder/aibridge/provider"
@@ -77,7 +77,7 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re
7777
slog.F("from", from.String()),
7878
slog.F("to", to.String()),
7979
)
80-
if cfg != nil && m != nil {
80+
if m != nil {
8181
m.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(circuitbreaker.StateToGaugeValue(to))
8282
if to == gobreaker.StateOpen {
8383
m.CircuitBreakerTrips.WithLabelValues(providerName, endpoint).Inc()
@@ -90,7 +90,7 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re
9090
for _, path := range prov.BridgedRoutes() {
9191
handler := newInterceptionProcessor(prov, rec, mcpProxy, logger, m, tracer)
9292
// Wrap with circuit breaker middleware (nil cbs passes through)
93-
wrapped := circuitbreaker.Middleware(cbs, m)(handler)
93+
wrapped := circuitbreaker.Middleware(cbs, m, logger)(handler)
9494
mux.Handle(path, wrapped)
9595
}
9696

circuit_breaker_integration_test.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,19 @@ import (
1616
"github.com/coder/aibridge"
1717
"github.com/coder/aibridge/circuitbreaker"
1818
"github.com/coder/aibridge/config"
19+
"github.com/coder/aibridge/mcp"
1920
"github.com/coder/aibridge/metrics"
2021
"github.com/coder/aibridge/provider"
21-
"github.com/coder/aibridge/mcp"
2222
"github.com/prometheus/client_golang/prometheus"
2323
promtest "github.com/prometheus/client_golang/prometheus/testutil"
2424
"github.com/stretchr/testify/assert"
2525
"github.com/stretchr/testify/require"
2626
"go.opentelemetry.io/otel"
2727
)
2828

29-
func TestCircuitBreaker_WithNewRequestBridge(t *testing.T) {
29+
// TestCircuitBreaker_FullRecoveryCycle tests the complete circuit breaker lifecycle:
30+
// closed → open (after consecutive failures) → half-open (after timeout) → closed (after successful request)
31+
func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) {
3032
t.Parallel()
3133

3234
type testCase struct {
@@ -136,12 +138,14 @@ func TestCircuitBreaker_WithNewRequestBridge(t *testing.T) {
136138
mockSrv.Start()
137139

138140
makeRequest := func() *http.Response {
139-
req, _ := http.NewRequest("POST", mockSrv.URL+"/"+tc.providerName+tc.endpoint, strings.NewReader(tc.requestBody))
141+
req, err := http.NewRequest("POST", mockSrv.URL+"/"+tc.providerName+tc.endpoint, strings.NewReader(tc.requestBody))
142+
require.NoError(t, err)
140143
req.Header.Set("Content-Type", "application/json")
141144
tc.setupHeaders(req)
142145
resp, err := http.DefaultClient.Do(req)
143146
require.NoError(t, err)
144-
_, _ = io.ReadAll(resp.Body)
147+
_, err = io.ReadAll(resp.Body)
148+
require.NoError(t, err)
145149
resp.Body.Close()
146150
return resp
147151
}
@@ -203,6 +207,8 @@ func TestCircuitBreaker_WithNewRequestBridge(t *testing.T) {
203207
}
204208
}
205209

210+
// TestCircuitBreaker_HalfOpenFailure tests that a failed request in half-open state
211+
// returns the circuit to open: closed → open → half-open → open
206212
func TestCircuitBreaker_HalfOpenFailure(t *testing.T) {
207213
t.Parallel()
208214

@@ -253,13 +259,15 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) {
253259
mockSrv.Start()
254260

255261
makeRequest := func() *http.Response {
256-
req, _ := http.NewRequest("POST", mockSrv.URL+"/openai/v1/chat/completions",
262+
req, err := http.NewRequest("POST", mockSrv.URL+"/openai/v1/chat/completions",
257263
strings.NewReader(`{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`))
264+
require.NoError(t, err)
258265
req.Header.Set("Content-Type", "application/json")
259266
req.Header.Set("Authorization", "Bearer test-key")
260267
resp, err := http.DefaultClient.Do(req)
261268
require.NoError(t, err)
262-
_, _ = io.ReadAll(resp.Body)
269+
_, err = io.ReadAll(resp.Body)
270+
require.NoError(t, err)
263271
resp.Body.Close()
264272
return resp
265273
}

circuitbreaker/circuitbreaker.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"sync"
99
"time"
1010

11+
"cdr.dev/slog/v3"
1112
"github.com/coder/aibridge/metrics"
1213
"github.com/sony/gobreaker/v2"
1314
)
@@ -39,20 +40,26 @@ func DefaultConfig() Config {
3940
}
4041
}
4142

42-
// DefaultIsFailure returns true for status codes that typically indicate
43-
// upstream overload: 429 (Too Many Requests), 503 (Service Unavailable),
44-
// and 529 (Anthropic Overloaded).
43+
// DefaultIsFailure returns true for standard HTTP status codes that typically
44+
// indicate upstream overload: 429 (Too Many Requests) and 503 (Service Unavailable).
4545
func DefaultIsFailure(statusCode int) bool {
4646
switch statusCode {
4747
case http.StatusTooManyRequests, // 429
48-
http.StatusServiceUnavailable, // 503
49-
529: // Anthropic "Overloaded"
48+
http.StatusServiceUnavailable: // 503
5049
return true
5150
default:
5251
return false
5352
}
5453
}
5554

55+
// AnthropicIsFailure extends DefaultIsFailure with Anthropic's custom 529 "Overloaded" status.
56+
func AnthropicIsFailure(statusCode int) bool {
57+
if statusCode == 529 {
58+
return true
59+
}
60+
return DefaultIsFailure(statusCode)
61+
}
62+
5663
// ProviderCircuitBreakers manages per-endpoint circuit breakers for a single provider.
5764
type ProviderCircuitBreakers struct {
5865
provider string
@@ -141,7 +148,7 @@ func (w *statusCapturingWriter) Unwrap() http.ResponseWriter {
141148
// Middleware returns middleware that wraps handlers with circuit breaker protection.
142149
// It captures the response status code to determine success/failure without provider-specific logic.
143150
// If cbs is nil, requests pass through without circuit breaker protection.
144-
func Middleware(cbs *ProviderCircuitBreakers, m *metrics.Metrics) func(http.Handler) http.Handler {
151+
func Middleware(cbs *ProviderCircuitBreakers, m *metrics.Metrics, logger slog.Logger) func(http.Handler) http.Handler {
145152
return func(next http.Handler) http.Handler {
146153
// No circuit breaker configured - pass through
147154
if cbs == nil {
@@ -170,6 +177,8 @@ func Middleware(cbs *ProviderCircuitBreakers, m *metrics.Metrics) func(http.Hand
170177
w.Header().Set("Content-Type", "application/json")
171178
w.WriteHeader(http.StatusServiceUnavailable)
172179
w.Write([]byte(`{"type":"error","error":{"type":"circuit_breaker_open","message":"circuit breaker is open"}}`))
180+
} else if err != nil {
181+
logger.Warn(r.Context(), "unexpected circuit breaker error", slog.Error(err))
173182
}
174183
})
175184
}

circuitbreaker/circuitbreaker_test.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ import (
77
"testing"
88
"time"
99

10+
"cdr.dev/slog/v3"
11+
"cdr.dev/slog/v3/sloggers/slogtest"
12+
1013
"github.com/sony/gobreaker/v2"
1114
"github.com/stretchr/testify/assert"
1215
"github.com/stretchr/testify/require"
@@ -36,7 +39,8 @@ func TestMiddleware_PerEndpointIsolation(t *testing.T) {
3639
MaxRequests: 1,
3740
}, func(endpoint string, from, to gobreaker.State) {})
3841

39-
handler := Middleware(cbs, nil)(upstream)
42+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
43+
handler := Middleware(cbs, nil, logger)(upstream)
4044
server := httptest.NewServer(handler)
4145
defer server.Close()
4246

@@ -71,7 +75,8 @@ func TestMiddleware_NotConfigured(t *testing.T) {
7175
})
7276

7377
// No circuit breaker configured (nil)
74-
handler := Middleware(nil, nil)(upstream)
78+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
79+
handler := Middleware(nil, nil, logger)(upstream)
7580
server := httptest.NewServer(handler)
7681
defer server.Close()
7782

@@ -103,7 +108,8 @@ func TestMiddleware_CustomIsFailure(t *testing.T) {
103108
},
104109
}, func(endpoint string, from, to gobreaker.State) {})
105110

106-
handler := Middleware(cbs, nil)(upstream)
111+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
112+
handler := Middleware(cbs, nil, logger)(upstream)
107113
server := httptest.NewServer(handler)
108114
defer server.Close()
109115

metrics/metrics.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func NewMetrics(reg prometheus.Registerer) *Metrics {
120120
CircuitBreakerTrips: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{
121121
Subsystem: "circuit_breaker",
122122
Name: "trips_total",
123-
Help: "Total number of times the circuit breaker has tripped open.",
123+
Help: "Total number of times the circuit breaker transitioned to open state.",
124124
}, []string{"provider", "endpoint"}),
125125
// Pessimistic cardinality: 2 providers, 5 endpoints = up to 10.
126126
CircuitBreakerRejects: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{

0 commit comments

Comments
 (0)