Skip to content

Commit 55f3767

Browse files
author
Codex
committed
fix(openai): keep image edit multipart body reusable
1 parent 84051a5 commit 55f3767

7 files changed

Lines changed: 150 additions & 22 deletions

File tree

dto/openai_image_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"github.com/stretchr/testify/require"
77
)
88

9+
// TestImageRequestStreamJSON verifies that image requests preserve stream=true.
910
func TestImageRequestStreamJSON(t *testing.T) {
1011
var req ImageRequest
1112
require.NoError(t, req.UnmarshalJSON([]byte(`{"model":"gpt-image-1","prompt":"draw a cat","stream":true}`)))
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package openai
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"mime/multipart"
7+
"net/http"
8+
"net/http/httptest"
9+
"testing"
10+
11+
"github.com/QuantumNous/new-api/dto"
12+
relaycommon "github.com/QuantumNous/new-api/relay/common"
13+
relayconstant "github.com/QuantumNous/new-api/relay/constant"
14+
"github.com/gin-gonic/gin"
15+
"github.com/stretchr/testify/require"
16+
)
17+
18+
// TestConvertImageEditRequestKeepsValidMultipartStreamFields verifies multipart replay.
19+
func TestConvertImageEditRequestKeepsValidMultipartStreamFields(t *testing.T) {
20+
gin.SetMode(gin.TestMode)
21+
22+
var body bytes.Buffer
23+
writer := multipart.NewWriter(&body)
24+
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
25+
require.NoError(t, writer.WriteField("prompt", "edit this image"))
26+
require.NoError(t, writer.WriteField("stream", "true"))
27+
require.NoError(t, writer.WriteField("partial_images", "3"))
28+
part, err := writer.CreateFormFile("image", "input.png")
29+
require.NoError(t, err)
30+
_, err = part.Write([]byte("fake image"))
31+
require.NoError(t, err)
32+
require.NoError(t, writer.Close())
33+
34+
recorder := httptest.NewRecorder()
35+
c, _ := gin.CreateTestContext(recorder)
36+
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
37+
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
38+
require.NoError(t, c.Request.ParseMultipartForm(32<<20))
39+
40+
info := &relaycommon.RelayInfo{
41+
RelayMode: relayconstant.RelayModeImagesEdits,
42+
}
43+
request := dto.ImageRequest{
44+
Model: "gpt-image-1",
45+
Prompt: "edit this image",
46+
Stream: true,
47+
}
48+
49+
converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request)
50+
require.NoError(t, err)
51+
52+
convertedBody, ok := converted.(*bytes.Buffer)
53+
require.True(t, ok)
54+
55+
contentType := c.Request.Header.Get("Content-Type")
56+
replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes()))
57+
replayedRequest.Header.Set("Content-Type", contentType)
58+
require.NoError(t, replayedRequest.ParseMultipartForm(32<<20))
59+
60+
require.Equal(t, "gpt-image-1", replayedRequest.PostForm.Get("model"))
61+
require.Equal(t, "edit this image", replayedRequest.PostForm.Get("prompt"))
62+
require.Equal(t, "true", replayedRequest.PostForm.Get("stream"))
63+
require.Equal(t, "3", replayedRequest.PostForm.Get("partial_images"))
64+
require.Len(t, replayedRequest.MultipartForm.File["image"], 1)
65+
66+
file, err := replayedRequest.MultipartForm.File["image"][0].Open()
67+
require.NoError(t, err)
68+
defer file.Close()
69+
fileBytes, err := io.ReadAll(file)
70+
require.NoError(t, err)
71+
require.Equal(t, []byte("fake image"), fileBytes)
72+
}

relay/channel/openai/image_stream_test.go

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ import (
88
"testing"
99

1010
"github.com/QuantumNous/new-api/constant"
11+
"github.com/QuantumNous/new-api/dto"
1112
relaycommon "github.com/QuantumNous/new-api/relay/common"
1213
"github.com/gin-gonic/gin"
1314
"github.com/stretchr/testify/require"
1415
)
1516

17+
// TestOpenaiImageStreamHandlerForwardsSSEAndUsage verifies image SSE passthrough.
1618
func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
1719
gin.SetMode(gin.TestMode)
1820

@@ -24,7 +26,7 @@ func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
2426
`event: image_generation.partial_image`,
2527
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
2628
``,
27-
`data: {"usage":{"prompt_tokens":3,"completion_tokens":4,"total_tokens":7}}`,
29+
`data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`,
2830
``,
2931
`data: [DONE]`,
3032
``,
@@ -49,8 +51,30 @@ func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
4951
require.Equal(t, 3, usage.PromptTokens)
5052
require.Equal(t, 4, usage.CompletionTokens)
5153
require.Equal(t, 7, usage.TotalTokens)
54+
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
55+
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
5256
require.Contains(t, recorder.Body.String(), `event: image_generation.partial_image`)
5357
require.Contains(t, recorder.Body.String(), `data: {"type":"image_generation.partial_image","b64_json":"partial"}`)
54-
require.Contains(t, recorder.Body.String(), `data: {"usage":{"prompt_tokens":3,"completion_tokens":4,"total_tokens":7}}`)
58+
require.Contains(t, recorder.Body.String(), `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`)
5559
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
5660
}
61+
62+
// TestNormalizeOpenAIUsageMapsImageTokenDetailsWithoutDoubleCounting verifies ImageRatio inputs.
63+
func TestNormalizeOpenAIUsageMapsImageTokenDetailsWithoutDoubleCounting(t *testing.T) {
64+
usage := &dto.Usage{
65+
InputTokens: 5000,
66+
OutputTokens: 4000,
67+
InputTokensDetails: &dto.InputTokenDetails{
68+
ImageTokens: 1000,
69+
TextTokens: 4000,
70+
},
71+
}
72+
73+
normalizeOpenAIUsage(usage)
74+
75+
require.Equal(t, 5000, usage.PromptTokens)
76+
require.Equal(t, 4000, usage.CompletionTokens)
77+
require.Equal(t, 9000, usage.TotalTokens)
78+
require.Equal(t, 1000, usage.PromptTokensDetails.ImageTokens)
79+
require.Equal(t, 4000, usage.PromptTokensDetails.TextTokens)
80+
}

relay/channel/openai/relay-openai.go

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -575,24 +575,34 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
575575
// 写入新的 response body
576576
service.IOCopyBytesGracefully(c, resp, responseBody)
577577

578-
// Once we've written to the client, we should not return errors anymore
579-
// because the upstream has already consumed resources and returned content
580-
// We should still perform billing even if parsing fails
581-
// format
582-
if usageResp.InputTokens > 0 {
583-
usageResp.PromptTokens += usageResp.InputTokens
578+
normalizeOpenAIUsage(&usageResp.Usage)
579+
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
580+
return &usageResp.Usage, nil
581+
}
582+
583+
// normalizeOpenAIUsage maps OpenAI usage aliases into NewAPI billing fields.
584+
func normalizeOpenAIUsage(usage *dto.Usage) {
585+
if usage == nil {
586+
return
584587
}
585-
if usageResp.OutputTokens > 0 {
586-
usageResp.CompletionTokens += usageResp.OutputTokens
588+
if usage.InputTokens != 0 {
589+
usage.PromptTokens = usage.InputTokens
587590
}
588-
if usageResp.InputTokensDetails != nil {
589-
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
590-
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
591+
if usage.OutputTokens != 0 {
592+
usage.CompletionTokens = usage.OutputTokens
593+
}
594+
if usage.InputTokensDetails != nil {
595+
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
596+
usage.PromptTokensDetails.ImageTokens = usage.InputTokensDetails.ImageTokens
597+
usage.PromptTokensDetails.TextTokens = usage.InputTokensDetails.TextTokens
598+
usage.PromptTokensDetails.AudioTokens = usage.InputTokensDetails.AudioTokens
599+
}
600+
if usage.TotalTokens == 0 {
601+
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
591602
}
592-
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
593-
return &usageResp.Usage, nil
594603
}
595604

605+
// OpenaiImageStreamHandler forwards OpenAI Images SSE events and extracts usage.
596606
func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
597607
if resp == nil || resp.Body == nil {
598608
logger.LogError(c, "invalid image stream response")
@@ -609,7 +619,7 @@ func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
609619
}
610620

611621
scanner := bufio.NewScanner(resp.Body)
612-
scanner.Buffer(make([]byte, helper.InitialScannerBufferSize), helper.DefaultMaxScannerBufferSize)
622+
scanner.Buffer(make([]byte, helper.InitialScannerBufferSize), helper.GetScannerBufferSize())
613623
for scanner.Scan() {
614624
line := scanner.Text()
615625
if strings.HasPrefix(line, "data:") {
@@ -621,8 +631,11 @@ func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
621631
info.ReceivedResponseCount++
622632
lastStreamData = common.StringToByteSlice(data)
623633
var usageResp dto.SimpleResponse
624-
if err := common.Unmarshal(lastStreamData, &usageResp); err == nil && service.ValidUsage(&usageResp.Usage) {
625-
usage = &usageResp.Usage
634+
if err := common.Unmarshal(lastStreamData, &usageResp); err == nil {
635+
normalizeOpenAIUsage(&usageResp.Usage)
636+
if service.ValidUsage(&usageResp.Usage) {
637+
usage = &usageResp.Usage
638+
}
626639
}
627640
}
628641
}

relay/helper/openai_image_request_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@ package helper
22

33
import (
44
"bytes"
5+
"io"
56
"mime/multipart"
67
"net/http"
78
"net/http/httptest"
9+
"net/url"
810
"testing"
911

12+
"github.com/QuantumNous/new-api/common"
1013
relayconstant "github.com/QuantumNous/new-api/relay/constant"
1114
"github.com/gin-gonic/gin"
1215
"github.com/stretchr/testify/require"
1316
)
1417

18+
// TestGetAndValidOpenAIImageRequestMultipartStream verifies reusable image edit parsing.
1519
func TestGetAndValidOpenAIImageRequestMultipartStream(t *testing.T) {
1620
gin.SetMode(gin.TestMode)
1721

@@ -26,6 +30,7 @@ func TestGetAndValidOpenAIImageRequestMultipartStream(t *testing.T) {
2630
_, err = part.Write([]byte("fake image"))
2731
require.NoError(t, err)
2832
require.NoError(t, writer.Close())
33+
originalBody := body.String()
2934

3035
recorder := httptest.NewRecorder()
3136
c, _ := gin.CreateTestContext(recorder)
@@ -36,4 +41,13 @@ func TestGetAndValidOpenAIImageRequestMultipartStream(t *testing.T) {
3641
require.NoError(t, err)
3742
require.True(t, req.Stream)
3843
require.True(t, req.IsStream(c))
44+
45+
bodyAfterValidation, err := io.ReadAll(c.Request.Body)
46+
require.NoError(t, err)
47+
require.Equal(t, originalBody, string(bodyAfterValidation))
48+
49+
form, err := common.ParseMultipartFormReusable(c)
50+
require.NoError(t, err)
51+
require.Equal(t, "true", url.Values(form.Value).Get("stream"))
52+
require.Len(t, form.File["image"], 1)
3953
}

relay/helper/stream_scanner.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ const (
2727
DefaultPingInterval = 10 * time.Second
2828
)
2929

30-
func getScannerBufferSize() int {
30+
// GetScannerBufferSize returns the configured maximum SSE scanner token size.
31+
func GetScannerBufferSize() int {
3132
if constant.StreamScannerMaxBufferMB > 0 {
3233
return constant.StreamScannerMaxBufferMB << 20
3334
}
@@ -108,7 +109,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
108109
close(stopChan)
109110
}()
110111

111-
scanner.Buffer(make([]byte, InitialScannerBufferSize), getScannerBufferSize())
112+
scanner.Buffer(make([]byte, InitialScannerBufferSize), GetScannerBufferSize())
112113
scanner.Split(bufio.ScanLines)
113114
SetEventStreamHeaders(c)
114115

relay/helper/valid_request.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"fmt"
77
"math"
8+
"net/url"
89
"strconv"
910
"strings"
1011

@@ -146,11 +147,13 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
146147
switch relayMode {
147148
case relayconstant.RelayModeImagesEdits:
148149
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
149-
_, err := c.MultipartForm()
150+
form, err := common.ParseMultipartFormReusable(c)
150151
if err != nil {
151152
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
152153
}
153-
formData := c.Request.PostForm
154+
formData := url.Values(form.Value)
155+
c.Request.MultipartForm = form
156+
c.Request.PostForm = formData
154157
imageRequest.Prompt = formData.Get("prompt")
155158
imageRequest.Model = formData.Get("model")
156159
imageRequest.N = common.GetPointer(uint(common.String2Int(formData.Get("n"))))

0 commit comments

Comments
 (0)