Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions dto/openai_image.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ type ImageRequest struct {
OutputFormat json.RawMessage `json:"output_format,omitempty"`
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
PartialImages json.RawMessage `json:"partial_images,omitempty"`
// Stream bool `json:"stream,omitempty"`
Images json.RawMessage `json:"images,omitempty"`
Mask json.RawMessage `json:"mask,omitempty"`
InputFidelity json.RawMessage `json:"input_fidelity,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
Stream bool `json:"stream,omitempty"`
Images json.RawMessage `json:"images,omitempty"`
Mask json.RawMessage `json:"mask,omitempty"`
InputFidelity json.RawMessage `json:"input_fidelity,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
// zhipu 4v
WatermarkEnabled json.RawMessage `json:"watermark_enabled,omitempty"`
UserId json.RawMessage `json:"user_id,omitempty"`
Expand Down Expand Up @@ -163,7 +163,7 @@ func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
}

func (i *ImageRequest) IsStream(c *gin.Context) bool {
return false
return i.Stream
}

func (i *ImageRequest) SetModelName(modelName string) {
Expand Down
16 changes: 16 additions & 0 deletions dto/openai_image_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package dto

import (
"testing"

"github.com/stretchr/testify/require"
)

// TestImageRequestStreamJSON verifies that image requests preserve stream=true.
func TestImageRequestStreamJSON(t *testing.T) {
var req ImageRequest
require.NoError(t, req.UnmarshalJSON([]byte(`{"model":"gpt-image-1","prompt":"draw a cat","stream":true}`)))

require.True(t, req.Stream)
require.True(t, req.IsStream(nil))
}
16 changes: 12 additions & 4 deletions relay/channel/openai/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"mime/multipart"
"net/http"
"net/textproto"
"net/url"
"path/filepath"
"strings"

Expand Down Expand Up @@ -437,10 +438,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
// 使用已解析的 multipart 表单,避免重复解析
mf := c.Request.MultipartForm
if mf == nil {
if _, err := c.MultipartForm(); err != nil {
return nil, errors.New("failed to parse multipart form")
form, err := common.ParseMultipartFormReusable(c)
if err != nil {
return nil, fmt.Errorf("failed to parse multipart form: %w", err)
}
mf = c.Request.MultipartForm
c.Request.MultipartForm = form
c.Request.PostForm = url.Values(form.Value)
mf = form
}

// 写入所有非文件字段
Expand Down Expand Up @@ -623,7 +627,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
case relayconstant.RelayModeAudioTranscription:
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
usage, err = OpenaiHandlerWithUsage(c, info, resp)
if info.IsStream {
usage, err = OpenaiImageStreamHandler(c, info, resp)
} else {
usage, err = OpenaiHandlerWithUsage(c, info, resp)
}
case relayconstant.RelayModeRerank:
usage, err = common_handler.RerankHandler(c, info, resp)
case relayconstant.RelayModeResponses:
Expand Down
121 changes: 121 additions & 0 deletions relay/channel/openai/image_edit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package openai

import (
"bytes"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"testing"

"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)

// TestConvertImageEditRequestKeepsValidMultipartStreamFields verifies multipart replay.
func TestConvertImageEditRequestKeepsValidMultipartStreamFields(t *testing.T) {
gin.SetMode(gin.TestMode)

var body bytes.Buffer
writer := multipart.NewWriter(&body)
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
require.NoError(t, writer.WriteField("prompt", "edit this image"))
require.NoError(t, writer.WriteField("stream", "true"))
require.NoError(t, writer.WriteField("partial_images", "3"))
part, err := writer.CreateFormFile("image", "input.png")
require.NoError(t, err)
_, err = part.Write([]byte("fake image"))
require.NoError(t, err)
require.NoError(t, writer.Close())

recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
require.NoError(t, c.Request.ParseMultipartForm(32<<20))

info := &relaycommon.RelayInfo{
RelayMode: relayconstant.RelayModeImagesEdits,
}
request := dto.ImageRequest{
Model: "gpt-image-1",
Prompt: "edit this image",
Stream: true,
}

converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request)
require.NoError(t, err)

convertedBody, ok := converted.(*bytes.Buffer)
require.True(t, ok)

contentType := c.Request.Header.Get("Content-Type")
replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes()))
replayedRequest.Header.Set("Content-Type", contentType)
require.NoError(t, replayedRequest.ParseMultipartForm(32<<20))

require.Equal(t, "gpt-image-1", replayedRequest.PostForm.Get("model"))
require.Equal(t, "edit this image", replayedRequest.PostForm.Get("prompt"))
require.Equal(t, "true", replayedRequest.PostForm.Get("stream"))
require.Equal(t, "3", replayedRequest.PostForm.Get("partial_images"))
require.Len(t, replayedRequest.MultipartForm.File["image"], 1)

file, err := replayedRequest.MultipartForm.File["image"][0].Open()
require.NoError(t, err)
defer file.Close()
fileBytes, err := io.ReadAll(file)
require.NoError(t, err)
require.Equal(t, []byte("fake image"), fileBytes)
}

// TestConvertImageEditRequestParsesReusableMultipartWhenFormIsMissing verifies fallback parsing.
func TestConvertImageEditRequestParsesReusableMultipartWhenFormIsMissing(t *testing.T) {
gin.SetMode(gin.TestMode)

var body bytes.Buffer
writer := multipart.NewWriter(&body)
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
require.NoError(t, writer.WriteField("prompt", "edit without pre-parsed form"))
require.NoError(t, writer.WriteField("stream", "true"))
part, err := writer.CreateFormFile("image", "input.png")
require.NoError(t, err)
_, err = part.Write([]byte("fake image"))
require.NoError(t, err)
require.NoError(t, writer.Close())

recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
c.Request.Header.Set("Content-Type", writer.FormDataContentType())

storage, err := common.GetBodyStorage(c)
require.NoError(t, err)
c.Request.Body = io.NopCloser(storage)
c.Request.MultipartForm = nil
c.Request.PostForm = nil

info := &relaycommon.RelayInfo{
RelayMode: relayconstant.RelayModeImagesEdits,
}
request := dto.ImageRequest{
Model: "gpt-image-1",
Prompt: "edit without pre-parsed form",
Stream: true,
}

converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request)
require.NoError(t, err)

convertedBody, ok := converted.(*bytes.Buffer)
require.True(t, ok)
replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes()))
replayedRequest.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
require.NoError(t, replayedRequest.ParseMultipartForm(32<<20))
require.Equal(t, "edit without pre-parsed form", replayedRequest.PostForm.Get("prompt"))
require.Equal(t, "true", replayedRequest.PostForm.Get("stream"))
require.Len(t, replayedRequest.MultipartForm.File["image"], 1)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
85 changes: 85 additions & 0 deletions relay/channel/openai/image_stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package openai

import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)

// TestOpenaiImageStreamHandlerForwardsSSEAndUsage verifies image SSE passthrough.
func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
oldMode := gin.Mode()
gin.SetMode(gin.TestMode)
t.Cleanup(func() { gin.SetMode(oldMode) })

oldTimeout := constant.StreamingTimeout
constant.StreamingTimeout = 30
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
Comment thread
coderabbitai[bot] marked this conversation as resolved.

body := strings.Join([]string{
`event: image_generation.partial_image`,
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
``,
`data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`,
``,
`data: [DONE]`,
``,
}, "\n")

recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)

resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(body)),
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
}
info := &relaycommon.RelayInfo{
ChannelMeta: &relaycommon.ChannelMeta{},
IsStream: true,
}

usage, err := OpenaiImageStreamHandler(c, info, resp)
require.Nil(t, err)
require.Equal(t, 3, usage.PromptTokens)
require.Equal(t, 4, usage.CompletionTokens)
require.Equal(t, 7, usage.TotalTokens)
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
require.Contains(t, recorder.Body.String(), `event: image_generation.partial_image`)
require.Contains(t, recorder.Body.String(), `data: {"type":"image_generation.partial_image","b64_json":"partial"}`)
require.Contains(t, recorder.Body.String(), `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`)
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
}

// TestNormalizeOpenAIUsageMapsImageTokenDetailsWithoutDoubleCounting verifies ImageRatio inputs.
func TestNormalizeOpenAIUsageMapsImageTokenDetailsWithoutDoubleCounting(t *testing.T) {
usage := &dto.Usage{
InputTokens: 5000,
OutputTokens: 4000,
InputTokensDetails: &dto.InputTokenDetails{
CachedCreationTokens: 200,
ImageTokens: 1000,
TextTokens: 4000,
},
}

normalizeOpenAIUsage(usage)

require.Equal(t, 5000, usage.PromptTokens)
require.Equal(t, 4000, usage.CompletionTokens)
require.Equal(t, 9000, usage.TotalTokens)
require.Equal(t, 200, usage.PromptTokensDetails.CachedCreationTokens)
require.Equal(t, 1000, usage.PromptTokensDetails.ImageTokens)
require.Equal(t, 4000, usage.PromptTokensDetails.TextTokens)
}
Loading