Skip to content

Commit cc09565

Browse files
committed
cleaned up deltaBuffer, added drain method
1 parent cbee9ad commit cc09565

3 files changed

Lines changed: 36 additions & 37 deletions

File tree

intercept/responses/base.go

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ type responsesInterceptionBase struct {
3838
metrics metrics.Metrics
3939
}
4040

41-
func NewResponsesService(baseURL string, key string, logger slog.Logger) responses.ResponseService {
41+
func NewResponsesService(baseURL string, apiKey string) responses.ResponseService {
4242
opts := []option.RequestOption{
43-
option.WithAPIKey(key),
4443
option.WithBaseURL(baseURL),
44+
option.WithAPIKey(apiKey),
4545
}
4646

4747
return responses.NewResponseService(opts...)
@@ -99,7 +99,7 @@ func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http.
9999
return nil
100100
}
101101

102-
func (i *responsesInterceptionBase) requestOptions(respBody *deltaBuffer) []option.RequestOption {
102+
func (i *responsesInterceptionBase) requestOptions(payloadBuff *deltaBuffer) []option.RequestOption {
103103
opts := []option.RequestOption{
104104
// Sends original payload to solve json re-encoding issues
105105
// eg. Codex CLI produces requests without ID set in reasoning items: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-item-reasoning-id
@@ -108,15 +108,15 @@ func (i *responsesInterceptionBase) requestOptions(respBody *deltaBuffer) []opti
108108
option.WithRequestBody("application/json", i.reqPayload),
109109

110110
// Reads response body into given buffer
111-
option.WithMiddleware(teeMiddleware(respBody)),
111+
option.WithMiddleware(teeMiddleware(payloadBuff)),
112112
}
113113
if !i.req.Stream {
114114
opts = append(opts, option.WithRequestTimeout(time.Second*60)) // TODO: configurable timeout
115115
}
116116
return opts
117117
}
118118

119-
// handleUpstreamError checks error if it is an openAI error and if ProcessRequest should exit early.
119+
// handleUpstreamError checks if error is an openAI error and if caller should exit early.
120120
// If it is openAI responses error -> sets proper http response code and returns false to not exit early
121121
// response body will be sent using the same method as non-error response, using teeMiddleware -> deltaBuffer.
122122
// If it is a connection error or unnknown error -> returns given error + true to indicate early exit
@@ -142,22 +142,24 @@ func (i *responsesInterceptionBase) handleUpstreamError(ctx context.Context, ups
142142
}
143143

144144
// teeMiddleware copies response body to given buffer leaving original response intact/consumable for openAI SDK
145-
func teeMiddleware(respBody io.Writer) func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
145+
func teeMiddleware(payloadBuff *deltaBuffer) func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
146146
return func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
147147
resp, err := next(req)
148148
if err != nil || resp == nil || resp.Body == nil {
149149
return resp, err
150150
}
151151

152-
resp.Body = io.NopCloser(io.TeeReader(resp.Body, respBody))
152+
payloadBuff.closer = resp.Body
153+
resp.Body = io.NopCloser(io.TeeReader(resp.Body, payloadBuff))
153154
return resp, nil
154155
}
155156
}
156157

157-
// deltaBuffer stores everything written to it and lets you read only the new bytes since last drain.
158+
// deltaBuffer is a thread safe byte buffer that supports reading incremental data (added after last read)
158159
type deltaBuffer struct {
159-
mu sync.Mutex
160-
buf bytes.Buffer
160+
mu sync.Mutex
161+
buf bytes.Buffer
162+
closer io.ReadCloser
161163
}
162164

163165
func (d *deltaBuffer) Write(p []byte) (int, error) {
@@ -166,11 +168,16 @@ func (d *deltaBuffer) Write(p []byte) (int, error) {
166168
return d.buf.Write(p)
167169
}
168170

169-
func (d *deltaBuffer) Read(p []byte) (int, error) {
171+
// Reads all from original resqusts body so it is properly copied by TeeReader to buffer
172+
func (d *deltaBuffer) drain() error {
173+
if d.closer == nil {
174+
return nil
175+
}
176+
170177
d.mu.Lock()
171178
defer d.mu.Unlock()
172-
c, err := d.buf.Read(p)
173-
return c, err
179+
_, err := io.ReadAll(d.closer)
180+
return err
174181
}
175182

176183
// readDelta returns only the bytes appended since the last readDelta call.

intercept/responses/blocking.go

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package responses
33
import (
44
"errors"
55
"fmt"
6-
"io"
76
"net/http"
87

98
"cdr.dev/slog/v3"
@@ -45,25 +44,22 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *
4544
return err
4645
}
4746

48-
var respBody deltaBuffer
49-
srv := NewResponsesService(i.baseURL, i.key, i.logger)
47+
var respPayload deltaBuffer
48+
srv := NewResponsesService(i.baseURL, i.key)
5049

51-
opts := i.requestOptions(&respBody)
52-
_, err := srv.New(ctx, i.req.ResponseNewParams, opts...)
50+
opts := i.requestOptions(&respPayload)
51+
_, upstreamErr := srv.New(ctx, i.req.ResponseNewParams, opts...)
5352

54-
upstreamErr, earlyExit := i.handleUpstreamError(ctx, err, w)
53+
upstreamErr, earlyExit := i.handleUpstreamError(ctx, upstreamErr, w)
5554
if earlyExit {
5655
return upstreamErr
5756
}
5857

59-
w.Header().Set("Content-Type", "application/json")
60-
out, err := io.ReadAll(&respBody)
61-
if err != nil {
62-
i.logger.Warn(ctx, "failed to read upstream response", slog.Error(err))
63-
return errors.Join(fmt.Errorf("failed to read upstream response: %w", err), upstreamErr)
58+
if err := respPayload.drain(); err != nil {
59+
i.logger.Warn(ctx, "failed to drain original response body", slog.Error(err))
6460
}
65-
66-
_, err = w.Write(out)
61+
w.Header().Set("Content-Type", "application/json")
62+
_, err := w.Write(respPayload.readDelta())
6763
if err != nil {
6864
i.logger.Warn(ctx, "failed to write response", slog.Error(err))
6965
return errors.Join(fmt.Errorf("failed to write response: %w", err), upstreamErr)

intercept/responses/streaming.go

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"errors"
66
"fmt"
7-
"io"
87
"net/http"
98
"time"
109

@@ -63,15 +62,15 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r
6362
_ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes.
6463
}()
6564

66-
var respBody deltaBuffer
65+
var respPayload deltaBuffer
6766

68-
srv := NewResponsesService(i.baseURL, i.key, i.logger)
69-
opts := i.requestOptions(&respBody)
67+
srv := NewResponsesService(i.baseURL, i.key)
68+
opts := i.requestOptions(&respPayload)
7069
stream := srv.NewStreaming(ctx, i.req.ResponseNewParams, opts...)
7170
defer stream.Close()
7271

7372
for stream.Next() {
74-
if err := events.Send(ctx, respBody.readDelta()); err != nil {
73+
if err := events.Send(ctx, respPayload.readDelta()); err != nil {
7574
i.logger.Warn(ctx, "failed to relay chunk", slog.Error(err))
7675
err = fmt.Errorf("relay chunk: %w", err)
7776
stream.Close()
@@ -84,13 +83,10 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r
8483
return upstreamErr
8584
}
8685

87-
// Sometimes stream.Next() returns before respBody buffer is filled
88-
lastRead, err := io.ReadAll(&respBody)
89-
if err != nil {
90-
i.logger.Warn(ctx, "failed to read upstream response", slog.Error(err))
91-
return fmt.Errorf("failed to read upstream response: %w", err)
86+
if err := respPayload.drain(); err != nil {
87+
i.logger.Warn(ctx, "failed to drain original response body", slog.Error(err))
9288
}
93-
events.Send(ctx, lastRead)
89+
events.Send(ctx, respPayload.readDelta())
9490

9591
// Give the events stream 30 seconds (TODO: configurable) to gracefully shutdown.
9692
shutdownCtx, shutdownCancel := context.WithTimeout(ctx, time.Second*30)

0 commit comments

Comments
 (0)