Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/12367.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Honor the `group_id` query parameter of the `GET /folders` REST API so the virtual folder listing can be scoped to a single project, instead of always returning every accessible folder (including model-store and, for admins, all cross-project folders) regardless of the requested scope.
8 changes: 8 additions & 0 deletions src/ai/backend/manager/repositories/vfolder/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any, cast
from uuid import UUID

import sqlalchemy as sa
from sqlalchemy import exc as sa_exc
Expand Down Expand Up @@ -415,9 +416,13 @@ async def list_accessible_vfolders(
domain_name: str,
allowed_vfolder_types: list[str],
extra_conditions: sa.sql.elements.ColumnElement[bool] | None = None,
project_id: UUID | None = None,
) -> VFolderListResult:
"""
List all VFolders accessible to a user.

When ``project_id`` is given, project-owned vfolders are restricted to
that project; user-owned and invited vfolders are unaffected.
Returns VFolderListResult with access information.
"""
async with self._db.begin_readonly_session() as session:
Expand All @@ -429,6 +434,9 @@ async def list_accessible_vfolders(
domain_name=domain_name,
allowed_vfolder_types=allowed_vfolder_types,
extra_vf_conds=extra_conditions,
extra_vf_group_conds=(
(VFolderRow.group == project_id) if project_id is not None else None
),
)

vfolder_access_infos = []
Expand Down
12 changes: 12 additions & 0 deletions src/ai/backend/manager/services/vfolder/services/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Any,
cast,
)
from uuid import UUID

import aiohttp
import msgpack
Expand All @@ -18,6 +19,7 @@
from ai.backend.common.bgtask.bgtask import BackgroundTaskManager
from ai.backend.common.clients.valkey_client.valkey_stat.client import ValkeyStatClient
from ai.backend.common.contexts.user import current_user
from ai.backend.common.data.permission.types import ScopeType
from ai.backend.common.defs import VFOLDER_GROUP_PERMISSION_MODE
from ai.backend.common.etcd import AsyncEtcd
from ai.backend.common.exception import UnreachableError
Expand Down Expand Up @@ -579,12 +581,22 @@ async def list(self, action: ListVFolderAction) -> ListVFolderActionResult:
raise ObjectNotFound(object_name="User")
user_role, user_domain_name = user_info

# When scoped to a project (i.e. the request carried a `group_id`),
# restrict project-owned vfolders to that project. Without this, the
# listing returns the full accessible union (every MODEL_STORE project,
# and every domain project for admins) regardless of the scope, which
# made the REST `group_id` query parameter a no-op.
project_id: UUID | None = None
if action.scope_type() == ScopeType.PROJECT:
project_id = UUID(action.scope_id())

# Use repository to get accessible vfolders
vfolder_list_result = await self._vfolder_repository.list_accessible_vfolders(
user_id=action.user_uuid,
user_role=user_role,
domain_name=user_domain_name,
allowed_vfolder_types=list(allowed_vfolder_types),
project_id=project_id,
)

vfolders = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,54 @@ async def test_no_vfolders_returns_empty_list(
assert isinstance(result, ListVFolderActionResult)
assert result.vfolders == []

async def test_project_scope_restricts_project_vfolders(
self,
vfolder_service: VFolderService,
mock_vfolder_repository: MagicMock,
user_uuid: uuid.UUID,
group_uuid: uuid.UUID,
) -> None:
mock_vfolder_repository.get_user_info = AsyncMock(return_value=(UserRole.USER, "default"))
mock_vfolder_repository.list_accessible_vfolders = AsyncMock(
return_value=VFolderListResult(vfolders=[])
)

action = ListVFolderAction(
user_uuid=user_uuid,
_scope_type=ScopeType.PROJECT,
_scope_id=str(group_uuid),
)

await vfolder_service.list(action)

assert (
mock_vfolder_repository.list_accessible_vfolders.call_args.kwargs["project_id"]
== group_uuid
)

async def test_user_scope_does_not_restrict_project_vfolders(
self,
vfolder_service: VFolderService,
mock_vfolder_repository: MagicMock,
user_uuid: uuid.UUID,
) -> None:
mock_vfolder_repository.get_user_info = AsyncMock(return_value=(UserRole.USER, "default"))
mock_vfolder_repository.list_accessible_vfolders = AsyncMock(
return_value=VFolderListResult(vfolders=[])
)

action = ListVFolderAction(
user_uuid=user_uuid,
_scope_type=ScopeType.USER,
_scope_id=str(user_uuid),
)

await vfolder_service.list(action)

assert (
mock_vfolder_repository.list_accessible_vfolders.call_args.kwargs["project_id"] is None
)

async def test_returns_owned_and_shared_vfolders(
self,
vfolder_service: VFolderService,
Expand Down
Loading