Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions internal/types/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ type SidecarMount struct {
ReadWrite bool `json:"read_write"` // If false (default), the mount is read-only.
}

// AttachmentDownload contains information needed to download a single task attachment.
type AttachmentDownload struct {
AttachmentID string `json:"attachment_id"`
Filename string `json:"filename"`
MimeType string `json:"mime_type"`
DownloadURL string `json:"download_url"`
}

// TaskAssignmentMessage is sent from server to worker when a task is available
type TaskAssignmentMessage struct {
TaskID string `json:"task_id"`
Expand All @@ -40,6 +48,9 @@ type TaskAssignmentMessage struct {
EnvVars map[string]string `json:"env_vars,omitempty"`
// AdditionalSidecars is a list of extra sidecar images to mount into the task container.
AdditionalSidecars []SidecarMount `json:"additional_sidecars,omitempty"`
// Attachments contains presigned download URLs for task attachments.
// The worker should download these files before starting the agent.
Attachments []AttachmentDownload `json:"attachments,omitempty"`
}

// TaskClaimedMessage is sent from worker to server after successfully claiming a task
Expand Down
101 changes: 99 additions & 2 deletions internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"sync"
"time"

Expand Down Expand Up @@ -346,7 +350,8 @@ func (w *Worker) handleTaskAssignment(assignment *types.TaskAssignmentMessage) {

// prepareTaskParams converts a TaskAssignmentMessage into backend-agnostic TaskParams,
// resolving common environment variables, default images, and base CLI arguments.
func (w *Worker) prepareTaskParams(assignment *types.TaskAssignmentMessage) *TaskParams {
// attachmentsDir is the path to locally downloaded attachments (empty if none).
func (w *Worker) prepareTaskParams(assignment *types.TaskAssignmentMessage, attachmentsDir string) *TaskParams {
task := assignment.Task

// Resolve Docker image.
Expand Down Expand Up @@ -381,6 +386,9 @@ func (w *Worker) prepareTaskParams(assignment *types.TaskAssignmentMessage) *Tas
if w.config.SessionSharingServerURL != "" {
baseArgs = append(baseArgs, "--session-sharing-server-url", w.config.SessionSharingServerURL)
}
if attachmentsDir != "" {
baseArgs = append(baseArgs, "--attachments-dir", attachmentsDir)
}

// Build a unified sidecar list:
// entrypoint.sh lives) comes first, followed by any additional sidecars.
Expand Down Expand Up @@ -442,7 +450,28 @@ func (w *Worker) executeTask(ctx context.Context, assignment *types.TaskAssignme
taskID := assignment.TaskID
log.Infof(ctx, "Starting task execution: taskID=%s, title=%s", taskID, assignment.Task.Title)

params := w.prepareTaskParams(assignment)
// Download attachments if present, before preparing task params.
var attachmentsDir string
if len(assignment.Attachments) > 0 {
var err error
attachmentsDir, err = w.downloadAttachments(ctx, taskID, assignment.Attachments)
if err != nil {
log.Errorf(ctx, "Failed to download attachments for task %s: %v", taskID, err)
if statusErr := w.sendTaskFailed(taskID, fmt.Sprintf("Failed to download attachments: %v", err)); statusErr != nil {
log.Errorf(ctx, "Failed to send task failed message: %v", statusErr)
}
return
}
if attachmentsDir != "" {
defer func() {
if err := os.RemoveAll(attachmentsDir); err != nil {
log.Warnf(ctx, "Failed to clean up attachments directory %s: %v", attachmentsDir, err)
}
}()
}
}

params := w.prepareTaskParams(assignment, attachmentsDir)
if err := w.backend.ExecuteTask(ctx, params); err != nil {
log.Errorf(ctx, "Task execution failed: taskID=%s, error=%v", taskID, err)
if statusErr := w.sendTaskFailed(taskID, fmt.Sprintf("Failed to execute task: %v", err)); statusErr != nil {
Expand All @@ -454,6 +483,74 @@ func (w *Worker) executeTask(ctx context.Context, assignment *types.TaskAssignme
log.Infof(ctx, "Task execution completed successfully: taskID=%s", taskID)
}

// downloadAttachments downloads task attachments to a temporary directory using presigned URLs.
// Returns the path to the attachments directory, or empty string if no attachments were downloaded.
// The file naming convention matches the server's ResolveAttachmentReferencesFromTaskDefinition:
// each file is saved as "{attachmentID}_{filename}" inside the directory.
func (w *Worker) downloadAttachments(ctx context.Context, taskID string, attachments []types.AttachmentDownload) (string, error) {
if len(attachments) == 0 {
return "", nil
}

attachmentsDir, err := os.MkdirTemp("", fmt.Sprintf("oz-attachments-%s-*", taskID))
if err != nil {
return "", fmt.Errorf("failed to create attachments directory: %w", err)
}

log.Infof(ctx, "Downloading %d attachments for task %s to %s", len(attachments), taskID, attachmentsDir)

httpClient := &http.Client{Timeout: 2 * time.Minute}

for _, att := range attachments {
filename := filepath.Base(att.Filename)
if filename == "" || filename == "." {
filename = fmt.Sprintf("attachment_%s", att.AttachmentID)
}

// Match the server's naming convention: {uuid}_{filename}
localPath := filepath.Join(attachmentsDir, fmt.Sprintf("%s_%s", att.AttachmentID, filename))

if err := downloadFile(ctx, httpClient, att.DownloadURL, localPath); err != nil {
log.Warnf(ctx, "Failed to download attachment %s (%s) for task %s: %v", att.AttachmentID, att.Filename, taskID, err)
continue
}

log.Debugf(ctx, "Downloaded attachment %s -> %s", att.Filename, localPath)
}

return attachmentsDir, nil
}

// downloadFile downloads a file from a URL to a local path.
func downloadFile(ctx context.Context, client *http.Client, downloadURL, destPath string) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}

resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to download: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code %d", resp.StatusCode)
}

out, err := os.Create(destPath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer out.Close()

if _, err := io.Copy(out, resp.Body); err != nil {
return fmt.Errorf("failed to write file: %w", err)
}

return nil
}

func (w *Worker) sendTaskClaimed(taskID string) error {
claimed := types.TaskClaimedMessage{
TaskID: taskID,
Expand Down
6 changes: 3 additions & 3 deletions internal/worker/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func TestPrepareTaskParamsSidecarImageOverride(t *testing.T) {
TaskID: "task-1",
Task: &types.Task{ID: "task-1"},
SidecarImage: "docker.io/warpdotdev/warp-agent:latest",
})
}, "")
if len(params.Sidecars) == 0 {
t.Fatal("expected at least one sidecar")
}
Expand All @@ -125,7 +125,7 @@ func TestPrepareTaskParamsSidecarImageOverride(t *testing.T) {
TaskID: "task-1",
Task: &types.Task{ID: "task-1"},
SidecarImage: "docker.io/warpdotdev/warp-agent:latest",
})
}, "")
if len(params.Sidecars) == 0 {
t.Fatal("expected at least one sidecar")
}
Expand All @@ -140,7 +140,7 @@ func TestPrepareTaskParamsSidecarImageOverride(t *testing.T) {
TaskID: "task-1",
Task: &types.Task{ID: "task-1"},
SidecarImage: "",
})
}, "")
if len(params.Sidecars) != 0 {
t.Errorf("expected no sidecars when server sidecar image is empty, got %d", len(params.Sidecars))
}
Expand Down
Loading