Skip to content

Commit c2e7a5d

Browse files
committed
Add workflow live update events
Stacked on top of origin PR invoke-ai#9018 (shared/private workflows and boards) for multiuser workflow visibility semantics.
1 parent ed45bd4 commit c2e7a5d

9 files changed

Lines changed: 533 additions & 12 deletions

File tree

invokeai/app/api/routers/workflows.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,23 @@ async def update_workflow(
6666
workflow: Workflow = Body(description="The updated workflow", embed=True),
6767
) -> WorkflowRecordDTO:
6868
"""Updates a workflow"""
69+
try:
70+
existing = ApiDependencies.invoker.services.workflow_records.get(workflow.id)
71+
except WorkflowNotFoundError:
72+
raise HTTPException(status_code=404, detail="Workflow not found")
73+
6974
config = ApiDependencies.invoker.services.configuration
7075
if config.multiuser:
71-
try:
72-
existing = ApiDependencies.invoker.services.workflow_records.get(workflow.id)
73-
except WorkflowNotFoundError:
74-
raise HTTPException(status_code=404, detail="Workflow not found")
7576
if not current_user.is_admin and existing.user_id != current_user.user_id:
7677
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
77-
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow)
78+
updated = ApiDependencies.invoker.services.workflow_records.update(workflow=workflow)
79+
ApiDependencies.invoker.services.events.emit_workflow_updated(
80+
workflow_id=updated.workflow_id,
81+
user_id=updated.user_id,
82+
old_is_public=existing.is_public,
83+
new_is_public=updated.is_public,
84+
)
85+
return updated
7886

7987

8088
@workflows_router.delete(
@@ -86,12 +94,13 @@ async def delete_workflow(
8694
workflow_id: str = Path(description="The workflow to delete"),
8795
) -> None:
8896
"""Deletes a workflow"""
97+
try:
98+
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
99+
except WorkflowNotFoundError:
100+
raise HTTPException(status_code=404, detail="Workflow not found")
101+
89102
config = ApiDependencies.invoker.services.configuration
90103
if config.multiuser:
91-
try:
92-
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
93-
except WorkflowNotFoundError:
94-
raise HTTPException(status_code=404, detail="Workflow not found")
95104
if not current_user.is_admin and existing.user_id != current_user.user_id:
96105
raise HTTPException(status_code=403, detail="Not authorized to delete this workflow")
97106
try:
@@ -100,6 +109,11 @@ async def delete_workflow(
100109
# It's OK if the workflow has no thumbnail file. We can still delete the workflow.
101110
pass
102111
ApiDependencies.invoker.services.workflow_records.delete(workflow_id)
112+
ApiDependencies.invoker.services.events.emit_workflow_deleted(
113+
workflow_id=existing.workflow_id,
114+
user_id=existing.user_id,
115+
is_public=existing.is_public,
116+
)
103117

104118

105119
@workflows_router.post(
@@ -114,7 +128,13 @@ async def create_workflow(
114128
workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True),
115129
) -> WorkflowRecordDTO:
116130
"""Creates a workflow"""
117-
return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow, user_id=current_user.user_id)
131+
created = ApiDependencies.invoker.services.workflow_records.create(workflow=workflow, user_id=current_user.user_id)
132+
ApiDependencies.invoker.services.events.emit_workflow_created(
133+
workflow_id=created.workflow_id,
134+
user_id=created.user_id,
135+
is_public=created.is_public,
136+
)
137+
return created
118138

119139

120140
@workflows_router.get(
@@ -302,9 +322,14 @@ async def update_workflow_is_public(
302322
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
303323
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
304324

305-
return ApiDependencies.invoker.services.workflow_records.update_is_public(
306-
workflow_id=workflow_id, is_public=is_public
325+
updated = ApiDependencies.invoker.services.workflow_records.update_is_public(workflow_id=workflow_id, is_public=is_public)
326+
ApiDependencies.invoker.services.events.emit_workflow_updated(
327+
workflow_id=updated.workflow_id,
328+
user_id=updated.user_id,
329+
old_is_public=existing.is_public,
330+
new_is_public=updated.is_public,
307331
)
332+
return updated
308333

309334

310335
@workflows_router.get("/tags", operation_id="get_all_tags")

invokeai/app/api/sockets.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
QueueEventBase,
3838
QueueItemStatusChangedEvent,
3939
RecallParametersUpdatedEvent,
40+
WorkflowCreatedEvent,
41+
WorkflowDeletedEvent,
42+
WorkflowEventBase,
43+
WorkflowUpdatedEvent,
4044
register_events,
4145
)
4246
from invokeai.backend.util.logging import InvokeAILogger
@@ -86,6 +90,7 @@ class BulkDownloadSubscriptionEvent(BaseModel):
8690
}
8791

8892
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
93+
WORKFLOW_EVENTS = {WorkflowCreatedEvent, WorkflowUpdatedEvent, WorkflowDeletedEvent}
8994

9095

9196
class SocketIO:
@@ -115,6 +120,7 @@ def __init__(self, app: FastAPI):
115120
register_events(QUEUE_EVENTS, self._handle_queue_event)
116121
register_events(MODEL_EVENTS, self._handle_model_event)
117122
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
123+
register_events(WORKFLOW_EVENTS, self._handle_workflow_event)
118124

119125
async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> bool:
120126
"""Handle socket connection and authenticate the user.
@@ -145,6 +151,10 @@ async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> b
145151
logger.info(
146152
f"Socket {sid} connected with user_id: {token_data.user_id}, is_admin: {token_data.is_admin}"
147153
)
154+
await self._sio.enter_room(sid, f"user:{token_data.user_id}")
155+
await self._sio.enter_room(sid, "workflows:shared")
156+
if token_data.is_admin:
157+
await self._sio.enter_room(sid, "admin")
148158
return True
149159

150160
# If no valid token, store system user for backward compatibility
@@ -266,3 +276,28 @@ async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | Downloa
266276

267277
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
268278
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)
279+
280+
async def _handle_workflow_event(self, event: FastAPIEvent[WorkflowEventBase]) -> None:
281+
event_name, event_data = event
282+
payload = event_data.model_dump(mode="json")
283+
284+
await self._sio.emit(event=event_name, data=payload, room=f"user:{event_data.user_id}")
285+
await self._sio.emit(event=event_name, data=payload, room="admin")
286+
287+
if event_name == "workflow_created":
288+
if getattr(event_data, "is_public", False):
289+
await self._sio.emit(event=event_name, data=payload, room="workflows:shared")
290+
return
291+
292+
if event_name == "workflow_deleted":
293+
if getattr(event_data, "is_public", False):
294+
await self._sio.emit(event=event_name, data=payload, room="workflows:shared")
295+
return
296+
297+
if event_name == "workflow_updated":
298+
if getattr(event_data, "new_is_public", False):
299+
await self._sio.emit(event=event_name, data=payload, room="workflows:shared")
300+
elif getattr(event_data, "old_is_public", False):
301+
await self._sio.emit(
302+
event="workflow_deleted", data={"workflow_id": event_data.workflow_id}, room="workflows:shared"
303+
)

invokeai/app/services/events/events_base.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
QueueItemsRetriedEvent,
3333
QueueItemStatusChangedEvent,
3434
RecallParametersUpdatedEvent,
35+
WorkflowCreatedEvent,
36+
WorkflowDeletedEvent,
37+
WorkflowUpdatedEvent,
3538
)
3639

3740
if TYPE_CHECKING:
@@ -118,6 +121,29 @@ def emit_recall_parameters_updated(self, queue_id: str, parameters: dict) -> Non
118121

119122
# endregion
120123

124+
# region Workflow library
125+
126+
def emit_workflow_created(self, workflow_id: str, user_id: str, is_public: bool) -> None:
127+
"""Emitted when a workflow is created."""
128+
self.dispatch(WorkflowCreatedEvent.build(workflow_id=workflow_id, user_id=user_id, is_public=is_public))
129+
130+
def emit_workflow_updated(self, workflow_id: str, user_id: str, old_is_public: bool, new_is_public: bool) -> None:
131+
"""Emitted when a workflow is updated."""
132+
self.dispatch(
133+
WorkflowUpdatedEvent.build(
134+
workflow_id=workflow_id,
135+
user_id=user_id,
136+
old_is_public=old_is_public,
137+
new_is_public=new_is_public,
138+
)
139+
)
140+
141+
def emit_workflow_deleted(self, workflow_id: str, user_id: str, is_public: bool) -> None:
142+
"""Emitted when a workflow is deleted."""
143+
self.dispatch(WorkflowDeletedEvent.build(workflow_id=workflow_id, user_id=user_id, is_public=is_public))
144+
145+
# endregion
146+
121147
# region Download
122148

123149
def emit_download_started(self, job: "DownloadJob") -> None:

invokeai/app/services/events/events_common.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,60 @@ def build(cls, queue_id: str) -> "QueueClearedEvent":
321321
return cls(queue_id=queue_id)
322322

323323

324+
class WorkflowEventBase(EventBase):
325+
"""Base class for workflow library CRUD events."""
326+
327+
workflow_id: str = Field(description="The ID of the workflow")
328+
user_id: str = Field(description="The owner of the workflow")
329+
330+
331+
@payload_schema.register
332+
class WorkflowCreatedEvent(WorkflowEventBase):
333+
"""Event model for workflow_created"""
334+
335+
__event_name__ = "workflow_created"
336+
337+
is_public: bool = Field(description="Whether the workflow is shared with all users")
338+
339+
@classmethod
340+
def build(cls, workflow_id: str, user_id: str, is_public: bool) -> "WorkflowCreatedEvent":
341+
return cls(workflow_id=workflow_id, user_id=user_id, is_public=is_public)
342+
343+
344+
@payload_schema.register
345+
class WorkflowUpdatedEvent(WorkflowEventBase):
346+
"""Event model for workflow_updated"""
347+
348+
__event_name__ = "workflow_updated"
349+
350+
old_is_public: bool = Field(description="Whether the workflow was shared before the update")
351+
new_is_public: bool = Field(description="Whether the workflow is shared after the update")
352+
353+
@classmethod
354+
def build(
355+
cls, workflow_id: str, user_id: str, old_is_public: bool, new_is_public: bool
356+
) -> "WorkflowUpdatedEvent":
357+
return cls(
358+
workflow_id=workflow_id,
359+
user_id=user_id,
360+
old_is_public=old_is_public,
361+
new_is_public=new_is_public,
362+
)
363+
364+
365+
@payload_schema.register
366+
class WorkflowDeletedEvent(WorkflowEventBase):
367+
"""Event model for workflow_deleted"""
368+
369+
__event_name__ = "workflow_deleted"
370+
371+
is_public: bool = Field(description="Whether the workflow was shared when it was deleted")
372+
373+
@classmethod
374+
def build(cls, workflow_id: str, user_id: str, is_public: bool) -> "WorkflowDeletedEvent":
375+
return cls(workflow_id=workflow_id, user_id=user_id, is_public=is_public)
376+
377+
324378
class DownloadEventBase(EventBase):
325379
"""Base class for events associated with a download"""
326380

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import { LIST_TAG } from 'services/api';
2+
import { describe, expect, it, vi } from 'vitest';
3+
4+
import { setEventListeners } from './setEventListeners';
5+
6+
vi.mock('app/logging/logger', () => ({
7+
logger: () => ({
8+
debug: vi.fn(),
9+
trace: vi.fn(),
10+
info: vi.fn(),
11+
warn: vi.fn(),
12+
error: vi.fn(),
13+
}),
14+
}));
15+
16+
vi.mock('features/toast/toast', () => ({
17+
toast: vi.fn(),
18+
}));
19+
20+
vi.mock('./onInvocationComplete', () => ({
21+
buildOnInvocationComplete: () => vi.fn(),
22+
}));
23+
24+
vi.mock('./onModelInstallError', () => ({
25+
buildOnModelInstallError: () => vi.fn(),
26+
DiscordLink: () => null,
27+
GitHubIssuesLink: () => null,
28+
}));
29+
30+
const createMockSocket = () => {
31+
const handlers = new Map<string, (...args: Array<unknown>) => void>();
32+
33+
return {
34+
on: vi.fn((event: string, handler: (...args: Array<unknown>) => void) => {
35+
handlers.set(event, handler);
36+
}),
37+
emit: vi.fn(),
38+
trigger: (event: string, payload?: unknown) => {
39+
const handler = handlers.get(event);
40+
if (!handler) {
41+
throw new Error(`No handler registered for ${event}`);
42+
}
43+
handler(payload);
44+
},
45+
};
46+
};
47+
48+
describe('setEventListeners workflow live updates', () => {
49+
it('invalidates workflow list caches on workflow_created', () => {
50+
const socket = createMockSocket();
51+
const dispatch = vi.fn();
52+
const store = {
53+
dispatch,
54+
getState: vi.fn(() => ({})),
55+
};
56+
57+
setEventListeners({
58+
socket: socket as never,
59+
store: store as never,
60+
setIsConnected: vi.fn(),
61+
});
62+
63+
socket.trigger('workflow_created', { workflow_id: 'wf-1', is_public: true });
64+
65+
expect(dispatch).toHaveBeenCalledWith(
66+
expect.objectContaining({
67+
payload: expect.arrayContaining([
68+
{ type: 'Workflow', id: LIST_TAG },
69+
'WorkflowTags',
70+
'WorkflowTagCounts',
71+
'WorkflowCategoryCounts',
72+
]),
73+
})
74+
);
75+
});
76+
77+
it('ignores unrelated events for workflow cache invalidation', () => {
78+
const socket = createMockSocket();
79+
const dispatch = vi.fn();
80+
const store = {
81+
dispatch,
82+
getState: vi.fn(() => ({})),
83+
};
84+
85+
setEventListeners({
86+
socket: socket as never,
87+
store: store as never,
88+
setIsConnected: vi.fn(),
89+
});
90+
91+
socket.trigger('download_started', { source: 'x', download_path: '/tmp/x' });
92+
93+
expect(dispatch).not.toHaveBeenCalledWith(
94+
expect.objectContaining({
95+
payload: expect.arrayContaining([{ type: 'Workflow', id: LIST_TAG }]),
96+
})
97+
);
98+
});
99+
});

invokeai/frontend/web/src/services/events/setEventListeners.tsx

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,32 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
9898
setIsConnected(false);
9999
});
100100

101+
const invalidateWorkflowLibrary = () => {
102+
dispatch(
103+
api.util.invalidateTags([
104+
{ type: 'Workflow', id: LIST_TAG },
105+
'WorkflowTags',
106+
'WorkflowTagCounts',
107+
'WorkflowCategoryCounts',
108+
])
109+
);
110+
};
111+
112+
socket.on('workflow_created', (data) => {
113+
log.debug({ data }, 'Workflow created');
114+
invalidateWorkflowLibrary();
115+
});
116+
117+
socket.on('workflow_updated', (data) => {
118+
log.debug({ data }, 'Workflow updated');
119+
invalidateWorkflowLibrary();
120+
});
121+
122+
socket.on('workflow_deleted', (data) => {
123+
log.debug({ data }, 'Workflow deleted');
124+
invalidateWorkflowLibrary();
125+
});
126+
101127
socket.on('invocation_started', (data) => {
102128
if (finishedQueueItemIds.has(data.item_id)) {
103129
return;

0 commit comments

Comments
 (0)