|
9 | 9 | "errors" |
10 | 10 | "fmt" |
11 | 11 | "io" |
| 12 | + "net/http" |
12 | 13 | "os" |
13 | 14 | "os/exec" |
14 | 15 | "os/user" |
@@ -37,6 +38,7 @@ import ( |
37 | 38 | "github.com/dstackai/dstack/runner/internal/common/types" |
38 | 39 | "github.com/dstackai/dstack/runner/internal/shim/backends" |
39 | 40 | "github.com/dstackai/dstack/runner/internal/shim/host" |
| 41 | + "github.com/dstackai/dstack/runner/internal/shim/netmeter" |
40 | 42 | ) |
41 | 43 |
|
42 | 44 | // TODO: Allow for configuration via cli arguments or environment variables. |
@@ -380,7 +382,8 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error { |
380 | 382 | if err := d.tasks.Update(task); err != nil { |
381 | 383 | return fmt.Errorf("%w: failed to update task %s: %w", ErrInternal, task.ID, err) |
382 | 384 | } |
383 | | - err = d.waitContainer(ctx, &task) |
| 385 | + |
| 386 | + err = d.waitContainerWithQuota(ctx, &task, cfg) |
384 | 387 | } |
385 | 388 | if err != nil { |
386 | 389 | log.Error(ctx, "failed to run container", "err", err) |
@@ -910,6 +913,49 @@ func (d *DockerRunner) waitContainer(ctx context.Context, task *Task) error { |
910 | 913 | return nil |
911 | 914 | } |
912 | 915 |
|
| 916 | +// waitContainerWithQuota waits for the container to finish, optionally enforcing |
| 917 | +// a data transfer quota. If the quota is exceeded, it notifies the runner |
| 918 | +// (so the server reads the termination reason via /api/pull) and stops the container. |
| 919 | +func (d *DockerRunner) waitContainerWithQuota(ctx context.Context, task *Task, cfg TaskConfig) error { |
| 920 | + if cfg.DataTransferQuota <= 0 { |
| 921 | + return d.waitContainer(ctx, task) |
| 922 | + } |
| 923 | + |
| 924 | + nm := netmeter.New(task.ID, cfg.DataTransferQuota) |
| 925 | + if err := nm.Start(ctx); err != nil { |
| 926 | + errMessage := fmt.Sprintf("data transfer quota configured but metering unavailable: %s", err) |
| 927 | + log.Error(ctx, errMessage) |
| 928 | + task.SetStatusTerminated(string(types.TerminationReasonExecutorError), errMessage) |
| 929 | + return fmt.Errorf("data transfer meter: %w", err) |
| 930 | + } |
| 931 | + defer nm.Stop() |
| 932 | + |
| 933 | + waitDone := make(chan error, 1) |
| 934 | + go func() { waitDone <- d.waitContainer(ctx, task) }() |
| 935 | + |
| 936 | + select { |
| 937 | + case err := <-waitDone: |
| 938 | + return err |
| 939 | + case <-nm.Exceeded(): |
| 940 | + log.Error(ctx, "Data transfer quota exceeded", "task", task.ID, "quota", cfg.DataTransferQuota) |
| 941 | + terminateMsg := fmt.Sprintf("Outbound data transfer exceeded quota of %d bytes", cfg.DataTransferQuota) |
| 942 | + if err := terminateRunner(ctx, d.dockerParams.RunnerHTTPPort(), |
| 943 | + types.TerminationReasonDataTransferQuotaExceeded, terminateMsg); err != nil { |
| 944 | + log.Error(ctx, "failed to notify runner of termination", "err", err) |
| 945 | + } |
| 946 | + stopTimeout := 10 |
| 947 | + stopOpts := container.StopOptions{Timeout: &stopTimeout} |
| 948 | + if err := d.client.ContainerStop(ctx, task.containerID, stopOpts); err != nil { |
| 949 | + log.Error(ctx, "failed to stop container after quota exceeded", "err", err) |
| 950 | + } |
| 951 | + <-waitDone |
| 952 | + // The runner already set the job state with the termination reason. |
| 953 | + // The server will read it via /api/pull. |
| 954 | + task.SetStatusTerminated(string(types.TerminationReasonDoneByRunner), "") |
| 955 | + return nil |
| 956 | + } |
| 957 | +} |
| 958 | + |
913 | 959 | func encodeRegistryAuth(username string, password string) (string, error) { |
914 | 960 | if username == "" && password == "" { |
915 | 961 | return "", nil |
@@ -1180,6 +1226,31 @@ func getContainerLastLogs(ctx context.Context, client docker.APIClient, containe |
1180 | 1226 | return lines, nil |
1181 | 1227 | } |
1182 | 1228 |
|
| 1229 | +// terminateRunner calls the runner's /api/terminate endpoint to set the job termination state. |
| 1230 | +// This allows the server to read the termination reason via /api/pull before the container dies. |
| 1231 | +func terminateRunner(ctx context.Context, runnerPort int, reason types.TerminationReason, message string) error { |
| 1232 | + url := fmt.Sprintf("http://localhost:%d/api/terminate", runnerPort) |
| 1233 | + body := fmt.Sprintf(`{"reason":%q,"message":%q}`, reason, message) |
| 1234 | + // 5s is generous for a localhost HTTP call; if the runner doesn't respond in time, |
| 1235 | + // we proceed with stopping the container anyway (the server will handle the termination). |
| 1236 | + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) |
| 1237 | + defer cancel() |
| 1238 | + req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(body)) |
| 1239 | + if err != nil { |
| 1240 | + return fmt.Errorf("create request: %w", err) |
| 1241 | + } |
| 1242 | + req.Header.Set("Content-Type", "application/json") |
| 1243 | + resp, err := http.DefaultClient.Do(req) |
| 1244 | + if err != nil { |
| 1245 | + return fmt.Errorf("request failed: %w", err) |
| 1246 | + } |
| 1247 | + defer resp.Body.Close() |
| 1248 | + if resp.StatusCode != http.StatusOK { |
| 1249 | + return fmt.Errorf("unexpected status: %d", resp.StatusCode) |
| 1250 | + } |
| 1251 | + return nil |
| 1252 | +} |
| 1253 | + |
1183 | 1254 | /* DockerParameters interface implementation for CLIArgs */ |
1184 | 1255 |
|
1185 | 1256 | func (c *CLIArgs) DockerPrivileged() bool { |
@@ -1228,6 +1299,10 @@ func (c *CLIArgs) DockerPorts() []int { |
1228 | 1299 | return []int{c.Runner.HTTPPort, c.Runner.SSHPort} |
1229 | 1300 | } |
1230 | 1301 |
|
| 1302 | +func (c *CLIArgs) RunnerHTTPPort() int { |
| 1303 | + return c.Runner.HTTPPort |
| 1304 | +} |
| 1305 | + |
1231 | 1306 | func (c *CLIArgs) MakeRunnerDir(name string) (string, error) { |
1232 | 1307 | runnerTemp := filepath.Join(c.Shim.HomeDir, "runners", name) |
1233 | 1308 | if err := os.MkdirAll(runnerTemp, 0o755); err != nil { |
|
0 commit comments