Skip to content

Commit 759ec53

Browse files
committed
feat: improved initial prompt handling
1 parent 9d7eb5a commit 759ec53

3 files changed

Lines changed: 379 additions & 44 deletions

File tree

lib/screentracker/conversation.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package screentracker
22

33
import (
44
"context"
5+
"strings"
56
"time"
67

78
"github.com/coder/agentapi/lib/util"
@@ -49,6 +50,14 @@ type MessagePart interface {
4950
String() string
5051
}
5152

53+
func buildStringFromMessageParts(parts []MessagePart) string {
54+
var sb strings.Builder
55+
for _, part := range parts {
56+
sb.WriteString(part.String())
57+
}
58+
return sb.String()
59+
}
60+
5261
// Conversation represents a conversation between a user and an agent.
5362
// It is intended as the primary interface for interacting with a session.
5463
// Implementations must support the following capabilities:

lib/screentracker/pty_conversation.go

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ type MessagePartText struct {
3131
}
3232

3333
type AgentState struct {
34-
Version int `json:"version"`
35-
Messages []ConversationMessage `json:"messages"`
36-
InitialPrompt string `json:"initial_prompt"`
34+
Version int `json:"version"`
35+
Messages []ConversationMessage `json:"messages"`
36+
InitialPrompt string `json:"initial_prompt"`
37+
InitialPromptSent bool `json:"initial_prompt_sent"`
3738
}
3839

3940
var _ MessagePart = &MessagePartText{}
@@ -129,6 +130,7 @@ type PTYConversation struct {
129130
// initialPromptReady is set to true when ReadyForInitialPrompt returns true.
130131
// Checked inline in the snapshot loop on each tick.
131132
initialPromptReady bool
133+
initialPromptSent bool
132134
}
133135

134136
var _ Conversation = &PTYConversation{}
@@ -167,10 +169,6 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PT
167169
userSentMessageAfterLoadState: false,
168170
loadStateSuccessful: false,
169171
}
170-
// If we have an initial prompt, enqueue it
171-
if len(cfg.InitialPrompt) > 0 {
172-
c.outboundQueue <- outboundMessage{parts: cfg.InitialPrompt, errCh: nil}
173-
}
174172
if c.cfg.ReadyForInitialPrompt == nil {
175173
c.cfg.ReadyForInitialPrompt = func(string) bool { return true }
176174
}
@@ -200,6 +198,13 @@ func (c *PTYConversation) Start(ctx context.Context) {
200198
c.loadStateSuccessful = true
201199
}
202200

201+
// Enqueue initial prompt once after agent is ready (and after state is potentially loaded)
202+
if c.initialPromptReady && len(c.cfg.InitialPrompt) > 0 && !c.initialPromptSent {
203+
c.outboundQueue <- outboundMessage{parts: c.cfg.InitialPrompt, errCh: nil}
204+
c.initialPromptSent = true
205+
c.dirty = true
206+
}
207+
203208
if c.initialPromptReady && len(c.outboundQueue) > 0 && c.isScreenStableLocked() {
204209
select {
205210
case c.stableSignal <- struct{}{}:
@@ -324,11 +329,7 @@ func (c *PTYConversation) snapshotLocked(screen string) {
324329

325330
func (c *PTYConversation) Send(messageParts ...MessagePart) error {
326331
// Validate message content before enqueueing
327-
var sb strings.Builder
328-
for _, part := range messageParts {
329-
sb.WriteString(part.String())
330-
}
331-
message := sb.String()
332+
message := buildStringFromMessageParts(messageParts)
332333
if message != msgfmt.TrimWhitespace(message) {
333334
return ErrMessageValidationWhitespace
334335
}
@@ -352,11 +353,7 @@ func (c *PTYConversation) Send(messageParts ...MessagePart) error {
352353
// around the parts that access shared state, but releases it during
353354
// writeStabilize to avoid blocking the snapshot loop.
354355
func (c *PTYConversation) sendMessage(ctx context.Context, messageParts ...MessagePart) error {
355-
var sb strings.Builder
356-
for _, part := range messageParts {
357-
sb.WriteString(part.String())
358-
}
359-
message := sb.String()
356+
message := buildStringFromMessageParts(messageParts)
360357

361358
c.lock.Lock()
362359
screenBeforeMessage := c.cfg.AgentIO.ReadScreen()
@@ -559,18 +556,15 @@ func (c *PTYConversation) SaveState() error {
559556
// Serialize initial prompt from message parts
560557
var initialPromptStr string
561558
if len(c.cfg.InitialPrompt) > 0 {
562-
var sb strings.Builder
563-
for _, part := range c.cfg.InitialPrompt {
564-
sb.WriteString(part.String())
565-
}
566-
initialPromptStr = sb.String()
559+
initialPromptStr = buildStringFromMessageParts(c.cfg.InitialPrompt)
567560
}
568561

569562
// Use atomic write: write to temp file, then rename to target path
570563
data, err := json.MarshalIndent(AgentState{
571-
Version: 1,
572-
Messages: conversation,
573-
InitialPrompt: initialPromptStr,
564+
Version: 1,
565+
Messages: conversation,
566+
InitialPrompt: initialPromptStr,
567+
InitialPromptSent: c.initialPromptSent,
574568
}, "", " ")
575569
if err != nil {
576570
return xerrors.Errorf("failed to marshal state: %w", err)
@@ -637,12 +631,22 @@ func (c *PTYConversation) loadStateLocked() error {
637631
return xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err)
638632
}
639633

640-
//c.cfg.initialPromptSent = agentState.InitialPromptSent
641-
c.cfg.InitialPrompt = []MessagePart{MessagePartText{
642-
Content: agentState.InitialPrompt,
643-
Alias: "",
644-
Hidden: false,
645-
}}
634+
// Handle initial prompt restoration:
635+
// - If a new initial prompt was provided via flags, check if it differs from the saved one.
636+
// If different, mark as not sent (will be sent). If same, preserve sent status.
637+
// - If no new prompt provided, restore the saved prompt and its sent status.
638+
c.initialPromptSent = agentState.InitialPromptSent
639+
if len(c.cfg.InitialPrompt) > 0 {
640+
isDifferent := buildStringFromMessageParts(c.cfg.InitialPrompt) != agentState.InitialPrompt
641+
c.initialPromptSent = !isDifferent
642+
} else {
643+
c.cfg.InitialPrompt = []MessagePart{MessagePartText{
644+
Content: agentState.InitialPrompt,
645+
Alias: "",
646+
Hidden: false,
647+
}}
648+
}
649+
646650
c.messages = agentState.Messages
647651

648652
// Store the first stable snapshot for filtering later

0 commit comments

Comments
 (0)