-
-
Notifications
You must be signed in to change notification settings - Fork 6.9k
feat(vertex): support embedding via :predict endpoint #4640
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,33 @@ | ||
| package vertex | ||
|
|
||
| import "github.com/QuantumNous/new-api/common" | ||
| import ( | ||
| "io" | ||
| "net/http" | ||
| "strings" | ||
|
|
||
| "github.com/QuantumNous/new-api/common" | ||
| "github.com/QuantumNous/new-api/dto" | ||
| "github.com/QuantumNous/new-api/logger" | ||
| relaycommon "github.com/QuantumNous/new-api/relay/common" | ||
| "github.com/QuantumNous/new-api/service" | ||
| "github.com/QuantumNous/new-api/types" | ||
|
|
||
| "github.com/gin-gonic/gin" | ||
| ) | ||
|
|
||
| // isVertexEmbedding decides whether to route the response through the Vertex | ||
| // embedding handler. Both the OpenAI-compatible /v1/embeddings path and the | ||
| // Gemini-native :embedContent / :batchEmbedContents paths land here, plus | ||
| // embedding model names regardless of relay mode. | ||
| func isVertexEmbedding(info *relaycommon.RelayInfo) bool { | ||
| if strings.Contains(info.RequestURLPath, "embed") { | ||
| return true | ||
| } | ||
| m := info.UpstreamModelName | ||
| return strings.HasPrefix(m, "gemini-embedding") || | ||
| strings.HasPrefix(m, "text-embedding") || | ||
| strings.HasPrefix(m, "text-multilingual-embedding") | ||
| } | ||
|
|
||
| func GetModelRegion(other string, localModelName string) string { | ||
| // if other is json string | ||
|
|
@@ -20,3 +47,68 @@ func GetModelRegion(other string, localModelName string) string { | |
| } | ||
| return other | ||
| } | ||
|
|
||
| type VertexEmbeddingResponse struct { | ||
| Predictions []struct { | ||
| Embeddings struct { | ||
| Statistics struct { | ||
| TokenCount int `json:"token_count"` | ||
| } `json:"statistics"` | ||
| Values []float64 `json:"values"` | ||
| } `json:"embeddings"` | ||
| } `json:"predictions"` | ||
| Metadata struct { | ||
| BillableCharacterCount int `json:"billableCharacterCount"` | ||
| } `json:"metadata"` | ||
| } | ||
|
|
||
| func vertexEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) { | ||
| defer service.CloseResponseBodyGracefully(resp) | ||
|
|
||
| responseBody, err := io.ReadAll(resp.Body) | ||
| if err != nil { | ||
| return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) | ||
| } | ||
|
|
||
| if common.DebugEnabled { | ||
| logger.LogDebug(c, "Vertex Embedding response body: "+string(responseBody)) | ||
| } | ||
|
|
||
| var vertexResponse VertexEmbeddingResponse | ||
| if err := common.Unmarshal(responseBody, &vertexResponse); err != nil { | ||
| return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) | ||
| } | ||
|
|
||
| openAIResponse := dto.OpenAIEmbeddingResponse{ | ||
| Object: "list", | ||
| Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(vertexResponse.Predictions)), | ||
| Model: info.UpstreamModelName, | ||
| } | ||
|
|
||
| tokenCount := 0 | ||
| for i, prediction := range vertexResponse.Predictions { | ||
| openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{ | ||
| Object: "embedding", | ||
| Embedding: prediction.Embeddings.Values, | ||
| Index: i, | ||
| }) | ||
| tokenCount += prediction.Embeddings.Statistics.TokenCount | ||
| } | ||
|
|
||
| usage := &dto.Usage{ | ||
| PromptTokens: tokenCount, | ||
| TotalTokens: tokenCount, | ||
| } | ||
| openAIResponse.Usage = *usage | ||
|
|
||
| jsonResponse, err := common.Marshal(openAIResponse) | ||
| if err != nil { | ||
| return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) | ||
| } | ||
|
|
||
| c.Writer.Header().Set("Content-Type", "application/json") | ||
| c.Writer.WriteHeader(http.StatusOK) | ||
| _, _ = c.Writer.Write(jsonResponse) | ||
|
|
||
| return usage, nil | ||
| } | ||
|
Comment on lines
+65
to
+114
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Check whether other handlers in this repo guard on resp.StatusCode and how relay dispatches DoResponse.
rg -nP -C3 '\bresp\.StatusCode\b' relay/channel/vertex relay/channel/gemini
rg -nP -C5 'DoResponse\(' relay/relay_adaptor.go relay/relay_text.go relay/relay_embedding.go 2>/dev/null
fd -t f -e go . relay | xargs rg -nP -C2 'StatusCode\s*!=\s*http\.StatusOK' | head -n 80Repository: QuantumNous/new-api Length of output: 6210 Add status code check to
Add an early return on non-2xx status, mirroring the pattern used throughout the relay framework (e.g., 🤖 Prompt for AI Agents |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Surface body-conversion errors instead of swallowing them.
The new embedding rewrite path silently ignores both unmarshal and marshal failures:
if err := common.Unmarshal(bodyBytes, &req); err == nil { ... }— when parsing fails, control just falls through and the code sends{"instances":[]}to Vertex. The client then sees a confusing upstream400 INVALID_ARGUMENT: Should provide instances for text model predictionwhile the real parse error is gone.newBodyBytes, _ := common.Marshal(vertexReq)— if marshalling fails, an empty/nilbody is forwarded silently.Please return the error in both cases so misuse / schema drift is observable.
🛠️ Suggested fix
if info.IsGeminiBatchEmbedding { var req dto.GeminiBatchEmbeddingRequest - if err := common.Unmarshal(bodyBytes, &req); err == nil { - for _, r := range req.Requests { - instance := make(map[string]interface{}) - content := "" - for _, part := range r.Content.Parts { - if part.Text != "" { - content += part.Text - } - } - instance["content"] = content - if r.TaskType != "" { - instance["task_type"] = r.TaskType - } - if r.Title != "" { - instance["title"] = r.Title - } - instances = append(instances, instance) - } - } + if err := common.Unmarshal(bodyBytes, &req); err != nil { + return nil, fmt.Errorf("failed to parse gemini batch embedding request: %w", err) + } + for _, r := range req.Requests { + instance := make(map[string]interface{}) + content := "" + for _, part := range r.Content.Parts { + if part.Text != "" { + content += part.Text + } + } + instance["content"] = content + if r.TaskType != "" { + instance["task_type"] = r.TaskType + } + if r.Title != "" { + instance["title"] = r.Title + } + instances = append(instances, instance) + } } else { var req dto.GeminiEmbeddingRequest - if err := common.Unmarshal(bodyBytes, &req); err == nil { - instance := make(map[string]interface{}) - content := "" - for _, part := range req.Content.Parts { - if part.Text != "" { - content += part.Text - } - } - instance["content"] = content - if req.TaskType != "" { - instance["task_type"] = req.TaskType - } - if req.Title != "" { - instance["title"] = req.Title - } - instances = append(instances, instance) - - if req.OutputDimensionality > 0 { - vertexReq["parameters"] = map[string]interface{}{ - "outputDimensionality": req.OutputDimensionality, - } - } - } + if err := common.Unmarshal(bodyBytes, &req); err != nil { + return nil, fmt.Errorf("failed to parse gemini embedding request: %w", err) + } + instance := make(map[string]interface{}) + content := "" + for _, part := range req.Content.Parts { + if part.Text != "" { + content += part.Text + } + } + instance["content"] = content + if req.TaskType != "" { + instance["task_type"] = req.TaskType + } + if req.Title != "" { + instance["title"] = req.Title + } + instances = append(instances, instance) + + if req.OutputDimensionality > 0 { + vertexReq["parameters"] = map[string]interface{}{ + "outputDimensionality": req.OutputDimensionality, + } + } } vertexReq["instances"] = instances - newBodyBytes, _ := common.Marshal(vertexReq) + newBodyBytes, err := common.Marshal(vertexReq) + if err != nil { + return nil, fmt.Errorf("failed to marshal vertex embedding request: %w", err) + } requestBody = bytes.NewReader(newBodyBytes)🤖 Prompt for AI Agents