|
1 | | -import { describe, expect, it } from 'vitest'; |
| 1 | +import { afterEach, describe, expect, it, vi } from 'vitest'; |
2 | 2 |
|
3 | | -import { isQwenImageEditModel } from './buildQwenImageGraph'; |
| 3 | +vi.mock('app/logging/logger', () => ({ |
| 4 | + logger: () => ({ |
| 5 | + debug: vi.fn(), |
| 6 | + }), |
| 7 | +})); |
| 8 | + |
| 9 | +let nextId = 0; |
| 10 | +vi.mock('features/controlLayers/konva/util', () => ({ |
| 11 | + getPrefixedId: (prefix: string) => `${prefix}:${nextId++}`, |
| 12 | +})); |
| 13 | + |
| 14 | +const model = { |
| 15 | + key: 'qwen-model', |
| 16 | + hash: 'qwen-hash', |
| 17 | + name: 'Qwen Image Generate', |
| 18 | + base: 'qwen-image', |
| 19 | + type: 'main', |
| 20 | + variant: 'generate', |
| 21 | +}; |
| 22 | + |
| 23 | +const defaultParams: { |
| 24 | + cfgScale: number | number[]; |
| 25 | + steps: number; |
| 26 | + qwenImageComponentSource: null; |
| 27 | + qwenImageQuantization: string; |
| 28 | + qwenImageShift: number; |
| 29 | +} = { |
| 30 | + cfgScale: 4, |
| 31 | + steps: 20, |
| 32 | + qwenImageComponentSource: null, |
| 33 | + qwenImageQuantization: 'none', |
| 34 | + qwenImageShift: 1, |
| 35 | +}; |
| 36 | + |
| 37 | +let params = { ...defaultParams }; |
| 38 | + |
| 39 | +const refImagesSlice = { |
| 40 | + entities: [ |
| 41 | + { |
| 42 | + id: 'ref-image-1', |
| 43 | + isEnabled: true, |
| 44 | + config: { |
| 45 | + type: 'qwen_image_reference_image', |
| 46 | + image: { |
| 47 | + original: { |
| 48 | + image: { |
| 49 | + image_name: 'reference.png', |
| 50 | + width: 512, |
| 51 | + height: 512, |
| 52 | + }, |
| 53 | + }, |
| 54 | + }, |
| 55 | + }, |
| 56 | + }, |
| 57 | + ], |
| 58 | +}; |
| 59 | + |
| 60 | +vi.mock('features/controlLayers/store/paramsSlice', () => ({ |
| 61 | + selectMainModelConfig: vi.fn(() => model), |
| 62 | + selectParamsSlice: vi.fn(() => params), |
| 63 | +})); |
| 64 | + |
| 65 | +vi.mock('features/controlLayers/store/refImagesSlice', () => ({ |
| 66 | + selectRefImagesSlice: vi.fn(() => refImagesSlice), |
| 67 | +})); |
| 68 | + |
| 69 | +vi.mock('features/controlLayers/store/selectors', () => ({ |
| 70 | + selectCanvasMetadata: vi.fn(() => ({})), |
| 71 | +})); |
| 72 | + |
| 73 | +vi.mock('features/controlLayers/store/types', () => ({ |
| 74 | + isQwenImageReferenceImageConfig: vi.fn((config: { type?: string }) => config.type === 'qwen_image_reference_image'), |
| 75 | +})); |
| 76 | + |
| 77 | +vi.mock('features/controlLayers/store/validators', () => ({ |
| 78 | + getGlobalReferenceImageWarnings: vi.fn(() => []), |
| 79 | +})); |
| 80 | + |
| 81 | +vi.mock('features/metadata/util/modelFetchingHelpers', () => ({ |
| 82 | + fetchModelConfigWithTypeGuard: vi.fn(() => Promise.resolve(model)), |
| 83 | +})); |
| 84 | + |
| 85 | +vi.mock('features/nodes/types/common', async () => { |
| 86 | + const actual = await vi.importActual('features/nodes/types/common'); |
| 87 | + return { |
| 88 | + ...actual, |
| 89 | + zImageField: { |
| 90 | + parse: vi.fn((image) => image), |
| 91 | + }, |
| 92 | + }; |
| 93 | +}); |
| 94 | + |
| 95 | +vi.mock('features/nodes/util/graph/generation/addImageToImage', () => ({ |
| 96 | + addImageToImage: vi.fn(), |
| 97 | +})); |
| 98 | + |
| 99 | +vi.mock('features/nodes/util/graph/generation/addInpaint', () => ({ |
| 100 | + addInpaint: vi.fn(), |
| 101 | +})); |
| 102 | + |
| 103 | +vi.mock('features/nodes/util/graph/generation/addNSFWChecker', () => ({ |
| 104 | + addNSFWChecker: vi.fn((_g, node) => node), |
| 105 | +})); |
| 106 | + |
| 107 | +vi.mock('features/nodes/util/graph/generation/addOutpaint', () => ({ |
| 108 | + addOutpaint: vi.fn(), |
| 109 | +})); |
| 110 | + |
| 111 | +vi.mock('features/nodes/util/graph/generation/addQwenImageLoRAs', () => ({ |
| 112 | + addQwenImageLoRAs: vi.fn(), |
| 113 | +})); |
| 114 | + |
| 115 | +vi.mock('features/nodes/util/graph/generation/addTextToImage', () => ({ |
| 116 | + addTextToImage: vi.fn(({ l2i }) => l2i), |
| 117 | +})); |
| 118 | + |
| 119 | +vi.mock('features/nodes/util/graph/generation/addWatermarker', () => ({ |
| 120 | + addWatermarker: vi.fn((_g, node) => node), |
| 121 | +})); |
| 122 | + |
| 123 | +vi.mock('features/nodes/util/graph/graphBuilderUtils', () => ({ |
| 124 | + selectCanvasOutputFields: vi.fn(() => ({})), |
| 125 | + selectPresetModifiedPrompts: vi.fn(() => ({ |
| 126 | + positive: 'a prompt', |
| 127 | + negative: 'a negative prompt', |
| 128 | + })), |
| 129 | +})); |
| 130 | + |
| 131 | +vi.mock('features/ui/store/uiSelectors', () => ({ |
| 132 | + selectActiveTab: vi.fn(() => 'generation'), |
| 133 | +})); |
| 134 | + |
| 135 | +vi.mock('services/api/types', async () => { |
| 136 | + const actual = await vi.importActual('services/api/types'); |
| 137 | + return { |
| 138 | + ...actual, |
| 139 | + isNonRefinerMainModelConfig: vi.fn(() => true), |
| 140 | + }; |
| 141 | +}); |
| 142 | + |
| 143 | +import { buildQwenImageGraph, isQwenImageEditModel, shouldUseCfg } from './buildQwenImageGraph'; |
4 | 144 |
|
5 | 145 | describe('isQwenImageEditModel', () => { |
| 146 | + afterEach(() => { |
| 147 | + nextId = 0; |
| 148 | + params = { ...defaultParams }; |
| 149 | + }); |
| 150 | + |
6 | 151 | it('returns true for edit variant', () => { |
7 | 152 | expect(isQwenImageEditModel({ variant: 'edit' })).toBe(true); |
8 | 153 | }); |
@@ -35,23 +180,163 @@ describe('isQwenImageEditModel', () => { |
35 | 180 |
|
36 | 181 | describe('reference image filtering regression', () => { |
37 | 182 | it('prevents reference images from leaking to generate models when switching from edit', () => { |
38 | | - // Simulate: user was using an edit model (variant='edit') with reference images, |
39 | | - // then switches to a generate model (variant='generate'). |
40 | | - // The generate model should NOT receive reference images. |
41 | 183 | const editModel = { variant: 'edit' as const }; |
42 | 184 | const generateModel = { variant: 'generate' as const }; |
43 | 185 |
|
44 | | - // Edit model: reference images should be collected |
45 | 186 | expect(isQwenImageEditModel(editModel)).toBe(true); |
46 | | - |
47 | | - // Generate model: reference images must NOT be collected, even if they exist in state |
48 | 187 | expect(isQwenImageEditModel(generateModel)).toBe(false); |
49 | 188 | }); |
50 | 189 |
|
51 | 190 | it('prevents reference images from leaking to GGUF models without variant', () => { |
52 | | - // GGUF models installed without a variant field default to generate behavior |
53 | 191 | const ggufModelNoVariant = {}; |
54 | 192 | expect(isQwenImageEditModel(ggufModelNoVariant)).toBe(false); |
55 | 193 | }); |
56 | 194 | }); |
57 | 195 | }); |
| 196 | + |
| 197 | +describe('shouldUseCfg', () => { |
| 198 | + afterEach(() => { |
| 199 | + nextId = 0; |
| 200 | + params = { ...defaultParams }; |
| 201 | + }); |
| 202 | + |
| 203 | + describe('negative conditioning is included when cfgScale > 1', () => { |
| 204 | + it('returns true for cfgScale = 4', () => { |
| 205 | + expect(shouldUseCfg(4)).toBe(true); |
| 206 | + }); |
| 207 | + |
| 208 | + it('returns true for cfgScale = 1.5', () => { |
| 209 | + expect(shouldUseCfg(1.5)).toBe(true); |
| 210 | + }); |
| 211 | + |
| 212 | + it('returns true for cfgScale = 1.01', () => { |
| 213 | + expect(shouldUseCfg(1.01)).toBe(true); |
| 214 | + }); |
| 215 | + }); |
| 216 | + |
| 217 | + describe('negative conditioning is excluded when cfgScale <= 1', () => { |
| 218 | + it('returns false for cfgScale = 1', () => { |
| 219 | + expect(shouldUseCfg(1)).toBe(false); |
| 220 | + }); |
| 221 | + |
| 222 | + it('returns false for cfgScale = 0.5', () => { |
| 223 | + expect(shouldUseCfg(0.5)).toBe(false); |
| 224 | + }); |
| 225 | + |
| 226 | + it('returns false for cfgScale = 0', () => { |
| 227 | + expect(shouldUseCfg(0)).toBe(false); |
| 228 | + }); |
| 229 | + }); |
| 230 | + |
| 231 | + describe('array cfgScale (per-step)', () => { |
| 232 | + it('returns true for per-step arrays with values > 1', () => { |
| 233 | + expect(shouldUseCfg([4, 3, 2, 1])).toBe(true); |
| 234 | + }); |
| 235 | + |
| 236 | + it('returns true when any per-step cfg value is > 1', () => { |
| 237 | + expect(shouldUseCfg([1, 1.1, 1])).toBe(true); |
| 238 | + expect(shouldUseCfg([0.5, 2, 0.5])).toBe(true); |
| 239 | + }); |
| 240 | + |
| 241 | + it('returns false when every per-step cfg value is <= 1', () => { |
| 242 | + expect(shouldUseCfg([1, 1, 1])).toBe(false); |
| 243 | + expect(shouldUseCfg([0.5, 0.75, 1])).toBe(false); |
| 244 | + }); |
| 245 | + }); |
| 246 | + |
| 247 | + describe('CFG gating regression', () => { |
| 248 | + it('with cfgScale=1, neg_prompt is absent from the graph (no wasted compute)', () => { |
| 249 | + expect(shouldUseCfg(1)).toBe(false); |
| 250 | + }); |
| 251 | + |
| 252 | + it('with cfgScale=4, neg_prompt is present in the graph for classifier-free guidance', () => { |
| 253 | + expect(shouldUseCfg(4)).toBe(true); |
| 254 | + }); |
| 255 | + |
| 256 | + it('omits negative conditioning edges from the graph when per-step cfg never exceeds 1', async () => { |
| 257 | + params = { |
| 258 | + ...defaultParams, |
| 259 | + cfgScale: [1, 1, 1], |
| 260 | + }; |
| 261 | + |
| 262 | + const { g } = await buildQwenImageGraph({ |
| 263 | + generationMode: 'txt2img', |
| 264 | + manager: null, |
| 265 | + state: { |
| 266 | + system: { |
| 267 | + shouldUseNSFWChecker: false, |
| 268 | + shouldUseWatermarker: false, |
| 269 | + }, |
| 270 | + } as never, |
| 271 | + }); |
| 272 | + |
| 273 | + const graph = g.getGraph(); |
| 274 | + const nodeIds = Object.keys(graph.nodes); |
| 275 | + const hasNegativePromptNode = nodeIds.some((id) => id.startsWith('neg_prompt:')); |
| 276 | + const hasNegativeConditioningEdge = graph.edges.some( |
| 277 | + (edge) => edge.destination.field === 'negative_conditioning' |
| 278 | + ); |
| 279 | + |
| 280 | + expect(hasNegativePromptNode).toBe(false); |
| 281 | + expect(hasNegativeConditioningEdge).toBe(false); |
| 282 | + }); |
| 283 | + |
| 284 | + it('includes negative conditioning edges in the graph when any per-step cfg exceeds 1', async () => { |
| 285 | + params = { |
| 286 | + ...defaultParams, |
| 287 | + cfgScale: [1, 2, 1], |
| 288 | + }; |
| 289 | + |
| 290 | + const { g } = await buildQwenImageGraph({ |
| 291 | + generationMode: 'txt2img', |
| 292 | + manager: null, |
| 293 | + state: { |
| 294 | + system: { |
| 295 | + shouldUseNSFWChecker: false, |
| 296 | + shouldUseWatermarker: false, |
| 297 | + }, |
| 298 | + } as never, |
| 299 | + }); |
| 300 | + |
| 301 | + const graph = g.getGraph(); |
| 302 | + const nodeIds = Object.keys(graph.nodes); |
| 303 | + const hasNegativePromptNode = nodeIds.some((id) => id.startsWith('neg_prompt:')); |
| 304 | + const hasNegativeConditioningEdge = graph.edges.some( |
| 305 | + (edge) => edge.destination.field === 'negative_conditioning' |
| 306 | + ); |
| 307 | + |
| 308 | + expect(hasNegativePromptNode).toBe(true); |
| 309 | + expect(hasNegativeConditioningEdge).toBe(true); |
| 310 | + }); |
| 311 | + }); |
| 312 | +}); |
| 313 | + |
| 314 | +describe('buildQwenImageGraph', () => { |
| 315 | + afterEach(() => { |
| 316 | + nextId = 0; |
| 317 | + params = { ...defaultParams }; |
| 318 | + }); |
| 319 | + |
| 320 | + it('does not include hidden Qwen reference images for generate-variant models', async () => { |
| 321 | + const { g } = await buildQwenImageGraph({ |
| 322 | + generationMode: 'txt2img', |
| 323 | + manager: null, |
| 324 | + state: { |
| 325 | + system: { |
| 326 | + shouldUseNSFWChecker: false, |
| 327 | + shouldUseWatermarker: false, |
| 328 | + }, |
| 329 | + } as never, |
| 330 | + }); |
| 331 | + |
| 332 | + const graph = g.getGraph(); |
| 333 | + const nodeIds = Object.keys(graph.nodes); |
| 334 | + const hasReferenceCollectionNode = nodeIds.some((id) => id.startsWith('qwen_ref_img_collect:')); |
| 335 | + const hasReferenceImagesEdge = graph.edges.some((edge) => edge.destination.field === 'reference_images'); |
| 336 | + const hasReferenceLatentsEdge = graph.edges.some((edge) => edge.destination.field === 'reference_latents'); |
| 337 | + |
| 338 | + expect(hasReferenceCollectionNode).toBe(false); |
| 339 | + expect(hasReferenceImagesEdge).toBe(false); |
| 340 | + expect(hasReferenceLatentsEdge).toBe(false); |
| 341 | + }); |
| 342 | +}); |
0 commit comments