Skip to content

Commit 2ca13c7

Browse files
authored
Merge pull request #57 from ShellMonster/fix/openai-image-dynamic-resolution
fix: support OpenAI image dynamic resolution
2 parents d136a9b + 1268a4c commit 2ca13c7

18 files changed

Lines changed: 276 additions & 46 deletions

backend/internal/api/handlers.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ func GenerateWithImagesHandler(c *gin.Context) {
582582
"model_id": modelID,
583583
"aspect_ratio": req.AspectRatio,
584584
"resolution_level": req.ImageSize,
585+
"quality": req.Quality,
585586
"count": req.Count,
586587
"reference_images": refImageBytes, // 传递 interface 列表,方便 Provider 类型断言
587588
}

backend/internal/api/multipart_helper.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type MultipartRequest struct {
2525
Prompt string
2626
AspectRatio string
2727
ImageSize string
28+
Quality string
2829
Count int
2930
Verbose bool
3031
PromptOptimizeMode string
@@ -86,6 +87,14 @@ func ParseGenerateRequestFromMultipart(c *gin.Context) (*MultipartRequest, error
8687
req.ImageSize = string(data)
8788
return nil
8889
})
90+
p.Parser.Register("quality", func(reader io.Reader, header formstream.Header) error {
91+
data, err := io.ReadAll(reader)
92+
if err != nil {
93+
return err
94+
}
95+
req.Quality = string(data)
96+
return nil
97+
})
8998
p.Parser.Register("count", func(reader io.Reader, header formstream.Header) error {
9099
data, err := io.ReadAll(reader)
91100
if err != nil {
@@ -172,6 +181,7 @@ func parseWithStandardLibrary(c *gin.Context) (*MultipartRequest, error) {
172181
Prompt: c.PostForm("prompt"),
173182
AspectRatio: c.PostForm("aspectRatio"),
174183
ImageSize: c.PostForm("imageSize"),
184+
Quality: c.PostForm("quality"),
175185
Count: 1,
176186
}
177187

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package api
2+
3+
import (
4+
"bytes"
5+
"mime/multipart"
6+
"net/http"
7+
"net/http/httptest"
8+
"testing"
9+
10+
"github.com/gin-gonic/gin"
11+
)
12+
13+
func TestParseGenerateRequestFromMultipartIncludesQuality(t *testing.T) {
14+
gin.SetMode(gin.TestMode)
15+
16+
var body bytes.Buffer
17+
writer := multipart.NewWriter(&body)
18+
fields := map[string]string{
19+
"provider": "openai-image",
20+
"model_id": "gpt-image-2",
21+
"prompt": "edit prompt",
22+
"aspectRatio": "1:1",
23+
"imageSize": "2K",
24+
"quality": "high",
25+
"count": "1",
26+
}
27+
for key, value := range fields {
28+
if err := writer.WriteField(key, value); err != nil {
29+
t.Fatalf("WriteField %s: %v", key, err)
30+
}
31+
}
32+
if err := writer.Close(); err != nil {
33+
t.Fatalf("Close multipart writer: %v", err)
34+
}
35+
36+
request := httptest.NewRequest(http.MethodPost, "/api/generate-with-images", &body)
37+
request.Header.Set("Content-Type", writer.FormDataContentType())
38+
recorder := httptest.NewRecorder()
39+
c, _ := gin.CreateTestContext(recorder)
40+
c.Request = request
41+
42+
req, err := ParseGenerateRequestFromMultipart(c)
43+
if err != nil {
44+
t.Fatalf("ParseGenerateRequestFromMultipart: %v", err)
45+
}
46+
if req.Quality != "high" {
47+
t.Fatalf("Quality = %q, want high", req.Quality)
48+
}
49+
}

backend/internal/api/task_timeout_reconcile.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ func normalizeProviderForTimeout(providerName string) string {
4141
switch {
4242
case strings.HasPrefix(name, "gemini"):
4343
return "gemini"
44+
case name == "openai-image":
45+
return "openai-image"
4446
case strings.HasPrefix(name, "openai"):
4547
return "openai"
4648
default:
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package api
2+
3+
import (
4+
"testing"
5+
"time"
6+
)
7+
8+
func TestTaskTimeoutForProviderKeepsOpenAIImageConfig(t *testing.T) {
9+
timeoutMap := map[string]time.Duration{
10+
"openai": 150 * time.Second,
11+
"openai-image": 500 * time.Second,
12+
}
13+
14+
if got := normalizeProviderForTimeout("openai-image"); got != "openai-image" {
15+
t.Fatalf("normalizeProviderForTimeout(openai-image) = %q, want openai-image", got)
16+
}
17+
if got := taskTimeoutForProvider("openai-image", timeoutMap); got != 500*time.Second {
18+
t.Fatalf("openai-image timeout = %s, want 500s", got)
19+
}
20+
if got := taskTimeoutForProvider("openai-chat", timeoutMap); got != 150*time.Second {
21+
t.Fatalf("openai-chat timeout = %s, want 150s", got)
22+
}
23+
}

backend/internal/provider/openai_image.go

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@ import (
66
"encoding/base64"
77
"encoding/json"
88
"fmt"
9+
"image"
910
"image-gen-service/internal/diagnostic"
1011
"image-gen-service/internal/model"
12+
_ "image/gif"
13+
_ "image/jpeg"
14+
"image/png"
1115
"io"
1216
"math"
1317
"mime/multipart"
@@ -17,6 +21,8 @@ import (
1721
"strconv"
1822
"strings"
1923
"time"
24+
25+
_ "golang.org/x/image/webp"
2026
)
2127

2228
type OpenAIImageProvider struct {
@@ -414,20 +420,35 @@ func collectOpenAIImageReferences(raw interface{}) ([]openAIImageReference, erro
414420
if len(imgBytes) == 0 {
415421
continue
416422
}
417-
mimeType := http.DetectContentType(imgBytes)
418-
if mimeType != "image/png" {
419-
return nil, fmt.Errorf("第 %d 张参考图不是有效的 PNG 图片,OpenAI Edit 仅支持 PNG 格式", idx+1)
423+
pngBytes, err := normalizeOpenAIReferenceImagePNG(imgBytes)
424+
if err != nil {
425+
return nil, fmt.Errorf("第 %d 张参考图不是有效图片: %w", idx+1, err)
420426
}
421427
refs = append(refs, openAIImageReference{
422428
Name: fmt.Sprintf("reference-%d.png", idx+1),
423-
Content: imgBytes,
424-
MIME: mimeType,
429+
Content: pngBytes,
430+
MIME: "image/png",
425431
})
426432
}
427433

428434
return refs, nil
429435
}
430436

437+
func normalizeOpenAIReferenceImagePNG(imgBytes []byte) ([]byte, error) {
438+
if http.DetectContentType(imgBytes) == "image/png" {
439+
return imgBytes, nil
440+
}
441+
img, _, err := image.Decode(bytes.NewReader(imgBytes))
442+
if err != nil {
443+
return nil, err
444+
}
445+
var buf bytes.Buffer
446+
if err := png.Encode(&buf, img); err != nil {
447+
return nil, err
448+
}
449+
return buf.Bytes(), nil
450+
}
451+
431452
func escapeMultipartFilename(name string) string {
432453
return strings.NewReplacer("\\", "\\\\", `"`, "\\\"").Replace(name)
433454
}

backend/internal/provider/openai_image_test.go

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
package provider
22

33
import (
4+
"bytes"
45
"encoding/base64"
56
"encoding/json"
7+
"image"
68
"image-gen-service/internal/model"
9+
"image/color"
10+
"image/jpeg"
11+
"io"
712
"net/http"
813
"net/http/httptest"
914
"strings"
@@ -188,6 +193,7 @@ func TestOpenAIImageProviderGenerateWithReferenceUsesEdits(t *testing.T) {
188193
"model_id": "gpt-image-2-all",
189194
"aspect_ratio": "1:1",
190195
"resolution_level": "1K",
196+
"quality": "high",
191197
"count": 1,
192198
"reference_images": []interface{}{refBytes, refBytes, refBytes},
193199
})
@@ -200,6 +206,9 @@ func TestOpenAIImageProviderGenerateWithReferenceUsesEdits(t *testing.T) {
200206
if seenFields["model"] != "gpt-image-2-all" || seenFields["prompt"] != "edit prompt" || seenFields["size"] != "1280x1280" {
201207
t.Fatalf("unexpected fields: %+v", seenFields)
202208
}
209+
if seenFields["quality"] != "high" {
210+
t.Fatalf("quality = %q, want high", seenFields["quality"])
211+
}
203212
if _, ok := seenFields["input_fidelity"]; ok {
204213
t.Fatalf("input_fidelity should not be sent to edits: %+v", seenFields)
205214
}
@@ -214,10 +223,45 @@ func TestOpenAIImageProviderGenerateWithReferenceUsesEdits(t *testing.T) {
214223
}
215224
}
216225

217-
func TestOpenAIImageProviderRejectsNonPNGReference(t *testing.T) {
226+
func TestOpenAIImageProviderConvertsJPEGReferenceToPNG(t *testing.T) {
227+
var jpegRef bytes.Buffer
228+
img := image.NewRGBA(image.Rect(0, 0, 1, 1))
229+
img.Set(0, 0, color.White)
230+
if err := jpeg.Encode(&jpegRef, img, nil); err != nil {
231+
t.Fatalf("encode jpeg: %v", err)
232+
}
233+
234+
var seenImageContentType string
235+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
236+
if err := r.ParseMultipartForm(2 << 20); err != nil {
237+
t.Fatalf("ParseMultipartForm: %v", err)
238+
}
239+
files := r.MultipartForm.File["image"]
240+
if len(files) != 1 {
241+
t.Fatalf("image files = %d, want 1", len(files))
242+
}
243+
seenImageContentType = files[0].Header.Get("Content-Type")
244+
file, err := files[0].Open()
245+
if err != nil {
246+
t.Fatalf("open image part: %v", err)
247+
}
248+
defer file.Close()
249+
content, err := io.ReadAll(file)
250+
if err != nil {
251+
t.Fatalf("read image part: %v", err)
252+
}
253+
if http.DetectContentType(content) != "image/png" {
254+
t.Fatalf("uploaded content type = %q, want image/png", http.DetectContentType(content))
255+
}
256+
_ = json.NewEncoder(w).Encode(map[string]interface{}{
257+
"data": []map[string]string{{"b64_json": tinyPNGBase64}},
258+
})
259+
}))
260+
defer server.Close()
261+
218262
p, err := NewOpenAIImageProvider(&model.ProviderConfig{
219263
ProviderName: "openai-image",
220-
APIBase: "http://example.test/v1",
264+
APIBase: server.URL,
221265
APIKey: "test-key",
222266
TimeoutSeconds: 5,
223267
})
@@ -231,9 +275,19 @@ func TestOpenAIImageProviderRejectsNonPNGReference(t *testing.T) {
231275
"aspect_ratio": "1:1",
232276
"resolution_level": "1K",
233277
"count": 1,
234-
"reference_images": []interface{}{[]byte("not an image")},
278+
"reference_images": []interface{}{jpegRef.Bytes()},
235279
})
236-
if err == nil || !strings.Contains(err.Error(), "OpenAI Edit 仅支持 PNG 格式") {
237-
t.Fatalf("Generate error = %v, want PNG validation error", err)
280+
if err != nil {
281+
t.Fatalf("Generate: %v", err)
282+
}
283+
if seenImageContentType != "image/png" {
284+
t.Fatalf("image part Content-Type = %q, want image/png", seenImageContentType)
285+
}
286+
}
287+
288+
func TestOpenAIImageProviderRejectsInvalidReference(t *testing.T) {
289+
_, err := collectOpenAIImageReferences([]interface{}{[]byte("not an image")})
290+
if err == nil || !strings.Contains(err.Error(), "不是有效图片") {
291+
t.Fatalf("collectOpenAIImageReferences error = %v, want invalid image error", err)
238292
}
239293
}

backend/internal/worker/pool.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,13 +480,15 @@ func fetchProviderTimeout(providerName string) time.Duration {
480480
name := strings.TrimSpace(strings.ToLower(providerName))
481481
if strings.HasPrefix(name, "gemini") {
482482
name = "gemini"
483+
} else if name == "openai-image" {
484+
name = "openai-image"
483485
} else if strings.HasPrefix(name, "openai") {
484486
name = "openai"
485487
}
486488

487489
defaultTimeout := func(p string) time.Duration {
488490
switch p {
489-
case "gemini", "openai":
491+
case "gemini", "openai", "openai-image":
490492
return 500 * time.Second
491493
default:
492494
return 150 * time.Second
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package worker
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"image-gen-service/internal/model"
8+
9+
"gorm.io/driver/sqlite"
10+
"gorm.io/gorm"
11+
)
12+
13+
func TestFetchProviderTimeoutKeepsOpenAIImageConfig(t *testing.T) {
14+
originalDB := model.DB
15+
t.Cleanup(func() {
16+
model.DB = originalDB
17+
})
18+
19+
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
20+
if err != nil {
21+
t.Fatalf("open test database: %v", err)
22+
}
23+
if err := db.AutoMigrate(&model.ProviderConfig{}); err != nil {
24+
t.Fatalf("migrate provider config: %v", err)
25+
}
26+
model.DB = db
27+
28+
configs := []model.ProviderConfig{
29+
{ProviderName: "openai", TimeoutSeconds: 150},
30+
{ProviderName: "openai-image", TimeoutSeconds: 500},
31+
}
32+
for _, cfg := range configs {
33+
if err := db.Create(&cfg).Error; err != nil {
34+
t.Fatalf("create provider config %s: %v", cfg.ProviderName, err)
35+
}
36+
}
37+
38+
if got := fetchProviderTimeout("openai-image"); got != 500*time.Second {
39+
t.Fatalf("openai-image timeout = %s, want 500s", got)
40+
}
41+
if got := fetchProviderTimeout("openai"); got != 150*time.Second {
42+
t.Fatalf("openai timeout = %s, want 150s", got)
43+
}
44+
}

desktop/package-lock.json

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)