Skip to content

Commit c2d7b8b

Browse files
committed
chore(frontend): add regression test for qwen image CFG handling in graph builder
chore(backend): add test for heuristic detection of Qwen Image Edit GGUF model variant chore(frontend): add regression test for ref images not added to qwen image in generate mode fix(frontend): graph build handling of Qwen Image when CFG <=1 chore(frontend): add regression test for optimal dimension selection
1 parent be9cbc3 commit c2d7b8b

File tree

4 files changed

+480
-10
lines changed

4 files changed

+480
-10
lines changed

invokeai/frontend/web/src/features/nodes/util/graph/generation/buildQwenImageGraph.test.ts

Lines changed: 294 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,153 @@
1-
import { describe, expect, it } from 'vitest';
1+
import { afterEach, describe, expect, it, vi } from 'vitest';
22

3-
import { isQwenImageEditModel } from './buildQwenImageGraph';
3+
vi.mock('app/logging/logger', () => ({
4+
logger: () => ({
5+
debug: vi.fn(),
6+
}),
7+
}));
8+
9+
let nextId = 0;
10+
vi.mock('features/controlLayers/konva/util', () => ({
11+
getPrefixedId: (prefix: string) => `${prefix}:${nextId++}`,
12+
}));
13+
14+
const model = {
15+
key: 'qwen-model',
16+
hash: 'qwen-hash',
17+
name: 'Qwen Image Generate',
18+
base: 'qwen-image',
19+
type: 'main',
20+
variant: 'generate',
21+
};
22+
23+
const defaultParams: {
24+
cfgScale: number | number[];
25+
steps: number;
26+
qwenImageComponentSource: null;
27+
qwenImageQuantization: string;
28+
qwenImageShift: number;
29+
} = {
30+
cfgScale: 4,
31+
steps: 20,
32+
qwenImageComponentSource: null,
33+
qwenImageQuantization: 'none',
34+
qwenImageShift: 1,
35+
};
36+
37+
let params = { ...defaultParams };
38+
39+
const refImagesSlice = {
40+
entities: [
41+
{
42+
id: 'ref-image-1',
43+
isEnabled: true,
44+
config: {
45+
type: 'qwen_image_reference_image',
46+
image: {
47+
original: {
48+
image: {
49+
image_name: 'reference.png',
50+
width: 512,
51+
height: 512,
52+
},
53+
},
54+
},
55+
},
56+
},
57+
],
58+
};
59+
60+
vi.mock('features/controlLayers/store/paramsSlice', () => ({
61+
selectMainModelConfig: vi.fn(() => model),
62+
selectParamsSlice: vi.fn(() => params),
63+
}));
64+
65+
vi.mock('features/controlLayers/store/refImagesSlice', () => ({
66+
selectRefImagesSlice: vi.fn(() => refImagesSlice),
67+
}));
68+
69+
vi.mock('features/controlLayers/store/selectors', () => ({
70+
selectCanvasMetadata: vi.fn(() => ({})),
71+
}));
72+
73+
vi.mock('features/controlLayers/store/types', () => ({
74+
isQwenImageReferenceImageConfig: vi.fn((config: { type?: string }) => config.type === 'qwen_image_reference_image'),
75+
}));
76+
77+
vi.mock('features/controlLayers/store/validators', () => ({
78+
getGlobalReferenceImageWarnings: vi.fn(() => []),
79+
}));
80+
81+
vi.mock('features/metadata/util/modelFetchingHelpers', () => ({
82+
fetchModelConfigWithTypeGuard: vi.fn(() => Promise.resolve(model)),
83+
}));
84+
85+
vi.mock('features/nodes/types/common', async () => {
86+
const actual = await vi.importActual('features/nodes/types/common');
87+
return {
88+
...actual,
89+
zImageField: {
90+
parse: vi.fn((image) => image),
91+
},
92+
};
93+
});
94+
95+
vi.mock('features/nodes/util/graph/generation/addImageToImage', () => ({
96+
addImageToImage: vi.fn(),
97+
}));
98+
99+
vi.mock('features/nodes/util/graph/generation/addInpaint', () => ({
100+
addInpaint: vi.fn(),
101+
}));
102+
103+
vi.mock('features/nodes/util/graph/generation/addNSFWChecker', () => ({
104+
addNSFWChecker: vi.fn((_g, node) => node),
105+
}));
106+
107+
vi.mock('features/nodes/util/graph/generation/addOutpaint', () => ({
108+
addOutpaint: vi.fn(),
109+
}));
110+
111+
vi.mock('features/nodes/util/graph/generation/addQwenImageLoRAs', () => ({
112+
addQwenImageLoRAs: vi.fn(),
113+
}));
114+
115+
vi.mock('features/nodes/util/graph/generation/addTextToImage', () => ({
116+
addTextToImage: vi.fn(({ l2i }) => l2i),
117+
}));
118+
119+
vi.mock('features/nodes/util/graph/generation/addWatermarker', () => ({
120+
addWatermarker: vi.fn((_g, node) => node),
121+
}));
122+
123+
vi.mock('features/nodes/util/graph/graphBuilderUtils', () => ({
124+
selectCanvasOutputFields: vi.fn(() => ({})),
125+
selectPresetModifiedPrompts: vi.fn(() => ({
126+
positive: 'a prompt',
127+
negative: 'a negative prompt',
128+
})),
129+
}));
130+
131+
vi.mock('features/ui/store/uiSelectors', () => ({
132+
selectActiveTab: vi.fn(() => 'generation'),
133+
}));
134+
135+
vi.mock('services/api/types', async () => {
136+
const actual = await vi.importActual('services/api/types');
137+
return {
138+
...actual,
139+
isNonRefinerMainModelConfig: vi.fn(() => true),
140+
};
141+
});
142+
143+
import { buildQwenImageGraph, isQwenImageEditModel, shouldUseCfg } from './buildQwenImageGraph';
4144

5145
describe('isQwenImageEditModel', () => {
146+
afterEach(() => {
147+
nextId = 0;
148+
params = { ...defaultParams };
149+
});
150+
6151
it('returns true for edit variant', () => {
7152
expect(isQwenImageEditModel({ variant: 'edit' })).toBe(true);
8153
});
@@ -35,23 +180,163 @@ describe('isQwenImageEditModel', () => {
35180

36181
describe('reference image filtering regression', () => {
37182
it('prevents reference images from leaking to generate models when switching from edit', () => {
38-
// Simulate: user was using an edit model (variant='edit') with reference images,
39-
// then switches to a generate model (variant='generate').
40-
// The generate model should NOT receive reference images.
41183
const editModel = { variant: 'edit' as const };
42184
const generateModel = { variant: 'generate' as const };
43185

44-
// Edit model: reference images should be collected
45186
expect(isQwenImageEditModel(editModel)).toBe(true);
46-
47-
// Generate model: reference images must NOT be collected, even if they exist in state
48187
expect(isQwenImageEditModel(generateModel)).toBe(false);
49188
});
50189

51190
it('prevents reference images from leaking to GGUF models without variant', () => {
52-
// GGUF models installed without a variant field default to generate behavior
53191
const ggufModelNoVariant = {};
54192
expect(isQwenImageEditModel(ggufModelNoVariant)).toBe(false);
55193
});
56194
});
57195
});
196+
197+
describe('shouldUseCfg', () => {
198+
afterEach(() => {
199+
nextId = 0;
200+
params = { ...defaultParams };
201+
});
202+
203+
describe('negative conditioning is included when cfgScale > 1', () => {
204+
it('returns true for cfgScale = 4', () => {
205+
expect(shouldUseCfg(4)).toBe(true);
206+
});
207+
208+
it('returns true for cfgScale = 1.5', () => {
209+
expect(shouldUseCfg(1.5)).toBe(true);
210+
});
211+
212+
it('returns true for cfgScale = 1.01', () => {
213+
expect(shouldUseCfg(1.01)).toBe(true);
214+
});
215+
});
216+
217+
describe('negative conditioning is excluded when cfgScale <= 1', () => {
218+
it('returns false for cfgScale = 1', () => {
219+
expect(shouldUseCfg(1)).toBe(false);
220+
});
221+
222+
it('returns false for cfgScale = 0.5', () => {
223+
expect(shouldUseCfg(0.5)).toBe(false);
224+
});
225+
226+
it('returns false for cfgScale = 0', () => {
227+
expect(shouldUseCfg(0)).toBe(false);
228+
});
229+
});
230+
231+
describe('array cfgScale (per-step)', () => {
232+
it('returns true for per-step arrays with values > 1', () => {
233+
expect(shouldUseCfg([4, 3, 2, 1])).toBe(true);
234+
});
235+
236+
it('returns true when any per-step cfg value is > 1', () => {
237+
expect(shouldUseCfg([1, 1.1, 1])).toBe(true);
238+
expect(shouldUseCfg([0.5, 2, 0.5])).toBe(true);
239+
});
240+
241+
it('returns false when every per-step cfg value is <= 1', () => {
242+
expect(shouldUseCfg([1, 1, 1])).toBe(false);
243+
expect(shouldUseCfg([0.5, 0.75, 1])).toBe(false);
244+
});
245+
});
246+
247+
describe('CFG gating regression', () => {
248+
it('with cfgScale=1, neg_prompt is absent from the graph (no wasted compute)', () => {
249+
expect(shouldUseCfg(1)).toBe(false);
250+
});
251+
252+
it('with cfgScale=4, neg_prompt is present in the graph for classifier-free guidance', () => {
253+
expect(shouldUseCfg(4)).toBe(true);
254+
});
255+
256+
it('omits negative conditioning edges from the graph when per-step cfg never exceeds 1', async () => {
257+
params = {
258+
...defaultParams,
259+
cfgScale: [1, 1, 1],
260+
};
261+
262+
const { g } = await buildQwenImageGraph({
263+
generationMode: 'txt2img',
264+
manager: null,
265+
state: {
266+
system: {
267+
shouldUseNSFWChecker: false,
268+
shouldUseWatermarker: false,
269+
},
270+
} as never,
271+
});
272+
273+
const graph = g.getGraph();
274+
const nodeIds = Object.keys(graph.nodes);
275+
const hasNegativePromptNode = nodeIds.some((id) => id.startsWith('neg_prompt:'));
276+
const hasNegativeConditioningEdge = graph.edges.some(
277+
(edge) => edge.destination.field === 'negative_conditioning'
278+
);
279+
280+
expect(hasNegativePromptNode).toBe(false);
281+
expect(hasNegativeConditioningEdge).toBe(false);
282+
});
283+
284+
it('includes negative conditioning edges in the graph when any per-step cfg exceeds 1', async () => {
285+
params = {
286+
...defaultParams,
287+
cfgScale: [1, 2, 1],
288+
};
289+
290+
const { g } = await buildQwenImageGraph({
291+
generationMode: 'txt2img',
292+
manager: null,
293+
state: {
294+
system: {
295+
shouldUseNSFWChecker: false,
296+
shouldUseWatermarker: false,
297+
},
298+
} as never,
299+
});
300+
301+
const graph = g.getGraph();
302+
const nodeIds = Object.keys(graph.nodes);
303+
const hasNegativePromptNode = nodeIds.some((id) => id.startsWith('neg_prompt:'));
304+
const hasNegativeConditioningEdge = graph.edges.some(
305+
(edge) => edge.destination.field === 'negative_conditioning'
306+
);
307+
308+
expect(hasNegativePromptNode).toBe(true);
309+
expect(hasNegativeConditioningEdge).toBe(true);
310+
});
311+
});
312+
});
313+
314+
describe('buildQwenImageGraph', () => {
315+
afterEach(() => {
316+
nextId = 0;
317+
params = { ...defaultParams };
318+
});
319+
320+
it('does not include hidden Qwen reference images for generate-variant models', async () => {
321+
const { g } = await buildQwenImageGraph({
322+
generationMode: 'txt2img',
323+
manager: null,
324+
state: {
325+
system: {
326+
shouldUseNSFWChecker: false,
327+
shouldUseWatermarker: false,
328+
},
329+
} as never,
330+
});
331+
332+
const graph = g.getGraph();
333+
const nodeIds = Object.keys(graph.nodes);
334+
const hasReferenceCollectionNode = nodeIds.some((id) => id.startsWith('qwen_ref_img_collect:'));
335+
const hasReferenceImagesEdge = graph.edges.some((edge) => edge.destination.field === 'reference_images');
336+
const hasReferenceLatentsEdge = graph.edges.some((edge) => edge.destination.field === 'reference_latents');
337+
338+
expect(hasReferenceCollectionNode).toBe(false);
339+
expect(hasReferenceImagesEdge).toBe(false);
340+
expect(hasReferenceLatentsEdge).toBe(false);
341+
});
342+
});

invokeai/frontend/web/src/features/nodes/util/graph/generation/buildQwenImageGraph.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ export const isQwenImageEditModel = (model: { variant?: string | null } | null):
3838
return 'variant' in model && model.variant === 'edit';
3939
};
4040

41+
/**
42+
* Determine whether classifier-free guidance (negative conditioning) should be used.
43+
* CFG is only enabled when cfg_scale > 1. With cfg_scale <= 1, the negative prompt
44+
* is mathematically unused and the model runs once per step instead of twice.
45+
*/
46+
export const shouldUseCfg = (cfgScale: number | number[]): boolean => {
47+
if (typeof cfgScale === 'number') {
48+
return cfgScale > 1;
49+
}
50+
// For per-step CFG arrays, enable CFG if any value exceeds 1
51+
return cfgScale.some((value) => value > 1);
52+
};
53+
4154
export const buildQwenImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
4255
const { generationMode, state, manager } = arg;
4356

@@ -73,7 +86,7 @@ export const buildQwenImageGraph = async (arg: GraphBuilderArg): Promise<GraphBu
7386
});
7487

7588
// Negative conditioning for CFG (only when cfg_scale > 1)
76-
const useCfg = typeof cfg_scale === 'number' ? cfg_scale > 1 : true;
89+
const useCfg = shouldUseCfg(cfg_scale);
7790
const negCond = useCfg
7891
? g.addNode({
7992
type: 'qwen_image_text_encoder',

0 commit comments

Comments
 (0)