diff --git a/internal/types/messages.go b/internal/types/messages.go index 3763e05..728bff2 100644 --- a/internal/types/messages.go +++ b/internal/types/messages.go @@ -29,6 +29,14 @@ type SidecarMount struct { ReadWrite bool `json:"read_write"` // If false (default), the mount is read-only. } +// AttachmentDownload contains information needed to download a single task attachment. +type AttachmentDownload struct { + AttachmentID string `json:"attachment_id"` + Filename string `json:"filename"` + MimeType string `json:"mime_type"` + DownloadURL string `json:"download_url"` +} + // TaskAssignmentMessage is sent from server to worker when a task is available type TaskAssignmentMessage struct { TaskID string `json:"task_id"` @@ -40,6 +48,9 @@ type TaskAssignmentMessage struct { EnvVars map[string]string `json:"env_vars,omitempty"` // AdditionalSidecars is a list of extra sidecar images to mount into the task container. AdditionalSidecars []SidecarMount `json:"additional_sidecars,omitempty"` + // Attachments contains presigned download URLs for task attachments. + // The worker should download these files before starting the agent. + Attachments []AttachmentDownload `json:"attachments,omitempty"` } // TaskClaimedMessage is sent from worker to server after successfully claiming a task diff --git a/internal/worker/worker.go b/internal/worker/worker.go index c8e789e..0c46d19 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -4,7 +4,11 @@ import ( "context" "encoding/json" "fmt" + "io" + "net/http" "net/url" + "os" + "path/filepath" "sync" "time" @@ -346,7 +350,8 @@ func (w *Worker) handleTaskAssignment(assignment *types.TaskAssignmentMessage) { // prepareTaskParams converts a TaskAssignmentMessage into backend-agnostic TaskParams, // resolving common environment variables, default images, and base CLI arguments. -func (w *Worker) prepareTaskParams(assignment *types.TaskAssignmentMessage) *TaskParams { +// attachmentsDir is the path to locally downloaded attachments (empty if none). +func (w *Worker) prepareTaskParams(assignment *types.TaskAssignmentMessage, attachmentsDir string) *TaskParams { task := assignment.Task // Resolve Docker image. @@ -381,6 +386,9 @@ func (w *Worker) prepareTaskParams(assignment *types.TaskAssignmentMessage) *Tas if w.config.SessionSharingServerURL != "" { baseArgs = append(baseArgs, "--session-sharing-server-url", w.config.SessionSharingServerURL) } + if attachmentsDir != "" { + baseArgs = append(baseArgs, "--attachments-dir", attachmentsDir) + } // Build a unified sidecar list: // entrypoint.sh lives) comes first, followed by any additional sidecars. @@ -442,7 +450,28 @@ func (w *Worker) executeTask(ctx context.Context, assignment *types.TaskAssignme taskID := assignment.TaskID log.Infof(ctx, "Starting task execution: taskID=%s, title=%s", taskID, assignment.Task.Title) - params := w.prepareTaskParams(assignment) + // Download attachments if present, before preparing task params. + var attachmentsDir string + if len(assignment.Attachments) > 0 { + var err error + attachmentsDir, err = w.downloadAttachments(ctx, taskID, assignment.Attachments) + if err != nil { + log.Errorf(ctx, "Failed to download attachments for task %s: %v", taskID, err) + if statusErr := w.sendTaskFailed(taskID, fmt.Sprintf("Failed to download attachments: %v", err)); statusErr != nil { + log.Errorf(ctx, "Failed to send task failed message: %v", statusErr) + } + return + } + if attachmentsDir != "" { + defer func() { + if err := os.RemoveAll(attachmentsDir); err != nil { + log.Warnf(ctx, "Failed to clean up attachments directory %s: %v", attachmentsDir, err) + } + }() + } + } + + params := w.prepareTaskParams(assignment, attachmentsDir) if err := w.backend.ExecuteTask(ctx, params); err != nil { log.Errorf(ctx, "Task execution failed: taskID=%s, error=%v", taskID, err) if statusErr := w.sendTaskFailed(taskID, fmt.Sprintf("Failed to execute task: %v", err)); statusErr != nil { @@ -454,6 +483,74 @@ func (w *Worker) executeTask(ctx context.Context, assignment *types.TaskAssignme log.Infof(ctx, "Task execution completed successfully: taskID=%s", taskID) } +// downloadAttachments downloads task attachments to a temporary directory using presigned URLs. +// Returns the path to the attachments directory, or empty string if no attachments were downloaded. +// The file naming convention matches the server's ResolveAttachmentReferencesFromTaskDefinition: +// each file is saved as "{attachmentID}_{filename}" inside the directory. +func (w *Worker) downloadAttachments(ctx context.Context, taskID string, attachments []types.AttachmentDownload) (string, error) { + if len(attachments) == 0 { + return "", nil + } + + attachmentsDir, err := os.MkdirTemp("", fmt.Sprintf("oz-attachments-%s-*", taskID)) + if err != nil { + return "", fmt.Errorf("failed to create attachments directory: %w", err) + } + + log.Infof(ctx, "Downloading %d attachments for task %s to %s", len(attachments), taskID, attachmentsDir) + + httpClient := &http.Client{Timeout: 2 * time.Minute} + + for _, att := range attachments { + filename := filepath.Base(att.Filename) + if filename == "" || filename == "." { + filename = fmt.Sprintf("attachment_%s", att.AttachmentID) + } + + // Match the server's naming convention: {uuid}_{filename} + localPath := filepath.Join(attachmentsDir, fmt.Sprintf("%s_%s", att.AttachmentID, filename)) + + if err := downloadFile(ctx, httpClient, att.DownloadURL, localPath); err != nil { + log.Warnf(ctx, "Failed to download attachment %s (%s) for task %s: %v", att.AttachmentID, att.Filename, taskID, err) + continue + } + + log.Debugf(ctx, "Downloaded attachment %s -> %s", att.Filename, localPath) + } + + return attachmentsDir, nil +} + +// downloadFile downloads a file from a URL to a local path. +func downloadFile(ctx context.Context, client *http.Client, downloadURL, destPath string) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to download: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + + out, err := os.Create(destPath) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + defer out.Close() + + if _, err := io.Copy(out, resp.Body); err != nil { + return fmt.Errorf("failed to write file: %w", err) + } + + return nil +} + func (w *Worker) sendTaskClaimed(taskID string) error { claimed := types.TaskClaimedMessage{ TaskID: taskID, diff --git a/internal/worker/worker_test.go b/internal/worker/worker_test.go index 1ba1c86..8d937a6 100644 --- a/internal/worker/worker_test.go +++ b/internal/worker/worker_test.go @@ -110,7 +110,7 @@ func TestPrepareTaskParamsSidecarImageOverride(t *testing.T) { TaskID: "task-1", Task: &types.Task{ID: "task-1"}, SidecarImage: "docker.io/warpdotdev/warp-agent:latest", - }) + }, "") if len(params.Sidecars) == 0 { t.Fatal("expected at least one sidecar") } @@ -125,7 +125,7 @@ func TestPrepareTaskParamsSidecarImageOverride(t *testing.T) { TaskID: "task-1", Task: &types.Task{ID: "task-1"}, SidecarImage: "docker.io/warpdotdev/warp-agent:latest", - }) + }, "") if len(params.Sidecars) == 0 { t.Fatal("expected at least one sidecar") } @@ -140,7 +140,7 @@ func TestPrepareTaskParamsSidecarImageOverride(t *testing.T) { TaskID: "task-1", Task: &types.Task{ID: "task-1"}, SidecarImage: "", - }) + }, "") if len(params.Sidecars) != 0 { t.Errorf("expected no sidecars when server sidecar image is empty, got %d", len(params.Sidecars)) }