diff --git a/api.go b/api.go index e0486d77..a3943c54 100644 --- a/api.go +++ b/api.go @@ -3,14 +3,15 @@ package aibridge import ( "context" + "github.com/prometheus/client_golang/prometheus" + "go.opentelemetry.io/otel/trace" + "cdr.dev/slog/v3" "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" - "github.com/prometheus/client_golang/prometheus" - "go.opentelemetry.io/otel/trace" ) // Const + Type + function aliases for backwards compatibility. diff --git a/bridge.go b/bridge.go index 3b6a964d..cbd6dc38 100644 --- a/bridge.go +++ b/bridge.go @@ -11,6 +11,12 @@ import ( "sync/atomic" "time" + "github.com/hashicorp/go-multierror" + "github.com/sony/gobreaker/v2" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + "cdr.dev/slog/v3" "github.com/coder/aibridge/circuitbreaker" aibcontext "github.com/coder/aibridge/context" @@ -19,10 +25,6 @@ import ( "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" - "github.com/hashicorp/go-multierror" - "github.com/sony/gobreaker/v2" - "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/trace" ) const ( @@ -67,10 +69,10 @@ func validateProviders(providers []provider.Provider) error { for _, prov := range providers { name := prov.Name() if !validProviderName.MatchString(name) { - return fmt.Errorf("invalid provider name %q: must contain only lowercase alphanumeric characters and hyphens", name) + return xerrors.Errorf("invalid provider name %q: must contain only lowercase alphanumeric characters and hyphens", name) } if names[name] { - return fmt.Errorf("duplicate provider name: %q", name) + return xerrors.Errorf("duplicate provider name: %q", name) } names[name] = true } @@ -125,7 +127,7 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re slog.F("prefix", prov.RoutePrefix()), slog.F("path", path), ) - return nil, fmt.Errorf("failed to configure provider '%v': failed to join bridged path: %w", providerName, err) + return nil, xerrors.Errorf("failed to configure provider '%v': failed to join bridged path: %w", providerName, err) } mux.Handle(route, handler) } @@ -144,7 +146,7 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re slog.F("prefix", prov.RoutePrefix()), slog.F("path", path), ) - return nil, fmt.Errorf("failed to configure provider '%v': failed to join passed through path: %w", providerName, err) + return nil, xerrors.Errorf("failed to configure provider '%v': failed to join passed through path: %w", providerName, err) } mux.Handle(route, http.StripPrefix(prov.RoutePrefix(), ftr)) } @@ -325,7 +327,7 @@ func (b *RequestBridge) Shutdown(ctx context.Context) error { select { case <-ctx.Done(): // Cancel all inflight requests, if any are still running. - b.logger.Debug(ctx, "shutdown context canceled; cancelling inflight requests", slog.Error(ctx.Err())) + b.logger.Debug(ctx, "shutdown context canceled; canceling inflight requests", slog.Error(ctx.Err())) b.inflightCancel() <-done err = ctx.Err() @@ -347,8 +349,8 @@ func (b *RequestBridge) InflightRequests() int32 { return b.inflightReqs.Load() } -// mergeContexts merges two contexts together, so that if either is cancelled -// the returned context is cancelled. The context values will only be used from +// mergeContexts merges two contexts together, so that if either is canceled +// the returned context is canceled. The context values will only be used from // the first context. func mergeContexts(base, other context.Context) context.Context { ctx, cancel := context.WithCancel(base) diff --git a/bridge_test.go b/bridge_test.go index 161f3f3e..f83fd0b0 100644 --- a/bridge_test.go +++ b/bridge_test.go @@ -5,12 +5,13 @@ import ( "net/http/httptest" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge/config" "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/provider" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestValidateProvider_Names(t *testing.T) { diff --git a/buildinfo/buildinfo_test.go b/buildinfo/buildinfo_test.go index a59b3089..390d4e93 100644 --- a/buildinfo/buildinfo_test.go +++ b/buildinfo/buildinfo_test.go @@ -3,8 +3,9 @@ package buildinfo_test import ( "testing" - "github.com/coder/aibridge/buildinfo" "github.com/stretchr/testify/assert" + + "github.com/coder/aibridge/buildinfo" ) func TestBuildInfo(t *testing.T) { diff --git a/circuitbreaker/circuitbreaker.go b/circuitbreaker/circuitbreaker.go index 4be1d2b8..ae4f226c 100644 --- a/circuitbreaker/circuitbreaker.go +++ b/circuitbreaker/circuitbreaker.go @@ -7,14 +7,16 @@ import ( "sync" "time" + "github.com/sony/gobreaker/v2" + "golang.org/x/xerrors" + "github.com/coder/aibridge/config" "github.com/coder/aibridge/metrics" - "github.com/sony/gobreaker/v2" ) // ErrCircuitOpen is returned by Execute when the circuit breaker is open // and the request was rejected without calling the handler. -var ErrCircuitOpen = errors.New("circuit breaker is open") +var ErrCircuitOpen = xerrors.New("circuit breaker is open") // DefaultIsFailure returns true for standard HTTP status codes that typically // indicate upstream overload. @@ -153,7 +155,7 @@ func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.Respons _, err := cb.Execute(func() (struct{}, error) { handlerErr = handler(sw) if p.isFailure(sw.statusCode) { - return struct{}{}, fmt.Errorf("upstream error: %d", sw.statusCode) + return struct{}{}, xerrors.Errorf("upstream error: %d", sw.statusCode) } return struct{}{}, nil }) diff --git a/circuitbreaker/circuitbreaker_test.go b/circuitbreaker/circuitbreaker_test.go index 18913718..4f78da38 100644 --- a/circuitbreaker/circuitbreaker_test.go +++ b/circuitbreaker/circuitbreaker_test.go @@ -8,9 +8,10 @@ import ( "testing" "time" - "github.com/coder/aibridge/config" "github.com/sony/gobreaker/v2" "github.com/stretchr/testify/assert" + + "github.com/coder/aibridge/config" ) func TestExecute_PerModelIsolation(t *testing.T) { diff --git a/go.mod b/go.mod index e464c34e..f8a94c8e 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b golang.org/x/sync v0.16.0 golang.org/x/tools v0.36.0 + golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 ) // AI-related libs. @@ -86,7 +87,6 @@ require ( go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/term v0.34.0 // indirect - golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 // indirect google.golang.org/protobuf v1.36.8 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/intercept/actor_headers.go b/intercept/actor_headers.go index 2d94503c..4406a0bc 100644 --- a/intercept/actor_headers.go +++ b/intercept/actor_headers.go @@ -5,8 +5,9 @@ import ( "strings" ant_option "github.com/anthropics/anthropic-sdk-go/option" - "github.com/coder/aibridge/context" oai_option "github.com/openai/openai-go/v3/option" + + "github.com/coder/aibridge/context" ) const ( diff --git a/intercept/actor_headers_test.go b/intercept/actor_headers_test.go index e9f80a8c..f38a7315 100644 --- a/intercept/actor_headers_test.go +++ b/intercept/actor_headers_test.go @@ -3,10 +3,11 @@ package intercept import ( "testing" - "github.com/coder/aibridge/context" - "github.com/coder/aibridge/recorder" "github.com/google/uuid" "github.com/stretchr/testify/require" + + "github.com/coder/aibridge/context" + "github.com/coder/aibridge/recorder" ) func TestNilActor(t *testing.T) { diff --git a/intercept/apidump/apidump.go b/intercept/apidump/apidump.go index e8e6d893..1926ba9a 100644 --- a/intercept/apidump/apidump.go +++ b/intercept/apidump/apidump.go @@ -12,11 +12,14 @@ import ( "slices" "strings" + "golang.org/x/xerrors" + "cdr.dev/slog/v3" - "github.com/coder/quartz" "github.com/google/uuid" "github.com/tidwall/pretty" + + "github.com/coder/quartz" ) const ( @@ -71,7 +74,7 @@ type dumper struct { func (d *dumper) dumpRequest(req *http.Request) error { dumpPath := d.dumpPath + SuffixRequest if err := os.MkdirAll(filepath.Dir(dumpPath), 0o755); err != nil { - return fmt.Errorf("create dump dir: %w", err) + return xerrors.Errorf("create dump dir: %w", err) } // Read and restore body @@ -80,7 +83,7 @@ func (d *dumper) dumpRequest(req *http.Request) error { var err error bodyBytes, err = io.ReadAll(req.Body) if err != nil { - return fmt.Errorf("read request body: %w", err) + return xerrors.Errorf("read request body: %w", err) } req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) } @@ -91,18 +94,18 @@ func (d *dumper) dumpRequest(req *http.Request) error { var buf bytes.Buffer _, err := fmt.Fprintf(&buf, "%s %s %s\r\n", req.Method, req.URL.RequestURI(), req.Proto) if err != nil { - return fmt.Errorf("write request uri: %w", err) + return xerrors.Errorf("write request uri: %w", err) } err = d.writeRedactedHeaders(&buf, req.Header, sensitiveRequestHeaders, map[string]string{ "Content-Length": fmt.Sprintf("%d", len(prettyBody)), }) if err != nil { - return fmt.Errorf("write request headers: %w", err) + return xerrors.Errorf("write request headers: %w", err) } _, err = fmt.Fprintf(&buf, "\r\n") if err != nil { - return fmt.Errorf("write request header terminator: %w", err) + return xerrors.Errorf("write request header terminator: %w", err) } buf.Write(prettyBody) buf.WriteByte('\n') @@ -117,15 +120,15 @@ func (d *dumper) dumpResponse(resp *http.Response) error { var headerBuf bytes.Buffer _, err := fmt.Fprintf(&headerBuf, "%s %s\r\n", resp.Proto, resp.Status) if err != nil { - return fmt.Errorf("write response status: %w", err) + return xerrors.Errorf("write response status: %w", err) } err = d.writeRedactedHeaders(&headerBuf, resp.Header, sensitiveResponseHeaders, nil) if err != nil { - return fmt.Errorf("write response headers: %w", err) + return xerrors.Errorf("write response headers: %w", err) } _, err = fmt.Fprintf(&headerBuf, "\r\n") if err != nil { - return fmt.Errorf("write response header terminator: %w", err) + return xerrors.Errorf("write response header terminator: %w", err) } // Wrap the response body to capture it as it streams @@ -175,7 +178,7 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv if override, ok := overrides[key]; ok { _, err := fmt.Fprintf(w, "%s: %s\r\n", key, override) if err != nil { - return fmt.Errorf("write response header override: %w", err) + return xerrors.Errorf("write response header override: %w", err) } } continue @@ -190,7 +193,7 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv } _, err := fmt.Fprintf(w, "%s: %s\r\n", key, value) if err != nil { - return fmt.Errorf("write response headers: %w", err) + return xerrors.Errorf("write response headers: %w", err) } } } diff --git a/intercept/apidump/apidump_test.go b/intercept/apidump/apidump_test.go index 1fa3d7f9..5fb8aa2e 100644 --- a/intercept/apidump/apidump_test.go +++ b/intercept/apidump/apidump_test.go @@ -10,11 +10,12 @@ import ( "strings" "testing" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/quartz" - "github.com/google/uuid" - "github.com/stretchr/testify/require" ) // findDumpFile finds a dump file matching the pattern in the given directory. diff --git a/intercept/apidump/headers_test.go b/intercept/apidump/headers_test.go index 181eae21..3fbce5a3 100644 --- a/intercept/apidump/headers_test.go +++ b/intercept/apidump/headers_test.go @@ -7,9 +7,10 @@ import ( "cdr.dev/slog/v3" - "github.com/coder/quartz" "github.com/google/uuid" "github.com/stretchr/testify/require" + + "github.com/coder/quartz" ) func TestRedactHeaderValue(t *testing.T) { diff --git a/intercept/apidump/streaming.go b/intercept/apidump/streaming.go index 1ad51215..e2db42ac 100644 --- a/intercept/apidump/streaming.go +++ b/intercept/apidump/streaming.go @@ -1,11 +1,12 @@ package apidump import ( - "fmt" "io" "os" "path/filepath" "sync" + + "golang.org/x/xerrors" ) // streamingBodyDumper wraps an io.ReadCloser and writes all data to a dump file @@ -24,18 +25,18 @@ type streamingBodyDumper struct { func (s *streamingBodyDumper) init() { s.once.Do(func() { if err := os.MkdirAll(filepath.Dir(s.dumpPath), 0o755); err != nil { - s.initErr = fmt.Errorf("create dump dir: %w", err) + s.initErr = xerrors.Errorf("create dump dir: %w", err) return } f, err := os.Create(s.dumpPath) if err != nil { - s.initErr = fmt.Errorf("create dump file: %w", err) + s.initErr = xerrors.Errorf("create dump file: %w", err) return } s.file = f // Write headers first. if _, err := s.file.Write(s.headerData); err != nil { - s.initErr = fmt.Errorf("write headers: %w", err) + s.initErr = xerrors.Errorf("write headers: %w", err) s.file.Close() s.file = nil } diff --git a/intercept/apidump/streaming_test.go b/intercept/apidump/streaming_test.go index 653a7262..2a39c1b8 100644 --- a/intercept/apidump/streaming_test.go +++ b/intercept/apidump/streaming_test.go @@ -9,11 +9,12 @@ import ( "strings" "testing" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/quartz" - "github.com/google/uuid" - "github.com/stretchr/testify/require" ) func TestMiddleware_StreamingResponse(t *testing.T) { diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index 75691136..e77c257c 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -7,6 +7,13 @@ import ( "net/http" "strings" + "github.com/google/uuid" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/shared" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/intercept" @@ -15,12 +22,6 @@ import ( "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" "github.com/coder/quartz" - "github.com/google/uuid" - "github.com/openai/openai-go/v3" - "github.com/openai/openai-go/v3/option" - "github.com/openai/openai-go/v3/shared" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" "cdr.dev/slog/v3" ) diff --git a/intercept/chatcompletions/base_test.go b/intercept/chatcompletions/base_test.go index 1647a2d5..5f83f5c3 100644 --- a/intercept/chatcompletions/base_test.go +++ b/intercept/chatcompletions/base_test.go @@ -3,9 +3,10 @@ package chatcompletions import ( "testing" - "github.com/coder/aibridge/utils" "github.com/openai/openai-go/v3" "github.com/stretchr/testify/require" + + "github.com/coder/aibridge/utils" ) func TestScanForCorrelatingToolCallID(t *testing.T) { diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index 532addd3..ed0fc71b 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -3,11 +3,17 @@ package chatcompletions import ( "context" "encoding/json" - "fmt" "net/http" "strings" "time" + "github.com/google/uuid" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/intercept" @@ -15,11 +21,6 @@ import ( "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" - "github.com/google/uuid" - "github.com/openai/openai-go/v3" - "github.com/openai/openai-go/v3/option" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" "cdr.dev/slog/v3" ) @@ -64,7 +65,7 @@ func (s *BlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyV func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if i.req == nil { - return fmt.Errorf("developer error: req is nil") + return xerrors.New("developer error: req is nil") } ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) @@ -220,16 +221,16 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req if err != nil { if eventstream.IsConnError(err) { http.Error(w, err.Error(), http.StatusInternalServerError) - return fmt.Errorf("upstream connection closed: %w", err) + return xerrors.Errorf("upstream connection closed: %w", err) } if apiErr := getErrorResponse(err); apiErr != nil { i.writeUpstreamError(w, apiErr) - return fmt.Errorf("openai API error: %w", err) + return xerrors.Errorf("openai API error: %w", err) } http.Error(w, err.Error(), http.StatusInternalServerError) - return fmt.Errorf("chat completion failed: %w", err) + return xerrors.Errorf("chat completion failed: %w", err) } if completion == nil { @@ -247,7 +248,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req w.Header().Set("Content-Type", "application/json") out, err := json.Marshal(completion) if err != nil { - out, _ = json.Marshal(i.newErrorResponse(fmt.Errorf("failed to marshal response: %w", err))) + out, _ = json.Marshal(i.newErrorResponse(xerrors.Errorf("failed to marshal response: %w", err))) w.WriteHeader(http.StatusInternalServerError) } else { w.WriteHeader(http.StatusOK) diff --git a/intercept/chatcompletions/paramswrap.go b/intercept/chatcompletions/paramswrap.go index b30d929c..ab6dd524 100644 --- a/intercept/chatcompletions/paramswrap.go +++ b/intercept/chatcompletions/paramswrap.go @@ -1,12 +1,12 @@ package chatcompletions import ( - "errors" - - "github.com/coder/aibridge/utils" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/packages/param" "github.com/tidwall/gjson" + "golang.org/x/xerrors" + + "github.com/coder/aibridge/utils" ) // ChatCompletionNewParamsWrapper exists because the "stream" param is not included in openai.ChatCompletionNewParams. @@ -42,11 +42,11 @@ func (c *ChatCompletionNewParamsWrapper) UnmarshalJSON(raw []byte) error { func (c *ChatCompletionNewParamsWrapper) lastUserPrompt() (*string, error) { if c == nil { - return nil, errors.New("nil struct") + return nil, xerrors.New("nil struct") } if len(c.Messages) == 0 { - return nil, errors.New("no messages") + return nil, xerrors.New("no messages") } // We only care if the last message was issued by a user. diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index d7a5485d..97bf2161 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -4,20 +4,11 @@ import ( "bytes" "context" "encoding/json" - "errors" - "fmt" "net/http" "slices" "strings" "time" - "github.com/coder/aibridge/config" - aibcontext "github.com/coder/aibridge/context" - "github.com/coder/aibridge/intercept" - "github.com/coder/aibridge/intercept/eventstream" - "github.com/coder/aibridge/mcp" - "github.com/coder/aibridge/recorder" - "github.com/coder/aibridge/tracing" "github.com/google/uuid" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/option" @@ -25,6 +16,15 @@ import ( "github.com/tidwall/sjson" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + + "github.com/coder/aibridge/config" + aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" + "github.com/coder/aibridge/intercept/eventstream" + "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/recorder" + "github.com/coder/aibridge/tracing" "cdr.dev/slog/v3" ) @@ -81,7 +81,7 @@ func (s *StreamingInterception) TraceAttributes(r *http.Request) []attribute.Key // can continue until all injected tool invocations are completed and the response is relayed to the client. func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if i.req == nil { - return fmt.Errorf("developer error: req is nil") + return xerrors.New("developer error: req is nil") } ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) @@ -101,7 +101,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re logger := i.logger.With(slog.F("model", i.req.Model)) streamCtx, streamCancel := context.WithCancelCause(ctx) - defer streamCancel(errors.New("deferred")) + defer streamCancel(xerrors.New("deferred")) // events will either terminate when shutdown after interaction with upstream completes, or when streamCtx is done. events := eventstream.NewEventStream(streamCtx, logger.Named("sse-sender"), nil) @@ -137,10 +137,10 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re // We take control of request body here and pass it to the SDK as a raw byte slice. // This is because the SDK's serialization applies hidden request options that result in - // unexpected, breaking behaviour. See https://github.com/coder/aibridge/pull/164 + // unexpected, breaking behavior. See https://github.com/coder/aibridge/pull/164 body, err := json.Marshal(i.req.ChatCompletionNewParams) if err != nil { - return fmt.Errorf("marshal request body: %w", err) + return xerrors.Errorf("marshal request body: %w", err) } opts = append(opts, option.WithRequestBody("application/json", body)) opts = append(opts, option.WithJSONSet("stream", true)) @@ -167,12 +167,12 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re payload, err := i.marshalChunk(&chunk, i.ID(), processor) if err != nil { logger.Warn(ctx, "failed to marshal chunk", slog.Error(err), slog.F("chunk", chunk.RawJSON())) - lastErr = fmt.Errorf("marshal chunk: %w", err) + lastErr = xerrors.Errorf("marshal chunk: %w", err) break } if err := events.Send(ctx, payload); err != nil { logger.Warn(ctx, "failed to relay chunk", slog.Error(err)) - lastErr = fmt.Errorf("relay chunk: %w", err) + lastErr = xerrors.Errorf("relay chunk: %w", err) break } } @@ -247,12 +247,12 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re // into known types (i.e. [shared.OverloadedError]). // See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171 // All it does is wrap the payload in an error - which is all we can return, currently. - interceptionErr = newErrorResponse(fmt.Errorf("unknown stream error: %w", streamErr)) + interceptionErr = newErrorResponse(xerrors.Errorf("unknown stream error: %w", streamErr)) } } else if lastErr != nil { // Otherwise check if any logical errors occurred during processing. logger.Warn(ctx, "stream failed", slog.Error(lastErr)) - interceptionErr = newErrorResponse(fmt.Errorf("processing error: %w", lastErr)) + interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr)) } if interceptionErr != nil { @@ -340,9 +340,9 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re } if err != nil { - streamCancel(fmt.Errorf("stream err: %w", err)) + streamCancel(xerrors.Errorf("stream err: %w", err)) } else { - streamCancel(errors.New("gracefully done")) + streamCancel(xerrors.New("gracefully done")) } return interceptionErr @@ -367,7 +367,7 @@ func (i *StreamingInterception) getInjectedToolByName(name string) *mcp.Tool { func (i *StreamingInterception) marshalChunk(chunk *openai.ChatCompletionChunk, id uuid.UUID, prc *streamProcessor) ([]byte, error) { sj, err := sjson.Set(chunk.RawJSON(), "id", id.String()) if err != nil { - return nil, fmt.Errorf("marshal chunk id failed: %w", err) + return nil, xerrors.Errorf("marshal chunk id failed: %w", err) } // If usage information is available, relay the cumulative usage once all tool invocations have completed. @@ -375,7 +375,7 @@ func (i *StreamingInterception) marshalChunk(chunk *openai.ChatCompletionChunk, u := prc.getCumulativeUsage() sj, err = sjson.Set(sj, "usage", u) if err != nil { - return nil, fmt.Errorf("marshal chunk usage failed: %w", err) + return nil, xerrors.Errorf("marshal chunk usage failed: %w", err) } } @@ -385,7 +385,7 @@ func (i *StreamingInterception) marshalChunk(chunk *openai.ChatCompletionChunk, func (i *StreamingInterception) marshalErr(err error) ([]byte, error) { data, err := json.Marshal(err) if err != nil { - return nil, fmt.Errorf("marshal error failed: %w", err) + return nil, xerrors.Errorf("marshal error failed: %w", err) } return i.encodeForStream(data), nil diff --git a/intercept/chatcompletions/streaming_test.go b/intercept/chatcompletions/streaming_test.go index ee27f431..52d5baa5 100644 --- a/intercept/chatcompletions/streaming_test.go +++ b/intercept/chatcompletions/streaming_test.go @@ -6,16 +6,17 @@ import ( "strconv" "testing" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" - "github.com/coder/aibridge/config" - "github.com/coder/aibridge/intercept" - "github.com/coder/aibridge/internal/testutil" "github.com/google/uuid" "github.com/openai/openai-go/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/intercept" + "github.com/coder/aibridge/internal/testutil" ) // Test that when the upstream provider returns an error before streaming starts, diff --git a/intercept/eventstream/eventstream.go b/intercept/eventstream/eventstream.go index b3ee96a2..562e385c 100644 --- a/intercept/eventstream/eventstream.go +++ b/intercept/eventstream/eventstream.go @@ -3,7 +3,6 @@ package eventstream import ( "context" "errors" - "fmt" "io" "net" "net/http" @@ -13,10 +12,12 @@ import ( "syscall" "time" + "golang.org/x/xerrors" + "cdr.dev/slog/v3" ) -var ErrEventStreamClosed = errors.New("event stream closed") +var ErrEventStreamClosed = xerrors.New("event stream closed") const pingInterval = time.Second * 10 @@ -187,9 +188,9 @@ func (s *EventStream) Shutdown(shutdownCtx context.Context) error { select { case <-shutdownCtx.Done(): // If shutdownCtx completes, shutdown likely exceeded its timeout. - err = fmt.Errorf("shutdown ended prematurely with %d outstanding events: %w", len(s.eventsCh), shutdownCtx.Err()) + err = xerrors.Errorf("shutdown ended prematurely with %d outstanding events: %w", len(s.eventsCh), shutdownCtx.Err()) case <-s.ctx.Done(): - err = fmt.Errorf("shutdown ended prematurely with %d outstanding events: %w", len(s.eventsCh), s.ctx.Err()) + err = xerrors.Errorf("shutdown ended prematurely with %d outstanding events: %w", len(s.eventsCh), s.ctx.Err()) case <-s.doneCh: return nil } @@ -235,7 +236,7 @@ func IsUnrecoverableError(err error) bool { func flush(w http.ResponseWriter) (err error) { flusher, ok := w.(http.Flusher) if !ok || flusher == nil { - return errors.New("SSE not supported") + return xerrors.New("SSE not supported") } defer func() { diff --git a/intercept/interceptor.go b/intercept/interceptor.go index 8b954286..4517ebd4 100644 --- a/intercept/interceptor.go +++ b/intercept/interceptor.go @@ -3,11 +3,12 @@ package intercept import ( "net/http" + "github.com/google/uuid" + "go.opentelemetry.io/otel/attribute" + "cdr.dev/slog/v3" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" - "github.com/google/uuid" - "go.opentelemetry.io/otel/attribute" ) // Interceptor describes a (potentially) stateful interaction with an AI provider. diff --git a/intercept/messages/base.go b/intercept/messages/base.go index a1458b07..b3aa40cc 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -16,6 +16,8 @@ import ( "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" + "golang.org/x/xerrors" + aibconfig "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/intercept" @@ -266,22 +268,22 @@ func (i *interceptionBase) withBody() option.RequestOption { func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) { if cfg == nil { - return nil, fmt.Errorf("nil config given") + return nil, xerrors.New("nil config given") } if cfg.Region == "" && cfg.BaseURL == "" { - return nil, fmt.Errorf("region or base url required") + return nil, xerrors.New("region or base url required") } if cfg.AccessKey == "" { - return nil, fmt.Errorf("access key required") + return nil, xerrors.New("access key required") } if cfg.AccessKeySecret == "" { - return nil, fmt.Errorf("access key secret required") + return nil, xerrors.New("access key secret required") } if cfg.Model == "" { - return nil, fmt.Errorf("model required") + return nil, xerrors.New("model required") } if cfg.SmallFastModel == "" { - return nil, fmt.Errorf("small fast model required") + return nil, xerrors.New("small fast model required") } opts := []func(*config.LoadOptions) error{ @@ -297,7 +299,7 @@ func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibco awsCfg, err := config.LoadDefaultConfig(ctx, opts...) if err != nil { - return nil, fmt.Errorf("failed to load AWS Bedrock config: %w", err) + return nil, xerrors.Errorf("failed to load AWS Bedrock config: %w", err) } var out []option.RequestOption @@ -369,9 +371,7 @@ func filterBedrockBetaFlags(headers http.Header, model string) { // https://httpwg.org/specs/rfc9110.html#rfc.section.5.3 var flags []string for _, v := range headers.Values("Anthropic-Beta") { - for _, flag := range strings.Split(v, ",") { - flags = append(flags, flag) - } + flags = append(flags, strings.Split(v, ",")...) } if len(flags) == 0 { diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index 1096e9a8..0ab88b8d 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -5,15 +5,16 @@ import ( "net/http" "testing" - "cdr.dev/slog/v3" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/shared/constant" - "github.com/coder/aibridge/config" - "github.com/coder/aibridge/mcp" - "github.com/coder/aibridge/utils" mcpgo "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" + + "cdr.dev/slog/v3" + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/utils" ) func TestScanForCorrelatingToolCallID(t *testing.T) { diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index f83f2187..7b63d344 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -8,6 +8,13 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/option" + "github.com/google/uuid" + mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/tidwall/sjson" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/intercept" @@ -15,11 +22,6 @@ import ( "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" - "github.com/google/uuid" - mcplib "github.com/mark3labs/mcp-go/mcp" - "github.com/tidwall/sjson" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" "cdr.dev/slog/v3" ) @@ -66,7 +68,7 @@ func (s *BlockingInterception) Streaming() bool { func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if len(i.reqPayload) == 0 { - return fmt.Errorf("developer error: request payload is empty") + return xerrors.New("developer error: request payload is empty") } ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) @@ -91,7 +93,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req svc, err := i.newMessagesService(ctx, opts...) if err != nil { - err = fmt.Errorf("create anthropic client: %w", err) + err = xerrors.Errorf("create anthropic client: %w", err) http.Error(w, err.Error(), http.StatusInternalServerError) return err } @@ -108,16 +110,16 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req if err != nil { if eventstream.IsConnError(err) { // Can't write a response, just error out. - return fmt.Errorf("upstream connection closed: %w", err) + return xerrors.Errorf("upstream connection closed: %w", err) } if antErr := getErrorResponse(err); antErr != nil { i.writeUpstreamError(w, antErr) - return fmt.Errorf("anthropic API error: %w", err) + return xerrors.Errorf("anthropic API error: %w", err) } http.Error(w, "internal error", http.StatusInternalServerError) - return fmt.Errorf("internal error: %w", err) + return xerrors.Errorf("internal error: %w", err) } if prompt != nil { @@ -305,7 +307,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req updatedPayload, rewriteErr := i.reqPayload.appendedMessages(loopMessages) if rewriteErr != nil { http.Error(w, rewriteErr.Error(), http.StatusInternalServerError) - return fmt.Errorf("rewrite payload for agentic loop: %w", rewriteErr) + return xerrors.Errorf("rewrite payload for agentic loop: %w", rewriteErr) } i.reqPayload = updatedPayload } @@ -317,13 +319,13 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req // Overwrite response identifier since proxy obscures injected tool call invocations. sj, err := sjson.Set(resp.RawJSON(), "id", i.ID().String()) if err != nil { - return fmt.Errorf("marshal response id failed: %w", err) + return xerrors.Errorf("marshal response id failed: %w", err) } // Overwrite the response's usage with the cumulative usage across any inner loops which invokes injected MCP tools. sj, err = sjson.Set(sj, "usage", cumulativeUsage) if err != nil { - return fmt.Errorf("marshal response usage failed: %w", err) + return xerrors.Errorf("marshal response usage failed: %w", err) } w.Header().Set("Content-Type", "application/json") diff --git a/intercept/messages/reqpayload.go b/intercept/messages/reqpayload.go index a139f9c1..fa5142ea 100644 --- a/intercept/messages/reqpayload.go +++ b/intercept/messages/reqpayload.go @@ -3,7 +3,6 @@ package messages import ( "bytes" "encoding/json" - "fmt" "net/http" "slices" @@ -11,6 +10,7 @@ import ( "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + "golang.org/x/xerrors" ) const ( @@ -89,10 +89,10 @@ type MessagesRequestPayload []byte func NewMessagesRequestPayload(raw []byte) (MessagesRequestPayload, error) { if len(bytes.TrimSpace(raw)) == 0 { - return nil, fmt.Errorf("messages empty request body") + return nil, xerrors.New("messages empty request body") } if !json.Valid(raw) { - return nil, fmt.Errorf("messages invalid JSON request body") + return nil, xerrors.New("messages invalid JSON request body") } return MessagesRequestPayload(raw), nil @@ -153,7 +153,7 @@ func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) { return "", false, nil } if !messages.IsArray() { - return "", false, fmt.Errorf("unexpected messages type: %s", messages.Type) + return "", false, xerrors.Errorf("unexpected messages type: %s", messages.Type) } messageItems := messages.Array() @@ -174,7 +174,7 @@ func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) { return content.String(), true, nil } if !content.IsArray() { - return "", false, fmt.Errorf("unexpected message content type: %s", content.Type) + return "", false, xerrors.Errorf("unexpected message content type: %s", content.Type) } contentItems := content.Array() @@ -202,7 +202,7 @@ func (p MessagesRequestPayload) injectTools(injected []anthropic.ToolUnionParam) existing, err := p.tools() if err != nil { - return p, fmt.Errorf("get existing tools: %w", err) + return p, xerrors.Errorf("get existing tools: %w", err) } // Using []json.Marshaler to merge differently-typed slices ([]anthropic.ToolUnionParam @@ -229,24 +229,24 @@ func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPaylo if !toolChoice.Exists() || toolChoice.Type == gjson.Null { updated, err := p.set(messagesReqPathToolChoiceType, constAuto) if err != nil { - return p, fmt.Errorf("set tool choice type: %w", err) + return p, xerrors.Errorf("set tool choice type: %w", err) } return updated.set(messagesReqPathToolChoiceDisableParallel, true) } if !toolChoice.IsObject() { - return p, fmt.Errorf("unsupported tool_choice type: %s", toolChoice.Type) + return p, xerrors.Errorf("unsupported tool_choice type: %s", toolChoice.Type) } toolChoiceType := gjson.GetBytes(p, messagesReqPathToolChoiceType) if toolChoiceType.Exists() && toolChoiceType.Type != gjson.String { - return p, fmt.Errorf("unsupported tool_choice.type type: %s", toolChoiceType.Type) + return p, xerrors.Errorf("unsupported tool_choice.type type: %s", toolChoiceType.Type) } switch toolChoiceType.String() { case "": updated, err := p.set(messagesReqPathToolChoiceType, constAuto) if err != nil { - return p, fmt.Errorf("set tool_choice.type: %w", err) + return p, xerrors.Errorf("set tool_choice.type: %w", err) } return updated.set(messagesReqPathToolChoiceDisableParallel, true) case constAuto, constAny, constTool: @@ -254,7 +254,7 @@ func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPaylo case constNone: return p, nil default: - return p, fmt.Errorf("unsupported tool_choice.type value: %q", toolChoiceType.String()) + return p, xerrors.Errorf("unsupported tool_choice.type value: %q", toolChoiceType.String()) } } @@ -265,7 +265,7 @@ func (p MessagesRequestPayload) appendedMessages(newMessages []anthropic.Message existing, err := p.messages() if err != nil { - return p, fmt.Errorf("get existing messages: %w", err) + return p, xerrors.Errorf("get existing messages: %w", err) } // Using []json.Marshaler to merge differently-typed slices ([]json.Marshaler containing @@ -295,7 +295,7 @@ func (p MessagesRequestPayload) messages() ([]json.RawMessage, error) { return nil, nil } if !messages.IsArray() { - return nil, fmt.Errorf("unsupported messages type: %s", messages.Type) + return nil, xerrors.Errorf("unsupported messages type: %s", messages.Type) } return p.resultToRawMessage(messages.Array()), nil @@ -307,7 +307,7 @@ func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) { return nil, nil } if !tools.IsArray() { - return nil, fmt.Errorf("unsupported tools type: %s", tools.Type) + return nil, xerrors.Errorf("unsupported tools type: %s", tools.Type) } return p.resultToRawMessage(tools.Array()), nil @@ -335,7 +335,7 @@ func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesReq maxTokens := gjson.GetBytes(p, messagesReqPathMaxTokens).Int() if maxTokens <= 0 { // max_tokens is required by messages API - return p, fmt.Errorf("max_tokens: field required") + return p, xerrors.New("max_tokens: field required") } effort := gjson.GetBytes(p, messagesReqPathOutputConfigEffort).String() @@ -380,7 +380,7 @@ func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesReq func (p MessagesRequestPayload) removeUnsupportedBedrockFields(headers http.Header) (MessagesRequestPayload, error) { var payloadMap map[string]any if err := json.Unmarshal(p, &payloadMap); err != nil { - return p, fmt.Errorf("failed to unmarshal request payload when removing unsupported Bedrock fields: %w", err) + return p, xerrors.Errorf("failed to unmarshal request payload when removing unsupported Bedrock fields: %w", err) } // Always strip unconditionally unsupported fields. @@ -398,7 +398,7 @@ func (p MessagesRequestPayload) removeUnsupportedBedrockFields(headers http.Head result, err := json.Marshal(payloadMap) if err != nil { - return p, fmt.Errorf("failed to marshal request payload when removing unsupported Bedrock fields: %w", err) + return p, xerrors.Errorf("failed to marshal request payload when removing unsupported Bedrock fields: %w", err) } return MessagesRequestPayload(result), nil } @@ -406,7 +406,7 @@ func (p MessagesRequestPayload) removeUnsupportedBedrockFields(headers http.Head func (p MessagesRequestPayload) set(path string, value any) (MessagesRequestPayload, error) { out, err := sjson.SetBytes(p, path, value) if err != nil { - return p, fmt.Errorf("set %s: %w", path, err) + return p, xerrors.Errorf("set %s: %w", path, err) } return MessagesRequestPayload(out), nil } diff --git a/intercept/messages/reqpayload_test.go b/intercept/messages/reqpayload_test.go index fcfdd39b..f16fa4fa 100644 --- a/intercept/messages/reqpayload_test.go +++ b/intercept/messages/reqpayload_test.go @@ -5,9 +5,10 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/shared/constant" - "github.com/coder/aibridge/utils" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" + + "github.com/coder/aibridge/utils" ) func TestNewMessagesRequestPayload(t *testing.T) { diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 760313ec..30f76274 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "net/http" "strings" @@ -14,6 +13,13 @@ import ( "github.com/anthropics/anthropic-sdk-go/option" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/anthropics/anthropic-sdk-go/shared/constant" + "github.com/google/uuid" + mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/tidwall/sjson" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/intercept" @@ -21,11 +27,6 @@ import ( "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" - "github.com/google/uuid" - mcplib "github.com/mark3labs/mcp-go/mcp" - "github.com/tidwall/sjson" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" "cdr.dev/slog/v3" ) @@ -91,7 +92,7 @@ func (s *StreamingInterception) TraceAttributes(r *http.Request) []attribute.Key // can continue until all injected tool invocations are completed and the response is relayed to the client. func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { if len(i.reqPayload) == 0 { - return fmt.Errorf("developer error: request payload is empty") + return xerrors.New("developer error: request payload is empty") } ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) @@ -122,7 +123,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re } streamCtx, streamCancel := context.WithCancelCause(ctx) - defer streamCancel(errors.New("deferred")) + defer streamCancel(xerrors.New("deferred")) // TODO(ssncferreira): inject actor headers directly in the client-header // middleware instead of using SDK options. @@ -133,7 +134,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re svc, err := i.newMessagesService(streamCtx, opts...) if err != nil { - err = fmt.Errorf("create anthropic client: %w", err) + err = xerrors.Errorf("create anthropic client: %w", err) http.Error(w, err.Error(), http.StatusInternalServerError) return err } @@ -156,7 +157,7 @@ newStream: for { // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) if err := streamCtx.Err(); err != nil { - lastErr = fmt.Errorf("stream exit: %w", err) + lastErr = xerrors.Errorf("stream exit: %w", err) break } @@ -171,7 +172,7 @@ newStream: event := stream.Current() if err := message.Accumulate(event); err != nil { logger.Warn(ctx, "failed to accumulate streaming events", slog.Error(err), slog.F("event", event), slog.F("msg", message.RawJSON())) - lastErr = fmt.Errorf("accumulate event: %w", err) + lastErr = xerrors.Errorf("accumulate event: %w", err) break } @@ -422,7 +423,7 @@ newStream: // sends the updated payload on the next iteration. updatedPayload, syncErr := i.reqPayload.appendedMessages(loopMessages) if syncErr != nil { - lastErr = fmt.Errorf("sync payload for agentic loop: %w", syncErr) + lastErr = xerrors.Errorf("sync payload for agentic loop: %w", syncErr) break } i.reqPayload = updatedPayload @@ -456,7 +457,7 @@ newStream: payload, err := i.marshalEvent(event) if err != nil { logger.Warn(ctx, "failed to marshal event", slog.Error(err), slog.F("event", event.RawJSON())) - lastErr = fmt.Errorf("marshal event: %w", err) + lastErr = xerrors.Errorf("marshal event: %w", err) break } if err := events.Send(streamCtx, payload); err != nil { @@ -465,7 +466,7 @@ newStream: break // Stop processing if client disconnected or context canceled. } else { logger.Warn(ctx, "failed to relay event", slog.Error(err)) - lastErr = fmt.Errorf("relay event: %w", err) + lastErr = xerrors.Errorf("relay event: %w", err) break } } @@ -496,12 +497,12 @@ newStream: // into known types (i.e. [shared.OverloadedError]). // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174 // All it does is wrap the payload in an error - which is all we can return, currently. - interceptionErr = newErrorResponse(fmt.Errorf("unknown stream error: %w", streamErr)) + interceptionErr = newErrorResponse(xerrors.Errorf("unknown stream error: %w", streamErr)) } } else if lastErr != nil { // Otherwise check if any logical errors occurred during processing. logger.Warn(ctx, "stream failed", slog.Error(lastErr)) - interceptionErr = newErrorResponse(fmt.Errorf("processing error: %w", lastErr)) + interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr)) } if interceptionErr != nil { @@ -528,7 +529,7 @@ newStream: if interceptionErr != nil { streamCancel(interceptionErr) } else { - streamCancel(errors.New("gracefully done")) + streamCancel(xerrors.New("gracefully done")) } break @@ -540,12 +541,12 @@ newStream: func (s *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventUnion) ([]byte, error) { sj, err := sjson.Set(event.RawJSON(), "message.id", s.ID().String()) if err != nil { - return nil, fmt.Errorf("marshal event id failed: %w", err) + return nil, xerrors.Errorf("marshal event id failed: %w", err) } sj, err = sjson.Set(sj, "usage.output_tokens", event.Usage.OutputTokens) if err != nil { - return nil, fmt.Errorf("marshal event usage failed: %w", err) + return nil, xerrors.Errorf("marshal event usage failed: %w", err) } return s.encodeForStream([]byte(sj), event.Type), nil @@ -554,17 +555,17 @@ func (s *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventU func (s *StreamingInterception) marshal(payload any) ([]byte, error) { data, err := json.Marshal(payload) if err != nil { - return nil, fmt.Errorf("marshal payload: %w", err) + return nil, xerrors.Errorf("marshal payload: %w", err) } var parsed map[string]any if err := json.Unmarshal(data, &parsed); err != nil { - return nil, fmt.Errorf("unmarshal payload: %w", err) + return nil, xerrors.Errorf("unmarshal payload: %w", err) } eventType, ok := parsed["type"].(string) if !ok || strings.TrimSpace(eventType) == "" { - return nil, fmt.Errorf("could not determine type from payload %q", data) + return nil, xerrors.Errorf("could not determine type from payload %q", data) } return s.encodeForStream(data, eventType), nil diff --git a/intercept/responses/base.go b/intercept/responses/base.go index 9949009c..7d659507 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "fmt" "io" "net/http" "strconv" @@ -13,6 +12,15 @@ import ( "sync/atomic" "time" + "github.com/google/uuid" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/responses" + "github.com/openai/openai-go/v3/shared/constant" + "github.com/tidwall/gjson" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + "cdr.dev/slog/v3" "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" @@ -22,13 +30,6 @@ import ( "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" "github.com/coder/quartz" - "github.com/google/uuid" - "github.com/openai/openai-go/v3/option" - "github.com/openai/openai-go/v3/responses" - "github.com/openai/openai-go/v3/shared/constant" - "github.com/tidwall/gjson" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" ) const ( @@ -115,7 +116,7 @@ func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streami func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http.ResponseWriter) error { if i.reqPayload.background() { - err := fmt.Errorf("background requests are currently not supported by AI Bridge") + err := xerrors.New("background requests are currently not supported by AI Bridge") i.sendCustomErr(ctx, w, http.StatusNotImplemented, err) return err } @@ -372,11 +373,11 @@ func (r *responseCopier) forwardResp(w http.ResponseWriter) error { b, err := r.readAll() if err != nil { - return fmt.Errorf("failed to read response body: %w", err) + return xerrors.Errorf("failed to read response body: %w", err) } if _, err := w.Write(b); err != nil { - return fmt.Errorf("failed to write response body: %w", err) + return xerrors.Errorf("failed to write response body: %w", err) } return nil } diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index f4d117e8..e25f5922 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -5,12 +5,13 @@ import ( "testing" "time" - "cdr.dev/slog/v3" - "github.com/coder/aibridge/internal/testutil" - "github.com/coder/aibridge/recorder" "github.com/google/uuid" oairesponses "github.com/openai/openai-go/v3/responses" "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3" + "github.com/coder/aibridge/internal/testutil" + "github.com/coder/aibridge/recorder" ) func TestRecordPrompt(t *testing.T) { diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 9d263dec..c68ecb85 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -3,10 +3,16 @@ package responses import ( "context" "errors" - "fmt" "net/http" "time" + "github.com/google/uuid" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/responses" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + "cdr.dev/slog/v3" "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" @@ -14,11 +20,6 @@ import ( "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" - "github.com/google/uuid" - "github.com/openai/openai-go/v3/option" - "github.com/openai/openai-go/v3/responses" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" ) type BlockingResponsesInterceptor struct { @@ -127,7 +128,7 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * if upstreamErr != nil && !respCopy.responseReceived.Load() { // no response received from upstream, return custom error i.sendCustomErr(ctx, w, http.StatusInternalServerError, upstreamErr) - return fmt.Errorf("failed to connect to upstream: %w", upstreamErr) + return xerrors.Errorf("failed to connect to upstream: %w", upstreamErr) } err = respCopy.forwardResp(w) diff --git a/intercept/responses/injected_tools.go b/intercept/responses/injected_tools.go index fee27218..dd44014b 100644 --- a/intercept/responses/injected_tools.go +++ b/intercept/responses/injected_tools.go @@ -6,11 +6,13 @@ import ( "fmt" "strings" - "cdr.dev/slog/v3" - "github.com/coder/aibridge/recorder" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/responses" "github.com/openai/openai-go/v3/shared/constant" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/aibridge/recorder" ) func (i *responsesInterceptionBase) injectTools() { @@ -80,7 +82,7 @@ func (i *responsesInterceptionBase) handleInnerAgenticLoop(ctx context.Context, // See https://platform.openai.com/docs/guides/function-calling results, err := i.handleInjectedToolCalls(ctx, pending, response) if err != nil { - return false, fmt.Errorf("failed to handle injected tool calls: %w", err) + return false, xerrors.Errorf("failed to handle injected tool calls: %w", err) } // No tool results means no tools were invocable, so the flow is complete. @@ -99,7 +101,7 @@ func (i *responsesInterceptionBase) handleInnerAgenticLoop(ctx context.Context, // Returns a list of tool call results. func (i *responsesInterceptionBase) handleInjectedToolCalls(ctx context.Context, pending []responses.ResponseFunctionToolCall, response *responses.Response) ([]responses.ResponseInputItemUnionParam, error) { if response == nil { - return nil, fmt.Errorf("empty response") + return nil, xerrors.New("empty response") } // MCP proxy has not been configured; no way to handle injected functions. @@ -133,7 +135,7 @@ func (i *responsesInterceptionBase) prepareRequestForAgenticLoop(ctx context.Con updated, err := i.reqPayload.appendInputItems(newItems) if err != nil { i.logger.Error(ctx, "failed to rewrite input in inner agentic loop", slog.Error(err)) - return fmt.Errorf("failed to rewrite input: %w", err) + return xerrors.Errorf("failed to rewrite input: %w", err) } i.reqPayload = updated diff --git a/intercept/responses/reqpayload.go b/intercept/responses/reqpayload.go index 238f356f..02086355 100644 --- a/intercept/responses/reqpayload.go +++ b/intercept/responses/reqpayload.go @@ -7,11 +7,13 @@ import ( "fmt" "strings" - "cdr.dev/slog/v3" "github.com/openai/openai-go/v3/responses" "github.com/openai/openai-go/v3/shared/constant" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" ) const ( @@ -43,10 +45,10 @@ type ResponsesRequestPayload []byte func NewResponsesRequestPayload(raw []byte) (ResponsesRequestPayload, error) { if len(bytes.TrimSpace(raw)) == 0 { - return nil, fmt.Errorf("empty request body") + return nil, xerrors.New("empty request body") } if !json.Valid(raw) { - return nil, fmt.Errorf("invalid JSON payload") + return nil, xerrors.New("invalid JSON payload") } return ResponsesRequestPayload(raw), nil @@ -108,7 +110,7 @@ func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog // Array variant: checking only the last input item if !inputItems.IsArray() { - return "", false, fmt.Errorf("unexpected input type: %s", inputItems.Type) + return "", false, xerrors.Errorf("unexpected input type: %s", inputItems.Type) } inputItemsArr := inputItems.Array() @@ -135,7 +137,7 @@ func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog } if !content.IsArray() { - return "", false, fmt.Errorf("unexpected input content type: %s", content.Type) + return "", false, xerrors.Errorf("unexpected input content type: %s", content.Type) } var sb strings.Builder @@ -173,7 +175,7 @@ func (p ResponsesRequestPayload) injectTools(injected []responses.ToolUnionParam existing, err := p.toolItems() if err != nil { - return p, fmt.Errorf("failed to get existing tools: %w", err) + return p, xerrors.Errorf("failed to get existing tools: %w", err) } allTools := make([]any, 0, len(existing)+len(injected)) @@ -198,7 +200,7 @@ func (p ResponsesRequestPayload) appendInputItems(items []responses.ResponseInpu existing, err := p.inputItems() if err != nil { - return p, fmt.Errorf("failed to get existing 'input' items: %w", err) + return p, xerrors.Errorf("failed to get existing 'input' items: %w", err) } allInput := make([]any, 0, len(existing)+len(items)) @@ -221,7 +223,7 @@ func (p ResponsesRequestPayload) inputItems() ([]any, error) { } if !input.IsArray() { - return nil, fmt.Errorf("unsupported 'input' type: %s", input.Type) + return nil, xerrors.Errorf("unsupported 'input' type: %s", input.Type) } items := input.Array() @@ -239,7 +241,7 @@ func (p ResponsesRequestPayload) toolItems() ([]json.RawMessage, error) { return nil, nil } if !tools.IsArray() { - return nil, fmt.Errorf("unsupported 'tools' type: %s", tools.Type) + return nil, xerrors.Errorf("unsupported 'tools' type: %s", tools.Type) } items := tools.Array() @@ -254,7 +256,7 @@ func (p ResponsesRequestPayload) toolItems() ([]json.RawMessage, error) { func (p ResponsesRequestPayload) set(path string, value any) (ResponsesRequestPayload, error) { updated, err := sjson.SetBytes(p, path, value) if err != nil { - return p, fmt.Errorf("failed to set value at path %s: %w", path, err) + return p, xerrors.Errorf("failed to set value at path %s: %w", path, err) } return updated, nil } diff --git a/intercept/responses/reqpayload_test.go b/intercept/responses/reqpayload_test.go index b0338fb6..09b74807 100644 --- a/intercept/responses/reqpayload_test.go +++ b/intercept/responses/reqpayload_test.go @@ -5,14 +5,15 @@ import ( "fmt" "testing" - "cdr.dev/slog/v3" - "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/utils" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/responses" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" + + "cdr.dev/slog/v3" + "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/utils" ) func TestNewResponsesRequestPayload(t *testing.T) { diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 0c692f83..8a354ecd 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -3,10 +3,18 @@ package responses import ( "context" "errors" - "fmt" "net/http" "time" + "github.com/google/uuid" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/packages/ssestream" + "github.com/openai/openai-go/v3/responses" + oaiconst "github.com/openai/openai-go/v3/shared/constant" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + "cdr.dev/slog/v3" "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" @@ -15,13 +23,6 @@ import ( "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" - "github.com/google/uuid" - "github.com/openai/openai-go/v3/option" - "github.com/openai/openai-go/v3/packages/ssestream" - "github.com/openai/openai-go/v3/responses" - oaiconst "github.com/openai/openai-go/v3/shared/constant" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" ) const ( @@ -162,7 +163,7 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r // This is needed to keep consistency between response.id and response.previous_response_id fields. if i.mcpProxy == nil { if err := events.Send(ctx, respCopy.buff.readDelta()); err != nil { - err = fmt.Errorf("failed to relay chunk: %w", err) + err = xerrors.Errorf("failed to relay chunk: %w", err) return err } } @@ -203,7 +204,7 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r b, err := respCopy.readAll() if err != nil { - return fmt.Errorf("failed to read response body: %w", err) + return xerrors.Errorf("failed to read response body: %w", err) } err = events.Send(ctx, b) diff --git a/internal/integrationtest/apidump_test.go b/internal/integrationtest/apidump_test.go index 55231b05..f3e4a741 100644 --- a/internal/integrationtest/apidump_test.go +++ b/internal/integrationtest/apidump_test.go @@ -13,12 +13,13 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/provider" - "github.com/stretchr/testify/require" ) func TestAPIDump(t *testing.T) { diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index 4f5f72d0..73f42ad3 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -15,14 +15,6 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/anthropics/anthropic-sdk-go/shared/constant" - "github.com/coder/aibridge" - "github.com/coder/aibridge/config" - "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/intercept" - "github.com/coder/aibridge/mcp" - "github.com/coder/aibridge/provider" - "github.com/coder/aibridge/recorder" - "github.com/coder/aibridge/utils" "github.com/google/uuid" "github.com/openai/openai-go/v3" oaissestream "github.com/openai/openai-go/v3/packages/ssestream" @@ -31,6 +23,16 @@ import ( "github.com/tidwall/gjson" "github.com/tidwall/sjson" "go.uber.org/goleak" + "golang.org/x/xerrors" + + "github.com/coder/aibridge" + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/intercept" + "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/provider" + "github.com/coder/aibridge/recorder" + "github.com/coder/aibridge/utils" ) func TestMain(m *testing.M) { @@ -628,23 +630,23 @@ func TestSimple(t *testing.T) { for stream.Next() { event := stream.Current() if err := message.Accumulate(event); err != nil { - return "", fmt.Errorf("accumulate event: %w", err) + return "", xerrors.Errorf("accumulate event: %w", err) } } if stream.Err() != nil { - return "", fmt.Errorf("stream error: %w", stream.Err()) + return "", xerrors.Errorf("stream error: %w", stream.Err()) } return message.ID, nil } body, err := io.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("read body: %w", err) + return "", xerrors.Errorf("read body: %w", err) } var message anthropic.Message if err := json.Unmarshal(body, &message); err != nil { - return "", fmt.Errorf("unmarshal response: %w", err) + return "", xerrors.Errorf("unmarshal response: %w", err) } return message.ID, nil } @@ -660,7 +662,7 @@ func TestSimple(t *testing.T) { message.AddChunk(chunk) } if stream.Err() != nil { - return "", fmt.Errorf("stream error: %w", stream.Err()) + return "", xerrors.Errorf("stream error: %w", stream.Err()) } return message.ID, nil } @@ -668,12 +670,12 @@ func TestSimple(t *testing.T) { // Parse & unmarshal the response. body, err := io.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("read body: %w", err) + return "", xerrors.Errorf("read body: %w", err) } var message openai.ChatCompletion if err := json.Unmarshal(body, &message); err != nil { - return "", fmt.Errorf("unmarshal response: %w", err) + return "", xerrors.Errorf("unmarshal response: %w", err) } return message.ID, nil } diff --git a/internal/integrationtest/circuit_breaker_test.go b/internal/integrationtest/circuit_breaker_test.go index 4e392649..35beaee6 100644 --- a/internal/integrationtest/circuit_breaker_test.go +++ b/internal/integrationtest/circuit_breaker_test.go @@ -11,13 +11,14 @@ import ( "testing" "time" - "github.com/coder/aibridge/config" - "github.com/coder/aibridge/metrics" - "github.com/coder/aibridge/provider" "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/metrics" + "github.com/coder/aibridge/provider" ) // Common response bodies for circuit breaker tests. diff --git a/internal/integrationtest/metrics_test.go b/internal/integrationtest/metrics_test.go index 29ab16e5..6f1dae08 100644 --- a/internal/integrationtest/metrics_test.go +++ b/internal/integrationtest/metrics_test.go @@ -9,14 +9,15 @@ import ( "testing" "time" - "github.com/coder/aibridge" - "github.com/coder/aibridge/config" - "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/metrics" "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" "github.com/tidwall/sjson" + + "github.com/coder/aibridge" + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/metrics" ) func TestMetrics_Interception(t *testing.T) { diff --git a/internal/integrationtest/mockmcp.go b/internal/integrationtest/mockmcp.go index eba25dd1..5f82f3f4 100644 --- a/internal/integrationtest/mockmcp.go +++ b/internal/integrationtest/mockmcp.go @@ -2,7 +2,6 @@ package integrationtest import ( "context" - "errors" "fmt" "net/http" "net/http/httptest" @@ -10,15 +9,17 @@ import ( "testing" "time" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" - "github.com/coder/aibridge/mcp" "github.com/mark3labs/mcp-go/client/transport" mcplib "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace/noop" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/aibridge/mcp" ) // mockToolName is the primary mock tool name used in MCP tests. @@ -143,7 +144,7 @@ func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) { s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { acc.addCall(request.Params.Name, request.Params.Arguments) if errMsg, ok := acc.getToolError(request.Params.Name); ok { - return nil, errors.New(errMsg) + return nil, xerrors.New(errMsg) } return mcplib.NewToolResultText("mock"), nil }) diff --git a/internal/integrationtest/mockupstream.go b/internal/integrationtest/mockupstream.go index a658b054..4112fea8 100644 --- a/internal/integrationtest/mockupstream.go +++ b/internal/integrationtest/mockupstream.go @@ -17,10 +17,11 @@ import ( "testing" "github.com/anthropics/anthropic-sdk-go" - "github.com/coder/aibridge/fixtures" "github.com/openai/openai-go/v3" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" + + "github.com/coder/aibridge/fixtures" ) // upstreamResponse defines a single response that mockUpstream will replay diff --git a/internal/integrationtest/responses_test.go b/internal/integrationtest/responses_test.go index 2dc6466c..6861ccb8 100644 --- a/internal/integrationtest/responses_test.go +++ b/internal/integrationtest/responses_test.go @@ -14,16 +14,17 @@ import ( "testing" "time" + "github.com/openai/openai-go/v3/responses" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/sjson" + "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/utils" - "github.com/openai/openai-go/v3/responses" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/tidwall/sjson" ) type keyVal struct { diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go index bb999d21..5c91953b 100644 --- a/internal/integrationtest/setupbridge.go +++ b/internal/integrationtest/setupbridge.go @@ -9,6 +9,11 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/tidwall/sjson" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" + "cdr.dev/slog/v3" "github.com/coder/aibridge" "github.com/coder/aibridge/config" @@ -19,10 +24,6 @@ import ( "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" - "github.com/stretchr/testify/require" - "github.com/tidwall/sjson" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/trace" ) const ( diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index e9b27d64..b19707a0 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -8,9 +8,6 @@ import ( "testing" "time" - "github.com/coder/aibridge/config" - "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/tracing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" @@ -20,6 +17,10 @@ import ( sdktrace "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/sdk/trace/tracetest" oteltrace "go.opentelemetry.io/otel/trace" + + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/tracing" ) // expect 'count' amount of traces named 'name' with status 'status' diff --git a/internal/testutil/mock_recorder.go b/internal/testutil/mock_recorder.go index b5366945..4d9b5636 100644 --- a/internal/testutil/mock_recorder.go +++ b/internal/testutil/mock_recorder.go @@ -2,14 +2,15 @@ package testutil import ( "context" - "fmt" "slices" "strings" "sync" "testing" - "github.com/coder/aibridge/recorder" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/aibridge/recorder" ) // MockRecorder is a test implementation of aibridge.Recorder that @@ -39,7 +40,7 @@ func (m *MockRecorder) RecordInterceptionEnded(ctx context.Context, req *recorde m.interceptionsEnd = make(map[string]*recorder.InterceptionRecordEnded) } if !slices.ContainsFunc(m.interceptions, func(intc *recorder.InterceptionRecord) bool { return intc.ID == req.ID }) { - return fmt.Errorf("id not found") + return xerrors.New("id not found") } m.interceptionsEnd[req.ID] = req return nil diff --git a/internal/testutil/mockprovider.go b/internal/testutil/mockprovider.go index 06b8a2f2..6ef9175f 100644 --- a/internal/testutil/mockprovider.go +++ b/internal/testutil/mockprovider.go @@ -4,9 +4,10 @@ import ( "fmt" "net/http" + "go.opentelemetry.io/otel/trace" + "github.com/coder/aibridge/config" "github.com/coder/aibridge/intercept" - "go.opentelemetry.io/otel/trace" ) type MockProvider struct { diff --git a/mcp/client_info.go b/mcp/client_info.go index 84b33d09..7dea5c53 100644 --- a/mcp/client_info.go +++ b/mcp/client_info.go @@ -1,8 +1,9 @@ package mcp import ( - "github.com/coder/aibridge/buildinfo" "github.com/mark3labs/mcp-go/mcp" + + "github.com/coder/aibridge/buildinfo" ) // GetClientInfo returns the MCP client information to use when initializing MCP connections. diff --git a/mcp/client_info_test.go b/mcp/client_info_test.go index d273d10d..a48487b6 100644 --- a/mcp/client_info_test.go +++ b/mcp/client_info_test.go @@ -3,8 +3,9 @@ package mcp_test import ( "testing" - "github.com/coder/aibridge/mcp" "github.com/stretchr/testify/assert" + + "github.com/coder/aibridge/mcp" ) func TestGetClientInfo(t *testing.T) { diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 9cb548a2..5769440e 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -11,15 +11,17 @@ import ( "testing" "time" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" "go.opentelemetry.io/otel" "go.uber.org/goleak" - "github.com/coder/aibridge/mcp" + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/require" + "github.com/coder/aibridge/mcp" + mcplib "github.com/mark3labs/mcp-go/mcp" ) diff --git a/mcp/proxy_streamable_http.go b/mcp/proxy_streamable_http.go index 9a6407e0..2ba8f2ad 100644 --- a/mcp/proxy_streamable_http.go +++ b/mcp/proxy_streamable_http.go @@ -2,19 +2,20 @@ package mcp import ( "context" - "fmt" "regexp" "slices" "strings" - "cdr.dev/slog/v3" - "github.com/coder/aibridge/tracing" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "golang.org/x/exp/maps" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/aibridge/tracing" ) var _ ServerProxier = &StreamableHTTPServerProxy{} @@ -40,7 +41,7 @@ func NewStreamableHTTPServerProxy(serverName, serverURL string, headers map[stri mcpClient, err := client.NewStreamableHttpClient(serverURL, opts...) if err != nil { - return nil, fmt.Errorf("create streamable http client: %w", err) + return nil, xerrors.Errorf("create streamable http client: %w", err) } return &StreamableHTTPServerProxy{ @@ -63,7 +64,7 @@ func (p *StreamableHTTPServerProxy) Init(ctx context.Context) (outErr error) { defer tracing.EndSpanErr(span, &outErr) if err := p.client.Start(ctx); err != nil { - return fmt.Errorf("start client: %w", err) + return xerrors.Errorf("start client: %w", err) } version := mcp.LATEST_PROTOCOL_VERSION @@ -76,21 +77,21 @@ func (p *StreamableHTTPServerProxy) Init(ctx context.Context) (outErr error) { result, err := p.client.Initialize(ctx, initReq) if err != nil { - return fmt.Errorf("init MCP client: %w", err) + return xerrors.Errorf("init MCP client: %w", err) } if !slices.Contains(mcp.ValidProtocolVersions, result.ProtocolVersion) { if err := p.client.Close(); err != nil { p.logger.Debug(ctx, "failed to close MCP client on unsuccessful version negotiation", slog.Error(err)) } - return fmt.Errorf("MCP version negotiation failed; requested %q, accepts %q, received %q", version, strings.Join(mcp.ValidProtocolVersions, ","), result.ProtocolVersion) + return xerrors.Errorf("MCP version negotiation failed; requested %q, accepts %q, received %q", version, strings.Join(mcp.ValidProtocolVersions, ","), result.ProtocolVersion) } p.logger.Debug(ctx, "MCP client initialized", slog.F("name", result.ServerInfo.Name), slog.F("server_version", result.ServerInfo.Version)) tools, err := p.fetchTools(ctx) if err != nil { - return fmt.Errorf("fetch tools: %w", err) + return xerrors.Errorf("fetch tools: %w", err) } // Only include allowed tools. @@ -121,7 +122,7 @@ func (p *StreamableHTTPServerProxy) GetTool(name string) *Tool { func (p *StreamableHTTPServerProxy) CallTool(ctx context.Context, name string, input any) (*mcp.CallToolResult, error) { tool := p.GetTool(name) if tool == nil { - return nil, fmt.Errorf("%q tool not known", name) + return nil, xerrors.Errorf("%q tool not known", name) } return p.client.CallTool(ctx, mcp.CallToolRequest{ @@ -138,7 +139,7 @@ func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (_ map[strin tools, err := p.client.ListTools(ctx, mcp.ListToolsRequest{}) if err != nil { - return nil, fmt.Errorf("list MCP tools: %w", err) + return nil, xerrors.Errorf("list MCP tools: %w", err) } out := make(map[string]*Tool, len(tools.Tools)) diff --git a/mcp/server_proxy_manager.go b/mcp/server_proxy_manager.go index 01c87909..58e15214 100644 --- a/mcp/server_proxy_manager.go +++ b/mcp/server_proxy_manager.go @@ -2,15 +2,16 @@ package mcp import ( "context" - "fmt" "slices" "strings" "sync" - "github.com/coder/aibridge/tracing" - "github.com/coder/aibridge/utils" "github.com/mark3labs/mcp-go/mcp" "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + + "github.com/coder/aibridge/tracing" + "github.com/coder/aibridge/utils" ) var _ ServerProxier = &ServerProxyManager{} @@ -106,12 +107,12 @@ func (s *ServerProxyManager) ListTools() []*Tool { func (s *ServerProxyManager) CallTool(ctx context.Context, name string, input any) (*mcp.CallToolResult, error) { tool := s.GetTool(name) if tool == nil { - return nil, fmt.Errorf("%q tool not known", name) + return nil, xerrors.Errorf("%q tool not known", name) } proxy, ok := s.proxiers[tool.ServerName] if !ok { - return nil, fmt.Errorf("%q server not known", tool.ServerName) + return nil, xerrors.Errorf("%q server not known", tool.ServerName) } return proxy.CallTool(ctx, name, input) diff --git a/mcp/tool.go b/mcp/tool.go index cddf6271..846928d8 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -3,16 +3,17 @@ package mcp import ( "context" "encoding/json" - "errors" "regexp" "strings" "time" - "cdr.dev/slog/v3" - "github.com/coder/aibridge/tracing" "github.com/mark3labs/mcp-go/mcp" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/aibridge/tracing" ) const ( @@ -21,7 +22,7 @@ const ( injectedToolDelimiter = "_" ) -// ToolCaller is the narrowest interface which describes the behaviour required from [mcp.Client], +// ToolCaller is the narrowest interface which describes the behavior required from [mcp.Client], // which will normally be passed into [Tool] for interaction with an MCP server. // TODO: don't expose github.com/mark3labs/mcp-go outside this package. type ToolCaller interface { @@ -43,10 +44,10 @@ type Tool struct { func (t *Tool) Call(ctx context.Context, input any, tracer trace.Tracer) (_ *mcp.CallToolResult, outErr error) { if t == nil { - return nil, errors.New("nil tool") + return nil, xerrors.New("nil tool") } if t.Client == nil { - return nil, errors.New("nil client") + return nil, xerrors.New("nil client") } spanAttrs := append( diff --git a/passthrough.go b/passthrough.go index c6b59edd..5e1efe6a 100644 --- a/passthrough.go +++ b/passthrough.go @@ -7,15 +7,16 @@ import ( "net/url" "time" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + "cdr.dev/slog/v3" "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/tracing" "github.com/coder/quartz" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/trace" ) // newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically diff --git a/passthrough_test.go b/passthrough_test.go index c51b6d4f..8f219c7c 100644 --- a/passthrough_test.go +++ b/passthrough_test.go @@ -5,10 +5,11 @@ import ( "net/http/httptest" "testing" - "cdr.dev/slog/v3/sloggers/slogtest" - "github.com/coder/aibridge/internal/testutil" "github.com/stretchr/testify/assert" "go.opentelemetry.io/otel" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/aibridge/internal/testutil" ) var testTracer = otel.Tracer("bridge_test") diff --git a/provider/anthropic.go b/provider/anthropic.go index 44870c63..e5132880 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -7,15 +7,17 @@ import ( "os" "strings" + "github.com/google/uuid" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + "github.com/coder/aibridge/circuitbreaker" "github.com/coder/aibridge/config" "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/messages" "github.com/coder/aibridge/tracing" "github.com/coder/aibridge/utils" - "github.com/google/uuid" - "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/trace" ) // anthropicForwardHeaders lists headers from incoming requests that should be @@ -106,12 +108,12 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr case routeMessages: payload, err := io.ReadAll(r.Body) if err != nil { - return nil, fmt.Errorf("read body: %w", err) + return nil, xerrors.Errorf("read body: %w", err) } reqPayload, err := messages.NewMessagesRequestPayload(payload) if err != nil { - return nil, fmt.Errorf("unmarshal request body: %w", err) + return nil, xerrors.Errorf("unmarshal request body: %w", err) } cfg := p.cfg diff --git a/provider/anthropic_test.go b/provider/anthropic_test.go index bc240a46..1d2bc1dd 100644 --- a/provider/anthropic_test.go +++ b/provider/anthropic_test.go @@ -6,10 +6,11 @@ import ( "net/http/httptest" "testing" - "cdr.dev/slog/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "cdr.dev/slog/v3" + "github.com/coder/aibridge/config" "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/internal/testutil" diff --git a/provider/copilot.go b/provider/copilot.go index a7df6e8b..735c9b83 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -8,15 +8,17 @@ import ( "os" "strings" + "github.com/google/uuid" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + "github.com/coder/aibridge/config" "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/chatcompletions" "github.com/coder/aibridge/intercept/responses" "github.com/coder/aibridge/tracing" "github.com/coder/aibridge/utils" - "github.com/google/uuid" - "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/trace" ) const ( @@ -129,7 +131,7 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac key := utils.ExtractBearerToken(r.Header.Get("Authorization")) if key == "" { span.SetStatus(codes.Error, "missing authorization") - return nil, fmt.Errorf("missing Copilot authorization: Authorization header not found or invalid") + return nil, xerrors.New("missing Copilot authorization: Authorization header not found or invalid") } id := uuid.New() @@ -154,7 +156,7 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac case routeCopilotChatCompletions: var req chatcompletions.ChatCompletionNewParamsWrapper if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, fmt.Errorf("unmarshal chat completions request body: %w", err) + return nil, xerrors.Errorf("unmarshal chat completions request body: %w", err) } if req.Stream { @@ -166,11 +168,11 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac case routeCopilotResponses: payload, err := io.ReadAll(r.Body) if err != nil { - return nil, fmt.Errorf("read body: %w", err) + return nil, xerrors.Errorf("read body: %w", err) } reqPayload, err := responses.NewResponsesRequestPayload(payload) if err != nil { - return nil, fmt.Errorf("unmarshal request body: %w", err) + return nil, xerrors.Errorf("unmarshal request body: %w", err) } if reqPayload.Stream() { diff --git a/provider/copilot_test.go b/provider/copilot_test.go index 4fea128b..897f82a5 100644 --- a/provider/copilot_test.go +++ b/provider/copilot_test.go @@ -6,11 +6,12 @@ import ( "net/http/httptest" "testing" - "cdr.dev/slog/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + "cdr.dev/slog/v3" + "github.com/coder/aibridge/config" "github.com/coder/aibridge/internal/testutil" ) diff --git a/provider/openai.go b/provider/openai.go index a8e86216..ae594521 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -8,15 +8,17 @@ import ( "os" "strings" + "github.com/google/uuid" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + "github.com/coder/aibridge/config" "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/chatcompletions" "github.com/coder/aibridge/intercept/responses" "github.com/coder/aibridge/tracing" "github.com/coder/aibridge/utils" - "github.com/google/uuid" - "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/trace" ) const ( @@ -125,7 +127,7 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace case routeChatCompletions: var req chatcompletions.ChatCompletionNewParamsWrapper if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, fmt.Errorf("unmarshal request body: %w", err) + return nil, xerrors.Errorf("unmarshal request body: %w", err) } if req.Stream { @@ -137,11 +139,11 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace case routeResponses: payload, err := io.ReadAll(r.Body) if err != nil { - return nil, fmt.Errorf("read body: %w", err) + return nil, xerrors.Errorf("read body: %w", err) } reqPayload, err := responses.NewResponsesRequestPayload(payload) if err != nil { - return nil, fmt.Errorf("unmarshal request body: %w", err) + return nil, xerrors.Errorf("unmarshal request body: %w", err) } if reqPayload.Stream() { interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer, cred) diff --git a/provider/openai_test.go b/provider/openai_test.go index 0c715cc8..80e5097e 100644 --- a/provider/openai_test.go +++ b/provider/openai_test.go @@ -9,14 +9,15 @@ import ( "strings" "testing" - "cdr.dev/slog/v3" - "github.com/coder/aibridge/config" - "github.com/coder/aibridge/intercept" - "github.com/coder/aibridge/internal/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace/noop" "golang.org/x/sync/errgroup" + + "cdr.dev/slog/v3" + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/intercept" + "github.com/coder/aibridge/internal/testutil" ) const ( diff --git a/provider/provider.go b/provider/provider.go index 0e9fca3e..4d76344e 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -1,15 +1,16 @@ package provider import ( - "errors" "net/http" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + "github.com/coder/aibridge/config" "github.com/coder/aibridge/intercept" - "go.opentelemetry.io/otel/trace" ) -var UnknownRoute = errors.New("unknown route") +var UnknownRoute = xerrors.New("unknown route") // Provider defines routes (bridged and passed through) for given provider. // Bridged routes are processed by dedicated interceptors. @@ -23,7 +24,7 @@ var UnknownRoute = errors.New("unknown route") // When request is bridged, interceptor created based on route processes the request. // When request is passed through the {host} + {aibridge root} + {provider prefix} URL part // is replaced by provider's base URL and request is forwarded. -// This mirrors behaviour in bridged routes and SDKs used by interceptors. +// This mirrors behavior in bridged routes and SDKs used by interceptors. // // Example: // diff --git a/recorder/recorder.go b/recorder/recorder.go index 795845ef..f87b0800 100644 --- a/recorder/recorder.go +++ b/recorder/recorder.go @@ -2,15 +2,17 @@ package recorder import ( "context" - "fmt" "sync" "time" + "golang.org/x/xerrors" + "cdr.dev/slog/v3" + "go.opentelemetry.io/otel/trace" + "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/tracing" - "go.opentelemetry.io/otel/trace" ) var ( @@ -32,7 +34,7 @@ func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *Intercept client, err := r.clientFn() if err != nil { - return fmt.Errorf("acquire client: %w", err) + return xerrors.Errorf("acquire client: %w", err) } req.StartedAt = time.Now() @@ -50,7 +52,7 @@ func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *Inte client, err := r.clientFn() if err != nil { - return fmt.Errorf("acquire client: %w", err) + return xerrors.Errorf("acquire client: %w", err) } req.EndedAt = time.Now().UTC() @@ -68,7 +70,7 @@ func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsag client, err := r.clientFn() if err != nil { - return fmt.Errorf("acquire client: %w", err) + return xerrors.Errorf("acquire client: %w", err) } req.CreatedAt = time.Now() @@ -86,7 +88,7 @@ func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageR client, err := r.clientFn() if err != nil { - return fmt.Errorf("acquire client: %w", err) + return xerrors.Errorf("acquire client: %w", err) } req.CreatedAt = time.Now() @@ -104,7 +106,7 @@ func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRec client, err := r.clientFn() if err != nil { - return fmt.Errorf("acquire client: %w", err) + return xerrors.Errorf("acquire client: %w", err) } req.CreatedAt = time.Now() @@ -122,7 +124,7 @@ func (r *RecorderWrapper) RecordModelThought(ctx context.Context, req *ModelThou client, err := r.clientFn() if err != nil { - return fmt.Errorf("acquire client: %w", err) + return xerrors.Errorf("acquire client: %w", err) } req.CreatedAt = time.Now() diff --git a/session.go b/session.go index 2d99db1c..e89ad784 100644 --- a/session.go +++ b/session.go @@ -7,8 +7,9 @@ import ( "regexp" "strings" - "github.com/coder/aibridge/utils" "github.com/tidwall/gjson" + + "github.com/coder/aibridge/utils" ) var claudeCodePattern = regexp.MustCompile(`_session_(.+)$`) // Legacy format: save compilation on each call. diff --git a/session_test.go b/session_test.go index 7e7ccaca..2f952d35 100644 --- a/session_test.go +++ b/session_test.go @@ -6,8 +6,9 @@ import ( "strings" "testing" - "github.com/coder/aibridge/utils" "github.com/stretchr/testify/require" + + "github.com/coder/aibridge/utils" ) func TestGuessSessionID(t *testing.T) { diff --git a/utils/auth_test.go b/utils/auth_test.go index b16b7725..eea3cdc5 100644 --- a/utils/auth_test.go +++ b/utils/auth_test.go @@ -3,8 +3,9 @@ package utils_test import ( "testing" - "github.com/coder/aibridge/utils" "github.com/stretchr/testify/assert" + + "github.com/coder/aibridge/utils" ) func TestExtractBearerToken(t *testing.T) { diff --git a/utils/concurrent_group_test.go b/utils/concurrent_group_test.go index 516ce0d7..36ca4481 100644 --- a/utils/concurrent_group_test.go +++ b/utils/concurrent_group_test.go @@ -1,12 +1,13 @@ package utils_test import ( - "errors" "testing" - "github.com/coder/aibridge/utils" "github.com/stretchr/testify/require" "go.uber.org/goleak" + "golang.org/x/xerrors" + + "github.com/coder/aibridge/utils" ) func TestMain(m *testing.M) { @@ -34,7 +35,7 @@ func TestConcurrentGroup(t *testing.T) { t.Run("multiple goroutines, one err", func(t *testing.T) { cg := utils.NewConcurrentGroup() - oops := errors.New("oops") + oops := xerrors.New("oops") cg.Go(func() error { return oops }) @@ -46,8 +47,8 @@ func TestConcurrentGroup(t *testing.T) { t.Run("multiple goroutines, multiple errs", func(t *testing.T) { cg := utils.NewConcurrentGroup() - oops := errors.New("oops") - eek := errors.New("eek") + oops := xerrors.New("oops") + eek := xerrors.New("eek") cg.Go(func() error { return oops }) diff --git a/utils/mask_test.go b/utils/mask_test.go index f71b8cf3..9ce65529 100644 --- a/utils/mask_test.go +++ b/utils/mask_test.go @@ -3,8 +3,9 @@ package utils_test import ( "testing" - "github.com/coder/aibridge/utils" "github.com/stretchr/testify/assert" + + "github.com/coder/aibridge/utils" ) func TestMaskSecret(t *testing.T) {