Skip to content

Commit 46cc494

Browse files
committed
fix: make sure invocation_error arriving before execution-state initialization is processed correctly
1 parent 105dba5 commit 46cc494

File tree

3 files changed

+82
-12
lines changed

3 files changed

+82
-12
lines changed

invokeai/frontend/web/src/services/events/nodeExecutionState.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,20 @@ export const getUpdatedNodeExecutionStateOnInvocationComplete = (
6868

6969
return _nodeExecutionState;
7070
};
71+
72+
export const getUpdatedNodeExecutionStateOnInvocationError = (
73+
nodeExecutionState: NodeExecutionState | undefined,
74+
data: S['InvocationErrorEvent']
75+
) => {
76+
const _nodeExecutionState = deepClone(nodeExecutionState ?? getInitialNodeExecutionState(data.invocation_source_id));
77+
_nodeExecutionState.status = zNodeStatus.enum.FAILED;
78+
_nodeExecutionState.progress = null;
79+
_nodeExecutionState.progressImage = null;
80+
_nodeExecutionState.error = {
81+
error_type: data.error_type,
82+
error_message: data.error_message,
83+
error_traceback: data.error_traceback,
84+
};
85+
86+
return _nodeExecutionState;
87+
};

invokeai/frontend/web/src/services/events/nodeExecutionStateHelpers.test.ts

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { describe, expect, it } from 'vitest';
55

66
import {
77
getUpdatedNodeExecutionStateOnInvocationComplete,
8+
getUpdatedNodeExecutionStateOnInvocationError,
89
getUpdatedNodeExecutionStateOnInvocationProgress,
910
getUpdatedNodeExecutionStateOnInvocationStarted,
1011
} from './nodeExecutionState';
@@ -87,6 +88,26 @@ const buildInvocationCompleteEvent = (
8788
...overrides,
8889
}) as S['InvocationCompleteEvent'];
8990

91+
const buildInvocationErrorEvent = (overrides: Partial<S['InvocationErrorEvent']> = {}): S['InvocationErrorEvent'] =>
92+
({
93+
queue_id: 'default',
94+
item_id: 1,
95+
batch_id: 'batch-1',
96+
origin: 'workflows',
97+
destination: 'gallery',
98+
user_id: 'user-1',
99+
session_id: 'session-1',
100+
invocation_source_id: 'node-1',
101+
invocation: {
102+
id: 'prepared-node-1',
103+
type: 'add',
104+
},
105+
error_type: 'TestError',
106+
error_message: 'boom',
107+
error_traceback: 'traceback',
108+
...overrides,
109+
}) as S['InvocationErrorEvent'];
110+
90111
describe(getUpdatedNodeExecutionStateOnInvocationStarted.name, () => {
91112
it('creates an execution state when started arrives before initialization', () => {
92113
const event = buildInvocationStartedEvent();
@@ -219,3 +240,41 @@ describe(getUpdatedNodeExecutionStateOnInvocationComplete.name, () => {
219240
expect(secondUpdate?.outputs).toEqual([firstEvent.result, secondEvent.result]);
220241
});
221242
});
243+
244+
describe(getUpdatedNodeExecutionStateOnInvocationError.name, () => {
245+
it('creates an execution state when error arrives before initialization', () => {
246+
const event = buildInvocationErrorEvent();
247+
const updated = getUpdatedNodeExecutionStateOnInvocationError(undefined, event);
248+
249+
expect(updated?.nodeId).toBe(event.invocation_source_id);
250+
expect(updated?.status).toBe(zNodeStatus.enum.FAILED);
251+
expect(updated?.progress).toBeNull();
252+
expect(updated?.progressImage).toBeNull();
253+
expect(updated?.error).toEqual({
254+
error_type: event.error_type,
255+
error_message: event.error_message,
256+
error_traceback: event.error_traceback,
257+
});
258+
});
259+
260+
it('marks the node failed and records the error', () => {
261+
const event = buildInvocationErrorEvent();
262+
const updated = getUpdatedNodeExecutionStateOnInvocationError(
263+
buildNodeExecutionState({
264+
status: zNodeStatus.enum.IN_PROGRESS,
265+
progress: 0.5,
266+
progressImage: { dataURL: 'data:image/png;base64,abc', width: 64, height: 64 },
267+
}),
268+
event
269+
);
270+
271+
expect(updated?.status).toBe(zNodeStatus.enum.FAILED);
272+
expect(updated?.progress).toBeNull();
273+
expect(updated?.progressImage).toBeNull();
274+
expect(updated?.error).toEqual({
275+
error_type: event.error_type,
276+
error_message: event.error_message,
277+
error_traceback: event.error_traceback,
278+
});
279+
});
280+
});

invokeai/frontend/web/src/services/events/setEventListeners.tsx

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import { imagesApi } from 'services/api/endpoints/images';
3838
import { modelsApi } from 'services/api/endpoints/models';
3939
import { queueApi } from 'services/api/endpoints/queue';
4040
import {
41+
getUpdatedNodeExecutionStateOnInvocationError,
4142
getUpdatedNodeExecutionStateOnInvocationProgress,
4243
getUpdatedNodeExecutionStateOnInvocationStarted,
4344
} from 'services/events/nodeExecutionState';
@@ -161,19 +162,12 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
161162
log.trace({ data } as JsonObject, `Received event for already-finished queue item ${data.item_id}`);
162163
return;
163164
}
164-
const { invocation_source_id, invocation, error_type, error_message, error_traceback } = data;
165+
const { invocation_source_id, invocation } = data;
165166
log.error({ data } as JsonObject, `Invocation error (${invocation.type}, ${invocation_source_id})`);
166-
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
167-
if (nes) {
168-
nes.status = zNodeStatus.enum.FAILED;
169-
nes.progress = null;
170-
nes.progressImage = null;
171-
nes.error = {
172-
error_type,
173-
error_message,
174-
error_traceback,
175-
};
176-
upsertExecutionState(nes.nodeId, nes);
167+
const nes = $nodeExecutionStates.get()[invocation_source_id];
168+
const updatedNodeExecutionState = getUpdatedNodeExecutionStateOnInvocationError(nes, data);
169+
if (updatedNodeExecutionState) {
170+
upsertExecutionState(updatedNodeExecutionState.nodeId, updatedNodeExecutionState);
177171
}
178172
// Clear canvas workflow integration processing state on error
179173
if (data.origin === 'canvas_workflow_integration') {

0 commit comments

Comments
 (0)