Skip to content

Commit 2d27fba

Browse files
author
Codex
committed
fix(openai): keep image edit multipart body reusable
1 parent 84051a5 commit 2d27fba

6 files changed

Lines changed: 117 additions & 22 deletions

File tree

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package openai
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"mime/multipart"
7+
"net/http"
8+
"net/http/httptest"
9+
"testing"
10+
11+
"github.com/QuantumNous/new-api/dto"
12+
relaycommon "github.com/QuantumNous/new-api/relay/common"
13+
relayconstant "github.com/QuantumNous/new-api/relay/constant"
14+
"github.com/gin-gonic/gin"
15+
"github.com/stretchr/testify/require"
16+
)
17+
18+
func TestConvertImageEditRequestKeepsValidMultipartStreamFields(t *testing.T) {
19+
gin.SetMode(gin.TestMode)
20+
21+
var body bytes.Buffer
22+
writer := multipart.NewWriter(&body)
23+
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
24+
require.NoError(t, writer.WriteField("prompt", "edit this image"))
25+
require.NoError(t, writer.WriteField("stream", "true"))
26+
require.NoError(t, writer.WriteField("partial_images", "3"))
27+
part, err := writer.CreateFormFile("image", "input.png")
28+
require.NoError(t, err)
29+
_, err = part.Write([]byte("fake image"))
30+
require.NoError(t, err)
31+
require.NoError(t, writer.Close())
32+
33+
recorder := httptest.NewRecorder()
34+
c, _ := gin.CreateTestContext(recorder)
35+
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
36+
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
37+
require.NoError(t, c.Request.ParseMultipartForm(32<<20))
38+
39+
info := &relaycommon.RelayInfo{
40+
RelayMode: relayconstant.RelayModeImagesEdits,
41+
}
42+
request := dto.ImageRequest{
43+
Model: "gpt-image-1",
44+
Prompt: "edit this image",
45+
Stream: true,
46+
}
47+
48+
converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request)
49+
require.NoError(t, err)
50+
51+
convertedBody, ok := converted.(*bytes.Buffer)
52+
require.True(t, ok)
53+
54+
contentType := c.Request.Header.Get("Content-Type")
55+
replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes()))
56+
replayedRequest.Header.Set("Content-Type", contentType)
57+
require.NoError(t, replayedRequest.ParseMultipartForm(32<<20))
58+
59+
require.Equal(t, "gpt-image-1", replayedRequest.PostForm.Get("model"))
60+
require.Equal(t, "edit this image", replayedRequest.PostForm.Get("prompt"))
61+
require.Equal(t, "true", replayedRequest.PostForm.Get("stream"))
62+
require.Equal(t, "3", replayedRequest.PostForm.Get("partial_images"))
63+
require.Len(t, replayedRequest.MultipartForm.File["image"], 1)
64+
65+
file, err := replayedRequest.MultipartForm.File["image"][0].Open()
66+
require.NoError(t, err)
67+
defer file.Close()
68+
fileBytes, err := io.ReadAll(file)
69+
require.NoError(t, err)
70+
require.Equal(t, []byte("fake image"), fileBytes)
71+
}

relay/channel/openai/image_stream_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
2424
`event: image_generation.partial_image`,
2525
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
2626
``,
27-
`data: {"usage":{"prompt_tokens":3,"completion_tokens":4,"total_tokens":7}}`,
27+
`data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`,
2828
``,
2929
`data: [DONE]`,
3030
``,
@@ -49,8 +49,10 @@ func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
4949
require.Equal(t, 3, usage.PromptTokens)
5050
require.Equal(t, 4, usage.CompletionTokens)
5151
require.Equal(t, 7, usage.TotalTokens)
52+
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
53+
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
5254
require.Contains(t, recorder.Body.String(), `event: image_generation.partial_image`)
5355
require.Contains(t, recorder.Body.String(), `data: {"type":"image_generation.partial_image","b64_json":"partial"}`)
54-
require.Contains(t, recorder.Body.String(), `data: {"usage":{"prompt_tokens":3,"completion_tokens":4,"total_tokens":7}}`)
56+
require.Contains(t, recorder.Body.String(), `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`)
5557
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
5658
}

relay/channel/openai/relay-openai.go

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -575,22 +575,25 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
575575
// 写入新的 response body
576576
service.IOCopyBytesGracefully(c, resp, responseBody)
577577

578-
// Once we've written to the client, we should not return errors anymore
579-
// because the upstream has already consumed resources and returned content
580-
// We should still perform billing even if parsing fails
581-
// format
582-
if usageResp.InputTokens > 0 {
583-
usageResp.PromptTokens += usageResp.InputTokens
578+
normalizeOpenAIUsage(&usageResp.Usage)
579+
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
580+
return &usageResp.Usage, nil
581+
}
582+
583+
func normalizeOpenAIUsage(usage *dto.Usage) {
584+
if usage == nil {
585+
return
584586
}
585-
if usageResp.OutputTokens > 0 {
586-
usageResp.CompletionTokens += usageResp.OutputTokens
587+
if usage.InputTokens > 0 {
588+
usage.PromptTokens += usage.InputTokens
587589
}
588-
if usageResp.InputTokensDetails != nil {
589-
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
590-
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
590+
if usage.OutputTokens > 0 {
591+
usage.CompletionTokens += usage.OutputTokens
592+
}
593+
if usage.InputTokensDetails != nil {
594+
usage.PromptTokensDetails.ImageTokens += usage.InputTokensDetails.ImageTokens
595+
usage.PromptTokensDetails.TextTokens += usage.InputTokensDetails.TextTokens
591596
}
592-
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
593-
return &usageResp.Usage, nil
594597
}
595598

596599
func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
@@ -609,7 +612,7 @@ func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
609612
}
610613

611614
scanner := bufio.NewScanner(resp.Body)
612-
scanner.Buffer(make([]byte, helper.InitialScannerBufferSize), helper.DefaultMaxScannerBufferSize)
615+
scanner.Buffer(make([]byte, helper.InitialScannerBufferSize), helper.GetScannerBufferSize())
613616
for scanner.Scan() {
614617
line := scanner.Text()
615618
if strings.HasPrefix(line, "data:") {
@@ -621,8 +624,11 @@ func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
621624
info.ReceivedResponseCount++
622625
lastStreamData = common.StringToByteSlice(data)
623626
var usageResp dto.SimpleResponse
624-
if err := common.Unmarshal(lastStreamData, &usageResp); err == nil && service.ValidUsage(&usageResp.Usage) {
625-
usage = &usageResp.Usage
627+
if err := common.Unmarshal(lastStreamData, &usageResp); err == nil {
628+
normalizeOpenAIUsage(&usageResp.Usage)
629+
if service.ValidUsage(&usageResp.Usage) {
630+
usage = &usageResp.Usage
631+
}
626632
}
627633
}
628634
}

relay/helper/openai_image_request_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@ package helper
22

33
import (
44
"bytes"
5+
"io"
56
"mime/multipart"
67
"net/http"
78
"net/http/httptest"
9+
"net/url"
810
"testing"
911

12+
"github.com/QuantumNous/new-api/common"
1013
relayconstant "github.com/QuantumNous/new-api/relay/constant"
1114
"github.com/gin-gonic/gin"
1215
"github.com/stretchr/testify/require"
@@ -26,6 +29,7 @@ func TestGetAndValidOpenAIImageRequestMultipartStream(t *testing.T) {
2629
_, err = part.Write([]byte("fake image"))
2730
require.NoError(t, err)
2831
require.NoError(t, writer.Close())
32+
originalBody := body.String()
2933

3034
recorder := httptest.NewRecorder()
3135
c, _ := gin.CreateTestContext(recorder)
@@ -36,4 +40,13 @@ func TestGetAndValidOpenAIImageRequestMultipartStream(t *testing.T) {
3640
require.NoError(t, err)
3741
require.True(t, req.Stream)
3842
require.True(t, req.IsStream(c))
43+
44+
bodyAfterValidation, err := io.ReadAll(c.Request.Body)
45+
require.NoError(t, err)
46+
require.Equal(t, originalBody, string(bodyAfterValidation))
47+
48+
form, err := common.ParseMultipartFormReusable(c)
49+
require.NoError(t, err)
50+
require.Equal(t, "true", url.Values(form.Value).Get("stream"))
51+
require.Len(t, form.File["image"], 1)
3952
}

relay/helper/stream_scanner.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ const (
2727
DefaultPingInterval = 10 * time.Second
2828
)
2929

30-
func getScannerBufferSize() int {
30+
func GetScannerBufferSize() int {
3131
if constant.StreamScannerMaxBufferMB > 0 {
3232
return constant.StreamScannerMaxBufferMB << 20
3333
}
@@ -108,7 +108,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
108108
close(stopChan)
109109
}()
110110

111-
scanner.Buffer(make([]byte, InitialScannerBufferSize), getScannerBufferSize())
111+
scanner.Buffer(make([]byte, InitialScannerBufferSize), GetScannerBufferSize())
112112
scanner.Split(bufio.ScanLines)
113113
SetEventStreamHeaders(c)
114114

relay/helper/valid_request.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"fmt"
77
"math"
8+
"net/url"
89
"strconv"
910
"strings"
1011

@@ -146,11 +147,13 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
146147
switch relayMode {
147148
case relayconstant.RelayModeImagesEdits:
148149
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
149-
_, err := c.MultipartForm()
150+
form, err := common.ParseMultipartFormReusable(c)
150151
if err != nil {
151152
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
152153
}
153-
formData := c.Request.PostForm
154+
formData := url.Values(form.Value)
155+
c.Request.MultipartForm = form
156+
c.Request.PostForm = formData
154157
imageRequest.Prompt = formData.Get("prompt")
155158
imageRequest.Model = formData.Get("model")
156159
imageRequest.N = common.GetPointer(uint(common.String2Int(formData.Get("n"))))

0 commit comments

Comments
 (0)