diff --git a/runner/internal/runner/api/http.go b/runner/internal/runner/api/http.go index 34220acc6..4919852af 100644 --- a/runner/internal/runner/api/http.go +++ b/runner/internal/runner/api/http.go @@ -38,11 +38,15 @@ func (s *Server) metricsGetHandler(w http.ResponseWriter, r *http.Request) (inte return metrics, nil } +// submitPostHandler must be called first +// It's safe to call it more than once func (s *Server) submitPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.executor.Lock() defer s.executor.Unlock() state := s.executor.GetRunnerState() - if state != executor.WaitSubmit { + if state == executor.WaitRun { + log.Warning(r.Context(), "Job already submitted, submitting again", "current_state", state) + } else if state != executor.WaitSubmit { log.Warning(r.Context(), "Executor doesn't wait submit", "current_state", state) return nil, &api.Error{Status: http.StatusConflict} } @@ -52,20 +56,19 @@ func (s *Server) submitPostHandler(w http.ResponseWriter, r *http.Request) (inte log.Error(r.Context(), "Failed to decode submit body", "err", err) return nil, err } - // todo go-playground/validator s.executor.SetJob(body) - s.jobBarrierCh <- nil // notify server that job submitted + s.executor.SetRunnerState(executor.WaitRun) return nil, nil } -// uploadArchivePostHandler may be called 0 or more times, and must be called after submitPostHandler -// and before uploadCodePostHandler +// If uploadArchivePostHandler is called, it must be called after submitPostHandler and before runPostHandler +// It's safe to call it more than once with the same archive func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.executor.Lock() defer s.executor.Unlock() - if s.executor.GetRunnerState() != executor.WaitCode { + if s.executor.GetRunnerState() != executor.WaitRun { return nil, &api.Error{Status: http.StatusConflict} } @@ -123,10 +126,12 @@ func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request return nil, nil } +// If uploadCodePostHandler is called, it must be called after submitPostHandler and before runPostHandler +// It's safe to call it more than once func (s *Server) uploadCodePostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.executor.Lock() defer s.executor.Unlock() - if s.executor.GetRunnerState() != executor.WaitCode { + if s.executor.GetRunnerState() != executor.WaitRun { return nil, &api.Error{Status: http.StatusConflict} } @@ -139,8 +144,6 @@ func (s *Server) uploadCodePostHandler(w http.ResponseWriter, r *http.Request) ( return nil, fmt.Errorf("copy request body: %w", err) } - s.executor.SetRunnerState(executor.WaitRun) - return nil, nil } @@ -151,6 +154,7 @@ func (s *Server) runPostHandler(w http.ResponseWriter, r *http.Request) (interfa return nil, &api.Error{Status: http.StatusConflict} } s.executor.SetRunnerState(executor.ServeLogs) + s.jobBarrierCh <- nil // notify server that job started s.executor.Unlock() var runCtx context.Context diff --git a/runner/internal/runner/api/server.go b/runner/internal/runner/api/server.go index 11b76d887..227ea1dbb 100644 --- a/runner/internal/runner/api/server.go +++ b/runner/internal/runner/api/server.go @@ -21,8 +21,8 @@ type Server struct { pullDoneCh chan interface{} // Closed then /api/pull gave everything wsDoneCh chan interface{} // Closed then /logs_ws gave everything - submitWaitDuration time.Duration - logsWaitDuration time.Duration + startWaitDuration time.Duration + logsWaitDuration time.Duration executor executor.Executor cancelRun context.CancelFunc @@ -51,8 +51,8 @@ func NewServer(ctx context.Context, address string, version string, ex executor. pullDoneCh: make(chan interface{}), wsDoneCh: make(chan interface{}), - submitWaitDuration: 5 * time.Minute, - logsWaitDuration: 5 * time.Minute, + startWaitDuration: 5 * time.Minute, + logsWaitDuration: 5 * time.Minute, executor: ex, @@ -82,7 +82,7 @@ func (s *Server) Run(ctx context.Context) error { select { case <-s.jobBarrierCh: // job started - case <-time.After(s.submitWaitDuration): + case <-time.After(s.startWaitDuration): log.Error(ctx, "Job didn't start in time, shutting down") return errors.New("no job submitted") case <-ctx.Done(): diff --git a/runner/internal/runner/executor/base.go b/runner/internal/runner/executor/base.go index b8093e5e7..bafe714bc 100644 --- a/runner/internal/runner/executor/base.go +++ b/runner/internal/runner/executor/base.go @@ -9,12 +9,21 @@ import ( ) type Executor interface { + // It must be safe to call SetJob more than once + SetJob(job schemas.SubmitBody) + // It must be safe to call WriteFileArchive more than once with the same archive + WriteFileArchive(id string, src io.Reader) error + // It must be safe to call WriteRepoBlob more than once + WriteRepoBlob(src io.Reader) error + Run(ctx context.Context) error + GetHistory(timestamp int64) *schemas.PullResponse GetJobWsLogsHistory() []schemas.LogEvent + GetRunnerState() string + SetRunnerState(state string) + GetJobInfo(ctx context.Context) (username string, workingDir string, err error) - Run(ctx context.Context) error - SetJob(job schemas.SubmitBody) SetJobState(ctx context.Context, state schemas.JobState) SetJobStateWithTerminationReason( ctx context.Context, @@ -22,9 +31,7 @@ type Executor interface { terminationReason types.TerminationReason, terminationMessage string, ) - SetRunnerState(state string) - WriteFileArchive(id string, src io.Reader) error - WriteRepoBlob(src io.Reader) error + Lock() RLock() RUnlock() diff --git a/runner/internal/runner/executor/executor.go b/runner/internal/runner/executor/executor.go index 3662a45aa..2c3f3f03f 100644 --- a/runner/internal/runner/executor/executor.go +++ b/runner/internal/runner/executor/executor.go @@ -295,7 +295,6 @@ func (ex *RunExecutor) SetJob(body schemas.SubmitBody) { ex.secrets = body.Secrets ex.repoCredentials = body.RepoCredentials ex.jobLogs.SetQuota(body.LogQuotaHour) - ex.state = WaitCode } func (ex *RunExecutor) SetJobState(ctx context.Context, state schemas.JobState) { diff --git a/runner/internal/runner/executor/repo.go b/runner/internal/runner/executor/repo.go index 116e4b225..40cb495fb 100644 --- a/runner/internal/runner/executor/repo.go +++ b/runner/internal/runner/executor/repo.go @@ -106,9 +106,8 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error { return fmt.Errorf("prepare git repo: %w", err) } case "local", "virtual": - log.Trace(ctx, "Extracting tar archive") - if err := ex.prepareArchive(ctx); err != nil { - return fmt.Errorf("prepare archive: %w", err) + if err := ex.extractCodeArchive(ctx); err != nil { + return fmt.Errorf("extract code archive: %w", err) } default: return fmt.Errorf("unknown RepoType: %s", ex.getRepoData().RepoType) @@ -164,26 +163,32 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error { return fmt.Errorf("set repo config: %w", err) } + if ex.repoBlobPath == "" { + log.Trace(ctx, "No diff to apply") + return nil + } log.Trace(ctx, "Applying diff") repoDiff, err := os.ReadFile(ex.repoBlobPath) if err != nil { return fmt.Errorf("read repo diff: %w", err) } - if len(repoDiff) > 0 { - if err := repo.ApplyDiff(ctx, ex.repoDir, string(repoDiff)); err != nil { - return fmt.Errorf("apply diff: %w", err) - } + if err := repo.ApplyDiff(ctx, ex.repoDir, string(repoDiff)); err != nil { + return fmt.Errorf("apply diff: %w", err) } return nil } -func (ex *RunExecutor) prepareArchive(ctx context.Context) error { +func (ex *RunExecutor) extractCodeArchive(ctx context.Context) error { + if ex.repoBlobPath == "" { + log.Trace(ctx, "No code archive to extract") + return nil + } + log.Trace(ctx, "Extracting code archive", "src", ex.repoBlobPath, "dst", ex.repoDir) file, err := os.Open(ex.repoBlobPath) if err != nil { return fmt.Errorf("open code archive: %w", err) } defer func() { _ = file.Close() }() - log.Trace(ctx, "Extracting code archive", "src", ex.repoBlobPath, "dst", ex.repoDir) if err := extract.Tar(ctx, file, ex.repoDir, nil); err != nil { return fmt.Errorf("extract tar archive: %w", err) } diff --git a/runner/internal/runner/executor/states.go b/runner/internal/runner/executor/states.go index cfa6dc15e..f18887144 100644 --- a/runner/internal/runner/executor/states.go +++ b/runner/internal/runner/executor/states.go @@ -2,7 +2,6 @@ package executor const ( WaitSubmit = "wait_submit" - WaitCode = "wait_code" WaitRun = "wait_run" ServeLogs = "serve_logs" WaitLogsFinished = "wait_logs_finished" diff --git a/src/dstack/_internal/core/models/repos/base.py b/src/dstack/_internal/core/models/repos/base.py index 03c4b8676..e9d864305 100644 --- a/src/dstack/_internal/core/models/repos/base.py +++ b/src/dstack/_internal/core/models/repos/base.py @@ -22,6 +22,10 @@ class Repo(ABC): repo_dir: Optional[str] run_repo_data: "repos.AnyRunRepoData" + @abstractmethod + def has_code_to_write(self) -> bool: + pass + @abstractmethod def write_code_file(self, fp: BinaryIO) -> str: pass diff --git a/src/dstack/_internal/core/models/repos/local.py b/src/dstack/_internal/core/models/repos/local.py index 2316c2c4f..a7e162e7c 100644 --- a/src/dstack/_internal/core/models/repos/local.py +++ b/src/dstack/_internal/core/models/repos/local.py @@ -73,6 +73,10 @@ def __init__( self.repo_id = repo_id self.run_repo_data = repo_data + def has_code_to_write(self) -> bool: + # LocalRepo is deprecated, no need for real implementation + return True + def write_code_file(self, fp: BinaryIO) -> str: repo_path = Path(self.run_repo_data.repo_dir) with tarfile.TarFile(mode="w", fileobj=fp) as t: diff --git a/src/dstack/_internal/core/models/repos/remote.py b/src/dstack/_internal/core/models/repos/remote.py index 3bfd34024..e03773c41 100644 --- a/src/dstack/_internal/core/models/repos/remote.py +++ b/src/dstack/_internal/core/models/repos/remote.py @@ -183,6 +183,13 @@ def __init__( self.repo_id = repo_id self.run_repo_data = repo_data + def has_code_to_write(self) -> bool: + # repo_diff is: + # * None for RemoteRepo.from_url() + # * an empty string for RemoteRepo.from_dir() if there are no changes ("clean" state) + # * a non-empty string for RemoteRepo.from_dir() if there are changes ("dirty" state) + return bool(self.run_repo_data.repo_diff) + def write_code_file(self, fp: BinaryIO) -> str: if self.run_repo_data.repo_diff is not None: fp.write(self.run_repo_data.repo_diff.encode()) diff --git a/src/dstack/_internal/core/models/repos/virtual.py b/src/dstack/_internal/core/models/repos/virtual.py index 4a975481a..24ffeac61 100644 --- a/src/dstack/_internal/core/models/repos/virtual.py +++ b/src/dstack/_internal/core/models/repos/virtual.py @@ -73,6 +73,9 @@ def add_file(self, path: str, content: bytes): self.files[resolve_relative_path(path).as_posix()] = content + def has_code_to_write(self) -> bool: + return len(self.files) > 0 + def write_code_file(self, fp: BinaryIO) -> str: with tarfile.TarFile(mode="w", fileobj=fp) as t: for path, content in sorted(self.files.items()): diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index b85ba77c4..369894310 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -1320,7 +1320,7 @@ def _submit_job_to_runner( job: Job, jrd: Optional[JobRuntimeData], cluster_info: ClusterInfo, - code: bytes, + code: Optional[bytes], file_archives: Iterable[tuple[uuid.UUID, bytes]], secrets: Dict[str, str], repo_credentials: Optional[RemoteRepoCreds], @@ -1352,11 +1352,15 @@ def _submit_job_to_runner( repo_credentials=repo_credentials, instance_env=instance_env, ) - logger.debug("%s: uploading file archive(s)", fmt(job_model)) for archive_id, archive in file_archives: + logger.debug("%s: uploading file archive: %s", fmt(job_model), archive_id) runner_client.upload_archive(archive_id, archive) - logger.debug("%s: uploading code", fmt(job_model)) - runner_client.upload_code(code) + if code is None and not runner_client.is_code_upload_optional(): + # Old runner, we must call `/api/upload_code` to proceed + code = b"" + if code is not None: + logger.debug("%s: uploading code", fmt(job_model)) + runner_client.upload_code(code) logger.debug("%s: starting job", fmt(job_model)) job_info = runner_client.run_job() if job_info is not None: @@ -1520,18 +1524,20 @@ def _get_repo_code_hash(run: Run, job: Job) -> Optional[str]: return job.job_spec.repo_code_hash -async def _get_job_code(project: ProjectModel, repo: RepoModel, code_hash: Optional[str]) -> bytes: +async def _get_job_code( + project: ProjectModel, repo: RepoModel, code_hash: Optional[str] +) -> Optional[bytes]: if code_hash is None: - return b"" + return None async with get_session_ctx() as session: code_model = await get_code_model(session=session, repo=repo, code_hash=code_hash) if code_model is None: - return b"" + return None if code_model.blob is not None: return code_model.blob storage = get_default_storage() if storage is None: - return b"" + return None blob = await run_async( storage.get_code, project.name, @@ -1542,7 +1548,7 @@ async def _get_job_code(project: ProjectModel, repo: RepoModel, code_hash: Optio logger.error( "Failed to get repo code hash %s from storage for repo %s", code_hash, repo.name ) - return b"" + return None return blob diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index 5f63ac8b5..b6b014c50 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -1173,17 +1173,17 @@ def _get_repo_code_hash(run: Run, job: Job) -> Optional[str]: async def _get_job_code( session: AsyncSession, project: ProjectModel, repo: RepoModel, code_hash: Optional[str] -) -> bytes: +) -> Optional[bytes]: if code_hash is None: - return b"" + return None code_model = await get_code_model(session=session, repo=repo, code_hash=code_hash) if code_model is None: - return b"" + return None if code_model.blob is not None: return code_model.blob storage = get_default_storage() if storage is None: - return b"" + return None blob = await common_utils.run_async( storage.get_code, project.name, @@ -1194,7 +1194,7 @@ async def _get_job_code( logger.error( "Failed to get repo code hash %s from storage for repo %s", code_hash, repo.name ) - return b"" + return None return blob @@ -1243,7 +1243,7 @@ def _submit_job_to_runner( job_model: JobModel, job: Job, cluster_info: ClusterInfo, - code: bytes, + code: Optional[bytes], file_archives: Iterable[tuple[uuid.UUID, bytes]], secrets: Dict[str, str], repo_credentials: Optional[RemoteRepoCreds], @@ -1285,11 +1285,15 @@ def _submit_job_to_runner( repo_credentials=repo_credentials, instance_env=instance_env, ) - logger.debug("%s: uploading file archive(s)", fmt(job_model)) for archive_id, archive in file_archives: + logger.debug("%s: uploading file archive: %s", fmt(job_model), archive_id) runner_client.upload_archive(archive_id, archive) - logger.debug("%s: uploading code", fmt(job_model)) - runner_client.upload_code(code) + if code is None and not runner_client.is_code_upload_optional(): + # Old runner, we must call `/api/upload_code` to proceed + code = b"" + if code is not None: + logger.debug("%s: uploading code", fmt(job_model)) + runner_client.upload_code(code) logger.debug("%s: starting job", fmt(job_model)) job_info = runner_client.run_job() if job_info is not None: diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index 4b78eefee..8fd17a3a2 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -50,6 +50,13 @@ class RunnerClient: + # `/api/upload_code` call is not required if there is no code + _OPTIONAL_CODE_UPLOAD_MIN_VERSION = (0, 20, 17) + + _version_string: str + _version_tuple: Optional["_Version"] + _negotiated: bool = False + def __init__( self, port: int, @@ -59,13 +66,28 @@ def __init__( self.hostname = hostname self.port = port + def get_version_string(self) -> str: + if not self._negotiated: + self._negotiate() + return self._version_string + + def get_version_tuple(self) -> Optional["_Version"]: + if not self._negotiated: + self._negotiate() + return self._version_tuple + + def is_code_upload_optional(self) -> bool: + version_tuple = self.get_version_tuple() + return version_tuple is None or version_tuple >= self._OPTIONAL_CODE_UPLOAD_MIN_VERSION + def healthcheck(self) -> Optional[HealthcheckResponse]: try: - resp = requests.get(self._url("/api/healthcheck"), timeout=REQUEST_TIMEOUT) - resp.raise_for_status() - return HealthcheckResponse.__response__.parse_obj(resp.json()) + healthcheck_response = self._healthcheck() except requests.exceptions.RequestException: return None + if not self._negotiated: + self._negotiate(healthcheck_response) + return healthcheck_response def get_metrics(self) -> Optional[MetricsResponse]: resp = requests.get(self._url("/api/metrics"), timeout=REQUEST_TIMEOUT) @@ -150,6 +172,20 @@ def stop(self): def _url(self, path: str) -> str: return f"{'https' if self.secure else 'http'}://{self.hostname}:{self.port}/{path.lstrip('/')}" + def _healthcheck(self) -> HealthcheckResponse: + resp = requests.get(self._url("/api/healthcheck"), timeout=REQUEST_TIMEOUT) + resp.raise_for_status() + return HealthcheckResponse.__response__.parse_obj(resp.json()) + + def _negotiate(self, healthcheck_response: Optional[HealthcheckResponse] = None) -> None: + if healthcheck_response is None: + healthcheck_response = self._healthcheck() + version_string = healthcheck_response.version + version_tuple = _parse_version(version_string) + self._version_string = version_string + self._version_tuple = version_tuple + self._negotiated = True + class ShimError(DstackError): pass diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 1d49de206..5c21857db 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -68,6 +68,7 @@ Profile, TerminationPolicy, ) +from dstack._internal.core.models.repos import AnyRunRepoData from dstack._internal.core.models.repos.base import RepoType from dstack._internal.core.models.repos.local import LocalRunRepoData from dstack._internal.core.models.resources import CPUSpec, Memory, ResourcesSpec @@ -91,6 +92,7 @@ ) from dstack._internal.server.models import ( BackendModel, + CodeModel, ComputeGroupModel, DecryptedString, EventModel, @@ -267,6 +269,22 @@ async def create_repo( return repo +async def create_code( + session: AsyncSession, + repo: RepoModel, + blob_hash: str = "blob_hash", + blob: Optional[bytes] = b"blob_content", +) -> CodeModel: + code = CodeModel( + repo_id=repo.id, + blob_hash=blob_hash, + blob=blob, + ) + session.add(code) + await session.commit() + return code + + async def create_repo_creds( session: AsyncSession, repo_id: UUID, @@ -312,14 +330,16 @@ def get_run_spec( profile: Union[Profile, Callable[[], Profile], None] = lambda: Profile(name="default"), configuration: Optional[AnyRunConfiguration] = None, ssh_key_pub: Optional[str] = "user_ssh_key", + repo_data: AnyRunRepoData = LocalRunRepoData(repo_dir="/"), + repo_code_hash: Optional[str] = None, ) -> RunSpec: if callable(profile): profile = profile() return RunSpec( run_name=run_name, repo_id=repo_id, - repo_data=LocalRunRepoData(repo_dir="/"), - repo_code_hash=None, + repo_data=repo_data, + repo_code_hash=repo_code_hash, configuration_path=configuration_path, configuration=configuration or DevEnvironmentConfiguration(ide="vscode"), profile=profile, diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index a8afac24c..696041e62 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -498,8 +498,8 @@ def get_run_plan( """ if repo is None: repo = VirtualRepo() - repo_code_hash = None - else: + repo_code_hash: Optional[str] = None + if repo.has_code_to_write(): with _prepare_code_file(repo) as (_, repo_code_hash): pass @@ -571,9 +571,7 @@ def apply_plan( if repo is None: repo = VirtualRepo() - else: - # Do not upload the diff without a repo (a default virtual repo) - # since upload_code() requires a repo to be initialized. + if repo.has_code_to_write(): with _prepare_code_file(repo) as (fp, repo_code_hash): self._api_client.repos.upload_code( project_name=self._project, diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py index 55259cd41..db82a1012 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -55,6 +55,7 @@ from dstack._internal.server.services.volumes import volume_model_to_volume from dstack._internal.server.testing.common import ( create_backend, + create_code, create_export, create_fleet, create_gateway, @@ -494,13 +495,44 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( assert job.lock_token is None assert job.lock_owner is None + @pytest.mark.parametrize( + ["has_repo_code", "runner_version", "upload_code_call_expected"], + [ + pytest.param(False, "0.20.17", False, id="without-repo-code-new-runner"), + pytest.param(True, "0.20.17", True, id="with-repo-code-new-runner"), + pytest.param(False, "0.20.16", True, id="without-repo-code-old-runner"), + pytest.param(True, "0.20.16", True, id="with-repo-code-old-runner"), + ], + ) async def test_runs_provisioning_job( - self, test_db, session: AsyncSession, worker: JobRunningWorker + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + runner_version: str, + has_repo_code: bool, + upload_code_call_expected: bool, ): project = await create_project(session=session) user = await create_user(session=session) repo = await create_repo(session=session, project_id=project.id) - run = await create_run(session=session, project=project, repo=repo, user=user) + repo_code_hash: Optional[str] = None + if has_repo_code: + repo_code_hash = "blob_hash" + await create_code(session=session, repo=repo, blob_hash=repo_code_hash, blob=b"blob") + run_spec = get_run_spec( + run_name="test-run", + repo_id=repo.name, + repo_code_hash=repo_code_hash, + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name=run_spec.run_name, + run_spec=run_spec, + ) instance = await create_instance( session=session, project=project, status=InstanceStatus.BUSY ) @@ -518,23 +550,26 @@ async def test_runs_provisioning_job( with ( patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as runner_client_cls, + patch.object(RunnerClient, "_healthcheck") as healthcheck_mock, + patch.object(RunnerClient, "submit_job") as submit_job_mock, + patch.object(RunnerClient, "upload_code") as upload_code_mock, + patch.object(RunnerClient, "run_job") as run_job_mock, ): - runner_client_mock = runner_client_cls.return_value - runner_client_mock.healthcheck.return_value = HealthcheckResponse( - service="dstack-runner", version="0.0.1.dev2" + healthcheck_mock.return_value = HealthcheckResponse( + service="dstack-runner", version=runner_version ) - runner_client_mock.run_job.return_value = JobInfoResponse( + run_job_mock.return_value = JobInfoResponse( working_dir="/dstack/run", username="dstack" ) await _process_job(session, worker, job) assert ssh_tunnel_cls.call_count == 2 - assert runner_client_mock.healthcheck.call_count == 2 - runner_client_mock.submit_job.assert_called_once() - runner_client_mock.upload_code.assert_called_once() - runner_client_mock.run_job.assert_called_once() + assert healthcheck_mock.call_count == 2 + submit_job_mock.assert_called_once() + if upload_code_call_expected: + upload_code_mock.assert_called_once() + else: + upload_code_mock.assert_not_called() + run_job_mock.assert_called_once() await session.refresh(job) assert job.status == JobStatus.RUNNING @@ -730,7 +765,21 @@ async def test_pulling_shim( project = await create_project(session=session) user = await create_user(session=session) repo = await create_repo(session=session, project_id=project.id) - run = await create_run(session=session, project=project, repo=repo, user=user) + repo_code_hash = "blob_hash" + await create_code(session=session, repo=repo, blob_hash=repo_code_hash, blob=b"blob") + run_spec = get_run_spec( + run_name="test-run", + repo_id=repo.name, + repo_code_hash=repo_code_hash, + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name=run_spec.run_name, + run_spec=run_spec, + ) instance = await create_instance( session=session, project=project, status=InstanceStatus.BUSY ) diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py index 0e2a63162..00c9cbf83 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py @@ -54,6 +54,7 @@ ) from dstack._internal.server.testing.common import ( create_backend, + create_code, create_export, create_fleet, create_gateway, @@ -156,10 +157,26 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( await session.refresh(job) assert job.status == JobStatus.PROVISIONING + @pytest.mark.parametrize( + ["has_repo_code", "old_runner", "upload_code_call_expected"], + [ + pytest.param(False, False, False, id="without-repo-code-new-runner"), + pytest.param(True, False, True, id="with-repo-code-new-runner"), + pytest.param(False, True, True, id="without-repo-code-old-runner"), + pytest.param(True, True, True, id="with-repo-code-old-runner"), + ], + ) @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_runs_provisioning_job( - self, test_db, session: AsyncSession, ssh_tunnel_mock: Mock, runner_client_mock: Mock + self, + test_db, + session: AsyncSession, + ssh_tunnel_mock: Mock, + runner_client_mock: Mock, + old_runner: bool, + has_repo_code: bool, + upload_code_call_expected: bool, ): project = await create_project(session=session) user = await create_user(session=session) @@ -167,11 +184,22 @@ async def test_runs_provisioning_job( session=session, project_id=project.id, ) + repo_code_hash: Optional[str] = None + if has_repo_code: + repo_code_hash = "blob_hash" + await create_code(session=session, repo=repo, blob_hash=repo_code_hash, blob=b"blob") + run_spec = get_run_spec( + run_name="test-run", + repo_id=repo.name, + repo_code_hash=repo_code_hash, + ) run = await create_run( session=session, project=project, repo=repo, user=user, + run_name=run_spec.run_name, + run_spec=run_spec, ) instance = await create_instance( session=session, @@ -191,11 +219,14 @@ async def test_runs_provisioning_job( runner_client_mock.run_job.return_value = JobInfoResponse( working_dir="/dstack/run", username="dstack" ) + runner_client_mock.is_code_upload_optional.return_value = not old_runner await process_running_jobs() ssh_tunnel_mock.assert_called() assert runner_client_mock.healthcheck.call_count == 2 - runner_client_mock.submit_job.assert_called_once() - runner_client_mock.upload_code.assert_called_once() + if upload_code_call_expected: + runner_client_mock.upload_code.assert_called_once() + else: + runner_client_mock.upload_code.assert_not_called() runner_client_mock.run_job.assert_called_once() await session.refresh(job) assert job.status == JobStatus.RUNNING @@ -490,11 +521,20 @@ async def test_pulling_shim( project = await create_project(session=session) user = await create_user(session=session) repo = await create_repo(session=session, project_id=project.id) + repo_code_hash = "blob_hash" + await create_code(session=session, repo=repo, blob_hash=repo_code_hash, blob=b"blob") + run_spec = get_run_spec( + run_name="test-run", + repo_id=repo.name, + repo_code_hash=repo_code_hash, + ) run = await create_run( session=session, project=project, repo=repo, user=user, + run_name=run_spec.run_name, + run_spec=run_spec, ) instance = await create_instance( session=session,