Skip to content

Commit 4c90f9f

Browse files
authored
Rework runner job submission flow (#3743)
The overall process is more lenient now: * It's safe to call `/api/submit` more than once -- even after `/api/upload_code` * It's safe to call `/api/upload_code` more than once or don't call it at all -- code upload is now optional -- if we have nothing to upload, we can skip this step Since `/api/upload_code` is now optional, the server no longer calls this method if there is no code (`RemoteRepo` diff or `VirtualRepo` file archive) to upload. In addition, Python API (used by CLI internally) has been optimized -- it does not upload a code blob in there is no diff (`RemoteRepo`) or no files (empty `VirtualRepo`, incl. default virtual repo used when no repo specified in the run configuration). Fixes: #3740
1 parent dd6234e commit 4c90f9f

File tree

15 files changed

+187
-63
lines changed

15 files changed

+187
-63
lines changed

runner/internal/runner/api/http.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,15 @@ func (s *Server) metricsGetHandler(w http.ResponseWriter, r *http.Request) (inte
3838
return metrics, nil
3939
}
4040

41+
// submitPostHandler must be called first
42+
// It's safe to call it more than once
4143
func (s *Server) submitPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
4244
s.executor.Lock()
4345
defer s.executor.Unlock()
4446
state := s.executor.GetRunnerState()
45-
if state != executor.WaitSubmit {
47+
if state == executor.WaitRun {
48+
log.Warning(r.Context(), "Job already submitted, submitting again", "current_state", state)
49+
} else if state != executor.WaitSubmit {
4650
log.Warning(r.Context(), "Executor doesn't wait submit", "current_state", state)
4751
return nil, &api.Error{Status: http.StatusConflict}
4852
}
@@ -52,20 +56,19 @@ func (s *Server) submitPostHandler(w http.ResponseWriter, r *http.Request) (inte
5256
log.Error(r.Context(), "Failed to decode submit body", "err", err)
5357
return nil, err
5458
}
55-
// todo go-playground/validator
5659

5760
s.executor.SetJob(body)
58-
s.jobBarrierCh <- nil // notify server that job submitted
61+
s.executor.SetRunnerState(executor.WaitRun)
5962

6063
return nil, nil
6164
}
6265

63-
// uploadArchivePostHandler may be called 0 or more times, and must be called after submitPostHandler
64-
// and before uploadCodePostHandler
66+
// If uploadArchivePostHandler is called, it must be called after submitPostHandler and before runPostHandler
67+
// It's safe to call it more than once with the same archive
6568
func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
6669
s.executor.Lock()
6770
defer s.executor.Unlock()
68-
if s.executor.GetRunnerState() != executor.WaitCode {
71+
if s.executor.GetRunnerState() != executor.WaitRun {
6972
return nil, &api.Error{Status: http.StatusConflict}
7073
}
7174

@@ -123,10 +126,12 @@ func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request
123126
return nil, nil
124127
}
125128

129+
// If uploadCodePostHandler is called, it must be called after submitPostHandler and before runPostHandler
130+
// It's safe to call it more than once
126131
func (s *Server) uploadCodePostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
127132
s.executor.Lock()
128133
defer s.executor.Unlock()
129-
if s.executor.GetRunnerState() != executor.WaitCode {
134+
if s.executor.GetRunnerState() != executor.WaitRun {
130135
return nil, &api.Error{Status: http.StatusConflict}
131136
}
132137

@@ -139,8 +144,6 @@ func (s *Server) uploadCodePostHandler(w http.ResponseWriter, r *http.Request) (
139144
return nil, fmt.Errorf("copy request body: %w", err)
140145
}
141146

142-
s.executor.SetRunnerState(executor.WaitRun)
143-
144147
return nil, nil
145148
}
146149

@@ -151,6 +154,7 @@ func (s *Server) runPostHandler(w http.ResponseWriter, r *http.Request) (interfa
151154
return nil, &api.Error{Status: http.StatusConflict}
152155
}
153156
s.executor.SetRunnerState(executor.ServeLogs)
157+
s.jobBarrierCh <- nil // notify server that job started
154158
s.executor.Unlock()
155159

156160
var runCtx context.Context

runner/internal/runner/api/server.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ type Server struct {
2121
pullDoneCh chan interface{} // Closed then /api/pull gave everything
2222
wsDoneCh chan interface{} // Closed then /logs_ws gave everything
2323

24-
submitWaitDuration time.Duration
25-
logsWaitDuration time.Duration
24+
startWaitDuration time.Duration
25+
logsWaitDuration time.Duration
2626

2727
executor executor.Executor
2828
cancelRun context.CancelFunc
@@ -51,8 +51,8 @@ func NewServer(ctx context.Context, address string, version string, ex executor.
5151
pullDoneCh: make(chan interface{}),
5252
wsDoneCh: make(chan interface{}),
5353

54-
submitWaitDuration: 5 * time.Minute,
55-
logsWaitDuration: 5 * time.Minute,
54+
startWaitDuration: 5 * time.Minute,
55+
logsWaitDuration: 5 * time.Minute,
5656

5757
executor: ex,
5858

@@ -82,7 +82,7 @@ func (s *Server) Run(ctx context.Context) error {
8282

8383
select {
8484
case <-s.jobBarrierCh: // job started
85-
case <-time.After(s.submitWaitDuration):
85+
case <-time.After(s.startWaitDuration):
8686
log.Error(ctx, "Job didn't start in time, shutting down")
8787
return errors.New("no job submitted")
8888
case <-ctx.Done():

runner/internal/runner/executor/base.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,29 @@ import (
99
)
1010

1111
type Executor interface {
12+
// It must be safe to call SetJob more than once
13+
SetJob(job schemas.SubmitBody)
14+
// It must be safe to call WriteFileArchive more than once with the same archive
15+
WriteFileArchive(id string, src io.Reader) error
16+
// It must be safe to call WriteRepoBlob more than once
17+
WriteRepoBlob(src io.Reader) error
18+
Run(ctx context.Context) error
19+
1220
GetHistory(timestamp int64) *schemas.PullResponse
1321
GetJobWsLogsHistory() []schemas.LogEvent
22+
1423
GetRunnerState() string
24+
SetRunnerState(state string)
25+
1526
GetJobInfo(ctx context.Context) (username string, workingDir string, err error)
16-
Run(ctx context.Context) error
17-
SetJob(job schemas.SubmitBody)
1827
SetJobState(ctx context.Context, state schemas.JobState)
1928
SetJobStateWithTerminationReason(
2029
ctx context.Context,
2130
state schemas.JobState,
2231
terminationReason types.TerminationReason,
2332
terminationMessage string,
2433
)
25-
SetRunnerState(state string)
26-
WriteFileArchive(id string, src io.Reader) error
27-
WriteRepoBlob(src io.Reader) error
34+
2835
Lock()
2936
RLock()
3037
RUnlock()

runner/internal/runner/executor/executor.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,6 @@ func (ex *RunExecutor) SetJob(body schemas.SubmitBody) {
295295
ex.secrets = body.Secrets
296296
ex.repoCredentials = body.RepoCredentials
297297
ex.jobLogs.SetQuota(body.LogQuotaHour)
298-
ex.state = WaitCode
299298
}
300299

301300
func (ex *RunExecutor) SetJobState(ctx context.Context, state schemas.JobState) {

runner/internal/runner/executor/repo.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,8 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error {
106106
return fmt.Errorf("prepare git repo: %w", err)
107107
}
108108
case "local", "virtual":
109-
log.Trace(ctx, "Extracting tar archive")
110-
if err := ex.prepareArchive(ctx); err != nil {
111-
return fmt.Errorf("prepare archive: %w", err)
109+
if err := ex.extractCodeArchive(ctx); err != nil {
110+
return fmt.Errorf("extract code archive: %w", err)
112111
}
113112
default:
114113
return fmt.Errorf("unknown RepoType: %s", ex.getRepoData().RepoType)
@@ -164,26 +163,32 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error {
164163
return fmt.Errorf("set repo config: %w", err)
165164
}
166165

166+
if ex.repoBlobPath == "" {
167+
log.Trace(ctx, "No diff to apply")
168+
return nil
169+
}
167170
log.Trace(ctx, "Applying diff")
168171
repoDiff, err := os.ReadFile(ex.repoBlobPath)
169172
if err != nil {
170173
return fmt.Errorf("read repo diff: %w", err)
171174
}
172-
if len(repoDiff) > 0 {
173-
if err := repo.ApplyDiff(ctx, ex.repoDir, string(repoDiff)); err != nil {
174-
return fmt.Errorf("apply diff: %w", err)
175-
}
175+
if err := repo.ApplyDiff(ctx, ex.repoDir, string(repoDiff)); err != nil {
176+
return fmt.Errorf("apply diff: %w", err)
176177
}
177178
return nil
178179
}
179180

180-
func (ex *RunExecutor) prepareArchive(ctx context.Context) error {
181+
func (ex *RunExecutor) extractCodeArchive(ctx context.Context) error {
182+
if ex.repoBlobPath == "" {
183+
log.Trace(ctx, "No code archive to extract")
184+
return nil
185+
}
186+
log.Trace(ctx, "Extracting code archive", "src", ex.repoBlobPath, "dst", ex.repoDir)
181187
file, err := os.Open(ex.repoBlobPath)
182188
if err != nil {
183189
return fmt.Errorf("open code archive: %w", err)
184190
}
185191
defer func() { _ = file.Close() }()
186-
log.Trace(ctx, "Extracting code archive", "src", ex.repoBlobPath, "dst", ex.repoDir)
187192
if err := extract.Tar(ctx, file, ex.repoDir, nil); err != nil {
188193
return fmt.Errorf("extract tar archive: %w", err)
189194
}

runner/internal/runner/executor/states.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package executor
22

33
const (
44
WaitSubmit = "wait_submit"
5-
WaitCode = "wait_code"
65
WaitRun = "wait_run"
76
ServeLogs = "serve_logs"
87
WaitLogsFinished = "wait_logs_finished"

src/dstack/_internal/core/models/repos/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ class Repo(ABC):
2222
repo_dir: Optional[str]
2323
run_repo_data: "repos.AnyRunRepoData"
2424

25+
@abstractmethod
26+
def has_code_to_write(self) -> bool:
27+
pass
28+
2529
@abstractmethod
2630
def write_code_file(self, fp: BinaryIO) -> str:
2731
pass

src/dstack/_internal/core/models/repos/local.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def __init__(
7373
self.repo_id = repo_id
7474
self.run_repo_data = repo_data
7575

76+
def has_code_to_write(self) -> bool:
77+
# LocalRepo is deprecated, no need for real implementation
78+
return True
79+
7680
def write_code_file(self, fp: BinaryIO) -> str:
7781
repo_path = Path(self.run_repo_data.repo_dir)
7882
with tarfile.TarFile(mode="w", fileobj=fp) as t:

src/dstack/_internal/core/models/repos/remote.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,13 @@ def __init__(
180180
)
181181
self.repo_id = repo_id
182182

183+
def has_code_to_write(self) -> bool:
184+
# repo_diff is:
185+
# * None for RemoteRepo.from_url()
186+
# * an empty string for RemoteRepo.from_dir() if there are no changes ("clean" state)
187+
# * a non-empty string for RemoteRepo.from_dir() if there are changes ("dirty" state)
188+
return bool(self.run_repo_data.repo_diff)
189+
183190
def write_code_file(self, fp: BinaryIO) -> str:
184191
if self.run_repo_data.repo_diff is not None:
185192
fp.write(self.run_repo_data.repo_diff.encode())

src/dstack/_internal/core/models/repos/virtual.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ def add_file(self, path: str, content: bytes):
7373

7474
self.files[resolve_relative_path(path).as_posix()] = content
7575

76+
def has_code_to_write(self) -> bool:
77+
return len(self.files) > 0
78+
7679
def write_code_file(self, fp: BinaryIO) -> str:
7780
with tarfile.TarFile(mode="w", fileobj=fp) as t:
7881
for path, content in sorted(self.files.items()):

0 commit comments

Comments
 (0)