Skip to content

Commit 7897201

Browse files
vkodithalaoz-agent
andauthored
[REMOTE-1820] Handle cancellation requests from dispatch server (#83)
## Summary This is the worker-side half of REMOTE-1820: self-hosted worker cancellation support. The companion `warp-server` change now forwards `task_cancellation` control messages to the worker. This PR makes `oz-agent-worker` consume those messages, cancel the matching active task context, and report the task back as terminal `CANCELLED` rather than as a normal failure. End to end: - Adds the `task_cancellation` WebSocket message and a shared `TaskStateCancelled` wire value. - Tracks each active task's context and cancel function so the worker can cancel only the requested task. - Cancels matching active tasks when a `task_cancellation` control message arrives; cancellations for inactive tasks are logged and ignored. - Reports cancelled tasks through the existing `task_completed` message shape with `task_state=CANCELLED`, so server-side lifecycle reconciliation can distinguish cancellation from failure. - Preserves the normal success/failure paths for non-cancelled tasks. - Records cancellation lifecycle events and marks cancelled executions separately from failed executions. - Handles context cancellation cleanly in the direct backend and uses bounded Docker cleanup so worker shutdown does not block indefinitely. ## Code overview - `internal/types/messages.go` — adds `task_cancellation`, `TaskState`, and optional `task_state` fields on terminal worker messages. - `internal/worker/worker.go` — handles cancellation messages, tracks active task contexts, emits terminal `CANCELLED`, and keeps shutdown cancellation on the same lifecycle path. - `internal/worker/direct.go` — classifies context cancellation as task cancellation instead of a generic backend failure. - `internal/worker/docker.go` — bounds container cleanup during cancellation/shutdown. - `internal/worker/worker_test.go` — covers cancellation message handling, terminal `CANCELLED` reporting, and shutdown behavior. ## Validation - `go test ./internal/metrics ./internal/worker` - `go test ./...` - `helm template oz-agent-worker charts/oz-agent-worker --set image.tag=test --set worker.workerId=test-worker --set warp.apiKeySecret.create=true --set warp.apiKeySecret.value=dummy --set metrics.enabled=true` - `git diff --check` Co-Authored-By: Oz <oz-agent@warp.dev>
1 parent 8f7ee33 commit 7897201

5 files changed

Lines changed: 242 additions & 49 deletions

File tree

internal/types/messages.go

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@ import (
99
type MessageType string
1010

1111
const (
12-
MessageTypeTaskAssignment MessageType = "task_assignment"
13-
MessageTypeTaskClaimed MessageType = "task_claimed"
14-
MessageTypeTaskCompleted MessageType = "task_completed"
15-
MessageTypeTaskFailed MessageType = "task_failed"
16-
MessageTypeTaskRejected MessageType = "task_rejected"
17-
MessageTypeHeartbeat MessageType = "heartbeat"
12+
MessageTypeTaskAssignment MessageType = "task_assignment"
13+
MessageTypeTaskClaimed MessageType = "task_claimed"
14+
MessageTypeTaskCompleted MessageType = "task_completed"
15+
MessageTypeTaskFailed MessageType = "task_failed"
16+
MessageTypeTaskRejected MessageType = "task_rejected"
17+
MessageTypeTaskCancellation MessageType = "task_cancellation"
18+
MessageTypeHeartbeat MessageType = "heartbeat"
1819
)
1920

2021
// WebSocketMessage is the base structure for all WebSocket messages
@@ -58,14 +59,16 @@ type TaskClaimedMessage struct {
5859

5960
// TaskCompletedMessage tells the server to end the active run execution after a successful agent process exit.
6061
type TaskCompletedMessage struct {
61-
TaskID string `json:"task_id"`
62-
Message string `json:"message"`
62+
TaskID string `json:"task_id"`
63+
Message string `json:"message"`
64+
TaskState *TaskState `json:"task_state,omitempty"`
6365
}
6466

6567
// TaskFailedMessage is sent from worker to server if task launch fails
6668
type TaskFailedMessage struct {
67-
TaskID string `json:"task_id"`
68-
Message string `json:"message"`
69+
TaskID string `json:"task_id"`
70+
Message string `json:"message"`
71+
TaskState *TaskState `json:"task_state,omitempty"`
6972
}
7073

7174
// TaskRejectedMessage is sent from worker to server when the worker cannot accept the task
@@ -75,6 +78,18 @@ type TaskRejectedMessage struct {
7578
Reason string `json:"reason"`
7679
}
7780

81+
// TaskCancellationMessage is sent from server to worker to cancel an active task.
82+
type TaskCancellationMessage struct {
83+
TaskID string `json:"task_id"`
84+
}
85+
86+
// TaskState is the serialized terminal task state accepted by warp-server.
87+
type TaskState string
88+
89+
const (
90+
TaskStateCancelled TaskState = "CANCELLED"
91+
)
92+
7893
type TaskDefinition struct {
7994
Prompt string `json:"prompt"`
8095
}

internal/worker/direct.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ func (b *DirectBackend) ExecuteTask(ctx context.Context, params *TaskParams) err
202202

203203
log.Infof(ctx, "Running setup command: %s", b.config.SetupCommand)
204204
if err := b.runCommand(ctx, b.config.SetupCommand, workspaceDir, setupEnv); err != nil {
205+
if ctx.Err() != nil {
206+
return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonTaskCancelled, ctx.Err())
207+
}
205208
return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonSetupCommand, fmt.Errorf("setup command failed: %w", err))
206209
}
207210
}
@@ -232,6 +235,9 @@ func (b *DirectBackend) ExecuteTask(ctx context.Context, params *TaskParams) err
232235
log.Debugf(ctx, "Command: %s %s", b.ozPath, strings.Join(params.BaseArgs, " "))
233236

234237
if err := cmd.Run(); err != nil {
238+
if ctx.Err() != nil {
239+
return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonTaskCancelled, ctx.Err())
240+
}
235241
return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonAgentInvocation, fmt.Errorf("oz agent exited with error: %w", err))
236242
}
237243

internal/worker/docker.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ func (b *DockerBackend) ExecuteTask(ctx context.Context, params *TaskParams) err
153153

154154
defer func() {
155155
if containerID != "" && !b.config.NoCleanup {
156-
if _, removeErr := dockerClient.ContainerRemove(ctx, containerID, client.ContainerRemoveOptions{Force: true}); removeErr != nil {
156+
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), BackendShutdownTimeout)
157+
defer cleanupCancel()
158+
if _, removeErr := dockerClient.ContainerRemove(cleanupCtx, containerID, client.ContainerRemoveOptions{Force: true}); removeErr != nil {
157159
log.Debugf(ctx, "Container %s already removed or removal failed: %v", containerID, removeErr)
158160
}
159161
}

internal/worker/worker.go

Lines changed: 106 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,23 @@ type Worker struct {
6161
reconnectDelay time.Duration
6262
lastHeartbeat time.Time
6363
sendChan chan []byte
64-
activeTasks map[string]context.CancelFunc
64+
activeTasks map[string]activeTask
6565
tasksMutex sync.Mutex
6666
backend Backend
6767
taskSemaphore *semaphore.Weighted // nil when unlimited
6868
}
69+
type taskCancellationSource string
70+
71+
const (
72+
taskCancellationSourceUser taskCancellationSource = "user"
73+
taskCancellationSourceShutdown taskCancellationSource = "shutdown"
74+
)
75+
76+
type activeTask struct {
77+
ctx context.Context
78+
cancel context.CancelFunc
79+
cancellationSource taskCancellationSource
80+
}
6981

7082
func New(ctx context.Context, config Config) (*Worker, error) {
7183
workerCtx, cancel := context.WithCancel(ctx)
@@ -111,7 +123,7 @@ func New(ctx context.Context, config Config) (*Worker, error) {
111123
cancel: cancel,
112124
reconnectDelay: InitialReconnectDelay,
113125
sendChan: make(chan []byte, 256),
114-
activeTasks: make(map[string]context.CancelFunc),
126+
activeTasks: make(map[string]activeTask),
115127
backend: backend,
116128
taskSemaphore: taskSemaphore,
117129
}, nil
@@ -324,11 +336,40 @@ func (w *Worker) handleMessage(message []byte) {
324336
}
325337
w.handleTaskAssignment(&assignment)
326338

339+
case types.MessageTypeTaskCancellation:
340+
var cancellation types.TaskCancellationMessage
341+
if err := json.Unmarshal(msg.Data, &cancellation); err != nil {
342+
log.Errorf(w.ctx, "Failed to unmarshal task cancellation: %v", err)
343+
return
344+
}
345+
w.handleTaskCancellation(&cancellation)
346+
327347
default:
328348
log.Warnf(w.ctx, "Unknown message type: %s", msg.Type)
329349
}
330350
}
331351

352+
func (w *Worker) handleTaskCancellation(cancellation *types.TaskCancellationMessage) {
353+
w.tasksMutex.Lock()
354+
task, ok := w.activeTasks[cancellation.TaskID]
355+
if ok && task.cancellationSource == "" {
356+
task.cancellationSource = taskCancellationSourceUser
357+
w.activeTasks[cancellation.TaskID] = task
358+
}
359+
w.tasksMutex.Unlock()
360+
if !ok {
361+
log.Warnf(w.ctx, "Received cancellation for inactive task: taskID=%s", cancellation.TaskID)
362+
return
363+
}
364+
365+
log.Infof(w.ctx, "Cancelling task from server request: taskID=%s", cancellation.TaskID)
366+
metrics.AddTaskEvent(task.ctx, "task.cancellation_requested",
367+
attribute.String("source", "server"),
368+
attribute.String("task.id", cancellation.TaskID),
369+
)
370+
task.cancel()
371+
}
372+
332373
func (w *Worker) handleTaskAssignment(assignment *types.TaskAssignmentMessage) {
333374
receivedAt := time.Now()
334375
log.Infof(w.ctx, "Received task assignment: taskID=%s, title=%s", assignment.TaskID, assignment.Task.Title)
@@ -381,7 +422,10 @@ func (w *Worker) handleTaskAssignment(assignment *types.TaskAssignmentMessage) {
381422
taskCtx, taskCancel := context.WithCancel(executionCtx)
382423

383424
w.tasksMutex.Lock()
384-
w.activeTasks[assignment.TaskID] = taskCancel
425+
w.activeTasks[assignment.TaskID] = activeTask{
426+
ctx: taskCtx,
427+
cancel: taskCancel,
428+
}
385429
w.tasksMutex.Unlock()
386430
go w.executeTask(taskCtx, taskCancel, span, assignment, receivedAt)
387431
}
@@ -512,6 +556,19 @@ func (w *Worker) executeTask(ctx context.Context, taskCancel context.CancelFunc,
512556

513557
err := w.backend.ExecuteTask(ctx, params)
514558
if err != nil {
559+
if ctx.Err() == context.Canceled && w.cancellationSource(taskID) == taskCancellationSourceUser {
560+
result = metrics.TaskResultCancelled
561+
metrics.AddTaskEvent(ctx, "task.cancelled",
562+
attribute.String("source", string(taskCancellationSourceUser)),
563+
)
564+
span.SetStatus(codes.Ok, "task cancelled by user request")
565+
log.Infof(ctx, "Task execution cancelled by user request: taskID=%s", taskID)
566+
if statusErr := w.sendTaskCancelled(taskID, "Task cancelled by user request."); statusErr != nil {
567+
log.Errorf(ctx, "Failed to send task cancelled message: %v", statusErr)
568+
}
569+
return
570+
}
571+
515572
result = metrics.TaskResultFailed
516573
phase, reason := taskFailureLabels(err)
517574
metrics.RecordTaskFailure(phase, reason)
@@ -537,6 +594,16 @@ func (w *Worker) executeTask(ctx context.Context, taskCancel context.CancelFunc,
537594
}
538595
}
539596

597+
func (w *Worker) cancellationSource(taskID string) taskCancellationSource {
598+
w.tasksMutex.Lock()
599+
defer w.tasksMutex.Unlock()
600+
601+
task, ok := w.activeTasks[taskID]
602+
if !ok {
603+
return ""
604+
}
605+
return task.cancellationSource
606+
}
540607
func (w *Worker) sendTaskClaimed(taskID string) error {
541608
claimed := types.TaskClaimedMessage{
542609
TaskID: taskID,
@@ -561,6 +628,32 @@ func (w *Worker) sendTaskClaimed(taskID string) error {
561628
return w.sendMessage(msgBytes)
562629
}
563630

631+
func (w *Worker) sendTaskCancelled(taskID, message string) error {
632+
taskState := types.TaskStateCancelled
633+
completedMsg := types.TaskCompletedMessage{
634+
TaskID: taskID,
635+
Message: message,
636+
TaskState: &taskState,
637+
}
638+
639+
data, err := json.Marshal(completedMsg)
640+
if err != nil {
641+
return fmt.Errorf("failed to marshal task cancelled message: %w", err)
642+
}
643+
644+
msg := types.WebSocketMessage{
645+
Type: types.MessageTypeTaskCompleted,
646+
Data: data,
647+
}
648+
649+
msgBytes, err := json.Marshal(msg)
650+
if err != nil {
651+
return fmt.Errorf("failed to marshal websocket message: %w", err)
652+
}
653+
654+
return w.sendMessage(msgBytes)
655+
}
656+
564657
func (w *Worker) sendTaskRejected(taskID, reason string) error {
565658
rejectedMsg := types.TaskRejectedMessage{
566659
TaskID: taskID,
@@ -654,9 +747,17 @@ func (w *Worker) Shutdown() {
654747
log.Infof(w.ctx, "Preserving %d active tasks during worker shutdown", activeTaskCount)
655748
} else if activeTaskCount > 0 {
656749
log.Infof(w.ctx, "Cancelling %d active tasks", activeTaskCount)
657-
for taskID, cancel := range w.activeTasks {
750+
for taskID, task := range w.activeTasks {
751+
if task.cancellationSource == "" {
752+
task.cancellationSource = taskCancellationSourceShutdown
753+
w.activeTasks[taskID] = task
754+
}
658755
log.Debugf(w.ctx, "Cancelling task: %s", taskID)
659-
cancel()
756+
metrics.AddTaskEvent(task.ctx, "task.cancellation_requested",
757+
attribute.String("source", "signal"),
758+
attribute.String("task.id", taskID),
759+
)
760+
task.cancel()
660761
}
661762
}
662763
w.tasksMutex.Unlock()

0 commit comments

Comments
 (0)