Skip to content

Commit a325f56

Browse files
authored
Handle concurrent repo blob/file archive uploads (#3737)
Fixes: #3731 Fixes: #3732
1 parent 45c648f commit a325f56

File tree

8 files changed

+186
-28
lines changed

8 files changed

+186
-28
lines changed

src/dstack/_internal/server/services/files.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import uuid
22
from typing import Optional
33

4+
import sqlalchemy.exc
45
from fastapi import UploadFile
56
from sqlalchemy import select
67
from sqlalchemy.ext.asyncio import AsyncSession
78

8-
from dstack._internal.core.errors import ServerClientError
9+
from dstack._internal.core.errors import ServerClientError, ServerError
910
from dstack._internal.core.models.files import FileArchive
1011
from dstack._internal.server.models import FileArchiveModel, UserModel
1112
from dstack._internal.server.services.storage import get_default_storage
@@ -72,6 +73,7 @@ async def upload_archive(
7273
if archive_model is not None:
7374
logger.debug("File archive (user_id=%s, hash=%s) already uploaded", user.id, archive_hash)
7475
return archive_model_to_archive(archive_model)
76+
7577
blob = await file.read()
7678
storage = get_default_storage()
7779
if storage is not None:
@@ -81,9 +83,30 @@ async def upload_archive(
8183
blob_hash=archive_hash,
8284
blob=blob if storage is None else None,
8385
)
84-
session.add(archive_model)
86+
87+
conflict = False
88+
try:
89+
async with session.begin_nested():
90+
session.add(archive_model)
91+
except sqlalchemy.exc.IntegrityError as e:
92+
# Concurrent API call just uploaded the same archive (TOC/TOU race condition),
93+
# safe to ignore, but we need to refetch the archive from the DB to get its id
94+
conflict = True
95+
logger.debug("Conflict, rolling back: %s", e)
8596
await session.commit()
86-
logger.debug("File archive (user_id=%s, hash=%s) has been uploaded", user.id, archive_hash)
97+
98+
if conflict:
99+
archive_model = await get_archive_model_by_hash(
100+
session=session,
101+
user=user,
102+
hash=archive_hash,
103+
)
104+
if archive_model is None:
105+
raise ServerError("Failed to upload archive, unexpected conflict condition")
106+
logger.debug("File archive (user_id=%s, hash=%s) already uploaded", user.id, archive_hash)
107+
else:
108+
logger.debug("File archive (user_id=%s, hash=%s) has been uploaded", user.id, archive_hash)
109+
87110
return archive_model_to_archive(archive_model)
88111

89112

src/dstack/_internal/server/services/repos.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,13 @@ async def upload_code(
304304
blob=None,
305305
)
306306
await run_async(storage.upload_code, project.name, repo.name, code.blob_hash, blob)
307-
session.add(code)
307+
try:
308+
async with session.begin_nested():
309+
session.add(code)
310+
except sqlalchemy.exc.IntegrityError as e:
311+
# Concurrent API call just uploaded the same code blob (TOC/TOU race condition),
312+
# safe to ignore
313+
logger.debug("Conflict, rolling back: %s", e)
308314
await session.commit()
309315

310316

src/dstack/_internal/server/services/storage/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ class BaseStorage(ABC):
66
@abstractmethod
77
def upload_code(
88
self,
9-
project_id: str,
9+
project_name: str,
1010
repo_id: str,
1111
code_hash: str,
1212
blob: bytes,
@@ -16,7 +16,7 @@ def upload_code(
1616
@abstractmethod
1717
def get_code(
1818
self,
19-
project_id: str,
19+
project_name: str,
2020
repo_id: str,
2121
code_hash: str,
2222
) -> Optional[bytes]:
@@ -40,8 +40,8 @@ def get_archive(
4040
pass
4141

4242
@staticmethod
43-
def _get_code_key(project_id: str, repo_id: str, code_hash: str) -> str:
44-
return f"data/projects/{project_id}/codes/{repo_id}/{code_hash}"
43+
def _get_code_key(project_name: str, repo_id: str, code_hash: str) -> str:
44+
return f"data/projects/{project_name}/codes/{repo_id}/{code_hash}"
4545

4646
@staticmethod
4747
def _get_archive_key(user_id: str, archive_hash: str) -> str:

src/dstack/_internal/server/services/storage/gcs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,21 @@ def __init__(
2020

2121
def upload_code(
2222
self,
23-
project_id: str,
23+
project_name: str,
2424
repo_id: str,
2525
code_hash: str,
2626
blob: bytes,
2727
):
28-
key = self._get_code_key(project_id, repo_id, code_hash)
28+
key = self._get_code_key(project_name, repo_id, code_hash)
2929
self._upload(key, blob)
3030

3131
def get_code(
3232
self,
33-
project_id: str,
33+
project_name: str,
3434
repo_id: str,
3535
code_hash: str,
3636
) -> Optional[bytes]:
37-
key = self._get_code_key(project_id, repo_id, code_hash)
37+
key = self._get_code_key(project_name, repo_id, code_hash)
3838
return self._get(key)
3939

4040
def upload_archive(

src/dstack/_internal/server/services/storage/s3.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,21 @@ def __init__(
2222

2323
def upload_code(
2424
self,
25-
project_id: str,
25+
project_name: str,
2626
repo_id: str,
2727
code_hash: str,
2828
blob: bytes,
2929
):
30-
key = self._get_code_key(project_id, repo_id, code_hash)
30+
key = self._get_code_key(project_name, repo_id, code_hash)
3131
self._upload(key, blob)
3232

3333
def get_code(
3434
self,
35-
project_id: str,
35+
project_name: str,
3636
repo_id: str,
3737
code_hash: str,
3838
) -> Optional[bytes]:
39-
key = self._get_code_key(project_id, repo_id, code_hash)
39+
key = self._get_code_key(project_name, repo_id, code_hash)
4040
return self._get(key)
4141

4242
def upload_archive(

src/dstack/_internal/server/testing/common.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
)
9292
from dstack._internal.server.models import (
9393
BackendModel,
94+
CodeModel,
9495
ComputeGroupModel,
9596
DecryptedString,
9697
EventModel,
@@ -267,6 +268,22 @@ async def create_repo(
267268
return repo
268269

269270

271+
async def create_code(
272+
session: AsyncSession,
273+
repo: RepoModel,
274+
blob_hash: str = "blob_hash",
275+
blob: Optional[bytes] = b"blob_content",
276+
) -> CodeModel:
277+
code = CodeModel(
278+
repo_id=repo.id,
279+
blob_hash=blob_hash,
280+
blob=blob,
281+
)
282+
session.add(code)
283+
await session.commit()
284+
return code
285+
286+
270287
async def create_repo_creds(
271288
session: AsyncSession,
272289
repo_id: UUID,
@@ -293,7 +310,7 @@ async def create_file_archive(
293310
session: AsyncSession,
294311
user_id: UUID,
295312
blob_hash: str = "blob_hash",
296-
blob: bytes = b"blob_content",
313+
blob: Optional[bytes] = b"blob_content",
297314
) -> FileArchiveModel:
298315
archive = FileArchiveModel(
299316
user_id=user_id,

src/tests/_internal/server/routers/test_files.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import Mock
1+
from unittest.mock import AsyncMock, Mock
22

33
import pytest
44
from httpx import AsyncClient
@@ -146,3 +146,41 @@ async def test_uploads_archive_to_storage(
146146
default_storage_mock.upload_archive.assert_called_once_with(
147147
str(user.id), self.file_hash, self.file_content
148148
)
149+
150+
async def test_handles_race_condition(
151+
self,
152+
monkeypatch: pytest.MonkeyPatch,
153+
session: AsyncSession,
154+
client: AsyncClient,
155+
default_storage_mock: Mock,
156+
):
157+
user = await create_user(session=session, global_role=GlobalRole.USER)
158+
existing_archive = await create_file_archive(
159+
session=session, user_id=user.id, blob_hash=self.file_hash, blob=None
160+
)
161+
monkeypatch.setattr(
162+
"dstack._internal.server.services.files.get_archive_model_by_hash",
163+
# first call checks if already uploaded (not yet)
164+
# second call refetches after unique constraint violation
165+
AsyncMock(side_effect=[None, existing_archive]),
166+
)
167+
response = await client.post(
168+
"/api/files/upload_archive",
169+
headers=get_auth_headers(user.token),
170+
files={"file": self.file},
171+
)
172+
assert response.status_code == 200, response.json()
173+
assert response.json() == {
174+
"id": str(existing_archive.id),
175+
"hash": self.file_hash,
176+
}
177+
res = await session.execute(
178+
select(FileArchiveModel).where(FileArchiveModel.user_id == user.id)
179+
)
180+
archive = res.scalar_one()
181+
assert archive.id == existing_archive.id
182+
assert archive.blob_hash == self.file_hash
183+
assert archive.blob is None
184+
default_storage_mock.upload_archive.assert_called_once_with(
185+
str(user.id), self.file_hash, self.file_content
186+
)

src/tests/_internal/server/routers/test_repos.py

Lines changed: 84 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from unittest.mock import AsyncMock, Mock
23

34
import pytest
45
from httpx import AsyncClient
@@ -8,7 +9,9 @@
89
from dstack._internal.core.models.users import GlobalRole, ProjectRole
910
from dstack._internal.server.models import CodeModel, RepoCredsModel, RepoModel
1011
from dstack._internal.server.services.projects import add_project_member
12+
from dstack._internal.server.services.storage import BaseStorage
1113
from dstack._internal.server.testing.common import (
14+
create_code,
1215
create_project,
1316
create_repo,
1417
create_repo_creds,
@@ -321,11 +324,26 @@ async def test_deletes_repos(self, test_db, session: AsyncSession, client: Async
321324
assert repo is None
322325

323326

327+
@pytest.mark.asyncio
328+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
329+
@pytest.mark.usefixtures("test_db")
324330
class TestUploadCode:
325-
@pytest.mark.asyncio
326-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
331+
@pytest.fixture
332+
def default_storage_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock:
333+
storage_mock = Mock(spec_set=BaseStorage)
334+
monkeypatch.setattr(
335+
"dstack._internal.server.services.repos.get_default_storage", lambda: storage_mock
336+
)
337+
return storage_mock
338+
339+
@pytest.fixture
340+
def no_default_storage(self, monkeypatch: pytest.MonkeyPatch):
341+
monkeypatch.setattr(
342+
"dstack._internal.server.services.repos.get_default_storage", lambda: None
343+
)
344+
327345
async def test_returns_403_if_not_project_member(
328-
self, test_db, session: AsyncSession, client: AsyncClient
346+
self, session: AsyncSession, client: AsyncClient
329347
):
330348
user = await create_user(session=session, global_role=GlobalRole.USER)
331349
project = await create_project(session=session, owner=user)
@@ -336,9 +354,8 @@ async def test_returns_403_if_not_project_member(
336354
)
337355
assert response.status_code == 403
338356

339-
@pytest.mark.asyncio
340-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
341-
async def test_uploads_code(self, test_db, session: AsyncSession, client: AsyncClient):
357+
@pytest.mark.usefixtures("no_default_storage")
358+
async def test_uploads_code_to_db(self, session: AsyncSession, client: AsyncClient):
342359
user = await create_user(session=session, global_role=GlobalRole.USER)
343360
project = await create_project(session=session, owner=user)
344361
await add_project_member(
@@ -354,14 +371,38 @@ async def test_uploads_code(self, test_db, session: AsyncSession, client: AsyncC
354371
)
355372
assert response.status_code == 200, response.json()
356373
res = await session.execute(select(CodeModel))
357-
code = res.scalar()
374+
code = res.scalar_one()
358375
assert code.blob_hash == file[0]
359376
assert code.blob == file[1]
360377

361-
@pytest.mark.asyncio
362-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
378+
async def test_uploads_code_to_storage(
379+
self, session: AsyncSession, client: AsyncClient, default_storage_mock: Mock
380+
):
381+
user = await create_user(session=session, global_role=GlobalRole.USER)
382+
project = await create_project(session=session, owner=user)
383+
await add_project_member(
384+
session=session, project=project, user=user, project_role=ProjectRole.USER
385+
)
386+
repo = await create_repo(session=session, project_id=project.id)
387+
file = ("blob_hash", b"blob_content")
388+
response = await client.post(
389+
f"/api/project/{project.name}/repos/upload_code",
390+
headers=get_auth_headers(user.token),
391+
params={"repo_id": repo.name},
392+
files={"file": file},
393+
)
394+
assert response.status_code == 200, response.json()
395+
res = await session.execute(select(CodeModel))
396+
code = res.scalar_one()
397+
assert code.blob_hash == file[0]
398+
assert code.blob is None
399+
default_storage_mock.upload_code.assert_called_once_with(
400+
project.name, repo.name, file[0], file[1]
401+
)
402+
403+
@pytest.mark.usefixtures("no_default_storage")
363404
async def test_uploads_same_code_for_different_repos(
364-
self, test_db, session: AsyncSession, client: AsyncClient
405+
self, session: AsyncSession, client: AsyncClient
365406
):
366407
user = await create_user(session=session, global_role=GlobalRole.USER)
367408
project = await create_project(session=session, owner=user)
@@ -388,3 +429,36 @@ async def test_uploads_same_code_for_different_repos(
388429
res = await session.execute(select(CodeModel))
389430
codes = res.scalars().all()
390431
assert len(codes) == 2
432+
433+
async def test_handles_race_condition(
434+
self,
435+
monkeypatch: pytest.MonkeyPatch,
436+
session: AsyncSession,
437+
client: AsyncClient,
438+
default_storage_mock: Mock,
439+
):
440+
user = await create_user(session=session, global_role=GlobalRole.USER)
441+
project = await create_project(session=session, owner=user)
442+
await add_project_member(
443+
session=session, project=project, user=user, project_role=ProjectRole.USER
444+
)
445+
repo = await create_repo(session=session, project_id=project.id)
446+
file = ("blob_hash", b"blob_content")
447+
code = await create_code(session=session, repo=repo, blob_hash=file[0], blob=file[1])
448+
monkeypatch.setattr(
449+
"dstack._internal.server.services.repos.get_code_model", AsyncMock(return_value=None)
450+
)
451+
response = await client.post(
452+
f"/api/project/{project.name}/repos/upload_code",
453+
headers=get_auth_headers(user.token),
454+
params={"repo_id": repo.name},
455+
files={"file": file},
456+
)
457+
assert response.status_code == 200, response.json()
458+
res = await session.execute(select(CodeModel))
459+
code = res.scalar_one()
460+
assert code.blob_hash == file[0]
461+
assert code.blob == file[1]
462+
default_storage_mock.upload_code.assert_called_once_with(
463+
project.name, repo.name, file[0], file[1]
464+
)

0 commit comments

Comments
 (0)