Skip to content

Commit b2cbf56

Browse files
authored
fix: address review findings from #177 (#195)
1 parent d7d7744 commit b2cbf56

9 files changed

Lines changed: 286 additions & 21 deletions

File tree

cmd/server/process_unix.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//go:build unix
2+
3+
package server
4+
5+
import (
6+
"errors"
7+
"os"
8+
"syscall"
9+
)
10+
11+
// isProcessRunning checks if a process with the given PID is running.
12+
func isProcessRunning(pid int) bool {
13+
process, err := os.FindProcess(pid)
14+
if err != nil {
15+
return false
16+
}
17+
err = process.Signal(syscall.Signal(0))
18+
return err == nil || errors.Is(err, syscall.EPERM)
19+
}

cmd/server/process_windows.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//go:build windows
2+
3+
package server
4+
5+
// isProcessRunning checks if a process with the given PID is running.
6+
// On Windows, Signal(0) is not supported, so this always returns false.
7+
// PID file liveness detection is best-effort on this platform.
8+
func isProcessRunning(_ int) bool {
9+
return false
10+
}

cmd/server/server.go

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"sort"
1313
"strconv"
1414
"strings"
15-
"syscall"
1615
"time"
1716

1817
"github.com/coder/agentapi/lib/screentracker"
@@ -292,25 +291,28 @@ func writePIDFile(pidFile string, logger *slog.Logger) error {
292291
return nil
293292
}
294293

295-
// cleanupPIDFile removes the PID file if it exists
294+
// cleanupPIDFile removes the PID file if it was written by this process.
296295
func cleanupPIDFile(pidFile string, logger *slog.Logger) {
296+
data, err := os.ReadFile(pidFile)
297+
if err != nil {
298+
if !os.IsNotExist(err) {
299+
logger.Error("Failed to read PID file for cleanup", "pidFile", pidFile, "error", err)
300+
}
301+
return
302+
}
303+
pidStr := strings.TrimSpace(string(data))
304+
filePID, err := strconv.Atoi(pidStr)
305+
if err != nil || filePID != os.Getpid() {
306+
logger.Info("PID file belongs to another process, skipping cleanup", "pidFile", pidFile, "filePID", pidStr)
307+
return
308+
}
297309
if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) {
298310
logger.Error("Failed to remove PID file", "pidFile", pidFile, "error", err)
299311
} else if err == nil {
300312
logger.Info("Removed PID file", "pidFile", pidFile)
301313
}
302314
}
303315

304-
// isProcessRunning checks if a process with the given PID is running
305-
func isProcessRunning(pid int) bool {
306-
process, err := os.FindProcess(pid)
307-
if err != nil {
308-
return false
309-
}
310-
err = process.Signal(syscall.Signal(0))
311-
return err == nil || errors.Is(err, syscall.EPERM)
312-
}
313-
314316
type flagSpec struct {
315317
name string
316318
shorthand string

cmd/server/server_test.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -641,8 +641,9 @@ func TestPIDFileOperations(t *testing.T) {
641641
tmpDir := t.TempDir()
642642
pidFile := tmpDir + "/test.pid"
643643

644-
// Write initial PID file
645-
err := os.WriteFile(pidFile, []byte("12345\n"), 0o644)
644+
// Write a non-numeric PID so strconv.Atoi fails and the liveness
645+
// check is skipped, avoiding flakes when a real PID matches.
646+
err := os.WriteFile(pidFile, []byte("not-a-pid\n"), 0o644)
646647
require.NoError(t, err)
647648

648649
// Overwrite with current PID
@@ -657,12 +658,25 @@ func TestPIDFileOperations(t *testing.T) {
657658
assert.Equal(t, expectedPID, string(data))
658659
})
659660

661+
t.Run("writePIDFile detects running process", func(t *testing.T) {
662+
tmpDir := t.TempDir()
663+
pidFile := tmpDir + "/test.pid"
664+
665+
// Write the current process PID so isProcessRunning returns true.
666+
err := os.WriteFile(pidFile, []byte(fmt.Sprintf("%d\n", os.Getpid())), 0o644)
667+
require.NoError(t, err)
668+
669+
err = writePIDFile(pidFile, discardLogger)
670+
require.Error(t, err)
671+
assert.Contains(t, err.Error(), "another instance is already running")
672+
})
673+
660674
t.Run("cleanupPIDFile removes file", func(t *testing.T) {
661675
tmpDir := t.TempDir()
662676
pidFile := tmpDir + "/test.pid"
663677

664-
// Create PID file
665-
err := os.WriteFile(pidFile, []byte("12345\n"), 0o644)
678+
// Create PID file with current process PID so ownership check passes
679+
err := os.WriteFile(pidFile, []byte(fmt.Sprintf("%d\n", os.Getpid())), 0o644)
666680
require.NoError(t, err)
667681

668682
// Cleanup

cmd/server/signals_unix.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import (
1818
func handleSignals(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, srv *httpapi.Server) {
1919
// Handle shutdown signals (SIGTERM, SIGINT, SIGHUP)
2020
shutdownCh := make(chan os.Signal, 1)
21-
signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGINT)
21+
signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP)
2222
go func() {
2323
defer signal.Stop(shutdownCh)
2424
sig := <-shutdownCh

lib/httpapi/events.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ func convertStatus(status st.ConversationStatus) AgentStatus {
9494

9595
const defaultSubscriptionBufSize uint = 1024
9696

97+
// maxStoredErrors caps the number of errors retained for late subscribers.
98+
const maxStoredErrors = 100
99+
97100
type EventEmitterOption func(*EventEmitter)
98101

99102
func WithSubscriptionBufSize(size uint) EventEmitterOption {
@@ -224,8 +227,11 @@ func (e *EventEmitter) EmitError(message string, level st.ErrorLevel) {
224227
Time: e.clock.Now(),
225228
}
226229

227-
// Store the error so new subscribers can receive all errors
230+
// Store the error so new subscribers can receive recent errors.
228231
e.errors = append(e.errors, errorBody)
232+
if len(e.errors) > maxStoredErrors {
233+
e.errors = e.errors[len(e.errors)-maxStoredErrors:]
234+
}
229235

230236
e.notifyChannels(EventTypeError, errorBody)
231237
}

lib/httpapi/events_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,70 @@ func TestEventEmitter(t *testing.T) {
9999
}
100100
})
101101

102+
t.Run("error-cap", func(t *testing.T) {
103+
emitter := NewEventEmitter(WithSubscriptionBufSize(10))
104+
105+
for i := range 150 {
106+
emitter.EmitError(fmt.Sprintf("error %d", i), st.ErrorLevelError)
107+
}
108+
109+
_, _, stateEvents := emitter.Subscribe()
110+
111+
var errorEvents []Event
112+
for _, ev := range stateEvents {
113+
if ev.Type == EventTypeError {
114+
errorEvents = append(errorEvents, ev)
115+
}
116+
}
117+
118+
assert.Len(t, errorEvents, maxStoredErrors)
119+
120+
// Errors should be the last 100: "error 50" through "error 149".
121+
for i, ev := range errorEvents {
122+
body, ok := ev.Payload.(ErrorBody)
123+
assert.True(t, ok)
124+
assert.Equal(t, fmt.Sprintf("error %d", i+50), body.Message)
125+
}
126+
})
127+
128+
t.Run("error-events-in-initial-state", func(t *testing.T) {
129+
mockClock := quartz.NewMock(t)
130+
fixedTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
131+
mockClock.Set(fixedTime)
132+
133+
emitter := NewEventEmitter(WithClock(mockClock), WithSubscriptionBufSize(10))
134+
135+
emitter.EmitError("err1", st.ErrorLevelError)
136+
mockClock.Set(fixedTime.Add(1 * time.Second))
137+
emitter.EmitError("err2", st.ErrorLevelWarning)
138+
mockClock.Set(fixedTime.Add(2 * time.Second))
139+
emitter.EmitError("err3", st.ErrorLevelError)
140+
141+
_, _, stateEvents := emitter.Subscribe()
142+
143+
var errorEvents []Event
144+
for _, ev := range stateEvents {
145+
if ev.Type == EventTypeError {
146+
errorEvents = append(errorEvents, ev)
147+
}
148+
}
149+
150+
assert.Len(t, errorEvents, 3)
151+
152+
expected := []ErrorBody{
153+
{Message: "err1", Level: st.ErrorLevelError, Time: fixedTime},
154+
{Message: "err2", Level: st.ErrorLevelWarning, Time: fixedTime.Add(1 * time.Second)},
155+
{Message: "err3", Level: st.ErrorLevelError, Time: fixedTime.Add(2 * time.Second)},
156+
}
157+
for i, ev := range errorEvents {
158+
body, ok := ev.Payload.(ErrorBody)
159+
assert.True(t, ok)
160+
assert.Equal(t, expected[i].Message, body.Message)
161+
assert.Equal(t, expected[i].Level, body.Level)
162+
assert.Equal(t, expected[i].Time, body.Time)
163+
}
164+
})
165+
102166
t.Run("clock-injection", func(t *testing.T) {
103167
mockClock := quartz.NewMock(t)
104168
fixedTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)

lib/screentracker/pty_conversation.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,17 +200,21 @@ func (c *PTYConversation) Start(ctx context.Context) {
200200
c.initialPromptReady = true
201201
}
202202

203+
var loadErr string
203204
if c.initialPromptReady && c.loadStateStatus == LoadStatePending && c.cfg.StatePersistenceConfig.LoadState {
204205
if err := c.loadStateLocked(); err != nil {
205206
c.cfg.Logger.Error("Failed to load state", "error", err)
206-
c.emitter.EmitError(fmt.Sprintf("Failed to restore previous session: %v", err), ErrorLevelWarning)
207+
loadErr = fmt.Sprintf("Failed to restore previous session: %v", err)
207208
c.loadStateStatus = LoadStateFailed
208209
} else {
209210
c.loadStateStatus = LoadStateSucceeded
210211
}
211212
}
212213

213214
if c.initialPromptReady && len(c.cfg.InitialPrompt) > 0 && !c.initialPromptSent {
215+
// Safe to send under lock: the queue is guaranteed empty here because
216+
// statusLocked blocks Send until the snapshot buffer fills, which
217+
// cannot happen before this first enqueue completes.
214218
c.outboundQueue <- outboundMessage{parts: c.cfg.InitialPrompt, errCh: nil}
215219
c.initialPromptSent = true
216220
c.dirty = true
@@ -226,6 +230,9 @@ func (c *PTYConversation) Start(ctx context.Context) {
226230
}
227231
c.lock.Unlock()
228232

233+
if loadErr != "" {
234+
c.emitter.EmitError(loadErr, ErrorLevelWarning)
235+
}
229236
c.emitter.EmitStatus(status)
230237
c.emitter.EmitMessages(messages)
231238
c.emitter.EmitScreen(screen)
@@ -292,7 +299,8 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp
292299
if c.cfg.FormatMessage != nil {
293300
agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message)
294301
}
295-
if c.loadStateStatus == LoadStateSucceeded && !c.userSentMessageAfterLoadState && len(c.messages) > 0 {
302+
if c.loadStateStatus == LoadStateSucceeded && !c.userSentMessageAfterLoadState && len(c.messages) > 0 &&
303+
c.messages[len(c.messages)-1].Role == ConversationRoleAgent {
296304
agentMessage = c.messages[len(c.messages)-1].Message
297305
}
298306
if c.cfg.FormatToolCall != nil {
@@ -605,6 +613,12 @@ func (c *PTYConversation) SaveState() error {
605613
return xerrors.Errorf("failed to encode state: %w", err)
606614
}
607615

616+
// Flush to disk before rename for crash safety
617+
if err := f.Sync(); err != nil {
618+
_ = f.Close()
619+
return xerrors.Errorf("failed to sync state file: %w", err)
620+
}
621+
608622
// Close file before rename
609623
if err := f.Close(); err != nil {
610624
return xerrors.Errorf("failed to close temp state file: %w", err)
@@ -668,7 +682,10 @@ func (c *PTYConversation) loadStateLocked() error {
668682
c.initialPromptSent = agentState.InitialPromptSent
669683
if len(c.cfg.InitialPrompt) > 0 {
670684
isDifferent := buildStringFromMessageParts(c.cfg.InitialPrompt) != agentState.InitialPrompt
671-
c.initialPromptSent = !isDifferent
685+
if isDifferent {
686+
c.initialPromptSent = false
687+
}
688+
// If same prompt, keep agentState.InitialPromptSent
672689
} else if agentState.InitialPrompt != "" {
673690
c.cfg.InitialPrompt = []MessagePart{MessagePartText{
674691
Content: agentState.InitialPrompt,

0 commit comments

Comments
 (0)