diff --git a/invokeai/app/api/routers/client_state.py b/invokeai/app/api/routers/client_state.py index 2e34ea9fe6b..cd92263f97c 100644 --- a/invokeai/app/api/routers/client_state.py +++ b/invokeai/app/api/routers/client_state.py @@ -45,6 +45,44 @@ async def set_client_state( raise HTTPException(status_code=500, detail="Error setting client state") +@client_state_router.get( + "/{queue_id}/get_keys_by_prefix", + operation_id="get_client_state_keys_by_prefix", + response_model=list[str], +) +async def get_client_state_keys_by_prefix( + current_user: CurrentUserOrDefault, + queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"), + prefix: str = Query(..., description="Prefix to filter keys by"), +) -> list[str]: + """Gets client state keys matching a prefix for the current user""" + try: + return ApiDependencies.invoker.services.client_state_persistence.get_keys_by_prefix( + current_user.user_id, prefix + ) + except Exception as e: + logging.error(f"Error getting client state keys: {e}") + raise HTTPException(status_code=500, detail="Error getting client state keys") + + +@client_state_router.post( + "/{queue_id}/delete_by_key", + operation_id="delete_client_state_by_key", + responses={204: {"description": "Client state key deleted"}}, +) +async def delete_client_state_by_key( + current_user: CurrentUserOrDefault, + queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"), + key: str = Query(..., description="Key to delete"), +) -> None: + """Deletes a specific client state key for the current user""" + try: + ApiDependencies.invoker.services.client_state_persistence.delete_by_key(current_user.user_id, key) + except Exception as e: + logging.error(f"Error deleting client state key: {e}") + raise HTTPException(status_code=500, detail="Error deleting client state key") + + @client_state_router.post( "/{queue_id}/delete", operation_id="delete_client_state", diff --git a/invokeai/app/services/client_state_persistence/client_state_persistence_base.py b/invokeai/app/services/client_state_persistence/client_state_persistence_base.py index 99ad71bc8b7..7be6841a790 100644 --- a/invokeai/app/services/client_state_persistence/client_state_persistence_base.py +++ b/invokeai/app/services/client_state_persistence/client_state_persistence_base.py @@ -36,6 +36,31 @@ def get_by_key(self, user_id: str, key: str) -> str | None: """ pass + @abstractmethod + def get_keys_by_prefix(self, user_id: str, prefix: str) -> list[str]: + """ + Get all keys matching a prefix for a user. + + Args: + user_id (str): The user ID to get keys for. + prefix (str): The prefix to filter keys by. + + Returns: + list[str]: A list of keys matching the prefix. + """ + pass + + @abstractmethod + def delete_by_key(self, user_id: str, key: str) -> None: + """ + Delete a specific key-value pair for a user. + + Args: + user_id (str): The user ID to delete state for. + key (str): The key to delete. + """ + pass + @abstractmethod def delete(self, user_id: str) -> None: """ diff --git a/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py b/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py index 643db306857..7a0c0f9f4c9 100644 --- a/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py +++ b/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py @@ -44,6 +44,28 @@ def get_by_key(self, user_id: str, key: str) -> str | None: return None return row[0] + def get_keys_by_prefix(self, user_id: str, prefix: str) -> list[str]: + with self._db.transaction() as cursor: + cursor.execute( + """ + SELECT key FROM client_state + WHERE user_id = ? AND key LIKE ? + ORDER BY updated_at DESC + """, + (user_id, f"{prefix}%"), + ) + return [row[0] for row in cursor.fetchall()] + + def delete_by_key(self, user_id: str, key: str) -> None: + with self._db.transaction() as cursor: + cursor.execute( + """ + DELETE FROM client_state + WHERE user_id = ? AND key = ? + """, + (user_id, key), + ) + def delete(self, user_id: str) -> None: with self._db.transaction() as cursor: cursor.execute( diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 61253826179..525fcd4f4c9 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2997,6 +2997,23 @@ "switchOnStartDesc": "Switch on start", "switchOnFinish": "On Finish", "switchOnFinishDesc": "Switch on finish" + }, + "snapshot": { + "snapshots": "Save or Load Canvas Snapshot", + "saveSnapshot": "Save Snapshot", + "restoreSnapshot": "Restore Snapshot", + "snapshotNamePlaceholder": "Snapshot name", + "save": "Save", + "delete": "Delete", + "snapshotSaved": "Snapshot \"{{name}}\" saved", + "snapshotRestored": "Snapshot \"{{name}}\" restored", + "snapshotDeleted": "Snapshot \"{{name}}\" deleted", + "snapshotSaveFailed": "Failed to save snapshot", + "snapshotRestoreFailed": "Failed to restore snapshot", + "snapshotDeleteFailed": "Failed to delete snapshot", + "snapshotMissingImages_one": "{{count}} image referenced by this snapshot no longer exists and will appear as a placeholder", + "snapshotMissingImages_other": "{{count}} images referenced by this snapshot no longer exist and will appear as placeholders", + "snapshotIncompatible": "This snapshot was created with a different version and is no longer compatible" } }, "upscaling": { diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx index bf186ed6300..76533605965 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx @@ -14,6 +14,7 @@ import { CanvasToolbarRedoButton } from 'features/controlLayers/components/Toolb import { CanvasToolbarResetViewButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarResetViewButton'; import { CanvasToolbarSaveToGalleryButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarSaveToGalleryButton'; import { CanvasToolbarScale } from 'features/controlLayers/components/Toolbar/CanvasToolbarScale'; +import { CanvasToolbarSnapshotMenuButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton'; import { CanvasToolbarUndoButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarUndoButton'; import { useCanvasDeleteLayerHotkey } from 'features/controlLayers/hooks/useCanvasDeleteLayerHotkey'; import { useCanvasEntityQuickSwitchHotkey } from 'features/controlLayers/hooks/useCanvasEntityQuickSwitchHotkey'; @@ -68,6 +69,7 @@ export const CanvasToolbar = memo(() => { + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx new file mode 100644 index 00000000000..9d21041c6e1 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx @@ -0,0 +1,198 @@ +import { + Flex, + IconButton, + Input, + Menu, + MenuButton, + MenuDivider, + MenuGroup, + MenuItem, + MenuList, + Text, +} from '@invoke-ai/ui-library'; +import type { SnapshotInfo } from 'features/controlLayers/hooks/useCanvasSnapshots'; +import { useCanvasSnapshots } from 'features/controlLayers/hooks/useCanvasSnapshots'; +import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice'; +import { toast } from 'features/toast/toast'; +import type { ChangeEvent, KeyboardEvent, MouseEvent } from 'react'; +import { memo, useCallback, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiCameraBold, PiFloppyDiskBold, PiTrashBold } from 'react-icons/pi'; + +const SnapshotItem = memo( + ({ + snapshot, + onRestore, + onDelete, + isRestoreDisabled, + }: { + snapshot: SnapshotInfo; + onRestore: (key: string, name: string) => void; + onDelete: (e: MouseEvent, key: string, name: string) => void; + isRestoreDisabled: boolean; + }) => { + const handleClick = useCallback(() => { + onRestore(snapshot.key, snapshot.name); + }, [onRestore, snapshot.key, snapshot.name]); + + const handleDelete = useCallback( + (e: MouseEvent) => { + onDelete(e, snapshot.key, snapshot.name); + }, + [onDelete, snapshot.key, snapshot.name] + ); + + return ( + + + + {snapshot.name} + + } + size="xs" + variant="ghost" + colorScheme="error" + onClick={handleDelete} + isDisabled={isRestoreDisabled} + /> + + + ); + } +); + +SnapshotItem.displayName = 'SnapshotItem'; + +const getDefaultSnapshotName = (): string => { + const now = new Date(); + const y = now.getFullYear(); + const mo = String(now.getMonth() + 1).padStart(2, '0'); + const d = String(now.getDate()).padStart(2, '0'); + const h = String(now.getHours()).padStart(2, '0'); + const mi = String(now.getMinutes()).padStart(2, '0'); + return `${y}/${mo}/${d} ${h}:${mi}`; +}; + +export const CanvasToolbarSnapshotMenuButton = memo(() => { + const { t } = useTranslation(); + const { snapshots, saveSnapshot, restoreSnapshot, deleteSnapshot } = useCanvasSnapshots(); + const isStaging = useCanvasIsStaging(); + const [snapshotName, setSnapshotName] = useState(''); + + const onNameChange = useCallback((e: ChangeEvent) => { + setSnapshotName(e.target.value); + }, []); + + const onSave = useCallback(async () => { + const name = snapshotName.trim() || getDefaultSnapshotName(); + const success = await saveSnapshot(name); + if (success) { + toast({ title: t('controlLayers.snapshot.snapshotSaved', { name }), status: 'info' }); + setSnapshotName(''); + } else { + toast({ title: t('controlLayers.snapshot.snapshotSaveFailed'), status: 'error' }); + } + }, [snapshotName, saveSnapshot, t]); + + const onKeyDown = useCallback( + (e: KeyboardEvent) => { + if (e.key === 'Enter') { + e.preventDefault(); + e.stopPropagation(); + onSave(); + } + }, + [onSave] + ); + + const onRestore = useCallback( + async (key: string, name: string) => { + const result = await restoreSnapshot(key); + if (result.success) { + if (result.missingImageCount && result.missingImageCount > 0) { + toast({ + title: t('controlLayers.snapshot.snapshotRestored', { name }), + description: t('controlLayers.snapshot.snapshotMissingImages', { count: result.missingImageCount }), + status: 'warning', + }); + } else { + toast({ title: t('controlLayers.snapshot.snapshotRestored', { name }), status: 'info' }); + } + } else if (result.error === 'incompatible') { + toast({ + title: t('controlLayers.snapshot.snapshotRestoreFailed'), + description: t('controlLayers.snapshot.snapshotIncompatible'), + status: 'error', + }); + } else { + toast({ title: t('controlLayers.snapshot.snapshotRestoreFailed'), status: 'error' }); + } + }, + [restoreSnapshot, t] + ); + + const onDelete = useCallback( + async (e: MouseEvent, key: string, name: string) => { + e.stopPropagation(); + const success = await deleteSnapshot(key); + if (success) { + toast({ title: t('controlLayers.snapshot.snapshotDeleted', { name }), status: 'info' }); + } else { + toast({ title: t('controlLayers.snapshot.snapshotDeleteFailed'), status: 'error' }); + } + }, + [deleteSnapshot, t] + ); + + return ( + + } + variant="link" + alignSelf="stretch" + /> + + + + + } + size="sm" + onClick={onSave} + /> + + + {snapshots.length > 0 && ( + <> + + + {snapshots.map((snapshot) => ( + + ))} + + + )} + + + ); +}); + +CanvasToolbarSnapshotMenuButton.displayName = 'CanvasToolbarSnapshotMenuButton'; diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts new file mode 100644 index 00000000000..2c6cf0af9bd --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts @@ -0,0 +1,196 @@ +import { logger } from 'app/logging/logger'; +import { useAppDispatch, useAppStore } from 'app/store/storeHooks'; +import { canvasSnapshotRestored } from 'features/controlLayers/store/canvasSlice'; +import { selectCanvasSlice } from 'features/controlLayers/store/selectors'; +import type { CanvasState } from 'features/controlLayers/store/types'; +import { zCanvasState } from 'features/controlLayers/store/types'; +import { useCallback, useMemo } from 'react'; +import { serializeError } from 'serialize-error'; +import { appInfoApi } from 'services/api/endpoints/appInfo'; +import { + clientStateApi, + useDeleteClientStateByKeyMutation, + useGetClientStateKeysByPrefixQuery, + useSetClientStateByKeyMutation, +} from 'services/api/endpoints/clientState'; +import { getImageDTOSafe } from 'services/api/endpoints/images'; +import type { JsonObject } from 'type-fest'; +import { z } from 'zod'; + +const log = logger('canvas'); + +const SNAPSHOT_PREFIX = 'canvas_snapshot:'; + +/** + * Collect all unique image_name references from a canvas state. + */ +const collectImageNames = (state: CanvasState): string[] => { + const names = new Set(); + + const entityGroups = [state.rasterLayers, state.controlLayers, state.inpaintMasks, state.regionalGuidance]; + for (const group of entityGroups) { + for (const entity of group.entities) { + for (const obj of entity.objects) { + if (obj.type === 'image' && 'image_name' in obj.image) { + names.add(obj.image.image_name); + } + } + } + } + + // Regional guidance reference images (IP Adapter / FLUX Redux) + for (const entity of state.regionalGuidance.entities) { + for (const ref of entity.referenceImages) { + if (ref.config.image && 'image_name' in ref.config.image) { + names.add(ref.config.image.image_name); + } + } + } + + return [...names]; +}; + +/** + * Quick health check to determine if the backend is reachable. + * Uses the existing appInfoApi RTKQ endpoint for consistency. + */ +const isBackendReachable = async (dispatch: ReturnType): Promise => { + const req = dispatch(appInfoApi.endpoints.getAppVersion.initiate(undefined, { subscribe: false })); + try { + await req.unwrap(); + return true; + } catch { + return false; + } finally { + req.unsubscribe(); + } +}; + +/** + * Check which image_names still exist on the server. + * Returns the list of missing image names. If the backend is unreachable, + * skips all checks and returns an empty array to avoid false warnings. + */ +const findMissingImages = async ( + imageNames: string[], + dispatch: ReturnType +): Promise => { + // Pre-flight: verify backend is reachable before checking individual images + if (!(await isBackendReachable(dispatch))) { + log.warn('Backend unreachable — skipping missing image check'); + return []; + } + + const results = await Promise.all( + imageNames.map(async (name) => { + const dto = await getImageDTOSafe(name); + return dto === null ? name : null; + }) + ); + return results.filter((name): name is string => name !== null); +}; + +export type SnapshotInfo = { + key: string; + name: string; +}; + +type RestoreResult = { + success: boolean; + missingImageCount?: number; + error?: 'incompatible' | 'not_found' | 'unknown'; +}; + +export const useCanvasSnapshots = () => { + const dispatch = useAppDispatch(); + const store = useAppStore(); + + const { data: keys } = useGetClientStateKeysByPrefixQuery(SNAPSHOT_PREFIX); + const [setClientState] = useSetClientStateByKeyMutation(); + const [deleteClientState] = useDeleteClientStateByKeyMutation(); + + const snapshots: SnapshotInfo[] = useMemo( + () => + (keys ?? []).map((key) => ({ + key, + name: key.slice(SNAPSHOT_PREFIX.length), + })), + [keys] + ); + + const saveSnapshot = useCallback( + async (name: string) => { + try { + const state = selectCanvasSlice(store.getState()); + const value = JSON.stringify(state); + const key = `${SNAPSHOT_PREFIX}${name}`; + await setClientState({ key, value }).unwrap(); + return true; + } catch (e) { + log.error({ error: serializeError(e) } as JsonObject, 'Failed to save snapshot'); + return false; + } + }, + [store, setClientState] + ); + + const restoreSnapshot = useCallback( + async (key: string): Promise => { + const req = dispatch(clientStateApi.endpoints.getClientStateByKey.initiate(key, { subscribe: false })); + try { + const raw = await req.unwrap(); + if (raw === null) { + throw new Error('Snapshot data not found'); + } + const parsed = JSON.parse(raw); + const canvasState = zCanvasState.parse(parsed); + + // Check for missing images before restoring + const imageNames = collectImageNames(canvasState); + const missingImages = imageNames.length > 0 ? await findMissingImages(imageNames, dispatch) : []; + + if (missingImages.length > 0) { + log.warn( + { missingCount: missingImages.length, total: imageNames.length } as unknown as JsonObject, + 'Snapshot references images that no longer exist' + ); + } + + dispatch(canvasSnapshotRestored(canvasState)); + return { success: true, missingImageCount: missingImages.length }; + } catch (e) { + log.error({ error: serializeError(e) } as JsonObject, 'Failed to restore snapshot'); + // Distinguish Zod validation errors (incompatible snapshot) from other failures + const isZodError = e instanceof z.ZodError; + const isNotFound = e instanceof Error && e.message === 'Snapshot data not found'; + return { + success: false, + error: isZodError ? 'incompatible' : isNotFound ? 'not_found' : 'unknown', + }; + } finally { + req.unsubscribe(); + } + }, + [dispatch] + ); + + const deleteSnapshot = useCallback( + async (key: string) => { + try { + await deleteClientState(key).unwrap(); + return true; + } catch (e) { + log.error({ error: serializeError(e) } as JsonObject, 'Failed to delete snapshot'); + return false; + } + }, + [deleteClientState] + ); + + return { + snapshots, + saveSnapshot, + restoreSnapshot, + deleteSnapshot, + }; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 79d3963d122..bfdea7b1de5 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -1710,6 +1710,22 @@ const slice = createSlice({ state.regionalGuidance.entities = regionalGuidance; return state; }, + canvasSnapshotRestored: (state, action: PayloadAction) => { + const snapshot = action.payload; + state.controlLayers = snapshot.controlLayers; + state.inpaintMasks = snapshot.inpaintMasks; + state.rasterLayers = snapshot.rasterLayers; + state.regionalGuidance = snapshot.regionalGuidance; + // Restore bbox from snapshot but preserve the current modelBase to avoid desync + // with the currently selected model (same pattern as resetState). + const currentModelBase = state.bbox.modelBase; + state.bbox = snapshot.bbox; + state.bbox.modelBase = currentModelBase; + syncScaledSize(state); + state.selectedEntityIdentifier = snapshot.selectedEntityIdentifier; + state.bookmarkedEntityIdentifier = snapshot.bookmarkedEntityIdentifier; + return state; + }, canvasUndo: () => {}, canvasRedo: () => {}, canvasClearHistory: () => {}, @@ -1768,6 +1784,7 @@ const resetState = (state: CanvasState) => { export const { canvasMetadataRecalled, + canvasSnapshotRestored, canvasUndo, canvasRedo, canvasClearHistory, @@ -1893,6 +1910,10 @@ const canvasUndoableConfig: UndoableOptions = { if (!action.type.startsWith(slice.name)) { return false; } + // Snapshot restore replaces the canvas state and should not be undoable + if (action.type === canvasSnapshotRestored.type) { + return false; + } // Throttle rapid actions of the same type filter = actionsThrottlingFilter(action); return filter; diff --git a/invokeai/frontend/web/src/services/api/endpoints/clientState.ts b/invokeai/frontend/web/src/services/api/endpoints/clientState.ts new file mode 100644 index 00000000000..5d3cc96d226 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/endpoints/clientState.ts @@ -0,0 +1,48 @@ +import { api, buildV1Url, LIST_TAG } from '..'; + +/** + * Builds an endpoint URL for the client_state router. + * The queue_id path parameter is kept as 'default' for backwards compatibility. + */ +const buildClientStateUrl = (path: string, query?: Record) => + buildV1Url(`client_state/default/${path}`, query); + +export const clientStateApi = api.injectEndpoints({ + endpoints: (build) => ({ + getClientStateKeysByPrefix: build.query({ + query: (prefix) => ({ + url: buildClientStateUrl('get_keys_by_prefix', { prefix }), + method: 'GET', + }), + providesTags: [{ type: 'ClientState', id: LIST_TAG }, 'FetchOnReconnect'], + }), + getClientStateByKey: build.query({ + query: (key) => ({ + url: buildClientStateUrl('get_by_key', { key }), + method: 'GET', + }), + }), + setClientStateByKey: build.mutation({ + query: ({ key, value }) => ({ + url: buildClientStateUrl('set_by_key', { key }), + method: 'POST', + // Send raw string body — the backend expects Body(...) as a plain string, + // not JSON-encoded. Setting Content-Type to text/plain prevents fetchBaseQuery + // from JSON-stringifying the body. + headers: { 'Content-Type': 'text/plain' }, + body: value, + }), + invalidatesTags: [{ type: 'ClientState', id: LIST_TAG }], + }), + deleteClientStateByKey: build.mutation({ + query: (key) => ({ + url: buildClientStateUrl('delete_by_key', { key }), + method: 'POST', + }), + invalidatesTags: [{ type: 'ClientState', id: LIST_TAG }], + }), + }), +}); + +export const { useGetClientStateKeysByPrefixQuery, useSetClientStateByKeyMutation, useDeleteClientStateByKeyMutation } = + clientStateApi; diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 18216bca2c0..dfb62a53c9b 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -2408,6 +2408,46 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v1/client_state/{queue_id}/get_keys_by_prefix": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * Get Client State Keys By Prefix + * @description Gets client state keys matching a prefix for the current user + */ + get: operations["get_client_state_keys_by_prefix"]; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; + "/api/v1/client_state/{queue_id}/delete_by_key": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get?: never; + put?: never; + /** + * Delete Client State By Key + * @description Deletes a specific client state key for the current user + */ + post: operations["delete_client_state_by_key"]; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/v1/client_state/{queue_id}/delete": { parameters: { query?: never; @@ -34518,6 +34558,83 @@ export interface operations { }; }; }; + get_client_state_keys_by_prefix: { + parameters: { + query: { + /** @description Prefix to filter keys by */ + prefix: string; + }; + header?: never; + path: { + /** @description The queue id (ignored, kept for backwards compatibility) */ + queue_id: string; + }; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": string[]; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + delete_client_state_by_key: { + parameters: { + query: { + /** @description Key to delete */ + key: string; + }; + header?: never; + path: { + /** @description The queue id (ignored, kept for backwards compatibility) */ + queue_id: string; + }; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": unknown; + }; + }; + /** @description Client state key deleted */ + 204: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; delete_client_state: { parameters: { query?: never; diff --git a/tests/app/routers/test_client_state_multiuser.py b/tests/app/routers/test_client_state_multiuser.py index 814c9182fec..4ca1de3bf49 100644 --- a/tests/app/routers/test_client_state_multiuser.py +++ b/tests/app/routers/test_client_state_multiuser.py @@ -297,3 +297,148 @@ def test_complex_json_values(client: TestClient, admin_token: str): ) assert get_response.status_code == status.HTTP_200_OK assert get_response.json() == complex_value + + +def test_get_keys_by_prefix_without_auth(client: TestClient, monkeypatch, mock_invoker: Invoker): + """Test that keys can be retrieved by prefix without authentication.""" + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.routers.client_state.ApiDependencies", MockApiDependencies(mock_invoker)) + + # Set several keys with a common prefix directly + for i in range(3): + mock_invoker.services.client_state_persistence.set_by_key("system", f"canvas_snapshot:snap{i}", f"value{i}") + mock_invoker.services.client_state_persistence.set_by_key("system", "other_key", "other_value") + + # Get keys by prefix + response = client.get("/api/v1/client_state/default/get_keys_by_prefix?prefix=canvas_snapshot:") + assert response.status_code == status.HTTP_200_OK + keys = response.json() + assert len(keys) == 3 + assert "canvas_snapshot:snap0" in keys + assert "canvas_snapshot:snap1" in keys + assert "canvas_snapshot:snap2" in keys + assert "other_key" not in keys + + +def test_get_keys_by_prefix_empty_without_auth(client: TestClient, monkeypatch, mock_invoker: Invoker): + """Test that an empty list is returned when no keys match the prefix.""" + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.routers.client_state.ApiDependencies", MockApiDependencies(mock_invoker)) + + response = client.get("/api/v1/client_state/default/get_keys_by_prefix?prefix=nonexistent_prefix:") + assert response.status_code == status.HTTP_200_OK + assert response.json() == [] + + +def test_delete_by_key_without_auth(client: TestClient, monkeypatch, mock_invoker: Invoker): + """Test that a specific key can be deleted without affecting other keys.""" + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.routers.client_state.ApiDependencies", MockApiDependencies(mock_invoker)) + + # Set two keys directly + mock_invoker.services.client_state_persistence.set_by_key("system", "keep_key", "keep_value") + mock_invoker.services.client_state_persistence.set_by_key("system", "delete_key", "delete_value") + + # Delete only one key via endpoint + delete_response = client.post("/api/v1/client_state/default/delete_by_key?key=delete_key") + assert delete_response.status_code == status.HTTP_200_OK + + # Verify deleted key is gone + value = mock_invoker.services.client_state_persistence.get_by_key("system", "delete_key") + assert value is None + + # Verify other key still exists + value = mock_invoker.services.client_state_persistence.get_by_key("system", "keep_key") + assert value == "keep_value" + + +def test_get_keys_by_prefix(client: TestClient, admin_token: str): + """Test that keys can be retrieved by prefix with authentication.""" + # Set several keys with a common prefix + for i in range(3): + client.post( + f"/api/v1/client_state/default/set_by_key?key=canvas_snapshot:snap{i}", + json=f"value{i}", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + # Set a key without the prefix + client.post( + "/api/v1/client_state/default/set_by_key?key=other_key", + json="other_value", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + # Get keys by prefix + response = client.get( + "/api/v1/client_state/default/get_keys_by_prefix?prefix=canvas_snapshot:", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == status.HTTP_200_OK + keys = response.json() + assert len(keys) == 3 + assert "canvas_snapshot:snap0" in keys + assert "canvas_snapshot:snap1" in keys + assert "canvas_snapshot:snap2" in keys + assert "other_key" not in keys + + +def test_delete_by_key(client: TestClient, admin_token: str): + """Test that a specific key can be deleted without affecting other keys.""" + # Set two keys + client.post( + "/api/v1/client_state/default/set_by_key?key=keep_key", + json="keep_value", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + client.post( + "/api/v1/client_state/default/set_by_key?key=delete_key", + json="delete_value", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + # Delete only one key + delete_response = client.post( + "/api/v1/client_state/default/delete_by_key?key=delete_key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert delete_response.status_code == status.HTTP_200_OK + + # Verify deleted key is gone + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=delete_key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.json() is None + + # Verify other key still exists + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=keep_key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.json() == "keep_value" + + +def test_get_keys_by_prefix_isolation_between_users(client: TestClient, user1_token: str, user2_token: str): + """Test that get_keys_by_prefix is isolated between users.""" + # User 1 sets keys + client.post( + "/api/v1/client_state/default/set_by_key?key=snapshot:u1", + json="user1_data", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + + # User 2 sets keys + client.post( + "/api/v1/client_state/default/set_by_key?key=snapshot:u2", + json="user2_data", + headers={"Authorization": f"Bearer {user2_token}"}, + ) + + # User 1 should only see their own keys + response = client.get( + "/api/v1/client_state/default/get_keys_by_prefix?prefix=snapshot:", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + keys = response.json() + assert "snapshot:u1" in keys + assert "snapshot:u2" not in keys