Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions backend/internal/handler/admin/account_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,10 @@ func (h *AccountHandler) Delete(c *gin.Context) {

// TestAccountRequest represents the request body for testing an account
type TestAccountRequest struct {
ModelID string `json:"model_id"`
Prompt string `json:"prompt"`
Mode string `json:"mode"`
ModelID string `json:"model_id"`
Prompt string `json:"prompt"`
Mode string `json:"mode"`
PromptOptimization *bool `json:"prompt_optimization"`
}

type SyncFromCRSRequest struct {
Expand Down Expand Up @@ -731,7 +732,7 @@ func (h *AccountHandler) Test(c *gin.Context) {
_ = c.ShouldBindJSON(&req)

// Use AccountTestService to test the account with SSE streaming
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt, req.Mode); err != nil {
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt, req.Mode, req.PromptOptimization); err != nil {
// Error already sent via SSE, just log
return
}
Expand Down
20 changes: 10 additions & 10 deletions backend/internal/service/account_test_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
// All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
// mode is optional - "compact" routes OpenAI accounts to the /responses/compact probe path
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string, mode string) error {
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string, mode string, promptOptimization *bool) error {
ctx := c.Request.Context()

// Get account
Expand All @@ -181,7 +181,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int

// Route to platform-specific test method
if account.IsOpenAI() {
return s.testOpenAIAccountConnection(c, account, modelID, prompt, normalizeAccountTestMode(mode))
return s.testOpenAIAccountConnection(c, account, modelID, prompt, normalizeAccountTestMode(mode), promptOptimization)
}

if account.IsGemini() {
Expand Down Expand Up @@ -492,9 +492,8 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
}

// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string, mode string) error {
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string, mode string, promptOptimization *bool) error {
ctx := c.Request.Context()
_ = prompt
mode = normalizeAccountTestMode(mode)

// Default to openai.DefaultTestModel for OpenAI testing
Expand All @@ -520,7 +519,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if account.Type == "apikey" {
return s.testOpenAIImageAPIKey(c, ctx, account, testModelID, imagePrompt)
}
return s.testOpenAIImageOAuth(c, ctx, account, testModelID, imagePrompt)
return s.testOpenAIImageOAuth(c, ctx, account, testModelID, imagePrompt, promptOptimization)
}

// Determine authentication method and API URL
Expand Down Expand Up @@ -1415,7 +1414,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C
}

// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via Codex /responses API.
func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error {
func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string, promptOptimization *bool) error {
authToken := account.GetOpenAIAccessToken()
if authToken == "" {
return s.sendErrorAndEnd(c, "No access token available")
Expand All @@ -1432,9 +1431,10 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
s.sendEvent(c, TestEvent{Type: "content", Text: "Calling Codex /responses image tool...\n"})

parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesGenerationsEndpoint,
Model: strings.TrimSpace(modelID),
Prompt: prompt,
Endpoint: openAIImagesGenerationsEndpoint,
Model: strings.TrimSpace(modelID),
Prompt: prompt,
PromptOptimization: promptOptimization,
}
applyOpenAIImagesDefaults(parsed)

Expand Down Expand Up @@ -1538,7 +1538,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
ginCtx, _ := gin.CreateTestContext(w)
ginCtx.Request = (&http.Request{}).WithContext(ctx)

testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "", AccountTestModeDefault)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "", AccountTestModeDefault, nil)

finishedAt := time.Now()
body := w.Body.String()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestAccountTestService_TestAccountConnection_OpenAICompactOAuthSuccessPersi
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", bytes.NewReader(nil))

err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact, nil)
require.NoError(t, err)

require.Equal(t, chatgptCodexAPIURL+"/compact", upstream.lastReq.URL.String())
Expand Down Expand Up @@ -102,7 +102,7 @@ func TestAccountTestService_TestAccountConnection_OpenAICompactOAuth404MarksUnsu
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/2/test", bytes.NewReader(nil))

err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact, nil)
require.Error(t, err)

updates := <-updateCalls
Expand Down Expand Up @@ -148,7 +148,7 @@ func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyUsesCompact
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/3/test", bytes.NewReader(nil))

err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact, nil)
require.NoError(t, err)

require.Equal(t, "https://example.com/v1/responses/compact", upstream.lastReq.URL.String())
Expand Down Expand Up @@ -192,7 +192,7 @@ func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyDefaultBase
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/4/test", bytes.NewReader(nil))

err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact, nil)
require.NoError(t, err)
require.Equal(t, "https://api.openai.com/v1/responses/compact", upstream.lastReq.URL.String())
<-updateCalls
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,50 @@ func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *tes
},
}

err := svc.testOpenAIImageOAuth(c, context.Background(), account, "gpt-image-2", "draw a cat")
err := svc.testOpenAIImageOAuth(c, context.Background(), account, "gpt-image-2", "draw a cat", nil)
require.NoError(t, err)
require.Contains(t, rec.Body.String(), "Calling Codex /responses image tool")
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
require.Contains(t, rec.Body.String(), "\"success\":true")
}

func TestAccountTestService_OpenAIImageOAuthCanDisablePromptOptimization(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)

upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"output_format\":\"png\"}}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000006,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 53,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token-123",
},
}

disabled := false
err := svc.testOpenAIImageOAuth(c, context.Background(), account, "gpt-image-2", "draw a cat", &disabled)
require.NoError(t, err)
require.Contains(t, string(upstream.lastBody), "Do not rewrite")
require.Contains(t, rec.Body.String(), "\"success\":true")
}

func TestAccountTestService_OpenAIImageAPIKeyUsesConfiguredV1BaseURL(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
Expand Down
16 changes: 8 additions & 8 deletions backend/internal/service/account_test_service_openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
Credentials: map[string]any{"access_token": "test-token"},
}

err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "", nil)
require.NoError(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
Expand All @@ -152,7 +152,7 @@ func TestAccountTestService_OpenAIStreamEOFBeforeCompletedFails(t *testing.T) {
Credentials: map[string]any{"access_token": "test-token"},
}

err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "", nil)
require.Error(t, err)
require.Contains(t, recorder.Body.String(), "response.completed")
require.NotContains(t, recorder.Body.String(), `"success":true`)
Expand Down Expand Up @@ -182,7 +182,7 @@ func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testin
Credentials: map[string]any{"access_token": "test-token"},
}

err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "", nil)
require.Error(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
Expand Down Expand Up @@ -213,7 +213,7 @@ func TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleErro
Credentials: map[string]any{"access_token": "test-token"},
}

err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "", nil)
require.Error(t, err)
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
Expand Down Expand Up @@ -242,7 +242,7 @@ func TestAccountTestService_OpenAI429SyncsObservedPlanType(t *testing.T) {
Credentials: map[string]any{"access_token": "test-token", "plan_type": "plus"},
}

err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "", nil)
require.Error(t, err)
require.Equal(t, []int64{account.ID}, repo.bulkUpdatedIDs)
require.Equal(t, "free", repo.bulkUpdatedPayload.Credentials["plan_type"])
Expand All @@ -269,7 +269,7 @@ func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T
Credentials: map[string]any{"access_token": "test-token"},
}

err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "", nil)
require.Error(t, err)
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
Expand Down Expand Up @@ -297,7 +297,7 @@ func TestAccountTestService_OpenAI429WithoutResetSignalDoesNotMutateRuntimeState
Credentials: map[string]any{"access_token": "test-token"},
}

err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "", nil)
require.Error(t, err)
require.Zero(t, repo.rateLimitedID)
require.Nil(t, repo.rateLimitedAt)
Expand Down Expand Up @@ -325,7 +325,7 @@ func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) {
Credentials: map[string]any{"access_token": "test-token"},
}

err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "", nil)
require.Error(t, err)
require.Equal(t, account.ID, repo.setErrorID)
require.Contains(t, repo.setErrorMsg, "Authentication failed (401)")
Expand Down
76 changes: 70 additions & 6 deletions backend/internal/service/openai_images.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type OpenAIImagesRequest struct {
Model string
ExplicitModel bool
Prompt string
PromptOptimization *bool
Stream bool
N int
Size string
Expand Down Expand Up @@ -229,6 +230,11 @@ func parseOpenAIImagesJSONRequest(body []byte, req *OpenAIImagesRequest) error {
req.ExplicitModel = req.Model != ""
}
req.Prompt = strings.TrimSpace(gjson.GetBytes(body, "prompt").String())
promptOptimization, err := parseOpenAIImagesJSONOptionalBool(body, "prompt_optimization", "promptOptimization")
if err != nil {
return err
}
req.PromptOptimization = promptOptimization

if streamResult := gjson.GetBytes(body, "stream"); streamResult.Exists() {
if streamResult.Type != gjson.True && streamResult.Type != gjson.False {
Expand Down Expand Up @@ -374,6 +380,12 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope
req.ExplicitModel = value != ""
case "prompt":
req.Prompt = value
case "prompt_optimization", "promptOptimization":
parsed, err := strconv.ParseBool(value)
if err != nil {
return fmt.Errorf("invalid %s field value", name)
}
req.PromptOptimization = &parsed
case "size":
req.Size = value
req.ExplicitSize = value != ""
Expand Down Expand Up @@ -436,6 +448,29 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope
return nil
}

func parseOpenAIImagesJSONOptionalBool(body []byte, paths ...string) (*bool, error) {
for _, path := range paths {
result := gjson.GetBytes(body, path)
if !result.Exists() {
continue
}
switch result.Type {
case gjson.True, gjson.False:
value := result.Bool()
return &value, nil
case gjson.String:
value, err := strconv.ParseBool(strings.TrimSpace(result.String()))
if err != nil {
return nil, fmt.Errorf("invalid %s field value", path)
}
return &value, nil
default:
return nil, fmt.Errorf("invalid %s field type", path)
}
}
return nil, nil
}

func parseOpenAIImageDimensions(_ textproto.MIMEHeader) (int, int) {
return 0, 0
}
Expand Down Expand Up @@ -808,21 +843,46 @@ func buildOpenAIImagesURL(base string, endpoint string) string {

func rewriteOpenAIImagesModel(body []byte, contentType string, model string) ([]byte, string, error) {
model = strings.TrimSpace(model)
if model == "" {
return body, contentType, nil
}
mediaType, _, err := mime.ParseMediaType(contentType)
if err == nil && strings.EqualFold(mediaType, "multipart/form-data") {
rewrittenBody, rewrittenType, rewriteErr := rewriteOpenAIImagesMultipartModel(body, contentType, model)
return rewrittenBody, rewrittenType, rewriteErr
}
rewritten, err := sjson.SetBytes(body, "model", model)
rewritten, err := stripOpenAIImagesInternalJSONFields(body)
if err != nil {
return nil, "", err
}
if model == "" {
return rewritten, contentType, nil
}
rewritten, err = sjson.SetBytes(rewritten, "model", model)
if err != nil {
return nil, "", fmt.Errorf("rewrite image request model: %w", err)
}
return rewritten, contentType, nil
}

func stripOpenAIImagesInternalJSONFields(body []byte) ([]byte, error) {
rewritten := body
var err error
for _, path := range []string{"prompt_optimization", "promptOptimization"} {
rewritten, err = sjson.DeleteBytes(rewritten, path)
if err != nil {
return nil, fmt.Errorf("strip image request field %s: %w", path, err)
}
}
return rewritten, nil
}

func isOpenAIImagesInternalMultipartField(name string) bool {
switch strings.TrimSpace(name) {
case "prompt_optimization", "promptOptimization":
return true
default:
return false
}
}

func rewriteOpenAIImagesMultipartModel(body []byte, contentType string, model string) ([]byte, string, error) {
_, params, err := mime.ParseMediaType(contentType)
if err != nil {
Expand All @@ -848,14 +908,18 @@ func rewriteOpenAIImagesMultipartModel(body []byte, contentType string, model st
}

formName := strings.TrimSpace(part.FormName())
if part.FileName() == "" && isOpenAIImagesInternalMultipartField(formName) {
_ = part.Close()
continue
}
partHeader := cloneMultipartHeader(part.Header)
target, err := writer.CreatePart(partHeader)
if err != nil {
_ = part.Close()
return nil, "", fmt.Errorf("create multipart part: %w", err)
}

if formName == "model" && part.FileName() == "" {
if formName == "model" && part.FileName() == "" && model != "" {
if _, err := target.Write([]byte(model)); err != nil {
_ = part.Close()
return nil, "", fmt.Errorf("rewrite multipart model: %w", err)
Expand All @@ -871,7 +935,7 @@ func rewriteOpenAIImagesMultipartModel(body []byte, contentType string, model st
_ = part.Close()
}

if !modelWritten {
if !modelWritten && model != "" {
if err := writer.WriteField("model", model); err != nil {
return nil, "", fmt.Errorf("append multipart model field: %w", err)
}
Expand Down
Loading
Loading