Skip to content

Commit dc585db

Browse files
committed
Return new fields in /api/run response
/api/pull is too late, we need these fields as soon as the job state is switched to RUNNING
1 parent 0ed127d commit dc585db

11 files changed

Lines changed: 146 additions & 214 deletions

File tree

runner/internal/executor/base.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ type Executor interface {
1212
GetHistory(timestamp int64) *schemas.PullResponse
1313
GetJobWsLogsHistory() []schemas.LogEvent
1414
GetRunnerState() string
15+
GetJobInfo(ctx context.Context) (username string, workingDir string, err error)
1516
Run(ctx context.Context) error
1617
SetJob(job schemas.SubmitBody)
1718
SetJobState(ctx context.Context, state types.JobState)

runner/internal/executor/executor.go

Lines changed: 85 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"github.com/creack/pty"
2222
"github.com/dstackai/ansistrip"
2323
"github.com/prometheus/procfs"
24+
"github.com/sirupsen/logrus"
2425
"golang.org/x/sys/unix"
2526

2627
"github.com/dstackai/dstack/runner/consts"
@@ -61,6 +62,10 @@ type RunExecutor struct {
6162
fileArchiveDir string
6263
repoBlobDir string
6364

65+
runnerLogFile *os.File
66+
runnerLogStripper *ansistrip.Writer
67+
runnerLogger *logrus.Entry
68+
6469
run schemas.Run
6570
jobSpec schemas.JobSpec
6671
jobSubmission schemas.JobSubmission
@@ -136,14 +141,26 @@ func NewRunExecutor(tempDir string, dstackDir string, currentUser linuxuser.User
136141
}, nil
137142
}
138143

144+
// GetJobInfo must be called after SetJob
145+
func (ex *RunExecutor) GetJobInfo(ctx context.Context) (string, string, error) {
146+
// preRun() sets ex.jobUser and ex.jobWorkingDir
147+
if err := ex.preRun(ctx); err != nil {
148+
return "", "", err
149+
}
150+
return ex.jobUser.Username, ex.jobWorkingDir, nil
151+
}
152+
139153
// Run must be called after SetJob and WriteRepoBlob
140154
func (ex *RunExecutor) Run(ctx context.Context) (err error) {
141-
runnerLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerLogFileName))
142-
if err != nil {
143-
ex.SetJobState(ctx, types.JobStateFailed)
144-
return fmt.Errorf("create runner log file: %w", err)
155+
// If jobStateHistory is not empty, either Run() has already been called or
156+
// preRun() has already been called via GetJobInfo() and failed
157+
if len(ex.jobStateHistory) > 0 {
158+
return errors.New("already running or finished")
159+
}
160+
if err := ex.preRun(ctx); err != nil {
161+
return err
145162
}
146-
defer func() { _ = runnerLogFile.Close() }()
163+
defer ex.postRun(ctx)
147164

148165
jobLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerJobLogFileName))
149166
if err != nil {
@@ -153,7 +170,7 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) {
153170
defer func() { _ = jobLogFile.Close() }()
154171

155172
defer func() {
156-
// recover goes after runnerLogFile.Close() to keep the log
173+
// recover goes after postRun(), which closes runnerLogFile, to keep the log
157174
if r := recover(); r != nil {
158175
log.Error(ctx, "Executor PANIC", "err", r)
159176
ex.SetJobState(ctx, types.JobStateFailed)
@@ -171,21 +188,8 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) {
171188
}
172189
}()
173190

174-
stripper := ansistrip.NewWriter(ex.runnerLogs, AnsiStripFlushInterval, AnsiStripMaxDelay, MaxBufferSize)
175-
defer func() { _ = stripper.Close() }()
176-
logger := io.MultiWriter(runnerLogFile, os.Stdout, stripper)
177-
ctx = log.WithLogger(ctx, log.NewEntry(logger, int(log.DefaultEntry.Logger.Level))) // todo loglevel
178-
log.Info(ctx, "Run job", "log_level", log.GetLogger(ctx).Logger.Level.String())
179-
180-
if err := ex.setJobUser(ctx); err != nil {
181-
ex.SetJobStateWithTerminationReason(
182-
ctx,
183-
types.JobStateFailed,
184-
types.TerminationReasonExecutorError,
185-
fmt.Sprintf("Failed to set job user (%s)", err),
186-
)
187-
return fmt.Errorf("set job user: %w", err)
188-
}
191+
ctx = log.WithLogger(ctx, ex.runnerLogger)
192+
log.Info(ctx, "Run job")
189193

190194
// setJobUser sets User.HomeDir to "/" if the original home dir is not set or not accessible,
191195
// in that case we skip home dir provisioning
@@ -204,16 +208,6 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) {
204208
}
205209
}
206210

207-
if err := ex.setJobWorkingDir(ctx); err != nil {
208-
ex.SetJobStateWithTerminationReason(
209-
ctx,
210-
types.JobStateFailed,
211-
types.TerminationReasonExecutorError,
212-
fmt.Sprintf("Failed to set job working dir (%s)", err),
213-
)
214-
return fmt.Errorf("set job working dir: %w", err)
215-
}
216-
217211
if err := ex.setupRepo(ctx); err != nil {
218212
ex.SetJobStateWithTerminationReason(
219213
ctx,
@@ -336,6 +330,66 @@ func (ex *RunExecutor) SetRunnerState(state string) {
336330
ex.state = state
337331
}
338332

333+
// preRun performs actions that were once part of Run() but were moved to a separate function
334+
// to implement GetJobInfo()
335+
// preRun must not execute long-running operations, as GetJobInfo() is called synchronously
336+
// in the /api/run method
337+
func (ex *RunExecutor) preRun(ctx context.Context) error {
338+
// Already called once
339+
if ex.runnerLogFile != nil {
340+
return nil
341+
}
342+
343+
// logging is required for the subsequent setJob{User,WorkingDir} calls
344+
runnerLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerLogFileName))
345+
if err != nil {
346+
ex.SetJobState(ctx, types.JobStateFailed)
347+
return fmt.Errorf("create runner log file: %w", err)
348+
}
349+
ex.runnerLogFile = runnerLogFile
350+
ex.runnerLogStripper = ansistrip.NewWriter(ex.runnerLogs, AnsiStripFlushInterval, AnsiStripMaxDelay, MaxBufferSize)
351+
runnerLogWriter := io.MultiWriter(ex.runnerLogFile, os.Stdout, ex.runnerLogStripper)
352+
runnerLogLevel := log.DefaultEntry.Logger.Level
353+
ex.runnerLogger = log.NewEntry(runnerLogWriter, int(runnerLogLevel))
354+
ctx = log.WithLogger(ctx, ex.runnerLogger)
355+
log.Info(ctx, "Logging configured", "log_level", runnerLogLevel.String())
356+
357+
// jobUser and jobWorkingDir are required for GetJobInfo()
358+
if err := ex.setJobUser(ctx); err != nil {
359+
ex.SetJobStateWithTerminationReason(
360+
ctx,
361+
types.JobStateFailed,
362+
types.TerminationReasonExecutorError,
363+
fmt.Sprintf("Failed to set job user (%s)", err),
364+
)
365+
return fmt.Errorf("set job user: %w", err)
366+
}
367+
if err := ex.setJobWorkingDir(ctx); err != nil {
368+
ex.SetJobStateWithTerminationReason(
369+
ctx,
370+
types.JobStateFailed,
371+
types.TerminationReasonExecutorError,
372+
fmt.Sprintf("Failed to set job working dir (%s)", err),
373+
)
374+
return fmt.Errorf("set job working dir: %w", err)
375+
}
376+
377+
return nil
378+
}
379+
380+
func (ex *RunExecutor) postRun(ctx context.Context) {
381+
if ex.runnerLogFile != nil {
382+
if err := ex.runnerLogFile.Close(); err != nil {
383+
log.Error(ctx, "Failed to close runnerLogFile", "err", err)
384+
}
385+
}
386+
if ex.runnerLogStripper != nil {
387+
if err := ex.runnerLogStripper.Close(); err != nil {
388+
log.Error(ctx, "Failed to close runnerLogStripper", "err", err)
389+
}
390+
}
391+
}
392+
339393
// setJobWorkingDir must be called from Run after setJobUser
340394
func (ex *RunExecutor) setJobWorkingDir(ctx context.Context) error {
341395
var err error

runner/internal/executor/executor_test.go

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -353,30 +353,6 @@ func TestExecutor_LogsAnsiCodeHandling(t *testing.T) {
353353
}
354354
}
355355

356-
func TestGetHistory_IncludesWorkingDirAndUsername(t *testing.T) {
357-
ex := makeTestExecutor(t)
358-
resp := ex.GetHistory(0)
359-
assert.NotEmpty(t, resp.WorkingDir)
360-
assert.True(t, path.IsAbs(resp.WorkingDir))
361-
assert.NotEmpty(t, resp.Username)
362-
}
363-
364-
func TestGetHistory_BeforeRun(t *testing.T) {
365-
baseDir, err := filepath.EvalSymlinks(t.TempDir())
366-
require.NoError(t, err)
367-
tempDir := filepath.Join(baseDir, "temp")
368-
require.NoError(t, os.Mkdir(tempDir, 0o700))
369-
dstackDir := filepath.Join(baseDir, "dstack")
370-
require.NoError(t, os.Mkdir(dstackDir, 0o755))
371-
currentUser, err := linuxuser.FromCurrentProcess()
372-
require.NoError(t, err)
373-
ex, err := NewRunExecutor(tempDir, dstackDir, *currentUser, new(sshdMock))
374-
require.NoError(t, err)
375-
resp := ex.GetHistory(0)
376-
assert.Empty(t, resp.WorkingDir)
377-
assert.Empty(t, resp.Username)
378-
}
379-
380356
type sshdMock struct{}
381357

382358
func (d *sshdMock) Port() int {

runner/internal/executor/query.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,14 @@ func (ex *RunExecutor) GetJobWsLogsHistory() []schemas.LogEvent {
99
}
1010

1111
func (ex *RunExecutor) GetHistory(timestamp int64) *schemas.PullResponse {
12-
resp := &schemas.PullResponse{
12+
return &schemas.PullResponse{
1313
JobStates: eventsAfter(ex.jobStateHistory, timestamp),
1414
JobLogs: eventsAfter(ex.jobLogs.history, timestamp),
1515
RunnerLogs: eventsAfter(ex.runnerLogs.history, timestamp),
1616
LastUpdated: ex.timestamp.GetLatest(),
1717
NoConnectionsSecs: ex.connectionTracker.GetNoConnectionsSecs(),
1818
HasMore: ex.state != WaitLogsFinished,
19-
WorkingDir: ex.jobWorkingDir,
2019
}
21-
if ex.jobUser != nil {
22-
resp.Username = ex.jobUser.Username
23-
}
24-
return resp
2520
}
2621

2722
func (ex *RunExecutor) GetRunnerState() string {

runner/internal/runner/api/http.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,18 +146,27 @@ func (s *Server) uploadCodePostHandler(w http.ResponseWriter, r *http.Request) (
146146

147147
func (s *Server) runPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
148148
s.executor.Lock()
149-
defer s.executor.Unlock()
150149
if s.executor.GetRunnerState() != executor.WaitRun {
150+
s.executor.Unlock()
151151
return nil, &api.Error{Status: http.StatusConflict}
152152
}
153+
s.executor.SetRunnerState(executor.ServeLogs)
154+
s.executor.Unlock()
153155

154156
var runCtx context.Context
155157
runCtx, s.cancelRun = context.WithCancel(context.Background())
158+
username, workingDir, err := s.executor.GetJobInfo(runCtx)
156159
go func() {
157160
_ = s.executor.Run(runCtx) // INFO: all errors are handled inside the Run()
158161
s.jobBarrierCh <- nil // notify server that job finished
159162
}()
160-
s.executor.SetRunnerState(executor.ServeLogs)
163+
164+
if err == nil {
165+
return &schemas.JobInfoResponse{
166+
Username: username,
167+
WorkingDir: workingDir,
168+
}, nil
169+
}
161170

162171
return nil, nil
163172
}

runner/internal/schemas/schemas.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@ type PullResponse struct {
3535
LastUpdated int64 `json:"last_updated"`
3636
NoConnectionsSecs int64 `json:"no_connections_secs"`
3737
HasMore bool `json:"has_more"`
38-
WorkingDir string `json:"working_dir,omitempty"`
39-
Username string `json:"username,omitempty"`
38+
}
39+
40+
type JobInfoResponse struct {
41+
WorkingDir string `json:"working_dir"`
42+
Username string `json:"username"`
4043
}
4144

4245
type Run struct {

src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -774,18 +774,6 @@ def _process_running(
774774
timestamp = job_model.runner_timestamp
775775
resp = runner_client.pull(timestamp) # raises error if runner is down, causes retry
776776
job_model.runner_timestamp = resp.last_updated
777-
if resp.working_dir or resp.username:
778-
jrd = get_job_runtime_data(job_model)
779-
if jrd is not None:
780-
updated = False
781-
if resp.working_dir and jrd.working_dir is None:
782-
jrd.working_dir = resp.working_dir
783-
updated = True
784-
if resp.username and jrd.username is None:
785-
jrd.username = resp.username
786-
updated = True
787-
if updated:
788-
job_model.job_runtime_data = jrd.json()
789777
# may raise LogStorageError, causing a retry
790778
logs_services.write_logs(
791779
project=run_model.project,
@@ -1128,7 +1116,13 @@ def _submit_job_to_runner(
11281116
logger.debug("%s: uploading code", fmt(job_model))
11291117
runner_client.upload_code(code)
11301118
logger.debug("%s: starting job", fmt(job_model))
1131-
runner_client.run_job()
1119+
job_info = runner_client.run_job()
1120+
if job_info is not None:
1121+
jrd = get_job_runtime_data(job_model)
1122+
if jrd is not None:
1123+
jrd.working_dir = job_info.working_dir
1124+
jrd.username = job_info.username
1125+
job_model.job_runtime_data = jrd.json()
11321126

11331127
switch_job_status(session, job_model, JobStatus.RUNNING)
11341128
# do not log here, because the runner will send a new status

src/dstack/_internal/server/schemas/runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,11 @@ class PullResponse(CoreModel):
4444
runner_logs: List[LogEvent]
4545
last_updated: int
4646
no_connections_secs: Optional[int] = None # Optional for compatibility with old runners
47-
working_dir: Optional[str] = None # Optional for compatibility with old runners
48-
username: Optional[str] = None # Optional for compatibility with old runners
47+
48+
49+
class JobInfoResponse(CoreModel):
50+
working_dir: str
51+
username: str
4952

5053

5154
class SubmitBody(CoreModel):

src/dstack/_internal/server/services/runner/client.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
GPUDevice,
2525
HealthcheckResponse,
2626
InstanceHealthResponse,
27+
JobInfoResponse,
2728
LegacyPullResponse,
2829
LegacyStopBody,
2930
LegacySubmitBody,
@@ -124,9 +125,13 @@ def upload_code(self, file: Union[BinaryIO, bytes]):
124125
)
125126
resp.raise_for_status()
126127

127-
def run_job(self):
128+
def run_job(self) -> Optional[JobInfoResponse]:
128129
resp = requests.post(self._url("/api/run"), timeout=REQUEST_TIMEOUT)
129130
resp.raise_for_status()
131+
if not _is_json_response(resp):
132+
# Old runner or runner failed to get job info
133+
return None
134+
return JobInfoResponse.__response__.parse_obj(resp.json())
130135

131136
def pull(self, timestamp: int) -> PullResponse:
132137
resp = requests.get(
@@ -617,6 +622,13 @@ def _memory_to_bytes(memory: Optional[Memory]) -> int:
617622
return int(memory * 1024**3)
618623

619624

625+
def _is_json_response(response: requests.Response) -> bool:
626+
content_type = response.headers.get("content-type")
627+
if not content_type:
628+
return False
629+
return content_type.split(";", maxsplit=1)[0].strip() == "application/json"
630+
631+
620632
_TaskID = Union[uuid.UUID, str]
621633

622634
_Version = tuple[int, int, int]

src/dstack/api/_public/runs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,9 @@ def attach(
352352
if runtime_data is not None and runtime_data.ports is not None:
353353
container_ssh_port = runtime_data.ports.get(container_ssh_port, container_ssh_port)
354354

355-
# TODO: get login name from runner in case it's not specified in the run configuration
356-
# (i.e. the default image user is used, and it is not root)
357-
if job.job_spec.user is not None and job.job_spec.user.username is not None:
355+
if runtime_data is not None and runtime_data.username is not None:
356+
container_user = runtime_data.username
357+
elif job.job_spec.user is not None and job.job_spec.user.username is not None:
358358
container_user = job.job_spec.user.username
359359
else:
360360
container_user = "root"

0 commit comments

Comments
 (0)