Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions src/dstack/_internal/server/services/files.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import uuid
from typing import Optional

import sqlalchemy.exc
from fastapi import UploadFile
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal.core.errors import ServerClientError
from dstack._internal.core.errors import ServerClientError, ServerError
from dstack._internal.core.models.files import FileArchive
from dstack._internal.server.models import FileArchiveModel, UserModel
from dstack._internal.server.services.storage import get_default_storage
Expand Down Expand Up @@ -72,6 +73,7 @@ async def upload_archive(
if archive_model is not None:
logger.debug("File archive (user_id=%s, hash=%s) already uploaded", user.id, archive_hash)
return archive_model_to_archive(archive_model)

blob = await file.read()
storage = get_default_storage()
if storage is not None:
Expand All @@ -81,9 +83,30 @@ async def upload_archive(
blob_hash=archive_hash,
blob=blob if storage is None else None,
)
session.add(archive_model)

conflict = False
try:
async with session.begin_nested():
session.add(archive_model)
except sqlalchemy.exc.IntegrityError as e:
# Concurrent API call just uploaded the same archive (TOC/TOU race condition),
# safe to ignore, but we need to refetch the archive from the DB to get its id
conflict = True
logger.debug("Conflict, rolling back: %s", e)
await session.commit()
logger.debug("File archive (user_id=%s, hash=%s) has been uploaded", user.id, archive_hash)

if conflict:
archive_model = await get_archive_model_by_hash(
session=session,
user=user,
hash=archive_hash,
)
if archive_model is None:
raise ServerError("Failed to upload archive, unexpected conflict condition")
logger.debug("File archive (user_id=%s, hash=%s) already uploaded", user.id, archive_hash)
else:
logger.debug("File archive (user_id=%s, hash=%s) has been uploaded", user.id, archive_hash)

return archive_model_to_archive(archive_model)


Expand Down
8 changes: 7 additions & 1 deletion src/dstack/_internal/server/services/repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,13 @@ async def upload_code(
blob=None,
)
await run_async(storage.upload_code, project.name, repo.name, code.blob_hash, blob)
session.add(code)
try:
async with session.begin_nested():
session.add(code)
except sqlalchemy.exc.IntegrityError as e:
# Concurrent API call just uploaded the same code blob (TOC/TOU race condition),
# safe to ignore
logger.debug("Conflict, rolling back: %s", e)
await session.commit()


Expand Down
8 changes: 4 additions & 4 deletions src/dstack/_internal/server/services/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class BaseStorage(ABC):
@abstractmethod
def upload_code(
self,
project_id: str,
project_name: str,
repo_id: str,
code_hash: str,
blob: bytes,
Expand All @@ -16,7 +16,7 @@ def upload_code(
@abstractmethod
def get_code(
self,
project_id: str,
project_name: str,
repo_id: str,
code_hash: str,
) -> Optional[bytes]:
Expand All @@ -40,8 +40,8 @@ def get_archive(
pass

@staticmethod
def _get_code_key(project_id: str, repo_id: str, code_hash: str) -> str:
return f"data/projects/{project_id}/codes/{repo_id}/{code_hash}"
def _get_code_key(project_name: str, repo_id: str, code_hash: str) -> str:
return f"data/projects/{project_name}/codes/{repo_id}/{code_hash}"

@staticmethod
def _get_archive_key(user_id: str, archive_hash: str) -> str:
Expand Down
8 changes: 4 additions & 4 deletions src/dstack/_internal/server/services/storage/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@ def __init__(

def upload_code(
self,
project_id: str,
project_name: str,
repo_id: str,
code_hash: str,
blob: bytes,
):
key = self._get_code_key(project_id, repo_id, code_hash)
key = self._get_code_key(project_name, repo_id, code_hash)
self._upload(key, blob)

def get_code(
self,
project_id: str,
project_name: str,
repo_id: str,
code_hash: str,
) -> Optional[bytes]:
key = self._get_code_key(project_id, repo_id, code_hash)
key = self._get_code_key(project_name, repo_id, code_hash)
return self._get(key)

def upload_archive(
Expand Down
8 changes: 4 additions & 4 deletions src/dstack/_internal/server/services/storage/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ def __init__(

def upload_code(
self,
project_id: str,
project_name: str,
repo_id: str,
code_hash: str,
blob: bytes,
):
key = self._get_code_key(project_id, repo_id, code_hash)
key = self._get_code_key(project_name, repo_id, code_hash)
self._upload(key, blob)

def get_code(
self,
project_id: str,
project_name: str,
repo_id: str,
code_hash: str,
) -> Optional[bytes]:
key = self._get_code_key(project_id, repo_id, code_hash)
key = self._get_code_key(project_name, repo_id, code_hash)
return self._get(key)

def upload_archive(
Expand Down
19 changes: 18 additions & 1 deletion src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
)
from dstack._internal.server.models import (
BackendModel,
CodeModel,
ComputeGroupModel,
DecryptedString,
EventModel,
Expand Down Expand Up @@ -267,6 +268,22 @@ async def create_repo(
return repo


async def create_code(
session: AsyncSession,
repo: RepoModel,
blob_hash: str = "blob_hash",
blob: Optional[bytes] = b"blob_content",
) -> CodeModel:
code = CodeModel(
repo_id=repo.id,
blob_hash=blob_hash,
blob=blob,
)
session.add(code)
await session.commit()
return code


async def create_repo_creds(
session: AsyncSession,
repo_id: UUID,
Expand All @@ -293,7 +310,7 @@ async def create_file_archive(
session: AsyncSession,
user_id: UUID,
blob_hash: str = "blob_hash",
blob: bytes = b"blob_content",
blob: Optional[bytes] = b"blob_content",
) -> FileArchiveModel:
archive = FileArchiveModel(
user_id=user_id,
Expand Down
40 changes: 39 additions & 1 deletion src/tests/_internal/server/routers/test_files.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock

import pytest
from httpx import AsyncClient
Expand Down Expand Up @@ -146,3 +146,41 @@ async def test_uploads_archive_to_storage(
default_storage_mock.upload_archive.assert_called_once_with(
str(user.id), self.file_hash, self.file_content
)

async def test_handles_race_condition(
self,
monkeypatch: pytest.MonkeyPatch,
session: AsyncSession,
client: AsyncClient,
default_storage_mock: Mock,
):
user = await create_user(session=session, global_role=GlobalRole.USER)
existing_archive = await create_file_archive(
session=session, user_id=user.id, blob_hash=self.file_hash, blob=None
)
monkeypatch.setattr(
"dstack._internal.server.services.files.get_archive_model_by_hash",
# first call checks if already uploaded (not yet)
# second call refetches after unique constraint violation
AsyncMock(side_effect=[None, existing_archive]),
)
response = await client.post(
"/api/files/upload_archive",
headers=get_auth_headers(user.token),
files={"file": self.file},
)
assert response.status_code == 200, response.json()
assert response.json() == {
"id": str(existing_archive.id),
"hash": self.file_hash,
}
res = await session.execute(
select(FileArchiveModel).where(FileArchiveModel.user_id == user.id)
)
archive = res.scalar_one()
assert archive.id == existing_archive.id
assert archive.blob_hash == self.file_hash
assert archive.blob is None
default_storage_mock.upload_archive.assert_called_once_with(
str(user.id), self.file_hash, self.file_content
)
94 changes: 84 additions & 10 deletions src/tests/_internal/server/routers/test_repos.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from unittest.mock import AsyncMock, Mock

import pytest
from httpx import AsyncClient
Expand All @@ -8,7 +9,9 @@
from dstack._internal.core.models.users import GlobalRole, ProjectRole
from dstack._internal.server.models import CodeModel, RepoCredsModel, RepoModel
from dstack._internal.server.services.projects import add_project_member
from dstack._internal.server.services.storage import BaseStorage
from dstack._internal.server.testing.common import (
create_code,
create_project,
create_repo,
create_repo_creds,
Expand Down Expand Up @@ -321,11 +324,26 @@ async def test_deletes_repos(self, test_db, session: AsyncSession, client: Async
assert repo is None


@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@pytest.mark.usefixtures("test_db")
class TestUploadCode:
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@pytest.fixture
def default_storage_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock:
storage_mock = Mock(spec_set=BaseStorage)
monkeypatch.setattr(
"dstack._internal.server.services.repos.get_default_storage", lambda: storage_mock
)
return storage_mock

@pytest.fixture
def no_default_storage(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
"dstack._internal.server.services.repos.get_default_storage", lambda: None
)

async def test_returns_403_if_not_project_member(
self, test_db, session: AsyncSession, client: AsyncClient
self, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
Expand All @@ -336,9 +354,8 @@ async def test_returns_403_if_not_project_member(
)
assert response.status_code == 403

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_uploads_code(self, test_db, session: AsyncSession, client: AsyncClient):
@pytest.mark.usefixtures("no_default_storage")
async def test_uploads_code_to_db(self, session: AsyncSession, client: AsyncClient):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
Expand All @@ -354,14 +371,38 @@ async def test_uploads_code(self, test_db, session: AsyncSession, client: AsyncC
)
assert response.status_code == 200, response.json()
res = await session.execute(select(CodeModel))
code = res.scalar()
code = res.scalar_one()
assert code.blob_hash == file[0]
assert code.blob == file[1]

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_uploads_code_to_storage(
self, session: AsyncSession, client: AsyncClient, default_storage_mock: Mock
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
repo = await create_repo(session=session, project_id=project.id)
file = ("blob_hash", b"blob_content")
response = await client.post(
f"/api/project/{project.name}/repos/upload_code",
headers=get_auth_headers(user.token),
params={"repo_id": repo.name},
files={"file": file},
)
assert response.status_code == 200, response.json()
res = await session.execute(select(CodeModel))
code = res.scalar_one()
assert code.blob_hash == file[0]
assert code.blob is None
default_storage_mock.upload_code.assert_called_once_with(
project.name, repo.name, file[0], file[1]
)

@pytest.mark.usefixtures("no_default_storage")
async def test_uploads_same_code_for_different_repos(
self, test_db, session: AsyncSession, client: AsyncClient
self, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
Expand All @@ -388,3 +429,36 @@ async def test_uploads_same_code_for_different_repos(
res = await session.execute(select(CodeModel))
codes = res.scalars().all()
assert len(codes) == 2

async def test_handles_race_condition(
self,
monkeypatch: pytest.MonkeyPatch,
session: AsyncSession,
client: AsyncClient,
default_storage_mock: Mock,
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
repo = await create_repo(session=session, project_id=project.id)
file = ("blob_hash", b"blob_content")
code = await create_code(session=session, repo=repo, blob_hash=file[0], blob=file[1])
monkeypatch.setattr(
"dstack._internal.server.services.repos.get_code_model", AsyncMock(return_value=None)
)
response = await client.post(
f"/api/project/{project.name}/repos/upload_code",
headers=get_auth_headers(user.token),
params={"repo_id": repo.name},
files={"file": file},
)
assert response.status_code == 200, response.json()
res = await session.execute(select(CodeModel))
code = res.scalar_one()
assert code.blob_hash == file[0]
assert code.blob == file[1]
default_storage_mock.upload_code.assert_called_once_with(
project.name, repo.name, file[0], file[1]
)
Loading