Skip to content

Commit 3399a6e

Browse files
committed
add tools, check capabilities
1 parent 091cb0c commit 3399a6e

File tree

4 files changed

+84
-24
lines changed

4 files changed

+84
-24
lines changed

pkg/aiusechat/openaicomp/openaicomp-convertmessage.go

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,31 @@ func appendToLastUserMessage(messages []CompletionsMessage, text string) {
3232
}
3333
}
3434

35+
// convertToolDefinitions converts Wave ToolDefinitions to OpenAI format
36+
// Only includes tools whose required capabilities are met
37+
func convertToolDefinitions(waveTools []uctypes.ToolDefinition, capabilities []string) []ToolDefinition {
38+
if len(waveTools) == 0 {
39+
return nil
40+
}
41+
42+
openaiTools := make([]ToolDefinition, 0, len(waveTools))
43+
for _, waveTool := range waveTools {
44+
if !waveTool.HasRequiredCapabilities(capabilities) {
45+
continue
46+
}
47+
openaiTool := ToolDefinition{
48+
Type: "function",
49+
Function: ToolFunctionDef{
50+
Name: waveTool.Name,
51+
Description: waveTool.Description,
52+
Parameters: waveTool.InputSchema,
53+
},
54+
}
55+
openaiTools = append(openaiTools, openaiTool)
56+
}
57+
return openaiTools
58+
}
59+
3560
// buildCompletionsHTTPRequest creates an HTTP request for the OpenAI completions API
3661
func buildCompletionsHTTPRequest(ctx context.Context, messages []CompletionsMessage, chatOpts uctypes.WaveChatOpts) (*http.Request, error) {
3762
opts := chatOpts.Config
@@ -77,8 +102,18 @@ func buildCompletionsHTTPRequest(ctx context.Context, messages []CompletionsMess
77102
reqBody.MaxTokens = maxTokens
78103
}
79104

105+
// Add tool definitions if tools capability is available and tools exist
106+
var allTools []uctypes.ToolDefinition
107+
if opts.HasCapability(uctypes.AICapabilityTools) {
108+
allTools = append(allTools, chatOpts.Tools...)
109+
allTools = append(allTools, chatOpts.TabTools...)
110+
if len(allTools) > 0 {
111+
reqBody.Tools = convertToolDefinitions(allTools, opts.Capabilities)
112+
}
113+
}
114+
80115
if wavebase.IsDevMode() {
81-
log.Printf("openaicomp: model %s, messages: %d\n", opts.Model, len(messages))
116+
log.Printf("openaicomp: model %s, messages: %d, tools: %d\n", opts.Model, len(messages), len(allTools))
82117
}
83118

84119
buf, err := json.Marshal(reqBody)
@@ -120,30 +155,30 @@ func ConvertAIMessageToCompletionsMessage(aiMsg uctypes.AIMessage) (*Completions
120155
firstText := true
121156
for _, part := range aiMsg.Parts {
122157
var partText string
123-
158+
124159
switch {
125160
case part.Type == uctypes.AIMessagePartTypeText:
126161
partText = part.Text
127-
162+
128163
case part.MimeType == "text/plain":
129164
textData, err := aiutil.ExtractTextData(part.Data, part.URL)
130165
if err != nil {
131166
log.Printf("openaicomp: error extracting text data for %s: %v\n", part.FileName, err)
132167
continue
133168
}
134169
partText = aiutil.FormatAttachedTextFile(part.FileName, textData)
135-
170+
136171
case part.MimeType == "directory":
137172
if len(part.Data) == 0 {
138173
log.Printf("openaicomp: directory listing part missing data for %s\n", part.FileName)
139174
continue
140175
}
141176
partText = aiutil.FormatAttachedDirectoryListing(part.FileName, string(part.Data))
142-
177+
143178
default:
144179
continue
145180
}
146-
181+
147182
if partText != "" {
148183
if !firstText {
149184
textBuilder.WriteString("\n\n")

pkg/aiusechat/tools_screenshot.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ func GetCaptureScreenshotToolDefinition(tabId string) uctypes.ToolDefinition {
6767
"required": []string{"widget_id"},
6868
"additionalProperties": false,
6969
},
70+
RequiredCapabilities: []string{uctypes.AICapabilityImages},
7071
ToolCallDesc: func(input any, output any, toolUseData *uctypes.UIMessageDataToolUse) string {
7172
inputMap, ok := input.(map[string]any)
7273
if !ok {

pkg/aiusechat/uctypes/uctypes.go

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package uctypes
66
import (
77
"fmt"
88
"net/url"
9+
"slices"
910
"strings"
1011
)
1112

@@ -78,13 +79,14 @@ type UIMessageDataUserFile struct {
7879

7980
// ToolDefinition represents a tool that can be used by the AI model
8081
type ToolDefinition struct {
81-
Name string `json:"name"`
82-
DisplayName string `json:"displayname,omitempty"` // internal field (cannot marshal to API, must be stripped)
83-
Description string `json:"description"`
84-
ShortDescription string `json:"shortdescription,omitempty"` // internal field (cannot marshal to API, must be stripped)
85-
ToolLogName string `json:"-"` // short name for telemetry (e.g., "term:getscrollback")
86-
InputSchema map[string]any `json:"input_schema"`
87-
Strict bool `json:"strict,omitempty"`
82+
Name string `json:"name"`
83+
DisplayName string `json:"displayname,omitempty"` // internal field (cannot marshal to API, must be stripped)
84+
Description string `json:"description"`
85+
ShortDescription string `json:"shortdescription,omitempty"` // internal field (cannot marshal to API, must be stripped)
86+
ToolLogName string `json:"-"` // short name for telemetry (e.g., "term:getscrollback")
87+
InputSchema map[string]any `json:"input_schema"`
88+
Strict bool `json:"strict,omitempty"`
89+
RequiredCapabilities []string `json:"requiredcapabilities,omitempty"`
8890

8991
ToolTextCallback func(any) (string, error) `json:"-"`
9092
ToolAnyCallback func(any, *UIMessageDataToolUse) (any, error) `json:"-"` // *UIMessageDataToolUse will NOT be nil
@@ -114,6 +116,18 @@ func (td *ToolDefinition) Desc() string {
114116
return td.Description
115117
}
116118

119+
func (td *ToolDefinition) HasRequiredCapabilities(capabilities []string) bool {
120+
if td == nil || len(td.RequiredCapabilities) == 0 {
121+
return true
122+
}
123+
for _, reqCap := range td.RequiredCapabilities {
124+
if !slices.Contains(capabilities, reqCap) {
125+
return false
126+
}
127+
}
128+
return true
129+
}
130+
117131
//------------------
118132
// Wave specific types, stop reasons, tool calls, config
119133
// these are used internally to coordinate the calls/steps
@@ -168,6 +182,10 @@ type AIThinkingModeConfig struct {
168182
Capabilities []string `json:"capabilities,omitempty"`
169183
}
170184

185+
func (c *AIThinkingModeConfig) HasCapability(cap string) bool {
186+
return slices.Contains(c.Capabilities, cap)
187+
}
188+
171189
// when updating this struct, also modify frontend/app/aipanel/aitypes.ts WaveUIDataTypes.tooluse
172190
type UIMessageDataToolUse struct {
173191
ToolCallId string `json:"toolcallid"`
@@ -230,17 +248,18 @@ type WaveContinueResponse struct {
230248

231249
// Wave Specific AI opts for configuration
232250
type AIOptsType struct {
233-
APIType string `json:"apitype,omitempty"`
234-
Model string `json:"model"`
235-
APIToken string `json:"apitoken"`
236-
OrgID string `json:"orgid,omitempty"`
237-
APIVersion string `json:"apiversion,omitempty"`
238-
BaseURL string `json:"baseurl,omitempty"`
239-
ProxyURL string `json:"proxyurl,omitempty"`
240-
MaxTokens int `json:"maxtokens,omitempty"`
241-
TimeoutMs int `json:"timeoutms,omitempty"`
242-
ThinkingLevel string `json:"thinkinglevel,omitempty"` // ThinkingLevelLow, ThinkingLevelMedium, or ThinkingLevelHigh
243-
ThinkingMode string `json:"thinkingmode,omitempty"` // quick, balanced, or deep
251+
APIType string `json:"apitype,omitempty"`
252+
Model string `json:"model"`
253+
APIToken string `json:"apitoken"`
254+
OrgID string `json:"orgid,omitempty"`
255+
APIVersion string `json:"apiversion,omitempty"`
256+
BaseURL string `json:"baseurl,omitempty"`
257+
ProxyURL string `json:"proxyurl,omitempty"`
258+
MaxTokens int `json:"maxtokens,omitempty"`
259+
TimeoutMs int `json:"timeoutms,omitempty"`
260+
ThinkingLevel string `json:"thinkinglevel,omitempty"` // ThinkingLevelLow, ThinkingLevelMedium, or ThinkingLevelHigh
261+
ThinkingMode string `json:"thinkingmode,omitempty"` // quick, balanced, or deep
262+
Capabilities []string `json:"capabilities,omitempty"`
244263
}
245264

246265
func (opts AIOptsType) IsWaveProxy() bool {
@@ -251,6 +270,10 @@ func (opts AIOptsType) IsPremiumModel() bool {
251270
return opts.Model == "gpt-5" || opts.Model == "gpt-5.1" || strings.Contains(opts.Model, "claude-sonnet")
252271
}
253272

273+
func (opts AIOptsType) HasCapability(cap string) bool {
274+
return slices.Contains(opts.Capabilities, cap)
275+
}
276+
254277
type AIChat struct {
255278
ChatId string `json:"chatid"`
256279
APIType string `json:"apitype"`

pkg/aiusechat/usechat.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ func getWaveAISettings(premium bool, builderMode bool, rtInfo *waveobj.ObjRTInfo
152152
ThinkingLevel: config.ThinkingLevel,
153153
ThinkingMode: thinkingMode,
154154
BaseURL: baseUrl,
155+
Capabilities: config.Capabilities,
155156
}
156157
if apiToken != "" {
157158
opts.APIToken = apiToken

0 commit comments

Comments
 (0)