Skip to content

Commit e02470e

Browse files
committed
inject tab state + tools into user message. better for caching and for AI's understanding
1 parent 668085c commit e02470e

10 files changed

Lines changed: 225 additions & 94 deletions

File tree

pkg/aiusechat/anthropic/anthropic-backend.go

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -395,23 +395,16 @@ func RunAnthropicChatStep(
395395
if !ok {
396396
return nil, nil, fmt.Errorf("expected anthropicChatMessage, got %T", genMsg)
397397
}
398-
399-
// Convert to anthropicInputMessage
398+
// Convert to anthropicInputMessage with copied content
399+
contentCopy := make([]anthropicMessageContentBlock, len(chatMsg.Content))
400+
copy(contentCopy, chatMsg.Content)
400401
inputMsg := anthropicInputMessage{
401402
Role: chatMsg.Role,
402-
Content: chatMsg.Content,
403+
Content: contentCopy,
403404
}
404405
anthropicMsgs = append(anthropicMsgs, inputMsg)
405406
}
406407

407-
// pretty print json of anthropicMsgs
408-
if jsonBytes, err := json.MarshalIndent(anthropicMsgs, "", " "); err == nil {
409-
log.Printf("system-prompt: %v\n", chatOpts.SystemPrompt)
410-
log.Printf("anthropicMsgs JSON:\n%s", string(jsonBytes))
411-
} else {
412-
return nil, nil, fmt.Errorf("failed to marshal messages to JSON: %w", err)
413-
}
414-
415408
req, err := buildAnthropicHTTPRequest(ctx, anthropicMsgs, chatOpts)
416409
if err != nil {
417410
return nil, nil, err

pkg/aiusechat/anthropic/anthropic-convertmessage.go

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717

1818
"github.com/google/uuid"
1919
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
20+
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
2021
)
2122

2223
// these conversions are based off the anthropic spec
@@ -55,6 +56,23 @@ func buildAnthropicHTTPRequest(ctx context.Context, msgs []anthropicInputMessage
5556
convertedMsgs[i] = convertMessageForAPI(msg)
5657
}
5758

59+
// inject chatOpts.TabState as a "text" block at the END of the LAST "user" message found (append to Content)
60+
if chatOpts.TabState != "" {
61+
// Find the last "user" message
62+
for i := len(convertedMsgs) - 1; i >= 0; i-- {
63+
if convertedMsgs[i].Role == "user" {
64+
// Create a text block with the TabState content
65+
tabStateBlock := anthropicMessageContentBlock{
66+
Type: "text",
67+
Text: chatOpts.TabState,
68+
}
69+
// Append to the Content of this message
70+
convertedMsgs[i].Content = append(convertedMsgs[i].Content, tabStateBlock)
71+
break
72+
}
73+
}
74+
}
75+
5876
// Build request body
5977
reqBody := &anthropicStreamRequest{
6078
Model: opts.Model,
@@ -82,16 +100,33 @@ func buildAnthropicHTTPRequest(ctx context.Context, msgs []anthropicInputMessage
82100
}
83101
reqBody.Tools = cleanedTools
84102
}
103+
for _, tool := range chatOpts.TabTools {
104+
cleanedTool := *tool.Clean()
105+
reqBody.Tools = append(reqBody.Tools, cleanedTool)
106+
}
85107

86108
// Enable extended thinking based on level
87109
reqBody.Thinking = makeThinkingOpts(opts.ThinkingLevel, maxTokens)
88110

89-
bodyBytes, err := json.Marshal(reqBody)
111+
// pretty print json of anthropicMsgs
112+
if jsonStr, err := utilfn.MarshalIndentNoHTMLString(convertedMsgs, "", " "); err == nil {
113+
log.Printf("system-prompt: %v\n", chatOpts.SystemPrompt)
114+
var toolNames []string
115+
for _, tool := range chatOpts.Tools {
116+
toolNames = append(toolNames, tool.Name)
117+
}
118+
log.Printf("tools: %s\n", strings.Join(toolNames, ", "))
119+
log.Printf("anthropicMsgs JSON:\n%s", jsonStr)
120+
}
121+
122+
var buf bytes.Buffer
123+
encoder := json.NewEncoder(&buf)
124+
encoder.SetEscapeHTML(false)
125+
err := encoder.Encode(reqBody)
90126
if err != nil {
91127
return nil, err
92128
}
93-
94-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(bodyBytes))
129+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, &buf)
95130
if err != nil {
96131
return nil, err
97132
}

pkg/aiusechat/tools.go

Lines changed: 103 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ import (
77
"context"
88
"fmt"
99
"strings"
10+
"time"
1011

1112
"github.com/google/uuid"
1213
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
1314
"github.com/wavetermdev/waveterm/pkg/blockcontroller"
1415
"github.com/wavetermdev/waveterm/pkg/waveobj"
16+
"github.com/wavetermdev/waveterm/pkg/wcore"
1517
"github.com/wavetermdev/waveterm/pkg/wstore"
1618
)
1719

@@ -99,44 +101,42 @@ func MakeBlockShortDesc(block *waveobj.Block) string {
99101
}
100102
}
101103

102-
func AddToolsForTab(ctx context.Context, tabid string, widgetAccess bool, chatOpts *uctypes.WaveChatOpts) error {
104+
func GenerateTabStateAndTools(ctx context.Context, tabid string, widgetAccess bool) (string, []uctypes.ToolDefinition, error) {
103105
if tabid == "" {
104-
return nil
105-
}
106-
if !widgetAccess {
107-
chatOpts.SystemPrompt = append(chatOpts.SystemPrompt, "The user has chosen not to share widget context with you.")
108-
return nil
109-
}
110-
111-
if _, err := uuid.Parse(tabid); err != nil {
112-
return fmt.Errorf("tabid must be a valid UUID")
113-
}
114-
115-
tabObj, err := wstore.DBMustGet[*waveobj.Tab](ctx, tabid)
116-
if err != nil {
117-
return fmt.Errorf("error getting tab: %v", err)
106+
return "", nil, nil
118107
}
119-
120108
var blocks []*waveobj.Block
121-
for _, blockId := range tabObj.BlockIds {
122-
block, err := wstore.DBGet[*waveobj.Block](ctx, blockId)
123-
if err != nil {
124-
continue
109+
if widgetAccess {
110+
if _, err := uuid.Parse(tabid); err != nil {
111+
return "", nil, fmt.Errorf("tabid must be a valid UUID")
125112
}
126-
blocks = append(blocks, block)
127-
}
128113

129-
systemPrompt := generateTabSystemPrompt(blocks)
130-
chatOpts.SystemPrompt = append(chatOpts.SystemPrompt, systemPrompt)
114+
tabObj, err := wstore.DBMustGet[*waveobj.Tab](ctx, tabid)
115+
if err != nil {
116+
return "", nil, fmt.Errorf("error getting tab: %v", err)
117+
}
131118

132-
return nil
119+
for _, blockId := range tabObj.BlockIds {
120+
block, err := wstore.DBGet[*waveobj.Block](ctx, blockId)
121+
if err != nil {
122+
continue
123+
}
124+
blocks = append(blocks, block)
125+
}
126+
}
127+
tabState := GenerateCurrentTabStatePrompt(blocks, widgetAccess)
128+
var tools []uctypes.ToolDefinition
129+
for _, block := range blocks {
130+
blockTools := generateToolsForBlock(block)
131+
tools = append(tools, blockTools...)
132+
}
133+
return tabState, tools, nil
133134
}
134135

135-
func generateTabSystemPrompt(blocks []*waveobj.Block) string {
136-
if len(blocks) == 0 {
137-
return "This tab is empty with no widgets currently open."
136+
func GenerateCurrentTabStatePrompt(blocks []*waveobj.Block, widgetAccess bool) string {
137+
if !widgetAccess {
138+
return `<current_tab_state>The user has chosen not to share widget context with you</current_tab_state>`
138139
}
139-
140140
var widgetDescriptions []string
141141
for _, block := range blocks {
142142
desc := MakeBlockShortDesc(block)
@@ -148,21 +148,86 @@ func generateTabSystemPrompt(blocks []*waveobj.Block) string {
148148
widgetDescriptions = append(widgetDescriptions, fullDesc)
149149
}
150150

151-
totalWidgets := len(widgetDescriptions)
152151
var prompt strings.Builder
153-
if totalWidgets == 1 {
154-
prompt.WriteString("In this tab there is 1 widget open (the widgetid appears in parentheses before the description):\n")
152+
prompt.WriteString("<current_tab_state>\n")
153+
if len(widgetDescriptions) == 0 {
154+
prompt.WriteString("No widgets open\n")
155155
} else {
156-
prompt.WriteString(fmt.Sprintf("In this tab there are %d widgets open (the widgetid appears in parentheses before the description):\n", totalWidgets))
156+
for _, desc := range widgetDescriptions {
157+
prompt.WriteString("* ")
158+
prompt.WriteString(desc)
159+
prompt.WriteString("\n")
160+
}
157161
}
162+
prompt.WriteString("</current_tab_state>")
163+
return prompt.String()
164+
}
158165

159-
for _, desc := range widgetDescriptions {
160-
prompt.WriteString("* ")
161-
prompt.WriteString(desc)
162-
prompt.WriteString("\n")
166+
func generateToolsForBlock(block *waveobj.Block) []uctypes.ToolDefinition {
167+
if block.Meta == nil {
168+
return nil
163169
}
164170

165-
return prompt.String()
171+
viewType, ok := block.Meta["view"].(string)
172+
if !ok {
173+
return nil
174+
}
175+
176+
var tools []uctypes.ToolDefinition
177+
switch viewType {
178+
case "web":
179+
tools = append(tools, GetWebNavigateToolDefinition(block))
180+
}
181+
182+
return tools
183+
}
184+
185+
func GetWebNavigateToolDefinition(block *waveobj.Block) uctypes.ToolDefinition {
186+
blockIdPrefix := block.OID[:8]
187+
toolName := fmt.Sprintf("web_navigate_%s", blockIdPrefix)
188+
189+
return uctypes.ToolDefinition{
190+
Name: toolName,
191+
DisplayName: fmt.Sprintf("Navigate Web Block %s", blockIdPrefix),
192+
Description: fmt.Sprintf("Navigate the web browser widget %s to a new URL", blockIdPrefix),
193+
InputSchema: map[string]any{
194+
"type": "object",
195+
"properties": map[string]any{
196+
"url": map[string]any{
197+
"type": "string",
198+
"description": "URL to navigate to",
199+
},
200+
},
201+
"required": []string{"url"},
202+
},
203+
ToolAnyCallback: func(input any) (any, error) {
204+
inputMap, ok := input.(map[string]any)
205+
if !ok {
206+
return nil, fmt.Errorf("invalid input format")
207+
}
208+
209+
url, ok := inputMap["url"].(string)
210+
if !ok {
211+
return nil, fmt.Errorf("missing or invalid url parameter")
212+
}
213+
214+
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
215+
defer cancelFn()
216+
217+
blockORef := waveobj.MakeORef(waveobj.OType_Block, block.OID)
218+
meta := map[string]any{
219+
"url": url,
220+
}
221+
222+
err := wstore.UpdateObjectMeta(ctx, blockORef, meta, false)
223+
if err != nil {
224+
return nil, fmt.Errorf("failed to update web block URL: %w", err)
225+
}
226+
227+
wcore.SendWaveObjUpdate(blockORef)
228+
return true, nil
229+
},
230+
}
166231
}
167232

168233
func GetAdderToolDefinition() uctypes.ToolDefinition {

pkg/aiusechat/uctypes/usechat-types.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,13 @@ func (m *UIMessage) GetContent() string {
322322
}
323323

324324
type WaveChatOpts struct {
325-
ChatId string
326-
Config AIOptsType
327-
Tools []ToolDefinition
328-
SystemPrompt []string
325+
ChatId string
326+
Config AIOptsType
327+
Tools []ToolDefinition
328+
SystemPrompt []string
329+
TabStateGenerator func() (string, []ToolDefinition, error)
330+
331+
// emphemeral to the step
332+
TabState string
333+
TabTools []ToolDefinition
329334
}

pkg/aiusechat/usechat.go

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,19 @@ func RunWaveAIRequest(ctx context.Context, sseHandler *sse.SSEHandlerCh, aiOpts
9898
}
9999

100100
func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctypes.WaveChatOpts) error {
101+
log.Printf("RunAIChat\n")
101102
// Stream the Anthropic chat response
102103
firstStep := true
103104
var cont *uctypes.WaveContinueResponse
104105
for {
105-
stopReason, rtnMessage, err := anthropic.RunAnthropicChatStep(ctx, sseHandler, chatOpts, cont)
106+
var stopReason *uctypes.WaveStopReason
107+
var rtnMessage uctypes.GenAIMessage
108+
tabState, tabTools, err := chatOpts.TabStateGenerator()
109+
if err == nil {
110+
chatOpts.TabState = tabState
111+
chatOpts.TabTools = tabTools
112+
stopReason, rtnMessage, err = anthropic.RunAnthropicChatStep(ctx, sseHandler, chatOpts, cont)
113+
}
106114
if firstStep && err != nil {
107115
return fmt.Errorf("failed to stream anthropic chat: %w", err)
108116
}
@@ -119,7 +127,7 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctyp
119127
for _, toolCall := range stopReason.ToolCalls {
120128
inputJSON, _ := json.Marshal(toolCall.Input)
121129
log.Printf("TOOLUSE name=%s id=%s input=%s\n", toolCall.Name, toolCall.ID, string(inputJSON))
122-
result := ResolveToolCall(toolCall, chatOpts.Tools)
130+
result := ResolveToolCall(toolCall, chatOpts)
123131
toolResults = append(toolResults, result)
124132
if result.ErrorText != "" {
125133
log.Printf(" error=%s\n", result.ErrorText)
@@ -138,7 +146,7 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctyp
138146
}
139147

140148
cont = &uctypes.WaveContinueResponse{
141-
MessageID: rtnMessage.MessageId,
149+
MessageID: rtnMessage.GetMessageId(),
142150
Model: chatOpts.Config.Model,
143151
ContinueFromKind: uctypes.StopKindToolUse,
144152
ContinueFromRawReason: stopReason.RawReason,
@@ -151,7 +159,7 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctyp
151159
}
152160

153161
// ResolveToolCall resolves a single tool call and returns an AIToolResult
154-
func ResolveToolCall(toolCall uctypes.WaveToolCall, tools []uctypes.ToolDefinition) (result uctypes.AIToolResult) {
162+
func ResolveToolCall(toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpts) (result uctypes.AIToolResult) {
155163
result = uctypes.AIToolResult{
156164
ToolName: toolCall.Name,
157165
ToolUseID: toolCall.ID,
@@ -166,12 +174,20 @@ func ResolveToolCall(toolCall uctypes.WaveToolCall, tools []uctypes.ToolDefiniti
166174

167175
// Find the matching tool definition
168176
var toolDef *uctypes.ToolDefinition
169-
for i := range tools {
170-
if tools[i].Name == toolCall.Name {
171-
toolDef = &tools[i]
177+
for _, tool := range chatOpts.Tools {
178+
if tool.Name == toolCall.Name {
179+
toolDef = &tool
172180
break
173181
}
174182
}
183+
if toolDef == nil {
184+
for _, tool := range chatOpts.TabTools {
185+
if tool.Name == toolCall.Name {
186+
toolDef = &tool
187+
break
188+
}
189+
}
190+
}
175191

176192
if toolDef == nil {
177193
result.ErrorText = fmt.Sprintf("tool '%s' not found", toolCall.Name)
@@ -207,6 +223,7 @@ func ResolveToolCall(toolCall uctypes.WaveToolCall, tools []uctypes.ToolDefiniti
207223
}
208224

209225
func WaveAIPostMessageWrap(ctx context.Context, sseHandler *sse.SSEHandlerCh, message *uctypes.AIMessage, chatOpts uctypes.WaveChatOpts) error {
226+
log.Printf("WaveAIPostMessageWrap\n")
210227
// Only support Anthropic for now
211228
if chatOpts.Config.APIType != APIType_Anthropic {
212229
return fmt.Errorf("only Anthropic API type is supported, got: %s", chatOpts.Config.APIType)
@@ -276,10 +293,8 @@ func WaveAIPostMessageHandler(w http.ResponseWriter, r *http.Request) {
276293

277294
// Create tools array with adder tool
278295
chatOpts.Tools = append(chatOpts.Tools, GetAdderToolDefinition())
279-
err = AddToolsForTab(r.Context(), req.TabId, req.WidgetAccess, &chatOpts)
280-
if err != nil {
281-
http.Error(w, fmt.Sprintf("Error trying to add tab tool context: %v", err), http.StatusInternalServerError)
282-
return
296+
chatOpts.TabStateGenerator = func() (string, []uctypes.ToolDefinition, error) {
297+
return GenerateTabStateAndTools(r.Context(), req.TabId, req.WidgetAccess)
283298
}
284299

285300
// Validate the message

0 commit comments

Comments
 (0)