Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 140 additions & 1 deletion internal/runtime/executor/xai_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ type xaiPreparedRequest struct {

func (e *XAIExecutor) prepareResponsesRequest(ctx context.Context, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) (*xaiPreparedRequest, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
baseModel = xaiCompactionModel(baseModel, opts)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
originalPayloadSource := req.Payload
Expand Down Expand Up @@ -507,6 +508,8 @@ func (e *XAIExecutor) prepareResponsesRequest(ctx context.Context, req cliproxye
body, _ = sjson.DeleteBytes(body, "stream_options")
body = normalizeXAITools(body)
body = normalizeXAIToolChoiceForTools(body)
body = normalizeXAIInputCustomToolCalls(body)
body = dropXAIInputWebSearchCalls(body)
body = normalizeXAIInputReasoningItems(body)
body = normalizeCodexInstructions(body)
body = sanitizeXAIResponsesBody(body, baseModel)
Expand Down Expand Up @@ -599,6 +602,43 @@ func xaiExecutionSessionID(req cliproxyexecutor.Request, opts cliproxyexecutor.O
return ""
}

func xaiCompactionModel(model string, opts cliproxyexecutor.Options) string {
name := strings.ToLower(strings.TrimSpace(thinking.ParseSuffix(model).ModelName))
if idx := strings.LastIndex(name, "/"); idx >= 0 {
name = name[idx+1:]
}
if name != "grok-build-0.1" || !xaiIsCompactionRequest(opts) {
return model
}
return "grok-4.3"
}

func xaiIsCompactionRequest(opts cliproxyexecutor.Options) bool {
if opts.Alt == "responses/compact" {
return true
}
if opts.Headers != nil {
if xaiTurnMetadataIsCompaction(opts.Headers.Get("X-Codex-Turn-Metadata")) {
return true
}
}
if raw, ok := opts.Metadata["X-Codex-Turn-Metadata"]; ok {
return xaiTurnMetadataIsCompaction(fmt.Sprint(raw))
}
if raw, ok := opts.Metadata["x-codex-turn-metadata"]; ok {
return xaiTurnMetadataIsCompaction(fmt.Sprint(raw))
}
return false
}

func xaiTurnMetadataIsCompaction(raw string) bool {
raw = strings.TrimSpace(raw)
if raw == "" || !gjson.Valid(raw) {
return false
}
return strings.TrimSpace(gjson.Get(raw, "request_kind").String()) == "compaction"
}

func xaiImageEndpointPath(opts cliproxyexecutor.Options) string {
if opts.SourceFormat.String() != xaiImageHandlerType {
return ""
Expand Down Expand Up @@ -776,6 +816,105 @@ func normalizeXAITool(tool gjson.Result) ([]byte, bool, bool) {
return raw, changed, true
}

func normalizeXAIInputCustomToolCalls(body []byte) []byte {
input := gjson.GetBytes(body, "input")
if !input.Exists() || !input.IsArray() {
return body
}

changed := false
items := make([]json.RawMessage, 0, len(input.Array()))
for _, item := range input.Array() {
raw := []byte(item.Raw)
switch item.Get("type").String() {
case "custom_tool_call":
updated, errSet := sjson.SetBytes(raw, "type", "function_call")
if errSet != nil {
return body
}
if inputArg := item.Get("input"); inputArg.Exists() {
updatedWithArgs, errArgs := sjson.SetBytes(updated, "arguments", xaiCustomToolCallArguments(inputArg))
if errArgs != nil {
return body
}
updated = updatedWithArgs
updatedWithoutInput, errDel := sjson.DeleteBytes(updated, "input")
if errDel != nil {
return body
}
updated = updatedWithoutInput
} else if !item.Get("arguments").Exists() {
updatedWithArgs, errArgs := sjson.SetBytes(updated, "arguments", "{}")
if errArgs != nil {
return body
}
updated = updatedWithArgs
}
raw = updated
changed = true
case "custom_tool_call_output":
updated, errSet := sjson.SetBytes(raw, "type", "function_call_output")
if errSet != nil {
return body
}
raw = updated
changed = true
}
items = append(items, json.RawMessage(raw))
}
if !changed {
return body
}

rawInput, errMarshal := json.Marshal(items)
if errMarshal != nil {
return body
}
updated, errSet := sjson.SetRawBytes(body, "input", rawInput)
if errSet != nil {
return body
}
return updated
}

func xaiCustomToolCallArguments(input gjson.Result) string {
payload, errSet := sjson.SetRawBytes([]byte(`{}`), "input", []byte(input.Raw))
if errSet != nil {
return "{}"
}
return string(payload)
}

func dropXAIInputWebSearchCalls(body []byte) []byte {
input := gjson.GetBytes(body, "input")
if !input.Exists() || !input.IsArray() {
return body
}

changed := false
items := make([]json.RawMessage, 0, len(input.Array()))
for _, item := range input.Array() {
if item.Get("type").String() == "web_search_call" {
changed = true
continue
}
items = append(items, json.RawMessage(item.Raw))
}
if !changed {
return body
}

rawInput, errMarshal := json.Marshal(items)
if errMarshal != nil {
return body
}
updated, errSet := sjson.SetRawBytes(body, "input", rawInput)
if errSet != nil {
return body
}
return updated
}

func normalizeXAIInputReasoningItems(body []byte) []byte {
input := gjson.GetBytes(body, "input")
if !input.Exists() || !input.IsArray() {
Expand All @@ -796,7 +935,7 @@ func normalizeXAIInputReasoningItems(body []byte) []byte {
updated = updatedBody
}
encryptedContentPath := fmt.Sprintf("input.%d.encrypted_content", i)
if encryptedContent := gjson.GetBytes(updated, encryptedContentPath); encryptedContent.Exists() && encryptedContent.Type == gjson.Null {
if encryptedContent := gjson.GetBytes(updated, encryptedContentPath); encryptedContent.Exists() {
updatedBody, errDel := sjson.DeleteBytes(updated, encryptedContentPath)
if errDel != nil {
return body
Expand Down
116 changes: 115 additions & 1 deletion internal/runtime/executor/xai_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestXAIExecutorExecuteShapesResponsesRequest(t *testing.T) {

_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "grok-4.3",
Payload: []byte(`{"model":"grok-4.3","input":[{"type":"reasoning","summary":[{"type":"summary_text","text":"test"}],"content":null,"encrypted_content":null},{"type":"reasoning","summary":[{"type":"summary_text","text":"second"}]},{"role":"user","content":"hello"}],"include":["reasoning.encrypted_content"],"reasoning":{"effort":"high"},"tools":[{"type":"tool_search"},{"type":"image_generation"},{"type":"custom","name":"apply_patch"},{"type":"custom","name":"custom_lookup"},{"type":"function","name":"lookup"},{"type":"web_search","external_web_access":true,"search_content_types":["text","image"]},{"type":"namespace","name":"codex_app","description":"Tools in the codex_app namespace.","tools":[{"type":"function","name":"automation_update"},{"type":"custom","name":"namespace_custom"},{"type":"tool_search"}]}]}`),
Payload: []byte(`{"model":"grok-4.3","input":[{"type":"reasoning","summary":[{"type":"summary_text","text":"test"}],"content":null,"encrypted_content":"foreign-encrypted-content"},{"type":"reasoning","summary":[{"type":"summary_text","text":"second"}]},{"role":"user","content":"hello"}],"include":["reasoning.encrypted_content"],"reasoning":{"effort":"high"},"tools":[{"type":"tool_search"},{"type":"image_generation"},{"type":"custom","name":"apply_patch"},{"type":"custom","name":"custom_lookup"},{"type":"function","name":"lookup"},{"type":"web_search","external_web_access":true,"search_content_types":["text","image"]},{"type":"namespace","name":"codex_app","description":"Tools in the codex_app namespace.","tools":[{"type":"function","name":"automation_update"},{"type":"custom","name":"namespace_custom"},{"type":"tool_search"}]}]}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatOpenAIResponse,
Stream: false,
Expand Down Expand Up @@ -196,6 +196,57 @@ func TestXAIExecutorOmitsUnsupportedReasoningEffort(t *testing.T) {
}
}

func TestXAIExecutorUsesLargeContextModelForGrokBuildCompaction(t *testing.T) {
exec := NewXAIExecutor(&config.Config{})
req := cliproxyexecutor.Request{
Model: "grok-build-0.1",
Payload: []byte(`{"model":"grok-build-0.1","input":"hello"}`),
}

for _, opts := range []cliproxyexecutor.Options{
{
SourceFormat: sdktranslator.FormatOpenAIResponse,
Alt: "responses/compact",
},
{
SourceFormat: sdktranslator.FormatOpenAIResponse,
Headers: http.Header{
"X-Codex-Turn-Metadata": []string{`{"request_kind":"compaction","compaction":{"trigger":"auto","reason":"context_limit"}}`},
},
},
} {
prepared, err := exec.prepareResponsesRequest(context.Background(), req, opts, true)
if err != nil {
t.Fatalf("prepareResponsesRequest() error = %v", err)
}
if prepared.baseModel != "grok-4.3" {
t.Fatalf("baseModel = %q, want grok-4.3", prepared.baseModel)
}
if got := gjson.GetBytes(prepared.body, "model").String(); got != "grok-4.3" {
t.Fatalf("body model = %q, want grok-4.3; body=%s", got, string(prepared.body))
}
}
}

func TestXAIExecutorKeepsGrokBuildModelForNormalRequests(t *testing.T) {
exec := NewXAIExecutor(&config.Config{})
prepared, err := exec.prepareResponsesRequest(context.Background(), cliproxyexecutor.Request{
Model: "grok-build-0.1",
Payload: []byte(`{"model":"grok-build-0.1","input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatOpenAIResponse,
}, true)
if err != nil {
t.Fatalf("prepareResponsesRequest() error = %v", err)
}
if prepared.baseModel != "grok-build-0.1" {
t.Fatalf("baseModel = %q, want grok-build-0.1", prepared.baseModel)
}
if got := gjson.GetBytes(prepared.body, "model").String(); got != "grok-build-0.1" {
t.Fatalf("body model = %q, want grok-build-0.1; body=%s", got, string(prepared.body))
}
}

func TestXAIExecutorAppliesThinkingSuffix(t *testing.T) {
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -646,3 +697,66 @@ func TestNormalizeXAIToolChoiceForTools_NoOpWhenBothAbsent(t *testing.T) {
t.Fatalf("tool_choice should not appear: %s", string(out))
}
}

func TestNormalizeXAIInputCustomToolCalls(t *testing.T) {
body := []byte(`{"input":[{"type":"custom_tool_call","status":"completed","call_id":"call_1","name":"apply_patch","input":"*** patch"},{"type":"custom_tool_call_output","call_id":"call_1","output":"ok"},{"type":"message","role":"user","content":"hi"}]}`)
out := normalizeXAIInputCustomToolCalls(body)

if got := gjson.GetBytes(out, "input.0.type").String(); got != "function_call" {
t.Fatalf("input.0.type = %q, want function_call: %s", got, string(out))
}
if gjson.GetBytes(out, "input.0.input").Exists() {
t.Fatalf("custom input field should be removed from function_call: %s", string(out))
}
arguments := gjson.GetBytes(out, "input.0.arguments").String()
if got := gjson.Get(arguments, "input").String(); got != "*** patch" {
t.Fatalf("function_call arguments input = %q, want patch payload: %s", got, string(out))
}
if got := gjson.GetBytes(out, "input.1.type").String(); got != "function_call_output" {
t.Fatalf("input.1.type = %q, want function_call_output: %s", got, string(out))
}
if got := gjson.GetBytes(out, "input.1.output").String(); got != "ok" {
t.Fatalf("input.1.output = %q, want ok: %s", got, string(out))
}
if got := gjson.GetBytes(out, "input.2.type").String(); got != "message" {
t.Fatalf("input.2.type = %q, want message: %s", got, string(out))
}
}

func TestDropXAIInputWebSearchCalls(t *testing.T) {
body := []byte(`{"input":[{"type":"message","role":"user","content":"hi"},{"type":"web_search_call","status":"completed","action":{"type":"search","query":"test","queries":["test"]}},{"type":"function_call","call_id":"call_1","name":"lookup","arguments":"{}"},{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
out := dropXAIInputWebSearchCalls(body)

if got := len(gjson.GetBytes(out, "input").Array()); got != 3 {
t.Fatalf("input length = %d, want 3: %s", got, string(out))
}
if got := gjson.GetBytes(out, "input.0.type").String(); got != "message" {
t.Fatalf("input.0.type = %q, want message: %s", got, string(out))
}
if got := gjson.GetBytes(out, "input.1.type").String(); got != "function_call" {
t.Fatalf("input.1.type = %q, want function_call: %s", got, string(out))
}
if got := gjson.GetBytes(out, "input.2.type").String(); got != "function_call_output" {
t.Fatalf("input.2.type = %q, want function_call_output: %s", got, string(out))
}
for _, item := range gjson.GetBytes(out, "input").Array() {
if item.Get("type").String() == "web_search_call" {
t.Fatalf("web_search_call should be dropped before sending to xAI: %s", string(out))
}
}
}

func TestNormalizeXAIInputReasoningItemsDropsEncryptedContent(t *testing.T) {
body := []byte(`{"input":[{"type":"reasoning","summary":[{"type":"summary_text","text":"kept"}],"content":null,"encrypted_content":"foreign-encrypted-content"},{"type":"message","role":"user","content":"hi"}]}`)
out := normalizeXAIInputReasoningItems(body)

if gjson.GetBytes(out, "input.0.encrypted_content").Exists() {
t.Fatalf("encrypted_content should be removed before sending to xAI: %s", string(out))
}
if gjson.GetBytes(out, "input.0.content").Exists() {
t.Fatalf("null reasoning content should be removed before sending to xAI: %s", string(out))
}
if got := gjson.GetBytes(out, "input.0.summary.0.text").String(); got != "kept" {
t.Fatalf("summary text = %q, want kept: %s", got, string(out))
}
}
Loading