Skip to content

Commit e4a6b8c

Browse files
Merge upstream/main (auto-sync feat/copilot)
- 776a9c0 fix(translator/gemini): support developer role in OpenAI Responses requests - 430e679 fix(auth): strip "generate" from payload during WebSocket HTTP fallback - 3a54fb7 Merge branch 'dev', commit 'refs/pull/3621/head' of github.com:router-for-me/CLIProxyAPI into dev
2 parents 58189cf + 3a54fb7 commit e4a6b8c

4 files changed

Lines changed: 241 additions & 2 deletions

File tree

internal/translator/gemini/openai/responses/gemini_openai-responses_request.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
119119

120120
switch itemType {
121121
case "message":
122-
if strings.EqualFold(itemRole, "system") {
122+
if strings.EqualFold(itemRole, "system") || strings.EqualFold(itemRole, "developer") {
123123
if contentArray := item.Get("content"); contentArray.Exists() {
124124
systemInstr := []byte(`{"parts":[]}`)
125125
if systemInstructionResult := gjson.GetBytes(out, "systemInstruction"); systemInstructionResult.Exists() {

internal/translator/gemini/openai/responses/gemini_openai-responses_request_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,80 @@ func TestConvertOpenAIResponsesRequestToGemini_ReasoningSignatureCompatibility(t
5555
}
5656
}
5757

58+
func TestConvertOpenAIResponsesRequestToGemini_SystemAndDeveloperRoles(t *testing.T) {
59+
tests := []struct {
60+
name string
61+
role string
62+
wantText string
63+
}{
64+
{
65+
name: "system role",
66+
role: "system",
67+
wantText: "System message text",
68+
},
69+
{
70+
name: "developer role",
71+
role: "developer",
72+
wantText: "Developer message text",
73+
},
74+
}
75+
76+
for _, tt := range tests {
77+
t.Run(tt.name, func(t *testing.T) {
78+
input := []byte(`{
79+
"instructions": "Be a helpful assistant",
80+
"input": [
81+
{
82+
"type": "message",
83+
"role": "` + tt.role + `",
84+
"content": [
85+
{
86+
"type": "input_text",
87+
"text": "` + tt.wantText + `"
88+
}
89+
]
90+
},
91+
{
92+
"type": "message",
93+
"role": "user",
94+
"content": [
95+
{
96+
"type": "input_text",
97+
"text": "Hello"
98+
}
99+
]
100+
}
101+
]
102+
}`)
103+
104+
output := ConvertOpenAIResponsesRequestToGemini("gemini-3.5-flash", input, false)
105+
result := gjson.ParseBytes(output)
106+
107+
systemInstruction := result.Get("systemInstruction")
108+
if !systemInstruction.Exists() {
109+
t.Fatalf("systemInstruction missing. Output: %s", output)
110+
}
111+
parts := systemInstruction.Get("parts")
112+
if got := parts.Get("#").Int(); got != 2 {
113+
t.Fatalf("systemInstruction parts = %d, want 2. Output: %s", got, output)
114+
}
115+
if got := parts.Get("0.text").String(); got != "Be a helpful assistant" {
116+
t.Fatalf("first systemInstruction part = %q, want %q. Output: %s", got, "Be a helpful assistant", output)
117+
}
118+
if got := parts.Get("1.text").String(); got != tt.wantText {
119+
t.Fatalf("second systemInstruction part = %q, want %q. Output: %s", got, tt.wantText, output)
120+
}
121+
122+
result.Get("contents").ForEach(func(_, value gjson.Result) bool {
123+
if role := value.Get("role").String(); role == tt.role {
124+
t.Fatalf("role %q leaked into contents array. Output: %s", tt.role, output)
125+
}
126+
return true
127+
})
128+
})
129+
}
130+
}
131+
58132
func validResponsesGPTReasoningSignature() string {
59133
raw := make([]byte, 1+8+16+16+32)
60134
raw[0] = 0x80

sdk/api/handlers/openai/openai_responses_websocket_test.go

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ type websocketPinnedFailoverExecutor struct {
7777
payloads map[string][][]byte
7878
}
7979

80+
type websocketBootstrapFallbackExecutor struct {
81+
mu sync.Mutex
82+
authIDs []string
83+
payloads map[string][][]byte
84+
}
85+
8086
type websocketPinnedFailoverStatusError struct {
8187
status int
8288
msg string
@@ -86,6 +92,70 @@ func (e websocketPinnedFailoverStatusError) Error() string { return e.msg }
8692

8793
func (e websocketPinnedFailoverStatusError) StatusCode() int { return e.status }
8894

95+
func (e *websocketBootstrapFallbackExecutor) Identifier() string { return "test-provider" }
96+
97+
func (e *websocketBootstrapFallbackExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
98+
return coreexecutor.Response{}, errors.New("not implemented")
99+
}
100+
101+
func (e *websocketBootstrapFallbackExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
102+
authID := ""
103+
if auth != nil {
104+
authID = auth.ID
105+
}
106+
107+
e.mu.Lock()
108+
if e.payloads == nil {
109+
e.payloads = make(map[string][][]byte)
110+
}
111+
e.authIDs = append(e.authIDs, authID)
112+
e.payloads[authID] = append(e.payloads[authID], bytes.Clone(req.Payload))
113+
e.mu.Unlock()
114+
115+
chunks := make(chan coreexecutor.StreamChunk, 1)
116+
if authID == "auth-ws" {
117+
chunks <- coreexecutor.StreamChunk{Err: websocketPinnedFailoverStatusError{
118+
status: http.StatusServiceUnavailable,
119+
msg: `{"error":{"message":"websocket bootstrap failed","type":"server_error","code":"ws_failed"}}`,
120+
}}
121+
close(chunks)
122+
return &coreexecutor.StreamResult{Chunks: chunks}, nil
123+
}
124+
125+
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-http","output":[{"type":"message","id":"out-http"}]}}`)}
126+
close(chunks)
127+
return &coreexecutor.StreamResult{Chunks: chunks}, nil
128+
}
129+
130+
func (e *websocketBootstrapFallbackExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
131+
return auth, nil
132+
}
133+
134+
func (e *websocketBootstrapFallbackExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
135+
return coreexecutor.Response{}, errors.New("not implemented")
136+
}
137+
138+
func (e *websocketBootstrapFallbackExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
139+
return nil, errors.New("not implemented")
140+
}
141+
142+
func (e *websocketBootstrapFallbackExecutor) AuthIDs() []string {
143+
e.mu.Lock()
144+
defer e.mu.Unlock()
145+
return append([]string(nil), e.authIDs...)
146+
}
147+
148+
func (e *websocketBootstrapFallbackExecutor) Payloads(authID string) [][]byte {
149+
e.mu.Lock()
150+
defer e.mu.Unlock()
151+
src := e.payloads[authID]
152+
out := make([][]byte, len(src))
153+
for i := range src {
154+
out[i] = bytes.Clone(src[i])
155+
}
156+
return out
157+
}
158+
89159
type websocketUpstreamDisconnectExecutor struct {
90160
mu sync.Mutex
91161
subscribed chan string
@@ -1340,6 +1410,87 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
13401410
}
13411411
}
13421412

1413+
func TestResponsesWebsocketStripsGenerateWhenWebsocketAttemptFallsBackToHTTP(t *testing.T) {
1414+
gin.SetMode(gin.TestMode)
1415+
1416+
selector := &orderedWebsocketSelector{order: []string{"auth-ws", "auth-http"}}
1417+
executor := &websocketBootstrapFallbackExecutor{}
1418+
manager := coreauth.NewManager(nil, selector, nil)
1419+
manager.RegisterExecutor(executor)
1420+
1421+
authWS := &coreauth.Auth{
1422+
ID: "auth-ws",
1423+
Provider: executor.Identifier(),
1424+
Status: coreauth.StatusActive,
1425+
Attributes: map[string]string{"websockets": "true"},
1426+
}
1427+
if _, err := manager.Register(context.Background(), authWS); err != nil {
1428+
t.Fatalf("Register websocket auth: %v", err)
1429+
}
1430+
authHTTP := &coreauth.Auth{ID: "auth-http", Provider: executor.Identifier(), Status: coreauth.StatusActive}
1431+
if _, err := manager.Register(context.Background(), authHTTP); err != nil {
1432+
t.Fatalf("Register HTTP auth: %v", err)
1433+
}
1434+
1435+
registry.GetGlobalRegistry().RegisterClient(authWS.ID, authWS.Provider, []*registry.ModelInfo{{ID: "test-model"}})
1436+
registry.GetGlobalRegistry().RegisterClient(authHTTP.ID, authHTTP.Provider, []*registry.ModelInfo{{ID: "test-model"}})
1437+
t.Cleanup(func() {
1438+
registry.GetGlobalRegistry().UnregisterClient(authWS.ID)
1439+
registry.GetGlobalRegistry().UnregisterClient(authHTTP.ID)
1440+
})
1441+
1442+
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
1443+
h := NewOpenAIResponsesAPIHandler(base)
1444+
router := gin.New()
1445+
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
1446+
1447+
server := httptest.NewServer(router)
1448+
defer server.Close()
1449+
1450+
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
1451+
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
1452+
if err != nil {
1453+
t.Fatalf("dial websocket: %v", err)
1454+
}
1455+
defer func() {
1456+
if errClose := conn.Close(); errClose != nil {
1457+
t.Fatalf("close websocket: %v", errClose)
1458+
}
1459+
}()
1460+
1461+
request := `{"type":"response.create","model":"test-model","generate":false,"input":[{"type":"message","id":"msg-1"}]}`
1462+
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(request)); errWrite != nil {
1463+
t.Fatalf("write websocket message: %v", errWrite)
1464+
}
1465+
_, payload, errReadMessage := conn.ReadMessage()
1466+
if errReadMessage != nil {
1467+
t.Fatalf("read websocket message: %v", errReadMessage)
1468+
}
1469+
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
1470+
t.Fatalf("payload type = %s, want %s: %s", got, wsEventTypeCompleted, payload)
1471+
}
1472+
1473+
if got := executor.AuthIDs(); len(got) != 2 || got[0] != "auth-ws" || got[1] != "auth-http" {
1474+
t.Fatalf("selected auth IDs = %v, want [auth-ws auth-http]", got)
1475+
}
1476+
1477+
wsPayloads := executor.Payloads("auth-ws")
1478+
if len(wsPayloads) != 1 {
1479+
t.Fatalf("auth-ws payload count = %d, want 1", len(wsPayloads))
1480+
}
1481+
if !gjson.GetBytes(wsPayloads[0], "generate").Exists() {
1482+
t.Fatalf("websocket attempt payload unexpectedly stripped generate: %s", wsPayloads[0])
1483+
}
1484+
1485+
httpPayloads := executor.Payloads("auth-http")
1486+
if len(httpPayloads) != 1 {
1487+
t.Fatalf("auth-http payload count = %d, want 1", len(httpPayloads))
1488+
}
1489+
if gjson.GetBytes(httpPayloads[0], "generate").Exists() {
1490+
t.Fatalf("generate leaked after HTTP fallback: %s", httpPayloads[0])
1491+
}
1492+
}
1493+
13431494
func TestWebsocketClientAddressUsesGinClientIP(t *testing.T) {
13441495
gin.SetMode(gin.TestMode)
13451496

sdk/cliproxy/auth/conductor.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
2626
coreusage "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/usage"
2727
log "github.com/sirupsen/logrus"
28+
"github.com/tidwall/sjson"
2829
)
2930

3031
// ProviderExecutor defines the contract required by Manager to execute provider calls.
@@ -1581,7 +1582,8 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
15811582
lastErr = errPrepare
15821583
continue
15831584
}
1584-
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, models, pooled)
1585+
execReq := sanitizeDownstreamWebsocketFallbackRequest(execCtx, auth, req)
1586+
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, execReq, opts, routeModel, models, pooled)
15851587
if errStream != nil {
15861588
if errCtx := execCtx.Err(); errCtx != nil {
15871589
return nil, errCtx
@@ -1599,6 +1601,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
15991601
}
16001602
}
16011603

1604+
func sanitizeDownstreamWebsocketFallbackRequest(ctx context.Context, auth *Auth, req cliproxyexecutor.Request) cliproxyexecutor.Request {
1605+
if !cliproxyexecutor.DownstreamWebsocket(ctx) || authWebsocketsEnabled(auth) || len(req.Payload) == 0 {
1606+
return req
1607+
}
1608+
updated, errDelete := sjson.DeleteBytes(req.Payload, "generate")
1609+
if errDelete != nil {
1610+
return req
1611+
}
1612+
req.Payload = updated
1613+
return req
1614+
}
1615+
16021616
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {
16031617
requestedModel = strings.TrimSpace(requestedModel)
16041618
if requestedModel == "" {

0 commit comments

Comments
 (0)