diff --git a/internal/runtime/executor/xai_executor.go b/internal/runtime/executor/xai_executor.go index aeab85d7ae..6cb79764cc 100644 --- a/internal/runtime/executor/xai_executor.go +++ b/internal/runtime/executor/xai_executor.go @@ -480,6 +480,7 @@ type xaiPreparedRequest struct { func (e *XAIExecutor) prepareResponsesRequest(ctx context.Context, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) (*xaiPreparedRequest, error) { baseModel := thinking.ParseSuffix(req.Model).ModelName + baseModel = xaiCompactionModel(baseModel, opts) from := opts.SourceFormat to := sdktranslator.FromString("codex") originalPayloadSource := req.Payload @@ -507,6 +508,8 @@ func (e *XAIExecutor) prepareResponsesRequest(ctx context.Context, req cliproxye body, _ = sjson.DeleteBytes(body, "stream_options") body = normalizeXAITools(body) body = normalizeXAIToolChoiceForTools(body) + body = normalizeXAIInputCustomToolCalls(body) + body = dropXAIInputWebSearchCalls(body) body = normalizeXAIInputReasoningItems(body) body = normalizeCodexInstructions(body) body = sanitizeXAIResponsesBody(body, baseModel) @@ -599,6 +602,43 @@ func xaiExecutionSessionID(req cliproxyexecutor.Request, opts cliproxyexecutor.O return "" } +func xaiCompactionModel(model string, opts cliproxyexecutor.Options) string { + name := strings.ToLower(strings.TrimSpace(thinking.ParseSuffix(model).ModelName)) + if idx := strings.LastIndex(name, "/"); idx >= 0 { + name = name[idx+1:] + } + if name != "grok-build-0.1" || !xaiIsCompactionRequest(opts) { + return model + } + return "grok-4.3" +} + +func xaiIsCompactionRequest(opts cliproxyexecutor.Options) bool { + if opts.Alt == "responses/compact" { + return true + } + if opts.Headers != nil { + if xaiTurnMetadataIsCompaction(opts.Headers.Get("X-Codex-Turn-Metadata")) { + return true + } + } + if raw, ok := opts.Metadata["X-Codex-Turn-Metadata"]; ok { + return xaiTurnMetadataIsCompaction(fmt.Sprint(raw)) + } + if raw, ok := opts.Metadata["x-codex-turn-metadata"]; ok { + return xaiTurnMetadataIsCompaction(fmt.Sprint(raw)) + } + return false +} + +func xaiTurnMetadataIsCompaction(raw string) bool { + raw = strings.TrimSpace(raw) + if raw == "" || !gjson.Valid(raw) { + return false + } + return strings.TrimSpace(gjson.Get(raw, "request_kind").String()) == "compaction" +} + func xaiImageEndpointPath(opts cliproxyexecutor.Options) string { if opts.SourceFormat.String() != xaiImageHandlerType { return "" @@ -776,6 +816,105 @@ func normalizeXAITool(tool gjson.Result) ([]byte, bool, bool) { return raw, changed, true } +func normalizeXAIInputCustomToolCalls(body []byte) []byte { + input := gjson.GetBytes(body, "input") + if !input.Exists() || !input.IsArray() { + return body + } + + changed := false + items := make([]json.RawMessage, 0, len(input.Array())) + for _, item := range input.Array() { + raw := []byte(item.Raw) + switch item.Get("type").String() { + case "custom_tool_call": + updated, errSet := sjson.SetBytes(raw, "type", "function_call") + if errSet != nil { + return body + } + if inputArg := item.Get("input"); inputArg.Exists() { + updatedWithArgs, errArgs := sjson.SetBytes(updated, "arguments", xaiCustomToolCallArguments(inputArg)) + if errArgs != nil { + return body + } + updated = updatedWithArgs + updatedWithoutInput, errDel := sjson.DeleteBytes(updated, "input") + if errDel != nil { + return body + } + updated = updatedWithoutInput + } else if !item.Get("arguments").Exists() { + updatedWithArgs, errArgs := sjson.SetBytes(updated, "arguments", "{}") + if errArgs != nil { + return body + } + updated = updatedWithArgs + } + raw = updated + changed = true + case "custom_tool_call_output": + updated, errSet := sjson.SetBytes(raw, "type", "function_call_output") + if errSet != nil { + return body + } + raw = updated + changed = true + } + items = append(items, json.RawMessage(raw)) + } + if !changed { + return body + } + + rawInput, errMarshal := json.Marshal(items) + if errMarshal != nil { + return body + } + updated, errSet := sjson.SetRawBytes(body, "input", rawInput) + if errSet != nil { + return body + } + return updated +} + +func xaiCustomToolCallArguments(input gjson.Result) string { + payload, errSet := sjson.SetRawBytes([]byte(`{}`), "input", []byte(input.Raw)) + if errSet != nil { + return "{}" + } + return string(payload) +} + +func dropXAIInputWebSearchCalls(body []byte) []byte { + input := gjson.GetBytes(body, "input") + if !input.Exists() || !input.IsArray() { + return body + } + + changed := false + items := make([]json.RawMessage, 0, len(input.Array())) + for _, item := range input.Array() { + if item.Get("type").String() == "web_search_call" { + changed = true + continue + } + items = append(items, json.RawMessage(item.Raw)) + } + if !changed { + return body + } + + rawInput, errMarshal := json.Marshal(items) + if errMarshal != nil { + return body + } + updated, errSet := sjson.SetRawBytes(body, "input", rawInput) + if errSet != nil { + return body + } + return updated +} + func normalizeXAIInputReasoningItems(body []byte) []byte { input := gjson.GetBytes(body, "input") if !input.Exists() || !input.IsArray() { @@ -796,7 +935,7 @@ func normalizeXAIInputReasoningItems(body []byte) []byte { updated = updatedBody } encryptedContentPath := fmt.Sprintf("input.%d.encrypted_content", i) - if encryptedContent := gjson.GetBytes(updated, encryptedContentPath); encryptedContent.Exists() && encryptedContent.Type == gjson.Null { + if encryptedContent := gjson.GetBytes(updated, encryptedContentPath); encryptedContent.Exists() { updatedBody, errDel := sjson.DeleteBytes(updated, encryptedContentPath) if errDel != nil { return body diff --git a/internal/runtime/executor/xai_executor_test.go b/internal/runtime/executor/xai_executor_test.go index e8c11cf6ed..3bdeab810d 100644 --- a/internal/runtime/executor/xai_executor_test.go +++ b/internal/runtime/executor/xai_executor_test.go @@ -55,7 +55,7 @@ func TestXAIExecutorExecuteShapesResponsesRequest(t *testing.T) { _, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ Model: "grok-4.3", - Payload: []byte(`{"model":"grok-4.3","input":[{"type":"reasoning","summary":[{"type":"summary_text","text":"test"}],"content":null,"encrypted_content":null},{"type":"reasoning","summary":[{"type":"summary_text","text":"second"}]},{"role":"user","content":"hello"}],"include":["reasoning.encrypted_content"],"reasoning":{"effort":"high"},"tools":[{"type":"tool_search"},{"type":"image_generation"},{"type":"custom","name":"apply_patch"},{"type":"custom","name":"custom_lookup"},{"type":"function","name":"lookup"},{"type":"web_search","external_web_access":true,"search_content_types":["text","image"]},{"type":"namespace","name":"codex_app","description":"Tools in the codex_app namespace.","tools":[{"type":"function","name":"automation_update"},{"type":"custom","name":"namespace_custom"},{"type":"tool_search"}]}]}`), + Payload: []byte(`{"model":"grok-4.3","input":[{"type":"reasoning","summary":[{"type":"summary_text","text":"test"}],"content":null,"encrypted_content":"foreign-encrypted-content"},{"type":"reasoning","summary":[{"type":"summary_text","text":"second"}]},{"role":"user","content":"hello"}],"include":["reasoning.encrypted_content"],"reasoning":{"effort":"high"},"tools":[{"type":"tool_search"},{"type":"image_generation"},{"type":"custom","name":"apply_patch"},{"type":"custom","name":"custom_lookup"},{"type":"function","name":"lookup"},{"type":"web_search","external_web_access":true,"search_content_types":["text","image"]},{"type":"namespace","name":"codex_app","description":"Tools in the codex_app namespace.","tools":[{"type":"function","name":"automation_update"},{"type":"custom","name":"namespace_custom"},{"type":"tool_search"}]}]}`), }, cliproxyexecutor.Options{ SourceFormat: sdktranslator.FormatOpenAIResponse, Stream: false, @@ -196,6 +196,57 @@ func TestXAIExecutorOmitsUnsupportedReasoningEffort(t *testing.T) { } } +func TestXAIExecutorUsesLargeContextModelForGrokBuildCompaction(t *testing.T) { + exec := NewXAIExecutor(&config.Config{}) + req := cliproxyexecutor.Request{ + Model: "grok-build-0.1", + Payload: []byte(`{"model":"grok-build-0.1","input":"hello"}`), + } + + for _, opts := range []cliproxyexecutor.Options{ + { + SourceFormat: sdktranslator.FormatOpenAIResponse, + Alt: "responses/compact", + }, + { + SourceFormat: sdktranslator.FormatOpenAIResponse, + Headers: http.Header{ + "X-Codex-Turn-Metadata": []string{`{"request_kind":"compaction","compaction":{"trigger":"auto","reason":"context_limit"}}`}, + }, + }, + } { + prepared, err := exec.prepareResponsesRequest(context.Background(), req, opts, true) + if err != nil { + t.Fatalf("prepareResponsesRequest() error = %v", err) + } + if prepared.baseModel != "grok-4.3" { + t.Fatalf("baseModel = %q, want grok-4.3", prepared.baseModel) + } + if got := gjson.GetBytes(prepared.body, "model").String(); got != "grok-4.3" { + t.Fatalf("body model = %q, want grok-4.3; body=%s", got, string(prepared.body)) + } + } +} + +func TestXAIExecutorKeepsGrokBuildModelForNormalRequests(t *testing.T) { + exec := NewXAIExecutor(&config.Config{}) + prepared, err := exec.prepareResponsesRequest(context.Background(), cliproxyexecutor.Request{ + Model: "grok-build-0.1", + Payload: []byte(`{"model":"grok-build-0.1","input":"hello"}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FormatOpenAIResponse, + }, true) + if err != nil { + t.Fatalf("prepareResponsesRequest() error = %v", err) + } + if prepared.baseModel != "grok-build-0.1" { + t.Fatalf("baseModel = %q, want grok-build-0.1", prepared.baseModel) + } + if got := gjson.GetBytes(prepared.body, "model").String(); got != "grok-build-0.1" { + t.Fatalf("body model = %q, want grok-build-0.1; body=%s", got, string(prepared.body)) + } +} + func TestXAIExecutorAppliesThinkingSuffix(t *testing.T) { var gotBody []byte server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -646,3 +697,66 @@ func TestNormalizeXAIToolChoiceForTools_NoOpWhenBothAbsent(t *testing.T) { t.Fatalf("tool_choice should not appear: %s", string(out)) } } + +func TestNormalizeXAIInputCustomToolCalls(t *testing.T) { + body := []byte(`{"input":[{"type":"custom_tool_call","status":"completed","call_id":"call_1","name":"apply_patch","input":"*** patch"},{"type":"custom_tool_call_output","call_id":"call_1","output":"ok"},{"type":"message","role":"user","content":"hi"}]}`) + out := normalizeXAIInputCustomToolCalls(body) + + if got := gjson.GetBytes(out, "input.0.type").String(); got != "function_call" { + t.Fatalf("input.0.type = %q, want function_call: %s", got, string(out)) + } + if gjson.GetBytes(out, "input.0.input").Exists() { + t.Fatalf("custom input field should be removed from function_call: %s", string(out)) + } + arguments := gjson.GetBytes(out, "input.0.arguments").String() + if got := gjson.Get(arguments, "input").String(); got != "*** patch" { + t.Fatalf("function_call arguments input = %q, want patch payload: %s", got, string(out)) + } + if got := gjson.GetBytes(out, "input.1.type").String(); got != "function_call_output" { + t.Fatalf("input.1.type = %q, want function_call_output: %s", got, string(out)) + } + if got := gjson.GetBytes(out, "input.1.output").String(); got != "ok" { + t.Fatalf("input.1.output = %q, want ok: %s", got, string(out)) + } + if got := gjson.GetBytes(out, "input.2.type").String(); got != "message" { + t.Fatalf("input.2.type = %q, want message: %s", got, string(out)) + } +} + +func TestDropXAIInputWebSearchCalls(t *testing.T) { + body := []byte(`{"input":[{"type":"message","role":"user","content":"hi"},{"type":"web_search_call","status":"completed","action":{"type":"search","query":"test","queries":["test"]}},{"type":"function_call","call_id":"call_1","name":"lookup","arguments":"{}"},{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + out := dropXAIInputWebSearchCalls(body) + + if got := len(gjson.GetBytes(out, "input").Array()); got != 3 { + t.Fatalf("input length = %d, want 3: %s", got, string(out)) + } + if got := gjson.GetBytes(out, "input.0.type").String(); got != "message" { + t.Fatalf("input.0.type = %q, want message: %s", got, string(out)) + } + if got := gjson.GetBytes(out, "input.1.type").String(); got != "function_call" { + t.Fatalf("input.1.type = %q, want function_call: %s", got, string(out)) + } + if got := gjson.GetBytes(out, "input.2.type").String(); got != "function_call_output" { + t.Fatalf("input.2.type = %q, want function_call_output: %s", got, string(out)) + } + for _, item := range gjson.GetBytes(out, "input").Array() { + if item.Get("type").String() == "web_search_call" { + t.Fatalf("web_search_call should be dropped before sending to xAI: %s", string(out)) + } + } +} + +func TestNormalizeXAIInputReasoningItemsDropsEncryptedContent(t *testing.T) { + body := []byte(`{"input":[{"type":"reasoning","summary":[{"type":"summary_text","text":"kept"}],"content":null,"encrypted_content":"foreign-encrypted-content"},{"type":"message","role":"user","content":"hi"}]}`) + out := normalizeXAIInputReasoningItems(body) + + if gjson.GetBytes(out, "input.0.encrypted_content").Exists() { + t.Fatalf("encrypted_content should be removed before sending to xAI: %s", string(out)) + } + if gjson.GetBytes(out, "input.0.content").Exists() { + t.Fatalf("null reasoning content should be removed before sending to xAI: %s", string(out)) + } + if got := gjson.GetBytes(out, "input.0.summary.0.text").String(); got != "kept" { + t.Fatalf("summary text = %q, want kept: %s", got, string(out)) + } +}