Skip to content

Commit f11cc8d

Browse files
committed
make sendToolProgress and createToolUseData generic
1 parent 597a67b commit f11cc8d

File tree

2 files changed

+86
-75
lines changed

2 files changed

+86
-75
lines changed

pkg/aiusechat/aiutil/aiutil.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,20 @@ package aiutil
55

66
import (
77
"bytes"
8+
"context"
89
"crypto/sha256"
910
"encoding/base64"
1011
"encoding/hex"
1112
"encoding/json"
1213
"fmt"
1314
"strconv"
1415
"strings"
16+
"time"
1517

1618
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
1719
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
20+
"github.com/wavetermdev/waveterm/pkg/wcore"
21+
"github.com/wavetermdev/waveterm/pkg/web/sse"
1822
)
1923

2024
// ExtractXmlAttribute extracts an attribute value from an XML-like tag.
@@ -189,3 +193,81 @@ func IsOpenAIReasoningModel(model string) bool {
189193
strings.HasPrefix(m, "gpt-5") ||
190194
strings.HasPrefix(m, "gpt-5.1")
191195
}
196+
197+
// CreateToolUseData creates a UIMessageDataToolUse from tool call information
198+
func CreateToolUseData(toolCallID, toolName string, arguments string, chatOpts uctypes.WaveChatOpts) *uctypes.UIMessageDataToolUse {
199+
toolUseData := &uctypes.UIMessageDataToolUse{
200+
ToolCallId: toolCallID,
201+
ToolName: toolName,
202+
Status: uctypes.ToolUseStatusPending,
203+
}
204+
205+
toolDef := chatOpts.GetToolDefinition(toolName)
206+
if toolDef == nil {
207+
toolUseData.Status = uctypes.ToolUseStatusError
208+
toolUseData.ErrorMessage = "tool not found"
209+
return toolUseData
210+
}
211+
212+
var parsedArgs any
213+
if err := json.Unmarshal([]byte(arguments), &parsedArgs); err != nil {
214+
toolUseData.Status = uctypes.ToolUseStatusError
215+
toolUseData.ErrorMessage = fmt.Sprintf("failed to parse tool arguments: %v", err)
216+
return toolUseData
217+
}
218+
219+
if toolDef.ToolCallDesc != nil {
220+
toolUseData.ToolDesc = toolDef.ToolCallDesc(parsedArgs, nil, nil)
221+
}
222+
223+
if toolDef.ToolApproval != nil {
224+
toolUseData.Approval = toolDef.ToolApproval(parsedArgs)
225+
}
226+
227+
if chatOpts.TabId != "" {
228+
if argsMap, ok := parsedArgs.(map[string]any); ok {
229+
if widgetId, ok := argsMap["widget_id"].(string); ok && widgetId != "" {
230+
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
231+
defer cancelFn()
232+
fullBlockId, err := wcore.ResolveBlockIdFromPrefix(ctx, chatOpts.TabId, widgetId)
233+
if err == nil {
234+
toolUseData.BlockId = fullBlockId
235+
}
236+
}
237+
}
238+
}
239+
240+
return toolUseData
241+
}
242+
243+
244+
// SendToolProgress sends tool progress updates via SSE if the tool has a progress descriptor
245+
func SendToolProgress(toolCallID, toolName string, jsonData []byte, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, usePartialParse bool) {
246+
toolDef := chatOpts.GetToolDefinition(toolName)
247+
if toolDef == nil || toolDef.ToolProgressDesc == nil {
248+
return
249+
}
250+
251+
var parsedJSON any
252+
var err error
253+
if usePartialParse {
254+
parsedJSON, err = utilfn.ParsePartialJson(jsonData)
255+
} else {
256+
err = json.Unmarshal(jsonData, &parsedJSON)
257+
}
258+
if err != nil {
259+
return
260+
}
261+
262+
statusLines, err := toolDef.ToolProgressDesc(parsedJSON)
263+
if err != nil {
264+
return
265+
}
266+
267+
progressData := &uctypes.UIMessageDataToolProgress{
268+
ToolCallId: toolCallID,
269+
ToolName: toolName,
270+
StatusLines: statusLines,
271+
}
272+
_ = sseHandler.AiMsgData("data-toolprogress", "progress-"+toolCallID, progressData)
273+
}

pkg/aiusechat/openai/openai-backend.go

Lines changed: 4 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ import (
1717

1818
"github.com/google/uuid"
1919
"github.com/launchdarkly/eventsource"
20+
"github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil"
2021
"github.com/wavetermdev/waveterm/pkg/aiusechat/chatstore"
2122
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
2223
"github.com/wavetermdev/waveterm/pkg/util/logutil"
2324
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
24-
"github.com/wavetermdev/waveterm/pkg/wcore"
2525
"github.com/wavetermdev/waveterm/pkg/web/sse"
2626
)
2727

@@ -862,8 +862,7 @@ func handleOpenAIEvent(
862862
}
863863
if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockToolUse {
864864
st.partialJSON = append(st.partialJSON, []byte(ev.Delta)...)
865-
toolDef := state.chatOpts.GetToolDefinition(st.toolName)
866-
sendToolProgress(st, toolDef, sse, st.partialJSON, true)
865+
aiutil.SendToolProgress(st.toolCallID, st.toolName, st.partialJSON, state.chatOpts, sse, true)
867866
}
868867
return nil, nil
869868

@@ -876,10 +875,9 @@ func handleOpenAIEvent(
876875

877876
// Get the function call info from the block state
878877
if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockToolUse {
879-
toolDef := state.chatOpts.GetToolDefinition(st.toolName)
880-
toolUseData := createToolUseData(st.toolCallID, st.toolName, toolDef, ev.Arguments, state.chatOpts)
878+
toolUseData := aiutil.CreateToolUseData(st.toolCallID, st.toolName, ev.Arguments, state.chatOpts)
881879
state.toolUseData[st.toolCallID] = toolUseData
882-
sendToolProgress(st, toolDef, sse, []byte(ev.Arguments), false)
880+
aiutil.SendToolProgress(st.toolCallID, st.toolName, []byte(ev.Arguments), state.chatOpts, sse, false)
883881
}
884882
return nil, nil
885883

@@ -936,75 +934,6 @@ func handleOpenAIEvent(
936934
}
937935
}
938936

939-
func sendToolProgress(st *openaiBlockState, toolDef *uctypes.ToolDefinition, sse *sse.SSEHandlerCh, jsonData []byte, usePartialParse bool) {
940-
if toolDef == nil || toolDef.ToolProgressDesc == nil {
941-
return
942-
}
943-
var parsedJSON any
944-
var err error
945-
if usePartialParse {
946-
parsedJSON, err = utilfn.ParsePartialJson(jsonData)
947-
} else {
948-
err = json.Unmarshal(jsonData, &parsedJSON)
949-
}
950-
if err != nil {
951-
return
952-
}
953-
statusLines, err := toolDef.ToolProgressDesc(parsedJSON)
954-
if err != nil {
955-
return
956-
}
957-
progressData := &uctypes.UIMessageDataToolProgress{
958-
ToolCallId: st.toolCallID,
959-
ToolName: st.toolName,
960-
StatusLines: statusLines,
961-
}
962-
_ = sse.AiMsgData("data-toolprogress", "progress-"+st.toolCallID, progressData)
963-
}
964-
965-
func createToolUseData(toolCallID, toolName string, toolDef *uctypes.ToolDefinition, arguments string, chatOpts uctypes.WaveChatOpts) *uctypes.UIMessageDataToolUse {
966-
toolUseData := &uctypes.UIMessageDataToolUse{
967-
ToolCallId: toolCallID,
968-
ToolName: toolName,
969-
Status: uctypes.ToolUseStatusPending,
970-
}
971-
972-
if toolDef == nil {
973-
toolUseData.Status = uctypes.ToolUseStatusError
974-
toolUseData.ErrorMessage = "tool not found"
975-
return toolUseData
976-
}
977-
978-
var parsedArgs any
979-
if err := json.Unmarshal([]byte(arguments), &parsedArgs); err != nil {
980-
toolUseData.Status = uctypes.ToolUseStatusError
981-
toolUseData.ErrorMessage = fmt.Sprintf("failed to parse tool arguments: %v", err)
982-
return toolUseData
983-
}
984-
985-
if toolDef.ToolCallDesc != nil {
986-
toolUseData.ToolDesc = toolDef.ToolCallDesc(parsedArgs, nil, nil)
987-
}
988-
989-
if toolDef.ToolApproval != nil {
990-
toolUseData.Approval = toolDef.ToolApproval(parsedArgs)
991-
}
992-
993-
if chatOpts.TabId != "" {
994-
if argsMap, ok := parsedArgs.(map[string]any); ok {
995-
if widgetId, ok := argsMap["widget_id"].(string); ok && widgetId != "" {
996-
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
997-
defer cancelFn()
998-
fullBlockId, err := wcore.ResolveBlockIdFromPrefix(ctx, chatOpts.TabId, widgetId)
999-
if err == nil {
1000-
toolUseData.BlockId = fullBlockId
1001-
}
1002-
}
1003-
}
1004-
}
1005-
1006-
return toolUseData
1007-
}
1008937

1009938
// extractMessageAndToolsFromResponse extracts the final OpenAI message and tool calls from the completed response
1010939
func extractMessageAndToolsFromResponse(resp openaiResponse, state *openaiStreamingState) ([]*OpenAIChatMessage, []uctypes.WaveToolCall) {

0 commit comments

Comments
 (0)