Skip to content

Commit b2ec5b3

Browse files
committed
feat: add background prompt optimization for image tasks
1 parent 7572a00 commit b2ec5b3

18 files changed

Lines changed: 935 additions & 161 deletions

File tree

backend/internal/api/folders.go

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -261,29 +261,32 @@ func GetFoldersHandler(c *gin.Context) {
261261
}
262262

263263
type FolderImageTaskResponse struct {
264-
TaskID string `json:"task_id"`
265-
Prompt string `json:"prompt"`
266-
ModelID string `json:"model_id,omitempty"`
267-
ProviderName string `json:"provider_name,omitempty"`
268-
LocalPath string `json:"local_path,omitempty"`
269-
ThumbnailPath string `json:"thumbnail_path,omitempty"`
270-
ImageURL string `json:"image_url,omitempty"`
271-
ThumbnailURL string `json:"thumbnail_url,omitempty"`
272-
ImageSource *ImageSource `json:"image_source,omitempty"`
273-
ThumbnailSource *ImageSource `json:"thumbnail_source,omitempty"`
274-
Width int `json:"width,omitempty"`
275-
Height int `json:"height,omitempty"`
276-
CreatedAt string `json:"created_at"`
277-
UpdatedAt string `json:"updated_at,omitempty"`
278-
Status string `json:"status"`
279-
TotalCount int `json:"total_count,omitempty"`
280-
ErrorMessage string `json:"error_message,omitempty"`
281-
ErrorCode string `json:"error_code,omitempty"`
282-
ErrorCategory string `json:"error_category,omitempty"`
283-
ErrorRequestID string `json:"error_request_id,omitempty"`
284-
ErrorRetryable bool `json:"error_retryable,omitempty"`
285-
ErrorDetail string `json:"error_detail,omitempty"`
286-
ConfigSnap string `json:"config_snapshot,omitempty"`
264+
TaskID string `json:"task_id"`
265+
Prompt string `json:"prompt"`
266+
PromptOriginal string `json:"prompt_original,omitempty"`
267+
PromptOptimized string `json:"prompt_optimized,omitempty"`
268+
PromptOptimizeMode string `json:"prompt_optimize_mode,omitempty"`
269+
ModelID string `json:"model_id,omitempty"`
270+
ProviderName string `json:"provider_name,omitempty"`
271+
LocalPath string `json:"local_path,omitempty"`
272+
ThumbnailPath string `json:"thumbnail_path,omitempty"`
273+
ImageURL string `json:"image_url,omitempty"`
274+
ThumbnailURL string `json:"thumbnail_url,omitempty"`
275+
ImageSource *ImageSource `json:"image_source,omitempty"`
276+
ThumbnailSource *ImageSource `json:"thumbnail_source,omitempty"`
277+
Width int `json:"width,omitempty"`
278+
Height int `json:"height,omitempty"`
279+
CreatedAt string `json:"created_at"`
280+
UpdatedAt string `json:"updated_at,omitempty"`
281+
Status string `json:"status"`
282+
TotalCount int `json:"total_count,omitempty"`
283+
ErrorMessage string `json:"error_message,omitempty"`
284+
ErrorCode string `json:"error_code,omitempty"`
285+
ErrorCategory string `json:"error_category,omitempty"`
286+
ErrorRequestID string `json:"error_request_id,omitempty"`
287+
ErrorRetryable bool `json:"error_retryable,omitempty"`
288+
ErrorDetail string `json:"error_detail,omitempty"`
289+
ConfigSnap string `json:"config_snapshot,omitempty"`
287290
}
288291

289292
// GetFolderImagesHandler 获取指定文件夹下的图片列表(分页)
@@ -361,28 +364,31 @@ func GetFolderImagesHandler(c *gin.Context) {
361364
for i, task := range tasks {
362365
enrichTaskError(&task)
363366
responses[i] = FolderImageTaskResponse{
364-
TaskID: task.TaskID,
365-
Prompt: task.Prompt,
366-
ModelID: task.ModelID,
367-
ProviderName: task.ProviderName,
368-
LocalPath: toPublicImagePath(task.LocalPath),
369-
ThumbnailPath: toPublicImagePath(task.ThumbnailPath),
370-
ImageURL: strings.TrimSpace(task.ImageURL),
371-
ThumbnailURL: strings.TrimSpace(task.ThumbnailURL),
372-
ImageSource: pickFirstImageSource(task.LocalPath, task.ImageURL, task.ThumbnailPath, task.ThumbnailURL),
373-
ThumbnailSource: pickFirstImageSource(task.ThumbnailPath, task.LocalPath, task.ThumbnailURL, task.ImageURL),
374-
Width: task.Width,
375-
Height: task.Height,
376-
CreatedAt: task.CreatedAt.Format(time.RFC3339),
377-
Status: task.Status,
378-
TotalCount: task.TotalCount,
379-
ErrorMessage: task.ErrorMessage,
380-
ErrorCode: task.ErrorCode,
381-
ErrorCategory: task.ErrorCategory,
382-
ErrorRequestID: task.ErrorRequestID,
383-
ErrorRetryable: task.ErrorRetryable,
384-
ErrorDetail: task.ErrorDetail,
385-
ConfigSnap: task.ConfigSnapshot,
367+
TaskID: task.TaskID,
368+
Prompt: task.Prompt,
369+
PromptOriginal: task.PromptOriginal,
370+
PromptOptimized: task.PromptOptimized,
371+
PromptOptimizeMode: task.PromptOptimizeMode,
372+
ModelID: task.ModelID,
373+
ProviderName: task.ProviderName,
374+
LocalPath: toPublicImagePath(task.LocalPath),
375+
ThumbnailPath: toPublicImagePath(task.ThumbnailPath),
376+
ImageURL: strings.TrimSpace(task.ImageURL),
377+
ThumbnailURL: strings.TrimSpace(task.ThumbnailURL),
378+
ImageSource: pickFirstImageSource(task.LocalPath, task.ImageURL, task.ThumbnailPath, task.ThumbnailURL),
379+
ThumbnailSource: pickFirstImageSource(task.ThumbnailPath, task.LocalPath, task.ThumbnailURL, task.ImageURL),
380+
Width: task.Width,
381+
Height: task.Height,
382+
CreatedAt: task.CreatedAt.Format(time.RFC3339),
383+
Status: task.Status,
384+
TotalCount: task.TotalCount,
385+
ErrorMessage: task.ErrorMessage,
386+
ErrorCode: task.ErrorCode,
387+
ErrorCategory: task.ErrorCategory,
388+
ErrorRequestID: task.ErrorRequestID,
389+
ErrorRetryable: task.ErrorRetryable,
390+
ErrorDetail: task.ErrorDetail,
391+
ConfigSnap: task.ConfigSnapshot,
386392
}
387393
}
388394

backend/internal/api/handlers.go

Lines changed: 48 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"image-gen-service/internal/diagnostic"
2222
"image-gen-service/internal/model"
2323
"image-gen-service/internal/platform"
24+
"image-gen-service/internal/promptopt"
2425
"image-gen-service/internal/provider"
2526
"image-gen-service/internal/storage"
2627
"image-gen-service/internal/worker"
@@ -99,6 +100,9 @@ func buildConfigSnapshot(providerName, modelID string, params map[string]interfa
99100
} else if v, ok := params["count"].(float64); ok && v > 0 {
100101
snapshot["count"] = int(v)
101102
}
103+
if v, ok := params["prompt_optimize_mode"].(string); ok && strings.TrimSpace(v) != "" && strings.TrimSpace(v) != promptopt.ModeOff {
104+
snapshot["promptOptimizeMode"] = strings.TrimSpace(v)
105+
}
102106

103107
b, err := json.Marshal(snapshot)
104108
if err != nil {
@@ -345,59 +349,23 @@ func OptimizePromptHandler(c *gin.Context) {
345349
return
346350
}
347351

348-
providerName := strings.TrimSpace(strings.ToLower(req.Provider))
349-
if providerName == "" {
350-
providerName = "openai-chat"
351-
}
352-
if providerName == "openai" {
353-
providerName = "openai-chat"
354-
}
355-
if providerName == "gemini" {
356-
providerName = "gemini-chat"
357-
}
358-
req.Provider = providerName
359-
if strings.TrimSpace(req.Prompt) == "" {
360-
Error(c, http.StatusBadRequest, 400, "prompt 不能为空")
361-
return
362-
}
363-
364-
var cfg model.ProviderConfig
365-
if err := model.DB.Where("provider_name = ?", req.Provider).First(&cfg).Error; err != nil {
366-
Error(c, http.StatusBadRequest, 400, "未找到指定的 Provider: "+req.Provider)
367-
return
368-
}
369-
if strings.TrimSpace(cfg.APIKey) == "" {
370-
Error(c, http.StatusBadRequest, 400, "Provider API Key 未配置")
371-
return
372-
}
373-
374-
modelName := provider.ResolveModelID(provider.ModelResolveOptions{
375-
ProviderName: req.Provider,
376-
Purpose: provider.PurposeChat,
377-
RequestModel: req.Model,
378-
Config: &cfg,
379-
}).ID
380-
if modelName == "" {
381-
Error(c, http.StatusBadRequest, 400, "未找到可用的模型")
382-
return
383-
}
384-
385-
responseFormat := strings.ToLower(strings.TrimSpace(req.ResponseFormat))
386-
forceJSON := responseFormat == "json" || responseFormat == "json_object" || responseFormat == "application/json"
387-
388-
var optimized string
389-
var err error
390-
if req.Provider == "gemini-chat" {
391-
optimized, err = callGeminiOptimize(c.Request.Context(), &cfg, modelName, req.Prompt, forceJSON)
392-
} else {
393-
optimized, err = callOpenAIOptimize(c.Request.Context(), &cfg, modelName, req.Prompt, forceJSON)
394-
}
352+
result, err := promptopt.OptimizePrompt(c.Request.Context(), promptopt.Request{
353+
Provider: req.Provider,
354+
Model: req.Model,
355+
Prompt: req.Prompt,
356+
Mode: func() string {
357+
if strings.TrimSpace(req.ResponseFormat) == "" {
358+
return promptopt.ModeText
359+
}
360+
return req.ResponseFormat
361+
}(),
362+
})
395363
if err != nil {
396364
Error(c, http.StatusBadRequest, 400, err.Error())
397365
return
398366
}
399367

400-
Success(c, gin.H{"prompt": optimized})
368+
Success(c, gin.H{"prompt": result.Prompt})
401369
}
402370

403371
// GenerateHandler 处理图片生成请求
@@ -419,6 +387,16 @@ func GenerateHandler(c *gin.Context) {
419387
req.Params = map[string]interface{}{}
420388
}
421389
diagnostic.AttachVerboseFlag(req.Params, diagnostic.VerboseEnabled(req.Params))
390+
promptOptimizeMode := promptopt.NormalizeMode(strings.TrimSpace(fmt.Sprint(req.Params["prompt_optimize_mode"])))
391+
if promptOptimizeMode != promptopt.ModeOff {
392+
promptopt.ApplyPromptHints(
393+
req.Params,
394+
strings.TrimSpace(fmt.Sprint(req.Params["prompt"])),
395+
promptOptimizeMode,
396+
strings.TrimSpace(fmt.Sprint(req.Params["prompt_optimize_provider"])),
397+
strings.TrimSpace(fmt.Sprint(req.Params["prompt_optimize_model"])),
398+
)
399+
}
422400
modelID := provider.ResolveModelID(provider.ModelResolveOptions{
423401
ProviderName: req.Provider,
424402
Purpose: provider.PurposeImage,
@@ -444,13 +422,15 @@ func GenerateHandler(c *gin.Context) {
444422
}
445423

446424
taskModel := &model.Task{
447-
TaskID: taskID,
448-
Prompt: prompt,
449-
ProviderName: req.Provider,
450-
ModelID: modelID,
451-
TotalCount: 1, // 目前单次请求只生成一张,后续可扩展
452-
Status: "pending",
453-
ConfigSnapshot: buildConfigSnapshot(req.Provider, modelID, req.Params),
425+
TaskID: taskID,
426+
Prompt: prompt,
427+
PromptOriginal: prompt,
428+
PromptOptimizeMode: promptOptimizeMode,
429+
ProviderName: req.Provider,
430+
ModelID: modelID,
431+
TotalCount: 1, // 目前单次请求只生成一张,后续可扩展
432+
Status: "pending",
433+
ConfigSnapshot: buildConfigSnapshot(req.Provider, modelID, req.Params),
454434
}
455435

456436
if count, ok := req.Params["count"].(float64); ok {
@@ -578,6 +558,10 @@ func GenerateWithImagesHandler(c *gin.Context) {
578558
"reference_images": refImageBytes, // 传递 interface 列表,方便 Provider 类型断言
579559
}
580560
diagnostic.AttachVerboseFlag(taskParams, req.Verbose)
561+
promptOptimizeMode := promptopt.NormalizeMode(req.PromptOptimizeMode)
562+
if promptOptimizeMode != promptopt.ModeOff {
563+
promptopt.ApplyPromptHints(taskParams, req.Prompt, promptOptimizeMode, req.PromptOptimizeProvider, req.PromptOptimizeModel)
564+
}
581565

582566
log.Printf("[API] 提交任务: Prompt=%s, Images=%d\n", req.Prompt, len(refImageBytes))
583567

@@ -589,13 +573,15 @@ func GenerateWithImagesHandler(c *gin.Context) {
589573

590574
taskID := uuid.New().String()
591575
taskModel := &model.Task{
592-
TaskID: taskID,
593-
Prompt: req.Prompt,
594-
ProviderName: req.Provider,
595-
ModelID: modelID,
596-
TotalCount: req.Count,
597-
Status: "pending",
598-
ConfigSnapshot: buildConfigSnapshot(req.Provider, modelID, taskParams),
576+
TaskID: taskID,
577+
Prompt: req.Prompt,
578+
PromptOriginal: req.Prompt,
579+
PromptOptimizeMode: promptOptimizeMode,
580+
ProviderName: req.Provider,
581+
ModelID: modelID,
582+
TotalCount: req.Count,
583+
Status: "pending",
584+
ConfigSnapshot: buildConfigSnapshot(req.Provider, modelID, taskParams),
599585
}
600586

601587
if err := model.DB.Create(taskModel).Error; err != nil {

backend/internal/api/multipart_helper.go

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,18 @@ type MultipartFile struct {
2020

2121
// MultipartRequest 表示图生图请求解析后的数据
2222
type MultipartRequest struct {
23-
Provider string
24-
ModelID string
25-
Prompt string
26-
AspectRatio string
27-
ImageSize string
28-
Count int
29-
Verbose bool
30-
RefImages []MultipartFile
31-
RefPaths []string
23+
Provider string
24+
ModelID string
25+
Prompt string
26+
AspectRatio string
27+
ImageSize string
28+
Count int
29+
Verbose bool
30+
PromptOptimizeMode string
31+
PromptOptimizeProvider string
32+
PromptOptimizeModel string
33+
RefImages []MultipartFile
34+
RefPaths []string
3235
}
3336

3437
// ParseGenerateRequestFromMultipart 使用 formstream 解析图生图请求
@@ -101,6 +104,30 @@ func ParseGenerateRequestFromMultipart(c *gin.Context) (*MultipartRequest, error
101104
req.Verbose = parseLooseBool(string(data))
102105
return nil
103106
})
107+
p.Parser.Register("prompt_optimize_mode", func(reader io.Reader, header formstream.Header) error {
108+
data, err := io.ReadAll(reader)
109+
if err != nil {
110+
return err
111+
}
112+
req.PromptOptimizeMode = string(data)
113+
return nil
114+
})
115+
p.Parser.Register("prompt_optimize_provider", func(reader io.Reader, header formstream.Header) error {
116+
data, err := io.ReadAll(reader)
117+
if err != nil {
118+
return err
119+
}
120+
req.PromptOptimizeProvider = string(data)
121+
return nil
122+
})
123+
p.Parser.Register("prompt_optimize_model", func(reader io.Reader, header formstream.Header) error {
124+
data, err := io.ReadAll(reader)
125+
if err != nil {
126+
return err
127+
}
128+
req.PromptOptimizeModel = string(data)
129+
return nil
130+
})
104131
p.Parser.Register("refPaths", func(reader io.Reader, header formstream.Header) error {
105132
data, err := io.ReadAll(reader)
106133
if err != nil {
@@ -154,6 +181,9 @@ func parseWithStandardLibrary(c *gin.Context) (*MultipartRequest, error) {
154181
}
155182
}
156183
req.Verbose = parseLooseBool(c.PostForm("verbose_logging"))
184+
req.PromptOptimizeMode = c.PostForm("prompt_optimize_mode")
185+
req.PromptOptimizeProvider = c.PostForm("prompt_optimize_provider")
186+
req.PromptOptimizeModel = c.PostForm("prompt_optimize_model")
157187

158188
form, err := c.MultipartForm()
159189
if err == nil && form.File != nil {

backend/internal/model/models.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ type Task struct {
2828
ID uint `gorm:"primaryKey" json:"id"`
2929
TaskID string `gorm:"uniqueIndex;not null" json:"task_id"` // 外部调用的唯一 ID
3030
Prompt string `gorm:"index:idx_prompt_search;index" json:"prompt"` // 提示词,添加复合索引支持搜索
31+
PromptOriginal string `json:"prompt_original"` // 用户原始输入的提示词
32+
PromptOptimized string `json:"prompt_optimized"` // 自动优化后的提示词
33+
PromptOptimizeMode string `json:"prompt_optimize_mode"` // 自动优化模式:off/text/json
3134
FolderID string `gorm:"index" json:"folder_id"` // 所属文件夹 ID(可选)
3235
ProviderName string `gorm:"index" json:"provider_name"` // 使用的 Provider
3336
ModelID string `gorm:"index" json:"model_id"` // 使用的模型 ID

0 commit comments

Comments
 (0)