diff --git a/dto/openai_image.go b/dto/openai_image.go index fdef12b1a7d..416697e3b8c 100644 --- a/dto/openai_image.go +++ b/dto/openai_image.go @@ -26,11 +26,11 @@ type ImageRequest struct { OutputFormat json.RawMessage `json:"output_format,omitempty"` OutputCompression json.RawMessage `json:"output_compression,omitempty"` PartialImages json.RawMessage `json:"partial_images,omitempty"` - // Stream bool `json:"stream,omitempty"` - Images json.RawMessage `json:"images,omitempty"` - Mask json.RawMessage `json:"mask,omitempty"` - InputFidelity json.RawMessage `json:"input_fidelity,omitempty"` - Watermark *bool `json:"watermark,omitempty"` + Stream bool `json:"stream,omitempty"` + Images json.RawMessage `json:"images,omitempty"` + Mask json.RawMessage `json:"mask,omitempty"` + InputFidelity json.RawMessage `json:"input_fidelity,omitempty"` + Watermark *bool `json:"watermark,omitempty"` // zhipu 4v WatermarkEnabled json.RawMessage `json:"watermark_enabled,omitempty"` UserId json.RawMessage `json:"user_id,omitempty"` @@ -163,7 +163,7 @@ func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta { } func (i *ImageRequest) IsStream(c *gin.Context) bool { - return false + return i.Stream } func (i *ImageRequest) SetModelName(modelName string) { diff --git a/dto/openai_image_test.go b/dto/openai_image_test.go new file mode 100644 index 00000000000..27e13637745 --- /dev/null +++ b/dto/openai_image_test.go @@ -0,0 +1,16 @@ +package dto + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestImageRequestStreamJSON verifies that image requests preserve stream=true. +func TestImageRequestStreamJSON(t *testing.T) { + var req ImageRequest + require.NoError(t, req.UnmarshalJSON([]byte(`{"model":"gpt-image-1","prompt":"draw a cat","stream":true}`))) + + require.True(t, req.Stream) + require.True(t, req.IsStream(nil)) +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 6941ca54a73..fe797d98a2c 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -9,6 +9,7 @@ import ( "mime/multipart" "net/http" "net/textproto" + "net/url" "path/filepath" "strings" @@ -437,10 +438,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf // 使用已解析的 multipart 表单,避免重复解析 mf := c.Request.MultipartForm if mf == nil { - if _, err := c.MultipartForm(); err != nil { - return nil, errors.New("failed to parse multipart form") + form, err := common.ParseMultipartFormReusable(c) + if err != nil { + return nil, fmt.Errorf("failed to parse multipart form: %w", err) } - mf = c.Request.MultipartForm + c.Request.MultipartForm = form + c.Request.PostForm = url.Values(form.Value) + mf = form } // 写入所有非文件字段 @@ -623,7 +627,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case relayconstant.RelayModeAudioTranscription: err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat) case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits: - usage, err = OpenaiHandlerWithUsage(c, info, resp) + if info.IsStream { + usage, err = OpenaiImageStreamHandler(c, info, resp) + } else { + usage, err = OpenaiHandlerWithUsage(c, info, resp) + } case relayconstant.RelayModeRerank: usage, err = common_handler.RerankHandler(c, info, resp) case relayconstant.RelayModeResponses: diff --git a/relay/channel/openai/image_edit_test.go b/relay/channel/openai/image_edit_test.go new file mode 100644 index 00000000000..b37551b04da --- /dev/null +++ b/relay/channel/openai/image_edit_test.go @@ -0,0 +1,121 @@ +package openai + +import ( + "bytes" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/dto" + relaycommon "github.com/QuantumNous/new-api/relay/common" + relayconstant "github.com/QuantumNous/new-api/relay/constant" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// TestConvertImageEditRequestKeepsValidMultipartStreamFields verifies multipart replay. +func TestConvertImageEditRequestKeepsValidMultipartStreamFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-1")) + require.NoError(t, writer.WriteField("prompt", "edit this image")) + require.NoError(t, writer.WriteField("stream", "true")) + require.NoError(t, writer.WriteField("partial_images", "3")) + part, err := writer.CreateFormFile("image", "input.png") + require.NoError(t, err) + _, err = part.Write([]byte("fake image")) + require.NoError(t, err) + require.NoError(t, writer.Close()) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body) + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + require.NoError(t, c.Request.ParseMultipartForm(32<<20)) + + info := &relaycommon.RelayInfo{ + RelayMode: relayconstant.RelayModeImagesEdits, + } + request := dto.ImageRequest{ + Model: "gpt-image-1", + Prompt: "edit this image", + Stream: true, + } + + converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request) + require.NoError(t, err) + + convertedBody, ok := converted.(*bytes.Buffer) + require.True(t, ok) + + contentType := c.Request.Header.Get("Content-Type") + replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes())) + replayedRequest.Header.Set("Content-Type", contentType) + require.NoError(t, replayedRequest.ParseMultipartForm(32<<20)) + + require.Equal(t, "gpt-image-1", replayedRequest.PostForm.Get("model")) + require.Equal(t, "edit this image", replayedRequest.PostForm.Get("prompt")) + require.Equal(t, "true", replayedRequest.PostForm.Get("stream")) + require.Equal(t, "3", replayedRequest.PostForm.Get("partial_images")) + require.Len(t, replayedRequest.MultipartForm.File["image"], 1) + + file, err := replayedRequest.MultipartForm.File["image"][0].Open() + require.NoError(t, err) + defer file.Close() + fileBytes, err := io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, []byte("fake image"), fileBytes) +} + +// TestConvertImageEditRequestParsesReusableMultipartWhenFormIsMissing verifies fallback parsing. +func TestConvertImageEditRequestParsesReusableMultipartWhenFormIsMissing(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-1")) + require.NoError(t, writer.WriteField("prompt", "edit without pre-parsed form")) + require.NoError(t, writer.WriteField("stream", "true")) + part, err := writer.CreateFormFile("image", "input.png") + require.NoError(t, err) + _, err = part.Write([]byte("fake image")) + require.NoError(t, err) + require.NoError(t, writer.Close()) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body) + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + + storage, err := common.GetBodyStorage(c) + require.NoError(t, err) + c.Request.Body = io.NopCloser(storage) + c.Request.MultipartForm = nil + c.Request.PostForm = nil + + info := &relaycommon.RelayInfo{ + RelayMode: relayconstant.RelayModeImagesEdits, + } + request := dto.ImageRequest{ + Model: "gpt-image-1", + Prompt: "edit without pre-parsed form", + Stream: true, + } + + converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request) + require.NoError(t, err) + + convertedBody, ok := converted.(*bytes.Buffer) + require.True(t, ok) + replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes())) + replayedRequest.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + require.NoError(t, replayedRequest.ParseMultipartForm(32<<20)) + require.Equal(t, "edit without pre-parsed form", replayedRequest.PostForm.Get("prompt")) + require.Equal(t, "true", replayedRequest.PostForm.Get("stream")) + require.Len(t, replayedRequest.MultipartForm.File["image"], 1) +} diff --git a/relay/channel/openai/image_stream_test.go b/relay/channel/openai/image_stream_test.go new file mode 100644 index 00000000000..3dcfade620b --- /dev/null +++ b/relay/channel/openai/image_stream_test.go @@ -0,0 +1,85 @@ +package openai + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// TestOpenaiImageStreamHandlerForwardsSSEAndUsage verifies image SSE passthrough. +func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) { + oldMode := gin.Mode() + gin.SetMode(gin.TestMode) + t.Cleanup(func() { gin.SetMode(oldMode) }) + + oldTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 30 + t.Cleanup(func() { constant.StreamingTimeout = oldTimeout }) + + body := strings.Join([]string{ + `event: image_generation.partial_image`, + `data: {"type":"image_generation.partial_image","b64_json":"partial"}`, + ``, + `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`, + ``, + `data: [DONE]`, + ``, + }, "\n") + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + info := &relaycommon.RelayInfo{ + ChannelMeta: &relaycommon.ChannelMeta{}, + IsStream: true, + } + + usage, err := OpenaiImageStreamHandler(c, info, resp) + require.Nil(t, err) + require.Equal(t, 3, usage.PromptTokens) + require.Equal(t, 4, usage.CompletionTokens) + require.Equal(t, 7, usage.TotalTokens) + require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens) + require.Equal(t, 1, usage.PromptTokensDetails.TextTokens) + require.Contains(t, recorder.Body.String(), `event: image_generation.partial_image`) + require.Contains(t, recorder.Body.String(), `data: {"type":"image_generation.partial_image","b64_json":"partial"}`) + require.Contains(t, recorder.Body.String(), `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`) + require.Contains(t, recorder.Body.String(), `data: [DONE]`) + require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type")) +} + +// TestNormalizeOpenAIUsageMapsImageTokenDetailsWithoutDoubleCounting verifies ImageRatio inputs. +func TestNormalizeOpenAIUsageMapsImageTokenDetailsWithoutDoubleCounting(t *testing.T) { + usage := &dto.Usage{ + InputTokens: 5000, + OutputTokens: 4000, + InputTokensDetails: &dto.InputTokenDetails{ + CachedCreationTokens: 200, + ImageTokens: 1000, + TextTokens: 4000, + }, + } + + normalizeOpenAIUsage(usage) + + require.Equal(t, 5000, usage.PromptTokens) + require.Equal(t, 4000, usage.CompletionTokens) + require.Equal(t, 9000, usage.TotalTokens) + require.Equal(t, 200, usage.PromptTokensDetails.CachedCreationTokens) + require.Equal(t, 1000, usage.PromptTokensDetails.ImageTokens) + require.Equal(t, 4000, usage.PromptTokensDetails.TextTokens) +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index a85751844c0..3cdfbb0a08e 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -1,6 +1,7 @@ package openai import ( + "bufio" "fmt" "io" "net/http" @@ -574,22 +575,96 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h // 写入新的 response body service.IOCopyBytesGracefully(c, resp, responseBody) - // Once we've written to the client, we should not return errors anymore - // because the upstream has already consumed resources and returned content - // We should still perform billing even if parsing fails - // format - if usageResp.InputTokens > 0 { - usageResp.PromptTokens += usageResp.InputTokens + normalizeOpenAIUsage(&usageResp.Usage) + applyUsagePostProcessing(info, &usageResp.Usage, responseBody) + return &usageResp.Usage, nil +} + +// normalizeOpenAIUsage maps OpenAI usage aliases into NewAPI billing fields. +func normalizeOpenAIUsage(usage *dto.Usage) { + if usage == nil { + return } - if usageResp.OutputTokens > 0 { - usageResp.CompletionTokens += usageResp.OutputTokens + if usage.InputTokens != 0 { + usage.PromptTokens = usage.InputTokens } - if usageResp.InputTokensDetails != nil { - usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens - usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens + if usage.OutputTokens != 0 { + usage.CompletionTokens = usage.OutputTokens } - applyUsagePostProcessing(info, &usageResp.Usage, responseBody) - return &usageResp.Usage, nil + if usage.InputTokensDetails != nil { + usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens + usage.PromptTokensDetails.CachedCreationTokens = usage.InputTokensDetails.CachedCreationTokens + usage.PromptTokensDetails.ImageTokens = usage.InputTokensDetails.ImageTokens + usage.PromptTokensDetails.TextTokens = usage.InputTokensDetails.TextTokens + usage.PromptTokensDetails.AudioTokens = usage.InputTokensDetails.AudioTokens + } + if usage.TotalTokens == 0 { + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } +} + +// OpenaiImageStreamHandler forwards OpenAI Images SSE events and extracts usage. +func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + if resp == nil || resp.Body == nil { + logger.LogError(c, "invalid image stream response") + return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError) + } + + contentType := strings.ToLower(resp.Header.Get("Content-Type")) + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices || !strings.Contains(contentType, "text/event-stream") { + return OpenaiHandlerWithUsage(c, info, resp) + } + defer service.CloseResponseBodyGracefully(resp) + + usage := &dto.Usage{} + var lastStreamData []byte + + helper.SetEventStreamHeaders(c) + if info.StreamStatus == nil { + info.StreamStatus = relaycommon.NewStreamStatus() + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, helper.InitialScannerBufferSize), helper.GetScannerBufferSize()) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data:") { + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if data == "[DONE]" { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil) + } else if data != "" { + info.SetFirstResponseTime() + info.ReceivedResponseCount++ + lastStreamData = common.StringToByteSlice(data) + var usageResp dto.SimpleResponse + if err := common.Unmarshal(lastStreamData, &usageResp); err == nil { + normalizeOpenAIUsage(&usageResp.Usage) + if service.ValidUsage(&usageResp.Usage) { + usage = &usageResp.Usage + } + } + } + } + if _, err := c.Writer.Write(append([]byte(line), '\n')); err != nil { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err) + return usage, nil + } + if line == "" { + if err := helper.FlushWriter(c); err != nil { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err) + return usage, nil + } + } + } + if err := scanner.Err(); err != nil { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, err) + } else if info.StreamStatus.EndReason == relaycommon.StreamEndReasonNone { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil) + } + _ = helper.FlushWriter(c) + + applyUsagePostProcessing(info, usage, lastStreamData) + return usage, nil } func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) { diff --git a/relay/helper/openai_image_request_test.go b/relay/helper/openai_image_request_test.go new file mode 100644 index 00000000000..a0bb46c6aa4 --- /dev/null +++ b/relay/helper/openai_image_request_test.go @@ -0,0 +1,73 @@ +package helper + +import ( + "bytes" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/QuantumNous/new-api/common" + relayconstant "github.com/QuantumNous/new-api/relay/constant" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// TestGetAndValidOpenAIImageRequestMultipartStream verifies reusable image edit parsing. +func TestGetAndValidOpenAIImageRequestMultipartStream(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-1")) + require.NoError(t, writer.WriteField("prompt", "edit this image")) + require.NoError(t, writer.WriteField("stream", "true")) + require.NoError(t, writer.WriteField("n", "1")) + part, err := writer.CreateFormFile("image", "input.png") + require.NoError(t, err) + _, err = part.Write([]byte("fake image")) + require.NoError(t, err) + require.NoError(t, writer.Close()) + originalBody := body.String() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body) + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + + req, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits) + require.NoError(t, err) + require.True(t, req.Stream) + require.True(t, req.IsStream(c)) + + bodyAfterValidation, err := io.ReadAll(c.Request.Body) + require.NoError(t, err) + require.Equal(t, originalBody, string(bodyAfterValidation)) + + form, err := common.ParseMultipartFormReusable(c) + require.NoError(t, err) + require.Equal(t, "true", url.Values(form.Value).Get("stream")) + require.Len(t, form.File["image"], 1) +} + +// TestGetAndValidOpenAIImageRequestMultipartStreamInvalidValue verifies stream validation. +func TestGetAndValidOpenAIImageRequestMultipartStreamInvalidValue(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-1")) + require.NoError(t, writer.WriteField("stream", "notabool")) + require.NoError(t, writer.Close()) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body) + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + + _, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid stream value") +} diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index a9bc5e16a72..df7aa39ea54 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -27,7 +27,8 @@ const ( DefaultPingInterval = 10 * time.Second ) -func getScannerBufferSize() int { +// GetScannerBufferSize returns the configured maximum SSE scanner token size. +func GetScannerBufferSize() int { if constant.StreamScannerMaxBufferMB > 0 { return constant.StreamScannerMaxBufferMB << 20 } @@ -40,8 +41,9 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon return } - // 无条件新建 StreamStatus - info.StreamStatus = relaycommon.NewStreamStatus() + if info.StreamStatus == nil { + info.StreamStatus = relaycommon.NewStreamStatus() + } // 确保响应体总是被关闭 defer func() { @@ -107,7 +109,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon close(stopChan) }() - scanner.Buffer(make([]byte, InitialScannerBufferSize), getScannerBufferSize()) + scanner.Buffer(make([]byte, InitialScannerBufferSize), GetScannerBufferSize()) scanner.Split(bufio.ScanLines) SetEventStreamHeaders(c) diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go index 2581b2812c9..a53b63693d7 100644 --- a/relay/helper/valid_request.go +++ b/relay/helper/valid_request.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "math" + "net/url" + "strconv" "strings" "github.com/QuantumNous/new-api/common" @@ -144,16 +146,25 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq switch relayMode { case relayconstant.RelayModeImagesEdits: if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { - _, err := c.MultipartForm() + form, err := common.ParseMultipartFormReusable(c) if err != nil { return nil, fmt.Errorf("failed to parse image edit form request: %w", err) } - formData := c.Request.PostForm + formData := url.Values(form.Value) + c.Request.MultipartForm = form + c.Request.PostForm = formData imageRequest.Prompt = formData.Get("prompt") imageRequest.Model = formData.Get("model") imageRequest.N = common.GetPointer(uint(common.String2Int(formData.Get("n")))) imageRequest.Quality = formData.Get("quality") imageRequest.Size = formData.Get("size") + if streamValue := strings.TrimSpace(formData.Get("stream")); streamValue != "" { + stream, err := strconv.ParseBool(streamValue) + if err != nil { + return nil, fmt.Errorf("invalid stream value: %w", err) + } + imageRequest.Stream = stream + } if imageValue := formData.Get("image"); imageValue != "" { imageRequest.Image, _ = common.Marshal(imageValue) }