diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 0d91032d0f3..4278ab6bd27 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -1,6 +1,7 @@ package vertex import ( + "bytes" "encoding/json" "errors" "fmt" @@ -10,6 +11,7 @@ import ( "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/relay/channel" "github.com/QuantumNous/new-api/relay/channel/claude" "github.com/QuantumNous/new-api/relay/channel/gemini" @@ -190,7 +192,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { suffix = "generateContent" } - if strings.HasPrefix(info.UpstreamModelName, "imagen") { + if strings.HasPrefix(info.UpstreamModelName, "imagen") || strings.Contains(info.UpstreamModelName, "embedding") { suffix = "predict" } return a.getRequestUrl(info, info.UpstreamModelName, suffix) @@ -314,8 +316,8 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + geminiAdaptor := gemini.Adaptor{} + return geminiAdaptor.ConvertEmbeddingRequest(c, info, request) } func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { @@ -324,6 +326,67 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + if a.RequestMode == RequestModeGemini && strings.Contains(c.Request.URL.Path, "embed") { + bodyBytes, err := io.ReadAll(requestBody) + if err != nil { + return nil, err + } + + vertexReq := make(map[string]interface{}) + instances := make([]interface{}, 0) + + if info.IsGeminiBatchEmbedding { + var req dto.GeminiBatchEmbeddingRequest + if err := common.Unmarshal(bodyBytes, &req); err == nil { + for _, r := range req.Requests { + instance := make(map[string]interface{}) + content := "" + for _, part := range r.Content.Parts { + if part.Text != "" { + content += part.Text + } + } + instance["content"] = content + if r.TaskType != "" { + instance["task_type"] = r.TaskType + } + if r.Title != "" { + instance["title"] = r.Title + } + instances = append(instances, instance) + } + } + } else { + var req dto.GeminiEmbeddingRequest + if err := common.Unmarshal(bodyBytes, &req); err == nil { + instance := make(map[string]interface{}) + content := "" + for _, part := range req.Content.Parts { + if part.Text != "" { + content += part.Text + } + } + instance["content"] = content + if req.TaskType != "" { + instance["task_type"] = req.TaskType + } + if req.Title != "" { + instance["title"] = req.Title + } + instances = append(instances, instance) + + if req.OutputDimensionality > 0 { + vertexReq["parameters"] = map[string]interface{}{ + "outputDimensionality": req.OutputDimensionality, + } + } + } + } + vertexReq["instances"] = instances + newBodyBytes, _ := common.Marshal(vertexReq) + requestBody = bytes.NewReader(newBodyBytes) + logger.LogDebug(c, "Vertex Embedding request body: "+string(newBodyBytes)) + } return channel.DoApiRequest(a, c, info, requestBody) } @@ -347,6 +410,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case RequestModeClaude: return claudeAdaptor.DoResponse(c, resp, info) case RequestModeGemini: + if isVertexEmbedding(info) { + return vertexEmbeddingHandler(c, resp, info) + } if info.RelayMode == constant.RelayModeGemini { return gemini.GeminiTextGenerationHandler(c, info, resp) } else { diff --git a/relay/channel/vertex/relay-vertex.go b/relay/channel/vertex/relay-vertex.go index c5103a977ec..9bde9691f08 100644 --- a/relay/channel/vertex/relay-vertex.go +++ b/relay/channel/vertex/relay-vertex.go @@ -1,6 +1,33 @@ package vertex -import "github.com/QuantumNous/new-api/common" +import ( + "io" + "net/http" + "strings" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/logger" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/service" + "github.com/QuantumNous/new-api/types" + + "github.com/gin-gonic/gin" +) + +// isVertexEmbedding decides whether to route the response through the Vertex +// embedding handler. Both the OpenAI-compatible /v1/embeddings path and the +// Gemini-native :embedContent / :batchEmbedContents paths land here, plus +// embedding model names regardless of relay mode. +func isVertexEmbedding(info *relaycommon.RelayInfo) bool { + if strings.Contains(info.RequestURLPath, "embed") { + return true + } + m := info.UpstreamModelName + return strings.HasPrefix(m, "gemini-embedding") || + strings.HasPrefix(m, "text-embedding") || + strings.HasPrefix(m, "text-multilingual-embedding") +} func GetModelRegion(other string, localModelName string) string { // if other is json string @@ -20,3 +47,68 @@ func GetModelRegion(other string, localModelName string) string { } return other } + +type VertexEmbeddingResponse struct { + Predictions []struct { + Embeddings struct { + Statistics struct { + TokenCount int `json:"token_count"` + } `json:"statistics"` + Values []float64 `json:"values"` + } `json:"embeddings"` + } `json:"predictions"` + Metadata struct { + BillableCharacterCount int `json:"billableCharacterCount"` + } `json:"metadata"` +} + +func vertexEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) { + defer service.CloseResponseBodyGracefully(resp) + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + if common.DebugEnabled { + logger.LogDebug(c, "Vertex Embedding response body: "+string(responseBody)) + } + + var vertexResponse VertexEmbeddingResponse + if err := common.Unmarshal(responseBody, &vertexResponse); err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + openAIResponse := dto.OpenAIEmbeddingResponse{ + Object: "list", + Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(vertexResponse.Predictions)), + Model: info.UpstreamModelName, + } + + tokenCount := 0 + for i, prediction := range vertexResponse.Predictions { + openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{ + Object: "embedding", + Embedding: prediction.Embeddings.Values, + Index: i, + }) + tokenCount += prediction.Embeddings.Statistics.TokenCount + } + + usage := &dto.Usage{ + PromptTokens: tokenCount, + TotalTokens: tokenCount, + } + openAIResponse.Usage = *usage + + jsonResponse, err := common.Marshal(openAIResponse) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(jsonResponse) + + return usage, nil +}