Skip to content

Commit 1ecdc2e

Browse files
Fix parallel_tool_cals parameter not being sent
1 parent d751426 commit 1ecdc2e

2 files changed

Lines changed: 52 additions & 51 deletions

File tree

Kattbot.Common/Models/KattGpt/ChatCompletionCreateRequest.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public record ChatCompletionCreateRequest
5050
/// https://platform.openai.com/docs/api-reference/chat/create#chat-create-parallel_tool_calls
5151
/// </summary>
5252
[JsonPropertyName("parallel_tool_calls")]
53-
public bool ParallelToolCalls { get; set; }
53+
public bool? ParallelToolCalls { get; set; }
5454

5555
/// <summary>
5656
/// Gets or sets what sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more
@@ -135,4 +135,4 @@ public record ChatCompletionCreateRequest
135135
/// </summary>
136136
[JsonPropertyName("user")]
137137
public string? User { get; set; }
138-
}
138+
}

Kattbot/NotificationHandlers/KattGptMessageHandler.cs

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -65,35 +65,35 @@ public KattGptMessageHandler(
6565

6666
public async Task Handle(MessageCreatedNotification notification, CancellationToken cancellationToken)
6767
{
68-
MessageCreatedEventArgs args = notification.EventArgs;
69-
DiscordMessage message = args.Message;
70-
DiscordUser author = args.Author;
71-
DiscordChannel channel = args.Message.Channel ?? throw new Exception("Channel is null.");
68+
MessageCreatedEventArgs? args = notification.EventArgs;
69+
DiscordMessage? message = args.Message;
70+
DiscordUser? author = args.Author;
71+
DiscordChannel? channel = args.Message.Channel ?? throw new Exception("Channel is null.");
7272

7373
if (!ShouldHandleMessage(message)) return;
7474

7575
try
7676
{
77-
List<ChatCompletionMessage> systemPromptsMessages = _kattGptService.BuildSystemPromptsMessages(channel);
78-
ChatCompletionFunction chatCompletionFunction = DalleToolBuilder.BuildDalleImageToolDefinition().Function;
77+
List<ChatCompletionMessage>? systemPromptsMessages = _kattGptService.BuildSystemPromptsMessages(channel);
78+
ChatCompletionFunction? chatCompletionFunction = DalleToolBuilder.BuildDalleImageToolDefinition().Function;
7979
List<ChatCompletionMessage> newContextMessages = [];
8080

81-
KattGptChannelContext channelContext = GetOrCreateCachedContext(
81+
KattGptChannelContext? channelContext = GetOrCreateCachedContext(
8282
channel,
8383
systemPromptsMessages,
8484
chatCompletionFunction);
8585

8686
bool shouldReplyToMessage = ShouldReplyToMessage(message);
8787

88-
string recipientMarker = shouldReplyToMessage
88+
string? recipientMarker = shouldReplyToMessage
8989
? RecipientMarkerToYou
9090
: RecipientMarkerToOthers;
9191

9292
// Add new message from notification
93-
string newMessageUser = author.GetDisplayName();
94-
string newMessageContent = message.SubstituteMentions();
93+
string? newMessageUser = author.GetDisplayName();
94+
string? newMessageContent = message.SubstituteMentions();
9595

96-
ChatCompletionMessage newUserMessage =
96+
ChatCompletionMessage? newUserMessage =
9797
ChatCompletionMessage.AsUser($"{newMessageUser}{recipientMarker}: {newMessageContent}");
9898

9999
newContextMessages.Add(newUserMessage);
@@ -106,22 +106,22 @@ public async Task Handle(MessageCreatedNotification notification, CancellationTo
106106

107107
await channel.TriggerTypingAsync();
108108

109-
ChatCompletionCreateRequest request = BuildRequest(
109+
ChatCompletionCreateRequest? request = BuildRequest(
110110
systemPromptsMessages,
111111
channelContext,
112112
allowToolCalls: true,
113113
newUserMessage);
114114

115-
ChatCompletionCreateResponse response = await _chatGpt.ChatCompletionCreate(request);
115+
ChatCompletionCreateResponse? response = await _chatGpt.ChatCompletionCreate(request);
116116

117-
ChatCompletionChoice chatGptResponse = response.Choices[0];
118-
ChatCompletionMessage chatGptResponseMessage = chatGptResponse.Message;
117+
ChatCompletionChoice? chatGptResponse = response.Choices[0];
118+
ChatCompletionMessage? chatGptResponseMessage = chatGptResponse.Message;
119119

120120
newContextMessages.Add(chatGptResponseMessage);
121121

122122
if (chatGptResponse.FinishReason == ChoiceFinishReason.tool_calls)
123123
{
124-
List<ChatCompletionMessage> toolResponseMessages = await HandleToolCallResponse(
124+
List<ChatCompletionMessage>? toolResponseMessages = await HandleToolCallResponse(
125125
message,
126126
systemPromptsMessages,
127127
channelContext,
@@ -156,6 +156,9 @@ private static ChatCompletionCreateRequest BuildRequest(
156156
? [DalleToolBuilder.BuildDalleImageToolDefinition()]
157157
: null;
158158

159+
// Not allowed to include parallel tool calls field when tools is null
160+
bool? parallelToolCalls = allowToolCalls ? false : null;
161+
159162
// Collect request messages
160163
var requestMessages = new List<ChatCompletionMessage>();
161164
requestMessages.AddRange(systemPromptsMessages);
@@ -170,7 +173,7 @@ private static ChatCompletionCreateRequest BuildRequest(
170173
Temperature = DefaultTemperature,
171174
MaxTokens = MaxTokensToGenerate,
172175
Tools = chatCompletionTools,
173-
ParallelToolCalls = false,
176+
ParallelToolCalls = parallelToolCalls,
174177
};
175178

176179
return request;
@@ -184,13 +187,13 @@ private static async Task SendImageReply(
184187
{
185188
const int maxFilenameLength = 32;
186189

187-
string truncatedFilename = filename.Length > maxFilenameLength
190+
string? truncatedFilename = filename.Length > maxFilenameLength
188191
? filename[..maxFilenameLength]
189192
: filename;
190193

191-
string safeFilename = truncatedFilename.ToSafeFilename(imageStream.FileExtension);
194+
string? safeFilename = truncatedFilename.ToSafeFilename(imageStream.FileExtension);
192195

193-
DiscordMessageBuilder mb = new DiscordMessageBuilder()
196+
DiscordMessageBuilder? mb = new DiscordMessageBuilder()
194197
.AddFile(safeFilename, imageStream.MemoryStream)
195198
.WithContent(responseMessageText);
196199

@@ -199,11 +202,11 @@ private static async Task SendImageReply(
199202

200203
private static async Task SendTextReply(string responseMessage, DiscordMessage messageToReplyTo)
201204
{
202-
List<string> messageChunks = responseMessage.SplitString(DiscordConstants.MaxMessageLength, MessageSplitToken);
205+
List<string>? messageChunks = responseMessage.SplitString(DiscordConstants.MaxMessageLength, MessageSplitToken);
203206

204-
DiscordMessage nextMessageToReplyTo = messageToReplyTo;
207+
DiscordMessage? nextMessageToReplyTo = messageToReplyTo;
205208

206-
foreach (string messageChunk in messageChunks)
209+
foreach (string? messageChunk in messageChunks)
207210
{
208211
nextMessageToReplyTo = await nextMessageToReplyTo.RespondAsync(messageChunk);
209212
}
@@ -214,11 +217,11 @@ private static async Task SendToolUseReply(
214217
ChatCompletionMessage chatGptToolCallResponse,
215218
string prompt)
216219
{
217-
string toolUseText = string.Format(MessageToolUseTemplate, prompt);
218-
string responseMessageText = chatGptToolCallResponse.Content ?? string.Empty;
220+
string? toolUseText = string.Format(MessageToolUseTemplate, prompt);
221+
string? responseMessageText = chatGptToolCallResponse.Content ?? string.Empty;
219222

220223
// Tool call messages have content only sometimes
221-
string responseTextWithToolUse = !string.IsNullOrWhiteSpace(responseMessageText)
224+
string? responseTextWithToolUse = !string.IsNullOrWhiteSpace(responseMessageText)
222225
? $"{responseMessageText.TrimEnd()}\n\n{toolUseText}"
223226
: toolUseText;
224227

@@ -234,14 +237,14 @@ private async Task<ImageStreamResult> GetDalleResult(string prompt, string userI
234237
User = userId,
235238
};
236239

237-
CreateImageResponse response = await _dalleHttpClient.CreateImage(imageRequest);
240+
CreateImageResponse? response = await _dalleHttpClient.CreateImage(imageRequest);
238241
if (response.Data == null || !response.Data.Any()) throw new Exception("Empty result");
239242

240-
ImageResponseUrlData imageUrl = response.Data.First();
243+
ImageResponseUrlData? imageUrl = response.Data.First();
241244

242-
Image image = await _imageService.DownloadImage(imageUrl.Url);
245+
Image? image = await _imageService.DownloadImage(imageUrl.Url);
243246

244-
ImageStreamResult imageStream = await _imageService.GetImageStream(image);
247+
ImageStreamResult? imageStream = await _imageService.GetImageStream(image);
245248

246249
return imageStream;
247250
}
@@ -254,50 +257,48 @@ private async Task<List<ChatCompletionMessage>> HandleToolCallResponse(
254257
{
255258
List<ChatCompletionMessage> responseMessages = [];
256259

257-
ChatCompletionToolCall toolCall =
260+
ChatCompletionToolCall? toolCall =
258261
chatGptToolCallResponse.ToolCalls?[0] ?? throw new Exception("Tool call is null.");
259262

260263
if (chatGptToolCallResponse.ToolCalls.Count > 1)
261-
{
262264
throw new Exception($"Too many tool calls: {chatGptToolCallResponse.ToolCalls.Count.ToString()}");
263-
}
264265

265266
// Parse the function call arguments
266-
string functionCallArguments = toolCall.FunctionCall.Arguments;
267+
string? functionCallArguments = toolCall.FunctionCall.Arguments;
267268

268-
JsonNode parsedArguments = JsonNode.Parse(functionCallArguments)
269-
?? throw new Exception("Could not parse function call arguments.");
269+
JsonNode? parsedArguments = JsonNode.Parse(functionCallArguments)
270+
?? throw new Exception("Could not parse function call arguments.");
270271

271-
string prompt = parsedArguments["prompt"]?.GetValue<string>()
272-
?? throw new Exception("Function call arguments are invalid.");
272+
string? prompt = parsedArguments["prompt"]?.GetValue<string>()
273+
?? throw new Exception("Function call arguments are invalid.");
273274

274275
// Send the tool use message as a confirmation
275276
await SendToolUseReply(message, chatGptToolCallResponse, prompt);
276277

277278
var authorId = message.Author!.Id.ToString();
278279

279-
ImageStreamResult dalleResult = await GetDalleResult(prompt, authorId);
280+
ImageStreamResult? dalleResult = await GetDalleResult(prompt, authorId);
280281

281282
// Build the function call result message
282283
var functionCallResult = $"An image of {prompt} has been generated and attached to this message.";
283284

284-
ChatCompletionMessage functionCallResultMessage =
285+
ChatCompletionMessage? functionCallResultMessage =
285286
ChatCompletionMessage.AsToolCallResult(functionCallResult, toolCall.Id);
286287

287288
// Force a content value for the ChatGPT response due the api not allowing nulls even though it says it does
288289
chatGptToolCallResponse.Content ??= "null";
289290

290-
ChatCompletionCreateRequest request = BuildRequest(
291+
ChatCompletionCreateRequest? request = BuildRequest(
291292
systemPromptsMessages,
292293
channelContext,
293294
allowToolCalls: false,
294295
chatGptToolCallResponse,
295296
functionCallResultMessage);
296297

297-
ChatCompletionCreateResponse response = await _chatGpt.ChatCompletionCreate(request);
298+
ChatCompletionCreateResponse? response = await _chatGpt.ChatCompletionCreate(request);
298299

299300
// Handle new response
300-
ChatCompletionMessage functionCallResponse = response.Choices[0].Message;
301+
ChatCompletionMessage? functionCallResponse = response.Choices[0].Message;
301302

302303
await SendImageReply(functionCallResponse.Content!, message, prompt, dalleResult);
303304

@@ -320,7 +321,7 @@ private KattGptChannelContext GetOrCreateCachedContext(
320321
List<ChatCompletionMessage> systemPromptsMessages,
321322
ChatCompletionFunction chatCompletionFunction)
322323
{
323-
string cacheKey = KattGptChannelCache.KattGptChannelCacheKey(channel.Id);
324+
string? cacheKey = KattGptChannelCache.KattGptChannelCacheKey(channel.Id);
324325

325326
KattGptChannelContext? channelContext = _cache.GetCache(cacheKey);
326327

@@ -357,14 +358,14 @@ private bool ShouldHandleMessage(DiscordMessage message)
357358
if (!IsRelevantMessage(message)) return false;
358359

359360
string[] commandPrefixes = [_botOptions.CommandPrefix, _botOptions.AlternateCommandPrefix];
360-
string messageContent = message.Content.ToLower().TrimStart();
361+
string? messageContent = message.Content.ToLower().TrimStart();
361362

362363
bool messageStartsWithCommandPrefix = commandPrefixes.Any(messageContent.StartsWith);
363364

364365
if (messageStartsWithCommandPrefix)
365366
return false;
366367

367-
DiscordChannel channel = message.Channel!;
368+
DiscordChannel? channel = message.Channel!;
368369

369370
ChannelOptions? channelOptions = _kattGptService.GetChannelOptions(channel);
370371

@@ -374,7 +375,7 @@ private bool ShouldHandleMessage(DiscordMessage message)
374375
if (!channelOptions.AlwaysOn) return true;
375376

376377
// otherwise check if the message does not start with the MetaMessagePrefix
377-
string[] metaMessagePrefixes = _kattGptOptions.AlwaysOnIgnoreMessagePrefixes;
378+
string[]? metaMessagePrefixes = _kattGptOptions.AlwaysOnIgnoreMessagePrefixes;
378379
bool messageStartsWithMetaMessagePrefix = metaMessagePrefixes.Any(messageContent.StartsWith);
379380

380381
// if it does, return false
@@ -388,7 +389,7 @@ private bool ShouldHandleMessage(DiscordMessage message)
388389
/// <returns>True if KattGpt should reply.</returns>
389390
private bool ShouldReplyToMessage(DiscordMessage message)
390391
{
391-
DiscordChannel channel = message.Channel!;
392+
DiscordChannel? channel = message.Channel!;
392393

393394
ChannelOptions? channelOptions = _kattGptService.GetChannelOptions(channel);
394395

@@ -409,7 +410,7 @@ private bool ShouldReplyToMessage(DiscordMessage message)
409410
}
410411

411412
// otherwise check if the message does not start with the MetaMessagePrefix
412-
string[] metaMessagePrefixes = _kattGptOptions.AlwaysOnIgnoreMessagePrefixes;
413+
string[]? metaMessagePrefixes = _kattGptOptions.AlwaysOnIgnoreMessagePrefixes;
413414
bool messageStartsWithMetaMessagePrefix = metaMessagePrefixes.Any(message.Content.TrimStart().StartsWith);
414415

415416
// if it does, return false

0 commit comments

Comments
 (0)