diff --git a/changes/12367.fix.md b/changes/12367.fix.md new file mode 100644 index 00000000000..fc6184ef89b --- /dev/null +++ b/changes/12367.fix.md @@ -0,0 +1 @@ +Honor the `group_id` query parameter of the `GET /folders` REST API so the virtual folder listing can be scoped to a single project, instead of always returning every accessible folder (including model-store and, for admins, all cross-project folders) regardless of the requested scope. diff --git a/src/ai/backend/manager/repositories/vfolder/repository.py b/src/ai/backend/manager/repositories/vfolder/repository.py index 20bb6090ac7..1cc76dde1d1 100644 --- a/src/ai/backend/manager/repositories/vfolder/repository.py +++ b/src/ai/backend/manager/repositories/vfolder/repository.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from datetime import UTC, datetime from typing import Any, cast +from uuid import UUID import sqlalchemy as sa from sqlalchemy import exc as sa_exc @@ -415,9 +416,13 @@ async def list_accessible_vfolders( domain_name: str, allowed_vfolder_types: list[str], extra_conditions: sa.sql.elements.ColumnElement[bool] | None = None, + project_id: UUID | None = None, ) -> VFolderListResult: """ List all VFolders accessible to a user. + + When ``project_id`` is given, project-owned vfolders are restricted to + that project; user-owned and invited vfolders are unaffected. Returns VFolderListResult with access information. """ async with self._db.begin_readonly_session() as session: @@ -429,6 +434,9 @@ async def list_accessible_vfolders( domain_name=domain_name, allowed_vfolder_types=allowed_vfolder_types, extra_vf_conds=extra_conditions, + extra_vf_group_conds=( + (VFolderRow.group == project_id) if project_id is not None else None + ), ) vfolder_access_infos = [] diff --git a/src/ai/backend/manager/services/vfolder/services/vfolder.py b/src/ai/backend/manager/services/vfolder/services/vfolder.py index 4a4d3aa0022..65ff9fb3d69 100644 --- a/src/ai/backend/manager/services/vfolder/services/vfolder.py +++ b/src/ai/backend/manager/services/vfolder/services/vfolder.py @@ -8,6 +8,7 @@ Any, cast, ) +from uuid import UUID import aiohttp import msgpack @@ -18,6 +19,7 @@ from ai.backend.common.bgtask.bgtask import BackgroundTaskManager from ai.backend.common.clients.valkey_client.valkey_stat.client import ValkeyStatClient from ai.backend.common.contexts.user import current_user +from ai.backend.common.data.permission.types import ScopeType from ai.backend.common.defs import VFOLDER_GROUP_PERMISSION_MODE from ai.backend.common.etcd import AsyncEtcd from ai.backend.common.exception import UnreachableError @@ -579,12 +581,22 @@ async def list(self, action: ListVFolderAction) -> ListVFolderActionResult: raise ObjectNotFound(object_name="User") user_role, user_domain_name = user_info + # When scoped to a project (i.e. the request carried a `group_id`), + # restrict project-owned vfolders to that project. Without this, the + # listing returns the full accessible union (every MODEL_STORE project, + # and every domain project for admins) regardless of the scope, which + # made the REST `group_id` query parameter a no-op. + project_id: UUID | None = None + if action.scope_type() == ScopeType.PROJECT: + project_id = UUID(action.scope_id()) + # Use repository to get accessible vfolders vfolder_list_result = await self._vfolder_repository.list_accessible_vfolders( user_id=action.user_uuid, user_role=user_role, domain_name=user_domain_name, allowed_vfolder_types=list(allowed_vfolder_types), + project_id=project_id, ) vfolders = [ diff --git a/tests/unit/manager/services/vfolder/test_vfolder_crud_service.py b/tests/unit/manager/services/vfolder/test_vfolder_crud_service.py index bd34ee7aa2e..96417166927 100644 --- a/tests/unit/manager/services/vfolder/test_vfolder_crud_service.py +++ b/tests/unit/manager/services/vfolder/test_vfolder_crud_service.py @@ -587,6 +587,54 @@ async def test_no_vfolders_returns_empty_list( assert isinstance(result, ListVFolderActionResult) assert result.vfolders == [] + async def test_project_scope_restricts_project_vfolders( + self, + vfolder_service: VFolderService, + mock_vfolder_repository: MagicMock, + user_uuid: uuid.UUID, + group_uuid: uuid.UUID, + ) -> None: + mock_vfolder_repository.get_user_info = AsyncMock(return_value=(UserRole.USER, "default")) + mock_vfolder_repository.list_accessible_vfolders = AsyncMock( + return_value=VFolderListResult(vfolders=[]) + ) + + action = ListVFolderAction( + user_uuid=user_uuid, + _scope_type=ScopeType.PROJECT, + _scope_id=str(group_uuid), + ) + + await vfolder_service.list(action) + + assert ( + mock_vfolder_repository.list_accessible_vfolders.call_args.kwargs["project_id"] + == group_uuid + ) + + async def test_user_scope_does_not_restrict_project_vfolders( + self, + vfolder_service: VFolderService, + mock_vfolder_repository: MagicMock, + user_uuid: uuid.UUID, + ) -> None: + mock_vfolder_repository.get_user_info = AsyncMock(return_value=(UserRole.USER, "default")) + mock_vfolder_repository.list_accessible_vfolders = AsyncMock( + return_value=VFolderListResult(vfolders=[]) + ) + + action = ListVFolderAction( + user_uuid=user_uuid, + _scope_type=ScopeType.USER, + _scope_id=str(user_uuid), + ) + + await vfolder_service.list(action) + + assert ( + mock_vfolder_repository.list_accessible_vfolders.call_args.kwargs["project_id"] is None + ) + async def test_returns_owned_and_shared_vfolders( self, vfolder_service: VFolderService,