Skip to content

Commit 1c4e7c4

Browse files
committed
more tool refactoring
1 parent 5c57a4b commit 1c4e7c4

9 files changed

Lines changed: 120 additions & 84 deletions

File tree

pkg/aiusechat/tools.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,20 @@ import (
1515
"github.com/wavetermdev/waveterm/pkg/wstore"
1616
)
1717

18+
func resolveBlockIdFromPrefix(tab *waveobj.Tab, blockIdPrefix string) (string, error) {
19+
if len(blockIdPrefix) != 8 {
20+
return "", fmt.Errorf("widget_id must be 8 characters")
21+
}
22+
23+
for _, blockId := range tab.BlockIds {
24+
if strings.HasPrefix(blockId, blockIdPrefix) {
25+
return blockId, nil
26+
}
27+
}
28+
29+
return "", fmt.Errorf("widget_id not found: %q", blockIdPrefix)
30+
}
31+
1832
func MakeBlockShortDesc(block *waveobj.Block) string {
1933
if block.Meta == nil {
2034
return ""
@@ -110,8 +124,11 @@ func GenerateTabStateAndTools(ctx context.Context, tabid string, widgetAccess bo
110124
}
111125
tabState := GenerateCurrentTabStatePrompt(blocks, widgetAccess)
112126
var tools []uctypes.ToolDefinition
127+
if widgetAccess {
128+
tools = append(tools, GetCaptureScreenshotToolDefinition(tabid))
129+
}
113130
for _, block := range blocks {
114-
blockTools := generateToolsForBlock(block)
131+
blockTools := generateToolsForBlock(tabid, block)
115132
tools = append(tools, blockTools...)
116133
}
117134
return tabState, tools, nil
@@ -147,7 +164,7 @@ func GenerateCurrentTabStatePrompt(blocks []*waveobj.Block, widgetAccess bool) s
147164
return prompt.String()
148165
}
149166

150-
func generateToolsForBlock(block *waveobj.Block) []uctypes.ToolDefinition {
167+
func generateToolsForBlock(tabId string, block *waveobj.Block) []uctypes.ToolDefinition {
151168
if block.Meta == nil {
152169
return nil
153170
}
@@ -160,7 +177,7 @@ func generateToolsForBlock(block *waveobj.Block) []uctypes.ToolDefinition {
160177
var tools []uctypes.ToolDefinition
161178
switch viewType {
162179
case "term":
163-
tools = append(tools, GetTermGetScrollbackToolDefinition(block))
180+
tools = append(tools, GetTermGetScrollbackToolDefinition(tabId))
164181
case "web":
165182
tools = append(tools, GetWebNavigateToolDefinition(block))
166183
case "tsunami":
@@ -188,6 +205,7 @@ func GetAdderToolDefinition() uctypes.ToolDefinition {
188205
Name: "adder",
189206
DisplayName: "Adder",
190207
Description: "Add an array of numbers together and return their sum",
208+
ToolLogName: "gen:adder",
191209
Strict: true,
192210
InputSchema: map[string]any{
193211
"type": "object",

pkg/aiusechat/tools_readfile.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ func GetReadTextFileToolDefinition() uctypes.ToolDefinition {
163163
Name: "read_text_file",
164164
DisplayName: "Read Text File",
165165
Description: "Read a text file from the filesystem. Can read specific line ranges or from the end. Detects and rejects binary files.",
166+
ToolLogName: "gen:readfile",
166167
Strict: false,
167168
InputSchema: map[string]any{
168169
"type": "object",

pkg/aiusechat/tools_screenshot.go

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package aiusechat
66
import (
77
"context"
88
"fmt"
9-
"strings"
109
"time"
1110

1211
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
@@ -17,20 +16,6 @@ import (
1716
"github.com/wavetermdev/waveterm/pkg/wstore"
1817
)
1918

20-
func resolveBlockIdFromPrefix(tab *waveobj.Tab, blockIdPrefix string) (string, error) {
21-
if len(blockIdPrefix) != 8 {
22-
return "", fmt.Errorf("widget_id must be 8 characters")
23-
}
24-
25-
for _, blockId := range tab.BlockIds {
26-
if strings.HasPrefix(blockId, blockIdPrefix) {
27-
return blockId, nil
28-
}
29-
}
30-
31-
return "", fmt.Errorf("widget_id not found: %q", blockIdPrefix)
32-
}
33-
3419
func makeTabCaptureBlockScreenshot(tabId string) func(any) (string, error) {
3520
return func(input any) (string, error) {
3621
inputMap, ok := input.(map[string]any)
@@ -75,6 +60,7 @@ func GetCaptureScreenshotToolDefinition(tabId string) uctypes.ToolDefinition {
7560
Name: "capture_screenshot",
7661
DisplayName: "Capture Screenshot",
7762
Description: "Capture a screenshot of a widget and return it as an image",
63+
ToolLogName: "gen:screenshot",
7864
Strict: true,
7965
InputSchema: map[string]any{
8066
"type": "object",

pkg/aiusechat/tools_term.go

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package aiusechat
55

66
import (
7+
"context"
78
"encoding/json"
89
"fmt"
910
"strings"
@@ -14,11 +15,13 @@ import (
1415
"github.com/wavetermdev/waveterm/pkg/wshrpc"
1516
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
1617
"github.com/wavetermdev/waveterm/pkg/wshutil"
18+
"github.com/wavetermdev/waveterm/pkg/wstore"
1719
)
1820

1921
type TermGetScrollbackToolInput struct {
20-
LineStart int `json:"line_start,omitempty"`
21-
Count int `json:"count,omitempty"`
22+
WidgetId string `json:"widget_id"`
23+
LineStart int `json:"line_start,omitempty"`
24+
Count int `json:"count,omitempty"`
2225
}
2326

2427
type TermGetScrollbackToolOutput struct {
@@ -70,18 +73,20 @@ func parseTermGetScrollbackInput(input any) (*TermGetScrollbackToolInput, error)
7073
return result, nil
7174
}
7275

73-
func GetTermGetScrollbackToolDefinition(block *waveobj.Block) uctypes.ToolDefinition {
74-
blockIdPrefix := block.OID[:8]
75-
toolName := fmt.Sprintf("term_get_scrollback_%s", blockIdPrefix)
76-
76+
func GetTermGetScrollbackToolDefinition(tabId string) uctypes.ToolDefinition {
7777
return uctypes.ToolDefinition{
78-
Name: toolName,
79-
DisplayName: fmt.Sprintf("Get Terminal Scrollback %s", blockIdPrefix),
80-
Description: fmt.Sprintf("Fetch terminal scrollback from widget %s as plain text. Index 0 is the most recent line; indices increase going upward (older lines).", blockIdPrefix),
81-
Strict: false,
78+
Name: "term_get_scrollback",
79+
DisplayName: "Get Terminal Scrollback",
80+
Description: "Fetch terminal scrollback from a widget as plain text. Index 0 is the most recent line; indices increase going upward (older lines).",
81+
ToolLogName: "term:getscrollback",
82+
Strict: true,
8283
InputSchema: map[string]any{
8384
"type": "object",
8485
"properties": map[string]any{
86+
"widget_id": map[string]any{
87+
"type": "string",
88+
"description": "8-character widget ID of the terminal widget",
89+
},
8590
"line_start": map[string]any{
8691
"type": "integer",
8792
"minimum": 0,
@@ -93,7 +98,7 @@ func GetTermGetScrollbackToolDefinition(block *waveobj.Block) uctypes.ToolDefini
9398
"description": "Number of lines to return from line_start (default: 200)",
9499
},
95100
},
96-
"required": []string{},
101+
"required": []string{"widget_id"},
97102
"additionalProperties": false,
98103
},
99104
ToolInputDesc: func(input any) string {
@@ -103,17 +108,30 @@ func GetTermGetScrollbackToolDefinition(block *waveobj.Block) uctypes.ToolDefini
103108
}
104109

105110
if parsed.LineStart == 0 && parsed.Count == 200 {
106-
return fmt.Sprintf("reading terminal output from %s (most recent %d lines)", blockIdPrefix, parsed.Count)
111+
return fmt.Sprintf("reading terminal output from %s (most recent %d lines)", parsed.WidgetId, parsed.Count)
107112
}
108113
lineEnd := parsed.LineStart + parsed.Count
109-
return fmt.Sprintf("reading terminal output from %s (lines %d-%d)", blockIdPrefix, parsed.LineStart, lineEnd)
114+
return fmt.Sprintf("reading terminal output from %s (lines %d-%d)", parsed.WidgetId, parsed.LineStart, lineEnd)
110115
},
111116
ToolAnyCallback: func(input any) (any, error) {
112117
parsed, err := parseTermGetScrollbackInput(input)
113118
if err != nil {
114119
return nil, err
115120
}
116121

122+
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
123+
defer cancelFn()
124+
125+
tab, err := wstore.DBMustGet[*waveobj.Tab](ctx, tabId)
126+
if err != nil {
127+
return nil, fmt.Errorf("error getting tab: %w", err)
128+
}
129+
130+
fullBlockId, err := resolveBlockIdFromPrefix(tab, parsed.WidgetId)
131+
if err != nil {
132+
return nil, err
133+
}
134+
117135
lineEnd := parsed.LineStart + parsed.Count
118136

119137
rpcClient := wshclient.GetBareRpcClient()
@@ -123,7 +141,7 @@ func GetTermGetScrollbackToolDefinition(block *waveobj.Block) uctypes.ToolDefini
123141
LineStart: parsed.LineStart,
124142
LineEnd: lineEnd,
125143
},
126-
&wshrpc.RpcOpts{Route: wshutil.MakeFeBlockRouteId(block.OID)},
144+
&wshrpc.RpcOpts{Route: wshutil.MakeFeBlockRouteId(fullBlockId)},
127145
)
128146
if err != nil {
129147
return nil, fmt.Errorf("failed to get terminal scrollback: %w", err)

pkg/aiusechat/tools_tsunami.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,9 @@ func GetTsunamiGetDataToolDefinition(block *waveobj.Block, rtInfo *waveobj.ObjRT
116116
}
117117

118118
return &uctypes.ToolDefinition{
119-
Name: toolName,
120-
Strict: true,
119+
Name: toolName,
120+
ToolLogName: "tsunami:getdata",
121+
Strict: true,
121122
InputSchema: map[string]any{
122123
"type": "object",
123124
"properties": map[string]any{},
@@ -140,8 +141,9 @@ func GetTsunamiGetConfigToolDefinition(block *waveobj.Block, rtInfo *waveobj.Obj
140141
}
141142

142143
return &uctypes.ToolDefinition{
143-
Name: toolName,
144-
Strict: true,
144+
Name: toolName,
145+
ToolLogName: "tsunami:getconfig",
146+
Strict: true,
145147
InputSchema: map[string]any{
146148
"type": "object",
147149
"properties": map[string]any{},
@@ -178,6 +180,7 @@ func GetTsunamiSetConfigToolDefinition(block *waveobj.Block, rtInfo *waveobj.Obj
178180

179181
return &uctypes.ToolDefinition{
180182
Name: toolName,
183+
ToolLogName: "tsunami:setconfig",
181184
InputSchema: inputSchema,
182185
ToolInputDesc: func(input any) string {
183186
return fmt.Sprintf("updating config for %s (%s)", desc, blockIdPrefix)

pkg/aiusechat/tools_web.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ func GetWebNavigateToolDefinition(block *waveobj.Block) uctypes.ToolDefinition {
2222
Name: toolName,
2323
DisplayName: fmt.Sprintf("Navigate Web Block %s", blockIdPrefix),
2424
Description: fmt.Sprintf("Navigate the web browser widget %s to a new URL", blockIdPrefix),
25+
ToolLogName: "web:navigate",
2526
Strict: true,
2627
InputSchema: map[string]any{
2728
"type": "object",

pkg/aiusechat/uctypes/usechat-types.go

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ type ToolDefinition struct {
8080
DisplayName string `json:"displayname,omitempty"` // internal field (cannot marshal to API, must be stripped)
8181
Description string `json:"description"`
8282
ShortDescription string `json:"shortdescription,omitempty"` // internal field (cannot marshal to API, must be stripped)
83+
ToolLogName string `json:"-"` // short name for telemetry (e.g., "term:getscrollback")
8384
InputSchema map[string]any `json:"input_schema"`
8485
Strict bool `json:"strict,omitempty"`
8586
ToolTextCallback func(any) (string, error) `json:"-"`
@@ -189,19 +190,20 @@ type AIUsage struct {
189190
}
190191

191192
type AIMetrics struct {
192-
Usage AIUsage `json:"usage"`
193-
RequestCount int `json:"requestcount"`
194-
ToolUseCount int `json:"toolusecount"`
195-
PremiumReqCount int `json:"premiumreqcount"`
196-
ProxyReqCount int `json:"proxyreqcount"`
197-
HadError bool `json:"haderror"`
198-
ImageCount int `json:"imagecount"`
199-
PDFCount int `json:"pdfcount"`
200-
TextDocCount int `json:"textdoccount"`
201-
TextLen int `json:"textlen"`
202-
FirstByteLatency int `json:"firstbytelatency"` // ms
203-
RequestDuration int `json:"requestduration"` // ms
204-
WidgetAccess bool `json:"widgetaccess"`
193+
Usage AIUsage `json:"usage"`
194+
RequestCount int `json:"requestcount"`
195+
ToolUseCount int `json:"toolusecount"`
196+
ToolDetail map[string]int `json:"tooldetail,omitempty"`
197+
PremiumReqCount int `json:"premiumreqcount"`
198+
ProxyReqCount int `json:"proxyreqcount"`
199+
HadError bool `json:"haderror"`
200+
ImageCount int `json:"imagecount"`
201+
PDFCount int `json:"pdfcount"`
202+
TextDocCount int `json:"textdoccount"`
203+
TextLen int `json:"textlen"`
204+
FirstByteLatency int `json:"firstbytelatency"` // ms
205+
RequestDuration int `json:"requestduration"` // ms
206+
WidgetAccess bool `json:"widgetaccess"`
205207
}
206208

207209
// GenAIMessage interface for messages stored in conversations

pkg/aiusechat/usechat.go

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,19 @@ func GetChatUsage(chat *uctypes.AIChat) uctypes.AIUsage {
199199
return usage
200200
}
201201

202-
func processToolResults(stopReason *uctypes.WaveStopReason, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh) {
202+
func processToolResults(stopReason *uctypes.WaveStopReason, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, metrics *uctypes.AIMetrics) {
203203
var toolResults []uctypes.AIToolResult
204204
for _, toolCall := range stopReason.ToolCalls {
205205
inputJSON, _ := json.Marshal(toolCall.Input)
206206
log.Printf("TOOLUSE name=%s id=%s input=%s\n", toolCall.Name, toolCall.ID, utilfn.TruncateString(string(inputJSON), 40))
207207
result := ResolveToolCall(toolCall, chatOpts)
208208
toolResults = append(toolResults, result)
209+
210+
// Track tool usage by ToolLogName
211+
toolDef := getToolDefinition(toolCall.Name, chatOpts)
212+
if toolDef != nil && toolDef.ToolLogName != "" {
213+
metrics.ToolDetail[toolDef.ToolLogName]++
214+
}
209215
if result.ErrorText != "" {
210216
log.Printf(" error=%s\n", result.ErrorText)
211217
} else {
@@ -242,6 +248,7 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctyp
242248
Model: chatOpts.Config.Model,
243249
},
244250
WidgetAccess: chatOpts.WidgetAccess,
251+
ToolDetail: make(map[string]int),
245252
}
246253
firstStep := true
247254
var cont *uctypes.WaveContinueResponse
@@ -298,7 +305,7 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctyp
298305
}
299306
if stopReason != nil && stopReason.Kind == uctypes.StopKindToolUse {
300307
metrics.ToolUseCount += len(stopReason.ToolCalls)
301-
processToolResults(stopReason, chatOpts, sseHandler)
308+
processToolResults(stopReason, chatOpts, sseHandler, metrics)
302309

303310
var messageID string
304311
if len(rtnMessage) > 0 && rtnMessage[0] != nil {
@@ -318,6 +325,20 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctyp
318325
}
319326

320327
// ResolveToolCall resolves a single tool call and returns an AIToolResult
328+
func getToolDefinition(toolName string, chatOpts uctypes.WaveChatOpts) *uctypes.ToolDefinition {
329+
for _, tool := range chatOpts.Tools {
330+
if tool.Name == toolName {
331+
return &tool
332+
}
333+
}
334+
for _, tool := range chatOpts.TabTools {
335+
if tool.Name == toolName {
336+
return &tool
337+
}
338+
}
339+
return nil
340+
}
341+
321342
func ResolveToolCall(toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpts) (result uctypes.AIToolResult) {
322343
result = uctypes.AIToolResult{
323344
ToolName: toolCall.Name,
@@ -331,22 +352,7 @@ func ResolveToolCall(toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpt
331352
}
332353
}()
333354

334-
// Find the matching tool definition
335-
var toolDef *uctypes.ToolDefinition
336-
for _, tool := range chatOpts.Tools {
337-
if tool.Name == toolCall.Name {
338-
toolDef = &tool
339-
break
340-
}
341-
}
342-
if toolDef == nil {
343-
for _, tool := range chatOpts.TabTools {
344-
if tool.Name == toolCall.Name {
345-
toolDef = &tool
346-
break
347-
}
348-
}
349-
}
355+
toolDef := getToolDefinition(toolCall.Name, chatOpts)
350356

351357
if toolDef == nil {
352358
result.ErrorText = fmt.Sprintf("tool '%s' not found", toolCall.Name)
@@ -442,6 +448,7 @@ func sendAIMetricsTelemetry(ctx context.Context, metrics *uctypes.AIMetrics) {
442448
WaveAIOutputTokens: metrics.Usage.OutputTokens,
443449
WaveAIRequestCount: metrics.RequestCount,
444450
WaveAIToolUseCount: metrics.ToolUseCount,
451+
WaveAIToolDetail: metrics.ToolDetail,
445452
WaveAIPremiumReq: metrics.PremiumReqCount,
446453
WaveAIProxyReq: metrics.ProxyReqCount,
447454
WaveAIHadError: metrics.HadError,
@@ -519,7 +526,6 @@ func WaveAIPostMessageHandler(w http.ResponseWriter, r *http.Request) {
519526
chatOpts.TabStateGenerator = func() (string, []uctypes.ToolDefinition, error) {
520527
return GenerateTabStateAndTools(r.Context(), req.TabId, req.WidgetAccess)
521528
}
522-
chatOpts.Tools = append(chatOpts.Tools, GetCaptureScreenshotToolDefinition(req.TabId))
523529

524530
// Validate the message
525531
if err := req.Msg.Validate(); err != nil {

0 commit comments

Comments
 (0)