Skip to content

Commit ac71222

Browse files
committed
feat(sd.cpp): support parsing structured JSON metadata from SDCPP field
1 parent 5677383 commit ac71222

12 files changed

Lines changed: 667 additions & 35 deletions

File tree

.golangci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ linters:
77
- std-error-handling
88
settings:
99
cyclop:
10-
max-complexity: 14
10+
max-complexity: 15
1111
exhaustive:
1212
default-signifies-exhaustive: true
1313
ireturn:

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# `fooocus-metadata`
22

3-
A Go library for reading and writing image generation parameters for images produced by [Fooocus and various forks](#compatibility).
3+
A Go library for reading and writing image generation parameters for images produced by [Fooocus and other Stable Diffusion implementations](#compatibility).
44

55
## Features
66

@@ -52,6 +52,15 @@ Tested with RuinedFooocus version 2.0.0 and newer.
5252
|--------------|-------------------|-----------------|------|-------|
5353
| PNG | Embedded | JSON |||
5454

55+
### [stable-diffusion.cpp]
56+
57+
Tested with stable-diffusion.cpp version f40a707d and newer.
58+
59+
| Image Format | Metadata Location | Metadata Scheme | Read | Write |
60+
|--------------|-------------------|-----------------|------|-------|
61+
| PNG | Embedded | JSON |||
62+
63+
5564
### AUTOMATIC1111-style metadata
5665

5766
Basic read-only support for metadata encoded in `a1111` (plain text) format. Unsupported keys are ignored.

cmd/extract/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
_ "github.com/fkleon/fooocus-metadata/fooocusplus"
1414
_ "github.com/fkleon/fooocus-metadata/ruinedfooocus"
1515
_ "github.com/fkleon/fooocus-metadata/stablediffusion"
16+
_ "github.com/fkleon/fooocus-metadata/stablediffusioncpp"
1617

1718
fooocusmeta "github.com/fkleon/fooocus-metadata"
1819
)

main_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import (
1212
_ "github.com/fkleon/fooocus-metadata/fooocus"
1313
_ "github.com/fkleon/fooocus-metadata/fooocusplus"
1414
_ "github.com/fkleon/fooocus-metadata/ruinedfooocus"
15+
_ "github.com/fkleon/fooocus-metadata/stablediffusion"
16+
_ "github.com/fkleon/fooocus-metadata/stablediffusioncpp"
1517

1618
"github.com/stretchr/testify/assert"
1719
"github.com/stretchr/testify/require"
@@ -161,6 +163,31 @@ func TestExtractMetadata_RuinedFooocus(t *testing.T) {
161163
}
162164
}
163165

166+
func TestExtractMetadata_StableDiffusionCPP(t *testing.T) {
167+
const testpath = "./stablediffusioncpp/testdata/"
168+
testCases := []struct {
169+
file string
170+
source string
171+
software string
172+
}{
173+
{"sd.cpp-txt-meta.png", "StableDiffusion", "stable-diffusion.cpp"},
174+
{"sd.cpp-json-meta.png", "stable-diffusion.cpp", "stable-diffusion.cpp (f40a707)"},
175+
}
176+
177+
for _, tc := range testCases {
178+
t.Run(tc.file, func(t *testing.T) {
179+
path := filepath.Join(testpath, tc.file)
180+
meta, err := ExtractFromFile(path)
181+
182+
require.NoError(t, err)
183+
require.NotNil(t, meta)
184+
185+
assert.Equal(t, tc.source, meta.Source)
186+
assert.Equal(t, tc.software, meta.Params.Version())
187+
})
188+
}
189+
}
190+
164191
func TestExtractCreatedTime(t *testing.T) {
165192
filenamePattern := "2024-01-05_23-11-48_9167_*.png"
166193
expectedCreatedTime := time.Date(2024, time.January, 5, 23, 11, 48, 0, time.UTC)

stablediffusion/metadata.go

Lines changed: 176 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,48 +14,57 @@ import (
1414
)
1515

1616
type Metadata struct {
17-
BatchSize int `json:"batch_size,string,omitempty"`
18-
BatchPos int `json:"batch_pos,string,omitempty"`
19-
CfgScale float32 `json:"cfg_scale,string"`
20-
ClipSkip int `json:"clip_skip,string,omitempty"`
21-
DenoisingStrength float32 `json:"denoising_strength,string,omitempty"`
22-
Eta Float `json:"eta,string,omitempty"`
23-
HiresSteps int `json:"hires_steps,string,omitempty"`
24-
HiresUpscale float32 `json:"hires_upscale,string,omitempty"`
25-
HiresUpscaler string `json:"hires_upscaler,omitempty"`
26-
Guidance float32 `json:"guidance,string,omitempty"`
27-
ImageNoiseMultiplier float32 `json:"image_noise_multiplier,string,omitempty"`
28-
Loras Loras `json:"loras,omitempty"`
29-
Model string `json:"model,omitempty"`
30-
ModelHash string `json:"model_hash,omitempty"`
31-
NegativePrompt string `json:"negative_prompt,omitempty"`
32-
Prompt string `json:"prompt"`
33-
Rng string `json:"rng,omitempty"`
34-
Sampler string `json:"sampler"` // The Sampler field contains both the sampler and scheduler names
35-
Seed int `json:"seed,string"`
36-
Size *Size `json:"size,omitempty"`
37-
Steps int `json:"steps,string"`
38-
TextEncoder string `json:"TE,omitempty"`
39-
Unet string `json:"unet,omitempty"`
40-
Vae string `json:"vae,omitempty"`
41-
VaeHash string `json:"vae_hash,omitempty"`
42-
Version string `json:"version,omitempty"`
17+
BatchSize int `json:"batch_size,string,omitempty"`
18+
BatchPos int `json:"batch_pos,string,omitempty"`
19+
CfgScale float32 `json:"cfg_scale,string"`
20+
ClipSkip int `json:"clip_skip,string,omitempty"`
21+
CustomSigmas FloatSlice `json:"custom_sigmas,omitempty"`
22+
DenoisingStrength float32 `json:"denoising_strength,string,omitempty"`
23+
Eta Float `json:"eta,string,omitempty"`
24+
HiresSteps int `json:"hires_steps,string,omitempty"`
25+
HiresUpscale float32 `json:"hires_upscale,string,omitempty"`
26+
HiresUpscaler string `json:"hires_upscaler,omitempty"`
27+
Guidance float32 `json:"guidance,string,omitempty"`
28+
ImageNoiseMultiplier float32 `json:"image_noise_multiplier,string,omitempty"`
29+
Loras Loras `json:"loras,omitempty"`
30+
Model string `json:"model,omitempty"`
31+
ModelHash string `json:"model_hash,omitempty"`
32+
NegativePrompt string `json:"negative_prompt,omitempty"`
33+
Prompt string `json:"prompt"`
34+
Rng string `json:"rng,omitempty"`
35+
SamplerRng string `json:"sampler_rng,omitempty"`
36+
Sampler string `json:"sampler"` // The Sampler field contains both the sampler and scheduler names
37+
Seed int `json:"seed,string"`
38+
SkipLayers IntSlice `json:"skip_layers,omitempty"`
39+
SkipLayerEnd float32 `json:"skip_layer_end,string,omitempty"`
40+
SkipLayerStart float32 `json:"skip_layer_start,string,omitempty"`
41+
SLGScale float32 `json:"slg_scale,string,omitempty"`
42+
Size *Size `json:"size,omitempty"`
43+
Steps int `json:"steps,string"`
44+
TextEncoder string `json:"TE,omitempty"`
45+
Unet string `json:"unet,omitempty"`
46+
Vae string `json:"vae,omitempty"`
47+
VaeHash string `json:"vae_hash,omitempty"`
48+
Version string `json:"version,omitempty"`
49+
HiresResize *Size `json:"hires_resize,omitempty"`
4350
}
4451

52+
// Float is a wrapper around float64 that supports unmarshaling from
53+
// both numeric and string JSON values, including support for "inf" and "-inf".
4554
type Float float64
4655

4756
func (f *Float) UnmarshalJSON(v []byte) (err error) {
4857
if s := string(v); s == "inf" || s == "-inf" {
49-
// if +/- indiciates infinity
50-
if s == "inf" {
51-
*f = Float(math.Inf(1))
58+
if strings.HasPrefix(s, "-") {
59+
*f = Float(math.Inf(-1))
5260
return nil
5361
}
5462

55-
*f = Float(math.Inf(-1))
63+
*f = Float(math.Inf(1))
5664

5765
return nil
5866
}
67+
5968
// just a regular float value
6069
var fv float64
6170
if err := json.Unmarshal(v, &fv); err != nil {
@@ -83,6 +92,86 @@ func (f Float) MarshalJSON() ([]byte, error) {
8392
return json.Marshal(v) // marshal result as standard float64
8493
}
8594

95+
type FloatSlice []float32
96+
97+
func (s *FloatSlice) UnmarshalJSON(p []byte) error {
98+
if len(p) == 0 {
99+
return nil
100+
}
101+
102+
if p[0] == '"' {
103+
var raw string
104+
if err := json.Unmarshal(p, &raw); err != nil {
105+
return err
106+
}
107+
108+
parts := splitBracketList(raw)
109+
110+
values := make(FloatSlice, 0, len(parts))
111+
for _, part := range parts {
112+
value, err := strconv.ParseFloat(part, 32)
113+
if err != nil {
114+
return err
115+
}
116+
117+
values = append(values, float32(value))
118+
}
119+
120+
*s = values
121+
122+
return nil
123+
}
124+
125+
var values []float32
126+
if err := json.Unmarshal(p, &values); err != nil {
127+
return err
128+
}
129+
130+
*s = values
131+
132+
return nil
133+
}
134+
135+
type IntSlice []int
136+
137+
func (s *IntSlice) UnmarshalJSON(p []byte) error {
138+
if len(p) == 0 {
139+
return nil
140+
}
141+
142+
if p[0] == '"' {
143+
var raw string
144+
if err := json.Unmarshal(p, &raw); err != nil {
145+
return err
146+
}
147+
148+
parts := splitBracketList(raw)
149+
150+
values := make(IntSlice, 0, len(parts))
151+
for _, part := range parts {
152+
value, err := strconv.Atoi(part)
153+
if err != nil {
154+
return err
155+
}
156+
157+
values = append(values, value)
158+
}
159+
160+
*s = values
161+
162+
return nil
163+
}
164+
165+
var values []int
166+
if err := json.Unmarshal(p, &values); err != nil {
167+
return err
168+
}
169+
170+
*s = values
171+
172+
return nil
173+
}
174+
86175
type Loras []Lora
87176

88177
func (l *Loras) UnmarshalJSON(p []byte) (err error) {
@@ -173,6 +262,28 @@ func (s Size) MarshalJSON() ([]byte, error) {
173262
return json.Marshal(val)
174263
}
175264

265+
func splitBracketList(raw string) []string {
266+
trimmed := strings.TrimSpace(raw)
267+
trimmed = strings.TrimPrefix(trimmed, "[")
268+
269+
trimmed = strings.TrimSuffix(trimmed, "]")
270+
if trimmed == "" {
271+
return nil
272+
}
273+
274+
parts := strings.Split(trimmed, ",")
275+
276+
compact := make([]string, 0, len(parts))
277+
for _, part := range parts {
278+
part = strings.TrimSpace(part)
279+
if part != "" {
280+
compact = append(compact, part)
281+
}
282+
}
283+
284+
return compact
285+
}
286+
176287
func ParseParameters(in string) (meta Metadata, err error) {
177288
if json.Valid([]byte(in)) {
178289
return meta, fmt.Errorf("input is JSON, not plaintext")
@@ -205,9 +316,8 @@ func ParseParameters(in string) (meta Metadata, err error) {
205316

206317
// If prev was negative prompt, the match was not sufficient,
207318
// fix it up
208-
if pmKey == "negative_prompt" {
209-
nprompt := in2[pmIdx : match[0]-1]
210-
kv["negative_prompt"] = strings.TrimSpace(nprompt)
319+
if needsExtendedValue(pmKey) {
320+
kv[pmKey] = strings.TrimSpace(in2[pmIdx : match[0]-1])
211321
}
212322

213323
// Normalize key: Lowercase, trim spaces, replace spaces with underscores
@@ -228,6 +338,12 @@ func ParseParameters(in string) (meta Metadata, err error) {
228338
pmIdx = match[4]
229339
}
230340

341+
if needsExtendedValue(pmKey) {
342+
kv[pmKey] = strings.TrimSpace(in2[pmIdx:])
343+
}
344+
345+
normalizeStableDiffusionCPPFields(kv)
346+
231347
// LoRAs
232348
lr := regexp.MustCompile("<lora:[^>]+>")
233349

@@ -241,3 +357,30 @@ func ParseParameters(in string) (meta Metadata, err error) {
241357

242358
return meta, err
243359
}
360+
361+
func needsExtendedValue(key string) bool {
362+
switch key {
363+
case "negative_prompt", "custom_sigmas", "skip_layers":
364+
return true
365+
default:
366+
return false
367+
}
368+
}
369+
370+
func normalizeStableDiffusionCPPFields(kv map[string]string) {
371+
value, ok := kv["hires_upscale"]
372+
if !ok {
373+
return
374+
}
375+
376+
if _, err := strconv.ParseFloat(value, 32); err == nil {
377+
return
378+
}
379+
380+
kv["hires_upscaler"] = value
381+
delete(kv, "hires_upscale")
382+
383+
if scale, ok := kv["hires_scale"]; ok {
384+
kv["hires_upscale"] = scale
385+
}
386+
}

0 commit comments

Comments
 (0)