11import json
2+ from unittest .mock import AsyncMock , Mock
23
34import pytest
45from httpx import AsyncClient
89from dstack ._internal .core .models .users import GlobalRole , ProjectRole
910from dstack ._internal .server .models import CodeModel , RepoCredsModel , RepoModel
1011from dstack ._internal .server .services .projects import add_project_member
12+ from dstack ._internal .server .services .storage import BaseStorage
1113from dstack ._internal .server .testing .common import (
14+ create_code ,
1215 create_project ,
1316 create_repo ,
1417 create_repo_creds ,
@@ -321,11 +324,26 @@ async def test_deletes_repos(self, test_db, session: AsyncSession, client: Async
321324 assert repo is None
322325
323326
327+ @pytest .mark .asyncio
328+ @pytest .mark .parametrize ("test_db" , ["sqlite" , "postgres" ], indirect = True )
329+ @pytest .mark .usefixtures ("test_db" )
324330class TestUploadCode :
325- @pytest .mark .asyncio
326- @pytest .mark .parametrize ("test_db" , ["sqlite" , "postgres" ], indirect = True )
331+ @pytest .fixture
332+ def default_storage_mock (self , monkeypatch : pytest .MonkeyPatch ) -> Mock :
333+ storage_mock = Mock (spec_set = BaseStorage )
334+ monkeypatch .setattr (
335+ "dstack._internal.server.services.repos.get_default_storage" , lambda : storage_mock
336+ )
337+ return storage_mock
338+
339+ @pytest .fixture
340+ def no_default_storage (self , monkeypatch : pytest .MonkeyPatch ):
341+ monkeypatch .setattr (
342+ "dstack._internal.server.services.repos.get_default_storage" , lambda : None
343+ )
344+
327345 async def test_returns_403_if_not_project_member (
328- self , test_db , session : AsyncSession , client : AsyncClient
346+ self , session : AsyncSession , client : AsyncClient
329347 ):
330348 user = await create_user (session = session , global_role = GlobalRole .USER )
331349 project = await create_project (session = session , owner = user )
@@ -336,9 +354,8 @@ async def test_returns_403_if_not_project_member(
336354 )
337355 assert response .status_code == 403
338356
339- @pytest .mark .asyncio
340- @pytest .mark .parametrize ("test_db" , ["sqlite" , "postgres" ], indirect = True )
341- async def test_uploads_code (self , test_db , session : AsyncSession , client : AsyncClient ):
357+ @pytest .mark .usefixtures ("no_default_storage" )
358+ async def test_uploads_code_to_db (self , session : AsyncSession , client : AsyncClient ):
342359 user = await create_user (session = session , global_role = GlobalRole .USER )
343360 project = await create_project (session = session , owner = user )
344361 await add_project_member (
@@ -354,14 +371,38 @@ async def test_uploads_code(self, test_db, session: AsyncSession, client: AsyncC
354371 )
355372 assert response .status_code == 200 , response .json ()
356373 res = await session .execute (select (CodeModel ))
357- code = res .scalar ()
374+ code = res .scalar_one ()
358375 assert code .blob_hash == file [0 ]
359376 assert code .blob == file [1 ]
360377
361- @pytest .mark .asyncio
362- @pytest .mark .parametrize ("test_db" , ["sqlite" , "postgres" ], indirect = True )
378+ async def test_uploads_code_to_storage (
379+ self , session : AsyncSession , client : AsyncClient , default_storage_mock : Mock
380+ ):
381+ user = await create_user (session = session , global_role = GlobalRole .USER )
382+ project = await create_project (session = session , owner = user )
383+ await add_project_member (
384+ session = session , project = project , user = user , project_role = ProjectRole .USER
385+ )
386+ repo = await create_repo (session = session , project_id = project .id )
387+ file = ("blob_hash" , b"blob_content" )
388+ response = await client .post (
389+ f"/api/project/{ project .name } /repos/upload_code" ,
390+ headers = get_auth_headers (user .token ),
391+ params = {"repo_id" : repo .name },
392+ files = {"file" : file },
393+ )
394+ assert response .status_code == 200 , response .json ()
395+ res = await session .execute (select (CodeModel ))
396+ code = res .scalar_one ()
397+ assert code .blob_hash == file [0 ]
398+ assert code .blob is None
399+ default_storage_mock .upload_code .assert_called_once_with (
400+ project .name , repo .name , file [0 ], file [1 ]
401+ )
402+
403+ @pytest .mark .usefixtures ("no_default_storage" )
363404 async def test_uploads_same_code_for_different_repos (
364- self , test_db , session : AsyncSession , client : AsyncClient
405+ self , session : AsyncSession , client : AsyncClient
365406 ):
366407 user = await create_user (session = session , global_role = GlobalRole .USER )
367408 project = await create_project (session = session , owner = user )
@@ -388,3 +429,36 @@ async def test_uploads_same_code_for_different_repos(
388429 res = await session .execute (select (CodeModel ))
389430 codes = res .scalars ().all ()
390431 assert len (codes ) == 2
432+
433+ async def test_handles_race_condition (
434+ self ,
435+ monkeypatch : pytest .MonkeyPatch ,
436+ session : AsyncSession ,
437+ client : AsyncClient ,
438+ default_storage_mock : Mock ,
439+ ):
440+ user = await create_user (session = session , global_role = GlobalRole .USER )
441+ project = await create_project (session = session , owner = user )
442+ await add_project_member (
443+ session = session , project = project , user = user , project_role = ProjectRole .USER
444+ )
445+ repo = await create_repo (session = session , project_id = project .id )
446+ file = ("blob_hash" , b"blob_content" )
447+ code = await create_code (session = session , repo = repo , blob_hash = file [0 ], blob = file [1 ])
448+ monkeypatch .setattr (
449+ "dstack._internal.server.services.repos.get_code_model" , AsyncMock (return_value = None )
450+ )
451+ response = await client .post (
452+ f"/api/project/{ project .name } /repos/upload_code" ,
453+ headers = get_auth_headers (user .token ),
454+ params = {"repo_id" : repo .name },
455+ files = {"file" : file },
456+ )
457+ assert response .status_code == 200 , response .json ()
458+ res = await session .execute (select (CodeModel ))
459+ code = res .scalar_one ()
460+ assert code .blob_hash == file [0 ]
461+ assert code .blob == file [1 ]
462+ default_storage_mock .upload_code .assert_called_once_with (
463+ project .name , repo .name , file [0 ], file [1 ]
464+ )
0 commit comments