Skip to content

Commit edc20f7

Browse files
Copilotsawka
andauthored
Bring Anthropic usechat backend to OpenAI-level tool-use parity and stream robustness (#2971)
This updates `pkg/aiusechat/anthropic` from partial implementation to full backend parity for core tool-use orchestration and stream behavior. The main gaps were unimplemented tool lifecycle methods, missing persisted tool-use UI state, and weaker disconnect/error handling versus the OpenAI backend. - **Tool-use lifecycle parity (critical path)** - Implemented Anthropic backend support for: - `UpdateToolUseData` - `RemoveToolUseCall` - `GetFunctionCallInputByToolCallId` - Wired `pkg/aiusechat/usechat-backend.go` to call Anthropic implementations instead of stubs. - Added Anthropic run-step nil-message guard so `nil` responses are not wrapped into `[]GenAIMessage{nil}`. - **Persisted tool-use state in Anthropic native messages** - Added internal `ToolUseData` storage on Anthropic `tool_use` blocks. - Ensured internal-only fields are stripped before API requests via `Clean()`. - **UI conversion parity for reloaded history** - Extended `ConvertToUIMessage()` to emit `data-tooluse` parts when tool-use metadata exists, in addition to `tool-{name}` parts. - **Streaming UX parity for tool argument deltas** - Added `aiutil.SendToolProgress(...)` calls during: - `input_json_delta` (incremental updates) - `content_block_stop` for `tool_use` (final update) - **Disconnect/stream robustness** - Added `sse.Err()` checks in event handling and decode-error path. - Added partial-text extraction on client disconnect and deterministic ordering of partial blocks. - Cleans up completed blocks from in-flight state to avoid duplicate partial extraction. - **Correctness + hygiene alignment** - Continuation model checks now use `AreModelsCompatible(...)` (instead of strict string equality). - Added hostname sanitization in Anthropic error paths (HTTP error parsing and `httpClient.Do` failures). - Replaced unconditional Anthropic debug `log.Printf` calls with `logutil.DevPrintf`. - **Targeted coverage additions** - Added Anthropic tests for: - function-call lookup by tool call id - tool-use data update + removal - `data-tooluse` UI conversion behavior ```go // usechat-backend.go func (b *anthropicBackend) RunChatStep(...) (..., []uctypes.GenAIMessage, ...) { stopReason, msg, rateLimitInfo, err := anthropic.RunAnthropicChatStep(ctx, sseHandler, chatOpts, cont) if msg == nil { return stopReason, nil, rateLimitInfo, err } return stopReason, []uctypes.GenAIMessage{msg}, rateLimitInfo, err } ``` <!-- START COPILOT CODING AGENT TIPS --> --- ✨ Let Copilot coding agent [set things up for you](https://github.com/wavetermdev/waveterm/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: sawka <2722291+sawka@users.noreply.github.com> Co-authored-by: sawka <mike@commandline.dev>
1 parent 9d4acb7 commit edc20f7

5 files changed

Lines changed: 335 additions & 41 deletions

File tree

frontend/app/aipanel/aipanel.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ const AIPanelComponentInner = memo(() => {
306306
};
307307

308308
useEffect(() => {
309-
globalStore.set(model.isAIStreaming, status == "streaming");
309+
globalStore.set(model.isAIStreaming, status === "streaming" || status === "submitted");
310310
}, [status]);
311311

312312
useEffect(() => {

pkg/aiusechat/anthropic/anthropic-backend.go

Lines changed: 111 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ import (
1010
"errors"
1111
"fmt"
1212
"io"
13-
"log"
1413
"net/http"
14+
"net/url"
15+
"sort"
1516
"strings"
1617
"time"
1718

@@ -20,6 +21,7 @@ import (
2021
"github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil"
2122
"github.com/wavetermdev/waveterm/pkg/aiusechat/chatstore"
2223
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
24+
"github.com/wavetermdev/waveterm/pkg/util/logutil"
2325
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
2426
"github.com/wavetermdev/waveterm/pkg/web/sse"
2527
)
@@ -56,10 +58,11 @@ func (m *anthropicChatMessage) GetUsage() *uctypes.AIUsage {
5658
}
5759

5860
return &uctypes.AIUsage{
59-
APIType: uctypes.APIType_AnthropicMessages,
60-
Model: m.Usage.Model,
61-
InputTokens: m.Usage.InputTokens,
62-
OutputTokens: m.Usage.OutputTokens,
61+
APIType: uctypes.APIType_AnthropicMessages,
62+
Model: m.Usage.Model,
63+
InputTokens: m.Usage.InputTokens,
64+
OutputTokens: m.Usage.OutputTokens,
65+
NativeWebSearchCount: m.Usage.NativeWebSearchCount,
6366
}
6467
}
6568

@@ -95,8 +98,9 @@ type anthropicMessageContentBlock struct {
9598
Name string `json:"name,omitempty"`
9699
Input interface{} `json:"input,omitempty"`
97100

98-
ToolUseDisplayName string `json:"toolusedisplayname,omitempty"` // internal field (cannot marshal to API, must be stripped)
99-
ToolUseShortDescription string `json:"tooluseshortdescription,omitempty"` // internal field (cannot marshal to API, must be stripped)
101+
ToolUseDisplayName string `json:"toolusedisplayname,omitempty"` // internal field (cannot marshal to API, must be stripped)
102+
ToolUseShortDescription string `json:"tooluseshortdescription,omitempty"` // internal field (cannot marshal to API, must be stripped)
103+
ToolUseData *uctypes.UIMessageDataToolUse `json:"toolusedata,omitempty"` // internal field (cannot marshal to API, must be stripped)
100104

101105
// Tool result content
102106
ToolUseID string `json:"tool_use_id,omitempty"`
@@ -154,6 +158,7 @@ func (b *anthropicMessageContentBlock) Clean() *anthropicMessageContentBlock {
154158
rtn.SourcePreviewUrl = ""
155159
rtn.ToolUseDisplayName = ""
156160
rtn.ToolUseShortDescription = ""
161+
rtn.ToolUseData = nil
157162
if rtn.Source != nil {
158163
rtn.Source = rtn.Source.Clean()
159164
}
@@ -177,10 +182,15 @@ type anthropicStreamRequest struct {
177182
Stream bool `json:"stream"`
178183
System []anthropicMessageContentBlock `json:"system,omitempty"`
179184
ToolChoice any `json:"tool_choice,omitempty"`
180-
Tools []uctypes.ToolDefinition `json:"tools,omitempty"`
185+
Tools []any `json:"tools,omitempty"` // *uctypes.ToolDefinition or *anthropicWebSearchTool
181186
Thinking *anthropicThinkingOpts `json:"thinking,omitempty"`
182187
}
183188

189+
type anthropicWebSearchTool struct {
190+
Type string `json:"type"` // "web_search_20250305"
191+
Name string `json:"name"` // "web_search"
192+
}
193+
184194
type anthropicCacheControl struct {
185195
Type string `json:"type"` // "ephemeral"
186196
TTL string `json:"ttl"` // "5m" or "1h"
@@ -228,8 +238,9 @@ type anthropicUsageType struct {
228238
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
229239
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
230240

231-
// internal field for Wave use (not sent to API)
232-
Model string `json:"model,omitempty"`
241+
// internal fields for Wave use (not sent to API)
242+
Model string `json:"model,omitempty"`
243+
NativeWebSearchCount int `json:"nativewebsearchcount,omitempty"`
233244

234245
// for reference, but we dont keep thsese up to date or track them
235246
CacheCreation *anthropicCacheCreationType `json:"cache_creation,omitempty"` // breakdown of cached tokens by TTL
@@ -290,14 +301,16 @@ type partialJSON struct {
290301
}
291302

292303
type streamingState struct {
293-
blockMap map[int]*blockState
294-
toolCalls []uctypes.WaveToolCall
295-
stopFromDelta string
296-
msgID string
297-
model string
298-
stepStarted bool
299-
rtnMessage *anthropicChatMessage
300-
usage *anthropicUsageType
304+
blockMap map[int]*blockState
305+
toolCalls []uctypes.WaveToolCall
306+
stopFromDelta string
307+
msgID string
308+
model string
309+
stepStarted bool
310+
rtnMessage *anthropicChatMessage
311+
usage *anthropicUsageType
312+
chatOpts uctypes.WaveChatOpts
313+
webSearchCount int
301314
}
302315

303316
func (p *partialJSON) Write(s string) {
@@ -330,6 +343,20 @@ func (p *partialJSON) FinalObject() (json.RawMessage, error) {
330343
}
331344
}
332345

346+
// sanitizeHostnameInError removes the Wave cloud hostname from error messages
347+
func sanitizeHostnameInError(err error) error {
348+
if err == nil {
349+
return nil
350+
}
351+
errStr := err.Error()
352+
parsedURL, parseErr := url.Parse(uctypes.DefaultAIEndpoint)
353+
if parseErr == nil && parsedURL.Host != "" && strings.Contains(errStr, parsedURL.Host) {
354+
errStr = strings.ReplaceAll(errStr, uctypes.DefaultAIEndpoint, "AI service")
355+
errStr = strings.ReplaceAll(errStr, parsedURL.Host, "host")
356+
}
357+
return fmt.Errorf("%s", errStr)
358+
}
359+
333360
// makeThinkingOpts creates thinking options based on level and max tokens
334361
func makeThinkingOpts(thinkingLevel string, maxTokens int) *anthropicThinkingOpts {
335362
if thinkingLevel != uctypes.ThinkingLevelMedium && thinkingLevel != uctypes.ThinkingLevelHigh {
@@ -373,21 +400,21 @@ func parseAnthropicHTTPError(resp *http.Response) error {
373400
// Try to parse as Anthropic error format first
374401
var eresp anthropicHTTPErrorResponse
375402
if err := json.Unmarshal(slurp, &eresp); err == nil && eresp.Error.Message != "" {
376-
return fmt.Errorf("anthropic %s: %s", resp.Status, eresp.Error.Message)
403+
return sanitizeHostnameInError(fmt.Errorf("anthropic %s: %s", resp.Status, eresp.Error.Message))
377404
}
378405

379406
// Try to parse as proxy error format
380407
var proxyErr uctypes.ProxyErrorResponse
381408
if err := json.Unmarshal(slurp, &proxyErr); err == nil && !proxyErr.Success && proxyErr.Error != "" {
382-
return fmt.Errorf("anthropic %s: %s", resp.Status, proxyErr.Error)
409+
return sanitizeHostnameInError(fmt.Errorf("anthropic %s: %s", resp.Status, proxyErr.Error))
383410
}
384411

385412
// Fall back to truncated raw response
386413
msg := utilfn.TruncateString(strings.TrimSpace(string(slurp)), 120)
387414
if msg == "" {
388415
msg = "unknown error"
389416
}
390-
return fmt.Errorf("anthropic %s: %s", resp.Status, msg)
417+
return sanitizeHostnameInError(fmt.Errorf("anthropic %s: %s", resp.Status, msg))
391418
}
392419

393420
func RunAnthropicChatStep(
@@ -426,7 +453,7 @@ func RunAnthropicChatStep(
426453

427454
// Validate continuation if provided
428455
if cont != nil {
429-
if chatOpts.Config.Model != cont.Model {
456+
if !uctypes.AreModelsCompatible(chat.APIType, chatOpts.Config.Model, cont.Model) {
430457
return nil, nil, nil, fmt.Errorf("cannot continue with a different model, model:%q, cont-model:%q", chatOpts.Config.Model, cont.Model)
431458
}
432459
}
@@ -461,7 +488,7 @@ func RunAnthropicChatStep(
461488

462489
resp, err := httpClient.Do(req)
463490
if err != nil {
464-
return nil, nil, nil, err
491+
return nil, nil, nil, sanitizeHostnameInError(err)
465492
}
466493
defer resp.Body.Close()
467494

@@ -499,7 +526,7 @@ func RunAnthropicChatStep(
499526
// Use eventsource decoder for proper SSE parsing
500527
decoder := eventsource.NewDecoder(resp.Body)
501528

502-
stopReason, rtnMessage := handleAnthropicStreamingResp(ctx, sse, decoder, cont)
529+
stopReason, rtnMessage := handleAnthropicStreamingResp(ctx, sse, decoder, cont, chatOpts)
503530
return stopReason, rtnMessage, rateLimitInfo, nil
504531
}
505532

@@ -509,6 +536,7 @@ func handleAnthropicStreamingResp(
509536
sse *sse.SSEHandlerCh,
510537
decoder *eventsource.Decoder,
511538
cont *uctypes.WaveContinueResponse,
539+
chatOpts uctypes.WaveChatOpts,
512540
) (*uctypes.WaveStopReason, *anthropicChatMessage) {
513541
// Per-response state
514542
state := &streamingState{
@@ -518,6 +546,7 @@ func handleAnthropicStreamingResp(
518546
Role: "assistant",
519547
Content: []anthropicMessageContentBlock{},
520548
},
549+
chatOpts: chatOpts,
521550
}
522551

523552
var rtnStopReason *uctypes.WaveStopReason
@@ -526,8 +555,10 @@ func handleAnthropicStreamingResp(
526555
defer func() {
527556
// Set usage in the returned message
528557
if state.usage != nil {
529-
// Set model in usage for internal use
530558
state.usage.Model = state.model
559+
if state.webSearchCount > 0 {
560+
state.usage.NativeWebSearchCount = state.webSearchCount
561+
}
531562
state.rtnMessage.Usage = state.usage
532563
}
533564

@@ -558,6 +589,13 @@ func handleAnthropicStreamingResp(
558589
// Normal end of stream
559590
break
560591
}
592+
if sse.Err() != nil {
593+
return &uctypes.WaveStopReason{
594+
Kind: uctypes.StopKindCanceled,
595+
ErrorType: "client_disconnect",
596+
ErrorText: "client disconnected",
597+
}, extractPartialTextFromState(state)
598+
}
561599
// transport error mid-stream
562600
_ = sse.AiMsgError(err.Error())
563601
return &uctypes.WaveStopReason{
@@ -587,6 +625,37 @@ func handleAnthropicStreamingResp(
587625
return rtnStopReason, state.rtnMessage
588626
}
589627

628+
func extractPartialTextFromState(state *streamingState) *anthropicChatMessage {
629+
var content []anthropicMessageContentBlock
630+
for _, block := range state.rtnMessage.Content {
631+
if block.Type == "text" && block.Text != "" {
632+
content = append(content, block)
633+
}
634+
}
635+
var partialIdx []int
636+
for idx, st := range state.blockMap {
637+
if st.kind == blockText && st.contentBlock != nil && st.contentBlock.Text != "" {
638+
partialIdx = append(partialIdx, idx)
639+
}
640+
}
641+
sort.Ints(partialIdx)
642+
for _, idx := range partialIdx {
643+
st := state.blockMap[idx]
644+
if st.kind == blockText && st.contentBlock != nil && st.contentBlock.Text != "" {
645+
content = append(content, *st.contentBlock)
646+
}
647+
}
648+
if len(content) == 0 {
649+
return nil
650+
}
651+
return &anthropicChatMessage{
652+
MessageId: state.rtnMessage.MessageId,
653+
Role: "assistant",
654+
Content: content,
655+
Usage: state.rtnMessage.Usage,
656+
}
657+
}
658+
590659
// handleAnthropicEvent processes one SSE event block. It may emit SSE parts
591660
// and/or return a StopReason when the stream is complete.
592661
//
@@ -601,6 +670,13 @@ func handleAnthropicEvent(
601670
state *streamingState,
602671
cont *uctypes.WaveContinueResponse,
603672
) (stopFromDelta *string, final *uctypes.WaveStopReason) {
673+
if err := sse.Err(); err != nil {
674+
return nil, &uctypes.WaveStopReason{
675+
Kind: uctypes.StopKindCanceled,
676+
ErrorType: "client_disconnect",
677+
ErrorText: "client disconnected",
678+
}
679+
}
604680
eventName := event.Event()
605681
data := event.Data()
606682
switch eventName {
@@ -693,6 +769,10 @@ func handleAnthropicEvent(
693769
}
694770
state.blockMap[idx] = st
695771
_ = sse.AiMsgToolInputStart(tcID, tName)
772+
case "server_tool_use":
773+
if ev.ContentBlock.Name == "web_search" {
774+
state.webSearchCount++
775+
}
696776
default:
697777
// ignore other block types gracefully per Anthropic guidance :contentReference[oaicite:18]{index=18}
698778
}
@@ -732,6 +812,7 @@ func handleAnthropicEvent(
732812
if st.kind == blockToolUse {
733813
st.accumJSON.Write(ev.Delta.PartialJSON)
734814
_ = sse.AiMsgToolInputDelta(st.toolCallID, ev.Delta.PartialJSON)
815+
aiutil.SendToolProgress(st.toolCallID, st.toolName, st.accumJSON.Bytes(), state.chatOpts, sse, true)
735816
}
736817
case "signature_delta":
737818
// Accumulate signature for thinking blocks
@@ -784,6 +865,7 @@ func handleAnthropicEvent(
784865
}
785866
}
786867
_ = sse.AiMsgToolInputAvailable(st.toolCallID, st.toolName, raw)
868+
aiutil.SendToolProgress(st.toolCallID, st.toolName, raw, state.chatOpts, sse, false)
787869
state.toolCalls = append(state.toolCalls, uctypes.WaveToolCall{
788870
ID: st.toolCallID,
789871
Name: st.toolName,
@@ -798,6 +880,9 @@ func handleAnthropicEvent(
798880
}
799881
state.rtnMessage.Content = append(state.rtnMessage.Content, toolUseBlock)
800882
}
883+
// extractPartialTextFromState reads blockMap for still-in-flight content, so remove completed blocks
884+
// once they have been appended to rtnMessage.Content to avoid duplicate text on disconnect.
885+
delete(state.blockMap, *ev.Index)
801886
return nil, nil
802887

803888
case "message_delta":
@@ -868,7 +953,7 @@ func handleAnthropicEvent(
868953
}
869954

870955
default:
871-
log.Printf("unknown anthropic event type: %s", eventName)
956+
logutil.DevPrintf("unknown anthropic event type: %s", eventName)
872957
return nil, nil
873958
}
874959
}

0 commit comments

Comments
 (0)