diff --git a/invokeai/frontend/web/src/services/events/nodeExecutionState.ts b/invokeai/frontend/web/src/services/events/nodeExecutionState.ts new file mode 100644 index 00000000000..ed2bbce6c3d --- /dev/null +++ b/invokeai/frontend/web/src/services/events/nodeExecutionState.ts @@ -0,0 +1,87 @@ +import { deepClone } from 'common/util/deepClone'; +import type { NodeExecutionState } from 'features/nodes/types/invocation'; +import { zNodeStatus } from 'features/nodes/types/invocation'; +import type { S } from 'services/api/types'; + +const getInvocationKey = (data: { item_id: number; invocation: { id: string } }) => + `${data.item_id}:${data.invocation.id}`; + +const getInitialNodeExecutionState = (nodeId: string): NodeExecutionState => ({ + nodeId, + status: zNodeStatus.enum.PENDING, + progress: null, + progressImage: null, + outputs: [], + error: null, +}); + +export const getUpdatedNodeExecutionStateOnInvocationStarted = ( + nodeExecutionState: NodeExecutionState | undefined, + data: S['InvocationStartedEvent'], + completedInvocationKeys: Set +) => { + if (completedInvocationKeys.has(getInvocationKey(data))) { + return; + } + + const _nodeExecutionState = deepClone(nodeExecutionState ?? getInitialNodeExecutionState(data.invocation_source_id)); + _nodeExecutionState.status = zNodeStatus.enum.IN_PROGRESS; + + return _nodeExecutionState; +}; + +export const getUpdatedNodeExecutionStateOnInvocationProgress = ( + nodeExecutionState: NodeExecutionState | undefined, + data: S['InvocationProgressEvent'], + completedInvocationKeys: Set +) => { + if (completedInvocationKeys.has(getInvocationKey(data))) { + return; + } + + const _nodeExecutionState = deepClone(nodeExecutionState ?? getInitialNodeExecutionState(data.invocation_source_id)); + _nodeExecutionState.status = zNodeStatus.enum.IN_PROGRESS; + _nodeExecutionState.progress = data.percentage ?? null; + _nodeExecutionState.progressImage = data.image ?? null; + + return _nodeExecutionState; +}; + +export const getUpdatedNodeExecutionStateOnInvocationComplete = ( + nodeExecutionState: NodeExecutionState | undefined, + data: S['InvocationCompleteEvent'], + completedInvocationKeys: Set +) => { + const completedInvocationKey = getInvocationKey(data); + + if (completedInvocationKeys.has(completedInvocationKey)) { + return; + } + + const _nodeExecutionState = deepClone(nodeExecutionState ?? getInitialNodeExecutionState(data.invocation_source_id)); + _nodeExecutionState.status = zNodeStatus.enum.COMPLETED; + if (_nodeExecutionState.progress !== null) { + _nodeExecutionState.progress = 1; + } + _nodeExecutionState.outputs.push(data.result); + completedInvocationKeys.add(completedInvocationKey); + + return _nodeExecutionState; +}; + +export const getUpdatedNodeExecutionStateOnInvocationError = ( + nodeExecutionState: NodeExecutionState | undefined, + data: S['InvocationErrorEvent'] +) => { + const _nodeExecutionState = deepClone(nodeExecutionState ?? getInitialNodeExecutionState(data.invocation_source_id)); + _nodeExecutionState.status = zNodeStatus.enum.FAILED; + _nodeExecutionState.progress = null; + _nodeExecutionState.progressImage = null; + _nodeExecutionState.error = { + error_type: data.error_type, + error_message: data.error_message, + error_traceback: data.error_traceback, + }; + + return _nodeExecutionState; +}; diff --git a/invokeai/frontend/web/src/services/events/nodeExecutionStateHelpers.test.ts b/invokeai/frontend/web/src/services/events/nodeExecutionStateHelpers.test.ts new file mode 100644 index 00000000000..6e0848ec07b --- /dev/null +++ b/invokeai/frontend/web/src/services/events/nodeExecutionStateHelpers.test.ts @@ -0,0 +1,280 @@ +import type { NodeExecutionState } from 'features/nodes/types/invocation'; +import { zNodeStatus } from 'features/nodes/types/invocation'; +import type { S } from 'services/api/types'; +import { describe, expect, it } from 'vitest'; + +import { + getUpdatedNodeExecutionStateOnInvocationComplete, + getUpdatedNodeExecutionStateOnInvocationError, + getUpdatedNodeExecutionStateOnInvocationProgress, + getUpdatedNodeExecutionStateOnInvocationStarted, +} from './nodeExecutionState'; + +const buildNodeExecutionState = (overrides: Partial = {}): NodeExecutionState => ({ + nodeId: 'node-1', + status: zNodeStatus.enum.PENDING, + progress: null, + progressImage: null, + outputs: [], + error: null, + ...overrides, +}); + +const buildInvocationStartedEvent = ( + overrides: Partial = {} +): S['InvocationStartedEvent'] => + ({ + queue_id: 'default', + item_id: 1, + batch_id: 'batch-1', + origin: 'workflows', + destination: 'gallery', + user_id: 'user-1', + session_id: 'session-1', + invocation_source_id: 'node-1', + invocation: { + id: 'prepared-node-1', + type: 'add', + }, + ...overrides, + }) as S['InvocationStartedEvent']; + +const buildInvocationProgressEvent = ( + overrides: Partial = {} +): S['InvocationProgressEvent'] => + ({ + queue_id: 'default', + item_id: 1, + batch_id: 'batch-1', + origin: 'workflows', + destination: 'gallery', + user_id: 'user-1', + session_id: 'session-1', + invocation_source_id: 'node-1', + invocation: { + id: 'prepared-node-1', + type: 'add', + }, + percentage: 0.42, + image: { + dataURL: 'data:image/png;base64,abc', + width: 64, + height: 64, + }, + message: 'working', + ...overrides, + }) as S['InvocationProgressEvent']; + +const buildInvocationCompleteEvent = ( + overrides: Partial = {} +): S['InvocationCompleteEvent'] => + ({ + queue_id: 'default', + item_id: 1, + batch_id: 'batch-1', + origin: 'workflows', + destination: 'gallery', + user_id: 'user-1', + session_id: 'session-1', + invocation_source_id: 'node-1', + invocation: { + id: 'prepared-node-1', + type: 'add', + }, + result: { + type: 'integer_output', + value: 42, + }, + ...overrides, + }) as S['InvocationCompleteEvent']; + +const buildInvocationErrorEvent = (overrides: Partial = {}): S['InvocationErrorEvent'] => + ({ + queue_id: 'default', + item_id: 1, + batch_id: 'batch-1', + origin: 'workflows', + destination: 'gallery', + user_id: 'user-1', + session_id: 'session-1', + invocation_source_id: 'node-1', + invocation: { + id: 'prepared-node-1', + type: 'add', + }, + error_type: 'TestError', + error_message: 'boom', + error_traceback: 'traceback', + ...overrides, + }) as S['InvocationErrorEvent']; + +describe(getUpdatedNodeExecutionStateOnInvocationStarted.name, () => { + it('creates an execution state when started arrives before initialization', () => { + const event = buildInvocationStartedEvent(); + const updated = getUpdatedNodeExecutionStateOnInvocationStarted(undefined, event, new Set()); + + expect(updated?.nodeId).toBe(event.invocation_source_id); + expect(updated?.status).toBe(zNodeStatus.enum.IN_PROGRESS); + expect(updated?.outputs).toEqual([]); + }); + + it('marks the node in progress on invocation start', () => { + const updated = getUpdatedNodeExecutionStateOnInvocationStarted( + buildNodeExecutionState(), + buildInvocationStartedEvent(), + new Set() + ); + + expect(updated?.status).toBe(zNodeStatus.enum.IN_PROGRESS); + }); + + it('ignores a late started event after that invocation already completed', () => { + const event = buildInvocationStartedEvent(); + const updated = getUpdatedNodeExecutionStateOnInvocationStarted( + buildNodeExecutionState({ status: zNodeStatus.enum.COMPLETED, progress: 1 }), + event, + new Set([`${event.item_id}:${event.invocation.id}`]) + ); + + expect(updated).toBeUndefined(); + }); +}); + +describe(getUpdatedNodeExecutionStateOnInvocationProgress.name, () => { + it('creates an execution state when progress arrives before initialization', () => { + const event = buildInvocationProgressEvent(); + const updated = getUpdatedNodeExecutionStateOnInvocationProgress(undefined, event, new Set()); + + expect(updated?.nodeId).toBe(event.invocation_source_id); + expect(updated?.status).toBe(zNodeStatus.enum.IN_PROGRESS); + expect(updated?.progress).toBe(event.percentage); + expect(updated?.progressImage).toEqual(event.image); + }); + + it('marks the node in progress and preserves progress updates', () => { + const event = buildInvocationProgressEvent(); + const updated = getUpdatedNodeExecutionStateOnInvocationProgress( + buildNodeExecutionState(), + event, + new Set() + ); + + expect(updated?.status).toBe(zNodeStatus.enum.IN_PROGRESS); + expect(updated?.progress).toBe(event.percentage); + expect(updated?.progressImage).toEqual(event.image); + }); + + it('ignores a late progress event after that invocation already completed', () => { + const event = buildInvocationProgressEvent(); + const updated = getUpdatedNodeExecutionStateOnInvocationProgress( + buildNodeExecutionState({ status: zNodeStatus.enum.COMPLETED, progress: 1 }), + event, + new Set([`${event.item_id}:${event.invocation.id}`]) + ); + + expect(updated).toBeUndefined(); + }); +}); + +describe(getUpdatedNodeExecutionStateOnInvocationComplete.name, () => { + it('creates an execution state when completion arrives before initialization', () => { + const event = buildInvocationCompleteEvent(); + const completedInvocationKeys = new Set(); + const updated = getUpdatedNodeExecutionStateOnInvocationComplete(undefined, event, completedInvocationKeys); + + expect(updated?.nodeId).toBe(event.invocation_source_id); + expect(updated?.status).toBe(zNodeStatus.enum.COMPLETED); + expect(updated?.outputs).toEqual([event.result]); + expect(completedInvocationKeys).toEqual(new Set([`${event.item_id}:${event.invocation.id}`])); + }); + + it('records a completed invocation result once', () => { + const event = buildInvocationCompleteEvent(); + const completedInvocationKeys = new Set(); + + const updated = getUpdatedNodeExecutionStateOnInvocationComplete( + buildNodeExecutionState({ status: zNodeStatus.enum.IN_PROGRESS, progress: 0.5 }), + event, + completedInvocationKeys + ); + + expect(updated?.status).toBe(zNodeStatus.enum.COMPLETED); + expect(updated?.progress).toBe(1); + expect(updated?.outputs).toEqual([event.result]); + expect(completedInvocationKeys).toEqual(new Set([`${event.item_id}:${event.invocation.id}`])); + }); + + it('ignores duplicate completion events for the same invocation', () => { + const event = buildInvocationCompleteEvent(); + const updated = getUpdatedNodeExecutionStateOnInvocationComplete( + buildNodeExecutionState({ status: zNodeStatus.enum.COMPLETED, progress: 1, outputs: [event.result] }), + event, + new Set([`${event.item_id}:${event.invocation.id}`]) + ); + + expect(updated).toBeUndefined(); + }); + + it('allows the same prepared invocation id on a different queue item', () => { + const firstEvent = buildInvocationCompleteEvent({ + item_id: 1, + result: { type: 'integer_output', value: 1 } as unknown as S['InvocationCompleteEvent']['result'], + }); + const secondEvent = buildInvocationCompleteEvent({ + item_id: 2, + result: { type: 'integer_output', value: 2 } as unknown as S['InvocationCompleteEvent']['result'], + }); + const completedInvocationKeys = new Set(); + + const firstUpdate = getUpdatedNodeExecutionStateOnInvocationComplete( + buildNodeExecutionState(), + firstEvent, + completedInvocationKeys + ); + const secondUpdate = getUpdatedNodeExecutionStateOnInvocationComplete( + firstUpdate, + secondEvent, + completedInvocationKeys + ); + + expect(secondUpdate?.outputs).toEqual([firstEvent.result, secondEvent.result]); + }); +}); + +describe(getUpdatedNodeExecutionStateOnInvocationError.name, () => { + it('creates an execution state when error arrives before initialization', () => { + const event = buildInvocationErrorEvent(); + const updated = getUpdatedNodeExecutionStateOnInvocationError(undefined, event); + + expect(updated?.nodeId).toBe(event.invocation_source_id); + expect(updated?.status).toBe(zNodeStatus.enum.FAILED); + expect(updated?.progress).toBeNull(); + expect(updated?.progressImage).toBeNull(); + expect(updated?.error).toEqual({ + error_type: event.error_type, + error_message: event.error_message, + error_traceback: event.error_traceback, + }); + }); + + it('marks the node failed and records the error', () => { + const event = buildInvocationErrorEvent(); + const updated = getUpdatedNodeExecutionStateOnInvocationError( + buildNodeExecutionState({ + status: zNodeStatus.enum.IN_PROGRESS, + progress: 0.5, + progressImage: { dataURL: 'data:image/png;base64,abc', width: 64, height: 64 }, + }), + event + ); + + expect(updated?.status).toBe(zNodeStatus.enum.FAILED); + expect(updated?.progress).toBeNull(); + expect(updated?.progressImage).toBeNull(); + expect(updated?.error).toEqual({ + error_type: event.error_type, + error_message: event.error_message, + error_traceback: event.error_traceback, + }); + }); +}); diff --git a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx index 14bdf343ec4..a9403c112e0 100644 --- a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx +++ b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx @@ -1,6 +1,5 @@ import { logger } from 'app/logging/logger'; import type { AppDispatch, AppGetState } from 'app/store/store'; -import { deepClone } from 'common/util/deepClone'; import { canvasWorkflowIntegrationProcessingCompleted } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice'; import { selectAutoSwitch, @@ -12,8 +11,6 @@ import { import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState'; import { isImageField, isImageFieldCollection } from 'features/nodes/types/common'; -import { zNodeStatus } from 'features/nodes/types/invocation'; -import type { LRUCache } from 'lru-cache'; import { LIST_ALL_TAG } from 'services/api'; import { boardsApi } from 'services/api/endpoints/boards'; import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images'; @@ -21,6 +18,7 @@ import { queueApi } from 'services/api/endpoints/queue'; import type { ImageDTO, S } from 'services/api/types'; import { getCategories } from 'services/api/util'; import { insertImageIntoNamesResult } from 'services/api/util/optimisticUpdates'; +import { getUpdatedNodeExecutionStateOnInvocationComplete } from 'services/events/nodeExecutionState'; import { $lastProgressEvent } from 'services/events/stores'; import stableHash from 'stable-hash'; import type { Param0 } from 'tsafe'; @@ -38,13 +36,13 @@ const nodeTypeDenylist = ['load_image', 'image']; * * @param getState The Redux getState function. * @param dispatch The Redux dispatch function. - * @param finishedQueueItemIds A cache of finished queue item IDs to prevent duplicate handling and avoid race - * conditions that can happen when a graph finishes very quickly. + * @param completedInvocationKeys A listener-local set used to dedupe repeated invocation completion events and to + * share completion knowledge with the other invocation event handlers. */ export const buildOnInvocationComplete = ( getState: AppGetState, dispatch: AppDispatch, - finishedQueueItemIds: LRUCache + completedInvocationKeys: Set ) => { const addImagesToGallery = async (data: S['InvocationCompleteEvent']) => { if (nodeTypeDenylist.includes(data.invocation.type)) { @@ -242,22 +240,24 @@ export const buildOnInvocationComplete = ( }; return async (data: S['InvocationCompleteEvent']) => { - if (finishedQueueItemIds.has(data.item_id)) { - log.trace({ data } as JsonObject, `Received event for already-finished queue item ${data.item_id}`); - return; - } log.debug({ data } as JsonObject, `Invocation complete (${data.invocation.type}, ${data.invocation_source_id})`); const nodeExecutionState = $nodeExecutionStates.get()[data.invocation_source_id]; + const updatedNodeExecutionState = getUpdatedNodeExecutionStateOnInvocationComplete( + nodeExecutionState, + data, + completedInvocationKeys + ); - if (nodeExecutionState) { - const _nodeExecutionState = deepClone(nodeExecutionState); - _nodeExecutionState.status = zNodeStatus.enum.COMPLETED; - if (_nodeExecutionState.progress !== null) { - _nodeExecutionState.progress = 1; - } - _nodeExecutionState.outputs.push(data.result); - upsertExecutionState(_nodeExecutionState.nodeId, _nodeExecutionState); + if (nodeExecutionState && !updatedNodeExecutionState) { + log.trace( + { data } as JsonObject, + `Ignoring duplicate invocation complete (${data.invocation.type}, ${data.invocation_source_id})` + ); + } + + if (updatedNodeExecutionState) { + upsertExecutionState(updatedNodeExecutionState.nodeId, updatedNodeExecutionState); } // Clear canvas workflow integration processing state if needed diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index fb08fc08dd1..f4f2d0268bb 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -37,6 +37,11 @@ import { api, LIST_ALL_TAG, LIST_TAG } from 'services/api'; import { imagesApi } from 'services/api/endpoints/images'; import { modelsApi } from 'services/api/endpoints/models'; import { queueApi } from 'services/api/endpoints/queue'; +import { + getUpdatedNodeExecutionStateOnInvocationError, + getUpdatedNodeExecutionStateOnInvocationProgress, + getUpdatedNodeExecutionStateOnInvocationStarted, +} from 'services/events/nodeExecutionState'; import { buildOnInvocationComplete } from 'services/events/onInvocationComplete'; import { buildOnModelInstallError, DiscordLink, GitHubIssuesLink } from 'services/events/onModelInstallError'; import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; @@ -65,6 +70,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis // We can have race conditions where we receive a progress event for a queue item that has already finished. Easiest // way to handle this is to keep track of finished queue items in a cache and ignore progress events for those. const finishedQueueItemIds = new LRUCache({ max: 100 }); + const completedInvocationKeys = new Set(); socket.on('connect', () => { log.debug('Connected'); @@ -107,10 +113,14 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis } const { invocation_source_id, invocation } = data; log.debug({ data } as JsonObject, `Invocation started (${invocation.type}, ${invocation_source_id})`); - const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); - if (nes) { - nes.status = zNodeStatus.enum.IN_PROGRESS; - upsertExecutionState(nes.nodeId, nes); + const nes = $nodeExecutionStates.get()[invocation_source_id]; + const updatedNodeExecutionState = getUpdatedNodeExecutionStateOnInvocationStarted( + nes, + data, + completedInvocationKeys + ); + if (updatedNodeExecutionState) { + upsertExecutionState(updatedNodeExecutionState.nodeId, updatedNodeExecutionState); } }); @@ -119,7 +129,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.trace({ data } as JsonObject, `Received event for already-finished queue item ${data.item_id}`); return; } - const { invocation_source_id, invocation, image, origin, percentage, message } = data; + const { invocation_source_id, invocation, origin, percentage, message } = data; let _message = 'Invocation progress'; if (message) { @@ -135,12 +145,14 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis $lastProgressEvent.set(data); if (origin === 'workflows') { - const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); - if (nes) { - nes.status = zNodeStatus.enum.IN_PROGRESS; - nes.progress = percentage; - nes.progressImage = image ?? null; - upsertExecutionState(nes.nodeId, nes); + const nes = $nodeExecutionStates.get()[invocation_source_id]; + const updatedNodeExecutionState = getUpdatedNodeExecutionStateOnInvocationProgress( + nes, + data, + completedInvocationKeys + ); + if (updatedNodeExecutionState) { + upsertExecutionState(updatedNodeExecutionState.nodeId, updatedNodeExecutionState); } } }); @@ -150,19 +162,12 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.trace({ data } as JsonObject, `Received event for already-finished queue item ${data.item_id}`); return; } - const { invocation_source_id, invocation, error_type, error_message, error_traceback } = data; + const { invocation_source_id, invocation } = data; log.error({ data } as JsonObject, `Invocation error (${invocation.type}, ${invocation_source_id})`); - const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); - if (nes) { - nes.status = zNodeStatus.enum.FAILED; - nes.progress = null; - nes.progressImage = null; - nes.error = { - error_type, - error_message, - error_traceback, - }; - upsertExecutionState(nes.nodeId, nes); + const nes = $nodeExecutionStates.get()[invocation_source_id]; + const updatedNodeExecutionState = getUpdatedNodeExecutionStateOnInvocationError(nes, data); + if (updatedNodeExecutionState) { + upsertExecutionState(updatedNodeExecutionState.nodeId, updatedNodeExecutionState); } // Clear canvas workflow integration processing state on error if (data.origin === 'canvas_workflow_integration') { @@ -170,7 +175,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis } }); - const onInvocationComplete = buildOnInvocationComplete(getState, dispatch, finishedQueueItemIds); + const onInvocationComplete = buildOnInvocationComplete(getState, dispatch, completedInvocationKeys); socket.on('invocation_complete', onInvocationComplete); socket.on('model_load_started', (data) => {