|
1 | 1 | import { deepClone } from 'common/util/deepClone'; |
2 | 2 | import { set } from 'es-toolkit/compat'; |
| 3 | +import type { InvocationTemplate } from 'features/nodes/types/invocation'; |
3 | 4 | import { describe, expect, it } from 'vitest'; |
4 | 5 |
|
5 | 6 | import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils'; |
6 | 7 | import { validateConnection } from './validateConnection'; |
7 | 8 |
|
| 9 | +const ifTemplate: InvocationTemplate = { |
| 10 | + title: 'If', |
| 11 | + type: 'if', |
| 12 | + version: '1.0.0', |
| 13 | + tags: [], |
| 14 | + description: 'Selects between two inputs based on a boolean condition', |
| 15 | + outputType: 'if_output', |
| 16 | + inputs: { |
| 17 | + condition: { |
| 18 | + name: 'condition', |
| 19 | + title: 'Condition', |
| 20 | + required: true, |
| 21 | + description: 'The condition used to select an input', |
| 22 | + fieldKind: 'input', |
| 23 | + input: 'connection', |
| 24 | + ui_hidden: false, |
| 25 | + ui_type: 'BooleanField', |
| 26 | + type: { |
| 27 | + name: 'BooleanField', |
| 28 | + cardinality: 'SINGLE', |
| 29 | + batch: false, |
| 30 | + }, |
| 31 | + default: false, |
| 32 | + }, |
| 33 | + true_input: { |
| 34 | + name: 'true_input', |
| 35 | + title: 'True Input', |
| 36 | + required: false, |
| 37 | + description: 'Selected when condition is true', |
| 38 | + fieldKind: 'input', |
| 39 | + input: 'connection', |
| 40 | + ui_hidden: false, |
| 41 | + ui_type: 'AnyField', |
| 42 | + type: { |
| 43 | + name: 'AnyField', |
| 44 | + cardinality: 'SINGLE', |
| 45 | + batch: false, |
| 46 | + }, |
| 47 | + default: undefined, |
| 48 | + }, |
| 49 | + false_input: { |
| 50 | + name: 'false_input', |
| 51 | + title: 'False Input', |
| 52 | + required: false, |
| 53 | + description: 'Selected when condition is false', |
| 54 | + fieldKind: 'input', |
| 55 | + input: 'connection', |
| 56 | + ui_hidden: false, |
| 57 | + ui_type: 'AnyField', |
| 58 | + type: { |
| 59 | + name: 'AnyField', |
| 60 | + cardinality: 'SINGLE', |
| 61 | + batch: false, |
| 62 | + }, |
| 63 | + default: undefined, |
| 64 | + }, |
| 65 | + }, |
| 66 | + outputs: { |
| 67 | + value: { |
| 68 | + fieldKind: 'output', |
| 69 | + name: 'value', |
| 70 | + title: 'Output', |
| 71 | + description: 'The selected value', |
| 72 | + type: { |
| 73 | + name: 'AnyField', |
| 74 | + cardinality: 'SINGLE', |
| 75 | + batch: false, |
| 76 | + }, |
| 77 | + ui_hidden: false, |
| 78 | + ui_type: 'AnyField', |
| 79 | + }, |
| 80 | + }, |
| 81 | + useCache: true, |
| 82 | + nodePack: 'invokeai', |
| 83 | + classification: 'stable', |
| 84 | +}; |
| 85 | + |
| 86 | +const floatOutputTemplate: InvocationTemplate = { |
| 87 | + title: 'Float Output', |
| 88 | + type: 'float_output', |
| 89 | + version: '1.0.0', |
| 90 | + tags: [], |
| 91 | + description: 'Outputs a float', |
| 92 | + outputType: 'float_output', |
| 93 | + inputs: {}, |
| 94 | + outputs: { |
| 95 | + value: { |
| 96 | + fieldKind: 'output', |
| 97 | + name: 'value', |
| 98 | + title: 'Value', |
| 99 | + description: 'Float value', |
| 100 | + type: { |
| 101 | + name: 'FloatField', |
| 102 | + cardinality: 'SINGLE', |
| 103 | + batch: false, |
| 104 | + }, |
| 105 | + ui_hidden: false, |
| 106 | + ui_type: 'FloatField', |
| 107 | + }, |
| 108 | + }, |
| 109 | + useCache: true, |
| 110 | + nodePack: 'invokeai', |
| 111 | + classification: 'stable', |
| 112 | +}; |
| 113 | + |
| 114 | +const integerCollectionOutputTemplate: InvocationTemplate = { |
| 115 | + title: 'Integer Collection Output', |
| 116 | + type: 'integer_collection_output', |
| 117 | + version: '1.0.0', |
| 118 | + tags: [], |
| 119 | + description: 'Outputs an integer collection', |
| 120 | + outputType: 'integer_collection_output', |
| 121 | + inputs: {}, |
| 122 | + outputs: { |
| 123 | + value: { |
| 124 | + fieldKind: 'output', |
| 125 | + name: 'value', |
| 126 | + title: 'Value', |
| 127 | + description: 'Integer collection value', |
| 128 | + type: { |
| 129 | + name: 'IntegerField', |
| 130 | + cardinality: 'COLLECTION', |
| 131 | + batch: false, |
| 132 | + }, |
| 133 | + ui_hidden: false, |
| 134 | + ui_type: 'IntegerField', |
| 135 | + }, |
| 136 | + }, |
| 137 | + useCache: true, |
| 138 | + nodePack: 'invokeai', |
| 139 | + classification: 'stable', |
| 140 | +}; |
| 141 | + |
8 | 142 | describe(validateConnection.name, () => { |
9 | 143 | it('should reject invalid connection to self', () => { |
10 | 144 | const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; |
@@ -201,6 +335,118 @@ describe(validateConnection.name, () => { |
201 | 335 | expect(r).toEqual('nodes.fieldTypesMustMatch'); |
202 | 336 | }); |
203 | 337 |
|
| 338 | + it('should reject mismatched types between if node branch inputs', () => { |
| 339 | + const n1 = buildNode(add); |
| 340 | + const n2 = buildNode(img_resize); |
| 341 | + const n3 = buildNode(ifTemplate); |
| 342 | + const nodes = [n1, n2, n3]; |
| 343 | + const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input'); |
| 344 | + const edges = [e1]; |
| 345 | + const c = { source: n2.id, sourceHandle: 'image', target: n3.id, targetHandle: 'false_input' }; |
| 346 | + const r = validateConnection(c, nodes, edges, { ...templates, if: ifTemplate }, null); |
| 347 | + expect(r).toEqual('nodes.fieldTypesMustMatch'); |
| 348 | + }); |
| 349 | + |
| 350 | + it('should reject mismatched types between if node branch inputs regardless of branch order', () => { |
| 351 | + const n1 = buildNode(add); |
| 352 | + const n2 = buildNode(img_resize); |
| 353 | + const n3 = buildNode(ifTemplate); |
| 354 | + const nodes = [n1, n2, n3]; |
| 355 | + const e1 = buildEdge(n1.id, 'value', n3.id, 'false_input'); |
| 356 | + const edges = [e1]; |
| 357 | + const c = { source: n2.id, sourceHandle: 'image', target: n3.id, targetHandle: 'true_input' }; |
| 358 | + const r = validateConnection(c, nodes, edges, { ...templates, if: ifTemplate }, null); |
| 359 | + expect(r).toEqual('nodes.fieldTypesMustMatch'); |
| 360 | + }); |
| 361 | + |
| 362 | + it('should accept convertible types between if node branch inputs', () => { |
| 363 | + const n1 = buildNode(add); |
| 364 | + const n2 = buildNode(sub); |
| 365 | + const n3 = buildNode(ifTemplate); |
| 366 | + const nodes = [n1, n2, n3]; |
| 367 | + const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input'); |
| 368 | + const edges = [e1]; |
| 369 | + const c = { source: n2.id, sourceHandle: 'value', target: n3.id, targetHandle: 'false_input' }; |
| 370 | + const r = validateConnection(c, nodes, edges, { ...templates, if: ifTemplate }, null); |
| 371 | + expect(r).toEqual(null); |
| 372 | + }); |
| 373 | + |
| 374 | + it('should accept one-way-convertible types between if node branch inputs in either connection order', () => { |
| 375 | + const n1 = buildNode(add); |
| 376 | + const n2 = buildNode(floatOutputTemplate); |
| 377 | + const n3 = buildNode(ifTemplate); |
| 378 | + const nodes = [n1, n2, n3]; |
| 379 | + const e1 = buildEdge(n1.id, 'value', n3.id, 'false_input'); |
| 380 | + const edges = [e1]; |
| 381 | + const c = { source: n2.id, sourceHandle: 'value', target: n3.id, targetHandle: 'true_input' }; |
| 382 | + const r = validateConnection( |
| 383 | + c, |
| 384 | + nodes, |
| 385 | + edges, |
| 386 | + { ...templates, if: ifTemplate, float_output: floatOutputTemplate }, |
| 387 | + null |
| 388 | + ); |
| 389 | + expect(r).toEqual(null); |
| 390 | + }); |
| 391 | + |
| 392 | + it('should accept SINGLE and COLLECTION of the same type between if node branch inputs', () => { |
| 393 | + const n1 = buildNode(add); |
| 394 | + const n2 = buildNode(integerCollectionOutputTemplate); |
| 395 | + const n3 = buildNode(ifTemplate); |
| 396 | + const nodes = [n1, n2, n3]; |
| 397 | + const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input'); |
| 398 | + const edges = [e1]; |
| 399 | + const c = { source: n2.id, sourceHandle: 'value', target: n3.id, targetHandle: 'false_input' }; |
| 400 | + const r = validateConnection( |
| 401 | + c, |
| 402 | + nodes, |
| 403 | + edges, |
| 404 | + { ...templates, if: ifTemplate, integer_collection_output: integerCollectionOutputTemplate }, |
| 405 | + null |
| 406 | + ); |
| 407 | + expect(r).toEqual(null); |
| 408 | + }); |
| 409 | + |
| 410 | + it('should accept if output to collection input when both if branch inputs are collections of matching type', () => { |
| 411 | + const n1 = buildNode(integerCollectionOutputTemplate); |
| 412 | + const n2 = buildNode(integerCollectionOutputTemplate); |
| 413 | + const n3 = buildNode(ifTemplate); |
| 414 | + const n4 = buildNode(templates.iterate!); |
| 415 | + const nodes = [n1, n2, n3, n4]; |
| 416 | + const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input'); |
| 417 | + const e2 = buildEdge(n2.id, 'value', n3.id, 'false_input'); |
| 418 | + const edges = [e1, e2]; |
| 419 | + const c = { source: n3.id, sourceHandle: 'value', target: n4.id, targetHandle: 'collection' }; |
| 420 | + const r = validateConnection( |
| 421 | + c, |
| 422 | + nodes, |
| 423 | + edges, |
| 424 | + { ...templates, if: ifTemplate, integer_collection_output: integerCollectionOutputTemplate }, |
| 425 | + null |
| 426 | + ); |
| 427 | + expect(r).toEqual(null); |
| 428 | + }); |
| 429 | + |
| 430 | + it('should reject if output to collection input when if branch inputs are not both collection-compatible', () => { |
| 431 | + const n1 = buildNode(add); |
| 432 | + const n2 = buildNode(integerCollectionOutputTemplate); |
| 433 | + const n3 = buildNode(ifTemplate); |
| 434 | + const n4 = buildNode(templates.iterate!); |
| 435 | + const nodes = [n1, n2, n3, n4]; |
| 436 | + const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input'); |
| 437 | + const e2 = buildEdge(n2.id, 'value', n3.id, 'false_input'); |
| 438 | + const edges = [e1, e2]; |
| 439 | + const c = { source: n3.id, sourceHandle: 'value', target: n4.id, targetHandle: 'collection' }; |
| 440 | + const r = validateConnection( |
| 441 | + c, |
| 442 | + nodes, |
| 443 | + edges, |
| 444 | + { ...templates, if: ifTemplate, integer_collection_output: integerCollectionOutputTemplate }, |
| 445 | + null |
| 446 | + ); |
| 447 | + expect(r).toEqual('nodes.fieldTypesMustMatch'); |
| 448 | + }); |
| 449 | + |
204 | 450 | it('should reject connections that would create cycles', () => { |
205 | 451 | const n1 = buildNode(add); |
206 | 452 | const n2 = buildNode(sub); |
|
0 commit comments