Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions dto/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,29 @@ type GeminiImageRequest struct {
Parameters GeminiImageParameters `json:"parameters"`
}

// GetTokenCountMeta returns prompt text metadata for Gemini image requests.
func (r *GeminiImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
inputTexts := make([]string, 0, len(r.Instances))
for _, instance := range r.Instances {
if strings.TrimSpace(instance.Prompt) != "" {
inputTexts = append(inputTexts, strings.TrimSpace(instance.Prompt))
}
}
return &types.TokenCountMeta{
CombineText: strings.Join(inputTexts, "\n"),
}
}

// IsStream reports whether the Gemini image request uses streaming.
func (r *GeminiImageRequest) IsStream(c *gin.Context) bool {
return false
}

// SetModelName keeps Gemini image model names in the request URL path.
func (r *GeminiImageRequest) SetModelName(modelName string) {
// Gemini image request carries model in URL path, not in body.
}

type GeminiImageInstance struct {
Prompt string `json:"prompt"`
}
Expand Down
31 changes: 29 additions & 2 deletions relay/channel/gemini/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,18 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
return nil, errors.New("not implemented")
}

// ConvertImageRequest converts OpenAI image requests into Gemini Imagen requests.
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
return nil, errors.New("not supported model for image generation, only imagen models are supported")
}

// convert size to aspect ratio but allow user to specify aspect ratio
aspectRatio := "1:1" // default aspect ratio
size := strings.TrimSpace(request.Size)
if size != "" {
if strings.Contains(size, ":") {
if info.UpstreamModelName == "imagen-4.0-generate-001" {
aspectRatio = convertImagen4SizeToAspectRatio(size)
} else if strings.Contains(size, ":") {
aspectRatio = size
} else {
switch size {
Expand Down Expand Up @@ -123,6 +125,27 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
return geminiRequest, nil
}

// convertImagen4SizeToAspectRatio maps OpenAI image sizes to Imagen 4 aspect ratios.
func convertImagen4SizeToAspectRatio(size string) string {
if strings.Contains(size, ":") {
return size
}
switch strings.ToLower(strings.TrimSpace(size)) {
case "1024x1024", "2048x2048":
return "1:1"
case "896x1280", "1792x2560":
return "3:4"
case "1280x896", "2560x1792":
return "4:3"
case "768x1408", "1536x2816", "1024x1792":
return "9:16"
case "1408x768", "2816x1536", "1792x1024":
return "16:9"
default:
return "1:1"
}
}

func (a *Adaptor) Init(info *relaycommon.RelayInfo) {

}
Expand Down Expand Up @@ -246,8 +269,12 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}

// DoResponse converts Gemini upstream responses into relay usage data.
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini {
if strings.HasPrefix(info.UpstreamModelName, "imagen") && strings.Contains(info.RequestURLPath, ":predict") {
return GeminiNativeImagePredictHandler(c, info, resp)
}
if strings.Contains(info.RequestURLPath, ":embedContent") ||
strings.Contains(info.RequestURLPath, ":batchEmbedContents") {
return NativeGeminiEmbeddingHandler(c, resp, info)
Expand Down
32 changes: 32 additions & 0 deletions relay/channel/gemini/relay-gemini-native.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,35 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
return true
})
}

// GeminiNativeImagePredictHandler relays native Gemini Imagen predict responses.
func GeminiNativeImagePredictHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer service.CloseResponseBodyGracefully(resp)

responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}

if common.DebugEnabled {
println(string(responseBody))
}

var geminiResponse dto.GeminiImageResponse
if err = common.Unmarshal(responseBody, &geminiResponse); err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}

generatedImages := 0
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue
}
if prediction.BytesBase64Encoded != "" {
generatedImages++
}
}

service.IOCopyBytesGracefully(c, resp, responseBody)
return buildGeminiImageUsage(generatedImages), nil
}
21 changes: 12 additions & 9 deletions relay/channel/gemini/relay-gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -1538,6 +1538,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
return usage, nil
}

// GeminiImageHandler converts Gemini Imagen responses to OpenAI image responses.
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
Expand Down Expand Up @@ -1578,18 +1579,20 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)

return buildGeminiImageUsage(len(openAIResponse.Data)), nil
}

// buildGeminiImageUsage returns usage for generated Gemini images.
func buildGeminiImageUsage(generatedImages int) *dto.Usage {
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
// each image has fixed 258 tokens.
const imageTokens = 258
generatedImages := len(openAIResponse.Data)

usage := &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
imageTokenCount := imageTokens * generatedImages
return &dto.Usage{
PromptTokens: imageTokenCount,
CompletionTokens: 0,
TotalTokens: imageTokenCount,
}

return usage, nil
}

type GeminiModelsResponse struct {
Expand Down
8 changes: 6 additions & 2 deletions relay/channel/vertex/adaptor.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package vertex

import (
"encoding/json"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -229,6 +228,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}

// ConvertOpenAIRequest converts OpenAI-compatible requests for Vertex upstreams.
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
Expand Down Expand Up @@ -266,7 +266,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
}
if len(request.ExtraBody) > 0 {
var extra map[string]any
if err := json.Unmarshal(request.ExtraBody, &extra); err == nil {
if err := common.Unmarshal(request.ExtraBody, &extra); err == nil {
if n, ok := extra["n"].(float64); ok && n > 0 {
imgReq.N = lo.ToPtr(uint(n))
}
Expand Down Expand Up @@ -327,6 +327,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}

// DoResponse converts Vertex upstream responses into relay usage data.
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
claudeAdaptor := claude.Adaptor{}
if info.IsStream {
Expand All @@ -348,6 +349,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return claudeAdaptor.DoResponse(c, resp, info)
case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini {
if strings.HasPrefix(info.UpstreamModelName, "imagen") && strings.Contains(info.RequestURLPath, ":predict") {
return gemini.GeminiNativeImagePredictHandler(c, info, resp)
}
return gemini.GeminiTextGenerationHandler(c, info, resp)
} else {
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
Expand Down
71 changes: 71 additions & 0 deletions relay/gemini_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package relay

import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -52,9 +53,14 @@ func trimModelThinking(modelName string) string {
return modelName
}

// GeminiHelper handles native Gemini relay requests.
func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)

if strings.HasPrefix(info.OriginModelName, "imagen") && strings.Contains(info.RequestURLPath, ":predict") {
return geminiImagePredictHelper(c, info)
}

geminiReq, ok := info.Request.(*dto.GeminiChatRequest)
if !ok {
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
Expand Down Expand Up @@ -198,6 +204,71 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
return nil
}

// geminiImagePredictHelper forwards native Gemini Imagen predict requests.
func geminiImagePredictHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
imageReq, ok := info.Request.(*dto.GeminiImageRequest)
if !ok {
return types.NewErrorWithStatusCode(
fmt.Errorf("invalid request type for imagen predict, got %T", info.Request),
types.ErrorCodeInvalidRequest,
http.StatusBadRequest,
types.ErrOptionWithSkipRetry(),
)
}

if err := helper.ModelMappedHelper(c, info, imageReq); err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}

adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(info)

storage, err := common.GetBodyStorage(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
body, err := storage.Bytes()
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
if len(info.ParamOverride) > 0 {
body, err = relaycommon.ApplyParamOverride(body, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}
logger.LogDebug(c, "Gemini imagen predict request body: "+string(body))

resp, err := adaptor.DoRequest(c, info, bytes.NewReader(body))
if err != nil {
logger.LogError(c, "Do gemini imagen predict request failed: "+err.Error())
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}

statusCodeMappingStr := c.GetString("status_code_mapping")
if resp == nil {
return types.NewErrorWithStatusCode(errors.New("empty upstream response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
}
httpResp := resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}

usage, openaiErr := adaptor.DoResponse(c, httpResp, info)
if openaiErr != nil {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}

service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
return nil
}

func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)

Expand Down
42 changes: 42 additions & 0 deletions relay/helper/valid_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/gin-gonic/gin"
)

// GetAndValidateRequest parses and validates requests for the relay format.
func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dto.Request, err error) {
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)

Expand All @@ -27,6 +28,8 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt
request, err = GetAndValidateGeminiEmbeddingRequest(c)
} else if strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
request, err = GetAndValidateGeminiBatchEmbeddingRequest(c)
} else if isGeminiImagePredictPath(c.Request.URL.Path) {
request, err = GetAndValidateGeminiImageRequest(c)
} else {
request, err = GetAndValidateGeminiRequest(c)
}
Expand Down Expand Up @@ -246,6 +249,33 @@ func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest
return textRequest, nil
}

// isGeminiImagePredictPath reports whether a Gemini path targets Imagen predict.
func isGeminiImagePredictPath(path string) bool {
if !strings.Contains(path, ":predict") {
return false
}
modelName := extractGeminiModelNameFromPath(path)
return strings.HasPrefix(modelName, "imagen")
}

// extractGeminiModelNameFromPath extracts the model segment from a Gemini URL path.
func extractGeminiModelNameFromPath(path string) string {
const modelsPrefix = "/models/"
modelsIndex := strings.Index(path, modelsPrefix)
if modelsIndex == -1 {
return ""
}
startIndex := modelsIndex + len(modelsPrefix)
if startIndex >= len(path) {
return ""
}
colonIndex := strings.Index(path[startIndex:], ":")
if colonIndex == -1 {
return path[startIndex:]
}
return path[startIndex : startIndex+colonIndex]
}

func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenAIRequest, error) {
textRequest := &dto.GeneralOpenAIRequest{}
err := common.UnmarshalBodyReusable(c, textRequest)
Expand Down Expand Up @@ -321,6 +351,18 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error)
return request, nil
}

// GetAndValidateGeminiImageRequest parses and validates a native Gemini image request.
func GetAndValidateGeminiImageRequest(c *gin.Context) (*dto.GeminiImageRequest, error) {
request := &dto.GeminiImageRequest{}
if err := common.UnmarshalBodyReusable(c, request); err != nil {
return nil, err
}
if len(request.Instances) == 0 {
return nil, errors.New("instances is required")
}
return request, nil
}

func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingRequest, error) {
request := &dto.GeminiEmbeddingRequest{}
err := common.UnmarshalBodyReusable(c, request)
Expand Down
1 change: 1 addition & 0 deletions setting/ratio_setting/model_ratio.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ var defaultModelPrice = map[string]float64{
"suno_lyrics": 0.01,
"dall-e-3": 0.04,
"imagen-3.0-generate-002": 0.03,
"imagen-4.0-generate-001": 0.04,
"black-forest-labs/flux-1.1-pro": 0.04,
"gpt-4-gizmo-*": 0.1,
"mj_video": 0.8,
Expand Down