Skip to content

Commit 4ae460d

Browse files
committed
Fix tool calling logic in multiple iterations loop
- Prevent tool definition duplication in the system prompt during subsequent loop iterations. - Refine system prompt to enforce format more effectively.
1 parent 06f6693 commit 4ae460d

File tree

3 files changed

+49
-39
lines changed

3 files changed

+49
-39
lines changed

Examples/Examples/Chat/ChatExampleToolsSimple.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public async Task Start()
1717

1818
await AIHub.Chat()
1919
.WithModel("gpt-5-nano")
20-
.WithMessage("What time is it right now? Use tool provided.")
20+
.WithMessage("What time is it right now?")
2121
.WithTools(new ToolsConfigurationBuilder()
2222
.AddTool(
2323
name: "get_current_time",

Examples/Examples/Chat/ChatExampleToolsSimpleLocalLLM.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public async Task Start()
1414

1515
await AIHub.Chat()
1616
.WithModel("gemma3:4b")
17-
.WithMessage("What time is it right now? Use tool provided.")
17+
.WithMessage("What time is it right now?")
1818
.WithTools(new ToolsConfigurationBuilder()
1919
.AddTool(
2020
name: "get_current_time",

src/MaIN.Services/Services/LLMService/LLMService.cs

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
using MaIN.Services.Services.LLMService.Utils;
1717
using MaIN.Services.Services.Models;
1818
using MaIN.Services.Utils;
19-
using Microsoft.Extensions.Logging;
2019
using Microsoft.KernelMemory;
2120
using Grammar = LLama.Sampling.Grammar;
2221
using InferenceParams = MaIN.Domain.Entities.InferenceParams;
@@ -36,6 +35,11 @@ public class LLMService : ILLMService
3635
private readonly IMemoryFactory memoryFactory;
3736
private readonly string modelsPath;
3837

38+
private readonly JsonSerializerOptions _jsonToolOptions = new()
39+
{
40+
PropertyNameCaseInsensitive = true,
41+
};
42+
3943
public LLMService(
4044
MaINSettings options,
4145
INotificationService notificationService,
@@ -341,9 +345,10 @@ private static void ProcessTextMessage(Conversation conversation,
341345
}
342346
}
343347

344-
if (hasTools)
348+
if (hasTools && isNewConversation)
345349
{
346350
var toolsPrompt = FormatToolsForPrompt(chat.ToolsConfiguration!);
351+
// Dodaj to jako wiadomoœæ systemow¹ lub na pocz¹tku pierwszego promptu u¿ytkownika
347352
finalPrompt = $"{toolsPrompt}\n\n{finalPrompt}";
348353
}
349354

@@ -371,11 +376,14 @@ private static string FormatToolsForPrompt(ToolsConfiguration toolsConfig)
371376
sb.AppendLine($" Parameters: {JsonSerializer.Serialize(tool.Function.Parameters)}");
372377
}
373378

374-
sb.AppendLine("\n## RESPONSE FORMAT");
379+
sb.AppendLine("\n## RESPONSE FORMAT (YOU HAVE TO CHOOSE ONE FORMAT AND CANNOT MIX THEM)##");
375380
sb.AppendLine("1. For normal conversation, just respond with plain text.");
376-
sb.AppendLine("2. For tool calls, use this format:");
381+
sb.AppendLine("2. For tool calls, use this format. " +
382+
"You cannot respond with plain text before or after format. " +
383+
"If you want to call multiple functions, you have to combine them into one array." +
384+
"Your response MUST contain only one tool call block:");
377385
sb.AppendLine("<tool_call>");
378-
sb.AppendLine("{\"tool_calls\": [{\"id\": \"abc\", \"type\": \"function\", \"function\": {\"name\": \"fn\", \"arguments\": \"{\\\"p\\\":\\\"v\\\"}\"}}]}");
386+
sb.AppendLine("{\"tool_calls\": [{\"id\": \"call_1\", \"type\": \"function\", \"function\": {\"name\": \"tool_name\", \"arguments\": \"{\\\"param\\\":\\\"value\\\"}\"}},{\"id\": \"call_2\", \"type\": \"function\", \"function\": {\"name\": \"tool2_name\", \"arguments\": \"{\\\"param1\\\":\\\"value1\\\",\\\"param2\\\":\\\"value2\\\"}\"}}]}");
379387
sb.AppendLine("</tool_call>");
380388

381389
return sb.ToString();
@@ -385,9 +393,9 @@ private static string FormatToolsForPrompt(ToolsConfiguration toolsConfig)
385393
{
386394
if (string.IsNullOrWhiteSpace(response)) return null;
387395

396+
string jsonContent = ExtractJsonContent(response);
388397
try
389398
{
390-
string jsonContent = ExtractJsonContent(response);
391399
if (string.IsNullOrEmpty(jsonContent)) return null;
392400

393401
using var doc = JsonDocument.Parse(jsonContent);
@@ -396,7 +404,7 @@ private static string FormatToolsForPrompt(ToolsConfiguration toolsConfig)
396404
// OpenAI standard { "tool_calls": [...] }
397405
if (root.ValueKind == JsonValueKind.Object && root.TryGetProperty("tool_calls", out var toolCallsProp))
398406
{
399-
var calls = toolCallsProp.Deserialize<List<ToolCall>>(new JsonSerializerOptions { PropertyNameCaseInsensitive = true });
407+
var calls = toolCallsProp.Deserialize<List<ToolCall>>(_jsonToolOptions);
400408
return NormalizeToolCalls(calls);
401409
}
402410

@@ -417,7 +425,7 @@ private static string FormatToolsForPrompt(ToolsConfiguration toolsConfig)
417425
}
418426
catch (Exception)
419427
{
420-
// No tool calls found
428+
// No tool calls found no need to throw nor log
421429
}
422430

423431
return null;
@@ -429,14 +437,14 @@ private string ExtractJsonContent(string text)
429437

430438
int firstBrace = text.IndexOf('{');
431439
int firstBracket = text.IndexOf('[');
432-
int startIndex = (firstBrace >= 0 && firstBracket >= 0) ? Math.Min(firstBrace, firstBracket) : Math.Max(firstBrace, firstBracket);
440+
int startIndex = (firstBrace >= 0 && firstBracket >= 0) ? Math.Min(firstBrace, firstBracket) : Math.Max(firstBrace, firstBracket);
433441

434442
int lastBrace = text.LastIndexOf('}');
435443
int lastBracket = text.LastIndexOf(']');
436-
int endIndex = Math.Max(lastBrace, lastBracket);
444+
int endIndex = Math.Max(lastBrace, lastBracket);
437445

438-
if (startIndex >= 0 && endIndex > startIndex)
439-
{
446+
if (startIndex >= 0 && endIndex > startIndex)
447+
{
440448
return text.Substring(startIndex, endIndex - startIndex + 1);
441449
}
442450

@@ -648,34 +656,35 @@ private async Task<ChatResult> ProcessWithToolsAsync(
648656
Chat chat,
649657
ChatRequestOptions requestOptions,
650658
CancellationToken cancellationToken)
651-
{
659+
{
660+
NativeLogConfig.llama_log_set((level, message) => {
661+
if (level == LLamaLogLevel.Error)
662+
{
663+
Console.Error.Write(message);
664+
}
665+
}); // Remove llama native logging
666+
652667
var model = KnownModels.GetModel(chat.Model);
653668
var tokens = new List<LLMTokenValue>();
654669
var fullResponseBuilder = new StringBuilder();
655670
var iterations = 0;
656671

657672
while (iterations < MaxToolIterations)
658-
{
659-
if (iterations > 0 && requestOptions.InteractiveUpdates && fullResponseBuilder.Length > 0)
660-
{
661-
var spaceToken = new LLMTokenValue { Text = " ", Type = TokenType.Message };
662-
tokens.Add(spaceToken);
663-
664-
requestOptions.TokenCallback?.Invoke(spaceToken);
665-
666-
await notificationService.DispatchNotification(
667-
NotificationMessageBuilder.CreateChatCompletion(chat.Id, spaceToken, false),
668-
ServiceConstants.Notifications.ReceiveMessageUpdate);
669-
}
670-
673+
{
671674
var lastMsg = chat.Messages.Last();
675+
await SendNotification(chat.Id, new LLMTokenValue
676+
{
677+
Type = TokenType.FullAnswer,
678+
Text = $"Processing with tools... iteration {iterations + 1}\n\n"
679+
}, false);
680+
requestOptions.InteractiveUpdates = false;
672681
var iterationTokens = await ProcessChatRequest(chat, model, lastMsg, requestOptions, cancellationToken);
673682

674683
var responseText = string.Concat(iterationTokens.Select(x => x.Text));
675684

676685
if (fullResponseBuilder.Length > 0)
677686
{
678-
fullResponseBuilder.Append(" ");
687+
fullResponseBuilder.Append('\n');
679688
}
680689
fullResponseBuilder.Append(responseText);
681690
tokens.AddRange(iterationTokens);
@@ -684,6 +693,12 @@ await notificationService.DispatchNotification(
684693

685694
if (toolCalls == null || !toolCalls.Any())
686695
{
696+
requestOptions.InteractiveUpdates = true;
697+
await SendNotification(chat.Id, new LLMTokenValue
698+
{
699+
Type = TokenType.FullAnswer,
700+
Text = responseText
701+
}, false);
687702
break;
688703
}
689704

@@ -768,19 +783,14 @@ await notificationService.DispatchNotification(
768783

769784
if (iterations >= MaxToolIterations)
770785
{
786+
await SendNotification(chat.Id, new LLMTokenValue
787+
{
788+
Type = TokenType.FullAnswer,
789+
Text = "Maximum tool invocation iterations reached. Ending the conversation."
790+
}, false);
771791
}
772792

773793
var finalResponse = fullResponseBuilder.ToString();
774-
var finalToken = new LLMTokenValue { Text = finalResponse, Type = TokenType.FullAnswer };
775-
tokens.Add(finalToken);
776-
777-
if (requestOptions.InteractiveUpdates)
778-
{
779-
await notificationService.DispatchNotification(
780-
NotificationMessageBuilder.CreateChatCompletion(chat.Id, finalToken, true),
781-
ServiceConstants.Notifications.ReceiveMessageUpdate);
782-
}
783-
784794
chat.Messages.Last().MarkProcessed();
785795

786796
return new ChatResult

0 commit comments

Comments
 (0)