Skip to content

Commit 9e2857a

Browse files
authored
perf: skip redundant openai request serialization (#160)
* use dannykopping/anthropic-sdk-go to avoid appendCompact * make fmt * perf: use the more efficient sasswart/openai-go * perf: avoid an unncecessary json unmarshal when we intercept chat completions requests * perf: reduce allocations when creating chat completions interceptors * uncomment benchmark * make fmt * update openai-go dependency * chore: document why we replace llm provider sdks
1 parent 406f98f commit 9e2857a

8 files changed

Lines changed: 292 additions & 72 deletions

File tree

go.mod

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ require (
1818
go.uber.org/goleak v1.3.0
1919
go.uber.org/mock v0.6.0
2020
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b
21+
golang.org/x/sync v0.16.0
2122
golang.org/x/tools v0.36.0
2223
)
2324

@@ -88,4 +89,9 @@ require (
8889
gopkg.in/yaml.v3 v3.0.1 // indirect
8990
)
9091

92+
// Replace sdks with our own optimized forks until relevant upstream PRs are merged.
93+
// https://github.com/anthropics/anthropic-sdk-go/pull/262
9194
replace github.com/anthropics/anthropic-sdk-go v1.13.0 => github.com/dannykopping/anthropic-sdk-go v0.0.0-20251230111224-88a4315810bd
95+
96+
// https://github.com/openai/openai-go/pull/602
97+
replace github.com/openai/openai-go/v3 => github.com/SasSwart/openai-go/v3 v3.0.0-20260202093810-72af3b857f95

go.sum

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ cloud.google.com/go/logging v1.8.1 h1:26skQWPeYhvIasWKm48+Eq7oUqdcdbwsCVwz5Ys0Fv
77
cloud.google.com/go/logging v1.8.1/go.mod h1:TJjR+SimHwuC8MZ9cjByQulAMgni+RkXeI3wwctHJEI=
88
cloud.google.com/go/longrunning v0.5.1 h1:Fr7TXftcqTudoyRJa113hyaqlGdiBQkp0Gq7tErFDWI=
99
cloud.google.com/go/longrunning v0.5.1/go.mod h1:spvimkwdz6SPWKEt/XBij79E9fiTkHSQl/fRUUQJYJc=
10+
github.com/SasSwart/openai-go/v3 v3.0.0-20260202093810-72af3b857f95 h1:HVJp3FanNaeFAlwg0/lkdkSnwFemHnwwjXBM8KRj540=
11+
github.com/SasSwart/openai-go/v3 v3.0.0-20260202093810-72af3b857f95/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
1012
github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY=
1113
github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc=
1214
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 h1:tW1/Rkad38LA15X4UQtjXZXNKsCgkshC3EbmcUmghTg=
@@ -94,8 +96,6 @@ github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo
9496
github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8=
9597
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
9698
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
97-
github.com/openai/openai-go/v3 v3.15.0 h1:hk99rM7YPz+M99/5B/zOQcVwFRLLMdprVGx1vaZ8XMo=
98-
github.com/openai/openai-go/v3 v3.15.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
9999
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
100100
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
101101
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
@@ -155,6 +155,8 @@ golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11
155155
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4=
156156
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
157157
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
158+
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
159+
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
158160
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
159161
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
160162
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=

intercept/chatcompletions/paramswrap.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"github.com/coder/aibridge/utils"
77
"github.com/openai/openai-go/v3"
88
"github.com/openai/openai-go/v3/packages/param"
9+
"github.com/tidwall/gjson"
910
)
1011

1112
// ChatCompletionNewParamsWrapper exists because the "stream" param is not included in openai.ChatCompletionNewParams.
@@ -27,14 +28,10 @@ func (c *ChatCompletionNewParamsWrapper) UnmarshalJSON(raw []byte) error {
2728
return err
2829
}
2930

30-
if stream := utils.ExtractJSONField[bool](raw, "stream"); stream {
31-
c.Stream = stream
32-
if c.Stream {
33-
c.ChatCompletionNewParams.StreamOptions = openai.ChatCompletionStreamOptionsParam{
34-
IncludeUsage: openai.Bool(true), // Always include usage when streaming.
35-
}
36-
} else {
37-
c.ChatCompletionNewParams.StreamOptions = openai.ChatCompletionStreamOptionsParam{}
31+
c.Stream = gjson.GetBytes(raw, "stream").Bool()
32+
if c.Stream {
33+
c.ChatCompletionNewParams.StreamOptions = openai.ChatCompletionStreamOptionsParam{
34+
IncludeUsage: openai.Bool(true), // Always include usage when streaming.
3835
}
3936
} else {
4037
c.ChatCompletionNewParams.StreamOptions = openai.ChatCompletionStreamOptionsParam{}

intercept/chatcompletions/paramswrap_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package chatcompletions
22

33
import (
4+
"fmt"
5+
"strings"
46
"testing"
57

68
"github.com/openai/openai-go/v3"
@@ -130,3 +132,41 @@ func TestOpenAILastUserPrompt(t *testing.T) {
130132
})
131133
}
132134
}
135+
136+
// generatePayload creates a JSON payload with the specified number of messages.
137+
// Messages alternate between user and assistant roles to simulate a conversation.
138+
func generatePayload(messageCount int) []byte {
139+
var messages []string
140+
for i := range messageCount {
141+
role := "user"
142+
if i%2 == 1 {
143+
role = "assistant"
144+
}
145+
// Use realistic message content size
146+
content := fmt.Sprintf("This is message number %d with some realistic content that might appear in a conversation.", i+1)
147+
messages = append(messages, fmt.Sprintf(`{"role": "%s", "content": "%s"}`, role, content))
148+
}
149+
150+
return []byte(fmt.Sprintf(`{
151+
"model": "gpt-4",
152+
"stream": true,
153+
"messages": [%s]
154+
}`, strings.Join(messages, ",")))
155+
}
156+
157+
func BenchmarkChatCompletionNewParamsWrapper_UnmarshalJSON(b *testing.B) {
158+
messageCounts := []int{1, 10, 20, 50}
159+
160+
for _, count := range messageCounts {
161+
payload := generatePayload(count)
162+
163+
b.Run(fmt.Sprintf("messages=%d", count), func(b *testing.B) {
164+
b.ReportAllocs()
165+
b.ResetTimer()
166+
for range b.N {
167+
var wrapper ChatCompletionNewParamsWrapper
168+
_ = wrapper.UnmarshalJSON(payload)
169+
}
170+
})
171+
}
172+
}

provider/openai.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,12 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
8787
_, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor")
8888
defer tracing.EndSpanErr(span, &outErr)
8989

90-
payload, err := io.ReadAll(r.Body)
91-
if err != nil {
92-
return nil, fmt.Errorf("read body: %w", err)
93-
}
94-
9590
var interceptor intercept.Interceptor
9691

9792
switch r.URL.Path {
9893
case routeChatCompletions:
9994
var req chatcompletions.ChatCompletionNewParamsWrapper
100-
if err := json.Unmarshal(payload, &req); err != nil {
95+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
10196
return nil, fmt.Errorf("unmarshal request body: %w", err)
10297
}
10398

@@ -108,6 +103,10 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
108103
}
109104

110105
case routeResponses:
106+
payload, err := io.ReadAll(r.Body)
107+
if err != nil {
108+
return nil, fmt.Errorf("read body: %w", err)
109+
}
111110
var req responses.ResponsesNewParamsWrapper
112111
if err := json.Unmarshal(payload, &req); err != nil {
113112
return nil, fmt.Errorf("unmarshal request body: %w", err)

provider/openai_test.go

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
package provider
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"fmt"
7+
"net/http"
8+
"net/http/httptest"
9+
"strings"
10+
"testing"
11+
12+
"github.com/coder/aibridge/config"
13+
"go.opentelemetry.io/otel/trace/noop"
14+
"golang.org/x/sync/errgroup"
15+
)
16+
17+
type message struct {
18+
Role string
19+
Content string
20+
}
21+
22+
type providerStrategy interface {
23+
DefaultModel() string
24+
formatMessages(messages []message) []any
25+
buildRequestBody(model string, messages []any, stream bool) map[string]any
26+
}
27+
type responsesProvider struct{}
28+
29+
func (*responsesProvider) DefaultModel() string {
30+
return "gpt-5"
31+
}
32+
33+
func (*responsesProvider) formatMessages(messages []message) []any {
34+
formatted := make([]any, 0, len(messages))
35+
for _, msg := range messages {
36+
formatted = append(formatted, map[string]any{
37+
"type": "message",
38+
"role": msg.Role,
39+
"content": msg.Content,
40+
})
41+
}
42+
return formatted
43+
}
44+
45+
func (*responsesProvider) buildRequestBody(model string, messages []any, stream bool) map[string]any {
46+
return map[string]any{
47+
"model": model,
48+
"input": messages,
49+
"stream": stream,
50+
}
51+
}
52+
53+
type chatCompletionsProvider struct{}
54+
55+
func (*chatCompletionsProvider) DefaultModel() string {
56+
return "gpt-4"
57+
}
58+
59+
func (*chatCompletionsProvider) formatMessages(messages []message) []any {
60+
formatted := make([]any, 0, len(messages))
61+
for _, msg := range messages {
62+
formatted = append(formatted, map[string]string{
63+
"role": msg.Role,
64+
"content": msg.Content,
65+
})
66+
}
67+
return formatted
68+
}
69+
70+
func (*chatCompletionsProvider) buildRequestBody(model string, messages []any, stream bool) map[string]any {
71+
return map[string]any{
72+
"model": model,
73+
"messages": messages,
74+
"stream": stream,
75+
}
76+
}
77+
78+
func generateConversation(provider providerStrategy, targetSize int, numMessages int) []any {
79+
if targetSize <= 0 {
80+
return nil
81+
}
82+
if numMessages < 1 {
83+
numMessages = 1
84+
}
85+
86+
roles := []string{"user", "assistant"}
87+
messages := make([]message, numMessages)
88+
for i := range messages {
89+
messages[i].Role = roles[i%2]
90+
}
91+
// Ensure last message is from user (required for LLM APIs).
92+
if messages[len(messages)-1].Role != "user" {
93+
messages[len(messages)-1].Role = "user"
94+
}
95+
96+
overhead := measureJSONSize(provider.formatMessages(messages))
97+
98+
bytesPerMessage := targetSize - overhead
99+
if bytesPerMessage < 0 {
100+
bytesPerMessage = 0
101+
}
102+
103+
perMessage := bytesPerMessage / len(messages)
104+
remainder := bytesPerMessage % len(messages)
105+
106+
for i := range messages {
107+
size := perMessage
108+
if i == len(messages)-1 {
109+
size += remainder
110+
}
111+
messages[i].Content = strings.Repeat("x", size)
112+
}
113+
114+
return provider.formatMessages(messages)
115+
}
116+
117+
func measureJSONSize(v any) int {
118+
data, err := json.Marshal(v)
119+
if err != nil {
120+
return 0
121+
}
122+
return len(data)
123+
}
124+
125+
// generateChatCompletionsPayload creates a JSON payload with the specified number of messages.
126+
// Messages alternate between user and assistant roles to simulate a conversation.
127+
func generateChatCompletionsPayload(payloadSize int, messageCount int, stream bool) []byte {
128+
provider := &chatCompletionsProvider{}
129+
messages := generateConversation(provider, payloadSize, messageCount)
130+
131+
body := provider.buildRequestBody(provider.DefaultModel(), messages, stream)
132+
bodyBytes, err := json.Marshal(body)
133+
if err != nil {
134+
panic(err)
135+
}
136+
return bodyBytes
137+
}
138+
139+
// generateResponsesPayload creates a JSON payload for the responses API with the specified number of input items.
140+
// Input items alternate between user and assistant roles to simulate a conversation.
141+
func generateResponsesPayload(payloadSize int, inputCount int, stream bool) []byte {
142+
provider := &responsesProvider{}
143+
inputs := generateConversation(provider, payloadSize, inputCount)
144+
145+
body := provider.buildRequestBody(provider.DefaultModel(), inputs, stream)
146+
bodyBytes, err := json.Marshal(body)
147+
if err != nil {
148+
panic(err)
149+
}
150+
return bodyBytes
151+
}
152+
153+
func BenchmarkOpenAI_CreateInterceptor_ChatCompletions(b *testing.B) {
154+
provider := NewOpenAI(config.OpenAI{
155+
BaseURL: "https://api.openai.com/v1/",
156+
Key: "test-key",
157+
})
158+
159+
tracer := noop.NewTracerProvider().Tracer("test")
160+
messagesPerRequest := 50
161+
requestCount := 100
162+
maxConcurrentRequests := 10
163+
payloadSizes := []int{2000, 10000, 50000, 100000, 2000000}
164+
for _, payloadSize := range payloadSizes {
165+
for _, stream := range []bool{true, false} {
166+
payload := generateChatCompletionsPayload(payloadSize, messagesPerRequest, stream)
167+
name := fmt.Sprintf("stream=%t/payloadSize=%d/requests=%d", stream, payloadSize, requestCount)
168+
169+
b.Run(name, func(b *testing.B) {
170+
b.ResetTimer()
171+
for range b.N {
172+
eg := errgroup.Group{}
173+
eg.SetLimit(maxConcurrentRequests)
174+
for i := 0; i < requestCount; i++ {
175+
eg.Go(func() error {
176+
req := httptest.NewRequest(http.MethodPost, routeChatCompletions, bytes.NewReader(payload))
177+
w := httptest.NewRecorder()
178+
_, err := provider.CreateInterceptor(w, req, tracer)
179+
if err != nil {
180+
return err
181+
}
182+
return nil
183+
})
184+
}
185+
}
186+
})
187+
}
188+
}
189+
}
190+
191+
func BenchmarkOpenAI_CreateInterceptor_Responses(b *testing.B) {
192+
provider := NewOpenAI(config.OpenAI{
193+
BaseURL: "https://api.openai.com/v1/",
194+
Key: "test-key",
195+
})
196+
197+
tracer := noop.NewTracerProvider().Tracer("test")
198+
messagesPerRequest := 50
199+
requestCount := 100
200+
maxConcurrentRequests := 10
201+
// payloadSizes := []int{2000, 10000, 50000, 100000, 2000000}
202+
payloadSizes := []int{2000000}
203+
for _, payloadSize := range payloadSizes {
204+
for _, stream := range []bool{true, false} {
205+
payload := generateResponsesPayload(payloadSize, messagesPerRequest, stream)
206+
name := fmt.Sprintf("stream=%t/payloadSize=%d/requests=%d", stream, payloadSize, requestCount)
207+
208+
b.Run(name, func(b *testing.B) {
209+
b.ResetTimer()
210+
for range b.N {
211+
eg := errgroup.Group{}
212+
eg.SetLimit(maxConcurrentRequests)
213+
for i := 0; i < requestCount; i++ {
214+
eg.Go(func() error {
215+
req := httptest.NewRequest(http.MethodPost, routeResponses, bytes.NewReader(payload))
216+
w := httptest.NewRecorder()
217+
interceptor, err := provider.CreateInterceptor(w, req, tracer)
218+
if err != nil {
219+
return err
220+
}
221+
err = interceptor.ProcessRequest(w, req)
222+
if err != nil {
223+
return err
224+
}
225+
return nil
226+
})
227+
}
228+
}
229+
})
230+
}
231+
}
232+
}

0 commit comments

Comments
 (0)