Skip to content

Commit e6f2980

Browse files
authored
Added If node and ability to link an Any output to a node input if cardinality matches (#8869)
* Added If node * Added stricter type checking on inputs * feat(nodes): make if-node type checks cardinality-aware without loosening global AnyField * chore: typegen
1 parent 01c67c5 commit e6f2980

7 files changed

Lines changed: 514 additions & 9 deletions

File tree

invokeai/app/invocations/logic.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import Any, Optional
2+
3+
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
4+
from invokeai.app.invocations.fields import InputField, OutputField, UIType
5+
from invokeai.app.services.shared.invocation_context import InvocationContext
6+
7+
8+
@invocation_output("if_output")
9+
class IfInvocationOutput(BaseInvocationOutput):
10+
value: Optional[Any] = OutputField(
11+
default=None, description="The selected value", title="Output", ui_type=UIType.Any
12+
)
13+
14+
15+
@invocation("if", title="If", tags=["logic", "conditional"], category="logic", version="1.0.0")
16+
class IfInvocation(BaseInvocation):
17+
"""Selects between two optional inputs based on a boolean condition."""
18+
19+
condition: bool = InputField(default=False, description="The condition used to select an input", title="Condition")
20+
true_input: Optional[Any] = InputField(
21+
default=None,
22+
description="Selected when the condition is true",
23+
title="True Input",
24+
ui_type=UIType.Any,
25+
)
26+
false_input: Optional[Any] = InputField(
27+
default=None,
28+
description="Selected when the condition is false",
29+
title="False Input",
30+
ui_type=UIType.Any,
31+
)
32+
33+
def invoke(self, context: InvocationContext) -> IfInvocationOutput:
34+
return IfInvocationOutput(value=self.true_input if self.condition else self.false_input)

invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,144 @@
11
import { deepClone } from 'common/util/deepClone';
22
import { set } from 'es-toolkit/compat';
3+
import type { InvocationTemplate } from 'features/nodes/types/invocation';
34
import { describe, expect, it } from 'vitest';
45

56
import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils';
67
import { validateConnection } from './validateConnection';
78

9+
const ifTemplate: InvocationTemplate = {
10+
title: 'If',
11+
type: 'if',
12+
version: '1.0.0',
13+
tags: [],
14+
description: 'Selects between two inputs based on a boolean condition',
15+
outputType: 'if_output',
16+
inputs: {
17+
condition: {
18+
name: 'condition',
19+
title: 'Condition',
20+
required: true,
21+
description: 'The condition used to select an input',
22+
fieldKind: 'input',
23+
input: 'connection',
24+
ui_hidden: false,
25+
ui_type: 'BooleanField',
26+
type: {
27+
name: 'BooleanField',
28+
cardinality: 'SINGLE',
29+
batch: false,
30+
},
31+
default: false,
32+
},
33+
true_input: {
34+
name: 'true_input',
35+
title: 'True Input',
36+
required: false,
37+
description: 'Selected when condition is true',
38+
fieldKind: 'input',
39+
input: 'connection',
40+
ui_hidden: false,
41+
ui_type: 'AnyField',
42+
type: {
43+
name: 'AnyField',
44+
cardinality: 'SINGLE',
45+
batch: false,
46+
},
47+
default: undefined,
48+
},
49+
false_input: {
50+
name: 'false_input',
51+
title: 'False Input',
52+
required: false,
53+
description: 'Selected when condition is false',
54+
fieldKind: 'input',
55+
input: 'connection',
56+
ui_hidden: false,
57+
ui_type: 'AnyField',
58+
type: {
59+
name: 'AnyField',
60+
cardinality: 'SINGLE',
61+
batch: false,
62+
},
63+
default: undefined,
64+
},
65+
},
66+
outputs: {
67+
value: {
68+
fieldKind: 'output',
69+
name: 'value',
70+
title: 'Output',
71+
description: 'The selected value',
72+
type: {
73+
name: 'AnyField',
74+
cardinality: 'SINGLE',
75+
batch: false,
76+
},
77+
ui_hidden: false,
78+
ui_type: 'AnyField',
79+
},
80+
},
81+
useCache: true,
82+
nodePack: 'invokeai',
83+
classification: 'stable',
84+
};
85+
86+
const floatOutputTemplate: InvocationTemplate = {
87+
title: 'Float Output',
88+
type: 'float_output',
89+
version: '1.0.0',
90+
tags: [],
91+
description: 'Outputs a float',
92+
outputType: 'float_output',
93+
inputs: {},
94+
outputs: {
95+
value: {
96+
fieldKind: 'output',
97+
name: 'value',
98+
title: 'Value',
99+
description: 'Float value',
100+
type: {
101+
name: 'FloatField',
102+
cardinality: 'SINGLE',
103+
batch: false,
104+
},
105+
ui_hidden: false,
106+
ui_type: 'FloatField',
107+
},
108+
},
109+
useCache: true,
110+
nodePack: 'invokeai',
111+
classification: 'stable',
112+
};
113+
114+
const integerCollectionOutputTemplate: InvocationTemplate = {
115+
title: 'Integer Collection Output',
116+
type: 'integer_collection_output',
117+
version: '1.0.0',
118+
tags: [],
119+
description: 'Outputs an integer collection',
120+
outputType: 'integer_collection_output',
121+
inputs: {},
122+
outputs: {
123+
value: {
124+
fieldKind: 'output',
125+
name: 'value',
126+
title: 'Value',
127+
description: 'Integer collection value',
128+
type: {
129+
name: 'IntegerField',
130+
cardinality: 'COLLECTION',
131+
batch: false,
132+
},
133+
ui_hidden: false,
134+
ui_type: 'IntegerField',
135+
},
136+
},
137+
useCache: true,
138+
nodePack: 'invokeai',
139+
classification: 'stable',
140+
};
141+
8142
describe(validateConnection.name, () => {
9143
it('should reject invalid connection to self', () => {
10144
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
@@ -201,6 +335,118 @@ describe(validateConnection.name, () => {
201335
expect(r).toEqual('nodes.fieldTypesMustMatch');
202336
});
203337

338+
it('should reject mismatched types between if node branch inputs', () => {
339+
const n1 = buildNode(add);
340+
const n2 = buildNode(img_resize);
341+
const n3 = buildNode(ifTemplate);
342+
const nodes = [n1, n2, n3];
343+
const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input');
344+
const edges = [e1];
345+
const c = { source: n2.id, sourceHandle: 'image', target: n3.id, targetHandle: 'false_input' };
346+
const r = validateConnection(c, nodes, edges, { ...templates, if: ifTemplate }, null);
347+
expect(r).toEqual('nodes.fieldTypesMustMatch');
348+
});
349+
350+
it('should reject mismatched types between if node branch inputs regardless of branch order', () => {
351+
const n1 = buildNode(add);
352+
const n2 = buildNode(img_resize);
353+
const n3 = buildNode(ifTemplate);
354+
const nodes = [n1, n2, n3];
355+
const e1 = buildEdge(n1.id, 'value', n3.id, 'false_input');
356+
const edges = [e1];
357+
const c = { source: n2.id, sourceHandle: 'image', target: n3.id, targetHandle: 'true_input' };
358+
const r = validateConnection(c, nodes, edges, { ...templates, if: ifTemplate }, null);
359+
expect(r).toEqual('nodes.fieldTypesMustMatch');
360+
});
361+
362+
it('should accept convertible types between if node branch inputs', () => {
363+
const n1 = buildNode(add);
364+
const n2 = buildNode(sub);
365+
const n3 = buildNode(ifTemplate);
366+
const nodes = [n1, n2, n3];
367+
const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input');
368+
const edges = [e1];
369+
const c = { source: n2.id, sourceHandle: 'value', target: n3.id, targetHandle: 'false_input' };
370+
const r = validateConnection(c, nodes, edges, { ...templates, if: ifTemplate }, null);
371+
expect(r).toEqual(null);
372+
});
373+
374+
it('should accept one-way-convertible types between if node branch inputs in either connection order', () => {
375+
const n1 = buildNode(add);
376+
const n2 = buildNode(floatOutputTemplate);
377+
const n3 = buildNode(ifTemplate);
378+
const nodes = [n1, n2, n3];
379+
const e1 = buildEdge(n1.id, 'value', n3.id, 'false_input');
380+
const edges = [e1];
381+
const c = { source: n2.id, sourceHandle: 'value', target: n3.id, targetHandle: 'true_input' };
382+
const r = validateConnection(
383+
c,
384+
nodes,
385+
edges,
386+
{ ...templates, if: ifTemplate, float_output: floatOutputTemplate },
387+
null
388+
);
389+
expect(r).toEqual(null);
390+
});
391+
392+
it('should accept SINGLE and COLLECTION of the same type between if node branch inputs', () => {
393+
const n1 = buildNode(add);
394+
const n2 = buildNode(integerCollectionOutputTemplate);
395+
const n3 = buildNode(ifTemplate);
396+
const nodes = [n1, n2, n3];
397+
const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input');
398+
const edges = [e1];
399+
const c = { source: n2.id, sourceHandle: 'value', target: n3.id, targetHandle: 'false_input' };
400+
const r = validateConnection(
401+
c,
402+
nodes,
403+
edges,
404+
{ ...templates, if: ifTemplate, integer_collection_output: integerCollectionOutputTemplate },
405+
null
406+
);
407+
expect(r).toEqual(null);
408+
});
409+
410+
it('should accept if output to collection input when both if branch inputs are collections of matching type', () => {
411+
const n1 = buildNode(integerCollectionOutputTemplate);
412+
const n2 = buildNode(integerCollectionOutputTemplate);
413+
const n3 = buildNode(ifTemplate);
414+
const n4 = buildNode(templates.iterate!);
415+
const nodes = [n1, n2, n3, n4];
416+
const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input');
417+
const e2 = buildEdge(n2.id, 'value', n3.id, 'false_input');
418+
const edges = [e1, e2];
419+
const c = { source: n3.id, sourceHandle: 'value', target: n4.id, targetHandle: 'collection' };
420+
const r = validateConnection(
421+
c,
422+
nodes,
423+
edges,
424+
{ ...templates, if: ifTemplate, integer_collection_output: integerCollectionOutputTemplate },
425+
null
426+
);
427+
expect(r).toEqual(null);
428+
});
429+
430+
it('should reject if output to collection input when if branch inputs are not both collection-compatible', () => {
431+
const n1 = buildNode(add);
432+
const n2 = buildNode(integerCollectionOutputTemplate);
433+
const n3 = buildNode(ifTemplate);
434+
const n4 = buildNode(templates.iterate!);
435+
const nodes = [n1, n2, n3, n4];
436+
const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input');
437+
const e2 = buildEdge(n2.id, 'value', n3.id, 'false_input');
438+
const edges = [e1, e2];
439+
const c = { source: n3.id, sourceHandle: 'value', target: n4.id, targetHandle: 'collection' };
440+
const r = validateConnection(
441+
c,
442+
nodes,
443+
edges,
444+
{ ...templates, if: ifTemplate, integer_collection_output: integerCollectionOutputTemplate },
445+
null
446+
);
447+
expect(r).toEqual('nodes.fieldTypesMustMatch');
448+
});
449+
204450
it('should reject connections that would create cycles', () => {
205451
const n1 = buildNode(add);
206452
const n2 = buildNode(sub);

0 commit comments

Comments
 (0)