Skip to content

Commit d860b2e

Browse files
authored
Send TaskCompleted message upon CLI exit with no error. (#55)
WISOTT, context: https://warpdev.slack.com/archives/C09E37H1NMA/p1776898580885489?thread_ts=1776896423.112019&cid=C09E37H1NMA
1 parent 9e18e71 commit d860b2e

3 files changed

Lines changed: 131 additions & 0 deletions

File tree

internal/types/messages.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ type MessageType string
1111
const (
1212
MessageTypeTaskAssignment MessageType = "task_assignment"
1313
MessageTypeTaskClaimed MessageType = "task_claimed"
14+
MessageTypeTaskCompleted MessageType = "task_completed"
1415
MessageTypeTaskFailed MessageType = "task_failed"
1516
MessageTypeTaskRejected MessageType = "task_rejected"
1617
MessageTypeHeartbeat MessageType = "heartbeat"
@@ -50,6 +51,12 @@ type TaskClaimedMessage struct {
5051
WorkerID string `json:"worker_id"`
5152
}
5253

54+
// TaskCompletedMessage tells the server to end the active run execution after a successful agent process exit.
55+
type TaskCompletedMessage struct {
56+
TaskID string `json:"task_id"`
57+
Message string `json:"message"`
58+
}
59+
5360
// TaskFailedMessage is sent from worker to server if task launch fails
5461
type TaskFailedMessage struct {
5562
TaskID string `json:"task_id"`

internal/worker/worker.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,9 @@ func (w *Worker) executeTask(ctx context.Context, assignment *types.TaskAssignme
455455
}
456456

457457
log.Infof(ctx, "Task execution completed successfully: taskID=%s", taskID)
458+
if err := w.sendTaskCompleted(taskID, "Task completed successfully"); err != nil {
459+
log.Errorf(ctx, "Failed to send task completed message: %v", err)
460+
}
458461
}
459462

460463
func (w *Worker) sendTaskClaimed(taskID string) error {
@@ -505,6 +508,30 @@ func (w *Worker) sendTaskRejected(taskID, reason string) error {
505508
return w.sendMessage(msgBytes)
506509
}
507510

511+
func (w *Worker) sendTaskCompleted(taskID, message string) error {
512+
completedMsg := types.TaskCompletedMessage{
513+
TaskID: taskID,
514+
Message: message,
515+
}
516+
517+
data, err := json.Marshal(completedMsg)
518+
if err != nil {
519+
return fmt.Errorf("failed to marshal task completed message: %w", err)
520+
}
521+
522+
msg := types.WebSocketMessage{
523+
Type: types.MessageTypeTaskCompleted,
524+
Data: data,
525+
}
526+
527+
msgBytes, err := json.Marshal(msg)
528+
if err != nil {
529+
return fmt.Errorf("failed to marshal websocket message: %w", err)
530+
}
531+
532+
return w.sendMessage(msgBytes)
533+
}
534+
508535
func (w *Worker) sendTaskFailed(taskID, message string) error {
509536
failedMsg := types.TaskFailedMessage{
510537
TaskID: taskID,

internal/worker/worker_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package worker
22

33
import (
44
"context"
5+
"encoding/json"
6+
"errors"
57
"testing"
68

79
"github.com/warpdotdev/oz-agent-worker/internal/types"
@@ -21,6 +23,101 @@ func (b *shutdownRecordingBackend) Shutdown(ctx context.Context) {
2123
b.shutdownCtxErr = ctx.Err()
2224
}
2325

26+
type recordingBackend struct {
27+
err error
28+
}
29+
30+
func (b *recordingBackend) ExecuteTask(context.Context, *TaskParams) error {
31+
return b.err
32+
}
33+
34+
func (b *recordingBackend) Shutdown(context.Context) {}
35+
36+
func TestExecuteTaskReportsTaskCompletedOnSuccess(t *testing.T) {
37+
w := &Worker{
38+
ctx: context.Background(),
39+
config: Config{},
40+
sendChan: make(chan []byte, 1),
41+
activeTasks: map[string]context.CancelFunc{"task-1": func() {}},
42+
backend: &recordingBackend{},
43+
}
44+
45+
w.executeTask(context.Background(), &types.TaskAssignmentMessage{
46+
TaskID: "task-1",
47+
Task: &types.Task{ID: "task-1", Title: "test task"},
48+
})
49+
50+
msg := readWebSocketMessage(t, w.sendChan)
51+
if msg.Type != types.MessageTypeTaskCompleted {
52+
t.Fatalf("message type = %q, want %q", msg.Type, types.MessageTypeTaskCompleted)
53+
}
54+
55+
var completed types.TaskCompletedMessage
56+
if err := json.Unmarshal(msg.Data, &completed); err != nil {
57+
t.Fatalf("failed to unmarshal task completed message: %v", err)
58+
}
59+
if completed.TaskID != "task-1" {
60+
t.Errorf("task ID = %q, want %q", completed.TaskID, "task-1")
61+
}
62+
if completed.Message != "Task completed successfully" {
63+
t.Errorf("message = %q, want %q", completed.Message, "Task completed successfully")
64+
}
65+
if _, ok := w.activeTasks["task-1"]; ok {
66+
t.Fatal("task should be removed from active tasks")
67+
}
68+
}
69+
70+
func TestExecuteTaskReportsTaskFailedOnBackendError(t *testing.T) {
71+
w := &Worker{
72+
ctx: context.Background(),
73+
config: Config{},
74+
sendChan: make(chan []byte, 1),
75+
activeTasks: map[string]context.CancelFunc{"task-1": func() {}},
76+
backend: &recordingBackend{err: errors.New("boom")},
77+
}
78+
79+
w.executeTask(context.Background(), &types.TaskAssignmentMessage{
80+
TaskID: "task-1",
81+
Task: &types.Task{ID: "task-1", Title: "test task"},
82+
})
83+
84+
msg := readWebSocketMessage(t, w.sendChan)
85+
if msg.Type != types.MessageTypeTaskFailed {
86+
t.Fatalf("message type = %q, want %q", msg.Type, types.MessageTypeTaskFailed)
87+
}
88+
89+
var failed types.TaskFailedMessage
90+
if err := json.Unmarshal(msg.Data, &failed); err != nil {
91+
t.Fatalf("failed to unmarshal task failed message: %v", err)
92+
}
93+
if failed.TaskID != "task-1" {
94+
t.Errorf("task ID = %q, want %q", failed.TaskID, "task-1")
95+
}
96+
if failed.Message != "Failed to execute task: boom" {
97+
t.Errorf("message = %q, want %q", failed.Message, "Failed to execute task: boom")
98+
}
99+
if _, ok := w.activeTasks["task-1"]; ok {
100+
t.Fatal("task should be removed from active tasks")
101+
}
102+
}
103+
104+
func readWebSocketMessage(t *testing.T, messages <-chan []byte) types.WebSocketMessage {
105+
t.Helper()
106+
107+
select {
108+
case msgBytes := <-messages:
109+
var msg types.WebSocketMessage
110+
if err := json.Unmarshal(msgBytes, &msg); err != nil {
111+
t.Fatalf("failed to unmarshal websocket message: %v", err)
112+
}
113+
return msg
114+
default:
115+
t.Fatal("expected websocket message")
116+
}
117+
118+
return types.WebSocketMessage{}
119+
}
120+
24121
func TestDefaultImageForTask(t *testing.T) {
25122
newWorker := func(defaultImage string) *Worker {
26123
ctx := context.Background()

0 commit comments

Comments
 (0)