diff --git a/invokeai/app/api/routers/auth.py b/invokeai/app/api/routers/auth.py index 36aeabda822..e0b0c885cd2 100644 --- a/invokeai/app/api/routers/auth.py +++ b/invokeai/app/api/routers/auth.py @@ -80,6 +80,7 @@ class SetupStatusResponse(BaseModel): setup_required: bool = Field(description="Whether initial setup is required") multiuser_enabled: bool = Field(description="Whether multiuser mode is enabled") strict_password_checking: bool = Field(description="Whether strict password requirements are enforced") + admin_email: str | None = Field(default=None, description="Email of the first active admin user, if any") @auth_router.get("/status", response_model=SetupStatusResponse) @@ -94,15 +95,25 @@ async def get_setup_status() -> SetupStatusResponse: # If multiuser is disabled, setup is never required if not config.multiuser: return SetupStatusResponse( - setup_required=False, multiuser_enabled=False, strict_password_checking=config.strict_password_checking + setup_required=False, + multiuser_enabled=False, + strict_password_checking=config.strict_password_checking, + admin_email=None, ) # In multiuser mode, check if an admin exists user_service = ApiDependencies.invoker.services.users setup_required = not user_service.has_admin() + # Only expose admin_email during initial setup to avoid leaking + # administrator identity on public deployments. + admin_email = user_service.get_admin_email() if setup_required else None + return SetupStatusResponse( - setup_required=setup_required, multiuser_enabled=True, strict_password_checking=config.strict_password_checking + setup_required=setup_required, + multiuser_enabled=True, + strict_password_checking=config.strict_password_checking, + admin_email=admin_email, ) diff --git a/invokeai/app/api/routers/board_images.py b/invokeai/app/api/routers/board_images.py index cb5e0ab51ab..f94e4f2437c 100644 --- a/invokeai/app/api/routers/board_images.py +++ b/invokeai/app/api/routers/board_images.py @@ -1,12 +1,53 @@ from fastapi import Body, HTTPException from fastapi.routing import APIRouter +from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"]) +def _assert_board_write_access(board_id: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user may not mutate the given board. + + Write access is granted when ANY of these hold: + - The user is an admin. + - The user owns the board. + - The board visibility is Public (public boards accept contributions from any user). + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + except Exception: + raise HTTPException(status_code=404, detail="Board not found") + if current_user.is_admin: + return + if board.user_id == current_user.user_id: + return + if board.board_visibility == BoardVisibility.Public: + return + raise HTTPException(status_code=403, detail="Not authorized to modify this board") + + +def _assert_image_direct_owner(image_name: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user is not the direct owner of the image. + + This is intentionally stricter than _assert_image_owner in images.py: + board ownership is NOT sufficient here. Allowing a user to add someone + else's image to their own board would grant them mutation rights via the + board-ownership fallback in _assert_image_owner, escalating read access + into write access. + """ + if current_user.is_admin: + return + owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name) + if owner is not None and owner == current_user.user_id: + return + raise HTTPException(status_code=403, detail="Not authorized to move this image") + + @board_images_router.post( "/", operation_id="add_image_to_board", @@ -17,14 +58,17 @@ response_model=AddImagesToBoardResult, ) async def add_image_to_board( + current_user: CurrentUserOrDefault, board_id: str = Body(description="The id of the board to add to"), image_name: str = Body(description="The name of the image to add"), ) -> AddImagesToBoardResult: """Creates a board_image""" + _assert_board_write_access(board_id, current_user) + _assert_image_direct_owner(image_name, current_user) try: added_images: set[str] = set() affected_boards: set[str] = set() - old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" + old_board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none" ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name) added_images.add(image_name) affected_boards.add(board_id) @@ -48,13 +92,16 @@ async def add_image_to_board( response_model=RemoveImagesFromBoardResult, ) async def remove_image_from_board( + current_user: CurrentUserOrDefault, image_name: str = Body(description="The name of the image to remove", embed=True), ) -> RemoveImagesFromBoardResult: """Removes an image from its board, if it had one""" try: + old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" + if old_board_id != "none": + _assert_board_write_access(old_board_id, current_user) removed_images: set[str] = set() affected_boards: set[str] = set() - old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name) removed_images.add(image_name) affected_boards.add("none") @@ -64,6 +111,8 @@ async def remove_image_from_board( affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to remove image from board") @@ -78,16 +127,21 @@ async def remove_image_from_board( response_model=AddImagesToBoardResult, ) async def add_images_to_board( + current_user: CurrentUserOrDefault, board_id: str = Body(description="The id of the board to add to"), image_names: list[str] = Body(description="The names of the images to add", embed=True), ) -> AddImagesToBoardResult: """Adds a list of images to a board""" + _assert_board_write_access(board_id, current_user) try: added_images: set[str] = set() affected_boards: set[str] = set() for image_name in image_names: try: - old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" + _assert_image_direct_owner(image_name, current_user) + old_board_id = ( + ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none" + ) ApiDependencies.invoker.services.board_images.add_image_to_board( board_id=board_id, image_name=image_name, @@ -96,12 +150,16 @@ async def add_images_to_board( affected_boards.add(board_id) affected_boards.add(old_board_id) + except HTTPException: + raise except Exception: pass return AddImagesToBoardResult( added_images=list(added_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to add images to board") @@ -116,6 +174,7 @@ async def add_images_to_board( response_model=RemoveImagesFromBoardResult, ) async def remove_images_from_board( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(description="The names of the images to remove", embed=True), ) -> RemoveImagesFromBoardResult: """Removes a list of images from their board, if they had one""" @@ -125,15 +184,21 @@ async def remove_images_from_board( for image_name in image_names: try: old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" + if old_board_id != "none": + _assert_board_write_access(old_board_id, current_user) ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name) removed_images.add(image_name) affected_boards.add("none") affected_boards.add(old_board_id) + except HTTPException: + raise except Exception: pass return RemoveImagesFromBoardResult( removed_images=list(removed_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to remove images from board") diff --git a/invokeai/app/api/routers/boards.py b/invokeai/app/api/routers/boards.py index e93bb8b2a9b..6897e90aff4 100644 --- a/invokeai/app/api/routers/boards.py +++ b/invokeai/app/api/routers/boards.py @@ -6,7 +6,7 @@ from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies -from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy +from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy, BoardVisibility from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.pagination import OffsetPaginatedResults @@ -56,7 +56,14 @@ async def get_board( except Exception: raise HTTPException(status_code=404, detail="Board not found") - if not current_user.is_admin and result.user_id != current_user.user_id: + # Admins can access any board. + # Owners can access their own boards. + # Shared and public boards are visible to all authenticated users. + if ( + not current_user.is_admin + and result.user_id != current_user.user_id + and result.board_visibility == BoardVisibility.Private + ): raise HTTPException(status_code=403, detail="Not authorized to access this board") return result @@ -188,7 +195,11 @@ async def list_all_board_image_names( except Exception: raise HTTPException(status_code=404, detail="Board not found") - if not current_user.is_admin and board.user_id != current_user.user_id: + if ( + not current_user.is_admin + and board.user_id != current_user.user_id + and board.board_visibility == BoardVisibility.Private + ): raise HTTPException(status_code=403, detail="Not authorized to access this board") image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board( @@ -196,4 +207,15 @@ async def list_all_board_image_names( categories, is_intermediate, ) + + # For uncategorized images (board_id="none"), filter to only the caller's + # images so that one user cannot enumerate another's uncategorized images. + # Admin users can see all uncategorized images. + if board_id == "none" and not current_user.is_admin: + image_names = [ + name + for name in image_names + if ApiDependencies.invoker.services.image_records.get_user_id(name) == current_user.user_id + ] + return image_names diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 6b11762c9ec..a3ae6fce82b 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -38,6 +38,96 @@ IMAGE_MAX_AGE = 31536000 +def _assert_image_owner(image_name: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user does not own the image and is not an admin. + + Ownership is satisfied when ANY of these hold: + - The user is an admin. + - The user is the image's direct owner (image_records.user_id). + - The user owns the board the image sits on. + - The image sits on a Public board (public boards grant mutation rights). + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + if current_user.is_admin: + return + owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name) + if owner is not None and owner == current_user.user_id: + return + + # Check whether the user owns the board the image belongs to, + # or the board is Public (public boards grant mutation rights). + board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) + if board_id is not None: + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + if board.user_id == current_user.user_id: + return + if board.board_visibility == BoardVisibility.Public: + return + except Exception: + pass + + raise HTTPException(status_code=403, detail="Not authorized to modify this image") + + +def _assert_image_read_access(image_name: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user may not view the image. + + Access is granted when ANY of these hold: + - The user is an admin. + - The user owns the image. + - The image sits on a shared or public board. + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + if current_user.is_admin: + return + + owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name) + if owner is not None and owner == current_user.user_id: + return + + # Check whether the image's board makes it visible to other users. + board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) + if board_id is not None: + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public): + return + except Exception: + pass + + raise HTTPException(status_code=403, detail="Not authorized to access this image") + + +def _assert_board_read_access(board_id: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user may not read images from this board. + + Access is granted when ANY of these hold: + - The user is an admin. + - The user owns the board. + - The board visibility is Shared or Public. + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + if current_user.is_admin: + return + + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + except Exception: + raise HTTPException(status_code=404, detail="Board not found") + + if board.user_id == current_user.user_id: + return + + if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public): + return + + raise HTTPException(status_code=403, detail="Not authorized to access this board") + + class ResizeToDimensions(BaseModel): width: int = Field(..., gt=0) height: int = Field(..., gt=0) @@ -83,6 +173,22 @@ async def upload_image( ), ) -> ImageDTO: """Uploads an image for the current user""" + # If uploading into a board, verify the user has write access. + # Public boards allow uploads from any authenticated user. + if board_id is not None: + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + except Exception: + raise HTTPException(status_code=404, detail="Board not found") + if ( + not current_user.is_admin + and board.user_id != current_user.user_id + and board.board_visibility != BoardVisibility.Public + ): + raise HTTPException(status_code=403, detail="Not authorized to upload to this board") + if not file.content_type or not file.content_type.startswith("image"): raise HTTPException(status_code=415, detail="Not an image") @@ -165,9 +271,11 @@ async def create_image_upload_entry( @images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult) async def delete_image( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of the image to delete"), ) -> DeleteImagesResult: """Deletes an image""" + _assert_image_owner(image_name, current_user) deleted_images: set[str] = set() affected_boards: set[str] = set() @@ -189,26 +297,31 @@ async def delete_image( @images_router.delete("/intermediates", operation_id="clear_intermediates") -async def clear_intermediates() -> int: - """Clears all intermediates""" +async def clear_intermediates( + current_user: CurrentUserOrDefault, +) -> int: + """Clears all intermediates. Requires admin.""" + if not current_user.is_admin: + raise HTTPException(status_code=403, detail="Only admins can clear all intermediates") try: count_deleted = ApiDependencies.invoker.services.images.delete_intermediates() return count_deleted except Exception: raise HTTPException(status_code=500, detail="Failed to clear intermediates") - pass @images_router.get("/intermediates", operation_id="get_intermediates_count") -async def get_intermediates_count() -> int: - """Gets the count of intermediate images""" +async def get_intermediates_count( + current_user: CurrentUserOrDefault, +) -> int: + """Gets the count of intermediate images. Non-admin users only see their own intermediates.""" try: - return ApiDependencies.invoker.services.images.get_intermediates_count() + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.images.get_intermediates_count(user_id=user_id) except Exception: raise HTTPException(status_code=500, detail="Failed to get intermediates") - pass @images_router.patch( @@ -217,10 +330,12 @@ async def get_intermediates_count() -> int: response_model=ImageDTO, ) async def update_image( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of the image to update"), image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"), ) -> ImageDTO: """Updates an image""" + _assert_image_owner(image_name, current_user) try: return ApiDependencies.invoker.services.images.update(image_name, image_changes) @@ -234,9 +349,11 @@ async def update_image( response_model=ImageDTO, ) async def get_image_dto( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of image to get"), ) -> ImageDTO: """Gets an image's DTO""" + _assert_image_read_access(image_name, current_user) try: return ApiDependencies.invoker.services.images.get_dto(image_name) @@ -250,9 +367,11 @@ async def get_image_dto( response_model=Optional[MetadataField], ) async def get_image_metadata( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of image to get"), ) -> Optional[MetadataField]: """Gets an image's metadata""" + _assert_image_read_access(image_name, current_user) try: return ApiDependencies.invoker.services.images.get_metadata(image_name) @@ -269,8 +388,11 @@ class WorkflowAndGraphResponse(BaseModel): "/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse ) async def get_image_workflow( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of image whose workflow to get"), ) -> WorkflowAndGraphResponse: + _assert_image_read_access(image_name, current_user) + try: workflow = ApiDependencies.invoker.services.images.get_workflow(image_name) graph = ApiDependencies.invoker.services.images.get_graph(image_name) @@ -306,8 +428,12 @@ async def get_image_workflow( async def get_image_full( image_name: str = Path(description="The name of full-resolution image file to get"), ) -> Response: - """Gets a full-resolution image file""" + """Gets a full-resolution image file. + This endpoint is intentionally unauthenticated because browsers load images + via tags which cannot send Bearer tokens. Image names are UUIDs, + providing security through unguessability. + """ try: path = ApiDependencies.invoker.services.images.get_path(image_name) with open(path, "rb") as f: @@ -335,8 +461,12 @@ async def get_image_full( async def get_image_thumbnail( image_name: str = Path(description="The name of thumbnail image file to get"), ) -> Response: - """Gets a thumbnail image file""" + """Gets a thumbnail image file. + This endpoint is intentionally unauthenticated because browsers load images + via tags which cannot send Bearer tokens. Image names are UUIDs, + providing security through unguessability. + """ try: path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True) with open(path, "rb") as f: @@ -354,9 +484,11 @@ async def get_image_thumbnail( response_model=ImageUrlsDTO, ) async def get_image_urls( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of the image whose URL to get"), ) -> ImageUrlsDTO: """Gets an image and thumbnail URL""" + _assert_image_read_access(image_name, current_user) try: image_url = ApiDependencies.invoker.services.images.get_url(image_name) @@ -392,6 +524,11 @@ async def list_image_dtos( ) -> OffsetPaginatedResults[ImageDTO]: """Gets a list of image DTOs for the current user""" + # Validate that the caller can read from this board before listing its images. + # "none" is a sentinel for uncategorized images and is handled by the SQL layer. + if board_id is not None and board_id != "none": + _assert_board_read_access(board_id, current_user) + image_dtos = ApiDependencies.invoker.services.images.get_many( offset, limit, @@ -410,6 +547,7 @@ async def list_image_dtos( @images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult) async def delete_images_from_list( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(description="The list of names of images to delete", embed=True), ) -> DeleteImagesResult: try: @@ -417,24 +555,31 @@ async def delete_images_from_list( affected_boards: set[str] = set() for image_name in image_names: try: + _assert_image_owner(image_name, current_user) image_dto = ApiDependencies.invoker.services.images.get_dto(image_name) board_id = image_dto.board_id or "none" ApiDependencies.invoker.services.images.delete(image_name) deleted_images.add(image_name) affected_boards.add(board_id) + except HTTPException: + raise except Exception: pass return DeleteImagesResult( deleted_images=list(deleted_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to delete images") @images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult) -async def delete_uncategorized_images() -> DeleteImagesResult: - """Deletes all images that are uncategorized""" +async def delete_uncategorized_images( + current_user: CurrentUserOrDefault, +) -> DeleteImagesResult: + """Deletes all uncategorized images owned by the current user (or all if admin)""" image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board( board_id="none", categories=None, is_intermediate=None @@ -445,9 +590,13 @@ async def delete_uncategorized_images() -> DeleteImagesResult: affected_boards: set[str] = set() for image_name in image_names: try: + _assert_image_owner(image_name, current_user) ApiDependencies.invoker.services.images.delete(image_name) deleted_images.add(image_name) affected_boards.add("none") + except HTTPException: + # Skip images not owned by the current user + pass except Exception: pass return DeleteImagesResult( @@ -464,6 +613,7 @@ class ImagesUpdatedFromListResult(BaseModel): @images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult) async def star_images_in_list( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(description="The list of names of images to star", embed=True), ) -> StarredImagesResult: try: @@ -471,23 +621,29 @@ async def star_images_in_list( affected_boards: set[str] = set() for image_name in image_names: try: + _assert_image_owner(image_name, current_user) updated_image_dto = ApiDependencies.invoker.services.images.update( image_name, changes=ImageRecordChanges(starred=True) ) starred_images.add(image_name) affected_boards.add(updated_image_dto.board_id or "none") + except HTTPException: + raise except Exception: pass return StarredImagesResult( starred_images=list(starred_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to star images") @images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult) async def unstar_images_in_list( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(description="The list of names of images to unstar", embed=True), ) -> UnstarredImagesResult: try: @@ -495,17 +651,22 @@ async def unstar_images_in_list( affected_boards: set[str] = set() for image_name in image_names: try: + _assert_image_owner(image_name, current_user) updated_image_dto = ApiDependencies.invoker.services.images.update( image_name, changes=ImageRecordChanges(starred=False) ) unstarred_images.add(image_name) affected_boards.add(updated_image_dto.board_id or "none") + except HTTPException: + raise except Exception: pass return UnstarredImagesResult( unstarred_images=list(unstarred_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to unstar images") @@ -523,6 +684,7 @@ class ImagesDownloaded(BaseModel): "/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202 ) async def download_images_from_list( + current_user: CurrentUserOrDefault, background_tasks: BackgroundTasks, image_names: Optional[list[str]] = Body( default=None, description="The list of names of images to download", embed=True @@ -533,6 +695,16 @@ async def download_images_from_list( ) -> ImagesDownloaded: if (image_names is None or len(image_names) == 0) and board_id is None: raise HTTPException(status_code=400, detail="No images or board id specified.") + + # Validate that the caller can read every image they are requesting. + # For a board_id request, check board visibility; for explicit image names, + # check each image individually. + if board_id: + _assert_board_read_access(board_id, current_user) + if image_names: + for name in image_names: + _assert_image_read_access(name, current_user) + bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id) background_tasks.add_task( @@ -540,6 +712,7 @@ async def download_images_from_list( image_names, board_id, bulk_download_item_id, + current_user.user_id, ) return ImagesDownloaded(bulk_download_item_name=bulk_download_item_id + ".zip") @@ -558,11 +731,21 @@ async def download_images_from_list( }, ) async def get_bulk_download_item( + current_user: CurrentUserOrDefault, background_tasks: BackgroundTasks, bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"), ) -> FileResponse: - """Gets a bulk download zip file""" + """Gets a bulk download zip file. + + Requires authentication. The caller must be the user who initiated the + download (tracked by the bulk download service) or an admin. + """ try: + # Verify the caller owns this download (or is an admin) + owner = ApiDependencies.invoker.services.bulk_download.get_owner(bulk_download_item_name) + if owner is not None and owner != current_user.user_id and not current_user.is_admin: + raise HTTPException(status_code=403, detail="Not authorized to access this download") + path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name) response = FileResponse( @@ -574,6 +757,8 @@ async def get_bulk_download_item( response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}" background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name) return response + except HTTPException: + raise except Exception: raise HTTPException(status_code=404) @@ -594,6 +779,10 @@ async def get_image_names( ) -> ImageNamesResult: """Gets ordered list of image names with metadata for optimistic updates""" + # Validate that the caller can read from this board before listing its images. + if board_id is not None and board_id != "none": + _assert_board_read_access(board_id, current_user) + try: result = ApiDependencies.invoker.services.images.get_image_names( starred_first=starred_first, @@ -617,6 +806,7 @@ async def get_image_names( responses={200: {"model": list[ImageDTO]}}, ) async def get_images_by_names( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"), ) -> list[ImageDTO]: """Gets image DTOs for the specified image names. Maintains order of input names.""" @@ -628,8 +818,12 @@ async def get_images_by_names( image_dtos: list[ImageDTO] = [] for name in image_names: try: + _assert_image_read_access(name, current_user) dto = image_service.get_dto(name) image_dtos.append(dto) + except HTTPException: + # Skip images the user is not authorized to view + continue except Exception: # Skip missing images - they may have been deleted between name fetch and DTO fetch continue diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 65b059ecfce..822d9655fe7 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -858,7 +858,7 @@ def generate_html(title: str, heading: str, repo_id: str, is_error: bool, messag "/install", operation_id="list_model_installs", ) -async def list_model_installs() -> List[ModelInstallJob]: +async def list_model_installs(current_admin: AdminUserOrDefault) -> List[ModelInstallJob]: """Return the list of model install jobs. Install jobs have a numeric `id`, a `status`, and other fields that provide information on @@ -890,7 +890,9 @@ async def list_model_installs() -> List[ModelInstallJob]: 404: {"description": "No such job"}, }, ) -async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: +async def get_model_install_job( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install id") +) -> ModelInstallJob: """ Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' for information on the format of the return value. @@ -933,7 +935,9 @@ async def cancel_model_install_job( }, status_code=201, ) -async def pause_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob: +async def pause_model_install_job( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID") +) -> ModelInstallJob: """Pause the model install job corresponding to the given job ID.""" installer = ApiDependencies.invoker.services.model_manager.install try: @@ -953,7 +957,9 @@ async def pause_model_install_job(id: int = Path(description="Model install job }, status_code=201, ) -async def resume_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob: +async def resume_model_install_job( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID") +) -> ModelInstallJob: """Resume a paused model install job corresponding to the given job ID.""" installer = ApiDependencies.invoker.services.model_manager.install try: @@ -973,7 +979,9 @@ async def resume_model_install_job(id: int = Path(description="Model install job }, status_code=201, ) -async def restart_failed_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob: +async def restart_failed_model_install_job( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID") +) -> ModelInstallJob: """Restart failed or non-resumable file downloads for the given job.""" installer = ApiDependencies.invoker.services.model_manager.install try: @@ -994,6 +1002,7 @@ async def restart_failed_model_install_job(id: int = Path(description="Model ins status_code=201, ) async def restart_model_install_file( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID"), file_source: AnyHttpUrl = Body(description="File download URL to restart"), ) -> ModelInstallJob: @@ -1305,7 +1314,7 @@ class DeleteOrphanedModelsResponse(BaseModel): operation_id="get_orphaned_models", response_model=list[OrphanedModelInfo], ) -async def get_orphaned_models() -> list[OrphanedModelInfo]: +async def get_orphaned_models(_: AdminUserOrDefault) -> list[OrphanedModelInfo]: """Find orphaned model directories. Orphaned models are directories in the models folder that contain model files @@ -1332,7 +1341,9 @@ async def get_orphaned_models() -> list[OrphanedModelInfo]: operation_id="delete_orphaned_models", response_model=DeleteOrphanedModelsResponse, ) -async def delete_orphaned_models(request: DeleteOrphanedModelsRequest) -> DeleteOrphanedModelsResponse: +async def delete_orphaned_models( + request: DeleteOrphanedModelsRequest, _: AdminUserOrDefault +) -> DeleteOrphanedModelsResponse: """Delete specified orphaned model directories. Args: diff --git a/invokeai/app/api/routers/recall_parameters.py b/invokeai/app/api/routers/recall_parameters.py index 0af3fd29b0c..ec08adba2e8 100644 --- a/invokeai/app/api/routers/recall_parameters.py +++ b/invokeai/app/api/routers/recall_parameters.py @@ -7,6 +7,7 @@ from fastapi.routing import APIRouter from pydantic import BaseModel, ConfigDict, Field +from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies from invokeai.backend.image_util.controlnet_processor import process_controlnet_image from invokeai.backend.model_manager.taxonomy import ModelType @@ -291,12 +292,58 @@ def resolve_ip_adapter_models(ip_adapters: list[IPAdapterRecallParameter]) -> li return resolved_adapters +def _assert_recall_image_access(parameters: "RecallParameter", current_user: CurrentUserOrDefault) -> None: + """Validate that the caller can read every image referenced in the recall parameters. + + Control layers and IP adapters may reference image_name fields. Without this + check an attacker who knows another user's image UUID could use the recall + endpoint to extract image dimensions and — for ControlNet preprocessors — mint + a derived processed image they can then fetch. + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + image_names: list[str] = [] + if parameters.control_layers: + for layer in parameters.control_layers: + if layer.image_name is not None: + image_names.append(layer.image_name) + if parameters.ip_adapters: + for adapter in parameters.ip_adapters: + if adapter.image_name is not None: + image_names.append(adapter.image_name) + + if not image_names: + return + + # Admin can access all images + if current_user.is_admin: + return + + for image_name in image_names: + owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name) + if owner is not None and owner == current_user.user_id: + continue + + # Check board visibility + board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) + if board_id is not None: + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public): + continue + except Exception: + pass + + raise HTTPException(status_code=403, detail=f"Not authorized to access image {image_name}") + + @recall_parameters_router.post( "/{queue_id}", operation_id="update_recall_parameters", response_model=dict[str, Any], ) async def update_recall_parameters( + current_user: CurrentUserOrDefault, queue_id: str = Path(..., description="The queue id to perform this operation on"), parameters: RecallParameter = Body(..., description="Recall parameters to update"), ) -> dict[str, Any]: @@ -328,6 +375,10 @@ async def update_recall_parameters( """ logger = ApiDependencies.invoker.services.logger + # Validate image access before processing — prevents information leakage + # (dimensions) and derived-image minting via ControlNet preprocessors. + _assert_recall_image_access(parameters, current_user) + try: # Get only the parameters that were actually provided (non-None values) provided_params = {k: v for k, v in parameters.model_dump().items() if v is not None} @@ -335,14 +386,14 @@ async def update_recall_parameters( if not provided_params: return {"status": "no_parameters_provided", "updated_count": 0} - # Store each parameter in client state using a consistent key format + # Store each parameter in client state scoped to the current user updated_count = 0 for param_key, param_value in provided_params.items(): # Convert parameter values to JSON strings for storage value_str = json.dumps(param_value) try: ApiDependencies.invoker.services.client_state_persistence.set_by_key( - queue_id, f"recall_{param_key}", value_str + current_user.user_id, f"recall_{param_key}", value_str ) updated_count += 1 except Exception as e: @@ -396,7 +447,9 @@ async def update_recall_parameters( logger.info( f"Emitting recall_parameters_updated event for queue {queue_id} with {len(provided_params)} parameters" ) - ApiDependencies.invoker.services.events.emit_recall_parameters_updated(queue_id, provided_params) + ApiDependencies.invoker.services.events.emit_recall_parameters_updated( + queue_id, current_user.user_id, provided_params + ) logger.info("Successfully emitted recall_parameters_updated event") except Exception as e: logger.error(f"Error emitting recall parameters event: {e}", exc_info=True) @@ -425,6 +478,7 @@ async def update_recall_parameters( response_model=dict[str, Any], ) async def get_recall_parameters( + current_user: CurrentUserOrDefault, queue_id: str = Path(..., description="The queue id to retrieve parameters for"), ) -> dict[str, Any]: """ diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py index 403e7727cb4..41a5a411c7a 100644 --- a/invokeai/app/api/routers/session_queue.py +++ b/invokeai/app/api/routers/session_queue.py @@ -44,7 +44,8 @@ def sanitize_queue_item_for_user( """Sanitize queue item for non-admin users viewing other users' items. For non-admin users viewing queue items belonging to other users, - the field_values, session graph, and workflow should be hidden/cleared to protect privacy. + only timestamps, status, and error information are exposed. All other + fields (user identity, generation parameters, graphs, workflows) are stripped. Args: queue_item: The queue item to sanitize @@ -58,15 +59,25 @@ def sanitize_queue_item_for_user( if is_admin or queue_item.user_id == current_user_id: return queue_item - # For non-admins viewing other users' items, clear sensitive fields - # Create a shallow copy to avoid mutating the original + # For non-admins viewing other users' items, strip everything except + # item_id, queue_id, status, and timestamps sanitized_item = queue_item.model_copy(deep=False) + sanitized_item.user_id = "redacted" + sanitized_item.user_display_name = None + sanitized_item.user_email = None + sanitized_item.batch_id = "redacted" + sanitized_item.session_id = "redacted" + sanitized_item.origin = None + sanitized_item.destination = None + sanitized_item.priority = 0 sanitized_item.field_values = None + sanitized_item.retried_from_item_id = None sanitized_item.workflow = None - # Clear the session graph by replacing it with an empty graph execution state - # This prevents information leakage through the generation graph + sanitized_item.error_type = None + sanitized_item.error_message = None + sanitized_item.error_traceback = None sanitized_item.session = GraphExecutionState( - id=queue_item.session.id, + id="redacted", graph=Graph(), ) return sanitized_item @@ -126,12 +137,16 @@ async def list_all_queue_items( }, ) async def get_queue_item_ids( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"), ) -> ItemIdsResult: - """Gets all queue item ids that match the given parameters""" + """Gets all queue item ids that match the given parameters. Non-admin users only see their own items.""" try: - return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir) + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.session_queue.get_queue_item_ids( + queue_id=queue_id, order_dir=order_dir, user_id=user_id + ) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}") @@ -376,11 +391,15 @@ async def prune( }, ) async def get_current_queue_item( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), ) -> Optional[SessionQueueItem]: """Gets the currently execution queue item""" try: - return ApiDependencies.invoker.services.session_queue.get_current(queue_id) + item = ApiDependencies.invoker.services.session_queue.get_current(queue_id) + if item is not None: + item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin) + return item except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}") @@ -393,11 +412,15 @@ async def get_current_queue_item( }, ) async def get_next_queue_item( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), ) -> Optional[SessionQueueItem]: """Gets the next queue item, without executing it""" try: - return ApiDependencies.invoker.services.session_queue.get_next(queue_id) + item = ApiDependencies.invoker.services.session_queue.get_next(queue_id) + if item is not None: + item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin) + return item except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}") @@ -413,9 +436,10 @@ async def get_queue_status( current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), ) -> SessionQueueAndProcessorStatus: - """Gets the status of the session queue""" + """Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it.""" try: - queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=current_user.user_id) + user_id = None if current_user.is_admin else current_user.user_id + queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=user_id) processor = ApiDependencies.invoker.services.session_processor.get_status() return SessionQueueAndProcessorStatus(queue=queue, processor=processor) except Exception as e: @@ -430,12 +454,16 @@ async def get_queue_status( }, ) async def get_batch_status( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), batch_id: str = Path(description="The batch to get the status of"), ) -> BatchStatus: - """Gets the status of the session queue""" + """Gets the status of a batch. Non-admin users only see their own batches.""" try: - return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id) + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.session_queue.get_batch_status( + queue_id=queue_id, batch_id=batch_id, user_id=user_id + ) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}") @@ -529,13 +557,15 @@ async def cancel_queue_item( responses={200: {"model": SessionQueueCountsByDestination}}, ) async def counts_by_destination( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to query"), destination: str = Query(description="The destination to query"), ) -> SessionQueueCountsByDestination: - """Gets the counts of queue items by destination""" + """Gets the counts of queue items by destination. Non-admin users only see their own items.""" try: + user_id = None if current_user.is_admin else current_user.user_id return ApiDependencies.invoker.services.session_queue.get_counts_by_destination( - queue_id=queue_id, destination=destination + queue_id=queue_id, destination=destination, user_id=user_id ) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}") diff --git a/invokeai/app/api/routers/workflows.py b/invokeai/app/api/routers/workflows.py index 72d50a416b4..1c88a77a3f6 100644 --- a/invokeai/app/api/routers/workflows.py +++ b/invokeai/app/api/routers/workflows.py @@ -6,6 +6,7 @@ from fastapi.responses import FileResponse from PIL import Image +from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection @@ -33,16 +34,25 @@ }, ) async def get_workflow( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to get"), ) -> WorkflowRecordWithThumbnailDTO: """Gets a workflow""" try: - thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id) workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id) - return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump()) except WorkflowNotFoundError: raise HTTPException(status_code=404, detail="Workflow not found") + config = ApiDependencies.invoker.services.configuration + if config.multiuser: + is_default = workflow.workflow.meta.category is WorkflowCategory.Default + is_owner = workflow.user_id == current_user.user_id + if not (is_default or is_owner or workflow.is_public or current_user.is_admin): + raise HTTPException(status_code=403, detail="Not authorized to access this workflow") + + thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id) + return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump()) + @workflows_router.patch( "/i/{workflow_id}", @@ -52,10 +62,21 @@ async def get_workflow( }, ) async def update_workflow( + current_user: CurrentUserOrDefault, workflow: Workflow = Body(description="The updated workflow", embed=True), ) -> WorkflowRecordDTO: """Updates a workflow""" - return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow) + config = ApiDependencies.invoker.services.configuration + if config.multiuser: + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow.id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + if not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + # Pass user_id for defense-in-depth SQL scoping; admins pass None to allow any. + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow, user_id=user_id) @workflows_router.delete( @@ -63,15 +84,25 @@ async def update_workflow( operation_id="delete_workflow", ) async def delete_workflow( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to delete"), ) -> None: """Deletes a workflow""" + config = ApiDependencies.invoker.services.configuration + if config.multiuser: + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + if not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to delete this workflow") try: ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id) except WorkflowThumbnailFileNotFoundException: # It's OK if the workflow has no thumbnail file. We can still delete the workflow. pass - ApiDependencies.invoker.services.workflow_records.delete(workflow_id) + user_id = None if current_user.is_admin else current_user.user_id + ApiDependencies.invoker.services.workflow_records.delete(workflow_id, user_id=user_id) @workflows_router.post( @@ -82,10 +113,11 @@ async def delete_workflow( }, ) async def create_workflow( + current_user: CurrentUserOrDefault, workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True), ) -> WorkflowRecordDTO: """Creates a workflow""" - return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow) + return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow, user_id=current_user.user_id) @workflows_router.get( @@ -96,6 +128,7 @@ async def create_workflow( }, ) async def list_workflows( + current_user: CurrentUserOrDefault, page: int = Query(default=0, description="The page to get"), per_page: Optional[int] = Query(default=None, description="The number of workflows per page"), order_by: WorkflowRecordOrderBy = Query( @@ -106,8 +139,19 @@ async def list_workflows( tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"), query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"), has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]: """Gets a page of workflows""" + config = ApiDependencies.invoker.services.configuration + + # In multiuser mode, scope user-category workflows to the current user unless fetching shared workflows + user_id_filter: Optional[str] = None + if config.multiuser: + # Only filter 'user' category results by user_id when not explicitly listing public workflows + has_user_category = not categories or WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id + workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = [] workflows = ApiDependencies.invoker.services.workflow_records.get_many( order_by=order_by, @@ -118,6 +162,8 @@ async def list_workflows( categories=categories, tags=tags, has_been_opened=has_been_opened, + user_id=user_id_filter, + is_public=is_public, ) for workflow in workflows.items: workflows_with_thumbnails.append( @@ -143,15 +189,20 @@ async def list_workflows( }, ) async def set_workflow_thumbnail( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to update"), image: UploadFile = File(description="The image file to upload"), ): """Sets a workflow's thumbnail image""" try: - ApiDependencies.invoker.services.workflow_records.get(workflow_id) + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) except WorkflowNotFoundError: raise HTTPException(status_code=404, detail="Workflow not found") + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + if not image.content_type or not image.content_type.startswith("image"): raise HTTPException(status_code=415, detail="Not an image") @@ -177,14 +228,19 @@ async def set_workflow_thumbnail( }, ) async def delete_workflow_thumbnail( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to update"), ): """Removes a workflow's thumbnail image""" try: - ApiDependencies.invoker.services.workflow_records.get(workflow_id) + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) except WorkflowNotFoundError: raise HTTPException(status_code=404, detail="Workflow not found") + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + try: ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id) except ValueError as e: @@ -206,8 +262,12 @@ async def delete_workflow_thumbnail( async def get_workflow_thumbnail( workflow_id: str = Path(description="The id of the workflow thumbnail to get"), ) -> FileResponse: - """Gets a workflow's thumbnail image""" + """Gets a workflow's thumbnail image. + This endpoint is intentionally unauthenticated because browsers load images + via tags which cannot send Bearer tokens. Workflow IDs are UUIDs, + providing security through unguessability. + """ try: path = ApiDependencies.invoker.services.workflow_thumbnails.get_path(workflow_id) @@ -223,37 +283,91 @@ async def get_workflow_thumbnail( raise HTTPException(status_code=404) +@workflows_router.patch( + "/i/{workflow_id}/is_public", + operation_id="update_workflow_is_public", + responses={ + 200: {"model": WorkflowRecordDTO}, + }, +) +async def update_workflow_is_public( + current_user: CurrentUserOrDefault, + workflow_id: str = Path(description="The workflow to update"), + is_public: bool = Body(description="Whether the workflow should be shared publicly", embed=True), +) -> WorkflowRecordDTO: + """Updates whether a workflow is shared publicly""" + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.workflow_records.update_is_public( + workflow_id=workflow_id, is_public=is_public, user_id=user_id + ) + + @workflows_router.get("/tags", operation_id="get_all_tags") async def get_all_tags( + current_user: CurrentUserOrDefault, categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> list[str]: """Gets all unique tags from workflows""" - - return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories) + config = ApiDependencies.invoker.services.configuration + user_id_filter: Optional[str] = None + if config.multiuser: + has_user_category = not categories or WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id + + return ApiDependencies.invoker.services.workflow_records.get_all_tags( + categories=categories, user_id=user_id_filter, is_public=is_public + ) @workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag") async def get_counts_by_tag( + current_user: CurrentUserOrDefault, tags: list[str] = Query(description="The tags to get counts for"), categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"), has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> dict[str, int]: """Counts workflows by tag""" + config = ApiDependencies.invoker.services.configuration + user_id_filter: Optional[str] = None + if config.multiuser: + has_user_category = not categories or WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id return ApiDependencies.invoker.services.workflow_records.counts_by_tag( - tags=tags, categories=categories, has_been_opened=has_been_opened + tags=tags, categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public ) @workflows_router.get("/counts_by_category", operation_id="counts_by_category") async def counts_by_category( + current_user: CurrentUserOrDefault, categories: list[WorkflowCategory] = Query(description="The categories to include"), has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> dict[str, int]: """Counts workflows by category""" + config = ApiDependencies.invoker.services.configuration + user_id_filter: Optional[str] = None + if config.multiuser: + has_user_category = WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id return ApiDependencies.invoker.services.workflow_records.counts_by_category( - categories=categories, has_been_opened=has_been_opened + categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public ) @@ -262,7 +376,18 @@ async def counts_by_category( operation_id="update_opened_at", ) async def update_opened_at( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to update"), ) -> None: """Updates the opened_at field of a workflow""" - ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id) + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + + user_id = None if current_user.is_admin else current_user.user_id + ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id, user_id=user_id) diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index fcead54eb1e..5783b804c0b 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -121,6 +121,11 @@ async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> b Returns True to accept the connection, False to reject it. Stores user_id in the internal socket users dict for later use. + + In multiuser mode, connections without a valid token are rejected outright + so that anonymous clients cannot subscribe to queue rooms and observe + queue activity belonging to other users. In single-user mode, unauthenticated + connections are accepted as the system admin user. """ # Extract token from auth data or headers token = None @@ -137,6 +142,23 @@ async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> b if token: token_data = verify_token(token) if token_data: + # In multiuser mode, also verify the backing user record still + # exists and is active — mirrors the REST auth check in + # auth_dependencies.py. A deleted or deactivated user whose + # JWT has not yet expired must not be allowed to open a socket. + if self._is_multiuser_enabled(): + try: + from invokeai.app.api.dependencies import ApiDependencies + + user = ApiDependencies.invoker.services.users.get(token_data.user_id) + if user is None or not user.is_active: + logger.warning(f"Rejecting socket {sid}: user {token_data.user_id} not found or inactive") + return False + except Exception: + # If user service is unavailable, fail closed + logger.warning(f"Rejecting socket {sid}: unable to verify user record") + return False + # Store user_id and is_admin in socket users dict self._socket_users[sid] = { "user_id": token_data.user_id, @@ -147,14 +169,37 @@ async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> b ) return True - # If no valid token, store system user for backward compatibility + # No valid token provided. In multiuser mode this is not allowed — reject + # the connection so anonymous clients cannot subscribe to queue rooms. + # In single-user mode, fall through and accept the socket as system admin. + if self._is_multiuser_enabled(): + logger.warning( + f"Rejecting socket {sid} connection: multiuser mode is enabled and no valid auth token was provided" + ) + return False + self._socket_users[sid] = { "user_id": "system", - "is_admin": False, + "is_admin": True, } - logger.debug(f"Socket {sid} connected as system user (no valid token)") + logger.debug(f"Socket {sid} connected as system admin (single-user mode)") return True + @staticmethod + def _is_multiuser_enabled() -> bool: + """Check whether multiuser mode is enabled. Fails closed if configuration + is not yet initialized, which should not happen in practice but prevents + accidentally opening the socket during startup races.""" + try: + # Imported here to avoid a circular import at module load time. + from invokeai.app.api.dependencies import ApiDependencies + + return bool(ApiDependencies.invoker.services.configuration.multiuser) + except Exception: + # If dependencies are not initialized, fail closed (treat as multiuser) + # so we never accidentally admit an anonymous socket. + return True + async def _handle_disconnect(self, sid: str) -> None: """Handle socket disconnection and cleanup user info.""" if sid in self._socket_users: @@ -165,15 +210,20 @@ async def _handle_sub_queue(self, sid: str, data: Any) -> None: """Handle queue subscription and add socket to both queue and user-specific rooms.""" queue_id = QueueSubscriptionEvent(**data).queue_id - # Check if we have user info for this socket + # Check if we have user info for this socket. In multiuser mode _handle_connect + # will have already rejected any socket without a valid token, so missing user + # info here is a bug — refuse the subscription rather than silently falling back + # to an anonymous system user who could then receive queue item events. if sid not in self._socket_users: - logger.warning( - f"Socket {sid} subscribing to queue {queue_id} but has no user info - need to authenticate via connect event" - ) - # Store as system user temporarily - real auth should happen in connect + if self._is_multiuser_enabled(): + logger.warning( + f"Refusing queue subscription for socket {sid}: no user info (socket not authenticated via connect event)" + ) + return + # Single-user mode: safe to fall back to the system admin user. self._socket_users[sid] = { "user_id": "system", - "is_admin": False, + "is_admin": True, } user_id = self._socket_users[sid]["user_id"] @@ -198,6 +248,13 @@ async def _handle_unsub_queue(self, sid: str, data: Any) -> None: await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id) async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None: + # In multiuser mode, only allow authenticated sockets to subscribe. + # Bulk download events are routed to user-specific rooms, so the + # bulk_download_id room subscription is only kept for single-user + # backward compatibility. + if self._is_multiuser_enabled() and sid not in self._socket_users: + logger.warning(f"Refusing bulk download subscription for unknown socket {sid} in multiuser mode") + return await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: @@ -206,9 +263,17 @@ async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): """Handle queue events with user isolation. - Invocation events (progress, started, complete) are private - only emit to owner and admins. - Queue item status events are public - emit to all users (field values hidden via API). - Other queue events emit to all subscribers. + All queue item events (invocation events AND QueueItemStatusChangedEvent) are + private to the owning user and admins. They carry unsanitized user_id, batch_id, + session_id, origin, destination and error metadata, and must never be broadcast + to the whole queue room — otherwise any other authenticated subscriber could + observe cross-user queue activity. + + RecallParametersUpdatedEvent is also private to the owner + admins. + + BatchEnqueuedEvent carries the enqueuing user's batch_id/origin/counts and + is also routed privately. QueueClearedEvent is the only queue event that + is still broadcast to the whole queue room. IMPORTANT: Check InvocationEventBase BEFORE QueueItemEventBase since InvocationEventBase inherits from QueueItemEventBase. The order of isinstance checks matters! @@ -237,24 +302,40 @@ async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): logger.debug(f"Emitted private invocation event {event_name} to user room {user_room} and admin room") - # Queue item status events are visible to all users (field values masked via API) - # This catches QueueItemStatusChangedEvent but NOT InvocationEvents (already handled above) + # Other queue item events (QueueItemStatusChangedEvent) carry unsanitized + # user_id, batch_id, session_id, origin, destination and error metadata. + # They are private to the owning user + admins — never broadcast to the + # full queue room. elif isinstance(event_data, QueueItemEventBase) and hasattr(event_data, "user_id"): - # Emit to all subscribers in the queue - await self._sio.emit( - event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id - ) + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") - logger.info( - f"Emitted public queue item event {event_name} to all subscribers in queue {event_data.queue_id}" - ) + logger.debug(f"Emitted private queue item event {event_name} to user room {user_room} and admin room") + + # RecallParametersUpdatedEvent is private - only emit to owner + admins + elif isinstance(event_data, RecallParametersUpdatedEvent): + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") + logger.debug(f"Emitted private recall_parameters_updated event to user room {user_room} and admin room") + + # BatchEnqueuedEvent carries the enqueuing user's batch_id, origin, and + # enqueued counts. Route it privately to the owner + admins so other + # users do not observe cross-user batch activity. + elif isinstance(event_data, BatchEnqueuedEvent): + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") + logger.debug(f"Emitted private batch_enqueued event to user room {user_room} and admin room") else: - # For other queue events (like QueueClearedEvent, BatchEnqueuedEvent), emit to all subscribers + # For remaining queue events (e.g. QueueClearedEvent) that do not + # carry user identity, emit to all subscribers in the queue room. await self._sio.emit( event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id ) - logger.info( + logger.debug( f"Emitted general queue event {event_name} to all subscribers in queue {event_data.queue_id}" ) except Exception as e: @@ -265,4 +346,17 @@ async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | Downloa await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json")) async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None: - await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id) + event_name, event_data = event + # Route to user-specific + admin rooms so that other authenticated + # users cannot learn the bulk_download_item_name (the capability token + # needed to fetch the zip from the unauthenticated GET endpoint). + # In single-user mode (user_id="system"), fall back to the shared + # bulk_download_id room for backward compatibility. + if hasattr(event_data, "user_id") and event_data.user_id != "system": + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") + else: + await self._sio.emit( + event=event_name, data=event_data.model_dump(mode="json"), room=event_data.bulk_download_id + ) diff --git a/invokeai/app/services/board_records/board_records_common.py b/invokeai/app/services/board_records/board_records_common.py index ab6355a3930..b263f264cb8 100644 --- a/invokeai/app/services/board_records/board_records_common.py +++ b/invokeai/app/services/board_records/board_records_common.py @@ -9,6 +9,17 @@ from invokeai.app.util.model_exclude_null import BaseModelExcludeNull +class BoardVisibility(str, Enum, metaclass=MetaEnum): + """The visibility options for a board.""" + + Private = "private" + """Only the board owner (and admins) can see and modify this board.""" + Shared = "shared" + """All users can view this board, but only the owner (and admins) can modify it.""" + Public = "public" + """All users can view this board; only the owner (and admins) can modify its structure.""" + + class BoardRecord(BaseModelExcludeNull): """Deserialized board record.""" @@ -28,6 +39,10 @@ class BoardRecord(BaseModelExcludeNull): """The name of the cover image of the board.""" archived: bool = Field(description="Whether or not the board is archived.") """Whether or not the board is archived.""" + board_visibility: BoardVisibility = Field( + default=BoardVisibility.Private, description="The visibility of the board." + ) + """The visibility of the board (private, shared, or public).""" def deserialize_board_record(board_dict: dict) -> BoardRecord: @@ -44,6 +59,11 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord: updated_at = board_dict.get("updated_at", get_iso_timestamp()) deleted_at = board_dict.get("deleted_at", get_iso_timestamp()) archived = board_dict.get("archived", False) + board_visibility_raw = board_dict.get("board_visibility", BoardVisibility.Private.value) + try: + board_visibility = BoardVisibility(board_visibility_raw) + except ValueError: + board_visibility = BoardVisibility.Private return BoardRecord( board_id=board_id, @@ -54,6 +74,7 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord: updated_at=updated_at, deleted_at=deleted_at, archived=archived, + board_visibility=board_visibility, ) @@ -61,6 +82,7 @@ class BoardChanges(BaseModel, extra="forbid"): board_name: Optional[str] = Field(default=None, description="The board's new name.", max_length=300) cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.") archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived") + board_visibility: Optional[BoardVisibility] = Field(default=None, description="The visibility of the board.") class BoardRecordOrderBy(str, Enum, metaclass=MetaEnum): diff --git a/invokeai/app/services/board_records/board_records_sqlite.py b/invokeai/app/services/board_records/board_records_sqlite.py index a54f65686fd..1e3e11c8a36 100644 --- a/invokeai/app/services/board_records/board_records_sqlite.py +++ b/invokeai/app/services/board_records/board_records_sqlite.py @@ -116,6 +116,17 @@ def update( (changes.archived, board_id), ) + # Change the visibility of a board + if changes.board_visibility is not None: + cursor.execute( + """--sql + UPDATE boards + SET board_visibility = ? + WHERE board_id = ?; + """, + (changes.board_visibility.value, board_id), + ) + except sqlite3.Error as e: raise BoardRecordSaveException from e return self.get(board_id) @@ -155,7 +166,7 @@ def get_many( SELECT DISTINCT boards.* FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1) + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')) {archived_filter} ORDER BY {order_by} {direction} LIMIT ? OFFSET ?; @@ -194,14 +205,14 @@ def get_many( SELECT COUNT(DISTINCT boards.board_id) FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1); + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')); """ else: count_query = """ SELECT COUNT(DISTINCT boards.board_id) FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1) + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')) AND boards.archived = 0; """ @@ -251,7 +262,7 @@ def get_all( SELECT DISTINCT boards.* FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1) + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')) {archived_filter} ORDER BY LOWER(boards.board_name) {direction} """ @@ -260,7 +271,7 @@ def get_all( SELECT DISTINCT boards.* FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1) + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')) {archived_filter} ORDER BY {order_by} {direction} """ diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 617b611f566..6cd4ed0cbaf 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -7,7 +7,11 @@ class BulkDownloadBase(ABC): @abstractmethod def handler( - self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str] + self, + image_names: Optional[list[str]], + board_id: Optional[str], + bulk_download_item_id: Optional[str], + user_id: str = "system", ) -> None: """ Create a zip file containing the images specified by the given image names or board id. @@ -15,6 +19,7 @@ def handler( :param image_names: A list of image names to include in the zip file. :param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. :param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated. + :param user_id: The ID of the user who initiated the download. """ @abstractmethod @@ -42,3 +47,12 @@ def delete(self, bulk_download_item_name: str) -> None: :param bulk_download_item_name: The name of the bulk download item. """ + + @abstractmethod + def get_owner(self, bulk_download_item_name: str) -> Optional[str]: + """ + Get the user_id of the user who initiated the download. + + :param bulk_download_item_name: The name of the bulk download item. + :return: The user_id of the owner, or None if not tracked. + """ diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index dc4f8b1d81b..c037e9c5c15 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -25,15 +25,24 @@ def __init__(self): self._temp_directory = TemporaryDirectory() self._bulk_downloads_folder = Path(self._temp_directory.name) / "bulk_downloads" self._bulk_downloads_folder.mkdir(parents=True, exist_ok=True) + # Track which user owns each download so the fetch endpoint can enforce ownership + self._download_owners: dict[str, str] = {} def handler( - self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str] + self, + image_names: Optional[list[str]], + board_id: Optional[str], + bulk_download_item_id: Optional[str], + user_id: str = "system", ) -> None: bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID bulk_download_item_id = bulk_download_item_id or uuid_string() bulk_download_item_name = bulk_download_item_id + ".zip" - self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name) + # Record ownership so the fetch endpoint can verify the caller + self._download_owners[bulk_download_item_name] = user_id + + self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id) try: image_dtos: list[ImageDTO] = [] @@ -46,16 +55,16 @@ def handler( raise BulkDownloadParametersException() bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id) - self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) + self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id) except ( ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException, BulkDownloadParametersException, ) as e: - self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e) + self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e, user_id) except Exception as e: - self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e) + self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e, user_id) self._invoker.services.logger.error("Problem bulk downloading images.") raise e @@ -103,43 +112,60 @@ def _clean_string_to_path_safe(self, s: str) -> str: return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip() def _signal_job_started( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> None: """Signal that a bulk download job has started.""" if self._invoker: assert bulk_download_id is not None self._invoker.services.events.emit_bulk_download_started( - bulk_download_id, bulk_download_item_id, bulk_download_item_name + bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id=user_id ) def _signal_job_completed( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> None: """Signal that a bulk download job has completed.""" if self._invoker: assert bulk_download_id is not None assert bulk_download_item_name is not None self._invoker.services.events.emit_bulk_download_complete( - bulk_download_id, bulk_download_item_id, bulk_download_item_name + bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id=user_id ) def _signal_job_failed( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + exception: Exception, + user_id: str = "system", ) -> None: """Signal that a bulk download job has failed.""" if self._invoker: assert bulk_download_id is not None assert exception is not None self._invoker.services.events.emit_bulk_download_error( - bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception) + bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception), user_id=user_id ) def stop(self, *args, **kwargs): self._temp_directory.cleanup() + def get_owner(self, bulk_download_item_name: str) -> Optional[str]: + return self._download_owners.get(bulk_download_item_name) + def delete(self, bulk_download_item_name: str) -> None: path = self.get_path(bulk_download_item_name) Path(path).unlink() + self._download_owners.pop(bulk_download_item_name, None) def get_path(self, bulk_download_item_name: str) -> str: path = str(self._bulk_downloads_folder / bulk_download_item_name) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index aa1cbb5e0ee..935b422a732 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -100,9 +100,9 @@ def emit_queue_item_status_changed( """Emitted when a queue item's status changes""" self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status)) - def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None: + def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", user_id: str = "system") -> None: """Emitted when a batch is enqueued""" - self.dispatch(BatchEnqueuedEvent.build(enqueue_result)) + self.dispatch(BatchEnqueuedEvent.build(enqueue_result, user_id)) def emit_queue_items_retried(self, retry_result: "RetryItemsResult") -> None: """Emitted when a list of queue items are retried""" @@ -112,9 +112,9 @@ def emit_queue_cleared(self, queue_id: str) -> None: """Emitted when a queue is cleared""" self.dispatch(QueueClearedEvent.build(queue_id)) - def emit_recall_parameters_updated(self, queue_id: str, parameters: dict) -> None: + def emit_recall_parameters_updated(self, queue_id: str, user_id: str, parameters: dict) -> None: """Emitted when recall parameters are updated""" - self.dispatch(RecallParametersUpdatedEvent.build(queue_id, parameters)) + self.dispatch(RecallParametersUpdatedEvent.build(queue_id, user_id, parameters)) # endregion @@ -194,23 +194,42 @@ def emit_model_install_error(self, job: "ModelInstallJob") -> None: # region Bulk image download def emit_bulk_download_started( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> None: """Emitted when a bulk image download is started""" - self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) + self.dispatch( + BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id) + ) def emit_bulk_download_complete( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> None: """Emitted when a bulk image download is complete""" - self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) + self.dispatch( + BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id) + ) def emit_bulk_download_error( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + error: str, + user_id: str = "system", ) -> None: """Emitted when a bulk image download has an error""" self.dispatch( - BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error) + BulkDownloadErrorEvent.build( + bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, user_id + ) ) # endregion diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index bfb44eb48e8..998fe4f5309 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -281,9 +281,10 @@ class BatchEnqueuedEvent(QueueEventBase): ) priority: int = Field(description="The priority of the batch") origin: str | None = Field(default=None, description="The origin of the batch") + user_id: str = Field(default="system", description="The ID of the user who enqueued the batch") @classmethod - def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent": + def build(cls, enqueue_result: EnqueueBatchResult, user_id: str = "system") -> "BatchEnqueuedEvent": return cls( queue_id=enqueue_result.queue_id, batch_id=enqueue_result.batch.batch_id, @@ -291,6 +292,7 @@ def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent": enqueued=enqueue_result.enqueued, requested=enqueue_result.requested, priority=enqueue_result.priority, + user_id=user_id, ) @@ -609,6 +611,7 @@ class BulkDownloadEventBase(EventBase): bulk_download_id: str = Field(description="The ID of the bulk image download") bulk_download_item_id: str = Field(description="The ID of the bulk image download item") bulk_download_item_name: str = Field(description="The name of the bulk image download item") + user_id: str = Field(default="system", description="The ID of the user who initiated the download") @payload_schema.register @@ -619,12 +622,17 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase): @classmethod def build( - cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + cls, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> "BulkDownloadStartedEvent": return cls( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, + user_id=user_id, ) @@ -636,12 +644,17 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase): @classmethod def build( - cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + cls, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> "BulkDownloadCompleteEvent": return cls( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, + user_id=user_id, ) @@ -655,13 +668,19 @@ class BulkDownloadErrorEvent(BulkDownloadEventBase): @classmethod def build( - cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str + cls, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + error: str, + user_id: str = "system", ) -> "BulkDownloadErrorEvent": return cls( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, error=error, + user_id=user_id, ) @@ -671,8 +690,9 @@ class RecallParametersUpdatedEvent(QueueEventBase): __event_name__ = "recall_parameters_updated" + user_id: str = Field(description="The ID of the user whose recall parameters were updated") parameters: dict[str, Any] = Field(description="The recall parameters that were updated") @classmethod - def build(cls, queue_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent": - return cls(queue_id=queue_id, parameters=parameters) + def build(cls, queue_id: str, user_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent": + return cls(queue_id=queue_id, user_id=user_id, parameters=parameters) diff --git a/invokeai/app/services/events/events_fastapievents.py b/invokeai/app/services/events/events_fastapievents.py index f44eecc5559..90e1402773d 100644 --- a/invokeai/app/services/events/events_fastapievents.py +++ b/invokeai/app/services/events/events_fastapievents.py @@ -46,3 +46,9 @@ async def _dispatch_from_queue(self, stop_event: threading.Event): except asyncio.CancelledError as e: raise e # Raise a proper error + except Exception: + import logging + + logging.getLogger("InvokeAI").error( + f"Error dispatching event {getattr(event, '__event_name__', event)}", exc_info=True + ) diff --git a/invokeai/app/services/image_records/image_records_base.py b/invokeai/app/services/image_records/image_records_base.py index 16405c52708..457cf2f4686 100644 --- a/invokeai/app/services/image_records/image_records_base.py +++ b/invokeai/app/services/image_records/image_records_base.py @@ -74,8 +74,8 @@ def delete_intermediates(self) -> list[str]: pass @abstractmethod - def get_intermediates_count(self) -> int: - """Gets a count of all intermediate images.""" + def get_intermediates_count(self, user_id: Optional[str] = None) -> int: + """Gets a count of intermediate images. If user_id is provided, only counts that user's intermediates.""" pass @abstractmethod @@ -97,6 +97,11 @@ def save( """Saves an image record.""" pass + @abstractmethod + def get_user_id(self, image_name: str) -> Optional[str]: + """Gets the user_id of the image owner. Returns None if image not found.""" + pass + @abstractmethod def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]: """Gets the most recent image for a board.""" diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index c6c237fc1e7..07126d53a9f 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -46,6 +46,20 @@ def get(self, image_name: str) -> ImageRecord: return deserialize_image_record(dict(result)) + def get_user_id(self, image_name: str) -> Optional[str]: + with self._db.transaction() as cursor: + cursor.execute( + """--sql + SELECT user_id FROM images + WHERE image_name = ?; + """, + (image_name,), + ) + result = cast(Optional[sqlite3.Row], cursor.fetchone()) + if not result: + return None + return cast(Optional[str], dict(result).get("user_id")) + def get_metadata(self, image_name: str) -> Optional[MetadataField]: with self._db.transaction() as cursor: try: @@ -269,14 +283,14 @@ def delete_many(self, image_names: list[str]) -> None: except sqlite3.Error as e: raise ImageRecordDeleteException from e - def get_intermediates_count(self) -> int: + def get_intermediates_count(self, user_id: Optional[str] = None) -> int: with self._db.transaction() as cursor: - cursor.execute( - """--sql - SELECT COUNT(*) FROM images - WHERE is_intermediate = TRUE; - """ - ) + query = "SELECT COUNT(*) FROM images WHERE is_intermediate = TRUE" + params: list[str] = [] + if user_id is not None: + query += " AND user_id = ?" + params.append(user_id) + cursor.execute(query, params) count = cast(int, cursor.fetchone()[0]) return count diff --git a/invokeai/app/services/images/images_base.py b/invokeai/app/services/images/images_base.py index d11d75b3c1d..aebbead2f35 100644 --- a/invokeai/app/services/images/images_base.py +++ b/invokeai/app/services/images/images_base.py @@ -143,8 +143,8 @@ def delete_intermediates(self) -> int: pass @abstractmethod - def get_intermediates_count(self) -> int: - """Gets the number of intermediate images.""" + def get_intermediates_count(self, user_id: Optional[str] = None) -> int: + """Gets the number of intermediate images. If user_id is provided, only counts that user's intermediates.""" pass @abstractmethod diff --git a/invokeai/app/services/images/images_default.py b/invokeai/app/services/images/images_default.py index e82bd7f4de1..0f03f7c400e 100644 --- a/invokeai/app/services/images/images_default.py +++ b/invokeai/app/services/images/images_default.py @@ -310,9 +310,9 @@ def delete_intermediates(self) -> int: self.__invoker.services.logger.error("Problem deleting image records and files") raise e - def get_intermediates_count(self) -> int: + def get_intermediates_count(self, user_id: Optional[str] = None) -> int: try: - return self.__invoker.services.image_records.get_intermediates_count() + return self.__invoker.services.image_records.get_intermediates_count(user_id=user_id) except Exception as e: self.__invoker.services.logger.error("Problem getting intermediates count") raise e diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index 3c037dc77ab..14b93d97fc7 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -78,13 +78,15 @@ def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> Sess pass @abstractmethod - def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination: - """Gets the counts of queue items by destination""" + def get_counts_by_destination( + self, queue_id: str, destination: str, user_id: Optional[str] = None + ) -> SessionQueueCountsByDestination: + """Gets the counts of queue items by destination. If user_id is provided, only counts that user's items.""" pass @abstractmethod - def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus: - """Gets the status of a batch""" + def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] = None) -> BatchStatus: + """Gets the status of a batch. If user_id is provided, only counts that user's items.""" pass @abstractmethod @@ -172,8 +174,9 @@ def get_queue_item_ids( self, queue_id: str, order_dir: SQLiteDirection = SQLiteDirection.Descending, + user_id: Optional[str] = None, ) -> ItemIdsResult: - """Gets all queue item ids that match the given parameters""" + """Gets all queue item ids that match the given parameters. If user_id is provided, only returns items for that user.""" pass @abstractmethod diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 58544422119..09820fe6217 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -304,12 +304,6 @@ class SessionQueueStatus(BaseModel): failed: int = Field(..., description="Number of queue items with status 'error'") canceled: int = Field(..., description="Number of queue items with status 'canceled'") total: int = Field(..., description="Total number of queue items") - user_pending: Optional[int] = Field( - default=None, description="Number of queue items with status 'pending' for the current user" - ) - user_in_progress: Optional[int] = Field( - default=None, description="Number of queue items with status 'in_progress' for the current user" - ) class SessionQueueCountsByDestination(BaseModel): diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 4f46136fd79..070a7cef293 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -151,7 +151,7 @@ async def enqueue_batch( priority=priority, item_ids=item_ids, ) - self.__invoker.services.events.emit_batch_enqueued(enqueue_result) + self.__invoker.services.events.emit_batch_enqueued(enqueue_result, user_id=user_id) return enqueue_result def dequeue(self) -> Optional[SessionQueueItem]: @@ -765,15 +765,21 @@ def get_queue_item_ids( self, queue_id: str, order_dir: SQLiteDirection = SQLiteDirection.Descending, + user_id: Optional[str] = None, ) -> ItemIdsResult: with self._db.transaction() as cursor_: - query = f"""--sql + query = """--sql SELECT item_id FROM session_queue WHERE queue_id = ? - ORDER BY created_at {order_dir.value} """ - query_params = [queue_id] + query_params: list[str] = [queue_id] + + if user_id is not None: + query += " AND user_id = ?" + query_params.append(user_id) + + query += f" ORDER BY created_at {order_dir.value}" cursor_.execute(query, query_params) result = cast(list[sqlite3.Row], cursor_.fetchall()) @@ -783,20 +789,7 @@ def get_queue_item_ids( def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus: with self._db.transaction() as cursor: - # Get total counts - cursor.execute( - """--sql - SELECT status, count(*) - FROM session_queue - WHERE queue_id = ? - GROUP BY status - """, - (queue_id,), - ) - counts_result = cast(list[sqlite3.Row], cursor.fetchall()) - - # Get user-specific counts if user_id is provided (using a single query with CASE) - user_counts_result = [] + # When user_id is provided (non-admin), only count that user's items if user_id is not None: cursor.execute( """--sql @@ -807,48 +800,51 @@ def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> Sess """, (queue_id, user_id), ) - user_counts_result = cast(list[sqlite3.Row], cursor.fetchall()) + else: + cursor.execute( + """--sql + SELECT status, count(*) + FROM session_queue + WHERE queue_id = ? + GROUP BY status + """, + (queue_id,), + ) + counts_result = cast(list[sqlite3.Row], cursor.fetchall()) current_item = self.get_current(queue_id=queue_id) total = sum(row[1] or 0 for row in counts_result) counts: dict[str, int] = {row[0]: row[1] for row in counts_result} - # Process user-specific counts if available - user_pending = None - user_in_progress = None - if user_id is not None: - user_counts: dict[str, int] = {row[0]: row[1] for row in user_counts_result} - user_pending = user_counts.get("pending", 0) - user_in_progress = user_counts.get("in_progress", 0) + # For non-admin users, hide current item details if they don't own it + show_current_item = current_item is not None and (user_id is None or current_item.user_id == user_id) return SessionQueueStatus( queue_id=queue_id, - item_id=current_item.item_id if current_item else None, - session_id=current_item.session_id if current_item else None, - batch_id=current_item.batch_id if current_item else None, + item_id=current_item.item_id if show_current_item else None, + session_id=current_item.session_id if show_current_item else None, + batch_id=current_item.batch_id if show_current_item else None, pending=counts.get("pending", 0), in_progress=counts.get("in_progress", 0), completed=counts.get("completed", 0), failed=counts.get("failed", 0), canceled=counts.get("canceled", 0), total=total, - user_pending=user_pending, - user_in_progress=user_in_progress, ) - def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus: + def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] = None) -> BatchStatus: with self._db.transaction() as cursor: - cursor.execute( - """--sql + query = """--sql SELECT status, count(*), origin, destination FROM session_queue - WHERE - queue_id = ? - AND batch_id = ? - GROUP BY status - """, - (queue_id, batch_id), - ) + WHERE queue_id = ? AND batch_id = ? + """ + params: list[str] = [queue_id, batch_id] + if user_id is not None: + query += " AND user_id = ?" + params.append(user_id) + query += " GROUP BY status" + cursor.execute(query, params) result = cast(list[sqlite3.Row], cursor.fetchall()) total = sum(row[1] or 0 for row in result) counts: dict[str, int] = {row[0]: row[1] for row in result} @@ -868,18 +864,21 @@ def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus: total=total, ) - def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination: + def get_counts_by_destination( + self, queue_id: str, destination: str, user_id: Optional[str] = None + ) -> SessionQueueCountsByDestination: with self._db.transaction() as cursor: - cursor.execute( - """--sql + query = """--sql SELECT status, count(*) FROM session_queue - WHERE queue_id = ? - AND destination = ? - GROUP BY status - """, - (queue_id, destination), - ) + WHERE queue_id = ? AND destination = ? + """ + params: list[str] = [queue_id, destination] + if user_id is not None: + query += " AND user_id = ?" + params.append(user_id) + query += " GROUP BY status" + cursor.execute(query, params) counts_result = cast(list[sqlite3.Row], cursor.fetchall()) total = sum(row[1] or 0 for row in counts_result) diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 645509f1dde..fb8ca9fca38 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -30,6 +30,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_26 import build_migration_26 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_27 import build_migration_27 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_28 import build_migration_28 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_29 import build_migration_29 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -77,6 +79,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_25(app_config=config, logger=logger)) migrator.register_migration(build_migration_26(app_config=config, logger=logger)) migrator.register_migration(build_migration_27()) + migrator.register_migration(build_migration_28()) + migrator.register_migration(build_migration_29()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py new file mode 100644 index 00000000000..0cbd683ab5e --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py @@ -0,0 +1,45 @@ +"""Migration 28: Add per-user workflow isolation columns to workflow_library. + +This migration adds the database columns required for multiuser workflow isolation +to the workflow_library table: +- user_id: the owner of the workflow (defaults to 'system' for existing workflows) +- is_public: whether the workflow is shared with all users +""" + +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration28Callback: + """Migration to add user_id and is_public to the workflow_library table.""" + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._update_workflow_library_table(cursor) + + def _update_workflow_library_table(self, cursor: sqlite3.Cursor) -> None: + """Add user_id and is_public columns to workflow_library table.""" + cursor.execute("PRAGMA table_info(workflow_library);") + columns = [row[1] for row in cursor.fetchall()] + + if "user_id" not in columns: + cursor.execute("ALTER TABLE workflow_library ADD COLUMN user_id TEXT DEFAULT 'system';") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflow_library_user_id ON workflow_library(user_id);") + + if "is_public" not in columns: + cursor.execute("ALTER TABLE workflow_library ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflow_library_is_public ON workflow_library(is_public);") + + +def build_migration_28() -> Migration: + """Builds the migration object for migrating from version 27 to version 28. + + This migration adds per-user workflow isolation to the workflow_library table: + - user_id column: identifies the owner of each workflow + - is_public column: controls whether a workflow is shared with all users + """ + return Migration( + from_version=27, + to_version=28, + callback=Migration28Callback(), + ) diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_29.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_29.py new file mode 100644 index 00000000000..c9eb7c901ba --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_29.py @@ -0,0 +1,53 @@ +"""Migration 29: Add board_visibility column to boards table. + +This migration adds a board_visibility column to the boards table to support +three visibility levels: + - 'private': only the board owner (and admins) can view/modify + - 'shared': all users can view, but only the owner (and admins) can modify + - 'public': all users can view; only the owner (and admins) can modify the + board structure (rename/archive/delete) + +Existing boards with is_public = 1 are migrated to 'public'. +All other existing boards default to 'private'. +""" + +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration29Callback: + """Migration to add board_visibility column to the boards table.""" + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._update_boards_table(cursor) + + def _update_boards_table(self, cursor: sqlite3.Cursor) -> None: + """Add board_visibility column to boards table.""" + # Check if boards table exists + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='boards';") + if cursor.fetchone() is None: + return + + cursor.execute("PRAGMA table_info(boards);") + columns = [row[1] for row in cursor.fetchall()] + + if "board_visibility" not in columns: + cursor.execute("ALTER TABLE boards ADD COLUMN board_visibility TEXT NOT NULL DEFAULT 'private';") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_boards_board_visibility ON boards(board_visibility);") + # Migrate existing is_public = 1 boards to 'public' + if "is_public" in columns: + cursor.execute("UPDATE boards SET board_visibility = 'public' WHERE is_public = 1;") + + +def build_migration_29() -> Migration: + """Builds the migration object for migrating from version 28 to version 29. + + This migration adds the board_visibility column to the boards table, + supporting 'private', 'shared', and 'public' visibility levels. + """ + return Migration( + from_version=28, + to_version=29, + callback=Migration29Callback(), + ) diff --git a/invokeai/app/services/users/users_base.py b/invokeai/app/services/users/users_base.py index 728a0adfa37..dd789b561ee 100644 --- a/invokeai/app/services/users/users_base.py +++ b/invokeai/app/services/users/users_base.py @@ -131,6 +131,15 @@ def list_users(self, limit: int = 100, offset: int = 0) -> list[UserDTO]: """ pass + @abstractmethod + def get_admin_email(self) -> str | None: + """Get the email address of the first active admin user. + + Returns: + Email address of the first active admin, or None if no admin exists + """ + pass + @abstractmethod def count_admins(self) -> int: """Count active admin users. diff --git a/invokeai/app/services/users/users_default.py b/invokeai/app/services/users/users_default.py index 709e4cb82c6..6e472882124 100644 --- a/invokeai/app/services/users/users_default.py +++ b/invokeai/app/services/users/users_default.py @@ -256,6 +256,20 @@ def list_users(self, limit: int = 100, offset: int = 0) -> list[UserDTO]: for row in rows ] + def get_admin_email(self) -> str | None: + """Get the email address of the first active admin user.""" + with self._db.transaction() as cursor: + cursor.execute( + """ + SELECT email FROM users + WHERE is_admin = TRUE AND is_active = TRUE + ORDER BY created_at ASC + LIMIT 1 + """, + ) + row = cursor.fetchone() + return row[0] if row else None + def count_admins(self) -> int: """Count active admin users.""" with self._db.transaction() as cursor: diff --git a/invokeai/app/services/workflow_records/workflow_records_base.py b/invokeai/app/services/workflow_records/workflow_records_base.py index d5cf319594b..856a6c6d490 100644 --- a/invokeai/app/services/workflow_records/workflow_records_base.py +++ b/invokeai/app/services/workflow_records/workflow_records_base.py @@ -4,6 +4,7 @@ from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.services.workflow_records.workflow_records_common import ( + WORKFLOW_LIBRARY_DEFAULT_USER_ID, Workflow, WorkflowCategory, WorkflowRecordDTO, @@ -22,18 +23,18 @@ def get(self, workflow_id: str) -> WorkflowRecordDTO: pass @abstractmethod - def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO: + def create(self, workflow: WorkflowWithoutID, user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID) -> WorkflowRecordDTO: """Creates a workflow.""" pass @abstractmethod - def update(self, workflow: Workflow) -> WorkflowRecordDTO: - """Updates a workflow.""" + def update(self, workflow: Workflow, user_id: Optional[str] = None) -> WorkflowRecordDTO: + """Updates a workflow. When user_id is provided, the UPDATE is scoped to that user.""" pass @abstractmethod - def delete(self, workflow_id: str) -> None: - """Deletes a workflow.""" + def delete(self, workflow_id: str, user_id: Optional[str] = None) -> None: + """Deletes a workflow. When user_id is provided, the DELETE is scoped to that user.""" pass @abstractmethod @@ -47,6 +48,8 @@ def get_many( query: Optional[str], tags: Optional[list[str]], has_been_opened: Optional[bool], + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> PaginatedResults[WorkflowRecordListItemDTO]: """Gets many workflows.""" pass @@ -56,6 +59,8 @@ def counts_by_category( self, categories: list[WorkflowCategory], has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: """Gets a dictionary of counts for each of the provided categories.""" pass @@ -66,19 +71,28 @@ def counts_by_tag( tags: list[str], categories: Optional[list[WorkflowCategory]] = None, has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: """Gets a dictionary of counts for each of the provided tags.""" pass @abstractmethod - def update_opened_at(self, workflow_id: str) -> None: - """Open a workflow.""" + def update_opened_at(self, workflow_id: str, user_id: Optional[str] = None) -> None: + """Open a workflow. When user_id is provided, the UPDATE is scoped to that user.""" pass @abstractmethod def get_all_tags( self, categories: Optional[list[WorkflowCategory]] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> list[str]: """Gets all unique tags from workflows.""" pass + + @abstractmethod + def update_is_public(self, workflow_id: str, is_public: bool, user_id: Optional[str] = None) -> WorkflowRecordDTO: + """Updates the is_public field of a workflow. When user_id is provided, the UPDATE is scoped to that user.""" + pass diff --git a/invokeai/app/services/workflow_records/workflow_records_common.py b/invokeai/app/services/workflow_records/workflow_records_common.py index e0cea37468d..9c505530c90 100644 --- a/invokeai/app/services/workflow_records/workflow_records_common.py +++ b/invokeai/app/services/workflow_records/workflow_records_common.py @@ -9,6 +9,9 @@ __workflow_meta_version__ = semver.Version.parse("1.0.0") +WORKFLOW_LIBRARY_DEFAULT_USER_ID = "system" +"""Default user_id for workflows created in single-user mode or migrated from pre-multiuser databases.""" + class ExposedField(BaseModel): nodeId: str @@ -26,6 +29,7 @@ class WorkflowRecordOrderBy(str, Enum, metaclass=MetaEnum): UpdatedAt = "updated_at" OpenedAt = "opened_at" Name = "name" + IsPublic = "is_public" class WorkflowCategory(str, Enum, metaclass=MetaEnum): @@ -100,6 +104,8 @@ class WorkflowRecordDTOBase(BaseModel): opened_at: Optional[Union[datetime.datetime, str]] = Field( default=None, description="The opened timestamp of the workflow." ) + user_id: str = Field(description="The id of the user who owns this workflow.") + is_public: bool = Field(description="Whether this workflow is shared with all users.") class WorkflowRecordDTO(WorkflowRecordDTOBase): diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py index 0f72f7cd92c..c83d87eff68 100644 --- a/invokeai/app/services/workflow_records/workflow_records_sqlite.py +++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py @@ -7,6 +7,7 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase from invokeai.app.services.workflow_records.workflow_records_common import ( + WORKFLOW_LIBRARY_DEFAULT_USER_ID, Workflow, WorkflowCategory, WorkflowNotFoundError, @@ -36,7 +37,7 @@ def get(self, workflow_id: str) -> WorkflowRecordDTO: with self._db.transaction() as cursor: cursor.execute( """--sql - SELECT workflow_id, workflow, name, created_at, updated_at, opened_at + SELECT workflow_id, workflow, name, created_at, updated_at, opened_at, user_id, is_public FROM workflow_library WHERE workflow_id = ?; """, @@ -47,7 +48,7 @@ def get(self, workflow_id: str) -> WorkflowRecordDTO: raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found") return WorkflowRecordDTO.from_dict(dict(row)) - def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO: + def create(self, workflow: WorkflowWithoutID, user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID) -> WorkflowRecordDTO: if workflow.meta.category is WorkflowCategory.Default: raise ValueError("Default workflows cannot be created via this method") @@ -57,43 +58,98 @@ def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO: """--sql INSERT OR IGNORE INTO workflow_library ( workflow_id, - workflow + workflow, + user_id ) - VALUES (?, ?); + VALUES (?, ?, ?); """, - (workflow_with_id.id, workflow_with_id.model_dump_json()), + (workflow_with_id.id, workflow_with_id.model_dump_json(), user_id), ) return self.get(workflow_with_id.id) - def update(self, workflow: Workflow) -> WorkflowRecordDTO: + def update(self, workflow: Workflow, user_id: Optional[str] = None) -> WorkflowRecordDTO: if workflow.meta.category is WorkflowCategory.Default: raise ValueError("Default workflows cannot be updated") with self._db.transaction() as cursor: - cursor.execute( - """--sql - UPDATE workflow_library - SET workflow = ? - WHERE workflow_id = ? AND category = 'user'; - """, - (workflow.model_dump_json(), workflow.id), - ) + if user_id is not None: + cursor.execute( + """--sql + UPDATE workflow_library + SET workflow = ? + WHERE workflow_id = ? AND category = 'user' AND user_id = ?; + """, + (workflow.model_dump_json(), workflow.id, user_id), + ) + else: + cursor.execute( + """--sql + UPDATE workflow_library + SET workflow = ? + WHERE workflow_id = ? AND category = 'user'; + """, + (workflow.model_dump_json(), workflow.id), + ) return self.get(workflow.id) - def delete(self, workflow_id: str) -> None: + def delete(self, workflow_id: str, user_id: Optional[str] = None) -> None: if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default: raise ValueError("Default workflows cannot be deleted") with self._db.transaction() as cursor: - cursor.execute( - """--sql - DELETE from workflow_library - WHERE workflow_id = ? AND category = 'user'; - """, - (workflow_id,), - ) + if user_id is not None: + cursor.execute( + """--sql + DELETE from workflow_library + WHERE workflow_id = ? AND category = 'user' AND user_id = ?; + """, + (workflow_id, user_id), + ) + else: + cursor.execute( + """--sql + DELETE from workflow_library + WHERE workflow_id = ? AND category = 'user'; + """, + (workflow_id,), + ) return None + def update_is_public(self, workflow_id: str, is_public: bool, user_id: Optional[str] = None) -> WorkflowRecordDTO: + """Updates the is_public field of a workflow and manages the 'shared' tag automatically.""" + record = self.get(workflow_id) + workflow = record.workflow + + # Manage "shared" tag: add when public, remove when private + tags_list = [t.strip() for t in workflow.tags.split(",") if t.strip()] if workflow.tags else [] + if is_public and "shared" not in tags_list: + tags_list.append("shared") + elif not is_public and "shared" in tags_list: + tags_list.remove("shared") + updated_tags = ", ".join(tags_list) + updated_workflow = workflow.model_copy(update={"tags": updated_tags}) + + with self._db.transaction() as cursor: + if user_id is not None: + cursor.execute( + """--sql + UPDATE workflow_library + SET workflow = ?, is_public = ? + WHERE workflow_id = ? AND category = 'user' AND user_id = ?; + """, + (updated_workflow.model_dump_json(), is_public, workflow_id, user_id), + ) + else: + cursor.execute( + """--sql + UPDATE workflow_library + SET workflow = ?, is_public = ? + WHERE workflow_id = ? AND category = 'user'; + """, + (updated_workflow.model_dump_json(), is_public, workflow_id), + ) + return self.get(workflow_id) + def get_many( self, order_by: WorkflowRecordOrderBy, @@ -104,6 +160,8 @@ def get_many( query: Optional[str] = None, tags: Optional[list[str]] = None, has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> PaginatedResults[WorkflowRecordListItemDTO]: with self._db.transaction() as cursor: # sanitize! @@ -122,7 +180,9 @@ def get_many( created_at, updated_at, opened_at, - tags + tags, + user_id, + is_public FROM workflow_library """ count_query = "SELECT COUNT(*) FROM workflow_library" @@ -177,6 +237,16 @@ def get_many( conditions.append(query_condition) params.extend([wildcard_query, wildcard_query, wildcard_query]) + if user_id is not None: + # Scope to the given user but always include default workflows + conditions.append("(user_id = ? OR category = 'default')") + params.append(user_id) + + if is_public is True: + conditions.append("is_public = TRUE") + elif is_public is False: + conditions.append("is_public = FALSE") + if conditions: # If there are conditions, add a WHERE clause and then join the conditions main_query += " WHERE " @@ -226,6 +296,8 @@ def counts_by_tag( tags: list[str], categories: Optional[list[WorkflowCategory]] = None, has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: if not tags: return {} @@ -248,6 +320,16 @@ def counts_by_tag( elif has_been_opened is False: base_conditions.append("opened_at IS NULL") + if user_id is not None: + # Scope to the given user but always include default workflows + base_conditions.append("(user_id = ? OR category = 'default')") + base_params.append(user_id) + + if is_public is True: + base_conditions.append("is_public = TRUE") + elif is_public is False: + base_conditions.append("is_public = FALSE") + # For each tag to count, run a separate query for tag in tags: # Start with the base conditions @@ -277,6 +359,8 @@ def counts_by_category( self, categories: list[WorkflowCategory], has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: with self._db.transaction() as cursor: result: dict[str, int] = {} @@ -296,6 +380,16 @@ def counts_by_category( elif has_been_opened is False: base_conditions.append("opened_at IS NULL") + if user_id is not None: + # Scope to the given user but always include default workflows + base_conditions.append("(user_id = ? OR category = 'default')") + base_params.append(user_id) + + if is_public is True: + base_conditions.append("is_public = TRUE") + elif is_public is False: + base_conditions.append("is_public = FALSE") + # For each category to count, run a separate query for category in categories: # Start with the base conditions @@ -321,20 +415,32 @@ def counts_by_category( return result - def update_opened_at(self, workflow_id: str) -> None: + def update_opened_at(self, workflow_id: str, user_id: Optional[str] = None) -> None: with self._db.transaction() as cursor: - cursor.execute( - f"""--sql - UPDATE workflow_library - SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW') - WHERE workflow_id = ?; - """, - (workflow_id,), - ) + if user_id is not None: + cursor.execute( + f"""--sql + UPDATE workflow_library + SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW') + WHERE workflow_id = ? AND user_id = ?; + """, + (workflow_id, user_id), + ) + else: + cursor.execute( + f"""--sql + UPDATE workflow_library + SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW') + WHERE workflow_id = ?; + """, + (workflow_id,), + ) def get_all_tags( self, categories: Optional[list[WorkflowCategory]] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> list[str]: with self._db.transaction() as cursor: conditions: list[str] = [] @@ -349,6 +455,16 @@ def get_all_tags( conditions.append(f"category IN ({placeholders})") params.extend([category.value for category in categories]) + if user_id is not None: + # Scope to the given user but always include default workflows + conditions.append("(user_id = ? OR category = 'default')") + params.append(user_id) + + if is_public is True: + conditions.append("is_public = TRUE") + elif is_public is False: + conditions.append("is_public = FALSE") + stmt = """--sql SELECT DISTINCT tags FROM workflow_library diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index af8476528d6..19e5a3a68e9 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -6463,6 +6463,23 @@ "title": "Has Been Opened" }, "description": "Whether to include/exclude recent workflows" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -6655,6 +6672,23 @@ "title": "Categories" }, "description": "The categories to include" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -6744,6 +6778,23 @@ "title": "Has Been Opened" }, "description": "Whether to include/exclude recent workflows" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -6812,6 +6863,23 @@ "title": "Has Been Opened" }, "description": "Whether to include/exclude recent workflows" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -7352,6 +7420,67 @@ } } } + }, + "/api/v1/workflows/i/{workflow_id}/is_public": { + "patch": { + "tags": ["workflows"], + "summary": "Update Workflow Is Public", + "description": "Updates whether a workflow is shared publicly", + "operationId": "update_workflow_is_public", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + }, + "description": "The workflow to update" + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "properties": { + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether the workflow should be shared publicly" + } + }, + "type": "object", + "required": ["is_public"], + "title": "Body_update_workflow_is_public" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WorkflowRecordDTO" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } } }, "components": { @@ -59137,10 +59266,20 @@ "workflow": { "$ref": "#/components/schemas/Workflow", "description": "The workflow." + }, + "user_id": { + "type": "string", + "title": "User Id", + "description": "The id of the user who owns this workflow." + }, + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether this workflow is shared with all users." } }, "type": "object", - "required": ["workflow_id", "name", "created_at", "updated_at", "workflow"], + "required": ["workflow_id", "name", "created_at", "updated_at", "workflow", "user_id", "is_public"], "title": "WorkflowRecordDTO" }, "WorkflowRecordListItemWithThumbnailDTO": { @@ -59222,15 +59361,35 @@ ], "title": "Thumbnail Url", "description": "The URL of the workflow thumbnail." + }, + "user_id": { + "type": "string", + "title": "User Id", + "description": "The id of the user who owns this workflow." + }, + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether this workflow is shared with all users." } }, "type": "object", - "required": ["workflow_id", "name", "created_at", "updated_at", "description", "category", "tags"], + "required": [ + "workflow_id", + "name", + "created_at", + "updated_at", + "description", + "category", + "tags", + "user_id", + "is_public" + ], "title": "WorkflowRecordListItemWithThumbnailDTO" }, "WorkflowRecordOrderBy": { "type": "string", - "enum": ["created_at", "updated_at", "opened_at", "name"], + "enum": ["created_at", "updated_at", "opened_at", "name", "is_public"], "title": "WorkflowRecordOrderBy", "description": "The order by options for workflow records" }, @@ -59303,10 +59462,20 @@ ], "title": "Thumbnail Url", "description": "The URL of the workflow thumbnail." + }, + "user_id": { + "type": "string", + "title": "User Id", + "description": "The id of the user who owns this workflow." + }, + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether this workflow is shared with all users." } }, "type": "object", - "required": ["workflow_id", "name", "created_at", "updated_at", "workflow"], + "required": ["workflow_id", "name", "created_at", "updated_at", "workflow", "user_id", "is_public"], "title": "WorkflowRecordWithThumbnailDTO" }, "WorkflowWithoutID": { diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 9b2aaddad73..201ea8badba 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -161,7 +161,17 @@ "imagesWithCount_other": "{{count}} images", "assetsWithCount_one": "{{count}} asset", "assetsWithCount_other": "{{count}} assets", - "updateBoardError": "Error updating board" + "updateBoardError": "Error updating board", + "setBoardVisibility": "Set Board Visibility", + "setVisibilityPrivate": "Set Private", + "setVisibilityShared": "Set Shared", + "setVisibilityPublic": "Set Public", + "visibilityPrivate": "Private", + "visibilityShared": "Shared", + "visibilityPublic": "Public", + "visibilityBadgeShared": "Shared board", + "visibilityBadgePublic": "Public board", + "updateBoardVisibilityError": "Error updating board visibility" }, "accordions": { "generation": { @@ -1168,7 +1178,9 @@ "name": "Name", "modelPickerFallbackNoModelsInstalled": "No models installed.", "modelPickerFallbackNoModelsInstalled2": "Visit the Model Manager to install models.", + "modelPickerFallbackNoModelsInstalledNonAdmin": "No models installed. Ask your InvokeAI administrator () to install some models.", "noModelsInstalledDesc1": "Install models with the", + "noModelsInstalledAskAdmin": "Ask your administrator to install some.", "noModelSelected": "No Model Selected", "noMatchingModels": "No matching models", "noModelsInstalled": "No models installed", @@ -1535,6 +1547,7 @@ "info": "Info", "invoke": { "addingImagesTo": "Adding images to", + "boardNotWritable": "You do not have write access to board \"{{boardName}}\". Select a board you own or switch to Uncategorized.", "modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your account settings to upgrade.", "invoke": "Invoke", "missingFieldTemplate": "Missing field template", @@ -2287,6 +2300,8 @@ "tags": "Tags", "yourWorkflows": "Your Workflows", "recentlyOpened": "Recently Opened", + "sharedWorkflows": "Shared Workflows", + "shareWorkflow": "Shared workflow", "noRecentWorkflows": "No Recent Workflows", "private": "Private", "shared": "Shared", @@ -3021,6 +3036,7 @@ "tileOverlap": "Tile Overlap", "postProcessingMissingModelWarning": "Visit the Model Manager to install a post-processing (image to image) model.", "missingModelsWarning": "Visit the Model Manager to install the required models:", + "missingModelsWarningNonAdmin": "Ask your InvokeAI administrator () to install the required models:", "mainModelDesc": "Main model (SD1.5 or SDXL architecture)", "tileControlNetModelDesc": "Tile ControlNet model for the chosen main model architecture", "upscaleModelDesc": "Upscale (image to image) model", @@ -3129,6 +3145,7 @@ }, "workflows": { "description": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results.", + "descriptionMultiuser": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results. You may share your workflows with other users of the system by selecting 'Shared workflow' when you create or edit it.", "learnMoreLink": "Learn more about creating workflows", "browseTemplates": { "title": "Browse Workflow Templates", @@ -3207,9 +3224,11 @@ "toGetStartedLocal": "To get started, make sure to download or import models needed to run Invoke. Then, enter a prompt in the box and click Invoke to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the Gallery or edit them to the Canvas.", "toGetStarted": "To get started, enter a prompt in the box and click Invoke to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the Gallery or edit them to the Canvas.", "toGetStartedWorkflow": "To get started, fill in the fields on the left and press Invoke to generate your image. Want to explore more workflows? Click the folder icon next to the workflow title to see a list of other templates you can try.", + "toGetStartedNonAdmin": "To get started, ask your InvokeAI administrator () to install the AI models needed to run Invoke. Then, enter a prompt in the box and click Invoke to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the Gallery or edit them to the Canvas.", "gettingStartedSeries": "Want more guidance? Check out our Getting Started Series for tips on unlocking the full potential of the Invoke Studio.", "lowVRAMMode": "For best performance, follow our Low VRAM guide.", - "noModelsInstalled": "It looks like you don't have any models installed! You can download a starter model bundle or import models." + "noModelsInstalled": "It looks like you don't have any models installed! You can download a starter model bundle or import models.", + "noModelsInstalledAskAdmin": "Ask your administrator to install some." }, "whatsNew": { "whatsNewInInvoke": "What's New in Invoke", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx index e0e72d12ffd..fa4c29b8f42 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx @@ -12,10 +12,14 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) = effect: (action) => { log.debug(action.payload, 'Bulk download requested'); - // If we have an item name, we are processing the bulk download locally and should use it as the toast id to - // prevent multiple toasts for the same item. + // Use a "preparing:" prefix so this toast cannot collide with the + // "ready to download" toast that arrives via the bulk_download_complete + // socket event. The background task can complete in under 20ms, so the + // socket event may arrive *before* this Redux middleware runs — without + // distinct IDs the "preparing" toast would overwrite the "ready" toast. + const itemName = action.payload.bulk_download_item_name; toast({ - id: action.payload.bulk_download_item_name ?? undefined, + id: itemName ? `preparing:${itemName}` : undefined, title: t('gallery.bulkDownloadRequested'), status: 'success', // Show the response message if it exists, otherwise show the default message diff --git a/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx b/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx index 00217eb7963..5ac6ffcb7c9 100644 --- a/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx +++ b/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx @@ -3,6 +3,7 @@ import { Combobox, ConfirmationAlertDialog, Flex, FormControl, Text } from '@inv import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAssertSingleton } from 'common/hooks/useAssertSingleton'; +import { selectCurrentUser } from 'features/auth/store/authSlice'; import { changeBoardReset, isModalOpenChanged, @@ -13,6 +14,7 @@ import { memo, useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; import { useAddImagesToBoardMutation, useRemoveImagesFromBoardMutation } from 'services/api/endpoints/images'; +import type { BoardDTO } from 'services/api/types'; const selectImagesToChange = createSelector( selectChangeBoardModalSlice, @@ -28,6 +30,7 @@ const ChangeBoardModal = () => { useAssertSingleton('ChangeBoardModal'); const dispatch = useAppDispatch(); const currentBoardId = useAppSelector(selectSelectedBoardId); + const currentUser = useAppSelector(selectCurrentUser); const [selectedBoardId, setSelectedBoardId] = useState(); const { data: boards, isFetching } = useListAllBoardsQuery({ include_archived: true }); const isModalOpen = useAppSelector(selectIsModalOpen); @@ -36,10 +39,20 @@ const ChangeBoardModal = () => { const [removeImagesFromBoard] = useRemoveImagesFromBoardMutation(); const { t } = useTranslation(); + // Returns true if the current user can write images to the given board. + const canWriteToBoard = useCallback( + (board: BoardDTO): boolean => { + const isOwnerOrAdmin = !currentUser || currentUser.is_admin || board.user_id === currentUser.user_id; + return isOwnerOrAdmin || board.board_visibility === 'public'; + }, + [currentUser] + ); + const options = useMemo(() => { return [{ label: t('boards.uncategorized'), value: 'none' }] .concat( (boards ?? []) + .filter(canWriteToBoard) .map((board) => ({ label: board.board_name, value: board.board_id, @@ -47,7 +60,7 @@ const ChangeBoardModal = () => { .sort((a, b) => a.label.localeCompare(b.label)) ) .filter((board) => board.value !== currentBoardId); - }, [boards, currentBoardId, t]); + }, [boards, canWriteToBoard, currentBoardId, t]); const value = useMemo(() => options.find((o) => o.value === selectedBoardId), [options, selectedBoardId]); diff --git a/invokeai/frontend/web/src/features/dnd/dnd.ts b/invokeai/frontend/web/src/features/dnd/dnd.ts index f5e38d4b944..ee648e82ef6 100644 --- a/invokeai/frontend/web/src/features/dnd/dnd.ts +++ b/invokeai/frontend/web/src/features/dnd/dnd.ts @@ -434,6 +434,49 @@ export const replaceCanvasEntityObjectsWithImageDndTarget: DndTarget< //#endregion //#region Add To Board +/** + * Check whether the current user can move images out of their source board. + * Returns false if the source board is a shared board not owned by the current user + * (and the user is not an admin). In that case, images can be viewed/used but not moved. + */ +const canMoveFromSourceBoard = (sourceBoardId: BoardId, getState: AppGetState): boolean => { + const state = getState(); + // In single-user mode (no auth), always allow + const currentUser = state.auth?.user; + if (!currentUser) { + return true; + } + // Admins can always move + if (currentUser.is_admin) { + return true; + } + // "Uncategorized" (none) — user's own uncategorized images, allow + if (sourceBoardId === 'none') { + return true; + } + // Look up the board from the RTK Query cache + const boardsQueryState = state.api?.queries; + if (boardsQueryState) { + for (const query of Object.values(boardsQueryState)) { + if (query?.data && Array.isArray(query.data)) { + const board = (query.data as Array<{ board_id: string; user_id?: string; board_visibility?: string }>).find( + (b) => b.board_id === sourceBoardId + ); + if (board) { + // Owner can always move + if (board.user_id === currentUser.user_id) { + return true; + } + // Non-owner can only move from public boards + return board.board_visibility === 'public'; + } + } + } + } + // Board not found in cache — allow by default to avoid blocking legitimate operations + return true; +}; + const _addToBoard = buildTypeAndKey('add-to-board'); export type AddImageToBoardDndTargetData = DndData< typeof _addToBoard.type, @@ -447,16 +490,23 @@ export const addImageToBoardDndTarget: DndTarget< ..._addToBoard, typeGuard: buildTypeGuard(_addToBoard.key), getData: buildGetData(_addToBoard.key, _addToBoard.type), - isValid: ({ sourceData, targetData }) => { + isValid: ({ sourceData, targetData, getState }) => { if (singleImageDndSource.typeGuard(sourceData)) { const currentBoard = sourceData.payload.imageDTO.board_id ?? 'none'; const destinationBoard = targetData.payload.boardId; - return currentBoard !== destinationBoard; + if (currentBoard === destinationBoard) { + return false; + } + // Don't allow moving images from shared boards the user doesn't own + return canMoveFromSourceBoard(currentBoard, getState); } if (multipleImageDndSource.typeGuard(sourceData)) { const currentBoard = sourceData.payload.board_id; const destinationBoard = targetData.payload.boardId; - return currentBoard !== destinationBoard; + if (currentBoard === destinationBoard) { + return false; + } + return canMoveFromSourceBoard(currentBoard, getState); } return false; }, @@ -491,15 +541,22 @@ export const removeImageFromBoardDndTarget: DndTarget< ..._removeFromBoard, typeGuard: buildTypeGuard(_removeFromBoard.key), getData: buildGetData(_removeFromBoard.key, _removeFromBoard.type), - isValid: ({ sourceData }) => { + isValid: ({ sourceData, getState }) => { if (singleImageDndSource.typeGuard(sourceData)) { const currentBoard = sourceData.payload.imageDTO.board_id ?? 'none'; - return currentBoard !== 'none'; + if (currentBoard === 'none') { + return false; + } + // Don't allow removing images from shared boards the user doesn't own + return canMoveFromSourceBoard(currentBoard, getState); } if (multipleImageDndSource.typeGuard(sourceData)) { const currentBoard = sourceData.payload.board_id; - return currentBoard !== 'none'; + if (currentBoard === 'none') { + return false; + } + return canMoveFromSourceBoard(currentBoard, getState); } return false; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx index 5cc25f6c038..d10dde6ee44 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx @@ -2,15 +2,26 @@ import type { ContextMenuProps } from '@invoke-ai/ui-library'; import { ContextMenu, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library'; import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { selectCurrentUser } from 'features/auth/store/authSlice'; import { $boardToDelete } from 'features/gallery/components/Boards/DeleteBoardModal'; import { selectAutoAddBoardId, selectAutoAssignBoardOnClick } from 'features/gallery/store/gallerySelectors'; import { autoAddBoardIdChanged } from 'features/gallery/store/gallerySlice'; import { toast } from 'features/toast/toast'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { PiArchiveBold, PiArchiveFill, PiDownloadBold, PiPlusBold, PiTrashSimpleBold } from 'react-icons/pi'; +import { + PiArchiveBold, + PiArchiveFill, + PiDownloadBold, + PiGlobeBold, + PiLockBold, + PiPlusBold, + PiShareNetworkBold, + PiTrashSimpleBold, +} from 'react-icons/pi'; import { useUpdateBoardMutation } from 'services/api/endpoints/boards'; import { useBulkDownloadImagesMutation } from 'services/api/endpoints/images'; +import { useBoardAccess } from 'services/api/hooks/useBoardAccess'; import { useBoardName } from 'services/api/hooks/useBoardName'; import type { BoardDTO } from 'services/api/types'; @@ -23,6 +34,7 @@ const BoardContextMenu = ({ board, children }: Props) => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const autoAssignBoardOnClick = useAppSelector(selectAutoAssignBoardOnClick); + const currentUser = useAppSelector(selectCurrentUser); const selectIsSelectedForAutoAdd = useMemo( () => createSelector(selectAutoAddBoardId, (autoAddBoardId) => board.board_id === autoAddBoardId), [board.board_id] @@ -35,6 +47,11 @@ const BoardContextMenu = ({ board, children }: Props) => { const [bulkDownload] = useBulkDownloadImagesMutation(); + // Only the board owner or admin can modify visibility + const canChangeVisibility = currentUser !== null && (currentUser.is_admin || board.user_id === currentUser.user_id); + + const { canDeleteBoard } = useBoardAccess(board); + const handleSetAutoAdd = useCallback(() => { dispatch(autoAddBoardIdChanged(board.board_id)); }, [board.board_id, dispatch]); @@ -64,6 +81,26 @@ const BoardContextMenu = ({ board, children }: Props) => { }); }, [board.board_id, updateBoard]); + const handleSetVisibility = useCallback( + async (visibility: 'private' | 'shared' | 'public') => { + try { + await updateBoard({ + board_id: board.board_id, + changes: { board_visibility: visibility }, + }).unwrap(); + } catch { + toast({ status: 'error', title: t('boards.updateBoardVisibilityError') }); + } + }, + [board.board_id, t, updateBoard] + ); + + const handleSetVisibilityPrivate = useCallback(() => handleSetVisibility('private'), [handleSetVisibility]); + + const handleSetVisibilityShared = useCallback(() => handleSetVisibility('shared'), [handleSetVisibility]); + + const handleSetVisibilityPublic = useCallback(() => handleSetVisibility('public'), [handleSetVisibility]); + const setAsBoardToDelete = useCallback(() => { $boardToDelete.set(board); }, [board]); @@ -83,18 +120,50 @@ const BoardContextMenu = ({ board, children }: Props) => { {board.archived && ( - } onClick={handleUnarchive}> + } onClick={handleUnarchive} isDisabled={!canDeleteBoard}> {t('boards.unarchiveBoard')} )} {!board.archived && ( - } onClick={handleArchive}> + } onClick={handleArchive} isDisabled={!canDeleteBoard}> {t('boards.archiveBoard')} )} - } onClick={setAsBoardToDelete} isDestructive> + {canChangeVisibility && ( + <> + } + onClick={handleSetVisibilityPrivate} + isDisabled={board.board_visibility === 'private'} + > + {t('boards.setVisibilityPrivate')} + + } + onClick={handleSetVisibilityShared} + isDisabled={board.board_visibility === 'shared'} + > + {t('boards.setVisibilityShared')} + + } + onClick={handleSetVisibilityPublic} + isDisabled={board.board_visibility === 'public'} + > + {t('boards.setVisibilityPublic')} + + + )} + + } + onClick={setAsBoardToDelete} + isDestructive + isDisabled={!canDeleteBoard} + > {t('boards.deleteBoard')} @@ -108,8 +177,14 @@ const BoardContextMenu = ({ board, children }: Props) => { t, handleBulkDownload, board.archived, + board.board_visibility, handleUnarchive, handleArchive, + canChangeVisibility, + handleSetVisibilityPrivate, + handleSetVisibilityShared, + handleSetVisibilityPublic, + canDeleteBoard, setAsBoardToDelete, ] ); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardEditableTitle.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardEditableTitle.tsx index 67c7dad6ed0..cf2749e3400 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardEditableTitle.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardEditableTitle.tsx @@ -7,6 +7,7 @@ import { memo, useCallback, useRef } from 'react'; import { useTranslation } from 'react-i18next'; import { PiPencilBold } from 'react-icons/pi'; import { useUpdateBoardMutation } from 'services/api/endpoints/boards'; +import { useBoardAccess } from 'services/api/hooks/useBoardAccess'; import type { BoardDTO } from 'services/api/types'; type Props = { @@ -19,6 +20,7 @@ export const BoardEditableTitle = memo(({ board, isSelected }: Props) => { const isHovering = useBoolean(false); const inputRef = useRef(null); const [updateBoard, updateBoardResult] = useUpdateBoardMutation(); + const { canRenameBoard } = useBoardAccess(board); const onChange = useCallback( async (board_name: string) => { @@ -51,13 +53,13 @@ export const BoardEditableTitle = memo(({ board, isSelected }: Props) => { fontWeight="semibold" userSelect="none" color={isSelected ? 'base.100' : 'base.300'} - onDoubleClick={editable.startEditing} - cursor="text" + onDoubleClick={canRenameBoard ? editable.startEditing : undefined} + cursor={canRenameBoard ? 'text' : 'default'} noOfLines={1} > {editable.value} - {isHovering.isTrue && ( + {canRenameBoard && isHovering.isTrue && ( } diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx index 4d821f819c6..10fbe618322 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx @@ -18,8 +18,9 @@ import { import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { PiArchiveBold, PiImageSquare } from 'react-icons/pi'; +import { PiArchiveBold, PiGlobeBold, PiImageSquare, PiShareNetworkBold } from 'react-icons/pi'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { useBoardAccess } from 'services/api/hooks/useBoardAccess'; import type { BoardDTO } from 'services/api/types'; const _hover: SystemStyleObject = { @@ -62,6 +63,8 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => { const showOwner = currentUser?.is_admin && board.owner_username; + const { canWriteImages } = useBoardAccess(board); + return ( @@ -99,6 +102,20 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => { {autoAddBoardId === board.board_id && } {board.archived && } + {board.board_visibility === 'shared' && ( + + + + + + )} + {board.board_visibility === 'public' && ( + + + + + + )} {board.image_count} | {board.asset_count} @@ -108,7 +125,12 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => { )} - + ); }; diff --git a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemChangeBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemChangeBoard.tsx index 71764870153..f5c044132e5 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemChangeBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemChangeBoard.tsx @@ -5,11 +5,15 @@ import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiFoldersBold } from 'react-icons/pi'; +import { useBoardAccess } from 'services/api/hooks/useBoardAccess'; +import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard'; export const ContextMenuItemChangeBoard = memo(() => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const imageDTO = useImageDTOContext(); + const selectedBoard = useSelectedBoard(); + const { canWriteImages } = useBoardAccess(selectedBoard); const onClick = useCallback(() => { dispatch(imagesToChangeSelected([imageDTO.image_name])); @@ -17,7 +21,7 @@ export const ContextMenuItemChangeBoard = memo(() => { }, [dispatch, imageDTO]); return ( - } onClickCapture={onClick}> + } onClickCapture={onClick} isDisabled={!canWriteImages}> {t('boards.changeBoard')} ); diff --git a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemDeleteImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemDeleteImage.tsx index e20221f3423..5dfa7116b17 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemDeleteImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemDeleteImage.tsx @@ -4,11 +4,15 @@ import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiTrashSimpleBold } from 'react-icons/pi'; +import { useBoardAccess } from 'services/api/hooks/useBoardAccess'; +import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard'; export const ContextMenuItemDeleteImage = memo(() => { const { t } = useTranslation(); const deleteImageModal = useDeleteImageModalApi(); const imageDTO = useImageDTOContext(); + const selectedBoard = useSelectedBoard(); + const { canWriteImages } = useBoardAccess(selectedBoard); const onClick = useCallback(async () => { try { @@ -18,6 +22,10 @@ export const ContextMenuItemDeleteImage = memo(() => { } }, [deleteImageModal, imageDTO]); + if (!canWriteImages) { + return null; + } + return ( } diff --git a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MultipleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MultipleSelectionMenuItems.tsx index d148332943c..ee3c8e4e985 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MultipleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MultipleSelectionMenuItems.tsx @@ -10,12 +10,16 @@ import { useStarImagesMutation, useUnstarImagesMutation, } from 'services/api/endpoints/images'; +import { useBoardAccess } from 'services/api/hooks/useBoardAccess'; +import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard'; const MultipleSelectionMenuItems = () => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const selection = useAppSelector((s) => s.gallery.selection); const deleteImageModal = useDeleteImageModalApi(); + const selectedBoard = useSelectedBoard(); + const { canWriteImages } = useBoardAccess(selectedBoard); const [starImages] = useStarImagesMutation(); const [unstarImages] = useUnstarImagesMutation(); @@ -53,11 +57,16 @@ const MultipleSelectionMenuItems = () => { } onClickCapture={handleBulkDownload}> {t('gallery.downloadSelection')} - } onClickCapture={handleChangeBoard}> + } onClickCapture={handleChangeBoard} isDisabled={!canWriteImages}> {t('boards.changeBoard')} - } onClickCapture={handleDeleteSelection}> + } + onClickCapture={handleDeleteSelection} + isDisabled={!canWriteImages} + > {t('gallery.deleteSelection')} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx index ccd58992ef6..af1d376887b 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx @@ -108,6 +108,25 @@ export const GalleryImage = memo(({ imageDTO }: Props) => { if (!element) { return; } + + const monitorBinding = monitorForElements({ + // This is a "global" drag start event, meaning that it is called for all drag events. + onDragStart: ({ source }) => { + // When we start dragging multiple images, set the dragging state to true if the dragged image is part of the + // selection. This is called for all drag events. + if ( + multipleImageDndSource.typeGuard(source.data) && + source.data.payload.image_names.includes(imageDTO.image_name) + ) { + setIsDragging(true); + } + }, + onDrop: () => { + // Always set the dragging state to false when a drop event occurs. + setIsDragging(false); + }, + }); + return combine( firefoxDndFix(element), draggable({ @@ -153,23 +172,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => { } }, }), - monitorForElements({ - // This is a "global" drag start event, meaning that it is called for all drag events. - onDragStart: ({ source }) => { - // When we start dragging multiple images, set the dragging state to true if the dragged image is part of the - // selection. This is called for all drag events. - if ( - multipleImageDndSource.typeGuard(source.data) && - source.data.payload.image_names.includes(imageDTO.image_name) - ) { - setIsDragging(true); - } - }, - onDrop: () => { - // Always set the dragging state to false when a drop event occurs. - setIsDragging(false); - }, - }) + monitorBinding ); }, [imageDTO, store]); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryItemDeleteIconButton.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryItemDeleteIconButton.tsx index 0a97bf819de..612e6361b14 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryItemDeleteIconButton.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryItemDeleteIconButton.tsx @@ -5,6 +5,8 @@ import type { MouseEvent } from 'react'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiTrashSimpleFill } from 'react-icons/pi'; +import { useBoardAccess } from 'services/api/hooks/useBoardAccess'; +import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard'; import type { ImageDTO } from 'services/api/types'; type Props = { @@ -15,6 +17,8 @@ export const GalleryItemDeleteIconButton = memo(({ imageDTO }: Props) => { const shift = useShiftModifier(); const { t } = useTranslation(); const deleteImageModal = useDeleteImageModalApi(); + const selectedBoard = useSelectedBoard(); + const { canWriteImages } = useBoardAccess(selectedBoard); const onClick = useCallback( (e: MouseEvent) => { @@ -24,7 +28,7 @@ export const GalleryItemDeleteIconButton = memo(({ imageDTO }: Props) => { [deleteImageModal, imageDTO] ); - if (!shift) { + if (!shift || !canWriteImages) { return null; } diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/NoContentForViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/NoContentForViewer.tsx index b8a522c3a65..c301922df95 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/NoContentForViewer.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/NoContentForViewer.tsx @@ -1,7 +1,9 @@ import type { ButtonProps } from '@invoke-ai/ui-library'; import { Alert, AlertDescription, AlertIcon, Button, Divider, Flex, Link, Spinner, Text } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import { InvokeLogoIcon } from 'common/components/InvokeLogoIcon'; +import { selectCurrentUser } from 'features/auth/store/authSlice'; import { LOADING_SYMBOL, useHasImages } from 'features/gallery/hooks/useHasImages'; import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore'; import { navigationApi } from 'features/ui/layouts/navigation-api'; @@ -9,16 +11,26 @@ import type { PropsWithChildren } from 'react'; import { memo, useCallback, useMemo } from 'react'; import { Trans, useTranslation } from 'react-i18next'; import { PiArrowSquareOutBold, PiImageBold } from 'react-icons/pi'; +import { useGetSetupStatusQuery } from 'services/api/endpoints/auth'; import { useMainModels } from 'services/api/hooks/modelsByType'; export const NoContentForViewer = memo(() => { const hasImages = useHasImages(); const [mainModels, { data }] = useMainModels(); + const { data: setupStatus } = useGetSetupStatusQuery(); + const user = useAppSelector(selectCurrentUser); const { t } = useTranslation(); + const isMultiuser = setupStatus?.multiuser_enabled ?? false; + const isAdmin = !isMultiuser || (user?.is_admin ?? false); + const adminEmail = setupStatus?.admin_email ?? null; + + const modelsLoaded = data !== undefined; + const hasModels = mainModels.length > 0; + const showStarterBundles = useMemo(() => { - return data && mainModels.length === 0; - }, [mainModels.length, data]); + return modelsLoaded && !hasModels && isAdmin; + }, [modelsLoaded, hasModels, isAdmin]); if (hasImages === LOADING_SYMBOL) { // Blank bg w/ a spinner. The new user experience components below have an invoke logo, but it's not centered. @@ -36,10 +48,18 @@ export const NoContentForViewer = memo(() => { - - {showStarterBundles && } - - + {isAdmin ? ( + // Admin / single-user mode + <> + {modelsLoaded && hasModels ? : } + {showStarterBundles && } + + + + ) : ( + // Non-admin user in multiuser mode + <>{modelsLoaded && hasModels ? : } + )} ); @@ -99,6 +119,32 @@ const GetStartedLocal = () => { ); }; +const GetStartedWithModels = () => { + return ( + + + + ); +}; + +const GetStartedNonAdmin = ({ adminEmail }: { adminEmail: string | null }) => { + const AdminEmailLink = adminEmail ? ( + + {adminEmail} + + ) : ( + + your administrator + + ); + + return ( + + + + ); +}; + const StarterBundlesCallout = () => { const handleClickDownloadStarterModels = useCallback(() => { navigationApi.switchToTab('models'); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx index d1774f9ded0..9b76fbbde67 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx @@ -1,10 +1,11 @@ import { Button, Text, useToast } from '@invoke-ai/ui-library'; import { useAppSelector } from 'app/store/storeHooks'; -import { selectIsAuthenticated } from 'features/auth/store/authSlice'; +import { selectCurrentUser, selectIsAuthenticated } from 'features/auth/store/authSlice'; import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore'; import { navigationApi } from 'features/ui/layouts/navigation-api'; import { useCallback, useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { useGetSetupStatusQuery } from 'services/api/endpoints/auth'; import { useMainModels } from 'services/api/hooks/modelsByType'; const TOAST_ID = 'starterModels'; @@ -15,6 +16,11 @@ export const useStarterModelsToast = () => { const [mainModels, { data }] = useMainModels(); const toast = useToast(); const isAuthenticated = useAppSelector(selectIsAuthenticated); + const { data: setupStatus } = useGetSetupStatusQuery(); + const user = useAppSelector(selectCurrentUser); + + const isMultiuser = setupStatus?.multiuser_enabled ?? false; + const isAdmin = !isMultiuser || (user?.is_admin ?? false); useEffect(() => { // Only show the toast if the user is authenticated @@ -33,17 +39,17 @@ export const useStarterModelsToast = () => { toast({ id: TOAST_ID, title: t('modelManager.noModelsInstalled'), - description: , + description: isAdmin ? : , status: 'info', isClosable: true, duration: null, onCloseComplete: () => setDidToast(true), }); } - }, [data, didToast, isAuthenticated, mainModels.length, t, toast]); + }, [data, didToast, isAuthenticated, isAdmin, mainModels.length, t, toast]); }; -const ToastDescription = () => { +const AdminToastDescription = () => { const { t } = useTranslation(); const toast = useToast(); @@ -62,3 +68,9 @@ const ToastDescription = () => { ); }; + +const NonAdminToastDescription = () => { + const { t } = useTranslation(); + + return {t('modelManager.noModelsInstalledAskAdmin')}; +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx index f6e1a18f6fd..60200c8801f 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx @@ -37,7 +37,7 @@ export const ModelManager = memo(() => { {t('common.modelManager')} - + {canManageModels && } {!!selectedModelKey && canManageModels && (