diff --git a/src/dstack/_internal/server/services/files.py b/src/dstack/_internal/server/services/files.py index d77ad94c78..7cba858aa5 100644 --- a/src/dstack/_internal/server/services/files.py +++ b/src/dstack/_internal/server/services/files.py @@ -1,11 +1,12 @@ import uuid from typing import Optional +import sqlalchemy.exc from fastapi import UploadFile from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.errors import ServerClientError +from dstack._internal.core.errors import ServerClientError, ServerError from dstack._internal.core.models.files import FileArchive from dstack._internal.server.models import FileArchiveModel, UserModel from dstack._internal.server.services.storage import get_default_storage @@ -72,6 +73,7 @@ async def upload_archive( if archive_model is not None: logger.debug("File archive (user_id=%s, hash=%s) already uploaded", user.id, archive_hash) return archive_model_to_archive(archive_model) + blob = await file.read() storage = get_default_storage() if storage is not None: @@ -81,9 +83,30 @@ async def upload_archive( blob_hash=archive_hash, blob=blob if storage is None else None, ) - session.add(archive_model) + + conflict = False + try: + async with session.begin_nested(): + session.add(archive_model) + except sqlalchemy.exc.IntegrityError as e: + # Concurrent API call just uploaded the same archive (TOC/TOU race condition), + # safe to ignore, but we need to refetch the archive from the DB to get its id + conflict = True + logger.debug("Conflict, rolling back: %s", e) await session.commit() - logger.debug("File archive (user_id=%s, hash=%s) has been uploaded", user.id, archive_hash) + + if conflict: + archive_model = await get_archive_model_by_hash( + session=session, + user=user, + hash=archive_hash, + ) + if archive_model is None: + raise ServerError("Failed to upload archive, unexpected conflict condition") + logger.debug("File archive (user_id=%s, hash=%s) already uploaded", user.id, archive_hash) + else: + logger.debug("File archive (user_id=%s, hash=%s) has been uploaded", user.id, archive_hash) + return archive_model_to_archive(archive_model) diff --git a/src/dstack/_internal/server/services/repos.py b/src/dstack/_internal/server/services/repos.py index 1bc2acdfba..fd5bf77f38 100644 --- a/src/dstack/_internal/server/services/repos.py +++ b/src/dstack/_internal/server/services/repos.py @@ -304,7 +304,13 @@ async def upload_code( blob=None, ) await run_async(storage.upload_code, project.name, repo.name, code.blob_hash, blob) - session.add(code) + try: + async with session.begin_nested(): + session.add(code) + except sqlalchemy.exc.IntegrityError as e: + # Concurrent API call just uploaded the same code blob (TOC/TOU race condition), + # safe to ignore + logger.debug("Conflict, rolling back: %s", e) await session.commit() diff --git a/src/dstack/_internal/server/services/storage/base.py b/src/dstack/_internal/server/services/storage/base.py index de11599693..bd203b31b4 100644 --- a/src/dstack/_internal/server/services/storage/base.py +++ b/src/dstack/_internal/server/services/storage/base.py @@ -6,7 +6,7 @@ class BaseStorage(ABC): @abstractmethod def upload_code( self, - project_id: str, + project_name: str, repo_id: str, code_hash: str, blob: bytes, @@ -16,7 +16,7 @@ def upload_code( @abstractmethod def get_code( self, - project_id: str, + project_name: str, repo_id: str, code_hash: str, ) -> Optional[bytes]: @@ -40,8 +40,8 @@ def get_archive( pass @staticmethod - def _get_code_key(project_id: str, repo_id: str, code_hash: str) -> str: - return f"data/projects/{project_id}/codes/{repo_id}/{code_hash}" + def _get_code_key(project_name: str, repo_id: str, code_hash: str) -> str: + return f"data/projects/{project_name}/codes/{repo_id}/{code_hash}" @staticmethod def _get_archive_key(user_id: str, archive_hash: str) -> str: diff --git a/src/dstack/_internal/server/services/storage/gcs.py b/src/dstack/_internal/server/services/storage/gcs.py index a0f9ac568f..d30c1e849a 100644 --- a/src/dstack/_internal/server/services/storage/gcs.py +++ b/src/dstack/_internal/server/services/storage/gcs.py @@ -20,21 +20,21 @@ def __init__( def upload_code( self, - project_id: str, + project_name: str, repo_id: str, code_hash: str, blob: bytes, ): - key = self._get_code_key(project_id, repo_id, code_hash) + key = self._get_code_key(project_name, repo_id, code_hash) self._upload(key, blob) def get_code( self, - project_id: str, + project_name: str, repo_id: str, code_hash: str, ) -> Optional[bytes]: - key = self._get_code_key(project_id, repo_id, code_hash) + key = self._get_code_key(project_name, repo_id, code_hash) return self._get(key) def upload_archive( diff --git a/src/dstack/_internal/server/services/storage/s3.py b/src/dstack/_internal/server/services/storage/s3.py index df4b652d1d..2921e69f15 100644 --- a/src/dstack/_internal/server/services/storage/s3.py +++ b/src/dstack/_internal/server/services/storage/s3.py @@ -22,21 +22,21 @@ def __init__( def upload_code( self, - project_id: str, + project_name: str, repo_id: str, code_hash: str, blob: bytes, ): - key = self._get_code_key(project_id, repo_id, code_hash) + key = self._get_code_key(project_name, repo_id, code_hash) self._upload(key, blob) def get_code( self, - project_id: str, + project_name: str, repo_id: str, code_hash: str, ) -> Optional[bytes]: - key = self._get_code_key(project_id, repo_id, code_hash) + key = self._get_code_key(project_name, repo_id, code_hash) return self._get(key) def upload_archive( diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 1d49de206c..4ee80e440b 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -91,6 +91,7 @@ ) from dstack._internal.server.models import ( BackendModel, + CodeModel, ComputeGroupModel, DecryptedString, EventModel, @@ -267,6 +268,22 @@ async def create_repo( return repo +async def create_code( + session: AsyncSession, + repo: RepoModel, + blob_hash: str = "blob_hash", + blob: Optional[bytes] = b"blob_content", +) -> CodeModel: + code = CodeModel( + repo_id=repo.id, + blob_hash=blob_hash, + blob=blob, + ) + session.add(code) + await session.commit() + return code + + async def create_repo_creds( session: AsyncSession, repo_id: UUID, @@ -293,7 +310,7 @@ async def create_file_archive( session: AsyncSession, user_id: UUID, blob_hash: str = "blob_hash", - blob: bytes = b"blob_content", + blob: Optional[bytes] = b"blob_content", ) -> FileArchiveModel: archive = FileArchiveModel( user_id=user_id, diff --git a/src/tests/_internal/server/routers/test_files.py b/src/tests/_internal/server/routers/test_files.py index 28f851e1c8..c83938e71b 100644 --- a/src/tests/_internal/server/routers/test_files.py +++ b/src/tests/_internal/server/routers/test_files.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pytest from httpx import AsyncClient @@ -146,3 +146,41 @@ async def test_uploads_archive_to_storage( default_storage_mock.upload_archive.assert_called_once_with( str(user.id), self.file_hash, self.file_content ) + + async def test_handles_race_condition( + self, + monkeypatch: pytest.MonkeyPatch, + session: AsyncSession, + client: AsyncClient, + default_storage_mock: Mock, + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + existing_archive = await create_file_archive( + session=session, user_id=user.id, blob_hash=self.file_hash, blob=None + ) + monkeypatch.setattr( + "dstack._internal.server.services.files.get_archive_model_by_hash", + # first call checks if already uploaded (not yet) + # second call refetches after unique constraint violation + AsyncMock(side_effect=[None, existing_archive]), + ) + response = await client.post( + "/api/files/upload_archive", + headers=get_auth_headers(user.token), + files={"file": self.file}, + ) + assert response.status_code == 200, response.json() + assert response.json() == { + "id": str(existing_archive.id), + "hash": self.file_hash, + } + res = await session.execute( + select(FileArchiveModel).where(FileArchiveModel.user_id == user.id) + ) + archive = res.scalar_one() + assert archive.id == existing_archive.id + assert archive.blob_hash == self.file_hash + assert archive.blob is None + default_storage_mock.upload_archive.assert_called_once_with( + str(user.id), self.file_hash, self.file_content + ) diff --git a/src/tests/_internal/server/routers/test_repos.py b/src/tests/_internal/server/routers/test_repos.py index f986319856..d85cd6635b 100644 --- a/src/tests/_internal/server/routers/test_repos.py +++ b/src/tests/_internal/server/routers/test_repos.py @@ -1,4 +1,5 @@ import json +from unittest.mock import AsyncMock, Mock import pytest from httpx import AsyncClient @@ -8,7 +9,9 @@ from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.server.models import CodeModel, RepoCredsModel, RepoModel from dstack._internal.server.services.projects import add_project_member +from dstack._internal.server.services.storage import BaseStorage from dstack._internal.server.testing.common import ( + create_code, create_project, create_repo, create_repo_creds, @@ -321,11 +324,26 @@ async def test_deletes_repos(self, test_db, session: AsyncSession, client: Async assert repo is None +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("test_db") class TestUploadCode: - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.fixture + def default_storage_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + storage_mock = Mock(spec_set=BaseStorage) + monkeypatch.setattr( + "dstack._internal.server.services.repos.get_default_storage", lambda: storage_mock + ) + return storage_mock + + @pytest.fixture + def no_default_storage(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + "dstack._internal.server.services.repos.get_default_storage", lambda: None + ) + async def test_returns_403_if_not_project_member( - self, test_db, session: AsyncSession, client: AsyncClient + self, session: AsyncSession, client: AsyncClient ): user = await create_user(session=session, global_role=GlobalRole.USER) project = await create_project(session=session, owner=user) @@ -336,9 +354,8 @@ async def test_returns_403_if_not_project_member( ) assert response.status_code == 403 - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_uploads_code(self, test_db, session: AsyncSession, client: AsyncClient): + @pytest.mark.usefixtures("no_default_storage") + async def test_uploads_code_to_db(self, session: AsyncSession, client: AsyncClient): user = await create_user(session=session, global_role=GlobalRole.USER) project = await create_project(session=session, owner=user) await add_project_member( @@ -354,14 +371,38 @@ async def test_uploads_code(self, test_db, session: AsyncSession, client: AsyncC ) assert response.status_code == 200, response.json() res = await session.execute(select(CodeModel)) - code = res.scalar() + code = res.scalar_one() assert code.blob_hash == file[0] assert code.blob == file[1] - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_uploads_code_to_storage( + self, session: AsyncSession, client: AsyncClient, default_storage_mock: Mock + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + repo = await create_repo(session=session, project_id=project.id) + file = ("blob_hash", b"blob_content") + response = await client.post( + f"/api/project/{project.name}/repos/upload_code", + headers=get_auth_headers(user.token), + params={"repo_id": repo.name}, + files={"file": file}, + ) + assert response.status_code == 200, response.json() + res = await session.execute(select(CodeModel)) + code = res.scalar_one() + assert code.blob_hash == file[0] + assert code.blob is None + default_storage_mock.upload_code.assert_called_once_with( + project.name, repo.name, file[0], file[1] + ) + + @pytest.mark.usefixtures("no_default_storage") async def test_uploads_same_code_for_different_repos( - self, test_db, session: AsyncSession, client: AsyncClient + self, session: AsyncSession, client: AsyncClient ): user = await create_user(session=session, global_role=GlobalRole.USER) project = await create_project(session=session, owner=user) @@ -388,3 +429,36 @@ async def test_uploads_same_code_for_different_repos( res = await session.execute(select(CodeModel)) codes = res.scalars().all() assert len(codes) == 2 + + async def test_handles_race_condition( + self, + monkeypatch: pytest.MonkeyPatch, + session: AsyncSession, + client: AsyncClient, + default_storage_mock: Mock, + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + repo = await create_repo(session=session, project_id=project.id) + file = ("blob_hash", b"blob_content") + code = await create_code(session=session, repo=repo, blob_hash=file[0], blob=file[1]) + monkeypatch.setattr( + "dstack._internal.server.services.repos.get_code_model", AsyncMock(return_value=None) + ) + response = await client.post( + f"/api/project/{project.name}/repos/upload_code", + headers=get_auth_headers(user.token), + params={"repo_id": repo.name}, + files={"file": file}, + ) + assert response.status_code == 200, response.json() + res = await session.execute(select(CodeModel)) + code = res.scalar_one() + assert code.blob_hash == file[0] + assert code.blob == file[1] + default_storage_mock.upload_code.assert_called_once_with( + project.name, repo.name, file[0], file[1] + )