diff --git a/dto/gemini.go b/dto/gemini.go index 489ebea534b..0ddd7e4c542 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -483,6 +483,29 @@ type GeminiImageRequest struct { Parameters GeminiImageParameters `json:"parameters"` } +// GetTokenCountMeta returns prompt text metadata for Gemini image requests. +func (r *GeminiImageRequest) GetTokenCountMeta() *types.TokenCountMeta { + inputTexts := make([]string, 0, len(r.Instances)) + for _, instance := range r.Instances { + if strings.TrimSpace(instance.Prompt) != "" { + inputTexts = append(inputTexts, strings.TrimSpace(instance.Prompt)) + } + } + return &types.TokenCountMeta{ + CombineText: strings.Join(inputTexts, "\n"), + } +} + +// IsStream reports whether the Gemini image request uses streaming. +func (r *GeminiImageRequest) IsStream(c *gin.Context) bool { + return false +} + +// SetModelName keeps Gemini image model names in the request URL path. +func (r *GeminiImageRequest) SetModelName(modelName string) { + // Gemini image request carries model in URL path, not in body. +} + type GeminiImageInstance struct { Prompt string `json:"prompt"` } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 680c4ee484e..627e4cfa646 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -57,16 +57,18 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf return nil, errors.New("not implemented") } +// ConvertImageRequest converts OpenAI image requests into Gemini Imagen requests. func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { if !strings.HasPrefix(info.UpstreamModelName, "imagen") { return nil, errors.New("not supported model for image generation, only imagen models are supported") } - // convert size to aspect ratio but allow user to specify aspect ratio aspectRatio := "1:1" // default aspect ratio size := strings.TrimSpace(request.Size) if size != "" { - if strings.Contains(size, ":") { + if info.UpstreamModelName == "imagen-4.0-generate-001" { + aspectRatio = convertImagen4SizeToAspectRatio(size) + } else if strings.Contains(size, ":") { aspectRatio = size } else { switch size { @@ -123,6 +125,27 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf return geminiRequest, nil } +// convertImagen4SizeToAspectRatio maps OpenAI image sizes to Imagen 4 aspect ratios. +func convertImagen4SizeToAspectRatio(size string) string { + if strings.Contains(size, ":") { + return size + } + switch strings.ToLower(strings.TrimSpace(size)) { + case "1024x1024", "2048x2048": + return "1:1" + case "896x1280", "1792x2560": + return "3:4" + case "1280x896", "2560x1792": + return "4:3" + case "768x1408", "1536x2816", "1024x1792": + return "9:16" + case "1408x768", "2816x1536", "1792x1024": + return "16:9" + default: + return "1:1" + } +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } @@ -246,8 +269,12 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } +// DoResponse converts Gemini upstream responses into relay usage data. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeGemini { + if strings.HasPrefix(info.UpstreamModelName, "imagen") && strings.Contains(info.RequestURLPath, ":predict") { + return GeminiNativeImagePredictHandler(c, info, resp) + } if strings.Contains(info.RequestURLPath, ":embedContent") || strings.Contains(info.RequestURLPath, ":batchEmbedContents") { return NativeGeminiEmbeddingHandler(c, resp, info) diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 1a434a43276..fc7bbd21f37 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -95,3 +95,35 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn return true }) } + +// GeminiNativeImagePredictHandler relays native Gemini Imagen predict responses. +func GeminiNativeImagePredictHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + defer service.CloseResponseBodyGracefully(resp) + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + if common.DebugEnabled { + println(string(responseBody)) + } + + var geminiResponse dto.GeminiImageResponse + if err = common.Unmarshal(responseBody, &geminiResponse); err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + generatedImages := 0 + for _, prediction := range geminiResponse.Predictions { + if prediction.RaiFilteredReason != "" { + continue + } + if prediction.BytesBase64Encoded != "" { + generatedImages++ + } + } + + service.IOCopyBytesGracefully(c, resp, responseBody) + return buildGeminiImageUsage(generatedImages), nil +} diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 355c75d71b7..e8137a804b2 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -1538,6 +1538,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h return usage, nil } +// GeminiImageHandler converts Gemini Imagen responses to OpenAI image responses. func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, readErr := io.ReadAll(resp.Body) if readErr != nil { @@ -1578,18 +1579,20 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. c.Writer.WriteHeader(resp.StatusCode) _, _ = c.Writer.Write(jsonResponse) + return buildGeminiImageUsage(len(openAIResponse.Data)), nil +} + +// buildGeminiImageUsage returns usage for generated Gemini images. +func buildGeminiImageUsage(generatedImages int) *dto.Usage { // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb - // each image has fixed 258 tokens + // each image has fixed 258 tokens. const imageTokens = 258 - generatedImages := len(openAIResponse.Data) - - usage := &dto.Usage{ - PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens - CompletionTokens: 0, // image generation does not calculate completion tokens - TotalTokens: imageTokens * generatedImages, + imageTokenCount := imageTokens * generatedImages + return &dto.Usage{ + PromptTokens: imageTokenCount, + CompletionTokens: 0, + TotalTokens: imageTokenCount, } - - return usage, nil } type GeminiModelsResponse struct { diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 0d91032d0f3..a200a3776eb 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -1,7 +1,6 @@ package vertex import ( - "encoding/json" "errors" "fmt" "io" @@ -229,6 +228,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } +// ConvertOpenAIRequest converts OpenAI-compatible requests for Vertex upstreams. func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") @@ -266,7 +266,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn } if len(request.ExtraBody) > 0 { var extra map[string]any - if err := json.Unmarshal(request.ExtraBody, &extra); err == nil { + if err := common.Unmarshal(request.ExtraBody, &extra); err == nil { if n, ok := extra["n"].(float64); ok && n > 0 { imgReq.N = lo.ToPtr(uint(n)) } @@ -327,6 +327,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } +// DoResponse converts Vertex upstream responses into relay usage data. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { claudeAdaptor := claude.Adaptor{} if info.IsStream { @@ -348,6 +349,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom return claudeAdaptor.DoResponse(c, resp, info) case RequestModeGemini: if info.RelayMode == constant.RelayModeGemini { + if strings.HasPrefix(info.UpstreamModelName, "imagen") && strings.Contains(info.RequestURLPath, ":predict") { + return gemini.GeminiNativeImagePredictHandler(c, info, resp) + } return gemini.GeminiTextGenerationHandler(c, info, resp) } else { if strings.HasPrefix(info.UpstreamModelName, "imagen") { diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 3b4bafe2a67..f6739edf567 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -2,6 +2,7 @@ package relay import ( "bytes" + "errors" "fmt" "io" "net/http" @@ -52,9 +53,14 @@ func trimModelThinking(modelName string) string { return modelName } +// GeminiHelper handles native Gemini relay requests. func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { info.InitChannelMeta(c) + if strings.HasPrefix(info.OriginModelName, "imagen") && strings.Contains(info.RequestURLPath, ":predict") { + return geminiImagePredictHelper(c, info) + } + geminiReq, ok := info.Request.(*dto.GeminiChatRequest) if !ok { return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) @@ -198,6 +204,71 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ return nil } +// geminiImagePredictHelper forwards native Gemini Imagen predict requests. +func geminiImagePredictHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + imageReq, ok := info.Request.(*dto.GeminiImageRequest) + if !ok { + return types.NewErrorWithStatusCode( + fmt.Errorf("invalid request type for imagen predict, got %T", info.Request), + types.ErrorCodeInvalidRequest, + http.StatusBadRequest, + types.ErrOptionWithSkipRetry(), + ) + } + + if err := helper.ModelMappedHelper(c, info, imageReq); err != nil { + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) + } + + adaptor := GetAdaptor(info.ApiType) + if adaptor == nil { + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + } + adaptor.Init(info) + + storage, err := common.GetBodyStorage(c) + if err != nil { + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) + } + body, err := storage.Bytes() + if err != nil { + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) + } + if len(info.ParamOverride) > 0 { + body, err = relaycommon.ApplyParamOverride(body, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + if err != nil { + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + } + } + logger.LogDebug(c, "Gemini imagen predict request body: "+string(body)) + + resp, err := adaptor.DoRequest(c, info, bytes.NewReader(body)) + if err != nil { + logger.LogError(c, "Do gemini imagen predict request failed: "+err.Error()) + return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) + } + + statusCodeMappingStr := c.GetString("status_code_mapping") + if resp == nil { + return types.NewErrorWithStatusCode(errors.New("empty upstream response"), types.ErrorCodeBadResponse, http.StatusInternalServerError) + } + httpResp := resp.(*http.Response) + if httpResp.StatusCode != http.StatusOK { + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError + } + + usage, openaiErr := adaptor.DoResponse(c, httpResp, info) + if openaiErr != nil { + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + + service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil) + return nil +} + func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { info.InitChannelMeta(c) diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go index 2581b2812c9..3649cc3e791 100644 --- a/relay/helper/valid_request.go +++ b/relay/helper/valid_request.go @@ -16,6 +16,7 @@ import ( "github.com/gin-gonic/gin" ) +// GetAndValidateRequest parses and validates requests for the relay format. func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dto.Request, err error) { relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) @@ -27,6 +28,8 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt request, err = GetAndValidateGeminiEmbeddingRequest(c) } else if strings.Contains(c.Request.URL.Path, ":batchEmbedContents") { request, err = GetAndValidateGeminiBatchEmbeddingRequest(c) + } else if isGeminiImagePredictPath(c.Request.URL.Path) { + request, err = GetAndValidateGeminiImageRequest(c) } else { request, err = GetAndValidateGeminiRequest(c) } @@ -246,6 +249,33 @@ func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest return textRequest, nil } +// isGeminiImagePredictPath reports whether a Gemini path targets Imagen predict. +func isGeminiImagePredictPath(path string) bool { + if !strings.Contains(path, ":predict") { + return false + } + modelName := extractGeminiModelNameFromPath(path) + return strings.HasPrefix(modelName, "imagen") +} + +// extractGeminiModelNameFromPath extracts the model segment from a Gemini URL path. +func extractGeminiModelNameFromPath(path string) string { + const modelsPrefix = "/models/" + modelsIndex := strings.Index(path, modelsPrefix) + if modelsIndex == -1 { + return "" + } + startIndex := modelsIndex + len(modelsPrefix) + if startIndex >= len(path) { + return "" + } + colonIndex := strings.Index(path[startIndex:], ":") + if colonIndex == -1 { + return path[startIndex:] + } + return path[startIndex : startIndex+colonIndex] +} + func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenAIRequest, error) { textRequest := &dto.GeneralOpenAIRequest{} err := common.UnmarshalBodyReusable(c, textRequest) @@ -321,6 +351,18 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) return request, nil } +// GetAndValidateGeminiImageRequest parses and validates a native Gemini image request. +func GetAndValidateGeminiImageRequest(c *gin.Context) (*dto.GeminiImageRequest, error) { + request := &dto.GeminiImageRequest{} + if err := common.UnmarshalBodyReusable(c, request); err != nil { + return nil, err + } + if len(request.Instances) == 0 { + return nil, errors.New("instances is required") + } + return request, nil +} + func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingRequest, error) { request := &dto.GeminiEmbeddingRequest{} err := common.UnmarshalBodyReusable(c, request) diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 80702ee42ad..98e39062726 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -281,6 +281,7 @@ var defaultModelPrice = map[string]float64{ "suno_lyrics": 0.01, "dall-e-3": 0.04, "imagen-3.0-generate-002": 0.03, + "imagen-4.0-generate-001": 0.04, "black-forest-labs/flux-1.1-pro": 0.04, "gpt-4-gizmo-*": 0.1, "mj_video": 0.8,