Skip to content

Commit 33460d2

Browse files
committed
feat: address ai's review
1 parent eef927d commit 33460d2

8 files changed

Lines changed: 209 additions & 23 deletions

File tree

cmd/server/signals_windows.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,14 @@ import (
77
"log/slog"
88
"os"
99
"os/signal"
10-
"syscall"
1110

1211
"github.com/coder/agentapi/lib/httpapi"
1312
)
1413

1514
// handleSignals sets up signal handlers for Windows.
16-
// Only handles SIGTERM and SIGINT (SIGHUP and SIGUSR1 don't exist on Windows).
1715
func handleSignals(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, srv *httpapi.Server) {
18-
// Handle shutdown signals (SIGTERM, SIGINT only on Windows)
1916
shutdownCh := make(chan os.Signal, 1)
20-
signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM)
17+
signal.Notify(shutdownCh, os.Interrupt)
2118
go func() {
2219
defer signal.Stop(shutdownCh)
2320
sig := <-shutdownCh

lib/httpapi/events.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"sync"
77
"time"
88

9+
"github.com/coder/quartz"
10+
911
mf "github.com/coder/agentapi/lib/msgfmt"
1012
st "github.com/coder/agentapi/lib/screentracker"
1113
"github.com/coder/agentapi/lib/util"
@@ -54,9 +56,9 @@ type ScreenUpdateBody struct {
5456
}
5557

5658
type ErrorBody struct {
57-
Message string `json:"message" doc:"Error message"`
58-
Level string `json:"level" doc:"Error level: 'warning' or 'error'"`
59-
Time time.Time `json:"time" doc:"Timestamp when the error occurred"`
59+
Message string `json:"message" doc:"Error message"`
60+
Level st.ErrorLevel `json:"level" doc:"Error level"`
61+
Time time.Time `json:"time" doc:"Timestamp when the error occurred"`
6062
}
6163

6264
type Event struct {
@@ -74,6 +76,7 @@ type EventEmitter struct {
7476
subscriptionBufSize uint
7577
screen string
7678
errors []ErrorBody
79+
clock quartz.Clock
7780
}
7881

7982
func convertStatus(status st.ConversationStatus) AgentStatus {
@@ -109,6 +112,12 @@ func WithAgentType(agentType mf.AgentType) EventEmitterOption {
109112
}
110113
}
111114

115+
func WithClock(clock quartz.Clock) EventEmitterOption {
116+
return func(e *EventEmitter) {
117+
e.clock = clock
118+
}
119+
}
120+
112121
func NewEventEmitter(opts ...EventEmitterOption) *EventEmitter {
113122
e := &EventEmitter{
114123
messages: make([]st.ConversationMessage, 0),
@@ -119,6 +128,9 @@ func NewEventEmitter(opts ...EventEmitterOption) *EventEmitter {
119128
for _, opt := range opts {
120129
opt(e)
121130
}
131+
if e.clock == nil {
132+
e.clock = quartz.NewReal()
133+
}
122134
return e
123135
}
124136

@@ -202,14 +214,14 @@ func (e *EventEmitter) EmitScreen(newScreen string) {
202214
e.screen = newScreen
203215
}
204216

205-
func (e *EventEmitter) EmitError(message string, level string) {
217+
func (e *EventEmitter) EmitError(message string, level st.ErrorLevel) {
206218
e.mu.Lock()
207219
defer e.mu.Unlock()
208220

209221
errorBody := ErrorBody{
210222
Message: message,
211223
Level: level,
212-
Time: time.Now(),
224+
Time: e.clock.Now(),
213225
}
214226

215227
// Store the error so new subscribers can receive all errors

lib/httpapi/events_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"time"
77

88
st "github.com/coder/agentapi/lib/screentracker"
9+
"github.com/coder/quartz"
910
"github.com/stretchr/testify/assert"
1011
)
1112

@@ -97,4 +98,40 @@ func TestEventEmitter(t *testing.T) {
9798
t.Fatalf("read should not block")
9899
}
99100
})
101+
102+
t.Run("clock-injection", func(t *testing.T) {
103+
mockClock := quartz.NewMock(t)
104+
fixedTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
105+
mockClock.Set(fixedTime)
106+
107+
emitter := NewEventEmitter(WithClock(mockClock), WithSubscriptionBufSize(10))
108+
_, ch, stateEvents := emitter.Subscribe()
109+
110+
// Verify initial state events
111+
assert.Len(t, stateEvents, 2)
112+
113+
// Emit an error and verify it uses the mock clock time
114+
emitter.EmitError("test error", st.ErrorLevelError)
115+
116+
event := <-ch
117+
assert.Equal(t, EventTypeError, event.Type)
118+
errorBody, ok := event.Payload.(ErrorBody)
119+
assert.True(t, ok)
120+
assert.Equal(t, "test error", errorBody.Message)
121+
assert.Equal(t, st.ErrorLevelError, errorBody.Level)
122+
assert.Equal(t, fixedTime, errorBody.Time)
123+
124+
// Advance the clock and emit another error
125+
newTime := fixedTime.Add(1 * time.Hour)
126+
mockClock.Set(newTime)
127+
emitter.EmitError("another error", st.ErrorLevelWarning)
128+
129+
event = <-ch
130+
assert.Equal(t, EventTypeError, event.Type)
131+
errorBody, ok = event.Payload.(ErrorBody)
132+
assert.True(t, ok)
133+
assert.Equal(t, "another error", errorBody.Message)
134+
assert.Equal(t, st.ErrorLevelWarning, errorBody.Level)
135+
assert.Equal(t, newTime, errorBody.Time)
136+
})
100137
}

lib/httpapi/server.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
278278
}
279279
logger.Info("Created temporary directory for uploads", "tempDir", tempDir)
280280

281-
ctx, cancel := context.WithCancel(context.Background())
281+
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
282282

283283
s := &Server{
284284
router: router,
@@ -292,8 +292,8 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
292292
chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"),
293293
tempDir: tempDir,
294294
clock: config.Clock,
295-
shutdownCtx: ctx,
296-
shutdown: cancel,
295+
shutdownCtx: shutdownCtx,
296+
shutdown: shutdownCancel,
297297
}
298298

299299
// Register API routes

lib/screentracker/conversation.go

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@ var ConversationRoleValues = []ConversationRole{
3434
ConversationRoleAgent,
3535
}
3636

37+
type ErrorLevel string
38+
39+
func (e ErrorLevel) Schema(r huma.Registry) *huma.Schema {
40+
return util.OpenAPISchema(r, "ErrorLevel", ErrorLevelValues)
41+
}
42+
43+
const (
44+
ErrorLevelWarning ErrorLevel = "warning"
45+
ErrorLevelError ErrorLevel = "error"
46+
)
47+
48+
var ErrorLevelValues = []ErrorLevel{
49+
ErrorLevelWarning,
50+
ErrorLevelError,
51+
}
52+
3753
var (
3854
ErrMessageValidationWhitespace = xerrors.New("message must be trimmed of leading and trailing whitespace")
3955
ErrMessageValidationEmpty = xerrors.New("message must not be empty")
@@ -80,14 +96,14 @@ type Emitter interface {
8096
EmitMessages([]ConversationMessage)
8197
EmitStatus(ConversationStatus)
8298
EmitScreen(string)
83-
EmitError(message string, level string)
99+
EmitError(message string, level ErrorLevel)
84100
}
85101

86102
type ConversationMessage struct {
87-
Id int
88-
Message string
89-
Role ConversationRole
90-
Time time.Time
103+
Id int `json:"id"`
104+
Message string `json:"message"`
105+
Role ConversationRole `json:"role"`
106+
Time time.Time `json:"time"`
91107
}
92108

93109
type StatePersistenceConfig struct {

lib/screentracker/pty_conversation.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ type noopEmitter struct{}
152152
func (noopEmitter) EmitMessages([]ConversationMessage) {}
153153
func (noopEmitter) EmitStatus(ConversationStatus) {}
154154
func (noopEmitter) EmitScreen(string) {}
155-
func (noopEmitter) EmitError(_ string, _ string) {}
155+
func (noopEmitter) EmitError(_ string, _ ErrorLevel) {}
156156

157157
func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PTYConversation {
158158
if cfg.Clock == nil {
@@ -207,7 +207,7 @@ func (c *PTYConversation) Start(ctx context.Context) {
207207
if c.initialPromptReady && c.loadStateStatus == LoadStatePending && c.cfg.StatePersistenceConfig.LoadState {
208208
if err := c.loadStateLocked(); err != nil {
209209
c.cfg.Logger.Error("Failed to load state", "error", err)
210-
c.emitter.EmitError(fmt.Sprintf("Failed to restore previous session: %v", err), "warning")
210+
c.emitter.EmitError(fmt.Sprintf("Failed to restore previous session: %v", err), ErrorLevelWarning)
211211
c.loadStateStatus = LoadStateFailed
212212
} else {
213213
c.loadStateStatus = LoadStateSucceeded
@@ -649,6 +649,11 @@ func (c *PTYConversation) loadStateLocked() error {
649649
return xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err)
650650
}
651651

652+
// Validate version
653+
if agentState.Version != 1 {
654+
return xerrors.Errorf("unsupported state file version %d (expected 1)", agentState.Version)
655+
}
656+
652657
// Handle initial prompt restoration:
653658
// - If a new initial prompt was provided via flags, check if it differs from the saved one.
654659
// If different, mark as not sent (will be sent). If same, preserve sent status.
@@ -657,7 +662,7 @@ func (c *PTYConversation) loadStateLocked() error {
657662
if len(c.cfg.InitialPrompt) > 0 {
658663
isDifferent := buildStringFromMessageParts(c.cfg.InitialPrompt) != agentState.InitialPrompt
659664
c.initialPromptSent = !isDifferent
660-
} else {
665+
} else if agentState.InitialPrompt != "" {
661666
c.cfg.InitialPrompt = []MessagePart{MessagePartText{
662667
Content: agentState.InitialPrompt,
663668
Alias: "",

lib/screentracker/pty_conversation_test.go

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ type testEmitter struct{}
5656
func (testEmitter) EmitMessages([]st.ConversationMessage) {}
5757
func (testEmitter) EmitStatus(st.ConversationStatus) {}
5858
func (testEmitter) EmitScreen(string) {}
59-
func (testEmitter) EmitError(_ string, _ string) {}
59+
func (testEmitter) EmitError(_ string, _ st.ErrorLevel) {}
6060

6161
// advanceFor is a shorthand for advanceUntil with a time-based condition.
6262
func advanceFor(ctx context.Context, t *testing.T, mClock *quartz.Mock, total time.Duration) {
@@ -798,6 +798,57 @@ func TestStatePersistence(t *testing.T) {
798798
messages := c.Messages()
799799
assert.Len(t, messages, 1)
800800
})
801+
802+
t.Run("LoadState rejects unsupported version", func(t *testing.T) {
803+
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
804+
t.Cleanup(cancel)
805+
806+
tmpDir := t.TempDir()
807+
stateFile := tmpDir + "/unsupported_version.json"
808+
809+
// Create state file with unsupported version
810+
unsupportedState := map[string]interface{}{
811+
"version": 999, // Unsupported version
812+
"messages": []interface{}{},
813+
"initial_prompt": "",
814+
"initial_prompt_sent": false,
815+
}
816+
stateBytes, err := json.Marshal(unsupportedState)
817+
require.NoError(t, err)
818+
err = os.WriteFile(stateFile, stateBytes, 0o644)
819+
require.NoError(t, err)
820+
821+
mClock := quartz.NewMock(t)
822+
agent := &testAgent{screen: "ready"}
823+
cfg := st.PTYConversationConfig{
824+
Clock: mClock,
825+
SnapshotInterval: 100 * time.Millisecond,
826+
ScreenStabilityLength: 200 * time.Millisecond,
827+
AgentIO: agent,
828+
Logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
829+
FormatMessage: func(message string, userInput string) string {
830+
return message
831+
},
832+
ReadyForInitialPrompt: func(message string) bool {
833+
return message == "ready"
834+
},
835+
StatePersistenceConfig: st.StatePersistenceConfig{
836+
StateFile: stateFile,
837+
LoadState: true,
838+
SaveState: false,
839+
},
840+
}
841+
842+
// Should not panic - logs error and continues with empty state
843+
c := st.NewPTY(ctx, cfg, &testEmitter{})
844+
c.Start(ctx)
845+
846+
advanceFor(ctx, t, mClock, 300*time.Millisecond)
847+
848+
// Should have default initial message (version error causes fallback to empty state)
849+
messages := c.Messages()
850+
assert.Len(t, messages, 1)
851+
})
801852
}
802853

803854
func TestInitialPromptReadiness(t *testing.T) {
@@ -1286,4 +1337,63 @@ func TestInitialPromptSent(t *testing.T) {
12861337
}
12871338
assert.True(t, found, "saved prompt should be sent when no new prompt provided")
12881339
})
1340+
1341+
t.Run("empty prompt from state is not restored", func(t *testing.T) {
1342+
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
1343+
t.Cleanup(cancel)
1344+
1345+
tmpDir := t.TempDir()
1346+
stateFile := tmpDir + "/state.json"
1347+
1348+
// Create state file with empty prompt
1349+
emptyPromptState := st.AgentState{
1350+
Version: 1,
1351+
Messages: []st.ConversationMessage{},
1352+
InitialPrompt: "", // Empty prompt
1353+
InitialPromptSent: false,
1354+
}
1355+
stateBytes, err := json.Marshal(emptyPromptState)
1356+
require.NoError(t, err)
1357+
err = os.WriteFile(stateFile, stateBytes, 0o644)
1358+
require.NoError(t, err)
1359+
1360+
mClock := quartz.NewMock(t)
1361+
agent := &testAgent{screen: "ready"}
1362+
1363+
cfg := st.PTYConversationConfig{
1364+
Clock: mClock,
1365+
SnapshotInterval: 100 * time.Millisecond,
1366+
ScreenStabilityLength: 200 * time.Millisecond,
1367+
AgentIO: agent,
1368+
Logger: discardLogger,
1369+
FormatMessage: func(message string, userInput string) string {
1370+
return message
1371+
},
1372+
ReadyForInitialPrompt: func(message string) bool {
1373+
return message == "ready"
1374+
},
1375+
StatePersistenceConfig: st.StatePersistenceConfig{
1376+
StateFile: stateFile,
1377+
LoadState: true,
1378+
SaveState: false,
1379+
},
1380+
}
1381+
1382+
c := st.NewPTY(ctx, cfg, &testEmitter{})
1383+
c.Start(ctx)
1384+
1385+
// Agent becomes ready
1386+
agent.setScreen("ready")
1387+
1388+
// Advance time to ensure any prompt would be sent
1389+
advanceFor(ctx, t, mClock, 500*time.Millisecond)
1390+
1391+
// Verify no prompt was sent (should only have the initial screen message)
1392+
messages := c.Messages()
1393+
for _, msg := range messages {
1394+
if msg.Role == st.ConversationRoleUser {
1395+
t.Errorf("Unexpected user message sent: %q (empty prompt should not be restored)", msg.Message)
1396+
}
1397+
}
1398+
})
12891399
}

openapi.json

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
"additionalProperties": false,
2424
"properties": {
2525
"level": {
26-
"description": "Error level: 'warning' or 'error'",
27-
"type": "string"
26+
"$ref": "#/components/schemas/ErrorLevel",
27+
"description": "Error level"
2828
},
2929
"message": {
3030
"description": "Error message",
@@ -60,6 +60,15 @@
6060
},
6161
"type": "object"
6262
},
63+
"ErrorLevel": {
64+
"enum": [
65+
"error",
66+
"warning"
67+
],
68+
"example": "warning",
69+
"title": "ErrorLevel",
70+
"type": "string"
71+
},
6372
"ErrorModel": {
6473
"additionalProperties": false,
6574
"properties": {

0 commit comments

Comments
 (0)