Skip to content

Commit e65a67f

Browse files
committed
fix(ui): stabilize workflow node execution state updates
1 parent 33c288a commit e65a67f

File tree

4 files changed

+305
-30
lines changed

4 files changed

+305
-30
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import { deepClone } from 'common/util/deepClone';
2+
import type { NodeExecutionState } from 'features/nodes/types/invocation';
3+
import { zNodeStatus } from 'features/nodes/types/invocation';
4+
import type { S } from 'services/api/types';
5+
6+
const getInvocationKey = (data: { item_id: number; invocation: { id: string } }) =>
7+
`${data.item_id}:${data.invocation.id}`;
8+
9+
export const getUpdatedNodeExecutionStateOnInvocationStarted = (
10+
nodeExecutionState: NodeExecutionState | undefined,
11+
data: S['InvocationStartedEvent'],
12+
completedInvocationKeys: Set<string>
13+
) => {
14+
if (!nodeExecutionState) {
15+
return;
16+
}
17+
18+
if (completedInvocationKeys.has(getInvocationKey(data))) {
19+
return;
20+
}
21+
22+
const _nodeExecutionState = deepClone(nodeExecutionState);
23+
_nodeExecutionState.status = zNodeStatus.enum.IN_PROGRESS;
24+
25+
return _nodeExecutionState;
26+
};
27+
28+
export const getUpdatedNodeExecutionStateOnInvocationProgress = (
29+
nodeExecutionState: NodeExecutionState | undefined,
30+
data: S['InvocationProgressEvent'],
31+
completedInvocationKeys: Set<string>
32+
) => {
33+
if (!nodeExecutionState) {
34+
return;
35+
}
36+
37+
if (completedInvocationKeys.has(getInvocationKey(data))) {
38+
return;
39+
}
40+
41+
const _nodeExecutionState = deepClone(nodeExecutionState);
42+
_nodeExecutionState.status = zNodeStatus.enum.IN_PROGRESS;
43+
_nodeExecutionState.progress = data.percentage ?? null;
44+
_nodeExecutionState.progressImage = data.image ?? null;
45+
46+
return _nodeExecutionState;
47+
};
48+
49+
export const getUpdatedNodeExecutionStateOnInvocationComplete = (
50+
nodeExecutionState: NodeExecutionState | undefined,
51+
data: S['InvocationCompleteEvent'],
52+
completedInvocationKeys: Set<string>
53+
) => {
54+
if (!nodeExecutionState) {
55+
return;
56+
}
57+
58+
const completedInvocationKey = getInvocationKey(data);
59+
60+
if (completedInvocationKeys.has(completedInvocationKey)) {
61+
return;
62+
}
63+
64+
const _nodeExecutionState = deepClone(nodeExecutionState);
65+
_nodeExecutionState.status = zNodeStatus.enum.COMPLETED;
66+
if (_nodeExecutionState.progress !== null) {
67+
_nodeExecutionState.progress = 1;
68+
}
69+
_nodeExecutionState.outputs.push(data.result);
70+
completedInvocationKeys.add(completedInvocationKey);
71+
72+
return _nodeExecutionState;
73+
};
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import type { NodeExecutionState } from 'features/nodes/types/invocation';
2+
import { zNodeStatus } from 'features/nodes/types/invocation';
3+
import type { S } from 'services/api/types';
4+
import { describe, expect, it } from 'vitest';
5+
6+
import {
7+
getUpdatedNodeExecutionStateOnInvocationComplete,
8+
getUpdatedNodeExecutionStateOnInvocationProgress,
9+
getUpdatedNodeExecutionStateOnInvocationStarted,
10+
} from './nodeExecutionState';
11+
12+
const buildNodeExecutionState = (overrides: Partial<NodeExecutionState> = {}): NodeExecutionState => ({
13+
nodeId: 'node-1',
14+
status: zNodeStatus.enum.PENDING,
15+
progress: null,
16+
progressImage: null,
17+
outputs: [],
18+
error: null,
19+
...overrides,
20+
});
21+
22+
const buildInvocationStartedEvent = (
23+
overrides: Partial<S['InvocationStartedEvent']> = {}
24+
): S['InvocationStartedEvent'] =>
25+
({
26+
queue_id: 'default',
27+
item_id: 1,
28+
batch_id: 'batch-1',
29+
origin: 'workflows',
30+
destination: 'gallery',
31+
user_id: 'user-1',
32+
session_id: 'session-1',
33+
invocation_source_id: 'node-1',
34+
invocation: {
35+
id: 'prepared-node-1',
36+
type: 'add',
37+
},
38+
...overrides,
39+
}) as S['InvocationStartedEvent'];
40+
41+
const buildInvocationProgressEvent = (
42+
overrides: Partial<S['InvocationProgressEvent']> = {}
43+
): S['InvocationProgressEvent'] =>
44+
({
45+
queue_id: 'default',
46+
item_id: 1,
47+
batch_id: 'batch-1',
48+
origin: 'workflows',
49+
destination: 'gallery',
50+
user_id: 'user-1',
51+
session_id: 'session-1',
52+
invocation_source_id: 'node-1',
53+
invocation: {
54+
id: 'prepared-node-1',
55+
type: 'add',
56+
},
57+
percentage: 0.42,
58+
image: {
59+
dataURL: 'data:image/png;base64,abc',
60+
width: 64,
61+
height: 64,
62+
},
63+
message: 'working',
64+
...overrides,
65+
}) as S['InvocationProgressEvent'];
66+
67+
const buildInvocationCompleteEvent = (
68+
overrides: Partial<S['InvocationCompleteEvent']> = {}
69+
): S['InvocationCompleteEvent'] =>
70+
({
71+
queue_id: 'default',
72+
item_id: 1,
73+
batch_id: 'batch-1',
74+
origin: 'workflows',
75+
destination: 'gallery',
76+
user_id: 'user-1',
77+
session_id: 'session-1',
78+
invocation_source_id: 'node-1',
79+
invocation: {
80+
id: 'prepared-node-1',
81+
type: 'add',
82+
},
83+
result: {
84+
type: 'integer_output',
85+
value: 42,
86+
},
87+
...overrides,
88+
}) as S['InvocationCompleteEvent'];
89+
90+
describe(getUpdatedNodeExecutionStateOnInvocationStarted.name, () => {
91+
it('marks the node in progress on invocation start', () => {
92+
const updated = getUpdatedNodeExecutionStateOnInvocationStarted(
93+
buildNodeExecutionState(),
94+
buildInvocationStartedEvent(),
95+
new Set<string>()
96+
);
97+
98+
expect(updated?.status).toBe(zNodeStatus.enum.IN_PROGRESS);
99+
});
100+
101+
it('ignores a late started event after that invocation already completed', () => {
102+
const event = buildInvocationStartedEvent();
103+
const updated = getUpdatedNodeExecutionStateOnInvocationStarted(
104+
buildNodeExecutionState({ status: zNodeStatus.enum.COMPLETED, progress: 1 }),
105+
event,
106+
new Set([`${event.item_id}:${event.invocation.id}`])
107+
);
108+
109+
expect(updated).toBeUndefined();
110+
});
111+
});
112+
113+
describe(getUpdatedNodeExecutionStateOnInvocationProgress.name, () => {
114+
it('marks the node in progress and preserves progress updates', () => {
115+
const event = buildInvocationProgressEvent();
116+
const updated = getUpdatedNodeExecutionStateOnInvocationProgress(
117+
buildNodeExecutionState(),
118+
event,
119+
new Set<string>()
120+
);
121+
122+
expect(updated?.status).toBe(zNodeStatus.enum.IN_PROGRESS);
123+
expect(updated?.progress).toBe(event.percentage);
124+
expect(updated?.progressImage).toEqual(event.image);
125+
});
126+
127+
it('ignores a late progress event after that invocation already completed', () => {
128+
const event = buildInvocationProgressEvent();
129+
const updated = getUpdatedNodeExecutionStateOnInvocationProgress(
130+
buildNodeExecutionState({ status: zNodeStatus.enum.COMPLETED, progress: 1 }),
131+
event,
132+
new Set([`${event.item_id}:${event.invocation.id}`])
133+
);
134+
135+
expect(updated).toBeUndefined();
136+
});
137+
});
138+
139+
describe(getUpdatedNodeExecutionStateOnInvocationComplete.name, () => {
140+
it('records a completed invocation result once', () => {
141+
const event = buildInvocationCompleteEvent();
142+
const completedInvocationKeys = new Set<string>();
143+
144+
const updated = getUpdatedNodeExecutionStateOnInvocationComplete(
145+
buildNodeExecutionState({ status: zNodeStatus.enum.IN_PROGRESS, progress: 0.5 }),
146+
event,
147+
completedInvocationKeys
148+
);
149+
150+
expect(updated?.status).toBe(zNodeStatus.enum.COMPLETED);
151+
expect(updated?.progress).toBe(1);
152+
expect(updated?.outputs).toEqual([event.result]);
153+
expect(completedInvocationKeys).toEqual(new Set([`${event.item_id}:${event.invocation.id}`]));
154+
});
155+
156+
it('ignores duplicate completion events for the same invocation', () => {
157+
const event = buildInvocationCompleteEvent();
158+
const updated = getUpdatedNodeExecutionStateOnInvocationComplete(
159+
buildNodeExecutionState({ status: zNodeStatus.enum.COMPLETED, progress: 1, outputs: [event.result] }),
160+
event,
161+
new Set([`${event.item_id}:${event.invocation.id}`])
162+
);
163+
164+
expect(updated).toBeUndefined();
165+
});
166+
167+
it('allows the same prepared invocation id on a different queue item', () => {
168+
const firstEvent = buildInvocationCompleteEvent({
169+
item_id: 1,
170+
result: { type: 'integer_output', value: 1 } as unknown as S['InvocationCompleteEvent']['result'],
171+
});
172+
const secondEvent = buildInvocationCompleteEvent({
173+
item_id: 2,
174+
result: { type: 'integer_output', value: 2 } as unknown as S['InvocationCompleteEvent']['result'],
175+
});
176+
const completedInvocationKeys = new Set<string>();
177+
178+
const firstUpdate = getUpdatedNodeExecutionStateOnInvocationComplete(
179+
buildNodeExecutionState(),
180+
firstEvent,
181+
completedInvocationKeys
182+
);
183+
const secondUpdate = getUpdatedNodeExecutionStateOnInvocationComplete(
184+
firstUpdate,
185+
secondEvent,
186+
completedInvocationKeys
187+
);
188+
189+
expect(secondUpdate?.outputs).toEqual([firstEvent.result, secondEvent.result]);
190+
});
191+
});

invokeai/frontend/web/src/services/events/onInvocationComplete.tsx

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import { logger } from 'app/logging/logger';
22
import type { AppDispatch, AppGetState } from 'app/store/store';
3-
import { deepClone } from 'common/util/deepClone';
43
import { canvasWorkflowIntegrationProcessingCompleted } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
54
import {
65
selectAutoSwitch,
@@ -12,15 +11,14 @@ import {
1211
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
1312
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
1413
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
15-
import { zNodeStatus } from 'features/nodes/types/invocation';
16-
import type { LRUCache } from 'lru-cache';
1714
import { LIST_ALL_TAG } from 'services/api';
1815
import { boardsApi } from 'services/api/endpoints/boards';
1916
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
2017
import { queueApi } from 'services/api/endpoints/queue';
2118
import type { ImageDTO, S } from 'services/api/types';
2219
import { getCategories } from 'services/api/util';
2320
import { insertImageIntoNamesResult } from 'services/api/util/optimisticUpdates';
21+
import { getUpdatedNodeExecutionStateOnInvocationComplete } from 'services/events/nodeExecutionState';
2422
import { $lastProgressEvent } from 'services/events/stores';
2523
import stableHash from 'stable-hash';
2624
import type { Param0 } from 'tsafe';
@@ -38,13 +36,13 @@ const nodeTypeDenylist = ['load_image', 'image'];
3836
*
3937
* @param getState The Redux getState function.
4038
* @param dispatch The Redux dispatch function.
41-
* @param finishedQueueItemIds A cache of finished queue item IDs to prevent duplicate handling and avoid race
42-
* conditions that can happen when a graph finishes very quickly.
39+
* @param completedInvocationKeys A listener-local set used to dedupe repeated invocation completion events and to
40+
* share completion knowledge with the other invocation event handlers.
4341
*/
4442
export const buildOnInvocationComplete = (
4543
getState: AppGetState,
4644
dispatch: AppDispatch,
47-
finishedQueueItemIds: LRUCache<number, boolean>
45+
completedInvocationKeys: Set<string>
4846
) => {
4947
const addImagesToGallery = async (data: S['InvocationCompleteEvent']) => {
5048
if (nodeTypeDenylist.includes(data.invocation.type)) {
@@ -242,22 +240,24 @@ export const buildOnInvocationComplete = (
242240
};
243241

244242
return async (data: S['InvocationCompleteEvent']) => {
245-
if (finishedQueueItemIds.has(data.item_id)) {
246-
log.trace({ data } as JsonObject, `Received event for already-finished queue item ${data.item_id}`);
247-
return;
248-
}
249243
log.debug({ data } as JsonObject, `Invocation complete (${data.invocation.type}, ${data.invocation_source_id})`);
250244

251245
const nodeExecutionState = $nodeExecutionStates.get()[data.invocation_source_id];
246+
const updatedNodeExecutionState = getUpdatedNodeExecutionStateOnInvocationComplete(
247+
nodeExecutionState,
248+
data,
249+
completedInvocationKeys
250+
);
252251

253-
if (nodeExecutionState) {
254-
const _nodeExecutionState = deepClone(nodeExecutionState);
255-
_nodeExecutionState.status = zNodeStatus.enum.COMPLETED;
256-
if (_nodeExecutionState.progress !== null) {
257-
_nodeExecutionState.progress = 1;
258-
}
259-
_nodeExecutionState.outputs.push(data.result);
260-
upsertExecutionState(_nodeExecutionState.nodeId, _nodeExecutionState);
252+
if (nodeExecutionState && !updatedNodeExecutionState) {
253+
log.trace(
254+
{ data } as JsonObject,
255+
`Ignoring duplicate invocation complete (${data.invocation.type}, ${data.invocation_source_id})`
256+
);
257+
}
258+
259+
if (updatedNodeExecutionState) {
260+
upsertExecutionState(updatedNodeExecutionState.nodeId, updatedNodeExecutionState);
261261
}
262262

263263
// Clear canvas workflow integration processing state if needed

0 commit comments

Comments
 (0)