Skip to content

Commit 2f83ccc

Browse files
35C4n0rjohnstcnmafredri
authored
feat: implement state persistence (#177)
Co-authored-by: Cian Johnston <cian@coder.com> Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>
1 parent 5881f7d commit 2f83ccc

26 files changed

+2435
-106
lines changed

chat/src/app/layout.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ export default function RootLayout({
2929
disableTransitionOnChange
3030
>
3131
{children}
32-
<Toaster richColors />
32+
<Toaster richColors closeButton />
3333
</ThemeProvider>
3434
</body>
3535
</html>

chat/src/components/chat-provider.tsx

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ interface StatusChangeEvent {
3636
agent_type: string;
3737
}
3838

39+
interface ErrorEventData {
40+
message: string;
41+
level: string;
42+
time: string;
43+
}
44+
3945
interface APIErrorDetail {
4046
location: string;
4147
message: string;
@@ -215,6 +221,25 @@ export function ChatProvider({ children }: PropsWithChildren) {
215221
setAgentType(data.agent_type === "" ? "unknown" : data.agent_type as AgentType);
216222
});
217223

224+
// Handle agent error events
225+
eventSource.addEventListener("agent_error", (event) => {
226+
const messageEvent = event as MessageEvent;
227+
try {
228+
const data: ErrorEventData = JSON.parse(messageEvent.data);
229+
230+
// Display error as toast notification that persists until manually dismissed
231+
if (data.level === "error") {
232+
toast.error(data.message, { duration: Infinity });
233+
} else if (data.level === "warning") {
234+
toast.warning(data.message, { duration: Infinity });
235+
} else {
236+
toast.info(data.message, { duration: Infinity });
237+
}
238+
} catch (e) {
239+
console.error("Failed to parse agent_error event data:", e);
240+
}
241+
});
242+
218243
// Handle connection open (server is online)
219244
eventSource.onopen = () => {
220245
// Connection is established, but we'll wait for status_change event

cmd/server/process_unix.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//go:build unix
2+
3+
package server
4+
5+
import (
6+
"errors"
7+
"os"
8+
"syscall"
9+
)
10+
11+
// isProcessRunning checks if a process with the given PID is running.
12+
func isProcessRunning(pid int) bool {
13+
process, err := os.FindProcess(pid)
14+
if err != nil {
15+
return false
16+
}
17+
err = process.Signal(syscall.Signal(0))
18+
return err == nil || errors.Is(err, syscall.EPERM)
19+
}

cmd/server/process_windows.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//go:build windows
2+
3+
package server
4+
5+
// isProcessRunning checks if a process with the given PID is running.
6+
// On Windows, Signal(0) is not supported, so this always returns false.
7+
// PID file liveness detection is best-effort on this platform.
8+
func isProcessRunning(_ int) bool {
9+
return false
10+
}

cmd/server/server.go

Lines changed: 162 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@ import (
88
"log/slog"
99
"net/http"
1010
"os"
11+
"path/filepath"
1112
"sort"
13+
"strconv"
1214
"strings"
15+
"time"
1316

17+
"github.com/coder/agentapi/lib/screentracker"
1418
"github.com/mattn/go-isatty"
1519
"github.com/spf13/cobra"
1620
"github.com/spf13/viper"
@@ -104,9 +108,51 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
104108
}
105109
}
106110

107-
printOpenAPI := viper.GetBool(FlagPrintOpenAPI)
111+
// Get the variables related to state management
112+
stateFile := viper.GetString(FlagStateFile)
113+
loadState := false
114+
saveState := false
115+
116+
// Validate state file configuration
117+
if stateFile != "" {
118+
if !viper.IsSet(FlagLoadState) {
119+
loadState = true
120+
} else {
121+
loadState = viper.GetBool(FlagLoadState)
122+
}
123+
124+
if !viper.IsSet(FlagSaveState) {
125+
saveState = true
126+
} else {
127+
saveState = viper.GetBool(FlagSaveState)
128+
}
129+
} else {
130+
if viper.IsSet(FlagLoadState) && viper.GetBool(FlagLoadState) {
131+
return xerrors.Errorf("--load-state requires --state-file to be set")
132+
}
133+
if viper.IsSet(FlagSaveState) && viper.GetBool(FlagSaveState) {
134+
return xerrors.Errorf("--save-state requires --state-file to be set")
135+
}
136+
}
137+
108138
experimentalACP := viper.GetBool(FlagExperimentalACP)
109139

140+
if experimentalACP && (saveState || loadState) {
141+
return xerrors.Errorf("ACP mode doesn't support state persistence")
142+
}
143+
144+
pidFile := viper.GetString(FlagPidFile)
145+
146+
// Write PID file if configured
147+
if pidFile != "" {
148+
if err := writePIDFile(pidFile, logger); err != nil {
149+
return xerrors.Errorf("failed to write PID file: %w", err)
150+
}
151+
defer cleanupPIDFile(pidFile, logger)
152+
}
153+
154+
printOpenAPI := viper.GetBool(FlagPrintOpenAPI)
155+
110156
if printOpenAPI && experimentalACP {
111157
return xerrors.Errorf("flags --%s and --%s are mutually exclusive", FlagPrintOpenAPI, FlagExperimentalACP)
112158
}
@@ -154,33 +200,45 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
154200
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
155201
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
156202
InitialPrompt: initialPrompt,
203+
StatePersistenceConfig: screentracker.StatePersistenceConfig{
204+
StateFile: stateFile,
205+
LoadState: loadState,
206+
SaveState: saveState,
207+
},
157208
})
209+
158210
if err != nil {
159211
return xerrors.Errorf("failed to create server: %w", err)
160212
}
161213
if printOpenAPI {
162214
fmt.Println(srv.GetOpenAPI())
163215
return nil
164216
}
217+
218+
// Create a context for graceful shutdown
219+
gracefulCtx, gracefulCancel := context.WithCancel(ctx)
220+
defer gracefulCancel()
221+
222+
// Setup signal handlers (they will call gracefulCancel)
223+
handleSignals(gracefulCtx, gracefulCancel, logger, srv)
224+
165225
logger.Info("Starting server on port", "port", port)
226+
227+
// Monitor process exit
166228
processExitCh := make(chan error, 1)
167-
// Wait for process exit in PTY mode
168229
if process != nil {
169230
go func() {
170231
defer close(processExitCh)
232+
defer gracefulCancel()
171233
if err := process.Wait(); err != nil {
172234
if errors.Is(err, termexec.ErrNonZeroExitCode) {
173235
processExitCh <- xerrors.Errorf("========\n%s\n========\n: %w", strings.TrimSpace(process.ReadScreen()), err)
174236
} else {
175237
processExitCh <- xerrors.Errorf("failed to wait for process: %w", err)
176238
}
177239
}
178-
if err := srv.Stop(ctx); err != nil {
179-
logger.Error("Failed to stop server", "error", err)
180-
}
181240
}()
182241
}
183-
// Wait for process exit in ACP mode
184242
if acpResult != nil {
185243
go func() {
186244
defer close(processExitCh)
@@ -193,13 +251,45 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
193251
}
194252
}()
195253
}
196-
if err := srv.Start(); err != nil && err != context.Canceled && err != http.ErrServerClosed {
197-
return xerrors.Errorf("failed to start server: %w", err)
254+
255+
// Start the server
256+
serverErrCh := make(chan error, 1)
257+
go func() {
258+
defer close(serverErrCh)
259+
if err := srv.Start(); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) {
260+
serverErrCh <- err
261+
}
262+
}()
263+
264+
select {
265+
case err := <-serverErrCh:
266+
if err != nil {
267+
return xerrors.Errorf("failed to start server: %w", err)
268+
}
269+
case <-gracefulCtx.Done():
270+
}
271+
272+
if err := srv.SaveState("shutdown"); err != nil {
273+
logger.Error("Failed to save state during shutdown", "error", err)
274+
}
275+
276+
// Stop the HTTP server
277+
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
278+
defer cancel()
279+
if err := srv.Stop(shutdownCtx); err != nil {
280+
logger.Error("Failed to stop HTTP server", "error", err)
198281
}
282+
199283
select {
200284
case err := <-processExitCh:
201-
return xerrors.Errorf("agent exited with error: %w", err)
285+
if err != nil {
286+
return xerrors.Errorf("agent exited with error: %w", err)
287+
}
202288
default:
289+
// Close the process
290+
if err := process.Close(logger, 5*time.Second); err != nil {
291+
logger.Error("Failed to close process cleanly", "error", err)
292+
}
203293
}
204294
return nil
205295
}
@@ -213,6 +303,61 @@ var agentNames = (func() []string {
213303
return names
214304
})()
215305

306+
// writePIDFile writes the current process ID to the specified file
307+
func writePIDFile(pidFile string, logger *slog.Logger) error {
308+
pid := os.Getpid()
309+
pidContent := fmt.Sprintf("%d\n", pid)
310+
311+
// Create directory if it doesn't exist
312+
dir := filepath.Dir(pidFile)
313+
if err := os.MkdirAll(dir, 0o700); err != nil {
314+
return xerrors.Errorf("failed to create PID file directory: %w", err)
315+
}
316+
317+
// Check if PID file already exists
318+
if existingPIDData, err := os.ReadFile(pidFile); err == nil {
319+
existingPIDStr := strings.TrimSpace(string(existingPIDData))
320+
if existingPID, err := strconv.Atoi(existingPIDStr); err == nil {
321+
if isProcessRunning(existingPID) {
322+
return xerrors.Errorf("another instance is already running with PID %d (PID file: %s)", existingPID, pidFile)
323+
}
324+
logger.Warn("Found stale PID file, will overwrite", "pidFile", pidFile, "stalePID", existingPID)
325+
}
326+
} else if !os.IsNotExist(err) {
327+
return xerrors.Errorf("failed to read existing PID file: %w", err)
328+
}
329+
330+
// Write PID file
331+
if err := os.WriteFile(pidFile, []byte(pidContent), 0o600); err != nil {
332+
return xerrors.Errorf("failed to write PID file: %w", err)
333+
}
334+
335+
logger.Info("Wrote PID file", "pidFile", pidFile, "pid", pid)
336+
return nil
337+
}
338+
339+
// cleanupPIDFile removes the PID file if it was written by this process.
340+
func cleanupPIDFile(pidFile string, logger *slog.Logger) {
341+
data, err := os.ReadFile(pidFile)
342+
if err != nil {
343+
if !os.IsNotExist(err) {
344+
logger.Error("Failed to read PID file for cleanup", "pidFile", pidFile, "error", err)
345+
}
346+
return
347+
}
348+
pidStr := strings.TrimSpace(string(data))
349+
filePID, err := strconv.Atoi(pidStr)
350+
if err != nil || filePID != os.Getpid() {
351+
logger.Info("PID file belongs to another process, skipping cleanup", "pidFile", pidFile, "filePID", pidStr)
352+
return
353+
}
354+
if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) {
355+
logger.Error("Failed to remove PID file", "pidFile", pidFile, "error", err)
356+
} else if err == nil {
357+
logger.Info("Removed PID file", "pidFile", pidFile)
358+
}
359+
}
360+
216361
type flagSpec struct {
217362
name string
218363
shorthand string
@@ -232,6 +377,10 @@ const (
232377
FlagAllowedOrigins = "allowed-origins"
233378
FlagExit = "exit"
234379
FlagInitialPrompt = "initial-prompt"
380+
FlagStateFile = "state-file"
381+
FlagLoadState = "load-state"
382+
FlagSaveState = "save-state"
383+
FlagPidFile = "pid-file"
235384
FlagExperimentalACP = "experimental-acp"
236385
)
237386

@@ -271,6 +420,10 @@ func CreateServerCmd() *cobra.Command {
271420
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
272421
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
273422
{FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"},
423+
{FlagStateFile, "s", "", "Path to file for saving/loading server state", "string"},
424+
{FlagLoadState, "", false, "Load state from state-file on startup (defaults to true when state-file is set)", "bool"},
425+
{FlagSaveState, "", false, "Save state to state-file on shutdown (defaults to true when state-file is set)", "bool"},
426+
{FlagPidFile, "", "", "Path to file where the server process ID will be written for shutdown scripts", "string"},
274427
{FlagExperimentalACP, "", false, "Use experimental ACP transport instead of PTY", "bool"},
275428
}
276429

0 commit comments

Comments
 (0)