88 "errors"
99 "fmt"
1010 "io"
11+ "log"
1112 "regexp"
13+ "runtime/debug"
1214 "strings"
1315
1416 openaiapi "github.com/sashabaranov/go-openai"
@@ -72,7 +74,20 @@ func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType
7274func (OpenAIBackend ) StreamCompletion (ctx context.Context , request wshrpc.OpenAiStreamRequest ) chan wshrpc.RespOrErrorUnion [wshrpc.OpenAIPacketType ] {
7375 rtn := make (chan wshrpc.RespOrErrorUnion [wshrpc.OpenAIPacketType ])
7476 go func () {
75- defer close (rtn )
77+ defer func () {
78+ if r := recover (); r != nil {
79+ // Convert panic to error and send it
80+ log .Printf ("panic: %v\n " , r )
81+ debug .PrintStack ()
82+ err , ok := r .(error )
83+ if ! ok {
84+ err = fmt .Errorf ("openai backend panic: %v" , r )
85+ }
86+ rtn <- makeAIError (err )
87+ }
88+ // Always close the channel
89+ close (rtn )
90+ }()
7691 if request .Opts == nil {
7792 rtn <- makeAIError (errors .New ("no openai opts found" ))
7893 return
@@ -85,6 +100,7 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi
85100 rtn <- makeAIError (errors .New ("no api token" ))
86101 return
87102 }
103+
88104 clientConfig := openaiapi .DefaultConfig (request .Opts .APIToken )
89105 if request .Opts .BaseURL != "" {
90106 clientConfig .BaseURL = request .Opts .BaseURL
@@ -100,17 +116,49 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi
100116 if request .Opts .APIVersion != "" {
101117 clientConfig .APIVersion = request .Opts .APIVersion
102118 }
119+
103120 client := openaiapi .NewClientWithConfig (clientConfig )
104121 req := openaiapi.ChatCompletionRequest {
105- Model : request .Opts .Model ,
106- Messages : convertPrompt (request .Prompt ),
107- MaxTokens : request .Opts .MaxTokens ,
108- MaxCompletionTokens : request .Opts .MaxTokens ,
109- Stream : true ,
122+ Model : request .Opts .Model ,
123+ Messages : convertPrompt (request .Prompt ),
110124 }
125+
126+ // Handle o1 models differently - use non-streaming API
127+ if strings .HasPrefix (request .Opts .Model , "o1-" ) {
128+ req .MaxCompletionTokens = request .Opts .MaxTokens
129+ req .Stream = false
130+
131+ // Make non-streaming API call
132+ resp , err := client .CreateChatCompletion (ctx , req )
133+ if err != nil {
134+ rtn <- makeAIError (fmt .Errorf ("error calling openai API: %v" , err ))
135+ return
136+ }
137+
138+ // Send header packet
139+ headerPk := MakeOpenAIPacket ()
140+ headerPk .Model = resp .Model
141+ headerPk .Created = resp .Created
142+ rtn <- wshrpc.RespOrErrorUnion [wshrpc.OpenAIPacketType ]{Response : * headerPk }
143+
144+ // Send content packet(s)
145+ for i , choice := range resp .Choices {
146+ pk := MakeOpenAIPacket ()
147+ pk .Index = i
148+ pk .Text = choice .Message .Content
149+ pk .FinishReason = string (choice .FinishReason )
150+ rtn <- wshrpc.RespOrErrorUnion [wshrpc.OpenAIPacketType ]{Response : * pk }
151+ }
152+ return
153+ }
154+
155+ // Original streaming implementation for non-o1 models
156+ req .Stream = true
157+ req .MaxTokens = request .Opts .MaxTokens
111158 if request .Opts .MaxChoices > 1 {
112159 req .N = request .Opts .MaxChoices
113160 }
161+
114162 apiResp , err := client .CreateChatCompletionStream (ctx , req )
115163 if err != nil {
116164 rtn <- makeAIError (fmt .Errorf ("error calling openai API: %v" , err ))
0 commit comments