diff --git a/invokeai/app/api/routers/auth.py b/invokeai/app/api/routers/auth.py
index 36aeabda822..e0b0c885cd2 100644
--- a/invokeai/app/api/routers/auth.py
+++ b/invokeai/app/api/routers/auth.py
@@ -80,6 +80,7 @@ class SetupStatusResponse(BaseModel):
setup_required: bool = Field(description="Whether initial setup is required")
multiuser_enabled: bool = Field(description="Whether multiuser mode is enabled")
strict_password_checking: bool = Field(description="Whether strict password requirements are enforced")
+ admin_email: str | None = Field(default=None, description="Email of the first active admin user, if any")
@auth_router.get("/status", response_model=SetupStatusResponse)
@@ -94,15 +95,25 @@ async def get_setup_status() -> SetupStatusResponse:
# If multiuser is disabled, setup is never required
if not config.multiuser:
return SetupStatusResponse(
- setup_required=False, multiuser_enabled=False, strict_password_checking=config.strict_password_checking
+ setup_required=False,
+ multiuser_enabled=False,
+ strict_password_checking=config.strict_password_checking,
+ admin_email=None,
)
# In multiuser mode, check if an admin exists
user_service = ApiDependencies.invoker.services.users
setup_required = not user_service.has_admin()
+ # Only expose admin_email during initial setup to avoid leaking
+ # administrator identity on public deployments.
+ admin_email = user_service.get_admin_email() if setup_required else None
+
return SetupStatusResponse(
- setup_required=setup_required, multiuser_enabled=True, strict_password_checking=config.strict_password_checking
+ setup_required=setup_required,
+ multiuser_enabled=True,
+ strict_password_checking=config.strict_password_checking,
+ admin_email=admin_email,
)
diff --git a/invokeai/app/api/routers/board_images.py b/invokeai/app/api/routers/board_images.py
index cb5e0ab51ab..f94e4f2437c 100644
--- a/invokeai/app/api/routers/board_images.py
+++ b/invokeai/app/api/routers/board_images.py
@@ -1,12 +1,53 @@
from fastapi import Body, HTTPException
from fastapi.routing import APIRouter
+from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
+def _assert_board_write_access(board_id: str, current_user: CurrentUserOrDefault) -> None:
+ """Raise 403 if the current user may not mutate the given board.
+
+ Write access is granted when ANY of these hold:
+ - The user is an admin.
+ - The user owns the board.
+ - The board visibility is Public (public boards accept contributions from any user).
+ """
+ from invokeai.app.services.board_records.board_records_common import BoardVisibility
+
+ try:
+ board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
+ except Exception:
+ raise HTTPException(status_code=404, detail="Board not found")
+ if current_user.is_admin:
+ return
+ if board.user_id == current_user.user_id:
+ return
+ if board.board_visibility == BoardVisibility.Public:
+ return
+ raise HTTPException(status_code=403, detail="Not authorized to modify this board")
+
+
+def _assert_image_direct_owner(image_name: str, current_user: CurrentUserOrDefault) -> None:
+ """Raise 403 if the current user is not the direct owner of the image.
+
+ This is intentionally stricter than _assert_image_owner in images.py:
+ board ownership is NOT sufficient here. Allowing a user to add someone
+ else's image to their own board would grant them mutation rights via the
+ board-ownership fallback in _assert_image_owner, escalating read access
+ into write access.
+ """
+ if current_user.is_admin:
+ return
+ owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
+ if owner is not None and owner == current_user.user_id:
+ return
+ raise HTTPException(status_code=403, detail="Not authorized to move this image")
+
+
@board_images_router.post(
"/",
operation_id="add_image_to_board",
@@ -17,14 +58,17 @@
response_model=AddImagesToBoardResult,
)
async def add_image_to_board(
+ current_user: CurrentUserOrDefault,
board_id: str = Body(description="The id of the board to add to"),
image_name: str = Body(description="The name of the image to add"),
) -> AddImagesToBoardResult:
"""Creates a board_image"""
+ _assert_board_write_access(board_id, current_user)
+ _assert_image_direct_owner(image_name, current_user)
try:
added_images: set[str] = set()
affected_boards: set[str] = set()
- old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
+ old_board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none"
ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
added_images.add(image_name)
affected_boards.add(board_id)
@@ -48,13 +92,16 @@ async def add_image_to_board(
response_model=RemoveImagesFromBoardResult,
)
async def remove_image_from_board(
+ current_user: CurrentUserOrDefault,
image_name: str = Body(description="The name of the image to remove", embed=True),
) -> RemoveImagesFromBoardResult:
"""Removes an image from its board, if it had one"""
try:
+ old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
+ if old_board_id != "none":
+ _assert_board_write_access(old_board_id, current_user)
removed_images: set[str] = set()
affected_boards: set[str] = set()
- old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_images.add(image_name)
affected_boards.add("none")
@@ -64,6 +111,8 @@ async def remove_image_from_board(
affected_boards=list(affected_boards),
)
+ except HTTPException:
+ raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove image from board")
@@ -78,16 +127,21 @@ async def remove_image_from_board(
response_model=AddImagesToBoardResult,
)
async def add_images_to_board(
+ current_user: CurrentUserOrDefault,
board_id: str = Body(description="The id of the board to add to"),
image_names: list[str] = Body(description="The names of the images to add", embed=True),
) -> AddImagesToBoardResult:
"""Adds a list of images to a board"""
+ _assert_board_write_access(board_id, current_user)
try:
added_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
- old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
+ _assert_image_direct_owner(image_name, current_user)
+ old_board_id = (
+ ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none"
+ )
ApiDependencies.invoker.services.board_images.add_image_to_board(
board_id=board_id,
image_name=image_name,
@@ -96,12 +150,16 @@ async def add_images_to_board(
affected_boards.add(board_id)
affected_boards.add(old_board_id)
+ except HTTPException:
+ raise
except Exception:
pass
return AddImagesToBoardResult(
added_images=list(added_images),
affected_boards=list(affected_boards),
)
+ except HTTPException:
+ raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to add images to board")
@@ -116,6 +174,7 @@ async def add_images_to_board(
response_model=RemoveImagesFromBoardResult,
)
async def remove_images_from_board(
+ current_user: CurrentUserOrDefault,
image_names: list[str] = Body(description="The names of the images to remove", embed=True),
) -> RemoveImagesFromBoardResult:
"""Removes a list of images from their board, if they had one"""
@@ -125,15 +184,21 @@ async def remove_images_from_board(
for image_name in image_names:
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
+ if old_board_id != "none":
+ _assert_board_write_access(old_board_id, current_user)
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_images.add(image_name)
affected_boards.add("none")
affected_boards.add(old_board_id)
+ except HTTPException:
+ raise
except Exception:
pass
return RemoveImagesFromBoardResult(
removed_images=list(removed_images),
affected_boards=list(affected_boards),
)
+ except HTTPException:
+ raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove images from board")
diff --git a/invokeai/app/api/routers/boards.py b/invokeai/app/api/routers/boards.py
index e93bb8b2a9b..6897e90aff4 100644
--- a/invokeai/app/api/routers/boards.py
+++ b/invokeai/app/api/routers/boards.py
@@ -6,7 +6,7 @@
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
-from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
+from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy, BoardVisibility
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
@@ -56,7 +56,14 @@ async def get_board(
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
- if not current_user.is_admin and result.user_id != current_user.user_id:
+ # Admins can access any board.
+ # Owners can access their own boards.
+ # Shared and public boards are visible to all authenticated users.
+ if (
+ not current_user.is_admin
+ and result.user_id != current_user.user_id
+ and result.board_visibility == BoardVisibility.Private
+ ):
raise HTTPException(status_code=403, detail="Not authorized to access this board")
return result
@@ -188,7 +195,11 @@ async def list_all_board_image_names(
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
- if not current_user.is_admin and board.user_id != current_user.user_id:
+ if (
+ not current_user.is_admin
+ and board.user_id != current_user.user_id
+ and board.board_visibility == BoardVisibility.Private
+ ):
raise HTTPException(status_code=403, detail="Not authorized to access this board")
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
@@ -196,4 +207,15 @@ async def list_all_board_image_names(
categories,
is_intermediate,
)
+
+ # For uncategorized images (board_id="none"), filter to only the caller's
+ # images so that one user cannot enumerate another's uncategorized images.
+ # Admin users can see all uncategorized images.
+ if board_id == "none" and not current_user.is_admin:
+ image_names = [
+ name
+ for name in image_names
+ if ApiDependencies.invoker.services.image_records.get_user_id(name) == current_user.user_id
+ ]
+
return image_names
diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py
index 6b11762c9ec..a3ae6fce82b 100644
--- a/invokeai/app/api/routers/images.py
+++ b/invokeai/app/api/routers/images.py
@@ -38,6 +38,96 @@
IMAGE_MAX_AGE = 31536000
+def _assert_image_owner(image_name: str, current_user: CurrentUserOrDefault) -> None:
+ """Raise 403 if the current user does not own the image and is not an admin.
+
+ Ownership is satisfied when ANY of these hold:
+ - The user is an admin.
+ - The user is the image's direct owner (image_records.user_id).
+ - The user owns the board the image sits on.
+ - The image sits on a Public board (public boards grant mutation rights).
+ """
+ from invokeai.app.services.board_records.board_records_common import BoardVisibility
+
+ if current_user.is_admin:
+ return
+ owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
+ if owner is not None and owner == current_user.user_id:
+ return
+
+ # Check whether the user owns the board the image belongs to,
+ # or the board is Public (public boards grant mutation rights).
+ board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
+ if board_id is not None:
+ try:
+ board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
+ if board.user_id == current_user.user_id:
+ return
+ if board.board_visibility == BoardVisibility.Public:
+ return
+ except Exception:
+ pass
+
+ raise HTTPException(status_code=403, detail="Not authorized to modify this image")
+
+
+def _assert_image_read_access(image_name: str, current_user: CurrentUserOrDefault) -> None:
+ """Raise 403 if the current user may not view the image.
+
+ Access is granted when ANY of these hold:
+ - The user is an admin.
+ - The user owns the image.
+ - The image sits on a shared or public board.
+ """
+ from invokeai.app.services.board_records.board_records_common import BoardVisibility
+
+ if current_user.is_admin:
+ return
+
+ owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
+ if owner is not None and owner == current_user.user_id:
+ return
+
+ # Check whether the image's board makes it visible to other users.
+ board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
+ if board_id is not None:
+ try:
+ board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
+ if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
+ return
+ except Exception:
+ pass
+
+ raise HTTPException(status_code=403, detail="Not authorized to access this image")
+
+
+def _assert_board_read_access(board_id: str, current_user: CurrentUserOrDefault) -> None:
+ """Raise 403 if the current user may not read images from this board.
+
+ Access is granted when ANY of these hold:
+ - The user is an admin.
+ - The user owns the board.
+ - The board visibility is Shared or Public.
+ """
+ from invokeai.app.services.board_records.board_records_common import BoardVisibility
+
+ if current_user.is_admin:
+ return
+
+ try:
+ board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
+ except Exception:
+ raise HTTPException(status_code=404, detail="Board not found")
+
+ if board.user_id == current_user.user_id:
+ return
+
+ if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
+ return
+
+ raise HTTPException(status_code=403, detail="Not authorized to access this board")
+
+
class ResizeToDimensions(BaseModel):
width: int = Field(..., gt=0)
height: int = Field(..., gt=0)
@@ -83,6 +173,22 @@ async def upload_image(
),
) -> ImageDTO:
"""Uploads an image for the current user"""
+ # If uploading into a board, verify the user has write access.
+ # Public boards allow uploads from any authenticated user.
+ if board_id is not None:
+ from invokeai.app.services.board_records.board_records_common import BoardVisibility
+
+ try:
+ board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
+ except Exception:
+ raise HTTPException(status_code=404, detail="Board not found")
+ if (
+ not current_user.is_admin
+ and board.user_id != current_user.user_id
+ and board.board_visibility != BoardVisibility.Public
+ ):
+ raise HTTPException(status_code=403, detail="Not authorized to upload to this board")
+
if not file.content_type or not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
@@ -165,9 +271,11 @@ async def create_image_upload_entry(
@images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult)
async def delete_image(
+ current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of the image to delete"),
) -> DeleteImagesResult:
"""Deletes an image"""
+ _assert_image_owner(image_name, current_user)
deleted_images: set[str] = set()
affected_boards: set[str] = set()
@@ -189,26 +297,31 @@ async def delete_image(
@images_router.delete("/intermediates", operation_id="clear_intermediates")
-async def clear_intermediates() -> int:
- """Clears all intermediates"""
+async def clear_intermediates(
+ current_user: CurrentUserOrDefault,
+) -> int:
+ """Clears all intermediates. Requires admin."""
+ if not current_user.is_admin:
+ raise HTTPException(status_code=403, detail="Only admins can clear all intermediates")
try:
count_deleted = ApiDependencies.invoker.services.images.delete_intermediates()
return count_deleted
except Exception:
raise HTTPException(status_code=500, detail="Failed to clear intermediates")
- pass
@images_router.get("/intermediates", operation_id="get_intermediates_count")
-async def get_intermediates_count() -> int:
- """Gets the count of intermediate images"""
+async def get_intermediates_count(
+ current_user: CurrentUserOrDefault,
+) -> int:
+ """Gets the count of intermediate images. Non-admin users only see their own intermediates."""
try:
- return ApiDependencies.invoker.services.images.get_intermediates_count()
+ user_id = None if current_user.is_admin else current_user.user_id
+ return ApiDependencies.invoker.services.images.get_intermediates_count(user_id=user_id)
except Exception:
raise HTTPException(status_code=500, detail="Failed to get intermediates")
- pass
@images_router.patch(
@@ -217,10 +330,12 @@ async def get_intermediates_count() -> int:
response_model=ImageDTO,
)
async def update_image(
+ current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"),
) -> ImageDTO:
"""Updates an image"""
+ _assert_image_owner(image_name, current_user)
try:
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
@@ -234,9 +349,11 @@ async def update_image(
response_model=ImageDTO,
)
async def get_image_dto(
+ current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of image to get"),
) -> ImageDTO:
"""Gets an image's DTO"""
+ _assert_image_read_access(image_name, current_user)
try:
return ApiDependencies.invoker.services.images.get_dto(image_name)
@@ -250,9 +367,11 @@ async def get_image_dto(
response_model=Optional[MetadataField],
)
async def get_image_metadata(
+ current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of image to get"),
) -> Optional[MetadataField]:
"""Gets an image's metadata"""
+ _assert_image_read_access(image_name, current_user)
try:
return ApiDependencies.invoker.services.images.get_metadata(image_name)
@@ -269,8 +388,11 @@ class WorkflowAndGraphResponse(BaseModel):
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse
)
async def get_image_workflow(
+ current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of image whose workflow to get"),
) -> WorkflowAndGraphResponse:
+ _assert_image_read_access(image_name, current_user)
+
try:
workflow = ApiDependencies.invoker.services.images.get_workflow(image_name)
graph = ApiDependencies.invoker.services.images.get_graph(image_name)
@@ -306,8 +428,12 @@ async def get_image_workflow(
async def get_image_full(
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> Response:
- """Gets a full-resolution image file"""
+ """Gets a full-resolution image file.
+ This endpoint is intentionally unauthenticated because browsers load images
+ via
tags which cannot send Bearer tokens. Image names are UUIDs,
+ providing security through unguessability.
+ """
try:
path = ApiDependencies.invoker.services.images.get_path(image_name)
with open(path, "rb") as f:
@@ -335,8 +461,12 @@ async def get_image_full(
async def get_image_thumbnail(
image_name: str = Path(description="The name of thumbnail image file to get"),
) -> Response:
- """Gets a thumbnail image file"""
+ """Gets a thumbnail image file.
+ This endpoint is intentionally unauthenticated because browsers load images
+ via
tags which cannot send Bearer tokens. Image names are UUIDs,
+ providing security through unguessability.
+ """
try:
path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
with open(path, "rb") as f:
@@ -354,9 +484,11 @@ async def get_image_thumbnail(
response_model=ImageUrlsDTO,
)
async def get_image_urls(
+ current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL"""
+ _assert_image_read_access(image_name, current_user)
try:
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
@@ -392,6 +524,11 @@ async def list_image_dtos(
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of image DTOs for the current user"""
+ # Validate that the caller can read from this board before listing its images.
+ # "none" is a sentinel for uncategorized images and is handled by the SQL layer.
+ if board_id is not None and board_id != "none":
+ _assert_board_read_access(board_id, current_user)
+
image_dtos = ApiDependencies.invoker.services.images.get_many(
offset,
limit,
@@ -410,6 +547,7 @@ async def list_image_dtos(
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult)
async def delete_images_from_list(
+ current_user: CurrentUserOrDefault,
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
) -> DeleteImagesResult:
try:
@@ -417,24 +555,31 @@ async def delete_images_from_list(
affected_boards: set[str] = set()
for image_name in image_names:
try:
+ _assert_image_owner(image_name, current_user)
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
board_id = image_dto.board_id or "none"
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add(board_id)
+ except HTTPException:
+ raise
except Exception:
pass
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
+ except HTTPException:
+ raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to delete images")
@images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult)
-async def delete_uncategorized_images() -> DeleteImagesResult:
- """Deletes all images that are uncategorized"""
+async def delete_uncategorized_images(
+ current_user: CurrentUserOrDefault,
+) -> DeleteImagesResult:
+ """Deletes all uncategorized images owned by the current user (or all if admin)"""
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
board_id="none", categories=None, is_intermediate=None
@@ -445,9 +590,13 @@ async def delete_uncategorized_images() -> DeleteImagesResult:
affected_boards: set[str] = set()
for image_name in image_names:
try:
+ _assert_image_owner(image_name, current_user)
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add("none")
+ except HTTPException:
+ # Skip images not owned by the current user
+ pass
except Exception:
pass
return DeleteImagesResult(
@@ -464,6 +613,7 @@ class ImagesUpdatedFromListResult(BaseModel):
@images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult)
async def star_images_in_list(
+ current_user: CurrentUserOrDefault,
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
) -> StarredImagesResult:
try:
@@ -471,23 +621,29 @@ async def star_images_in_list(
affected_boards: set[str] = set()
for image_name in image_names:
try:
+ _assert_image_owner(image_name, current_user)
updated_image_dto = ApiDependencies.invoker.services.images.update(
image_name, changes=ImageRecordChanges(starred=True)
)
starred_images.add(image_name)
affected_boards.add(updated_image_dto.board_id or "none")
+ except HTTPException:
+ raise
except Exception:
pass
return StarredImagesResult(
starred_images=list(starred_images),
affected_boards=list(affected_boards),
)
+ except HTTPException:
+ raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to star images")
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult)
async def unstar_images_in_list(
+ current_user: CurrentUserOrDefault,
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
) -> UnstarredImagesResult:
try:
@@ -495,17 +651,22 @@ async def unstar_images_in_list(
affected_boards: set[str] = set()
for image_name in image_names:
try:
+ _assert_image_owner(image_name, current_user)
updated_image_dto = ApiDependencies.invoker.services.images.update(
image_name, changes=ImageRecordChanges(starred=False)
)
unstarred_images.add(image_name)
affected_boards.add(updated_image_dto.board_id or "none")
+ except HTTPException:
+ raise
except Exception:
pass
return UnstarredImagesResult(
unstarred_images=list(unstarred_images),
affected_boards=list(affected_boards),
)
+ except HTTPException:
+ raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to unstar images")
@@ -523,6 +684,7 @@ class ImagesDownloaded(BaseModel):
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
)
async def download_images_from_list(
+ current_user: CurrentUserOrDefault,
background_tasks: BackgroundTasks,
image_names: Optional[list[str]] = Body(
default=None, description="The list of names of images to download", embed=True
@@ -533,6 +695,16 @@ async def download_images_from_list(
) -> ImagesDownloaded:
if (image_names is None or len(image_names) == 0) and board_id is None:
raise HTTPException(status_code=400, detail="No images or board id specified.")
+
+ # Validate that the caller can read every image they are requesting.
+ # For a board_id request, check board visibility; for explicit image names,
+ # check each image individually.
+ if board_id:
+ _assert_board_read_access(board_id, current_user)
+ if image_names:
+ for name in image_names:
+ _assert_image_read_access(name, current_user)
+
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
background_tasks.add_task(
@@ -540,6 +712,7 @@ async def download_images_from_list(
image_names,
board_id,
bulk_download_item_id,
+ current_user.user_id,
)
return ImagesDownloaded(bulk_download_item_name=bulk_download_item_id + ".zip")
@@ -558,11 +731,21 @@ async def download_images_from_list(
},
)
async def get_bulk_download_item(
+ current_user: CurrentUserOrDefault,
background_tasks: BackgroundTasks,
bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"),
) -> FileResponse:
- """Gets a bulk download zip file"""
+ """Gets a bulk download zip file.
+
+ Requires authentication. The caller must be the user who initiated the
+ download (tracked by the bulk download service) or an admin.
+ """
try:
+ # Verify the caller owns this download (or is an admin)
+ owner = ApiDependencies.invoker.services.bulk_download.get_owner(bulk_download_item_name)
+ if owner is not None and owner != current_user.user_id and not current_user.is_admin:
+ raise HTTPException(status_code=403, detail="Not authorized to access this download")
+
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
response = FileResponse(
@@ -574,6 +757,8 @@ async def get_bulk_download_item(
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name)
return response
+ except HTTPException:
+ raise
except Exception:
raise HTTPException(status_code=404)
@@ -594,6 +779,10 @@ async def get_image_names(
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates"""
+ # Validate that the caller can read from this board before listing its images.
+ if board_id is not None and board_id != "none":
+ _assert_board_read_access(board_id, current_user)
+
try:
result = ApiDependencies.invoker.services.images.get_image_names(
starred_first=starred_first,
@@ -617,6 +806,7 @@ async def get_image_names(
responses={200: {"model": list[ImageDTO]}},
)
async def get_images_by_names(
+ current_user: CurrentUserOrDefault,
image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"),
) -> list[ImageDTO]:
"""Gets image DTOs for the specified image names. Maintains order of input names."""
@@ -628,8 +818,12 @@ async def get_images_by_names(
image_dtos: list[ImageDTO] = []
for name in image_names:
try:
+ _assert_image_read_access(name, current_user)
dto = image_service.get_dto(name)
image_dtos.append(dto)
+ except HTTPException:
+ # Skip images the user is not authorized to view
+ continue
except Exception:
# Skip missing images - they may have been deleted between name fetch and DTO fetch
continue
diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py
index 65b059ecfce..822d9655fe7 100644
--- a/invokeai/app/api/routers/model_manager.py
+++ b/invokeai/app/api/routers/model_manager.py
@@ -858,7 +858,7 @@ def generate_html(title: str, heading: str, repo_id: str, is_error: bool, messag
"/install",
operation_id="list_model_installs",
)
-async def list_model_installs() -> List[ModelInstallJob]:
+async def list_model_installs(current_admin: AdminUserOrDefault) -> List[ModelInstallJob]:
"""Return the list of model install jobs.
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
@@ -890,7 +890,9 @@ async def list_model_installs() -> List[ModelInstallJob]:
404: {"description": "No such job"},
},
)
-async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
+async def get_model_install_job(
+ current_admin: AdminUserOrDefault, id: int = Path(description="Model install id")
+) -> ModelInstallJob:
"""
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
for information on the format of the return value.
@@ -933,7 +935,9 @@ async def cancel_model_install_job(
},
status_code=201,
)
-async def pause_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
+async def pause_model_install_job(
+ current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
+) -> ModelInstallJob:
"""Pause the model install job corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
@@ -953,7 +957,9 @@ async def pause_model_install_job(id: int = Path(description="Model install job
},
status_code=201,
)
-async def resume_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
+async def resume_model_install_job(
+ current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
+) -> ModelInstallJob:
"""Resume a paused model install job corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
@@ -973,7 +979,9 @@ async def resume_model_install_job(id: int = Path(description="Model install job
},
status_code=201,
)
-async def restart_failed_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
+async def restart_failed_model_install_job(
+ current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
+) -> ModelInstallJob:
"""Restart failed or non-resumable file downloads for the given job."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
@@ -994,6 +1002,7 @@ async def restart_failed_model_install_job(id: int = Path(description="Model ins
status_code=201,
)
async def restart_model_install_file(
+ current_admin: AdminUserOrDefault,
id: int = Path(description="Model install job ID"),
file_source: AnyHttpUrl = Body(description="File download URL to restart"),
) -> ModelInstallJob:
@@ -1305,7 +1314,7 @@ class DeleteOrphanedModelsResponse(BaseModel):
operation_id="get_orphaned_models",
response_model=list[OrphanedModelInfo],
)
-async def get_orphaned_models() -> list[OrphanedModelInfo]:
+async def get_orphaned_models(_: AdminUserOrDefault) -> list[OrphanedModelInfo]:
"""Find orphaned model directories.
Orphaned models are directories in the models folder that contain model files
@@ -1332,7 +1341,9 @@ async def get_orphaned_models() -> list[OrphanedModelInfo]:
operation_id="delete_orphaned_models",
response_model=DeleteOrphanedModelsResponse,
)
-async def delete_orphaned_models(request: DeleteOrphanedModelsRequest) -> DeleteOrphanedModelsResponse:
+async def delete_orphaned_models(
+ request: DeleteOrphanedModelsRequest, _: AdminUserOrDefault
+) -> DeleteOrphanedModelsResponse:
"""Delete specified orphaned model directories.
Args:
diff --git a/invokeai/app/api/routers/recall_parameters.py b/invokeai/app/api/routers/recall_parameters.py
index 0af3fd29b0c..ec08adba2e8 100644
--- a/invokeai/app/api/routers/recall_parameters.py
+++ b/invokeai/app/api/routers/recall_parameters.py
@@ -7,6 +7,7 @@
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field
+from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.backend.image_util.controlnet_processor import process_controlnet_image
from invokeai.backend.model_manager.taxonomy import ModelType
@@ -291,12 +292,58 @@ def resolve_ip_adapter_models(ip_adapters: list[IPAdapterRecallParameter]) -> li
return resolved_adapters
+def _assert_recall_image_access(parameters: "RecallParameter", current_user: CurrentUserOrDefault) -> None:
+ """Validate that the caller can read every image referenced in the recall parameters.
+
+ Control layers and IP adapters may reference image_name fields. Without this
+ check an attacker who knows another user's image UUID could use the recall
+ endpoint to extract image dimensions and — for ControlNet preprocessors — mint
+ a derived processed image they can then fetch.
+ """
+ from invokeai.app.services.board_records.board_records_common import BoardVisibility
+
+ image_names: list[str] = []
+ if parameters.control_layers:
+ for layer in parameters.control_layers:
+ if layer.image_name is not None:
+ image_names.append(layer.image_name)
+ if parameters.ip_adapters:
+ for adapter in parameters.ip_adapters:
+ if adapter.image_name is not None:
+ image_names.append(adapter.image_name)
+
+ if not image_names:
+ return
+
+ # Admin can access all images
+ if current_user.is_admin:
+ return
+
+ for image_name in image_names:
+ owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
+ if owner is not None and owner == current_user.user_id:
+ continue
+
+ # Check board visibility
+ board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
+ if board_id is not None:
+ try:
+ board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
+ if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
+ continue
+ except Exception:
+ pass
+
+ raise HTTPException(status_code=403, detail=f"Not authorized to access image {image_name}")
+
+
@recall_parameters_router.post(
"/{queue_id}",
operation_id="update_recall_parameters",
response_model=dict[str, Any],
)
async def update_recall_parameters(
+ current_user: CurrentUserOrDefault,
queue_id: str = Path(..., description="The queue id to perform this operation on"),
parameters: RecallParameter = Body(..., description="Recall parameters to update"),
) -> dict[str, Any]:
@@ -328,6 +375,10 @@ async def update_recall_parameters(
"""
logger = ApiDependencies.invoker.services.logger
+ # Validate image access before processing — prevents information leakage
+ # (dimensions) and derived-image minting via ControlNet preprocessors.
+ _assert_recall_image_access(parameters, current_user)
+
try:
# Get only the parameters that were actually provided (non-None values)
provided_params = {k: v for k, v in parameters.model_dump().items() if v is not None}
@@ -335,14 +386,14 @@ async def update_recall_parameters(
if not provided_params:
return {"status": "no_parameters_provided", "updated_count": 0}
- # Store each parameter in client state using a consistent key format
+ # Store each parameter in client state scoped to the current user
updated_count = 0
for param_key, param_value in provided_params.items():
# Convert parameter values to JSON strings for storage
value_str = json.dumps(param_value)
try:
ApiDependencies.invoker.services.client_state_persistence.set_by_key(
- queue_id, f"recall_{param_key}", value_str
+ current_user.user_id, f"recall_{param_key}", value_str
)
updated_count += 1
except Exception as e:
@@ -396,7 +447,9 @@ async def update_recall_parameters(
logger.info(
f"Emitting recall_parameters_updated event for queue {queue_id} with {len(provided_params)} parameters"
)
- ApiDependencies.invoker.services.events.emit_recall_parameters_updated(queue_id, provided_params)
+ ApiDependencies.invoker.services.events.emit_recall_parameters_updated(
+ queue_id, current_user.user_id, provided_params
+ )
logger.info("Successfully emitted recall_parameters_updated event")
except Exception as e:
logger.error(f"Error emitting recall parameters event: {e}", exc_info=True)
@@ -425,6 +478,7 @@ async def update_recall_parameters(
response_model=dict[str, Any],
)
async def get_recall_parameters(
+ current_user: CurrentUserOrDefault,
queue_id: str = Path(..., description="The queue id to retrieve parameters for"),
) -> dict[str, Any]:
"""
diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py
index 403e7727cb4..41a5a411c7a 100644
--- a/invokeai/app/api/routers/session_queue.py
+++ b/invokeai/app/api/routers/session_queue.py
@@ -44,7 +44,8 @@ def sanitize_queue_item_for_user(
"""Sanitize queue item for non-admin users viewing other users' items.
For non-admin users viewing queue items belonging to other users,
- the field_values, session graph, and workflow should be hidden/cleared to protect privacy.
+ only timestamps, status, and error information are exposed. All other
+ fields (user identity, generation parameters, graphs, workflows) are stripped.
Args:
queue_item: The queue item to sanitize
@@ -58,15 +59,25 @@ def sanitize_queue_item_for_user(
if is_admin or queue_item.user_id == current_user_id:
return queue_item
- # For non-admins viewing other users' items, clear sensitive fields
- # Create a shallow copy to avoid mutating the original
+ # For non-admins viewing other users' items, strip everything except
+ # item_id, queue_id, status, and timestamps
sanitized_item = queue_item.model_copy(deep=False)
+ sanitized_item.user_id = "redacted"
+ sanitized_item.user_display_name = None
+ sanitized_item.user_email = None
+ sanitized_item.batch_id = "redacted"
+ sanitized_item.session_id = "redacted"
+ sanitized_item.origin = None
+ sanitized_item.destination = None
+ sanitized_item.priority = 0
sanitized_item.field_values = None
+ sanitized_item.retried_from_item_id = None
sanitized_item.workflow = None
- # Clear the session graph by replacing it with an empty graph execution state
- # This prevents information leakage through the generation graph
+ sanitized_item.error_type = None
+ sanitized_item.error_message = None
+ sanitized_item.error_traceback = None
sanitized_item.session = GraphExecutionState(
- id=queue_item.session.id,
+ id="redacted",
graph=Graph(),
)
return sanitized_item
@@ -126,12 +137,16 @@ async def list_all_queue_items(
},
)
async def get_queue_item_ids(
+ current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
) -> ItemIdsResult:
- """Gets all queue item ids that match the given parameters"""
+ """Gets all queue item ids that match the given parameters. Non-admin users only see their own items."""
try:
- return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir)
+ user_id = None if current_user.is_admin else current_user.user_id
+ return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(
+ queue_id=queue_id, order_dir=order_dir, user_id=user_id
+ )
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}")
@@ -376,11 +391,15 @@ async def prune(
},
)
async def get_current_queue_item(
+ current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the currently execution queue item"""
try:
- return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
+ item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
+ if item is not None:
+ item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin)
+ return item
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}")
@@ -393,11 +412,15 @@ async def get_current_queue_item(
},
)
async def get_next_queue_item(
+ current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the next queue item, without executing it"""
try:
- return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
+ item = ApiDependencies.invoker.services.session_queue.get_next(queue_id)
+ if item is not None:
+ item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin)
+ return item
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}")
@@ -413,9 +436,10 @@ async def get_queue_status(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndProcessorStatus:
- """Gets the status of the session queue"""
+ """Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it."""
try:
- queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=current_user.user_id)
+ user_id = None if current_user.is_admin else current_user.user_id
+ queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=user_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
except Exception as e:
@@ -430,12 +454,16 @@ async def get_queue_status(
},
)
async def get_batch_status(
+ current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
batch_id: str = Path(description="The batch to get the status of"),
) -> BatchStatus:
- """Gets the status of the session queue"""
+ """Gets the status of a batch. Non-admin users only see their own batches."""
try:
- return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
+ user_id = None if current_user.is_admin else current_user.user_id
+ return ApiDependencies.invoker.services.session_queue.get_batch_status(
+ queue_id=queue_id, batch_id=batch_id, user_id=user_id
+ )
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}")
@@ -529,13 +557,15 @@ async def cancel_queue_item(
responses={200: {"model": SessionQueueCountsByDestination}},
)
async def counts_by_destination(
+ current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to query"),
destination: str = Query(description="The destination to query"),
) -> SessionQueueCountsByDestination:
- """Gets the counts of queue items by destination"""
+ """Gets the counts of queue items by destination. Non-admin users only see their own items."""
try:
+ user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
- queue_id=queue_id, destination=destination
+ queue_id=queue_id, destination=destination, user_id=user_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}")
diff --git a/invokeai/app/api/routers/workflows.py b/invokeai/app/api/routers/workflows.py
index 72d50a416b4..1c88a77a3f6 100644
--- a/invokeai/app/api/routers/workflows.py
+++ b/invokeai/app/api/routers/workflows.py
@@ -6,6 +6,7 @@
from fastapi.responses import FileResponse
from PIL import Image
+from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
@@ -33,16 +34,25 @@
},
)
async def get_workflow(
+ current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to get"),
) -> WorkflowRecordWithThumbnailDTO:
"""Gets a workflow"""
try:
- thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
- return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
+ config = ApiDependencies.invoker.services.configuration
+ if config.multiuser:
+ is_default = workflow.workflow.meta.category is WorkflowCategory.Default
+ is_owner = workflow.user_id == current_user.user_id
+ if not (is_default or is_owner or workflow.is_public or current_user.is_admin):
+ raise HTTPException(status_code=403, detail="Not authorized to access this workflow")
+
+ thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
+ return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
+
@workflows_router.patch(
"/i/{workflow_id}",
@@ -52,10 +62,21 @@ async def get_workflow(
},
)
async def update_workflow(
+ current_user: CurrentUserOrDefault,
workflow: Workflow = Body(description="The updated workflow", embed=True),
) -> WorkflowRecordDTO:
"""Updates a workflow"""
- return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow)
+ config = ApiDependencies.invoker.services.configuration
+ if config.multiuser:
+ try:
+ existing = ApiDependencies.invoker.services.workflow_records.get(workflow.id)
+ except WorkflowNotFoundError:
+ raise HTTPException(status_code=404, detail="Workflow not found")
+ if not current_user.is_admin and existing.user_id != current_user.user_id:
+ raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
+ # Pass user_id for defense-in-depth SQL scoping; admins pass None to allow any.
+ user_id = None if current_user.is_admin else current_user.user_id
+ return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow, user_id=user_id)
@workflows_router.delete(
@@ -63,15 +84,25 @@ async def update_workflow(
operation_id="delete_workflow",
)
async def delete_workflow(
+ current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to delete"),
) -> None:
"""Deletes a workflow"""
+ config = ApiDependencies.invoker.services.configuration
+ if config.multiuser:
+ try:
+ existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
+ except WorkflowNotFoundError:
+ raise HTTPException(status_code=404, detail="Workflow not found")
+ if not current_user.is_admin and existing.user_id != current_user.user_id:
+ raise HTTPException(status_code=403, detail="Not authorized to delete this workflow")
try:
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
except WorkflowThumbnailFileNotFoundException:
# It's OK if the workflow has no thumbnail file. We can still delete the workflow.
pass
- ApiDependencies.invoker.services.workflow_records.delete(workflow_id)
+ user_id = None if current_user.is_admin else current_user.user_id
+ ApiDependencies.invoker.services.workflow_records.delete(workflow_id, user_id=user_id)
@workflows_router.post(
@@ -82,10 +113,11 @@ async def delete_workflow(
},
)
async def create_workflow(
+ current_user: CurrentUserOrDefault,
workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True),
) -> WorkflowRecordDTO:
"""Creates a workflow"""
- return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow)
+ return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow, user_id=current_user.user_id)
@workflows_router.get(
@@ -96,6 +128,7 @@ async def create_workflow(
},
)
async def list_workflows(
+ current_user: CurrentUserOrDefault,
page: int = Query(default=0, description="The page to get"),
per_page: Optional[int] = Query(default=None, description="The number of workflows per page"),
order_by: WorkflowRecordOrderBy = Query(
@@ -106,8 +139,19 @@ async def list_workflows(
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
+ is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
"""Gets a page of workflows"""
+ config = ApiDependencies.invoker.services.configuration
+
+ # In multiuser mode, scope user-category workflows to the current user unless fetching shared workflows
+ user_id_filter: Optional[str] = None
+ if config.multiuser:
+ # Only filter 'user' category results by user_id when not explicitly listing public workflows
+ has_user_category = not categories or WorkflowCategory.User in categories
+ if has_user_category and is_public is not True:
+ user_id_filter = current_user.user_id
+
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
order_by=order_by,
@@ -118,6 +162,8 @@ async def list_workflows(
categories=categories,
tags=tags,
has_been_opened=has_been_opened,
+ user_id=user_id_filter,
+ is_public=is_public,
)
for workflow in workflows.items:
workflows_with_thumbnails.append(
@@ -143,15 +189,20 @@ async def list_workflows(
},
)
async def set_workflow_thumbnail(
+ current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to update"),
image: UploadFile = File(description="The image file to upload"),
):
"""Sets a workflow's thumbnail image"""
try:
- ApiDependencies.invoker.services.workflow_records.get(workflow_id)
+ existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
+ config = ApiDependencies.invoker.services.configuration
+ if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
+ raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
+
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
@@ -177,14 +228,19 @@ async def set_workflow_thumbnail(
},
)
async def delete_workflow_thumbnail(
+ current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to update"),
):
"""Removes a workflow's thumbnail image"""
try:
- ApiDependencies.invoker.services.workflow_records.get(workflow_id)
+ existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
+ config = ApiDependencies.invoker.services.configuration
+ if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
+ raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
+
try:
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
except ValueError as e:
@@ -206,8 +262,12 @@ async def delete_workflow_thumbnail(
async def get_workflow_thumbnail(
workflow_id: str = Path(description="The id of the workflow thumbnail to get"),
) -> FileResponse:
- """Gets a workflow's thumbnail image"""
+ """Gets a workflow's thumbnail image.
+ This endpoint is intentionally unauthenticated because browsers load images
+ via
tags which cannot send Bearer tokens. Workflow IDs are UUIDs,
+ providing security through unguessability.
+ """
try:
path = ApiDependencies.invoker.services.workflow_thumbnails.get_path(workflow_id)
@@ -223,37 +283,91 @@ async def get_workflow_thumbnail(
raise HTTPException(status_code=404)
+@workflows_router.patch(
+ "/i/{workflow_id}/is_public",
+ operation_id="update_workflow_is_public",
+ responses={
+ 200: {"model": WorkflowRecordDTO},
+ },
+)
+async def update_workflow_is_public(
+ current_user: CurrentUserOrDefault,
+ workflow_id: str = Path(description="The workflow to update"),
+ is_public: bool = Body(description="Whether the workflow should be shared publicly", embed=True),
+) -> WorkflowRecordDTO:
+ """Updates whether a workflow is shared publicly"""
+ try:
+ existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
+ except WorkflowNotFoundError:
+ raise HTTPException(status_code=404, detail="Workflow not found")
+
+ config = ApiDependencies.invoker.services.configuration
+ if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
+ raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
+
+ user_id = None if current_user.is_admin else current_user.user_id
+ return ApiDependencies.invoker.services.workflow_records.update_is_public(
+ workflow_id=workflow_id, is_public=is_public, user_id=user_id
+ )
+
+
@workflows_router.get("/tags", operation_id="get_all_tags")
async def get_all_tags(
+ current_user: CurrentUserOrDefault,
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
+ is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> list[str]:
"""Gets all unique tags from workflows"""
-
- return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories)
+ config = ApiDependencies.invoker.services.configuration
+ user_id_filter: Optional[str] = None
+ if config.multiuser:
+ has_user_category = not categories or WorkflowCategory.User in categories
+ if has_user_category and is_public is not True:
+ user_id_filter = current_user.user_id
+
+ return ApiDependencies.invoker.services.workflow_records.get_all_tags(
+ categories=categories, user_id=user_id_filter, is_public=is_public
+ )
@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
async def get_counts_by_tag(
+ current_user: CurrentUserOrDefault,
tags: list[str] = Query(description="The tags to get counts for"),
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
+ is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> dict[str, int]:
"""Counts workflows by tag"""
+ config = ApiDependencies.invoker.services.configuration
+ user_id_filter: Optional[str] = None
+ if config.multiuser:
+ has_user_category = not categories or WorkflowCategory.User in categories
+ if has_user_category and is_public is not True:
+ user_id_filter = current_user.user_id
return ApiDependencies.invoker.services.workflow_records.counts_by_tag(
- tags=tags, categories=categories, has_been_opened=has_been_opened
+ tags=tags, categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public
)
@workflows_router.get("/counts_by_category", operation_id="counts_by_category")
async def counts_by_category(
+ current_user: CurrentUserOrDefault,
categories: list[WorkflowCategory] = Query(description="The categories to include"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
+ is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> dict[str, int]:
"""Counts workflows by category"""
+ config = ApiDependencies.invoker.services.configuration
+ user_id_filter: Optional[str] = None
+ if config.multiuser:
+ has_user_category = WorkflowCategory.User in categories
+ if has_user_category and is_public is not True:
+ user_id_filter = current_user.user_id
return ApiDependencies.invoker.services.workflow_records.counts_by_category(
- categories=categories, has_been_opened=has_been_opened
+ categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public
)
@@ -262,7 +376,18 @@ async def counts_by_category(
operation_id="update_opened_at",
)
async def update_opened_at(
+ current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to update"),
) -> None:
"""Updates the opened_at field of a workflow"""
- ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id)
+ try:
+ existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
+ except WorkflowNotFoundError:
+ raise HTTPException(status_code=404, detail="Workflow not found")
+
+ config = ApiDependencies.invoker.services.configuration
+ if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
+ raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
+
+ user_id = None if current_user.is_admin else current_user.user_id
+ ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id, user_id=user_id)
diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py
index fcead54eb1e..5783b804c0b 100644
--- a/invokeai/app/api/sockets.py
+++ b/invokeai/app/api/sockets.py
@@ -121,6 +121,11 @@ async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> b
Returns True to accept the connection, False to reject it.
Stores user_id in the internal socket users dict for later use.
+
+ In multiuser mode, connections without a valid token are rejected outright
+ so that anonymous clients cannot subscribe to queue rooms and observe
+ queue activity belonging to other users. In single-user mode, unauthenticated
+ connections are accepted as the system admin user.
"""
# Extract token from auth data or headers
token = None
@@ -137,6 +142,23 @@ async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> b
if token:
token_data = verify_token(token)
if token_data:
+ # In multiuser mode, also verify the backing user record still
+ # exists and is active — mirrors the REST auth check in
+ # auth_dependencies.py. A deleted or deactivated user whose
+ # JWT has not yet expired must not be allowed to open a socket.
+ if self._is_multiuser_enabled():
+ try:
+ from invokeai.app.api.dependencies import ApiDependencies
+
+ user = ApiDependencies.invoker.services.users.get(token_data.user_id)
+ if user is None or not user.is_active:
+ logger.warning(f"Rejecting socket {sid}: user {token_data.user_id} not found or inactive")
+ return False
+ except Exception:
+ # If user service is unavailable, fail closed
+ logger.warning(f"Rejecting socket {sid}: unable to verify user record")
+ return False
+
# Store user_id and is_admin in socket users dict
self._socket_users[sid] = {
"user_id": token_data.user_id,
@@ -147,14 +169,37 @@ async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> b
)
return True
- # If no valid token, store system user for backward compatibility
+ # No valid token provided. In multiuser mode this is not allowed — reject
+ # the connection so anonymous clients cannot subscribe to queue rooms.
+ # In single-user mode, fall through and accept the socket as system admin.
+ if self._is_multiuser_enabled():
+ logger.warning(
+ f"Rejecting socket {sid} connection: multiuser mode is enabled and no valid auth token was provided"
+ )
+ return False
+
self._socket_users[sid] = {
"user_id": "system",
- "is_admin": False,
+ "is_admin": True,
}
- logger.debug(f"Socket {sid} connected as system user (no valid token)")
+ logger.debug(f"Socket {sid} connected as system admin (single-user mode)")
return True
+ @staticmethod
+ def _is_multiuser_enabled() -> bool:
+ """Check whether multiuser mode is enabled. Fails closed if configuration
+ is not yet initialized, which should not happen in practice but prevents
+ accidentally opening the socket during startup races."""
+ try:
+ # Imported here to avoid a circular import at module load time.
+ from invokeai.app.api.dependencies import ApiDependencies
+
+ return bool(ApiDependencies.invoker.services.configuration.multiuser)
+ except Exception:
+ # If dependencies are not initialized, fail closed (treat as multiuser)
+ # so we never accidentally admit an anonymous socket.
+ return True
+
async def _handle_disconnect(self, sid: str) -> None:
"""Handle socket disconnection and cleanup user info."""
if sid in self._socket_users:
@@ -165,15 +210,20 @@ async def _handle_sub_queue(self, sid: str, data: Any) -> None:
"""Handle queue subscription and add socket to both queue and user-specific rooms."""
queue_id = QueueSubscriptionEvent(**data).queue_id
- # Check if we have user info for this socket
+ # Check if we have user info for this socket. In multiuser mode _handle_connect
+ # will have already rejected any socket without a valid token, so missing user
+ # info here is a bug — refuse the subscription rather than silently falling back
+ # to an anonymous system user who could then receive queue item events.
if sid not in self._socket_users:
- logger.warning(
- f"Socket {sid} subscribing to queue {queue_id} but has no user info - need to authenticate via connect event"
- )
- # Store as system user temporarily - real auth should happen in connect
+ if self._is_multiuser_enabled():
+ logger.warning(
+ f"Refusing queue subscription for socket {sid}: no user info (socket not authenticated via connect event)"
+ )
+ return
+ # Single-user mode: safe to fall back to the system admin user.
self._socket_users[sid] = {
"user_id": "system",
- "is_admin": False,
+ "is_admin": True,
}
user_id = self._socket_users[sid]["user_id"]
@@ -198,6 +248,13 @@ async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
+ # In multiuser mode, only allow authenticated sockets to subscribe.
+ # Bulk download events are routed to user-specific rooms, so the
+ # bulk_download_id room subscription is only kept for single-user
+ # backward compatibility.
+ if self._is_multiuser_enabled() and sid not in self._socket_users:
+ logger.warning(f"Refusing bulk download subscription for unknown socket {sid} in multiuser mode")
+ return
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
@@ -206,9 +263,17 @@ async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
"""Handle queue events with user isolation.
- Invocation events (progress, started, complete) are private - only emit to owner and admins.
- Queue item status events are public - emit to all users (field values hidden via API).
- Other queue events emit to all subscribers.
+ All queue item events (invocation events AND QueueItemStatusChangedEvent) are
+ private to the owning user and admins. They carry unsanitized user_id, batch_id,
+ session_id, origin, destination and error metadata, and must never be broadcast
+ to the whole queue room — otherwise any other authenticated subscriber could
+ observe cross-user queue activity.
+
+ RecallParametersUpdatedEvent is also private to the owner + admins.
+
+ BatchEnqueuedEvent carries the enqueuing user's batch_id/origin/counts and
+ is also routed privately. QueueClearedEvent is the only queue event that
+ is still broadcast to the whole queue room.
IMPORTANT: Check InvocationEventBase BEFORE QueueItemEventBase since InvocationEventBase
inherits from QueueItemEventBase. The order of isinstance checks matters!
@@ -237,24 +302,40 @@ async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
logger.debug(f"Emitted private invocation event {event_name} to user room {user_room} and admin room")
- # Queue item status events are visible to all users (field values masked via API)
- # This catches QueueItemStatusChangedEvent but NOT InvocationEvents (already handled above)
+ # Other queue item events (QueueItemStatusChangedEvent) carry unsanitized
+ # user_id, batch_id, session_id, origin, destination and error metadata.
+ # They are private to the owning user + admins — never broadcast to the
+ # full queue room.
elif isinstance(event_data, QueueItemEventBase) and hasattr(event_data, "user_id"):
- # Emit to all subscribers in the queue
- await self._sio.emit(
- event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id
- )
+ user_room = f"user:{event_data.user_id}"
+ await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
+ await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
- logger.info(
- f"Emitted public queue item event {event_name} to all subscribers in queue {event_data.queue_id}"
- )
+ logger.debug(f"Emitted private queue item event {event_name} to user room {user_room} and admin room")
+
+ # RecallParametersUpdatedEvent is private - only emit to owner + admins
+ elif isinstance(event_data, RecallParametersUpdatedEvent):
+ user_room = f"user:{event_data.user_id}"
+ await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
+ await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
+ logger.debug(f"Emitted private recall_parameters_updated event to user room {user_room} and admin room")
+
+ # BatchEnqueuedEvent carries the enqueuing user's batch_id, origin, and
+ # enqueued counts. Route it privately to the owner + admins so other
+ # users do not observe cross-user batch activity.
+ elif isinstance(event_data, BatchEnqueuedEvent):
+ user_room = f"user:{event_data.user_id}"
+ await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
+ await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
+ logger.debug(f"Emitted private batch_enqueued event to user room {user_room} and admin room")
else:
- # For other queue events (like QueueClearedEvent, BatchEnqueuedEvent), emit to all subscribers
+ # For remaining queue events (e.g. QueueClearedEvent) that do not
+ # carry user identity, emit to all subscribers in the queue room.
await self._sio.emit(
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id
)
- logger.info(
+ logger.debug(
f"Emitted general queue event {event_name} to all subscribers in queue {event_data.queue_id}"
)
except Exception as e:
@@ -265,4 +346,17 @@ async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | Downloa
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
- await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)
+ event_name, event_data = event
+ # Route to user-specific + admin rooms so that other authenticated
+ # users cannot learn the bulk_download_item_name (the capability token
+ # needed to fetch the zip from the unauthenticated GET endpoint).
+ # In single-user mode (user_id="system"), fall back to the shared
+ # bulk_download_id room for backward compatibility.
+ if hasattr(event_data, "user_id") and event_data.user_id != "system":
+ user_room = f"user:{event_data.user_id}"
+ await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
+ await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
+ else:
+ await self._sio.emit(
+ event=event_name, data=event_data.model_dump(mode="json"), room=event_data.bulk_download_id
+ )
diff --git a/invokeai/app/services/board_records/board_records_common.py b/invokeai/app/services/board_records/board_records_common.py
index ab6355a3930..b263f264cb8 100644
--- a/invokeai/app/services/board_records/board_records_common.py
+++ b/invokeai/app/services/board_records/board_records_common.py
@@ -9,6 +9,17 @@
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
+class BoardVisibility(str, Enum, metaclass=MetaEnum):
+ """The visibility options for a board."""
+
+ Private = "private"
+ """Only the board owner (and admins) can see and modify this board."""
+ Shared = "shared"
+ """All users can view this board, but only the owner (and admins) can modify it."""
+ Public = "public"
+ """All users can view this board; only the owner (and admins) can modify its structure."""
+
+
class BoardRecord(BaseModelExcludeNull):
"""Deserialized board record."""
@@ -28,6 +39,10 @@ class BoardRecord(BaseModelExcludeNull):
"""The name of the cover image of the board."""
archived: bool = Field(description="Whether or not the board is archived.")
"""Whether or not the board is archived."""
+ board_visibility: BoardVisibility = Field(
+ default=BoardVisibility.Private, description="The visibility of the board."
+ )
+ """The visibility of the board (private, shared, or public)."""
def deserialize_board_record(board_dict: dict) -> BoardRecord:
@@ -44,6 +59,11 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
archived = board_dict.get("archived", False)
+ board_visibility_raw = board_dict.get("board_visibility", BoardVisibility.Private.value)
+ try:
+ board_visibility = BoardVisibility(board_visibility_raw)
+ except ValueError:
+ board_visibility = BoardVisibility.Private
return BoardRecord(
board_id=board_id,
@@ -54,6 +74,7 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
updated_at=updated_at,
deleted_at=deleted_at,
archived=archived,
+ board_visibility=board_visibility,
)
@@ -61,6 +82,7 @@ class BoardChanges(BaseModel, extra="forbid"):
board_name: Optional[str] = Field(default=None, description="The board's new name.", max_length=300)
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")
+ board_visibility: Optional[BoardVisibility] = Field(default=None, description="The visibility of the board.")
class BoardRecordOrderBy(str, Enum, metaclass=MetaEnum):
diff --git a/invokeai/app/services/board_records/board_records_sqlite.py b/invokeai/app/services/board_records/board_records_sqlite.py
index a54f65686fd..1e3e11c8a36 100644
--- a/invokeai/app/services/board_records/board_records_sqlite.py
+++ b/invokeai/app/services/board_records/board_records_sqlite.py
@@ -116,6 +116,17 @@ def update(
(changes.archived, board_id),
)
+ # Change the visibility of a board
+ if changes.board_visibility is not None:
+ cursor.execute(
+ """--sql
+ UPDATE boards
+ SET board_visibility = ?
+ WHERE board_id = ?;
+ """,
+ (changes.board_visibility.value, board_id),
+ )
+
except sqlite3.Error as e:
raise BoardRecordSaveException from e
return self.get(board_id)
@@ -155,7 +166,7 @@ def get_many(
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
- WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
+ WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
@@ -194,14 +205,14 @@ def get_many(
SELECT COUNT(DISTINCT boards.board_id)
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
- WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1);
+ WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'));
"""
else:
count_query = """
SELECT COUNT(DISTINCT boards.board_id)
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
- WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
+ WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
AND boards.archived = 0;
"""
@@ -251,7 +262,7 @@ def get_all(
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
- WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
+ WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
{archived_filter}
ORDER BY LOWER(boards.board_name) {direction}
"""
@@ -260,7 +271,7 @@ def get_all(
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
- WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
+ WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
{archived_filter}
ORDER BY {order_by} {direction}
"""
diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py
index 617b611f566..6cd4ed0cbaf 100644
--- a/invokeai/app/services/bulk_download/bulk_download_base.py
+++ b/invokeai/app/services/bulk_download/bulk_download_base.py
@@ -7,7 +7,11 @@ class BulkDownloadBase(ABC):
@abstractmethod
def handler(
- self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
+ self,
+ image_names: Optional[list[str]],
+ board_id: Optional[str],
+ bulk_download_item_id: Optional[str],
+ user_id: str = "system",
) -> None:
"""
Create a zip file containing the images specified by the given image names or board id.
@@ -15,6 +19,7 @@ def handler(
:param image_names: A list of image names to include in the zip file.
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
:param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated.
+ :param user_id: The ID of the user who initiated the download.
"""
@abstractmethod
@@ -42,3 +47,12 @@ def delete(self, bulk_download_item_name: str) -> None:
:param bulk_download_item_name: The name of the bulk download item.
"""
+
+ @abstractmethod
+ def get_owner(self, bulk_download_item_name: str) -> Optional[str]:
+ """
+ Get the user_id of the user who initiated the download.
+
+ :param bulk_download_item_name: The name of the bulk download item.
+ :return: The user_id of the owner, or None if not tracked.
+ """
diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py
index dc4f8b1d81b..c037e9c5c15 100644
--- a/invokeai/app/services/bulk_download/bulk_download_default.py
+++ b/invokeai/app/services/bulk_download/bulk_download_default.py
@@ -25,15 +25,24 @@ def __init__(self):
self._temp_directory = TemporaryDirectory()
self._bulk_downloads_folder = Path(self._temp_directory.name) / "bulk_downloads"
self._bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
+ # Track which user owns each download so the fetch endpoint can enforce ownership
+ self._download_owners: dict[str, str] = {}
def handler(
- self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
+ self,
+ image_names: Optional[list[str]],
+ board_id: Optional[str],
+ bulk_download_item_id: Optional[str],
+ user_id: str = "system",
) -> None:
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
bulk_download_item_id = bulk_download_item_id or uuid_string()
bulk_download_item_name = bulk_download_item_id + ".zip"
- self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
+ # Record ownership so the fetch endpoint can verify the caller
+ self._download_owners[bulk_download_item_name] = user_id
+
+ self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
try:
image_dtos: list[ImageDTO] = []
@@ -46,16 +55,16 @@ def handler(
raise BulkDownloadParametersException()
bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
- self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
+ self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
except (
ImageRecordNotFoundException,
BoardRecordNotFoundException,
BulkDownloadException,
BulkDownloadParametersException,
) as e:
- self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
+ self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e, user_id)
except Exception as e:
- self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
+ self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e, user_id)
self._invoker.services.logger.error("Problem bulk downloading images.")
raise e
@@ -103,43 +112,60 @@ def _clean_string_to_path_safe(self, s: str) -> str:
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip()
def _signal_job_started(
- self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
+ self,
+ bulk_download_id: str,
+ bulk_download_item_id: str,
+ bulk_download_item_name: str,
+ user_id: str = "system",
) -> None:
"""Signal that a bulk download job has started."""
if self._invoker:
assert bulk_download_id is not None
self._invoker.services.events.emit_bulk_download_started(
- bulk_download_id, bulk_download_item_id, bulk_download_item_name
+ bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id=user_id
)
def _signal_job_completed(
- self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
+ self,
+ bulk_download_id: str,
+ bulk_download_item_id: str,
+ bulk_download_item_name: str,
+ user_id: str = "system",
) -> None:
"""Signal that a bulk download job has completed."""
if self._invoker:
assert bulk_download_id is not None
assert bulk_download_item_name is not None
self._invoker.services.events.emit_bulk_download_complete(
- bulk_download_id, bulk_download_item_id, bulk_download_item_name
+ bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id=user_id
)
def _signal_job_failed(
- self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception
+ self,
+ bulk_download_id: str,
+ bulk_download_item_id: str,
+ bulk_download_item_name: str,
+ exception: Exception,
+ user_id: str = "system",
) -> None:
"""Signal that a bulk download job has failed."""
if self._invoker:
assert bulk_download_id is not None
assert exception is not None
self._invoker.services.events.emit_bulk_download_error(
- bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
+ bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception), user_id=user_id
)
def stop(self, *args, **kwargs):
self._temp_directory.cleanup()
+ def get_owner(self, bulk_download_item_name: str) -> Optional[str]:
+ return self._download_owners.get(bulk_download_item_name)
+
def delete(self, bulk_download_item_name: str) -> None:
path = self.get_path(bulk_download_item_name)
Path(path).unlink()
+ self._download_owners.pop(bulk_download_item_name, None)
def get_path(self, bulk_download_item_name: str) -> str:
path = str(self._bulk_downloads_folder / bulk_download_item_name)
diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py
index aa1cbb5e0ee..935b422a732 100644
--- a/invokeai/app/services/events/events_base.py
+++ b/invokeai/app/services/events/events_base.py
@@ -100,9 +100,9 @@ def emit_queue_item_status_changed(
"""Emitted when a queue item's status changes"""
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
- def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
+ def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", user_id: str = "system") -> None:
"""Emitted when a batch is enqueued"""
- self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
+ self.dispatch(BatchEnqueuedEvent.build(enqueue_result, user_id))
def emit_queue_items_retried(self, retry_result: "RetryItemsResult") -> None:
"""Emitted when a list of queue items are retried"""
@@ -112,9 +112,9 @@ def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when a queue is cleared"""
self.dispatch(QueueClearedEvent.build(queue_id))
- def emit_recall_parameters_updated(self, queue_id: str, parameters: dict) -> None:
+ def emit_recall_parameters_updated(self, queue_id: str, user_id: str, parameters: dict) -> None:
"""Emitted when recall parameters are updated"""
- self.dispatch(RecallParametersUpdatedEvent.build(queue_id, parameters))
+ self.dispatch(RecallParametersUpdatedEvent.build(queue_id, user_id, parameters))
# endregion
@@ -194,23 +194,42 @@ def emit_model_install_error(self, job: "ModelInstallJob") -> None:
# region Bulk image download
def emit_bulk_download_started(
- self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
+ self,
+ bulk_download_id: str,
+ bulk_download_item_id: str,
+ bulk_download_item_name: str,
+ user_id: str = "system",
) -> None:
"""Emitted when a bulk image download is started"""
- self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
+ self.dispatch(
+ BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
+ )
def emit_bulk_download_complete(
- self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
+ self,
+ bulk_download_id: str,
+ bulk_download_item_id: str,
+ bulk_download_item_name: str,
+ user_id: str = "system",
) -> None:
"""Emitted when a bulk image download is complete"""
- self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
+ self.dispatch(
+ BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
+ )
def emit_bulk_download_error(
- self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
+ self,
+ bulk_download_id: str,
+ bulk_download_item_id: str,
+ bulk_download_item_name: str,
+ error: str,
+ user_id: str = "system",
) -> None:
"""Emitted when a bulk image download has an error"""
self.dispatch(
- BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
+ BulkDownloadErrorEvent.build(
+ bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, user_id
+ )
)
# endregion
diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py
index bfb44eb48e8..998fe4f5309 100644
--- a/invokeai/app/services/events/events_common.py
+++ b/invokeai/app/services/events/events_common.py
@@ -281,9 +281,10 @@ class BatchEnqueuedEvent(QueueEventBase):
)
priority: int = Field(description="The priority of the batch")
origin: str | None = Field(default=None, description="The origin of the batch")
+ user_id: str = Field(default="system", description="The ID of the user who enqueued the batch")
@classmethod
- def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
+ def build(cls, enqueue_result: EnqueueBatchResult, user_id: str = "system") -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
@@ -291,6 +292,7 @@ def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,
+ user_id=user_id,
)
@@ -609,6 +611,7 @@ class BulkDownloadEventBase(EventBase):
bulk_download_id: str = Field(description="The ID of the bulk image download")
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
+ user_id: str = Field(default="system", description="The ID of the user who initiated the download")
@payload_schema.register
@@ -619,12 +622,17 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase):
@classmethod
def build(
- cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
+ cls,
+ bulk_download_id: str,
+ bulk_download_item_id: str,
+ bulk_download_item_name: str,
+ user_id: str = "system",
) -> "BulkDownloadStartedEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
+ user_id=user_id,
)
@@ -636,12 +644,17 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase):
@classmethod
def build(
- cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
+ cls,
+ bulk_download_id: str,
+ bulk_download_item_id: str,
+ bulk_download_item_name: str,
+ user_id: str = "system",
) -> "BulkDownloadCompleteEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
+ user_id=user_id,
)
@@ -655,13 +668,19 @@ class BulkDownloadErrorEvent(BulkDownloadEventBase):
@classmethod
def build(
- cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
+ cls,
+ bulk_download_id: str,
+ bulk_download_item_id: str,
+ bulk_download_item_name: str,
+ error: str,
+ user_id: str = "system",
) -> "BulkDownloadErrorEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=error,
+ user_id=user_id,
)
@@ -671,8 +690,9 @@ class RecallParametersUpdatedEvent(QueueEventBase):
__event_name__ = "recall_parameters_updated"
+ user_id: str = Field(description="The ID of the user whose recall parameters were updated")
parameters: dict[str, Any] = Field(description="The recall parameters that were updated")
@classmethod
- def build(cls, queue_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent":
- return cls(queue_id=queue_id, parameters=parameters)
+ def build(cls, queue_id: str, user_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent":
+ return cls(queue_id=queue_id, user_id=user_id, parameters=parameters)
diff --git a/invokeai/app/services/events/events_fastapievents.py b/invokeai/app/services/events/events_fastapievents.py
index f44eecc5559..90e1402773d 100644
--- a/invokeai/app/services/events/events_fastapievents.py
+++ b/invokeai/app/services/events/events_fastapievents.py
@@ -46,3 +46,9 @@ async def _dispatch_from_queue(self, stop_event: threading.Event):
except asyncio.CancelledError as e:
raise e # Raise a proper error
+ except Exception:
+ import logging
+
+ logging.getLogger("InvokeAI").error(
+ f"Error dispatching event {getattr(event, '__event_name__', event)}", exc_info=True
+ )
diff --git a/invokeai/app/services/image_records/image_records_base.py b/invokeai/app/services/image_records/image_records_base.py
index 16405c52708..457cf2f4686 100644
--- a/invokeai/app/services/image_records/image_records_base.py
+++ b/invokeai/app/services/image_records/image_records_base.py
@@ -74,8 +74,8 @@ def delete_intermediates(self) -> list[str]:
pass
@abstractmethod
- def get_intermediates_count(self) -> int:
- """Gets a count of all intermediate images."""
+ def get_intermediates_count(self, user_id: Optional[str] = None) -> int:
+ """Gets a count of intermediate images. If user_id is provided, only counts that user's intermediates."""
pass
@abstractmethod
@@ -97,6 +97,11 @@ def save(
"""Saves an image record."""
pass
+ @abstractmethod
+ def get_user_id(self, image_name: str) -> Optional[str]:
+ """Gets the user_id of the image owner. Returns None if image not found."""
+ pass
+
@abstractmethod
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
"""Gets the most recent image for a board."""
diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py
index c6c237fc1e7..07126d53a9f 100644
--- a/invokeai/app/services/image_records/image_records_sqlite.py
+++ b/invokeai/app/services/image_records/image_records_sqlite.py
@@ -46,6 +46,20 @@ def get(self, image_name: str) -> ImageRecord:
return deserialize_image_record(dict(result))
+ def get_user_id(self, image_name: str) -> Optional[str]:
+ with self._db.transaction() as cursor:
+ cursor.execute(
+ """--sql
+ SELECT user_id FROM images
+ WHERE image_name = ?;
+ """,
+ (image_name,),
+ )
+ result = cast(Optional[sqlite3.Row], cursor.fetchone())
+ if not result:
+ return None
+ return cast(Optional[str], dict(result).get("user_id"))
+
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
with self._db.transaction() as cursor:
try:
@@ -269,14 +283,14 @@ def delete_many(self, image_names: list[str]) -> None:
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
- def get_intermediates_count(self) -> int:
+ def get_intermediates_count(self, user_id: Optional[str] = None) -> int:
with self._db.transaction() as cursor:
- cursor.execute(
- """--sql
- SELECT COUNT(*) FROM images
- WHERE is_intermediate = TRUE;
- """
- )
+ query = "SELECT COUNT(*) FROM images WHERE is_intermediate = TRUE"
+ params: list[str] = []
+ if user_id is not None:
+ query += " AND user_id = ?"
+ params.append(user_id)
+ cursor.execute(query, params)
count = cast(int, cursor.fetchone()[0])
return count
diff --git a/invokeai/app/services/images/images_base.py b/invokeai/app/services/images/images_base.py
index d11d75b3c1d..aebbead2f35 100644
--- a/invokeai/app/services/images/images_base.py
+++ b/invokeai/app/services/images/images_base.py
@@ -143,8 +143,8 @@ def delete_intermediates(self) -> int:
pass
@abstractmethod
- def get_intermediates_count(self) -> int:
- """Gets the number of intermediate images."""
+ def get_intermediates_count(self, user_id: Optional[str] = None) -> int:
+ """Gets the number of intermediate images. If user_id is provided, only counts that user's intermediates."""
pass
@abstractmethod
diff --git a/invokeai/app/services/images/images_default.py b/invokeai/app/services/images/images_default.py
index e82bd7f4de1..0f03f7c400e 100644
--- a/invokeai/app/services/images/images_default.py
+++ b/invokeai/app/services/images/images_default.py
@@ -310,9 +310,9 @@ def delete_intermediates(self) -> int:
self.__invoker.services.logger.error("Problem deleting image records and files")
raise e
- def get_intermediates_count(self) -> int:
+ def get_intermediates_count(self, user_id: Optional[str] = None) -> int:
try:
- return self.__invoker.services.image_records.get_intermediates_count()
+ return self.__invoker.services.image_records.get_intermediates_count(user_id=user_id)
except Exception as e:
self.__invoker.services.logger.error("Problem getting intermediates count")
raise e
diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py
index 3c037dc77ab..14b93d97fc7 100644
--- a/invokeai/app/services/session_queue/session_queue_base.py
+++ b/invokeai/app/services/session_queue/session_queue_base.py
@@ -78,13 +78,15 @@ def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> Sess
pass
@abstractmethod
- def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
- """Gets the counts of queue items by destination"""
+ def get_counts_by_destination(
+ self, queue_id: str, destination: str, user_id: Optional[str] = None
+ ) -> SessionQueueCountsByDestination:
+ """Gets the counts of queue items by destination. If user_id is provided, only counts that user's items."""
pass
@abstractmethod
- def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
- """Gets the status of a batch"""
+ def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] = None) -> BatchStatus:
+ """Gets the status of a batch. If user_id is provided, only counts that user's items."""
pass
@abstractmethod
@@ -172,8 +174,9 @@ def get_queue_item_ids(
self,
queue_id: str,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
+ user_id: Optional[str] = None,
) -> ItemIdsResult:
- """Gets all queue item ids that match the given parameters"""
+ """Gets all queue item ids that match the given parameters. If user_id is provided, only returns items for that user."""
pass
@abstractmethod
diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py
index 58544422119..09820fe6217 100644
--- a/invokeai/app/services/session_queue/session_queue_common.py
+++ b/invokeai/app/services/session_queue/session_queue_common.py
@@ -304,12 +304,6 @@ class SessionQueueStatus(BaseModel):
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
- user_pending: Optional[int] = Field(
- default=None, description="Number of queue items with status 'pending' for the current user"
- )
- user_in_progress: Optional[int] = Field(
- default=None, description="Number of queue items with status 'in_progress' for the current user"
- )
class SessionQueueCountsByDestination(BaseModel):
diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py
index 4f46136fd79..070a7cef293 100644
--- a/invokeai/app/services/session_queue/session_queue_sqlite.py
+++ b/invokeai/app/services/session_queue/session_queue_sqlite.py
@@ -151,7 +151,7 @@ async def enqueue_batch(
priority=priority,
item_ids=item_ids,
)
- self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
+ self.__invoker.services.events.emit_batch_enqueued(enqueue_result, user_id=user_id)
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
@@ -765,15 +765,21 @@ def get_queue_item_ids(
self,
queue_id: str,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
+ user_id: Optional[str] = None,
) -> ItemIdsResult:
with self._db.transaction() as cursor_:
- query = f"""--sql
+ query = """--sql
SELECT item_id
FROM session_queue
WHERE queue_id = ?
- ORDER BY created_at {order_dir.value}
"""
- query_params = [queue_id]
+ query_params: list[str] = [queue_id]
+
+ if user_id is not None:
+ query += " AND user_id = ?"
+ query_params.append(user_id)
+
+ query += f" ORDER BY created_at {order_dir.value}"
cursor_.execute(query, query_params)
result = cast(list[sqlite3.Row], cursor_.fetchall())
@@ -783,20 +789,7 @@ def get_queue_item_ids(
def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus:
with self._db.transaction() as cursor:
- # Get total counts
- cursor.execute(
- """--sql
- SELECT status, count(*)
- FROM session_queue
- WHERE queue_id = ?
- GROUP BY status
- """,
- (queue_id,),
- )
- counts_result = cast(list[sqlite3.Row], cursor.fetchall())
-
- # Get user-specific counts if user_id is provided (using a single query with CASE)
- user_counts_result = []
+ # When user_id is provided (non-admin), only count that user's items
if user_id is not None:
cursor.execute(
"""--sql
@@ -807,48 +800,51 @@ def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> Sess
""",
(queue_id, user_id),
)
- user_counts_result = cast(list[sqlite3.Row], cursor.fetchall())
+ else:
+ cursor.execute(
+ """--sql
+ SELECT status, count(*)
+ FROM session_queue
+ WHERE queue_id = ?
+ GROUP BY status
+ """,
+ (queue_id,),
+ )
+ counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] or 0 for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
- # Process user-specific counts if available
- user_pending = None
- user_in_progress = None
- if user_id is not None:
- user_counts: dict[str, int] = {row[0]: row[1] for row in user_counts_result}
- user_pending = user_counts.get("pending", 0)
- user_in_progress = user_counts.get("in_progress", 0)
+ # For non-admin users, hide current item details if they don't own it
+ show_current_item = current_item is not None and (user_id is None or current_item.user_id == user_id)
return SessionQueueStatus(
queue_id=queue_id,
- item_id=current_item.item_id if current_item else None,
- session_id=current_item.session_id if current_item else None,
- batch_id=current_item.batch_id if current_item else None,
+ item_id=current_item.item_id if show_current_item else None,
+ session_id=current_item.session_id if show_current_item else None,
+ batch_id=current_item.batch_id if show_current_item else None,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),
completed=counts.get("completed", 0),
failed=counts.get("failed", 0),
canceled=counts.get("canceled", 0),
total=total,
- user_pending=user_pending,
- user_in_progress=user_in_progress,
)
- def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
+ def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] = None) -> BatchStatus:
with self._db.transaction() as cursor:
- cursor.execute(
- """--sql
+ query = """--sql
SELECT status, count(*), origin, destination
FROM session_queue
- WHERE
- queue_id = ?
- AND batch_id = ?
- GROUP BY status
- """,
- (queue_id, batch_id),
- )
+ WHERE queue_id = ? AND batch_id = ?
+ """
+ params: list[str] = [queue_id, batch_id]
+ if user_id is not None:
+ query += " AND user_id = ?"
+ params.append(user_id)
+ query += " GROUP BY status"
+ cursor.execute(query, params)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
@@ -868,18 +864,21 @@ def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
total=total,
)
- def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
+ def get_counts_by_destination(
+ self, queue_id: str, destination: str, user_id: Optional[str] = None
+ ) -> SessionQueueCountsByDestination:
with self._db.transaction() as cursor:
- cursor.execute(
- """--sql
+ query = """--sql
SELECT status, count(*)
FROM session_queue
- WHERE queue_id = ?
- AND destination = ?
- GROUP BY status
- """,
- (queue_id, destination),
- )
+ WHERE queue_id = ? AND destination = ?
+ """
+ params: list[str] = [queue_id, destination]
+ if user_id is not None:
+ query += " AND user_id = ?"
+ params.append(user_id)
+ query += " GROUP BY status"
+ cursor.execute(query, params)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in counts_result)
diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py
index 645509f1dde..fb8ca9fca38 100644
--- a/invokeai/app/services/shared/sqlite/sqlite_util.py
+++ b/invokeai/app/services/shared/sqlite/sqlite_util.py
@@ -30,6 +30,8 @@
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_26 import build_migration_26
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_27 import build_migration_27
+from invokeai.app.services.shared.sqlite_migrator.migrations.migration_28 import build_migration_28
+from invokeai.app.services.shared.sqlite_migrator.migrations.migration_29 import build_migration_29
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -77,6 +79,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_25(app_config=config, logger=logger))
migrator.register_migration(build_migration_26(app_config=config, logger=logger))
migrator.register_migration(build_migration_27())
+ migrator.register_migration(build_migration_28())
+ migrator.register_migration(build_migration_29())
migrator.run_migrations()
return db
diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py
new file mode 100644
index 00000000000..0cbd683ab5e
--- /dev/null
+++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py
@@ -0,0 +1,45 @@
+"""Migration 28: Add per-user workflow isolation columns to workflow_library.
+
+This migration adds the database columns required for multiuser workflow isolation
+to the workflow_library table:
+- user_id: the owner of the workflow (defaults to 'system' for existing workflows)
+- is_public: whether the workflow is shared with all users
+"""
+
+import sqlite3
+
+from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
+
+
+class Migration28Callback:
+ """Migration to add user_id and is_public to the workflow_library table."""
+
+ def __call__(self, cursor: sqlite3.Cursor) -> None:
+ self._update_workflow_library_table(cursor)
+
+ def _update_workflow_library_table(self, cursor: sqlite3.Cursor) -> None:
+ """Add user_id and is_public columns to workflow_library table."""
+ cursor.execute("PRAGMA table_info(workflow_library);")
+ columns = [row[1] for row in cursor.fetchall()]
+
+ if "user_id" not in columns:
+ cursor.execute("ALTER TABLE workflow_library ADD COLUMN user_id TEXT DEFAULT 'system';")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflow_library_user_id ON workflow_library(user_id);")
+
+ if "is_public" not in columns:
+ cursor.execute("ALTER TABLE workflow_library ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflow_library_is_public ON workflow_library(is_public);")
+
+
+def build_migration_28() -> Migration:
+ """Builds the migration object for migrating from version 27 to version 28.
+
+ This migration adds per-user workflow isolation to the workflow_library table:
+ - user_id column: identifies the owner of each workflow
+ - is_public column: controls whether a workflow is shared with all users
+ """
+ return Migration(
+ from_version=27,
+ to_version=28,
+ callback=Migration28Callback(),
+ )
diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_29.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_29.py
new file mode 100644
index 00000000000..c9eb7c901ba
--- /dev/null
+++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_29.py
@@ -0,0 +1,53 @@
+"""Migration 29: Add board_visibility column to boards table.
+
+This migration adds a board_visibility column to the boards table to support
+three visibility levels:
+ - 'private': only the board owner (and admins) can view/modify
+ - 'shared': all users can view, but only the owner (and admins) can modify
+ - 'public': all users can view; only the owner (and admins) can modify the
+ board structure (rename/archive/delete)
+
+Existing boards with is_public = 1 are migrated to 'public'.
+All other existing boards default to 'private'.
+"""
+
+import sqlite3
+
+from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
+
+
+class Migration29Callback:
+ """Migration to add board_visibility column to the boards table."""
+
+ def __call__(self, cursor: sqlite3.Cursor) -> None:
+ self._update_boards_table(cursor)
+
+ def _update_boards_table(self, cursor: sqlite3.Cursor) -> None:
+ """Add board_visibility column to boards table."""
+ # Check if boards table exists
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='boards';")
+ if cursor.fetchone() is None:
+ return
+
+ cursor.execute("PRAGMA table_info(boards);")
+ columns = [row[1] for row in cursor.fetchall()]
+
+ if "board_visibility" not in columns:
+ cursor.execute("ALTER TABLE boards ADD COLUMN board_visibility TEXT NOT NULL DEFAULT 'private';")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_boards_board_visibility ON boards(board_visibility);")
+ # Migrate existing is_public = 1 boards to 'public'
+ if "is_public" in columns:
+ cursor.execute("UPDATE boards SET board_visibility = 'public' WHERE is_public = 1;")
+
+
+def build_migration_29() -> Migration:
+ """Builds the migration object for migrating from version 28 to version 29.
+
+ This migration adds the board_visibility column to the boards table,
+ supporting 'private', 'shared', and 'public' visibility levels.
+ """
+ return Migration(
+ from_version=28,
+ to_version=29,
+ callback=Migration29Callback(),
+ )
diff --git a/invokeai/app/services/users/users_base.py b/invokeai/app/services/users/users_base.py
index 728a0adfa37..dd789b561ee 100644
--- a/invokeai/app/services/users/users_base.py
+++ b/invokeai/app/services/users/users_base.py
@@ -131,6 +131,15 @@ def list_users(self, limit: int = 100, offset: int = 0) -> list[UserDTO]:
"""
pass
+ @abstractmethod
+ def get_admin_email(self) -> str | None:
+ """Get the email address of the first active admin user.
+
+ Returns:
+ Email address of the first active admin, or None if no admin exists
+ """
+ pass
+
@abstractmethod
def count_admins(self) -> int:
"""Count active admin users.
diff --git a/invokeai/app/services/users/users_default.py b/invokeai/app/services/users/users_default.py
index 709e4cb82c6..6e472882124 100644
--- a/invokeai/app/services/users/users_default.py
+++ b/invokeai/app/services/users/users_default.py
@@ -256,6 +256,20 @@ def list_users(self, limit: int = 100, offset: int = 0) -> list[UserDTO]:
for row in rows
]
+ def get_admin_email(self) -> str | None:
+ """Get the email address of the first active admin user."""
+ with self._db.transaction() as cursor:
+ cursor.execute(
+ """
+ SELECT email FROM users
+ WHERE is_admin = TRUE AND is_active = TRUE
+ ORDER BY created_at ASC
+ LIMIT 1
+ """,
+ )
+ row = cursor.fetchone()
+ return row[0] if row else None
+
def count_admins(self) -> int:
"""Count active admin users."""
with self._db.transaction() as cursor:
diff --git a/invokeai/app/services/workflow_records/workflow_records_base.py b/invokeai/app/services/workflow_records/workflow_records_base.py
index d5cf319594b..856a6c6d490 100644
--- a/invokeai/app/services/workflow_records/workflow_records_base.py
+++ b/invokeai/app/services/workflow_records/workflow_records_base.py
@@ -4,6 +4,7 @@
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.workflow_records.workflow_records_common import (
+ WORKFLOW_LIBRARY_DEFAULT_USER_ID,
Workflow,
WorkflowCategory,
WorkflowRecordDTO,
@@ -22,18 +23,18 @@ def get(self, workflow_id: str) -> WorkflowRecordDTO:
pass
@abstractmethod
- def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
+ def create(self, workflow: WorkflowWithoutID, user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID) -> WorkflowRecordDTO:
"""Creates a workflow."""
pass
@abstractmethod
- def update(self, workflow: Workflow) -> WorkflowRecordDTO:
- """Updates a workflow."""
+ def update(self, workflow: Workflow, user_id: Optional[str] = None) -> WorkflowRecordDTO:
+ """Updates a workflow. When user_id is provided, the UPDATE is scoped to that user."""
pass
@abstractmethod
- def delete(self, workflow_id: str) -> None:
- """Deletes a workflow."""
+ def delete(self, workflow_id: str, user_id: Optional[str] = None) -> None:
+ """Deletes a workflow. When user_id is provided, the DELETE is scoped to that user."""
pass
@abstractmethod
@@ -47,6 +48,8 @@ def get_many(
query: Optional[str],
tags: Optional[list[str]],
has_been_opened: Optional[bool],
+ user_id: Optional[str] = None,
+ is_public: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets many workflows."""
pass
@@ -56,6 +59,8 @@ def counts_by_category(
self,
categories: list[WorkflowCategory],
has_been_opened: Optional[bool] = None,
+ user_id: Optional[str] = None,
+ is_public: Optional[bool] = None,
) -> dict[str, int]:
"""Gets a dictionary of counts for each of the provided categories."""
pass
@@ -66,19 +71,28 @@ def counts_by_tag(
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
has_been_opened: Optional[bool] = None,
+ user_id: Optional[str] = None,
+ is_public: Optional[bool] = None,
) -> dict[str, int]:
"""Gets a dictionary of counts for each of the provided tags."""
pass
@abstractmethod
- def update_opened_at(self, workflow_id: str) -> None:
- """Open a workflow."""
+ def update_opened_at(self, workflow_id: str, user_id: Optional[str] = None) -> None:
+ """Open a workflow. When user_id is provided, the UPDATE is scoped to that user."""
pass
@abstractmethod
def get_all_tags(
self,
categories: Optional[list[WorkflowCategory]] = None,
+ user_id: Optional[str] = None,
+ is_public: Optional[bool] = None,
) -> list[str]:
"""Gets all unique tags from workflows."""
pass
+
+ @abstractmethod
+ def update_is_public(self, workflow_id: str, is_public: bool, user_id: Optional[str] = None) -> WorkflowRecordDTO:
+ """Updates the is_public field of a workflow. When user_id is provided, the UPDATE is scoped to that user."""
+ pass
diff --git a/invokeai/app/services/workflow_records/workflow_records_common.py b/invokeai/app/services/workflow_records/workflow_records_common.py
index e0cea37468d..9c505530c90 100644
--- a/invokeai/app/services/workflow_records/workflow_records_common.py
+++ b/invokeai/app/services/workflow_records/workflow_records_common.py
@@ -9,6 +9,9 @@
__workflow_meta_version__ = semver.Version.parse("1.0.0")
+WORKFLOW_LIBRARY_DEFAULT_USER_ID = "system"
+"""Default user_id for workflows created in single-user mode or migrated from pre-multiuser databases."""
+
class ExposedField(BaseModel):
nodeId: str
@@ -26,6 +29,7 @@ class WorkflowRecordOrderBy(str, Enum, metaclass=MetaEnum):
UpdatedAt = "updated_at"
OpenedAt = "opened_at"
Name = "name"
+ IsPublic = "is_public"
class WorkflowCategory(str, Enum, metaclass=MetaEnum):
@@ -100,6 +104,8 @@ class WorkflowRecordDTOBase(BaseModel):
opened_at: Optional[Union[datetime.datetime, str]] = Field(
default=None, description="The opened timestamp of the workflow."
)
+ user_id: str = Field(description="The id of the user who owns this workflow.")
+ is_public: bool = Field(description="Whether this workflow is shared with all users.")
class WorkflowRecordDTO(WorkflowRecordDTOBase):
diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py
index 0f72f7cd92c..c83d87eff68 100644
--- a/invokeai/app/services/workflow_records/workflow_records_sqlite.py
+++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py
@@ -7,6 +7,7 @@
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from invokeai.app.services.workflow_records.workflow_records_common import (
+ WORKFLOW_LIBRARY_DEFAULT_USER_ID,
Workflow,
WorkflowCategory,
WorkflowNotFoundError,
@@ -36,7 +37,7 @@ def get(self, workflow_id: str) -> WorkflowRecordDTO:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
- SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
+ SELECT workflow_id, workflow, name, created_at, updated_at, opened_at, user_id, is_public
FROM workflow_library
WHERE workflow_id = ?;
""",
@@ -47,7 +48,7 @@ def get(self, workflow_id: str) -> WorkflowRecordDTO:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
- def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
+ def create(self, workflow: WorkflowWithoutID, user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be created via this method")
@@ -57,43 +58,98 @@ def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
"""--sql
INSERT OR IGNORE INTO workflow_library (
workflow_id,
- workflow
+ workflow,
+ user_id
)
- VALUES (?, ?);
+ VALUES (?, ?, ?);
""",
- (workflow_with_id.id, workflow_with_id.model_dump_json()),
+ (workflow_with_id.id, workflow_with_id.model_dump_json(), user_id),
)
return self.get(workflow_with_id.id)
- def update(self, workflow: Workflow) -> WorkflowRecordDTO:
+ def update(self, workflow: Workflow, user_id: Optional[str] = None) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be updated")
with self._db.transaction() as cursor:
- cursor.execute(
- """--sql
- UPDATE workflow_library
- SET workflow = ?
- WHERE workflow_id = ? AND category = 'user';
- """,
- (workflow.model_dump_json(), workflow.id),
- )
+ if user_id is not None:
+ cursor.execute(
+ """--sql
+ UPDATE workflow_library
+ SET workflow = ?
+ WHERE workflow_id = ? AND category = 'user' AND user_id = ?;
+ """,
+ (workflow.model_dump_json(), workflow.id, user_id),
+ )
+ else:
+ cursor.execute(
+ """--sql
+ UPDATE workflow_library
+ SET workflow = ?
+ WHERE workflow_id = ? AND category = 'user';
+ """,
+ (workflow.model_dump_json(), workflow.id),
+ )
return self.get(workflow.id)
- def delete(self, workflow_id: str) -> None:
+ def delete(self, workflow_id: str, user_id: Optional[str] = None) -> None:
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be deleted")
with self._db.transaction() as cursor:
- cursor.execute(
- """--sql
- DELETE from workflow_library
- WHERE workflow_id = ? AND category = 'user';
- """,
- (workflow_id,),
- )
+ if user_id is not None:
+ cursor.execute(
+ """--sql
+ DELETE from workflow_library
+ WHERE workflow_id = ? AND category = 'user' AND user_id = ?;
+ """,
+ (workflow_id, user_id),
+ )
+ else:
+ cursor.execute(
+ """--sql
+ DELETE from workflow_library
+ WHERE workflow_id = ? AND category = 'user';
+ """,
+ (workflow_id,),
+ )
return None
+ def update_is_public(self, workflow_id: str, is_public: bool, user_id: Optional[str] = None) -> WorkflowRecordDTO:
+ """Updates the is_public field of a workflow and manages the 'shared' tag automatically."""
+ record = self.get(workflow_id)
+ workflow = record.workflow
+
+ # Manage "shared" tag: add when public, remove when private
+ tags_list = [t.strip() for t in workflow.tags.split(",") if t.strip()] if workflow.tags else []
+ if is_public and "shared" not in tags_list:
+ tags_list.append("shared")
+ elif not is_public and "shared" in tags_list:
+ tags_list.remove("shared")
+ updated_tags = ", ".join(tags_list)
+ updated_workflow = workflow.model_copy(update={"tags": updated_tags})
+
+ with self._db.transaction() as cursor:
+ if user_id is not None:
+ cursor.execute(
+ """--sql
+ UPDATE workflow_library
+ SET workflow = ?, is_public = ?
+ WHERE workflow_id = ? AND category = 'user' AND user_id = ?;
+ """,
+ (updated_workflow.model_dump_json(), is_public, workflow_id, user_id),
+ )
+ else:
+ cursor.execute(
+ """--sql
+ UPDATE workflow_library
+ SET workflow = ?, is_public = ?
+ WHERE workflow_id = ? AND category = 'user';
+ """,
+ (updated_workflow.model_dump_json(), is_public, workflow_id),
+ )
+ return self.get(workflow_id)
+
def get_many(
self,
order_by: WorkflowRecordOrderBy,
@@ -104,6 +160,8 @@ def get_many(
query: Optional[str] = None,
tags: Optional[list[str]] = None,
has_been_opened: Optional[bool] = None,
+ user_id: Optional[str] = None,
+ is_public: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
with self._db.transaction() as cursor:
# sanitize!
@@ -122,7 +180,9 @@ def get_many(
created_at,
updated_at,
opened_at,
- tags
+ tags,
+ user_id,
+ is_public
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
@@ -177,6 +237,16 @@ def get_many(
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
+ if user_id is not None:
+ # Scope to the given user but always include default workflows
+ conditions.append("(user_id = ? OR category = 'default')")
+ params.append(user_id)
+
+ if is_public is True:
+ conditions.append("is_public = TRUE")
+ elif is_public is False:
+ conditions.append("is_public = FALSE")
+
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
@@ -226,6 +296,8 @@ def counts_by_tag(
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
has_been_opened: Optional[bool] = None,
+ user_id: Optional[str] = None,
+ is_public: Optional[bool] = None,
) -> dict[str, int]:
if not tags:
return {}
@@ -248,6 +320,16 @@ def counts_by_tag(
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
+ if user_id is not None:
+ # Scope to the given user but always include default workflows
+ base_conditions.append("(user_id = ? OR category = 'default')")
+ base_params.append(user_id)
+
+ if is_public is True:
+ base_conditions.append("is_public = TRUE")
+ elif is_public is False:
+ base_conditions.append("is_public = FALSE")
+
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
@@ -277,6 +359,8 @@ def counts_by_category(
self,
categories: list[WorkflowCategory],
has_been_opened: Optional[bool] = None,
+ user_id: Optional[str] = None,
+ is_public: Optional[bool] = None,
) -> dict[str, int]:
with self._db.transaction() as cursor:
result: dict[str, int] = {}
@@ -296,6 +380,16 @@ def counts_by_category(
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
+ if user_id is not None:
+ # Scope to the given user but always include default workflows
+ base_conditions.append("(user_id = ? OR category = 'default')")
+ base_params.append(user_id)
+
+ if is_public is True:
+ base_conditions.append("is_public = TRUE")
+ elif is_public is False:
+ base_conditions.append("is_public = FALSE")
+
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
@@ -321,20 +415,32 @@ def counts_by_category(
return result
- def update_opened_at(self, workflow_id: str) -> None:
+ def update_opened_at(self, workflow_id: str, user_id: Optional[str] = None) -> None:
with self._db.transaction() as cursor:
- cursor.execute(
- f"""--sql
- UPDATE workflow_library
- SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW')
- WHERE workflow_id = ?;
- """,
- (workflow_id,),
- )
+ if user_id is not None:
+ cursor.execute(
+ f"""--sql
+ UPDATE workflow_library
+ SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW')
+ WHERE workflow_id = ? AND user_id = ?;
+ """,
+ (workflow_id, user_id),
+ )
+ else:
+ cursor.execute(
+ f"""--sql
+ UPDATE workflow_library
+ SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW')
+ WHERE workflow_id = ?;
+ """,
+ (workflow_id,),
+ )
def get_all_tags(
self,
categories: Optional[list[WorkflowCategory]] = None,
+ user_id: Optional[str] = None,
+ is_public: Optional[bool] = None,
) -> list[str]:
with self._db.transaction() as cursor:
conditions: list[str] = []
@@ -349,6 +455,16 @@ def get_all_tags(
conditions.append(f"category IN ({placeholders})")
params.extend([category.value for category in categories])
+ if user_id is not None:
+ # Scope to the given user but always include default workflows
+ conditions.append("(user_id = ? OR category = 'default')")
+ params.append(user_id)
+
+ if is_public is True:
+ conditions.append("is_public = TRUE")
+ elif is_public is False:
+ conditions.append("is_public = FALSE")
+
stmt = """--sql
SELECT DISTINCT tags
FROM workflow_library
diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json
index af8476528d6..19e5a3a68e9 100644
--- a/invokeai/frontend/web/openapi.json
+++ b/invokeai/frontend/web/openapi.json
@@ -6463,6 +6463,23 @@
"title": "Has Been Opened"
},
"description": "Whether to include/exclude recent workflows"
+ },
+ {
+ "name": "is_public",
+ "in": "query",
+ "required": false,
+ "schema": {
+ "anyOf": [
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Is Public"
+ },
+ "description": "Filter by public/shared status"
}
],
"responses": {
@@ -6655,6 +6672,23 @@
"title": "Categories"
},
"description": "The categories to include"
+ },
+ {
+ "name": "is_public",
+ "in": "query",
+ "required": false,
+ "schema": {
+ "anyOf": [
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Is Public"
+ },
+ "description": "Filter by public/shared status"
}
],
"responses": {
@@ -6744,6 +6778,23 @@
"title": "Has Been Opened"
},
"description": "Whether to include/exclude recent workflows"
+ },
+ {
+ "name": "is_public",
+ "in": "query",
+ "required": false,
+ "schema": {
+ "anyOf": [
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Is Public"
+ },
+ "description": "Filter by public/shared status"
}
],
"responses": {
@@ -6812,6 +6863,23 @@
"title": "Has Been Opened"
},
"description": "Whether to include/exclude recent workflows"
+ },
+ {
+ "name": "is_public",
+ "in": "query",
+ "required": false,
+ "schema": {
+ "anyOf": [
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Is Public"
+ },
+ "description": "Filter by public/shared status"
}
],
"responses": {
@@ -7352,6 +7420,67 @@
}
}
}
+ },
+ "/api/v1/workflows/i/{workflow_id}/is_public": {
+ "patch": {
+ "tags": ["workflows"],
+ "summary": "Update Workflow Is Public",
+ "description": "Updates whether a workflow is shared publicly",
+ "operationId": "update_workflow_is_public",
+ "parameters": [
+ {
+ "name": "workflow_id",
+ "in": "path",
+ "required": true,
+ "schema": {
+ "type": "string",
+ "title": "Workflow Id"
+ },
+ "description": "The workflow to update"
+ }
+ ],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "properties": {
+ "is_public": {
+ "type": "boolean",
+ "title": "Is Public",
+ "description": "Whether the workflow should be shared publicly"
+ }
+ },
+ "type": "object",
+ "required": ["is_public"],
+ "title": "Body_update_workflow_is_public"
+ }
+ }
+ },
+ "required": true
+ },
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/WorkflowRecordDTO"
+ }
+ }
+ }
+ },
+ "422": {
+ "description": "Validation Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/HTTPValidationError"
+ }
+ }
+ }
+ }
+ }
+ }
}
},
"components": {
@@ -59137,10 +59266,20 @@
"workflow": {
"$ref": "#/components/schemas/Workflow",
"description": "The workflow."
+ },
+ "user_id": {
+ "type": "string",
+ "title": "User Id",
+ "description": "The id of the user who owns this workflow."
+ },
+ "is_public": {
+ "type": "boolean",
+ "title": "Is Public",
+ "description": "Whether this workflow is shared with all users."
}
},
"type": "object",
- "required": ["workflow_id", "name", "created_at", "updated_at", "workflow"],
+ "required": ["workflow_id", "name", "created_at", "updated_at", "workflow", "user_id", "is_public"],
"title": "WorkflowRecordDTO"
},
"WorkflowRecordListItemWithThumbnailDTO": {
@@ -59222,15 +59361,35 @@
],
"title": "Thumbnail Url",
"description": "The URL of the workflow thumbnail."
+ },
+ "user_id": {
+ "type": "string",
+ "title": "User Id",
+ "description": "The id of the user who owns this workflow."
+ },
+ "is_public": {
+ "type": "boolean",
+ "title": "Is Public",
+ "description": "Whether this workflow is shared with all users."
}
},
"type": "object",
- "required": ["workflow_id", "name", "created_at", "updated_at", "description", "category", "tags"],
+ "required": [
+ "workflow_id",
+ "name",
+ "created_at",
+ "updated_at",
+ "description",
+ "category",
+ "tags",
+ "user_id",
+ "is_public"
+ ],
"title": "WorkflowRecordListItemWithThumbnailDTO"
},
"WorkflowRecordOrderBy": {
"type": "string",
- "enum": ["created_at", "updated_at", "opened_at", "name"],
+ "enum": ["created_at", "updated_at", "opened_at", "name", "is_public"],
"title": "WorkflowRecordOrderBy",
"description": "The order by options for workflow records"
},
@@ -59303,10 +59462,20 @@
],
"title": "Thumbnail Url",
"description": "The URL of the workflow thumbnail."
+ },
+ "user_id": {
+ "type": "string",
+ "title": "User Id",
+ "description": "The id of the user who owns this workflow."
+ },
+ "is_public": {
+ "type": "boolean",
+ "title": "Is Public",
+ "description": "Whether this workflow is shared with all users."
}
},
"type": "object",
- "required": ["workflow_id", "name", "created_at", "updated_at", "workflow"],
+ "required": ["workflow_id", "name", "created_at", "updated_at", "workflow", "user_id", "is_public"],
"title": "WorkflowRecordWithThumbnailDTO"
},
"WorkflowWithoutID": {
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index 9b2aaddad73..201ea8badba 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -161,7 +161,17 @@
"imagesWithCount_other": "{{count}} images",
"assetsWithCount_one": "{{count}} asset",
"assetsWithCount_other": "{{count}} assets",
- "updateBoardError": "Error updating board"
+ "updateBoardError": "Error updating board",
+ "setBoardVisibility": "Set Board Visibility",
+ "setVisibilityPrivate": "Set Private",
+ "setVisibilityShared": "Set Shared",
+ "setVisibilityPublic": "Set Public",
+ "visibilityPrivate": "Private",
+ "visibilityShared": "Shared",
+ "visibilityPublic": "Public",
+ "visibilityBadgeShared": "Shared board",
+ "visibilityBadgePublic": "Public board",
+ "updateBoardVisibilityError": "Error updating board visibility"
},
"accordions": {
"generation": {
@@ -1168,7 +1178,9 @@
"name": "Name",
"modelPickerFallbackNoModelsInstalled": "No models installed.",
"modelPickerFallbackNoModelsInstalled2": "Visit the Model Manager to install models.",
+ "modelPickerFallbackNoModelsInstalledNonAdmin": "No models installed. Ask your InvokeAI administrator () to install some models.",
"noModelsInstalledDesc1": "Install models with the",
+ "noModelsInstalledAskAdmin": "Ask your administrator to install some.",
"noModelSelected": "No Model Selected",
"noMatchingModels": "No matching models",
"noModelsInstalled": "No models installed",
@@ -1535,6 +1547,7 @@
"info": "Info",
"invoke": {
"addingImagesTo": "Adding images to",
+ "boardNotWritable": "You do not have write access to board \"{{boardName}}\". Select a board you own or switch to Uncategorized.",
"modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your account settings to upgrade.",
"invoke": "Invoke",
"missingFieldTemplate": "Missing field template",
@@ -2287,6 +2300,8 @@
"tags": "Tags",
"yourWorkflows": "Your Workflows",
"recentlyOpened": "Recently Opened",
+ "sharedWorkflows": "Shared Workflows",
+ "shareWorkflow": "Shared workflow",
"noRecentWorkflows": "No Recent Workflows",
"private": "Private",
"shared": "Shared",
@@ -3021,6 +3036,7 @@
"tileOverlap": "Tile Overlap",
"postProcessingMissingModelWarning": "Visit the Model Manager to install a post-processing (image to image) model.",
"missingModelsWarning": "Visit the Model Manager to install the required models:",
+ "missingModelsWarningNonAdmin": "Ask your InvokeAI administrator () to install the required models:",
"mainModelDesc": "Main model (SD1.5 or SDXL architecture)",
"tileControlNetModelDesc": "Tile ControlNet model for the chosen main model architecture",
"upscaleModelDesc": "Upscale (image to image) model",
@@ -3129,6 +3145,7 @@
},
"workflows": {
"description": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results.",
+ "descriptionMultiuser": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results. You may share your workflows with other users of the system by selecting 'Shared workflow' when you create or edit it.",
"learnMoreLink": "Learn more about creating workflows",
"browseTemplates": {
"title": "Browse Workflow Templates",
@@ -3207,9 +3224,11 @@
"toGetStartedLocal": "To get started, make sure to download or import models needed to run Invoke. Then, enter a prompt in the box and click Invoke to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the Gallery or edit them to the Canvas.",
"toGetStarted": "To get started, enter a prompt in the box and click Invoke to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the Gallery or edit them to the Canvas.",
"toGetStartedWorkflow": "To get started, fill in the fields on the left and press Invoke to generate your image. Want to explore more workflows? Click the folder icon next to the workflow title to see a list of other templates you can try.",
+ "toGetStartedNonAdmin": "To get started, ask your InvokeAI administrator () to install the AI models needed to run Invoke. Then, enter a prompt in the box and click Invoke to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the Gallery or edit them to the Canvas.",
"gettingStartedSeries": "Want more guidance? Check out our Getting Started Series for tips on unlocking the full potential of the Invoke Studio.",
"lowVRAMMode": "For best performance, follow our Low VRAM guide.",
- "noModelsInstalled": "It looks like you don't have any models installed! You can download a starter model bundle or import models."
+ "noModelsInstalled": "It looks like you don't have any models installed! You can download a starter model bundle or import models.",
+ "noModelsInstalledAskAdmin": "Ask your administrator to install some."
},
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx
index e0e72d12ffd..fa4c29b8f42 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx
@@ -12,10 +12,14 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
effect: (action) => {
log.debug(action.payload, 'Bulk download requested');
- // If we have an item name, we are processing the bulk download locally and should use it as the toast id to
- // prevent multiple toasts for the same item.
+ // Use a "preparing:" prefix so this toast cannot collide with the
+ // "ready to download" toast that arrives via the bulk_download_complete
+ // socket event. The background task can complete in under 20ms, so the
+ // socket event may arrive *before* this Redux middleware runs — without
+ // distinct IDs the "preparing" toast would overwrite the "ready" toast.
+ const itemName = action.payload.bulk_download_item_name;
toast({
- id: action.payload.bulk_download_item_name ?? undefined,
+ id: itemName ? `preparing:${itemName}` : undefined,
title: t('gallery.bulkDownloadRequested'),
status: 'success',
// Show the response message if it exists, otherwise show the default message
diff --git a/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx b/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx
index 00217eb7963..5ac6ffcb7c9 100644
--- a/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx
+++ b/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx
@@ -3,6 +3,7 @@ import { Combobox, ConfirmationAlertDialog, Flex, FormControl, Text } from '@inv
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
+import { selectCurrentUser } from 'features/auth/store/authSlice';
import {
changeBoardReset,
isModalOpenChanged,
@@ -13,6 +14,7 @@ import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { useAddImagesToBoardMutation, useRemoveImagesFromBoardMutation } from 'services/api/endpoints/images';
+import type { BoardDTO } from 'services/api/types';
const selectImagesToChange = createSelector(
selectChangeBoardModalSlice,
@@ -28,6 +30,7 @@ const ChangeBoardModal = () => {
useAssertSingleton('ChangeBoardModal');
const dispatch = useAppDispatch();
const currentBoardId = useAppSelector(selectSelectedBoardId);
+ const currentUser = useAppSelector(selectCurrentUser);
const [selectedBoardId, setSelectedBoardId] = useState();
const { data: boards, isFetching } = useListAllBoardsQuery({ include_archived: true });
const isModalOpen = useAppSelector(selectIsModalOpen);
@@ -36,10 +39,20 @@ const ChangeBoardModal = () => {
const [removeImagesFromBoard] = useRemoveImagesFromBoardMutation();
const { t } = useTranslation();
+ // Returns true if the current user can write images to the given board.
+ const canWriteToBoard = useCallback(
+ (board: BoardDTO): boolean => {
+ const isOwnerOrAdmin = !currentUser || currentUser.is_admin || board.user_id === currentUser.user_id;
+ return isOwnerOrAdmin || board.board_visibility === 'public';
+ },
+ [currentUser]
+ );
+
const options = useMemo(() => {
return [{ label: t('boards.uncategorized'), value: 'none' }]
.concat(
(boards ?? [])
+ .filter(canWriteToBoard)
.map((board) => ({
label: board.board_name,
value: board.board_id,
@@ -47,7 +60,7 @@ const ChangeBoardModal = () => {
.sort((a, b) => a.label.localeCompare(b.label))
)
.filter((board) => board.value !== currentBoardId);
- }, [boards, currentBoardId, t]);
+ }, [boards, canWriteToBoard, currentBoardId, t]);
const value = useMemo(() => options.find((o) => o.value === selectedBoardId), [options, selectedBoardId]);
diff --git a/invokeai/frontend/web/src/features/dnd/dnd.ts b/invokeai/frontend/web/src/features/dnd/dnd.ts
index f5e38d4b944..ee648e82ef6 100644
--- a/invokeai/frontend/web/src/features/dnd/dnd.ts
+++ b/invokeai/frontend/web/src/features/dnd/dnd.ts
@@ -434,6 +434,49 @@ export const replaceCanvasEntityObjectsWithImageDndTarget: DndTarget<
//#endregion
//#region Add To Board
+/**
+ * Check whether the current user can move images out of their source board.
+ * Returns false if the source board is a shared board not owned by the current user
+ * (and the user is not an admin). In that case, images can be viewed/used but not moved.
+ */
+const canMoveFromSourceBoard = (sourceBoardId: BoardId, getState: AppGetState): boolean => {
+ const state = getState();
+ // In single-user mode (no auth), always allow
+ const currentUser = state.auth?.user;
+ if (!currentUser) {
+ return true;
+ }
+ // Admins can always move
+ if (currentUser.is_admin) {
+ return true;
+ }
+ // "Uncategorized" (none) — user's own uncategorized images, allow
+ if (sourceBoardId === 'none') {
+ return true;
+ }
+ // Look up the board from the RTK Query cache
+ const boardsQueryState = state.api?.queries;
+ if (boardsQueryState) {
+ for (const query of Object.values(boardsQueryState)) {
+ if (query?.data && Array.isArray(query.data)) {
+ const board = (query.data as Array<{ board_id: string; user_id?: string; board_visibility?: string }>).find(
+ (b) => b.board_id === sourceBoardId
+ );
+ if (board) {
+ // Owner can always move
+ if (board.user_id === currentUser.user_id) {
+ return true;
+ }
+ // Non-owner can only move from public boards
+ return board.board_visibility === 'public';
+ }
+ }
+ }
+ }
+ // Board not found in cache — allow by default to avoid blocking legitimate operations
+ return true;
+};
+
const _addToBoard = buildTypeAndKey('add-to-board');
export type AddImageToBoardDndTargetData = DndData<
typeof _addToBoard.type,
@@ -447,16 +490,23 @@ export const addImageToBoardDndTarget: DndTarget<
..._addToBoard,
typeGuard: buildTypeGuard(_addToBoard.key),
getData: buildGetData(_addToBoard.key, _addToBoard.type),
- isValid: ({ sourceData, targetData }) => {
+ isValid: ({ sourceData, targetData, getState }) => {
if (singleImageDndSource.typeGuard(sourceData)) {
const currentBoard = sourceData.payload.imageDTO.board_id ?? 'none';
const destinationBoard = targetData.payload.boardId;
- return currentBoard !== destinationBoard;
+ if (currentBoard === destinationBoard) {
+ return false;
+ }
+ // Don't allow moving images from shared boards the user doesn't own
+ return canMoveFromSourceBoard(currentBoard, getState);
}
if (multipleImageDndSource.typeGuard(sourceData)) {
const currentBoard = sourceData.payload.board_id;
const destinationBoard = targetData.payload.boardId;
- return currentBoard !== destinationBoard;
+ if (currentBoard === destinationBoard) {
+ return false;
+ }
+ return canMoveFromSourceBoard(currentBoard, getState);
}
return false;
},
@@ -491,15 +541,22 @@ export const removeImageFromBoardDndTarget: DndTarget<
..._removeFromBoard,
typeGuard: buildTypeGuard(_removeFromBoard.key),
getData: buildGetData(_removeFromBoard.key, _removeFromBoard.type),
- isValid: ({ sourceData }) => {
+ isValid: ({ sourceData, getState }) => {
if (singleImageDndSource.typeGuard(sourceData)) {
const currentBoard = sourceData.payload.imageDTO.board_id ?? 'none';
- return currentBoard !== 'none';
+ if (currentBoard === 'none') {
+ return false;
+ }
+ // Don't allow removing images from shared boards the user doesn't own
+ return canMoveFromSourceBoard(currentBoard, getState);
}
if (multipleImageDndSource.typeGuard(sourceData)) {
const currentBoard = sourceData.payload.board_id;
- return currentBoard !== 'none';
+ if (currentBoard === 'none') {
+ return false;
+ }
+ return canMoveFromSourceBoard(currentBoard, getState);
}
return false;
diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx
index 5cc25f6c038..d10dde6ee44 100644
--- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx
@@ -2,15 +2,26 @@ import type { ContextMenuProps } from '@invoke-ai/ui-library';
import { ContextMenu, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { selectCurrentUser } from 'features/auth/store/authSlice';
import { $boardToDelete } from 'features/gallery/components/Boards/DeleteBoardModal';
import { selectAutoAddBoardId, selectAutoAssignBoardOnClick } from 'features/gallery/store/gallerySelectors';
import { autoAddBoardIdChanged } from 'features/gallery/store/gallerySlice';
import { toast } from 'features/toast/toast';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
-import { PiArchiveBold, PiArchiveFill, PiDownloadBold, PiPlusBold, PiTrashSimpleBold } from 'react-icons/pi';
+import {
+ PiArchiveBold,
+ PiArchiveFill,
+ PiDownloadBold,
+ PiGlobeBold,
+ PiLockBold,
+ PiPlusBold,
+ PiShareNetworkBold,
+ PiTrashSimpleBold,
+} from 'react-icons/pi';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
import { useBulkDownloadImagesMutation } from 'services/api/endpoints/images';
+import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
import { useBoardName } from 'services/api/hooks/useBoardName';
import type { BoardDTO } from 'services/api/types';
@@ -23,6 +34,7 @@ const BoardContextMenu = ({ board, children }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const autoAssignBoardOnClick = useAppSelector(selectAutoAssignBoardOnClick);
+ const currentUser = useAppSelector(selectCurrentUser);
const selectIsSelectedForAutoAdd = useMemo(
() => createSelector(selectAutoAddBoardId, (autoAddBoardId) => board.board_id === autoAddBoardId),
[board.board_id]
@@ -35,6 +47,11 @@ const BoardContextMenu = ({ board, children }: Props) => {
const [bulkDownload] = useBulkDownloadImagesMutation();
+ // Only the board owner or admin can modify visibility
+ const canChangeVisibility = currentUser !== null && (currentUser.is_admin || board.user_id === currentUser.user_id);
+
+ const { canDeleteBoard } = useBoardAccess(board);
+
const handleSetAutoAdd = useCallback(() => {
dispatch(autoAddBoardIdChanged(board.board_id));
}, [board.board_id, dispatch]);
@@ -64,6 +81,26 @@ const BoardContextMenu = ({ board, children }: Props) => {
});
}, [board.board_id, updateBoard]);
+ const handleSetVisibility = useCallback(
+ async (visibility: 'private' | 'shared' | 'public') => {
+ try {
+ await updateBoard({
+ board_id: board.board_id,
+ changes: { board_visibility: visibility },
+ }).unwrap();
+ } catch {
+ toast({ status: 'error', title: t('boards.updateBoardVisibilityError') });
+ }
+ },
+ [board.board_id, t, updateBoard]
+ );
+
+ const handleSetVisibilityPrivate = useCallback(() => handleSetVisibility('private'), [handleSetVisibility]);
+
+ const handleSetVisibilityShared = useCallback(() => handleSetVisibility('shared'), [handleSetVisibility]);
+
+ const handleSetVisibilityPublic = useCallback(() => handleSetVisibility('public'), [handleSetVisibility]);
+
const setAsBoardToDelete = useCallback(() => {
$boardToDelete.set(board);
}, [board]);
@@ -83,18 +120,50 @@ const BoardContextMenu = ({ board, children }: Props) => {
{board.archived && (
- } onClick={handleUnarchive}>
+ } onClick={handleUnarchive} isDisabled={!canDeleteBoard}>
{t('boards.unarchiveBoard')}
)}
{!board.archived && (
- } onClick={handleArchive}>
+ } onClick={handleArchive} isDisabled={!canDeleteBoard}>
{t('boards.archiveBoard')}
)}
- } onClick={setAsBoardToDelete} isDestructive>
+ {canChangeVisibility && (
+ <>
+ }
+ onClick={handleSetVisibilityPrivate}
+ isDisabled={board.board_visibility === 'private'}
+ >
+ {t('boards.setVisibilityPrivate')}
+
+ }
+ onClick={handleSetVisibilityShared}
+ isDisabled={board.board_visibility === 'shared'}
+ >
+ {t('boards.setVisibilityShared')}
+
+ }
+ onClick={handleSetVisibilityPublic}
+ isDisabled={board.board_visibility === 'public'}
+ >
+ {t('boards.setVisibilityPublic')}
+
+ >
+ )}
+
+ }
+ onClick={setAsBoardToDelete}
+ isDestructive
+ isDisabled={!canDeleteBoard}
+ >
{t('boards.deleteBoard')}
@@ -108,8 +177,14 @@ const BoardContextMenu = ({ board, children }: Props) => {
t,
handleBulkDownload,
board.archived,
+ board.board_visibility,
handleUnarchive,
handleArchive,
+ canChangeVisibility,
+ handleSetVisibilityPrivate,
+ handleSetVisibilityShared,
+ handleSetVisibilityPublic,
+ canDeleteBoard,
setAsBoardToDelete,
]
);
diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardEditableTitle.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardEditableTitle.tsx
index 67c7dad6ed0..cf2749e3400 100644
--- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardEditableTitle.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardEditableTitle.tsx
@@ -7,6 +7,7 @@ import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPencilBold } from 'react-icons/pi';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
+import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
import type { BoardDTO } from 'services/api/types';
type Props = {
@@ -19,6 +20,7 @@ export const BoardEditableTitle = memo(({ board, isSelected }: Props) => {
const isHovering = useBoolean(false);
const inputRef = useRef(null);
const [updateBoard, updateBoardResult] = useUpdateBoardMutation();
+ const { canRenameBoard } = useBoardAccess(board);
const onChange = useCallback(
async (board_name: string) => {
@@ -51,13 +53,13 @@ export const BoardEditableTitle = memo(({ board, isSelected }: Props) => {
fontWeight="semibold"
userSelect="none"
color={isSelected ? 'base.100' : 'base.300'}
- onDoubleClick={editable.startEditing}
- cursor="text"
+ onDoubleClick={canRenameBoard ? editable.startEditing : undefined}
+ cursor={canRenameBoard ? 'text' : 'default'}
noOfLines={1}
>
{editable.value}
- {isHovering.isTrue && (
+ {canRenameBoard && isHovering.isTrue && (
}
diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx
index 4d821f819c6..10fbe618322 100644
--- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx
@@ -18,8 +18,9 @@ import {
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
-import { PiArchiveBold, PiImageSquare } from 'react-icons/pi';
+import { PiArchiveBold, PiGlobeBold, PiImageSquare, PiShareNetworkBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
+import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
import type { BoardDTO } from 'services/api/types';
const _hover: SystemStyleObject = {
@@ -62,6 +63,8 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => {
const showOwner = currentUser?.is_admin && board.owner_username;
+ const { canWriteImages } = useBoardAccess(board);
+
return (
@@ -99,6 +102,20 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => {
{autoAddBoardId === board.board_id && }
{board.archived && }
+ {board.board_visibility === 'shared' && (
+
+
+
+
+
+ )}
+ {board.board_visibility === 'public' && (
+
+
+
+
+
+ )}
{board.image_count} | {board.asset_count}
@@ -108,7 +125,12 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => {
)}
-
+
);
};
diff --git a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemChangeBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemChangeBoard.tsx
index 71764870153..f5c044132e5 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemChangeBoard.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemChangeBoard.tsx
@@ -5,11 +5,15 @@ import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFoldersBold } from 'react-icons/pi';
+import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
+import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard';
export const ContextMenuItemChangeBoard = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const imageDTO = useImageDTOContext();
+ const selectedBoard = useSelectedBoard();
+ const { canWriteImages } = useBoardAccess(selectedBoard);
const onClick = useCallback(() => {
dispatch(imagesToChangeSelected([imageDTO.image_name]));
@@ -17,7 +21,7 @@ export const ContextMenuItemChangeBoard = memo(() => {
}, [dispatch, imageDTO]);
return (
- } onClickCapture={onClick}>
+ } onClickCapture={onClick} isDisabled={!canWriteImages}>
{t('boards.changeBoard')}
);
diff --git a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemDeleteImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemDeleteImage.tsx
index e20221f3423..5dfa7116b17 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemDeleteImage.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemDeleteImage.tsx
@@ -4,11 +4,15 @@ import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
+import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
+import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard';
export const ContextMenuItemDeleteImage = memo(() => {
const { t } = useTranslation();
const deleteImageModal = useDeleteImageModalApi();
const imageDTO = useImageDTOContext();
+ const selectedBoard = useSelectedBoard();
+ const { canWriteImages } = useBoardAccess(selectedBoard);
const onClick = useCallback(async () => {
try {
@@ -18,6 +22,10 @@ export const ContextMenuItemDeleteImage = memo(() => {
}
}, [deleteImageModal, imageDTO]);
+ if (!canWriteImages) {
+ return null;
+ }
+
return (
}
diff --git a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MultipleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MultipleSelectionMenuItems.tsx
index d148332943c..ee3c8e4e985 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MultipleSelectionMenuItems.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MultipleSelectionMenuItems.tsx
@@ -10,12 +10,16 @@ import {
useStarImagesMutation,
useUnstarImagesMutation,
} from 'services/api/endpoints/images';
+import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
+import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard';
const MultipleSelectionMenuItems = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const selection = useAppSelector((s) => s.gallery.selection);
const deleteImageModal = useDeleteImageModalApi();
+ const selectedBoard = useSelectedBoard();
+ const { canWriteImages } = useBoardAccess(selectedBoard);
const [starImages] = useStarImagesMutation();
const [unstarImages] = useUnstarImagesMutation();
@@ -53,11 +57,16 @@ const MultipleSelectionMenuItems = () => {
} onClickCapture={handleBulkDownload}>
{t('gallery.downloadSelection')}
- } onClickCapture={handleChangeBoard}>
+ } onClickCapture={handleChangeBoard} isDisabled={!canWriteImages}>
{t('boards.changeBoard')}
- } onClickCapture={handleDeleteSelection}>
+ }
+ onClickCapture={handleDeleteSelection}
+ isDisabled={!canWriteImages}
+ >
{t('gallery.deleteSelection')}
>
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
index ccd58992ef6..af1d376887b 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
@@ -108,6 +108,25 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
if (!element) {
return;
}
+
+ const monitorBinding = monitorForElements({
+ // This is a "global" drag start event, meaning that it is called for all drag events.
+ onDragStart: ({ source }) => {
+ // When we start dragging multiple images, set the dragging state to true if the dragged image is part of the
+ // selection. This is called for all drag events.
+ if (
+ multipleImageDndSource.typeGuard(source.data) &&
+ source.data.payload.image_names.includes(imageDTO.image_name)
+ ) {
+ setIsDragging(true);
+ }
+ },
+ onDrop: () => {
+ // Always set the dragging state to false when a drop event occurs.
+ setIsDragging(false);
+ },
+ });
+
return combine(
firefoxDndFix(element),
draggable({
@@ -153,23 +172,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
}
},
}),
- monitorForElements({
- // This is a "global" drag start event, meaning that it is called for all drag events.
- onDragStart: ({ source }) => {
- // When we start dragging multiple images, set the dragging state to true if the dragged image is part of the
- // selection. This is called for all drag events.
- if (
- multipleImageDndSource.typeGuard(source.data) &&
- source.data.payload.image_names.includes(imageDTO.image_name)
- ) {
- setIsDragging(true);
- }
- },
- onDrop: () => {
- // Always set the dragging state to false when a drop event occurs.
- setIsDragging(false);
- },
- })
+ monitorBinding
);
}, [imageDTO, store]);
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryItemDeleteIconButton.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryItemDeleteIconButton.tsx
index 0a97bf819de..612e6361b14 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryItemDeleteIconButton.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryItemDeleteIconButton.tsx
@@ -5,6 +5,8 @@ import type { MouseEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleFill } from 'react-icons/pi';
+import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
+import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard';
import type { ImageDTO } from 'services/api/types';
type Props = {
@@ -15,6 +17,8 @@ export const GalleryItemDeleteIconButton = memo(({ imageDTO }: Props) => {
const shift = useShiftModifier();
const { t } = useTranslation();
const deleteImageModal = useDeleteImageModalApi();
+ const selectedBoard = useSelectedBoard();
+ const { canWriteImages } = useBoardAccess(selectedBoard);
const onClick = useCallback(
(e: MouseEvent) => {
@@ -24,7 +28,7 @@ export const GalleryItemDeleteIconButton = memo(({ imageDTO }: Props) => {
[deleteImageModal, imageDTO]
);
- if (!shift) {
+ if (!shift || !canWriteImages) {
return null;
}
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/NoContentForViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/NoContentForViewer.tsx
index b8a522c3a65..c301922df95 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/NoContentForViewer.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/NoContentForViewer.tsx
@@ -1,7 +1,9 @@
import type { ButtonProps } from '@invoke-ai/ui-library';
import { Alert, AlertDescription, AlertIcon, Button, Divider, Flex, Link, Spinner, Text } from '@invoke-ai/ui-library';
+import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { InvokeLogoIcon } from 'common/components/InvokeLogoIcon';
+import { selectCurrentUser } from 'features/auth/store/authSlice';
import { LOADING_SYMBOL, useHasImages } from 'features/gallery/hooks/useHasImages';
import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore';
import { navigationApi } from 'features/ui/layouts/navigation-api';
@@ -9,16 +11,26 @@ import type { PropsWithChildren } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { PiArrowSquareOutBold, PiImageBold } from 'react-icons/pi';
+import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
import { useMainModels } from 'services/api/hooks/modelsByType';
export const NoContentForViewer = memo(() => {
const hasImages = useHasImages();
const [mainModels, { data }] = useMainModels();
+ const { data: setupStatus } = useGetSetupStatusQuery();
+ const user = useAppSelector(selectCurrentUser);
const { t } = useTranslation();
+ const isMultiuser = setupStatus?.multiuser_enabled ?? false;
+ const isAdmin = !isMultiuser || (user?.is_admin ?? false);
+ const adminEmail = setupStatus?.admin_email ?? null;
+
+ const modelsLoaded = data !== undefined;
+ const hasModels = mainModels.length > 0;
+
const showStarterBundles = useMemo(() => {
- return data && mainModels.length === 0;
- }, [mainModels.length, data]);
+ return modelsLoaded && !hasModels && isAdmin;
+ }, [modelsLoaded, hasModels, isAdmin]);
if (hasImages === LOADING_SYMBOL) {
// Blank bg w/ a spinner. The new user experience components below have an invoke logo, but it's not centered.
@@ -36,10 +48,18 @@ export const NoContentForViewer = memo(() => {
-
- {showStarterBundles && }
-
-
+ {isAdmin ? (
+ // Admin / single-user mode
+ <>
+ {modelsLoaded && hasModels ? : }
+ {showStarterBundles && }
+
+
+ >
+ ) : (
+ // Non-admin user in multiuser mode
+ <>{modelsLoaded && hasModels ? : }>
+ )}
);
@@ -99,6 +119,32 @@ const GetStartedLocal = () => {
);
};
+const GetStartedWithModels = () => {
+ return (
+
+
+
+ );
+};
+
+const GetStartedNonAdmin = ({ adminEmail }: { adminEmail: string | null }) => {
+ const AdminEmailLink = adminEmail ? (
+
+ {adminEmail}
+
+ ) : (
+
+ your administrator
+
+ );
+
+ return (
+
+
+
+ );
+};
+
const StarterBundlesCallout = () => {
const handleClickDownloadStarterModels = useCallback(() => {
navigationApi.switchToTab('models');
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx
index d1774f9ded0..9b76fbbde67 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx
+++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx
@@ -1,10 +1,11 @@
import { Button, Text, useToast } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
-import { selectIsAuthenticated } from 'features/auth/store/authSlice';
+import { selectCurrentUser, selectIsAuthenticated } from 'features/auth/store/authSlice';
import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
+import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
import { useMainModels } from 'services/api/hooks/modelsByType';
const TOAST_ID = 'starterModels';
@@ -15,6 +16,11 @@ export const useStarterModelsToast = () => {
const [mainModels, { data }] = useMainModels();
const toast = useToast();
const isAuthenticated = useAppSelector(selectIsAuthenticated);
+ const { data: setupStatus } = useGetSetupStatusQuery();
+ const user = useAppSelector(selectCurrentUser);
+
+ const isMultiuser = setupStatus?.multiuser_enabled ?? false;
+ const isAdmin = !isMultiuser || (user?.is_admin ?? false);
useEffect(() => {
// Only show the toast if the user is authenticated
@@ -33,17 +39,17 @@ export const useStarterModelsToast = () => {
toast({
id: TOAST_ID,
title: t('modelManager.noModelsInstalled'),
- description: ,
+ description: isAdmin ? : ,
status: 'info',
isClosable: true,
duration: null,
onCloseComplete: () => setDidToast(true),
});
}
- }, [data, didToast, isAuthenticated, mainModels.length, t, toast]);
+ }, [data, didToast, isAuthenticated, isAdmin, mainModels.length, t, toast]);
};
-const ToastDescription = () => {
+const AdminToastDescription = () => {
const { t } = useTranslation();
const toast = useToast();
@@ -62,3 +68,9 @@ const ToastDescription = () => {
);
};
+
+const NonAdminToastDescription = () => {
+ const { t } = useTranslation();
+
+ return {t('modelManager.noModelsInstalledAskAdmin')};
+};
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx
index f6e1a18f6fd..60200c8801f 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx
+++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx
@@ -37,7 +37,7 @@ export const ModelManager = memo(() => {
{t('common.modelManager')}
-
+ {canManageModels && }
{!!selectedModelKey && canManageModels && (
} onClick={handleClickAddModel}>
{t('modelManager.addModels')}
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton.tsx
index 91c6c1dae38..fe4b889f540 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton.tsx
@@ -1,5 +1,6 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useDoesWorkflowHaveUnsavedChanges } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
+import { useIsCurrentWorkflowOwner } from 'features/workflowLibrary/hooks/useIsCurrentWorkflowOwner';
import { useSaveOrSaveAsWorkflow } from 'features/workflowLibrary/hooks/useSaveOrSaveAsWorkflow';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -8,6 +9,7 @@ import { PiFloppyDiskBold } from 'react-icons/pi';
const SaveWorkflowButton = () => {
const { t } = useTranslation();
const doesWorkflowHaveUnsavedChanges = useDoesWorkflowHaveUnsavedChanges();
+ const isCurrentWorkflowOwner = useIsCurrentWorkflowOwner();
const saveOrSaveAsWorkflow = useSaveOrSaveAsWorkflow();
return (
@@ -15,7 +17,7 @@ const SaveWorkflowButton = () => {
tooltip={t('workflows.saveWorkflow')}
aria-label={t('workflows.saveWorkflow')}
icon={}
- isDisabled={!doesWorkflowHaveUnsavedChanges}
+ isDisabled={!doesWorkflowHaveUnsavedChanges || !isCurrentWorkflowOwner}
onClick={saveOrSaveAsWorkflow}
pointerEvents="auto"
/>
diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/WorkflowListMenu/SaveWorkflowButton.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/WorkflowListMenu/SaveWorkflowButton.tsx
index 39a93e4a382..779d6f018ee 100644
--- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/WorkflowListMenu/SaveWorkflowButton.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/WorkflowListMenu/SaveWorkflowButton.tsx
@@ -1,4 +1,6 @@
import { IconButton } from '@invoke-ai/ui-library';
+import { useDoesWorkflowHaveUnsavedChanges } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
+import { useIsCurrentWorkflowOwner } from 'features/workflowLibrary/hooks/useIsCurrentWorkflowOwner';
import { useSaveOrSaveAsWorkflow } from 'features/workflowLibrary/hooks/useSaveOrSaveAsWorkflow';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -7,12 +9,15 @@ import { PiFloppyDiskBold } from 'react-icons/pi';
const SaveWorkflowButton = () => {
const { t } = useTranslation();
const saveOrSaveAsWorkflow = useSaveOrSaveAsWorkflow();
+ const doesWorkflowHaveUnsavedChanges = useDoesWorkflowHaveUnsavedChanges();
+ const isCurrentWorkflowOwner = useIsCurrentWorkflowOwner();
return (
}
+ isDisabled={!doesWorkflowHaveUnsavedChanges || !isCurrentWorkflowOwner}
onClick={saveOrSaveAsWorkflow}
pointerEvents="auto"
variant="ghost"
diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowGeneralTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowGeneralTab.tsx
index c1094abf86d..11d27335352 100644
--- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowGeneralTab.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowGeneralTab.tsx
@@ -1,8 +1,19 @@
import type { FormControlProps } from '@invoke-ai/ui-library';
-import { Box, Flex, FormControl, FormControlGroup, FormLabel, Image, Input, Textarea } from '@invoke-ai/ui-library';
+import {
+ Box,
+ Checkbox,
+ Flex,
+ FormControl,
+ FormControlGroup,
+ FormLabel,
+ Image,
+ Input,
+ Textarea,
+} from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
+import { selectCurrentUser } from 'features/auth/store/authSlice';
import {
workflowAuthorChanged,
workflowContactChanged,
@@ -25,7 +36,8 @@ import {
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
-import { useGetWorkflowQuery } from 'services/api/endpoints/workflows';
+import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
+import { useGetWorkflowQuery, useUpdateWorkflowIsPublicMutation } from 'services/api/endpoints/workflows';
import { WorkflowThumbnailEditor } from './WorkflowThumbnail/WorkflowThumbnailEditor';
@@ -95,6 +107,7 @@ const WorkflowGeneralTab = () => {
{t('nodes.workflowName')}
+
{t('nodes.workflowVersion')}
@@ -187,3 +200,40 @@ const Thumbnail = ({ id }: { id?: string | null }) => {
// This is a default workflow and it does not have a thumbnail set. Users may not edit the thumbnail.
return null;
};
+
+const ShareWorkflowCheckbox = ({ id }: { id?: string | null }) => {
+ const { t } = useTranslation();
+ const currentUser = useAppSelector(selectCurrentUser);
+ const { data: setupStatus } = useGetSetupStatusQuery();
+ const { data } = useGetWorkflowQuery(id ?? skipToken);
+ const [updateIsPublic, { isLoading }] = useUpdateWorkflowIsPublicMutation();
+
+ const handleChange = useCallback(
+ (e: ChangeEvent) => {
+ if (!id) {
+ return;
+ }
+ updateIsPublic({ workflow_id: id, is_public: e.target.checked });
+ },
+ [id, updateIsPublic]
+ );
+
+ // Only show for saved user workflows in multiuser mode when the current user is the owner or admin
+ if (!data || !id || data.workflow.meta.category !== 'user') {
+ return null;
+ }
+ if (setupStatus?.multiuser_enabled) {
+ const isOwner = currentUser !== null && data.user_id === currentUser.user_id;
+ const isAdmin = currentUser?.is_admin ?? false;
+ if (!isOwner && !isAdmin) {
+ return null;
+ }
+ }
+
+ return (
+
+
+ {t('workflows.shareWorkflow')}
+
+ );
+};
diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibrarySideNav.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibrarySideNav.tsx
index 73b046c83a9..501b8365db5 100644
--- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibrarySideNav.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibrarySideNav.tsx
@@ -41,6 +41,7 @@ export const WorkflowLibrarySideNav = () => {
{t('workflows.recentlyOpened')}
+ {t('workflows.sharedWorkflows')}
diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowList.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowList.tsx
index 79dff535b05..e6605d2076a 100644
--- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowList.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowList.tsx
@@ -32,6 +32,8 @@ const getCategories = (view: WorkflowLibraryView): WorkflowCategory[] => {
return ['user', 'default'];
case 'yours':
return ['user'];
+ case 'shared':
+ return ['user'];
default:
assert>(false);
}
@@ -44,6 +46,13 @@ const getHasBeenOpened = (view: WorkflowLibraryView): boolean | undefined => {
return undefined;
};
+const getIsPublic = (view: WorkflowLibraryView): boolean | undefined => {
+ if (view === 'shared') {
+ return true;
+ }
+ return undefined;
+};
+
const useInfiniteQueryAry = () => {
const orderBy = useAppSelector(selectWorkflowLibraryOrderBy);
const direction = useAppSelector(selectWorkflowLibraryDirection);
@@ -62,6 +71,7 @@ const useInfiniteQueryAry = () => {
query: debouncedSearchTerm,
tags: view === 'defaults' || view === 'yours' ? selectedTags : [],
has_been_opened: getHasBeenOpened(view),
+ is_public: getIsPublic(view),
} satisfies Parameters[0];
}, [orderBy, direction, view, debouncedSearchTerm, selectedTags]);
diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx
index a1767765c93..a184f04039a 100644
--- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx
@@ -1,13 +1,15 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
-import { Badge, Flex, Icon, Image, Spacer, Text } from '@invoke-ai/ui-library';
+import { Badge, Flex, Icon, Image, Spacer, Switch, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { selectCurrentUser } from 'features/auth/store/authSlice';
import { selectWorkflowId } from 'features/nodes/store/selectors';
import { workflowModeChanged } from 'features/nodes/store/workflowLibrarySlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import InvokeLogo from 'public/assets/images/invoke-symbol-wht-lrg.svg';
-import { memo, useCallback, useMemo } from 'react';
+import { type ChangeEvent, memo, type MouseEvent, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiImage } from 'react-icons/pi';
+import { useUpdateWorkflowIsPublicMutation } from 'services/api/endpoints/workflows';
import type { WorkflowRecordListItemWithThumbnailDTO } from 'services/api/types';
import { DeleteWorkflow } from './WorkflowLibraryListItemActions/DeleteWorkflow';
@@ -33,12 +35,21 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
const { t } = useTranslation();
const dispatch = useAppDispatch();
const workflowId = useAppSelector(selectWorkflowId);
+ const currentUser = useAppSelector(selectCurrentUser);
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const isActive = useMemo(() => {
return workflowId === workflow.workflow_id;
}, [workflowId, workflow.workflow_id]);
+ const isOwner = useMemo(() => {
+ return currentUser !== null && workflow.user_id === currentUser.user_id;
+ }, [currentUser, workflow.user_id]);
+
+ const canEditOrDelete = useMemo(() => {
+ return isOwner || (currentUser?.is_admin ?? false);
+ }, [isOwner, currentUser]);
+
const tags = useMemo(() => {
if (!workflow.tags) {
return [];
@@ -102,6 +113,18 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
{t('workflows.opened')}
)}
+ {workflow.is_public && workflow.category !== 'default' && (
+
+ {t('workflows.shared')}
+
+ )}
{workflow.category === 'default' && (
)}
+ {isOwner && }
{workflow.category === 'default' && }
{workflow.category !== 'default' && (
<>
-
+ {canEditOrDelete && }
-
+ {canEditOrDelete && }
>
)}
@@ -152,6 +176,35 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
});
WorkflowListItem.displayName = 'WorkflowListItem';
+const ShareWorkflowToggle = memo(({ workflow }: { workflow: WorkflowRecordListItemWithThumbnailDTO }) => {
+ const { t } = useTranslation();
+ const [updateIsPublic, { isLoading }] = useUpdateWorkflowIsPublicMutation();
+
+ const handleChange = useCallback(
+ (e: ChangeEvent) => {
+ e.stopPropagation();
+ updateIsPublic({ workflow_id: workflow.workflow_id, is_public: e.target.checked });
+ },
+ [updateIsPublic, workflow.workflow_id]
+ );
+
+ const handleClick = useCallback((e: MouseEvent) => {
+ e.stopPropagation();
+ }, []);
+
+ return (
+
+
+
+ {t('workflows.shared')}
+
+
+
+
+ );
+});
+ShareWorkflowToggle.displayName = 'ShareWorkflowToggle';
+
const UserThumbnailFallback = memo(() => {
return (
;
const isOrderBy = (v: unknown): v is OrderBy => zOrderBy.safeParse(v).success;
@@ -32,6 +32,7 @@ export const WorkflowSortControl = () => {
created_at: t('workflows.created'),
updated_at: t('workflows.updated'),
name: t('workflows.name'),
+ is_public: t('workflows.shared'),
}),
[t]
);
diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowLibrarySlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowLibrarySlice.ts
index ee85a03c18f..1d5d8554aeb 100644
--- a/invokeai/frontend/web/src/features/nodes/store/workflowLibrarySlice.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/workflowLibrarySlice.ts
@@ -11,7 +11,7 @@ import {
} from 'services/api/types';
import z from 'zod';
-const zWorkflowLibraryView = z.enum(['recent', 'yours', 'defaults']);
+const zWorkflowLibraryView = z.enum(['recent', 'yours', 'shared', 'defaults']);
export type WorkflowLibraryView = z.infer;
const zWorkflowLibraryState = z.object({
@@ -55,6 +55,9 @@ const slice = createSlice({
if (action.payload === 'recent') {
state.orderBy = 'opened_at';
state.direction = 'DESC';
+ } else if (action.payload === 'shared') {
+ state.orderBy = 'name';
+ state.direction = 'ASC';
}
},
workflowLibraryTagToggled: (state, action: PayloadAction) => {
@@ -121,5 +124,11 @@ export const WORKFLOW_LIBRARY_TAG_CATEGORIES: WorkflowTagCategory[] = [
];
export const WORKFLOW_LIBRARY_TAGS = WORKFLOW_LIBRARY_TAG_CATEGORIES.flatMap(({ tags }) => tags);
-type WorkflowSortOption = 'opened_at' | 'created_at' | 'updated_at' | 'name';
-export const WORKFLOW_LIBRARY_SORT_OPTIONS: WorkflowSortOption[] = ['opened_at', 'created_at', 'updated_at', 'name'];
+type WorkflowSortOption = 'opened_at' | 'created_at' | 'updated_at' | 'name' | 'is_public';
+export const WORKFLOW_LIBRARY_SORT_OPTIONS: WorkflowSortOption[] = [
+ 'opened_at',
+ 'created_at',
+ 'updated_at',
+ 'name',
+ 'is_public',
+];
diff --git a/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx b/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx
index 452c490af19..7f92df04b10 100644
--- a/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx
@@ -3,6 +3,7 @@ import {
Button,
Flex,
Icon,
+ Link,
Popover,
PopoverArrow,
PopoverBody,
@@ -20,6 +21,7 @@ import { buildGroup, getRegex, isGroup, Picker, usePickerContext } from 'common/
import { useDisclosure } from 'common/hooks/useBoolean';
import { typedMemo } from 'common/util/typedMemo';
import { uniq } from 'es-toolkit/compat';
+import { selectCurrentUser } from 'features/auth/store/authSlice';
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { MODEL_BASE_TO_COLOR, MODEL_BASE_TO_LONG_NAME, MODEL_BASE_TO_SHORT_NAME } from 'features/modelManagerV2/models';
@@ -32,6 +34,7 @@ import { filesize } from 'filesize';
import { memo, useCallback, useMemo, useRef } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { PiCaretDownBold, PiLinkSimple } from 'react-icons/pi';
+import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
import type { AnyModelConfig } from 'services/api/types';
@@ -82,6 +85,32 @@ const components = {
const NoOptionsFallback = memo(({ noOptionsText }: { noOptionsText?: string }) => {
const { t } = useTranslation();
+ const { data: setupStatus } = useGetSetupStatusQuery();
+ const user = useAppSelector(selectCurrentUser);
+
+ const isMultiuser = setupStatus?.multiuser_enabled ?? false;
+ const isAdmin = !isMultiuser || (user?.is_admin ?? false);
+ const adminEmail = setupStatus?.admin_email ?? null;
+
+ if (!isAdmin) {
+ const AdminEmailLink = adminEmail ? (
+
+ {adminEmail}
+
+ ) : (
+
+ your administrator
+
+ );
+
+ return (
+
+
+
+
+
+ );
+ }
return (
diff --git a/invokeai/frontend/web/src/features/queue/components/InvokeButtonTooltip/InvokeButtonTooltip.tsx b/invokeai/frontend/web/src/features/queue/components/InvokeButtonTooltip/InvokeButtonTooltip.tsx
index 9f1d004ba87..61553910e25 100644
--- a/invokeai/frontend/web/src/features/queue/components/InvokeButtonTooltip/InvokeButtonTooltip.tsx
+++ b/invokeai/frontend/web/src/features/queue/components/InvokeButtonTooltip/InvokeButtonTooltip.tsx
@@ -17,6 +17,8 @@ import type { PropsWithChildren } from 'react';
import { memo, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { enqueueMutationFixedCacheKeyOptions, useEnqueueBatchMutation } from 'services/api/endpoints/queue';
+import { useAutoAddBoard } from 'services/api/hooks/useAutoAddBoard';
+import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
import { useBoardName } from 'services/api/hooks/useBoardName';
type Props = TooltipProps & {
@@ -53,19 +55,25 @@ TooltipContent.displayName = 'TooltipContent';
const CanvasTabTooltipContent = memo(({ prepend = false }: { prepend?: boolean }) => {
const isReady = useStore($isReadyToEnqueue);
const reasons = useStore($reasonsWhyCannotEnqueue);
+ const autoAddBoard = useAutoAddBoard();
+ const { canWriteImages } = useBoardAccess(autoAddBoard);
return (
-
+
- {reasons.length > 0 && (
+ {(reasons.length > 0 || !canWriteImages) && (
<>
-
+
+ >
+ )}
+ {canWriteImages && (
+ <>
+
+
>
)}
-
-
);
});
@@ -74,15 +82,17 @@ CanvasTabTooltipContent.displayName = 'CanvasTabTooltipContent';
const UpscaleTabTooltipContent = memo(({ prepend = false }: { prepend?: boolean }) => {
const isReady = useStore($isReadyToEnqueue);
const reasons = useStore($reasonsWhyCannotEnqueue);
+ const autoAddBoard = useAutoAddBoard();
+ const { canWriteImages } = useBoardAccess(autoAddBoard);
return (
-
+
- {reasons.length > 0 && (
+ {(reasons.length > 0 || !canWriteImages) && (
<>
-
+
>
)}
@@ -195,12 +205,23 @@ const IsReadyText = memo(({ isReady, prepend }: { isReady: boolean; prepend: boo
});
IsReadyText.displayName = 'IsReadyText';
-const ReasonsList = memo(({ reasons }: { reasons: Reason[] }) => {
+const ReasonsList = memo(({ reasons, canWriteImages = true }: { reasons: Reason[]; canWriteImages?: boolean }) => {
+ const { t } = useTranslation();
+ const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
+ const autoAddBoardName = useBoardName(autoAddBoardId);
+
return (
{reasons.map((reason, i) => (
))}
+ {!canWriteImages && (
+
+
+ {t('parameters.invoke.boardNotWritable', { boardName: autoAddBoardName || autoAddBoardId })}
+
+
+ )}
);
});
diff --git a/invokeai/frontend/web/src/features/queue/components/InvokeQueueBackButton.tsx b/invokeai/frontend/web/src/features/queue/components/InvokeQueueBackButton.tsx
index b175e4d8b09..a363d159e1d 100644
--- a/invokeai/frontend/web/src/features/queue/components/InvokeQueueBackButton.tsx
+++ b/invokeai/frontend/web/src/features/queue/components/InvokeQueueBackButton.tsx
@@ -5,6 +5,8 @@ import { QueueIterationsNumberInput } from 'features/queue/components/QueueItera
import { useInvoke } from 'features/queue/hooks/useInvoke';
import { memo } from 'react';
import { PiLightningFill, PiSparkleFill } from 'react-icons/pi';
+import { useAutoAddBoard } from 'services/api/hooks/useAutoAddBoard';
+import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
import { InvokeButtonTooltip } from './InvokeButtonTooltip/InvokeButtonTooltip';
@@ -14,6 +16,8 @@ export const InvokeButton = memo(() => {
const queue = useInvoke();
const shift = useShiftModifier();
const isLoadingDynamicPrompts = useAppSelector(selectDynamicPromptsIsLoading);
+ const autoAddBoard = useAutoAddBoard();
+ const { canWriteImages } = useBoardAccess(autoAddBoard);
return (
@@ -23,7 +27,7 @@ export const InvokeButton = memo(() => {
onClick={shift ? queue.enqueueFront : queue.enqueueBack}
isLoading={queue.isLoading || isLoadingDynamicPrompts}
loadingText={invoke}
- isDisabled={queue.isDisabled}
+ isDisabled={queue.isDisabled || !canWriteImages}
rightIcon={shift ? : }
variant="solid"
colorScheme="invokeYellow"
diff --git a/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx b/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx
index 3417488b09e..e8636466066 100644
--- a/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx
+++ b/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx
@@ -1,6 +1,4 @@
import { Badge, Portal } from '@invoke-ai/ui-library';
-import { useAppSelector } from 'app/store/storeHooks';
-import { selectIsAuthenticated } from 'features/auth/store/authSlice';
import type { RefObject } from 'react';
import { memo, useEffect, useMemo, useState } from 'react';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
@@ -13,53 +11,35 @@ type Props = {
type SessionQueueStatus = components['schemas']['SessionQueueStatus'];
/**
- * Determines if user-specific queue counts are available.
- */
-const hasUserCounts = (queueData: SessionQueueStatus): boolean => {
- return (
- queueData.user_pending !== undefined &&
- queueData.user_pending !== null &&
- queueData.user_in_progress !== undefined &&
- queueData.user_in_progress !== null
- );
-};
-
-/**
- * Calculates the appropriate badge text based on queue status and authentication state.
+ * Calculates the appropriate badge text based on queue status.
* Returns null if badge should be hidden.
+ *
+ * In multiuser mode, the backend already scopes counts to the current user for non-admins,
+ * so pending + in_progress reflects the user's own queue items.
*/
-const getBadgeText = (queueData: SessionQueueStatus | undefined, isAuthenticated: boolean): string | null => {
+const getBadgeText = (queueData: SessionQueueStatus | undefined): string | null => {
if (!queueData) {
return null;
}
const totalPending = queueData.pending + queueData.in_progress;
- // Hide badge if there are no pending jobs
if (totalPending === 0) {
return null;
}
- // In multiuser mode (authenticated user), show "X/Y" format where X is user's jobs and Y is total jobs
- if (isAuthenticated && hasUserCounts(queueData)) {
- const userPending = queueData.user_pending! + queueData.user_in_progress!;
- return `${userPending}/${totalPending}`;
- }
-
- // In single-user mode or when user counts aren't available, show total count only
return totalPending.toString();
};
export const QueueCountBadge = memo(({ targetRef }: Props) => {
const [badgePos, setBadgePos] = useState<{ x: string; y: string } | null>(null);
- const isAuthenticated = useAppSelector(selectIsAuthenticated);
const { queueData } = useGetQueueStatusQuery(undefined, {
selectFromResult: (res) => ({
queueData: res.data?.queue,
}),
});
- const badgeText = useMemo(() => getBadgeText(queueData, isAuthenticated), [queueData, isAuthenticated]);
+ const badgeText = useMemo(() => getBadgeText(queueData), [queueData]);
useEffect(() => {
if (!targetRef.current) {
diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/UpscaleSettingsAccordion/UpscaleWarning.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/UpscaleSettingsAccordion/UpscaleWarning.tsx
index 7d0a7ee2def..ff19e7ebb31 100644
--- a/invokeai/frontend/web/src/features/settingsAccordions/components/UpscaleSettingsAccordion/UpscaleWarning.tsx
+++ b/invokeai/frontend/web/src/features/settingsAccordions/components/UpscaleSettingsAccordion/UpscaleWarning.tsx
@@ -1,5 +1,6 @@
-import { Button, Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
+import { Button, Flex, Link, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { selectCurrentUser } from 'features/auth/store/authSlice';
import { selectModel } from 'features/controlLayers/store/paramsSlice';
import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore';
import {
@@ -10,6 +11,7 @@ import {
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { useCallback, useEffect, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
+import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
import { useControlNetModels } from 'services/api/hooks/modelsByType';
export const UpscaleWarning = () => {
@@ -19,6 +21,12 @@ export const UpscaleWarning = () => {
const tileControlnetModel = useAppSelector(selectTileControlNetModel);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useControlNetModels();
+ const { data: setupStatus } = useGetSetupStatusQuery();
+ const user = useAppSelector(selectCurrentUser);
+
+ const isMultiuser = setupStatus?.multiuser_enabled ?? false;
+ const isAdmin = !isMultiuser || (user?.is_admin ?? false);
+ const adminEmail = setupStatus?.admin_email ?? null;
useEffect(() => {
const validModel = modelConfigs.find((cnetModel) => {
@@ -59,19 +67,33 @@ export const UpscaleWarning = () => {
return null;
}
+ const AdminEmailLink = adminEmail ? (
+
+ {adminEmail}
+
+ ) : (
+
+ your administrator
+
+ );
+
return (
{!isBaseModelCompatible && {t('upscaling.incompatibleBaseModelDesc')}}
{warnings.length > 0 && (
-
- ),
- }}
- />
+ {isAdmin ? (
+
+ ),
+ }}
+ />
+ ) : (
+
+ )}
)}
{warnings.length > 0 && (
diff --git a/invokeai/frontend/web/src/features/ui/components/FloatingLeftPanelButtons.tsx b/invokeai/frontend/web/src/features/ui/components/FloatingLeftPanelButtons.tsx
index 81e8930e401..c9620d84ac9 100644
--- a/invokeai/frontend/web/src/features/ui/components/FloatingLeftPanelButtons.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/FloatingLeftPanelButtons.tsx
@@ -17,6 +17,8 @@ import {
PiXCircle,
} from 'react-icons/pi';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
+import { useAutoAddBoard } from 'services/api/hooks/useAutoAddBoard';
+import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
export const FloatingLeftPanelButtons = memo(() => {
return (
@@ -71,6 +73,8 @@ const InvokeIconButton = memo(() => {
const { t } = useTranslation();
const queue = useInvoke();
const shift = useShiftModifier();
+ const autoAddBoard = useAutoAddBoard();
+ const { canWriteImages } = useBoardAccess(autoAddBoard);
return (
@@ -78,7 +82,7 @@ const InvokeIconButton = memo(() => {
aria-label={t('queue.queueBack')}
onClick={shift ? queue.enqueueFront : queue.enqueueBack}
isLoading={queue.isLoading}
- isDisabled={queue.isDisabled}
+ isDisabled={queue.isDisabled || !canWriteImages}
icon={}
colorScheme="invokeYellow"
flexGrow={1}
diff --git a/invokeai/frontend/web/src/features/ui/layouts/WorkflowsLaunchpadPanel.tsx b/invokeai/frontend/web/src/features/ui/layouts/WorkflowsLaunchpadPanel.tsx
index d432f3193ef..b0d087528ad 100644
--- a/invokeai/frontend/web/src/features/ui/layouts/WorkflowsLaunchpadPanel.tsx
+++ b/invokeai/frontend/web/src/features/ui/layouts/WorkflowsLaunchpadPanel.tsx
@@ -6,6 +6,7 @@ import { memo, useCallback } from 'react';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { PiFilePlusBold, PiFolderOpenBold, PiUploadBold } from 'react-icons/pi';
+import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
import { LaunchpadButton } from './LaunchpadButton';
import { LaunchpadContainer } from './LaunchpadContainer';
@@ -14,6 +15,9 @@ export const WorkflowsLaunchpadPanel = memo(() => {
const { t } = useTranslation();
const workflowLibraryModal = useWorkflowLibraryModal();
const newWorkflow = useNewWorkflow();
+ const { data: setupStatus } = useGetSetupStatusQuery();
+
+ const isMultiuser = setupStatus?.multiuser_enabled ?? false;
const handleBrowseTemplates = useCallback(() => {
workflowLibraryModal.open();
@@ -45,11 +49,15 @@ export const WorkflowsLaunchpadPanel = memo(() => {
multiple: false,
});
+ const descriptionKey = isMultiuser
+ ? 'ui.launchpad.workflows.descriptionMultiuser'
+ : 'ui.launchpad.workflows.description';
+
return (
{/* Description */}
- {t('ui.launchpad.workflows.description')}
+ {t(descriptionKey)}
diff --git a/invokeai/frontend/web/src/features/workflowLibrary/components/SaveWorkflowAsDialog.tsx b/invokeai/frontend/web/src/features/workflowLibrary/components/SaveWorkflowAsDialog.tsx
index 6dab1e3f04d..1637cf56781 100644
--- a/invokeai/frontend/web/src/features/workflowLibrary/components/SaveWorkflowAsDialog.tsx
+++ b/invokeai/frontend/web/src/features/workflowLibrary/components/SaveWorkflowAsDialog.tsx
@@ -5,6 +5,7 @@ import {
AlertDialogFooter,
AlertDialogHeader,
Button,
+ Checkbox,
Flex,
FormControl,
FormLabel,
@@ -19,6 +20,7 @@ import { t } from 'i18next';
import { atom, computed } from 'nanostores';
import type { ChangeEvent, RefObject } from 'react';
import { memo, useCallback, useRef, useState } from 'react';
+import { useUpdateWorkflowIsPublicMutation } from 'services/api/endpoints/workflows';
import { assert } from 'tsafe';
/**
@@ -87,8 +89,10 @@ const Content = memo(({ workflow, cancelRef }: { workflow: WorkflowV3; cancelRef
}
return '';
});
+ const [isPublic, setIsPublic] = useState(false);
const { createNewWorkflow } = useCreateLibraryWorkflow();
+ const [updateIsPublic] = useUpdateWorkflowIsPublicMutation();
const inputRef = useRef(null);
@@ -96,6 +100,10 @@ const Content = memo(({ workflow, cancelRef }: { workflow: WorkflowV3; cancelRef
setName(e.target.value);
}, []);
+ const onChangeIsPublic = useCallback((e: ChangeEvent) => {
+ setIsPublic(e.target.checked);
+ }, []);
+
const onClose = useCallback(() => {
$workflowToSave.set(null);
}, []);
@@ -110,10 +118,19 @@ const Content = memo(({ workflow, cancelRef }: { workflow: WorkflowV3; cancelRef
await createNewWorkflow({
workflow,
- onSuccess: onClose,
+ onSuccess: async (workflowId?: string) => {
+ if (isPublic && workflowId) {
+ try {
+ await updateIsPublic({ workflow_id: workflowId, is_public: true }).unwrap();
+ } catch {
+ // Sharing failed silently - workflow was saved, just not shared
+ }
+ }
+ onClose();
+ },
onError: onClose,
});
- }, [workflow, name, createNewWorkflow, onClose]);
+ }, [workflow, name, isPublic, createNewWorkflow, updateIsPublic, onClose]);
return (
@@ -126,6 +143,10 @@ const Content = memo(({ workflow, cancelRef }: { workflow: WorkflowV3; cancelRef
{t('workflows.workflowName')}
+
+
+ {t('workflows.shareWorkflow')}
+
diff --git a/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryMenu/SaveWorkflowMenuItem.tsx b/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryMenu/SaveWorkflowMenuItem.tsx
index 6f5acc431ed..e683cfdbefd 100644
--- a/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryMenu/SaveWorkflowMenuItem.tsx
+++ b/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryMenu/SaveWorkflowMenuItem.tsx
@@ -1,5 +1,6 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useDoesWorkflowHaveUnsavedChanges } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
+import { useIsCurrentWorkflowOwner } from 'features/workflowLibrary/hooks/useIsCurrentWorkflowOwner';
import { useSaveOrSaveAsWorkflow } from 'features/workflowLibrary/hooks/useSaveOrSaveAsWorkflow';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -9,11 +10,12 @@ const SaveWorkflowMenuItem = () => {
const { t } = useTranslation();
const saveOrSaveAsWorkflow = useSaveOrSaveAsWorkflow();
const doesWorkflowHaveUnsavedChanges = useDoesWorkflowHaveUnsavedChanges();
+ const isCurrentWorkflowOwner = useIsCurrentWorkflowOwner();
return (
}
onClick={saveOrSaveAsWorkflow}
>
diff --git a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useCreateNewWorkflow.ts b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useCreateNewWorkflow.ts
index 543283c779c..37fe48726e0 100644
--- a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useCreateNewWorkflow.ts
+++ b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useCreateNewWorkflow.ts
@@ -29,7 +29,7 @@ export const isDraftWorkflow = (workflow: WorkflowV3): workflow is DraftWorkflow
type CreateLibraryWorkflowArg = {
workflow: DraftWorkflow;
- onSuccess?: () => void;
+ onSuccess?: (workflowId?: string) => void;
onError?: () => void;
};
@@ -70,7 +70,7 @@ export const useCreateLibraryWorkflow = (): CreateLibraryWorkflowReturn => {
// When a workflow is saved, the form field initial values are updated to the current form field values
dispatch(formFieldInitialValuesChanged({ formFieldInitialValues: getFormFieldInitialValues() }));
updateOpenedAt({ workflow_id: id });
- onSuccess?.();
+ onSuccess?.(id);
toast.update(toastRef.current, {
title: t('workflows.workflowSaved'),
status: 'success',
diff --git a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useIsCurrentWorkflowOwner.ts b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useIsCurrentWorkflowOwner.ts
new file mode 100644
index 00000000000..5183c9050b7
--- /dev/null
+++ b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useIsCurrentWorkflowOwner.ts
@@ -0,0 +1,48 @@
+import { skipToken } from '@reduxjs/toolkit/query';
+import { useAppSelector } from 'app/store/storeHooks';
+import { selectCurrentUser } from 'features/auth/store/authSlice';
+import { selectWorkflowId } from 'features/nodes/store/selectors';
+import { useMemo } from 'react';
+import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
+import { useGetWorkflowQuery } from 'services/api/endpoints/workflows';
+
+/**
+ * Returns true if the current user can save the currently-loaded workflow directly (not as a copy).
+ *
+ * In single-user mode, this always returns true.
+ * In multiuser mode, returns true when:
+ * - The workflow has no ID (new, unsaved workflow — will open Save As)
+ * - The current user is the owner of the workflow
+ * - The current user is an admin
+ */
+export const useIsCurrentWorkflowOwner = (): boolean => {
+ const workflowId = useAppSelector(selectWorkflowId);
+ const currentUser = useAppSelector(selectCurrentUser);
+ const { data: setupStatus } = useGetSetupStatusQuery();
+ const { data: workflowData } = useGetWorkflowQuery(workflowId ?? skipToken);
+
+ return useMemo(() => {
+ // In single-user mode there is no concept of ownership, so saving is always allowed.
+ if (!setupStatus?.multiuser_enabled) {
+ return true;
+ }
+
+ // No authenticated user — be permissive.
+ if (!currentUser) {
+ return true;
+ }
+
+ // No workflow ID means this is a new/unsaved workflow. Clicking "Save" will open the
+ // Save As dialog, so we should not block it.
+ if (!workflowId) {
+ return true;
+ }
+
+ // API data not yet available — be permissive to avoid incorrect disabling during loading.
+ if (!workflowData) {
+ return true;
+ }
+
+ return workflowData.user_id === currentUser.user_id || currentUser.is_admin;
+ }, [setupStatus?.multiuser_enabled, workflowId, workflowData, currentUser]);
+};
diff --git a/invokeai/frontend/web/src/services/api/endpoints/auth.ts b/invokeai/frontend/web/src/services/api/endpoints/auth.ts
index 419e7c730ce..ae7bfa7426f 100644
--- a/invokeai/frontend/web/src/services/api/endpoints/auth.ts
+++ b/invokeai/frontend/web/src/services/api/endpoints/auth.ts
@@ -34,6 +34,7 @@ type SetupStatusResponse = {
setup_required: boolean;
multiuser_enabled: boolean;
strict_password_checking: boolean;
+ admin_email: string | null;
};
export type UserDTO = components['schemas']['UserDTO'];
diff --git a/invokeai/frontend/web/src/services/api/endpoints/workflows.ts b/invokeai/frontend/web/src/services/api/endpoints/workflows.ts
index f58d3281a26..176546c90fd 100644
--- a/invokeai/frontend/web/src/services/api/endpoints/workflows.ts
+++ b/invokeai/frontend/web/src/services/api/endpoints/workflows.ts
@@ -157,6 +157,21 @@ export const workflowsApi = api.injectEndpoints({
}),
invalidatesTags: (result, error, workflow_id) => [{ type: 'Workflow', id: workflow_id }],
}),
+ updateWorkflowIsPublic: build.mutation<
+ paths['/api/v1/workflows/i/{workflow_id}/is_public']['patch']['responses']['200']['content']['application/json'],
+ { workflow_id: string; is_public: boolean }
+ >({
+ query: ({ workflow_id, is_public }) => ({
+ url: buildWorkflowsUrl(`i/${workflow_id}/is_public`),
+ method: 'PATCH',
+ body: { is_public },
+ }),
+ invalidatesTags: (result, error, { workflow_id }) => [
+ { type: 'Workflow', id: workflow_id },
+ { type: 'Workflow', id: LIST_TAG },
+ 'WorkflowCategoryCounts',
+ ],
+ }),
}),
});
@@ -173,4 +188,5 @@ export const {
useListWorkflowsInfiniteInfiniteQuery,
useSetWorkflowThumbnailMutation,
useDeleteWorkflowThumbnailMutation,
+ useUpdateWorkflowIsPublicMutation,
} = workflowsApi;
diff --git a/invokeai/frontend/web/src/services/api/hooks/useAutoAddBoard.ts b/invokeai/frontend/web/src/services/api/hooks/useAutoAddBoard.ts
new file mode 100644
index 00000000000..1ae22270079
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/hooks/useAutoAddBoard.ts
@@ -0,0 +1,21 @@
+import { useAppSelector } from 'app/store/storeHooks';
+import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
+import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
+
+/**
+ * Returns the `BoardDTO` for the board currently configured as the auto-add
+ * destination, or `null` when it is set to "Uncategorized" (`boardId === 'none'`)
+ * or when the board list has not yet loaded.
+ */
+export const useAutoAddBoard = () => {
+ const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
+ const { board } = useListAllBoardsQuery(
+ { include_archived: true },
+ {
+ selectFromResult: ({ data }) => ({
+ board: data?.find((b) => b.board_id === autoAddBoardId) ?? null,
+ }),
+ }
+ );
+ return board;
+};
diff --git a/invokeai/frontend/web/src/services/api/hooks/useBoardAccess.ts b/invokeai/frontend/web/src/services/api/hooks/useBoardAccess.ts
new file mode 100644
index 00000000000..9a222024255
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/hooks/useBoardAccess.ts
@@ -0,0 +1,32 @@
+import { useAppSelector } from 'app/store/storeHooks';
+import { selectCurrentUser } from 'features/auth/store/authSlice';
+import type { BoardDTO } from 'services/api/types';
+
+/**
+ * Returns permission flags for the given board based on the current user:
+ * - `canWriteImages`: can add / delete images in the board
+ * (owner or admin always; non-owner allowed only for public boards)
+ * - `canRenameBoard`: can rename the board (owner or admin only)
+ * - `canDeleteBoard`: can delete the board (owner or admin only)
+ *
+ * When `board` is null/undefined (e.g. "uncategorized"), all permissions are
+ * granted so that existing behaviour is preserved.
+ *
+ * When `currentUser` is null the app is running without authentication
+ * (single-user mode), so full access is granted unconditionally.
+ */
+export const useBoardAccess = (board: BoardDTO | null | undefined) => {
+ const currentUser = useAppSelector(selectCurrentUser);
+
+ if (!board) {
+ return { canWriteImages: true, canRenameBoard: true, canDeleteBoard: true };
+ }
+
+ const isOwnerOrAdmin = !currentUser || currentUser.is_admin || board.user_id === currentUser.user_id;
+
+ return {
+ canWriteImages: isOwnerOrAdmin || board.board_visibility === 'public',
+ canRenameBoard: isOwnerOrAdmin,
+ canDeleteBoard: isOwnerOrAdmin,
+ };
+};
diff --git a/invokeai/frontend/web/src/services/api/hooks/useSelectedBoard.ts b/invokeai/frontend/web/src/services/api/hooks/useSelectedBoard.ts
new file mode 100644
index 00000000000..40c6d77f37f
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/hooks/useSelectedBoard.ts
@@ -0,0 +1,21 @@
+import { useAppSelector } from 'app/store/storeHooks';
+import { selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
+import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
+
+/**
+ * Returns the `BoardDTO` for the currently selected board, or `null` when the
+ * user is viewing "Uncategorized" (`boardId === 'none'`) or when the board list
+ * has not yet loaded.
+ */
+export const useSelectedBoard = () => {
+ const selectedBoardId = useAppSelector(selectSelectedBoardId);
+ const { board } = useListAllBoardsQuery(
+ { include_archived: true },
+ {
+ selectFromResult: ({ data }) => ({
+ board: data?.find((b) => b.board_id === selectedBoardId) ?? null,
+ }),
+ }
+ );
+ return board;
+};
diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts
index f7ef229bf31..4f10cf6c483 100644
--- a/invokeai/frontend/web/src/services/api/schema.ts
+++ b/invokeai/frontend/web/src/services/api/schema.ts
@@ -1042,14 +1042,14 @@ export type paths = {
};
/**
* Get Intermediates Count
- * @description Gets the count of intermediate images
+ * @description Gets the count of intermediate images. Non-admin users only see their own intermediates.
*/
get: operations["get_intermediates_count"];
put?: never;
post?: never;
/**
* Clear Intermediates
- * @description Clears all intermediates
+ * @description Clears all intermediates. Requires admin.
*/
delete: operations["clear_intermediates"];
options?: never;
@@ -1103,7 +1103,11 @@ export type paths = {
};
/**
* Get Image Full
- * @description Gets a full-resolution image file
+ * @description Gets a full-resolution image file.
+ *
+ * This endpoint is intentionally unauthenticated because browsers load images
+ * via
tags which cannot send Bearer tokens. Image names are UUIDs,
+ * providing security through unguessability.
*/
get: operations["get_image_full"];
put?: never;
@@ -1112,7 +1116,11 @@ export type paths = {
options?: never;
/**
* Get Image Full
- * @description Gets a full-resolution image file
+ * @description Gets a full-resolution image file.
+ *
+ * This endpoint is intentionally unauthenticated because browsers load images
+ * via
tags which cannot send Bearer tokens. Image names are UUIDs,
+ * providing security through unguessability.
*/
head: operations["get_image_full_head"];
patch?: never;
@@ -1127,7 +1135,11 @@ export type paths = {
};
/**
* Get Image Thumbnail
- * @description Gets a thumbnail image file
+ * @description Gets a thumbnail image file.
+ *
+ * This endpoint is intentionally unauthenticated because browsers load images
+ * via
tags which cannot send Bearer tokens. Image names are UUIDs,
+ * providing security through unguessability.
*/
get: operations["get_image_thumbnail"];
put?: never;
@@ -1187,7 +1199,7 @@ export type paths = {
post?: never;
/**
* Delete Uncategorized Images
- * @description Deletes all images that are uncategorized
+ * @description Deletes all uncategorized images owned by the current user (or all if admin)
*/
delete: operations["delete_uncategorized_images"];
options?: never;
@@ -1255,7 +1267,10 @@ export type paths = {
};
/**
* Get Bulk Download Item
- * @description Gets a bulk download zip file
+ * @description Gets a bulk download zip file.
+ *
+ * Requires authentication. The caller must be the user who initiated the
+ * download (tracked by the bulk download service) or an admin.
*/
get: operations["get_bulk_download_item"];
put?: never;
@@ -1727,7 +1742,7 @@ export type paths = {
};
/**
* Get Queue Item Ids
- * @description Gets all queue item ids that match the given parameters
+ * @description Gets all queue item ids that match the given parameters. Non-admin users only see their own items.
*/
get: operations["get_queue_item_ids"];
put?: never;
@@ -1987,7 +2002,7 @@ export type paths = {
};
/**
* Get Queue Status
- * @description Gets the status of the session queue
+ * @description Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it.
*/
get: operations["get_queue_status"];
put?: never;
@@ -2007,7 +2022,7 @@ export type paths = {
};
/**
* Get Batch Status
- * @description Gets the status of the session queue
+ * @description Gets the status of a batch. Non-admin users only see their own batches.
*/
get: operations["get_batch_status"];
put?: never;
@@ -2071,7 +2086,7 @@ export type paths = {
};
/**
* Counts By Destination
- * @description Gets the counts of queue items by destination
+ * @description Gets the counts of queue items by destination. Non-admin users only see their own items.
*/
get: operations["counts_by_destination"];
put?: never;
@@ -2163,7 +2178,11 @@ export type paths = {
};
/**
* Get Workflow Thumbnail
- * @description Gets a workflow's thumbnail image
+ * @description Gets a workflow's thumbnail image.
+ *
+ * This endpoint is intentionally unauthenticated because browsers load images
+ * via
tags which cannot send Bearer tokens. Workflow IDs are UUIDs,
+ * providing security through unguessability.
*/
get: operations["get_workflow_thumbnail"];
/**
@@ -2182,6 +2201,26 @@ export type paths = {
patch?: never;
trace?: never;
};
+ "/api/v1/workflows/i/{workflow_id}/is_public": {
+ parameters: {
+ query?: never;
+ header?: never;
+ path?: never;
+ cookie?: never;
+ };
+ get?: never;
+ put?: never;
+ post?: never;
+ delete?: never;
+ options?: never;
+ head?: never;
+ /**
+ * Update Workflow Is Public
+ * @description Updates whether a workflow is shared publicly
+ */
+ patch: operations["update_workflow_is_public"];
+ trace?: never;
+ };
"/api/v1/workflows/tags": {
parameters: {
query?: never;
@@ -3397,6 +3436,12 @@ export type components = {
* @default null
*/
origin: string | null;
+ /**
+ * User Id
+ * @description The ID of the user who enqueued the batch
+ * @default system
+ */
+ user_id: string;
};
/** BatchStatus */
BatchStatus: {
@@ -3587,6 +3632,8 @@ export type components = {
* @description Whether or not the board is archived
*/
archived?: boolean | null;
+ /** @description The visibility of the board. */
+ board_visibility?: components["schemas"]["BoardVisibility"] | null;
};
/**
* BoardDTO
@@ -3633,6 +3680,11 @@ export type components = {
* @description Whether or not the board is archived.
*/
archived: boolean;
+ /**
+ * @description The visibility of the board.
+ * @default private
+ */
+ board_visibility?: components["schemas"]["BoardVisibility"];
/**
* Image Count
* @description The number of images in the board.
@@ -3666,6 +3718,12 @@ export type components = {
* @enum {string}
*/
BoardRecordOrderBy: "created_at" | "board_name";
+ /**
+ * BoardVisibility
+ * @description The visibility options for a board.
+ * @enum {string}
+ */
+ BoardVisibility: "private" | "shared" | "public";
/** Body_add_image_to_board */
Body_add_image_to_board: {
/**
@@ -3918,6 +3976,14 @@ export type components = {
/** @description The updated workflow */
workflow: components["schemas"]["Workflow"];
};
+ /** Body_update_workflow_is_public */
+ Body_update_workflow_is_public: {
+ /**
+ * Is Public
+ * @description Whether the workflow should be shared publicly
+ */
+ is_public: boolean;
+ };
/** Body_upload_image */
Body_upload_image: {
/**
@@ -4210,6 +4276,12 @@ export type components = {
* @description The name of the bulk image download item
*/
bulk_download_item_name: string;
+ /**
+ * User Id
+ * @description The ID of the user who initiated the download
+ * @default system
+ */
+ user_id: string;
};
/**
* BulkDownloadErrorEvent
@@ -4236,6 +4308,12 @@ export type components = {
* @description The name of the bulk image download item
*/
bulk_download_item_name: string;
+ /**
+ * User Id
+ * @description The ID of the user who initiated the download
+ * @default system
+ */
+ user_id: string;
/**
* Error
* @description The error message
@@ -4267,6 +4345,12 @@ export type components = {
* @description The name of the bulk image download item
*/
bulk_download_item_name: string;
+ /**
+ * User Id
+ * @description The ID of the user who initiated the download
+ * @default system
+ */
+ user_id: string;
};
/**
* BulkReidentifyModelsRequest
@@ -24494,6 +24578,11 @@ export type components = {
* @description The ID of the queue
*/
queue_id: string;
+ /**
+ * User Id
+ * @description The ID of the user whose recall parameters were updated
+ */
+ user_id: string;
/**
* Parameters
* @description The recall parameters that were updated
@@ -26094,16 +26183,6 @@ export type components = {
* @description Total number of queue items
*/
total: number;
- /**
- * User Pending
- * @description Number of queue items with status 'pending' for the current user
- */
- user_pending?: number | null;
- /**
- * User In Progress
- * @description Number of queue items with status 'in_progress' for the current user
- */
- user_in_progress?: number | null;
};
/**
* SetupRequest
@@ -26159,6 +26238,11 @@ export type components = {
* @description Whether strict password requirements are enforced
*/
strict_password_checking: boolean;
+ /**
+ * Admin Email
+ * @description Email of the first active admin user, if any
+ */
+ admin_email?: string | null;
};
/**
* Show Image
@@ -29426,6 +29510,16 @@ export type components = {
* @description The opened timestamp of the workflow.
*/
opened_at?: string | null;
+ /**
+ * User Id
+ * @description The id of the user who owns this workflow.
+ */
+ user_id: string;
+ /**
+ * Is Public
+ * @description Whether this workflow is shared with all users.
+ */
+ is_public: boolean;
/** @description The workflow. */
workflow: components["schemas"]["Workflow"];
};
@@ -29456,6 +29550,16 @@ export type components = {
* @description The opened timestamp of the workflow.
*/
opened_at?: string | null;
+ /**
+ * User Id
+ * @description The id of the user who owns this workflow.
+ */
+ user_id: string;
+ /**
+ * Is Public
+ * @description Whether this workflow is shared with all users.
+ */
+ is_public: boolean;
/**
* Description
* @description The description of the workflow.
@@ -29479,7 +29583,7 @@ export type components = {
* @description The order by options for workflow records
* @enum {string}
*/
- WorkflowRecordOrderBy: "created_at" | "updated_at" | "opened_at" | "name";
+ WorkflowRecordOrderBy: "created_at" | "updated_at" | "opened_at" | "name" | "is_public";
/** WorkflowRecordWithThumbnailDTO */
WorkflowRecordWithThumbnailDTO: {
/**
@@ -29507,6 +29611,16 @@ export type components = {
* @description The opened timestamp of the workflow.
*/
opened_at?: string | null;
+ /**
+ * User Id
+ * @description The id of the user who owns this workflow.
+ */
+ user_id: string;
+ /**
+ * Is Public
+ * @description Whether this workflow is shared with all users.
+ */
+ is_public: boolean;
/** @description The workflow. */
workflow: components["schemas"]["Workflow"];
/**
@@ -34637,6 +34751,8 @@ export interface operations {
query?: string | null;
/** @description Whether to include/exclude recent workflows */
has_been_opened?: boolean | null;
+ /** @description Filter by public/shared status */
+ is_public?: boolean | null;
};
header?: never;
path?: never;
@@ -34811,11 +34927,49 @@ export interface operations {
};
};
};
+ update_workflow_is_public: {
+ parameters: {
+ query?: never;
+ header?: never;
+ path: {
+ /** @description The workflow to update */
+ workflow_id: string;
+ };
+ cookie?: never;
+ };
+ requestBody: {
+ content: {
+ "application/json": components["schemas"]["Body_update_workflow_is_public"];
+ };
+ };
+ responses: {
+ /** @description Successful Response */
+ 200: {
+ headers: {
+ [name: string]: unknown;
+ };
+ content: {
+ "application/json": components["schemas"]["WorkflowRecordDTO"];
+ };
+ };
+ /** @description Validation Error */
+ 422: {
+ headers: {
+ [name: string]: unknown;
+ };
+ content: {
+ "application/json": components["schemas"]["HTTPValidationError"];
+ };
+ };
+ };
+ };
get_all_tags: {
parameters: {
query?: {
/** @description The categories to include */
categories?: components["schemas"]["WorkflowCategory"][] | null;
+ /** @description Filter by public/shared status */
+ is_public?: boolean | null;
};
header?: never;
path?: never;
@@ -34852,6 +35006,8 @@ export interface operations {
categories?: components["schemas"]["WorkflowCategory"][] | null;
/** @description Whether to include/exclude recent workflows */
has_been_opened?: boolean | null;
+ /** @description Filter by public/shared status */
+ is_public?: boolean | null;
};
header?: never;
path?: never;
@@ -34888,6 +35044,8 @@ export interface operations {
categories: components["schemas"]["WorkflowCategory"][];
/** @description Whether to include/exclude recent workflows */
has_been_opened?: boolean | null;
+ /** @description Filter by public/shared status */
+ is_public?: boolean | null;
};
header?: never;
path?: never;
diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts
index 4752b97d065..f4561cad97f 100644
--- a/invokeai/frontend/web/src/services/api/types.ts
+++ b/invokeai/frontend/web/src/services/api/types.ts
@@ -352,7 +352,7 @@ export type ModelInstallStatus = S['InstallStatus'];
export type Graph = S['Graph'];
export type NonNullableGraph = SetRequired;
export type Batch = S['Batch'];
-export const zWorkflowRecordOrderBy = z.enum(['name', 'created_at', 'updated_at', 'opened_at']);
+export const zWorkflowRecordOrderBy = z.enum(['name', 'created_at', 'updated_at', 'opened_at', 'is_public']);
export type WorkflowRecordOrderBy = z.infer;
assert>();
diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx
index 2e0ff2251eb..774acd3f934 100644
--- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx
+++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx
@@ -1,4 +1,4 @@
-import { ExternalLink, Flex, Text } from '@invoke-ai/ui-library';
+import { Flex, Text } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { socketConnected } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import type { AppStore } from 'app/store/store';
@@ -28,7 +28,7 @@ import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks
import { zNodeStatus } from 'features/nodes/types/invocation';
import { modelSelected } from 'features/parameters/store/actions';
import ErrorToastDescription, { getTitle } from 'features/toast/ErrorToastDescription';
-import { toast } from 'features/toast/toast';
+import { toast, toastApi } from 'features/toast/toast';
import { t } from 'i18next';
import { LRUCache } from 'lru-cache';
import { Trans } from 'react-i18next';
@@ -855,14 +855,61 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
log.debug({ data }, 'Bulk gallery download ready');
const { bulk_download_item_name } = data;
- // TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first
+ // Dismiss the "preparing" toast (which uses a prefixed id to avoid the
+ // race condition where this socket event arrives before the Redux
+ // middleware processes the POST response).
+ toastApi.close(`preparing:${bulk_download_item_name}`);
+
+ // The GET endpoint requires authentication, so we use fetch() with the
+ // Authorization header rather than a plain link (which cannot
+ // carry headers). After fetching the blob, we create a temporary object
+ // URL and trigger the browser's save dialog programmatically.
const url = `/api/v1/images/download/${bulk_download_item_name}`;
+ const token = localStorage.getItem('auth_token');
+ const headers: Record = token ? { Authorization: `Bearer ${token}` } : {};
+
+ const handleDownload = () => {
+ fetch(url, { headers })
+ .then((res) => {
+ if (!res.ok) {
+ throw new Error(`Download failed: ${res.status}`);
+ }
+ return res.blob();
+ })
+ .then((blob) => {
+ const blobUrl = URL.createObjectURL(blob);
+ const a = document.createElement('a');
+ a.href = blobUrl;
+ a.download = bulk_download_item_name;
+ document.body.appendChild(a);
+ a.click();
+ document.body.removeChild(a);
+ // Delay revocation — the browser's save dialog is asynchronous,
+ // and revoking immediately would invalidate the URL before the
+ // download completes.
+ setTimeout(() => URL.revokeObjectURL(blobUrl), 60_000);
+ })
+ .catch((err) => {
+ log.error({ err }, 'Bulk download fetch failed');
+ toast({
+ id: `error:${bulk_download_item_name}`,
+ title: t('gallery.bulkDownloadFailed'),
+ status: 'error',
+ description: String(err),
+ });
+ });
+ };
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadReady'),
status: 'success',
- description: ,
+ description: (
+ // eslint-disable-next-line react/jsx-no-bind -- not a component render; no re-render cost
+
+ {t('gallery.clickToDownload')}
+
+ ),
duration: null,
});
});
@@ -872,6 +919,9 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
const { bulk_download_item_name, error } = data;
+ // Dismiss the "preparing" toast
+ toastApi.close(`preparing:${bulk_download_item_name}`);
+
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadFailed'),
diff --git a/tests/app/routers/test_boards_multiuser.py b/tests/app/routers/test_boards_multiuser.py
index d5c48481567..ab64ac8a9b4 100644
--- a/tests/app/routers/test_boards_multiuser.py
+++ b/tests/app/routers/test_boards_multiuser.py
@@ -457,3 +457,221 @@ def test_enqueue_batch_requires_auth(enable_multiuser_for_tests: Any, client: Te
},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
+
+
+# ---------------------------------------------------------------------------
+# Board visibility tests
+# ---------------------------------------------------------------------------
+
+
+def test_board_created_with_private_visibility(client: TestClient, user1_token: str):
+ """Test that newly created boards default to private visibility."""
+ create = client.post(
+ "/api/v1/boards/?board_name=Visibility+Default+Board",
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert create.status_code == status.HTTP_201_CREATED
+ data = create.json()
+ assert data["board_visibility"] == "private"
+
+
+def test_set_board_visibility_shared(client: TestClient, user1_token: str):
+ """Test that the board owner can set their board to shared."""
+ create = client.post(
+ "/api/v1/boards/?board_name=Shared+Board",
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert create.status_code == status.HTTP_201_CREATED
+ board_id = create.json()["board_id"]
+
+ response = client.patch(
+ f"/api/v1/boards/{board_id}",
+ json={"board_visibility": "shared"},
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert response.status_code == status.HTTP_201_CREATED
+ assert response.json()["board_visibility"] == "shared"
+
+
+def test_set_board_visibility_public(client: TestClient, user1_token: str):
+ """Test that the board owner can set their board to public."""
+ create = client.post(
+ "/api/v1/boards/?board_name=Public+Board",
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert create.status_code == status.HTTP_201_CREATED
+ board_id = create.json()["board_id"]
+
+ response = client.patch(
+ f"/api/v1/boards/{board_id}",
+ json={"board_visibility": "public"},
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert response.status_code == status.HTTP_201_CREATED
+ assert response.json()["board_visibility"] == "public"
+
+
+def test_shared_board_visible_to_other_users(client: TestClient, user1_token: str, user2_token: str):
+ """Test that a shared board is accessible to other authenticated users."""
+ # user1 creates a board and sets it to shared
+ create = client.post(
+ "/api/v1/boards/?board_name=User1+Shared+Board",
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert create.status_code == status.HTTP_201_CREATED
+ board_id = create.json()["board_id"]
+
+ client.patch(
+ f"/api/v1/boards/{board_id}",
+ json={"board_visibility": "shared"},
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+
+ # user2 should be able to access the shared board
+ response = client.get(
+ f"/api/v1/boards/{board_id}",
+ headers={"Authorization": f"Bearer {user2_token}"},
+ )
+ assert response.status_code == status.HTTP_200_OK
+ assert response.json()["board_id"] == board_id
+
+
+def test_public_board_visible_to_other_users(client: TestClient, user1_token: str, user2_token: str):
+ """Test that a public board is accessible to other authenticated users."""
+ # user1 creates a board and sets it to public
+ create = client.post(
+ "/api/v1/boards/?board_name=User1+Public+Board",
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert create.status_code == status.HTTP_201_CREATED
+ board_id = create.json()["board_id"]
+
+ client.patch(
+ f"/api/v1/boards/{board_id}",
+ json={"board_visibility": "public"},
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+
+ # user2 should be able to access the public board
+ response = client.get(
+ f"/api/v1/boards/{board_id}",
+ headers={"Authorization": f"Bearer {user2_token}"},
+ )
+ assert response.status_code == status.HTTP_200_OK
+ assert response.json()["board_id"] == board_id
+
+
+def test_shared_board_appears_in_other_user_list(client: TestClient, user1_token: str, user2_token: str):
+ """Test that shared boards appear in other users' board listings."""
+ # user1 creates and shares a board
+ create = client.post(
+ "/api/v1/boards/?board_name=User1+Listed+Shared+Board",
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert create.status_code == status.HTTP_201_CREATED
+ board_id = create.json()["board_id"]
+
+ client.patch(
+ f"/api/v1/boards/{board_id}",
+ json={"board_visibility": "shared"},
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+
+ # user2 should see the shared board in their listing
+ response = client.get(
+ "/api/v1/boards/?all=true",
+ headers={"Authorization": f"Bearer {user2_token}"},
+ )
+ assert response.status_code == status.HTTP_200_OK
+ board_ids = [b["board_id"] for b in response.json()]
+ assert board_id in board_ids
+
+
+def test_private_board_not_visible_after_privacy_change(client: TestClient, user1_token: str, user2_token: str):
+ """Test that reverting a board from shared to private hides it from other users."""
+ # user1 creates a board, makes it shared, then reverts to private
+ create = client.post(
+ "/api/v1/boards/?board_name=Reverted+Board",
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert create.status_code == status.HTTP_201_CREATED
+ board_id = create.json()["board_id"]
+
+ client.patch(
+ f"/api/v1/boards/{board_id}",
+ json={"board_visibility": "shared"},
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ client.patch(
+ f"/api/v1/boards/{board_id}",
+ json={"board_visibility": "private"},
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+
+ # user2 should not be able to access the now-private board
+ response = client.get(
+ f"/api/v1/boards/{board_id}",
+ headers={"Authorization": f"Bearer {user2_token}"},
+ )
+ assert response.status_code == status.HTTP_403_FORBIDDEN
+
+
+def test_non_owner_cannot_change_board_visibility(client: TestClient, user1_token: str, user2_token: str):
+ """Test that a non-owner cannot change a board's visibility."""
+ # user1 creates a board
+ create = client.post(
+ "/api/v1/boards/?board_name=User1+Private+Locked+Board",
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert create.status_code == status.HTTP_201_CREATED
+ board_id = create.json()["board_id"]
+
+ # user2 tries to make it public - should be forbidden
+ response = client.patch(
+ f"/api/v1/boards/{board_id}",
+ json={"board_visibility": "public"},
+ headers={"Authorization": f"Bearer {user2_token}"},
+ )
+ assert response.status_code == status.HTTP_403_FORBIDDEN
+
+
+def test_shared_board_image_names_visible_to_other_users(client: TestClient, user1_token: str, user2_token: str):
+ """Test that image names for shared boards are accessible to other users."""
+ create = client.post(
+ "/api/v1/boards/?board_name=User1+Shared+Images+Board",
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert create.status_code == status.HTTP_201_CREATED
+ board_id = create.json()["board_id"]
+
+ client.patch(
+ f"/api/v1/boards/{board_id}",
+ json={"board_visibility": "shared"},
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+
+ # user2 can access image names for a shared board
+ response = client.get(
+ f"/api/v1/boards/{board_id}/image_names",
+ headers={"Authorization": f"Bearer {user2_token}"},
+ )
+ assert response.status_code == status.HTTP_200_OK
+
+
+def test_admin_can_change_any_board_visibility(client: TestClient, admin_token: str, user1_token: str):
+ """Test that an admin can change the visibility of any user's board."""
+ create = client.post(
+ "/api/v1/boards/?board_name=User1+Board+For+Admin+Visibility",
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert create.status_code == status.HTTP_201_CREATED
+ board_id = create.json()["board_id"]
+
+ # Admin sets it to public
+ response = client.patch(
+ f"/api/v1/boards/{board_id}",
+ json={"board_visibility": "public"},
+ headers={"Authorization": f"Bearer {admin_token}"},
+ )
+ assert response.status_code == status.HTTP_201_CREATED
+ assert response.json()["board_visibility"] == "public"
diff --git a/tests/app/routers/test_images.py b/tests/app/routers/test_images.py
index c0da3ec51ca..619ecb78c4f 100644
--- a/tests/app/routers/test_images.py
+++ b/tests/app/routers/test_images.py
@@ -52,7 +52,9 @@ def mock_get(*args, **kwargs):
def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None:
- monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
+ mock_deps = MockApiDependencies(mock_invoker)
+ monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", mock_deps)
+ monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", mock_deps)
monkeypatch.setattr(
"invokeai.app.api.routers.images.ApiDependencies.invoker.services.bulk_download.generate_item_id",
lambda arg: "test",
@@ -79,7 +81,9 @@ def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker:
mock_file.write_text("contents")
monkeypatch.setattr(mock_invoker.services.bulk_download, "get_path", lambda x: str(mock_file))
- monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
+ mock_deps = MockApiDependencies(mock_invoker)
+ monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", mock_deps)
+ monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", mock_deps)
def mock_add_task(*args, **kwargs):
return None
@@ -93,7 +97,9 @@ def mock_add_task(*args, **kwargs):
def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
- monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
+ mock_deps = MockApiDependencies(mock_invoker)
+ monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", mock_deps)
+ monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", mock_deps)
def mock_add_task(*args, **kwargs):
return None
@@ -112,7 +118,9 @@ def test_get_bulk_download_image_image_deleted_after_response(
mock_file.write_text("contents")
monkeypatch.setattr(mock_invoker.services.bulk_download, "get_path", lambda x: str(mock_file))
- monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
+ mock_deps = MockApiDependencies(mock_invoker)
+ monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", mock_deps)
+ monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", mock_deps)
client.get("/api/v1/images/download/test.zip")
diff --git a/tests/app/routers/test_multiuser_authorization.py b/tests/app/routers/test_multiuser_authorization.py
new file mode 100644
index 00000000000..e9efae7034d
--- /dev/null
+++ b/tests/app/routers/test_multiuser_authorization.py
@@ -0,0 +1,1819 @@
+"""Tests for API-level authorization on board-image mutations, image mutations,
+workflow thumbnail access, and admin email leak prevention.
+
+These tests verify the security fixes for:
+1. Shared-board write protection bypass via direct API calls
+2. Image mutation endpoints lacking ownership checks
+3. Private workflow thumbnail exposure
+4. Admin email leak on unauthenticated status endpoint
+"""
+
+import logging
+from typing import Any
+from unittest.mock import MagicMock
+
+import pytest
+from fastapi import status
+from fastapi.testclient import TestClient
+
+from invokeai.app.api.dependencies import ApiDependencies
+from invokeai.app.api_app import app
+from invokeai.app.services.config.config_default import InvokeAIAppConfig
+from invokeai.app.services.invocation_services import InvocationServices
+from invokeai.app.services.invoker import Invoker
+from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
+from invokeai.app.services.users.users_common import UserCreateRequest
+from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
+from invokeai.backend.util.logging import InvokeAILogger
+from tests.fixtures.sqlite_database import create_mock_sqlite_database
+
+
+class MockApiDependencies(ApiDependencies):
+ invoker: Invoker
+
+ def __init__(self, invoker: Invoker) -> None:
+ self.invoker = invoker
+
+
+WORKFLOW_BODY = {
+ "name": "Test Workflow",
+ "author": "",
+ "description": "",
+ "version": "1.0.0",
+ "contact": "",
+ "tags": "",
+ "notes": "",
+ "nodes": [],
+ "edges": [],
+ "exposedFields": [],
+ "meta": {"version": "3.0.0", "category": "user"},
+ "id": None,
+ "form_fields": [],
+}
+
+
+@pytest.fixture
+def setup_jwt_secret():
+ from invokeai.app.services.auth.token_service import set_jwt_secret
+
+ set_jwt_secret("test-secret-key-for-unit-tests-only-do-not-use-in-production")
+
+
+@pytest.fixture
+def client():
+ return TestClient(app)
+
+
+@pytest.fixture
+def mock_services() -> InvocationServices:
+ from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
+ from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
+ from invokeai.app.services.boards.boards_default import BoardService
+ from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
+ from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import (
+ ClientStatePersistenceSqlite,
+ )
+ from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage
+ from invokeai.app.services.images.images_default import ImageService
+ from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
+ from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
+ from invokeai.app.services.users.users_default import UserService
+ from tests.test_nodes import TestEventService
+
+ configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
+ logger = InvokeAILogger.get_logger()
+ db = create_mock_sqlite_database(configuration, logger)
+
+ return InvocationServices(
+ board_image_records=SqliteBoardImageRecordStorage(db=db),
+ board_images=None, # type: ignore
+ board_records=SqliteBoardRecordStorage(db=db),
+ boards=BoardService(),
+ bulk_download=BulkDownloadService(),
+ configuration=configuration,
+ events=TestEventService(),
+ image_files=None, # type: ignore
+ image_records=SqliteImageRecordStorage(db=db),
+ images=ImageService(),
+ invocation_cache=MemoryInvocationCache(max_cache_size=0),
+ logger=logging, # type: ignore
+ model_images=None, # type: ignore
+ model_manager=None, # type: ignore
+ download_queue=None, # type: ignore
+ names=None, # type: ignore
+ performance_statistics=InvocationStatsService(),
+ session_processor=None, # type: ignore
+ session_queue=None, # type: ignore
+ urls=None, # type: ignore
+ workflow_records=SqliteWorkflowRecordsStorage(db=db),
+ tensors=None, # type: ignore
+ conditioning=None, # type: ignore
+ style_preset_records=None, # type: ignore
+ style_preset_image_files=None, # type: ignore
+ workflow_thumbnails=None, # type: ignore
+ model_relationship_records=None, # type: ignore
+ model_relationships=None, # type: ignore
+ client_state_persistence=ClientStatePersistenceSqlite(db=db),
+ users=UserService(db),
+ )
+
+
+@pytest.fixture()
+def mock_invoker(mock_services: InvocationServices) -> Invoker:
+ return Invoker(services=mock_services)
+
+
+def _save_image(mock_invoker: Invoker, image_name: str, user_id: str) -> None:
+ """Helper to insert an image record owned by a specific user."""
+ from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
+
+ mock_invoker.services.image_records.save(
+ image_name=image_name,
+ image_origin=ResourceOrigin.INTERNAL,
+ image_category=ImageCategory.GENERAL,
+ width=100,
+ height=100,
+ has_workflow=False,
+ user_id=user_id,
+ )
+
+
+def _create_user(mock_invoker: Invoker, email: str, display_name: str, is_admin: bool = False) -> str:
+ user = mock_invoker.services.users.create(
+ UserCreateRequest(email=email, display_name=display_name, password="TestPass123", is_admin=is_admin)
+ )
+ return user.user_id
+
+
+def _login(client: TestClient, email: str) -> str:
+ r = client.post("/api/v1/auth/login", json={"email": email, "password": "TestPass123", "remember_me": False})
+ assert r.status_code == 200
+ return r.json()["token"]
+
+
+def _auth(token: str) -> dict[str, str]:
+ return {"Authorization": f"Bearer {token}"}
+
+
+@pytest.fixture
+def enable_multiuser(monkeypatch: Any, mock_invoker: Invoker):
+ mock_invoker.services.configuration.multiuser = True
+
+ mock_board_images = MagicMock()
+ mock_board_images.get_all_board_image_names_for_board.return_value = []
+ mock_invoker.services.board_images = mock_board_images
+
+ mock_workflow_thumbnails = MagicMock()
+ mock_workflow_thumbnails.get_url.return_value = None
+ mock_invoker.services.workflow_thumbnails = mock_workflow_thumbnails
+
+ mock_deps = MockApiDependencies(mock_invoker)
+ monkeypatch.setattr("invokeai.app.api.routers.auth.ApiDependencies", mock_deps)
+ monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", mock_deps)
+ monkeypatch.setattr("invokeai.app.api.routers.boards.ApiDependencies", mock_deps)
+ monkeypatch.setattr("invokeai.app.api.routers.board_images.ApiDependencies", mock_deps)
+ monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", mock_deps)
+ monkeypatch.setattr("invokeai.app.api.routers.workflows.ApiDependencies", mock_deps)
+ monkeypatch.setattr("invokeai.app.api.routers.session_queue.ApiDependencies", mock_deps)
+ monkeypatch.setattr("invokeai.app.api.routers.recall_parameters.ApiDependencies", mock_deps)
+ monkeypatch.setattr("invokeai.app.api.routers.model_manager.ApiDependencies", mock_deps)
+ yield
+
+
+@pytest.fixture
+def admin_token(setup_jwt_secret: None, enable_multiuser: Any, mock_invoker: Invoker, client: TestClient):
+ _create_user(mock_invoker, "admin@test.com", "Admin", is_admin=True)
+ return _login(client, "admin@test.com")
+
+
+@pytest.fixture
+def user1_token(enable_multiuser: Any, mock_invoker: Invoker, client: TestClient, admin_token: str):
+ _create_user(mock_invoker, "user1@test.com", "User One")
+ return _login(client, "user1@test.com")
+
+
+@pytest.fixture
+def user2_token(enable_multiuser: Any, mock_invoker: Invoker, client: TestClient, admin_token: str):
+ _create_user(mock_invoker, "user2@test.com", "User Two")
+ return _login(client, "user2@test.com")
+
+
+def _create_board(client: TestClient, token: str, name: str = "Test Board") -> str:
+ r = client.post(f"/api/v1/boards/?board_name={name.replace(' ', '+')}", headers=_auth(token))
+ assert r.status_code == status.HTTP_201_CREATED
+ return r.json()["board_id"]
+
+
+def _share_board(client: TestClient, token: str, board_id: str) -> None:
+ r = client.patch(f"/api/v1/boards/{board_id}", json={"board_visibility": "shared"}, headers=_auth(token))
+ assert r.status_code == status.HTTP_201_CREATED
+
+
+def _set_board_visibility(client: TestClient, token: str, board_id: str, visibility: str) -> None:
+ r = client.patch(f"/api/v1/boards/{board_id}", json={"board_visibility": visibility}, headers=_auth(token))
+ assert r.status_code == status.HTTP_201_CREATED
+
+
+def _create_workflow(client: TestClient, token: str) -> str:
+ r = client.post("/api/v1/workflows/", json={"workflow": WORKFLOW_BODY}, headers=_auth(token))
+ assert r.status_code == 200
+ return r.json()["workflow_id"]
+
+
+# ===========================================================================
+# 1. Board-image mutation authorization
+# ===========================================================================
+
+
+class TestBoardImageMutationAuth:
+ """Tests that board_images mutation endpoints enforce ownership."""
+
+ def test_add_image_to_board_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.post("/api/v1/board_images/", json={"board_id": "x", "image_name": "y"})
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_add_image_to_board_batch_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.post("/api/v1/board_images/batch", json={"board_id": "x", "image_names": ["y"]})
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_remove_image_from_board_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.request("DELETE", "/api/v1/board_images/", json={"image_name": "y"})
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_remove_images_from_board_batch_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.post("/api/v1/board_images/batch/delete", json={"image_names": ["y"]})
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_non_owner_cannot_add_image_to_shared_board(self, client: TestClient, user1_token: str, user2_token: str):
+ board_id = _create_board(client, user1_token, "User1 Shared Board")
+ _share_board(client, user1_token, board_id)
+
+ r = client.post(
+ "/api/v1/board_images/",
+ json={"board_id": board_id, "image_name": "some-image"},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_non_owner_cannot_add_images_batch_to_shared_board(
+ self, client: TestClient, user1_token: str, user2_token: str
+ ):
+ board_id = _create_board(client, user1_token, "User1 Shared Board Batch")
+ _share_board(client, user1_token, board_id)
+
+ r = client.post(
+ "/api/v1/board_images/batch",
+ json={"board_id": board_id, "image_names": ["img1", "img2"]},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_admin_can_add_image_to_any_board(
+ self, client: TestClient, mock_invoker: Invoker, admin_token: str, user1_token: str
+ ):
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-admin-board-img", user1.user_id)
+ board_id = _create_board(client, user1_token, "User1 Board For Admin")
+
+ # Admin can add any image to any board — should not be 403
+ r = client.post(
+ "/api/v1/board_images/",
+ json={"board_id": board_id, "image_name": "user1-admin-board-img"},
+ headers=_auth(admin_token),
+ )
+ assert r.status_code != status.HTTP_403_FORBIDDEN
+
+ def test_non_owner_can_add_own_image_to_public_board(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """Public boards are documented as writable by other authenticated users."""
+ public_board_id = _create_board(client, user1_token, "User1 Public Board")
+ _set_board_visibility(client, user1_token, public_board_id, "public")
+
+ user2 = mock_invoker.services.users.get_by_email("user2@test.com")
+ assert user2 is not None
+ _save_image(mock_invoker, "user2-public-board-img", user2.user_id)
+
+ r = client.post(
+ "/api/v1/board_images/",
+ json={"board_id": public_board_id, "image_name": "user2-public-board-img"},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_201_CREATED
+
+ def test_owner_can_add_image_to_own_board(self, client: TestClient, mock_invoker: Invoker, user1_token: str):
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-own-board-img", user1.user_id)
+ board_id = _create_board(client, user1_token, "User1 Own Board")
+
+ r = client.post(
+ "/api/v1/board_images/",
+ json={"board_id": board_id, "image_name": "user1-own-board-img"},
+ headers=_auth(user1_token),
+ )
+ assert r.status_code != status.HTTP_403_FORBIDDEN
+
+ def test_non_owner_cannot_add_other_users_image_to_own_board(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """Attacker creates their own board, then tries to add victim's image to it.
+ This must be rejected — otherwise the attacker gains mutation rights via
+ the board-ownership fallback in _assert_image_owner."""
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "victim-image", user1.user_id)
+
+ attacker_board = _create_board(client, user2_token, "Attacker Board")
+
+ r = client.post(
+ "/api/v1/board_images/",
+ json={"board_id": attacker_board, "image_name": "victim-image"},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_non_owner_cannot_batch_add_other_users_images_to_own_board(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """Same attack via the batch endpoint."""
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "victim-batch-img", user1.user_id)
+
+ attacker_board = _create_board(client, user2_token, "Attacker Batch Board")
+
+ r = client.post(
+ "/api/v1/board_images/batch",
+ json={"board_id": attacker_board, "image_names": ["victim-batch-img"]},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+
+# ===========================================================================
+# 2a. Image read-access authorization
+# ===========================================================================
+
+
+class TestImageReadAuth:
+ """Tests that image GET endpoints enforce visibility."""
+
+ def test_get_image_dto_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.get("/api/v1/images/i/some-image")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_get_image_metadata_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.get("/api/v1/images/i/some-image/metadata")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_get_image_full_is_unauthenticated(self, enable_multiuser: Any, client: TestClient):
+ # Binary image endpoints are intentionally unauthenticated because
+ # browsers load them via
which cannot send Bearer tokens.
+ r = client.get("/api/v1/images/i/some-image/full")
+ assert r.status_code != status.HTTP_401_UNAUTHORIZED
+
+ def test_get_image_thumbnail_is_unauthenticated(self, enable_multiuser: Any, client: TestClient):
+ r = client.get("/api/v1/images/i/some-image/thumbnail")
+ assert r.status_code != status.HTTP_401_UNAUTHORIZED
+
+ def test_get_image_urls_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.get("/api/v1/images/i/some-image/urls")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_non_owner_cannot_read_private_image(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """User2 should not be able to read user1's image that is not on a shared board."""
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-private-img", user1.user_id)
+
+ r = client.get("/api/v1/images/i/user1-private-img", headers=_auth(user2_token))
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_owner_can_read_own_image(self, client: TestClient, mock_invoker: Invoker, user1_token: str):
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-readable", user1.user_id)
+
+ r = client.get("/api/v1/images/i/user1-readable", headers=_auth(user1_token))
+ # Should not be 403 (may be 404/500 due to missing board_image_records mock)
+ assert r.status_code != status.HTTP_403_FORBIDDEN
+
+ def test_admin_can_read_any_image(
+ self, client: TestClient, mock_invoker: Invoker, admin_token: str, user1_token: str
+ ):
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-admin-read", user1.user_id)
+
+ r = client.get("/api/v1/images/i/user1-admin-read", headers=_auth(admin_token))
+ assert r.status_code != status.HTTP_403_FORBIDDEN
+
+ def test_shared_board_image_readable_by_other_user(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """An image on a shared board should be readable by any authenticated user."""
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "shared-board-img", user1.user_id)
+
+ # Create a shared board and add the image to it
+ board_id = _create_board(client, user1_token, "Shared Read Board")
+ _share_board(client, user1_token, board_id)
+ mock_invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name="shared-board-img")
+
+ r = client.get("/api/v1/images/i/shared-board-img", headers=_auth(user2_token))
+ # Should not be 403 — image is on a shared board
+ assert r.status_code != status.HTTP_403_FORBIDDEN
+
+ def test_non_owner_cannot_read_image_metadata(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-meta-blocked", user1.user_id)
+
+ r = client.get("/api/v1/images/i/user1-meta-blocked/metadata", headers=_auth(user2_token))
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_list_images_private_board_rejected_for_non_owner(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """User2 must not be able to enumerate images on user1's private board
+ via GET /api/v1/images?board_id=..."""
+ board_id = _create_board(client, user1_token, "Private Enum Board")
+
+ r = client.get(f"/api/v1/images/?board_id={board_id}", headers=_auth(user2_token))
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_list_images_shared_board_allowed_for_non_owner(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """User2 should be able to list images on user1's shared board."""
+ board_id = _create_board(client, user1_token, "Shared Enum Board")
+ _share_board(client, user1_token, board_id)
+
+ r = client.get(f"/api/v1/images/?board_id={board_id}", headers=_auth(user2_token))
+ assert r.status_code == status.HTTP_200_OK
+
+ def test_get_image_names_private_board_rejected_for_non_owner(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """User2 must not be able to enumerate image names on user1's private board
+ via GET /api/v1/images/names?board_id=..."""
+ board_id = _create_board(client, user1_token, "Private Names Board")
+
+ r = client.get(f"/api/v1/images/names?board_id={board_id}", headers=_auth(user2_token))
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_get_image_names_shared_board_allowed_for_non_owner(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """User2 should be able to list image names on user1's shared board."""
+ board_id = _create_board(client, user1_token, "Shared Names Board")
+ _share_board(client, user1_token, board_id)
+
+ r = client.get(f"/api/v1/images/names?board_id={board_id}", headers=_auth(user2_token))
+ assert r.status_code == status.HTTP_200_OK
+
+ def test_list_images_own_private_board_allowed(self, client: TestClient, mock_invoker: Invoker, user1_token: str):
+ """Owner should be able to list images on their own private board."""
+ board_id = _create_board(client, user1_token, "Own Private Board")
+
+ r = client.get(f"/api/v1/images/?board_id={board_id}", headers=_auth(user1_token))
+ assert r.status_code == status.HTTP_200_OK
+
+ def test_admin_can_list_images_on_any_board(
+ self, client: TestClient, mock_invoker: Invoker, admin_token: str, user1_token: str
+ ):
+ """Admin should be able to list images on any board."""
+ board_id = _create_board(client, user1_token, "Admin Enum Board")
+
+ r = client.get(f"/api/v1/images/?board_id={board_id}", headers=_auth(admin_token))
+ assert r.status_code == status.HTTP_200_OK
+
+
+# ===========================================================================
+# 2b. Image mutation authorization
+# ===========================================================================
+
+
+class TestImageUploadAuth:
+ """Tests that image upload enforces board ownership."""
+
+ def test_upload_to_other_users_shared_board_forbidden(self, client: TestClient, user1_token: str, user2_token: str):
+ """A user should not be able to upload an image into another user's shared board."""
+ board_id = _create_board(client, user1_token, "User1 Shared Upload Board")
+ _share_board(client, user1_token, board_id)
+
+ # user2 tries to upload into user1's shared board
+ import io
+
+ fake_image = io.BytesIO(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
+ r = client.post(
+ f"/api/v1/images/upload?image_category=general&is_intermediate=false&board_id={board_id}",
+ files={"file": ("test.png", fake_image, "image/png")},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_owner_can_upload_to_own_shared_board(self, client: TestClient, user1_token: str):
+ board_id = _create_board(client, user1_token, "User1 Own Upload Board")
+ _share_board(client, user1_token, board_id)
+
+ import io
+
+ fake_image = io.BytesIO(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
+ r = client.post(
+ f"/api/v1/images/upload?image_category=general&is_intermediate=false&board_id={board_id}",
+ files={"file": ("test.png", fake_image, "image/png")},
+ headers=_auth(user1_token),
+ )
+ # Should not be 403 (may fail for other reasons in test env)
+ assert r.status_code != status.HTTP_403_FORBIDDEN
+
+ def test_non_owner_can_upload_to_public_board(self, client: TestClient, user1_token: str, user2_token: str):
+ """Public boards allow any authenticated user to upload images."""
+ board_id = _create_board(client, user1_token, "User1 Public Upload Board")
+ _set_board_visibility(client, user1_token, board_id, "public")
+
+ import io
+
+ fake_image = io.BytesIO(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
+ r = client.post(
+ f"/api/v1/images/upload?image_category=general&is_intermediate=false&board_id={board_id}",
+ files={"file": ("test.png", fake_image, "image/png")},
+ headers=_auth(user2_token),
+ )
+ # Should not be 403 (may fail downstream for other reasons in test env)
+ assert r.status_code != status.HTTP_403_FORBIDDEN
+
+
+class TestImageMutationAuth:
+ """Tests that image mutation endpoints enforce ownership."""
+
+ def test_delete_image_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.delete("/api/v1/images/i/some-image")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_update_image_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.patch("/api/v1/images/i/some-image", json={"starred": True})
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_batch_delete_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.post("/api/v1/images/delete", json={"image_names": ["x"]})
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_star_images_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.post("/api/v1/images/star", json={"image_names": ["x"]})
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_unstar_images_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.post("/api/v1/images/unstar", json={"image_names": ["x"]})
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_clear_intermediates_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.delete("/api/v1/images/intermediates")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_delete_uncategorized_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.delete("/api/v1/images/uncategorized")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_non_owner_cannot_delete_image(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """User2 should not be able to delete user1's image."""
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-image", user1.user_id)
+
+ r = client.delete("/api/v1/images/i/user1-image", headers=_auth(user2_token))
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_owner_can_delete_own_image(self, client: TestClient, mock_invoker: Invoker, user1_token: str):
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-delete-me", user1.user_id)
+
+ r = client.delete("/api/v1/images/i/user1-delete-me", headers=_auth(user1_token))
+ # Should not be 403 (may be 200 or 500 depending on file system)
+ assert r.status_code != status.HTTP_403_FORBIDDEN
+
+ def test_admin_can_delete_any_image(
+ self, client: TestClient, mock_invoker: Invoker, admin_token: str, user1_token: str
+ ):
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-admin-delete", user1.user_id)
+
+ r = client.delete("/api/v1/images/i/user1-admin-delete", headers=_auth(admin_token))
+ assert r.status_code != status.HTTP_403_FORBIDDEN
+
+ def test_board_owner_can_delete_image_on_own_board(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str
+ ):
+ """Board owner should be able to delete images on their board even if
+ the image's user_id is 'system' (e.g. generated images)."""
+ # Create image owned by "system" (simulates queue-generated image)
+ _save_image(mock_invoker, "system-img-on-board", "system")
+
+ # Create a board owned by user1 and add the image to it
+ board_id = _create_board(client, user1_token, "User1 Board With System Img")
+ mock_invoker.services.board_image_records.add_image_to_board(
+ board_id=board_id, image_name="system-img-on-board"
+ )
+
+ r = client.delete("/api/v1/images/i/system-img-on-board", headers=_auth(user1_token))
+ assert r.status_code != status.HTTP_403_FORBIDDEN
+
+ def test_non_owner_cannot_update_image(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-no-star", user1.user_id)
+
+ r = client.patch(
+ "/api/v1/images/i/user1-no-star",
+ json={"starred": True},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_non_owner_cannot_star_image(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-star-blocked", user1.user_id)
+
+ r = client.post(
+ "/api/v1/images/star",
+ json={"image_names": ["user1-star-blocked"]},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_non_owner_cannot_batch_delete_image(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-batch-del", user1.user_id)
+
+ r = client.post(
+ "/api/v1/images/delete",
+ json={"image_names": ["user1-batch-del"]},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_non_owner_can_delete_image_from_public_board(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """Public-board semantics promise delete access to images contained in the board."""
+ public_board_id = _create_board(client, user1_token, "User1 Public Delete Board")
+ _set_board_visibility(client, user1_token, public_board_id, "public")
+
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-public-delete", user1.user_id)
+ mock_invoker.services.board_image_records.add_image_to_board(public_board_id, "user1-public-delete")
+
+ r = client.delete(
+ "/api/v1/images/i/user1-public-delete",
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_200_OK
+
+ def test_clear_intermediates_non_admin_forbidden(self, client: TestClient, user1_token: str):
+ r = client.delete("/api/v1/images/intermediates", headers=_auth(user1_token))
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_get_intermediates_count_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.get("/api/v1/images/intermediates")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_download_images_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.post("/api/v1/images/download", json={"image_names": ["x"]})
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_non_owner_cannot_fetch_existing_bulk_download_item(
+ self,
+ client: TestClient,
+ mock_invoker: Invoker,
+ monkeypatch: Any,
+ tmp_path: Any,
+ user1_token: str,
+ user2_token: str,
+ ):
+ """A bulk download zip should be fetchable only by its owner."""
+ from fastapi import BackgroundTasks
+
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+
+ mock_file = tmp_path / "owned-download.zip"
+ mock_file.write_text("contents")
+
+ monkeypatch.setattr(mock_invoker.services.bulk_download, "get_path", lambda _: str(mock_file))
+ monkeypatch.setattr(mock_invoker.services.bulk_download, "get_owner", lambda _: user1.user_id)
+ monkeypatch.setattr(BackgroundTasks, "add_task", lambda *args, **kwargs: None)
+
+ r = client.get("/api/v1/images/download/owned-download.zip", headers=_auth(user2_token))
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_images_by_names_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.post("/api/v1/images/images_by_names", json={"image_names": ["x"]})
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_images_by_names_filters_unauthorized(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """images_by_names should silently skip images the caller cannot access."""
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-by-name", user1.user_id)
+
+ r = client.post(
+ "/api/v1/images/images_by_names",
+ json={"image_names": ["user1-by-name"]},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == 200
+ # user2 should get an empty list — the image belongs to user1
+ assert r.json() == []
+
+ def test_none_board_image_names_only_return_callers_uncategorized_images(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """The uncategorized-images sentinel must not expose other users' image names."""
+ mock_invoker.services.board_images.get_all_board_image_names_for_board.side_effect = (
+ mock_invoker.services.board_image_records.get_all_board_image_names_for_board
+ )
+
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ user2 = mock_invoker.services.users.get_by_email("user2@test.com")
+ assert user1 is not None
+ assert user2 is not None
+
+ _save_image(mock_invoker, "user1-uncategorized-private", user1.user_id)
+ _save_image(mock_invoker, "user2-uncategorized-private", user2.user_id)
+
+ r = client.get("/api/v1/boards/none/image_names", headers=_auth(user2_token))
+ assert r.status_code == status.HTTP_200_OK
+ assert "user2-uncategorized-private" in r.json()
+ assert "user1-uncategorized-private" not in r.json()
+
+
+# ===========================================================================
+# 3. Workflow mutation authorization (additional)
+# ===========================================================================
+
+
+class TestWorkflowListScoping:
+ """Tests that listing workflows in multiuser mode does not filter out default workflows."""
+
+ def test_default_workflows_visible_when_listing_user_and_default(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str
+ ):
+ """When categories=['user','default'], default workflows must still appear even
+ though user_id_filter is set to the current user (default workflows belong to 'system')."""
+ from invokeai.app.services.workflow_records.workflow_records_common import (
+ Workflow,
+ WorkflowCategory,
+ WorkflowMeta,
+ WorkflowWithoutID,
+ )
+ from invokeai.app.util.misc import uuid_string
+
+ default_wf = WorkflowWithoutID(
+ name="Test Default Workflow",
+ description="A built-in workflow",
+ meta=WorkflowMeta(version="3.0.0", category=WorkflowCategory.Default),
+ nodes=[],
+ edges=[],
+ tags="",
+ author="",
+ contact="",
+ version="1.0.0",
+ notes="",
+ exposedFields=[],
+ form_fields=[],
+ )
+ wf_with_id = Workflow(**default_wf.model_dump(), id=uuid_string())
+ # Insert directly via DB since the create API rejects default workflows
+ with mock_invoker.services.workflow_records._db.transaction() as cursor:
+ cursor.execute(
+ "INSERT INTO workflow_library (workflow_id, workflow, user_id) VALUES (?, ?, ?)",
+ (wf_with_id.id, wf_with_id.model_dump_json(), "system"),
+ )
+
+ # Also create a user workflow via the API
+ _create_workflow(client, user1_token)
+
+ # List with categories=user&categories=default
+ r = client.get(
+ "/api/v1/workflows/?categories=user&categories=default",
+ headers=_auth(user1_token),
+ )
+ assert r.status_code == 200
+ data = r.json()
+ categories_found = {item["category"] for item in data["items"]}
+ assert "default" in categories_found, (
+ f"Default workflows were filtered out. Categories found: {categories_found}"
+ )
+ assert "user" in categories_found
+
+ def test_default_workflows_visible_when_no_category_filter(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str
+ ):
+ """When no categories filter is given, default workflows should still appear."""
+ from invokeai.app.services.workflow_records.workflow_records_common import (
+ Workflow,
+ WorkflowCategory,
+ WorkflowMeta,
+ WorkflowWithoutID,
+ )
+ from invokeai.app.util.misc import uuid_string
+
+ default_wf = WorkflowWithoutID(
+ name="Another Default Workflow",
+ description="Built-in",
+ meta=WorkflowMeta(version="3.0.0", category=WorkflowCategory.Default),
+ nodes=[],
+ edges=[],
+ tags="",
+ author="",
+ contact="",
+ version="1.0.0",
+ notes="",
+ exposedFields=[],
+ form_fields=[],
+ )
+ wf_with_id = Workflow(**default_wf.model_dump(), id=uuid_string())
+ with mock_invoker.services.workflow_records._db.transaction() as cursor:
+ cursor.execute(
+ "INSERT INTO workflow_library (workflow_id, workflow, user_id) VALUES (?, ?, ?)",
+ (wf_with_id.id, wf_with_id.model_dump_json(), "system"),
+ )
+
+ _create_workflow(client, user1_token)
+
+ r = client.get("/api/v1/workflows/", headers=_auth(user1_token))
+ assert r.status_code == 200
+ data = r.json()
+ categories_found = {item["category"] for item in data["items"]}
+ assert "default" in categories_found, (
+ f"Default workflows were filtered out. Categories found: {categories_found}"
+ )
+
+
+class TestWorkflowMutationAuth:
+ """Tests for additional workflow mutation endpoints."""
+
+ def test_update_opened_at_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.put("/api/v1/workflows/i/some-id/opened_at")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_non_owner_cannot_update_opened_at(self, client: TestClient, user1_token: str, user2_token: str):
+ workflow_id = _create_workflow(client, user1_token)
+ r = client.put(
+ f"/api/v1/workflows/i/{workflow_id}/opened_at",
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_owner_can_update_opened_at(self, client: TestClient, user1_token: str):
+ workflow_id = _create_workflow(client, user1_token)
+ r = client.put(
+ f"/api/v1/workflows/i/{workflow_id}/opened_at",
+ headers=_auth(user1_token),
+ )
+ assert r.status_code == 200
+
+
+# ===========================================================================
+# 4. Workflow thumbnail authorization
+# ===========================================================================
+
+
+class TestWorkflowThumbnailAuth:
+ """Tests for the workflow thumbnail GET endpoint.
+
+ Workflow and image thumbnail endpoints are intentionally unauthenticated
+ because browsers load them via
tags which cannot send Bearer
+ tokens. IDs are UUIDs, providing security through unguessability.
+ """
+
+ def test_thumbnail_is_unauthenticated(self, enable_multiuser: Any, client: TestClient):
+ # Binary image endpoints don't require auth — loaded via
+ r = client.get("/api/v1/workflows/i/some-workflow/thumbnail")
+ assert r.status_code != status.HTTP_401_UNAUTHORIZED
+
+
+# ===========================================================================
+# 4. Admin email leak prevention
+# ===========================================================================
+
+
+class TestAdminEmailLeak:
+ """Tests that the auth status endpoint does not leak admin email."""
+
+ def test_status_does_not_leak_admin_email_when_setup_complete(self, client: TestClient, admin_token: str):
+ """After setup is complete, admin_email must be null."""
+ r = client.get("/api/v1/auth/status")
+ assert r.status_code == 200
+ data = r.json()
+ assert data["multiuser_enabled"] is True
+ assert data["setup_required"] is False
+ assert data["admin_email"] is None
+
+ def test_status_returns_admin_email_during_setup(
+ self, setup_jwt_secret: None, enable_multiuser: Any, mock_invoker: Invoker, client: TestClient
+ ):
+ """Before any admin exists, setup_required=True and admin_email may be returned."""
+ # Don't create any users -- setup_required should be True
+ r = client.get("/api/v1/auth/status")
+ assert r.status_code == 200
+ data = r.json()
+ assert data["setup_required"] is True
+ # admin_email is null here because no admin exists yet, which is correct
+
+ def test_status_no_leak_in_single_user_mode(
+ self, setup_jwt_secret: None, monkeypatch: Any, mock_invoker: Invoker, client: TestClient
+ ):
+ """In single-user mode, admin_email should always be null."""
+ mock_invoker.services.configuration.multiuser = False
+ mock_deps = MockApiDependencies(mock_invoker)
+ monkeypatch.setattr("invokeai.app.api.routers.auth.ApiDependencies", mock_deps)
+
+ r = client.get("/api/v1/auth/status")
+ assert r.status_code == 200
+ data = r.json()
+ assert data["admin_email"] is None
+ assert data["multiuser_enabled"] is False
+
+
+# ===========================================================================
+# 6. Session queue authorization
+# ===========================================================================
+
+
+class TestSessionQueueAuth:
+ """Tests that session queue endpoints enforce authentication."""
+
+ def test_get_queue_item_ids_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.get("/api/v1/queue/default/item_ids")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_get_current_queue_item_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.get("/api/v1/queue/default/current")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_get_next_queue_item_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.get("/api/v1/queue/default/next")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_get_batch_status_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.get("/api/v1/queue/default/b/some-batch/status")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_counts_by_destination_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.get("/api/v1/queue/default/counts_by_destination?destination=canvas")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+
+# ===========================================================================
+# 6b. Session queue sanitization (cross-user isolation)
+# ===========================================================================
+
+
+class TestSessionQueueSanitization:
+ """Tests that sanitize_queue_item_for_user strips all sensitive fields
+ from queue items viewed by non-owner, non-admin users."""
+
+ @pytest.fixture
+ def _sample_queue_item(self):
+ from invokeai.app.services.shared.graph import Graph, GraphExecutionState
+
+ return SessionQueueItem(
+ item_id=42,
+ status="pending",
+ priority=10,
+ batch_id="batch-abc",
+ origin="workflows",
+ destination="canvas",
+ session_id="sess-123",
+ session=GraphExecutionState(id="sess-123", graph=Graph()),
+ error_type="RuntimeError",
+ error_message="something broke",
+ error_traceback="Traceback ...",
+ created_at="2026-01-01T00:00:00",
+ updated_at="2026-01-01T01:00:00",
+ started_at="2026-01-01T00:30:00",
+ completed_at=None,
+ queue_id="default",
+ user_id="owner-user",
+ user_display_name="Owner Display",
+ user_email="owner@test.com",
+ field_values=None,
+ workflow=None,
+ )
+
+ def test_owner_sees_all_fields(self, _sample_queue_item: SessionQueueItem):
+ from invokeai.app.api.routers.session_queue import sanitize_queue_item_for_user
+
+ result = sanitize_queue_item_for_user(_sample_queue_item, "owner-user", is_admin=False)
+ assert result.user_id == "owner-user"
+ assert result.user_display_name == "Owner Display"
+ assert result.user_email == "owner@test.com"
+ assert result.batch_id == "batch-abc"
+ assert result.origin == "workflows"
+ assert result.destination == "canvas"
+ assert result.session_id == "sess-123"
+ assert result.priority == 10
+
+ def test_admin_sees_all_fields(self, _sample_queue_item: SessionQueueItem):
+ from invokeai.app.api.routers.session_queue import sanitize_queue_item_for_user
+
+ result = sanitize_queue_item_for_user(_sample_queue_item, "admin-user", is_admin=True)
+ assert result.user_id == "owner-user"
+ assert result.user_display_name == "Owner Display"
+ assert result.user_email == "owner@test.com"
+ assert result.batch_id == "batch-abc"
+
+ def test_non_owner_sees_only_status_timestamps_errors(self, _sample_queue_item: SessionQueueItem):
+ from invokeai.app.api.routers.session_queue import sanitize_queue_item_for_user
+
+ result = sanitize_queue_item_for_user(_sample_queue_item, "other-user", is_admin=False)
+
+ # Preserved: item_id, queue_id, status, timestamps
+ assert result.item_id == 42
+ assert result.queue_id == "default"
+ assert result.status == "pending"
+ assert result.created_at == "2026-01-01T00:00:00"
+ assert result.updated_at == "2026-01-01T01:00:00"
+ assert result.started_at == "2026-01-01T00:30:00"
+ assert result.completed_at is None
+
+ # Stripped: errors (may leak file paths, prompts, model names)
+ assert result.error_type is None
+ assert result.error_message is None
+ assert result.error_traceback is None
+
+ # Stripped: user identity
+ assert result.user_id == "redacted"
+ assert result.user_display_name is None
+ assert result.user_email is None
+
+ # Stripped: generation metadata
+ assert result.batch_id == "redacted"
+ assert result.session_id == "redacted"
+ assert result.origin is None
+ assert result.destination is None
+ assert result.priority == 0
+ assert result.field_values is None
+ assert result.retried_from_item_id is None
+ assert result.workflow is None
+ assert result.session.id == "redacted"
+ assert len(result.session.graph.nodes) == 0
+
+ def test_sanitization_does_not_mutate_original(self, _sample_queue_item: SessionQueueItem):
+ from invokeai.app.api.routers.session_queue import sanitize_queue_item_for_user
+
+ sanitize_queue_item_for_user(_sample_queue_item, "other-user", is_admin=False)
+ # Original should be unchanged
+ assert _sample_queue_item.user_id == "owner-user"
+ assert _sample_queue_item.user_email == "owner@test.com"
+ assert _sample_queue_item.batch_id == "batch-abc"
+
+
+# ===========================================================================
+# 7. Recall parameters authorization
+# ===========================================================================
+
+
+class TestRecallParametersAuth:
+ """Tests that recall parameter endpoints enforce authentication."""
+
+ def test_get_recall_parameters_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.get("/api/v1/recall/default")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_update_recall_parameters_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.post("/api/v1/recall/default", json={"positive_prompt": "test"})
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+
+# ===========================================================================
+# 7a2. Recall parameters image access control
+# ===========================================================================
+
+
+class TestRecallImageAccess:
+ """Tests that recall parameter image references are validated for read access."""
+
+ def test_recall_controlnet_with_other_users_image_rejected(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """User2 must not be able to reference user1's private image in a control layer."""
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "victim-ctrl-img", user1.user_id)
+
+ r = client.post(
+ "/api/v1/recall/default",
+ json={"control_layers": [{"model_name": "some-controlnet", "image_name": "victim-ctrl-img"}]},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_recall_ip_adapter_with_other_users_image_rejected(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """User2 must not be able to reference user1's private image in an IP adapter."""
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "victim-ip-img", user1.user_id)
+
+ r = client.post(
+ "/api/v1/recall/default",
+ json={"ip_adapters": [{"model_name": "some-ip-adapter", "image_name": "victim-ip-img"}]},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_recall_own_image_allowed(self, client: TestClient, mock_invoker: Invoker, user1_token: str):
+ """Owner should be able to reference their own image in recall parameters."""
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "own-ctrl-img", user1.user_id)
+
+ r = client.post(
+ "/api/v1/recall/default",
+ json={"control_layers": [{"model_name": "some-controlnet", "image_name": "own-ctrl-img"}]},
+ headers=_auth(user1_token),
+ )
+ # Should not be 403 (may fail downstream for other reasons, e.g. model not found)
+ assert r.status_code != status.HTTP_403_FORBIDDEN
+
+ def test_recall_shared_board_image_allowed(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """An image on a shared board should be usable in recall by any user."""
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "shared-recall-img", user1.user_id)
+
+ board_id = _create_board(client, user1_token, "Shared Recall Board")
+ _share_board(client, user1_token, board_id)
+ mock_invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name="shared-recall-img")
+
+ r = client.post(
+ "/api/v1/recall/default",
+ json={"ip_adapters": [{"model_name": "some-ip-adapter", "image_name": "shared-recall-img"}]},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code != status.HTTP_403_FORBIDDEN
+
+ def test_recall_admin_can_reference_any_image(
+ self, client: TestClient, mock_invoker: Invoker, admin_token: str, user1_token: str
+ ):
+ """Admin should be able to reference any user's image."""
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "admin-recall-img", user1.user_id)
+
+ r = client.post(
+ "/api/v1/recall/default",
+ json={"control_layers": [{"model_name": "some-controlnet", "image_name": "admin-recall-img"}]},
+ headers=_auth(admin_token),
+ )
+ assert r.status_code != status.HTTP_403_FORBIDDEN
+
+
+# ===========================================================================
+# 7b. Recall parameters cross-user isolation
+# ===========================================================================
+
+
+class TestRecallParametersIsolation:
+ """Tests that recall parameters are scoped per-user, not globally by queue_id."""
+
+ def test_user1_write_does_not_leak_to_user2(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """User1 sets a recall parameter; user2 should not see it in client state."""
+ # user1 writes a recall parameter
+ r = client.post(
+ "/api/v1/recall/default",
+ json={"positive_prompt": "user1 secret prompt"},
+ headers=_auth(user1_token),
+ )
+ assert r.status_code == 200
+
+ # Verify that user1's data is stored under user1's user_id, not the queue_id
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ user2 = mock_invoker.services.users.get_by_email("user2@test.com")
+ assert user1 is not None
+ assert user2 is not None
+
+ # user1 should have the value
+ val = mock_invoker.services.client_state_persistence.get_by_key(user1.user_id, "recall_positive_prompt")
+ assert val is not None
+ assert "user1 secret prompt" in val
+
+ # user2 should NOT have the value
+ val2 = mock_invoker.services.client_state_persistence.get_by_key(user2.user_id, "recall_positive_prompt")
+ assert val2 is None
+
+ def test_two_users_independent_state(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """Both users can write recall params independently without overwriting each other."""
+ r1 = client.post(
+ "/api/v1/recall/default",
+ json={"positive_prompt": "prompt from user1"},
+ headers=_auth(user1_token),
+ )
+ assert r1.status_code == 200
+
+ r2 = client.post(
+ "/api/v1/recall/default",
+ json={"positive_prompt": "prompt from user2"},
+ headers=_auth(user2_token),
+ )
+ assert r2.status_code == 200
+
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ user2 = mock_invoker.services.users.get_by_email("user2@test.com")
+ assert user1 is not None
+ assert user2 is not None
+
+ val1 = mock_invoker.services.client_state_persistence.get_by_key(user1.user_id, "recall_positive_prompt")
+ val2 = mock_invoker.services.client_state_persistence.get_by_key(user2.user_id, "recall_positive_prompt")
+ assert val1 is not None and "prompt from user1" in val1
+ assert val2 is not None and "prompt from user2" in val2
+
+
+# ===========================================================================
+# 9. Recall parameters event user scoping
+# ===========================================================================
+
+
+class TestRecallParametersEventScoping:
+ """Tests that RecallParametersUpdatedEvent carries user_id for targeted delivery."""
+
+ def test_event_includes_user_id(self):
+ """RecallParametersUpdatedEvent.build() must set user_id so the socket handler
+ can route the event to the correct user room instead of broadcasting."""
+ from invokeai.app.services.events.events_common import RecallParametersUpdatedEvent
+
+ event = RecallParametersUpdatedEvent.build(
+ queue_id="default",
+ user_id="user-abc",
+ parameters={"positive_prompt": "test"},
+ )
+ assert event.queue_id == "default"
+ assert event.user_id == "user-abc"
+ assert event.parameters == {"positive_prompt": "test"}
+
+ def test_event_not_broadcast_to_all_queue_subscribers(self):
+ """RecallParametersUpdatedEvent must have a user_id field so _handle_queue_event
+ in sockets.py can route it to the owner room + admin room, not the queue room."""
+ from invokeai.app.services.events.events_common import RecallParametersUpdatedEvent
+
+ event = RecallParametersUpdatedEvent.build(
+ queue_id="default",
+ user_id="owner-123",
+ parameters={"seed": 42},
+ )
+ # The event must carry user_id; without it the socket handler would
+ # fall through to the generic else branch and broadcast to all subscribers
+ assert hasattr(event, "user_id")
+ assert event.user_id == "owner-123"
+
+
+# ===========================================================================
+# 10. Queue status endpoint scoping
+# ===========================================================================
+
+
+class TestQueueStatusScoping:
+ """Tests that queue status, batch status, and counts_by_destination
+ endpoints scope data to the current user for non-admin callers."""
+
+ def test_get_queue_status_hides_current_item_for_non_owner(self):
+ """get_queue_status() must not expose current item details to non-owner, non-admin users."""
+ from invokeai.app.services.session_queue.session_queue_common import SessionQueueStatus
+
+ # Simulate a status where the current item belongs to another user
+ # When user_id is provided and doesn't match, item details should be None
+ status_obj = SessionQueueStatus(
+ queue_id="default",
+ item_id=None, # hidden because user doesn't own current item
+ session_id=None,
+ batch_id=None,
+ pending=2,
+ in_progress=0,
+ completed=1,
+ failed=0,
+ canceled=0,
+ total=3,
+ )
+ # Verify the model accepts None for item details
+ assert status_obj.item_id is None
+ assert status_obj.session_id is None
+ assert status_obj.batch_id is None
+
+ def test_session_queue_status_no_user_fields(self):
+ """SessionQueueStatus should not have user_pending/user_in_progress fields anymore.
+ Non-admin users now get their own counts in the main pending/in_progress fields."""
+ from invokeai.app.services.session_queue.session_queue_common import SessionQueueStatus
+
+ fields = set(SessionQueueStatus.model_fields.keys())
+ assert "user_pending" not in fields
+ assert "user_in_progress" not in fields
+
+
+# ===========================================================================
+# 10b. Model install job authorization
+# ===========================================================================
+
+
+class TestModelInstallAuth:
+ """Tests that model install job endpoints require admin authentication."""
+
+ def test_list_model_installs_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.get("/api/v2/models/install")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_get_model_install_job_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.get("/api/v2/models/install/1")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_pause_model_install_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.post("/api/v2/models/install/1/pause")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_resume_model_install_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.post("/api/v2/models/install/1/resume")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_restart_failed_model_install_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.post("/api/v2/models/install/1/restart_failed")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_restart_model_install_file_requires_auth(self, enable_multiuser: Any, client: TestClient):
+ r = client.post("/api/v2/models/install/1/restart_file", json="https://example.com/model.safetensors")
+ assert r.status_code == status.HTTP_401_UNAUTHORIZED
+
+ def test_non_admin_cannot_list_model_installs(self, enable_multiuser: Any, client: TestClient, user1_token: str):
+ r = client.get("/api/v2/models/install", headers=_auth(user1_token))
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_non_admin_cannot_pause_model_install(self, enable_multiuser: Any, client: TestClient, user1_token: str):
+ r = client.post("/api/v2/models/install/1/pause", headers=_auth(user1_token))
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+
+# ===========================================================================
+# 11. Bulk download access control
+# ===========================================================================
+
+
+class TestBulkDownloadAccessControl:
+ """Tests that bulk download endpoints enforce image/board read access and
+ that the fetch endpoint verifies ownership of the zip file."""
+
+ @pytest.fixture(autouse=True)
+ def _mock_background_tasks(self, monkeypatch: Any):
+ """Prevent BackgroundTasks.add_task from actually running the handler,
+ which would fail because image_files is None in the test fixture."""
+ from fastapi import BackgroundTasks
+
+ monkeypatch.setattr(BackgroundTasks, "add_task", lambda *args, **kwargs: None)
+
+ def test_bulk_download_by_image_names_rejected_for_non_owner(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """User2 must not be able to bulk-download images owned by user1."""
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-private-dl", user1.user_id)
+
+ r = client.post(
+ "/api/v1/images/download",
+ json={"image_names": ["user1-private-dl"]},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_bulk_download_by_image_names_allowed_for_owner(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str
+ ):
+ """Owner should be able to bulk-download their own images."""
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-own-dl", user1.user_id)
+
+ r = client.post(
+ "/api/v1/images/download",
+ json={"image_names": ["user1-own-dl"]},
+ headers=_auth(user1_token),
+ )
+ assert r.status_code == status.HTTP_202_ACCEPTED
+
+ def test_bulk_download_by_board_rejected_for_private_board(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """User2 must not be able to bulk-download from user1's private board."""
+ board_id = _create_board(client, user1_token, "Private DL Board")
+
+ r = client.post(
+ "/api/v1/images/download",
+ json={"board_id": board_id},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_403_FORBIDDEN
+
+ def test_bulk_download_by_shared_board_allowed(
+ self, client: TestClient, mock_invoker: Invoker, user1_token: str, user2_token: str
+ ):
+ """User2 should be able to bulk-download from user1's shared board."""
+ board_id = _create_board(client, user1_token, "Shared DL Board")
+ _share_board(client, user1_token, board_id)
+
+ r = client.post(
+ "/api/v1/images/download",
+ json={"board_id": board_id},
+ headers=_auth(user2_token),
+ )
+ assert r.status_code == status.HTTP_202_ACCEPTED
+
+ def test_admin_can_bulk_download_any_images(
+ self, client: TestClient, mock_invoker: Invoker, admin_token: str, user1_token: str
+ ):
+ """Admin should be able to bulk-download any user's images."""
+ user1 = mock_invoker.services.users.get_by_email("user1@test.com")
+ assert user1 is not None
+ _save_image(mock_invoker, "user1-admin-dl", user1.user_id)
+
+ r = client.post(
+ "/api/v1/images/download",
+ json={"image_names": ["user1-admin-dl"]},
+ headers=_auth(admin_token),
+ )
+ assert r.status_code == status.HTTP_202_ACCEPTED
+
+ def test_bulk_download_events_carry_user_id(self):
+ """BulkDownloadEventBase must carry user_id so events can be routed privately."""
+ from invokeai.app.services.events.events_common import (
+ BulkDownloadCompleteEvent,
+ BulkDownloadErrorEvent,
+ BulkDownloadEventBase,
+ BulkDownloadStartedEvent,
+ )
+
+ assert "user_id" in BulkDownloadEventBase.model_fields
+
+ started = BulkDownloadStartedEvent.build("default", "item-1", "item-1.zip", user_id="owner-abc")
+ assert started.user_id == "owner-abc"
+
+ complete = BulkDownloadCompleteEvent.build("default", "item-2", "item-2.zip", user_id="owner-abc")
+ assert complete.user_id == "owner-abc"
+
+ error = BulkDownloadErrorEvent.build("default", "item-3", "item-3.zip", "oops", user_id="owner-abc")
+ assert error.user_id == "owner-abc"
+
+ def test_bulk_download_event_not_emitted_to_shared_default_room(self, mock_invoker: Invoker, monkeypatch: Any):
+ """Bulk download capability tokens must not be broadcast to the shared default room."""
+ import asyncio
+ from unittest.mock import AsyncMock
+
+ from fastapi import FastAPI
+
+ from invokeai.app.api.sockets import SocketIO
+ from invokeai.app.services.events.events_common import BulkDownloadCompleteEvent
+
+ mock_deps = MockApiDependencies(mock_invoker)
+ monkeypatch.setattr("invokeai.app.api.dependencies.ApiDependencies", mock_deps)
+
+ fastapi_app = FastAPI()
+ socketio = SocketIO(fastapi_app)
+
+ event = BulkDownloadCompleteEvent.build("default", "item-x", "item-x.zip", user_id="owner-xyz")
+
+ mock_emit = AsyncMock()
+ socketio._sio.emit = mock_emit
+
+ asyncio.run(socketio._handle_bulk_image_download_event(("bulk_download_complete", event)))
+
+ rooms_emitted_to = [call.kwargs.get("room") for call in mock_emit.call_args_list]
+ assert "default" not in rooms_emitted_to
+ assert "user:owner-xyz" in rooms_emitted_to
+
+
+# ===========================================================================
+# 12. WebSocket authentication and event scoping
+# ===========================================================================
+
+
+class TestWebSocketAuth:
+ """Tests that anonymous WebSocket clients cannot subscribe to queue rooms
+ in multiuser mode, and that queue item events are scoped to the owner +
+ admin rooms instead of being broadcast to the full queue room."""
+
+ @pytest.fixture
+ def socketio(self, mock_invoker: Invoker, monkeypatch: Any):
+ """Create a SocketIO instance wired to the mock invoker's configuration."""
+ from fastapi import FastAPI
+
+ from invokeai.app.api.sockets import SocketIO
+
+ # The SocketIO connect/sub handlers look up ApiDependencies.invoker.services.configuration.multiuser
+ # at request time. Patch it to point at the mock invoker.
+ mock_deps = MockApiDependencies(mock_invoker)
+ monkeypatch.setattr("invokeai.app.api.dependencies.ApiDependencies", mock_deps)
+
+ fastapi_app = FastAPI()
+ return SocketIO(fastapi_app)
+
+ def test_connect_rejected_without_token_in_multiuser_mode(self, socketio: Any, mock_invoker: Invoker) -> None:
+ """In multiuser mode, _handle_connect must return False when no valid token is provided."""
+ import asyncio
+
+ mock_invoker.services.configuration.multiuser = True
+
+ result = asyncio.run(socketio._handle_connect("sid-anon-1", environ={}, auth=None))
+ assert result is False
+ # The socket must not be recorded in the users dict
+ assert "sid-anon-1" not in socketio._socket_users
+
+ def test_connect_rejected_with_invalid_token_in_multiuser_mode(
+ self, socketio: Any, mock_invoker: Invoker, setup_jwt_secret: None
+ ) -> None:
+ """An invalid/garbage token in multiuser mode must still be rejected."""
+ import asyncio
+
+ mock_invoker.services.configuration.multiuser = True
+
+ result = asyncio.run(socketio._handle_connect("sid-bad-1", environ={}, auth={"token": "not-a-real-token"}))
+ assert result is False
+ assert "sid-bad-1" not in socketio._socket_users
+
+ def test_connect_accepted_without_token_in_single_user_mode(self, socketio: Any, mock_invoker: Invoker) -> None:
+ """In single-user mode, the socket handler should accept unauthenticated connections
+ as the system admin user (matching how the REST API's get_current_user_or_default behaves)."""
+ import asyncio
+
+ mock_invoker.services.configuration.multiuser = False
+
+ result = asyncio.run(socketio._handle_connect("sid-single-1", environ={}, auth=None))
+ assert result is True
+ assert socketio._socket_users["sid-single-1"]["user_id"] == "system"
+ assert socketio._socket_users["sid-single-1"]["is_admin"] is True
+
+ def test_connect_accepted_with_valid_token_in_multiuser_mode(
+ self,
+ socketio: Any,
+ mock_invoker: Invoker,
+ setup_jwt_secret: None,
+ ) -> None:
+ """A valid token in multiuser mode should be accepted with the correct user identity."""
+ import asyncio
+
+ from invokeai.app.services.auth.token_service import TokenData, create_access_token
+ from invokeai.app.services.users.users_common import UserCreateRequest
+
+ mock_invoker.services.configuration.multiuser = True
+
+ # Create the user in the database so the active-user check passes
+ user = mock_invoker.services.users.create(
+ UserCreateRequest(email="real@test.com", display_name="Real User", password="Test1234!@#$")
+ )
+ token = create_access_token(TokenData(user_id=user.user_id, email=user.email, is_admin=False))
+
+ result = asyncio.run(socketio._handle_connect("sid-good-1", environ={}, auth={"token": token}))
+ assert result is True
+ assert socketio._socket_users["sid-good-1"]["user_id"] == user.user_id
+ assert socketio._socket_users["sid-good-1"]["is_admin"] is False
+
+ def test_connect_rejected_for_deleted_user_in_multiuser_mode(
+ self, socketio: Any, mock_invoker: Invoker, setup_jwt_secret: None
+ ) -> None:
+ """A structurally valid JWT for a user that no longer exists in the database
+ must be rejected. This mirrors the REST auth check in auth_dependencies.py:53-58."""
+ import asyncio
+
+ from invokeai.app.services.auth.token_service import TokenData, create_access_token
+
+ mock_invoker.services.configuration.multiuser = True
+ # Create a token for a user_id that was never created in the user service
+ token = create_access_token(TokenData(user_id="deleted-user-999", email="gone@test.com", is_admin=False))
+
+ result = asyncio.run(socketio._handle_connect("sid-deleted-1", environ={}, auth={"token": token}))
+ assert result is False
+ assert "sid-deleted-1" not in socketio._socket_users
+
+ def test_connect_rejected_for_inactive_user_in_multiuser_mode(
+ self, socketio: Any, mock_invoker: Invoker, setup_jwt_secret: None
+ ) -> None:
+ """A structurally valid JWT for a deactivated user must be rejected even though
+ the token itself has not expired."""
+ import asyncio
+
+ from invokeai.app.services.auth.token_service import TokenData, create_access_token
+ from invokeai.app.services.users.users_common import UserCreateRequest
+
+ mock_invoker.services.configuration.multiuser = True
+
+ # Create a real user, then deactivate them
+ user = mock_invoker.services.users.create(
+ UserCreateRequest(email="inactive@test.com", display_name="Inactive", password="Test1234!@#$")
+ )
+ token = create_access_token(TokenData(user_id=user.user_id, email=user.email, is_admin=False))
+
+ # Deactivate the user
+ from invokeai.app.services.users.users_common import UserUpdateRequest
+
+ mock_invoker.services.users.update(user.user_id, UserUpdateRequest(is_active=False))
+
+ result = asyncio.run(socketio._handle_connect("sid-inactive-1", environ={}, auth={"token": token}))
+ assert result is False
+ assert "sid-inactive-1" not in socketio._socket_users
+
+ def test_sub_queue_refuses_unknown_socket_in_multiuser_mode(self, socketio: Any, mock_invoker: Invoker) -> None:
+ """If a socket somehow reaches _handle_sub_queue without a recorded identity
+ in multiuser mode (e.g. bug, race), it must be refused rather than falling back
+ to an anonymous system user who could then observe queue item events."""
+ import asyncio
+
+ mock_invoker.services.configuration.multiuser = True
+
+ # Call sub_queue without a corresponding connect — the sid is unknown.
+ asyncio.run(socketio._handle_sub_queue("sid-ghost-1", {"queue_id": "default"}))
+
+ # The ghost socket must not have been added to the internal users dict
+ assert "sid-ghost-1" not in socketio._socket_users
+
+ def test_queue_item_status_changed_has_user_id(self) -> None:
+ """QueueItemStatusChangedEvent must carry user_id so _handle_queue_event can
+ route it to the owner + admin rooms instead of the public queue room. Without
+ this field the event falls through to the generic broadcast branch and any
+ subscriber to the queue can observe cross-user queue activity."""
+ from invokeai.app.services.events.events_common import (
+ InvocationEventBase,
+ QueueItemEventBase,
+ QueueItemStatusChangedEvent,
+ )
+
+ # The event base carries a user_id field
+ assert "user_id" in QueueItemEventBase.model_fields
+ # QueueItemStatusChangedEvent inherits it
+ assert "user_id" in QueueItemStatusChangedEvent.model_fields
+ # It is NOT an InvocationEventBase (so the generic QueueItemEventBase branch
+ # in _handle_queue_event must also handle it privately)
+ assert not issubclass(QueueItemStatusChangedEvent, InvocationEventBase)
+
+ def test_batch_enqueued_event_carries_user_id(self) -> None:
+ """BatchEnqueuedEvent must carry user_id so it can be routed privately to the
+ owner and admin rooms. Otherwise a subscriber on the same queue_id would see
+ every other user's batch_id, origin and enqueued counts."""
+ from invokeai.app.services.events.events_common import BatchEnqueuedEvent
+ from invokeai.app.services.session_queue.session_queue_common import (
+ Batch,
+ EnqueueBatchResult,
+ )
+ from invokeai.app.services.shared.graph import Graph
+
+ enqueue_result = EnqueueBatchResult(
+ queue_id="default",
+ enqueued=3,
+ requested=3,
+ batch=Batch(batch_id="batch-xyz", origin="workflows", graph=Graph()),
+ priority=0,
+ item_ids=[1, 2, 3],
+ )
+ event = BatchEnqueuedEvent.build(enqueue_result, user_id="owner-123")
+ assert event.user_id == "owner-123"
+ assert event.batch_id == "batch-xyz"
+ assert event.queue_id == "default"
+
+ def test_queue_item_status_changed_routed_privately(self, socketio: Any) -> None:
+ """Verify that _handle_queue_event emits QueueItemStatusChangedEvent ONLY to
+ user:{user_id} and admin rooms, never to the queue_id room."""
+ import asyncio
+ from unittest.mock import AsyncMock
+
+ from invokeai.app.services.events.events_common import QueueItemStatusChangedEvent
+ from invokeai.app.services.session_queue.session_queue_common import (
+ BatchStatus,
+ SessionQueueStatus,
+ )
+
+ event = QueueItemStatusChangedEvent(
+ queue_id="default",
+ item_id=1,
+ batch_id="batch-private",
+ origin="workflows",
+ destination="canvas",
+ user_id="owner-xyz",
+ session_id="sess-private",
+ status="in_progress",
+ created_at="2026-01-01T00:00:00",
+ updated_at="2026-01-01T00:01:00",
+ started_at="2026-01-01T00:00:30",
+ completed_at=None,
+ batch_status=BatchStatus(
+ queue_id="default",
+ batch_id="batch-private",
+ origin="workflows",
+ destination="canvas",
+ pending=0,
+ in_progress=1,
+ completed=0,
+ failed=0,
+ canceled=0,
+ total=1,
+ ),
+ queue_status=SessionQueueStatus(
+ queue_id="default",
+ item_id=1,
+ session_id="sess-private",
+ batch_id="batch-private",
+ pending=0,
+ in_progress=1,
+ completed=0,
+ failed=0,
+ canceled=0,
+ total=1,
+ ),
+ )
+
+ mock_emit = AsyncMock()
+ socketio._sio.emit = mock_emit
+
+ asyncio.run(socketio._handle_queue_event(("queue_item_status_changed", event)))
+
+ rooms_emitted_to = [call.kwargs.get("room") for call in mock_emit.call_args_list]
+ assert "user:owner-xyz" in rooms_emitted_to
+ assert "admin" in rooms_emitted_to
+ # CRITICAL: must NOT emit to the queue_id room — that would leak to other users
+ assert "default" not in rooms_emitted_to
+
+ def test_batch_enqueued_routed_privately(self, socketio: Any) -> None:
+ """Verify that _handle_queue_event emits BatchEnqueuedEvent ONLY to
+ user:{user_id} and admin rooms, never to the queue_id room."""
+ import asyncio
+ from unittest.mock import AsyncMock
+
+ from invokeai.app.services.events.events_common import BatchEnqueuedEvent
+ from invokeai.app.services.session_queue.session_queue_common import (
+ Batch,
+ EnqueueBatchResult,
+ )
+ from invokeai.app.services.shared.graph import Graph
+
+ enqueue_result = EnqueueBatchResult(
+ queue_id="default",
+ enqueued=5,
+ requested=5,
+ batch=Batch(batch_id="batch-pvt", origin="workflows", graph=Graph()),
+ priority=0,
+ item_ids=[10, 11, 12, 13, 14],
+ )
+ event = BatchEnqueuedEvent.build(enqueue_result, user_id="owner-zzz")
+
+ mock_emit = AsyncMock()
+ socketio._sio.emit = mock_emit
+
+ asyncio.run(socketio._handle_queue_event(("batch_enqueued", event)))
+
+ rooms_emitted_to = [call.kwargs.get("room") for call in mock_emit.call_args_list]
+ assert "user:owner-zzz" in rooms_emitted_to
+ assert "admin" in rooms_emitted_to
+ assert "default" not in rooms_emitted_to
+
+ def test_queue_cleared_still_broadcast(self, socketio: Any) -> None:
+ """QueueClearedEvent does not carry user identity and should still be broadcast
+ to all queue subscribers — this is a sanity check that we haven't over-scoped."""
+ import asyncio
+ from unittest.mock import AsyncMock
+
+ from invokeai.app.services.events.events_common import QueueClearedEvent
+
+ event = QueueClearedEvent.build(queue_id="default")
+
+ mock_emit = AsyncMock()
+ socketio._sio.emit = mock_emit
+
+ asyncio.run(socketio._handle_queue_event(("queue_cleared", event)))
+
+ rooms_emitted_to = [call.kwargs.get("room") for call in mock_emit.call_args_list]
+ assert "default" in rooms_emitted_to
diff --git a/tests/app/routers/test_session_queue_sanitization.py b/tests/app/routers/test_session_queue_sanitization.py
index 1b2262d02ee..1cd7a4953db 100644
--- a/tests/app/routers/test_session_queue_sanitization.py
+++ b/tests/app/routers/test_session_queue_sanitization.py
@@ -100,11 +100,21 @@ def test_sanitize_queue_item_for_different_user(sample_session_queue_item):
# Non-admin viewing another user's item should have sanitized data
assert result.field_values is None
assert result.workflow is None
- # Session should be replaced with empty graph
+ # Session should be replaced with empty/redacted graph
assert result.session.graph.nodes is not None
assert len(result.session.graph.nodes) == 0
- # Session ID should be preserved
- assert result.session.id == "test_session"
+ assert result.session.id == "redacted"
+ # Identity and batch fields should be redacted
+ assert result.user_id == "redacted"
+ assert result.batch_id == "redacted"
+ assert result.session_id == "redacted"
+ assert result.user_display_name is None
+ assert result.user_email is None
+ assert result.origin is None
+ assert result.destination is None
+ assert result.error_type is None
+ assert result.error_message is None
+ assert result.error_traceback is None
def test_sanitize_preserves_non_sensitive_fields(sample_session_queue_item):
@@ -115,15 +125,18 @@ def test_sanitize_preserves_non_sensitive_fields(sample_session_queue_item):
is_admin=False,
)
- # These fields should be preserved
+ # Non-sensitive fields should be preserved
assert result.item_id == 1
assert result.status == "pending"
- assert result.batch_id == "batch_123"
- assert result.session_id == "session_123"
assert result.queue_id == "default"
- assert result.user_id == "user_123"
- assert result.user_display_name == "Test User"
- assert result.user_email == "test@example.com"
+ assert result.created_at is not None
+ assert result.updated_at is not None
+ # Sensitive fields should be redacted for non-owner non-admin
+ assert result.batch_id == "redacted"
+ assert result.session_id == "redacted"
+ assert result.user_id == "redacted"
+ assert result.user_display_name is None
+ assert result.user_email is None
def test_sanitize_system_user_item_for_non_admin(sample_session_queue_item):
diff --git a/tests/app/routers/test_workflows_multiuser.py b/tests/app/routers/test_workflows_multiuser.py
new file mode 100644
index 00000000000..28b301e18e3
--- /dev/null
+++ b/tests/app/routers/test_workflows_multiuser.py
@@ -0,0 +1,334 @@
+"""Tests for multiuser workflow library functionality."""
+
+import logging
+from typing import Any
+from unittest.mock import MagicMock
+
+import pytest
+from fastapi import status
+from fastapi.testclient import TestClient
+
+from invokeai.app.api.dependencies import ApiDependencies
+from invokeai.app.api_app import app
+from invokeai.app.services.config.config_default import InvokeAIAppConfig
+from invokeai.app.services.invocation_services import InvocationServices
+from invokeai.app.services.invoker import Invoker
+from invokeai.app.services.users.users_common import UserCreateRequest
+from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
+from invokeai.backend.util.logging import InvokeAILogger
+from tests.fixtures.sqlite_database import create_mock_sqlite_database
+
+
+class MockApiDependencies(ApiDependencies):
+ invoker: Invoker
+
+ def __init__(self, invoker: Invoker) -> None:
+ self.invoker = invoker
+
+
+WORKFLOW_BODY = {
+ "name": "Test Workflow",
+ "author": "",
+ "description": "A test workflow",
+ "version": "1.0.0",
+ "contact": "",
+ "tags": "",
+ "notes": "",
+ "nodes": [],
+ "edges": [],
+ "exposedFields": [],
+ "meta": {"version": "3.0.0", "category": "user"},
+ "id": None,
+ "form_fields": [],
+}
+
+
+@pytest.fixture
+def setup_jwt_secret():
+ from invokeai.app.services.auth.token_service import set_jwt_secret
+
+ set_jwt_secret("test-secret-key-for-unit-tests-only-do-not-use-in-production")
+
+
+@pytest.fixture
+def client():
+ return TestClient(app)
+
+
+@pytest.fixture
+def mock_services() -> InvocationServices:
+ from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
+ from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
+ from invokeai.app.services.boards.boards_default import BoardService
+ from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
+ from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import (
+ ClientStatePersistenceSqlite,
+ )
+ from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage
+ from invokeai.app.services.images.images_default import ImageService
+ from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
+ from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
+ from invokeai.app.services.users.users_default import UserService
+ from tests.test_nodes import TestEventService
+
+ configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
+ logger = InvokeAILogger.get_logger()
+ db = create_mock_sqlite_database(configuration, logger)
+
+ return InvocationServices(
+ board_image_records=SqliteBoardImageRecordStorage(db=db),
+ board_images=None, # type: ignore
+ board_records=SqliteBoardRecordStorage(db=db),
+ boards=BoardService(),
+ bulk_download=BulkDownloadService(),
+ configuration=configuration,
+ events=TestEventService(),
+ image_files=None, # type: ignore
+ image_records=SqliteImageRecordStorage(db=db),
+ images=ImageService(),
+ invocation_cache=MemoryInvocationCache(max_cache_size=0),
+ logger=logging, # type: ignore
+ model_images=None, # type: ignore
+ model_manager=None, # type: ignore
+ download_queue=None, # type: ignore
+ names=None, # type: ignore
+ performance_statistics=InvocationStatsService(),
+ session_processor=None, # type: ignore
+ session_queue=None, # type: ignore
+ urls=None, # type: ignore
+ workflow_records=SqliteWorkflowRecordsStorage(db=db),
+ tensors=None, # type: ignore
+ conditioning=None, # type: ignore
+ style_preset_records=None, # type: ignore
+ style_preset_image_files=None, # type: ignore
+ workflow_thumbnails=None, # type: ignore
+ model_relationship_records=None, # type: ignore
+ model_relationships=None, # type: ignore
+ client_state_persistence=ClientStatePersistenceSqlite(db=db),
+ users=UserService(db),
+ )
+
+
+def create_test_user(mock_invoker: Invoker, email: str, display_name: str, is_admin: bool = False) -> str:
+ user_service = mock_invoker.services.users
+ user_data = UserCreateRequest(email=email, display_name=display_name, password="TestPass123", is_admin=is_admin)
+ user = user_service.create(user_data)
+ return user.user_id
+
+
+def get_user_token(client: TestClient, email: str) -> str:
+ response = client.post(
+ "/api/v1/auth/login",
+ json={"email": email, "password": "TestPass123", "remember_me": False},
+ )
+ assert response.status_code == 200
+ return response.json()["token"]
+
+
+@pytest.fixture
+def enable_multiuser(monkeypatch: Any, mock_invoker: Invoker):
+ mock_invoker.services.configuration.multiuser = True
+ mock_workflow_thumbnails = MagicMock()
+ mock_workflow_thumbnails.get_url.return_value = None
+ mock_invoker.services.workflow_thumbnails = mock_workflow_thumbnails
+
+ mock_deps = MockApiDependencies(mock_invoker)
+ monkeypatch.setattr("invokeai.app.api.routers.auth.ApiDependencies", mock_deps)
+ monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", mock_deps)
+ monkeypatch.setattr("invokeai.app.api.routers.workflows.ApiDependencies", mock_deps)
+ yield
+
+
+@pytest.fixture
+def admin_token(setup_jwt_secret: None, enable_multiuser: Any, mock_invoker: Invoker, client: TestClient):
+ create_test_user(mock_invoker, "admin@test.com", "Admin", is_admin=True)
+ return get_user_token(client, "admin@test.com")
+
+
+@pytest.fixture
+def user1_token(enable_multiuser: Any, mock_invoker: Invoker, client: TestClient, admin_token: str):
+ create_test_user(mock_invoker, "user1@test.com", "User One", is_admin=False)
+ return get_user_token(client, "user1@test.com")
+
+
+@pytest.fixture
+def user2_token(enable_multiuser: Any, mock_invoker: Invoker, client: TestClient, admin_token: str):
+ create_test_user(mock_invoker, "user2@test.com", "User Two", is_admin=False)
+ return get_user_token(client, "user2@test.com")
+
+
+def create_workflow(client: TestClient, token: str) -> str:
+ response = client.post(
+ "/api/v1/workflows/",
+ json={"workflow": WORKFLOW_BODY},
+ headers={"Authorization": f"Bearer {token}"},
+ )
+ assert response.status_code == 200, response.text
+ return response.json()["workflow_id"]
+
+
+# ---------------------------------------------------------------------------
+# Auth tests
+# ---------------------------------------------------------------------------
+
+
+def test_list_workflows_requires_auth(enable_multiuser: Any, client: TestClient):
+ response = client.get("/api/v1/workflows/")
+ assert response.status_code == status.HTTP_401_UNAUTHORIZED
+
+
+def test_create_workflow_requires_auth(enable_multiuser: Any, client: TestClient):
+ response = client.post("/api/v1/workflows/", json={"workflow": WORKFLOW_BODY})
+ assert response.status_code == status.HTTP_401_UNAUTHORIZED
+
+
+# ---------------------------------------------------------------------------
+# Ownership isolation
+# ---------------------------------------------------------------------------
+
+
+def test_workflows_are_isolated_between_users(client: TestClient, user1_token: str, user2_token: str):
+ """Users should only see their own workflows in list."""
+ # user1 creates a workflow
+ create_workflow(client, user1_token)
+
+ # user1 can see it
+ r1 = client.get("/api/v1/workflows/?categories=user", headers={"Authorization": f"Bearer {user1_token}"})
+ assert r1.status_code == 200
+ assert r1.json()["total"] == 1
+
+ # user2 cannot see user1's workflow
+ r2 = client.get("/api/v1/workflows/?categories=user", headers={"Authorization": f"Bearer {user2_token}"})
+ assert r2.status_code == 200
+ assert r2.json()["total"] == 0
+
+
+def test_user_cannot_delete_another_users_workflow(client: TestClient, user1_token: str, user2_token: str):
+ workflow_id = create_workflow(client, user1_token)
+ response = client.delete(
+ f"/api/v1/workflows/i/{workflow_id}",
+ headers={"Authorization": f"Bearer {user2_token}"},
+ )
+ assert response.status_code == status.HTTP_403_FORBIDDEN
+
+
+def test_user_cannot_update_another_users_workflow(client: TestClient, user1_token: str, user2_token: str):
+ workflow_id = create_workflow(client, user1_token)
+ updated = {**WORKFLOW_BODY, "id": workflow_id, "name": "Hijacked"}
+ response = client.patch(
+ f"/api/v1/workflows/i/{workflow_id}",
+ json={"workflow": updated},
+ headers={"Authorization": f"Bearer {user2_token}"},
+ )
+ assert response.status_code == status.HTTP_403_FORBIDDEN
+
+
+def test_owner_can_delete_own_workflow(client: TestClient, user1_token: str):
+ workflow_id = create_workflow(client, user1_token)
+ response = client.delete(
+ f"/api/v1/workflows/i/{workflow_id}",
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert response.status_code == 200
+
+
+def test_admin_can_delete_any_workflow(client: TestClient, admin_token: str, user1_token: str):
+ workflow_id = create_workflow(client, user1_token)
+ response = client.delete(
+ f"/api/v1/workflows/i/{workflow_id}",
+ headers={"Authorization": f"Bearer {admin_token}"},
+ )
+ assert response.status_code == 200
+
+
+# ---------------------------------------------------------------------------
+# Shared workflow (is_public)
+# ---------------------------------------------------------------------------
+
+
+def test_update_is_public_owner_succeeds(client: TestClient, user1_token: str):
+ workflow_id = create_workflow(client, user1_token)
+ response = client.patch(
+ f"/api/v1/workflows/i/{workflow_id}/is_public",
+ json={"is_public": True},
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert response.status_code == 200
+ assert response.json()["is_public"] is True
+
+
+def test_update_is_public_other_user_forbidden(client: TestClient, user1_token: str, user2_token: str):
+ workflow_id = create_workflow(client, user1_token)
+ response = client.patch(
+ f"/api/v1/workflows/i/{workflow_id}/is_public",
+ json={"is_public": True},
+ headers={"Authorization": f"Bearer {user2_token}"},
+ )
+ assert response.status_code == status.HTTP_403_FORBIDDEN
+
+
+def test_public_workflow_visible_to_other_users(client: TestClient, user1_token: str, user2_token: str):
+ """A shared (is_public=True) workflow should appear when filtering with is_public=true."""
+ workflow_id = create_workflow(client, user1_token)
+ # Make it public
+ client.patch(
+ f"/api/v1/workflows/i/{workflow_id}/is_public",
+ json={"is_public": True},
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+
+ # user2 can see it through is_public=true filter
+ response = client.get(
+ "/api/v1/workflows/?categories=user&is_public=true",
+ headers={"Authorization": f"Bearer {user2_token}"},
+ )
+ assert response.status_code == 200
+ ids = [w["workflow_id"] for w in response.json()["items"]]
+ assert workflow_id in ids
+
+
+def test_private_workflow_not_visible_to_other_users(client: TestClient, user1_token: str, user2_token: str):
+ """A private (is_public=False) user workflow should NOT appear for another user."""
+ workflow_id = create_workflow(client, user1_token)
+
+ # user2 lists 'yours' style (their own workflows)
+ response = client.get(
+ "/api/v1/workflows/?categories=user",
+ headers={"Authorization": f"Bearer {user2_token}"},
+ )
+ assert response.status_code == 200
+ ids = [w["workflow_id"] for w in response.json()["items"]]
+ assert workflow_id not in ids
+
+
+def test_public_workflow_still_in_owners_list(client: TestClient, user1_token: str):
+ """A shared workflow should still appear in the owner's own workflow list."""
+ workflow_id = create_workflow(client, user1_token)
+ client.patch(
+ f"/api/v1/workflows/i/{workflow_id}/is_public",
+ json={"is_public": True},
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+
+ # owner's 'yours' list (no is_public filter)
+ response = client.get(
+ "/api/v1/workflows/?categories=user",
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert response.status_code == 200
+ ids = [w["workflow_id"] for w in response.json()["items"]]
+ assert workflow_id in ids
+
+
+def test_workflow_has_user_id_and_is_public_fields(client: TestClient, user1_token: str):
+ """Created workflow should return user_id and is_public fields."""
+ response = client.post(
+ "/api/v1/workflows/",
+ json={"workflow": WORKFLOW_BODY},
+ headers={"Authorization": f"Bearer {user1_token}"},
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert "user_id" in data
+ assert "is_public" in data
+ assert data["is_public"] is False