Skip to content

Commit a0f8bb5

Browse files
committed
feat: implement state persistence
1 parent e5f1bda commit a0f8bb5

6 files changed

Lines changed: 287 additions & 54 deletions

File tree

cmd/server/server.go

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,26 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
103103
}
104104
}
105105

106+
// Get the variables related to state management
107+
stateFile := viper.GetString(StateFile)
108+
loadState := true
109+
saveState := true
110+
if stateFile != "" {
111+
if !viper.IsSet(LoadState) {
112+
loadState = true
113+
} else {
114+
loadState = viper.GetBool(LoadState)
115+
}
116+
117+
if !viper.IsSet(SaveState) {
118+
saveState = true
119+
} else {
120+
saveState = viper.GetBool(SaveState)
121+
}
122+
}
123+
124+
pidFile := viper.GetString(PidFile)
125+
106126
printOpenAPI := viper.GetBool(FlagPrintOpenAPI)
107127
var process *termexec.Process
108128
if printOpenAPI {
@@ -128,7 +148,14 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
128148
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
129149
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
130150
InitialPrompt: initialPrompt,
151+
StatePersistenceCfg: httpapi.StatePersistenceCfg{
152+
StateFile: stateFile,
153+
LoadState: loadState,
154+
SaveState: saveState,
155+
PidFile: pidFile,
156+
},
131157
})
158+
132159
if err != nil {
133160
return xerrors.Errorf("failed to create server: %w", err)
134161
}
@@ -137,6 +164,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
137164
return nil
138165
}
139166
srv.StartSnapshotLoop(ctx)
167+
srv.HandleSignals(ctx, process)
140168
logger.Info("Starting server on port", "port", port)
141169
processExitCh := make(chan error, 1)
142170
go func() {
@@ -152,7 +180,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
152180
logger.Error("Failed to stop server", "error", err)
153181
}
154182
}()
155-
if err := srv.Start(); err != nil && err != context.Canceled && err != http.ErrServerClosed {
183+
if err := srv.Start(); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) {
156184
return xerrors.Errorf("failed to start server: %w", err)
157185
}
158186
select {
@@ -191,6 +219,10 @@ const (
191219
FlagAllowedOrigins = "allowed-origins"
192220
FlagExit = "exit"
193221
FlagInitialPrompt = "initial-prompt"
222+
StateFile = "state-file"
223+
LoadState = "load-state"
224+
SaveState = "save-state"
225+
PidFile = "pid-file"
194226
)
195227

196228
func CreateServerCmd() *cobra.Command {
@@ -229,6 +261,10 @@ func CreateServerCmd() *cobra.Command {
229261
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
230262
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
231263
{FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"},
264+
{StateFile, "s", "", "Path to file for saving/loading server state", "string"},
265+
{LoadState, "", false, "Load state from state-file on startup (defaults to true when state-file is set)", "bool"},
266+
{SaveState, "", false, "Save state to state-file on shutdown (defaults to true when state-file is set)", "bool"},
267+
{PidFile, "", "", "Path to file where the server process ID will be written for shutdown scripts", "string"},
232268
}
233269

234270
for _, spec := range flagSpecs {

lib/httpapi/events.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func (e *EventEmitter) notifyChannels(eventType EventType, payload any) {
120120
}
121121
}
122122

123-
// Assumes that only the last message can change or new messages can be added.
123+
// UpdateMessagesAndEmitChanges assumes that only the last message can change or new messages can be added.
124124
// If a new message is injected between existing messages (identified by Id), the behavior is undefined.
125125
func (e *EventEmitter) UpdateMessagesAndEmitChanges(newMessages []st.ConversationMessage) {
126126
e.mu.Lock()

lib/httpapi/server.go

Lines changed: 124 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ import (
1111
"net/http"
1212
"net/url"
1313
"os"
14+
"os/signal"
1415
"path/filepath"
1516
"slices"
1617
"sort"
1718
"strings"
1819
"sync"
20+
"syscall"
1921
"time"
2022
"unicode"
2123

@@ -34,18 +36,20 @@ import (
3436

3537
// Server represents the HTTP server
3638
type Server struct {
37-
router chi.Router
38-
api huma.API
39-
port int
40-
srv *http.Server
41-
mu sync.RWMutex
42-
logger *slog.Logger
43-
conversation *st.PTYConversation
44-
agentio *termexec.Process
45-
agentType mf.AgentType
46-
emitter *EventEmitter
47-
chatBasePath string
48-
tempDir string
39+
router chi.Router
40+
api huma.API
41+
port int
42+
srv *http.Server
43+
mu sync.RWMutex
44+
logger *slog.Logger
45+
conversation *st.PTYConversation
46+
agentio *termexec.Process
47+
agentType mf.AgentType
48+
emitter *EventEmitter
49+
chatBasePath string
50+
tempDir string
51+
statePersistenceCfg StatePersistenceCfg
52+
stateLoadComplete bool
4953
}
5054

5155
func (s *Server) NormalizeSchema(schema any) any {
@@ -94,14 +98,22 @@ func (s *Server) GetOpenAPI() string {
9498
// because the action of taking a snapshot takes time too.
9599
const snapshotInterval = 25 * time.Millisecond
96100

101+
type StatePersistenceCfg struct {
102+
StateFile string
103+
LoadState bool
104+
SaveState bool
105+
PidFile string
106+
}
107+
97108
type ServerConfig struct {
98-
AgentType mf.AgentType
99-
Process *termexec.Process
100-
Port int
101-
ChatBasePath string
102-
AllowedHosts []string
103-
AllowedOrigins []string
104-
InitialPrompt string
109+
AgentType mf.AgentType
110+
Process *termexec.Process
111+
Port int
112+
ChatBasePath string
113+
AllowedHosts []string
114+
AllowedOrigins []string
115+
InitialPrompt string
116+
StatePersistenceCfg StatePersistenceCfg
105117
}
106118

107119
// Validate allowed hosts don't contain whitespace, commas, schemes, or ports.
@@ -260,16 +272,18 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
260272
logger.Info("Created temporary directory for uploads", "tempDir", tempDir)
261273

262274
s := &Server{
263-
router: router,
264-
api: api,
265-
port: config.Port,
266-
conversation: conversation,
267-
logger: logger,
268-
agentio: config.Process,
269-
agentType: config.AgentType,
270-
emitter: emitter,
271-
chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"),
272-
tempDir: tempDir,
275+
router: router,
276+
api: api,
277+
port: config.Port,
278+
conversation: conversation,
279+
logger: logger,
280+
agentio: config.Process,
281+
agentType: config.AgentType,
282+
emitter: emitter,
283+
chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"),
284+
tempDir: tempDir,
285+
statePersistenceCfg: config.StatePersistenceCfg,
286+
stateLoadComplete: false,
273287
}
274288

275289
// Register API routes
@@ -337,15 +351,26 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) {
337351
currentStatus := s.conversation.Status()
338352

339353
// Send initial prompt when agent becomes stable for the first time
340-
if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable {
341-
342-
if err := s.conversation.Send(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil {
343-
s.logger.Error("Failed to send initial prompt", "error", err)
344-
} else {
345-
s.conversation.InitialPromptSent = true
346-
s.conversation.ReadyForInitialPrompt = false
347-
currentStatus = st.ConversationStatusChanging
348-
s.logger.Info("Initial prompt sent successfully")
354+
if convertStatus(currentStatus) == AgentStatusStable {
355+
356+
if !s.stateLoadComplete && s.statePersistenceCfg.LoadState {
357+
_, err := s.conversation.LoadState(s.statePersistenceCfg.StateFile)
358+
if err != nil {
359+
s.logger.Warn("Failed to load state file", "path", s.statePersistenceCfg.StateFile, "err", err)
360+
} else {
361+
s.logger.Info("Successfully loaded state", "path", s.statePersistenceCfg.StateFile)
362+
}
363+
s.stateLoadComplete = true
364+
}
365+
if !s.conversation.InitialPromptSent {
366+
if err := s.conversation.Send(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil {
367+
s.logger.Error("Failed to send initial prompt", "error", err)
368+
} else {
369+
s.conversation.InitialPromptSent = true
370+
s.conversation.ReadyForInitialPrompt = false
371+
currentStatus = st.ConversationStatusChanging
372+
s.logger.Info("Initial prompt sent successfully")
373+
}
349374
}
350375
}
351376
s.emitter.UpdateStatusAndEmitChanges(currentStatus, s.agentType)
@@ -592,6 +617,15 @@ func (s *Server) Start() error {
592617

593618
// Stop gracefully stops the HTTP server
594619
func (s *Server) Stop(ctx context.Context) error {
620+
// Save conversation state if configured
621+
if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" {
622+
if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil {
623+
s.logger.Error("Failed to save conversation state", "error", err)
624+
} else {
625+
s.logger.Info("Saved conversation state", "stateFile", s.statePersistenceCfg.StateFile)
626+
}
627+
}
628+
595629
// Clean up temporary directory
596630
s.cleanupTempDir()
597631

@@ -610,6 +644,58 @@ func (s *Server) cleanupTempDir() {
610644
}
611645
}
612646

647+
// HandleSignals sets up signal handlers for:
648+
// - SIGTERM, SIGINT, SIGHUP: save conversation state, then close the process
649+
// - SIGUSR1: save conversation state without exiting
650+
func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) {
651+
// Handle shutdown signals (SIGTERM, SIGINT, SIGHUP)
652+
shutdownCh := make(chan os.Signal, 1)
653+
signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP)
654+
go func() {
655+
sig := <-shutdownCh
656+
s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig)
657+
658+
// Save conversation state if configured (synchronously before closing process)
659+
if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" {
660+
if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil {
661+
s.logger.Error("Failed to save conversation state on signal", "signal", sig, "error", err)
662+
} else {
663+
s.logger.Info("Saved conversation state on signal", "signal", sig, "stateFile", s.statePersistenceCfg.StateFile)
664+
}
665+
}
666+
667+
// Now close the process
668+
if err := process.Close(s.logger, 5*time.Second); err != nil {
669+
s.logger.Error("Error closing process", "signal", sig, "error", err)
670+
}
671+
}()
672+
673+
// Handle SIGUSR1 for save without exit
674+
saveOnlyCh := make(chan os.Signal, 1)
675+
signal.Notify(saveOnlyCh, syscall.SIGUSR1)
676+
go func() {
677+
for {
678+
select {
679+
case <-saveOnlyCh:
680+
s.logger.Info("Received SIGUSR1, saving state without exiting")
681+
682+
// Save conversation state if configured
683+
if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" {
684+
if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil {
685+
s.logger.Error("Failed to save conversation state on SIGUSR1", "error", err)
686+
} else {
687+
s.logger.Info("Saved conversation state on SIGUSR1", "stateFile", s.statePersistenceCfg.StateFile)
688+
}
689+
} else {
690+
s.logger.Warn("SIGUSR1 received but state saving is not configured")
691+
}
692+
case <-ctx.Done():
693+
return
694+
}
695+
}
696+
}()
697+
}
698+
613699
// registerStaticFileRoutes sets up routes for serving static files
614700
func (s *Server) registerStaticFileRoutes() {
615701
chatHandler := FileServerWithIndexFallback(s.chatBasePath)

lib/httpapi/setup.go

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@ import (
44
"context"
55
"fmt"
66
"os"
7-
"os/signal"
87
"strings"
9-
"syscall"
10-
"time"
118

129
"github.com/coder/agentapi/lib/logctx"
1310
mf "github.com/coder/agentapi/lib/msgfmt"
@@ -45,16 +42,5 @@ func SetupProcess(ctx context.Context, config SetupProcessConfig) (*termexec.Pro
4542
return nil, err
4643
}
4744
}
48-
49-
// Handle SIGINT (Ctrl+C) and send it to the process
50-
signalCh := make(chan os.Signal, 1)
51-
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM)
52-
go func() {
53-
<-signalCh
54-
if err := process.Close(logger, 5*time.Second); err != nil {
55-
logger.Error("Error closing process", "error", err)
56-
}
57-
}()
58-
5945
return process, nil
6046
}

lib/screentracker/conversation.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ type MessagePart interface {
5252
// Conversation allows tracking of a conversation between a user and an agent.
5353
type Conversation interface {
5454
Messages() []ConversationMessage
55+
SaveState([]ConversationMessage, string) error
56+
LoadState(string) ([]ConversationMessage, error)
5557
Snapshot(string)
5658
Start(context.Context)
5759
Status() ConversationStatus
@@ -64,3 +66,10 @@ type ConversationMessage struct {
6466
Role ConversationRole
6567
Time time.Time
6668
}
69+
70+
type AgentState struct {
71+
Version int `json:"version"`
72+
Messages []ConversationMessage `json:"messages"`
73+
InitialPrompt string `json:"initial_prompt"`
74+
InitialPromptSent bool `json:"initial_prompt_sent"`
75+
}

0 commit comments

Comments
 (0)