@@ -8,9 +8,13 @@ import (
88 "log/slog"
99 "net/http"
1010 "os"
11+ "path/filepath"
1112 "sort"
13+ "strconv"
1214 "strings"
15+ "time"
1316
17+ "github.com/coder/agentapi/lib/screentracker"
1418 "github.com/mattn/go-isatty"
1519 "github.com/spf13/cobra"
1620 "github.com/spf13/viper"
@@ -104,9 +108,51 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
104108 }
105109 }
106110
107- printOpenAPI := viper .GetBool (FlagPrintOpenAPI )
111+ // Get the variables related to state management
112+ stateFile := viper .GetString (FlagStateFile )
113+ loadState := false
114+ saveState := false
115+
116+ // Validate state file configuration
117+ if stateFile != "" {
118+ if ! viper .IsSet (FlagLoadState ) {
119+ loadState = true
120+ } else {
121+ loadState = viper .GetBool (FlagLoadState )
122+ }
123+
124+ if ! viper .IsSet (FlagSaveState ) {
125+ saveState = true
126+ } else {
127+ saveState = viper .GetBool (FlagSaveState )
128+ }
129+ } else {
130+ if viper .IsSet (FlagLoadState ) && viper .GetBool (FlagLoadState ) {
131+ return xerrors .Errorf ("--load-state requires --state-file to be set" )
132+ }
133+ if viper .IsSet (FlagSaveState ) && viper .GetBool (FlagSaveState ) {
134+ return xerrors .Errorf ("--save-state requires --state-file to be set" )
135+ }
136+ }
137+
108138 experimentalACP := viper .GetBool (FlagExperimentalACP )
109139
140+ if experimentalACP && (saveState || loadState ) {
141+ return xerrors .Errorf ("ACP mode doesn't support state persistence" )
142+ }
143+
144+ pidFile := viper .GetString (FlagPidFile )
145+
146+ // Write PID file if configured
147+ if pidFile != "" {
148+ if err := writePIDFile (pidFile , logger ); err != nil {
149+ return xerrors .Errorf ("failed to write PID file: %w" , err )
150+ }
151+ defer cleanupPIDFile (pidFile , logger )
152+ }
153+
154+ printOpenAPI := viper .GetBool (FlagPrintOpenAPI )
155+
110156 if printOpenAPI && experimentalACP {
111157 return xerrors .Errorf ("flags --%s and --%s are mutually exclusive" , FlagPrintOpenAPI , FlagExperimentalACP )
112158 }
@@ -154,33 +200,45 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
154200 AllowedHosts : viper .GetStringSlice (FlagAllowedHosts ),
155201 AllowedOrigins : viper .GetStringSlice (FlagAllowedOrigins ),
156202 InitialPrompt : initialPrompt ,
203+ StatePersistenceConfig : screentracker.StatePersistenceConfig {
204+ StateFile : stateFile ,
205+ LoadState : loadState ,
206+ SaveState : saveState ,
207+ },
157208 })
209+
158210 if err != nil {
159211 return xerrors .Errorf ("failed to create server: %w" , err )
160212 }
161213 if printOpenAPI {
162214 fmt .Println (srv .GetOpenAPI ())
163215 return nil
164216 }
217+
218+ // Create a context for graceful shutdown
219+ gracefulCtx , gracefulCancel := context .WithCancel (ctx )
220+ defer gracefulCancel ()
221+
222+ // Setup signal handlers (they will call gracefulCancel)
223+ handleSignals (gracefulCtx , gracefulCancel , logger , srv )
224+
165225 logger .Info ("Starting server on port" , "port" , port )
226+
227+ // Monitor process exit
166228 processExitCh := make (chan error , 1 )
167- // Wait for process exit in PTY mode
168229 if process != nil {
169230 go func () {
170231 defer close (processExitCh )
232+ defer gracefulCancel ()
171233 if err := process .Wait (); err != nil {
172234 if errors .Is (err , termexec .ErrNonZeroExitCode ) {
173235 processExitCh <- xerrors .Errorf ("========\n %s\n ========\n : %w" , strings .TrimSpace (process .ReadScreen ()), err )
174236 } else {
175237 processExitCh <- xerrors .Errorf ("failed to wait for process: %w" , err )
176238 }
177239 }
178- if err := srv .Stop (ctx ); err != nil {
179- logger .Error ("Failed to stop server" , "error" , err )
180- }
181240 }()
182241 }
183- // Wait for process exit in ACP mode
184242 if acpResult != nil {
185243 go func () {
186244 defer close (processExitCh )
@@ -193,13 +251,45 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
193251 }
194252 }()
195253 }
196- if err := srv .Start (); err != nil && err != context .Canceled && err != http .ErrServerClosed {
197- return xerrors .Errorf ("failed to start server: %w" , err )
254+
255+ // Start the server
256+ serverErrCh := make (chan error , 1 )
257+ go func () {
258+ defer close (serverErrCh )
259+ if err := srv .Start (); err != nil && ! errors .Is (err , context .Canceled ) && ! errors .Is (err , http .ErrServerClosed ) {
260+ serverErrCh <- err
261+ }
262+ }()
263+
264+ select {
265+ case err := <- serverErrCh :
266+ if err != nil {
267+ return xerrors .Errorf ("failed to start server: %w" , err )
268+ }
269+ case <- gracefulCtx .Done ():
270+ }
271+
272+ if err := srv .SaveState ("shutdown" ); err != nil {
273+ logger .Error ("Failed to save state during shutdown" , "error" , err )
274+ }
275+
276+ // Stop the HTTP server
277+ shutdownCtx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
278+ defer cancel ()
279+ if err := srv .Stop (shutdownCtx ); err != nil {
280+ logger .Error ("Failed to stop HTTP server" , "error" , err )
198281 }
282+
199283 select {
200284 case err := <- processExitCh :
201- return xerrors .Errorf ("agent exited with error: %w" , err )
285+ if err != nil {
286+ return xerrors .Errorf ("agent exited with error: %w" , err )
287+ }
202288 default :
289+ // Close the process
290+ if err := process .Close (logger , 5 * time .Second ); err != nil {
291+ logger .Error ("Failed to close process cleanly" , "error" , err )
292+ }
203293 }
204294 return nil
205295}
@@ -213,6 +303,61 @@ var agentNames = (func() []string {
213303 return names
214304})()
215305
306+ // writePIDFile writes the current process ID to the specified file
307+ func writePIDFile (pidFile string , logger * slog.Logger ) error {
308+ pid := os .Getpid ()
309+ pidContent := fmt .Sprintf ("%d\n " , pid )
310+
311+ // Create directory if it doesn't exist
312+ dir := filepath .Dir (pidFile )
313+ if err := os .MkdirAll (dir , 0o700 ); err != nil {
314+ return xerrors .Errorf ("failed to create PID file directory: %w" , err )
315+ }
316+
317+ // Check if PID file already exists
318+ if existingPIDData , err := os .ReadFile (pidFile ); err == nil {
319+ existingPIDStr := strings .TrimSpace (string (existingPIDData ))
320+ if existingPID , err := strconv .Atoi (existingPIDStr ); err == nil {
321+ if isProcessRunning (existingPID ) {
322+ return xerrors .Errorf ("another instance is already running with PID %d (PID file: %s)" , existingPID , pidFile )
323+ }
324+ logger .Warn ("Found stale PID file, will overwrite" , "pidFile" , pidFile , "stalePID" , existingPID )
325+ }
326+ } else if ! os .IsNotExist (err ) {
327+ return xerrors .Errorf ("failed to read existing PID file: %w" , err )
328+ }
329+
330+ // Write PID file
331+ if err := os .WriteFile (pidFile , []byte (pidContent ), 0o600 ); err != nil {
332+ return xerrors .Errorf ("failed to write PID file: %w" , err )
333+ }
334+
335+ logger .Info ("Wrote PID file" , "pidFile" , pidFile , "pid" , pid )
336+ return nil
337+ }
338+
339+ // cleanupPIDFile removes the PID file if it was written by this process.
340+ func cleanupPIDFile (pidFile string , logger * slog.Logger ) {
341+ data , err := os .ReadFile (pidFile )
342+ if err != nil {
343+ if ! os .IsNotExist (err ) {
344+ logger .Error ("Failed to read PID file for cleanup" , "pidFile" , pidFile , "error" , err )
345+ }
346+ return
347+ }
348+ pidStr := strings .TrimSpace (string (data ))
349+ filePID , err := strconv .Atoi (pidStr )
350+ if err != nil || filePID != os .Getpid () {
351+ logger .Info ("PID file belongs to another process, skipping cleanup" , "pidFile" , pidFile , "filePID" , pidStr )
352+ return
353+ }
354+ if err := os .Remove (pidFile ); err != nil && ! os .IsNotExist (err ) {
355+ logger .Error ("Failed to remove PID file" , "pidFile" , pidFile , "error" , err )
356+ } else if err == nil {
357+ logger .Info ("Removed PID file" , "pidFile" , pidFile )
358+ }
359+ }
360+
216361type flagSpec struct {
217362 name string
218363 shorthand string
@@ -232,6 +377,10 @@ const (
232377 FlagAllowedOrigins = "allowed-origins"
233378 FlagExit = "exit"
234379 FlagInitialPrompt = "initial-prompt"
380+ FlagStateFile = "state-file"
381+ FlagLoadState = "load-state"
382+ FlagSaveState = "save-state"
383+ FlagPidFile = "pid-file"
235384 FlagExperimentalACP = "experimental-acp"
236385)
237386
@@ -271,6 +420,10 @@ func CreateServerCmd() *cobra.Command {
271420 // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
272421 {FlagAllowedOrigins , "o" , []string {"http://localhost:3284" , "http://localhost:3000" , "http://localhost:3001" }, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var" , "stringSlice" },
273422 {FlagInitialPrompt , "I" , "" , "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)" , "string" },
423+ {FlagStateFile , "s" , "" , "Path to file for saving/loading server state" , "string" },
424+ {FlagLoadState , "" , false , "Load state from state-file on startup (defaults to true when state-file is set)" , "bool" },
425+ {FlagSaveState , "" , false , "Save state to state-file on shutdown (defaults to true when state-file is set)" , "bool" },
426+ {FlagPidFile , "" , "" , "Path to file where the server process ID will be written for shutdown scripts" , "string" },
274427 {FlagExperimentalACP , "" , false , "Use experimental ACP transport instead of PTY" , "bool" },
275428 }
276429
0 commit comments