Skip to content

Commit de902ec

Browse files
authored
ai backend refactor + claude/anthropic API support (#1262)
1 parent 38eeba5 commit de902ec

5 files changed

Lines changed: 589 additions & 246 deletions

File tree

frontend/app/view/waveai/waveai.tsx

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,15 @@ export class WaveAiModel implements ViewModel {
180180
const presetKey = get(this.presetKey);
181181
const presetName = presets[presetKey]?.["display:name"] ?? "";
182182
const isCloud = isBlank(aiOpts.apitoken) && isBlank(aiOpts.baseurl);
183-
if (isCloud) {
183+
if (aiOpts?.apitype == "anthropic") {
184+
const modelName = aiOpts.model;
185+
viewTextChildren.push({
186+
elemtype: "iconbutton",
187+
icon: "globe",
188+
title: "Using Remote Antropic API (" + modelName + ")",
189+
disabled: true,
190+
});
191+
} else if (isCloud) {
184192
viewTextChildren.push({
185193
elemtype: "iconbutton",
186194
icon: "cloud",

pkg/waveai/anthropicbackend.go

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
// Copyright 2024, Command Line Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package waveai
5+
6+
import (
7+
"bufio"
8+
"context"
9+
"encoding/json"
10+
"errors"
11+
"fmt"
12+
"io"
13+
"log"
14+
"net/http"
15+
"runtime/debug"
16+
"strings"
17+
18+
"github.com/wavetermdev/waveterm/pkg/wshrpc"
19+
)
20+
21+
type AnthropicBackend struct{}
22+
23+
var _ AIBackend = AnthropicBackend{}
24+
25+
// Claude API request types
26+
type anthropicMessage struct {
27+
Role string `json:"role"`
28+
Content string `json:"content"`
29+
}
30+
31+
type anthropicRequest struct {
32+
Model string `json:"model"`
33+
Messages []anthropicMessage `json:"messages"`
34+
System string `json:"system,omitempty"`
35+
MaxTokens int `json:"max_tokens,omitempty"`
36+
Stream bool `json:"stream"`
37+
Temperature float32 `json:"temperature,omitempty"`
38+
}
39+
40+
// Claude API response types for SSE events
41+
type anthropicContentBlock struct {
42+
Type string `json:"type"` // "text" or other content types
43+
Text string `json:"text,omitempty"`
44+
}
45+
46+
type anthropicUsage struct {
47+
InputTokens int `json:"input_tokens"`
48+
OutputTokens int `json:"output_tokens"`
49+
}
50+
51+
type anthropicResponseMessage struct {
52+
ID string `json:"id"`
53+
Type string `json:"type"`
54+
Role string `json:"role"`
55+
Content []anthropicContentBlock `json:"content"`
56+
Model string `json:"model"`
57+
StopReason string `json:"stop_reason,omitempty"`
58+
StopSequence string `json:"stop_sequence,omitempty"`
59+
Usage *anthropicUsage `json:"usage,omitempty"`
60+
}
61+
62+
type anthropicStreamEventError struct {
63+
Type string `json:"type"`
64+
Message string `json:"message"`
65+
}
66+
67+
type anthropicStreamEventDelta struct {
68+
Text string `json:"text"`
69+
}
70+
71+
type anthropicStreamEvent struct {
72+
Type string `json:"type"`
73+
Message *anthropicResponseMessage `json:"message,omitempty"`
74+
ContentBlock *anthropicContentBlock `json:"content_block,omitempty"`
75+
Delta *anthropicStreamEventDelta `json:"delta,omitempty"`
76+
Error *anthropicStreamEventError `json:"error,omitempty"`
77+
Usage *anthropicUsage `json:"usage,omitempty"`
78+
}
79+
80+
// SSE event represents a parsed Server-Sent Event
81+
type sseEvent struct {
82+
Event string // The event type field
83+
Data string // The data field
84+
}
85+
86+
// parseSSE reads and parses SSE format from a bufio.Reader
87+
func parseSSE(reader *bufio.Reader) (*sseEvent, error) {
88+
var event sseEvent
89+
90+
for {
91+
line, err := reader.ReadString('\n')
92+
if err != nil {
93+
return nil, err
94+
}
95+
96+
line = strings.TrimSpace(line)
97+
if line == "" {
98+
// Empty line signals end of event
99+
if event.Event != "" || event.Data != "" {
100+
return &event, nil
101+
}
102+
continue
103+
}
104+
105+
if strings.HasPrefix(line, "event:") {
106+
event.Event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
107+
} else if strings.HasPrefix(line, "data:") {
108+
event.Data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
109+
}
110+
}
111+
}
112+
113+
func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
114+
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
115+
116+
go func() {
117+
defer func() {
118+
if r := recover(); r != nil {
119+
// Convert panic to error and send it
120+
log.Printf("panic: %v\n", r)
121+
debug.PrintStack()
122+
err, ok := r.(error)
123+
if !ok {
124+
err = fmt.Errorf("anthropic backend panic: %v", r)
125+
}
126+
rtn <- makeAIError(err)
127+
}
128+
// Always close the channel
129+
close(rtn)
130+
}()
131+
132+
if request.Opts == nil {
133+
rtn <- makeAIError(errors.New("no anthropic opts found"))
134+
return
135+
}
136+
137+
model := request.Opts.Model
138+
if model == "" {
139+
model = "claude-3-sonnet-20240229" // default model
140+
}
141+
142+
// Convert messages format
143+
var messages []anthropicMessage
144+
var systemPrompt string
145+
146+
for _, msg := range request.Prompt {
147+
if msg.Role == "system" {
148+
if systemPrompt != "" {
149+
systemPrompt += "\n"
150+
}
151+
systemPrompt += msg.Content
152+
continue
153+
}
154+
155+
role := "user"
156+
if msg.Role == "assistant" {
157+
role = "assistant"
158+
}
159+
160+
messages = append(messages, anthropicMessage{
161+
Role: role,
162+
Content: msg.Content,
163+
})
164+
}
165+
166+
anthropicReq := anthropicRequest{
167+
Model: model,
168+
Messages: messages,
169+
System: systemPrompt,
170+
Stream: true,
171+
MaxTokens: request.Opts.MaxTokens,
172+
}
173+
174+
reqBody, err := json.Marshal(anthropicReq)
175+
if err != nil {
176+
rtn <- makeAIError(fmt.Errorf("failed to marshal anthropic request: %v", err))
177+
return
178+
}
179+
180+
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.anthropic.com/v1/messages", strings.NewReader(string(reqBody)))
181+
if err != nil {
182+
rtn <- makeAIError(fmt.Errorf("failed to create anthropic request: %v", err))
183+
return
184+
}
185+
186+
req.Header.Set("Content-Type", "application/json")
187+
req.Header.Set("Accept", "text/event-stream")
188+
req.Header.Set("x-api-key", request.Opts.APIToken)
189+
req.Header.Set("anthropic-version", "2023-06-01")
190+
191+
client := &http.Client{}
192+
resp, err := client.Do(req)
193+
if err != nil {
194+
rtn <- makeAIError(fmt.Errorf("failed to send anthropic request: %v", err))
195+
return
196+
}
197+
defer resp.Body.Close()
198+
199+
if resp.StatusCode != http.StatusOK {
200+
bodyBytes, _ := io.ReadAll(resp.Body)
201+
rtn <- makeAIError(fmt.Errorf("Anthropic API error: %s - %s", resp.Status, string(bodyBytes)))
202+
return
203+
}
204+
205+
reader := bufio.NewReader(resp.Body)
206+
for {
207+
// Check for context cancellation
208+
select {
209+
case <-ctx.Done():
210+
rtn <- makeAIError(fmt.Errorf("request cancelled: %v", ctx.Err()))
211+
return
212+
default:
213+
}
214+
215+
sse, err := parseSSE(reader)
216+
if err == io.EOF {
217+
break
218+
}
219+
if err != nil {
220+
rtn <- makeAIError(fmt.Errorf("error reading SSE stream: %v", err))
221+
break
222+
}
223+
224+
if sse.Event == "ping" {
225+
continue // Ignore ping events
226+
}
227+
228+
var event anthropicStreamEvent
229+
if err := json.Unmarshal([]byte(sse.Data), &event); err != nil {
230+
rtn <- makeAIError(fmt.Errorf("error unmarshaling event data: %v", err))
231+
break
232+
}
233+
234+
if event.Error != nil {
235+
rtn <- makeAIError(fmt.Errorf("Anthropic API error: %s - %s", event.Error.Type, event.Error.Message))
236+
break
237+
}
238+
239+
switch sse.Event {
240+
case "message_start":
241+
if event.Message != nil {
242+
pk := MakeOpenAIPacket()
243+
pk.Model = event.Message.Model
244+
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
245+
}
246+
247+
case "content_block_start":
248+
if event.ContentBlock != nil && event.ContentBlock.Text != "" {
249+
pk := MakeOpenAIPacket()
250+
pk.Text = event.ContentBlock.Text
251+
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
252+
}
253+
254+
case "content_block_delta":
255+
if event.Delta != nil && event.Delta.Text != "" {
256+
pk := MakeOpenAIPacket()
257+
pk.Text = event.Delta.Text
258+
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
259+
}
260+
261+
case "content_block_stop":
262+
// Note: According to the docs, this just signals the end of a content block
263+
// We might want to use this for tracking block boundaries, but for now
264+
// we don't need to send anything special to match OpenAI's format
265+
266+
case "message_delta":
267+
// Update message metadata, usage stats
268+
if event.Usage != nil {
269+
pk := MakeOpenAIPacket()
270+
pk.Usage = &wshrpc.OpenAIUsageType{
271+
PromptTokens: event.Usage.InputTokens,
272+
CompletionTokens: event.Usage.OutputTokens,
273+
TotalTokens: event.Usage.InputTokens + event.Usage.OutputTokens,
274+
}
275+
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
276+
}
277+
278+
case "message_stop":
279+
if event.Message != nil {
280+
pk := MakeOpenAIPacket()
281+
pk.FinishReason = event.Message.StopReason
282+
if event.Message.Usage != nil {
283+
pk.Usage = &wshrpc.OpenAIUsageType{
284+
PromptTokens: event.Message.Usage.InputTokens,
285+
CompletionTokens: event.Message.Usage.OutputTokens,
286+
TotalTokens: event.Message.Usage.InputTokens + event.Message.Usage.OutputTokens,
287+
}
288+
}
289+
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
290+
}
291+
292+
default:
293+
rtn <- makeAIError(fmt.Errorf("unknown Anthropic event type: %s", sse.Event))
294+
return
295+
}
296+
}
297+
}()
298+
299+
return rtn
300+
}

0 commit comments

Comments
 (0)