Skip to content

Commit 95fd006

Browse files
authored
fixes for o1 models (#1269)
1 parent 29e54c8 commit 95fd006

2 files changed

Lines changed: 56 additions & 7 deletions

File tree

pkg/waveai/openaibackend.go

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import (
88
"errors"
99
"fmt"
1010
"io"
11+
"log"
1112
"regexp"
13+
"runtime/debug"
1214
"strings"
1315

1416
openaiapi "github.com/sashabaranov/go-openai"
@@ -72,7 +74,20 @@ func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType
7274
func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
7375
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
7476
go func() {
75-
defer close(rtn)
77+
defer func() {
78+
if r := recover(); r != nil {
79+
// Convert panic to error and send it
80+
log.Printf("panic: %v\n", r)
81+
debug.PrintStack()
82+
err, ok := r.(error)
83+
if !ok {
84+
err = fmt.Errorf("openai backend panic: %v", r)
85+
}
86+
rtn <- makeAIError(err)
87+
}
88+
// Always close the channel
89+
close(rtn)
90+
}()
7691
if request.Opts == nil {
7792
rtn <- makeAIError(errors.New("no openai opts found"))
7893
return
@@ -85,6 +100,7 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi
85100
rtn <- makeAIError(errors.New("no api token"))
86101
return
87102
}
103+
88104
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
89105
if request.Opts.BaseURL != "" {
90106
clientConfig.BaseURL = request.Opts.BaseURL
@@ -100,17 +116,49 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi
100116
if request.Opts.APIVersion != "" {
101117
clientConfig.APIVersion = request.Opts.APIVersion
102118
}
119+
103120
client := openaiapi.NewClientWithConfig(clientConfig)
104121
req := openaiapi.ChatCompletionRequest{
105-
Model: request.Opts.Model,
106-
Messages: convertPrompt(request.Prompt),
107-
MaxTokens: request.Opts.MaxTokens,
108-
MaxCompletionTokens: request.Opts.MaxTokens,
109-
Stream: true,
122+
Model: request.Opts.Model,
123+
Messages: convertPrompt(request.Prompt),
110124
}
125+
126+
// Handle o1 models differently - use non-streaming API
127+
if strings.HasPrefix(request.Opts.Model, "o1-") {
128+
req.MaxCompletionTokens = request.Opts.MaxTokens
129+
req.Stream = false
130+
131+
// Make non-streaming API call
132+
resp, err := client.CreateChatCompletion(ctx, req)
133+
if err != nil {
134+
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))
135+
return
136+
}
137+
138+
// Send header packet
139+
headerPk := MakeOpenAIPacket()
140+
headerPk.Model = resp.Model
141+
headerPk.Created = resp.Created
142+
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *headerPk}
143+
144+
// Send content packet(s)
145+
for i, choice := range resp.Choices {
146+
pk := MakeOpenAIPacket()
147+
pk.Index = i
148+
pk.Text = choice.Message.Content
149+
pk.FinishReason = string(choice.FinishReason)
150+
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk}
151+
}
152+
return
153+
}
154+
155+
// Original streaming implementation for non-o1 models
156+
req.Stream = true
157+
req.MaxTokens = request.Opts.MaxTokens
111158
if request.Opts.MaxChoices > 1 {
112159
req.N = request.Opts.MaxChoices
113160
}
161+
114162
apiResp, err := client.CreateChatCompletionStream(ctx, req)
115163
if err != nil {
116164
rtn <- makeAIError(fmt.Errorf("error calling openai API: %v", err))

pkg/waveai/waveai.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ const OpenAIPacketStr = "openai"
1515
const OpenAICloudReqStr = "openai-cloudreq"
1616
const PacketEOFStr = "EOF"
1717
const DefaultAzureAPIVersion = "2023-05-15"
18+
const ApiType_Anthropic = "anthropic"
1819

1920
type OpenAICmdInfoPacketOutputType struct {
2021
Model string `json:"model,omitempty"`
@@ -62,7 +63,7 @@ func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
6263
}
6364

6465
func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
65-
if request.Opts.APIType == "anthropic" {
66+
if request.Opts.APIType == ApiType_Anthropic {
6667
endpoint := request.Opts.BaseURL
6768
if endpoint == "" {
6869
endpoint = "default"

0 commit comments

Comments
 (0)